Skip to content

Commit 991334a

Browse files
authored
Merge pull request #1753 from dbcli/RW/progress-checkpoint-by-statement
Make `--progress` and `--checkpoint` strictly by-statement
2 parents 09f05cb + a0c50b8 commit 991334a

5 files changed

Lines changed: 144 additions & 9 deletions

File tree

changelog.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ Upcoming (TBD)
44
Features
55
---------
66
* Continue to expand TIPS.
7+
* Make `--progress` and `--checkpoint` strictly by statement.
8+
9+
10+
Internal
11+
---------
12+
* Add an `AGENTS.md`.
713

814

915
Internal

mycli/main.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2190,7 +2190,7 @@ class CliArgs:
21902190

21912191
@click.command()
21922192
@clickdc.adddc('cli_args', CliArgs)
2193-
@click.version_option(__version__, '--version', '-V', help='Output mycli\'s version.')
2193+
@click.version_option(__version__, '--version', '-V', help="Output mycli's version.")
21942194
def click_entrypoint(
21952195
cli_args: CliArgs,
21962196
) -> None:
@@ -2658,7 +2658,7 @@ def get_password_from_file(password_file: str | None) -> str | None:
26582658
cli_args.port,
26592659
)
26602660

2661-
# --execute argument
2661+
# --execute argument
26622662
if cli_args.execute:
26632663
if not sys.stdin.isatty():
26642664
click.secho('Ignoring STDIN since --execute was also given.', err=True, fg='red')
@@ -2742,6 +2742,7 @@ def dispatch_batch_statements(statements: str, batch_counter: int) -> None:
27422742
goal_statements += 1
27432743
batch_count_h.close()
27442744
batch_h = click.open_file(cli_args.batch)
2745+
batch_gen = statements_from_filehandle(batch_h)
27452746
except (OSError, FileNotFoundError):
27462747
click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red')
27472748
sys.exit(1)
@@ -2762,9 +2763,9 @@ def dispatch_batch_statements(statements: str, batch_counter: int) -> None:
27622763
]
27632764
err_output = prompt_toolkit.output.create_output(stdout=sys.stderr, always_prefer_tty=True)
27642765
with ProgressBar(style=pb_style, formatters=custom_formatters, output=err_output) as pb:
2765-
for pb_counter in pb(range(goal_statements)):
2766-
statement, _untrusted_counter = next(statements_from_filehandle(batch_h))
2767-
dispatch_batch_statements(statement, pb_counter)
2766+
for _pb_counter in pb(range(goal_statements)):
2767+
statement, statement_counter = next(batch_gen)
2768+
dispatch_batch_statements(statement, statement_counter)
27682769
except (ValueError, StopIteration) as e:
27692770
click.secho(str(e), err=True, fg='red')
27702771
sys.exit(1)

mycli/packages/batch_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import IO, Generator
22

33
import sqlglot
4+
import sqlparse
45

56
MAX_MULTILINE_BATCH_STATEMENT = 5000
67

@@ -20,11 +21,16 @@ def statements_from_filehandle(file_h: IO) -> Generator[tuple[str, int], None, N
2021
continue
2122
# we don't yet handle changing the delimiter within the batch input
2223
if tokens[-1].text == ';':
23-
yield (statements, batch_counter)
24-
batch_counter += 1
24+
# The advantage of sqlparse for splitting is that it preserves the input.
25+
# https://github.com/tobymao/sqlglot/issues/2587#issuecomment-1823109501
26+
for statement in sqlparse.split(statements):
27+
yield (statement, batch_counter)
28+
batch_counter += 1
2529
statements = ''
2630
line_counter = 0
2731
except sqlglot.errors.TokenError:
2832
continue
2933
if statements:
30-
yield (statements, batch_counter)
34+
for statement in sqlparse.split(statements):
35+
yield (statement, batch_counter)
36+
batch_counter += 1

test/pytests/test_batch_utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# type: ignore
2+
3+
from io import StringIO
4+
5+
import pytest
6+
7+
import mycli.packages.batch_utils
8+
from mycli.packages.batch_utils import statements_from_filehandle
9+
10+
11+
def collect_statements(sql: str) -> list[tuple[str, int]]:
12+
return list(statements_from_filehandle(StringIO(sql)))
13+
14+
15+
def test_statements_from_filehandle_splits_on_statements() -> None:
16+
statements = collect_statements('select 1;\nselect\n 2;\nselect 3; select 4;\n')
17+
18+
assert statements == [
19+
('select 1;', 0),
20+
('select\n 2;', 1),
21+
('select 3;', 2),
22+
('select 4;', 3),
23+
]
24+
25+
26+
def test_statements_from_filehandle_yields_trailing_statement_without_newline_01() -> None:
27+
statements = collect_statements('select 1;\nselect 2;')
28+
29+
assert statements == [
30+
('select 1;', 0),
31+
('select 2;', 1),
32+
]
33+
34+
35+
def test_statements_from_filehandle_yields_trailing_statement_without_newline_02() -> None:
36+
statements = collect_statements('select 1;\nselect 2')
37+
38+
assert statements == [
39+
('select 1;', 0),
40+
('select 2', 1),
41+
]
42+
43+
44+
def test_statements_from_filehandle_yields_trailing_statement_without_newline_03() -> None:
45+
statements = collect_statements('select 1\nwhere 1 == 1;')
46+
47+
assert statements == [('select 1\nwhere 1 == 1;', 0)]
48+
49+
50+
def test_statements_from_filehandle_rejects_overlong_statement(monkeypatch) -> None:
51+
monkeypatch.setattr(mycli.packages.batch_utils, 'MAX_MULTILINE_BATCH_STATEMENT', 2)
52+
53+
with pytest.raises(ValueError, match='Saw single input statement greater than 2 lines'):
54+
list(statements_from_filehandle(StringIO('select 1,\n2\nwhere 1 = 1;')))

test/pytests/test_main.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2139,6 +2139,25 @@ def test_batch_file(monkeypatch):
21392139
os.remove(batch_file.name)
21402140

21412141

2142+
def test_batch_file_no_progress_multiple_statements_per_line(monkeypatch):
2143+
mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch)
2144+
runner = CliRunner()
2145+
2146+
with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode='w', delete=False) as batch_file:
2147+
batch_file.write('select 2; select 3;\nselect 4;\n')
2148+
batch_file.flush()
2149+
2150+
try:
2151+
result = runner.invoke(
2152+
mycli_main.click_entrypoint,
2153+
args=['--batch', batch_file.name],
2154+
)
2155+
assert result.exit_code == 0
2156+
assert MockMyCli.ran_queries == ['select 2;', 'select 3;', 'select 4;']
2157+
finally:
2158+
os.remove(batch_file.name)
2159+
2160+
21422161
def test_batch_file_with_progress(monkeypatch):
21432162
mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch)
21442163
runner = CliRunner()
@@ -2182,7 +2201,56 @@ def __call__(self, iterable):
21822201
args=['--batch', batch_file.name, '--progress'],
21832202
)
21842203
assert result.exit_code == 0
2185-
assert MockMyCli.ran_queries == ['select 2;\n', 'select 2;\n', 'select 2;\n']
2204+
assert MockMyCli.ran_queries == ['select 2;', 'select 2;', 'select 2;']
2205+
assert DummyProgressBar.calls == [[0, 1, 2]]
2206+
finally:
2207+
os.remove(batch_file.name)
2208+
2209+
2210+
def test_batch_file_with_progress_multiple_statements_per_line(monkeypatch):
2211+
mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch)
2212+
runner = CliRunner()
2213+
2214+
class DummyProgressBar:
2215+
calls = []
2216+
2217+
def __init__(self, *args, **kwargs):
2218+
pass
2219+
2220+
def __enter__(self):
2221+
return self
2222+
2223+
def __exit__(self, exc_type, exc, tb):
2224+
return False
2225+
2226+
def __call__(self, iterable):
2227+
values = list(iterable)
2228+
DummyProgressBar.calls.append(values)
2229+
return values
2230+
2231+
monkeypatch.setattr(mycli_main, 'ProgressBar', DummyProgressBar)
2232+
monkeypatch.setattr(mycli_main.prompt_toolkit.output, 'create_output', lambda **kwargs: object())
2233+
monkeypatch.setattr(
2234+
mycli_main,
2235+
'sys',
2236+
SimpleNamespace(
2237+
stdin=SimpleNamespace(isatty=lambda: False),
2238+
stderr=SimpleNamespace(isatty=lambda: True),
2239+
exit=sys.exit,
2240+
),
2241+
)
2242+
2243+
with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode='w', delete=False) as batch_file:
2244+
batch_file.write('select 2; select 3;\nselect 4;\n')
2245+
batch_file.flush()
2246+
2247+
try:
2248+
result = runner.invoke(
2249+
mycli_main.click_entrypoint,
2250+
args=['--batch', batch_file.name, '--progress'],
2251+
)
2252+
assert result.exit_code == 0
2253+
assert MockMyCli.ran_queries == ['select 2;', 'select 3;', 'select 4;']
21862254
assert DummyProgressBar.calls == [[0, 1, 2]]
21872255
finally:
21882256
os.remove(batch_file.name)

0 commit comments

Comments
 (0)