Skip to main content

scirs2_datasets/streaming/
transforms.rs

1//! Lazy dataset transformations for streaming pipelines.
2//!
3//! Provides a composable [`Transform`] trait and concrete implementations
4//! (`Normalize`, `Filter`, `MapFeatures`) that can be chained into a
5//! [`TransformPipeline`].  All transforms operate on [`StreamingDataChunk`]
6//! values produced by [`NewStreamingIterator`], enabling fully lazy, zero-
7//! intermediate-copy processing.
8
9use crate::error::DatasetsError;
10use crate::streaming::iterator::{NewStreamingIterator, StreamingDataChunk};
11use scirs2_core::ndarray::{Array1, Array2, Axis};
12
13/// Type alias for a boxed row-level predicate used by [`Filter`].
14type RowPredicate = Box<dyn Fn(&[f64]) -> bool + Send + Sync>;
15
16/// Type alias for a boxed feature-mapping function used by [`MapFeatures`].
17type FeatureMapFn = Box<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>;
18
19// ---------------------------------------------------------------------------
20// Transform trait
21// ---------------------------------------------------------------------------
22
23/// A stateless (or internally-mutable) operation on a [`StreamingDataChunk`].
24///
25/// Implementing types must be `Send + Sync` so that pipelines can be safely
26/// moved across threads.
27pub trait Transform: Send + Sync {
28    /// Apply the transformation to `chunk`, returning a (potentially new)
29    /// chunk.  Implementations may mutate in place and return the same chunk,
30    /// or allocate a new one.
31    fn apply(&self, chunk: StreamingDataChunk) -> Result<StreamingDataChunk, DatasetsError>;
32}
33
34// ---------------------------------------------------------------------------
35// Normalize
36// ---------------------------------------------------------------------------
37
38/// Per-feature z-score normalisation: `x ← (x − mean) / std`.
39///
40/// Features with zero standard deviation are left unchanged (i.e. the
41/// column remains as-is rather than becoming NaN).
42#[derive(Debug, Clone)]
43pub struct Normalize {
44    mean: Vec<f64>,
45    std: Vec<f64>,
46}
47
48impl Normalize {
49    /// Fit from a single `Array2<f64>` (all rows visible at once).
50    pub fn fit(data: &Array2<f64>) -> Self {
51        let mean_arr = data
52            .mean_axis(Axis(0))
53            .unwrap_or_else(|| Array1::zeros(data.ncols()));
54        // Use sample standard deviation (ddof=1) to match sklearn/PyTorch convention
55        let std_arr = data.std_axis(Axis(0), 1.0);
56        Self {
57            mean: mean_arr.to_vec(),
58            std: std_arr.to_vec(),
59        }
60    }
61
62    /// Incremental fit over all chunks produced by `iter`.
63    ///
64    /// Uses Welford's online algorithm to compute mean and variance in a
65    /// single pass, consuming (then resetting) the iterator.
66    pub fn fit_from_chunks(iter: &mut NewStreamingIterator) -> Result<Self, DatasetsError> {
67        let nf = iter.n_features();
68        if nf == 0 {
69            return Ok(Self {
70                mean: vec![],
71                std: vec![],
72            });
73        }
74
75        let mut count = 0usize;
76        let mut mean = vec![0.0f64; nf];
77        let mut m2 = vec![0.0f64; nf]; // sum of squared deviations
78
79        for chunk_res in iter.by_ref() {
80            let chunk = chunk_res?;
81            for row in chunk.features.rows() {
82                count += 1;
83                for (j, &val) in row.iter().enumerate() {
84                    let delta = val - mean[j];
85                    mean[j] += delta / count as f64;
86                    let delta2 = val - mean[j];
87                    m2[j] += delta * delta2;
88                }
89            }
90        }
91
92        iter.reset();
93
94        let std_dev: Vec<f64> = m2
95            .into_iter()
96            .map(|s| {
97                if count > 1 {
98                    (s / (count - 1) as f64).sqrt()
99                } else {
100                    0.0
101                }
102            })
103            .collect();
104
105        Ok(Self { mean, std: std_dev })
106    }
107
108    /// Access fitted means (one per feature).
109    pub fn mean(&self) -> &[f64] {
110        &self.mean
111    }
112
113    /// Access fitted standard deviations (one per feature).
114    pub fn std(&self) -> &[f64] {
115        &self.std
116    }
117}
118
119impl Transform for Normalize {
120    fn apply(&self, mut chunk: StreamingDataChunk) -> Result<StreamingDataChunk, DatasetsError> {
121        let nf = chunk.features.ncols();
122        if nf != self.mean.len() {
123            return Err(DatasetsError::InvalidFormat(format!(
124                "Normalize: chunk has {nf} features, but was fitted on {}",
125                self.mean.len()
126            )));
127        }
128        for mut row in chunk.features.rows_mut() {
129            for (j, val) in row.iter_mut().enumerate() {
130                let s = self.std[j];
131                if s > 0.0 {
132                    *val = (*val - self.mean[j]) / s;
133                }
134            }
135        }
136        Ok(chunk)
137    }
138}
139
140// ---------------------------------------------------------------------------
141// Filter
142// ---------------------------------------------------------------------------
143
144/// Row-level filter: keeps only rows for which `condition(&row) == true`.
145pub struct Filter {
146    condition: RowPredicate,
147}
148
149impl Filter {
150    /// Create a filter from an arbitrary predicate on a row's feature slice.
151    pub fn new(f: impl Fn(&[f64]) -> bool + Send + Sync + 'static) -> Self {
152        Self {
153            condition: Box::new(f),
154        }
155    }
156}
157
158impl Transform for Filter {
159    fn apply(&self, chunk: StreamingDataChunk) -> Result<StreamingDataChunk, DatasetsError> {
160        let nf = chunk.features.ncols();
161        let n_rows = chunk.features.nrows();
162
163        let mut keep_feat: Vec<f64> = Vec::new();
164        let mut keep_labels: Vec<f64> = Vec::new();
165        let mut kept = 0usize;
166
167        for i in 0..n_rows {
168            let row: Vec<f64> = chunk.features.row(i).to_vec();
169            if (self.condition)(&row) {
170                keep_feat.extend_from_slice(&row);
171                if let Some(ref lbls) = chunk.labels {
172                    keep_labels.push(if i < lbls.len() { lbls[i] } else { 0.0 });
173                }
174                kept += 1;
175            }
176        }
177
178        let features = if kept == 0 {
179            Array2::zeros((0, nf.max(1)))
180        } else {
181            Array2::from_shape_vec((kept, nf), keep_feat)
182                .map_err(|e| DatasetsError::ComputationError(format!("Filter shape: {e}")))?
183        };
184
185        let labels = if chunk.labels.is_some() {
186            Some(keep_labels)
187        } else {
188            None
189        };
190
191        Ok(StreamingDataChunk {
192            features,
193            labels,
194            chunk_id: chunk.chunk_id,
195        })
196    }
197}
198
199// ---------------------------------------------------------------------------
200// MapFeatures
201// ---------------------------------------------------------------------------
202
203/// Row-level feature mapping: applies a function `Array1<f64> → Array1<f64>`
204/// to every row independently.
205///
206/// The output dimensionality may differ from the input; all rows must produce
207/// the same output length.
208pub struct MapFeatures {
209    transform: FeatureMapFn,
210}
211
212impl MapFeatures {
213    /// Create a feature map from an arbitrary function.
214    pub fn new(f: impl Fn(&Array1<f64>) -> Array1<f64> + Send + Sync + 'static) -> Self {
215        Self {
216            transform: Box::new(f),
217        }
218    }
219}
220
221impl Transform for MapFeatures {
222    fn apply(&self, chunk: StreamingDataChunk) -> Result<StreamingDataChunk, DatasetsError> {
223        let n_rows = chunk.features.nrows();
224        if n_rows == 0 {
225            return Ok(chunk);
226        }
227
228        // Apply transform to the first row to discover output dimensionality
229        let first_row = chunk.features.row(0).to_owned();
230        let first_out = (self.transform)(&first_row);
231        let out_nf = first_out.len();
232
233        let mut out_flat: Vec<f64> = Vec::with_capacity(n_rows * out_nf);
234        out_flat.extend(first_out.iter().copied());
235
236        for i in 1..n_rows {
237            let row = chunk.features.row(i).to_owned();
238            let out = (self.transform)(&row);
239            if out.len() != out_nf {
240                return Err(DatasetsError::InvalidFormat(format!(
241                    "MapFeatures: row {i} produced {} features, expected {out_nf}",
242                    out.len()
243                )));
244            }
245            out_flat.extend(out.iter().copied());
246        }
247
248        let features = Array2::from_shape_vec((n_rows, out_nf), out_flat)
249            .map_err(|e| DatasetsError::ComputationError(format!("MapFeatures shape: {e}")))?;
250
251        Ok(StreamingDataChunk {
252            features,
253            labels: chunk.labels,
254            chunk_id: chunk.chunk_id,
255        })
256    }
257}
258
259// ---------------------------------------------------------------------------
260// TransformPipeline
261// ---------------------------------------------------------------------------
262
263/// An ordered sequence of [`Transform`] steps applied in the order they were
264/// added.
265pub struct TransformPipeline {
266    transforms: Vec<Box<dyn Transform>>,
267}
268
269impl TransformPipeline {
270    /// Create an empty pipeline.
271    pub fn new() -> Self {
272        Self {
273            transforms: Vec::new(),
274        }
275    }
276
277    /// Append a transform step and return `self` (builder pattern).
278    #[allow(clippy::should_implement_trait)]
279    pub fn add(mut self, t: impl Transform + 'static) -> Self {
280        self.transforms.push(Box::new(t));
281        self
282    }
283
284    /// Apply all transforms in order to `chunk`.
285    pub fn apply_chunk(
286        &self,
287        chunk: StreamingDataChunk,
288    ) -> Result<StreamingDataChunk, DatasetsError> {
289        let mut current = chunk;
290        for transform in &self.transforms {
291            current = transform.apply(current)?;
292        }
293        Ok(current)
294    }
295
296    /// Number of transforms in this pipeline.
297    pub fn len(&self) -> usize {
298        self.transforms.len()
299    }
300
301    /// Returns `true` if no transforms have been added.
302    pub fn is_empty(&self) -> bool {
303        self.transforms.is_empty()
304    }
305}
306
307impl Default for TransformPipeline {
308    fn default() -> Self {
309        Self::new()
310    }
311}
312
313// ---------------------------------------------------------------------------
314// Tests
315// ---------------------------------------------------------------------------
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use crate::streaming::iterator::{DataSource, NewStreamingIterator, StreamingIteratorConfig};
321    use scirs2_core::ndarray::Array2;
322
323    fn make_chunk(data: Vec<Vec<f64>>) -> StreamingDataChunk {
324        let n = data.len();
325        let f = if n == 0 { 1 } else { data[0].len() };
326        let flat: Vec<f64> = data.into_iter().flatten().collect();
327        StreamingDataChunk {
328            features: Array2::from_shape_vec((n, f), flat).expect("shape"),
329            labels: None,
330            chunk_id: 0,
331        }
332    }
333
334    #[test]
335    fn test_normalize_transform() {
336        // Build data with known mean/std
337        let data = vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]];
338        let arr =
339            Array2::from_shape_vec((3, 2), data.iter().flatten().copied().collect::<Vec<_>>())
340                .unwrap();
341        let norm = Normalize::fit(&arr);
342
343        let chunk = make_chunk(data);
344        let out = norm.apply(chunk).expect("normalize");
345
346        // After normalisation the column means should be ≈ 0 and stds ≈ 1
347        let col0_mean: f64 = out.features.column(0).mean().unwrap_or(0.0);
348        let col1_mean: f64 = out.features.column(1).mean().unwrap_or(0.0);
349        assert!(col0_mean.abs() < 1e-10, "col0 mean {col0_mean}");
350        assert!(col1_mean.abs() < 1e-10, "col1 mean {col1_mean}");
351
352        let col0_std = out.features.column(0).std(1.0);
353        assert!((col0_std - 1.0).abs() < 1e-10, "col0 std {col0_std}");
354    }
355
356    #[test]
357    fn test_filter_transform() {
358        let data = vec![
359            vec![1.0, 2.0],
360            vec![3.0, 4.0],
361            vec![5.0, 6.0],
362            vec![7.0, 8.0],
363        ];
364        let chunk = make_chunk(data);
365        // Keep rows where first feature > 2
366        let filter = Filter::new(|row| row[0] > 2.0);
367        let out = filter.apply(chunk).expect("filter");
368        assert_eq!(out.n_rows(), 3);
369        assert!(out.features.column(0).iter().all(|&v| v > 2.0));
370    }
371
372    #[test]
373    fn test_filter_all_removed() {
374        let data = vec![vec![1.0], vec![2.0], vec![3.0]];
375        let chunk = make_chunk(data);
376        let filter = Filter::new(|row| row[0] > 100.0);
377        let out = filter.apply(chunk).expect("filter");
378        assert_eq!(out.n_rows(), 0);
379    }
380
381    #[test]
382    fn test_map_features_double() {
383        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
384        let chunk = make_chunk(data);
385        let map = MapFeatures::new(|row| row.mapv(|x| x * 2.0));
386        let out = map.apply(chunk).expect("map");
387        assert_eq!(out.features[[0, 0]], 2.0);
388        assert_eq!(out.features[[0, 1]], 4.0);
389        assert_eq!(out.features[[1, 0]], 6.0);
390    }
391
392    #[test]
393    fn test_transform_pipeline() {
394        let rows: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64, (i * 2) as f64]).collect();
395        let arr =
396            Array2::from_shape_vec((10, 2), rows.iter().flatten().copied().collect::<Vec<_>>())
397                .unwrap();
398        let norm = Normalize::fit(&arr);
399
400        // Pipeline: normalise → filter out rows with col0 < 0 → double values
401        let pipeline = TransformPipeline::new()
402            .add(norm)
403            .add(Filter::new(|row| row[0] >= -0.5))
404            .add(MapFeatures::new(|row| row.mapv(|x| x * 2.0)));
405
406        assert_eq!(pipeline.len(), 3);
407
408        let chunk = make_chunk(rows);
409        let out = pipeline.apply_chunk(chunk).expect("pipeline");
410        // After normalisation + filter, some rows should remain
411        assert!(out.n_rows() > 0);
412    }
413
414    #[test]
415    fn test_normalize_fit_from_chunks() {
416        let rows: Vec<Vec<f64>> = (0..30_usize)
417            .map(|i| vec![(i % 10) as f64, ((i % 5) * 2) as f64])
418            .collect();
419        let config = StreamingIteratorConfig {
420            chunk_size: 10,
421            ..Default::default()
422        };
423        let mut iter =
424            NewStreamingIterator::new(DataSource::InMemory(rows.clone()), config).expect("iter");
425        let norm = Normalize::fit_from_chunks(&mut iter).expect("fit");
426
427        // Check mean is correct (should match the data's column means)
428        let expected_mean0: f64 = rows.iter().map(|r| r[0]).sum::<f64>() / rows.len() as f64;
429        assert!((norm.mean()[0] - expected_mean0).abs() < 1e-10);
430        // std should be positive
431        assert!(norm.std()[0] > 0.0);
432        assert!(norm.std()[1] > 0.0);
433    }
434
435    #[test]
436    fn test_pipeline_empty_chunk() {
437        let chunk = StreamingDataChunk {
438            features: Array2::zeros((0, 3)),
439            labels: None,
440            chunk_id: 0,
441        };
442        let map = MapFeatures::new(|row| row.mapv(|x| x + 1.0));
443        let out = map.apply(chunk).expect("map empty");
444        assert_eq!(out.n_rows(), 0);
445    }
446}