# Copyright (c) 2024, qBraid Development Team
# All rights reserved.
"""
Module for making requests to the qBraid API.
"""
import configparser
import logging
import os
import warnings
from typing import TYPE_CHECKING, Any, Optional
from urllib.parse import urljoin
import requests
from requests.adapters import HTTPAdapter
from urllib3.exceptions import InsecureRequestWarning
from ._compat import __version__
from .config import DEFAULT_CONFIG_SECTION, DEFAULT_ENDPOINT_URL, load_config
from .config import save_config as save_user_config
from .config import update_config_option
from .exceptions import AuthError, ConfigError, RequestsApiError, UserNotFoundError
from .registry import client_registry, discover_services
from .retry import STATUS_FORCELIST, PostForcelistRetry
if TYPE_CHECKING:
import qbraid_core
logger = logging.getLogger(__name__)
[docs]
class Session(requests.Session):
"""Custom session with handling of request urls and authentication.
This is a child class of :py:class:`requests.Session`. It handles
authentication with custom headers,and retries on specific 5xx errors.
"""
[docs]
def __init__( # pylint: disable=too-many-arguments
self,
*args,
base_url: Optional[str] = None,
headers: Optional[dict[str, Any]] = None,
auth_headers: Optional[dict[str, Any]] = None,
retries_total: int = 2,
retries_connect: int = 1,
backoff_factor: float = 0.5,
**kwargs,
):
"""
Initialize custom session with default base_url and auth_headers.
Args:
base_url (optional, str): Base URL to prepend to all requests.
headers (optional, dict): Dictionary of headers to include in all requests.
auth_headers (optional, dict): Dictionary of authorization headers to include in all
requests. Values will be masked in error messages.
retries_total (int): Number of total retries for the requests. Default 2.
retries_connect (int): Number of connect retries for the requests. Default 1.
backoff_factor (float): Backoff factor (seconds) between retry attempts. Default 0.5.
"""
super().__init__(*args, **kwargs)
self.base_url = base_url
self.auth_headers = {}
if auth_headers:
self.auth_headers.update(auth_headers)
if headers:
self.headers.update(headers)
self.headers.update(self.auth_headers)
self.headers["User-Agent"] = self._user_agent()
self._initialize_retry(retries_total, retries_connect, backoff_factor)
@property
def base_url(self) -> Optional[str]:
"""Return the base URL."""
return self._base_url
@base_url.setter
def base_url(self, value: Optional[str]) -> None:
"""Set the base URL."""
self._base_url = value
def _user_agent(self) -> str:
"""Return the user agent string."""
return f"QbraidCore/{__version__}"
def add_user_agent(self, user_agent: str) -> None:
"""Updates the User-Agent header with additional information.
Args:
user_agent (str): Additional user agent information to append.
"""
if user_agent not in self.headers["User-Agent"]:
self.headers["User-Agent"] = f"{self.headers['User-Agent']} {user_agent}"
def _initialize_retry(
self, retries_total: int, retries_connect: int, backoff_factor: float
) -> None:
"""Set the session retry policy.
Args:
retries_total (int): Number of total retries for the requests.
retries_connect (int): Number of connect retries for the requests.
backoff_factor (float): Backoff factor between retry attempts.
"""
retry = PostForcelistRetry(
total=retries_total,
connect=retries_connect,
backoff_factor=backoff_factor,
status_forcelist=STATUS_FORCELIST,
)
retry_adapter = HTTPAdapter(max_retries=retry)
self.mount("http://", retry_adapter)
self.mount("https://", retry_adapter)
def request(self, method: str, url: str, *args, **kwargs) -> requests.Response:
"""Construct, prepare, and send a ``Request``.
Override the request method to prepend base_url to the URL and include additional headers.
Args:
method (str): HTTP method (e.g., 'get', 'post').
url (str): URL for the request. Prepend base_url if url is a relative URL.
**kwargs: Additional arguments for the request
Returns:
Response object.
Raises:
RequestsApiError: If the request failed.
"""
# Prepend the base_url if it is provided and the url is relative
if self.base_url and not url.startswith(("http://", "https://")):
base_url = self.base_url.rstrip("/") + "/"
url = url.lstrip("/")
url = urljoin(base_url, url)
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", InsecureRequestWarning)
response = super().request(method, url, *args, **kwargs)
response.raise_for_status()
except requests.RequestException as err:
# Wrap requests exceptions for compatibility.
message = str(err)
if err.response is not None:
try:
error_json = err.response.json()["error"]
msg = error_json["message"]
code = error_json["code"]
message += f". {msg}, Error code: {code}."
logger.debug(
"Response uber-trace-id: %s", err.response.headers["uber-trace-id"]
)
except Exception: # pylint: disable=broad-except
# the response did not contain the expected json.
message += f". {err.response.text}"
for _, value in self.auth_headers.items():
message = message.replace(value, "...")
raise RequestsApiError(message) from err
return response
[docs]
class QbraidSession(Session): # pylint: disable=too-many-instance-attributes
"""Custom session with handling of request urls and authentication.
This is a child class of :py:class:`qbraid_core.sessions.Session`.
It handles qbraid authentication with custom headers and has SSL
verification disabled for compatibility with qBraid Lab.
"""
[docs]
def __init__(self, *args, api_key: Optional[str] = None, **kwargs) -> None:
"""Initialize custom session with default base_url and auth_headers.
Args:
api_key (optional, str): Authenticated qBraid API key.
"""
self._api_key = None
self._user_email = None
self._refresh_token = None
self.api_key = api_key
self.user_email = kwargs.pop("user_email", None)
self.refresh_token = kwargs.pop("refresh_token", None)
self.verify = False
if "headers" not in kwargs:
kwargs["headers"] = {}
if "domain" not in kwargs["headers"]:
kwargs["headers"]["domain"] = kwargs.pop("pool", "qbraid")
if "auth_headers" not in kwargs:
kwargs["auth_headers"] = {}
if self.api_key:
kwargs["auth_headers"]["api-key"] = self.api_key
if self.refresh_token:
kwargs["auth_headers"]["refresh-token"] = self.refresh_token
if self.user_email:
kwargs["auth_headers"]["email"] = self.user_email
super().__init__(*args, **kwargs)
@Session.base_url.setter
def base_url(self, value: Optional[str]) -> None:
"""Set the qbraid api url."""
url = value or self.get_config("url")
value = url or DEFAULT_ENDPOINT_URL
value = value.rstrip("/") + "/"
super(QbraidSession, QbraidSession).base_url.fset(self, value)
@property
def api_key(self) -> Optional[str]:
"""Return the api key."""
return self._api_key
@api_key.setter
def api_key(self, value: Optional[str]) -> None:
"""Set the api key."""
api_key = value or self.get_config("api-key")
self._api_key = api_key or os.getenv("QBRAID_API_KEY")
@property
def user_email(self) -> Optional[str]:
"""Return the session user email."""
return self._user_email
@user_email.setter
def user_email(self, value: Optional[str]) -> None:
"""Set the session user email."""
user_email = value or self.get_config("email")
self._user_email = user_email or os.getenv("JUPYTERHUB_USER")
@property
def refresh_token(self) -> Optional[str]:
"""Return the session refresh token."""
return self._refresh_token
@refresh_token.setter
def refresh_token(self, value: Optional[str]) -> None:
"""Set the session refresh token."""
refresh_token = value or self.get_config("refresh-token")
self._refresh_token = refresh_token or os.getenv("REFRESH")
def get_config(self, config_name: str) -> Optional[str]:
"""Returns the config value of specified config.
Args:
config_name: The name of the config
"""
try:
config = load_config()
except ConfigError:
return None
section = DEFAULT_CONFIG_SECTION
if section in config.sections():
if config_name in config[section]:
return config[section][config_name]
return None
def get_user(self) -> dict[str, Any]:
"""Get user metadata.
Returns:
Dictionary containing user metadata.
Raises:
UserNotFoundError: If user metadata is invalid or not found.
"""
try:
metadata = self.get("/identity").json()
except RequestsApiError as err:
raise UserNotFoundError from err
if not metadata:
raise UserNotFoundError("User metadata invalid or not found.")
return metadata
def save_config(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
verify: bool = True,
overwrite: bool = False,
**kwargs,
) -> None:
"""Create qbraidrc file. In qBraid Lab, qbraidrc is automatically present in filesystem.
Raises:
UserNotFoundError: If user metadata is invalid or not found.
AuthError: If there is a credential mismatch.
ConfigError: If there is an error saving the config.
"""
self._api_key = api_key or self.api_key
self._refresh_token = kwargs.get("refresh_token", self.refresh_token)
self._user_email = kwargs.get("user_email", self.user_email)
if base_url:
value = base_url.rstrip("/") + "/"
super(QbraidSession, QbraidSession).base_url.fset(self, value)
config = configparser.ConfigParser()
if overwrite:
# Starting with a clean config if overwrite is True
section = DEFAULT_CONFIG_SECTION
config.add_section(section)
else:
# Load existing config if overwrite is False
try:
config = load_config()
except ConfigError:
config.add_section(DEFAULT_CONFIG_SECTION)
section = DEFAULT_CONFIG_SECTION
if section not in config.sections():
config.add_section(section)
# Set or update configurations
options = {
"email": self._user_email,
"api-key": self._api_key,
"refresh-token": kwargs.get("refresh_token", self._refresh_token),
"url": self.base_url,
}
for option, value in options.items():
config = update_config_option(config, section, option, value)
save_user_config(config)
if verify:
res_json = self.get_user()
res_email = res_json.get("email")
if self._user_email and self._user_email != res_email:
raise AuthError(
f"Credential mismatch: Session initialized for '{self._user_email}', "
f"but API key corresponds to '{res_email}'."
)
def get_available_services(self) -> list[str]:
"""
Get a list of available services that can be loaded as low-level
clients via :py:meth:`Session.client`.
Returns:
List: List of service names.
"""
services_path = os.path.join(os.path.dirname(__file__), "services")
return list(discover_services(services_path))
def client(
self, service_name: str, api_key: Optional[str] = None, **kwargs
) -> "qbraid_core.QbraidClient":
"""Return a client for the specified service.
Args:
service_name (str): Name of the service.
api_key (optional, str): API key for the client service.
Returns:
qbraid_core.QbraidClient: Client for the specified service.
"""
if len(client_registry) == 0:
self.get_available_services()
client_class = client_registry.get(service_name)
if not client_class:
raise ValueError(f"Service '{service_name}' not registered")
session = None if api_key else self
return client_class(session=session, api_key=api_key, **kwargs)