Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions climada/engine/test/test_impact_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@
Tests for Impact Forecast.
"""

import datetime as dt
from pathlib import Path

import numpy as np
import numpy.testing as npt
import pandas as pd
import pytest
import xarray as xr
from scipy.sparse import csr_matrix

from climada.engine import Impact, ImpactForecast
Expand Down Expand Up @@ -58,6 +62,7 @@ def impact_forecast(impact, lead_time, member):


class TestImpactForecastInit:

def assert_impact_kwargs(self, impact: Impact, **kwargs):
for key, value in kwargs.items():
attr = getattr(impact, key)
Expand Down
127 changes: 126 additions & 1 deletion climada/hazard/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,14 @@
"""

import logging
import pathlib
from typing import Any, Dict, List, Optional

import numpy as np
import scipy.sparse as sparse
import xarray as xr

from climada.hazard.xarray import HazardXarrayReader

from ..util.checker import size
from ..util.forecast import Forecast
Expand Down Expand Up @@ -86,9 +91,9 @@ def from_hazard(cls, hazard: Hazard, lead_time: np.ndarray, member: np.ndarray):
event_id=hazard.event_id,
frequency=hazard.frequency,
frequency_unit=hazard.frequency_unit,
orig=hazard.orig,
event_name=hazard.event_name,
date=hazard.date,
orig=hazard.orig,
intensity=hazard.intensity,
fraction=hazard.fraction,
)
Expand Down Expand Up @@ -282,6 +287,126 @@ def select(
reset_frequency=reset_frequency,
)

@classmethod
def from_xarray_raster(
cls,
data: xr.Dataset | pathlib.Path | str,
hazard_type: str,
intensity_unit: str,
*,
intensity: Optional[str] = None,
coordinate_vars: Optional[Dict[str, str]] = None,
crs: str = "EPSG:4326",
open_dataset_kws: dict[str, Any] | None = None,
):
"""Read forecast hazard data from an xarray Dataset

This extends the parent :py:meth:`~climada.hazard.base.Hazard.from_xarray_raster`
to handle forecast dimensions (lead_time and member). For forecast data, the
"event" dimension is constructed from the Cartesian product of lead_time and
member dimensions, so you don't need to specify an "event" coordinate.

Parameters
----------
data : xarray.Dataset or Path or str
The filepath to read the data from or the already opened dataset
hazard_type : str
The type identifier of the hazard
intensity_unit : str
The physical units of the intensity
intensity : str, optional
Identifier of the DataArray containing the hazard intensity data
coordinate_vars : dict(str, str), optional
Mapping from default coordinate names to coordinate names in the data.
For HazardForecast, should include:
- ``"lead_time"``: name of the lead time coordinate (required)
- ``"member"``: name of the ensemble member coordinate (required)
- ``"longitude"``: name of longitude coordinate (default: "longitude")
- ``"latitude"``: name of latitude coordinate (default: "latitude")

Note: The "event" coordinate is automatically constructed from lead_time
and member, so it should not be specified.
crs : str, optional
Coordinate reference system identifier. Defaults to "EPSG:4326"
open_dataset_kws : dict, optional
Keyword arguments passed to xarray.open_dataset if data is a file path
A forecast hazard object with lead_time and member attributes populated

See Also
--------
:py:meth:`climada.hazard.base.Hazard.from_xarray_raster`
Parent method documentation for standard hazard loading
"""

# Open dataset if needed
if isinstance(data, (pathlib.Path, str)):
open_dataset_kws = open_dataset_kws or {}
open_dataset_kws = {"chunks": "auto"} | open_dataset_kws
dset = xr.open_dataset(data, **open_dataset_kws)
else:
dset = data

if intensity is None:
data_var_names = list(dset.data_vars.keys())
if len(data_var_names) == 0:
raise ValueError("Dataset has no data variables")
intensity = data_var_names[0]
LOGGER.info(
"No intensity variable specified. "
"Assuming intensity variable is '%s'",
intensity,
)

# Extract forecast coordinates
coordinate_vars = coordinate_vars or {}
for key in ["lead_time", "member"]:
if key not in coordinate_vars:
raise ValueError(
f"coordinate_vars must include '{key}' key. "
f"Available coordinates: {list(dset.coords.keys())}"
)
leadtime_var = coordinate_vars["lead_time"]
member_var = coordinate_vars["member"]

dset = dset.assign_coords(
event=(
(leadtime_var, member_var),
np.zeros((len(dset[leadtime_var]), len(dset[member_var]))),
)
)

dset_squeezed = dset.squeeze()

# Prepare coordinate_vars for parent call
parent_coord_vars = {
k: v for k, v in coordinate_vars.items() if k not in ["member", "lead_time"]
}
parent_coord_vars["event"] = "event"

reader = HazardXarrayReader(
data=dset_squeezed,
coordinate_vars=parent_coord_vars,
intensity=intensity,
crs=crs,
)

kwargs = reader.get_hazard_kwargs() | {
"haz_type": hazard_type,
"units": intensity_unit,
"lead_time": reader.data_stacked[leadtime_var].to_numpy(),
"member": reader.data_stacked[member_var].to_numpy(),
}

# Generate from lead_time/member
kwargs["event_name"] = [
f"lt_{lt / np.timedelta64(1, 'h'):.0f}h_m_{m}"
for lt, m in zip(kwargs["lead_time"], kwargs["member"])
]
kwargs["date"] = np.zeros_like(kwargs["date"], dtype=int)

# Convert to HazardForecast with forecast attributes
return cls(**Hazard._check_and_cast_attrs(kwargs))

def _quantile(self, q: float, event_name: str | None = None):
"""
Reduce the impact matrix and at_event of a HazardForecast to the quantile value.
Expand Down
163 changes: 163 additions & 0 deletions climada/hazard/test/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,29 @@
Tests for Hazard Forecast.
"""

import datetime as dt
from pathlib import Path

import numpy as np
import numpy.testing as npt
import pandas as pd
import pytest
import xarray as xr
from packaging.version import Version
from scipy.sparse import csr_matrix

from climada.hazard.base import Hazard
from climada.hazard.centroids.centr import Centroids
from climada.hazard.forecast import HazardForecast
from climada.hazard.test.test_base import hazard_kwargs

# See https://docs.xarray.dev/en/stable/whats-new.html#id80
xarray_leadtime = pytest.mark.skipif(
(Version(xr.__version__) < Version("2025.07.0"))
and (Version(xr.__version__) >= Version("2025.04.0")),
reason="xarray timedelta bug",
)


@pytest.fixture
def haz_kwargs():
Expand Down Expand Up @@ -128,6 +141,156 @@ def test_type_fail(self, haz_fc, hazard):
Hazard.concat([haz_fc, hazard])


class TestXarrayReader:

@pytest.fixture()
def forecast_netcdf_file(self, tmp_path_factory):
"""Create a NetCDF file with forecast data structure"""
tmpdir = tmp_path_factory.mktemp("forecast_data")
netcdf_path = tmpdir / "forecast_data.nc"

crs = "EPSG:4326"

n_eps = 5
n_lead_time = 4
n_lat = 3
n_lon = 4

eps = np.array([3, 8, 13, 16, 20])
ref_time = np.array([dt.datetime(2025, 12, 8, 6, 0, 0)], dtype="datetime64[ns]")
lead_time_vals = pd.timedelta_range(
"3h", periods=n_lead_time, freq="2h"
).to_numpy()
lon = np.array([10.0, 10.5, 11.0, 11.5])
lat = np.array([45.0, 45.5, 46.0])

valid_time = ref_time[0] + lead_time_vals

np.random.seed(42)
intensity = np.random.rand(n_eps, 1, n_lead_time, n_lat, n_lon) * 10

# Create xarray Dataset
dset = xr.Dataset(
{
"__xarray_dataarray_variable__": (
["eps", "ref_time", "lead_time", "lat", "lon"],
intensity,
),
},
coords={
"eps": eps,
"ref_time": ref_time,
"lead_time": lead_time_vals,
"lon": lon,
"lat": lat,
"valid_time": (["lead_time"], valid_time),
},
)
dset.to_netcdf(netcdf_path)

return {
"path": netcdf_path,
"n_eps": n_eps,
"n_lead_time": n_lead_time,
"n_lat": n_lat,
"n_lon": n_lon,
"eps": eps,
"lead_time": lead_time_vals,
"lon": lon,
"lat": lat,
"crs": crs,
}

@xarray_leadtime
def test_from_xarray_raster_basic(self, forecast_netcdf_file):
"""Test basic loading of forecast hazard from xarray"""
haz_fc = HazardForecast.from_xarray_raster(
forecast_netcdf_file["path"],
hazard_type="PR",
intensity_unit="mm/h",
coordinate_vars={
"longitude": "lon",
"latitude": "lat",
"lead_time": "lead_time",
"member": "eps",
},
)

# Check that it's a HazardForecast instance
assert isinstance(haz_fc, HazardForecast)

# Check dimensions - after stacking, we should have n_eps * n_lead_time events
expected_n_events = (
forecast_netcdf_file["n_eps"] * forecast_netcdf_file["n_lead_time"]
)
assert len(haz_fc.event_id) == expected_n_events
assert len(haz_fc.lead_time) == expected_n_events
assert len(haz_fc.member) == expected_n_events

# Check that lead_time and member are correctly extracted
npt.assert_array_equal(np.unique(haz_fc.member), forecast_netcdf_file["eps"])

# Check intensity shape (events x centroids)
expected_n_centroids = (
forecast_netcdf_file["n_lat"] * forecast_netcdf_file["n_lon"]
)
assert haz_fc.intensity.shape == (expected_n_events, expected_n_centroids)

# Check centroids
assert len(haz_fc.centroids.lat) == expected_n_centroids
assert len(haz_fc.centroids.lon) == expected_n_centroids

@xarray_leadtime
def test_from_xarray_raster_event_names(self, forecast_netcdf_file):
"""Test that event names are auto-generated from lead_time and member"""
haz_fc = HazardForecast.from_xarray_raster(
forecast_netcdf_file["path"],
hazard_type="PR",
intensity_unit="mm/h",
coordinate_vars={
"longitude": "lon",
"latitude": "lat",
"lead_time": "lead_time",
"member": "eps",
},
crs=forecast_netcdf_file["crs"],
)

# Check that event names are generated with lead_time in hours
expected_n_events = (
forecast_netcdf_file["n_eps"] * forecast_netcdf_file["n_lead_time"]
)
assert len(haz_fc.event_name) == expected_n_events

event_names_expected = [
f"lt_{lt / np.timedelta64(1, 'h'):.0f}h_m_{mm}"
for lt, mm in zip(haz_fc.lead_time, haz_fc.member)
]
npt.assert_array_equal(haz_fc.event_name, event_names_expected)

@xarray_leadtime
def test_from_xarray_raster_dates(self, forecast_netcdf_file):
"""Test that dates are set to 0 for forecast events"""
haz_fc = HazardForecast.from_xarray_raster(
forecast_netcdf_file["path"],
hazard_type="PR",
intensity_unit="mm/h",
coordinate_vars={
"longitude": "lon",
"latitude": "lat",
"lead_time": "lead_time",
"member": "eps",
},
crs=forecast_netcdf_file["crs"],
)

# Check that all dates are 0 (undefined for forecast)
expected_n_events = (
forecast_netcdf_file["n_eps"] * forecast_netcdf_file["n_lead_time"]
)
npt.assert_array_equal(haz_fc.date, np.zeros(expected_n_events, dtype=int))


class TestSelect:

@pytest.mark.parametrize(
Expand Down
Loading
Loading