@@ -34,10 +34,12 @@ class PKLDataset(Dataset):
3434 The directory where the pkl file is located.
3535 feature_type (torch.dtype):
3636 The data type of the features.
37+ Defaults to torch.float.
3738 target_column (str):
3839 The name of the target column.
3940 target_type (torch.dtype):
4041 The data type of the targets.
42+ Defaults to torch.float.
4143 train (bool):
4244 Whether the dataset is for training or not.
4345 rmNA (bool):
@@ -73,7 +75,7 @@ class PKLDataset(Dataset):
7375 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7476 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
7577 [1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
76- 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 , 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0,
78+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 , 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0,
7779 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
7880 [1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
7981 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
@@ -104,9 +106,11 @@ def __init__(
104106 directory : None = None ,
105107 feature_type : torch .dtype = torch .float ,
106108 target_column : str = "y" ,
107- target_type : torch .dtype = torch .long ,
109+ target_type : torch .dtype = torch .float ,
108110 train : bool = True ,
109111 rmNA = True ,
112+ oe = OrdinalEncoder (),
113+ le = LabelEncoder (),
110114 ** desc ,
111115 ) -> None :
112116 super ().__init__ ()
@@ -117,16 +121,15 @@ def __init__(
117121 self .target_column = target_column
118122 self .train = train
119123 self .rmNA = rmNA
124+ self .oe = oe
125+ self .le = le
120126 self .data , self .targets = self ._load_data ()
121127
122128 @property
123129 def path (self ):
124- # user defined directory:
125130 if self .directory :
126131 return pathlib .Path (self .directory ).joinpath (self .filename )
127- # no user defined directory, use package directory
128- else :
129- return pathlib .Path (__file__ ).parent .joinpath (self .filename )
132+ return pathlib .Path (__file__ ).parent .joinpath (self .filename )
130133
131134 @property
132135 def _repr_content (self ):
@@ -135,26 +138,37 @@ def _repr_content(self):
135138 return content
136139
137140 def _load_data (self ) -> tuple :
141+ # ensure that self.target_type and self.feature_type are the same torch types
142+ if self .target_type != self .feature_type :
143+ raise ValueError ("target_type and feature_type must be the same torch type" )
138144 with open (self .path , "rb" ) as f :
139145 df = pd .read_pickle (f )
140146 # rm rows with NA
141147 if self .rmNA :
142148 df = df .dropna ()
143149
144- oe = OrdinalEncoder ()
145- # Apply LabelEncoder to string columns
146- le = LabelEncoder ()
147- # df = df.apply(lambda col: le.fit_transform(col) if col.dtypes == object else col)
148-
149150 # Split DataFrame into feature and target DataFrames
150151 feature_df = df .drop (columns = [self .target_column ])
151- feature_df = oe .fit_transform (feature_df )
152+
153+ # Identify non-numerical columns in the feature DataFrame
154+ non_numerical_columns = feature_df .select_dtypes (exclude = ["number" ]).columns .tolist ()
155+
156+ # Apply OrdinalEncoder to non-numerical feature columns
157+ if non_numerical_columns :
158+ feature_df [non_numerical_columns ] = self .oe .fit_transform (feature_df [non_numerical_columns ])
159+
152160 target_df = df [self .target_column ]
153- target_df = le .fit_transform (target_df )
154161
155- # Convert DataFrames to PyTorch tensors
156- feature_tensor = torch .tensor (feature_df , dtype = self .feature_type )
157- target_tensor = torch .tensor (target_df , dtype = self .target_type )
162+ # Check if the target column is non-numerical using dtype
163+ if not pd .api .types .is_numeric_dtype (target_df ):
164+ target_df = self .le .fit_transform (target_df )
165+
166+ # Convert DataFrames to NumPy arrays and then to PyTorch tensors
167+ feature_array = feature_df .to_numpy ()
168+ target_array = target_df
169+
170+ feature_tensor = torch .tensor (feature_array , dtype = self .feature_type )
171+ target_tensor = torch .tensor (target_array , dtype = self .target_type )
158172
159173 return feature_tensor , target_tensor
160174
@@ -214,3 +228,20 @@ def extra_repr(self) -> str:
214228 print(dataset)
215229 """
216230 return "filename={}, directory={}" .format (self .filename , self .directory )
231+
232+ def __ncols__ (self ) -> int :
233+ """
234+ Returns the number of columns in the dataset.
235+
236+ Returns:
237+ int: The number of columns in the dataset.
238+
239+ Examples:
240+ >>> from spotPython.data.pkldataset import PKLDataset
241+ import torch
242+ from torch.utils.data import DataLoader
243+ dataset = PKLDataset(target_column='prognosis', feature_type=torch.long)
244+ print(dataset.__ncols__())
245+ 64
246+ """
247+ return self .data .size (1 )
0 commit comments