# Copyright (C) 2024 qBraid
#
# This file is part of the qBraid-SDK
#
# The qBraid-SDK is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3.
"""
Module for providing transforamtions to ensure OpenQASM 3 compatibility
across various other quantum software frameworks.
"""
import math
import re
from functools import reduce
from typing import Union
from openqasm3 import dumps, parse
from openqasm3.ast import Include, Program, QuantumGate, QuantumMeasurementStatement
from openqasm3.parser import QASM3ParsingError
from qbraid._logging import logger
GATE_DEFINITIONS = {
"iswap": """
gate iswap _gate_q_0, _gate_q_1 {
s _gate_q_0;
s _gate_q_1;
h _gate_q_0;
cx _gate_q_0, _gate_q_1;
cx _gate_q_1, _gate_q_0;
h _gate_q_1;
}""",
"sxdg": """
gate sxdg _gate_q_0 {
s _gate_q_0;
h _gate_q_0;
s _gate_q_0;
}""",
}
[docs]
def insert_gate_def(qasm3_str: str, gate_name: str, force_insert: bool = False) -> str:
"""Add gate definitions to an Open0QASM 3 string.
Args:
qasm3_str (str): QASM 3.0 string.
gate_name (str): Name of the gate to insert.
force_insert (bool): If True, the gate definition will be added
even if the gate is never referenced. Defaults to False.
Returns:
str: QASM 3.0 string with gate definition.
Raises:
ValueError: If the gate definition is not found.
"""
defn = GATE_DEFINITIONS.get(gate_name)
if defn is None:
raise ValueError(
f"Gate {gate_name} definition not found. "
f"Available gate definitions include: {set(GATE_DEFINITIONS.keys())}"
)
if not force_insert and gate_name not in qasm3_str:
return qasm3_str
lines = qasm3_str.splitlines()
insert_index = 0
for i, line in enumerate(lines):
if "include" in line or "OPENQASM" in line:
insert_index = i + 1
break
lines.insert(insert_index, defn.strip())
return "\n".join(lines)
[docs]
def replace_gate_name(
qasm: str, old_gate_name: str, new_gate_name: str, force_replace: bool = False
) -> str:
"""
Replace occurrences of a specified gate name in a QASM program string with
a new gate name, while optionally enforcing the replacement even if the new
gate name isn't in the predefined gate map.
Args:
qasm (str): The QASM program as a string.
old_gate_name (str): The original gate name to replace.
new_gate_name (str): The new gate name to use in replacement.
force_replace (bool): If True, force the replacement even if the
new gate name isn't in the gate map.
Returns:
str: The modified QASM program with the gate names replaced.
"""
# Define pairs of interchangeable gates
gate_pairs = [
("cnot", "cx"),
("si", "sdg"),
("ti", "tdg"),
("v", "sx"),
("vi", "sxdg"),
("p", "phaseshift"),
("cp", "cphaseshift"),
]
# Create a mapping from each gate to its alternate form
gate_map = {old: new for pair in gate_pairs for old, new in (pair, pair[::-1])}
parameterized_gates = {"p", "cp", "phaseshift", "cphaseshift"}
suffix = "(" if old_gate_name in parameterized_gates else " "
# Replace based on gate map and force_replace flag
if old_gate_name in gate_map and (gate_map[old_gate_name] == new_gate_name or force_replace):
new_gate_name_with_suffix = new_gate_name + suffix
old_gate_name_with_suffix = old_gate_name + suffix
return qasm.replace(old_gate_name_with_suffix, new_gate_name_with_suffix)
if force_replace:
return qasm.replace(old_gate_name, new_gate_name)
return qasm
[docs]
def add_stdgates_include(qasm_str: str) -> str:
"""Add 'include "stdgates.inc";' to the QASM string if it is missing."""
if 'include "stdgates.inc";' in qasm_str:
return qasm_str
lines = qasm_str.splitlines()
for i, line in enumerate(lines):
if "OPENQASM" in line:
lines.insert(i + 1, 'include "stdgates.inc";')
break
return "\n".join(lines)
[docs]
def remove_stdgates_include(qasm: str) -> str:
"""Remove 'include "stdgates.inc";' from the QASM string."""
return qasm.replace('include "stdgates.inc";', "")
def _evaluate_expression(match):
"""Helper function for simplifying arithmetic expressions within parentheses."""
expr = match.group(1)
try:
simplified_value = eval(expr) # pylint: disable=eval-used
return f"({simplified_value})"
except SyntaxError:
return match.group(0)
def simplify_arithmetic_expressions(qasm_str: str) -> str:
"""Simplifies arithmetic expressions within parentheses in a QASM string."""
pattern = r"\(([0-9+\-*/. ]+)\)"
return re.sub(pattern, _evaluate_expression, qasm_str)
[docs]
def convert_qasm_pi_to_decimal(qasm: str) -> str:
"""Convert all instances of 'pi' in the QASM string to their decimal value."""
pattern = r"(?<![a-zA-Z])(\d*\.?\d*\s*[*/+-]\s*)?pi(\s*[*/+-]\s*\d*\.?\d*)?(?![a-zA-Z])"
gate_defs = set()
try:
program = parse(qasm)
for statement in program.statements:
if isinstance(statement, QuantumGate):
name = statement.name.name
if "pi" in name:
gate_defs.add(name)
except QASM3ParsingError as err:
logger.debug("Failed to parse QASM program for pi conversion: %s", err)
def replace_with_decimal(match: re.Match) -> str:
expr: str = match.group()
start = match.start()
end = match.end()
for gate_def in gate_defs:
if gate_def in qasm[max(0, start - len(gate_def)) : end]:
return expr # pragma: no cover
expr_with_pi_as_decimal = expr.replace("pi", str(math.pi))
try:
value = eval(expr_with_pi_as_decimal) # pylint: disable=eval-used
except SyntaxError:
return expr
return str(value)
return re.sub(pattern, replace_with_decimal, qasm)
def has_redundant_parentheses(qasm_str: str) -> bool:
"""Checks if a QASM string contains gate parameters with redundant parentheses."""
pattern = r"\w+\(\(\s*[-+]?\d+(\.\d*)?\s*\)\)"
if re.search(pattern, qasm_str):
return True
pattern_neg = r"\w+\(-\(\d*\.?\d+\)\)"
if re.search(pattern_neg, qasm_str):
return True
return False
def remove_spaces_in_parentheses(expression: str) -> str:
"""Removes all spaces inside parentheses in an expression."""
parenthesized_parts = re.findall(r"\(.*?\)", expression)
for part in parenthesized_parts:
cleaned_part = part.replace(" ", "")
expression = expression.replace(part, cleaned_part)
return expression
def simplify_parentheses_in_qasm(qasm_str: str) -> str:
"""Simplifies unnecessary parentheses around numbers in QASM strings."""
lines = qasm_str.splitlines()
simplified_lines = []
pattern = r"\(\s*([-+]?\s*\d+(\.\d*)?)\s*\)"
def simplify(match):
return match.group(1).replace(" ", "")
for line in lines:
if has_redundant_parentheses(line):
line = re.sub(pattern, simplify, line)
simplified_lines.append(line)
return "\n".join(simplified_lines)
def compose(*functions):
"""Compose multiple functions left to right."""
def compose_two(f, g):
return lambda x: g(f(x))
return reduce(compose_two, functions, lambda x: x)
[docs]
def normalize_qasm_gate_params(qasm: str) -> str:
"""Normalize the parameters of the gates in the QASM string using function composition."""
transform_qasm = compose(
convert_qasm_pi_to_decimal, simplify_arithmetic_expressions, simplify_parentheses_in_qasm
)
return transform_qasm(qasm)
def declarations_to_qasm2(qasm: str) -> str:
"""Converts QASM 3.0 qubit and bit declarations to QASM 2.0 qreg and creg declarations."""
for declaration_type, replacement_type in [("qubit", "qreg"), ("bit", "creg")]:
pattern = rf"{declaration_type}\[(\d+)\]\s+(\w+);"
replacement = rf"{replacement_type} \2[\1];"
qasm = re.sub(pattern, replacement, qasm)
return qasm
def remove_qasm_barriers(qasm_str: str) -> str:
"""Returns a copy of the input QASM with all barriers removed.
Args:
qasm_str: QASM to remove barriers from.
"""
quoted_re = r"(?:\"[^\"]*?\")"
statement_re = r"((?:[^;{}\"]*?" + quoted_re + r"?)*[;{}])?"
comment_re = r"(\n?//[^\n]*(?:\n|$))?"
statements_comments = re.findall(statement_re + comment_re, qasm_str)
lines = []
for statement, comment in statements_comments:
if re.match(r"^\s*barrier(?:(?:\s+)|(?:;))", statement) is None:
lines.append(statement + comment)
return "".join(lines)
[docs]
def remove_measurements(program: Union[Program, str]) -> str:
"""Remove all measurement operations from the program."""
program = parse(program) if isinstance(program, str) else program
statements = [
statement
for statement in program.statements
if not isinstance(statement, QuantumMeasurementStatement)
]
program_out = Program(statements=statements, version=program.version)
program_str = dumps(program_out)
if float(program.version) == 2.0:
program_str = declarations_to_qasm2(program_str)
return program_str
[docs]
def remove_include_statements(program: Union[Program, str]) -> str:
"""Remove all include statements from the program."""
program = parse(program) if isinstance(program, str) else program
statements = [
statement for statement in program.statements if not isinstance(statement, Include)
]
program_out = Program(statements=statements, version=program.version)
program_str = dumps(program_out)
if float(program.version) == 2.0:
program_str = declarations_to_qasm2(program_str)
return program_str