# Copyright (c) 2024, qBraid Development Team
# All rights reserved.
"""
Module for making requests to the qBraid API.
"""
import configparser
import logging
import os
import warnings
from typing import TYPE_CHECKING, Any, Optional, Union
from urllib.parse import urljoin
import requests
from requests.adapters import HTTPAdapter
from urllib3.exceptions import InsecureRequestWarning
from ._compat import __version__
from .config import DEFAULT_CONFIG_SECTION, DEFAULT_ENDPOINT_URL, load_config
from .config import save_config as save_user_config
from .config import update_config_option
from .exceptions import AuthError, ConfigError, RequestsApiError, UserNotFoundError
from .registry import client_registry, discover_services
from .retry import STATUS_FORCELIST, PostForcelistRetry
if TYPE_CHECKING:
import qbraid_core
logger = logging.getLogger(__name__)
[docs]
class Session(requests.Session):
"""Custom session with handling of request urls and authentication.
This is a child class of :py:class:`requests.Session`. It handles
authentication with custom headers,and retries on specific 5xx errors.
"""
[docs]
def __init__( # pylint: disable=too-many-arguments
self,
*args,
base_url: Optional[str] = None,
headers: Optional[dict[str, Any]] = None,
auth_headers: Optional[dict[str, Any]] = None,
**kwargs,
):
"""
Initialize custom session with default base_url and auth_headers.
Args:
base_url (optional, str): Base URL to prepend to all requests.
headers (optional, dict): Dictionary of headers to include in all requests.
auth_headers (optional, dict): Dictionary of authorization headers to include in all
requests. Values will be masked in error messages.
"""
super().__init__(*args, **kwargs)
self.base_url = base_url
self.auth_headers = {}
if auth_headers:
self.auth_headers.update(auth_headers)
if headers:
self.headers.update(headers)
self.headers.update(self.auth_headers)
self.headers["User-Agent"] = self._user_agent()
self._raise_for_status = True
@property
def base_url(self) -> Optional[str]:
"""Return the base URL."""
return self._base_url
@base_url.setter
def base_url(self, value: Optional[str]) -> None:
"""Set the base URL."""
self._base_url = value
def _user_agent(self) -> str:
"""Return the user agent string."""
return f"QbraidCore/{__version__}"
def add_user_agent(self, user_agent: str) -> None:
"""Updates the User-Agent header with additional information.
Args:
user_agent (str): Additional user agent information to append.
"""
if user_agent not in self.headers["User-Agent"]:
self.headers["User-Agent"] = f"{self.headers['User-Agent']} {user_agent}"
def initialize_retry(
self,
total: int = 2,
connect: int = 1,
backoff_factor: float = 0.5,
status_forcelist: Union[list[int], set[int], tuple[int, ...]] = STATUS_FORCELIST,
**kwargs,
) -> None:
"""Set the session retry policy.
Args:
total (int): Number of total retries for the requests. Default 2.
connect (int): Number of connect retries for the requests. Default 1.
backoff_factor (float): Backoff factor between retry attempts. Default 0.5.
status_forcelist (Union[list[int], set[int], tuple[int, ...]]): List of status
codes to force a retry on.
"""
# Raising an exception on status code is handled by the request method.
raise_on_status = kwargs.pop("raise_on_status", True)
if not isinstance(raise_on_status, bool):
raise ValueError("raise_on_status must be a boolean.")
self._raise_for_status = raise_on_status
retry = PostForcelistRetry(
total=total,
connect=connect,
backoff_factor=backoff_factor,
status_forcelist=status_forcelist,
raise_on_status=False,
**kwargs,
)
retry_adapter = HTTPAdapter(max_retries=retry)
self.mount("http://", retry_adapter)
self.mount("https://", retry_adapter)
@staticmethod
def _get_error_message_from_json(error_json: dict[str, Any]) -> Optional[str]:
"""Extracts the error message from the JSON response."""
if not error_json or not isinstance(error_json, dict):
return None
msg = error_json.get("message")
if not msg:
error_data = error_json.get("error")
if isinstance(error_data, dict):
msg = error_data.get("message")
elif isinstance(error_data, str):
msg = error_data
return msg
@staticmethod
def _mask_sensitive_data(message: str, auth_headers: dict[str, str]) -> str:
"""Replaces sensitive data in the message with a placeholder."""
for _, value in auth_headers.items():
message = message.replace(value, "...")
return message
def request(self, method: str, url: str, *args, **kwargs) -> requests.Response:
"""Construct, prepare, and send a ``Request``.
Override the request method to prepend base_url to the URL and include additional headers.
Args:
method (str): HTTP method (e.g., 'get', 'post').
url (str): URL for the request. Prepend base_url if url is a relative URL.
**kwargs: Additional arguments for the request
Returns:
Response object.
Raises:
RequestsApiError: If the request failed.
"""
if self.base_url and not url.startswith(("http://", "https://")):
base_url = self.base_url.rstrip("/") + "/"
url = url.lstrip("/")
url = urljoin(base_url, url)
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", InsecureRequestWarning)
response = super().request(method, url, *args, **kwargs)
if self._raise_for_status:
response.raise_for_status()
except requests.RequestException as err:
message = None
if err.response is not None:
try:
error_json: dict[str, Any] = err.response.json()
message = self._get_error_message_from_json(error_json)
except ValueError:
message = err.response.text
message = message or str(err)
message = message if message.endswith(".") else message + "."
message = self._mask_sensitive_data(message, self.auth_headers)
raise RequestsApiError(message) from err
return response
[docs]
class QbraidSession(Session): # pylint: disable=too-many-instance-attributes
"""Custom session with handling of request urls and authentication.
This is a child class of :py:class:`qbraid_core.sessions.Session`.
It handles qbraid authentication with custom headers and has SSL
verification disabled for compatibility with qBraid Lab.
"""
[docs]
def __init__(self, *args, api_key: Optional[str] = None, **kwargs) -> None:
"""Initialize custom session with default base_url and auth_headers.
Args:
api_key (optional, str): Authenticated qBraid API key.
"""
self._api_key = None
self._user_email = None
self._refresh_token = None
self.api_key = api_key
self.user_email = kwargs.pop("user_email", None)
self.refresh_token = kwargs.pop("refresh_token", None)
self.verify = False
if "headers" not in kwargs:
kwargs["headers"] = {}
if "domain" not in kwargs["headers"]:
kwargs["headers"]["domain"] = kwargs.pop("pool", "qbraid")
if "auth_headers" not in kwargs:
kwargs["auth_headers"] = {}
if self.api_key:
kwargs["auth_headers"]["api-key"] = self.api_key
if self.refresh_token:
kwargs["auth_headers"]["refresh-token"] = self.refresh_token
if self.user_email:
kwargs["auth_headers"]["email"] = self.user_email
super().__init__(*args, **kwargs)
self.initialize_retry()
@property
def base_url(self) -> Optional[str]:
"""Return the base URL."""
return super().base_url
@base_url.setter
def base_url(self, value: Optional[str]) -> None:
"""Set the qbraid api url."""
url = value or self.get_config("url")
value = url or DEFAULT_ENDPOINT_URL
value = value.rstrip("/") + "/"
self._base_url = value
@property
def api_key(self) -> Optional[str]:
"""Return the api key."""
return self._api_key
@api_key.setter
def api_key(self, value: Optional[str]) -> None:
"""Set the api key."""
api_key = value or self.get_config("api-key")
self._api_key = api_key or os.getenv("QBRAID_API_KEY")
@property
def user_email(self) -> Optional[str]:
"""Return the session user email."""
return self._user_email
@user_email.setter
def user_email(self, value: Optional[str]) -> None:
"""Set the session user email."""
user_email = value or self.get_config("email")
self._user_email = user_email or os.getenv("JUPYTERHUB_USER")
@property
def refresh_token(self) -> Optional[str]:
"""Return the session refresh token."""
return self._refresh_token
@refresh_token.setter
def refresh_token(self, value: Optional[str]) -> None:
"""Set the session refresh token."""
refresh_token = value or self.get_config("refresh-token")
self._refresh_token = refresh_token or os.getenv("REFRESH")
def get_config(self, config_name: str) -> Optional[str]:
"""Returns the config value of specified config.
Args:
config_name: The name of the config
"""
try:
config = load_config()
except ConfigError:
return None
section = DEFAULT_CONFIG_SECTION
if section in config.sections():
if config_name in config[section]:
return config[section][config_name]
return None
def get_user(self) -> dict[str, Any]:
"""Get user metadata.
Returns:
Dictionary containing user metadata.
Raises:
UserNotFoundError: If user metadata is invalid or not found.
"""
try:
metadata = self.get("/identity").json()
except RequestsApiError as err:
raise UserNotFoundError(str(err)) from err
if not metadata:
raise UserNotFoundError("User metadata invalid or not found.")
return metadata
def save_config(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
verify: bool = True,
overwrite: bool = False,
**kwargs,
) -> None:
"""Create qbraidrc file. In qBraid Lab, qbraidrc is automatically present in filesystem.
Raises:
UserNotFoundError: If user metadata is invalid or not found.
AuthError: If there is a credential mismatch.
ConfigError: If there is an error saving the config.
"""
self._api_key = api_key or self.api_key
self._refresh_token = kwargs.get("refresh_token", self.refresh_token)
self._user_email = kwargs.get("user_email", self.user_email)
if base_url:
value = base_url.rstrip("/") + "/"
self._base_url = value
config = configparser.ConfigParser()
if overwrite:
# Starting with a clean config if overwrite is True
section = DEFAULT_CONFIG_SECTION
config.add_section(section)
else:
# Load existing config if overwrite is False
try:
config = load_config()
except ConfigError:
config.add_section(DEFAULT_CONFIG_SECTION)
section = DEFAULT_CONFIG_SECTION
if section not in config.sections():
config.add_section(section)
# Set or update configurations
options = {
"email": self._user_email,
"api-key": self._api_key,
"refresh-token": kwargs.get("refresh_token", self._refresh_token),
"url": self.base_url,
}
for option, value in options.items():
config = update_config_option(config, section, option, value)
save_user_config(config)
if verify:
res_json = self.get_user()
res_email = res_json.get("email")
if self._user_email and self._user_email != res_email:
raise AuthError(
f"Credential mismatch: Session initialized for '{self._user_email}', "
f"but API key corresponds to '{res_email}'."
)
def get_available_services(self) -> list[str]:
"""
Get a list of available services that can be loaded as low-level
clients via :py:meth:`Session.client`.
Returns:
List: List of service names.
"""
services_path = os.path.join(os.path.dirname(__file__), "services")
return list(discover_services(services_path))
def client(
self, service_name: str, api_key: Optional[str] = None, **kwargs
) -> "qbraid_core.QbraidClient":
"""Return a client for the specified service.
Args:
service_name (str): Name of the service.
api_key (optional, str): API key for the client service.
Returns:
qbraid_core.QbraidClient: Client for the specified service.
"""
if len(client_registry) == 0:
self.get_available_services()
client_class = client_registry.get(service_name)
if not client_class:
raise ValueError(f"Service '{service_name}' not registered")
session = None if api_key else self
return client_class(session=session, api_key=api_key, **kwargs)