# Copyright 2025 qBraid
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the functionality to convert a PyQIR module into a squin kernel.
"""
from __future__ import annotations
import os
from plistlib import InvalidFileException
from typing import TYPE_CHECKING, Any, Callable, TypedDict
import pyqir
from bloqade import qubit
from bloqade.squin import kernel
from kirin import ir, lowering, types
from kirin.dialects import func, ilist, py
from kirin.rewrite import CFGCompactify, Walk
from qbraid_qir._pyqir_compat import pointer_id
from .exceptions import InvalidSquinInput
from .maps import PYQIR_TO_SQUIN_GATES_MAP, QIR_TO_SQUIN_UNSUPPORTED_STATEMENTS_MAP
if TYPE_CHECKING:
from typing import Unpack
class LoadKwargs(TypedDict, total=False):
"""Type definition for keyword arguments to the load function."""
kernel_name: str
dialects: ir.DialectGroup
register_as_argument: bool
return_measurements: list[int] | None
register_argument_name: str
globals: dict[str, Any] | None
file: str | None
lineno_offset: int
col_offset: int
compactify: bool
# pylint: disable=too-many-locals, too-many-statements
[docs]
def load(
module: str | pyqir.Module,
**kwargs: Unpack[LoadKwargs],
):
"""Converts a PyQIR module into a squin kernel.
Args:
module (str | pyqir.Module): PyQIR code or path to the .ll or .bc file,
or a PyQIR Module object.
Keyword Args:
kernel_name (str): The name of the kernel to load. Defaults to "main".
dialects (ir.DialectGroup): The dialects to use. Defaults to `squin.kernel`.
register_as_argument (bool): Determine whether the resulting kernel function should accept
a single `ilist.IList[Qubit, Any]` argument that is a list of qubits used within the
function. This allows you to compose kernel functions generated from circuits.
Defaults to `False`.
return_measurements (list[int] | None): Which measured qubit results to return. Default:None
register_argument_name (str): The name of the argument that represents the qubit register.
Only used when `register_as_argument=True`. Defaults to "q".
globals (dict[str, Any] | None): The global variables to use. Defaults to None.
file (str | None): The file name for error reporting. Defaults to None.
lineno_offset (int): The line number offset for error reporting. Defaults to 0.
col_offset (int): The column number offset for error reporting. Defaults to 0.
compactify (bool): Whether to compactify the output. Defaults to True.
"""
kernel_name: str = kwargs.pop("kernel_name", "main")
dialects: ir.DialectGroup = kwargs.pop("dialects", kernel)
register_as_argument: bool = kwargs.pop("register_as_argument", False)
# TODO:return_measurements: list[int] | None = kwargs.pop("return_measurements", None)
register_argument_name: str = kwargs.pop("register_argument_name", "q")
globals: dict[str, Any] | None = kwargs.pop( # pylint: disable=redefined-builtin
"globals", None
)
file: str | None = kwargs.pop("file", None)
lineno_offset: int = kwargs.pop("lineno_offset", 0)
col_offset: int = kwargs.pop("col_offset", 0)
compactify: bool = kwargs.pop("compactify", True)
if kwargs:
unexpected = ", ".join(f"'{k}'" for k in kwargs)
raise InvalidSquinInput(f"load() got unexpected keyword argument(s): {unexpected}")
# Validate input type
if not isinstance(module, (str, pyqir.Module)):
raise InvalidSquinInput(f"Invalid input {type(module)}, expected 'str | pyqir.Module'")
# If module is a string, interpret as path to a file (.ll or .bc for QIR IR/bitcode)
# or as QIR IR text that can be parsed
if isinstance(module, str):
if os.path.exists(module):
_, ext = os.path.splitext(module)
if ext.lower() == ".ll":
# Load LLVM IR (text) file as a PyQIR module
with open(module, "r", encoding="utf-8") as f:
ir_text = f.read()
module = pyqir.Module.from_ir(pyqir.Context(), ir_text, name=kernel_name)
elif ext.lower() == ".bc":
# Load LLVM bitcode as a PyQIR module
with open(module, "rb") as f:
bitcode_bytes = f.read()
module = pyqir.Module.from_bitcode(pyqir.Context(), bitcode_bytes, name=kernel_name)
else:
raise InvalidFileException(f"Expected file extension .ll or .bc but got {ext!r}")
else:
# Try to parse string as QIR IR text
try:
module = pyqir.Module.from_ir(pyqir.Context(), module, name=kernel_name)
except Exception as exc:
raise InvalidSquinInput(
f"Invalid input {type(module)}, String must be a valid QIR IR text."
) from exc
target = SquinVisitor( # pylint: disable=unexpected-keyword-arg
dialects=dialects, module=module
)
body = target.run(
module,
file=file,
globals=globals,
lineno_offset=lineno_offset,
col_offset=col_offset,
compactify=compactify,
register_as_argument=register_as_argument,
register_argument_name=register_argument_name,
# TODO: return_measurements=return_measurements,
)
# TODO: Determine what to return based on return_measurements parameter
# if return_measurements and len(return_measurements) > 0:
# # Return measurement results for specified qubits
# measurement_results = []
# for qid in return_measurements:
# # Validate qubit index is in range
# if qid < 0 or qid >= target.num_qubits:
# raise InvalidSquinInput(
# f"Cannot return measurement for qubit {qid}: "
# f"qubit index out of range [0, {target.num_qubits})"
# )
# # Check if qubit was measured
# if qid in target.measurement_results:
# measurement_results.append(target.measurement_results[qid])
# else:
# raise InvalidSquinInput(
# f"Cannot return measurement for qubit {qid}: qubit was not measured"
# )
# if len(measurement_results) == 1:
# # Single measurement result - return directly
# return_value = measurement_results[0]
# else:
# # Multiple measurement results - create tuple (py.tuple.New requires tuple, not list)
# tuple_stmt = py.tuple.New(values=tuple(measurement_results))
# body.blocks[0].stmts.append(tuple_stmt)
# return_value = tuple_stmt.result
# else:
# Return None
return_value = func.ConstantNone()
body.blocks[0].stmts.append(return_value)
return_node = func.Return(value_or_stmt=return_value)
body.blocks[0].stmts.append(return_node) # pylint: disable=no-member
self_arg_name = kernel_name + "_self"
arg_names = [self_arg_name]
if register_as_argument:
args = (target.qreg.type,)
arg_names.append(register_argument_name)
# Include argument name in slots so it appears in the printed signature
slots = (register_argument_name,)
else:
args = () # type: ignore
slots = () # type: ignore
signature = func.Signature(args, return_node.value.type)
body.blocks[0].args.insert_from(
0,
types.Generic(ir.Method, types.Tuple.where(signature.inputs), signature.output),
self_arg_name,
)
# pylint: disable-next=unexpected-keyword-arg
code = func.Function(
sym_name=kernel_name,
signature=signature,
slots=slots,
body=body,
)
mt = ir.Method(
sym_name=kernel_name,
arg_names=arg_names,
dialects=dialects,
code=code,
)
assert (run_pass := kernel.run_pass) is not None
run_pass(mt, typeinfer=True)
return mt
[docs]
class SquinVisitor(lowering.LoweringABC[pyqir.Module]):
"""convert a pyqir module to a squin kernel"""
[docs]
def __init__(self, dialects: ir.DialectGroup, module: pyqir.Module):
"""Initialize the SquinVisitor.
Args:
dialects: The dialects to use for lowering.
module: The PyQIR module to convert.
"""
super().__init__(dialects=dialects)
self.module = module
self.qreg: ir.SSAValue = None # type: ignore
self.num_qubits: int | None = None
self.qubit_ssa_map: dict[int, ir.SSAValue] = {}
self.visit_map: dict[
type, Callable[[lowering.State[pyqir.Module], Any], lowering.Result]
] = {
pyqir.Call: self.visit_call,
pyqir.BasicBlock: self.visit_basic_block,
pyqir.Constant: self.visit_constant,
pyqir.FloatConstant: self.visit_constant,
pyqir.IntConstant: self.visit_constant,
}
# Abstract/Required Methods
def lower_literal(self, state: lowering.State[pyqir.Module], value) -> ir.SSAValue:
raise lowering.BuildError("Literals not supported in pyqir module")
def lower_global(
self, state: lowering.State[pyqir.Module], node: pyqir.Module
) -> lowering.LoweringABC.Result:
raise lowering.BuildError("Globals not supported in pyqir module")
# pylint: disable-next=too-many-arguments
def run(
self,
module: pyqir.Module,
*,
globals: dict[str, Any] | None = None, # pylint: disable=redefined-builtin
file: str | None = None,
lineno_offset: int = 0,
col_offset: int = 0,
compactify: bool = True,
register_as_argument: bool = False,
register_argument_name: str = "q",
# return_measurements: list[int] | None = None,
) -> ir.Region:
"""Run the visitor on a PyQIR module."""
state = lowering.State(
self,
file=file,
lineno_offset=lineno_offset,
col_offset=col_offset,
)
with state.frame([module], globals=globals, finalize_next=False) as frame:
self.entry_point = next( # pylint: disable=attribute-defined-outside-init
filter(pyqir.is_entry_point, module.functions), None
)
if self.entry_point is None:
raise InvalidSquinInput("No entry point found in pyqir module")
self.num_qubits = pyqir.required_num_qubits(self.entry_point)
if self.num_qubits is None or self.num_qubits < 1:
raise InvalidSquinInput(
f"Invalid number of qubits {self.num_qubits}, must be greater than 0"
)
if register_as_argument:
frame.curr_block.args.append_from(
ilist.IListType[qubit.QubitType, types.Literal(self.num_qubits)],
name=register_argument_name,
)
self.qreg = frame.curr_block.args[0]
# Extract individual qubit SSA values from the register and store in map
for qid in range(self.num_qubits):
index_ssa = frame.push(py.Constant(qid)).result
qbit_getitem = frame.push(
py.GetItem(self.qreg, index_ssa) # pylint: disable=too-many-function-args
)
self.qubit_ssa_map[qid] = qbit_getitem.result
else:
# Create individual qubits using qubit.new() and store SSA values directly in map
for qid in range(self.num_qubits):
squin_qubit = frame.push(
func.Invoke( # pylint: disable=unexpected-keyword-arg, too-many-function-args
(), callee=qubit.new
)
)
self.qubit_ssa_map[qid] = squin_qubit.result
self.visit(state)
if compactify:
Walk(CFGCompactify()).rewrite(frame.curr_region)
region = frame.curr_region
return region
def visit(self, state: lowering.State[pyqir.Module]) -> lowering.Result:
"""Visit a PyQIR module.
Args:
state (lowering.State[pyqir.Module]): The state of the visitor.
Returns:
lowering.Result: The result of the visitor.
"""
# There could be multiple basic blocks in the entry point
assert isinstance(self.entry_point.basic_blocks, list) # type: ignore
for block in self.entry_point.basic_blocks: # type: ignore
self.visit_node(state, block)
def visit_node(self, state: lowering.State[pyqir.Module], node: Any) -> lowering.Result | None:
"""Visit a PyQIR node.
Args:
state (lowering.State[pyqir.Module]): The state of the visitor.
node (Any): The node to visit.
Returns:
lowering.Result: The result of the visitor.
"""
visitor_function = self.visit_map.get(type(node))
if visitor_function:
return visitor_function(state, node)
return None
def visit_basic_block(
self, state: lowering.State[pyqir.Module], block: pyqir.BasicBlock
) -> lowering.Result:
"""Visit a PyQIR basic block.
Args:
state (lowering.State[pyqir.Module]): The state of the visitor.
block (pyqir.BasicBlock): The basic block to visit.
Returns:
lowering.Result: The result of the visitor.
"""
if len(block.instructions) < 1:
raise InvalidSquinInput("No instructions found in basic block")
for instruction in block.instructions:
self.visit_node(state, instruction)
def visit_call(
self, state: lowering.State[pyqir.Module], call: pyqir.Call
) -> lowering.Result | None:
"""Visit a PyQIR Call instruction.
Args:
state (lowering.State[pyqir.Module]): The state of the visitor.
call (pyqir.Call): The call instruction to visit.
Returns:
lowering.Result: The result of the visitor.
"""
gate_name = call.callee.name
args = call.args
if gate_name in QIR_TO_SQUIN_UNSUPPORTED_STATEMENTS_MAP:
return None
if gate_name not in PYQIR_TO_SQUIN_GATES_MAP:
raise InvalidSquinInput(f"Unsupported gate: {gate_name}")
assert isinstance(args, list)
return self.visit_gate(state, gate_name, args)
def visit_gate(
self,
state: lowering.State[pyqir.Module],
gate_name: str,
args: list[pyqir.Value],
) -> lowering.Result:
"""Visit a PyQIR gate and convert it to a Squin gate.
Args:
state (lowering.State[pyqir.Module]): The state of the visitor.
gate_name (str): The name of the gate to visit.
args (list[pyqir.Value]): The arguments to the gate.
Returns:
lowering.Result: The result of the visitor.
"""
squin_gate = PYQIR_TO_SQUIN_GATES_MAP[gate_name]
inputs: list[ir.SSAValue] = []
for arg in args:
inputs.append(self.visit_node(state, arg))
inputs = tuple(inputs) # type: ignore
return state.current_frame.push(
func.Invoke( # pylint: disable=unexpected-keyword-arg, too-many-function-args
inputs, callee=squin_gate
)
)
def visit_constant(
self, state: lowering.State[pyqir.Module], value: pyqir.Value
) -> ir.SSAValue:
"""Visit a PyQIR Constant instruction.
Args:
state (lowering.State[pyqir.Module]): The state of the visitor.
value (pyqir.Value): The value to visit.
Returns:
ir.SSAValue: The SSA value of the constant.
"""
qubit_id = pointer_id(value)
if qubit_id is not None and qubit_id in self.qubit_ssa_map:
return self.qubit_ssa_map[qubit_id]
supported_classical_const = (
isinstance(value, (pyqir.FloatConstant, pyqir.IntConstant))
or value.type.is_double
or isinstance(value.type, pyqir.IntType)
)
if supported_classical_const:
return state.current_frame.push(py.Constant(value=value.value)).result # type: ignore
raise InvalidSquinInput(f"Unsupported constant value: {value}")