Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,8 @@ class ChebaiBaseNet(LightningModule, ABC):
optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the optimizer. Defaults to None.
**kwargs: Additional keyword arguments.

Attributes:
NAME (str): The name of the model.
"""

NAME = None

def __init__(
self,
criterion: torch.nn.Module = None,
Expand Down Expand Up @@ -88,10 +84,10 @@ def __init_subclass__(cls, **kwargs):
Args:
**kwargs: Additional keyword arguments.
"""
if cls.NAME in _MODEL_REGISTRY:
raise ValueError(f"Model {cls.NAME} does already exist")
if cls.__name__ in _MODEL_REGISTRY:
raise ValueError(f"Model {cls.__name__} does already exist")
else:
_MODEL_REGISTRY[cls.NAME] = cls
_MODEL_REGISTRY[cls.__name__] = cls

def _get_prediction_and_labels(
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
Expand Down
4 changes: 0 additions & 4 deletions chebai/models/chemberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@


class ChembertaPre(ChebaiBaseNet):
NAME = "ChembertaPre"

def __init__(self, p=0.2, **kwargs):
super().__init__(**kwargs)
self._p = p
Expand All @@ -47,8 +45,6 @@ def forward(self, data):


class Chemberta(ChebaiBaseNet):
NAME = "Chemberta"

def __init__(self, **kwargs):
# Remove this property in order to prevent it from being stored as a
# hyper parameter
Expand Down
2 changes: 0 additions & 2 deletions chebai/models/chemyk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@


class ChemYK(ChebaiBaseNet):
NAME = "ChemYK"

def __init__(self, in_d, out_d, num_classes, **kwargs):
super().__init__(num_classes, **kwargs)
d_internal = in_d
Expand Down
13 changes: 1 addition & 12 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,13 @@ class ElectraPre(ChebaiBaseNet):
**kwargs: Additional keyword arguments (passed to parent class).

Attributes:
NAME (str): Name of the ElectraPre model.
generator_config (ElectraConfig): Configuration for the generator model.
generator (ElectraForMaskedLM): Generator model for masked language modeling.
discriminator_config (ElectraConfig): Configuration for the discriminator model.
discriminator (ElectraForPreTraining): Discriminator model for pre-training.
replace_p (float): Probability of replacing tokens during training.
"""

NAME = "ElectraPre"

def __init__(self, config: Dict[str, Any] = None, **kwargs: Any):
super().__init__(config=config, **kwargs)
self.generator_config = ElectraConfig(**config["generator"])
Expand Down Expand Up @@ -174,12 +171,8 @@ class Electra(ChebaiBaseNet):
load_prefix (str, optional): Prefix to filter the state_dict keys from the pretrained checkpoint. Defaults to None.
**kwargs: Additional keyword arguments.

Attributes:
NAME (str): Name of the Electra model.
"""

NAME = "Electra"

def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]:
"""
Process a batch of data.
Expand Down Expand Up @@ -328,7 +321,7 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
inp = self.electra.embeddings.forward(data["features"].int())
except RuntimeError as e:
print(f"RuntimeError at forward: {e}")
print(f'data[features]: {data["features"]}')
print(f"data[features]: {data['features']}")
raise e
inp = self.word_dropout(inp)
electra = self.electra(inputs_embeds=inp, **kwargs)
Expand All @@ -340,8 +333,6 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:


class ElectraLegacy(ChebaiBaseNet):
NAME = "ElectraLeg"

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.config = ElectraConfig(**kwargs["config"], output_attentions=True)
Expand Down Expand Up @@ -374,8 +365,6 @@ def forward(self, data):


class ConeElectra(ChebaiBaseNet):
NAME = "ConeElectra"

def _process_batch(self, batch, batch_idx):
mask = pad_sequence(
[torch.ones(l + 1, device=self.device) for l in batch.lens],
Expand Down
Empty file removed chebai/models/external/__init__.py
Empty file.
4 changes: 1 addition & 3 deletions chebai/models/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
class FFN(ChebaiBaseNet):
# Reference: https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/models.py#L121-L139

NAME = "FFN"

def __init__(
self,
input_size: int,
hidden_layers: List[int] = [
1024,
],
**kwargs
**kwargs,
):
super().__init__(**kwargs)

Expand Down
2 changes: 0 additions & 2 deletions chebai/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@


class ChemLSTM(ChebaiBaseNet):
NAME = "LSTM"

def __init__(self, in_d, out_d, num_classes, **kwargs):
super().__init__(num_classes, **kwargs)
self.lstm = nn.LSTM(in_d, out_d, batch_first=True)
Expand Down
2 changes: 0 additions & 2 deletions chebai/models/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@


class Recursive(ChebaiBaseNet):
NAME = "REC"

def __init__(self, in_d, out_d, num_classes, **kwargs):
super().__init__(num_classes, **kwargs)
mem_len = in_d
Expand Down