1717
1818from __future__ import annotations
1919
20- from typing import Dict , List , Literal , Optional , Union
20+ from typing import Dict , List , Literal , Optional
2121
2222import bigframes_vendored .sklearn .ensemble ._forest
2323import bigframes_vendored .xgboost .sklearn
@@ -142,8 +142,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
142142
143143 def _fit (
144144 self ,
145- X : Union [ bpd . DataFrame , bpd . Series ] ,
146- y : Union [ bpd . DataFrame , bpd . Series ] ,
145+ X : utils . ArrayType ,
146+ y : utils . ArrayType ,
147147 transforms : Optional [List [str ]] = None ,
148148 ) -> XGBRegressor :
149149 X , y = utils .convert_to_dataframe (X , y )
@@ -158,24 +158,24 @@ def _fit(
158158
159159 def predict (
160160 self ,
161- X : Union [ bpd . DataFrame , bpd . Series ] ,
161+ X : utils . ArrayType ,
162162 ) -> bpd .DataFrame :
163163 if not self ._bqml_model :
164164 raise RuntimeError ("A model must be fitted before predict" )
165- (X ,) = utils .convert_to_dataframe (X )
165+ (X ,) = utils .convert_to_dataframe (X , session = self . _bqml_model . session )
166166
167167 return self ._bqml_model .predict (X )
168168
169169 def score (
170170 self ,
171- X : Union [ bpd . DataFrame , bpd . Series ] ,
172- y : Union [ bpd . DataFrame , bpd . Series ] ,
171+ X : utils . ArrayType ,
172+ y : utils . ArrayType ,
173173 ):
174- X , y = utils .convert_to_dataframe (X , y )
175-
176174 if not self ._bqml_model :
177175 raise RuntimeError ("A model must be fitted before score" )
178176
177+ X , y = utils .convert_to_dataframe (X , y , session = self ._bqml_model .session )
178+
179179 input_data = (
180180 X .join (y , how = "outer" ) if (X is not None ) and (y is not None ) else None
181181 )
@@ -291,8 +291,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
291291
292292 def _fit (
293293 self ,
294- X : Union [ bpd . DataFrame , bpd . Series ] ,
295- y : Union [ bpd . DataFrame , bpd . Series ] ,
294+ X : utils . ArrayType ,
295+ y : utils . ArrayType ,
296296 transforms : Optional [List [str ]] = None ,
297297 ) -> XGBClassifier :
298298 X , y = utils .convert_to_dataframe (X , y )
@@ -305,22 +305,22 @@ def _fit(
305305 )
306306 return self
307307
308- def predict (self , X : Union [ bpd . DataFrame , bpd . Series ] ) -> bpd .DataFrame :
308+ def predict (self , X : utils . ArrayType ) -> bpd .DataFrame :
309309 if not self ._bqml_model :
310310 raise RuntimeError ("A model must be fitted before predict" )
311- (X ,) = utils .convert_to_dataframe (X )
311+ (X ,) = utils .convert_to_dataframe (X , session = self . _bqml_model . session )
312312
313313 return self ._bqml_model .predict (X )
314314
315315 def score (
316316 self ,
317- X : Union [ bpd . DataFrame , bpd . Series ] ,
318- y : Union [ bpd . DataFrame , bpd . Series ] ,
317+ X : utils . ArrayType ,
318+ y : utils . ArrayType ,
319319 ):
320320 if not self ._bqml_model :
321321 raise RuntimeError ("A model must be fitted before score" )
322322
323- X , y = utils .convert_to_dataframe (X , y )
323+ X , y = utils .convert_to_dataframe (X , y , session = self . _bqml_model . session )
324324
325325 input_data = (
326326 X .join (y , how = "outer" ) if (X is not None ) and (y is not None ) else None
@@ -427,8 +427,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
427427
428428 def _fit (
429429 self ,
430- X : Union [ bpd . DataFrame , bpd . Series ] ,
431- y : Union [ bpd . DataFrame , bpd . Series ] ,
430+ X : utils . ArrayType ,
431+ y : utils . ArrayType ,
432432 transforms : Optional [List [str ]] = None ,
433433 ) -> RandomForestRegressor :
434434 X , y = utils .convert_to_dataframe (X , y )
@@ -443,18 +443,18 @@ def _fit(
443443
444444 def predict (
445445 self ,
446- X : Union [ bpd . DataFrame , bpd . Series ] ,
446+ X : utils . ArrayType ,
447447 ) -> bpd .DataFrame :
448448 if not self ._bqml_model :
449449 raise RuntimeError ("A model must be fitted before predict" )
450- (X ,) = utils .convert_to_dataframe (X )
450+ (X ,) = utils .convert_to_dataframe (X , session = self . _bqml_model . session )
451451
452452 return self ._bqml_model .predict (X )
453453
454454 def score (
455455 self ,
456- X : Union [ bpd . DataFrame , bpd . Series ] ,
457- y : Union [ bpd . DataFrame , bpd . Series ] ,
456+ X : utils . ArrayType ,
457+ y : utils . ArrayType ,
458458 ):
459459 """Calculate evaluation metrics of the model.
460460
@@ -476,7 +476,7 @@ def score(
476476 if not self ._bqml_model :
477477 raise RuntimeError ("A model must be fitted before score" )
478478
479- X , y = utils .convert_to_dataframe (X , y )
479+ X , y = utils .convert_to_dataframe (X , y , session = self . _bqml_model . session )
480480
481481 input_data = (
482482 X .join (y , how = "outer" ) if (X is not None ) and (y is not None ) else None
@@ -583,8 +583,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
583583
584584 def _fit (
585585 self ,
586- X : Union [ bpd . DataFrame , bpd . Series ] ,
587- y : Union [ bpd . DataFrame , bpd . Series ] ,
586+ X : utils . ArrayType ,
587+ y : utils . ArrayType ,
588588 transforms : Optional [List [str ]] = None ,
589589 ) -> RandomForestClassifier :
590590 X , y = utils .convert_to_dataframe (X , y )
@@ -599,18 +599,18 @@ def _fit(
599599
600600 def predict (
601601 self ,
602- X : Union [ bpd . DataFrame , bpd . Series ] ,
602+ X : utils . ArrayType ,
603603 ) -> bpd .DataFrame :
604604 if not self ._bqml_model :
605605 raise RuntimeError ("A model must be fitted before predict" )
606- (X ,) = utils .convert_to_dataframe (X )
606+ (X ,) = utils .convert_to_dataframe (X , session = self . _bqml_model . session )
607607
608608 return self ._bqml_model .predict (X )
609609
610610 def score (
611611 self ,
612- X : Union [ bpd . DataFrame , bpd . Series ] ,
613- y : Union [ bpd . DataFrame , bpd . Series ] ,
612+ X : utils . ArrayType ,
613+ y : utils . ArrayType ,
614614 ):
615615 """Calculate evaluation metrics of the model.
616616
@@ -632,7 +632,7 @@ def score(
632632 if not self ._bqml_model :
633633 raise RuntimeError ("A model must be fitted before score" )
634634
635- X , y = utils .convert_to_dataframe (X , y )
635+ X , y = utils .convert_to_dataframe (X , y , session = self . _bqml_model . session )
636636
637637 input_data = (
638638 X .join (y , how = "outer" ) if (X is not None ) and (y is not None ) else None
0 commit comments