Skip to content

Commit 4d46a8d

Browse files
0.14.21
1 parent 27a4f8e commit 4d46a8d

5 files changed

Lines changed: 139 additions & 167 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 46 additions & 163 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.14.20"
10+
version = "0.14.21"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/data/friedman.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import math
2+
import random
3+
4+
5+
class FriedmanDriftDataset:
6+
"""Friedman Drift Dataset."""
7+
8+
def __init__(self, n_samples=100, change_point1=50, change_point2=75, seed=None, constant=False):
9+
"""Constructor for the Friedman Drift Dataset.
10+
11+
Args:
12+
n_samples (int): The number of samples to generate.
13+
change_point1 (int): The index of the first change point.
14+
change_point2 (int): The index of the second change point.
15+
seed (int): The seed for the random number generator.
16+
constant (bool): If True, only the first feature is set to 1 and all others are set to 0.
17+
18+
Returns:
19+
None
20+
21+
Examples:
22+
>>> from spotPython.data.friedman import FriedmanDriftDataset
23+
data_generator = FriedmanDriftDataset(n_samples=100,
24+
seed=42, change_point1=50, change_point2=75, constant=False)
25+
data = [data for data in data_generator]
26+
indices = [i for _, _, i in data]
27+
values = {f"x{i}": [] for i in range(5)}
28+
values["y"] = []
29+
for x, y, _ in data:
30+
for i in range(5):
31+
values[f"x{i}"].append(x[i])
32+
values["y"].append(y)
33+
plt.figure(figsize=(10, 6))
34+
for label, series in values.items():
35+
plt.plot(indices, series, label=label)
36+
plt.xlabel('Index')
37+
plt.ylabel('Value')
38+
plt.title('')
39+
plt.axvline(x=50, color='k', linestyle='--', label='Drift Point 1')
40+
plt.axvline(x=75, color='r', linestyle='--', label='Drift Point 2')
41+
plt.legend()
42+
plt.grid(True)
43+
plt.show()
44+
"""
45+
self.n_samples = n_samples
46+
self._change_point1 = change_point1
47+
self._change_point2 = change_point2
48+
self.seed = seed
49+
self.index = 0
50+
self.rng = random.Random(self.seed)
51+
self.constant = constant
52+
53+
def __iter__(self):
54+
return self
55+
56+
def __next__(self):
57+
if self.index >= self.n_samples: # Specifying end of generation
58+
raise StopIteration
59+
if self.constant:
60+
# x[0] is set to 1, all others to 0
61+
x = {0: 1}
62+
x.update({i: 0 for i in range(1, 10)}) # All x[i] are 0 for i > 0
63+
else:
64+
x = {i: self.rng.uniform(a=0, b=1) for i in range(10)}
65+
y = self._global_recurring_abrupt_gen(x, self.index) + self.rng.gauss(mu=0, sigma=1)
66+
result = (x, y, self.index)
67+
self.index += 1
68+
return result
69+
70+
def _global_recurring_abrupt_gen(self, x, index):
71+
if index < self._change_point1 or index >= self._change_point2:
72+
return 10 * math.sin(math.pi * x[0] * x[1]) + 20 * (x[2] - 0.5) ** 2 + 10 * x[3] + 5 * x[4]
73+
else:
74+
return 10 * math.sin(math.pi * x[3] * x[5]) + 20 * (x[1] - 0.5) ** 2 + 10 * x[0] + 5 * x[2]
75+
76+
def __len__(self) -> int:
77+
"""
78+
Returns the length of the dataset.
79+
80+
Returns:
81+
int: The length of the dataset.
82+
83+
84+
"""
85+
return self.n_samples

src/spotPython/hyperparameters/values.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1511,4 +1511,4 @@ def get_metric_sklearn(metric_name):
15111511
sklearn.metrics: The metric from the metric name.
15121512
"""
15131513
metric_sklearn = getattr(sklearn.metrics, metric_name)
1514-
return metric_sklearn
1514+
return metric_sklearn

src/spotPython/utils/init.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
import datetime
77
from dateutil.tz import tzlocal
88
from torch.utils.tensorboard import SummaryWriter
9-
from spotPython.hyperparameters.values import (add_core_model_to_fun_control,
10-
get_core_model_from_name, get_metric_sklearn, get_prep_model)
9+
from spotPython.hyperparameters.values import (
10+
add_core_model_to_fun_control,
11+
get_core_model_from_name,
12+
get_metric_sklearn,
13+
get_prep_model,
14+
)
1115

1216

1317
def fun_control_init(

0 commit comments

Comments
 (0)