Skip to content

Commit 9c7fe8e

Browse files
0.14.32
1 parent 97f108f commit 9c7fe8e

4 files changed

Lines changed: 186 additions & 1 deletion

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3951,6 +3951,81 @@
39513951
"print(next(iter(data_module.train_dataloader()))[0].numpy())\n"
39523952
]
39533953
},
3954+
{
3955+
"cell_type": "code",
3956+
"execution_count": null,
3957+
"metadata": {},
3958+
"outputs": [],
3959+
"source": []
3960+
},
3961+
{
3962+
"cell_type": "code",
3963+
"execution_count": 1,
3964+
"metadata": {},
3965+
"outputs": [
3966+
{
3967+
"name": "stderr",
3968+
"output_type": "stream",
3969+
"text": [
3970+
"Seed set to 123\n",
3971+
"Seed set to 123\n"
3972+
]
3973+
},
3974+
{
3975+
"name": "stdout",
3976+
"output_type": "stream",
3977+
"text": [
3978+
"spotPython tuning: 0.0 [########--] 80.00% \n",
3979+
"spotPython tuning: 0.0 [#########-] 86.67% \n",
3980+
"spotPython tuning: 0.0 [#########-] 93.33% \n",
3981+
"spotPython tuning: 0.0 [##########] 100.00% Done...\n",
3982+
"\n",
3983+
"S.X: [[ 0.00000000e+00 0.00000000e+00]\n",
3984+
" [ 0.00000000e+00 1.00000000e+00]\n",
3985+
" [ 1.00000000e+00 0.00000000e+00]\n",
3986+
" [ 1.00000000e+00 1.00000000e+00]\n",
3987+
" [-9.09243389e-01 -1.58234577e-01]\n",
3988+
" [-2.05817107e-01 -4.81249089e-01]\n",
3989+
" [ 9.49741171e-01 -9.46312716e-01]\n",
3990+
" [-1.20955714e-01 6.38358863e-02]\n",
3991+
" [-6.62787018e-01 1.74316373e-01]\n",
3992+
" [ 2.82008441e-01 9.30010114e-01]\n",
3993+
" [ 4.78788115e-01 6.53210582e-01]\n",
3994+
" [ 2.64764215e-04 4.00803185e-03]\n",
3995+
" [-1.66363820e-05 4.65001027e-03]\n",
3996+
" [-2.60995680e-04 5.46114194e-03]\n",
3997+
" [ 3.74504308e-03 1.86731890e-02]]\n",
3998+
"S.y: [0.00000000e+00 1.00000000e+00 1.00000000e+00 2.00000000e+00\n",
3999+
" 8.51761723e-01 2.73961367e-01 1.79751605e+00 1.87053051e-02\n",
4000+
" 4.69672829e-01 9.44447573e-01 6.55922124e-01 1.61344194e-05\n",
4001+
" 2.16228723e-05 2.98921900e-05 3.62713334e-04]\n"
4002+
]
4003+
}
4004+
],
4005+
"source": [
4006+
"import numpy as np\n",
4007+
"from spotPython.fun.objectivefunctions import analytical\n",
4008+
"from spotPython.spot import spot\n",
4009+
"from spotPython.utils.init import (\n",
4010+
" fun_control_init, optimizer_control_init, surrogate_control_init, design_control_init\n",
4011+
" )\n",
4012+
"# number of initial points:\n",
4013+
"ni = 7\n",
4014+
"# start point X_0\n",
4015+
"X_start = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])\n",
4016+
"fun = analytical().fun_sphere\n",
4017+
"fun_control = fun_control_init(\n",
4018+
" lower = np.array([-1, -1]),\n",
4019+
" upper = np.array([1, 1]))\n",
4020+
"design_control=design_control_init(init_size=ni)\n",
4021+
"S = spot.Spot(fun=fun,\n",
4022+
" fun_control=fun_control,\n",
4023+
" design_control=design_control,)\n",
4024+
"S.run(X_start=X_start)\n",
4025+
"print(f\"S.X: {S.X}\")\n",
4026+
"print(f\"S.y: {S.y}\")"
4027+
]
4028+
},
39544029
{
39554030
"cell_type": "code",
39564031
"execution_count": null,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.14.31"
10+
version = "0.14.32"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/spot/spot.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,75 @@ def write_db_dict(self) -> None:
744744
print("No results to write.")
745745

746746
def run(self, X_start=None) -> Spot:
747+
"""
748+
Run the surrogate based optimization.
749+
The optimization process is controlled by the following steps:
750+
1. Initialize design
751+
2. Update stats
752+
3. Fit surrogate
753+
4. Update design
754+
5. Update stats
755+
6. Update writer
756+
7. Fit surrogate
757+
8. Show progress if needed
758+
759+
Args:
760+
self (object): Spot object
761+
X_start (numpy.ndarray, optional): initial design. Defaults to None.
762+
763+
Returns:
764+
(object): Spot object
765+
766+
Examples:
767+
>>> import numpy as np
768+
from spotPython.fun.objectivefunctions import analytical
769+
from spotPython.spot import spot
770+
from spotPython.utils.init import (
771+
fun_control_init, optimizer_control_init, surrogate_control_init, design_control_init
772+
)
773+
# number of initial points:
774+
ni = 7
775+
# start point X_0
776+
X_start = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
777+
fun = analytical().fun_sphere
778+
fun_control = fun_control_init(
779+
lower = np.array([-1, -1]),
780+
upper = np.array([1, 1]))
781+
design_control=design_control_init(init_size=ni)
782+
S = spot.Spot(fun=fun,
783+
fun_control=fun_control,
784+
design_control=design_control,)
785+
S.run(X_start=X_start)
786+
print(f"S.X: {S.X}")
787+
print(f"S.y: {S.y}")
788+
Seed set to 123
789+
Seed set to 123
790+
spotPython tuning: 0.0 [########--] 80.00%
791+
spotPython tuning: 0.0 [#########-] 86.67%
792+
spotPython tuning: 0.0 [#########-] 93.33%
793+
spotPython tuning: 0.0 [##########] 100.00% Done...
794+
795+
S.X: [[ 0.00000000e+00 0.00000000e+00]
796+
[ 0.00000000e+00 1.00000000e+00]
797+
[ 1.00000000e+00 0.00000000e+00]
798+
[ 1.00000000e+00 1.00000000e+00]
799+
[-9.09243389e-01 -1.58234577e-01]
800+
[-2.05817107e-01 -4.81249089e-01]
801+
[ 9.49741171e-01 -9.46312716e-01]
802+
[-1.20955714e-01 6.38358863e-02]
803+
[-6.62787018e-01 1.74316373e-01]
804+
[ 2.82008441e-01 9.30010114e-01]
805+
[ 4.78788115e-01 6.53210582e-01]
806+
[ 2.64764215e-04 4.00803185e-03]
807+
[-1.66363820e-05 4.65001027e-03]
808+
[-2.60995680e-04 5.46114194e-03]
809+
[ 3.74504308e-03 1.86731890e-02]]
810+
S.y: [0.00000000e+00 1.00000000e+00 1.00000000e+00 2.00000000e+00
811+
8.51761723e-01 2.73961367e-01 1.79751605e+00 1.87053051e-02
812+
4.69672829e-01 9.44447573e-01 6.55922124e-01 1.61344194e-05
813+
2.16228723e-05 2.98921900e-05 3.62713334e-04]
814+
815+
"""
747816
self.initialize_design(X_start)
748817
# New: self.update_stats() moved here:
749818
# changed in 0.5.9:

test/test_spot_run.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import numpy as np
2+
import pytest
3+
from spotPython.fun.objectivefunctions import analytical
4+
from spotPython.spot import spot
5+
from spotPython.utils.init import fun_control_init, design_control_init
6+
7+
8+
@pytest.fixture
9+
def setup_spot():
10+
"""
11+
PyTest Fixture for initializing Spot with given parameters.
12+
"""
13+
ni = 7 # number of initial points
14+
ne = 20 # number of evaluations
15+
X_start = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
16+
fun = analytical().fun_sphere
17+
fun_control = fun_control_init(lower=np.array([-1, -1]), upper=np.array([1, 1]), fun_evals=ne)
18+
design_control = design_control_init(init_size=ni)
19+
20+
S = spot.Spot(
21+
fun=fun,
22+
fun_control=fun_control,
23+
design_control=design_control,
24+
)
25+
return S, X_start
26+
27+
28+
def test_spot_run_shapes(setup_spot):
29+
"""
30+
Test the shapes of S.X and S.y after running the Spot.run method.
31+
"""
32+
S, X_start = setup_spot
33+
ne = S.fun_control["fun_evals"]
34+
exp_X_shape = (ne, 2)
35+
exp_y_shape = (ne,)
36+
37+
S.run(X_start=X_start)
38+
39+
# Assert shapes
40+
assert S.X.shape == exp_X_shape, f"Optimized S.X shape {S.X.shape} does not match expected shape {exp_X_shape}."
41+
assert S.y.shape == exp_y_shape, f"Optimized S.y shape {S.y.shape} does not match expected shape {exp_y_shape}."

0 commit comments

Comments
 (0)