From cc80c7325d04c22f662baa2a526c27b7ab7a17e8 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 3 Jul 2025 11:51:13 +0200 Subject: [PATCH 1/4] perform implicit model resgistration --- chebai/models/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index cb254570..2c7f30cf 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -88,10 +88,10 @@ def __init_subclass__(cls, **kwargs): Args: **kwargs: Additional keyword arguments. """ - if cls.NAME in _MODEL_REGISTRY: - raise ValueError(f"Model {cls.NAME} does already exist") + if cls.__name__ in _MODEL_REGISTRY: + raise ValueError(f"Model {cls.__name__} does already exist") else: - _MODEL_REGISTRY[cls.NAME] = cls + _MODEL_REGISTRY[cls.__name__] = cls def _get_prediction_and_labels( self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor From b33015f8b45785c205a55eb99b767a04cd9b3a80 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 3 Jul 2025 11:55:10 +0200 Subject: [PATCH 2/4] remove explicit NAME class var --- chebai/models/base.py | 4 ---- chebai/models/chemberta.py | 4 ---- chebai/models/chemyk.py | 2 -- chebai/models/electra.py | 13 +------------ chebai/models/ffn.py | 4 +--- chebai/models/lstm.py | 2 -- 6 files changed, 2 insertions(+), 27 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 2c7f30cf..1ca8d7ec 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -26,12 +26,8 @@ class ChebaiBaseNet(LightningModule, ABC): optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the optimizer. Defaults to None. **kwargs: Additional keyword arguments. - Attributes: - NAME (str): The name of the model. """ - NAME = None - def __init__( self, criterion: torch.nn.Module = None, diff --git a/chebai/models/chemberta.py b/chebai/models/chemberta.py index b601542a..3fb6dcd8 100644 --- a/chebai/models/chemberta.py +++ b/chebai/models/chemberta.py @@ -20,8 +20,6 @@ class ChembertaPre(ChebaiBaseNet): - NAME = "ChembertaPre" - def __init__(self, p=0.2, **kwargs): super().__init__(**kwargs) self._p = p @@ -47,8 +45,6 @@ def forward(self, data): class Chemberta(ChebaiBaseNet): - NAME = "Chemberta" - def __init__(self, **kwargs): # Remove this property in order to prevent it from being stored as a # hyper parameter diff --git a/chebai/models/chemyk.py b/chebai/models/chemyk.py index 13bbea7c..8aa05ff3 100644 --- a/chebai/models/chemyk.py +++ b/chebai/models/chemyk.py @@ -15,8 +15,6 @@ class ChemYK(ChebaiBaseNet): - NAME = "ChemYK" - def __init__(self, in_d, out_d, num_classes, **kwargs): super().__init__(num_classes, **kwargs) d_internal = in_d diff --git a/chebai/models/electra.py b/chebai/models/electra.py index dc6c719b..dd3b38a0 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -31,7 +31,6 @@ class ElectraPre(ChebaiBaseNet): **kwargs: Additional keyword arguments (passed to parent class). Attributes: - NAME (str): Name of the ElectraPre model. generator_config (ElectraConfig): Configuration for the generator model. generator (ElectraForMaskedLM): Generator model for masked language modeling. discriminator_config (ElectraConfig): Configuration for the discriminator model. @@ -39,8 +38,6 @@ class ElectraPre(ChebaiBaseNet): replace_p (float): Probability of replacing tokens during training. """ - NAME = "ElectraPre" - def __init__(self, config: Dict[str, Any] = None, **kwargs: Any): super().__init__(config=config, **kwargs) self.generator_config = ElectraConfig(**config["generator"]) @@ -174,12 +171,8 @@ class Electra(ChebaiBaseNet): load_prefix (str, optional): Prefix to filter the state_dict keys from the pretrained checkpoint. Defaults to None. **kwargs: Additional keyword arguments. - Attributes: - NAME (str): Name of the Electra model. """ - NAME = "Electra" - def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]: """ Process a batch of data. @@ -328,7 +321,7 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: inp = self.electra.embeddings.forward(data["features"].int()) except RuntimeError as e: print(f"RuntimeError at forward: {e}") - print(f'data[features]: {data["features"]}') + print(f"data[features]: {data['features']}") raise e inp = self.word_dropout(inp) electra = self.electra(inputs_embeds=inp, **kwargs) @@ -340,8 +333,6 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: class ElectraLegacy(ChebaiBaseNet): - NAME = "ElectraLeg" - def __init__(self, **kwargs): super().__init__(**kwargs) self.config = ElectraConfig(**kwargs["config"], output_attentions=True) @@ -374,8 +365,6 @@ def forward(self, data): class ConeElectra(ChebaiBaseNet): - NAME = "ConeElectra" - def _process_batch(self, batch, batch_idx): mask = pad_sequence( [torch.ones(l + 1, device=self.device) for l in batch.lens], diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py index c9c6f912..70641615 100644 --- a/chebai/models/ffn.py +++ b/chebai/models/ffn.py @@ -9,15 +9,13 @@ class FFN(ChebaiBaseNet): # Reference: https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/models.py#L121-L139 - NAME = "FFN" - def __init__( self, input_size: int, hidden_layers: List[int] = [ 1024, ], - **kwargs + **kwargs, ): super().__init__(**kwargs) diff --git a/chebai/models/lstm.py b/chebai/models/lstm.py index c706d6a5..f3431a71 100644 --- a/chebai/models/lstm.py +++ b/chebai/models/lstm.py @@ -10,8 +10,6 @@ class ChemLSTM(ChebaiBaseNet): - NAME = "LSTM" - def __init__(self, in_d, out_d, num_classes, **kwargs): super().__init__(num_classes, **kwargs) self.lstm = nn.LSTM(in_d, out_d, batch_first=True) From 10825fd929cd8564d9d10c8e2f419e64c04026a8 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 3 Jul 2025 11:55:36 +0200 Subject: [PATCH 3/4] remove redundant file --- chebai/models/external/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 chebai/models/external/__init__.py diff --git a/chebai/models/external/__init__.py b/chebai/models/external/__init__.py deleted file mode 100644 index e69de29b..00000000 From 6fd95916c1c3d7d051375c75b2fe922cfc8638ca Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 3 Jul 2025 11:58:58 +0200 Subject: [PATCH 4/4] Update recursive.py --- chebai/models/recursive.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/chebai/models/recursive.py b/chebai/models/recursive.py index fb408039..854e4326 100644 --- a/chebai/models/recursive.py +++ b/chebai/models/recursive.py @@ -11,8 +11,6 @@ class Recursive(ChebaiBaseNet): - NAME = "REC" - def __init__(self, in_d, out_d, num_classes, **kwargs): super().__init__(num_classes, **kwargs) mem_len = in_d