"""The stack is used by combinators to pass data between functions."""
from typing import Any, Callable, Optional
from functools import wraps
import logging
from redex import util
from redex import function as fn
from redex.function import Fn, Signature
# TODO: enable type check after: mypy > 0.910.
# https://github.com/python/mypy/issues/9980
Stack = tuple[Any, ...] # type:ignore
"""The stack."""
StackMethod = Callable[[Any, Stack], Stack]
"""The method from stack state to stack state."""
[docs]def constrained_call(
func: Fn,
stack: Stack,
signature: Optional[Signature] = None,
) -> Stack:
"""Applies the function with arguments taken from the stack.
Takes `n_in` arguments from the stack, reshapes them to match
function's input shape `in_shape`, calls the function, then
pushes ouputs back onto the stack.
Args:
func: a function to call.
stack: arguments available for the call.
signature: optional signature of the function. If not set,
it will be inferred.
Returns:
function outputs and rest of the stack.
Raises:
ValueError: if a number of arguments on the stack less
than required for function call.
>>> import operator as op
>>> from redex.stack import constrained_call
>>> constrained_call(func=op.add, stack=(1, 2, 0, 0))
(3, 0, 0)
"""
if signature is None:
signature = fn.infer_signature(func)
n_in, in_shape = signature.n_in, signature.in_shape
stack_size = len(stack)
logging.debug(
"constrained_call :: %s stack_size=%s signature=%s",
fn.infer_name(func).ljust(20),
stack_size,
signature,
)
verify_stack_size(func, stack, signature)
inputs = util.reshape_tuples(stack[:n_in], in_shape)
outputs = tuple(util.flatten_tuples(util.expand_to_tuple(func(*inputs))))
return outputs + stack[n_in:]
[docs]def verify_stack_size(
func: Fn,
stack: Stack,
signature: Optional[Signature] = None,
) -> int:
"""Verifies that the stack contains required number of arguments for function call.
Args:
func: a function.
stack: arguments available for the call.
signature: optional signature of the function. If not set,
it will be inferred.
Returns:
stack size.
Raises:
ValueError: if a number of arguments on the stack less
than required for function call.
"""
if signature is None:
signature = fn.infer_signature(func)
n_in = signature.n_in
stack_size = len(stack)
if stack_size < n_in:
raise ValueError(
f"The `{fn.infer_name(func)}` takes {n_in} "
f"positional arguments but {stack_size} were given."
)
return stack_size
[docs]def stackmethod(method: Fn) -> StackMethod:
"""Wraps a any method to a stackmethod.
The stackmethods expect an entire stack as a single argument
and output its modified version.
Args:
method: a method to wrap.
Returns:
a stackmethod.
::
from redex.stack import stackmethod, Stack
class Add:
@stackmethod
def __call__(self, stack: Stack) -> Stack:
a, b, *rest = stack
return (a + b, *rest)
Add()(1, 2) # -> 3
"""
@wraps(method)
def inner(self: Any, *inputs: Any) -> Any:
return util.squeeze_tuple(method(self, util.expand_to_tuple(inputs)))
return inner