Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Features
---------
* Respond to `-h` alone with the helpdoc.
* Allow `--hostname` as an alias for `--host`.
* Suggest tables with foreign key relationships for JOIN and ON (#975)
* Deprecate `$DSN` environment variable in favor of `$MYSQL_DSN`.


Expand Down
5 changes: 5 additions & 0 deletions mycli/completion_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def refresh_tables(completer: SQLCompleter, executor: SQLExecute) -> None:
completer.extend_columns(table_columns_dbresult, kind="tables")


@refresher("foreign_keys")
def refresh_foreign_keys(completer: SQLCompleter, executor: SQLExecute) -> None:
completer.extend_foreign_keys(executor.foreign_keys())


@refresher("enum_values")
def refresh_enum_values(completer: SQLCompleter, executor: SQLExecute) -> None:
completer.extend_enum_values(executor.enum_values())
Expand Down
8 changes: 6 additions & 2 deletions mycli/packages/completion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,10 +476,14 @@ def suggest_based_on_last_token(
or (token_v == "like" and re.match(r'^\s*create\s+table\s', full_text, re.IGNORECASE))
):
schema = (identifier and identifier.get_parent_name()) or []
is_join = token_v.endswith("join")

# Suggest tables from either the currently-selected schema or the
# public schema if no schema has been specified
suggest = [{"type": "table", "schema": schema}]
table_suggestion: dict[str, Any] = {"type": "table", "schema": schema}
if is_join:
table_suggestion["join"] = True
suggest = [table_suggestion]

if not schema:
# Suggest schemas
Expand Down Expand Up @@ -516,7 +520,7 @@ def suggest_based_on_last_token(
# ON <suggestion>
# Use table alias if there is one, otherwise the table name
aliases = [alias or table for (schema, table, alias) in tables]
suggest = [{"type": "alias", "aliases": aliases}]
suggest = [{"type": "fk_join", "tables": tables}, {"type": "alias", "aliases": aliases}]

# The lists of 'aliases' could be empty if we're trying to complete
# a GRANT query. eg: GRANT SELECT, INSERT ON <tab>
Expand Down
96 changes: 89 additions & 7 deletions mycli/sqlcompleter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from mycli.packages.completion_engine import is_inside_quotes, suggest_type
from mycli.packages.filepaths import complete_path, parse_path, suggest_path
from mycli.packages.parseutils import extract_columns_from_select, last_word
from mycli.packages.parseutils import extract_columns_from_select, extract_tables, last_word
from mycli.packages.special import llm
from mycli.packages.special.favoritequeries import FavoriteQueries
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
Expand Down Expand Up @@ -1052,6 +1052,51 @@ def extend_enum_values(self, enum_data: Iterable[tuple[str, str, list[str]]]) ->
table_meta = metadata[self.dbname].setdefault(relname_escaped, {})
table_meta[column_escaped] = values

def extend_foreign_keys(self, fk_data: Iterable[tuple[str, str, str, str]]) -> None:
"""Extend FK metadata.

:param fk_data: iterable of (table_name, column_name, referenced_table_name, referenced_column_name)
"""
metadata = self.dbmetadata["foreign_keys"]
schema_meta = metadata.setdefault(self.dbname, {})
schema_meta.setdefault("tables", {})
schema_meta.setdefault("relations", [])
for table, col, ref_table, ref_col in fk_data:
table = self.escape_name(table)
col = self.escape_name(col)
ref_table = self.escape_name(ref_table)
ref_col = self.escape_name(ref_col)
schema_meta["tables"].setdefault(table, set()).add(ref_table)
schema_meta["tables"].setdefault(ref_table, set()).add(table)
schema_meta["relations"].append((table, col, ref_table, ref_col))

def _fk_join_conditions(self, tables: list[tuple[str | None, str, str]]) -> list[str]:
"""Return FK-based join condition strings for the tables currently in the query.

For each FK relation where both the FK table and the referenced table appear in
*tables*, yields a string like ``alias1.col = alias2.ref_col`` (using the alias
when one exists, otherwise the table name).
"""
schema_meta = self.dbmetadata["foreign_keys"].get(self.dbname, {})
relations = schema_meta.get("relations", [])

# Map escaped table name -> alias (or table name when no alias).
# Skip tables from a different schema; we only have FK metadata for the current db.
alias_map: dict[str, str] = {}
for tbl_schema, tbl, alias in tables:
if tbl_schema and tbl_schema != self.dbname:
continue
escaped = self.escape_name(tbl)
alias_map[escaped] = alias or tbl

conditions: list[str] = []
for fk_table, fk_col, ref_table, ref_col in relations:
lhs = alias_map.get(fk_table)
rhs = alias_map.get(ref_table)
if lhs and rhs:
conditions.append(f"{lhs}.{fk_col} = {rhs}.{ref_col}")
return conditions

def extend_functions(self, func_data: list[str] | Generator[tuple[str, str]], builtin: bool = False) -> None:
# if 'builtin' is set this is extending the list of builtin functions
if builtin:
Expand Down Expand Up @@ -1124,6 +1169,7 @@ def reset_completions(self) -> None:
"functions": {},
"procedures": {},
"enum_values": {},
"foreign_keys": {},
}
self.all_completions = set(self.keywords + self.functions)

Expand Down Expand Up @@ -1366,12 +1412,39 @@ def get_completions(
tables = self.populate_schema_objects(suggestion["schema"], "tables", columns)
else:
tables = self.populate_schema_objects(suggestion["schema"], "tables")
tables_m = self.find_matches(
word_before_cursor,
tables,
text_before_cursor=document.text_before_cursor,
)
completions.extend([(*x, rank) for x in tables_m])

if suggestion.get("join"):
# For JOINs, suggest FK-related tables first (lower rank = higher priority)
current_tables = extract_tables(document.text)
fk_map = self.dbmetadata["foreign_keys"].get(self.dbname, {}).get("tables", {})
fk_related: set[str] = set()
for tbl_schema, tbl, _alias in current_tables:
# Skip cross-schema tables; FK metadata is only for the current db
if tbl_schema and tbl_schema != self.dbname:
continue
escaped = self.escape_name(tbl)
fk_related.update(fk_map.get(escaped, set()))
fk_tables = [t for t in tables if t in fk_related]
other_tables = [t for t in tables if t not in fk_related]
fk_tables_m = self.find_matches(
word_before_cursor,
fk_tables,
text_before_cursor=document.text_before_cursor,
)
other_tables_m = self.find_matches(
word_before_cursor,
other_tables,
text_before_cursor=document.text_before_cursor,
)
completions.extend([(*x, rank) for x in fk_tables_m])
completions.extend([(*x, rank + 1) for x in other_tables_m])
else:
tables_m = self.find_matches(
word_before_cursor,
tables,
text_before_cursor=document.text_before_cursor,
)
completions.extend([(*x, rank) for x in tables_m])

elif suggestion["type"] == "view":
views = self.populate_schema_objects(suggestion["schema"], "views")
Expand All @@ -1382,6 +1455,15 @@ def get_completions(
)
completions.extend([(*x, rank) for x in views_m])

elif suggestion["type"] == "fk_join":
fk_conditions = self._fk_join_conditions(suggestion["tables"])
fk_conditions_m = self.find_matches(
word_before_cursor,
fk_conditions,
text_before_cursor=document.text_before_cursor,
)
completions.extend([(*x, rank) for x in fk_conditions_m])

elif suggestion["type"] == "alias":
aliases = suggestion["aliases"]
aliases_m = self.find_matches(
Expand Down
15 changes: 15 additions & 0 deletions mycli/sqlexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ class SQLExecute:
where table_schema = %s and data_type = 'enum'
order by table_name,ordinal_position"""

foreign_keys_query = """SELECT TABLE_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME
FROM information_schema.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = %s AND REFERENCED_TABLE_NAME IS NOT NULL"""

now_query = """SELECT NOW()"""

@staticmethod
Expand Down Expand Up @@ -440,6 +444,17 @@ def enum_values(self) -> Generator[tuple[str, str, list[str]], None, None]:
if values:
yield (table_name, column_name, values)

def foreign_keys(self) -> Generator[tuple[str, str, str, str], None, None]:
"""Yields (table_name, column_name, referenced_table_name, referenced_column_name) tuples"""
assert isinstance(self.conn, Connection)
with self.conn.cursor() as cur:
_logger.debug("Foreign Keys Query. sql: %r", self.foreign_keys_query)
try:
cur.execute(self.foreign_keys_query, (self.dbname,))
yield from cur
except Exception as e:
_logger.error('No foreign key completions due to %r', e)

def databases(self) -> list[str]:
assert isinstance(self.conn, Connection)
with self.conn.cursor() as cur:
Expand Down
45 changes: 37 additions & 8 deletions test/pytests/test_completion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def test_select_suggests_cols_and_funcs():
"DESCRIBE ",
"DESC ",
"EXPLAIN ",
"SELECT * FROM foo JOIN ",
],
)
def test_expression_suggests_tables_views_and_schemas(expression):
Expand All @@ -179,6 +178,16 @@ def test_expression_suggests_tables_views_and_schemas(expression):
])


def test_join_expression_suggests_tables_views_and_schemas():
expression = "SELECT * FROM foo JOIN "
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{"type": "table", "schema": [], "join": True},
{"type": "view", "schema": []},
{"type": "database"},
])


@pytest.mark.parametrize(
"expression",
[
Expand All @@ -189,7 +198,6 @@ def test_expression_suggests_tables_views_and_schemas(expression):
"DESCRIBE sch.",
"DESC sch.",
"EXPLAIN sch.",
"SELECT * FROM foo JOIN sch.",
],
)
def test_expression_suggests_qualified_tables_views_and_schemas(expression):
Expand All @@ -200,6 +208,15 @@ def test_expression_suggests_qualified_tables_views_and_schemas(expression):
])


def test_join_expression_suggests_qualified_tables_views_and_schemas():
expression = "SELECT * FROM foo JOIN sch."
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{"type": "table", "schema": "sch", "join": True},
{"type": "view", "schema": "sch"},
])


def test_truncate_suggests_tables_and_schemas():
suggestions = suggest_type("TRUNCATE ", "TRUNCATE ")
assert sorted_dicts(suggestions) == sorted_dicts([
Expand Down Expand Up @@ -395,7 +412,7 @@ def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
suggestion = suggest_type(text, text)
assert sorted_dicts(suggestion) == sorted_dicts([
{"type": "database"},
{"type": "table", "schema": []},
{"type": "table", "schema": [], "join": True},
{"type": "view", "schema": []},
])

Expand Down Expand Up @@ -445,7 +462,10 @@ def test_join_alias_dot_suggests_cols2(sql):
)
def test_on_suggests_aliases(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}]
assert suggestions == [
{"type": "fk_join", "tables": [(None, "abc", "a"), (None, "bcd", "b")]},
{"type": "alias", "aliases": ["a", "b"]},
]


@pytest.mark.parametrize(
Expand All @@ -457,7 +477,10 @@ def test_on_suggests_aliases(sql):
)
def test_on_suggests_tables(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}]
assert suggestions == [
{"type": "fk_join", "tables": [(None, "abc", None), (None, "bcd", None)]},
{"type": "alias", "aliases": ["abc", "bcd"]},
]


@pytest.mark.parametrize(
Expand All @@ -469,7 +492,10 @@ def test_on_suggests_tables(sql):
)
def test_on_suggests_aliases_right_side(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}]
assert suggestions == [
{"type": "fk_join", "tables": [(None, "abc", "a"), (None, "bcd", "b")]},
{"type": "alias", "aliases": ["a", "b"]},
]


@pytest.mark.parametrize(
Expand All @@ -481,7 +507,10 @@ def test_on_suggests_aliases_right_side(sql):
)
def test_on_suggests_tables_right_side(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}]
assert suggestions == [
{"type": "fk_join", "tables": [(None, "abc", None), (None, "bcd", None)]},
{"type": "alias", "aliases": ["abc", "bcd"]},
]


@pytest.mark.parametrize("col_list", ["", "col1, "])
Expand Down Expand Up @@ -610,7 +639,7 @@ def test_cross_join():
suggestions = suggest_type(text, text)
assert sorted_dicts(suggestions) == sorted_dicts([
{"type": "database"},
{"type": "table", "schema": []},
{"type": "table", "schema": [], "join": True},
{"type": "view", "schema": []},
])

Expand Down
1 change: 1 addition & 0 deletions test/pytests/test_completion_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_ctor(refresher):
"databases",
"schemata",
"tables",
"foreign_keys",
"enum_values",
"users",
"functions",
Expand Down
Loading
Loading