scirs2_datasets/streaming/
transforms.rs1use crate::error::DatasetsError;
10use crate::streaming::iterator::{NewStreamingIterator, StreamingDataChunk};
11use scirs2_core::ndarray::{Array1, Array2, Axis};
12
13type RowPredicate = Box<dyn Fn(&[f64]) -> bool + Send + Sync>;
15
16type FeatureMapFn = Box<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>;
18
19pub trait Transform: Send + Sync {
28 fn apply(&self, chunk: StreamingDataChunk) -> Result<StreamingDataChunk, DatasetsError>;
32}
33
34#[derive(Debug, Clone)]
43pub struct Normalize {
44 mean: Vec<f64>,
45 std: Vec<f64>,
46}
47
48impl Normalize {
49 pub fn fit(data: &Array2<f64>) -> Self {
51 let mean_arr = data
52 .mean_axis(Axis(0))
53 .unwrap_or_else(|| Array1::zeros(data.ncols()));
54 let std_arr = data.std_axis(Axis(0), 1.0);
56 Self {
57 mean: mean_arr.to_vec(),
58 std: std_arr.to_vec(),
59 }
60 }
61
62 pub fn fit_from_chunks(iter: &mut NewStreamingIterator) -> Result<Self, DatasetsError> {
67 let nf = iter.n_features();
68 if nf == 0 {
69 return Ok(Self {
70 mean: vec![],
71 std: vec![],
72 });
73 }
74
75 let mut count = 0usize;
76 let mut mean = vec![0.0f64; nf];
77 let mut m2 = vec![0.0f64; nf]; for chunk_res in iter.by_ref() {
80 let chunk = chunk_res?;
81 for row in chunk.features.rows() {
82 count += 1;
83 for (j, &val) in row.iter().enumerate() {
84 let delta = val - mean[j];
85 mean[j] += delta / count as f64;
86 let delta2 = val - mean[j];
87 m2[j] += delta * delta2;
88 }
89 }
90 }
91
92 iter.reset();
93
94 let std_dev: Vec<f64> = m2
95 .into_iter()
96 .map(|s| {
97 if count > 1 {
98 (s / (count - 1) as f64).sqrt()
99 } else {
100 0.0
101 }
102 })
103 .collect();
104
105 Ok(Self { mean, std: std_dev })
106 }
107
108 pub fn mean(&self) -> &[f64] {
110 &self.mean
111 }
112
113 pub fn std(&self) -> &[f64] {
115 &self.std
116 }
117}
118
119impl Transform for Normalize {
120 fn apply(&self, mut chunk: StreamingDataChunk) -> Result<StreamingDataChunk, DatasetsError> {
121 let nf = chunk.features.ncols();
122 if nf != self.mean.len() {
123 return Err(DatasetsError::InvalidFormat(format!(
124 "Normalize: chunk has {nf} features, but was fitted on {}",
125 self.mean.len()
126 )));
127 }
128 for mut row in chunk.features.rows_mut() {
129 for (j, val) in row.iter_mut().enumerate() {
130 let s = self.std[j];
131 if s > 0.0 {
132 *val = (*val - self.mean[j]) / s;
133 }
134 }
135 }
136 Ok(chunk)
137 }
138}
139
140pub struct Filter {
146 condition: RowPredicate,
147}
148
149impl Filter {
150 pub fn new(f: impl Fn(&[f64]) -> bool + Send + Sync + 'static) -> Self {
152 Self {
153 condition: Box::new(f),
154 }
155 }
156}
157
158impl Transform for Filter {
159 fn apply(&self, chunk: StreamingDataChunk) -> Result<StreamingDataChunk, DatasetsError> {
160 let nf = chunk.features.ncols();
161 let n_rows = chunk.features.nrows();
162
163 let mut keep_feat: Vec<f64> = Vec::new();
164 let mut keep_labels: Vec<f64> = Vec::new();
165 let mut kept = 0usize;
166
167 for i in 0..n_rows {
168 let row: Vec<f64> = chunk.features.row(i).to_vec();
169 if (self.condition)(&row) {
170 keep_feat.extend_from_slice(&row);
171 if let Some(ref lbls) = chunk.labels {
172 keep_labels.push(if i < lbls.len() { lbls[i] } else { 0.0 });
173 }
174 kept += 1;
175 }
176 }
177
178 let features = if kept == 0 {
179 Array2::zeros((0, nf.max(1)))
180 } else {
181 Array2::from_shape_vec((kept, nf), keep_feat)
182 .map_err(|e| DatasetsError::ComputationError(format!("Filter shape: {e}")))?
183 };
184
185 let labels = if chunk.labels.is_some() {
186 Some(keep_labels)
187 } else {
188 None
189 };
190
191 Ok(StreamingDataChunk {
192 features,
193 labels,
194 chunk_id: chunk.chunk_id,
195 })
196 }
197}
198
199pub struct MapFeatures {
209 transform: FeatureMapFn,
210}
211
212impl MapFeatures {
213 pub fn new(f: impl Fn(&Array1<f64>) -> Array1<f64> + Send + Sync + 'static) -> Self {
215 Self {
216 transform: Box::new(f),
217 }
218 }
219}
220
221impl Transform for MapFeatures {
222 fn apply(&self, chunk: StreamingDataChunk) -> Result<StreamingDataChunk, DatasetsError> {
223 let n_rows = chunk.features.nrows();
224 if n_rows == 0 {
225 return Ok(chunk);
226 }
227
228 let first_row = chunk.features.row(0).to_owned();
230 let first_out = (self.transform)(&first_row);
231 let out_nf = first_out.len();
232
233 let mut out_flat: Vec<f64> = Vec::with_capacity(n_rows * out_nf);
234 out_flat.extend(first_out.iter().copied());
235
236 for i in 1..n_rows {
237 let row = chunk.features.row(i).to_owned();
238 let out = (self.transform)(&row);
239 if out.len() != out_nf {
240 return Err(DatasetsError::InvalidFormat(format!(
241 "MapFeatures: row {i} produced {} features, expected {out_nf}",
242 out.len()
243 )));
244 }
245 out_flat.extend(out.iter().copied());
246 }
247
248 let features = Array2::from_shape_vec((n_rows, out_nf), out_flat)
249 .map_err(|e| DatasetsError::ComputationError(format!("MapFeatures shape: {e}")))?;
250
251 Ok(StreamingDataChunk {
252 features,
253 labels: chunk.labels,
254 chunk_id: chunk.chunk_id,
255 })
256 }
257}
258
259pub struct TransformPipeline {
266 transforms: Vec<Box<dyn Transform>>,
267}
268
269impl TransformPipeline {
270 pub fn new() -> Self {
272 Self {
273 transforms: Vec::new(),
274 }
275 }
276
277 #[allow(clippy::should_implement_trait)]
279 pub fn add(mut self, t: impl Transform + 'static) -> Self {
280 self.transforms.push(Box::new(t));
281 self
282 }
283
284 pub fn apply_chunk(
286 &self,
287 chunk: StreamingDataChunk,
288 ) -> Result<StreamingDataChunk, DatasetsError> {
289 let mut current = chunk;
290 for transform in &self.transforms {
291 current = transform.apply(current)?;
292 }
293 Ok(current)
294 }
295
296 pub fn len(&self) -> usize {
298 self.transforms.len()
299 }
300
301 pub fn is_empty(&self) -> bool {
303 self.transforms.is_empty()
304 }
305}
306
307impl Default for TransformPipeline {
308 fn default() -> Self {
309 Self::new()
310 }
311}
312
313#[cfg(test)]
318mod tests {
319 use super::*;
320 use crate::streaming::iterator::{DataSource, NewStreamingIterator, StreamingIteratorConfig};
321 use scirs2_core::ndarray::Array2;
322
323 fn make_chunk(data: Vec<Vec<f64>>) -> StreamingDataChunk {
324 let n = data.len();
325 let f = if n == 0 { 1 } else { data[0].len() };
326 let flat: Vec<f64> = data.into_iter().flatten().collect();
327 StreamingDataChunk {
328 features: Array2::from_shape_vec((n, f), flat).expect("shape"),
329 labels: None,
330 chunk_id: 0,
331 }
332 }
333
334 #[test]
335 fn test_normalize_transform() {
336 let data = vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]];
338 let arr =
339 Array2::from_shape_vec((3, 2), data.iter().flatten().copied().collect::<Vec<_>>())
340 .unwrap();
341 let norm = Normalize::fit(&arr);
342
343 let chunk = make_chunk(data);
344 let out = norm.apply(chunk).expect("normalize");
345
346 let col0_mean: f64 = out.features.column(0).mean().unwrap_or(0.0);
348 let col1_mean: f64 = out.features.column(1).mean().unwrap_or(0.0);
349 assert!(col0_mean.abs() < 1e-10, "col0 mean {col0_mean}");
350 assert!(col1_mean.abs() < 1e-10, "col1 mean {col1_mean}");
351
352 let col0_std = out.features.column(0).std(1.0);
353 assert!((col0_std - 1.0).abs() < 1e-10, "col0 std {col0_std}");
354 }
355
356 #[test]
357 fn test_filter_transform() {
358 let data = vec![
359 vec![1.0, 2.0],
360 vec![3.0, 4.0],
361 vec![5.0, 6.0],
362 vec![7.0, 8.0],
363 ];
364 let chunk = make_chunk(data);
365 let filter = Filter::new(|row| row[0] > 2.0);
367 let out = filter.apply(chunk).expect("filter");
368 assert_eq!(out.n_rows(), 3);
369 assert!(out.features.column(0).iter().all(|&v| v > 2.0));
370 }
371
372 #[test]
373 fn test_filter_all_removed() {
374 let data = vec![vec![1.0], vec![2.0], vec![3.0]];
375 let chunk = make_chunk(data);
376 let filter = Filter::new(|row| row[0] > 100.0);
377 let out = filter.apply(chunk).expect("filter");
378 assert_eq!(out.n_rows(), 0);
379 }
380
381 #[test]
382 fn test_map_features_double() {
383 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
384 let chunk = make_chunk(data);
385 let map = MapFeatures::new(|row| row.mapv(|x| x * 2.0));
386 let out = map.apply(chunk).expect("map");
387 assert_eq!(out.features[[0, 0]], 2.0);
388 assert_eq!(out.features[[0, 1]], 4.0);
389 assert_eq!(out.features[[1, 0]], 6.0);
390 }
391
392 #[test]
393 fn test_transform_pipeline() {
394 let rows: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64, (i * 2) as f64]).collect();
395 let arr =
396 Array2::from_shape_vec((10, 2), rows.iter().flatten().copied().collect::<Vec<_>>())
397 .unwrap();
398 let norm = Normalize::fit(&arr);
399
400 let pipeline = TransformPipeline::new()
402 .add(norm)
403 .add(Filter::new(|row| row[0] >= -0.5))
404 .add(MapFeatures::new(|row| row.mapv(|x| x * 2.0)));
405
406 assert_eq!(pipeline.len(), 3);
407
408 let chunk = make_chunk(rows);
409 let out = pipeline.apply_chunk(chunk).expect("pipeline");
410 assert!(out.n_rows() > 0);
412 }
413
414 #[test]
415 fn test_normalize_fit_from_chunks() {
416 let rows: Vec<Vec<f64>> = (0..30_usize)
417 .map(|i| vec![(i % 10) as f64, ((i % 5) * 2) as f64])
418 .collect();
419 let config = StreamingIteratorConfig {
420 chunk_size: 10,
421 ..Default::default()
422 };
423 let mut iter =
424 NewStreamingIterator::new(DataSource::InMemory(rows.clone()), config).expect("iter");
425 let norm = Normalize::fit_from_chunks(&mut iter).expect("fit");
426
427 let expected_mean0: f64 = rows.iter().map(|r| r[0]).sum::<f64>() / rows.len() as f64;
429 assert!((norm.mean()[0] - expected_mean0).abs() < 1e-10);
430 assert!(norm.std()[0] > 0.0);
432 assert!(norm.std()[1] > 0.0);
433 }
434
435 #[test]
436 fn test_pipeline_empty_chunk() {
437 let chunk = StreamingDataChunk {
438 features: Array2::zeros((0, 3)),
439 labels: None,
440 chunk_id: 0,
441 };
442 let map = MapFeatures::new(|row| row.mapv(|x| x + 1.0));
443 let out = map.apply(chunk).expect("map empty");
444 assert_eq!(out.n_rows(), 0);
445 }
446}