diff --git a/pyproject.toml b/pyproject.toml index a80fbf0..565362b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "rasters" -version = "1.15.1" +version = "1.16.0" description = "raster processing toolkit" readme = "README.md" authors = [ diff --git a/rasters/mosaic.py b/rasters/mosaic.py index 0f75f83..f1d45fe 100644 --- a/rasters/mosaic.py +++ b/rasters/mosaic.py @@ -6,11 +6,12 @@ if TYPE_CHECKING: from .raster import Raster + from .multi_raster import MultiRaster from .raster_geometry import RasterGeometry def mosaic( - images: Iterator[Union[Raster, str]], - geometry: RasterGeometry, + images: Iterator[Union[Raster, MultiRaster, str]], + geometry: RasterGeometry = None, resampling: str = "nearest") -> Raster: """ Creates a mosaic from a sequence of Raster images. @@ -30,12 +31,25 @@ def mosaic( """ from .where import where # Assuming 'where' is a function in the same package from .raster import Raster # Import Raster here to avoid circular dependency + from .multi_raster import MultiRaster # Import MultiRaster here to avoid circular dependency - mosaic = Raster(np.full(geometry.shape, np.nan), geometry=geometry) # Initialize with NaN values + if len(images) == 0: + raise ValueError("No images provided for mosaicking.") + + if len(images[0].shape == 2): + mosaic = Raster(np.full(geometry.shape, np.nan), geometry=geometry) + elif len(images[0].shape == 3): + mosaic = MultiRaster(np.full(geometry.shape, np.nan), geometry=geometry) + else: + raise ValueError("Unsupported image shape for mosaicking.") + dtype = None nodata = None metadata = None + if geometry is None: + geometry = images[0].geometry + for image in images: if isinstance(image, str): image = Raster.open(image) # Open the image if it's a file path @@ -50,6 +64,12 @@ def mosaic( mosaic = where(np.isnan(mosaic), image.to_geometry(geometry, resampling=resampling), mosaic) mosaic = mosaic.astype(dtype) # Set the data type of the mosaic - mosaic = Raster(mosaic, geometry=geometry, nodata=nodata, metadata=metadata) # Create the final Raster + + if len(mosaic.shape) == 2: + mosaic = Raster(mosaic, geometry=geometry, nodata=nodata, metadata=metadata) # Create the final Raster + elif len(mosaic.shape) == 3: + mosaic = MultiRaster(mosaic, geometry=geometry, nodata=nodata, metadata=metadata) # Create the final MultiRaster + else: + raise ValueError("Unsupported mosaic shape after processing.") return mosaic