SoFunction
Updated on 2024-11-15

Python Implementation of Visitor Pattern Details

Suppose you want to implement an object that holds multiple types of data structures, such as a tree node that holds arithmetic operands and operators, and you need to hold nodes that contain unary operators, binary operators, and numeric types

class Node:
    pass


class UnaryOperator(Node):
    def __init__(self, operand):
         = operand


class BinaryOperator(Node):
    def __init__(self, left, right):
         = left
         = right


class Add(BinaryOperator):
    pass


class Sub(BinaryOperator):
    pass


class Mul(BinaryOperator):
    pass


class Div(BinaryOperator):
    pass


class Negative(UnaryOperator):
    pass


class Number(Node):
    def __init__(self, value):
         = value

Performing arithmetic requires this call:

# Assume the equation: 2 - (2+2) * 2 / 1 = 2-(8) = -6.0
t1 = Add(Number(2), Number(2))
t2 = Mul(t1, Number(2))
t3 = Div(t2, Number(1))
t4 = Sub(Number(2), t3)

Or call it like this:

t5 = Sub(Number(2), Div(Mul(Add(Number(2), Number(2)), Number(2)), Number(1)))

This way you need to perform multiple class calls, extremely difficult to read and write and lengthy, there is no way to make the call more general, access becomes simple. Here the visitor pattern can be used to achieve such a purpose.

Visitor patterns can manipulate elements without changing the structure of the object to which the element belongs, making it easy to call or caller (caller) in a way that is common in cab company operations, when a passenger calls a cab, the cab company receives a visitor and allocates a cab to pick up the passenger.

First define a visitor node classVisitorNodeTo realize the most basic access entry, any access method needs to inherit this visitor node class and access its various operations through the visit() method of this visitor node class.

# Base class of the visitor node
class NodeVisitor:
    def visit(self, node):
        if not isinstance(node, Node):  # Returned as a value if it is not a Node object, if there are other cases it can be handled according to the actual situation.
            return node
         = "visit_" + type(node).__name__.lower()  # type(node) can also be replaced with node.__class__ (as long as node.__class__ is not tampered with)
        meth = getattr(self, , None)  
        if meth is None:
            meth = self.generic_visit
        return meth(node)

    def generic_visit(self, node):
        raise RuntimeError(f"No {} method")


# Classes corresponding to (one) visitor
class Visitor(NodeVisitor):
    """
    The name definitions of the methods should all be consistent with the name of the node class (Node) that was previously defined
    """

    def visit_add(self, node):
        return () + ()

    def visit_sub(self, node):
        return () - ()

    def visit_mul(self, node):
        return () * ()

    def visit_div(self, node):
        return () / ()

    def visit_negative(self, node):  # If class Negative named -> class Neg, then def visit_negative named -> def visit_neg
        return -()

    def visit_number(self, node):
        return 

Here.meth = getattr(self, , None) uses a string to call an object method, dynamically defining the methods (visit_add, visit_sub, visit_mul...) corresponding to the class Visitor based on the names of the various Node classes (Add, Sub, Mul...). The code for accessing the entry point is simplified, and when the corresponding method is not obtained, generic_visit() is executed and a RuntimeError exception is thrown to indicate an exception in the access process.

If you need to add an operation, such as taking an absolute value, just define a class class Abs(Unaryoperator): pass and in class Visitor define avisit_abs(self, node) method is sufficient, without making any extra modifications, much less changing the structure of the storage

Here visit() method calls visit_xxx() method, and visit_xxx() may also call visit(), essentially visit() of the loop recursive call, when the amount of data becomes large, the efficiency will become very slow, and the recursion level is too deep will lead to more than the limit and fail, and the following is the use of the stack and generator to eliminate the recursion to enhance the efficiency of the implementation of the Visitor Pattern

import types


class Node:
    pass


class BinaryOperator(Node):
    def __init__(self, left, right):
         = left
         = right


class UnaryOperator(Node):
    def __init__(self, operand):
         = operand


class Add(BinaryOperator):
    pass


class Sub(BinaryOperator):
    pass


class Mul(BinaryOperator):
    pass


class Div(BinaryOperator):
    pass


class Negative(UnaryOperator):
    pass


class Number(Node):
    def __init__(self, value):  # Differs from UnaryOperator only by naming.
         = value


class NodeVisitor:
    def visit(self, node):
        # Use stack+generator to replace the original recursive way of writing visit()
        stack = [node]
        last_result = None  # Performing an operation ultimately returns a value
        while stack:
            last = stack[-1]
            try:
                if isinstance(last, Node):
                    (self._visit(()))
                elif isinstance(last, ):   # GeneratorType will be the object returned by the previous if, which will return the result of the arithmetic performed by the two nodes
                    # If it's a generator, don't pop off, but keep sending until StopIteration
                    # If last_result is not None, this value is given back to the generator (e.g. 2 is picked up by visit_add()'s left value)
                    ((last_result))
                    last_result = None
                else:   # The result of the calculation is a value
                    last_result = ()
            except StopIteration:   # End of generator yield
                ()
        return last_result

    def _visit(self, node):
        self.method_name = "visit_" + type(node).__name__.lower()
        method = getattr(self, self.method_name, None)
        if method is None:
            self.generic_visit(node)
        return method(node)

    def generic_visit(self, node):
        raise RuntimeError(f"No {self.method_name} method")


class Visitor(NodeVisitor):
    def visit_add(self, node):
        yield (yield ) + (yield )    # And both may be Node

    def visit_sub(self, node):
        yield (yield ) - (yield )

    def visit_mul(self, node):
        yield (yield ) * (yield )

    def visit_div(self, node):
        yield (yield ) / (yield )

    def visit_negative(self, node):
        yield -(yield )

    def visit_number(self, node):
        return 

Test to see if an exception is thrown for exceeding the number of recursion levels.

def test_time_cost():
    import time
    s = time.perf_counter()
    a = Number(0)
    for n in range(1, 100000):
        a = Add(a, Number(n))
    v = Visitor()
    print((a))
    print(f"time cost:{time.perf_counter() - s}")

Output is normal, no problem

4999950000
time cost:0.9547078

Finally worked out a method that seemed to work as an alternative:

clas Node:
    psass


class UnaryOperator(Node):
    def __init__(self, operand):
         = operand


class BinaryOperator(Node):
    def __init__(self, left, right):
         = left
         = right


class Add(BinaryOperator):
    def __init__(self, left, right):
        super().__init__(left, right)
         =  + 
    pass


class Sub(BinaryOperator):
    def __init__(self, left, right):
        super().__init__(left, right)
         =  - 
    pass


class Mul(BinaryOperator):
    def __init__(self, left, right):
        super().__init__(left, right)
         =  * 
    pass


class Div(BinaryOperator):
    def __init__(self, left, right):
        super().__init__(left, right)
         =  / 
    pass


class Negative(UnaryOperator):
    def __init__(self, operand):
        super().__init__(operand)
         = -
    pass


class Number(Node):
    def __init__(self, value):
         = value

Run the test:

def test_time_cost():
    import time
    s = time.perf_counter()
    a = Number(0)
    for n in range(1, 100000):
        a = Add(a, Number(n))
    print()
    print(time.perf_counter() - s)

Output:

4999950000
0.2506986

This article on the Python implementation of the visitor mode details of the article is introduced to this, more related Python visitor mode content, please search for my previous articles or continue to browse the following related articles I hope you will support me in the future!