From 0a9bc666442a7010c42cedc7707773636bbbb8d2 Mon Sep 17 00:00:00 2001 From: bartzbeielstein <32470350+bartzbeielstein@users.noreply.github.com> Date: Sat, 13 Jun 2026 01:03:25 +0200 Subject: [PATCH] feat(tasks)!: remove the tasks subpackage; relocate ENTSO-E CLI to spotforecast2.entsoe_cli The last tasks module, task_entsoe, is dissolved: - entsoe_data_loader / entsoe_test_data_loader now live in spotforecast2_safe.data.entsoe_loader (safe 22.3.0); the pin floor rises to >=22.3.0. - entsoe_lgbm_factory was body-identical to spotforecast2_safe.multitask.factories.default_lgbm_forecaster_factory and is dropped in favour of the safe original. - The CLI (main, entsoe_xgb_factory, pipeline plumbing) moves to spotforecast2.entsoe_cli; the spotforecast2-entsoe console script is unchanged in name and behaviour. - bart26k-lecture migrated its 8 import sites ahead of this change (commit 1df90d3 there); the team4 consumer-contract gate passes 4/4 against the updated qmd. - docs/tasks/entsoe.qmd loses two sections that documented the legacy ForecasterRecursiveLGBM/XGB wrapper API removed back in ADR-002; they are rewritten around the current factory + MultiTask pipeline. BREAKING CHANGE: the spotforecast2.tasks subpackage is removed. Import the data loaders from spotforecast2_safe.data.entsoe_loader, the LightGBM factory from spotforecast2_safe.multitask.factories, and the CLI module as spotforecast2.entsoe_cli. The spotforecast2-entsoe console script is unaffected. Co-Authored-By: Claude Fable 5 --- _quarto.yml | 12 +- docs/multitask/entsoe.qmd | 4 +- docs/tasks/entsoe.qmd | 133 ++++++------ pyproject.toml | 10 +- .../{tasks/task_entsoe.py => entsoe_cli.py} | 192 ++---------------- src/spotforecast2/tasks/__init__.py | 2 - ..._cadence.py => test_entsoe_cli_cadence.py} | 30 +-- ..._predict.py => test_entsoe_cli_predict.py} | 17 +- ...tsoe_train.py => test_entsoe_cli_train.py} | 21 +- tests/test_entsoe_pipeline.py | 8 +- uv.lock | 10 +- 11 files changed, 126 insertions(+), 313 deletions(-) rename src/spotforecast2/{tasks/task_entsoe.py => entsoe_cli.py} (57%) delete mode 100644 src/spotforecast2/tasks/__init__.py rename tests/{test_task_entsoe_cadence.py => test_entsoe_cli_cadence.py} (70%) rename tests/{test_tasks_entsoe_predict.py => test_entsoe_cli_predict.py} (76%) rename tests/{test_tasks_entsoe_train.py => test_entsoe_cli_train.py} (78%) diff --git a/_quarto.yml b/_quarto.yml index 69b46340..5af11cd4 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -130,10 +130,10 @@ website: contents: - text: "autocorrelation" file: docs/reference/stats.autocorrelation.qmd - - section: "Tasks" + - section: "ENTSO-E CLI" contents: - - text: "task_entsoe" - file: docs/reference/tasks.task_entsoe.qmd + - text: "entsoe_cli" + file: docs/reference/entsoe_cli.qmd - section: "Processing Guides" contents: @@ -251,10 +251,10 @@ quartodoc: contents: - stats.autocorrelation - - title: "Tasks" - desc: "ENTSO-E forecasting task and CLI." + - title: "ENTSO-E CLI" + desc: "ENTSO-E download / train / predict command-line pipeline." contents: - - tasks.task_entsoe + - entsoe_cli - title: "Warnings" desc: "Warning-style configuration for spotforecast2." diff --git a/docs/multitask/entsoe.qmd b/docs/multitask/entsoe.qmd index 396c3d50..91a26c43 100644 --- a/docs/multitask/entsoe.qmd +++ b/docs/multitask/entsoe.qmd @@ -30,9 +30,11 @@ import pandas as pd warnings.filterwarnings("ignore") from spotforecast2_safe.configurator import ConfigEntsoe +from spotforecast2_safe.multitask.factories import ( + default_lgbm_forecaster_factory as entsoe_lgbm_factory, +) from spotforecast2.multitask.multi import MultiTask -from spotforecast2.tasks.task_entsoe import entsoe_lgbm_factory CACHE_HOME = tempfile.mkdtemp() ``` diff --git a/docs/tasks/entsoe.qmd b/docs/tasks/entsoe.qmd index 3e90372a..9b7bde4a 100644 --- a/docs/tasks/entsoe.qmd +++ b/docs/tasks/entsoe.qmd @@ -275,107 +275,94 @@ print(ts_clean.values) # [1.0, 2.0, 3.0, 4.0, 5.0] ## Forecaster Models +The pipeline builds its forecasters through *factory* functions that take a +`ConfigEntsoe` and return a ready-to-fit +`spotforecast2_safe.forecaster.recursive.ForecasterRecursive`. + ### LightGBM Forecaster -Create a LightGBM-based recursive forecaster: +The stock LightGBM factory lives in the safe package: -```python -from spotforecast2.tasks.task_entsoe import ForecasterRecursiveLGBM, config +```{python} +from spotforecast2_safe.configurator import ConfigEntsoe +from spotforecast2_safe.multitask.factories import default_lgbm_forecaster_factory -model = ForecasterRecursiveLGBM(iteration=1) +config = ConfigEntsoe() +forecaster = default_lgbm_forecaster_factory(config) -print(model.name) # 'lgbm' -print(model.random_state) # 314159 (from config) -print(len(model.preprocessor.periods)) # 5 (from config) +print(type(forecaster).__name__) +print("lags:", len(forecaster.lags)) ``` ### XGBoost Forecaster -Create an XGBoost-based recursive forecaster: - -```python -from spotforecast2.tasks.task_entsoe import ForecasterRecursiveXGB, config - -model = ForecasterRecursiveXGB(iteration=1, lags=24) +The XGBoost variant ships with the CLI module (xgboost is not a +safe-package dependency): -print(model.name) # 'xgb' -``` - -### Custom Configuration Forecaster - -Override default configuration values: - -```python -from spotforecast2.tasks.task_entsoe import ForecasterRecursiveLGBM -from spotforecast2_safe.data import Period +```{python} +from spotforecast2_safe.configurator import ConfigEntsoe -custom_periods = [ - Period(name='hourly', n_periods=24, column='hour', input_range=(1, 24)), -] +from spotforecast2.entsoe_cli import entsoe_xgb_factory -model = ForecasterRecursiveLGBM( - iteration=1, - lags=48, - periods=custom_periods, - country_code='FR', - random_state=42 -) +config = ConfigEntsoe() +forecaster = entsoe_xgb_factory(config) -print(len(model.preprocessor.periods)) # 1 -print(model.preprocessor.country_code) # 'FR' +print(type(forecaster.estimator).__name__) ``` ---- +Both factories honour `config.random_state`, `config.lags_consider`, and +`config.window_size`; supply your own factory through +`config.forecaster_factory` to customise further (see the Multitask +tutorial). ## Using the Python API (Notebooks & Quarto) ### Full Prediction Pipeline -For users working in Jupyter Notebooks or Quarto, the entire ENTSO-E pipeline can be executed using the Python API. This approach is highly recommended for safety-critical research as it allows for precise control over time windows and hyperparameters. +For users working in Jupyter Notebooks or Quarto, the entire ENTSO-E pipeline +can be executed through `MultiTask` — the same path the CLI's `train` and +`predict` subcommands take. This approach is recommended for research as it +gives precise control over time windows and hyperparameters. ```python -import pandas as pd +import logging import os + +from spotforecast2_safe.configurator import ConfigEntsoe +from spotforecast2_safe.data.entsoe_loader import entsoe_data_loader +from spotforecast2_safe.data.fetch_data import get_cache_home from spotforecast2_safe.downloader.entsoe import download_new_data -from spotforecast2_safe.manager.trainer import handle_training as handle_training_safe -from spotforecast2_safe.manager.predictor import get_model_prediction as get_model_prediction_safe -from spotforecast2.plots.plotter import make_plot -from spotforecast2.tasks.task_entsoe import ForecasterRecursiveLGBM - -# 1. Setup Time Windows (Last 3 years until last month) and country: -country_code = "ES" -now = pd.Timestamp.now(tz='UTC').floor('D') -current_month_start = now.replace(day=1) -last_month_start = (current_month_start - pd.Timedelta(days=1)).replace(day=1) - -# 2. Download Data (Optional, requires ENTSOE_API_KEY) +from spotforecast2_safe.multitask.factories import default_lgbm_forecaster_factory + +from spotforecast2.multitask import MultiTask + +# 1. Download data (optional, requires ENTSOE_API_KEY) api_key = os.environ.get("ENTSOE_API_KEY") if api_key: - download_new_data(api_key=api_key, start="202301010000", country_code=country_code) - -# 3. Configure and Train -# Explicit parameters override global configuration for reproducibility -model_class = ForecasterRecursiveLGBM -model_name = "lgbm_advanced" - -handle_training_safe( - model_class=model_class, - model_name=model_name, - train_size=pd.Timedelta(days=3 * 365), - end_dev=last_month_start.strftime("%Y-%m-%d %H:%M%z"), - country_code=country_code -) + download_new_data(api_key=api_key, start="202301010000") -# 4. Generate Predictions for the forecast horizon -# The predictor will automatically load the model trained above -predictions = get_model_prediction_safe( - model_name=model_name, - predict_size=24 * 31 +# 2. Wire the loader and factory into the config +config = ConfigEntsoe() +config.targets = ["Actual Load"] +config.agg_weights = [1.0] +config.bounds = [(-1e9, 1e9)] +config.data_loader = entsoe_data_loader +config.forecaster_factory = default_lgbm_forecaster_factory +config.data_frame_name = "entsoe-lgbm" + +# 3. Run the five-step pipeline (task="defaults" trains; "predict" reuses +# the saved model) +mt = MultiTask( + config, + task="defaults", + cache_home=get_cache_home(config.cache_home), + log_level=logging.ERROR, ) - -# 5. Visualize Results -if predictions: - make_plot(predictions) +mt.prepare_data() +mt.detect_outliers() +mt.impute() +mt.build_exogenous_features() +mt.run(show=True) ``` --- diff --git a/pyproject.toml b/pyproject.toml index ad25cd56..72c4ef2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,16 +34,16 @@ dependencies = [ # 22.0.0 made warm_start_lags the seed lag list itself (default # DEFAULT_WARM_START_LAGS, None disables) — consumed by SpotOptimStrategy. # 21.2.0 added max_time_spotoptim, forwarded as SpotOptim's max_time. - "spotforecast2-safe>=22.0.0,<23", + "spotforecast2-safe>=22.3.0,<23", # spotoptim 1.0 is sequential-only and lean: torch/tensorboard moved to its # ``[torch]`` extra. sf2 forwards tensorboard_* kwargs into SpotOptim, so we # pin the extra to keep the TensorBoard tuning dashboards working (they were # always available via spotoptim's old hard torch dependency). "spotoptim[torch]>=1.0.0,<2", "tqdm>=4.67.2", - # Directly imported by spotforecast2.tasks.task_entsoe (and the xgb - # forecaster model). Previously satisfied transitively; declared explicitly - # now that spotoptim 1.0 no longer pulls it in. + # Directly imported by spotforecast2.entsoe_cli (and the xgb forecaster + # model). Previously satisfied transitively; declared explicitly now that + # spotoptim 1.0 no longer pulls it in. "xgboost>=3.2.0", ] @@ -74,7 +74,7 @@ dev = [ ] [project.scripts] -spotforecast2-entsoe = "spotforecast2.tasks.task_entsoe:main" +spotforecast2-entsoe = "spotforecast2.entsoe_cli:main" [tool.uv] # Accept pre-releases ONLY for dependencies whose specifier carries an explicit diff --git a/src/spotforecast2/tasks/task_entsoe.py b/src/spotforecast2/entsoe_cli.py similarity index 57% rename from src/spotforecast2/tasks/task_entsoe.py rename to src/spotforecast2/entsoe_cli.py index d989a3d2..fb938fa3 100644 --- a/src/spotforecast2/tasks/task_entsoe.py +++ b/src/spotforecast2/entsoe_cli.py @@ -5,7 +5,10 @@ Drives the pipeline directly through `spotforecast2.multitask.multi.MultiTask` with `ConfigEntsoe` plugged in via the `data_loader` and `forecaster_factory` -hooks introduced in ADR-001. The CLI exposes four subcommands: +hooks introduced in ADR-001. The data loaders live in +`spotforecast2_safe.data.entsoe_loader`; the LightGBM factory is the stock +`spotforecast2_safe.multitask.factories.default_lgbm_forecaster_factory`. +The CLI exposes four subcommands: - ``download`` — fetch raw data from the ENTSO-E Transparency Platform. - ``merge`` — concatenate raw CSVs into the interim merged file. @@ -30,12 +33,13 @@ from typing import Any, Optional import pandas as pd -from lightgbm import LGBMRegressor from spotforecast2_safe.configurator import ConfigEntsoe -from spotforecast2_safe.data.fetch_data import get_cache_home, get_data_home +from spotforecast2_safe.data.entsoe_loader import entsoe_data_loader +from spotforecast2_safe.data.fetch_data import get_cache_home from spotforecast2_safe.downloader.entsoe import download_new_data, merge_build_manual from spotforecast2_safe.forecaster.recursive import ForecasterRecursive from spotforecast2_safe.manager.trainer import should_retrain +from spotforecast2_safe.multitask.factories import default_lgbm_forecaster_factory from spotforecast2_safe.preprocessing import RollingFeatures from xgboost import XGBRegressor @@ -57,176 +61,6 @@ _DEFAULT_AGG_WEIGHTS = [1.0] -def entsoe_data_loader(config: ConfigEntsoe) -> pd.DataFrame: - """Read the merged interim ENTSO-E CSV that ``config.data_filename`` points at. - - Args: - config: A `ConfigEntsoe` with ``data_filename`` set. Relative paths - are resolved against `spotforecast2_safe.data.fetch_data.get_data_home`. - - Returns: - DataFrame indexed by the ENTSO-E timestamp column (``Time (UTC)``) - with the load columns as data columns. - - Raises: - FileNotFoundError: If the merged CSV does not exist. Run - ``spotforecast2-entsoe download`` and ``merge`` first. - - Examples: - ```{python} - import os - import tempfile - - import pandas as pd - from spotforecast2_safe.configurator import ConfigEntsoe - - from spotforecast2.tasks.task_entsoe import entsoe_data_loader - - # Build a tiny synthetic interim CSV in a temp directory. - tmp = tempfile.mkdtemp() - csv_path = os.path.join(tmp, "energy_load.csv") - idx = pd.date_range( - "2025-01-01", periods=48, freq="h", tz="UTC", name="Time (UTC)" - ) - pd.DataFrame({"Actual Load": range(48)}, index=idx).to_csv(csv_path) - - # Absolute path bypasses get_data_home; loader returns the full frame. - config = ConfigEntsoe() - config.data_filename = csv_path - df = entsoe_data_loader(config) - - print(df.shape) - assert df.shape == (48, 1) - assert df.index.name == "Time (UTC)" - ``` - """ - path = Path(config.data_filename) - if not path.is_absolute(): - path = get_data_home() / path - if not path.exists(): - raise FileNotFoundError( - f"ENTSO-E merged CSV not found at {path}. Run " - "`spotforecast2-entsoe download` and `merge` first." - ) - return pd.read_csv(path, index_col=0, parse_dates=True) - - -def entsoe_test_data_loader(config: ConfigEntsoe) -> pd.DataFrame: - """Return the merged ENTSO-E CSV sliced to the forecast horizon. - - The slice spans ``(end_train, end_train + predict_size * 1 h]`` so that - ``build_prediction_package``'s ``test_actual = ts.reindex(future_pred.index)`` - matches the hourly forecast row-for-row. ``end_train`` is taken from - ``config.end_train_default`` (treated as the *inclusive* last training - timestamp, the same convention the forecaster uses), and the step is - assumed to be 1 h after the pipeline's hourly resampling. - - For the live ENTSO-E exemplar with ``end_train_default = D-2 23:00 UTC`` - and ``predict_size = 24``, this returns the rows for - ``[D-1 00:00, D 00:00)`` — i.e., ``y_{-1}``. For backtests at an arbitrary - ``end_train_default``, it returns the post-cutoff window the model is - actually predicting, rather than always "yesterday in wall-clock UTC". - - Args: - config: A `ConfigEntsoe` with ``data_filename``, ``end_train_default``, - and ``predict_size`` set; the merged interim CSV must already - contain data covering the forecast horizon (run - ``spotforecast2-entsoe download`` first). - - Returns: - DataFrame indexed by ``Time (UTC)`` with the rows the forecast will be - scored against. - - Examples: - ```{python} - import os - import tempfile - - import pandas as pd - from spotforecast2_safe.configurator import ConfigEntsoe - - from spotforecast2.tasks.task_entsoe import entsoe_test_data_loader - - # Synthetic interim CSV spanning the forecast window. - tmp = tempfile.mkdtemp() - csv_path = os.path.join(tmp, "energy_load.csv") - idx = pd.date_range( - "2025-12-29 00:00", periods=120, freq="h", tz="UTC", name="Time (UTC)" - ) - pd.DataFrame({"Actual Load": range(120)}, index=idx).to_csv(csv_path) - - config = ConfigEntsoe() - config.data_filename = csv_path - config.end_train_default = "2025-12-31 00:00+00:00" - config.predict_size = 24 - - test_df = entsoe_test_data_loader(config) - - # The slice covers exactly predict_size hourly steps after end_train. - print(test_df.shape) - assert test_df.shape == (24, 1) - assert test_df.index[0] == pd.Timestamp("2025-12-31 01:00", tz="UTC") - ``` - """ - df = entsoe_data_loader(config) - end_train = pd.Timestamp(config.end_train_default) - if end_train.tzinfo is None: - end_train = end_train.tz_localize("UTC") - step = pd.Timedelta(hours=1) # post-resample assumption - start = end_train + step # first forecast step - end = start + config.predict_size * step # exclusive upper bound - if df.index.tz is None: - start = start.tz_localize(None) - end = end.tz_localize(None) - return df.loc[(df.index >= start) & (df.index < end)] - - -def entsoe_lgbm_factory( - config: ConfigEntsoe, - *, - weight_func: Optional[Any] = None, - target: Optional[str] = None, -) -> ForecasterRecursive: - """LightGBM ForecasterRecursive for the ENTSO-E pipeline. - - Identical to ``spotforecast2.multitask.factories.default_lgbm_forecaster_factory``; - kept as a named helper so the CLI's intent ("use LightGBM here") is - visible at the configuration site. - - Args: - config: Any object exposing ``random_state``, ``lags_consider``, and - ``window_size`` (typically `ConfigEntsoe`). - weight_func: Per-sample weight function from the imputation step. - target: Ignored; accepted for factory-signature compatibility. - - Examples: - ```{python} - from spotforecast2_safe.configurator import ConfigEntsoe - from spotforecast2_safe.forecaster.recursive import ForecasterRecursive - - from spotforecast2.tasks.task_entsoe import entsoe_lgbm_factory - - config = ConfigEntsoe() - forecaster = entsoe_lgbm_factory(config, weight_func=None, target="Actual Load") - - print(type(forecaster).__name__) - assert isinstance(forecaster, ForecasterRecursive) - # The lags array is derived from lags_consider[-1] = 23. - assert len(forecaster.lags) == config.lags_consider[-1] - print("lags:", forecaster.lags) - ``` - """ - del target - return ForecasterRecursive( - estimator=LGBMRegressor(random_state=config.random_state, verbose=-1), - lags=config.lags_consider[-1], - window_features=RollingFeatures( - stats=["mean"], window_sizes=config.window_size - ), - weight_func=weight_func, - ) - - def entsoe_xgb_factory( config: ConfigEntsoe, *, @@ -235,9 +69,9 @@ def entsoe_xgb_factory( ) -> ForecasterRecursive: """XGBoost ForecasterRecursive for the ENTSO-E pipeline. - Mirrors `entsoe_lgbm_factory()` but uses an `XGBRegressor` estimator. - Kept as a named helper so the XGBoost variant is explicit at the - configuration site. + Mirrors `spotforecast2_safe.multitask.factories.default_lgbm_forecaster_factory` + but uses an `XGBRegressor` estimator. Lives here rather than in the safe + package because xgboost is not a safe-package dependency. Args: config: Any object exposing ``random_state``, ``lags_consider``, and @@ -251,7 +85,7 @@ def entsoe_xgb_factory( from spotforecast2_safe.forecaster.recursive import ForecasterRecursive from xgboost import XGBRegressor - from spotforecast2.tasks.task_entsoe import entsoe_xgb_factory + from spotforecast2.entsoe_cli import entsoe_xgb_factory config = ConfigEntsoe() forecaster = entsoe_xgb_factory(config, weight_func=None, target="Actual Load") @@ -273,7 +107,7 @@ def entsoe_xgb_factory( ) -_FACTORY_BY_MODEL = {"lgbm": entsoe_lgbm_factory, "xgb": entsoe_xgb_factory} +_FACTORY_BY_MODEL = {"lgbm": default_lgbm_forecaster_factory, "xgb": entsoe_xgb_factory} def _build_entsoe_config(model: str) -> ConfigEntsoe: @@ -372,7 +206,7 @@ def main() -> None: ```{python} import sys - from spotforecast2.tasks.task_entsoe import main + from spotforecast2.entsoe_cli import main # With no subcommand, main() prints the usage summary and returns # without error — useful for verifying the CLI is wired correctly. diff --git a/src/spotforecast2/tasks/__init__.py b/src/spotforecast2/tasks/__init__.py deleted file mode 100644 index 2dbbaaf9..00000000 --- a/src/spotforecast2/tasks/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# SPDX-FileCopyrightText: 2026 bartzbeielstein -# SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/tests/test_task_entsoe_cadence.py b/tests/test_entsoe_cli_cadence.py similarity index 70% rename from tests/test_task_entsoe_cadence.py rename to tests/test_entsoe_cli_cadence.py index dfaa27b5..9629ec86 100644 --- a/tests/test_task_entsoe_cadence.py +++ b/tests/test_entsoe_cli_cadence.py @@ -1,14 +1,14 @@ # SPDX-FileCopyrightText: 2026 bartzbeielstein # SPDX-License-Identifier: AGPL-3.0-or-later -"""Tests for the 7-day retraining cadence gate in task_entsoe.train.""" +"""Tests for the 7-day retraining cadence gate in entsoe_cli.train.""" import sys from unittest.mock import patch import pandas as pd -import spotforecast2.tasks.task_entsoe as task_entsoe +import spotforecast2.entsoe_cli as entsoe_cli def _make_model_file(tmp_path, project: str, age_days: float): @@ -28,13 +28,13 @@ def _make_model_file(tmp_path, project: str, age_days: float): def test_train_subcommand_skips_when_recent_model_exists(tmp_path, monkeypatch): """A model trained 1 day ago should NOT trigger a retrain.""" monkeypatch.setattr( - task_entsoe, "get_cache_home", lambda *_args, **_kwargs: tmp_path + entsoe_cli, "get_cache_home", lambda *_args, **_kwargs: tmp_path ) _make_model_file(tmp_path, "entsoe-lgbm", age_days=1) - with patch.object(task_entsoe, "_run_entsoe_pipeline") as mock_pipeline: + with patch.object(entsoe_cli, "_run_entsoe_pipeline") as mock_pipeline: monkeypatch.setattr(sys, "argv", ["spotforecast2-entsoe", "train", "lgbm"]) - task_entsoe.main() + entsoe_cli.main() mock_pipeline.assert_not_called() @@ -42,15 +42,15 @@ def test_train_subcommand_skips_when_recent_model_exists(tmp_path, monkeypatch): def test_train_subcommand_retrains_with_force(tmp_path, monkeypatch): """``--force`` bypasses the cadence gate.""" monkeypatch.setattr( - task_entsoe, "get_cache_home", lambda *_args, **_kwargs: tmp_path + entsoe_cli, "get_cache_home", lambda *_args, **_kwargs: tmp_path ) _make_model_file(tmp_path, "entsoe-lgbm", age_days=1) - with patch.object(task_entsoe, "_run_entsoe_pipeline") as mock_pipeline: + with patch.object(entsoe_cli, "_run_entsoe_pipeline") as mock_pipeline: monkeypatch.setattr( sys, "argv", ["spotforecast2-entsoe", "train", "lgbm", "--force"] ) - task_entsoe.main() + entsoe_cli.main() mock_pipeline.assert_called_once() @@ -58,12 +58,12 @@ def test_train_subcommand_retrains_with_force(tmp_path, monkeypatch): def test_train_subcommand_retrains_when_no_previous_model(tmp_path, monkeypatch): """No saved model → retrain immediately.""" monkeypatch.setattr( - task_entsoe, "get_cache_home", lambda *_args, **_kwargs: tmp_path + entsoe_cli, "get_cache_home", lambda *_args, **_kwargs: tmp_path ) - with patch.object(task_entsoe, "_run_entsoe_pipeline") as mock_pipeline: + with patch.object(entsoe_cli, "_run_entsoe_pipeline") as mock_pipeline: monkeypatch.setattr(sys, "argv", ["spotforecast2-entsoe", "train", "lgbm"]) - task_entsoe.main() + entsoe_cli.main() mock_pipeline.assert_called_once() @@ -71,13 +71,13 @@ def test_train_subcommand_retrains_when_no_previous_model(tmp_path, monkeypatch) def test_train_subcommand_retrains_when_model_is_old(tmp_path, monkeypatch): """A model older than retrain_max_age triggers a retrain.""" monkeypatch.setattr( - task_entsoe, "get_cache_home", lambda *_args, **_kwargs: tmp_path + entsoe_cli, "get_cache_home", lambda *_args, **_kwargs: tmp_path ) _make_model_file(tmp_path, "entsoe-lgbm", age_days=10) - with patch.object(task_entsoe, "_run_entsoe_pipeline") as mock_pipeline: + with patch.object(entsoe_cli, "_run_entsoe_pipeline") as mock_pipeline: monkeypatch.setattr(sys, "argv", ["spotforecast2-entsoe", "train", "lgbm"]) - task_entsoe.main() + entsoe_cli.main() mock_pipeline.assert_called_once() @@ -87,4 +87,4 @@ def test_latest_saved_model_timestamp_returns_none_when_dir_missing(tmp_path): config = ConfigEntsoe() config.cache_home = tmp_path - assert task_entsoe._latest_saved_model_timestamp(config, "no-such") is None + assert entsoe_cli._latest_saved_model_timestamp(config, "no-such") is None diff --git a/tests/test_tasks_entsoe_predict.py b/tests/test_entsoe_cli_predict.py similarity index 76% rename from tests/test_tasks_entsoe_predict.py rename to tests/test_entsoe_cli_predict.py index fed6268a..1e17a4d5 100644 --- a/tests/test_tasks_entsoe_predict.py +++ b/tests/test_entsoe_cli_predict.py @@ -6,13 +6,10 @@ from unittest.mock import patch from spotforecast2_safe.configurator import ConfigEntsoe +from spotforecast2_safe.data.entsoe_loader import entsoe_data_loader +from spotforecast2_safe.multitask.factories import default_lgbm_forecaster_factory -from spotforecast2.tasks.task_entsoe import ( - entsoe_data_loader, - entsoe_lgbm_factory, - entsoe_xgb_factory, - main, -) +from spotforecast2.entsoe_cli import entsoe_xgb_factory, main def _predict_call(model: str): @@ -21,7 +18,7 @@ def _predict_call(model: str): Returns ``(positional_args, keyword_args)`` of the single ``_run_entsoe_pipeline`` call. """ - with patch("spotforecast2.tasks.task_entsoe._run_entsoe_pipeline") as mock_pipeline: + with patch("spotforecast2.entsoe_cli._run_entsoe_pipeline") as mock_pipeline: with patch("sys.argv", ["spotforecast2-entsoe", "predict", model]): main() assert mock_pipeline.call_count == 1 @@ -38,7 +35,7 @@ def test_predict_lgbm_dispatches_to_pipeline_with_predict_task(): assert config.agg_weights == [1.0] assert config.bounds == [(-1e9, 1e9)] assert config.data_loader is entsoe_data_loader - assert config.forecaster_factory is entsoe_lgbm_factory + assert config.forecaster_factory is default_lgbm_forecaster_factory assert config.index_name == "Time (UTC)" @@ -51,10 +48,10 @@ def test_predict_xgb_uses_xgb_project_and_factory(): def test_predict_default_model_is_lgbm(): - with patch("spotforecast2.tasks.task_entsoe._run_entsoe_pipeline") as mock_pipeline: + with patch("spotforecast2.entsoe_cli._run_entsoe_pipeline") as mock_pipeline: with patch("sys.argv", ["spotforecast2-entsoe", "predict"]): main() args = mock_pipeline.call_args.args kwargs = mock_pipeline.call_args.kwargs assert kwargs["project_name"] == "entsoe-lgbm" - assert args[0].forecaster_factory is entsoe_lgbm_factory + assert args[0].forecaster_factory is default_lgbm_forecaster_factory diff --git a/tests/test_tasks_entsoe_train.py b/tests/test_entsoe_cli_train.py similarity index 78% rename from tests/test_tasks_entsoe_train.py rename to tests/test_entsoe_cli_train.py index a5b1640b..a394d441 100644 --- a/tests/test_tasks_entsoe_train.py +++ b/tests/test_entsoe_cli_train.py @@ -3,7 +3,7 @@ """Tests for the ``spotforecast2-entsoe train`` CLI subcommand. -Mocks ``spotforecast2.tasks.task_entsoe._run_entsoe_pipeline`` so the tests +Mocks ``spotforecast2.entsoe_cli._run_entsoe_pipeline`` so the tests run offline and without needing the merged ENTSO-E CSV. The CLI invokes ``_run_entsoe_pipeline(config, task="defaults", project_name=..., show=...)`` — the config carries every pipeline parameter (targets, agg_weights, bounds, @@ -13,13 +13,10 @@ from unittest.mock import patch from spotforecast2_safe.configurator import ConfigEntsoe +from spotforecast2_safe.data.entsoe_loader import entsoe_data_loader +from spotforecast2_safe.multitask.factories import default_lgbm_forecaster_factory -from spotforecast2.tasks.task_entsoe import ( - entsoe_data_loader, - entsoe_lgbm_factory, - entsoe_xgb_factory, - main, -) +from spotforecast2.entsoe_cli import entsoe_xgb_factory, main def _train_call(model: str): @@ -29,7 +26,7 @@ def _train_call(model: str): ``_run_entsoe_pipeline`` call. ``--force`` bypasses the cadence gate so the dispatch test does not depend on the state of the user's model cache. """ - with patch("spotforecast2.tasks.task_entsoe._run_entsoe_pipeline") as mock_pipeline: + with patch("spotforecast2.entsoe_cli._run_entsoe_pipeline") as mock_pipeline: with patch("sys.argv", ["spotforecast2-entsoe", "train", model, "--force"]): main() assert mock_pipeline.call_count == 1 @@ -46,7 +43,7 @@ def test_train_lgbm_dispatches_to_pipeline_with_defaults_task(): assert config.agg_weights == [1.0] assert config.bounds == [(-1e9, 1e9)] assert config.data_loader is entsoe_data_loader - assert config.forecaster_factory is entsoe_lgbm_factory + assert config.forecaster_factory is default_lgbm_forecaster_factory assert config.index_name == "Time (UTC)" assert kwargs["show"] is False @@ -61,17 +58,17 @@ def test_train_xgb_dispatches_to_pipeline_with_xgb_factory(): def test_train_default_model_is_lgbm(): """Omitting the positional ``model`` arg falls back to LightGBM.""" - with patch("spotforecast2.tasks.task_entsoe._run_entsoe_pipeline") as mock_pipeline: + with patch("spotforecast2.entsoe_cli._run_entsoe_pipeline") as mock_pipeline: with patch("sys.argv", ["spotforecast2-entsoe", "train", "--force"]): main() args = mock_pipeline.call_args.args kwargs = mock_pipeline.call_args.kwargs assert kwargs["project_name"] == "entsoe-lgbm" - assert args[0].forecaster_factory is entsoe_lgbm_factory + assert args[0].forecaster_factory is default_lgbm_forecaster_factory def test_train_show_flag_forwards_to_pipeline(): - with patch("spotforecast2.tasks.task_entsoe._run_entsoe_pipeline") as mock_pipeline: + with patch("spotforecast2.entsoe_cli._run_entsoe_pipeline") as mock_pipeline: with patch( "sys.argv", ["spotforecast2-entsoe", "train", "lgbm", "--show", "--force"], diff --git a/tests/test_entsoe_pipeline.py b/tests/test_entsoe_pipeline.py index 560fefe9..3cff8350 100644 --- a/tests/test_entsoe_pipeline.py +++ b/tests/test_entsoe_pipeline.py @@ -15,12 +15,10 @@ import pandas as pd import pytest from spotforecast2_safe.configurator import ConfigEntsoe +from spotforecast2_safe.multitask.factories import default_lgbm_forecaster_factory +from spotforecast2.entsoe_cli import _run_entsoe_pipeline from spotforecast2.multitask.multi import MultiTask -from spotforecast2.tasks.task_entsoe import ( - _run_entsoe_pipeline, - entsoe_lgbm_factory, -) def _synthetic_entsoe_df(n_days: int = 30) -> pd.DataFrame: @@ -68,7 +66,7 @@ def entsoe_config(tmp_path: Path) -> ConfigEntsoe: predict_size=12, ) cfg.data_loader = _make_loader(df) - cfg.forecaster_factory = entsoe_lgbm_factory + cfg.forecaster_factory = default_lgbm_forecaster_factory cfg.cache_home = str(tmp_path) return cfg diff --git a/uv.lock b/uv.lock index c91a1374..c3e6c040 100644 --- a/uv.lock +++ b/uv.lock @@ -3491,7 +3491,7 @@ wheels = [ [[package]] name = "spotforecast2" -version = "8.1.1" +version = "9.0.0" source = { editable = "." } dependencies = [ { name = "astral" }, @@ -3571,7 +3571,7 @@ requires-dist = [ { name = "safety", marker = "extra == 'dev'", specifier = ">=3.0.0" }, { name = "scikit-learn", specifier = ">=1.8.0" }, { name = "shap", specifier = ">=0.49.1" }, - { name = "spotforecast2-safe", specifier = ">=22.0.0,<23" }, + { name = "spotforecast2-safe", specifier = ">=22.3.0,<23" }, { name = "spotoptim", extras = ["torch"], specifier = ">=1.0.0,<2" }, { name = "tqdm", specifier = ">=4.67.2" }, { name = "ty", marker = "extra == 'dev'", specifier = ">=0.0.29" }, @@ -3590,7 +3590,7 @@ dev = [ [[package]] name = "spotforecast2-safe" -version = "22.2.0" +version = "22.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "astral" }, @@ -3607,9 +3607,9 @@ dependencies = [ { name = "statsmodels" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5d/fd/685a4d9797d467ec646c3ffc75b8d6327ff770a7540b47c5a0300f23aac5/spotforecast2_safe-22.2.0.tar.gz", hash = "sha256:fd458aea0a6421cc8229cdbc2314b9a0863508771c286ac2537c8d3221eb1362", size = 20660329, upload-time = "2026-06-12T11:53:21.696Z" } +sdist = { url = "https://files.pythonhosted.org/packages/65/04/4f366dad3f78214f6d7564c883e77f67e8c734c807ac49f17c8cd1528d45/spotforecast2_safe-22.3.0.tar.gz", hash = "sha256:71ddb0215489f9042b050dda50e4eb8c0afea598ccb5c7796216073dff436e27", size = 20661753, upload-time = "2026-06-12T22:49:46.389Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/30/1b/24262db44be056e62680e4ef28e8423bd977abeac68b42c15d045f6941e3/spotforecast2_safe-22.2.0-py3-none-any.whl", hash = "sha256:1b79b7d132024103da23ccfb4e8057583e3d04906574e633c78f97e163301da9", size = 20729694, upload-time = "2026-06-12T11:53:19.078Z" }, + { url = "https://files.pythonhosted.org/packages/1d/10/249f41e69a3b8d9803b40b721487180cea24e4d0a152fe267e1387ab0ec9/spotforecast2_safe-22.3.0-py3-none-any.whl", hash = "sha256:0791cb055e36a54deb5124bdcd8e2545ffa1ae200a89563966ac99d43891665c", size = 20731775, upload-time = "2026-06-12T22:49:43.719Z" }, ] [[package]]