Skip to content

Commit ab9940a

Browse files
0.14.42
New nn base model
1 parent 06a24ec commit ab9940a

7 files changed

Lines changed: 819 additions & 53 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 297 additions & 43 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.14.41"
10+
version = "0.14.42"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/hyperparameters/values.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22
import copy
33
import json
4-
import river
54
import river.preprocessing
65
from sklearn.pipeline import make_pipeline
76
from river import compose
@@ -10,9 +9,12 @@
109
from spotPython.utils.transform import transform_hyper_parameter_values
1110

1211
# Important, do not delete the following imports, they are needed for the function add_core_model_to_fun_control
12+
import river
1313
from river import forest, tree, linear_model, rules
1414
from river import preprocessing
1515
import sklearn.metrics
16+
import spotPython
17+
from spotPython.light import regression
1618

1719

1820
def generate_one_config_from_var_dict(
@@ -1698,20 +1700,47 @@ def set_factor_hyperparameter_values(fun_control, key, levels):
16981700
fun_control["core_model_hyper_dict"][key].update({"upper": len(levels) - 1})
16991701

17001702

1701-
def get_core_model_from_name(core_model_name) -> object:
1703+
def get_core_model_from_name(core_model_name: str) -> tuple:
17021704
"""
1703-
Returns the river core model name and instance from a core model name.
1705+
Returns the river or spotPython core model name and instance from a core model name.
17041706
17051707
Args:
1706-
core_model_name (str): The name of the core model.
1708+
core_model_name (str): The full name of the core model in the format 'module.Model'.
17071709
17081710
Returns:
1709-
(str, object): The core model name and instance.
1711+
(str, object): A tuple containing the core model name and an instance of the core model.
1712+
1713+
Examples:
1714+
>>> from spotPython.hyperparameters.values import get_core_model_from_name
1715+
model_name, model_instance = get_core_model_from_name('tree.HoeffdingTreeRegressor')
1716+
print(f"Model Name: {model_name}, Model Instance: {model_instance}")
1717+
Model Name: HoeffdingTreeRegressor, Model Instance: <class 'river.tree.hoeffding_tree_regressor.HoeffdingTreeRegressor'>
1718+
>>> model_name, model_instance = get_core_model_from_name("light.regression.NNLinearRegressor")
1719+
print(f"Model Name: {model_name}, Model Instance: {model_instance}")
1720+
Model Name: NNLinearRegressor, Model Instance: <class 'spotPython.light.regression.nn_linear_regressor.NNLinearRegressor'>
17101721
"""
1711-
core_model_module = core_model_name.split(".")[0]
1712-
coremodel = core_model_name.split(".")[1]
1713-
core_model_instance = getattr(getattr(river, core_model_module), coremodel)
1714-
return coremodel, core_model_instance
1722+
# Split the model name into its components
1723+
name_parts = core_model_name.split(".")
1724+
if len(name_parts) < 2:
1725+
raise ValueError(f"Invalid core model name: {core_model_name}. Expected format: 'module.ModelName'.")
1726+
module_name = name_parts[0]
1727+
model_name = name_parts[1]
1728+
try:
1729+
# Try to get the model from the river library
1730+
core_model_instance = getattr(getattr(river, module_name), model_name)
1731+
return model_name, core_model_instance
1732+
except AttributeError:
1733+
try:
1734+
# Try to get the model from the spotPython library
1735+
submodule_name = name_parts[1]
1736+
model_name = name_parts[2] if len(name_parts) == 3 else model_name
1737+
print(f"module_name: {module_name}")
1738+
print(f"submodule_name: {submodule_name}")
1739+
print(f"model_name: {model_name}")
1740+
core_model_instance = getattr(getattr(getattr(spotPython, module_name), submodule_name), model_name)
1741+
return model_name, core_model_instance
1742+
except AttributeError:
1743+
raise ValueError(f"Model '{core_model_name}' not found in either 'river' or 'spotPython' libraries.")
17151744

17161745

17171746
def get_prep_model(prepmodel_name) -> object:
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""
2+
This module implements pytorch lightning neural networks for handling regression tasks.
3+
4+
"""
5+
6+
from .nn_linear_regressor import NNLinearRegressor
7+
from .netlightregression import NetLightRegression
8+
9+
__all__ = ["NNLinearRegressor", "NetLightRegression"]

0 commit comments

Comments
 (0)