Skip to content

Commit edb8139

Browse files
0.14.45
1 parent 805503b commit edb8139

5 files changed

Lines changed: 110 additions & 190 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 30 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -4288,7 +4288,7 @@
42884288
},
42894289
{
42904290
"cell_type": "code",
4291-
"execution_count": 12,
4291+
"execution_count": null,
42924292
"metadata": {},
42934293
"outputs": [],
42944294
"source": [
@@ -4339,17 +4339,9 @@
43394339
},
43404340
{
43414341
"cell_type": "code",
4342-
"execution_count": 13,
4342+
"execution_count": null,
43434343
"metadata": {},
4344-
"outputs": [
4345-
{
4346-
"name": "stdout",
4347-
"output_type": "stream",
4348-
"text": [
4349-
"Model Name: HoeffdingTreeRegressor, Model Instance: <class 'river.tree.hoeffding_tree_regressor.HoeffdingTreeRegressor'>\n"
4350-
]
4351-
}
4352-
],
4344+
"outputs": [],
43534345
"source": [
43544346
"\n",
43554347
"# Example of usage\n",
@@ -4359,186 +4351,19 @@
43594351
},
43604352
{
43614353
"cell_type": "code",
4362-
"execution_count": 14,
4354+
"execution_count": null,
43634355
"metadata": {},
4364-
"outputs": [
4365-
{
4366-
"name": "stdout",
4367-
"output_type": "stream",
4368-
"text": [
4369-
"module_name: light\n",
4370-
"submodule_name: regression\n",
4371-
"model_name: NNLinearRegressor\n",
4372-
"Model Name: NNLinearRegressor, Model Instance: <class 'spotPython.light.regression.nn_linear_regressor.NNLinearRegressor'>\n"
4373-
]
4374-
}
4375-
],
4356+
"outputs": [],
43764357
"source": [
43774358
"model_name, model_instance = get_core_model_from_name(\"light.regression.NNLinearRegressor\")\n",
43784359
"print(f\"Model Name: {model_name}, Model Instance: {model_instance}\")"
43794360
]
43804361
},
43814362
{
43824363
"cell_type": "code",
4383-
"execution_count": 15,
4364+
"execution_count": null,
43844365
"metadata": {},
4385-
"outputs": [
4386-
{
4387-
"name": "stderr",
4388-
"output_type": "stream",
4389-
"text": [
4390-
"/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'act_fn' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['act_fn'])`.\n",
4391-
"GPU available: True (mps), used: True\n",
4392-
"TPU available: False, using: 0 TPU cores\n",
4393-
"IPU available: False, using: 0 IPUs\n"
4394-
]
4395-
},
4396-
{
4397-
"name": "stderr",
4398-
"output_type": "stream",
4399-
"text": [
4400-
"HPU available: False, using: 0 HPUs\n",
4401-
"/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n",
4402-
"\n",
4403-
" | Name | Type | Params | In sizes | Out sizes\n",
4404-
"-------------------------------------------------------------\n",
4405-
"0 | layers | Sequential | 15.9 K | [8, 10] | [8, 1] \n",
4406-
"-------------------------------------------------------------\n",
4407-
"15.9 K Trainable params\n",
4408-
"0 Non-trainable params\n",
4409-
"15.9 K Total params\n",
4410-
"0.064 Total estimated model params size (MB)\n"
4411-
]
4412-
},
4413-
{
4414-
"name": "stdout",
4415-
"output_type": "stream",
4416-
"text": [
4417-
"torch.Size([8, 10])\n",
4418-
"torch.Size([8])\n"
4419-
]
4420-
},
4421-
{
4422-
"name": "stderr",
4423-
"output_type": "stream",
4424-
"text": [
4425-
"/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n"
4426-
]
4427-
},
4428-
{
4429-
"data": {
4430-
"application/vnd.jupyter.widget-view+json": {
4431-
"model_id": "f2089c81a3034f8181ae924de01692ca",
4432-
"version_major": 2,
4433-
"version_minor": 0
4434-
},
4435-
"text/plain": [
4436-
"Training: | | 0/? [00:00<?, ?it/s]"
4437-
]
4438-
},
4439-
"metadata": {},
4440-
"output_type": "display_data"
4441-
},
4442-
{
4443-
"name": "stderr",
4444-
"output_type": "stream",
4445-
"text": [
4446-
"`Trainer.fit` stopped: `max_epochs=2` reached.\n",
4447-
"/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n"
4448-
]
4449-
},
4450-
{
4451-
"data": {
4452-
"application/vnd.jupyter.widget-view+json": {
4453-
"model_id": "459a96c4bed440dfafbc3d40c0e7a8d0",
4454-
"version_major": 2,
4455-
"version_minor": 0
4456-
},
4457-
"text/plain": [
4458-
"Validation: | | 0/? [00:00<?, ?it/s]"
4459-
]
4460-
},
4461-
"metadata": {},
4462-
"output_type": "display_data"
4463-
},
4464-
{
4465-
"data": {
4466-
"text/html": [
4467-
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
4468-
"┃<span style=\"font-weight: bold\"> Validate metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n",
4469-
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
4470-
"│<span style=\"color: #008080; text-decoration-color: #008080\"> hp_metric </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 29042.5703125 </span>│\n",
4471-
"│<span style=\"color: #008080; text-decoration-color: #008080\"> val_loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 29042.5703125 </span>│\n",
4472-
"└───────────────────────────┴───────────────────────────┘\n",
4473-
"</pre>\n"
4474-
],
4475-
"text/plain": [
4476-
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
4477-
"\u001b[1m \u001b[0m\u001b[1m Validate metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
4478-
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
4479-
"\u001b[36m \u001b[0m\u001b[36m hp_metric \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 29042.5703125 \u001b[0m\u001b[35m \u001b[0m│\n",
4480-
"\u001b[36m \u001b[0m\u001b[36m val_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 29042.5703125 \u001b[0m\u001b[35m \u001b[0m│\n",
4481-
"└───────────────────────────┴───────────────────────────┘\n"
4482-
]
4483-
},
4484-
"metadata": {},
4485-
"output_type": "display_data"
4486-
},
4487-
{
4488-
"name": "stderr",
4489-
"output_type": "stream",
4490-
"text": [
4491-
"/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n"
4492-
]
4493-
},
4494-
{
4495-
"data": {
4496-
"application/vnd.jupyter.widget-view+json": {
4497-
"model_id": "9a7e5a3ceb724b9e87b9f23341122ce4",
4498-
"version_major": 2,
4499-
"version_minor": 0
4500-
},
4501-
"text/plain": [
4502-
"Testing: | | 0/? [00:00<?, ?it/s]"
4503-
]
4504-
},
4505-
"metadata": {},
4506-
"output_type": "display_data"
4507-
},
4508-
{
4509-
"data": {
4510-
"text/html": [
4511-
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
4512-
"┃<span style=\"font-weight: bold\"> Test metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n",
4513-
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
4514-
"│<span style=\"color: #008080; text-decoration-color: #008080\"> hp_metric </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 29042.5703125 </span>│\n",
4515-
"│<span style=\"color: #008080; text-decoration-color: #008080\"> val_loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 29042.5703125 </span>│\n",
4516-
"└───────────────────────────┴───────────────────────────┘\n",
4517-
"</pre>\n"
4518-
],
4519-
"text/plain": [
4520-
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
4521-
"\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
4522-
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
4523-
"\u001b[36m \u001b[0m\u001b[36m hp_metric \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 29042.5703125 \u001b[0m\u001b[35m \u001b[0m│\n",
4524-
"\u001b[36m \u001b[0m\u001b[36m val_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 29042.5703125 \u001b[0m\u001b[35m \u001b[0m│\n",
4525-
"└───────────────────────────┴───────────────────────────┘\n"
4526-
]
4527-
},
4528-
"metadata": {},
4529-
"output_type": "display_data"
4530-
},
4531-
{
4532-
"data": {
4533-
"text/plain": [
4534-
"[{'val_loss': 29042.5703125, 'hp_metric': 29042.5703125}]"
4535-
]
4536-
},
4537-
"execution_count": 15,
4538-
"metadata": {},
4539-
"output_type": "execute_result"
4540-
}
4541-
],
4366+
"outputs": [],
45424367
"source": [
45434368
"from torch.utils.data import DataLoader\n",
45444369
"from spotPython.data.diabetes import Diabetes\n",
@@ -4572,6 +4397,29 @@
45724397
"trainer.test(net_light_base, test_loader)"
45734398
]
45744399
},
4400+
{
4401+
"cell_type": "code",
4402+
"execution_count": 1,
4403+
"metadata": {},
4404+
"outputs": [
4405+
{
4406+
"ename": "NameError",
4407+
"evalue": "name 'MockDataSet' is not defined",
4408+
"output_type": "error",
4409+
"traceback": [
4410+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
4411+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
4412+
"Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mspotPython\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minit\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m get_feature_names\n\u001b[0;32m----> 2\u001b[0m fun_control \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata_set\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[43mMockDataSet\u001b[49m(names\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfeature1\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfeature2\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfeature3\u001b[39m\u001b[38;5;124m\"\u001b[39m])}\n\u001b[1;32m 3\u001b[0m get_feature_names(fun_control)\n",
4413+
"\u001b[0;31mNameError\u001b[0m: name 'MockDataSet' is not defined"
4414+
]
4415+
}
4416+
],
4417+
"source": [
4418+
"from spotPython.utils.init import get_feature_names\n",
4419+
"fun_control = {\"data_set\": MockDataSet(names=[\"feature1\", \"feature2\", \"feature3\"])}\n",
4420+
"get_feature_names(fun_control)\n"
4421+
]
4422+
},
45754423
{
45764424
"cell_type": "code",
45774425
"execution_count": null,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.14.44"
10+
version = "0.14.45"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/utils/file.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,24 +181,38 @@ def load_core_model_from_file(coremodel, dirname="userModel"):
181181
return core_model
182182

183183

184-
def get_experiment_from_PREFIX(PREFIX) -> tuple:
184+
def get_experiment_from_PREFIX(PREFIX, return_dict=True) -> dict:
185185
"""
186186
Setup the experiment based on the PREFIX provided and return the relevant configuration
187187
and control objects.
188188
189189
Args:
190-
PREFIX (str): The prefix for the experiment filename.
190+
PREFIX (str):
191+
The prefix for the experiment filename.
192+
return_dict (bool, optional):
193+
Whether to return the configuration and control objects as a dictionary.
194+
If False, a tuple is returned:
195+
"(config, fun_control, design_control, surrogate_control, optimizer_control)."
196+
Defaults to True.
191197
192198
Returns:
193-
tuple:
194-
A tuple containing config, spot_tuner, fun_control, design_control, surrogate_control,
195-
and optimizer_control.
199+
dict: Dictionary containing the configuration and control objects.
196200
197201
Example:
198-
>>> config, _, _, _, _, _ = get_experiment_from_PREFIX("100")
202+
>>> from spotPython.utils.file import get_experiment_from_PREFIX
203+
>>> config = get_experiment_from_PREFIX("100")["config"]
199204
200205
"""
201206
experiment_name = get_experiment_filename(PREFIX)
202207
spot_tuner, fun_control, design_control, surrogate_control, optimizer_control = load_experiment(experiment_name)
203208
config = get_tuned_architecture(spot_tuner, fun_control)
204-
return config, spot_tuner, fun_control, design_control, surrogate_control, optimizer_control
209+
if return_dict:
210+
return {
211+
"config": config,
212+
"fun_control": fun_control,
213+
"design_control": design_control,
214+
"surrogate_control": surrogate_control,
215+
"optimizer_control": optimizer_control,
216+
}
217+
else:
218+
return config, fun_control, design_control, surrogate_control, optimizer_control

src/spotPython/utils/init.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from typing import List, Dict, Any
23
import lightning as L
34
from scipy.optimize import differential_evolution
45
import numpy as np
@@ -683,3 +684,29 @@ def get_tensorboard_path(fun_control):
683684
tensorboard_path (str): The path to the folder where the tensorboard files are saved.
684685
"""
685686
return fun_control["TENSORBOARD_PATH"]
687+
688+
689+
def get_feature_names(fun_control: Dict[str, Any]) -> List[str]:
690+
"""
691+
Get the feature names from the fun_control dictionary.
692+
693+
Args:
694+
fun_control (dict): The function control dictionary. Must contain a "data_set" key.
695+
696+
Returns:
697+
List[str]: List of feature names.
698+
699+
Raises:
700+
ValueError: If "data_set" is not in fun_control.
701+
ValueError: If "data_set" is None.
702+
703+
Examples:
704+
>>> from spotPython.utils.init import get_feature_names
705+
get_feature_names(fun_control)
706+
"""
707+
data_set = fun_control.get("data_set")
708+
709+
if data_set is None:
710+
raise ValueError("'data_set' key not found or is None in 'fun_control'")
711+
712+
return data_set.names

test/test_get_feature_names.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
from spotPython.utils.init import get_feature_names # Replace 'your_module_name' with the actual module name
3+
4+
5+
class MockDataSet:
6+
def __init__(self, names):
7+
self.names = names
8+
9+
10+
def test_get_feature_names_success():
11+
fun_control = {"data_set": MockDataSet(names=["feature1", "feature2", "feature3"])}
12+
feature_names = get_feature_names(fun_control)
13+
assert feature_names == ["feature1", "feature2", "feature3"]
14+
15+
16+
def test_get_feature_names_missing_data_set_key():
17+
fun_control = {}
18+
with pytest.raises(ValueError, match="'data_set' key not found or is None in 'fun_control'"):
19+
get_feature_names(fun_control)
20+
21+
22+
def test_get_feature_names_data_set_none():
23+
fun_control = {"data_set": None}
24+
with pytest.raises(ValueError, match="'data_set' key not found or is None in 'fun_control'"):
25+
get_feature_names(fun_control)
26+
27+
28+
def test_get_feature_names_empty_names():
29+
fun_control = {"data_set": MockDataSet(names=[])}
30+
feature_names = get_feature_names(fun_control)
31+
assert feature_names == []

0 commit comments

Comments
 (0)