# Copyright 2025 qBraid
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module defining BraketDeviceWrapper Class
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Union
from braket.ahs.analog_hamiltonian_simulation import AnalogHamiltonianSimulation
from braket.aws import AwsDevice
from braket.circuits import Circuit
from braket.circuits.measure import Measure
from qbraid.programs import NATIVE_REGISTRY, QPROGRAM_REGISTRY, ExperimentType, load_program
from qbraid.runtime.device import QuantumDevice
from qbraid.runtime.enums import DeviceStatus
from qbraid.runtime.exceptions import DeviceProgramTypeMismatchError
from qbraid.transpiler import transpile
from .availability import next_available_time
from .job import BraketQuantumTask
if TYPE_CHECKING:
import braket.aws
import qbraid.runtime
import qbraid.runtime.aws
import qbraid.transpiler
[docs]
class BraketDevice(QuantumDevice):
"""Wrapper class for Amazon Braket ``Device`` objects."""
[docs]
def __init__(
self,
profile: qbraid.runtime.TargetProfile,
session: Optional[braket.aws.AwsSession] = None,
):
"""Create a BraketDevice."""
super().__init__(profile=profile)
self._device = AwsDevice(arn=self.id, aws_session=session)
self._provider_name = self.profile.get("provider_name")
@property
def name(self) -> str:
"""Return the name of this Device."""
return self._device.name
def __str__(self):
"""String representation of the BraketDevice object."""
return f"{self.__class__.__name__}('{self._provider_name} {self.name}')"
def status(self) -> qbraid.runtime.DeviceStatus:
"""Return the status of this Device."""
if self._device.status == "ONLINE":
if self._device.is_available:
return DeviceStatus.ONLINE
return DeviceStatus.UNAVAILABLE
if self._device.status == "RETIRED":
return DeviceStatus.RETIRED
return DeviceStatus.OFFLINE
def availability_window(self) -> tuple[bool, str, str]:
"""Provides device availability status. Indicates current availability,
time remaining (hours, minutes, seconds) until next availability or
unavailability, and future UTC datetime of next change in availability status.
Returns:
tuple[bool, str, Optional[str]]: Current device availability, hr/min/sec until
availability switch, future UTC datetime of availability switch
"""
return next_available_time(self._device)
def queue_depth(self) -> int:
"""Return the number of jobs in the queue for the device."""
queue_depth_info = self._device.queue_depth()
total_queued = 0
for queue_type in ["Normal", "Priority"]:
num_queued = queue_depth_info.quantum_tasks[queue_type]
if isinstance(num_queued, str) and num_queued.startswith(">"):
num_queued = num_queued[1:]
total_queued += int(num_queued)
return total_queued
def transform(
self, run_input: Union[Circuit, AnalogHamiltonianSimulation]
) -> Union[Circuit, AnalogHamiltonianSimulation]:
"""Transpile a circuit for the device."""
program = run_input
provider = (self.profile.provider_name or "").upper()
experiment_type = self.profile.experiment_type
if experiment_type == ExperimentType.GATE_MODEL and not isinstance(program, Circuit):
raise DeviceProgramTypeMismatchError(program, str(Circuit), experiment_type)
if experiment_type == ExperimentType.AHS and not isinstance(
program, AnalogHamiltonianSimulation
):
raise DeviceProgramTypeMismatchError(
program, str(AnalogHamiltonianSimulation), experiment_type
)
# Use Tket to transform circuits for IonQ to support a bigger gateset. The transformation
# only applies when no measurement is presented and the circuit only use contiguous qubits.
has_contiguous_qubits = False
includes_measurement = False
if isinstance(run_input, Circuit):
includes_measurement = any(
isinstance(instruction.operator, Measure) for instruction in run_input.instructions
)
has_contiguous_qubits = max(run_input.qubits) + 1 == run_input.qubit_count
if (
provider == "IONQ"
and isinstance(run_input, Circuit)
and not includes_measurement
and has_contiguous_qubits
):
graph = self.scheme.conversion_graph
if (
graph is not None
and graph.has_edge("pytket", "braket")
and QPROGRAM_REGISTRY["pytket"] == NATIVE_REGISTRY["pytket"]
and QPROGRAM_REGISTRY["braket"] == NATIVE_REGISTRY["braket"]
and self._target_spec.alias == "braket"
):
tk_circuit = transpile(program, "pytket", max_path_depth=1, conversion_graph=graph)
tk_program = load_program(tk_circuit)
tk_program.transform(self)
tk_transformed = tk_program.program
braket_transformed = transpile(
tk_transformed, "braket", max_path_depth=1, conversion_graph=graph
)
program = braket_transformed
qprogram = load_program(program)
qprogram.transform(self)
program = qprogram.program
return program
def submit(
self,
run_input: Union[
Circuit, AnalogHamiltonianSimulation, list[Circuit], list[AnalogHamiltonianSimulation]
],
*args,
**kwargs,
) -> Union[BraketQuantumTask, list[BraketQuantumTask]]:
"""Run a quantum task specification on this quantum device. Task must represent a
quantum circuit, annealing problems not supported.
Args:
run_input: Specification of a task to run on device.
Keyword Args:
shots (int): The number of times to run the task on the device.
Returns:
The job like object for the run.
"""
is_single_input = not isinstance(run_input, list)
run_input = [run_input] if is_single_input else run_input
if any(hasattr(circuit, "partial_measurement_qubits") for circuit in run_input):
# Extract partial measurement qubit information and add as tags for job tracking
tasks = []
for circuit in run_input:
tags: dict[str, str] = {}
if hasattr(circuit, "partial_measurement_qubits"):
partial_measurement_qubits: list[int] = circuit.partial_measurement_qubits
# Convert qubit indices to a string format for tagging (e.g., "0/2/3")
tag_value = "/".join([str(x) for x in partial_measurement_qubits])
tags = {"partial_measurement_qubits": tag_value}
task = self._device.run(circuit, tags=tags, *args, **kwargs)
tasks.append(BraketQuantumTask(task.id, task=task, device=self._device))
else:
aws_quantum_task_batch = self._device.run_batch(run_input, *args, **kwargs)
tasks = [
BraketQuantumTask(task.id, task=task, device=self._device)
for task in aws_quantum_task_batch.tasks
]
if is_single_input:
return tasks[0]
return tasks