@@ -35,7 +35,7 @@ def fun_control_init(
3535 db_dict_name = None ,
3636 design = None ,
3737 device = None ,
38- devices = 1 ,
38+ devices = "auto" ,
3939 enable_progress_bar = False ,
4040 EXPERIMENT_NAME = None ,
4141 eval = None ,
@@ -55,9 +55,11 @@ def fun_control_init(
5555 n_samples = None ,
5656 n_total = None ,
5757 num_workers = 0 ,
58+ num_nodes = 1 ,
5859 ocba_delta = 0 ,
5960 oml_grace_period = None ,
6061 optimizer = None ,
62+ precision = "32" ,
6163 prep_model = None ,
6264 prep_model_name = None ,
6365 progress_file = None ,
@@ -70,6 +72,7 @@ def fun_control_init(
7072 show_progress = True ,
7173 shuffle = None ,
7274 sigma = 0.0 ,
75+ strategy = "auto" ,
7376 surrogate = None ,
7477 target_column = None ,
7578 target_type = None ,
@@ -181,6 +184,8 @@ def fun_control_init(
181184 The number of samples in the dataset. Default is None.
182185 n_total (int):
183186 The total number of samples in the dataset. Default is None.
187+ num_nodes (int):
188+ The number of GPU nodes to use for the training/validation/testing. Default is 1.
184189 num_workers (int):
185190 The number of workers to use for the data loading. Default is 0.
186191 ocba_delta (int):
@@ -190,6 +195,8 @@ def fun_control_init(
190195 The grace period for the OML algorithm. Default is None.
191196 optimizer (object):
192197 The optimizer object used for the search on surrogate. Default is None.
198+ precision (str):
199+ The precision of the data. Default is "32". Can be e.g., "16-mixed" or "16-true".
193200 PREFIX (str):
194201 The prefix of the experiment name. If the PREFIX is not None, a spotWriter
195202 that us an instance of a SummaryWriter(), is created. Default is "00".
@@ -221,6 +228,8 @@ def fun_control_init(
221228 Whether the data were shuffled or not. Default is None.
222229 surrogate (object):
223230 The surrogate model object. Default is None.
231+ strategy (str):
232+ The strategy to use. Default is "auto".
224233 target_column (str):
225234 The name of the target column. Default is None.
226235 target_type (str):
@@ -393,11 +402,13 @@ def fun_control_init(
393402 "n_points" : n_points ,
394403 "n_samples" : n_samples ,
395404 "n_total" : n_total ,
405+ "num_nodes" : num_nodes ,
396406 "num_workers" : num_workers ,
397407 "ocba_delta" : ocba_delta ,
398408 "oml_grace_period" : oml_grace_period ,
399409 "optimizer" : optimizer ,
400410 "path" : None ,
411+ "precision" : precision ,
401412 "prep_model" : prep_model ,
402413 "prep_model_name" : prep_model_name ,
403414 "progress_file" : progress_file ,
@@ -413,6 +424,7 @@ def fun_control_init(
413424 "shuffle" : shuffle ,
414425 "sigma" : sigma ,
415426 "spot_tensorboard_path" : spot_tensorboard_path ,
427+ "strategy" : strategy ,
416428 "target_column" : target_column ,
417429 "target_type" : target_type ,
418430 "task" : task ,
0 commit comments