|
1 | | -from typing import Any, Union, Optional |
| 1 | +from typing import Any, Union, Optional, Literal |
2 | 2 | import ast, re |
3 | 3 | import functools as ft |
4 | 4 | from datasets import load_dataset, Dataset as HGDataset |
| 5 | +import torch |
5 | 6 | from transformers import ( |
6 | 7 | AutoModelForSeq2SeqLM, |
7 | 8 | T5ForConditionalGeneration, |
@@ -79,15 +80,24 @@ class T5RelationExtractor(PipelineStep): |
79 | 80 | DEFAULT_MODEL = "compnet-renard/t5-small-literary-relation-extraction" |
80 | 81 |
|
81 | 82 | def __init__( |
82 | | - self, model: Optional[Union[PreTrainedModel, str]] = None, batch_size: int = 1 |
| 83 | + self, |
| 84 | + model: Optional[Union[PreTrainedModel, str]] = None, |
| 85 | + batch_size: int = 1, |
| 86 | + device: Literal["cpu", "cuda", "auto"] = "auto", |
83 | 87 | ): |
84 | 88 | self.model = T5RelationExtractor.DEFAULT_MODEL if model is None else model |
85 | 89 | self.hg_pipeline = None |
86 | 90 | self.batch_size = batch_size |
| 91 | + if device == "auto": |
| 92 | + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 93 | + else: |
| 94 | + self.device = torch.device(device) |
87 | 95 |
|
88 | 96 | def _pipeline_init_(self, lang: str, progress_reporter: ProgressReporter, **kwargs): |
89 | 97 | super()._pipeline_init_(lang, progress_reporter, **kwargs) |
90 | | - self.hg_pipeline = hg_pipeline("text2text-generation", model=self.model) |
| 98 | + self.hg_pipeline = hg_pipeline( |
| 99 | + "text2text-generation", model=self.model, device=self.device |
| 100 | + ) |
91 | 101 |
|
92 | 102 | def __call__( |
93 | 103 | self, sentences: list[list[str]], characters: list[Character], **kwargs |
|
0 commit comments