Skip to content

Commit 06b9460

Browse files
committed
fix batch_size bug
1 parent fd541ee commit 06b9460

7 files changed

Lines changed: 21 additions & 12 deletions

File tree

pina/_src/data/creator.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,14 @@ def _compute_batch_sizes(self, datasets):
7979
"""
8080
batch_sizes = {}
8181
if self.batching_mode == "common_batch_size":
82-
82+
8383
if self.batch_size is None:
84-
batch_size = max(dataset.length for dataset in datasets.values())
84+
batch_size = max(
85+
dataset.length for dataset in datasets.values()
86+
)
8587
else:
8688
batch_size = self.batch_size
87-
89+
8890
for name in datasets.keys():
8991
batch_sizes[name] = min(batch_size, len(datasets[name]))
9092
return batch_sizes
@@ -169,9 +171,12 @@ def __call__(self, datasets):
169171
dataloaders = {}
170172
if self.batching_mode == "common_batch_size":
171173
max_len = max(len(dataset) for dataset in datasets.values())
172-
174+
print(batch_sizes)
173175
for name, dataset in datasets.items():
174-
if self.batching_mode == "common_batch_size" and dataset.length != batch_sizes[name]:
176+
if (
177+
self.batching_mode == "common_batch_size"
178+
and dataset.length != batch_sizes[name]
179+
):
175180
dataset.max_len = max_len
176181
dataloaders[name] = self.conditions[name].create_dataloader(
177182
dataset=dataset,

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pathlib import Path
33
import pytest
44

5+
56
@pytest.fixture
67
def clean_tmp_dir(tmp_path):
78
path = Path(tmp_path)
@@ -13,4 +14,4 @@ def clean_tmp_dir(tmp_path):
1314
yield path
1415

1516
if path.exists():
16-
shutil.rmtree(path)
17+
shutil.rmtree(path)

tests/test_problem.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def test_variables_correct_order_sampling():
4848
poisson_problem.input_variables
4949
)
5050

51+
5152
def test_add_points():
5253
poisson_problem = Poisson()
5354
poisson_problem.discretise_domain(1, "random", domains=["D"])
@@ -90,4 +91,4 @@ def test_wrong_custom_sampling_logic(mode):
9091

9192
# Necessary cleanup
9293
if "new" in poisson_problem.domains:
93-
del poisson_problem.domains["new"]
94+
del poisson_problem.domains["new"]

tests/test_solver/test_competitive_pinn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,4 +149,4 @@ def test_train_load_restore(clean_tmp_dir, problem):
149149
)
150150
torch.testing.assert_close(
151151
new_solver.forward(test_pts), solver.forward(test_pts)
152-
)
152+
)

tests/test_solver/test_pinn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def test_train_load_restore(clean_tmp_dir, problem):
116116
)
117117
trainer.train()
118118
import os
119+
119120
print(os.listdir(f"{dir}/lightning_logs/version_0/checkpoints/"))
120121

121122
# restore
@@ -125,7 +126,6 @@ def test_train_load_restore(clean_tmp_dir, problem):
125126
+ "epoch=4-step=5.ckpt"
126127
)
127128

128-
129129
# loading
130130
new_solver = PINN.load_from_checkpoint(
131131
f"{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt",
@@ -138,4 +138,4 @@ def test_train_load_restore(clean_tmp_dir, problem):
138138
assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape
139139
torch.testing.assert_close(
140140
new_solver.forward(test_pts), solver.forward(test_pts)
141-
)
141+
)

tests/test_solver/test_rba_pinn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,4 @@ def test_train_load_restore(clean_tmp_dir, problem):
158158
)
159159
torch.testing.assert_close(
160160
new_solver.forward(test_pts), solver.forward(test_pts)
161-
)
161+
)

tests/test_solver/test_supervised_solver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def test_solver_test_graph(batch_size, use_lt):
216216
from pathlib import Path
217217
import pytest
218218

219+
219220
@pytest.fixture
220221
def clean_tmp_dir():
221222
path = Path("tests/test_solver/tmp/")
@@ -229,6 +230,7 @@ def clean_tmp_dir():
229230
if path.exists():
230231
shutil.rmtree(path)
231232

233+
232234
def test_train_load_restore(clean_tmp_dir):
233235
dir = clean_tmp_dir
234236
problem = LabelTensorProblem()
@@ -264,4 +266,4 @@ def test_train_load_restore(clean_tmp_dir):
264266
assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape
265267
torch.testing.assert_close(
266268
new_solver.forward(test_pts), solver.forward(test_pts)
267-
)
269+
)

0 commit comments

Comments
 (0)