Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Upcoming Version

* Fix docs (pick highs solver)
* Add the `sphinx-copybutton` to the documentation
* Speed up LP file writing by 2-2.7x on large models through Polars streaming engine, join-based constraint assembly, and reduced per-constraint overhead

Version 0.6.1
--------------
Expand Down
19 changes: 19 additions & 0 deletions linopy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,25 @@ def group_terms_polars(df: pl.DataFrame) -> pl.DataFrame:
return df


def maybe_group_terms_polars(df: pl.DataFrame) -> pl.DataFrame:
"""
Group terms only if there are duplicate (labels, vars) pairs.

This avoids the expensive group_by operation when terms already
reference distinct variables (e.g. ``x - y`` has ``_term=2`` but
no duplicates). When skipping, columns are reordered to match the
output of ``group_terms_polars``.
"""
varcols = [c for c in df.columns if c.startswith("vars")]
keys = [c for c in ["labels"] + varcols if c in df.columns]
key_count = df.select(pl.struct(keys).n_unique()).item()
if key_count < df.height:
return group_terms_polars(df)
# Match column order of group_terms (group-by keys, coeffs, rest)
rest = [c for c in df.columns if c not in keys and c != "coeffs"]
return df.select(keys + ["coeffs"] + rest)


def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset:
"""
Join multiple xarray Dataarray's to a Dataset and warn if coordinates are not equal.
Expand Down
46 changes: 31 additions & 15 deletions linopy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@
generate_indices_for_printout,
get_dims_with_index_levels,
get_label_position,
group_terms_polars,
has_optimized_model,
infer_schema_polars,
iterate_slices,
maybe_group_terms_polars,
maybe_replace_signs,
print_coord,
print_single_constraint,
Expand Down Expand Up @@ -622,21 +621,38 @@ def to_polars(self) -> pl.DataFrame:
long = to_polars(ds[keys])

long = filter_nulls_polars(long)
long = group_terms_polars(long)
if ds.sizes.get("_term", 1) > 1:
long = maybe_group_terms_polars(long)
check_has_nulls_polars(long, name=f"{self.type} {self.name}")

short_ds = ds[[k for k in ds if "_term" not in ds[k].dims]]
schema = infer_schema_polars(short_ds)
schema["sign"] = pl.Enum(["=", "<=", ">="])
short = to_polars(short_ds, schema=schema)
short = filter_nulls_polars(short)
check_has_nulls_polars(short, name=f"{self.type} {self.name}")

df = pl.concat([short, long], how="diagonal_relaxed").sort(["labels", "rhs"])
# delete subsequent non-null rhs (happens is all vars per label are -1)
is_non_null = df["rhs"].is_not_null()
prev_non_is_null = is_non_null.shift(1).fill_null(False)
df = df.filter(is_non_null & ~prev_non_is_null | ~is_non_null)
# Build short DataFrame (labels, rhs, sign) without xarray broadcast.
# Apply labels mask directly instead of filter_nulls_polars.
labels_flat = ds["labels"].values.reshape(-1)
mask = labels_flat != -1
labels_masked = labels_flat[mask]
rhs_flat = np.broadcast_to(ds["rhs"].values, ds["labels"].shape).reshape(-1)

sign_values = ds["sign"].values
sign_flat = np.broadcast_to(sign_values, ds["labels"].shape).reshape(-1)
all_same_sign = len(sign_flat) > 0 and (
sign_flat[0] == sign_flat[-1] and (sign_flat[0] == sign_flat).all()
)

short_data: dict = {
"labels": labels_masked,
"rhs": rhs_flat[mask],
}
if all_same_sign:
short = pl.DataFrame(short_data).with_columns(
pl.lit(sign_flat[0]).cast(pl.Enum(["=", "<=", ">="])).alias("sign")
)
else:
short_data["sign"] = pl.Series(
"sign", sign_flat[mask], dtype=pl.Enum(["=", "<=", ">="])
)
short = pl.DataFrame(short_data)

df = long.join(short, on="labels", how="inner")
return df[["labels", "coeffs", "vars", "sign", "rhs"]]

# Wrapped function which would convert variable to dataarray
Expand Down
3 changes: 2 additions & 1 deletion linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
has_optimized_model,
is_constant,
iterate_slices,
maybe_group_terms_polars,
print_coord,
print_single_expression,
to_dataframe,
Expand Down Expand Up @@ -1463,7 +1464,7 @@ def to_polars(self) -> pl.DataFrame:

df = to_polars(self.data)
df = filter_nulls_polars(df)
df = group_terms_polars(df)
df = maybe_group_terms_polars(df)
check_has_nulls_polars(df, name=self.type)
return df

Expand Down
93 changes: 30 additions & 63 deletions linopy/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ def clean_name(name: str) -> str:
coord_sanitizer = str.maketrans("[,]", "(,)", " ")


def _format_and_write(
df: pl.DataFrame, columns: list[pl.Expr], f: BufferedWriter
) -> None:
"""
Format columns via concat_str and write to file.

Uses Polars streaming engine for better memory efficiency.
"""
df.lazy().select(pl.concat_str(columns, ignore_nulls=True)).collect(
engine="streaming"
).write_csv(
f, separator=" ", null_value="", quote_style="never", include_header=False
)


def signed_number(expr: pl.Expr) -> tuple[pl.Expr, pl.Expr]:
"""
Return polars expressions for a signed number string, handling -0.0 correctly.
Expand Down Expand Up @@ -155,10 +170,7 @@ def objective_write_linear_terms(
*signed_number(pl.col("coeffs")),
*print_variable(pl.col("vars")),
]
df = df.select(pl.concat_str(cols, ignore_nulls=True))
df.write_csv(
f, separator=" ", null_value="", quote_style="never", include_header=False
)
_format_and_write(df, cols, f)


def objective_write_quadratic_terms(
Expand All @@ -171,10 +183,7 @@ def objective_write_quadratic_terms(
*print_variable(pl.col("vars2")),
]
f.write(b"+ [\n")
df = df.select(pl.concat_str(cols, ignore_nulls=True))
df.write_csv(
f, separator=" ", null_value="", quote_style="never", include_header=False
)
_format_and_write(df, cols, f)
f.write(b"] / 2\n")


Expand Down Expand Up @@ -254,11 +263,7 @@ def bounds_to_file(
*signed_number(pl.col("upper")),
]

kwargs: Any = dict(
separator=" ", null_value="", quote_style="never", include_header=False
)
formatted = df.select(pl.concat_str(columns, ignore_nulls=True))
formatted.write_csv(f, **kwargs)
_format_and_write(df, columns, f)


def binaries_to_file(
Expand Down Expand Up @@ -296,11 +301,7 @@ def binaries_to_file(
*print_variable(pl.col("labels")),
]

kwargs: Any = dict(
separator=" ", null_value="", quote_style="never", include_header=False
)
formatted = df.select(pl.concat_str(columns, ignore_nulls=True))
formatted.write_csv(f, **kwargs)
_format_and_write(df, columns, f)


def integers_to_file(
Expand Down Expand Up @@ -339,11 +340,7 @@ def integers_to_file(
*print_variable(pl.col("labels")),
]

kwargs: Any = dict(
separator=" ", null_value="", quote_style="never", include_header=False
)
formatted = df.select(pl.concat_str(columns, ignore_nulls=True))
formatted.write_csv(f, **kwargs)
_format_and_write(df, columns, f)


def sos_to_file(
Expand Down Expand Up @@ -399,11 +396,7 @@ def sos_to_file(
pl.col("var_weights"),
]

kwargs: Any = dict(
separator=" ", null_value="", quote_style="never", include_header=False
)
formatted = df.select(pl.concat_str(columns, ignore_nulls=True))
formatted.write_csv(f, **kwargs)
_format_and_write(df, columns, f)


def constraints_to_file(
Expand Down Expand Up @@ -440,58 +433,32 @@ def constraints_to_file(
if df.height == 0:
continue

# Ensure each constraint has both coefficient and RHS terms
analysis = df.group_by("labels").agg(
[
pl.col("coeffs").is_not_null().sum().alias("coeff_rows"),
pl.col("sign").is_not_null().sum().alias("rhs_rows"),
]
)

valid = analysis.filter(
(pl.col("coeff_rows") > 0) & (pl.col("rhs_rows") > 0)
)

if valid.height == 0:
continue

# Keep only constraints that have both parts
df = df.join(valid.select("labels"), on="labels", how="inner")

# Sort by labels and mark first/last occurrences
df = df.sort("labels").with_columns(
[
pl.when(pl.col("labels").is_first_distinct())
.then(pl.col("labels"))
.otherwise(pl.lit(None))
.alias("labels_first"),
pl.col("labels").is_first_distinct().alias("is_first_in_group"),
(pl.col("labels") != pl.col("labels").shift(-1))
.fill_null(True)
.alias("is_last_in_group"),
]
)

row_labels = print_constraint(pl.col("labels_first"))
row_labels = print_constraint(pl.col("labels"))
col_labels = print_variable(pl.col("vars"))
columns = [
pl.when(pl.col("labels_first").is_not_null()).then(row_labels[0]),
pl.when(pl.col("labels_first").is_not_null()).then(row_labels[1]),
pl.when(pl.col("labels_first").is_not_null())
.then(pl.lit(":\n"))
.alias(":"),
pl.when(pl.col("is_first_in_group")).then(row_labels[0]),
pl.when(pl.col("is_first_in_group")).then(row_labels[1]),
pl.when(pl.col("is_first_in_group")).then(pl.lit(":\n")).alias(":"),
*signed_number(pl.col("coeffs")),
pl.when(pl.col("vars").is_not_null()).then(col_labels[0]),
pl.when(pl.col("vars").is_not_null()).then(col_labels[1]),
col_labels[0],
col_labels[1],
pl.when(pl.col("is_last_in_group")).then(pl.lit("\n")),
pl.when(pl.col("is_last_in_group")).then(pl.col("sign")),
pl.when(pl.col("is_last_in_group")).then(pl.lit(" ")),
pl.when(pl.col("is_last_in_group")).then(pl.col("rhs").cast(pl.String)),
]

kwargs: Any = dict(
separator=" ", null_value="", quote_style="never", include_header=False
)
formatted = df.select(pl.concat_str(columns, ignore_nulls=True))
formatted.write_csv(f, **kwargs)
_format_and_write(df, columns, f)

# in the future, we could use lazy dataframes when they support appending
# tp existent files
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies = [
"numexpr",
"xarray>=2024.2.0",
"dask>=0.18.0",
"polars",
"polars>=1.31",
"tqdm",
"deprecation",
"packaging",
Expand Down
18 changes: 18 additions & 0 deletions test/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
get_dims_with_index_levels,
is_constant,
iterate_slices,
maybe_group_terms_polars,
)
from linopy.testing import assert_linequal, assert_varequal

Expand Down Expand Up @@ -737,3 +738,20 @@ def test_is_constant() -> None:
]
for cv in constant_values:
assert is_constant(cv)


def test_maybe_group_terms_polars_no_duplicates() -> None:
"""Fast path: distinct (labels, vars) pairs skip group_by."""
df = pl.DataFrame({"labels": [0, 0], "vars": [1, 2], "coeffs": [3.0, 4.0]})
result = maybe_group_terms_polars(df)
assert result.shape == (2, 3)
assert result.columns == ["labels", "vars", "coeffs"]
assert result["coeffs"].to_list() == [3.0, 4.0]


def test_maybe_group_terms_polars_with_duplicates() -> None:
"""Slow path: duplicate (labels, vars) pairs trigger group_by."""
df = pl.DataFrame({"labels": [0, 0], "vars": [1, 1], "coeffs": [3.0, 4.0]})
result = maybe_group_terms_polars(df)
assert result.shape == (1, 3)
assert result["coeffs"].to_list() == [7.0]
14 changes: 14 additions & 0 deletions test/test_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,20 @@ def test_constraint_to_polars(c: linopy.constraints.Constraint) -> None:
assert isinstance(c.to_polars(), pl.DataFrame)


def test_constraint_to_polars_mixed_signs(m: Model, x: linopy.Variable) -> None:
"""Test to_polars when a constraint has mixed sign values across dims."""
# Create a constraint, then manually patch the sign to have mixed values
m.add_constraints(x >= 0, name="mixed")
con = m.constraints["mixed"]
# Replace sign data with mixed signs across the first dimension
n = con.data.sizes["first"]
signs = np.array(["<=" if i % 2 == 0 else ">=" for i in range(n)])
con.data["sign"] = xr.DataArray(signs, dims=con.data["sign"].dims)
df = con.to_polars()
assert isinstance(df, pl.DataFrame)
assert set(df["sign"].to_list()) == {"<=", ">="}


def test_constraint_assignment_with_anonymous_constraints(
m: Model, x: linopy.Variable, y: linopy.Variable
) -> None:
Expand Down
37 changes: 37 additions & 0 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,40 @@ def test_to_file_lp_with_negative_zero_coefficients(tmp_path: Path) -> None:

# Verify Gurobi can read it without errors
gurobipy.read(str(fn))


def test_to_file_lp_same_sign_constraints(tmp_path: Path) -> None:
"""Test LP writing when all constraints have the same sign operator."""
m = Model()
N = np.arange(5)
x = m.add_variables(coords=[N], name="x")
# All constraints use <=
m.add_constraints(x <= 10, name="upper")
m.add_constraints(x <= 20, name="upper2")
m.add_objective(x.sum())

fn = tmp_path / "same_sign.lp"
m.to_file(fn)
content = fn.read_text()
assert "s.t." in content
assert "<=" in content


def test_to_file_lp_mixed_sign_constraints(tmp_path: Path) -> None:
"""Test LP writing when constraints have different sign operators."""
m = Model()
N = np.arange(5)
x = m.add_variables(coords=[N], name="x")
# Mix of <= and >= constraints in the same container
m.add_constraints(x <= 10, name="upper")
m.add_constraints(x >= 1, name="lower")
m.add_constraints(2 * x == 8, name="eq")
m.add_objective(x.sum())

fn = tmp_path / "mixed_sign.lp"
m.to_file(fn)
content = fn.read_text()
assert "s.t." in content
assert "<=" in content
assert ">=" in content
assert "=" in content