@@ -69,7 +69,7 @@ def test_forward_pass(nn_linear_regressor):
6969
7070
7171def test_training_step (nn_linear_regressor ):
72- trainer = L .Trainer (max_epochs = 1 , enable_checkpointing = False )
72+ trainer = L .Trainer (max_epochs = 1 , enable_checkpointing = False , accelerator = "cpu" )
7373 train_loader = data_loader
7474 trainer .fit (nn_linear_regressor , train_loader )
7575 batch_x , batch_y = next (iter (train_loader ))
@@ -83,7 +83,7 @@ def test_training_step(nn_linear_regressor):
8383
8484
8585def test_validation_step (nn_linear_regressor ):
86- trainer = L .Trainer (max_epochs = 1 , enable_checkpointing = False )
86+ trainer = L .Trainer (max_epochs = 1 , enable_checkpointing = False , accelerator = "cpu" )
8787 val_loader = data_loader
8888 trainer .validate (nn_linear_regressor , val_loader )
8989
@@ -98,7 +98,7 @@ def test_validation_step(nn_linear_regressor):
9898
9999
100100def test_testing_step (nn_linear_regressor ):
101- trainer = L .Trainer (max_epochs = 1 , enable_checkpointing = False )
101+ trainer = L .Trainer (max_epochs = 1 , enable_checkpointing = False , accelerator = "cpu" )
102102 test_loader = data_loader
103103 trainer .test (nn_linear_regressor , test_loader )
104104
0 commit comments