diff --git a/changelog.md b/changelog.md index 3f928cde..2c433c35 100644 --- a/changelog.md +++ b/changelog.md @@ -5,6 +5,7 @@ Features -------- * Right-align numeric columns, and make the behavior configurable. * Add completions for stored procedures. +* Escape database completions. * Offer completions on `CREATE TABLE ... LIKE`. diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index 177a5018..1b6c0e06 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -814,7 +814,7 @@ def extend_special_commands(self, special_commands: list[str]) -> None: self.special_commands.extend(special_commands) def extend_database_names(self, databases: list[str]) -> None: - self.databases.extend(databases) + self.databases.extend([self.escape_name(db) for db in databases]) def extend_keywords(self, keywords: list[str], replace: bool = False) -> None: if replace: diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py index 6cd857b9..008e2f46 100644 --- a/test/test_smart_completion_public_schema_only.py +++ b/test/test_smart_completion_public_schema_only.py @@ -33,8 +33,12 @@ def completer(): tables.append((table,)) columns.extend([(table, col) for col in cols]) + databases = ["test", "test 2"] + + for db in databases: + comp.extend_schemata(db) + comp.extend_database_names(databases) comp.set_dbname("test") - comp.extend_schemata("test") comp.extend_relations(tables, kind="tables") comp.extend_columns(columns, kind="tables") comp.extend_enum_values([("orders", "status", ["pending", "shipped"])]) @@ -50,6 +54,16 @@ def complete_event(): return Mock() +def test_use_database_completion(completer, complete_event): + text = "USE " + position = len(text) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) + assert list(result) == [ + Completion(text="test", start_position=0), + Completion(text="`test 2`", start_position=0), + ] + + def test_special_name_completion(completer, complete_event): text = "\\d" position = len("\\d") @@ -101,6 +115,8 @@ def test_table_completion(completer, complete_event): Completion(text="time_zone_name", start_position=0), Completion(text="time_zone_transition", start_position=0), Completion(text="time_zone_transition_type", start_position=0), + Completion(text="test", start_position=0), + Completion(text="`test 2`", start_position=0), ] @@ -400,6 +416,8 @@ def test_table_names_after_from(completer, complete_event): Completion(text="time_zone_name", start_position=0), Completion(text="time_zone_transition", start_position=0), Completion(text="time_zone_transition_type", start_position=0), + Completion(text="test", start_position=0), + Completion(text="`test 2`", start_position=0), ] @@ -474,6 +492,8 @@ def test_grant_on_suggets_tables_and_schemata(completer, complete_event): position = len(text) result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == [ + Completion(text="test", start_position=0), + Completion(text="`test 2`", start_position=0), Completion(text='users', start_position=0), Completion(text='orders', start_position=0), Completion(text='`select`', start_position=0),