From 26706703c5da9b8dcb2c491a6d6660a6afc0cb7d Mon Sep 17 00:00:00 2001 From: Devin Jeanpierre Date: Fri, 24 Sep 2021 13:53:49 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 398807521 --- refex/python/evaluate.py | 48 +++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/refex/python/evaluate.py b/refex/python/evaluate.py index e449101..d65a178 100644 --- a/refex/python/evaluate.py +++ b/refex/python/evaluate.py @@ -34,14 +34,11 @@ from refex.python import error_strings from refex.python import matcher -from refex.python import matchers from refex.python import semiliteral_eval -# Actually collect all the matchers into the matchers module, so they can be -# enumerated. -import refex.python.matchers.ast_matchers # pylint: disable=unused-import -import refex.python.matchers.base_matchers # pylint: disable=unused-import -import refex.python.matchers.lexical_matchers # pylint: disable=unused-import -import refex.python.matchers.syntax_matchers # pylint: disable=unused-import +from refex.python.matchers import ast_matchers +from refex.python.matchers import base_matchers +from refex.python.matchers import lexical_matchers +from refex.python.matchers import syntax_matchers def _sorted_attributes(o): @@ -50,21 +47,32 @@ def _sorted_attributes(o): yield a, getattr(o, a) -def _get_matcher_map(): - """Get the mapping from dotted-names to matcher constructors/callables.""" - mapping = {} - for module_name, module in _sorted_attributes(matchers): - for global_variable, value in _sorted_attributes(module): - if not isinstance(value, type): - continue - if not matcher.is_safe_to_eval(value): - continue - mapping[global_variable] = value - mapping['%s.%s' % (module_name, global_variable)] = value - return mapping +# TODO: remove overwrite param -_ALL_MATCHERS = _get_matcher_map() +def add_module(module, overwrite=False): + """Adds a non-builtin matcher module to be available for compile_matcher.""" + for global_variable, value in _sorted_attributes(module): + if not isinstance(value, type): + continue + if not matcher.is_safe_to_eval(value): + continue + + is_mutated = False + module_name = module.__name__.rsplit('.', 1)[-1] + for name in global_variable, f'{module_name}.{global_variable}': + if overwrite or name not in _ALL_MATCHERS: + _ALL_MATCHERS[name] = value + is_mutated = True + if not is_mutated: + raise ValueError(f'Could not add matcher: f{value!r}') + + +_ALL_MATCHERS = {} +add_module(ast_matchers, overwrite=True) +add_module(base_matchers, overwrite=True) +add_module(lexical_matchers, overwrite=True) +add_module(syntax_matchers, overwrite=True) def compile_matcher(user_input:str) -> matcher.Matcher: