From 129dac248ee3f667eb69e72decf15a24091cb548 Mon Sep 17 00:00:00 2001 From: Roland Walker Date: Tue, 31 Mar 2026 06:58:04 -0400 Subject: [PATCH] add more tests for parseutils.py covering SQL parsing, though there is some non-SQL as well. Added tests for * extract_from_part() * extract_table_identifiers() * find_prev_keyword() * get_last_select() * is_subselect() * is_valid_connection_scheme() * last_word() * query_is_single_table_update() Also: * adding the tests inspired catching a possible IndexError in query_is_single_table_update(). * removed an unused __main__ section from parseutils.py * accidentally changed some quoting in the tests * test_extract_columns_from_select() seemed to incorrectly catch an exception, which was removed. If tests can raise, that is something that should itself be tested, not covered up. --- mycli/packages/parseutils.py | 24 +-- test/pytests/test_parseutils.py | 289 ++++++++++++++++++++++++++------ 2 files changed, 246 insertions(+), 67 deletions(-) diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py index 7a2b341f..53b96823 100644 --- a/mycli/packages/parseutils.py +++ b/mycli/packages/parseutils.py @@ -379,13 +379,18 @@ def query_is_single_table_update(query: str) -> bool: if not parsed: return False statement = parsed[0] - return ( - statement[0].value.lower() == 'update' - and statement[1].is_whitespace - and ',' not in statement[2].value # multiple tables - and statement[3].is_whitespace - and statement[4].value.lower() == 'set' - ) + try: + retval = bool( + statement[0].value.lower() == 'update' + and statement[1].is_whitespace + and ',' not in statement[2].value # multiple tables + and statement[3].is_whitespace + and statement[4].value.lower() == 'set' + ) + except IndexError: + retval = False + + return retval def is_destructive(keywords: list[str], queries: str) -> bool: @@ -428,8 +433,3 @@ def normalize_db_name(db: str) -> str: if database_token is not None and normalize_db_name(database_token.get_name()) == dbname: result = keywords[0].normalized == "DROP" return result - - -if __name__ == "__main__": - sql = "select * from (select t. from tabl t" - print(extract_tables(sql)) diff --git a/test/pytests/test_parseutils.py b/test/pytests/test_parseutils.py index 13d79b0b..9e9d2ae9 100644 --- a/test/pytests/test_parseutils.py +++ b/test/pytests/test_parseutils.py @@ -1,85 +1,101 @@ # type: ignore import pytest +import sqlparse from mycli.packages.parseutils import ( extract_columns_from_select, + extract_from_part, + extract_table_identifiers, extract_tables, extract_tables_from_complete_statements, + find_prev_keyword, + get_last_select, is_destructive, is_dropping_database, + is_subselect, + is_valid_connection_scheme, + last_word, queries_start_with, query_has_where_clause, + query_is_single_table_update, query_starts_with, ) def test_extract_columns_from_select(): - try: - columns = extract_columns_from_select("SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT FROM INFORMATION_SCHEMA.COLUMNS") - except Exception: - columns = [] - assert columns == ["COLUMN_NAME", "DATA_TYPE", "IS_NULLABLE", "COLUMN_DEFAULT"] + columns = extract_columns_from_select('SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT FROM INFORMATION_SCHEMA.COLUMNS') + assert columns == ['COLUMN_NAME', 'DATA_TYPE', 'IS_NULLABLE', 'COLUMN_DEFAULT'] + + +def test_extract_columns_from_select_empty(): + columns = extract_columns_from_select('') + assert columns == [] + + +def test_extract_columns_from_select_update(): + columns = extract_columns_from_select('UPDATE table SET value = 1 WHERE id = 1') + assert columns == [] def test_empty_string(): - tables = extract_tables("") + tables = extract_tables('') assert tables == [] def test_simple_select_single_table(): - tables = extract_tables("select * from abc") - assert tables == [(None, "abc", None)] + tables = extract_tables('select * from abc') + assert tables == [(None, 'abc', None)] def test_simple_select_single_table_schema_qualified(): - tables = extract_tables("select * from abc.def") - assert tables == [("abc", "def", None)] + tables = extract_tables('select * from abc.def') + assert tables == [('abc', 'def', None)] def test_simple_select_multiple_tables(): - tables = extract_tables("select * from abc, def") - assert sorted(tables) == [(None, "abc", None), (None, "def", None)] + tables = extract_tables('select * from abc, def') + assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] def test_simple_select_multiple_tables_schema_qualified(): - tables = extract_tables("select * from abc.def, ghi.jkl") - assert sorted(tables) == [("abc", "def", None), ("ghi", "jkl", None)] + tables = extract_tables('select * from abc.def, ghi.jkl') + assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)] def test_simple_select_with_cols_single_table(): - tables = extract_tables("select a,b from abc") - assert tables == [(None, "abc", None)] + tables = extract_tables('select a,b from abc') + assert tables == [(None, 'abc', None)] def test_simple_select_with_cols_single_table_schema_qualified(): - tables = extract_tables("select a,b from abc.def") - assert tables == [("abc", "def", None)] + tables = extract_tables('select a,b from abc.def') + assert tables == [('abc', 'def', None)] def test_simple_select_with_cols_multiple_tables(): - tables = extract_tables("select a,b from abc, def") - assert sorted(tables) == [(None, "abc", None), (None, "def", None)] + tables = extract_tables('select a,b from abc, def') + assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] def test_simple_select_with_cols_multiple_tables_with_schema(): - tables = extract_tables("select a,b from abc.def, def.ghi") - assert sorted(tables) == [("abc", "def", None), ("def", "ghi", None)] + tables = extract_tables('select a,b from abc.def, def.ghi') + assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)] def test_select_with_hanging_comma_single_table(): - tables = extract_tables("select a, from abc") - assert tables == [(None, "abc", None)] + tables = extract_tables('select a, from abc') + assert tables == [(None, 'abc', None)] def test_select_with_hanging_comma_multiple_tables(): - tables = extract_tables("select a, from abc, def") - assert sorted(tables) == [(None, "abc", None), (None, "def", None)] + tables = extract_tables('select a, from abc, def') + assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] def test_select_with_hanging_period_multiple_tables(): - tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2") - assert sorted(tables) == [(None, "tabl1", "t1"), (None, "tabl2", "t2")] + tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2') + assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')] def test_simple_insert_single_table(): @@ -87,73 +103,236 @@ def test_simple_insert_single_table(): # sqlparse mistakenly assigns an alias to the table # assert tables == [(None, 'abc', None)] - assert tables == [(None, "abc", "abc")] + assert tables == [(None, 'abc', 'abc')] def test_simple_insert_single_table_schema_qualified(): tables = extract_tables('insert into abc.def (id, name) values (1, "def")') - assert tables == [("abc", "def", None)] + assert tables == [('abc', 'def', None)] def test_simple_update_table(): - tables = extract_tables("update abc set id = 1") - assert tables == [(None, "abc", None)] + tables = extract_tables('update abc set id = 1') + assert tables == [(None, 'abc', None)] def test_simple_update_table_with_schema(): - tables = extract_tables("update abc.def set id = 1") - assert tables == [("abc", "def", None)] + tables = extract_tables('update abc.def set id = 1') + assert tables == [('abc', 'def', None)] def test_join_table(): - tables = extract_tables("SELECT * FROM abc a JOIN def d ON a.id = d.num") - assert sorted(tables) == [(None, "abc", "a"), (None, "def", "d")] + tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num') + assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')] def test_join_table_schema_qualified(): - tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num") - assert tables == [("abc", "def", "x"), ("ghi", "jkl", "y")] + tables = extract_tables('SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num') + assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')] def test_join_as_table(): - tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5") - assert tables == [(None, "my_table", "m")] + tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5') + assert tables == [(None, 'my_table', 'm')] def test_extract_tables_from_complete_statements(): - tables = extract_tables_from_complete_statements("SELECT * FROM my_table AS m WHERE m.a > 5") - assert tables == [(None, "my_table", "m")] + tables = extract_tables_from_complete_statements('SELECT * FROM my_table AS m WHERE m.a > 5') + assert tables == [(None, 'my_table', 'm')] def test_extract_tables_from_complete_statements_cte(): - tables = extract_tables_from_complete_statements("WITH my_cte (id, num) AS ( SELECT id, COUNT(1) FROM my_table GROUP BY id ) SELECT *") - assert tables == [(None, "my_table", None)] + tables = extract_tables_from_complete_statements('WITH my_cte (id, num) AS ( SELECT id, COUNT(1) FROM my_table GROUP BY id ) SELECT *') + assert tables == [(None, 'my_table', None)] # this would confuse plain extract_tables() per #1122 def test_extract_tables_from_multiple_complete_statements(): tables = extract_tables_from_complete_statements(r'\T sql-insert; SELECT * FROM my_table AS m WHERE m.a > 5') - assert tables == [(None, "my_table", "m")] + assert tables == [(None, 'my_table', 'm')] def test_query_starts_with(): - query = "USE test;" - assert query_starts_with(query, ("use",)) is True + query = 'USE test;' + assert query_starts_with(query, ('use',)) is True - query = "DROP DATABASE test;" - assert query_starts_with(query, ("use",)) is False + query = 'DROP DATABASE test;' + assert query_starts_with(query, ('use',)) is False def test_query_starts_with_comment(): - query = "# comment\nUSE test;" - assert query_starts_with(query, ("use",)) is True + query = '# comment\nUSE test;' + assert query_starts_with(query, ('use',)) is True def test_queries_start_with(): - sql = "# comment\nshow databases;use foo;" - assert queries_start_with(sql, ["show", "select"]) is True - assert queries_start_with(sql, ["use", "drop"]) is True - assert queries_start_with(sql, ["delete", "update"]) is False + sql = '# comment\nshow databases;use foo;' + assert queries_start_with(sql, ['show', 'select']) is True + assert queries_start_with(sql, ['use', 'drop']) is True + assert queries_start_with(sql, ['delete', 'update']) is False + + +@pytest.mark.parametrize( + ('text', 'is_valid', 'invalid_scheme'), + [ + ('localhost', False, None), + ('mysql://user@localhost/db', True, None), + ('mysqlx://user@localhost/db', True, None), + ('tcp://localhost:3306', True, None), + ('socket:///tmp/mysql.sock', True, None), + ('ssh://user@example.com', True, None), + ('postgres://user@localhost/db', False, 'postgres'), + ('http://example.com', False, 'http'), + ], +) +def test_is_valid_connection_scheme(text, is_valid, invalid_scheme): + assert is_valid_connection_scheme(text) == (is_valid, invalid_scheme) + + +@pytest.mark.parametrize( + ('text', 'include', 'expected'), + [ + ('abc', 'alphanum_underscore', 'abc'), + (' abc', 'alphanum_underscore', 'abc'), + ('', 'alphanum_underscore', ''), + (' ', 'alphanum_underscore', ''), + ('abc ', 'alphanum_underscore', ''), + ('abc def', 'alphanum_underscore', 'def'), + ('abc def ', 'alphanum_underscore', ''), + ('abc def;', 'alphanum_underscore', ''), + ('bac $def', 'alphanum_underscore', 'def'), + ('bac $def', 'most_punctuations', '$def'), + (r'bac \def', 'most_punctuations', r'\def'), + (r'bac \def;', 'most_punctuations', r'\def;'), + ('bac::def', 'most_punctuations', 'def'), + ('abc:def', 'many_punctuations', 'def'), + ('abc.def', 'all_punctuations', 'abc.def'), + ], +) +def test_last_word(text, include, expected): + assert last_word(text, include=include) == expected + + +def test_is_subselect_returns_false_for_non_group_token(): + token = sqlparse.parse('foo')[0].tokens[0] + assert is_subselect(token) is False + + +def test_is_subselect_returns_false_for_group_without_dml(): + token = sqlparse.parse('(foo)')[0].tokens[0] + assert is_subselect(token) is False + + +def test_is_subselect_returns_true_for_group_with_select(): + token = sqlparse.parse('(select 1)')[0].tokens[0] + assert is_subselect(token) is True + + +def test_get_last_select_returns_empty_token_list_without_select(): + parsed = sqlparse.parse('update t set x = 1')[0] + assert list(get_last_select(parsed).flatten()) == [] + + +def test_get_last_select_returns_single_select_statement(): + parsed = sqlparse.parse('select c1')[0] + tokens = get_last_select(parsed) + assert ''.join(token.value for token in tokens.flatten()) == 'select c1' + + +def test_get_last_select_returns_single_select_statement_with_from(): + parsed = sqlparse.parse('select c1 from')[0] + tokens = get_last_select(parsed) + assert ''.join(token.value for token in tokens.flatten()) == 'select c1 from' + + +def test_get_last_select_returns_last_top_level_select(): + parsed = sqlparse.parse('select c1 union select c2')[0] + tokens = get_last_select(parsed) + assert ''.join(token.value for token in tokens.flatten()) == 'select c2' + + +def test_get_last_select_keeps_outer_select_for_nested_subselect(): + parsed = sqlparse.parse('select c1 from (select c2')[0] + tokens = get_last_select(parsed) + assert ''.join(token.value for token in tokens.flatten()) == 'select c2' + + +def token_values(tokens): + return [token.value for token in tokens if not getattr(token, 'is_whitespace', False)] + + +# todo: coverage of stop_at_punctuation parameter +def test_extract_from_part_returns_identifier_after_from(): + parsed = sqlparse.parse('select * from abc')[0] + tokens = extract_from_part(parsed) + assert token_values(tokens) == ['abc'] + + +def test_extract_from_part_returns_identifier_list(): + parsed = sqlparse.parse('select * from abc, def')[0] + tokens = extract_from_part(parsed) + assert token_values(tokens) == ['abc, def'] + + +def test_extract_from_part_handles_multiple_joins_and_skips_on_clause(): + parsed = sqlparse.parse('select * from abc join def on abc.id = def.id join ghi')[0] + tokens = extract_from_part(parsed) + assert token_values(tokens) == ['abc', 'join', 'def', 'ghi'] + + +def test_extract_table_identifiers_handles_identifier_list(): + parsed = sqlparse.parse('select * from abc a, def d')[0] + token_stream = extract_from_part(parsed) + assert list(extract_table_identifiers(token_stream)) == [ + (None, 'abc', 'a'), + (None, 'def', 'd'), + ] + + +def test_extract_table_identifiers_handles_schema_qualified_identifier(): + parsed = sqlparse.parse('select * from abc.def x')[0] + token_stream = extract_from_part(parsed) + assert list(extract_table_identifiers(token_stream)) == [('abc', 'def', 'x')] + + +def test_extract_table_identifiers_handles_function_tokens(): + parsed = sqlparse.parse('select * from my_func()')[0] + token_stream = extract_from_part(parsed) + assert list(extract_table_identifiers(token_stream)) == [(None, 'my_func', 'my_func')] + + +@pytest.mark.parametrize( + ('sql', 'expected_keyword', 'expected_text'), + [ + ('', None, ''), + ('foo', None, ''), + ('select * from foo where bar = 1', 'where', 'select * from foo where'), + ('select * from foo where a = 1 and b = 2', 'where', 'select * from foo where'), + ('select * from foo where a between 1 and 2', 'where', 'select * from foo where'), + ('select count(', '(', 'select count('), + ], +) +def test_find_prev_keyword(sql, expected_keyword, expected_text): + token, text = find_prev_keyword(sql) + assert (token.value if token else None) == expected_keyword + assert text == expected_text + + +@pytest.mark.parametrize( + ('sql', 'is_single_table'), + [ + ('update test set x = 1', True), + ('update test t set x = 1', True), + ('update /* inline comment */ test set x = 1', True), + ('select 1', False), + ('', False), + ('update', False), + ('update test, foo set x = 1', False), + ('update test join foo on test.id = foo.id set test.x = 1', False), + ], +) +def test_query_is_single_table_update(sql, is_single_table): + assert query_is_single_table_update(sql) is is_single_table def test_is_destructive():