diff --git a/CHANGELOG.md b/CHANGELOG.md index fb4b3a96f..53064fa3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,10 @@ - Used `getIndex()` instead of `ptr()` for sorting nonlinear expression terms to avoid nondeterministic behavior ### Changed - Speed up `constant * Expr` via C-level API +- Speed up `Term.__eq__` via the C-level API ### Removed - Removed outdated warning about Make build system incompatibility +- Removed `Term.ptrtuple` to optimize `Term` memory usage ## 6.1.0 - 2026.01.31 ### Added diff --git a/src/pyscipopt/expr.pxi b/src/pyscipopt/expr.pxi index 59c7c7e4d..de6b775dd 100644 --- a/src/pyscipopt/expr.pxi +++ b/src/pyscipopt/expr.pxi @@ -104,13 +104,11 @@ cdef class Term: '''This is a monomial term''' cdef readonly tuple vartuple - cdef readonly tuple ptrtuple cdef Py_ssize_t hashval def __init__(self, *vartuple: Variable): self.vartuple = tuple(sorted(vartuple, key=lambda v: v.getIndex())) - self.ptrtuple = tuple(v.ptr() for v in self.vartuple) - self.hashval = hash(self.ptrtuple) + self.hashval = hash(tuple(v.ptr() for v in self.vartuple)) def __getitem__(self, idx): return self.vartuple[idx] @@ -118,8 +116,25 @@ cdef class Term: def __hash__(self) -> Py_ssize_t: return self.hashval - def __eq__(self, other: Term): - return self.ptrtuple == other.ptrtuple + def __eq__(self, other) -> bool: + if other is self: + return True + if Py_TYPE(other) is not Term: + return False + + cdef int n = len(self) + cdef Term _other = other + if n != len(_other) or self.hashval != _other.hashval: + return False + + cdef int i + cdef Variable var1, var2 + for i in range(n): + var1 = PyTuple_GET_ITEM(self.vartuple, i) + var2 = PyTuple_GET_ITEM(_other.vartuple, i) + if var1.ptr() != var2.ptr(): + return False + return True def __len__(self): return len(self.vartuple) @@ -156,8 +171,7 @@ cdef class Term: cdef Term res = Term.__new__(Term) res.vartuple = tuple(vartuple) - res.ptrtuple = tuple(v.ptr() for v in res.vartuple) - res.hashval = hash(res.ptrtuple) + res.hashval = hash(tuple(v.ptr() for v in res.vartuple)) return res def __repr__(self): diff --git a/src/pyscipopt/scip.pxi b/src/pyscipopt/scip.pxi index 871638ebf..db409f643 100644 --- a/src/pyscipopt/scip.pxi +++ b/src/pyscipopt/scip.pxi @@ -1565,7 +1565,6 @@ cdef class Variable(Expr): return cname.decode('utf-8') def ptr(self): - """ """ return (self.scip_var) def __repr__(self): diff --git a/src/pyscipopt/scip.pyi b/src/pyscipopt/scip.pyi index 92cc58540..ddc14ca70 100644 --- a/src/pyscipopt/scip.pyi +++ b/src/pyscipopt/scip.pyi @@ -2202,7 +2202,6 @@ class SumExpr(GenExpr): @disjoint_base class Term: - ptrtuple: Incomplete vartuple: Incomplete def __init__(self, *vartuple: Incomplete) -> None: ... def __mul__(self, other: Term) -> Term: ... diff --git a/tests/test_expr.py b/tests/test_expr.py index 855f47cff..61435bb85 100644 --- a/tests/test_expr.py +++ b/tests/test_expr.py @@ -2,7 +2,7 @@ import pytest -from pyscipopt import Model, sqrt, log, exp, sin, cos +from pyscipopt import Model, sqrt, log, exp, sin, cos, quickprod from pyscipopt.scip import Expr, GenExpr, ExprCons, CONST @@ -249,3 +249,24 @@ def test_abs_abs_expr(): # should print abs(x) not abs(abs(x)) assert str(abs(abs(x))) == str(abs(x)) + + +def test_term_eq(): + m = Model() + + x = m.addMatrixVar(1000) + y = m.addVar() + z = m.addVar() + + e1 = quickprod(x.flat) + e2 = quickprod(x.flat) + t1 = next(iter(e1)) + t2 = next(iter(e2)) + t3 = next(iter(e1 * y)) + t4 = next(iter(e2 * z)) + + assert t1 == t1 # same term + assert t1 == t2 # same term + assert t3 != t4 # same length, but different term + assert t1 != t3 # different length + assert t1 != "not a term" # different type