Source code for atomscale.core.client

from __future__ import annotations

import itertools
import os
import platform
import sys
from collections.abc import Callable  # type: ignore[ruleName]
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from importlib.metadata import PackageNotFoundError, version
from typing import Any, Literal
from urllib.parse import urljoin

from requests import Session
from requests.adapters import HTTPAdapter
from rich.progress import Progress
from urllib3.util.retry import Retry

try:
    __version__ = version("atomscale")
except PackageNotFoundError:
    __version__ = "0.0.0"


[docs] class BaseClient: """Base API client implementation"""
[docs] def __init__( self, api_key: str, endpoint: str, ): """ Args: api_key (str | None): API key. endpoint (str): Root API endpoint. """ self.api_key = api_key self.endpoint = endpoint self._session = None
@property def session(self) -> Session: """Session under which HTTP requests are issued""" if not self._session: self._session = self._create_session(self.api_key) return self._session # type: ignore[return-value] def _get( self, sub_url: str, params: dict[str, Any] | None = None, deserialize: bool = True, base_override: str | None = None, ) -> list[dict[Any, Any]] | dict[Any, Any] | bytes | None: """Method for issuing a GET request Args: sub_url (str): API sub-url to use. params (dict[str, Any] | None): Params to pass in the GET request. Defaults to None. deserialize (bool): Whether to JSON deserialize the response data or return raw bytes. Defaults to True. base_overrise (str): Base URL to use instead of the default ADS API root URL. Raises: ClientError: If the response code returned is not within the range of 200-400. Returns: (list[dict] | dict | bytes | None): Deserialized JSON data or raw bytes. Returns None if response is a 404. """ base_url = base_override or self.endpoint response = self.session.get( url=urljoin(base_url, sub_url), verify=True, params=params ) if not response.ok: if response.status_code == 404: return None raise ClientError( f"Problem retrieving data from {sub_url} with parameters {params}. HTTP Error {response.status_code}: {response.text}" ) if len(response.content) == 0: return None return response.json() if deserialize else response.content def _post_or_put( self, method: Literal["POST", "PUT"], sub_url: str, params: dict[str, Any] | None = None, body: dict[str, Any] | bytes | None = None, headers: dict[str, str] | None = None, deserialize: bool = True, base_override: str | None = None, return_headers: bool = False, ) -> list[dict[Any, Any]] | dict[Any, Any] | bytes | None: """Method for issuing a POST or PUT request Args: method (Literal["POST", "PUT"]): Method to use sub_url (str): API sub-url to use. params (dict[str, Any] | None): Params to pass in the GET request. Defaults to None. body (dict[str, Any] | bytes): Body data to send in the POST request. headers (dict[str, str] | None): Optional headers to include in the request. deserialize (bool): Whether to JSON deserialize the response data or return raw bytes. Defaults to True. base_overrise (str): Base URL to use instead of the default ADS API root URL. return_headers (bool): Whether to return the headers from the response instead of the content. Defaults to False. Raises: ClientError: If the response code returned is not within the range of 200-400. Returns: (list[dict] | dict | bytes | None): Deserialized JSON data or raw bytes. Returns None if response is a 404. """ base_url = base_override or self.endpoint method_func = self.session.put if method == "PUT" else self.session.post # decide whether to use data= (bytes/streams) or json= if body is None: data_params: dict[str, Any] = {} elif isinstance(body, bytes | bytearray): data_params = {"data": body} elif hasattr(body, "read"): # any file-like / RawIOBase data_params = {"data": body} else: # everything else (dict, list, etc.) goes through JSON data_params = {"json": body} response = method_func( url=urljoin(base_url, sub_url), verify=True, params=params, headers=headers or {}, **data_params, # type: ignore # noqa: PGH003 ) if not response.ok: if response.status_code == 404: return None raise ClientError( f"Problem sending data to {sub_url}. HTTP Error {response.status_code}: {response.text}" ) if return_headers: return_data: dict[Any, Any] = response.headers # type: ignore # noqa: PGH003 else: return_data = response.json() if deserialize else response.content # type: ignore #noqa: PGH003 return return_data def _multi_thread( self, func: Callable[..., Any], kwargs_list: list[dict[str, Any]], progress_bar: Progress | None = None, progress_description: str | None = None, progress_kwargs: dict | None = None, transient: bool = False, ) -> list[Any]: """Handles running a function concurrently with a ThreadPoolExecutor Arguments: func (Callable): Function to run concurrently kwargs_list (list): List of keyword argument inputs for the function progress_bar (Progress | None): Progress bar to show. Defaults to None. progress_description (str | None): Progress bar description. progress_kwargs (dict | None): Additional kwargs to pass to the progress task. transient (bool): Whether the progress bar is transient. Defaults to False, Returns: (list[Any]): List of results from passed function in the order of parameters passed """ return_dict = {} total_count = len(kwargs_list) kwargs_gen = iter(kwargs_list) if progress_bar is not None: progress_kwargs = progress_kwargs or {"pad": ""} task = progress_bar.add_task( progress_description or "", total=total_count, **progress_kwargs ) ind = 0 num_parallel = min(os.cpu_count() or 8, 8) with ThreadPoolExecutor(max_workers=num_parallel) as executor: # Get list of initial futures defined by max number of parallel requests futures = set() for kwargs in itertools.islice(kwargs_gen, num_parallel): future = executor.submit( func, **kwargs, ) future.ind = ind # type: ignore # noqa: PGH003 futures.add(future) ind += 1 while futures: # Wait for at least one future to complete and process finished finished, futures = wait(futures, return_when=FIRST_COMPLETED) for future in finished: data = future.result() if progress_bar is not None: progress_bar.update(task, advance=1, refresh=True) # type: ignore # noqa: PGH003 return_dict[future.ind] = data # type: ignore # noqa: PGH003 # Populate more futures to replace finished for kwargs in itertools.islice(kwargs_gen, len(finished)): new_future = executor.submit( func, **kwargs, ) new_future.ind = ind # type: ignore # noqa: PGH003 futures.add(new_future) ind += 1 if progress_bar is not None and transient: progress_bar.remove_task(task) # type: ignore # noqa: PGH003 return [t[1] for t in sorted(return_dict.items())] @staticmethod def _create_session(api_key: str): """Create a requests session Args: api_key (str): API key to include in the header. Returns: (Session): Requests Session object """ session = Session() session.headers = {"X-API-KEY": api_key} # User agent information atomscale_info = "atomscale/" + __version__ python_info = f"Python/{sys.version.split()[0]}" platform_info = f"{platform.system()}/{platform.release()}" session.headers["user-agent"] = ( f"{atomscale_info} ({python_info} {platform_info})" ) # TODO: Add retry setting to configuration somewhere max_retry_num = 3 retry = Retry( total=max_retry_num, read=max_retry_num, connect=max_retry_num, respect_retry_after_header=True, status_forcelist=[429, 504, 502], ) adapter = HTTPAdapter(max_retries=retry) session.mount("http://", adapter) session.mount("https://", adapter) return session
[docs] class ClientError(Exception): """Generic error thrown by the Atomic Data Sciences API client"""