Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) ->
while not result.ready():
try:
progress.update(queue.get(timeout=1))
except:
except Exception:
pass

# remaining items
Expand Down Expand Up @@ -819,7 +819,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) ->
while not result.ready():
try:
progress.update(queue.get(timeout=1))
except:
except Exception:
pass

# remaining items
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/datasets/mimic3.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
**kwargs,
) -> None:
"""
Initializes the MIMIC4Dataset with the specified parameters.
Initializes the MIMIC3Dataset with the specified parameters.

Args:
root (str): The root directory where the dataset is stored.
Expand Down
6 changes: 3 additions & 3 deletions pyhealth/models/adacare.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ def __init__(

def forward(
self,
x: torch.tensor,
mask: Optional[torch.tensor] = None,
) -> Tuple[torch.tensor, torch.tensor, torch.tensor]:
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward propagation.

Args:
Expand Down
26 changes: 13 additions & 13 deletions pyhealth/models/gamenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def reset_parameters(self):
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)

def forward(self, input: torch.tensor, adj: torch.tensor) -> torch.tensor:
def forward(self, input: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
"""
Args:
input: input feature tensor of shape [num_nodes, in_features].
Expand Down Expand Up @@ -74,7 +74,7 @@ class GCN(nn.Module):
dropout: dropout rate. Default is 0.5.
"""

def __init__(self, adj: torch.tensor, hidden_size: int, dropout: float = 0.5):
def __init__(self, adj: torch.Tensor, hidden_size: int, dropout: float = 0.5):
super(GCN, self).__init__()
self.emb_dim = hidden_size
self.dropout = dropout
Expand All @@ -89,7 +89,7 @@ def __init__(self, adj: torch.tensor, hidden_size: int, dropout: float = 0.5):
self.dropout_layer = nn.Dropout(p=dropout)
self.gcn2 = GCNLayer(hidden_size, hidden_size)

def normalize(self, mx: torch.tensor) -> torch.tensor:
def normalize(self, mx: torch.Tensor) -> torch.Tensor:
"""Normalizes the matrix row-wise."""
rowsum = mx.sum(1)
r_inv = torch.pow(rowsum, -1).flatten()
Expand All @@ -98,7 +98,7 @@ def normalize(self, mx: torch.tensor) -> torch.tensor:
mx = torch.mm(r_mat_inv, mx)
return mx

def forward(self) -> torch.tensor:
def forward(self) -> torch.Tensor:
"""Forward propagation.

Returns:
Expand Down Expand Up @@ -144,8 +144,8 @@ class GAMENetLayer(nn.Module):
def __init__(
self,
hidden_size: int,
ehr_adj: torch.tensor,
ddi_adj: torch.tensor,
ehr_adj: torch.Tensor,
ddi_adj: torch.Tensor,
dropout: float = 0.5,
):
super(GAMENetLayer, self).__init__()
Expand All @@ -163,11 +163,11 @@ def __init__(

def forward(
self,
queries: torch.tensor,
prev_drugs: torch.tensor,
curr_drugs: torch.tensor,
mask: Optional[torch.tensor] = None,
) -> Tuple[torch.tensor, torch.tensor]:
queries: torch.Tensor,
prev_drugs: torch.Tensor,
curr_drugs: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward propagation.

Args:
Expand Down Expand Up @@ -355,7 +355,7 @@ def __init__(
# save ddi adj
np.save(os.path.join(CACHE_PATH, "ddi_adj.npy"), ddi_adj.numpy())

def generate_ehr_adj(self) -> torch.tensor:
def generate_ehr_adj(self) -> torch.Tensor:
"""Generates the EHR graph adjacency matrix."""
label_vocab = self.dataset.output_processors[self.label_key].label_vocab
label_size = len(label_vocab)
Expand All @@ -376,7 +376,7 @@ def generate_ehr_adj(self) -> torch.tensor:
ehr_adj[med2, med1] = 1
return ehr_adj

def generate_ddi_adj(self) -> torch.tensor:
def generate_ddi_adj(self) -> torch.Tensor:
"""Generates the DDI graph adjacency matrix."""
atc = ATC()
ddi = atc.get_ddi(gamenet_ddi=True)
Expand Down
8 changes: 4 additions & 4 deletions pyhealth/models/grasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,10 @@ def grasp_encoder(self, input, static=None, mask=None):

def forward(
self,
x: torch.tensor,
static: Optional[torch.tensor] = None,
mask: Optional[torch.tensor] = None,
) -> torch.tensor:
x: torch.Tensor,
static: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward propagation.

Args:
Expand Down
20 changes: 10 additions & 10 deletions pyhealth/models/micron.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,21 @@ def __init__(

@staticmethod
def compute_reconstruction_loss(
logits: torch.tensor, logits_residual: torch.tensor, mask: torch.tensor
) -> torch.tensor:
logits: torch.Tensor, logits_residual: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""Compute reconstruction loss between predicted and actual medication changes.

The reconstruction loss measures how well the model captures medication changes
between consecutive visits by comparing the predicted changes (through residual
connections) with actual changes in prescriptions.

Args:
logits (torch.tensor): Raw logits for medication predictions across all visits.
logits_residual (torch.tensor): Residual logits representing predicted changes.
mask (torch.tensor): Boolean mask indicating valid visits.
logits (torch.Tensor): Raw logits for medication predictions across all visits.
logits_residual (torch.Tensor): Residual logits representing predicted changes.
mask (torch.Tensor): Boolean mask indicating valid visits.

Returns:
torch.tensor: Mean squared reconstruction loss value.
torch.Tensor: Mean squared reconstruction loss value.
"""
rec_loss = torch.mean(
torch.square(
Expand All @@ -88,10 +88,10 @@ def compute_reconstruction_loss(

def forward(
self,
patient_emb: torch.tensor,
drugs: torch.tensor,
mask: Optional[torch.tensor] = None,
) -> Tuple[torch.tensor, torch.tensor]:
patient_emb: torch.Tensor,
drugs: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward propagation.

Args:
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/models/molerec.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def forward(
substructure_mask: torch.Tensor,
substructure_graph: Union[StaticParaDict, Dict[str, Union[int, torch.Tensor]]],
molecule_graph: Union[StaticParaDict, Dict[str, Union[int, torch.Tensor]]],
mask: Optional[torch.tensor] = None,
mask: Optional[torch.Tensor] = None,
drug_indexes: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward propagation.
Expand Down
6 changes: 3 additions & 3 deletions pyhealth/models/retain.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def compute_beta(self, rx, lengths):

def forward(
self,
x: torch.tensor,
mask: Optional[torch.tensor] = None,
) -> Tuple[torch.tensor, torch.tensor]:
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward propagation.

Args:
Expand Down
18 changes: 9 additions & 9 deletions pyhealth/models/safedrug.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def reset_parameters(self):
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)

def forward(self, input: torch.tensor, mask: torch.tensor) -> torch.tensor:
def forward(self, input: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
Args:
input: input feature tensor of shape [batch size, ..., input_size].
Expand Down Expand Up @@ -139,8 +139,8 @@ class SafeDrugLayer(nn.Module):
ddi_adj: an adjacency tensor of shape [num_drugs, num_drugs].
num_fingerprints: total number of different fingerprints.
molecule_set: a list of molecule tuples (A, B, C) of length num_molecules.
- A <torch.tensor>: fingerprints of atoms in the molecule
- B <torch.tensor>: adjacency matrix of the molecule
- A <torch.Tensor>: fingerprints of atoms in the molecule
- B <torch.Tensor>: adjacency matrix of the molecule
- C <int>: molecular_size
average_projection: a tensor of shape [num_drugs, num_molecules] representing
the average projection for aggregating multiple molecules of the
Expand Down Expand Up @@ -257,10 +257,10 @@ def calculate_loss(

def forward(
self,
patient_emb: torch.tensor,
drugs: torch.tensor,
mask: Optional[torch.tensor] = None,
) -> Tuple[torch.tensor, torch.tensor]:
patient_emb: torch.Tensor,
drugs: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward propagation.

Args:
Expand Down Expand Up @@ -433,7 +433,7 @@ def __init__(
ddi_adj = self.generate_ddi_adj()
np.save(os.path.join(CACHE_PATH, "ddi_adj.npy"), ddi_adj.numpy())

def generate_ddi_adj(self) -> torch.tensor:
def generate_ddi_adj(self) -> torch.Tensor:
"""Generates the DDI graph adjacency matrix."""
atc = ATC()
ddi = atc.get_ddi(gamenet_ddi=True)
Expand Down Expand Up @@ -472,7 +472,7 @@ def generate_smiles_list(self) -> List[List[str]]:
all_smiles_list[index] += smiles_list
return all_smiles_list

def generate_mask_H(self) -> torch.tensor:
def generate_mask_H(self) -> torch.Tensor:
"""Generates the molecular segmentation mask H."""
all_substructures_list = [[] for _ in range(self.label_size)]
for index, smiles_list in enumerate(self.all_smiles_list):
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch


def batch_to_multihot(label: List[List[int]], num_labels: int) -> torch.tensor:
def batch_to_multihot(label: List[List[int]], num_labels: int) -> torch.Tensor:
"""Converts label to multihot format.

Args:
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def save_ckpt(self, ckpt_path: str) -> None:
return

def load_ckpt(self, ckpt_path: str) -> None:
"""Saves the model checkpoint."""
"""Loads the model checkpoint."""
state_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)
self.model.load_state_dict(state_dict)
return
Expand Down
Loading