Skip to content

Commit da1e9ea

Browse files
committed
add T5RelationExtractor GPU support
1 parent ecd793a commit da1e9ea

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

renard/pipeline/relation_extraction.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import Any, Union, Optional
1+
from typing import Any, Union, Optional, Literal
22
import ast, re
33
import functools as ft
44
from datasets import load_dataset, Dataset as HGDataset
5+
import torch
56
from transformers import (
67
AutoModelForSeq2SeqLM,
78
T5ForConditionalGeneration,
@@ -79,15 +80,24 @@ class T5RelationExtractor(PipelineStep):
7980
DEFAULT_MODEL = "compnet-renard/t5-small-literary-relation-extraction"
8081

8182
def __init__(
82-
self, model: Optional[Union[PreTrainedModel, str]] = None, batch_size: int = 1
83+
self,
84+
model: Optional[Union[PreTrainedModel, str]] = None,
85+
batch_size: int = 1,
86+
device: Literal["cpu", "cuda", "auto"] = "auto",
8387
):
8488
self.model = T5RelationExtractor.DEFAULT_MODEL if model is None else model
8589
self.hg_pipeline = None
8690
self.batch_size = batch_size
91+
if device == "auto":
92+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93+
else:
94+
self.device = torch.device(device)
8795

8896
def _pipeline_init_(self, lang: str, progress_reporter: ProgressReporter, **kwargs):
8997
super()._pipeline_init_(lang, progress_reporter, **kwargs)
90-
self.hg_pipeline = hg_pipeline("text2text-generation", model=self.model)
98+
self.hg_pipeline = hg_pipeline(
99+
"text2text-generation", model=self.model, device=self.device
100+
)
91101

92102
def __call__(
93103
self, sentences: list[list[str]], characters: list[Character], **kwargs

0 commit comments

Comments
 (0)