diff --git a/changelog.md b/changelog.md index 1e333ee4..ca9206be 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,7 @@ TBD Features -------- * Right-align numeric columns, and make the behavior configurable. +* Add completions for stored procedures. Bug Fixes diff --git a/mycli/completion_refresher.py b/mycli/completion_refresher.py index 1b8ffb07..9be14553 100644 --- a/mycli/completion_refresher.py +++ b/mycli/completion_refresher.py @@ -155,6 +155,11 @@ def refresh_functions(completer: SQLCompleter, executor: SQLExecute) -> None: completer.extend_functions(completer.tidb_functions, builtin=True) +@refresher("procedures") +def refresh_procedures(completer: SQLCompleter, executor: SQLExecute) -> None: + completer.extend_procedures(executor.procedures()) + + @refresher("special_commands") def refresh_special(completer: SQLCompleter, executor: SQLExecute) -> None: completer.extend_special_commands(list(COMMANDS.keys())) diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py index 87c09790..67f0132d 100644 --- a/mycli/packages/completion_engine.py +++ b/mycli/packages/completion_engine.py @@ -254,6 +254,8 @@ def suggest_based_on_last_token( # We're probably in a function argument list return [{"type": "column", "tables": extract_tables(full_text)}] + elif token_v in ("call"): + return [{"type": "procedure", "schema": []}] elif token_v in ("set", "order by", "distinct"): return [{"type": "column", "tables": extract_tables(full_text)}] elif token_v == "as": diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index 3d20ffeb..177a5018 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -924,6 +924,14 @@ def extend_functions(self, func_data: list[str] | Generator[tuple[str, str]], bu metadata[self.dbname][func[0]] = None self.all_completions.add(func[0]) + def extend_procedures(self, procedure_data: Generator[tuple[str, str]]) -> None: + metadata = self.dbmetadata["procedures"] + if self.dbname not in metadata: + metadata[self.dbname] = {} + + for elt in procedure_data: + metadata[self.dbname][elt[0]] = None + def set_dbname(self, dbname: str | None) -> None: self.dbname = dbname or '' @@ -932,7 +940,13 @@ def reset_completions(self) -> None: self.users: list[str] = [] self.show_items: list[Completion] = [] self.dbname = "" - self.dbmetadata: dict[str, Any] = {"tables": {}, "views": {}, "functions": {}, "enum_values": {}} + self.dbmetadata: dict[str, Any] = { + "tables": {}, + "views": {}, + "functions": {}, + "procedures": {}, + "enum_values": {}, + } self.all_completions = set(self.keywords + self.functions) @staticmethod @@ -1093,6 +1107,11 @@ def get_completions( ) completions.extend(predefined_funcs) + elif suggestion["type"] == "procedure": + procs = self.populate_schema_objects(suggestion["schema"], "procedures") + procs_m = self.find_matches(word_before_cursor, procs) + completions.extend(procs_m) + elif suggestion["type"] == "table": tables = self.populate_schema_objects(suggestion["schema"], "tables") tables_m = self.find_matches(word_before_cursor, tables) diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index a25978a1..dcdf3ae7 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -99,6 +99,9 @@ class SQLExecute: functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"''' + procedures_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES + WHERE ROUTINE_TYPE="PROCEDURE" AND ROUTINE_SCHEMA = "%s"''' + table_columns_query = """select TABLE_NAME, COLUMN_NAME from information_schema.columns where table_schema = '%s' order by table_name,ordinal_position""" @@ -452,6 +455,16 @@ def functions(self) -> Generator[tuple[str, str], None, None]: for row in cur: yield row + def procedures(self) -> Generator[tuple[str, str], None, None]: + """Yields tuples of (procedure_name, )""" + + assert isinstance(self.conn, Connection) + with self.conn.cursor() as cur: + _logger.debug("Procedures Query. sql: %r", self.procedures_query) + cur.execute(self.procedures_query % self.dbname) + for row in cur: + yield row + def show_candidates(self) -> Generator[tuple, None, None]: assert isinstance(self.conn, Connection) with self.conn.cursor() as cur: diff --git a/test/test_completion_refresher.py b/test/test_completion_refresher.py index b94db2ce..03583d4b 100644 --- a/test/test_completion_refresher.py +++ b/test/test_completion_refresher.py @@ -29,6 +29,7 @@ def test_ctor(refresher): "enum_values", "users", "functions", + "procedures", "special_commands", "show_commands", "keywords",