Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions refex/python/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,11 @@

from refex.python import error_strings
from refex.python import matcher
from refex.python import matchers
from refex.python import semiliteral_eval
# Actually collect all the matchers into the matchers module, so they can be
# enumerated.
import refex.python.matchers.ast_matchers # pylint: disable=unused-import
import refex.python.matchers.base_matchers # pylint: disable=unused-import
import refex.python.matchers.lexical_matchers # pylint: disable=unused-import
import refex.python.matchers.syntax_matchers # pylint: disable=unused-import
from refex.python.matchers import ast_matchers
from refex.python.matchers import base_matchers
from refex.python.matchers import lexical_matchers
from refex.python.matchers import syntax_matchers


def _sorted_attributes(o):
Expand All @@ -50,21 +47,32 @@ def _sorted_attributes(o):
yield a, getattr(o, a)


def _get_matcher_map():
"""Get the mapping from dotted-names to matcher constructors/callables."""
mapping = {}
for module_name, module in _sorted_attributes(matchers):
for global_variable, value in _sorted_attributes(module):
if not isinstance(value, type):
continue
if not matcher.is_safe_to_eval(value):
continue
mapping[global_variable] = value
mapping['%s.%s' % (module_name, global_variable)] = value
return mapping
# TODO: remove overwrite param


_ALL_MATCHERS = _get_matcher_map()
def add_module(module, overwrite=False):
"""Adds a non-builtin matcher module to be available for compile_matcher."""
for global_variable, value in _sorted_attributes(module):
if not isinstance(value, type):
continue
if not matcher.is_safe_to_eval(value):
continue

is_mutated = False
module_name = module.__name__.rsplit('.', 1)[-1]
for name in global_variable, f'{module_name}.{global_variable}':
if overwrite or name not in _ALL_MATCHERS:
_ALL_MATCHERS[name] = value
is_mutated = True
if not is_mutated:
raise ValueError(f'Could not add matcher: f{value!r}')


_ALL_MATCHERS = {}
add_module(ast_matchers, overwrite=True)
add_module(base_matchers, overwrite=True)
add_module(lexical_matchers, overwrite=True)
add_module(syntax_matchers, overwrite=True)


def compile_matcher(user_input:str) -> matcher.Matcher:
Expand Down