# mypy: disable-error-code=has-type
"""
Copyright 2022 ACCESS-NRI and contributors. See the top-level COPYRIGHT file for details.
SPDX-License-Identifier: Apache-2.0
"""
import re
from typing import Any
import libcst as cst
from IPython.core.getipython import get_ipython
from IPython.core.interactiveshell import ExecutionInfo
from libcst._exceptions import ParserSyntaxError
from .api import ApiHandler
from .registry import TelemetryRegister
from .utils import REGISTRIES
api_handler = ApiHandler()
registries = {registry: TelemetryRegister(registry) for registry in REGISTRIES.keys()}
[docs]
def strip_magic(code: str) -> str:
"""
Parse the provided code into an AST (Abstract Syntax Tree).
Parameters
----------
code : str
The code to parse.
Returns
-------
str
The code without IPython magic commands.
"""
IPYTHON_MAGIC_PATTERN = r"^\s*[%!?]{1,2}|^.*\?{1,2}$"
code = "\n".join(
line for line in code.splitlines() if not re.match(IPYTHON_MAGIC_PATTERN, line)
)
return code
[docs]
def capture_registered_calls(info: ExecutionInfo) -> None:
"""
Use the AST module to parse the code that we are executing & send an API call
if we detect specific function or method calls.
Fail silently if we can't parse the code.
Parameters
----------
info : IPython.core.interactiveshell.ExecutionInfo
An object containing information about the code being executed.
Returns
-------
None
"""
code: str | None = info.raw_cell
if code is None:
return None
code = strip_magic(code)
try:
tree = cst.parse_module(code)
except (ParserSyntaxError, IndentationError):
api_handler.send_failure_api_request(
"intake/failed-telemetry", code, "intake/failed-telemetry"
)
return None
_run_tree(tree)
return None
def _run_tree(tree: cst.Module) -> None: # pragma: no cover
user_namespace: dict[str, Any] = get_ipython().user_ns # type: ignore
try:
reducer = ChainSimplifier(user_namespace, REGISTRIES, api_handler)
reduced_tree = tree.visit(reducer)
visitor = CallListener(user_namespace, REGISTRIES, api_handler)
wrapper = cst.MetadataWrapper(reduced_tree)
wrapper.visit(visitor)
visitor._caught_calls |= reducer._caught_calls
except Exception:
# Catch all exceptions to avoid breaking the execution
# of the code being run. Then post the raw code to the `failed-telemetry` endpoint
api_handler.send_failure_api_request(
"intake/failed-telemetry", tree.code, "intake/failed-telemetry"
)
[docs]
class CallListener(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
def __init__(
self,
user_namespace: dict[str, Any],
registries: dict[str, set[str]],
api_handler: ApiHandler,
):
self.user_namespace = user_namespace
self.registries = registries
self._caught_calls: set[str] = set() # Mostly for debugging
self.api_handler = api_handler
[docs]
def visit_Attribute(self, node: cst.Attribute) -> None:
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
full_name = self._get_full_name(node)
match full_name, parent:
case str(), cst.Call():
return None
case str(), _:
self._process_api_call(full_name, [], {})
return None
[docs]
def visit_Call(self, node: cst.Call) -> None:
"""
Visit a call node, process it if it's a registered call
"""
match node:
case cst.Call(
func=cst.Name(
value=full_name,
)
):
args, kwargs = extract_call_args_kwargs(node, self.user_namespace)
self._process_api_call(full_name, args, kwargs)
case cst.Call(
func=cst.Attribute(
value=cst.Name(value=base_name),
attr=cst.Name(
value=attr_name,
),
)
):
args, kwargs = extract_call_args_kwargs(node, self.user_namespace)
full_name = f"{base_name}.{attr_name}"
self._process_api_call(full_name, args, kwargs)
case cst.Call(func=cst.Attribute() as attr_node):
if full_name := self._get_full_name(attr_node):
# If we have a full name, we can process the call
args, kwargs = extract_call_args_kwargs(node, self.user_namespace)
self._process_api_call(full_name, args, kwargs)
def _process_api_call(
self, func_name: str, args: list[Any], kwargs: dict[str, Any]
) -> None:
"""Process an API call for a matched function name."""
for registry, registered_funcs in self.registries.items():
if func_name in registered_funcs:
self.api_handler.send_api_request(
registry,
func_name,
args,
kwargs,
)
self._caught_calls |= {func_name}
def _get_full_name(self, node: cst.CSTNode) -> str:
"""Recursively get the full name of a function or method call."""
return _get_full_name(node)
[docs]
class ChainSimplifier(cst.CSTTransformer):
"""
Transform chained calls by removing intermediate method calls
Example: ds.search(...).search(...).to_dataset_dict()
becomes: ds.to_dataset_dict()
"""
def __init__(
self,
user_namespace: dict[str, Any],
registries: dict[str, set[str]],
api_handler: ApiHandler,
):
self.user_namespace = user_namespace
self.registries = registries
self._caught_calls: set[str] = set() # Mostly for debugging
self.api_handler = api_handler
self._inferred_types: dict[str, str] = {}
def _resolve_type(self, instance_name: str) -> str:
"""
Resolve the type of an instance by its name.
If the instance is a module, return its name.
"""
instance = self.user_namespace.get(instance_name)
if instance is None:
return self._inferred_types.get(instance_name, "type")
type_name = type(instance).__name__
if type_name == "module":
type_name = getattr(instance, "__name__", instance_name)
return type_name
[docs]
def leave_Assign(
self, original_node: cst.Assign, updated_node: cst.Assign
) -> cst.Assign:
"""
When we leave an assignment node, if the value is a call to a registered
function, we infer the type of the variable being assigned to. We also
handle the case of assigning a variable to another variable, so we can
track type information through simple variable assignments. This allows
us to resolve the type of variables that are assigned from API calls, and
use that type information to simplify chained calls.
"""
match updated_node:
case cst.Assign(
targets=[cst.AssignTarget(target=cst.Name(value=var_name))],
value=cst.Name(value=type_name),
):
self._inferred_types[var_name] = type_name
case _:
pass
return updated_node
[docs]
def leave_Attribute(
self, original_node: cst.Attribute, updated_node: cst.Attribute
) -> cst.Attribute:
"""
When we leave an attribute node, if it's parent is a cst.Name (ie. the
root of a chain of attribute accesses), we replace the value of the
attribute with the type name of the instance.
"""
match updated_node:
case cst.Attribute(
value=cst.Name(
value=instance_name,
),
attr=cst.Name(value=_),
) if (type_name := self._resolve_type(instance_name)) not in [None, "type"]:
return updated_node.with_changes(value=cst.Name(type_name))
case cst.Attribute(
value=cst.Call(
func=cst.Name(
value=_maybe_class_name,
),
)
) if (
type(self.user_namespace.get(_maybe_class_name, None)) is type
):
return updated_node.with_changes(value=cst.Name(_maybe_class_name))
case _:
return updated_node
[docs]
def leave_Subscript(
self, original_node: cst.Subscript, updated_node: cst.Subscript
) -> cst.Call | cst.Name:
"""
When we leave a subscript node, replace eg. `instance[key]` with `ClassName.__getitem__(key)`.
This means there is no need for a `CallListener.visit_Subscript` method.
"""
match updated_node:
case cst.Subscript( # Something like MyClass()['key']
value=cst.Call(func=cst.Name(value=type_name)),
slice=[
cst.SubscriptElement(
slice=cst.Index(value=cst.SimpleString(value=args))
)
],
) if (
type(self.user_namespace.get(type_name, None)) is type
):
return self._process_subscript_call(type_name, updated_node, args)
case cst.Subscript( # String index, eg. instance['key']
value=cst.Name(value=instance_name),
slice=[
cst.SubscriptElement(
slice=cst.Index(value=cst.SimpleString(value=args))
)
],
) if (type_name := self._resolve_type(instance_name)) is not None:
return self._process_subscript_call(type_name, updated_node, args)
case cst.Subscript( # Integer index
value=cst.Name(value=instance_name),
slice=[
cst.SubscriptElement(slice=cst.Index(value=cst.Integer(value=args)))
],
) if (type_name := self._resolve_type(instance_name)) is not None:
return cst.Call(
func=cst.Attribute(
value=cst.Name(type_name),
attr=cst.Name("__getitem__"),
),
args=[
cst.Arg(value=cst.Integer(value=args)),
],
)
case cst.Subscript( # Variable index
value=cst.Name(value=instance_name),
slice=[
cst.SubscriptElement(slice=cst.Index(value=cst.Name(value=args)))
],
) if (type_name := self._resolve_type(instance_name)) is not None:
res_args: int | str | object = self.user_namespace.get(args, args)
if isinstance(res_args, int):
mval: cst.BaseExpression = cst.Integer(value=f"{res_args}")
else:
mval = cst.SimpleString(value=f"'{res_args}'")
return cst.Call(
func=cst.Attribute(
value=cst.Name(type_name),
attr=cst.Name("__getitem__"),
),
args=[
cst.Arg(
value=mval
), # TODO: so we can put the right value in here
],
)
# Explicitly handle the case of `intake.cat.access_nri['something']
case cst.Subscript(
value=cst.Attribute(
value=cst.Attribute(
value=cst.Name(value="intake"),
attr=cst.Name(value="cat"),
),
attr=cst.Name(
value="access_nri",
),
),
slice=[
cst.SubscriptElement(
slice=cst.Index(value=cst.SimpleString(value=arg)),
),
],
):
self._process_api_call("intake.cat.access_nri", [], {})
return self._process_subscript_call("DfFileCatalog", updated_node, arg)
case cst.Subscript(
value=cst.Attribute(
value=cst.Attribute(
value=cst.Name(
value="intake",
),
attr=cst.Name(
value="cat",
),
),
attr=cst.Name(
value="access_nri",
),
),
slice=[
cst.SubscriptElement(
slice=cst.Index(
value=cst.Name(
value=arg,
),
),
),
],
) if (argval := self.user_namespace.get(arg, None)) is not None:
# Differs from above as we need to wrap arg in extra quotes to rewrite
# it as a simple string in the rewritten code.
self._process_api_call("intake.cat.access_nri", [], {})
return self._process_subscript_call(
"DfFileCatalog", updated_node, f"'{argval}'"
)
case _: # pragma: no cover
raise AssertionError(
"Subscript node does not match expected pattern. "
"This should not happen, please report this as a bug."
) # pragma: no cover
def _process_subscript_call(
self, type_name: str, updated_node: cst.Subscript, arg: str
) -> cst.Name:
_node = cst.Call(
func=cst.Attribute(
value=cst.Name(type_name),
attr=cst.Name("__getitem__"),
),
args=[
cst.Arg(value=cst.SimpleString(value=arg)),
],
)
full_name = f"{type_name}.__getitem__"
_args, _ = extract_call_args_kwargs(_node, self.user_namespace)
self._process_api_call(full_name, _args, {})
temp_module = cst.Module(
body=[cst.SimpleStatementLine(body=[cst.Expr(value=updated_node)])]
)
code = temp_module.code
try:
result_type = type(eval(code, globals(), self.user_namespace)).__name__
except Exception: # pragma: no cover
result_type = type_name # pragma: no cover
return cst.Name(value=result_type)
[docs]
def leave_Call(
self, original_node: cst.Call, updated_node: cst.Call
) -> cst.Call | cst.Name:
# Use matcher to identify the pattern: any_method(search_call(...))
match updated_node:
case cst.Call(
func=cst.Name(
value=func_name,
)
) if (instance := self.user_namespace.get(func_name, None)) is not None:
func_name = (
instance.__name__
) # Dealias if we've renamed it something else
return updated_node.with_changes(func=cst.Name(func_name))
case cst.Call(
func=cst.Attribute(
value=cst.Name(value=base_name),
attr=cst.Name(
value=attr_name,
),
)
): # TODO: check that we return self here or don't do anything
args, kwargs = extract_call_args_kwargs(
updated_node, self.user_namespace
)
full_name = f"{base_name}.{attr_name}"
self._process_api_call(full_name, args, kwargs)
# Then pop that attribute access out of the chain
return cst.Name(
value=base_name,
)
case _:
pass
return updated_node
def _process_api_call(
self, func_name: str, args: list[Any], kwargs: dict[str, Any]
) -> None:
"""Process an API call for a matched function name."""
for registry, registered_funcs in self.registries.items():
if func_name in registered_funcs:
self.api_handler.send_api_request(
registry,
func_name,
args,
kwargs,
)
self._caught_calls |= {func_name}
def _get_full_name(node: cst.CSTNode) -> str:
"""Recursively get the full name of a function or method call."""
match node:
case cst.Attribute(
value=base_name,
attr=cst.Name(value=attr_name),
):
# If the node is an attribute, we need to repeat to get the full name
return f"{_get_full_name(base_name)}.{attr_name}"
case cst.Name(value=name):
# If the node is a name, we return the name
assert isinstance(name, str), "Name node should have a string value"
return name
case _: # pragma: no cover
raise AssertionError(
"Node does not match expected pattern. "
"This should not happen, please report this as a bug."
) # pragma: no cover