diff --git a/tests/unit/dataset_classes/testDynamicDataset.py b/tests/unit/dataset_classes/testDynamicDataset.py index c8846273..b61ca80c 100644 --- a/tests/unit/dataset_classes/testDynamicDataset.py +++ b/tests/unit/dataset_classes/testDynamicDataset.py @@ -85,8 +85,10 @@ def test_get_test_split_valid(self) -> None: """ Test splitting the dataset into train and test sets and verify balance and non-overlap. """ - self.dataset.train_split = 0.5 + # self.dataset.train_split = 0.5 # Test size will be 0.25 * 16 = 4 + self.dataset.test_split = 0.25 + self.dataset.validation_split = 0.25 train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0) # Assert the correct number of rows in train and test sets @@ -149,7 +151,9 @@ def test_get_train_val_splits_given_test(self) -> None: Test splitting the dataset into train and validation sets and verify balance and non-overlap. """ self.dataset.use_inner_cross_validation = False - self.dataset.train_split = 0.5 + # self.dataset.train_split = 0.5 + self.dataset.test_split = 0.25 + self.dataset.validation_split = 0.25 df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0) train_df, val_df = self.dataset.get_train_val_splits_given_test( df_train_main, test_df, seed=42 @@ -220,7 +224,9 @@ def test_get_test_split_stratification(self) -> None: """ Test that the split into train and test sets maintains the stratification of labels. """ - self.dataset.train_split = 0.5 + # self.dataset.train_split = 0.5 + self.dataset.test_split = 0.25 + self.dataset.validation_split = 0.25 train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0) number_of_labels = len(self.data_df["labels"][0]) @@ -288,7 +294,10 @@ def test_get_train_val_splits_given_test_stratification(self) -> None: Test that the split into train and validation sets maintains the stratification of labels. """ self.dataset.use_inner_cross_validation = False - self.dataset.train_split = 0.5 + # self.dataset.train_split = 0.5 + self.dataset.test_split = 0.25 + self.dataset.validation_split = 0.25 + df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0) train_df, val_df = self.dataset.get_train_val_splits_given_test( df_train_main, test_df, seed=42