Source code for redex.combinator._branch

"""The branch combinator."""

from typing import List
from functools import reduce
from redex import util
from redex import function as fn
from redex.function import Fn, FnIter
from redex.combinator._serial import serial, Serial
from redex.combinator._parallel import parallel
from redex.combinator._select import select


[docs]def branch(*children: FnIter) -> Serial: """Creates a branch combinator. The combinator combines multiple branches of given functions and operate on copy of inputs. Each branch includes a sequence of functions applied serially. >>> import operator as op >>> from redex import combinator as cb >>> branch = cb.branch(cb.serial(op.add, op.add), op.add) >>> branch(1, 2, 3) == (1 + 2 + 3, 1 + 2) True Args: children: a sequence of functions. Returns: a combinator. """ flat_children = util.flatten(children) indices = _estimate_branch_indices(flat_children) return serial( select(indices=indices), parallel(*flat_children), )
def _estimate_branch_indices(children: List[Fn]) -> List[int]: def count(acc: List[int], child: Fn) -> List[int]: signature = fn.infer_signature(child) return acc + list(range(0, signature.n_in)) initializer: List[int] = [] return reduce(count, children, initializer)