Skip to content

Follow-Up to Python Visitor

I got an email from a reader in Spain that read, in part

I did a test and got an AttributeError. Maybe you are interested in debugging it.

The email referred to {% post_link python-visitor-pattern-helper %} which I wrote in January of last year. I pulled down the code, gave it a whirl, and found that I had a bug in it.

Wrong code ⇒ bad code.

I wanted to debug it.

Here it is.

python
# visit.py

import inspect

__all__ = ['on', 'when']

def on(param_name):
  def f(fn):
    dispatcher = Dispatcher(param_name, fn)
    return dispatcher
  return f


def when(param_type):
  def f(fn):
    frame = inspect.currentframe().f_back
    dispatcher = frame.f_locals[fn.func_name]
    if not isinstance(dispatcher, Dispatcher):
      dispatcher = dispatcher.dispatcher
    dispatcher.add_target(param_type, fn)
    def ff(*args, **kw):
      return dispatcher(*args, **kw)
    ff.dispatcher = dispatcher
    return ff
  return f


class Dispatcher(object):
  def __init__(self, param_name, fn):
    frame = inspect.currentframe().f_back.f_back
    top_level = frame.f_locals == frame.f_globals
    self.param_index = inspect.getargspec(fn).args.index(param_name)
    self.param_name = param_name
    self.targets = {}

  def __call__(self, *args, **kw):
    typ = args[self.param_index].__class__ # BUG FIX: use __class__ here
    d = self.targets.get(typ)
    if d is not None:
      return d(*args, **kw)
    else:
      issub = issubclass
      t = self.targets
      ks = t.iterkeys()
      return [t[k](*args, **kw) for k in ks if issub(typ, k)]

  def add_target(self, typ, target):
    self.targets[typ] = target

Then, in the example code, I had a name conflict. Here's a fix to that.

python
# ast.py

import visit as v
import sys

class BaseNode:
  def accept(self, visitor):
    visitor.visit(self)


class Literal(BaseNode):
  def __init__(self, val):
    self.value = val


class VariableNode(BaseNode):
  def __init__(self, name):
    self.name = name


class AssignmentExpression(BaseNode):
  def __init__(self, left, right):
    self.children = [left, right]


class AbstractSyntaxTreeVisitor(object):
  @v.on('node')
  def visit(self, node):
    """
    This is the generic method that initializes the
    dynamic dispatcher.
    """

  @v.when(BaseNode)
  def visit(self, node):
    """
    Will run for nodes that do specifically match the
    provided type.
    """
    print "Unrecognized node:", node

  @v.when(AssignmentExpression)
  def visit(self, node):
    """ Matches nodes of type AssignmentExpression. """
    node.children[0].accept(self)
    sys.stdout.write('=')
    node.children[1].accept(self)

  @v.when(VariableNode)
  def visit(self, node):
    """ Matches nodes that contain variables. """
    sys.stdout.write(str(node.name))

  @v.when(Literal)
  def visit(self, node):
    """ Matches nodes that contain literal values. """
    sys.stdout.write(str(node.value))

Now, with an example like this, we see that it works correctly.

python
import ast

v = ast.VariableNode('x')
l = ast.Literal(5)
n = ast.AssignmentExpression(v, l)
visitor = ast.AbstractSyntaxTreeVisitor()
visitor.visit(n)
print ''

And that's it.

Get the new visit code from visit.py

Released under CC BY-NC-ND 4.0