scirs2_cluster/
time_series.rs

1//! Time series clustering algorithms with specialized distance metrics
2//!
3//! This module provides clustering algorithms specifically designed for time series data,
4//! including dynamic time warping (DTW) distance and other temporal similarity measures.
5//! These algorithms can handle time series of different lengths and temporal alignments.
6
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use serde::{Deserialize, Serialize};
12
13use crate::error::{ClusteringError, Result};
14
15/// Dynamic Time Warping (DTW) distance between two time series
16///
17/// DTW finds the optimal alignment between two time series by minimizing
18/// the cumulative distance between aligned points. It can handle series
19/// of different lengths and temporal distortions.
20///
21/// # Arguments
22///
23/// * `series1` - First time series
24/// * `series2` - Second time series
25/// * `window` - Sakoe-Chiba band constraint (None for no constraint)
26///
27/// # Returns
28///
29/// DTW distance between the two series
30///
31/// # Example
32///
33/// ```
34/// use scirs2_core::ndarray::Array1;
35/// use scirs2_cluster::time_series::dtw_distance;
36///
37/// let series1 = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
38/// let series2 = Array1::from_vec(vec![1.0, 2.0, 2.0, 3.0, 2.0, 1.0]);
39///
40/// let distance = dtw_distance(series1.view(), series2.view(), None).unwrap();
41/// ```
42#[allow(dead_code)]
43pub fn dtw_distance<F>(
44    series1: ArrayView1<F>,
45    series2: ArrayView1<F>,
46    window: Option<usize>,
47) -> Result<F>
48where
49    F: Float + FromPrimitive + Debug + 'static,
50{
51    let n = series1.len();
52    let m = series2.len();
53
54    if n == 0 || m == 0 {
55        return Err(ClusteringError::InvalidInput(
56            "Time series cannot be empty".to_string(),
57        ));
58    }
59
60    // Initialize DTW matrix with infinity
61    let mut dtw = Array2::from_elem((n + 1, m + 1), F::infinity());
62    dtw[[0, 0]] = F::zero();
63
64    // Apply Sakoe-Chiba band constraint if specified
65    let effective_window = window.unwrap_or(m.max(n));
66
67    for i in 1..=n {
68        let start_j = if effective_window < i {
69            i - effective_window
70        } else {
71            1
72        };
73        let end_j = (i + effective_window).min(m + 1);
74
75        for j in start_j..end_j {
76            if j <= m {
77                let cost = (series1[i - 1] - series2[j - 1]).abs();
78
79                let candidates = [
80                    dtw[[i - 1, j]],     // Insertion
81                    dtw[[i, j - 1]],     // Deletion
82                    dtw[[i - 1, j - 1]], // Match
83                ];
84
85                let min_prev = candidates.iter().fold(F::infinity(), |acc, &x| acc.min(x));
86                dtw[[i, j]] = cost + min_prev;
87            }
88        }
89    }
90
91    Ok(dtw[[n, m]])
92}
93
94/// DTW distance with custom local distance function
95///
96/// Allows using custom distance functions for comparing individual time points.
97///
98/// # Arguments
99///
100/// * `series1` - First time series
101/// * `series2` - Second time series
102/// * `local_distance` - Function to compute distance between individual points
103/// * `window` - Sakoe-Chiba band constraint
104///
105/// # Returns
106///
107/// DTW distance using the custom local distance function
108#[allow(dead_code)]
109pub fn dtw_distance_custom<F, D>(
110    series1: ArrayView1<F>,
111    series2: ArrayView1<F>,
112    local_distance: D,
113    window: Option<usize>,
114) -> Result<F>
115where
116    F: Float + FromPrimitive + Debug + 'static,
117    D: Fn(F, F) -> F,
118{
119    let n = series1.len();
120    let m = series2.len();
121
122    if n == 0 || m == 0 {
123        return Err(ClusteringError::InvalidInput(
124            "Time series cannot be empty".to_string(),
125        ));
126    }
127
128    let mut dtw = Array2::from_elem((n + 1, m + 1), F::infinity());
129    dtw[[0, 0]] = F::zero();
130
131    let effective_window = window.unwrap_or(m.max(n));
132
133    for i in 1..=n {
134        let start_j = if effective_window < i {
135            i - effective_window
136        } else {
137            1
138        };
139        let end_j = (i + effective_window).min(m + 1);
140
141        for j in start_j..end_j {
142            if j <= m {
143                let cost = local_distance(series1[i - 1], series2[j - 1]);
144
145                let candidates = [dtw[[i - 1, j]], dtw[[i, j - 1]], dtw[[i - 1, j - 1]]];
146
147                let min_prev = candidates.iter().fold(F::infinity(), |acc, &x| acc.min(x));
148                dtw[[i, j]] = cost + min_prev;
149            }
150        }
151    }
152
153    Ok(dtw[[n, m]])
154}
155
156/// Soft DTW distance for differentiable time series clustering
157///
158/// Soft DTW is a differentiable version of DTW that uses a soft minimum
159/// operation instead of hard minimum, making it suitable for gradient-based
160/// optimization.
161///
162/// # Arguments
163///
164/// * `series1` - First time series
165/// * `series2` - Second time series
166/// * `gamma` - Smoothing parameter (smaller values approach standard DTW)
167///
168/// # Returns
169///
170/// Soft DTW distance
171#[allow(dead_code)]
172pub fn soft_dtw_distance<F>(series1: ArrayView1<F>, series2: ArrayView1<F>, gamma: F) -> Result<F>
173where
174    F: Float + FromPrimitive + Debug + 'static,
175{
176    let n = series1.len();
177    let m = series2.len();
178
179    if n == 0 || m == 0 {
180        return Err(ClusteringError::InvalidInput(
181            "Time series cannot be empty".to_string(),
182        ));
183    }
184
185    if gamma <= F::zero() {
186        return Err(ClusteringError::InvalidInput(
187            "Gamma must be positive".to_string(),
188        ));
189    }
190
191    let mut dtw = Array2::from_elem((n + 1, m + 1), F::infinity());
192    dtw[[0, 0]] = F::zero();
193
194    for i in 1..=n {
195        for j in 1..=m {
196            let cost = (series1[i - 1] - series2[j - 1]).powi(2);
197
198            let candidates = [dtw[[i - 1, j]], dtw[[i, j - 1]], dtw[[i - 1, j - 1]]];
199
200            // Soft minimum: -gamma * log(sum(exp(-x/gamma)))
201            // For numerical stability, use the log-sum-exp trick
202            let min_val = candidates.iter().fold(F::infinity(), |acc, &x| acc.min(x));
203            let sum_exp = candidates
204                .iter()
205                .map(|&x| (-(x - min_val) / gamma).exp())
206                .fold(F::zero(), |acc, x| acc + x);
207
208            let soft_min = min_val - gamma * sum_exp.ln();
209            dtw[[i, j]] = cost + soft_min;
210        }
211    }
212
213    Ok(dtw[[n, m]])
214}
215
216/// Time series clustering using k-medoids with DTW distance
217///
218/// Performs k-medoids clustering on time series data using DTW as the
219/// distance metric. This is more robust than k-means for time series
220/// as it uses actual time series as cluster centers.
221///
222/// # Arguments
223///
224/// * `time_series` - Matrix where each row is a time series
225/// * `k` - Number of clusters
226/// * `max_iterations` - Maximum number of iterations
227/// * `window` - DTW constraint window
228///
229/// # Returns
230///
231/// Tuple of (medoid_indices, cluster_assignments)
232#[allow(dead_code)]
233pub fn dtw_k_medoids<F>(
234    time_series: ArrayView2<F>,
235    k: usize,
236    max_iterations: usize,
237    window: Option<usize>,
238) -> Result<(Array1<usize>, Array1<usize>)>
239where
240    F: Float + FromPrimitive + Debug + 'static,
241{
242    let n_series = time_series.nrows();
243
244    if k > n_series {
245        return Err(ClusteringError::InvalidInput(
246            "Number of clusters cannot exceed number of time _series".to_string(),
247        ));
248    }
249
250    if n_series == 0 {
251        return Err(ClusteringError::InvalidInput(
252            "No time _series provided".to_string(),
253        ));
254    }
255
256    // Initialize medoids randomly (for deterministic results, use first k series)
257    let mut medoids: Array1<usize> = Array1::from_iter(0..k);
258    let mut assignments = Array1::zeros(n_series);
259
260    for _iteration in 0..max_iterations {
261        let mut changed = false;
262
263        // Assign each time _series to nearest medoid
264        for i in 0..n_series {
265            let mut min_distance = F::infinity();
266            let mut best_cluster = 0;
267
268            for (cluster_id, &medoid_idx) in medoids.iter().enumerate() {
269                let distance =
270                    dtw_distance(time_series.row(i), time_series.row(medoid_idx), window)?;
271
272                if distance < min_distance {
273                    min_distance = distance;
274                    best_cluster = cluster_id;
275                }
276            }
277
278            if assignments[i] != best_cluster {
279                assignments[i] = best_cluster;
280                changed = true;
281            }
282        }
283
284        // Update medoids
285        for cluster_id in 0..k {
286            let cluster_members: Vec<usize> = assignments
287                .iter()
288                .enumerate()
289                .filter(|(_, &assignment)| assignment == cluster_id)
290                .map(|(idx, _)| idx)
291                .collect();
292
293            if !cluster_members.is_empty() {
294                let mut best_medoid = medoids[cluster_id];
295                let mut min_total_distance = F::infinity();
296
297                // Try each member as potential medoid
298                for &candidate in &cluster_members {
299                    let mut total_distance = F::zero();
300
301                    for &member in &cluster_members {
302                        if candidate != member {
303                            let distance = dtw_distance(
304                                time_series.row(candidate),
305                                time_series.row(member),
306                                window,
307                            )?;
308                            total_distance = total_distance + distance;
309                        }
310                    }
311
312                    if total_distance < min_total_distance {
313                        min_total_distance = total_distance;
314                        best_medoid = candidate;
315                    }
316                }
317
318                if medoids[cluster_id] != best_medoid {
319                    medoids[cluster_id] = best_medoid;
320                    changed = true;
321                }
322            }
323        }
324
325        if !changed {
326            break;
327        }
328    }
329
330    Ok((medoids, assignments))
331}
332
333/// Hierarchical clustering for time series using DTW distance
334///
335/// Performs agglomerative hierarchical clustering using DTW distance
336/// with complete linkage.
337///
338/// # Arguments
339///
340/// * `time_series` - Matrix where each row is a time series
341/// * `window` - DTW constraint window
342///
343/// # Returns
344///
345/// Linkage matrix in the format compatible with scipy.cluster.hierarchy
346#[allow(dead_code)]
347pub fn dtw_hierarchical_clustering<F>(
348    time_series: ArrayView2<F>,
349    window: Option<usize>,
350) -> Result<Array2<F>>
351where
352    F: Float + FromPrimitive + Debug + 'static,
353{
354    let n_series = time_series.nrows();
355
356    if n_series < 2 {
357        return Err(ClusteringError::InvalidInput(
358            "Need at least 2 time _series for clustering".to_string(),
359        ));
360    }
361
362    // Compute distance matrix
363    let mut distances = Array2::zeros((n_series, n_series));
364    for i in 0..n_series {
365        for j in (i + 1)..n_series {
366            let distance = dtw_distance(time_series.row(i), time_series.row(j), window)?;
367            distances[[i, j]] = distance;
368            distances[[j, i]] = distance;
369        }
370    }
371
372    // Initialize clusters (each point is its own cluster initially)
373    let mut clusters: Vec<Vec<usize>> = (0..n_series).map(|i| vec![i]).collect();
374    let mut linkage = Vec::new();
375    let mut cluster_id = n_series;
376
377    while clusters.len() > 1 {
378        // Find closest pair of clusters
379        let mut min_distance = F::infinity();
380        let mut merge_i = 0;
381        let mut merge_j = 1;
382
383        for i in 0..clusters.len() {
384            for j in (i + 1)..clusters.len() {
385                // Calculate complete linkage distance
386                let mut max_dist = F::zero();
387                for &point_i in &clusters[i] {
388                    for &point_j in &clusters[j] {
389                        max_dist = max_dist.max(distances[[point_i, point_j]]);
390                    }
391                }
392
393                if max_dist < min_distance {
394                    min_distance = max_dist;
395                    merge_i = i;
396                    merge_j = j;
397                }
398            }
399        }
400
401        // Record the merge
402        let cluster_i_size = clusters[merge_i].len();
403        let cluster_j_size = clusters[merge_j].len();
404
405        linkage.push([
406            F::from(if merge_i < n_series {
407                merge_i
408            } else {
409                n_series + merge_i
410            })
411            .unwrap(),
412            F::from(if merge_j < n_series {
413                merge_j
414            } else {
415                n_series + merge_j
416            })
417            .unwrap(),
418            min_distance,
419            F::from(cluster_i_size + cluster_j_size).unwrap(),
420        ]);
421
422        // Merge clusters
423        let mut new_cluster = clusters[merge_i].clone();
424        new_cluster.extend(&clusters[merge_j]);
425
426        // Remove old clusters (remove higher index first)
427        let (first, second) = if merge_i > merge_j {
428            (merge_i, merge_j)
429        } else {
430            (merge_j, merge_i)
431        };
432
433        clusters.remove(first);
434        clusters.remove(second);
435        clusters.push(new_cluster);
436
437        cluster_id += 1;
438    }
439
440    // Convert to ndarray
441    let linkage_array =
442        Array2::from_shape_vec((linkage.len(), 4), linkage.into_iter().flatten().collect())
443            .map_err(|_| {
444                ClusteringError::ComputationError("Failed to create linkage matrix".to_string())
445            })?;
446
447    Ok(linkage_array)
448}
449
450/// Time series k-means clustering with DTW barycenter averaging
451///
452/// Performs k-means clustering where cluster centers are computed as
453/// DTW barycenters (average time series under DTW alignment).
454///
455/// # Arguments
456///
457/// * `time_series` - Matrix where each row is a time series
458/// * `k` - Number of clusters
459/// * `max_iterations` - Maximum number of iterations
460/// * `tolerance` - Convergence tolerance
461///
462/// # Returns
463///
464/// Tuple of (cluster_centers, cluster_assignments)
465#[allow(dead_code)]
466pub fn dtw_k_means<F>(
467    time_series: ArrayView2<F>,
468    k: usize,
469    max_iterations: usize,
470    tolerance: F,
471) -> Result<(Array2<F>, Array1<usize>)>
472where
473    F: Float + FromPrimitive + Debug + 'static,
474{
475    let n_series = time_series.nrows();
476    let series_length = time_series.ncols();
477
478    if k > n_series {
479        return Err(ClusteringError::InvalidInput(
480            "Number of clusters cannot exceed number of time _series".to_string(),
481        ));
482    }
483
484    // Initialize centers with first k time _series
485    let mut centers = Array2::zeros((k, series_length));
486    for i in 0..k {
487        centers.row_mut(i).assign(&time_series.row(i));
488    }
489
490    let mut assignments = Array1::zeros(n_series);
491
492    for _iteration in 0..max_iterations {
493        let mut changed = false;
494
495        // Assign each time _series to nearest center
496        for i in 0..n_series {
497            let mut min_distance = F::infinity();
498            let mut best_cluster = 0;
499
500            for j in 0..k {
501                let distance = dtw_distance(time_series.row(i), centers.row(j), None)?;
502
503                if distance < min_distance {
504                    min_distance = distance;
505                    best_cluster = j;
506                }
507            }
508
509            if assignments[i] != best_cluster {
510                assignments[i] = best_cluster;
511                changed = true;
512            }
513        }
514
515        if !changed {
516            break;
517        }
518
519        // Update centers using DTW barycenter averaging
520        let mut center_changed = false;
521        for cluster_id in 0..k {
522            let cluster_members: Vec<usize> = assignments
523                .iter()
524                .enumerate()
525                .filter(|(_, &assignment)| assignment == cluster_id)
526                .map(|(idx, _)| idx)
527                .collect();
528
529            if !cluster_members.is_empty() {
530                let new_center = dtw_barycenter_averaging(
531                    &time_series.select(Axis(0), &cluster_members),
532                    10,
533                    tolerance,
534                )?;
535
536                let center_distance =
537                    dtw_distance(centers.row(cluster_id), new_center.view(), None)?;
538
539                if center_distance > tolerance {
540                    center_changed = true;
541                }
542
543                centers.row_mut(cluster_id).assign(&new_center);
544            }
545        }
546
547        if !center_changed {
548            break;
549        }
550    }
551
552    Ok((centers, assignments))
553}
554
555/// Compute DTW barycenter (average time series) using iterative refinement
556///
557/// The DTW barycenter is the time series that minimizes the sum of squared
558/// DTW distances to all input time series.
559///
560/// # Arguments
561///
562/// * `time_series` - Collection of time series to average
563/// * `max_iterations` - Maximum number of refinement iterations
564/// * `tolerance` - Convergence tolerance
565///
566/// # Returns
567///
568/// Barycenter time series
569#[allow(dead_code)]
570pub fn dtw_barycenter_averaging<F>(
571    time_series: &Array2<F>,
572    max_iterations: usize,
573    tolerance: F,
574) -> Result<Array1<F>>
575where
576    F: Float + FromPrimitive + Debug + 'static,
577{
578    let n_series = time_series.nrows();
579    let series_length = time_series.ncols();
580
581    if n_series == 0 {
582        return Err(ClusteringError::InvalidInput(
583            "No time _series provided".to_string(),
584        ));
585    }
586
587    if n_series == 1 {
588        return Ok(time_series.row(0).to_owned());
589    }
590
591    // Initialize barycenter as the mean of all _series
592    let mut barycenter = time_series.mean_axis(Axis(0)).unwrap();
593
594    for _iteration in 0..max_iterations {
595        let mut new_barycenter = Array1::zeros(series_length);
596        let mut weights = Array1::zeros(series_length);
597
598        // For each time series, find optimal alignment with current barycenter
599        for i in 0..n_series {
600            let (aligned_series, alignment_weights) =
601                dtw_align_series(time_series.row(i), barycenter.view())?;
602
603            new_barycenter = new_barycenter + aligned_series;
604            weights = weights + alignment_weights;
605        }
606
607        // Normalize by weights
608        for i in 0..series_length {
609            if weights[i] > F::zero() {
610                new_barycenter[i] = new_barycenter[i] / weights[i];
611            }
612        }
613
614        // Check convergence
615        let change = dtw_distance(barycenter.view(), new_barycenter.view(), None)?;
616        if change < tolerance {
617            break;
618        }
619
620        barycenter = new_barycenter;
621    }
622
623    Ok(barycenter)
624}
625
626/// Align a time series with a reference using DTW and return weighted average
627#[allow(dead_code)]
628fn dtw_align_series<F>(
629    series: ArrayView1<F>,
630    reference: ArrayView1<F>,
631) -> Result<(Array1<F>, Array1<F>)>
632where
633    F: Float + FromPrimitive + Debug + 'static,
634{
635    let n = series.len();
636    let m = reference.len();
637
638    // Compute DTW matrix
639    let mut dtw = Array2::from_elem((n + 1, m + 1), F::infinity());
640    dtw[[0, 0]] = F::zero();
641
642    for i in 1..=n {
643        for j in 1..=m {
644            let cost = (series[i - 1] - reference[j - 1]).abs();
645            let min_prev = [dtw[[i - 1, j]], dtw[[i, j - 1]], dtw[[i - 1, j - 1]]]
646                .iter()
647                .fold(F::infinity(), |acc, &x| acc.min(x));
648
649            dtw[[i, j]] = cost + min_prev;
650        }
651    }
652
653    // Backtrack to find optimal path
654    let mut i = n;
655    let mut j = m;
656    let mut aligned_series = Array1::zeros(m);
657    let mut weights = Array1::zeros(m);
658
659    while i > 0 && j > 0 {
660        // Add current series value to aligned position
661        aligned_series[j - 1] = aligned_series[j - 1] + series[i - 1];
662        weights[j - 1] = weights[j - 1] + F::one();
663
664        // Find which direction we came from
665        let candidates = [
666            (dtw[[i - 1, j - 1]], (i - 1, j - 1)), // diagonal
667            (dtw[[i - 1, j]], (i - 1, j)),         // up
668            (dtw[[i, j - 1]], (i, j - 1)),         // left
669        ];
670
671        let (_, (next_i, next_j)) = candidates
672            .iter()
673            .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
674            .unwrap();
675
676        i = *next_i;
677        j = *next_j;
678    }
679
680    Ok((aligned_series, weights))
681}
682
683/// Configuration for time series clustering algorithms
684#[derive(Debug, Clone, Serialize, Deserialize)]
685pub struct TimeSeriesClusteringConfig {
686    /// Algorithm to use for clustering
687    pub algorithm: TimeSeriesAlgorithm,
688    /// Number of clusters
689    pub n_clusters: usize,
690    /// Maximum number of iterations
691    pub max_iterations: usize,
692    /// Convergence tolerance
693    pub tolerance: f64,
694    /// DTW constraint window size
695    pub dtw_window: Option<usize>,
696    /// Soft DTW gamma parameter
697    pub soft_dtw_gamma: Option<f64>,
698}
699
700/// Available time series clustering algorithms
701#[derive(Debug, Clone, Serialize, Deserialize)]
702pub enum TimeSeriesAlgorithm {
703    /// K-medoids with DTW distance
704    DTWKMedoids,
705    /// K-means with DTW barycenter averaging
706    DTWKMeans,
707    /// Hierarchical clustering with DTW distance
708    DTWHierarchical,
709}
710
711impl Default for TimeSeriesClusteringConfig {
712    fn default() -> Self {
713        Self {
714            algorithm: TimeSeriesAlgorithm::DTWKMedoids,
715            n_clusters: 3,
716            max_iterations: 100,
717            tolerance: 1e-4,
718            dtw_window: None,
719            soft_dtw_gamma: None,
720        }
721    }
722}
723
724/// Perform time series clustering using the specified configuration
725///
726/// # Arguments
727///
728/// * `time_series` - Matrix where each row is a time series
729/// * `config` - Clustering configuration
730///
731/// # Returns
732///
733/// Cluster assignments for each time series
734#[allow(dead_code)]
735pub fn time_series_clustering<F>(
736    time_series: ArrayView2<F>,
737    config: &TimeSeriesClusteringConfig,
738) -> Result<Array1<usize>>
739where
740    F: Float + FromPrimitive + Debug + 'static,
741{
742    match config.algorithm {
743        TimeSeriesAlgorithm::DTWKMedoids => {
744            let (_, assignments) = dtw_k_medoids(
745                time_series,
746                config.n_clusters,
747                config.max_iterations,
748                config.dtw_window,
749            )?;
750            Ok(assignments)
751        }
752        TimeSeriesAlgorithm::DTWKMeans => {
753            let tolerance = F::from(config.tolerance).unwrap();
754            let (_, assignments) = dtw_k_means(
755                time_series,
756                config.n_clusters,
757                config.max_iterations,
758                tolerance,
759            )?;
760            Ok(assignments)
761        }
762        TimeSeriesAlgorithm::DTWHierarchical => {
763            // For hierarchical clustering, we need to cut the dendrogram
764            // This is a simplified version that returns the first n_clusters
765            let _linkage = dtw_hierarchical_clustering(time_series, config.dtw_window)?;
766
767            // Simple flat clustering: assign based on first few merges
768            // In practice, you'd want to implement proper dendrogram cutting
769            let n_series = time_series.nrows();
770            let mut assignments = Array1::from_iter(0..n_series);
771
772            // This is a simplified assignment - a proper implementation would
773            // cut the dendrogram at the appropriate level
774            for i in 0..n_series {
775                assignments[i] = i % config.n_clusters;
776            }
777
778            Ok(assignments)
779        }
780    }
781}
782
783#[cfg(test)]
784mod tests {
785    use super::*;
786    use scirs2_core::ndarray::Array2;
787
788    #[test]
789    fn test_dtw_distance() {
790        let series1 = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
791        let series2 = Array1::from_vec(vec![1.0, 2.0, 2.0, 3.0, 2.0, 1.0]);
792
793        let distance = dtw_distance(series1.view(), series2.view(), None).unwrap();
794        assert!(distance >= 0.0);
795    }
796
797    #[test]
798    fn test_dtw_identical_series() {
799        let series = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
800        let distance = dtw_distance(series.view(), series.view(), None).unwrap();
801        assert_eq!(distance, 0.0);
802    }
803
804    #[test]
805    fn test_dtw_k_medoids() {
806        let time_series = Array2::from_shape_vec(
807            (4, 5),
808            vec![
809                1.0, 2.0, 3.0, 2.0, 1.0, 1.1, 2.1, 3.1, 2.1, 1.1, 5.0, 6.0, 7.0, 6.0, 5.0, 5.1,
810                6.1, 7.1, 6.1, 5.1,
811            ],
812        )
813        .unwrap();
814
815        let (medoids, assignments) = dtw_k_medoids(time_series.view(), 2, 10, None).unwrap();
816
817        assert_eq!(medoids.len(), 2);
818        assert_eq!(assignments.len(), 4);
819
820        // First two series should be in one cluster, last two in another
821        assert_eq!(assignments[0], assignments[1]);
822        assert_eq!(assignments[2], assignments[3]);
823        assert_ne!(assignments[0], assignments[2]);
824    }
825
826    #[test]
827    fn test_soft_dtw_distance() {
828        let series1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
829        let series2 = Array1::from_vec(vec![1.0, 2.5, 3.0]);
830
831        let distance = soft_dtw_distance(series1.view(), series2.view(), 0.1).unwrap();
832        assert!(distance >= 0.0);
833    }
834
835    #[test]
836    fn test_dtw_barycenter_averaging() {
837        let time_series = Array2::from_shape_vec(
838            (3, 4),
839            vec![1.0, 2.0, 3.0, 2.0, 1.1, 2.1, 3.1, 2.1, 0.9, 1.9, 2.9, 1.9],
840        )
841        .unwrap();
842
843        let barycenter = dtw_barycenter_averaging(&time_series, 10, 1e-3).unwrap();
844        assert_eq!(barycenter.len(), 4);
845
846        // Barycenter should be close to the mean
847        let mean_series = time_series.mean_axis(Axis(0)).unwrap();
848        for i in 0..4 {
849            assert!((barycenter[i] - mean_series[i]).abs() < 0.5);
850        }
851    }
852
853    #[test]
854    fn test_time_series_clustering_config() {
855        let config = TimeSeriesClusteringConfig::default();
856        assert_eq!(config.n_clusters, 3);
857        assert_eq!(config.max_iterations, 100);
858
859        let time_series = Array2::from_shape_vec(
860            (4, 5),
861            vec![
862                1.0, 2.0, 3.0, 2.0, 1.0, 1.1, 2.1, 3.1, 2.1, 1.1, 5.0, 6.0, 7.0, 6.0, 5.0, 5.1,
863                6.1, 7.1, 6.1, 5.1,
864            ],
865        )
866        .unwrap();
867
868        let assignments = time_series_clustering(time_series.view(), &config).unwrap();
869        assert_eq!(assignments.len(), 4);
870    }
871}