1use scirs2_core::ndarray::ArrayStatCompat;
8use scirs2_core::ndarray::{ArrayView1, ArrayView2};
9use scirs2_core::numeric::{Float, FromPrimitive, ToPrimitive, Zero};
10use std::fmt::{Debug, Write};
11
12use crate::error::{NdimageError, NdimageResult};
13use crate::visualization::types::{PlotConfig, ReportFormat};
14use statrs::statistics::Statistics;
15
16#[allow(dead_code)]
51pub fn create_image_montage<T>(
52 images: &[ArrayView2<T>],
53 grid_cols: usize,
54 config: &PlotConfig,
55) -> NdimageResult<String>
56where
57 T: Float + FromPrimitive + ToPrimitive + Debug + Clone,
58{
59 if images.is_empty() {
60 return Err(NdimageError::InvalidInput("No images provided".into()));
61 }
62
63 if grid_cols == 0 {
64 return Err(NdimageError::InvalidInput(
65 "Grid columns must be positive".into(),
66 ));
67 }
68
69 let mut plot = String::new();
70 let grid_rows = (images.len() + grid_cols - 1) / grid_cols;
71
72 let mut global_min = T::infinity();
74 let mut global_max = T::neg_infinity();
75
76 for image in images {
77 let min_val = image.iter().cloned().fold(T::infinity(), T::min);
78 let max_val = image.iter().cloned().fold(T::neg_infinity(), T::max);
79 global_min = global_min.min(min_val);
80 global_max = global_max.max(max_val);
81 }
82
83 if global_max <= global_min {
84 return Err(NdimageError::InvalidInput(
85 "All image values are the same".into(),
86 ));
87 }
88
89 match config.format {
90 ReportFormat::Html => {
91 writeln!(&mut plot, "<div class='image-montage'>")?;
92 writeln!(&mut plot, "<h3>{}</h3>", config.title)?;
93 writeln!(&mut plot, "<div class='montage-grid' style='display: grid; grid-template-columns: repeat({}, 1fr); gap: 10px;'>", grid_cols)?;
94
95 for (idx, image) in images.iter().enumerate() {
96 let (height, width) = image.dim();
97 writeln!(&mut plot, "<div class='montage-cell'>")?;
98 writeln!(&mut plot, "<h4>Image {}</h4>", idx + 1)?;
99 writeln!(
100 &mut plot,
101 "<div class='image-data' data-width='{}' data-height='{}'>",
102 width, height
103 )?;
104
105 writeln!(&mut plot, "<p>{}×{} array</p>", height, width)?;
107 writeln!(
108 &mut plot,
109 "<p>Range: [{:.3}, {:.3}]</p>",
110 image
111 .iter()
112 .cloned()
113 .fold(T::infinity(), T::min)
114 .to_f64()
115 .unwrap_or(0.0),
116 image
117 .iter()
118 .cloned()
119 .fold(T::neg_infinity(), T::max)
120 .to_f64()
121 .unwrap_or(0.0)
122 )?;
123
124 writeln!(&mut plot, "</div>")?;
125 writeln!(&mut plot, "</div>")?;
126 }
127
128 writeln!(&mut plot, "</div>")?;
129 writeln!(&mut plot, "<div class='montage-info'>")?;
130 writeln!(
131 &mut plot,
132 "<p>Global range: [{:.3}, {:.3}]</p>",
133 global_min.to_f64().unwrap_or(0.0),
134 global_max.to_f64().unwrap_or(0.0)
135 )?;
136 writeln!(
137 &mut plot,
138 "<p>Grid: {} rows × {} columns</p>",
139 grid_rows, grid_cols
140 )?;
141 writeln!(&mut plot, "</div>")?;
142 writeln!(&mut plot, "</div>")?;
143 }
144 ReportFormat::Markdown => {
145 writeln!(&mut plot, "## {} (Image Montage)", config.title)?;
146 writeln!(&mut plot)?;
147 writeln!(
148 &mut plot,
149 "Grid layout: {} rows × {} columns",
150 grid_rows, grid_cols
151 )?;
152 writeln!(
153 &mut plot,
154 "Global value range: [{:.3}, {:.3}]",
155 global_min.to_f64().unwrap_or(0.0),
156 global_max.to_f64().unwrap_or(0.0)
157 )?;
158 writeln!(&mut plot)?;
159
160 for (idx, image) in images.iter().enumerate() {
161 let (height, width) = image.dim();
162 let min_val = image.iter().cloned().fold(T::infinity(), T::min);
163 let max_val = image.iter().cloned().fold(T::neg_infinity(), T::max);
164
165 writeln!(&mut plot, "### Image {}", idx + 1)?;
166 writeln!(&mut plot, "- Dimensions: {}×{}", height, width)?;
167 writeln!(
168 &mut plot,
169 "- Value range: [{:.3}, {:.3}]",
170 min_val.to_f64().unwrap_or(0.0),
171 max_val.to_f64().unwrap_or(0.0)
172 )?;
173 writeln!(&mut plot)?;
174 }
175 }
176 ReportFormat::Text => {
177 writeln!(&mut plot, "{} (Image Montage)", config.title)?;
178 writeln!(&mut plot, "{}", "=".repeat(config.title.len() + 16))?;
179 writeln!(&mut plot)?;
180 writeln!(
181 &mut plot,
182 "Grid layout: {} rows × {} columns",
183 grid_rows, grid_cols
184 )?;
185 writeln!(
186 &mut plot,
187 "Global value range: [{:.3}, {:.3}]",
188 global_min.to_f64().unwrap_or(0.0),
189 global_max.to_f64().unwrap_or(0.0)
190 )?;
191 writeln!(&mut plot)?;
192
193 for (idx, image) in images.iter().enumerate() {
194 let (height, width) = image.dim();
195 let min_val = image.iter().cloned().fold(T::infinity(), T::min);
196 let max_val = image.iter().cloned().fold(T::neg_infinity(), T::max);
197
198 writeln!(
199 &mut plot,
200 "Image {}: {}×{}, range [{:.3}, {:.3}]",
201 idx + 1,
202 height,
203 width,
204 min_val.to_f64().unwrap_or(0.0),
205 max_val.to_f64().unwrap_or(0.0)
206 )?;
207 }
208 }
209 }
210
211 Ok(plot)
212}
213
214#[allow(dead_code)]
251pub fn plot_statistical_comparison<T>(
252 datasets: &[(&str, ArrayView1<T>)],
253 config: &PlotConfig,
254) -> NdimageResult<String>
255where
256 T: Float + FromPrimitive + ToPrimitive + Debug + Clone,
257{
258 if datasets.is_empty() {
259 return Err(NdimageError::InvalidInput("No datasets provided".into()));
260 }
261
262 let mut plot = String::new();
263
264 let mut stats = Vec::new();
266 for (name, data) in datasets {
267 if data.is_empty() {
268 continue;
269 }
270
271 let mean = data.mean_or(T::zero());
272 let min_val = data.iter().cloned().fold(T::infinity(), T::min);
273 let max_val = data.iter().cloned().fold(T::neg_infinity(), T::max);
274 let variance = data
275 .mapv(|x| (x - mean) * (x - mean))
276 .mean()
277 .unwrap_or(T::zero());
278 let std_dev = variance.sqrt();
279
280 stats.push((name, mean, std_dev, min_val, max_val, data.len()));
281 }
282
283 match config.format {
284 ReportFormat::Html => {
285 writeln!(&mut plot, "<div class='statistical-comparison'>")?;
286 writeln!(&mut plot, "<h3>{}</h3>", config.title)?;
287 writeln!(&mut plot, "<table class='stats-table'>")?;
288 writeln!(&mut plot, "<tr><th>Dataset</th><th>Count</th><th>Mean</th><th>Std Dev</th><th>Min</th><th>Max</th></tr>")?;
289
290 for (name, mean, std_dev, min_val, max_val, count) in &stats {
291 writeln!(
292 &mut plot,
293 "<tr><td>{}</td><td>{}</td><td>{:.4}</td><td>{:.4}</td><td>{:.4}</td><td>{:.4}</td></tr>",
294 name, count,
295 mean.to_f64().unwrap_or(0.0),
296 std_dev.to_f64().unwrap_or(0.0),
297 min_val.to_f64().unwrap_or(0.0),
298 max_val.to_f64().unwrap_or(0.0)
299 )?;
300 }
301
302 writeln!(&mut plot, "</table>")?;
303 writeln!(&mut plot, "</div>")?;
304 }
305 ReportFormat::Markdown => {
306 writeln!(&mut plot, "## {} (Statistical Comparison)", config.title)?;
307 writeln!(&mut plot)?;
308 writeln!(
309 &mut plot,
310 "| Dataset | Count | Mean | Std Dev | Min | Max |"
311 )?;
312 writeln!(
313 &mut plot,
314 "|---------|-------|------|---------|-----|-----|"
315 )?;
316
317 for (name, mean, std_dev, min_val, max_val, count) in &stats {
318 writeln!(
319 &mut plot,
320 "| {} | {} | {:.4} | {:.4} | {:.4} | {:.4} |",
321 name,
322 count,
323 mean.to_f64().unwrap_or(0.0),
324 std_dev.to_f64().unwrap_or(0.0),
325 min_val.to_f64().unwrap_or(0.0),
326 max_val.to_f64().unwrap_or(0.0)
327 )?;
328 }
329 writeln!(&mut plot)?;
330 }
331 ReportFormat::Text => {
332 writeln!(&mut plot, "{} (Statistical Comparison)", config.title)?;
333 writeln!(&mut plot, "{}", "=".repeat(config.title.len() + 25))?;
334 writeln!(&mut plot)?;
335 writeln!(
336 &mut plot,
337 "{:<15} {:>8} {:>10} {:>10} {:>10} {:>10}",
338 "Dataset", "Count", "Mean", "Std Dev", "Min", "Max"
339 )?;
340 writeln!(&mut plot, "{}", "-".repeat(75))?;
341
342 for (name, mean, std_dev, min_val, max_val, count) in &stats {
343 writeln!(
344 &mut plot,
345 "{:<15} {:>8} {:>10.4} {:>10.4} {:>10.4} {:>10.4}",
346 name,
347 count,
348 mean.to_f64().unwrap_or(0.0),
349 std_dev.to_f64().unwrap_or(0.0),
350 min_val.to_f64().unwrap_or(0.0),
351 max_val.to_f64().unwrap_or(0.0)
352 )?;
353 }
354 }
355 }
356
357 Ok(plot)
358}
359
360pub fn calculate_dataset_statistics<T>(data: &ArrayView1<T>) -> (T, T, T, T, usize)
373where
374 T: Float + FromPrimitive + ToPrimitive + Debug + Clone,
375{
376 if data.is_empty() {
377 return (T::zero(), T::zero(), T::zero(), T::zero(), 0);
378 }
379
380 let mean = data.mean_or(T::zero());
381 let min_val = data.iter().cloned().fold(T::infinity(), T::min);
382 let max_val = data.iter().cloned().fold(T::neg_infinity(), T::max);
383 let variance = data
384 .mapv(|x| (x - mean) * (x - mean))
385 .mean()
386 .unwrap_or(T::zero());
387 let std_dev = variance.sqrt();
388
389 (mean, std_dev, min_val, max_val, data.len())
390}
391
392#[allow(dead_code)]
406pub fn plot_correlation_matrix<T>(
407 datasets: &[(&str, ArrayView1<T>)],
408 config: &PlotConfig,
409) -> NdimageResult<String>
410where
411 T: Float + FromPrimitive + ToPrimitive + Debug + Clone,
412{
413 if datasets.len() < 2 {
414 return Err(NdimageError::InvalidInput(
415 "Need at least 2 datasets for correlation".into(),
416 ));
417 }
418
419 let mut plot = String::new();
420 let n = datasets.len();
421
422 let mut correlations = vec![vec![0.0; n]; n];
424
425 for i in 0..n {
426 for j in 0..n {
427 if i == j {
428 correlations[i][j] = 1.0;
429 } else {
430 let corr = calculate_correlation(&datasets[i].1, &datasets[j].1);
431 correlations[i][j] = corr;
432 }
433 }
434 }
435
436 match config.format {
437 ReportFormat::Html => {
438 writeln!(&mut plot, "<div class='correlation-matrix'>")?;
439 writeln!(&mut plot, "<h3>{}</h3>", config.title)?;
440 writeln!(&mut plot, "<table class='correlation-table'>")?;
441
442 write!(&mut plot, "<tr><th></th>")?;
444 for (name, _) in datasets {
445 write!(&mut plot, "<th>{}</th>", name)?;
446 }
447 writeln!(&mut plot, "</tr>")?;
448
449 for i in 0..n {
451 write!(&mut plot, "<tr><th>{}</th>", datasets[i].0)?;
452 for j in 0..n {
453 let corr = correlations[i][j];
454 let color_class = if corr.abs() > 0.7 {
455 "strong-corr"
456 } else {
457 "weak-corr"
458 };
459 write!(&mut plot, "<td class='{}'>{:.3}</td>", color_class, corr)?;
460 }
461 writeln!(&mut plot, "</tr>")?;
462 }
463
464 writeln!(&mut plot, "</table>")?;
465 writeln!(&mut plot, "</div>")?;
466 }
467 ReportFormat::Markdown => {
468 writeln!(&mut plot, "## {} (Correlation Matrix)", config.title)?;
469 writeln!(&mut plot)?;
470
471 write!(&mut plot, "|")?;
473 for (name, _) in datasets {
474 write!(&mut plot, " {} |", name)?;
475 }
476 writeln!(&mut plot)?;
477
478 write!(&mut plot, "|")?;
480 for _ in 0..n {
481 write!(&mut plot, "------|")?;
482 }
483 writeln!(&mut plot)?;
484
485 for i in 0..n {
487 write!(&mut plot, "| **{}** |", datasets[i].0)?;
488 for j in 0..n {
489 write!(&mut plot, " {:.3} |", correlations[i][j])?;
490 }
491 writeln!(&mut plot)?;
492 }
493 writeln!(&mut plot)?;
494 }
495 ReportFormat::Text => {
496 writeln!(&mut plot, "{} (Correlation Matrix)", config.title)?;
497 writeln!(&mut plot, "{}", "=".repeat(config.title.len() + 20))?;
498 writeln!(&mut plot)?;
499
500 write!(&mut plot, "{:>12}", "")?;
502 for (name, _) in datasets {
503 write!(&mut plot, " {:>8}", &name[..name.len().min(8)])?;
504 }
505 writeln!(&mut plot)?;
506
507 for i in 0..n {
509 write!(
510 &mut plot,
511 "{:>12}",
512 &datasets[i].0[..datasets[i].0.len().min(12)]
513 )?;
514 for j in 0..n {
515 write!(&mut plot, " {:>8.3}", correlations[i][j])?;
516 }
517 writeln!(&mut plot)?;
518 }
519 }
520 }
521
522 Ok(plot)
523}
524
525fn calculate_correlation<T>(data1: &ArrayView1<T>, data2: &ArrayView1<T>) -> f64
527where
528 T: Float + FromPrimitive + ToPrimitive + Debug + Clone,
529{
530 if data1.len() != data2.len() || data1.len() < 2 {
531 return 0.0;
532 }
533
534 let mean1 = data1.mean_or(T::zero()).to_f64().unwrap_or(0.0);
535 let mean2 = data2.mean_or(T::zero()).to_f64().unwrap_or(0.0);
536
537 let mut sum_xy = 0.0;
538 let mut sum_x2 = 0.0;
539 let mut sum_y2 = 0.0;
540
541 for i in 0..data1.len() {
542 let x = data1[i].to_f64().unwrap_or(0.0) - mean1;
543 let y = data2[i].to_f64().unwrap_or(0.0) - mean2;
544
545 sum_xy += x * y;
546 sum_x2 += x * x;
547 sum_y2 += y * y;
548 }
549
550 let denominator = (sum_x2 * sum_y2).sqrt();
551 if denominator == 0.0 {
552 0.0
553 } else {
554 sum_xy / denominator
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561 use scirs2_core::ndarray::{Array1, Array2};
562
563 #[test]
564 fn test_create_image_montage() {
565 let img1 = Array2::zeros((5, 5));
566 let img2 = Array2::ones((5, 5));
567 let img3 = Array2::from_elem((5, 5), 2.0);
568
569 let images = vec![img1.view(), img2.view(), img3.view()];
570
571 let config = PlotConfig::new()
572 .with_format(ReportFormat::Text)
573 .with_title("Test Montage");
574
575 let result = create_image_montage(&images, 2, &config);
576 assert!(result.is_ok());
577
578 let montage = result.expect("Operation failed");
579 assert!(montage.contains("Test Montage"));
580 assert!(montage.contains("Grid layout: 2 rows × 2 columns"));
581 assert!(montage.contains("Image 1: 5×5"));
582 assert!(montage.contains("Image 2: 5×5"));
583 assert!(montage.contains("Image 3: 5×5"));
584 }
585
586 #[test]
587 fn test_plot_statistical_comparison() {
588 let data1 = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
589 let data2 = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0, 10.0]);
590
591 let datasets = vec![("Dataset A", data1.view()), ("Dataset B", data2.view())];
592
593 let config = PlotConfig::new()
594 .with_format(ReportFormat::Markdown)
595 .with_title("Statistical Test");
596
597 let result = plot_statistical_comparison(&datasets, &config);
598 assert!(result.is_ok());
599
600 let comparison = result.expect("Operation failed");
601 assert!(comparison.contains("Statistical Test"));
602 assert!(comparison.contains("Dataset A"));
603 assert!(comparison.contains("Dataset B"));
604 assert!(comparison.contains("| Dataset | Count | Mean"));
605 }
606
607 #[test]
608 fn test_calculate_dataset_statistics() {
609 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
610 let (mean, std_dev, min_val, max_val, count) = calculate_dataset_statistics(&data.view());
611
612 assert!((mean - 3.0).abs() < 1e-6);
613 assert_eq!(min_val, 1.0);
614 assert_eq!(max_val, 5.0);
615 assert_eq!(count, 5);
616 assert!(std_dev > 0.0);
617 }
618
619 #[test]
620 fn test_calculate_correlation() {
621 let data1 = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
622 let data2 = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0, 10.0]); let corr = calculate_correlation(&data1.view(), &data2.view());
625 assert!((corr - 1.0).abs() < 1e-10); }
627
628 #[test]
629 fn test_plot_correlation_matrix() {
630 let data1 = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
631 let data2 = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0, 10.0]);
632 let data3 = Array1::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0]);
633
634 let datasets = vec![
635 ("Data A", data1.view()),
636 ("Data B", data2.view()),
637 ("Data C", data3.view()),
638 ];
639
640 let config = PlotConfig::new()
641 .with_format(ReportFormat::Text)
642 .with_title("Correlation Test");
643
644 let result = plot_correlation_matrix(&datasets, &config);
645 assert!(result.is_ok());
646
647 let matrix = result.expect("Operation failed");
648 assert!(matrix.contains("Correlation Test"));
649 assert!(matrix.contains("Data A"));
650 assert!(matrix.contains("Data B"));
651 assert!(matrix.contains("Data C"));
652 }
653
654 #[test]
655 fn test_empty_image_montage() {
656 let images: Vec<scirs2_core::ndarray::ArrayView2<f64>> = vec![];
657 let config = PlotConfig::new();
658
659 let result = create_image_montage(&images, 2, &config);
660 assert!(result.is_err());
661 assert!(result
662 .unwrap_err()
663 .to_string()
664 .contains("No images provided"));
665 }
666
667 #[test]
668 fn test_zero_grid_cols() {
669 let img = Array2::<f64>::zeros((5, 5));
670 let images = vec![img.view()];
671 let config = PlotConfig::new();
672
673 let result = create_image_montage(&images, 0, &config);
674 assert!(result.is_err());
675 assert!(result
676 .unwrap_err()
677 .to_string()
678 .contains("Grid columns must be positive"));
679 }
680}