# Copyright (C) 2024 qBraid
#
# This file is part of qbraid-qir
#
# Qbraid-qir is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for qbraid-qir, as per Section 15 of the GPL v3.
# pylint: disable=too-many-instance-attributes,too-many-lines,too-many-branches
"""
Module defining Qasm3 Visitor.
"""
import logging
from typing import Any, Union
import openqasm3.ast as qasm3_ast
import pyqir
import pyqir._native
import pyqir.rt
from openqasm3.ast import UnaryOperator
from .elements import QasmQIRModule
from .exceptions import raise_qasm3_error
from .maps import map_qasm_op_to_pyqir_callable
logger = logging.getLogger(__name__)
[docs]
class QasmQIRVisitor:
"""A visitor for converting OpenQASM 3 programs to QIR.
This class is designed to traverse and interact with statements in an OpenQASM program.
Args:
initialize_runtime (bool): If True, quantum runtime will be initialized. Defaults to True.
record_output (bool): If True, output of the circuit will be recorded. Defaults to True.
external_gates (list[str]): List of custom gates that should not be unrolled.
Instead, these gates are marked for external linkage, as
qir-functions with the name "__quantum__qis__<GateName>__body"
"""
[docs]
def __init__(
self,
initialize_runtime: bool = True,
record_output: bool = True,
external_gates: list[str] | None = None,
):
self._llvm_module: pyqir.Module
self._builder: pyqir.Builder
self._entry_point: str = ""
self._qubit_labels: dict[str, int] = {}
self._clbit_labels: dict[str, int] = {}
self._global_qreg_size_map: dict[str, int] = {}
self._global_creg_size_map: dict[str, int] = {}
self._custom_gates: dict[str, qasm3_ast.QuantumGateDefinition] = {}
self._barrier_qubits: set[pyqir.Constant] = set()
self._initialize_runtime: bool = initialize_runtime
self._record_output: bool = record_output
if external_gates is None:
external_gates = []
self._external_gates_map: dict[str, pyqir.Function | None] = {
external_gate: None for external_gate in external_gates
}
def visit_qasm3_module(self, module: QasmQIRModule) -> None:
"""
Visit a Qasm3 module.
Args:
module (Qasm3Module): The module to visit.
Returns:
None
"""
qasm3_module = module.qasm_program
logger.debug("Visiting Qasm3 module '%s' (%d)", module.name, qasm3_module.num_qubits)
self._llvm_module = module.llvm_module
context = self._llvm_module.context
entry = pyqir.entry_point(
self._llvm_module, module.name, qasm3_module.num_qubits, qasm3_module.num_clbits
)
self._entry_point = entry.name
self._builder = pyqir.Builder(context)
self._builder.insert_at_end(pyqir.BasicBlock(context, "entry", entry))
if self._initialize_runtime is True:
i8p = pyqir.PointerType(pyqir.IntType(context, 8))
nullptr = pyqir.Constant.null(i8p)
pyqir.rt.initialize(self._builder, nullptr)
@property
def entry_point(self) -> str:
return self._entry_point
def finalize(self) -> None:
self._check_and_apply_barrier() # to check if we have an incomplete barrier at program end
self._builder.ret(None)
def record_output(self, module: QasmQIRModule) -> None:
if self._record_output is False:
return
i8p = pyqir.PointerType(pyqir.IntType(self._llvm_module.context, 8))
for i in range(module.qasm_program.num_qubits):
result_ref = pyqir.result(self._llvm_module.context, i)
pyqir.rt.result_record_output(self._builder, result_ref, pyqir.Constant.null(i8p))
def _visit_register(
self, register: Union[qasm3_ast.QubitDeclaration, qasm3_ast.ClassicalDeclaration]
) -> None:
"""Visit a register statement.
Args:
register (QubitDeclaration|ClassicalDeclaration): The register name and size.
Returns:
None
"""
logger.debug("Visiting register '%s'", str(register))
is_qubit = isinstance(register, qasm3_ast.QubitDeclaration)
current_size = len(self._qubit_labels) if is_qubit else len(self._clbit_labels)
if is_qubit:
register_size = (
1 if register.size is None else register.size.value # type: ignore[union-attr]
)
else:
register_size = (
1
if register.type.size is None # type: ignore[union-attr]
else register.type.size.value # type: ignore[union-attr]
)
register_name = (
register.qubit.name # type: ignore[union-attr]
if is_qubit
else register.identifier.name # type: ignore[union-attr]
)
size_map = self._global_qreg_size_map if is_qubit else self._global_creg_size_map
label_map = self._qubit_labels if is_qubit else self._clbit_labels
for i in range(register_size):
size_map[f"{register_name}"] = register_size
label_map[f"{register_name}_{i}"] = current_size + i
logger.debug("Added labels for register '%s'", str(register))
def _get_op_bits(self, operation: Any, qubits: bool = True) -> list[pyqir.Constant]:
"""Get the quantum / classical bits for the operation.
Args:
operation (Any): The operation to get qubits for.
reg_size_map (dict): The size map of the registers in scope.
qubits (bool): Whether the bits are quantum bits or classical bits. Defaults to True.
Returns:
Unionlist[pyqir.Constant] : The bits for the operation.
"""
qir_bits = []
bit_list = []
if isinstance(operation, qasm3_ast.QuantumMeasurementStatement):
assert operation.target is not None
bit_list = [operation.measure.qubit] if qubits else [operation.target]
else:
bit_list = (
operation.qubits if isinstance(operation.qubits, list) else [operation.qubits]
)
for bit in bit_list:
# as we have unrolled qasm3, we can assume that the bit is an IndexedIdentifier
assert isinstance(bit, qasm3_ast.IndexedIdentifier)
reg_name = bit.name.name
assert isinstance(bit.indices, list) and len(bit.indices) == 1
assert isinstance(bit.indices[0], list) and len(bit.indices[0]) == 1
assert isinstance(bit.indices[0][0], qasm3_ast.IntegerLiteral)
bit_id = bit.indices[0][0].value
bit_ids = [bit_id]
label_map = self._qubit_labels if qubits else self._clbit_labels
reg_ids = [label_map[f"{reg_name}_{bit_id}"] for bit_id in bit_ids]
qir_bits.extend(
[
(
pyqir.qubit(self._llvm_module.context, bit_id)
if qubits
else pyqir.result(self._llvm_module.context, bit_id)
)
for bit_id in reg_ids
]
)
return qir_bits
def _visit_measurement(self, statement: qasm3_ast.QuantumMeasurementStatement) -> None:
"""Visit a measurement statement element.
Args:
statement (qasm3_ast.QuantumMeasurementStatement): The measurement statement to visit.
Returns:
None
"""
logger.debug("Visiting measurement statement '%s'", str(statement))
source = statement.measure.qubit
target = statement.target
assert source and target
source_ids = self._get_op_bits(statement, qubits=True)
target_ids = self._get_op_bits(statement, qubits=False)
for src_id, tgt_id in zip(source_ids, target_ids):
pyqir._native.mz(self._builder, src_id, tgt_id) # type: ignore[arg-type]
def _visit_reset(self, statement: qasm3_ast.QuantumReset) -> None:
"""Visit a reset statement element.
Args:
statement (qasm3_ast.QuantumReset): The reset statement to visit.
Returns:
None
"""
logger.debug("Visiting reset statement '%s'", str(statement))
qubit_ids = self._get_op_bits(statement, True)
for qid in qubit_ids:
# qid is of type Constant which is inherited from Value, so we ignore the type error
pyqir._native.reset(self._builder, qid) # type: ignore[arg-type]
def _barrier_applicable(self) -> bool:
"""Check if the barrier operation is applicable.
Args:
None
Returns:
bool: Whether the barrier operation is applicable.
"""
total_qubit_count = sum(self._global_qreg_size_map.values())
return len(self._barrier_qubits) == total_qubit_count
def _check_and_apply_barrier(self) -> None:
"""Apply the barrier operation.
Returns:
None
"""
if len(self._barrier_qubits) == 0:
return
if self._barrier_applicable():
pyqir._native.barrier(self._builder)
self._barrier_qubits.clear()
else:
raise_qasm3_error(
"Barrier operation on a qubit subset is not supported in pyqir",
err_type=NotImplementedError,
)
# pylint: disable=unused-argument
def _visit_barrier(self, barrier: qasm3_ast.QuantumBarrier) -> None:
"""Visit a barrier statement element.
Args:
statement (qasm3_ast.QuantumBarrier): The barrier statement to visit.
Returns:
None
"""
barrier_qubit = self._get_op_bits(barrier, qubits=True)
self._barrier_qubits.update(barrier_qubit)
# try to apply barrier in case all qubits are covered here itself
if self._barrier_applicable():
pyqir._native.barrier(self._builder)
self._barrier_qubits.clear()
def _get_op_parameters(self, operation: qasm3_ast.QuantumGate) -> list[float]:
"""Get the parameters for the operation.
Args:
operation (qasm3_ast.QuantumGate): The operation to get parameters for.
Returns:
list[float]: The parameters for the operation.
"""
param_list = []
for param in operation.arguments:
assert hasattr(param, "value")
param_value = param.value
param_list.append(param_value)
return param_list
def _visit_basic_gate_operation(self, operation: qasm3_ast.QuantumGate) -> None:
"""Visit a gate operation element.
Args:
operation (qasm3_ast.QuantumGate): The gate operation to visit.
Returns:
None
Raises:
Qasm3ConversionError: If the number of qubits is invalid.
"""
logger.debug("Visiting basic gate operation '%s'", str(operation))
op_name: str = operation.name.name
op_qubits = self._get_op_bits(operation)
qir_func, op_qubit_count = map_qasm_op_to_pyqir_callable(op_name)
op_parameters = None
if len(operation.arguments) > 0: # parametric gate
op_parameters = self._get_op_parameters(operation)
for i in range(0, len(op_qubits), op_qubit_count):
# we apply the gate on the qubit subset linearly
qubit_subset = op_qubits[i : i + op_qubit_count]
if op_parameters is not None:
qir_func(self._builder, *op_parameters, *qubit_subset)
else:
qir_func(self._builder, *qubit_subset)
def _visit_external_gate_operation(self, operation: qasm3_ast.QuantumGate) -> None:
"""Visit an external gate operation element.
Args:
operation (qasm3_ast.QuantumGate): The gate operation to visit.
Returns:
None
Raises:
Qasm3ConversionError: If the number of qubits is invalid.
"""
logger.debug("Visiting external gate operation '%s'", str(operation))
op_name: str = operation.name.name
op_qubits = self._get_op_bits(operation)
op_qubit_count = len(op_qubits)
if len(operation.modifiers) > 0:
raise_qasm3_error(
"Modifiers on externally linked gates are not supported in pyqir",
err_type=NotImplementedError,
)
context = self._llvm_module.context
qir_function = self._external_gates_map[op_name]
if qir_function is None:
# First time seeing this external gate -> define new function
qir_function_arguments = [pyqir.Type.double(context)] * len(operation.arguments)
qir_function_arguments += [pyqir.qubit_type(context)] * op_qubit_count
qir_function = pyqir.Function(
pyqir.FunctionType(pyqir.Type.void(context), qir_function_arguments),
pyqir.Linkage.EXTERNAL,
f"__quantum__qis__{op_name}__body",
self._llvm_module,
)
self._external_gates_map[op_name] = qir_function
op_parameters = None
if len(operation.arguments) > 0: # parametric gate
op_parameters = self._get_op_parameters(operation)
if op_parameters is not None:
self._builder.call(qir_function, [*op_parameters, *op_qubits])
else:
self._builder.call(qir_function, op_qubits)
def _visit_generic_gate_operation(self, operation: qasm3_ast.QuantumGate) -> None:
"""Visit a gate operation element.
Args:
operation (qasm3_ast.QuantumGate): The gate operation to visit.
Returns:
None
"""
if operation.name.name in self._external_gates_map:
self._visit_external_gate_operation(operation)
else:
self._visit_basic_gate_operation(operation)
def _get_branch_params(self, condition: Any) -> tuple[str, int, bool]:
"""
Get the branch parameters from the branching condition
Args:
condition (Any): The condition to analyze
Returns:
tuple[str, int, bool]: (register name, register id, positive branch)
"""
def validate_index_expression(expression):
assert isinstance(expression, qasm3_ast.IndexExpression)
assert isinstance(expression.collection, qasm3_ast.Identifier)
assert isinstance(expression.index, list) and len(expression.index) == 1
assert isinstance(expression.index[0], qasm3_ast.IntegerLiteral)
if isinstance(condition, qasm3_ast.UnaryExpression):
validate_index_expression(condition.expression)
return (
condition.expression.collection.name, # type: ignore
condition.expression.index[0].value, # type: ignore
not condition.op == UnaryOperator["!"],
)
if isinstance(condition, qasm3_ast.BinaryExpression):
assert isinstance(
condition.rhs, qasm3_ast.BooleanLiteral
), "Invalid branching condition"
validate_index_expression(condition.lhs)
return (
condition.lhs.collection.name, # type: ignore
condition.lhs.index[0].value, # type: ignore
condition.rhs.value,
)
if isinstance(condition, qasm3_ast.IndexExpression):
assert isinstance(condition.index, list) and len(condition.index) == 1
return (condition.collection.name, condition.index[0].value, True) # type: ignore
# default case
return "", -1, True
def _visit_branching_statement(self, statement: qasm3_ast.BranchingStatement) -> None:
"""Visit a branching statement element.
Args:
statement (qasm3_ast.BranchingStatement): The branching statement to visit.
Returns:
None
"""
condition = statement.condition
if_block = statement.if_block
else_block = statement.else_block
reg_name, reg_id, positive_branch = self._get_branch_params(condition)
if not positive_branch:
if_block, else_block = else_block, if_block
def _visit_statement_block(block):
for stmt in block:
self.visit_statement(stmt)
pyqir._native.if_result(
self._builder,
pyqir.result(self._llvm_module.context, self._clbit_labels[f"{reg_name}_{reg_id}"]),
zero=lambda: _visit_statement_block(else_block),
one=lambda: _visit_statement_block(if_block),
)
def visit_statement(self, statement: qasm3_ast.Statement) -> None:
"""Visit a statement element.
Args:
statement (qasm3_ast.Statement): The statement to visit.
Returns:
None
"""
logger.debug("Visiting statement '%s'", str(statement))
visit_map = {
qasm3_ast.Include: lambda x: None, # No operation
qasm3_ast.QubitDeclaration: self._visit_register,
qasm3_ast.ClassicalDeclaration: self._visit_register,
qasm3_ast.QuantumMeasurementStatement: self._visit_measurement,
qasm3_ast.QuantumReset: self._visit_reset,
qasm3_ast.QuantumBarrier: self._visit_barrier,
qasm3_ast.QuantumGate: self._visit_generic_gate_operation,
qasm3_ast.BranchingStatement: self._visit_branching_statement,
qasm3_ast.QuantumPhase: lambda x: None, # No operation
}
visitor_function = visit_map.get(type(statement))
if not isinstance(statement, qasm3_ast.QuantumBarrier):
self._check_and_apply_barrier()
if visitor_function:
visitor_function(statement) # type: ignore[operator]
else:
raise_qasm3_error(
f"Unsupported statement of type {type(statement)}", span=statement.span
)
def ir(self) -> str:
return str(self._llvm_module)
def bitcode(self) -> bytes:
return self._llvm_module.bitcode