Skip to main content

surtgis_core/
streaming.rs

1//! Streaming processor for window-based raster algorithms.
2//!
3//! Processes rasters strip-by-strip with bounded memory, enabling
4//! arbitrarily large DEMs for algorithms that only need a local window.
5
6use std::path::Path;
7
8use ndarray::Array2;
9
10use crate::error::Result;
11use crate::io::strip_reader::StripReader;
12use crate::io::strip_writer::{StripWriterConfig, write_geotiff_streaming};
13
14/// Trait for algorithms that operate on a moving window.
15///
16/// Implementations define the kernel radius and how to process a chunk
17/// of rows. The [`StripProcessor`] handles I/O and buffer management.
18pub trait WindowAlgorithm: Send + Sync {
19    /// Kernel radius (1 for 3x3, 10 for 21x21).
20    fn kernel_radius(&self) -> usize;
21
22    /// Process a chunk of input rows, producing output rows.
23    ///
24    /// `input` has `(chunk_rows + top_pad + bottom_pad)` rows x cols columns,
25    /// where the padding is the halo clamped to image bounds.
26    /// `output` has `chunk_rows` rows x cols columns.
27    /// The algorithm should write results for the center rows only.
28    fn process_chunk(
29        &self,
30        input: &Array2<f64>,
31        output: &mut Array2<f64>,
32        nodata: Option<f64>,
33        cell_size_x: f64,
34        cell_size_y: f64,
35    );
36}
37
38/// Streaming processor that reads and writes GeoTIFF strip-by-strip.
39///
40/// Memory usage is bounded to approximately
41/// `(chunk_rows + 2 * radius) * cols * 8` bytes for input plus
42/// `chunk_rows * cols * 8` bytes for output.
43pub struct StripProcessor {
44    /// Number of output rows per chunk (default: 256).
45    pub chunk_rows: usize,
46}
47
48impl StripProcessor {
49    /// Create a new `StripProcessor` with the given chunk size.
50    pub fn new(chunk_rows: usize) -> Self {
51        Self { chunk_rows }
52    }
53
54    /// Process an entire raster file using streaming I/O.
55    ///
56    /// Returns `(rows, cols)` of the processed raster.
57    pub fn process<A: WindowAlgorithm>(
58        &self,
59        input_path: &Path,
60        output_path: &Path,
61        algorithm: &A,
62        compress: bool,
63    ) -> Result<(usize, usize)> {
64        let reader = StripReader::open(input_path)?;
65        let rows = reader.rows();
66        let cols = reader.cols();
67        let radius = algorithm.kernel_radius();
68
69        let nodata = reader.nodata();
70
71        let config = StripWriterConfig {
72            rows,
73            cols,
74            transform: reader.transform().clone(),
75            crs: reader.crs().cloned(),
76            nodata: nodata.or(Some(f64::NAN)),
77            compress,
78            rows_per_strip: self.chunk_rows as u32,
79        };
80        let cell_size_x = reader.transform().pixel_width.abs();
81        let cell_size_y = reader.transform().pixel_height.abs();
82
83        // Use RefCell to share the reader with the sequential write callback.
84        let reader_cell = std::cell::RefCell::new(reader);
85
86        let mut current_out_row = 0usize;
87
88        write_geotiff_streaming(output_path, &config, |_strip_idx, out_strip_rows| {
89            let mut reader = reader_cell.borrow_mut();
90
91            // Determine output row range
92            let out_start = current_out_row;
93            let out_end = (current_out_row + out_strip_rows).min(rows);
94            let actual_out_rows = out_end - out_start;
95
96            // Input range with halo (clamped to image bounds)
97            let in_start = out_start.saturating_sub(radius);
98            let in_end = (out_end + radius).min(rows);
99
100            // Read the available input rows
101            let raw_input = reader.read_rows(in_start, in_end - in_start)?;
102
103            // Build a padded input buffer that always has exactly
104            // (actual_out_rows + 2*radius) rows with the output centered.
105            // Rows outside the image are filled with NaN.
106            let padded_rows = actual_out_rows + 2 * radius;
107            let mut input = Array2::<f64>::from_elem((padded_rows, cols), f64::NAN);
108
109            // How many halo rows are missing at the top?
110            let top_pad = radius.saturating_sub(out_start);
111            // Copy raw data into the padded buffer at the right offset
112            let copy_rows = raw_input.nrows().min(padded_rows - top_pad);
113            input
114                .slice_mut(ndarray::s![top_pad..top_pad + copy_rows, ..])
115                .assign(&raw_input.slice(ndarray::s![..copy_rows, ..]));
116
117            // Prepare output buffer
118            let mut output = Array2::<f64>::from_elem((actual_out_rows, cols), f64::NAN);
119
120            // Process: input[radius..radius+actual_out_rows] are the center rows
121            algorithm.process_chunk(&input, &mut output, nodata, cell_size_x, cell_size_y);
122
123            current_out_row = out_end;
124            Ok(output)
125        })?;
126
127        Ok((rows, cols))
128    }
129}