Source code for qbraid.visualization.plot_conversions

# Copyright (C) 2024 qBraid
#
# This file is part of the qBraid-SDK
#
# The qBraid-SDK is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3.
"""
Module for plotting qBraid transpiler quantum program conversion graphs.

"""
from __future__ import annotations

import math
import warnings
from typing import TYPE_CHECKING, Optional, Union

import rustworkx as rx
from qbraid_core._import import LazyLoader
from rustworkx.visualization import mpl_draw

from qbraid.programs.experiment import ExperimentType
from qbraid.programs.registry import is_registered_alias_native

if TYPE_CHECKING:
    import matplotlib.pyplot

    import qbraid.transpiler

plt: matplotlib.pyplot = LazyLoader("plt", globals(), "matplotlib.pyplot")

transpiler: qbraid.transpiler = LazyLoader("transpiler", globals(), "qbraid.transpiler")


[docs] def plot_conversion_graph( # pylint: disable=too-many-arguments graph: qbraid.transpiler.ConversionGraph, title: Optional[str] = "qBraid Quantum Program Conversion Graph", legend: bool = False, seed: Optional[int] = None, node_size: int = 1200, min_target_margin: int = 18, show: bool = True, save_path: Optional[str] = None, colors: Optional[dict[str, str]] = None, edge_labels: bool = False, experiment_type: Optional[Union[ExperimentType, list[ExperimentType]]] = None, **kwargs, ) -> None: """ Plot the conversion graph using matplotlib. The graph is displayed using node and edge color conventions, with options for a title, legend, and figure saving. Args: graph (ConversionGraph): The directed conversion graph to be plotted. title (str, optional): Title of the plot. Defaults to 'qBraid Quantum Program Conversion Graph'. legend (bool): If True, display a legend on the graph. Defaults to False. seed (int, optional): Seed for the node layout algorithm. Useful for consistent positioning. Defaults to None. node_size (int): Size of the nodes. Defaults to 1200. min_target_margin (int): Minimum target margin for edges. Defaults to 18. show (bool): If True, display the figure. Defaults to True. save_path (str, optional): Path to save the figure. If None, the figure is not saved. Defaults to None. colors (dict[str, str], optional): Dictionary for node and edge colors. Expected keys are 'qbraid_node', 'external_node', 'qbraid_edge', 'external_edge'. Defaults to None. edge_labels (bool): If True, display edge weights as labels. Defaults to False. experiment_type (Union[ExperimentType, list[ExperimentType]], optional): Filter the graph by experiment type. Defaults to None, meaning all experiment types are included. Returns: None """ if colors is None: colors = { "qbraid_node": "lightblue", "external_node": "lightgray", "qbraid_edge": "gray", "external_edge": "blue", "extras_edge": "red", } if experiment_type: node_experiment_types = graph.get_node_experiment_types() exp_types = experiment_type if isinstance(experiment_type, list) else [experiment_type] nodes = [n for n in graph.nodes() if node_experiment_types[n] in exp_types] if not nodes: exp_type_names = ", ".join([exp_type.name for exp_type in exp_types]) raise ValueError( f"No program type nodes found with experiment type(s) '{exp_type_names}'. " "Use ConversionGraph.get_node_experiment_types() to inspect all experiment " "type mappings in this graph." ) graph = transpiler.ConversionGraph( conversions=graph.conversions(), require_native=graph.require_native, include_isolated=graph._include_isolated, edge_bias=graph.edge_bias, nodes=nodes, ) ncolors = [ colors["qbraid_node"] if is_registered_alias_native(node) else colors["external_node"] for node in graph.nodes() ] conversion_dict = { (conversion.source, conversion.target): conversion for conversion in graph.conversions() } conversions_ordered = [ conversion_dict[(graph.get_node_data(edge[0]), graph.get_node_data(edge[1]))] for edge in graph.edge_list() if (graph.get_node_data(edge[0]), graph.get_node_data(edge[1])) in conversion_dict ] ecolors = [ ( colors["qbraid_edge"] if graph.get_edge_data( graph._node_alias_id_map[edge.source], graph._node_alias_id_map[edge.target] )["native"] else colors["extras_edge"] if len(edge._extras) > 0 else colors["external_edge"] ) for edge in conversions_ordered ] rustworkx_version = rx.__version__ # pylint: disable=no-member if len(set(ecolors)) > 1 and rustworkx_version in ["0.15.0", "0.15.1"]: warnings.warn( "Detected multiple edge colors, which may not display correctly " "due to a known bug in rustworkx versions 0.15.0 and 0.15.1 " "(see: https://github.com/Qiskit/rustworkx/issues/1308). " "To avoid this issue, please upgrade to rustworkx>0.15.1.", UserWarning, ) k = kwargs.pop("k", max(1 / math.sqrt(len(graph.nodes())), 3)) pos = rx.spring_layout(graph, seed=seed, k=k, **kwargs) # good seeds: 123, 134 kwargs = {} if edge_labels: kwargs["edge_labels"] = lambda edge: round(edge["weight"], 2) mpl_draw( graph, pos, node_color=ncolors, edge_color=ecolors, node_size=node_size, with_labels=True, labels=str, min_target_margin=min_target_margin, **kwargs, ) if title: plt.title(title) plt.axis("off") if legend: legend_info = [ ("qBraid - Node", "o", colors["qbraid_node"], None), ("External - Node", "o", colors["external_node"], None), ("qBraid - Edge", None, colors["qbraid_edge"], "-"), ("Extras - Edge", None, colors["extras_edge"], "-"), ("External - Edge", None, colors["external_edge"], "-"), ] legend_elements = [ plt.Line2D( [0], [0], marker=marker, color="w" if marker else color, label=label, markersize=10 if marker else None, markerfacecolor=color if marker else None, linestyle=linestyle, linewidth=2 if linestyle else None, ) for label, marker, color, linestyle in legend_info ] plt.legend(handles=legend_elements, loc="best") if show: plt.show() if save_path: plt.savefig(save_path)