1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use serde::{Deserialize, Serialize};
12
13use crate::error::{ClusteringError, Result};
14
15#[allow(dead_code)]
43pub fn dtw_distance<F>(
44 series1: ArrayView1<F>,
45 series2: ArrayView1<F>,
46 window: Option<usize>,
47) -> Result<F>
48where
49 F: Float + FromPrimitive + Debug + 'static,
50{
51 let n = series1.len();
52 let m = series2.len();
53
54 if n == 0 || m == 0 {
55 return Err(ClusteringError::InvalidInput(
56 "Time series cannot be empty".to_string(),
57 ));
58 }
59
60 let mut dtw = Array2::from_elem((n + 1, m + 1), F::infinity());
62 dtw[[0, 0]] = F::zero();
63
64 let effective_window = window.unwrap_or(m.max(n));
66
67 for i in 1..=n {
68 let start_j = if effective_window < i {
69 i - effective_window
70 } else {
71 1
72 };
73 let end_j = (i + effective_window).min(m + 1);
74
75 for j in start_j..end_j {
76 if j <= m {
77 let cost = (series1[i - 1] - series2[j - 1]).abs();
78
79 let candidates = [
80 dtw[[i - 1, j]], dtw[[i, j - 1]], dtw[[i - 1, j - 1]], ];
84
85 let min_prev = candidates.iter().fold(F::infinity(), |acc, &x| acc.min(x));
86 dtw[[i, j]] = cost + min_prev;
87 }
88 }
89 }
90
91 Ok(dtw[[n, m]])
92}
93
94#[allow(dead_code)]
109pub fn dtw_distance_custom<F, D>(
110 series1: ArrayView1<F>,
111 series2: ArrayView1<F>,
112 local_distance: D,
113 window: Option<usize>,
114) -> Result<F>
115where
116 F: Float + FromPrimitive + Debug + 'static,
117 D: Fn(F, F) -> F,
118{
119 let n = series1.len();
120 let m = series2.len();
121
122 if n == 0 || m == 0 {
123 return Err(ClusteringError::InvalidInput(
124 "Time series cannot be empty".to_string(),
125 ));
126 }
127
128 let mut dtw = Array2::from_elem((n + 1, m + 1), F::infinity());
129 dtw[[0, 0]] = F::zero();
130
131 let effective_window = window.unwrap_or(m.max(n));
132
133 for i in 1..=n {
134 let start_j = if effective_window < i {
135 i - effective_window
136 } else {
137 1
138 };
139 let end_j = (i + effective_window).min(m + 1);
140
141 for j in start_j..end_j {
142 if j <= m {
143 let cost = local_distance(series1[i - 1], series2[j - 1]);
144
145 let candidates = [dtw[[i - 1, j]], dtw[[i, j - 1]], dtw[[i - 1, j - 1]]];
146
147 let min_prev = candidates.iter().fold(F::infinity(), |acc, &x| acc.min(x));
148 dtw[[i, j]] = cost + min_prev;
149 }
150 }
151 }
152
153 Ok(dtw[[n, m]])
154}
155
156#[allow(dead_code)]
172pub fn soft_dtw_distance<F>(series1: ArrayView1<F>, series2: ArrayView1<F>, gamma: F) -> Result<F>
173where
174 F: Float + FromPrimitive + Debug + 'static,
175{
176 let n = series1.len();
177 let m = series2.len();
178
179 if n == 0 || m == 0 {
180 return Err(ClusteringError::InvalidInput(
181 "Time series cannot be empty".to_string(),
182 ));
183 }
184
185 if gamma <= F::zero() {
186 return Err(ClusteringError::InvalidInput(
187 "Gamma must be positive".to_string(),
188 ));
189 }
190
191 let mut dtw = Array2::from_elem((n + 1, m + 1), F::infinity());
192 dtw[[0, 0]] = F::zero();
193
194 for i in 1..=n {
195 for j in 1..=m {
196 let cost = (series1[i - 1] - series2[j - 1]).powi(2);
197
198 let candidates = [dtw[[i - 1, j]], dtw[[i, j - 1]], dtw[[i - 1, j - 1]]];
199
200 let min_val = candidates.iter().fold(F::infinity(), |acc, &x| acc.min(x));
203 let sum_exp = candidates
204 .iter()
205 .map(|&x| (-(x - min_val) / gamma).exp())
206 .fold(F::zero(), |acc, x| acc + x);
207
208 let soft_min = min_val - gamma * sum_exp.ln();
209 dtw[[i, j]] = cost + soft_min;
210 }
211 }
212
213 Ok(dtw[[n, m]])
214}
215
216#[allow(dead_code)]
233pub fn dtw_k_medoids<F>(
234 time_series: ArrayView2<F>,
235 k: usize,
236 max_iterations: usize,
237 window: Option<usize>,
238) -> Result<(Array1<usize>, Array1<usize>)>
239where
240 F: Float + FromPrimitive + Debug + 'static,
241{
242 let n_series = time_series.nrows();
243
244 if k > n_series {
245 return Err(ClusteringError::InvalidInput(
246 "Number of clusters cannot exceed number of time _series".to_string(),
247 ));
248 }
249
250 if n_series == 0 {
251 return Err(ClusteringError::InvalidInput(
252 "No time _series provided".to_string(),
253 ));
254 }
255
256 let mut medoids: Array1<usize> = Array1::from_iter(0..k);
258 let mut assignments = Array1::zeros(n_series);
259
260 for _iteration in 0..max_iterations {
261 let mut changed = false;
262
263 for i in 0..n_series {
265 let mut min_distance = F::infinity();
266 let mut best_cluster = 0;
267
268 for (cluster_id, &medoid_idx) in medoids.iter().enumerate() {
269 let distance =
270 dtw_distance(time_series.row(i), time_series.row(medoid_idx), window)?;
271
272 if distance < min_distance {
273 min_distance = distance;
274 best_cluster = cluster_id;
275 }
276 }
277
278 if assignments[i] != best_cluster {
279 assignments[i] = best_cluster;
280 changed = true;
281 }
282 }
283
284 for cluster_id in 0..k {
286 let cluster_members: Vec<usize> = assignments
287 .iter()
288 .enumerate()
289 .filter(|(_, &assignment)| assignment == cluster_id)
290 .map(|(idx, _)| idx)
291 .collect();
292
293 if !cluster_members.is_empty() {
294 let mut best_medoid = medoids[cluster_id];
295 let mut min_total_distance = F::infinity();
296
297 for &candidate in &cluster_members {
299 let mut total_distance = F::zero();
300
301 for &member in &cluster_members {
302 if candidate != member {
303 let distance = dtw_distance(
304 time_series.row(candidate),
305 time_series.row(member),
306 window,
307 )?;
308 total_distance = total_distance + distance;
309 }
310 }
311
312 if total_distance < min_total_distance {
313 min_total_distance = total_distance;
314 best_medoid = candidate;
315 }
316 }
317
318 if medoids[cluster_id] != best_medoid {
319 medoids[cluster_id] = best_medoid;
320 changed = true;
321 }
322 }
323 }
324
325 if !changed {
326 break;
327 }
328 }
329
330 Ok((medoids, assignments))
331}
332
333#[allow(dead_code)]
347pub fn dtw_hierarchical_clustering<F>(
348 time_series: ArrayView2<F>,
349 window: Option<usize>,
350) -> Result<Array2<F>>
351where
352 F: Float + FromPrimitive + Debug + 'static,
353{
354 let n_series = time_series.nrows();
355
356 if n_series < 2 {
357 return Err(ClusteringError::InvalidInput(
358 "Need at least 2 time _series for clustering".to_string(),
359 ));
360 }
361
362 let mut distances = Array2::zeros((n_series, n_series));
364 for i in 0..n_series {
365 for j in (i + 1)..n_series {
366 let distance = dtw_distance(time_series.row(i), time_series.row(j), window)?;
367 distances[[i, j]] = distance;
368 distances[[j, i]] = distance;
369 }
370 }
371
372 let mut clusters: Vec<Vec<usize>> = (0..n_series).map(|i| vec![i]).collect();
374 let mut linkage = Vec::new();
375 let mut cluster_id = n_series;
376
377 while clusters.len() > 1 {
378 let mut min_distance = F::infinity();
380 let mut merge_i = 0;
381 let mut merge_j = 1;
382
383 for i in 0..clusters.len() {
384 for j in (i + 1)..clusters.len() {
385 let mut max_dist = F::zero();
387 for &point_i in &clusters[i] {
388 for &point_j in &clusters[j] {
389 max_dist = max_dist.max(distances[[point_i, point_j]]);
390 }
391 }
392
393 if max_dist < min_distance {
394 min_distance = max_dist;
395 merge_i = i;
396 merge_j = j;
397 }
398 }
399 }
400
401 let cluster_i_size = clusters[merge_i].len();
403 let cluster_j_size = clusters[merge_j].len();
404
405 linkage.push([
406 F::from(if merge_i < n_series {
407 merge_i
408 } else {
409 n_series + merge_i
410 })
411 .unwrap(),
412 F::from(if merge_j < n_series {
413 merge_j
414 } else {
415 n_series + merge_j
416 })
417 .unwrap(),
418 min_distance,
419 F::from(cluster_i_size + cluster_j_size).unwrap(),
420 ]);
421
422 let mut new_cluster = clusters[merge_i].clone();
424 new_cluster.extend(&clusters[merge_j]);
425
426 let (first, second) = if merge_i > merge_j {
428 (merge_i, merge_j)
429 } else {
430 (merge_j, merge_i)
431 };
432
433 clusters.remove(first);
434 clusters.remove(second);
435 clusters.push(new_cluster);
436
437 cluster_id += 1;
438 }
439
440 let linkage_array =
442 Array2::from_shape_vec((linkage.len(), 4), linkage.into_iter().flatten().collect())
443 .map_err(|_| {
444 ClusteringError::ComputationError("Failed to create linkage matrix".to_string())
445 })?;
446
447 Ok(linkage_array)
448}
449
450#[allow(dead_code)]
466pub fn dtw_k_means<F>(
467 time_series: ArrayView2<F>,
468 k: usize,
469 max_iterations: usize,
470 tolerance: F,
471) -> Result<(Array2<F>, Array1<usize>)>
472where
473 F: Float + FromPrimitive + Debug + 'static,
474{
475 let n_series = time_series.nrows();
476 let series_length = time_series.ncols();
477
478 if k > n_series {
479 return Err(ClusteringError::InvalidInput(
480 "Number of clusters cannot exceed number of time _series".to_string(),
481 ));
482 }
483
484 let mut centers = Array2::zeros((k, series_length));
486 for i in 0..k {
487 centers.row_mut(i).assign(&time_series.row(i));
488 }
489
490 let mut assignments = Array1::zeros(n_series);
491
492 for _iteration in 0..max_iterations {
493 let mut changed = false;
494
495 for i in 0..n_series {
497 let mut min_distance = F::infinity();
498 let mut best_cluster = 0;
499
500 for j in 0..k {
501 let distance = dtw_distance(time_series.row(i), centers.row(j), None)?;
502
503 if distance < min_distance {
504 min_distance = distance;
505 best_cluster = j;
506 }
507 }
508
509 if assignments[i] != best_cluster {
510 assignments[i] = best_cluster;
511 changed = true;
512 }
513 }
514
515 if !changed {
516 break;
517 }
518
519 let mut center_changed = false;
521 for cluster_id in 0..k {
522 let cluster_members: Vec<usize> = assignments
523 .iter()
524 .enumerate()
525 .filter(|(_, &assignment)| assignment == cluster_id)
526 .map(|(idx, _)| idx)
527 .collect();
528
529 if !cluster_members.is_empty() {
530 let new_center = dtw_barycenter_averaging(
531 &time_series.select(Axis(0), &cluster_members),
532 10,
533 tolerance,
534 )?;
535
536 let center_distance =
537 dtw_distance(centers.row(cluster_id), new_center.view(), None)?;
538
539 if center_distance > tolerance {
540 center_changed = true;
541 }
542
543 centers.row_mut(cluster_id).assign(&new_center);
544 }
545 }
546
547 if !center_changed {
548 break;
549 }
550 }
551
552 Ok((centers, assignments))
553}
554
555#[allow(dead_code)]
570pub fn dtw_barycenter_averaging<F>(
571 time_series: &Array2<F>,
572 max_iterations: usize,
573 tolerance: F,
574) -> Result<Array1<F>>
575where
576 F: Float + FromPrimitive + Debug + 'static,
577{
578 let n_series = time_series.nrows();
579 let series_length = time_series.ncols();
580
581 if n_series == 0 {
582 return Err(ClusteringError::InvalidInput(
583 "No time _series provided".to_string(),
584 ));
585 }
586
587 if n_series == 1 {
588 return Ok(time_series.row(0).to_owned());
589 }
590
591 let mut barycenter = time_series.mean_axis(Axis(0)).unwrap();
593
594 for _iteration in 0..max_iterations {
595 let mut new_barycenter = Array1::zeros(series_length);
596 let mut weights = Array1::zeros(series_length);
597
598 for i in 0..n_series {
600 let (aligned_series, alignment_weights) =
601 dtw_align_series(time_series.row(i), barycenter.view())?;
602
603 new_barycenter = new_barycenter + aligned_series;
604 weights = weights + alignment_weights;
605 }
606
607 for i in 0..series_length {
609 if weights[i] > F::zero() {
610 new_barycenter[i] = new_barycenter[i] / weights[i];
611 }
612 }
613
614 let change = dtw_distance(barycenter.view(), new_barycenter.view(), None)?;
616 if change < tolerance {
617 break;
618 }
619
620 barycenter = new_barycenter;
621 }
622
623 Ok(barycenter)
624}
625
626#[allow(dead_code)]
628fn dtw_align_series<F>(
629 series: ArrayView1<F>,
630 reference: ArrayView1<F>,
631) -> Result<(Array1<F>, Array1<F>)>
632where
633 F: Float + FromPrimitive + Debug + 'static,
634{
635 let n = series.len();
636 let m = reference.len();
637
638 let mut dtw = Array2::from_elem((n + 1, m + 1), F::infinity());
640 dtw[[0, 0]] = F::zero();
641
642 for i in 1..=n {
643 for j in 1..=m {
644 let cost = (series[i - 1] - reference[j - 1]).abs();
645 let min_prev = [dtw[[i - 1, j]], dtw[[i, j - 1]], dtw[[i - 1, j - 1]]]
646 .iter()
647 .fold(F::infinity(), |acc, &x| acc.min(x));
648
649 dtw[[i, j]] = cost + min_prev;
650 }
651 }
652
653 let mut i = n;
655 let mut j = m;
656 let mut aligned_series = Array1::zeros(m);
657 let mut weights = Array1::zeros(m);
658
659 while i > 0 && j > 0 {
660 aligned_series[j - 1] = aligned_series[j - 1] + series[i - 1];
662 weights[j - 1] = weights[j - 1] + F::one();
663
664 let candidates = [
666 (dtw[[i - 1, j - 1]], (i - 1, j - 1)), (dtw[[i - 1, j]], (i - 1, j)), (dtw[[i, j - 1]], (i, j - 1)), ];
670
671 let (_, (next_i, next_j)) = candidates
672 .iter()
673 .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
674 .unwrap();
675
676 i = *next_i;
677 j = *next_j;
678 }
679
680 Ok((aligned_series, weights))
681}
682
683#[derive(Debug, Clone, Serialize, Deserialize)]
685pub struct TimeSeriesClusteringConfig {
686 pub algorithm: TimeSeriesAlgorithm,
688 pub n_clusters: usize,
690 pub max_iterations: usize,
692 pub tolerance: f64,
694 pub dtw_window: Option<usize>,
696 pub soft_dtw_gamma: Option<f64>,
698}
699
700#[derive(Debug, Clone, Serialize, Deserialize)]
702pub enum TimeSeriesAlgorithm {
703 DTWKMedoids,
705 DTWKMeans,
707 DTWHierarchical,
709}
710
711impl Default for TimeSeriesClusteringConfig {
712 fn default() -> Self {
713 Self {
714 algorithm: TimeSeriesAlgorithm::DTWKMedoids,
715 n_clusters: 3,
716 max_iterations: 100,
717 tolerance: 1e-4,
718 dtw_window: None,
719 soft_dtw_gamma: None,
720 }
721 }
722}
723
724#[allow(dead_code)]
735pub fn time_series_clustering<F>(
736 time_series: ArrayView2<F>,
737 config: &TimeSeriesClusteringConfig,
738) -> Result<Array1<usize>>
739where
740 F: Float + FromPrimitive + Debug + 'static,
741{
742 match config.algorithm {
743 TimeSeriesAlgorithm::DTWKMedoids => {
744 let (_, assignments) = dtw_k_medoids(
745 time_series,
746 config.n_clusters,
747 config.max_iterations,
748 config.dtw_window,
749 )?;
750 Ok(assignments)
751 }
752 TimeSeriesAlgorithm::DTWKMeans => {
753 let tolerance = F::from(config.tolerance).unwrap();
754 let (_, assignments) = dtw_k_means(
755 time_series,
756 config.n_clusters,
757 config.max_iterations,
758 tolerance,
759 )?;
760 Ok(assignments)
761 }
762 TimeSeriesAlgorithm::DTWHierarchical => {
763 let _linkage = dtw_hierarchical_clustering(time_series, config.dtw_window)?;
766
767 let n_series = time_series.nrows();
770 let mut assignments = Array1::from_iter(0..n_series);
771
772 for i in 0..n_series {
775 assignments[i] = i % config.n_clusters;
776 }
777
778 Ok(assignments)
779 }
780 }
781}
782
783#[cfg(test)]
784mod tests {
785 use super::*;
786 use scirs2_core::ndarray::Array2;
787
788 #[test]
789 fn test_dtw_distance() {
790 let series1 = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
791 let series2 = Array1::from_vec(vec![1.0, 2.0, 2.0, 3.0, 2.0, 1.0]);
792
793 let distance = dtw_distance(series1.view(), series2.view(), None).unwrap();
794 assert!(distance >= 0.0);
795 }
796
797 #[test]
798 fn test_dtw_identical_series() {
799 let series = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
800 let distance = dtw_distance(series.view(), series.view(), None).unwrap();
801 assert_eq!(distance, 0.0);
802 }
803
804 #[test]
805 fn test_dtw_k_medoids() {
806 let time_series = Array2::from_shape_vec(
807 (4, 5),
808 vec![
809 1.0, 2.0, 3.0, 2.0, 1.0, 1.1, 2.1, 3.1, 2.1, 1.1, 5.0, 6.0, 7.0, 6.0, 5.0, 5.1,
810 6.1, 7.1, 6.1, 5.1,
811 ],
812 )
813 .unwrap();
814
815 let (medoids, assignments) = dtw_k_medoids(time_series.view(), 2, 10, None).unwrap();
816
817 assert_eq!(medoids.len(), 2);
818 assert_eq!(assignments.len(), 4);
819
820 assert_eq!(assignments[0], assignments[1]);
822 assert_eq!(assignments[2], assignments[3]);
823 assert_ne!(assignments[0], assignments[2]);
824 }
825
826 #[test]
827 fn test_soft_dtw_distance() {
828 let series1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
829 let series2 = Array1::from_vec(vec![1.0, 2.5, 3.0]);
830
831 let distance = soft_dtw_distance(series1.view(), series2.view(), 0.1).unwrap();
832 assert!(distance >= 0.0);
833 }
834
835 #[test]
836 fn test_dtw_barycenter_averaging() {
837 let time_series = Array2::from_shape_vec(
838 (3, 4),
839 vec![1.0, 2.0, 3.0, 2.0, 1.1, 2.1, 3.1, 2.1, 0.9, 1.9, 2.9, 1.9],
840 )
841 .unwrap();
842
843 let barycenter = dtw_barycenter_averaging(&time_series, 10, 1e-3).unwrap();
844 assert_eq!(barycenter.len(), 4);
845
846 let mean_series = time_series.mean_axis(Axis(0)).unwrap();
848 for i in 0..4 {
849 assert!((barycenter[i] - mean_series[i]).abs() < 0.5);
850 }
851 }
852
853 #[test]
854 fn test_time_series_clustering_config() {
855 let config = TimeSeriesClusteringConfig::default();
856 assert_eq!(config.n_clusters, 3);
857 assert_eq!(config.max_iterations, 100);
858
859 let time_series = Array2::from_shape_vec(
860 (4, 5),
861 vec![
862 1.0, 2.0, 3.0, 2.0, 1.0, 1.1, 2.1, 3.1, 2.1, 1.1, 5.0, 6.0, 7.0, 6.0, 5.0, 5.1,
863 6.1, 7.1, 6.1, 5.1,
864 ],
865 )
866 .unwrap();
867
868 let assignments = time_series_clustering(time_series.view(), &config).unwrap();
869 assert_eq!(assignments.len(), 4);
870 }
871}