# Copyright (C) 2024 qBraid
#
# This file is part of qbraid-qir
#
# Qbraid-qir is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for qbraid-qir, as per Section 15 of the GPL v3.
"""
Module defining CirqVisitor.
"""
import logging
from abc import ABCMeta, abstractmethod
import cirq
import pyqir
import pyqir._native
import pyqir.rt
from pyqir import BasicBlock, Builder, Constant, IntType, PointerType
from .elements import CirqModule
from .opsets import map_cirq_op_to_pyqir_callable
logger = logging.getLogger(__name__)
class CircuitElementVisitor(metaclass=ABCMeta):
@abstractmethod
def visit_register(self, qids):
pass
@abstractmethod
def visit_operation(self, operation):
pass
[docs]
class BasicCirqVisitor(CircuitElementVisitor):
"""A visitor for basic cirq.Circuit elements.
This class is designed to traverse and interact with elements in a quantum circuit.
Args:
initialize_runtime (bool): If True, quantum runtime will be initialized. Defaults to True.
record_output (bool): If True, output of the circuit will be recorded. Defaults to True.
"""
[docs]
def __init__(self, initialize_runtime: bool = True, record_output: bool = True):
self._module: pyqir.Module
self._builder: pyqir.Builder
self._entry_point: str
self._qubit_labels: dict[cirq.Qid, int] = {}
self._measured_qubits: dict = {}
self._initialize_runtime = initialize_runtime
self._record_output = record_output
def visit_cirq_module(self, module: CirqModule) -> None:
logger.debug("Visiting Cirq module '%s' (%d)", module.name, module.num_qubits)
self._module = module.module
context = self._module.context
entry = pyqir.entry_point(self._module, module.name, module.num_qubits, module.num_clbits)
self._entry_point = entry.name
self._builder = Builder(context)
self._builder.insert_at_end(BasicBlock(context, "entry", entry))
if self._initialize_runtime is True:
i8p = PointerType(IntType(context, 8))
nullptr = Constant.null(i8p)
pyqir.rt.initialize(self._builder, nullptr)
@property
def entry_point(self) -> str:
return self._entry_point
def finalize(self) -> None:
self._builder.ret(None)
def record_output(self, module: CirqModule) -> None:
if self._record_output is False:
return
i8p = PointerType(IntType(self._module.context, 8))
for i in range(module.num_qubits):
result_ref = pyqir.result(self._module.context, i)
pyqir.rt.result_record_output(self._builder, result_ref, Constant.null(i8p))
def visit_register(self, qids: list[cirq.Qid]) -> None:
logger.debug("Visiting qids '%s'", str(qids))
if not all(isinstance(x, cirq.Qid) for x in qids):
raise TypeError("All elements in the list must be of type cirq.Qid.")
self._qubit_labels.update({bit: n + len(self._qubit_labels) for n, bit in enumerate(qids)})
logger.debug("Added labels for qubits %s", str(qids))
def visit_operation(self, operation: cirq.Operation) -> None:
qlabels = [self._qubit_labels[bit] for bit in operation.qubits]
qubits = [pyqir.qubit(self._module.context, n) for n in qlabels]
results = [pyqir.result(self._module.context, n) for n in qlabels]
def handle_measurement(pyqir_func):
logger.debug("Visiting measurement operation '%s'", str(operation))
for qubit, result in zip(qubits, results):
self._measured_qubits[pyqir.qubit_id(qubit)] = True
pyqir_func(self._builder, qubit, result)
# dealing with conditional gates
if isinstance(operation, cirq.ClassicallyControlledOperation):
op_conds = operation._conditions # list of measurement keys
conditions = [
pyqir.result(self._module.context, int(op_conds[i].keys[0].name))
for i in range(len(op_conds))
]
regular_op = operation.without_classical_controls()
temp_pyqir_func, op_str = map_cirq_op_to_pyqir_callable(regular_op)
# pylint: disable=unnecessary-lambda-assignment
if op_str in ["Rx", "Ry", "Rz"]:
pyqir_func = lambda: temp_pyqir_func(
self._builder,
operation._sub_operation.gate._rads, # type: ignore[union-attr]
*qubits,
)
else:
pyqir_func = lambda: temp_pyqir_func(self._builder, *qubits)
def _branch(conds, pyqir_func):
if len(conds) == 0:
temp_id, _ = map_cirq_op_to_pyqir_callable(cirq.I)
passable_identity = lambda: temp_id(self._builder, *qubits)
return passable_identity
return pyqir._native.if_result(
self._builder,
conds[0],
zero=_branch(conds[1:], pyqir_func),
one=pyqir_func,
)
_branch(conditions, pyqir_func)
else:
pyqir_func, op_str = map_cirq_op_to_pyqir_callable(operation)
if op_str.startswith("measure"):
handle_measurement(pyqir_func)
elif op_str in ["Rx", "Ry", "Rz"]:
pyqir_func(self._builder, operation.gate._rads, *qubits) # type: ignore[union-attr]
else:
pyqir_func(self._builder, *qubits)
def ir(self) -> str:
return str(self._module)