1use crate::error::{ClusterError, ClusterResult};
21use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
22use scirs2_core::parallel_ops::{IntoParallelIterator, ParallelIterator};
23use std::sync::Arc;
24use torsh_tensor::Tensor;
25
26#[derive(Debug, Clone)]
28pub struct MemoryEfficientConfig {
29 pub chunk_size: usize,
31 pub parallel: bool,
33 pub memory_limit_mb: Option<usize>,
35}
36
37impl Default for MemoryEfficientConfig {
38 fn default() -> Self {
39 Self {
40 chunk_size: 1000,
41 parallel: true,
42 memory_limit_mb: None,
43 }
44 }
45}
46
47pub struct ChunkedDataProcessor {
74 chunk_size: usize,
75 parallel: bool,
76}
77
78impl ChunkedDataProcessor {
79 pub fn new(chunk_size: usize) -> Self {
81 Self {
82 chunk_size,
83 parallel: true,
84 }
85 }
86
87 pub fn parallel(mut self, parallel: bool) -> Self {
89 self.parallel = parallel;
90 self
91 }
92
93 pub fn process<F>(&self, data: &Tensor, mut f: F) -> ClusterResult<()>
97 where
98 F: FnMut(ArrayView2<f32>) -> ClusterResult<()>,
99 {
100 let shape = data.shape();
101 let n_samples = shape.dims()[0];
102 let n_features = shape.dims()[1];
103
104 let data_vec = data.to_vec()?;
106 let data_array = Array2::from_shape_vec((n_samples, n_features), data_vec)
107 .map_err(|e| ClusterError::InvalidInput(format!("Shape error: {}", e)))?;
108
109 for start_idx in (0..n_samples).step_by(self.chunk_size) {
111 let end_idx = (start_idx + self.chunk_size).min(n_samples);
112 let chunk = data_array.slice(s![start_idx..end_idx, ..]);
113 f(chunk)?;
114 }
115
116 Ok(())
117 }
118
119 pub fn process_parallel<F, R>(&self, data: &Tensor, f: F) -> ClusterResult<Vec<R>>
124 where
125 F: Fn(ArrayView2<f32>) -> ClusterResult<R> + Send + Sync,
126 R: Send,
127 {
128 let shape = data.shape();
129 let n_samples = shape.dims()[0];
130 let n_features = shape.dims()[1];
131
132 let data_vec = data.to_vec()?;
134 let data_array = Array2::from_shape_vec((n_samples, n_features), data_vec)
135 .map_err(|e| ClusterError::InvalidInput(format!("Shape error: {}", e)))?;
136
137 let data_arc = Arc::new(data_array);
139
140 let chunks: Vec<(usize, usize)> = (0..n_samples)
142 .step_by(self.chunk_size)
143 .map(|start| {
144 let end = (start + self.chunk_size).min(n_samples);
145 (start, end)
146 })
147 .collect();
148
149 if !self.parallel || chunks.len() <= 1 {
150 let results: Result<Vec<R>, ClusterError> = chunks
152 .iter()
153 .map(|(start, end)| {
154 let chunk = data_arc.slice(s![*start..*end, ..]);
155 f(chunk)
156 })
157 .collect();
158 return results;
159 }
160
161 let results: Result<Vec<R>, ClusterError> = chunks
163 .into_par_iter()
164 .map(|(start, end)| {
165 let chunk = data_arc.slice(s![start..end, ..]);
166 f(chunk)
167 })
168 .collect();
169
170 results
171 }
172
173 pub fn optimal_chunk_size(
175 n_samples: usize,
176 n_features: usize,
177 available_memory_mb: usize,
178 ) -> usize {
179 let bytes_per_sample = n_features * std::mem::size_of::<f32>();
181 let available_bytes = available_memory_mb * 1024 * 1024;
182
183 let safe_bytes = (available_bytes as f64 * 0.8) as usize;
185
186 let chunk_size = safe_bytes / bytes_per_sample;
188
189 chunk_size.max(10).min(n_samples)
191 }
192}
193
194pub struct IncrementalCentroidUpdater {
211 centroids: Array2<f64>,
213 counts: Array1<usize>,
215 n_samples: usize,
217}
218
219impl IncrementalCentroidUpdater {
220 pub fn new(n_clusters: usize, n_features: usize) -> Self {
222 Self {
223 centroids: Array2::zeros((n_clusters, n_features)),
224 counts: Array1::zeros(n_clusters),
225 n_samples: 0,
226 }
227 }
228
229 pub fn initialize(&mut self, initial_centroids: ArrayView2<f64>) -> ClusterResult<()> {
231 let (n_clusters, n_features) = initial_centroids.dim();
232
233 if (n_clusters, n_features) != self.centroids.dim() {
234 return Err(ClusterError::InvalidInput(format!(
235 "Expected {} clusters and {} features, got {} and {}",
236 self.centroids.nrows(),
237 self.centroids.ncols(),
238 n_clusters,
239 n_features
240 )));
241 }
242
243 self.centroids.assign(&initial_centroids);
244 self.counts.fill(1); self.n_samples = n_clusters;
246
247 Ok(())
248 }
249
250 pub fn update_batch(
254 &mut self,
255 samples: ArrayView2<f64>,
256 labels: &[usize],
257 ) -> ClusterResult<()> {
258 if samples.nrows() != labels.len() {
259 return Err(ClusterError::InvalidInput(format!(
260 "Sample count {} doesn't match label count {}",
261 samples.nrows(),
262 labels.len()
263 )));
264 }
265
266 for (sample, &label) in samples.outer_iter().zip(labels.iter()) {
268 if label >= self.centroids.nrows() {
269 return Err(ClusterError::InvalidInput(format!(
270 "Label {} exceeds number of clusters {}",
271 label,
272 self.centroids.nrows()
273 )));
274 }
275
276 let count = self.counts[label];
277 let mut centroid = self.centroids.row_mut(label);
278
279 for (i, &value) in sample.iter().enumerate() {
281 centroid[i] += (value - centroid[i]) / (count + 1) as f64;
282 }
283
284 self.counts[label] += 1;
285 }
286
287 self.n_samples += samples.nrows();
288
289 Ok(())
290 }
291
292 pub fn centroids(&self) -> ArrayView2<'_, f64> {
294 self.centroids.view()
295 }
296
297 pub fn counts(&self) -> &Array1<usize> {
299 &self.counts
300 }
301
302 pub fn n_samples(&self) -> usize {
304 self.n_samples
305 }
306}
307
308pub fn estimate_memory_usage(n_samples: usize, n_features: usize, n_clusters: usize) -> f64 {
312 let data_size = n_samples * n_features * std::mem::size_of::<f32>();
314
315 let centroids_size = n_clusters * n_features * std::mem::size_of::<f64>();
317
318 let labels_size = n_samples * std::mem::size_of::<usize>();
320
321 let distances_size = n_samples * n_clusters * std::mem::size_of::<f32>();
323
324 let total_bytes = data_size + centroids_size + labels_size + distances_size;
326 total_bytes as f64 / (1024.0 * 1024.0)
327}
328
329pub fn suggest_clustering_strategy(
331 n_samples: usize,
332 n_features: usize,
333 available_memory_mb: usize,
334) -> String {
335 let estimated_mb = estimate_memory_usage(n_samples, n_features, 10); if estimated_mb < available_memory_mb as f64 * 0.5 {
338 format!(
339 "Standard clustering (estimated {:.2} MB, available {} MB)",
340 estimated_mb, available_memory_mb
341 )
342 } else if estimated_mb < available_memory_mb as f64 * 0.8 {
343 format!(
344 "Use parallel processing with caution (estimated {:.2} MB, available {} MB)",
345 estimated_mb, available_memory_mb
346 )
347 } else {
348 let chunk_size =
349 ChunkedDataProcessor::optimal_chunk_size(n_samples, n_features, available_memory_mb);
350 format!(
351 "Use chunked processing with chunk_size={} (estimated {:.2} MB exceeds available {} MB)",
352 chunk_size, estimated_mb, available_memory_mb
353 )
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use approx::assert_relative_eq;
361
362 #[test]
363 fn test_chunked_processor_basic() -> Result<(), Box<dyn std::error::Error>> {
364 let data = Tensor::from_vec((0..100).map(|i| i as f32).collect(), &[10, 10])?;
365
366 let processor = ChunkedDataProcessor::new(3);
367
368 let mut chunk_count = 0;
369 processor.process(&data, |chunk| {
370 chunk_count += 1;
371 assert!(chunk.nrows() <= 3);
372 Ok(())
373 })?;
374
375 assert_eq!(chunk_count, 4); Ok(())
378 }
379
380 #[test]
381 fn test_chunked_processor_parallel() -> Result<(), Box<dyn std::error::Error>> {
382 let data = Tensor::from_vec((0..100).map(|i| i as f32).collect(), &[10, 10])?;
383
384 let processor = ChunkedDataProcessor::new(3).parallel(true);
385
386 let results = processor.process_parallel(&data, |chunk| Ok(chunk.nrows()))?;
387
388 assert_eq!(results.len(), 4);
389 assert_eq!(results.iter().sum::<usize>(), 10);
390
391 Ok(())
392 }
393
394 #[test]
395 fn test_optimal_chunk_size() {
396 let chunk_size = ChunkedDataProcessor::optimal_chunk_size(1000, 100, 100);
398
399 assert!(chunk_size > 0);
403 assert!(chunk_size <= 1000);
404 }
405
406 #[test]
407 fn test_incremental_centroid_updater() -> Result<(), Box<dyn std::error::Error>> {
408 let mut updater = IncrementalCentroidUpdater::new(2, 3);
409
410 let initial = Array2::from_shape_vec((2, 3), vec![0.0, 0.0, 0.0, 5.0, 5.0, 5.0])?;
412 updater.initialize(initial.view())?;
413
414 let samples = Array2::from_shape_vec((2, 3), vec![1.0, 1.0, 1.0, 6.0, 6.0, 6.0])?;
416 let labels = vec![0, 1];
417 updater.update_batch(samples.view(), &labels)?;
418
419 let centroids = updater.centroids();
421 assert_relative_eq!(centroids[[0, 0]], 0.5, epsilon = 1e-6);
422 assert_relative_eq!(centroids[[1, 0]], 5.5, epsilon = 1e-6);
423
424 assert_eq!(updater.n_samples(), 4); Ok(())
427 }
428
429 #[test]
430 fn test_memory_estimation() {
431 let memory_mb = estimate_memory_usage(1000, 100, 10);
432
433 assert!(memory_mb > 0.4);
440 assert!(memory_mb < 0.5);
441 }
442
443 #[test]
444 fn test_suggest_clustering_strategy() {
445 let strategy = suggest_clustering_strategy(100, 10, 100);
447 assert!(strategy.contains("Standard"));
448
449 let strategy = suggest_clustering_strategy(1_000_000, 100, 10);
451 assert!(strategy.contains("chunked"));
452 }
453}