Source code for cuqi.experimental.algebra._ast

"""
CUQIpy specific implementation of an abstract syntax tree (AST) for algebra on variables.

The AST is used to record the operations applied to variables allowing a delayed evaluation
of said operations when needed by traversing the tree with the __call__ method.

For example, the following code

    x = VariableNode('x')
    y = VariableNode('y')
    z = 2*x + 3*y

will create the following AST:

z = AddNode(
        MultiplyNode(
            ValueNode(2),
            VariableNode('x')
        ),
        MultiplyNode(
            ValueNode(3),
            VariableNode('y')
        )
    )

which can be evaluated by calling the __call__ method:

    z(x=1, y=2) # returns 8

"""

from abc import ABC, abstractmethod

convert_to_node = lambda x: x if isinstance(x, Node) else ValueNode(x)
""" Converts any non-Node object to a ValueNode object. """

# ====== Base classes for the nodes ======


[docs] class Node(ABC): """Base class for all nodes in the abstract syntax tree. Responsible for building the AST by creating nodes that represent the operations applied to variables. Each subclass must implement the __call__ method that will evaluate the node given the input parameters. """ @abstractmethod def __call__(self, **kwargs): """Evaluate node at a given parameter value. This will traverse the sub-tree originated at this node and evaluate it given the recorded operations.""" pass
[docs] @abstractmethod def condition(self, **kwargs): """ Conditions the tree by replacing any VariableNode with a ValueNode if the variable is in the kwargs dictionary. """ pass
@abstractmethod def __repr__(self): """String representation of the node. Used for printing the AST.""" pass
[docs] def get_variables(self, variables=None): """Returns a set with the names of all variables in the sub-tree originated at this node.""" if variables is None: variables = set() if isinstance(self, VariableNode): variables.add(self.name) if hasattr(self, "child"): self.child.get_variables(variables) if hasattr(self, "left"): self.left.get_variables(variables) if hasattr(self, "right"): self.right.get_variables(variables) return variables
def __add__(self, other): return AddNode(self, convert_to_node(other)) def __radd__(self, other): return AddNode(convert_to_node(other), self) def __sub__(self, other): return SubtractNode(self, convert_to_node(other)) def __rsub__(self, other): return SubtractNode(convert_to_node(other), self) def __mul__(self, other): return MultiplyNode(self, convert_to_node(other)) def __rmul__(self, other): return MultiplyNode(convert_to_node(other), self) def __truediv__(self, other): return DivideNode(self, convert_to_node(other)) def __rtruediv__(self, other): return DivideNode(convert_to_node(other), self) def __pow__(self, other): return PowerNode(self, convert_to_node(other)) def __rpow__(self, other): return PowerNode(convert_to_node(other), self) def __neg__(self): return NegateNode(self) def __abs__(self): return AbsNode(self) def __getitem__(self, i): return GetItemNode(self, convert_to_node(i)) def __matmul__(self, other): return MatMulNode(self, convert_to_node(other)) def __rmatmul__(self, other): return MatMulNode(convert_to_node(other), self)
class UnaryNode(Node, ABC): """Base class for all unary nodes in the abstract syntax tree. Parameters ---------- child : Node The direct child node on which the unary operation is performed. """ def __init__(self, child: Node): self.child = child def condition(self, **kwargs): return self.__class__(self.child.condition(**kwargs)) class BinaryNode(Node, ABC): """Base class for all binary nodes in the abstract syntax tree. The op_symbol attribute is used for printing the operation in the __repr__ method. Parameters ---------- left : Node Left child node to the binary operation. right : Node Right child node to the binary operation. """ @property @abstractmethod def op_symbol(self): """Symbol used to represent the operation in the __repr__ method.""" pass def __init__(self, left: Node, right: Node): self.left = left self.right = right def condition(self, **kwargs): return self.__class__(self.left.condition(**kwargs), self.right.condition(**kwargs)) def __repr__(self): return f"{self.left} {self.op_symbol} {self.right}" class BinaryNodeWithParenthesis(BinaryNode, ABC): """Base class for all binary nodes in the abstract syntax tree that should be printed with parenthesis.""" def __repr__(self): left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left) right = ( f"({self.right})" if isinstance(self.right, BinaryNode) else str(self.right) ) return f"{left} {self.op_symbol} {right}" class BinaryNodeWithParenthesisNoSpace(BinaryNode, ABC): """Base class for all binary nodes in the abstract syntax tree that should be printed with parenthesis but no space.""" def __repr__(self): left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left) right = ( f"({self.right})" if isinstance(self.right, BinaryNode) else str(self.right) ) return f"{left}{self.op_symbol}{right}" # ====== Specific implementations of the "leaf" nodes ======
[docs] class VariableNode(Node): """Node that represents a generic variable, e.g. "x" or "y". Parameters ---------- name : str Name of the variable. Used for printing and to retrieve the given input value of the variable in the kwargs dictionary when evaluating the tree. """
[docs] def __init__(self, name): self.name = name
def __call__(self, **kwargs): """Retrieves the value of the variable from the passed kwargs. If no value is found, it raises a KeyError.""" if not self.name in kwargs: raise KeyError( f"Variable '{self.name}' not found in the given input parameters. Unable to evaluate the expression." ) return kwargs[self.name]
[docs] def condition(self, **kwargs): if self.name in kwargs: return ValueNode(kwargs[self.name]) return self
def __repr__(self): return self.name
class ValueNode(Node): """Node that represents a constant value. The value can be any python object that is not a Node. Parameters ---------- value : object The python object that represents the value of the node. """ def __init__(self, value): self.value = value def __call__(self, **kwargs): """Returns the value of the node.""" return self.value def condition(self, **kwargs): return self def __repr__(self): return str(self.value) # ====== Specific implementations of the "internal" nodes ====== class AddNode(BinaryNode): """Node that represents the addition operation.""" @property def op_symbol(self): return "+" def __call__(self, **kwargs): return self.left(**kwargs) + self.right(**kwargs) class SubtractNode(BinaryNode): """Node that represents the subtraction operation.""" @property def op_symbol(self): return "-" def __call__(self, **kwargs): return self.left(**kwargs) - self.right(**kwargs) class MultiplyNode(BinaryNodeWithParenthesis): """Node that represents the multiplication operation.""" @property def op_symbol(self): return "*" def __call__(self, **kwargs): return self.left(**kwargs) * self.right(**kwargs) class DivideNode(BinaryNodeWithParenthesis): """Node that represents the division operation.""" @property def op_symbol(self): return "/" def __call__(self, **kwargs): return self.left(**kwargs) / self.right(**kwargs) class PowerNode(BinaryNodeWithParenthesisNoSpace): """Node that represents the power operation.""" @property def op_symbol(self): return "^" def __call__(self, **kwargs): return self.left(**kwargs) ** self.right(**kwargs) class GetItemNode(BinaryNode): """Node that represents the get item operation. Here the left node is the object and the right node is the index.""" def __call__(self, **kwargs): return self.left(**kwargs)[self.right(**kwargs)] def __repr__(self): left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left) return f"{left}[{self.right}]" @property def op_symbol(self): pass class NegateNode(UnaryNode): """Node that represents the arithmetic negation operation.""" def __call__(self, **kwargs): return -self.child(**kwargs) def __repr__(self): child = ( f"({self.child})" if isinstance(self.child, (BinaryNode, UnaryNode)) else str(self.child) ) return f"-{child}" class AbsNode(UnaryNode): """Node that represents the absolute value operation.""" def __call__(self, **kwargs): return abs(self.child(**kwargs)) def __repr__(self): return f"abs({self.child})" class MatMulNode(BinaryNodeWithParenthesis): """Node that represents the matrix multiplication operation.""" @property def op_symbol(self): return "@" def __call__(self, **kwargs): return self.left(**kwargs) @ self.right(**kwargs)