Source code for qbraid_qir.qasm3.visitor

# Copyright 2025 qBraid
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=too-many-instance-attributes,too-many-lines,too-many-branches

"""
Module defining Profile-based Qasm3 Visitor for QIR generation.

This refactored approach uses Profile objects to handle different QIR profiles
without code duplication, based on JSON profile specifications.

"""
import logging
from typing import Any, Callable, List, Optional, Union

import openqasm3.ast as qasm3_ast
import pyqir
import pyqir._native
import pyqir.rt
from openqasm3.ast import UnaryOperator
from pyqir import qis

from ..profile import QIRVisitor
from ..profiles import Profile, ProfileRegistry
from .elements import QasmQIRModule
from .exceptions import raise_qasm3_error
from .maps import PYQIR_ONE_QUBIT_ROTATION_MAP, map_qasm_op_to_pyqir_callable

logger = logging.getLogger(__name__)


[docs] class QasmQIRVisitor(QIRVisitor): """A profile-aware visitor for converting OpenQASM 3 programs to QIR. This class is designed to traverse and interact with statements in an OpenQASM program. It uses Profile objects to handle different QIR profile requirements. Args: profile_name (str): Name of the QIR profile to use. Defaults to "Base". initialize_runtime (bool): If True, quantum runtime will be initialized. Defaults to True. record_output (bool): If True, output of the circuit will be recorded. Defaults to True. external_gates (list[str]): List of custom gates that should not be unrolled. emit_barrier_calls (bool): If True, barrier calls will be emitted. Defaults to True. """ # pylint: disable=too-many-arguments
[docs] def __init__( self, profile_name: str = "Base", initialize_runtime: bool = True, record_output: bool = True, external_gates: Optional[list[str]] = None, emit_barrier_calls: bool = True, ): # Call parent class constructor super().__init__() # Get the profile self._profile = ProfileRegistry.get_profile(profile_name) self._llvm_module: pyqir.Module self._builder: pyqir.Builder self._entry_point: str = "" self._qubit_labels: dict[str, int] = {} self._clbit_labels: dict[str, int] = {} self._global_qreg_size_map: dict[str, int] = {} self._global_creg_size_map: dict[str, int] = {} self._custom_gates: dict[str, qasm3_ast.QuantumGateDefinition] = {} self._barrier_qubits: set[pyqir.Constant] = set() # Configuration self._initialize_runtime: bool = initialize_runtime self._record_output: bool = record_output self._emit_barrier_calls: bool = emit_barrier_calls # Profile-specific attributes if self._profile.should_track_qubit_measurement(): self._measured_qubits: dict[int, bool] = {} # External gates if external_gates is None: external_gates = [] self._external_gates_map: dict[str, Optional[pyqir.Function]] = { external_gate: None for external_gate in external_gates }
@property def profile(self) -> Profile: """Get the current profile.""" return self._profile @property def entry_point(self) -> str: return self._entry_point def visit_qasm3_module(self, module: QasmQIRModule) -> None: """ Visit a Qasm3 module. Args: module (Qasm3Module): The module to visit. Returns: None """ qasm3_module = module.qasm_program logger.debug( "Visiting Qasm3 module '%s' (%d) with profile '%s'", module.name, qasm3_module.num_qubits, self._profile.name, ) self._llvm_module = module.llvm_module context = self._llvm_module.context # Set qir_profiles based on the profile being used qir_profiles = "adaptive" if self._profile.name == "AdaptiveExecution" else "base" entry = pyqir.entry_point( self._llvm_module, module.name, qasm3_module.num_qubits, qasm3_module.num_clbits, qir_profiles=qir_profiles, ) self._entry_point = entry.name self._builder = pyqir.Builder(context) self._builder.insert_at_end(pyqir.BasicBlock(context, "entry", entry)) if self._initialize_runtime: i8p = pyqir.PointerType(pyqir.IntType(context, 8)) nullptr = pyqir.Constant.null(i8p) pyqir.rt.initialize(self._builder, nullptr) def finalize(self) -> None: self._check_and_apply_barrier() # to check if we have an incomplete barrier at program end self._builder.ret(None) def record_output(self, module: QasmQIRModule) -> None: """Record output using profile-specific method.""" self._profile.record_output_method(self, module) def _visit_register( self, register: Union[qasm3_ast.QubitDeclaration, qasm3_ast.ClassicalDeclaration] ) -> None: """Visit a register statement. Args: register (QubitDeclaration|ClassicalDeclaration): The register name and size. Returns: None """ logger.debug("Visiting register '%s'", str(register)) is_qubit = isinstance(register, qasm3_ast.QubitDeclaration) current_size = len(self._qubit_labels) if is_qubit else len(self._clbit_labels) if is_qubit: register_size = ( 1 if register.size is None else register.size.value # type: ignore[union-attr] ) else: register_size = ( 1 if register.type.size is None # type: ignore[union-attr] else register.type.size.value # type: ignore[union-attr] ) register_name = ( register.qubit.name # type: ignore[union-attr] if is_qubit else register.identifier.name # type: ignore[union-attr] ) size_map = self._global_qreg_size_map if is_qubit else self._global_creg_size_map label_map = self._qubit_labels if is_qubit else self._clbit_labels for i in range(register_size): size_map[f"{register_name}"] = register_size label_map[f"{register_name}_{i}"] = current_size + i logger.debug("Added labels for register '%s'", str(register)) def _get_op_bits(self, operation: Any, qubits: bool = True) -> list[pyqir.Constant]: """Get the quantum / classical bits for the operation. Args: operation (Any): The operation to get qubits for. reg_size_map (dict): The size map of the registers in scope. qubits (bool): Whether the bits are quantum bits or classical bits. Defaults to True. Returns: Unionlist[pyqir.Constant] : The bits for the operation. """ qir_bits = [] bit_list = [] if isinstance(operation, qasm3_ast.QuantumMeasurementStatement): assert operation.target is not None bit_list = [operation.measure.qubit] if qubits else [operation.target] else: bit_list = ( operation.qubits if isinstance(operation.qubits, list) else [operation.qubits] ) for bit in bit_list: # as we have unrolled qasm3, we can assume that the bit is an IndexedIdentifier assert isinstance(bit, qasm3_ast.IndexedIdentifier) reg_name = bit.name.name assert isinstance(bit.indices, list) and len(bit.indices) == 1 assert isinstance(bit.indices[0], list) and len(bit.indices[0]) == 1 assert isinstance(bit.indices[0][0], qasm3_ast.IntegerLiteral) bit_id = bit.indices[0][0].value bit_ids = [bit_id] label_map = self._qubit_labels if qubits else self._clbit_labels reg_ids = [label_map[f"{reg_name}_{bit_id}"] for bit_id in bit_ids] qir_bits.extend( [ ( pyqir.qubit(self._llvm_module.context, bit_id) if qubits else pyqir.result(self._llvm_module.context, bit_id) ) for bit_id in reg_ids ] ) return qir_bits # pylint: disable=unused-argument def _check_qubit_use_after_measurement(self, qubit_ids: List[pyqir.Constant]) -> None: """ Check qubit use after measurement based on profile capabilities. Args: qubit_ids (List[pyqir.Constant]): The qubit ids to check. Returns: None """ if not self._profile.allow_qubit_use_after_measurement(): # For profiles that don't allow it, we could add validation here for qubit_id in qubit_ids: qubit_id_result = pyqir.qubit_id(qubit_id) if qubit_id_result is not None and self._measured_qubits.get( qubit_id_result, False ): raise_qasm3_error( f"Base Profile violation: Cannot use qubit {qubit_id_result} after measurement" # pylint: disable=line-too-long ) def _visit_measurement(self, statement: qasm3_ast.QuantumMeasurementStatement) -> None: """Visit a measurement statement element. Args: statement (qasm3_ast.QuantumMeasurementStatement): The measurement statement to visit. Returns: None """ logger.debug("Visiting measurement statement '%s'", str(statement)) source = statement.measure.qubit target = statement.target assert source and target source_ids = self._get_op_bits(statement, qubits=True) target_ids = self._get_op_bits(statement, qubits=False) measurement_func = self._profile.get_measurement_function() for src_id, tgt_id in zip(source_ids, target_ids): # Track measurement if profile supports it if self._profile.should_track_qubit_measurement(): qubit_id_result = pyqir.qubit_id(src_id) if qubit_id_result is not None: self._measured_qubits[qubit_id_result] = True measurement_func(self._builder, src_id, tgt_id) def _visit_reset(self, statement: qasm3_ast.QuantumReset) -> None: """Visit a reset statement element. Args: statement (qasm3_ast.QuantumReset): The reset statement to visit. Returns: None """ logger.debug("Visiting reset statement '%s'", str(statement)) qubit_ids = self._get_op_bits(statement, True) reset_func = self._profile.get_reset_function() for qid in qubit_ids: # Clear measurement tracking if profile supports it if self._profile.should_track_qubit_measurement(): qubit_id_result = pyqir.qubit_id(qid) if qubit_id_result is not None: self._measured_qubits[qubit_id_result] = False reset_func(self._builder, qid) def _barrier_applicable(self) -> bool: """Check if the barrier operation is applicable. Args: None Returns: bool: Whether the barrier operation is applicable. """ if self._profile.restrictions.subset_barriers_allowed: return True total_qubit_count = sum(self._global_qreg_size_map.values()) return len(self._barrier_qubits) == total_qubit_count def _check_and_apply_barrier(self) -> None: """Apply the barrier operation. Returns: None """ if len(self._barrier_qubits) == 0: return if self._barrier_applicable(): if self._emit_barrier_calls: barrier_func = self._profile.get_barrier_function() barrier_func(self._builder) self._barrier_qubits.clear() else: if self._emit_barrier_calls: raise_qasm3_error( "Barrier operation on a qubit subset is not supported in pyqir", err_type=NotImplementedError, ) # pylint: disable=unused-argument def _visit_barrier(self, barrier: qasm3_ast.QuantumBarrier) -> None: """Visit a barrier statement element. Args: statement (qasm3_ast.QuantumBarrier): The barrier statement to visit. Returns: None """ barrier_qubit = self._get_op_bits(barrier, qubits=True) self._barrier_qubits.update(barrier_qubit) # try to apply barrier in case all qubits are covered here itself if self._barrier_applicable(): if self._emit_barrier_calls: barrier_func = self._profile.get_barrier_function() barrier_func(self._builder) self._barrier_qubits.clear() def _get_op_parameters(self, operation: qasm3_ast.QuantumGate) -> list[float]: """Get the parameters for the operation. Args: operation (qasm3_ast.QuantumGate): The operation to get parameters for. Returns: list[float]: The parameters for the operation. """ param_list = [] for param in operation.arguments: assert hasattr(param, "value") param_value = param.value param_list.append(param_value) return param_list def _visit_basic_gate_operation(self, operation: qasm3_ast.QuantumGate) -> None: """Visit a gate operation element. Args: operation (qasm3_ast.QuantumGate): The gate operation to visit. Returns: None Raises: Qasm3ConversionError: If the number of qubits is invalid. """ logger.debug("Visiting basic gate operation '%s'", str(operation)) op_name: str = operation.name.name op_qubits = self._get_op_bits(operation) # Profile-aware qubit usage check self._check_qubit_use_after_measurement(op_qubits) # Use existing gate mapping logic but with profile awareness gate_map = { "h": qis.h if self._profile.restrictions.prefer_qis_over_native else qis.h, "x": qis.x, "y": qis.y, "z": qis.z, "s": qis.s, "sdg": qis.s_adj, "t": qis.t, "tdg": qis.t_adj, "cx": qis.cx, "cnot": qis.cx, "cz": qis.cz, "ccx": qis.ccx, "swap": qis.swap, "rx": qis.rx, "ry": qis.ry, "rz": qis.rz, } if op_name in gate_map: qir_func = gate_map[op_name] if op_name in ["rx", "ry", "rz"]: op_parameters = self._get_op_parameters(operation) if op_parameters: qir_func(self._builder, *op_parameters, *op_qubits) # type: ignore else: raise_qasm3_error(f"Parametric gate {op_name} requires parameters") else: qir_func(self._builder, *op_qubits) # type: ignore elif op_name == "id": # Identity gate implementation qubit = op_qubits[0] qis.x(self._builder, qubit) qis.x(self._builder, qubit) else: # Use the mapping system try: qir_func, expected_qubit_count = map_qasm_op_to_pyqir_callable(op_name) if len(op_qubits) != expected_qubit_count: raise_qasm3_error( f"Gate {op_name} expects {expected_qubit_count} qubits,got {len(op_qubits)}" ) op_parameters = self._get_op_parameters(operation) is_parametric = op_name in PYQIR_ONE_QUBIT_ROTATION_MAP or op_name in [ "xx", "xy", "yy", "zz", "pswap", "cp", "cphaseshift", "cp00", "cphaseshift00", "cp01", "cphaseshift01", "cp10", "cphaseshift10", "ms", "prx", ] if is_parametric: if not op_parameters: raise_qasm3_error(f"Parametric gate {op_name} requires parameters") qir_func(self._builder, *op_parameters, *op_qubits) else: if op_parameters: raise_qasm3_error( f"Non-parametric gate {op_name} should not have parameters" ) qir_func(self._builder, *op_qubits) except ValueError as conversion_error: if "Unsupported / undeclared QASM operation" in str(conversion_error): raise_qasm3_error(f"Unsupported gate operation: {op_name}") else: raise_qasm3_error(f"Error mapping gate {op_name}: {conversion_error}") except (TypeError, Exception) as e: # pylint: disable=broad-exception-caught raise_qasm3_error(f"Error executing gate {op_name}: {e}") def _visit_external_gate_operation(self, operation: qasm3_ast.QuantumGate) -> None: """Visit an external gate operation element. Args: operation (qasm3_ast.QuantumGate): The gate operation to visit. Returns: None Raises: Qasm3ConversionError: If the number of qubits is invalid. """ logger.debug("Visiting external gate operation '%s'", str(operation)) op_name: str = operation.name.name op_qubits = self._get_op_bits(operation) op_qubit_count = len(op_qubits) self._check_qubit_use_after_measurement(op_qubits) if len(operation.modifiers) > 0: raise_qasm3_error( "Modifiers on externally linked gates are not supported in pyqir", err_type=NotImplementedError, ) context = self._llvm_module.context qir_function = self._external_gates_map[op_name] if qir_function is None: # First time seeing this external gate -> define new function qir_function_arguments = [pyqir.Type.double(context)] * len(operation.arguments) qir_function_arguments += [pyqir.qubit_type(context)] * op_qubit_count qir_function = pyqir.Function( pyqir.FunctionType(pyqir.Type.void(context), qir_function_arguments), pyqir.Linkage.EXTERNAL, f"__quantum__qis__{op_name}__body", self._llvm_module, ) self._external_gates_map[op_name] = qir_function op_parameters = None if len(operation.arguments) > 0: # parametric gate op_parameters = self._get_op_parameters(operation) op_parameters = list(map(float, op_parameters)) if op_parameters is not None: self._builder.call(qir_function, [*op_parameters, *op_qubits]) else: self._builder.call(qir_function, op_qubits) def _visit_generic_gate_operation(self, operation: qasm3_ast.QuantumGate) -> None: """Visit a gate operation element. Args: operation (qasm3_ast.QuantumGate): The gate operation to visit. Returns: None """ if operation.name.name in self._external_gates_map: self._visit_external_gate_operation(operation) else: self._visit_basic_gate_operation(operation) def _get_branch_params(self, condition: Any) -> tuple[str, int, bool]: """ Get the branch parameters from the branching condition Args: condition (Any): The condition to analyze Returns: tuple[str, int, bool]: (register name, register id, positive branch) """ def validate_index_expression(expression): assert isinstance(expression, qasm3_ast.IndexExpression) assert isinstance(expression.collection, qasm3_ast.Identifier) assert isinstance(expression.index, list) and len(expression.index) == 1 assert isinstance(expression.index[0], qasm3_ast.IntegerLiteral) if isinstance(condition, qasm3_ast.UnaryExpression): validate_index_expression(condition.expression) return ( condition.expression.collection.name, # type: ignore condition.expression.index[0].value, # type: ignore not condition.op == UnaryOperator["!"], ) if isinstance(condition, qasm3_ast.BinaryExpression): assert isinstance( condition.rhs, qasm3_ast.BooleanLiteral ), "Invalid branching condition" validate_index_expression(condition.lhs) return ( condition.lhs.collection.name, # type: ignore condition.lhs.index[0].value, # type: ignore condition.rhs.value, ) if isinstance(condition, qasm3_ast.IndexExpression): assert isinstance(condition.index, list) and len(condition.index) == 1 return (condition.collection.name, condition.index[0].value, True) # type: ignore # default case return "", -1, True def _visit_branching_statement(self, statement: qasm3_ast.BranchingStatement) -> None: """Visit a branching statement element. Args: statement (qasm3_ast.BranchingStatement): The branching statement to visit. Returns: None """ logger.debug("Visiting branching statement with profile '%s'", self._profile.name) # Check if profile supports conditional execution if not self._profile.capabilities.conditional_execution: raise_qasm3_error( f"Profile '{self._profile.name}' does not support conditional execution", err_type=NotImplementedError, ) condition = statement.condition if_block = statement.if_block else_block = statement.else_block reg_name, reg_id, positive_branch = self._get_branch_params(condition) if not positive_branch: if_block, else_block = else_block, if_block def _visit_statement_block(block): if block: for stmt in block: self.visit_statement(stmt) # Use profile-specific conditional function conditional_func = self._profile.get_conditional_function() zero_callback: Callable[[], None] = lambda: _visit_statement_block(else_block) one_callback: Callable[[], None] = lambda: _visit_statement_block(if_block) conditional_func( self._builder, pyqir.result(self._llvm_module.context, self._clbit_labels[f"{reg_name}_{reg_id}"]), zero=zero_callback, one=one_callback, ) def visit_statement(self, statement: qasm3_ast.Statement) -> None: """Visit a statement element. Args: statement (qasm3_ast.Statement): The statement to visit. Returns: None """ logger.debug("Visiting statement '%s'", str(statement)) visit_map = { qasm3_ast.Include: lambda x: None, # No operation qasm3_ast.QubitDeclaration: self._visit_register, qasm3_ast.ClassicalDeclaration: self._visit_register, qasm3_ast.QuantumMeasurementStatement: self._visit_measurement, qasm3_ast.QuantumReset: self._visit_reset, qasm3_ast.QuantumBarrier: self._visit_barrier, qasm3_ast.QuantumGate: self._visit_generic_gate_operation, qasm3_ast.BranchingStatement: self._visit_branching_statement, qasm3_ast.QuantumPhase: lambda x: None, # No operation } visitor_function = visit_map.get(type(statement)) if not isinstance(statement, qasm3_ast.QuantumBarrier): self._check_and_apply_barrier() if visitor_function: visitor_function(statement) # type: ignore[operator] else: raise_qasm3_error( f"Unsupported statement of type {type(statement)}", span=statement.span ) def ir(self) -> str: return str(self._llvm_module) def bitcode(self) -> bytes: return self._llvm_module.bitcode