# Copyright 2025 qBraid
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=import-outside-toplevel
"""
Functions for drawing quantum circuits.
"""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
from openqasm3 import ast
from pyqasm.expressions import Qasm3ExprEvaluator
from pyqasm.maps.gates import (
ONE_QUBIT_OP_MAP,
ONE_QUBIT_ROTATION_MAP,
REV_CTRL_GATE_MAP,
TWO_QUBIT_OP_MAP,
)
if TYPE_CHECKING:
import matplotlib.pyplot as plt
from pyqasm.modules.base import QasmModule
# Constants
DEFAULT_GATE_COLOR = "#d4b6e8"
HADAMARD_GATE_COLOR = "#f0a6a6"
FIG_MAX_WIDTH = 12
GATE_BOX_WIDTH = 0.6
GATE_BOX_HEIGHT = 0.6
GATE_SPACING = 0.2
LINE_SPACING = 0.6
TEXT_MARGIN = 0.6
FRAME_PADDING = 0.2
BOX_STYLE = "round,pad=0.02,rounding_size=0.05"
Declaration = (
ast.CalibrationGrammarDeclaration
| ast.ClassicalDeclaration
| ast.ConstantDeclaration
| ast.ExternDeclaration
| ast.IODeclaration
| ast.QubitDeclaration
)
QuantumStatement = (
ast.QuantumGate | ast.QuantumMeasurementStatement | ast.QuantumBarrier | ast.QuantumReset
)
QubitIdentifier = ast.Identifier | ast.IndexedIdentifier
[docs]
def draw(
program: str | QasmModule,
output: Literal["mpl"] = "mpl",
idle_wires: bool = True,
**kwargs: Any,
) -> None:
"""Draw the quantum circuit.
Args:
module (QasmModule): The quantum module to draw
output (str): The output format. Defaults to "mpl".
idle_wires (bool): Whether to show idle wires. Defaults to True.
Returns:
None: The drawing is displayed or saved to a file.
"""
if isinstance(program, str):
from pyqasm.entrypoint import loads
program = loads(program)
if output == "mpl":
_ = mpl_draw(program, idle_wires=idle_wires, external_draw=False, **kwargs)
import matplotlib.pyplot as plt
if not plt.isinteractive():
plt.show()
else:
raise ValueError(f"Unsupported output format: {output}")
def mpl_draw( # pylint: disable=too-many-locals
program: str | QasmModule,
idle_wires: bool = True,
filename: str | Path | None = None,
external_draw: bool = True,
) -> plt.Figure:
"""Internal matplotlib drawing implementation."""
if isinstance(program, str):
from pyqasm.entrypoint import loads
program = loads(program)
try:
# pylint: disable-next=unused-import
import matplotlib.pyplot as plt
plt.ioff()
except ImportError as e:
raise ImportError(
"matplotlib needs to be installed prior to running pyqasm.mpl_draw(). "
"You can install matplotlib with:\n'pip install matplotlib'"
) from e
program.unroll()
program.remove_includes()
line_nums, sizes = _compute_line_nums(program)
global_phase = 0
statements: list[ast.Statement | ast.Pragma] = []
for s in program._statements:
if isinstance(s, ast.QuantumPhase):
global_phase += Qasm3ExprEvaluator.evaluate_expression(s.argument)[0]
else:
statements.append(s)
# Compute moments
moments, depths = _compute_moments(statements, line_nums)
if not idle_wires:
# remove all lines that are not used
ks = sorted(line_nums.keys(), key=lambda k: line_nums[k])
ks = [k for k in ks if depths[k] > 0]
line_nums = {k: i for i, k in enumerate(ks)}
fig = _mpl_draw(program, moments, line_nums, sizes, global_phase)
if filename is not None:
plt.savefig(filename, bbox_inches="tight", dpi=300)
if external_draw:
plt.close(fig)
return fig
def _compute_line_nums(
module: QasmModule,
) -> tuple[dict[tuple[str, int], int], dict[tuple[str, int], int]]:
"""Compute line number and register size lookup table for the circuit."""
line_nums = {}
sizes = {}
line_num = -1
max_depth = 0
# Classical registers condensed to single line
for k in module._classical_registers:
line_num += 1
line_nums[(k, -1)] = line_num
sizes[(k, -1)] = module._classical_registers[k]
# Calculate qubit lines and depths
for qubit_reg in module._qubit_registers:
size = module._qubit_registers[qubit_reg]
line_num += size
for i in range(size):
line_nums[(qubit_reg, i)] = line_num
depth = module._qubit_depths[(qubit_reg, i)]._total_ops()
max_depth = max(max_depth, depth)
line_num -= 1
line_num += size
return line_nums, sizes
def _get_line_keys(
all_keys: list[tuple[str, int]], line_nums: dict[tuple[str, int], int]
) -> list[tuple[str, int]]:
if not all_keys:
return []
# Get line numbers for the given keys and find min/max
key_line_nums = [line_nums[key] for key in all_keys]
start_key, end_key = min(key_line_nums), max(key_line_nums)
# Return all keys within the range
return [key for key, line_num in line_nums.items() if start_key <= line_num <= end_key]
# pylint: disable-next=too-many-branches,too-many-locals
def _compute_moments(
statements: list[ast.Statement | ast.Pragma], line_nums: dict[tuple[str, int], int]
) -> tuple[list[list[QuantumStatement]], dict[tuple[str, int], int]]:
depths = {}
for k in line_nums:
depths[k] = -1
moments: list[list[QuantumStatement]] = []
# Find and remove final measurements (measurements at the end of the statements list)
final_measurements: list[ast.QuantumMeasurementStatement] = []
seen_keys = set[tuple[str, int]]()
for stmt in reversed(statements):
if not isinstance(stmt, ast.QuantumMeasurementStatement):
break
control_key = _identifier_to_key(stmt.measure.qubit)
if control_key not in seen_keys:
seen_keys.add(control_key)
final_measurements.append(stmt)
# Remove the final measurements from the end of statements list
statements = statements[: -len(final_measurements)] if final_measurements else statements
for statement in statements:
if isinstance(statement, Declaration):
continue
if not isinstance(statement, QuantumStatement):
raise ValueError(f"Unsupported statement: {statement}")
if isinstance(statement, ast.QuantumGate):
qubits = [_identifier_to_key(q) for q in statement.qubits]
# Get line keys for multi-qubit gates, otherwise use qubits directly
target_keys = _get_line_keys(qubits, line_nums) if len(qubits) > 1 else qubits
# Calculate new depth and update all affected keys
depth = 1 + max(depths[key] for key in target_keys)
for key in target_keys:
depths[key] = depth
elif isinstance(statement, ast.QuantumMeasurementStatement):
keys = [_identifier_to_key(statement.measure.qubit)]
if statement.target:
target_key = _identifier_to_key(statement.target)[0], -1
keys.append(target_key)
line_keys = _get_line_keys(keys, line_nums)
# Calculate new depth and update all affected keys
depth = 1 + max(depths[key] for key in line_keys)
for key in line_keys:
depths[key] = depth
elif isinstance(statement, ast.QuantumBarrier):
qubits = []
for expr in statement.qubits:
# https://github.com/openqasm/openqasm/issues/461
if not isinstance(expr, QubitIdentifier):
raise ValueError(
f"Unsupported qubit type '{type(expr).__name__}' in "
f"'{type(statement).__name__}' statement. "
f"Expected a qubit of type {QubitIdentifier}."
)
qubits.append(_identifier_to_key(expr))
depth = 1 + max(depths[q] for q in qubits)
for q in qubits:
depths[q] = depth
elif isinstance(statement, ast.QuantumReset):
qubit_key = _identifier_to_key(statement.qubits)
depth = 1 + depths[qubit_key]
depths[qubit_key] = depth
if depth >= len(moments):
moments.append([])
moments[depth].append(statement)
depth = max(depths.values())
for measurement in final_measurements:
depth += 1
if depth >= len(moments):
moments.append([])
moments[depth].append(measurement)
return moments, depths
def _identifier_to_key(identifier: ast.Identifier | ast.IndexedIdentifier) -> tuple[str, int]:
if isinstance(identifier, ast.Identifier):
return identifier.name, -1
indices = identifier.indices
if len(indices) >= 1 and isinstance(indices[0], list) and len(indices[0]) >= 1:
return (
identifier.name.name,
Qasm3ExprEvaluator.evaluate_expression(indices[0][0])[0],
)
raise ValueError(f"Unsupported identifier: {identifier}")
def _compute_sections(
moments: list[list[QuantumStatement]],
) -> tuple[list[list[list[QuantumStatement]]], float]:
sections: list[list[list[QuantumStatement]]] = [[]]
width = TEXT_MARGIN
for moment in moments:
w = _mpl_get_moment_width(moment)
if width + w < FIG_MAX_WIDTH:
width += w
else:
width = TEXT_MARGIN
width = w
sections.append([])
sections[-1].append(moment)
if len(sections) > 1:
width = FIG_MAX_WIDTH
return sections, width
def _mpl_draw(
module: QasmModule,
moments: list[list[QuantumStatement]],
line_nums: dict[tuple[str, int], int],
sizes: dict[tuple[str, int], int],
global_phase: float,
):
sections, width = _compute_sections(moments)
n_lines = max(line_nums.values()) + 1
fig, axs = _mpl_setup_figure(sections, width, n_lines)
for sidx, ms in enumerate(sections):
ax = axs[sidx]
_mpl_draw_section(module, ms, line_nums, sizes, ax, global_phase)
return fig
def _mpl_setup_figure(
sections: list[list[list[QuantumStatement]]], width: float, n_lines: int
) -> tuple[plt.Figure, list[plt.Axes]]:
import matplotlib.pyplot as plt
import numpy as np
fig_ax_tuple: tuple[plt.Figure, list[plt.Axes] | plt.Axes] = plt.subplots(
len(sections),
1,
sharex=True,
figsize=(width, len(sections) * (n_lines * GATE_BOX_HEIGHT + LINE_SPACING * (n_lines - 1))),
)
fig, axs = fig_ax_tuple
axs = (
axs.flatten().tolist()
if isinstance(axs, np.ndarray)
else [axs] if isinstance(axs, plt.Axes) else axs
)
for ax in axs:
ax.set_ylim(
-GATE_BOX_HEIGHT / 2 - FRAME_PADDING / 2,
n_lines * GATE_BOX_HEIGHT
+ LINE_SPACING * (n_lines - 1)
- GATE_BOX_HEIGHT / 2
+ FRAME_PADDING / 2,
)
ax.set_xlim(-FRAME_PADDING / 2, width)
ax.axis("off")
return fig, axs
# pylint: disable-next=too-many-arguments
def _mpl_draw_section(
module: QasmModule,
moments: list[list[QuantumStatement]],
line_nums: dict[tuple[str, int], int],
sizes: dict[tuple[str, int], int],
ax: plt.Axes,
global_phase: float,
):
x = 0.0
if global_phase != 0:
_mpl_draw_global_phase(global_phase, ax, x)
for k in module._qubit_registers.keys():
for i in range(module._qubit_registers[k]):
if (k, i) in line_nums:
line_num = line_nums[(k, i)]
_mpl_draw_qubit_label((k, i), line_num, ax, x)
for k in module._classical_registers.keys():
_mpl_draw_creg_label(k, line_nums[(k, -1)], ax, x)
x += TEXT_MARGIN
x0 = x
for i, moment in enumerate(moments):
dx = _mpl_get_moment_width(moment)
_mpl_draw_lines(dx, line_nums, sizes, ax, x, start=i == 0)
x += dx
x = x0
for moment in moments:
dx = _mpl_get_moment_width(moment)
for statement in moment:
_mpl_draw_statement(statement, line_nums, ax, x)
x += dx
def _mpl_line_to_y(line_num: int) -> float:
return line_num * (GATE_BOX_HEIGHT + LINE_SPACING)
def _mpl_draw_global_phase(global_phase: float, ax: plt.Axes, x: float):
ax.text(x, -0.75, f"Global Phase: {global_phase:.3f}", ha="left", va="center")
def _mpl_draw_qubit_label(qubit: tuple[str, int], line_num: int, ax: plt.Axes, x: float):
ax.text(x, _mpl_line_to_y(line_num), f"{qubit[0]}[{qubit[1]}]", ha="right", va="center")
def _mpl_draw_creg_label(creg: str, line_num: int, ax: plt.Axes, x: float):
ax.text(x, _mpl_line_to_y(line_num), f"{creg[0]}", ha="right", va="center")
# pylint: disable-next=too-many-arguments
def _mpl_draw_lines(
width,
line_nums: dict[tuple[str, int], int],
sizes: dict[tuple[str, int], int],
ax: plt.Axes,
x: float,
start=True,
):
for k in line_nums.keys():
y = _mpl_line_to_y(line_nums[k])
if k[1] == -1:
gap = GATE_BOX_HEIGHT / 15
ax.hlines(
xmin=x - width / 2,
xmax=x + width / 2,
y=y + gap / 2,
color="black",
linestyle="-",
zorder=-10,
)
ax.hlines(
xmin=x - width / 2,
xmax=x + width / 2,
y=y - gap / 2,
color="black",
linestyle="-",
zorder=-10,
)
if start:
ax.plot(
[x - width / 2 + gap, x - width / 2 + 2 * gap],
[y - 2 * gap, y + 2 * gap],
color="black",
zorder=-10,
)
ax.text(x - width / 2 + 3 * gap, y + 3 * gap, f"{sizes[k]}", fontsize=8)
else:
ax.hlines(
xmin=x - width / 2,
xmax=x + width / 2,
y=y,
color="black",
linestyle="-",
zorder=-10,
)
def _mpl_get_moment_width(moment: list[QuantumStatement]) -> float:
return max(_mpl_get_statement_width(s) for s in moment)
def _mpl_get_statement_width(_: QuantumStatement) -> float:
return GATE_BOX_WIDTH + GATE_SPACING
def _mpl_draw_statement(
statement: QuantumStatement, line_nums: dict[tuple[str, int], int], ax: plt.Axes, x: float
):
if isinstance(statement, ast.QuantumGate):
args = [Qasm3ExprEvaluator.evaluate_expression(arg)[0] for arg in statement.arguments]
lines = [line_nums[_identifier_to_key(q)] for q in statement.qubits]
_mpl_draw_gate(statement, args, lines, ax, x)
elif isinstance(statement, ast.QuantumMeasurementStatement):
qubit_key = _identifier_to_key(statement.measure.qubit)
if statement.target is None:
_mpl_draw_measurement(line_nums[qubit_key], -1, -1, ax, x)
return
name, idx = _identifier_to_key(statement.target)
_mpl_draw_measurement(line_nums[qubit_key], line_nums[(name, -1)], idx, ax, x)
elif isinstance(statement, ast.QuantumBarrier):
lines = []
for q in statement.qubits:
# https://github.com/openqasm/openqasm/issues/461
if not isinstance(q, QubitIdentifier):
raise ValueError(
f"Unsupported qubit type '{type(q).__name__}' in "
f"'{type(statement).__name__}' statement. "
f"Expected a qubit of type {QubitIdentifier}."
)
lines.append(line_nums[_identifier_to_key(q)])
_mpl_draw_barrier(lines, ax, x)
elif isinstance(statement, ast.QuantumReset):
_mpl_draw_reset(line_nums[_identifier_to_key(statement.qubits)], ax, x)
else:
raise NotImplementedError(f"Unsupported statement: {statement}")
def _mpl_draw_gate(
gate: ast.QuantumGate, args: list[Any], lines: list[int], ax: plt.Axes, x: float
):
name = gate.name.name
if name in REV_CTRL_GATE_MAP:
i = 0
while name in REV_CTRL_GATE_MAP:
name = REV_CTRL_GATE_MAP[name]
_draw_mpl_control(lines[i], lines[-1], ax, x)
i += 1
lines = lines[i:]
gate.name.name = name
if name in ONE_QUBIT_OP_MAP or name in ONE_QUBIT_ROTATION_MAP:
_draw_mpl_one_qubit_gate(gate, args, lines[0], ax, x)
elif name in TWO_QUBIT_OP_MAP:
if name == "swap":
_draw_mpl_swap(lines[0], lines[1], ax, x)
else:
raise NotImplementedError(f"Unsupported gate: {name}")
else:
raise NotImplementedError(f"Unsupported gate: {name}")
def _draw_mpl_one_qubit_gate(
gate: ast.QuantumGate, args: list[Any], line: int, ax: plt.Axes, x: float
):
from matplotlib.patches import FancyBboxPatch
color = DEFAULT_GATE_COLOR
if gate.name.name == "h":
color = HADAMARD_GATE_COLOR
text = gate.name.name.upper()
y = _mpl_line_to_y(line)
rect = FancyBboxPatch(
(x - GATE_BOX_WIDTH / 2, y - GATE_BOX_HEIGHT / 2),
GATE_BOX_WIDTH,
GATE_BOX_HEIGHT,
facecolor=color,
edgecolor="none",
boxstyle=BOX_STYLE,
)
ax.add_patch(rect)
if len(args) > 0:
args_text = f"{', '.join([f'{a:.3f}' if isinstance(a, float) else str(a) for a in args])}"
ax.text(x, y + GATE_BOX_HEIGHT / 8, text, ha="center", va="center", fontsize=12)
ax.text(x, y - GATE_BOX_HEIGHT / 4, args_text, ha="center", va="center", fontsize=8)
else:
ax.text(x, y, text, ha="center", va="center", fontsize=12)
def _draw_mpl_control(ctrl_line: int, target_line: int, ax: plt.Axes, x: float):
y1 = _mpl_line_to_y(ctrl_line)
y2 = _mpl_line_to_y(target_line)
ax.vlines(x=x, ymin=min(y1, y2), ymax=max(y1, y2), color="black", linestyle="-", zorder=-1)
ax.plot(x, y1, "ko", markersize=8, markerfacecolor="black")
def _draw_mpl_swap(line1: int, line2: int, ax: plt.Axes, x: float):
y1 = _mpl_line_to_y(line1)
y2 = _mpl_line_to_y(line2)
ax.vlines(x=x, ymin=min(y1, y2), ymax=max(y1, y2), color="black", linestyle="-")
ax.plot(x, y1, "x", markersize=8, color="black")
ax.plot(x, y2, "x", markersize=8, color="black")
def _mpl_draw_measurement(qbit_line: int, cbit_line: int, idx: int, ax: plt.Axes, x: float):
from matplotlib.patches import FancyBboxPatch
y1 = _mpl_line_to_y(qbit_line)
color = "#A0A0A0"
gap = GATE_BOX_WIDTH / 3
rect = FancyBboxPatch(
(x - GATE_BOX_WIDTH / 2, y1 - GATE_BOX_HEIGHT / 2),
GATE_BOX_WIDTH,
GATE_BOX_HEIGHT,
facecolor=color,
edgecolor="none",
boxstyle=BOX_STYLE,
)
ax.add_patch(rect)
ax.text(x, y1, "M", ha="center", va="center")
if cbit_line >= 0 and idx >= 0:
y2 = _mpl_line_to_y(cbit_line)
ax.vlines(
x=x - gap / 10,
ymin=min(y1, y2) + gap,
ymax=max(y1, y2),
color=color,
linestyle="-",
zorder=-1,
)
ax.vlines(
x=x + gap / 10,
ymin=min(y1, y2) + gap,
ymax=max(y1, y2),
color=color,
linestyle="-",
zorder=-1,
)
ax.plot(x, y2 + gap, "v", markersize=12, color=color)
ax.text(x + gap, y2 + gap, str(idx), color=color, ha="left", va="bottom", fontsize=8)
def _mpl_draw_barrier(lines: list[int], ax: plt.Axes, x: float):
import matplotlib.pyplot as plt
for line in lines:
y = _mpl_line_to_y(line)
ax.vlines(
x=x,
ymin=y - GATE_BOX_HEIGHT / 2 - LINE_SPACING / 2,
ymax=y + GATE_BOX_HEIGHT / 2 + LINE_SPACING / 2,
color="black",
linestyle="--",
)
rect = plt.Rectangle(
(x - GATE_BOX_WIDTH / 4, y - GATE_BOX_HEIGHT / 2 - LINE_SPACING / 2),
GATE_BOX_WIDTH / 2,
GATE_BOX_HEIGHT + LINE_SPACING,
facecolor="lightgray",
edgecolor="none",
alpha=0.5,
zorder=-1,
)
ax.add_patch(rect)
def _mpl_draw_reset(line: int, ax: plt.Axes, x: float):
from matplotlib.patches import FancyBboxPatch
y = _mpl_line_to_y(line)
rect = FancyBboxPatch(
(x - GATE_BOX_WIDTH / 2, y - GATE_BOX_HEIGHT / 2),
GATE_BOX_WIDTH,
GATE_BOX_HEIGHT,
facecolor="lightgray",
edgecolor="none",
boxstyle=BOX_STYLE,
)
ax.add_patch(rect)
ax.text(x, y, "∣0⟩", ha="center", va="center", fontsize=12)