11import logging
22import numpy as np
33from numpy .random import default_rng
4- from numpy import array
54from spotPython .light .traintest import train_model
6- from spotPython .hyperparameters .values import (
7- assign_values ,
8- generate_one_config_from_var_dict ,
9- )
5+ from spotPython .hyperparameters .values import assign_values , generate_one_config_from_var_dict , get_var_name
106
117logger = logging .getLogger (__name__ )
128py_handler = logging .FileHandler (f"{ __name__ } .log" , mode = "w" )
@@ -32,37 +28,25 @@ class HyperLight:
3228 Examples:
3329 >>> hyper_light = HyperLight(seed=126, log_level=50)
3430 >>> print(hyper_light.seed)
35- 126
31+ 126
3632 """
3733
3834 def __init__ (self , seed : int = 126 , log_level : int = 50 ) -> None :
3935 self .seed = seed
4036 self .rng = default_rng (seed = self .seed )
41- self .fun_control = {
42- "seed" : None ,
43- "data" : None ,
44- "step" : 10_000 ,
45- "horizon" : None ,
46- "grace_period" : None ,
47- "metric_river" : None ,
48- "metric_sklearn" : None ,
49- "weights" : array ([1 , 0 , 0 ]),
50- "weight_coeff" : 0.0 ,
51- "log_level" : log_level ,
52- "var_name" : [],
53- "var_type" : [],
54- }
55- self .log_level = self .fun_control ["log_level" ]
56- logger .setLevel (self .log_level )
57- logger .info (f"Starting the logger at level { self .log_level } for module { __name__ } :" )
58-
59- def check_X_shape (self , X : np .ndarray ) -> np .ndarray :
37+ self .log_level = log_level
38+ logger .setLevel (log_level )
39+ logger .info (f"Starting the logger at level { log_level } for module { __name__ } :" )
40+
41+ def check_X_shape (self , X : np .ndarray , fun_control : dict ) -> np .ndarray :
6042 """
6143 Checks the shape of the input array X and raises an exception if it is not valid.
6244
6345 Args:
6446 X (np.ndarray):
6547 input array.
48+ fun_control (dict):
49+ dictionary containing control parameters for the hyperparameter tuning.
6650
6751 Returns:
6852 np.ndarray:
@@ -73,17 +57,31 @@ def check_X_shape(self, X: np.ndarray) -> np.ndarray:
7357 if the shape of the input array is not valid.
7458
7559 Examples:
76- >>> hyper_light = HyperLight(seed=126, log_level=50)
77- >>> X = np.array([[1, 2], [3, 4]])
78- >>> hyper_light.check_X_shape(X)
79- array([[1, 2],
80- [3, 4]])
60+ >>> import numpy as np
61+ from spotPython.utils.init import fun_control_init
62+ from spotPython.light.netlightregression import NetLightRegression
63+ from spotPython.hyperdict.light_hyper_dict import LightHyperDict
64+ from spotPython.hyperparameters.values import add_core_model_to_fun_control
65+ from spotPython.fun.hyperlight import HyperLight
66+ from spotPython.hyperparameters.values import get_var_name
67+ fun_control = fun_control_init()
68+ add_core_model_to_fun_control(core_model=NetLightRegression,
69+ fun_control=fun_control,
70+ hyper_dict=LightHyperDict)
71+ hyper_light = HyperLight(seed=126, log_level=50)
72+ n_hyperparams = len(get_var_name(fun_control))
73+ # generate a random np.array X with shape (2, n_hyperparams)
74+ X = np.random.rand(2, n_hyperparams)
75+ X == hyper_light.check_X_shape(X, fun_control)
76+ array([[ True, True, True, True, True, True, True, True, True],
77+ [ True, True, True, True, True, True, True, True, True]])
78+
8179 """
8280 try :
8381 X .shape [1 ]
8482 except ValueError :
8583 X = np .array ([X ])
86- if X .shape [1 ] != len (self . fun_control [ "var_name" ] ):
84+ if X .shape [1 ] != len (get_var_name ( fun_control ) ):
8785 raise Exception ("Invalid shape of input array X." )
8886 return X
8987
@@ -102,30 +100,51 @@ def fun(self, X: np.ndarray, fun_control: dict = None) -> np.ndarray:
102100 array containing the evaluation results.
103101
104102 Examples:
105- >>> hyper_light = HyperLight(seed=126, log_level=50)
106- X = np.array([[1, 2], [3, 4]])
107- fun_control = {"weights": np.array([1, 0, 0])}
103+ >>> from spotPython.utils.init import fun_control_init
104+ from spotPython.light.netlightregression import NetLightRegression
105+ from spotPython.hyperdict.light_hyper_dict import LightHyperDict
106+ from spotPython.hyperparameters.values import
107+ (add_core_model_to_fun_control,
108+ get_default_hyperparameters_as_array)
109+ from spotPython.fun.hyperlight import HyperLight
110+ from spotPython.data.diabetes import Diabetes
111+ from spotPython.hyperparameters.values import set_data_set
112+ import numpy as np
113+ fun_control = fun_control_init(
114+ _L_in=10,
115+ _L_out=1,)
116+
117+ dataset = Diabetes()
118+ set_data_set(fun_control=fun_control,
119+ data_set=dataset)
120+
121+ add_core_model_to_fun_control(core_model=NetLightRegression,
122+ fun_control=fun_control,
123+ hyper_dict=LightHyperDict)
124+ hyper_light = HyperLight(seed=126, log_level=50)
125+ X = get_default_hyperparameters_as_array(fun_control)
126+ # combine X and X to a np.array with shape (2, n_hyperparams)
127+ # so that two values are returned
128+ X = np.vstack((X, X))
108129 hyper_light.fun(X, fun_control)
109- array([nan, nan ])
130+ array([27462.84179688, 20990.08007812 ])
110131 """
111132 z_res = np .array ([], dtype = float )
112- if fun_control is not None :
113- self .fun_control .update (fun_control )
114- self .check_X_shape (X )
115- var_dict = assign_values (X , self .fun_control ["var_name" ])
133+ self .check_X_shape (X = X , fun_control = fun_control )
134+ var_dict = assign_values (X , get_var_name (fun_control ))
116135 # type information and transformations are considered in generate_one_config_from_var_dict:
117- for config in generate_one_config_from_var_dict (var_dict , self . fun_control ):
136+ for config in generate_one_config_from_var_dict (var_dict , fun_control ):
118137 logger .debug (f"\n config: { config } " )
119138 # extract parameters like epochs, batch_size, lr, etc. from config
120139 # config_id = generate_config_id(config)
121140 try :
122- print ("fun: Calling train_model" )
123- df_eval = train_model (config , self . fun_control )
124- print ("fun: train_model returned" )
141+ logger . debug ("fun: Calling train_model" )
142+ df_eval = train_model (config , fun_control )
143+ logger . debug ("fun: train_model returned" )
125144 except Exception as err :
126145 logger .error (f"Error in fun(). Call to train_model failed. { err = } , { type (err )= } " )
127146 logger .error ("Setting df_eval to np.nan" )
128147 df_eval = np .nan
129- z_val = self . fun_control ["weights" ] * df_eval
148+ z_val = fun_control ["weights" ] * df_eval
130149 z_res = np .append (z_res , z_val )
131150 return z_res
0 commit comments