diff --git a/changelog.md b/changelog.md index cff962b4..15e3eacc 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,11 @@ TBD ============== +Features +-------- +* "Eager" completions for the `source` command, limited to `*.sql` files. + + Internal -------- * Remove `align_decimals` preprocessor, which had no effect. diff --git a/mycli/packages/filepaths.py b/mycli/packages/filepaths.py index 2ef3c166..19368050 100644 --- a/mycli/packages/filepaths.py +++ b/mycli/packages/filepaths.py @@ -19,7 +19,11 @@ def list_path(root_dir: str) -> list[str]: res = [] if os.path.isdir(root_dir): for name in os.listdir(root_dir): - res.append(name) + if os.path.isdir(name): + res.append(f'{name}/') + # if .sql is too restrictive it can be made configurable with some effort + elif name.lower().endswith('.sql'): + res.append(name) return res @@ -69,7 +73,16 @@ def suggest_path(root_dir: str) -> list[str]: """ if not root_dir: - return [os.path.abspath(os.sep), "~", os.curdir, os.pardir] + return [ + os.path.abspath(os.sep), + "~", + os.curdir, + os.pardir, + *list_path(os.curdir), + ] + + if root_dir[0] not in ('/', '~') and root_dir[0:1] != './': + return list_path(os.curdir) if "~" in root_dir: root_dir = os.path.expanduser(root_dir) diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py index 008e2f46..b9c7b9fc 100644 --- a/test/test_smart_completion_public_schema_only.py +++ b/test/test_smart_completion_public_schema_only.py @@ -1,5 +1,6 @@ # type: ignore +import os.path from unittest.mock import patch from prompt_toolkit.completion import Completion @@ -589,3 +590,26 @@ def test_create_table_like_completion(completer, complete_event): 'time_zone_leap_second', 'time_zone_transition_type', ] + + +def test_source_eager_completion(completer, complete_event): + text = "source sc" + position = len(text) + script_filename = 'script_for_test_suite.sql' + f = open(script_filename, 'w') + f.close() + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + success = True + error = 'unknown' + try: + assert [x.text for x in result] == [ + 'screenshots/', + script_filename, + ] + except AssertionError as e: + success = False + error = e + if os.path.exists(script_filename): + os.remove(script_filename) + if not success: + raise AssertionError(error)