Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ Features
* "Eager" completions for the `source` command, limited to `*.sql` files.


Bug Fixes
--------
* Refactor completions for special commands, with minor casing fixes.


Internal
--------
* Remove `align_decimals` preprocessor, which had no effect.
Expand Down
20 changes: 15 additions & 5 deletions mycli/packages/completion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlparse.sql import Comparison, Identifier, Token, Where

from mycli.packages.parseutils import extract_tables, find_prev_keyword, last_word
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
from mycli.packages.special.main import parse_special_command

sqlparse.engine.grouping.MAX_GROUPING_DEPTH = None # type: ignore[assignment]
Expand Down Expand Up @@ -126,8 +127,12 @@ def suggest_type(full_text: str, text_before_cursor: str) -> list[dict[str, Any]
# Be careful here because trivial whitespace is parsed as a statement,
# but the statement won't have a first token
tok1 = statement.token_first()
if tok1 and (tok1.value == "source" or tok1.value.startswith("\\")):
# lenient because \. will parse as two tokens
if tok1 and tok1.value.startswith('\\'):
return suggest_special(text_before_cursor)
elif tok1:
if tok1.value.lower() in SPECIAL_COMMANDS:
return suggest_special(text_before_cursor)

last_token = statement and statement.token_prev(len(statement.tokens))[1] or ""

Expand All @@ -146,9 +151,15 @@ def suggest_special(text: str) -> list[dict[str, Any]]:
if cmd in ("\\u", "\\r"):
return [{"type": "database"}]

if cmd.lower() in ('use', 'connect'):
return [{'type': 'database'}]

if cmd in (r'\T', r'\Tr'):
return [{"type": "table_format"}]

if cmd.lower() in ('tableformat', 'redirectformat'):
return [{"type": "table_format"}]

if cmd in ["\\f", "\\fs", "\\fd"]:
return [{"type": "favoritequery"}]

Expand All @@ -158,7 +169,7 @@ def suggest_special(text: str) -> list[dict[str, Any]]:
{"type": "view", "schema": []},
{"type": "schema"},
]
elif cmd in ["\\.", "source"]:
elif cmd.lower() in ["\\.", "source"]:
return [{"type": "file_name"}]
if cmd in ["\\llm", "\\ai"]:
return [{"type": "llm"}]
Expand Down Expand Up @@ -350,12 +361,11 @@ def suggest_based_on_last_token(
suggest.append({"type": "table", "schema": parent})
return suggest

elif token_v in ("use", "database", "template", "connect"):
elif token_v in ("database", "template"):
# "\c <db", "use <db>", "DROP DATABASE <db>",
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
return [{"type": "database"}]
elif token_v in ("tableformat", "redirectformat"):
return [{"type": "table_format"}]

elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]:
original_text = text_before_cursor
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
Expand Down
10 changes: 10 additions & 0 deletions test/test_completion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from mycli.packages import special
from mycli.packages.completion_engine import suggest_type


Expand Down Expand Up @@ -538,6 +539,13 @@ def test_specials_included_for_initial_completion(initial_text):
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}, {"type": "special"}])


@pytest.mark.parametrize('initial_text', ['REDIRECT'])
def test_specials_included_with_caps(initial_text):
suggestions = suggest_type(initial_text, initial_text)

assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'keyword'}, {'type': 'special'}])


def test_specials_not_included_after_initial_token():
suggestions = suggest_type("create table foo (dt d", "create table foo (dt d")

Expand Down Expand Up @@ -593,6 +601,8 @@ def test_after_as(expression):
],
)
def test_source_is_file(expression):
# "source" has to be registered by hand because that usually happens inside MyCLI in mycli/main.py
special.register_special_command(..., 'source', '\\. filename', 'Execute commands from file.', aliases=['\\.'])
suggestions = suggest_type(expression, expression)
assert suggestions == [{"type": "file_name"}]

Expand Down