Source code for qbraid.runtime.braket.provider

# Copyright (C) 2024 qBraid
#
# This file is part of the qBraid-SDK
#
# The qBraid-SDK is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3.

"""
Module for configuring provider credentials and authentication.

"""
import json
import os
from typing import TYPE_CHECKING, Optional

import boto3
from boto3.session import Session
from braket.ahs.analog_hamiltonian_simulation import AnalogHamiltonianSimulation
from braket.aws import AwsDevice, AwsSession
from braket.circuits import Circuit
from qbraid_core.services.quantum import quantum_lib_proxy_state
from qbraid_core.services.quantum.proxy_braket import aws_configure

from qbraid.exceptions import QbraidError
from qbraid.programs import ProgramSpec
from qbraid.runtime import DeviceActionType, DeviceType, QuantumProvider, TargetProfile

from .device import BraketDevice

if TYPE_CHECKING:
    import braket.aws

    import qbraid.runtime.braket


[docs] class BraketProvider(QuantumProvider): """ This class is responsible for managing the interactions and authentications with the AWS services. Attributes: aws_access_key_id (str): AWS access key ID for authenticating with AWS services. aws_secret_access_key (str): AWS secret access key for authenticating with AWS services. """ REGIONS = ("us-east-1", "us-west-1", "us-west-2", "eu-west-2")
[docs] def __init__( self, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None ): """ Initializes the QbraidProvider object with optional AWS credentials. Args: aws_access_key_id (str, optional): AWS access key ID. Defaults to None. aws_secret_access_key (str, optional): AWS secret access token. Defaults to None. """ self.aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID") self.aws_secret_access_key = aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
def save_config( self, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, **kwargs, ): """Save the current configuration.""" aws_access_key_id = aws_access_key_id or self.aws_access_key_id aws_secret_access_key = aws_secret_access_key or self.aws_secret_access_key aws_configure( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, **kwargs, ) @staticmethod def _get_default_region() -> str: """Returns the default AWS region.""" default_region = "us-east-1" try: session = Session() return session.region_name or default_region except Exception: # pylint: disable=broad-exception-caught return default_region @staticmethod def _get_s3_default_bucket() -> str: """Return name of the default bucket to use in the AWS Session""" try: aws_session = AwsSession() return aws_session.default_bucket() except Exception: # pylint: disable=broad-exception-caught return "amazon-braket-qbraid-jobs" def _get_aws_session(self, region_name: Optional[str] = None) -> "braket.aws.AwsSession": """Returns the AwsSession provider.""" region_name = region_name or self._get_default_region() default_bucket = self._get_s3_default_bucket() boto_session = Session( region_name=region_name, aws_access_key_id=self.aws_access_key_id, aws_secret_access_key=self.aws_secret_access_key, ) braket_client = boto_session.client("braket", region_name=region_name) return AwsSession( boto_session=boto_session, braket_client=braket_client, default_bucket=default_bucket ) def _build_runtime_profile( self, device: "braket.aws.AwsDevice", program_spec: Optional[ProgramSpec] = None, **kwargs ) -> TargetProfile: """Returns the runtime profile for the device.""" metadata = device.aws_session.get_device(device.arn) device_type = DeviceType(metadata.get("deviceType")) provider_name = metadata.get("providerName") capabilities = json.loads(metadata.get("deviceCapabilities")) num_qubits = capabilities.get("paradigm", {}).get("qubitCount") action = capabilities.get("action", {}) if action.get("braket.ir.openqasm.program") is not None: action_type = DeviceActionType.OPENQASM program_spec = program_spec or ProgramSpec(Circuit) elif action.get("braket.ir.ahs.program") is not None: action_type = DeviceActionType.AHS program_spec = program_spec or ProgramSpec( AnalogHamiltonianSimulation, alias="braket_ahs" ) else: raise QbraidError( f"TargetProfile cannot be created for device '{device.arn}' as it does not " f"support 'braket.ir.openqasm.program' program types. Please verify device " f"capabilities or select a different, compatible device." ) return TargetProfile( device_type=device_type, num_qubits=num_qubits, action_type=action_type, program_spec=program_spec, provider_name=provider_name, device_id=device.arn, **kwargs, ) def get_devices( self, aws_session: "Optional[braket.aws.AwsSession]" = None, statuses: Optional[list[str]] = None, **kwargs, ) -> list["qbraid.runtime.braket.BraketDevice"]: """Return a list of backends matching the specified filtering.""" aws_session = self._get_aws_session() if aws_session is None else aws_session statuses = ["ONLINE", "OFFLINE"] if statuses is None else statuses aws_devices = AwsDevice.get_devices(aws_session=aws_session, statuses=statuses, **kwargs) return [ BraketDevice(profile=self._build_runtime_profile(device), session=device.aws_session) for device in aws_devices ] def get_device(self, device_id: str) -> "qbraid.runtime.braket.BraketDevice": """Returns the AWS device.""" try: region_name = device_id.split(":")[3] except IndexError as err: raise ValueError( f"Device ARN is not a valid format: {device_id}. For valid Braket ARNs, " "see 'https://docs.aws.amazon.com/braket/latest/developerguide/braket-devices.html'" ) from err region_name = region_name or self._get_default_region() aws_session = self._get_aws_session(region_name=region_name) device = AwsDevice(arn=device_id, aws_session=aws_session) profile = self._build_runtime_profile(device) return BraketDevice(profile=profile, session=device.aws_session) @staticmethod def _fetch_resources(region_names: list[str], key: str, values: list[str]) -> list[str]: """Fetch resources from AWS.""" tasks = [] for region_name in region_names: client = boto3.client("resourcegroupstaggingapi", region_name=region_name) response = client.get_resources( TagFilters=[ { "Key": key, "Values": values, } ], ) matches = [t["ResourceARN"] for t in response["ResourceTagMappingList"]] tasks += matches return tasks def get_tasks_by_tag( self, key: str, values: Optional[list[str]] = None, region_names: Optional[list[str]] = None ) -> list[str]: """ Retrieves a list of quantum task ARNs that match the specified tag keys or key/value pairs. Args: key (str): The tag key to match. values (Optional[list[str]]): A list of tag values to match against the provided key. If None, tasks with the specified key, regardless of its value, are matched. region_names (Optional[list[str]]): A list of region names to search. If None, all regions in `self.REGIONS` are searched. Returns: list[str]: A list of ARNs for quantum tasks that match the given tag criteria. Raises: QbraidError: If the function is called within a qBraid quantum job environment where AWS S3 requests are not supported. """ try: jobs_state = quantum_lib_proxy_state("braket") jobs_enabled = jobs_state.get("enabled", False) except ValueError: jobs_enabled = False if jobs_enabled: raise QbraidError("AWS S3 requests not supported by qBraid quantum jobs.") region_names = ( region_names if region_names is not None and len(region_names) > 0 else list(self.REGIONS) ) values = values if values is not None else [] return self._fetch_resources(region_names, key, values)