"""
Copyright 2022 ACCESS-NRI and contributors. See the top-level COPYRIGHT file for details.
SPDX-License-Identifier: Apache-2.0
"""
import asyncio
import getpass
import logging
import multiprocessing
import platform
import re
import sys
import uuid
import warnings
from functools import wraps
from pathlib import Path, PurePosixPath
from typing import Any, Callable, Iterable, Type
import httpx
import pydantic
import yaml
from .utils import ENDPOINTS
logging.getLogger("httpx").setLevel(logging.WARNING)
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
with open(Path(__file__).parent / "config.yaml", "r") as f:
config = yaml.safe_load(f)
NRI_USER = True
[docs]
class ProductionToggle:
"""
Singleton class to hold info about whether the code is running in production
or not.
This class is a singleton so that the production status can be set once and
accessed from anywhere in the code.
Exposed functionality:
- ``production`` (bool): Whether the code is running in production or not.
Setting this will also set the server URL to the production or staging URL.
- ``debug`` (Callable): A decorator that wraps a function in a try/except
block. If the code is running in production, the function will be called
normally. Exceptions are suppressed in production but raised in staging.
"""
_production = True
_instance = None
PRODUCTION_URL = "https://reporting.access-nri-store.cloud.edu.au/api/"
STAGING_URL = "https://reporting-dev.access-nri-store.cloud.edu.au/api/"
def __new__(cls: Type[Self]) -> Self:
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if hasattr(self, "initialized"):
return None
self.initialized = True
@property
def production(self) -> bool:
return self._production
@production.setter
def production(self, prod: bool) -> None:
"""
Set the production status.
"""
if not isinstance(prod, bool):
raise TypeError("Production status must be a boolean")
if prod:
ApiHandler().server_url = self.PRODUCTION_URL
else:
ApiHandler().server_url = self.STAGING_URL
self._production = prod
return None
[docs]
def debug(self) -> Callable[..., Any]:
"""
Debugging decorator. Applying this to a function will wrap all telemetry
calls in try/except blocks so that users never see any exceptions from
the telemetry code.
Notes
-----
We have to apply the branching logic *within* the decorator, because
otherwise the logic gets applied at initialization time, and we can't
change the production status after that.
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if self.production:
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return func(*args, **kwargs)
except Exception:
return None
else:
return func(*args, **kwargs)
return wrapper
return decorator
def __str__(self) -> str:
return f"ProductionToggle(production={self._production})"
def __repr__(self) -> str:
return f"ProductionToggle(production={self._production})"
TOGGLE = ProductionToggle()
[docs]
class ApiHandler:
"""
Singleton class to handle API requests. I'm only using a class here so we can save
the extra_fields attribute.
Singleton so that we can add extra fields elsewhere in the code and have them
persist across all telemetry calls.
To configure request timeouts and the multiprocessing context manually, configure
the _request_timeout and _mproc_override class attributes as desired.
"""
_instance = None
endpoints = {service: endpoint for service, endpoint in ENDPOINTS.items()}
headers: dict[str, dict[str, str]] = {service: {} for service in ENDPOINTS}
_extra_fields: dict[str, dict[str, Any]] = {ep_name: {} for ep_name in ENDPOINTS}
_pop_fields: dict[str, list[str]] = {}
_request_timeout = None
_mproc_override = None
def __new__(cls: Type[Self], *args: Any, **kwargs: Any) -> Self:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(
self,
server_url: str = "https://reporting.access-nri-store.cloud.edu.au",
) -> None:
if hasattr(self, "_initialized"):
return None
self._initialized = True
self._server_url = server_url
@property
def extra_fields(self) -> dict[str, Any]:
return self._extra_fields
[docs]
@pydantic.validate_call
def add_extra_fields(self, service_name: str, fields: dict[str, Any]) -> None:
"""
Add an extra field to the telemetry data. Only works for services that
already have an endpoint defined.
"""
if service_name not in self.endpoints:
raise KeyError(f"Endpoint for '{service_name}' not found")
self._extra_fields[service_name] = fields
return None
[docs]
@pydantic.validate_call
def set_headers(
self, service_names: str | Iterable[str] | None, headers: dict[str, str]
) -> None:
"""
Add headers to the telemetry request for a given service or services, if
specified.
If service_names is None, the headers will be added to all services.
"""
if isinstance(service_names, str):
service_names = [service_names]
for service_name in service_names or self.endpoints:
if service_name not in self.endpoints:
raise KeyError(f"Endpoint for '{service_name}' not found")
self.headers[service_name] = headers
return None
[docs]
@pydantic.validate_call
def clear_headers(self, service_names: str | Iterable[str] | None = None) -> None:
"""
Clear the headers for a given service or services, if specified.
If service_names is None, the headers will be cleared for all services.
"""
if isinstance(service_names, str):
service_names = [service_names]
for service_name in service_names or self.endpoints:
if service_name not in self.endpoints:
raise KeyError(f"Endpoint for '{service_name}' not found")
self.headers[service_name] = {}
return None
@property
def server_url(self) -> str:
return self._server_url
@server_url.setter
def server_url(self, url: str) -> None:
"""
Set the server URL for the telemetry API.
"""
if NRI_USER and (
"https://reporting-dev.access-nri-store.cloud.edu.au/" not in url
and "https://reporting.access-nri-store.cloud.edu.au/" not in url
):
warnings.warn(
"Server URL not an ACCESS-NRI Reporting API URL",
stacklevel=2,
category=UserWarning,
)
if NRI_USER and not url.lower().endswith(("api", "api/")):
warnings.warn(
"Server URL does not end with 'api' or 'api/' - this is likely an error",
stacklevel=2,
category=UserWarning,
)
self._server_url = url
return None
@property
def pop_fields(self) -> dict[str, list[str]]:
return self._pop_fields
@property
def request_timeout(self) -> float | None:
return self._request_timeout
@request_timeout.setter
def request_timeout(self, timeout: float | None) -> None:
"""
Set the request timeout for the telemetry API.
"""
if timeout is None:
self._request_timeout = None
return None
if not isinstance(timeout, (int, float)):
raise TypeError("Timeout must be a number")
elif timeout <= 0 or not isinstance(timeout, (int, float)):
raise ValueError("Timeout must be a positive number")
self._request_timeout = timeout
return None
[docs]
@pydantic.validate_call
def remove_fields(self, service: str, fields: str | Iterable[str]) -> None:
"""
Set the fields to remove from the telemetry data for a given service. Useful for excluding default
fields that are not needed for a particular telemetry call: eg, removing
Session tracking if a CLI is being used.
Note: This does not use a set union, so you must specify all fields you want to remove in one call.
# TODO: Maybe make this easier to use?
"""
if isinstance(fields, str):
fields = [fields]
self._pop_fields[service] = list(fields)
[docs]
@TOGGLE.debug()
def send_api_request(
self,
service_name: str,
function_name: str,
args: list[Any] | tuple[Any, ...],
kwargs: dict[str, str | Any],
) -> None:
"""
Send an API request with telemetry data.
Parameters
----------
service_name : str
The name of the service to send the telemetry data to.
function_name : str
The name of the function being tracked.
args : list
The list of positional arguments passed to the function.
kwargs : dict
The dictionary of keyword arguments passed to the function.
Returns
-------
None
Warnings
--------
RuntimeWarning
If the request fails.
"""
telemetry_data = self._create_telemetry_record(
service_name, function_name, args, kwargs
)
endpoint = self._get_endpoints(service_name)
telemetry_headers = self.headers.get(service_name, {})
endpoint = _format_endpoint(self.server_url, endpoint)
send_in_loop(
endpoint,
telemetry_data,
telemetry_headers,
self._request_timeout,
self._mproc_override,
)
return None
[docs]
@TOGGLE.debug()
def send_failure_api_request(
self,
service_name: str,
code: str,
endpoint: str,
) -> None:
"""
Send an API request with telemetry data, for instance where telemetry has failed parsing and
we just want to dump the raw code so it can be interrogated later.
Parameters
----------
service_name : str
The name of the service to send the telemetry data to.
code : str
The code that failed to parse, or otherwise caused an error.
endpoint : str
The endpoint to send the telemetry data to. This is separate from the service name because we
don't parse a registry of functions for this use case, so we have to do things slightly
differently. TODO: I think I built in a side mechanism to do this - find it.
Returns
-------
None
Warnings
--------
RuntimeWarning
If the request fails.
"""
telemetry_data = self._create_failure_record(service_name, code)
# Get headers for the service, defaulting to the headers for the intake_catalog. Needed
# because we don't have a service registered for this use yet
telemetry_headers = self.headers.get(
service_name, self.headers.get("intake_catalog", {})
)
endpoint = _format_endpoint(self.server_url, endpoint)
send_in_loop(
endpoint,
telemetry_data,
telemetry_headers,
self._request_timeout,
self._mproc_override,
)
return None
def _get_endpoints(self, service_name: str) -> str:
"""
Get the endpoint for a given service name.
"""
try:
endpoint = self.endpoints[service_name]
except KeyError as e:
raise KeyError(
f"Endpoint for '{service_name}' not found in {self.endpoints}"
) from e
return endpoint
def _create_telemetry_record(
self,
service_name: str,
function_name: str,
args: list[Any] | tuple[Any, ...],
kwargs: dict[str, Any],
) -> dict[str, Any]:
"""
Create and return a telemetry record, cache it as an instance attribute.
Notes
-----
SessionID() is a lazily evaluated singleton, so it looks like we are
going to generate a new session ID every time we call this function, but we
aren't. I've also modified __get__, so SessionID() evaluates to a string.
"""
telemetry_data = {
"name": getpass.getuser(),
"function": function_name,
"args": args,
"kwargs": kwargs,
"session_id": SessionID(),
**self.extra_fields.get(service_name, {}),
}
for field in self.pop_fields.get(service_name, []):
telemetry_data.pop(field)
self._last_record = telemetry_data
return telemetry_data
def _create_failure_record(
self,
service_name: str,
code: str,
) -> dict[str, Any]:
"""
Create and return a telemetry record, cache it as an instance attribute.
Notes
-----
SessionID() is a lazily evaluated singleton, so it looks like we are
going to generate a new session ID every time we call this function, but we
aren't. I've also modified __get__, so SessionID() evaluates to a string.
"""
telemetry_data = {
"code": code,
"session_id": SessionID(),
**self.extra_fields.get(service_name, {}),
}
for field in self.pop_fields.get(service_name, []):
telemetry_data.pop(field)
self._last_record = telemetry_data
return telemetry_data
[docs]
class SessionID:
"""
Singleton class to store and generate a unique session ID.
This class ensures that only one instance of the session ID exists. The session
ID is generated the first time it is accessed and is represented as a string,
using the UUID4 algorithm.
"""
_instance = None
def __new__(cls: type[Self]) -> Self:
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if hasattr(self, "initialized"):
return None
self.initialized = True
def __get__(self, obj: object, objtype: type | None = None) -> str:
if not hasattr(self, "value"):
self.value = SessionID.create_session_id()
return self.value
[docs]
@staticmethod
def create_session_id() -> str:
"""
Generate a unique session ID.
"""
return str(uuid.uuid4())
[docs]
async def send_telemetry(
endpoint: str,
data: dict[str, Any],
headers: dict[str, str],
warn: bool | None = None,
) -> None:
"""
Asynchronously send telemetry data to the specified endpoint.
Parameters
----------
endpoint : str
The URL to send the telemetry data to.
data : dict
The telemetry data to send.
headers : dict
The headers to send the telemetry data with.
warn : bool, optional
If True, a warning will be raised if the request fails. If False, no
warning will be raised. If None, warn will default the value of
` not ProductionToggle().production`. It wil also enable some status info
about the request being sent.
Returns
-------
None
Warnings
--------
RuntimeWarning
If the request fails.
"""
if warn is None:
warn = not ProductionToggle().production
headers = {
"Content-Type": "application/json",
**headers,
}
async with httpx.AsyncClient() as client:
try:
if warn:
print(f"Posting telemetry to {endpoint}")
response = await client.post(endpoint, json=data, headers=headers)
response.raise_for_status()
except (httpx.RequestError, httpx.HTTPStatusError) as e:
if warn:
warnings.warn(
f"Request failed: {e}", category=RuntimeWarning, stacklevel=2
)
return None
[docs]
def send_in_loop(
endpoint: str,
telemetry_data: dict[str, Any],
telemetry_headers: dict[str, str] | None = None,
timeout: float | None = None,
mproc_override: str | None = None,
) -> None:
"""
Wraps the send_telemetry function in an event loop.
- If an event loop is already running, sends the telemetry data as a background task.
- If no event loop is running, creates a new event loop in a separate process and
sends the telemetry data in the background using that loop.
Parameters
----------
endpoint : str
The URL to send the telemetry data to.
telemetry_data : dict
The telemetry data to send.
headers : dict, optional
The headers to send the telemetry data with.
timeout : float, optional
The maximum time to wait for the coroutine to finish. If the coroutine takes
longer than this time, a TimeoutError will be raised. If None, the coroutine
will terminate after 60 seconds. Timeout will also revert to 60 seconds if
set to 0.
mproc_override : str, optional
The multiprocessing context to use. If None, the context will be set to "fork"
on Linux systems and "spawn" on Windows/ MacOS systems. If a context is specified,
it will be used regardless of the system.
Returns
-------
None
"""
timeout = timeout or 60
telemetry_headers = telemetry_headers or {}
try:
loop = asyncio.get_running_loop()
except RuntimeError:
_run_in_proc(
endpoint, telemetry_data, telemetry_headers, timeout, mproc_override
)
else:
loop.create_task(send_telemetry(endpoint, telemetry_data, telemetry_headers))
return None
def _run_event_loop(
endpoint: str,
telemetry_data: dict[str, Any],
telemetry_headers: dict[str, str] | None = None,
warn: bool | None = None,
) -> None:
"""
Handles the creation and running of an event loop for sending telemetry data.
This function is intended to be run in a separate process, and will:
- Create a new event loop
- Send the telemetry data
- Run the event loop until the telemetry data is sent
Parameters
----------
endpoint : str
The URL to send the telemetry data to.
telemetry_data : dict
The telemetry data to send.
telemetry_headers : dict, optional
The headers to send the telemetry data with.
warn : bool, optional
If True, a warning will be raised if the request fails. If False, no
warning will be raised. If None, warn will default the value of
` not ProductionToggle().production`. It will also enable some status info
about the request being sent.
Returns
-------
None
Notes
-----
We pass through warn here as otherwise ProductionToggle() will be initialized
in the main process, and we want to avoid that.
"""
if warn is None:
warn = not ProductionToggle().production
telemetry_headers = telemetry_headers or {}
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(
send_telemetry(endpoint, telemetry_data, telemetry_headers, warn)
)
def _run_in_proc(
endpoint: str,
telemetry_data: dict[str, Any],
telemetry_headers: dict[str, str] | None,
timeout: float = 60,
mproc_override: str | None = None,
) -> None:
"""
Handles the creation and running of a separate process for sending telemetry data.
This function will:
- Create a new process and run the _run_event_loop function in that process
- Wait for the process to finish
- If the process takes longer than the specified timeout, terminate the process
and raise a warning
Parameters
----------
endpoint : str
The URL to send the telemetry data to.
telemetry_data : dict
The telemetry data to send.
timeout : float
The maximum time to wait for the process to finish.
telemetry_headers : dict, optional
The headers to send the telemetry data with.
mproc_override : str, optional
The multiprocessing context to use. If None, the context will be set to "fork"
on Linux systems and "spawn" on Windows systems. If a context is specified, it
will be used regardless of the system.
Returns
-------
None
"""
telemetry_headers = telemetry_headers or {}
if not mproc_override:
ctx_type = "fork" if platform.system().lower() == "linux" else "spawn"
else:
ctx_type = mproc_override
# Mypy gets upset below because it doesn't know we wont use "fork" on Windows
proc = multiprocessing.get_context(ctx_type).Process( # type: ignore
target=_run_event_loop,
args=(endpoint, telemetry_data, telemetry_headers),
)
proc.start()
proc.join(timeout)
if proc.is_alive():
proc.terminate()
warnings.warn(
f"Telemetry data not sent within {timeout} seconds",
category=RuntimeWarning,
stacklevel=2,
)
return None
def _format_endpoint(server_url: str, endpoint: str) -> str:
"""
Concatenates the server URL and endpoint, ensuring that there is only one
slash between them.
"""
endpoint = str(PurePosixPath(server_url) / endpoint.lstrip("/"))
return re.sub(r"^(https?:/)(.*?)(?<!/)\/?$", r"\1/\2/", endpoint)