diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 730d3e04..60904063 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -3,6 +3,7 @@ Release Notes .. Upcoming Version +* Allow constant values in objective cost function. Refactored objective setting. * Add support for SOS1 and SOS2 (Special Ordered Sets) constraints via ``Model.add_sos_constraints()`` and ``Model.remove_sos_constraints()`` * Add simplify method to LinearExpression to combine duplicate terms * Add convenience function to create LinearExpression from constant diff --git a/linopy/expressions.py b/linopy/expressions.py index 10e243de..a9b65455 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -726,8 +726,8 @@ def const(self, value: DataArray) -> None: self._data = assign_multiindex_safe(self.data, const=value) @property - def has_constant(self) -> DataArray: - return self.const.any() + def has_constant(self) -> bool: + return bool(self.const.any()) # create a dummy for a mask, which can be implemented later @property @@ -1097,6 +1097,9 @@ def empty(self) -> EmptyDeprecationWrapper: """ return EmptyDeprecationWrapper(not self.size) + def drop_constant(self: GenericExpression) -> GenericExpression: + return self - self.const # type: ignore + def densify_terms(self: GenericExpression) -> GenericExpression: """ Move all non-zero term entries to the front and cut off all-zero diff --git a/linopy/model.py b/linopy/model.py index 81c069ab..7ac7d58c 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -10,9 +10,10 @@ import os import re from collections.abc import Callable, Mapping, Sequence +from functools import wraps from pathlib import Path from tempfile import NamedTemporaryFile, gettempdir -from typing import Any, Literal, overload +from typing import Any, Literal, ParamSpec, TypeVar, overload import numpy as np import pandas as pd @@ -77,6 +78,58 @@ logger = logging.getLogger(__name__) +P = ParamSpec("P") +R = TypeVar("R") + + +class ConstantObjectiveError(Exception): ... + + +def strip_and_replace_constant_objective(func: Callable[P, R]) -> Callable[P, R]: + """ + Decorates a Model instance method. + + If the model objective contains a constant term, this decorator will: + - Remove the constant term from the model objective + - Call the decorated method + - Add the constant term back to the model objective + """ + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + assert args, "Expected at least one argument (self)" + self = args[0] + assert isinstance(self, Model), ( + f"First argument must be a Model instance, got {type(self)}" + ) + model: Model = self + if not model.objective.has_constant: + # Continue as normal if there is no constant term + return func(*args, **kwargs) + + # The objective contains a constant term + # Modify the model objective to drop the constant term + constant = float(model.objective.expression.const.values) + model.objective.expression = model.objective.expression.drop_constant() + args = (model, *args[1:]) # type: ignore + + try: + result = func(*args, **kwargs) + except Exception as e: + # Even if there is an exception, make sure the model returns to its original state + model.objective.expression = model.objective.expression + constant + raise e + + # Re-add the constant term to return the model objective to the original expression + model.objective.expression = model.objective.expression + constant + if model.objective.value is not None: + model.objective.set_value(model.objective.value + constant) + + return result + + return wrapper + + class Model: """ Linear optimization model. @@ -212,23 +265,20 @@ def objective(self) -> Objective: @objective.setter def objective( - self, obj: Objective | LinearExpression | QuadraticExpression + self, obj: Objective | Variable | LinearExpression | QuadraticExpression ) -> Objective: - if not isinstance(obj, Objective): - obj = Objective(obj, self) - - self._objective = obj + self.add_objective(expr=obj, overwrite=True, allow_constant=False) return self._objective @property - def sense(self) -> str: + def sense(self) -> Literal["min", "max"]: """ Sense of the objective function. """ return self.objective.sense @sense.setter - def sense(self, value: str) -> None: + def sense(self, value: Literal["min", "max"]) -> None: self.objective.sense = value @property @@ -727,28 +777,54 @@ def add_constraints( self.constraints.add(constraint) return constraint - def add_objective( + @overload + def add_objective( # Set objective as Objective object + self, + expr: Objective, + sense: None = None, + overwrite: bool = False, + allow_constant: bool = False, + ) -> None: ... + + @overload + def add_objective( # Set objective as expression-like with sense self, expr: Variable | LinearExpression | QuadraticExpression | Sequence[tuple[ConstantLike, VariableLike]], + sense: Literal["min", "max"] | None = None, overwrite: bool = False, - sense: str = "min", + allow_constant: bool = False, + ) -> None: ... + + def add_objective( + self, + expr: Variable + | LinearExpression + | QuadraticExpression + | Sequence[tuple[ConstantLike, VariableLike]] + | Objective, + sense: Literal["min", "max"] | None = None, + overwrite: bool = False, + allow_constant: bool = False, ) -> None: """ Add an objective function to the model. Parameters ---------- - expr : linopy.LinearExpression, linopy.QuadraticExpression + expr : linopy.Variable, linopy.LinearExpression, linopy.QuadraticExpression, Objective Expression describing the objective function. + sense: "min" or "max", the sense to optimize for. Defaults to min. Cannot be set if passing Objective directly overwrite : False, optional Whether to overwrite the existing objective. The default is False. + allow_constant: bool, optional + If True, the objective is allowed to contain a constant term. The default is False Returns ------- - linopy.LinearExpression + linopy.LinearExpression, linopy.QuadraticExpression The objective function assigned to the model. """ if not overwrite: @@ -756,10 +832,26 @@ def add_objective( "Objective already defined." " Set `overwrite` to True to force overwriting." ) - if isinstance(expr, Variable): - expr = 1 * expr - self.objective.expression = expr - self.objective.sense = sense + + if isinstance(expr, Objective): + assert sense is None, "Cannot set sense if objective object is passed" + objective = expr + assert objective.model == self + elif isinstance(expr, Variable | LinearExpression | QuadraticExpression): + sense = sense or "min" + objective = Objective(expression=expr, model=self, sense=sense) + else: + sense = sense or "min" + objective = Objective( + expression=self.linexpr(*expr), model=self, sense=sense + ) + + if not allow_constant and objective.expression.has_constant: + raise ConstantObjectiveError( + "Objective contains constant term. Either remove constants from the expression with expr.drop_constants() or use model.add_objective(..., allow_constant=True).", + ) + + self._objective = objective def remove_variables(self, name: str) -> None: """ @@ -1107,6 +1199,7 @@ def get_problem_file( ) as f: return Path(f.name) + @strip_and_replace_constant_objective def solve( self, solver_name: str | None = None, diff --git a/linopy/objective.py b/linopy/objective.py index b1449270..87fd76ef 100644 --- a/linopy/objective.py +++ b/linopy/objective.py @@ -9,7 +9,7 @@ import functools from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import numpy as np import polars as pl @@ -24,6 +24,7 @@ from linopy import expressions from linopy.types import ConstantLike +from linopy.variables import Variable if TYPE_CHECKING: from linopy.expressions import LinearExpression, QuadraticExpression @@ -64,13 +65,19 @@ class Objective: def __init__( self, - expression: expressions.LinearExpression | expressions.QuadraticExpression, + expression: Variable + | expressions.LinearExpression + | expressions.QuadraticExpression, model: Model, - sense: str = "min", + sense: Literal["min", "max"] = "min", ) -> None: self._model: Model = model self._value: float | None = None + if isinstance(expression, Variable): + expression = 1 * expression + + assert sense in ["min", "max"] self.sense: str = sense self.expression: ( expressions.LinearExpression | expressions.QuadraticExpression @@ -189,11 +196,15 @@ def expression( if len(expr.coord_dims): expr = expr.sum() - if (expr.const != 0.0) and not np.isnan(expr.const): - raise ValueError("Constant values in objective function not supported.") - self._expression = expr + @property + def has_constant(self) -> bool: + """ + Returns whether the objective has a constant term. + """ + return self.expression.has_constant + @property def model(self) -> Model: """ @@ -202,14 +213,14 @@ def model(self) -> Model: return self._model @property - def sense(self) -> str: + def sense(self) -> Literal["min", "max"]: """ Returns the sense of the objective. """ return self._sense @sense.setter - def sense(self, sense: str) -> None: + def sense(self, sense: Literal["min", "max"]) -> None: """ Sets the sense of the objective. """ diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index a75ace3f..4969e964 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -1230,6 +1230,21 @@ def test_cumsum(m: Model, multiple: float) -> None: cumsum.nterm == 2 +def test_drop_constant(x: Variable) -> None: + """Test that constants are removed""" + expr_a = 2 * x + expr_b = expr_a + [1, 2] + expr_c = expr_b + float("nan") + for expr in [expr_a, expr_b, expr_c]: + expr = 2 * x + 10 + expr_2 = expr.drop_constant() + + assert all(expr_2.const.values == 0.0), ( + f"Expected constant 0.0, got {expr_2.const.values}" + ) + assert not expr_2.has_constant + + def test_simplify_basic(x: Variable) -> None: """Test basic simplification with duplicate terms.""" expr = 2 * x + 3 * x + 1 * x diff --git a/test/test_model.py b/test/test_model.py index c363fe4c..494118ad 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -9,10 +9,12 @@ from tempfile import gettempdir import numpy as np +import pandas as pd import pytest import xarray as xr from linopy import EQUAL, Model +from linopy.model import ConstantObjectiveError from linopy.testing import assert_model_equal target_shape: tuple[int, int] = (10, 10) @@ -67,7 +69,7 @@ def test_objective() -> None: y = m.add_variables(lower, upper, name="y") obj1 = (10 * x + 5 * y).sum() - m.add_objective(obj1) + m.add_objective(obj1, allow_constant=True) assert m.objective.vars.size == 200 # test overwriting @@ -82,8 +84,8 @@ def test_objective() -> None: assert m.objectiverange.min() == 2 assert m.objectiverange.max() == 2 - # test objective with constant which is not supported - with pytest.raises(ValueError): + # test setting constant term in objective with explicitly allowing it + with pytest.raises(ConstantObjectiveError): m.objective = m.objective + 3 @@ -163,3 +165,20 @@ def test_assert_model_equal() -> None: m.add_objective(obj) assert_model_equal(m, m) + + +def test_constant_not_allowed_in_objective_unless_specified_explicitly() -> None: + model = Model() + days = pd.Index(["Mon", "Tue", "Wed", "Thu", "Fri"], name="day") + x = model.add_variables(name="x", coords=[days]) + non_linear = x + 1 + + with pytest.raises(ConstantObjectiveError): + model.add_objective(expr=non_linear, overwrite=True, allow_constant=False) + with pytest.raises(ConstantObjectiveError): + model.add_objective(expr=non_linear, overwrite=True) + + with pytest.raises(ConstantObjectiveError): + model.objective = non_linear + + model.add_objective(expr=non_linear, overwrite=True, allow_constant=True) diff --git a/test/test_objective.py b/test/test_objective.py index d869175a..f886b9c4 100644 --- a/test/test_objective.py +++ b/test/test_objective.py @@ -69,7 +69,7 @@ def test_set_sense_via_model( def test_sense_setter_error(linear_objective: Objective) -> None: with pytest.raises(ValueError): - linear_objective.sense = "not min or max" + linear_objective.sense = "not min or max" # type: ignore def test_variables_inherited_properties(linear_objective: Objective) -> None: @@ -187,10 +187,3 @@ def test_repr(linear_objective: Objective, quadratic_objective: Objective) -> No assert "Linear" in linear_objective.__repr__() assert "Quadratic" in quadratic_objective.__repr__() - - -def test_objective_constant() -> None: - m = Model() - linear_expr = LinearExpression(None, m) + 1 - with pytest.raises(ValueError): - m.objective = Objective(linear_expr, m) diff --git a/test/test_optimization.py b/test/test_optimization.py index 12399a4e..846b292a 100644 --- a/test/test_optimization.py +++ b/test/test_optimization.py @@ -947,7 +947,7 @@ def test_model_resolve( # add another constraint after solve model.add_constraints(model.variables.y >= 3) - status, condition = model.solve( + status, _ = model.solve( solver, io_api=io_api, explicit_coordinate_names=explicit_coordinate_names ) assert status == "ok" @@ -955,6 +955,46 @@ def test_model_resolve( assert np.isclose(model.objective.value or 0, 5.25) +def test_constant_feasible(model: Model) -> None: + objective = model.objective.expression + 1 + model.add_objective(expr=objective, overwrite=True, allow_constant=True) + + status, _ = model.solve(solver_name="highs") + assert status == "ok" + # x = -0.1, y = 1.7 + assert model.objective.value == 4.3 + assert model.objective.expression.const == 1 + assert model.objective.expression.solution == 4.3 + + +def test_constant_infeasible(model: Model) -> None: + objective = model.objective.expression + 1 + model.add_objective(expr=objective, overwrite=True, allow_constant=True) + model.add_constraints([(1, "x")], "<=", 0) + model.add_constraints([(1, "y")], "<=", 0) + + _, condition = model.solve(solver_name="highs") + + assert condition == "infeasible" + # Even though the problem was not solved, the constant term should still be accessible + assert model.objective.expression.const == 1 + + +def test_constant_error(model: Model) -> None: + objective = model.objective.expression + 1 + model.add_objective(expr=objective, overwrite=True, allow_constant=True) + model.add_constraints([(1, "x")], "<=", 0) + model.add_constraints([(1, "y")], "<=", 0) + + try: + _ = model.solve(solver_name="apples") + except AssertionError: + pass + + # Even if something goes wrong, the model objective should return to the correct state + assert model.objective.expression.const == 1 + + @pytest.mark.parametrize( "solver,io_api,explicit_coordinate_names", [p for p in params if "direct" not in p] ) diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index fc1bb25f..cb203b15 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -312,6 +312,20 @@ def test_quadratic_expression_constant_to_polars() -> None: assert all(arr.to_numpy() == df["const"].to_numpy()) +def test_drop_constant(x: Variable) -> None: + """Test that constants are removed""" + expr_a = 2 * x * x + expr_b = expr_a + 1 + for expr in [expr_a, expr_b]: + expr = 2 * x + 10 + expr_2 = expr.drop_constant() + + assert all(expr_2.const.values == 0.0), ( + f"Expected constant 0.0, got {expr_2.const.values}" + ) + assert not expr_2.has_constant + + def test_quadratic_expression_to_matrix(model: Model, x: Variable, y: Variable) -> None: expr: QuadraticExpression = x * y + x + 5 # type: ignore