@@ -170,34 +170,15 @@ def __init__(
170170 # self.dropout1 = nn.Dropout(dropout[0])
171171 # self.dropout2 = nn.Dropout(dropout[1])
172172 # self.dropout3 = nn.Dropout(dropout[2])
173- # # TODO: use different dropout for different layers
173+ # # TODO: use enhanced dropout management for different layers
174174 self .dropout1 = nn .Dropout (self .hparams .dropout_prob )
175175 self .dropout2 = nn .Dropout (self .hparams .dropout_prob // 10.0 )
176176 self .dropout3 = nn .Dropout (self .hparams .dropout_prob // 100.0 )
177177
178- activation_fct = nn .ReLU ()
179- self .activation_fct = activation_fct
180- # self.activation_fct = self.hparams.act_fn
181-
182- # ###########################################
183- # old:
184- # if self.hparams.l1 < 4:
185- # raise ValueError("l1 must be at least 4")
186- # hidden_sizes = [self.hparams.l1, self.hparams.l1 // 2, self.hparams.l1 // 2, self.hparams.l1 // 4]
187- # Create the network based on the specified hidden sizes
188- # layers = []
189- # layer_sizes = [self._L_in] + hidden_sizes
190- # layer_size_last = layer_sizes[0]
191- # for layer_size in layer_sizes[1:]:
192- # layers += [
193- # nn.Linear(layer_size_last, layer_size),
194- # self.hparams.act_fn,
195- # nn.Dropout(self.hparams.dropout_prob),
196- # ]
197- # layer_size_last = layer_size
198- # layers += [nn.Linear(layer_sizes[-1], self._L_out)]
199- # # nn.Sequential summarizes a list of modules into a single module, applying them in sequence
200- # self.layers = nn.Sequential(*layers)
178+ # TODO: Enable different activation functions
179+ # activation_fct = nn.ReLU()
180+ # self.activation_fct = activation_fct
181+ self .activation_fct = self .hparams .act_fn
201182
202183 def forward (self , x : torch .Tensor ) -> torch .Tensor :
203184 """
@@ -227,22 +208,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
227208 # print(f"output_layer: {x.shape}")
228209 return x
229210
230- # old:
231- # x = self.layers(x)
232- # # check if the number of columns in x is 1, otherwise throw an error
233- # try:
234- # assert x.shape[1] == 1
235- # except AssertionError:
236- # print(f"forward x.shape: {x.shape}")
237- # raise AssertionError("Number of columns in x is not 1.")
238- # return x
239-
240- def training_step (self , batch : tuple ) -> torch .Tensor :
211+ def training_step (self , batch : tuple , prog_bar : bool = False ) -> torch .Tensor :
241212 """
242213 Performs a single training step.
243214
244215 Args:
245216 batch (tuple): A tuple containing a batch of input data and labels.
217+ prog_bar (bool, optional): Whether to display the progress bar. Defaults to False.
246218
247219 Returns:
248220 torch.Tensor: A tensor containing the loss for this batch.
@@ -251,26 +223,14 @@ def training_step(self, batch: tuple) -> torch.Tensor:
251223 x , y = batch
252224 # reshape the tensor y to be a column vector (len(y) rows and 1 column)
253225 y = y .view (len (y ), 1 )
254- # check if the number of rows in x is equal to the number of rows in y, otherwise throw an error
255- try :
256- assert x .shape [0 ] == y .shape [0 ]
257- except AssertionError :
258- print (f"training_step x.shape: { x .shape } " )
259- print (f"training_step y.shape: { y .shape } " )
260- raise AssertionError ("Number of rows in x and y must be equal" )
226+ # Note: the number of rows in x is equal to the number of rows in y
261227 y_hat = self (x )
262- # check if the number of rows in y_hat is equal to the number of rows in y, otherwise throw an error
263- try :
264- assert y_hat .shape [0 ] == y .shape [0 ]
265- except AssertionError :
266- print (f"training_step y_hat.shape: { y_hat .shape } " )
267- print (f"training_step y.shape: { y .shape } " )
268- raise AssertionError ("Number of rows in y_hat and y must be equal" )
269- val_loss = F .mse_loss (y_hat , y )
228+ # Note: the number of rows in y_hat is equal to the number of rows in y
229+ train_loss = F .mse_loss (y_hat , y )
270230 # mae_loss = F.l1_loss(y_hat, y)
271231 # self.log("train_loss", val_loss, on_step=True, on_epoch=True, prog_bar=True)
272232 # self.log("train_mae_loss", mae_loss, on_step=True, on_epoch=True, prog_bar=True)
273- return val_loss
233+ return train_loss
274234
275235 def validation_step (self , batch : tuple , batch_idx : int , prog_bar : bool = False ) -> torch .Tensor :
276236 """
0 commit comments