diff --git a/changelog.md b/changelog.md index 15e3eacc..da4af9a7 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,11 @@ Features * "Eager" completions for the `source` command, limited to `*.sql` files. +Bug Fixes +-------- +* Refactor completions for special commands, with minor casing fixes. + + Internal -------- * Remove `align_decimals` preprocessor, which had no effect. diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py index 1efd55d0..989ecd93 100644 --- a/mycli/packages/completion_engine.py +++ b/mycli/packages/completion_engine.py @@ -5,6 +5,7 @@ from sqlparse.sql import Comparison, Identifier, Token, Where from mycli.packages.parseutils import extract_tables, find_prev_keyword, last_word +from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS from mycli.packages.special.main import parse_special_command sqlparse.engine.grouping.MAX_GROUPING_DEPTH = None # type: ignore[assignment] @@ -126,8 +127,12 @@ def suggest_type(full_text: str, text_before_cursor: str) -> list[dict[str, Any] # Be careful here because trivial whitespace is parsed as a statement, # but the statement won't have a first token tok1 = statement.token_first() - if tok1 and (tok1.value == "source" or tok1.value.startswith("\\")): + # lenient because \. will parse as two tokens + if tok1 and tok1.value.startswith('\\'): return suggest_special(text_before_cursor) + elif tok1: + if tok1.value.lower() in SPECIAL_COMMANDS: + return suggest_special(text_before_cursor) last_token = statement and statement.token_prev(len(statement.tokens))[1] or "" @@ -146,9 +151,15 @@ def suggest_special(text: str) -> list[dict[str, Any]]: if cmd in ("\\u", "\\r"): return [{"type": "database"}] + if cmd.lower() in ('use', 'connect'): + return [{'type': 'database'}] + if cmd in (r'\T', r'\Tr'): return [{"type": "table_format"}] + if cmd.lower() in ('tableformat', 'redirectformat'): + return [{"type": "table_format"}] + if cmd in ["\\f", "\\fs", "\\fd"]: return [{"type": "favoritequery"}] @@ -158,7 +169,7 @@ def suggest_special(text: str) -> list[dict[str, Any]]: {"type": "view", "schema": []}, {"type": "schema"}, ] - elif cmd in ["\\.", "source"]: + elif cmd.lower() in ["\\.", "source"]: return [{"type": "file_name"}] if cmd in ["\\llm", "\\ai"]: return [{"type": "llm"}] @@ -350,12 +361,11 @@ def suggest_based_on_last_token( suggest.append({"type": "table", "schema": parent}) return suggest - elif token_v in ("use", "database", "template", "connect"): + elif token_v in ("database", "template"): # "\c ", "DROP DATABASE ", # "CREATE DATABASE WITH TEMPLATE " return [{"type": "database"}] - elif token_v in ("tableformat", "redirectformat"): - return [{"type": "table_format"}] + elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]: original_text = text_before_cursor prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) diff --git a/test/test_completion_engine.py b/test/test_completion_engine.py index 71d4692d..0528d05a 100644 --- a/test/test_completion_engine.py +++ b/test/test_completion_engine.py @@ -2,6 +2,7 @@ import pytest +from mycli.packages import special from mycli.packages.completion_engine import suggest_type @@ -538,6 +539,13 @@ def test_specials_included_for_initial_completion(initial_text): assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}, {"type": "special"}]) +@pytest.mark.parametrize('initial_text', ['REDIRECT']) +def test_specials_included_with_caps(initial_text): + suggestions = suggest_type(initial_text, initial_text) + + assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'keyword'}, {'type': 'special'}]) + + def test_specials_not_included_after_initial_token(): suggestions = suggest_type("create table foo (dt d", "create table foo (dt d") @@ -593,6 +601,8 @@ def test_after_as(expression): ], ) def test_source_is_file(expression): + # "source" has to be registered by hand because that usually happens inside MyCLI in mycli/main.py + special.register_special_command(..., 'source', '\\. filename', 'Execute commands from file.', aliases=['\\.']) suggestions = suggest_type(expression, expression) assert suggestions == [{"type": "file_name"}]