Skip to content
Open
1 change: 1 addition & 0 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Release Notes

.. Upcoming Version

* Allow constant values in objective cost function. Refactored objective setting.
* Add support for SOS1 and SOS2 (Special Ordered Sets) constraints via ``Model.add_sos_constraints()`` and ``Model.remove_sos_constraints()``
* Add simplify method to LinearExpression to combine duplicate terms
* Add convenience function to create LinearExpression from constant
Expand Down
7 changes: 5 additions & 2 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,8 +726,8 @@ def const(self, value: DataArray) -> None:
self._data = assign_multiindex_safe(self.data, const=value)

@property
def has_constant(self) -> DataArray:
return self.const.any()
def has_constant(self) -> bool:
return bool(self.const.any())

# create a dummy for a mask, which can be implemented later
@property
Expand Down Expand Up @@ -1097,6 +1097,9 @@ def empty(self) -> EmptyDeprecationWrapper:
"""
return EmptyDeprecationWrapper(not self.size)

def drop_constant(self: GenericExpression) -> GenericExpression:
return self - self.const # type: ignore

def densify_terms(self: GenericExpression) -> GenericExpression:
"""
Move all non-zero term entries to the front and cut off all-zero
Expand Down
126 changes: 110 additions & 16 deletions linopy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import os
import re
from collections.abc import Callable, Mapping, Sequence
from functools import wraps
from pathlib import Path
from tempfile import NamedTemporaryFile, gettempdir
from typing import Any, Literal, overload
from typing import Any, Literal, ParamSpec, TypeVar, overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -77,6 +78,59 @@
logger = logging.getLogger(__name__)


P = ParamSpec("P")
R = TypeVar("R")


class ConstantObjectiveError(Exception): ...


def strip_and_replace_constant_objective(func: Callable[P, R]) -> Callable[P, R]:
"""
Decorates a Model instance method.

If the model objective contains a constant term, this decorator will:
- Remove the constant term from the model objective
- Call the decorated method
- Add the constant term back to the model objective
"""

@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
assert args, "Expected at least one argument (self)"
self = args[0]
assert isinstance(self, Model), (
f"First argument must be a Model instance, got {type(self)}"
)
model = self
if not self.objective.has_constant:
# Continue as normal if there is no constant term
return func(*args, **kwargs)

# The objective contains a constant term
# Modify the model objective to drop the constant term
model = self
constant = float(self.objective.expression.const.values)
model.objective.expression = self.objective.expression.drop_constant()
args = (model, *args[1:]) # type: ignore

try:
result = func(*args, **kwargs)
except Exception as e:
# Even if there is an exception, make sure the model returns to it's original state
model.objective.expression = model.objective.expression + constant
raise e

# Re-add the constant term to return the model objective to the original expression
model.objective.expression = model.objective.expression + constant
if model.objective.value is not None:
model.objective.set_value(model.objective.value + constant)

return result

return wrapper


class Model:
"""
Linear optimization model.
Expand Down Expand Up @@ -212,23 +266,20 @@ def objective(self) -> Objective:

@objective.setter
def objective(
self, obj: Objective | LinearExpression | QuadraticExpression
self, obj: Objective | Variable | LinearExpression | QuadraticExpression
) -> Objective:
if not isinstance(obj, Objective):
obj = Objective(obj, self)

self._objective = obj
self.add_objective(expr=obj, overwrite=True, allow_constant=False)
return self._objective

@property
def sense(self) -> str:
def sense(self) -> Literal["min", "max"]:
"""
Sense of the objective function.
"""
return self.objective.sense

@sense.setter
def sense(self, value: str) -> None:
def sense(self, value: Literal["min", "max"]) -> None:
self.objective.sense = value

@property
Expand Down Expand Up @@ -727,39 +778,81 @@ def add_constraints(
self.constraints.add(constraint)
return constraint

def add_objective(
@overload
def add_objective( # Set objective as Objective object
self,
expr: Objective,
sense: None = None,
overwrite: bool = False,
allow_constant: bool = False,
) -> None: ...

@overload
def add_objective( # Set objective as expression-like with sense
self,
expr: Variable
| LinearExpression
| QuadraticExpression
| Sequence[tuple[ConstantLike, VariableLike]],
sense: Literal["min", "max"] | None = None,
overwrite: bool = False,
allow_constant: bool = False,
) -> None: ...

def add_objective(
self,
expr: Variable
| LinearExpression
| QuadraticExpression
| Sequence[tuple[ConstantLike, VariableLike]]
| Objective,
sense: Literal["min", "max"] | None = None,
overwrite: bool = False,
sense: str = "min",
allow_constant: bool = False,
) -> None:
"""
Add an objective function to the model.

Parameters
----------
expr : linopy.LinearExpression, linopy.QuadraticExpression
expr : linopy.Variable, linopy.LinearExpression, linopy.QuadraticExpression, Objective
Expression describing the objective function.
sense: "min" or "max", the sense to optimize for. Defaults to min. Cannot be set if passing Objective directly
overwrite : False, optional
Whether to overwrite the existing objective. The default is False.
allow_constant: bool, optional
If True, the objective is allowed to contain a constant term. The default is False

Returns
-------
linopy.LinearExpression
linopy.LinearExpression, linopy.QuadraticExpression
The objective function assigned to the model.
"""
if not overwrite:
assert self.objective.expression.empty, (
"Objective already defined."
" Set `overwrite` to True to force overwriting."
)
if isinstance(expr, Variable):
expr = 1 * expr
self.objective.expression = expr
self.objective.sense = sense

if isinstance(expr, Objective):
assert sense is None, "Cannot set sense if objective object is passed"
objective = expr
assert objective.model == self
elif isinstance(expr, Variable | LinearExpression | QuadraticExpression):
sense = sense or "min"
objective = Objective(expression=expr, model=self, sense=sense)
else:
sense = sense or "min"
objective = Objective(
expression=self.linexpr(*expr), model=self, sense=sense
)

if not allow_constant and objective.expression.has_constant:
raise ConstantObjectiveError(
"Objective contains constant term. Either remove constants from the expression with expr.drop_constants() or use model.add_objective(..., allow_constant=True).",
)

self._objective = objective

def remove_variables(self, name: str) -> None:
"""
Expand Down Expand Up @@ -1107,6 +1200,7 @@ def get_problem_file(
) as f:
return Path(f.name)

@strip_and_replace_constant_objective
def solve(
self,
solver_name: str | None = None,
Expand Down
27 changes: 19 additions & 8 deletions linopy/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import functools
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
import polars as pl
Expand All @@ -24,6 +24,7 @@

from linopy import expressions
from linopy.types import ConstantLike
from linopy.variables import Variable

if TYPE_CHECKING:
from linopy.expressions import LinearExpression, QuadraticExpression
Expand Down Expand Up @@ -64,13 +65,19 @@ class Objective:

def __init__(
self,
expression: expressions.LinearExpression | expressions.QuadraticExpression,
expression: Variable
| expressions.LinearExpression
| expressions.QuadraticExpression,
model: Model,
sense: str = "min",
sense: Literal["min", "max"] = "min",
) -> None:
self._model: Model = model
self._value: float | None = None

if isinstance(expression, Variable):
expression = 1 * expression

assert sense in ["min", "max"]
self.sense: str = sense
self.expression: (
expressions.LinearExpression | expressions.QuadraticExpression
Expand Down Expand Up @@ -189,11 +196,15 @@ def expression(
if len(expr.coord_dims):
expr = expr.sum()

if (expr.const != 0.0) and not np.isnan(expr.const):
raise ValueError("Constant values in objective function not supported.")

self._expression = expr

@property
def has_constant(self) -> bool:
"""
Returns whether the objective has a constant term.
"""
return self.expression.has_constant

@property
def model(self) -> Model:
"""
Expand All @@ -202,14 +213,14 @@ def model(self) -> Model:
return self._model

@property
def sense(self) -> str:
def sense(self) -> Literal["min", "max"]:
"""
Returns the sense of the objective.
"""
return self._sense

@sense.setter
def sense(self, sense: str) -> None:
def sense(self, sense: Literal["min", "max"]) -> None:
"""
Sets the sense of the objective.
"""
Expand Down
15 changes: 15 additions & 0 deletions test/test_linear_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,21 @@ def test_cumsum(m: Model, multiple: float) -> None:
cumsum.nterm == 2


def test_drop_constant(x: Variable) -> None:
"""Test that constants are removed"""
expr_a = 2 * x
expr_b = expr_a + [1, 2]
expr_c = expr_b + float("nan")
for expr in [expr_a, expr_b, expr_c]:
expr = 2 * x + 10
expr_2 = expr.drop_constant()

assert all(expr_2.const.values == 0.0), (
f"Expected constant 0.0, got {expr_2.const.values}"
)
assert not expr_2.has_constant


def test_simplify_basic(x: Variable) -> None:
"""Test basic simplification with duplicate terms."""
expr = 2 * x + 3 * x + 1 * x
Expand Down
25 changes: 22 additions & 3 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from tempfile import gettempdir

import numpy as np
import pandas as pd
import pytest
import xarray as xr

from linopy import EQUAL, Model
from linopy.model import ConstantObjectiveError
from linopy.testing import assert_model_equal

target_shape: tuple[int, int] = (10, 10)
Expand Down Expand Up @@ -67,7 +69,7 @@ def test_objective() -> None:
y = m.add_variables(lower, upper, name="y")

obj1 = (10 * x + 5 * y).sum()
m.add_objective(obj1)
m.add_objective(obj1, allow_constant=True)
assert m.objective.vars.size == 200

# test overwriting
Expand All @@ -82,8 +84,8 @@ def test_objective() -> None:
assert m.objectiverange.min() == 2
assert m.objectiverange.max() == 2

# test objective with constant which is not supported
with pytest.raises(ValueError):
# test setting constant term in objective with explicitly allowing it
with pytest.raises(ConstantObjectiveError):
m.objective = m.objective + 3


Expand Down Expand Up @@ -163,3 +165,20 @@ def test_assert_model_equal() -> None:
m.add_objective(obj)

assert_model_equal(m, m)


def test_constant_not_allowed_in_objective_unless_specified_explicitly() -> None:
model = Model()
days = pd.Index(["Mon", "Tue", "Wed", "Thu", "Fri"], name="day")
x = model.add_variables(name="x", coords=[days])
non_linear = x + 1

with pytest.raises(ConstantObjectiveError):
model.add_objective(expr=non_linear, overwrite=True, allow_constant=False)
with pytest.raises(ConstantObjectiveError):
model.add_objective(expr=non_linear, overwrite=True)

with pytest.raises(ConstantObjectiveError):
model.objective = non_linear

model.add_objective(expr=non_linear, overwrite=True, allow_constant=True)
9 changes: 1 addition & 8 deletions test/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_set_sense_via_model(

def test_sense_setter_error(linear_objective: Objective) -> None:
with pytest.raises(ValueError):
linear_objective.sense = "not min or max"
linear_objective.sense = "not min or max" # type: ignore


def test_variables_inherited_properties(linear_objective: Objective) -> None:
Expand Down Expand Up @@ -187,10 +187,3 @@ def test_repr(linear_objective: Objective, quadratic_objective: Objective) -> No

assert "Linear" in linear_objective.__repr__()
assert "Quadratic" in quadratic_objective.__repr__()


def test_objective_constant() -> None:
m = Model()
linear_expr = LinearExpression(None, m) + 1
with pytest.raises(ValueError):
m.objective = Objective(linear_expr, m)
Loading