diff --git a/aeronet/dataset/raster/collection.py b/aeronet/dataset/raster/collection.py index 4d07160..dd44b09 100644 --- a/aeronet/dataset/raster/collection.py +++ b/aeronet/dataset/raster/collection.py @@ -158,10 +158,14 @@ def resample(self, dst_res, directory=None, interpolation='nearest'): r_bands.append(r_band) return BandCollection(r_bands) - def generate_samples(self, height, width): + def generate_samples(self, height, width, return_sample_coord=False): for x in range(0, self.width, width): for y in range(0, self.height, height): - yield self.sample(y, x, height, width) + if return_sample_coord: + yield ((y, x, height, width), + self.sample(y, x, height, width)) + else: + yield self.sample(y, x, height, width) def numpy(self): return self.sample(0, 0, self.height, self.width).numpy()