From 839dfbeb7c765af235a5c809c6f9ea60d31423df Mon Sep 17 00:00:00 2001 From: genisis0x Date: Fri, 29 May 2026 17:24:08 +0530 Subject: [PATCH] fix(model): correct default-metric check in HIST and IGMTF metric_fn in pytorch_hist.py and pytorch_igmtf.py compared self.metric (a string, default "") against the tuple ("", "loss") with ==, which is never true, so the default/"loss" metric fell through to raise ValueError("unknown metric"). Every other torch model (pytorch_lstm, pytorch_alstm, pytorch_gru, ...) uses `if self.metric in ("", "loss")`. Align HIST and IGMTF to the same membership test so the default metric returns the negative loss as intended. Closes #2163. --- qlib/contrib/model/pytorch_hist.py | 2 +- qlib/contrib/model/pytorch_igmtf.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/qlib/contrib/model/pytorch_hist.py b/qlib/contrib/model/pytorch_hist.py index 779cde9c859..50ef64ec432 100644 --- a/qlib/contrib/model/pytorch_hist.py +++ b/qlib/contrib/model/pytorch_hist.py @@ -170,7 +170,7 @@ def metric_fn(self, pred, label): vy = y - torch.mean(y) return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2))) - if self.metric == ("", "loss"): + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/contrib/model/pytorch_igmtf.py b/qlib/contrib/model/pytorch_igmtf.py index 0bddc5a0f5f..1e8be1c8f3f 100644 --- a/qlib/contrib/model/pytorch_igmtf.py +++ b/qlib/contrib/model/pytorch_igmtf.py @@ -163,7 +163,7 @@ def metric_fn(self, pred, label): vy = y - torch.mean(y) return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2))) - if self.metric == ("", "loss"): + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric)