sklears_kernel_approximation/
time_series_kernels.rs

1//! Time series kernel approximations
2//!
3//! This module provides kernel approximation methods specifically designed for
4//! time series data, including Dynamic Time Warping (DTW), autoregressive kernels,
5//! spectral kernels, and other time-series specific kernel methods.
6
7use rayon::prelude::*;
8use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView2};
9use scirs2_core::random::essentials::Normal as RandNormal;
10use scirs2_core::random::rngs::StdRng as RealStdRng;
11use scirs2_core::random::Distribution;
12use scirs2_core::random::{thread_rng, Rng, SeedableRng};
13use sklears_core::error::Result;
14
15/// Time series kernel type
16#[derive(Clone, Debug, PartialEq)]
17/// TimeSeriesKernelType
18pub enum TimeSeriesKernelType {
19    /// Dynamic Time Warping kernel
20    DTW {
21        /// Window size for DTW alignment
22        window_size: Option<usize>,
23        /// Penalty for insertions/deletions
24        penalty: f64,
25    },
26    /// Autoregressive kernel
27    Autoregressive {
28        /// Order of the autoregressive model
29        order: usize,
30        /// Regularization parameter
31        lambda: f64,
32    },
33    /// Spectral kernel using Fourier transform
34    Spectral {
35        /// Number of frequency components
36        n_frequencies: usize,
37        /// Whether to use magnitude only or include phase
38        magnitude_only: bool,
39    },
40    /// Global Alignment Kernel (GAK)
41    GlobalAlignment { sigma: f64, triangular: bool },
42    /// Time Warp Edit Distance (TWED)
43    TimeWarpEdit {
44        /// Penalty for elastic transformation
45        nu: f64,
46        /// Penalty for time warping
47        lambda: f64,
48    },
49    /// Subsequence Time Series Kernel
50    Subsequence {
51        /// Length of subsequences
52        subsequence_length: usize,
53        /// Step size for sliding window
54        step_size: usize,
55    },
56    /// Shapelet-based kernel
57    Shapelet {
58        /// Number of shapelets to extract
59        n_shapelets: usize,
60        /// Minimum shapelet length
61        min_length: usize,
62        /// Maximum shapelet length
63        max_length: usize,
64    },
65}
66
67/// DTW distance computation configuration
68#[derive(Clone, Debug)]
69/// DTWConfig
70pub struct DTWConfig {
71    /// Window constraint for DTW alignment
72    pub window_type: DTWWindowType,
73    /// Distance metric for individual points
74    pub distance_metric: DTWDistanceMetric,
75    /// Step pattern for DTW
76    pub step_pattern: DTWStepPattern,
77    /// Whether to normalize by path length
78    pub normalize: bool,
79}
80
81/// DTW window constraint types
82#[derive(Clone, Debug, PartialEq)]
83/// DTWWindowType
84pub enum DTWWindowType {
85    /// No window constraint
86    None,
87    /// Sakoe-Chiba band
88    SakoeChiba { window_size: usize },
89    /// Itakura parallelogram
90    Itakura,
91    /// Custom window function
92    Custom { window_func: Vec<(usize, usize)> },
93}
94
95/// Distance metrics for DTW
96#[derive(Clone, Debug, PartialEq)]
97/// DTWDistanceMetric
98pub enum DTWDistanceMetric {
99    /// Euclidean distance
100    Euclidean,
101    /// Manhattan distance
102    Manhattan,
103    /// Cosine distance
104    Cosine,
105    /// Custom distance function
106    Custom,
107}
108
109/// DTW step patterns
110#[derive(Clone, Debug, PartialEq)]
111/// DTWStepPattern
112pub enum DTWStepPattern {
113    /// Symmetric step pattern
114    Symmetric,
115    /// Asymmetric step pattern
116    Asymmetric,
117    /// Custom step pattern
118    Custom { steps: Vec<(i32, i32, f64)> },
119}
120
121impl Default for DTWConfig {
122    fn default() -> Self {
123        Self {
124            window_type: DTWWindowType::None,
125            distance_metric: DTWDistanceMetric::Euclidean,
126            step_pattern: DTWStepPattern::Symmetric,
127            normalize: true,
128        }
129    }
130}
131
132/// Time series kernel configuration
133#[derive(Clone, Debug)]
134/// TimeSeriesKernelConfig
135pub struct TimeSeriesKernelConfig {
136    /// Type of time series kernel
137    pub kernel_type: TimeSeriesKernelType,
138    /// Number of random features for approximation
139    pub n_components: usize,
140    /// Random state for reproducibility
141    pub random_state: Option<u64>,
142    /// DTW-specific configuration
143    pub dtw_config: Option<DTWConfig>,
144    /// Whether to normalize time series
145    pub normalize_series: bool,
146    /// Number of parallel workers
147    pub n_workers: usize,
148}
149
150impl Default for TimeSeriesKernelConfig {
151    fn default() -> Self {
152        Self {
153            kernel_type: TimeSeriesKernelType::DTW {
154                window_size: None,
155                penalty: 0.0,
156            },
157            n_components: 100,
158            random_state: None,
159            dtw_config: Some(DTWConfig::default()),
160            normalize_series: true,
161            n_workers: num_cpus::get(),
162        }
163    }
164}
165
166/// Dynamic Time Warping kernel approximation
167pub struct DTWKernelApproximation {
168    config: TimeSeriesKernelConfig,
169    reference_series: Option<Array2<f64>>,
170    random_indices: Option<Vec<usize>>,
171    dtw_distances: Option<Array2<f64>>,
172    kernel_bandwidth: f64,
173}
174
175impl DTWKernelApproximation {
176    /// Create a new DTW kernel approximation
177    pub fn new(n_components: usize) -> Self {
178        Self {
179            config: TimeSeriesKernelConfig {
180                n_components,
181                kernel_type: TimeSeriesKernelType::DTW {
182                    window_size: None,
183                    penalty: 0.0,
184                },
185                ..Default::default()
186            },
187            reference_series: None,
188            random_indices: None,
189            dtw_distances: None,
190            kernel_bandwidth: 1.0,
191        }
192    }
193
194    /// Set kernel bandwidth
195    pub fn bandwidth(mut self, bandwidth: f64) -> Self {
196        self.kernel_bandwidth = bandwidth;
197        self
198    }
199
200    /// Set DTW window size
201    pub fn window_size(mut self, new_window_size: Option<usize>) -> Self {
202        if let TimeSeriesKernelType::DTW {
203            ref mut window_size,
204            ..
205        } = self.config.kernel_type
206        {
207            *window_size = new_window_size;
208        }
209        self
210    }
211
212    /// Set configuration
213    pub fn with_config(mut self, config: TimeSeriesKernelConfig) -> Self {
214        self.config = config;
215        self
216    }
217
218    /// Fit the DTW kernel approximation
219    pub fn fit(&mut self, time_series: &Array3<f64>) -> Result<()> {
220        let (n_series, n_timepoints, n_features) = time_series.dim();
221
222        // Select random reference series for approximation
223        let mut rng = if let Some(seed) = self.config.random_state {
224            RealStdRng::seed_from_u64(seed)
225        } else {
226            RealStdRng::from_seed(thread_rng().gen())
227        };
228
229        let n_references = std::cmp::min(self.config.n_components, n_series);
230        let mut indices: Vec<usize> = (0..n_series).collect();
231        indices.sort_by_key(|_| rng.gen::<u32>());
232        indices.truncate(n_references);
233        self.random_indices = Some(indices.clone());
234
235        // Extract reference series
236        let mut reference_series = Array2::zeros((n_references, n_timepoints * n_features));
237        for (i, &idx) in indices.iter().enumerate() {
238            let series = time_series.slice(s![idx, .., ..]);
239            let flattened = series.into_shape((n_timepoints * n_features,)).unwrap();
240            reference_series.row_mut(i).assign(&flattened);
241        }
242        self.reference_series = Some(reference_series);
243
244        Ok(())
245    }
246
247    /// Transform time series using DTW kernel features
248    pub fn transform(&self, time_series: &Array3<f64>) -> Result<Array2<f64>> {
249        let reference_series = self.reference_series.as_ref().ok_or("Model not fitted")?;
250        let (n_series, n_timepoints, n_features) = time_series.dim();
251        let n_references = reference_series.nrows();
252
253        let mut features = Array2::zeros((n_series, n_references));
254
255        // Compute DTW distances to reference series
256        for i in 0..n_series {
257            let series = time_series.slice(s![i, .., ..]);
258            let series_flat = series.into_shape((n_timepoints * n_features,)).unwrap();
259
260            for j in 0..n_references {
261                let reference = reference_series.row(j);
262                let distance = self.compute_dtw_distance(&series_flat, &reference)?;
263
264                // Convert distance to kernel value using RBF kernel
265                let kernel_value = (-distance / (2.0 * self.kernel_bandwidth.powi(2))).exp();
266                features[[i, j]] = kernel_value;
267            }
268        }
269
270        Ok(features)
271    }
272
273    /// Compute DTW distance between two time series
274    fn compute_dtw_distance(
275        &self,
276        series1: &ArrayView1<f64>,
277        series2: &ArrayView1<f64>,
278    ) -> Result<f64> {
279        let n1 = series1.len();
280        let n2 = series2.len();
281
282        // Initialize DTW matrix
283        let mut dtw_matrix = Array2::from_elem((n1 + 1, n2 + 1), f64::INFINITY);
284        dtw_matrix[[0, 0]] = 0.0;
285
286        // Apply window constraint if specified
287        let window_constraint = self.get_window_constraint(n1, n2);
288
289        for i in 1..=n1 {
290            for j in 1..=n2 {
291                if self.is_within_window(i - 1, j - 1, &window_constraint) {
292                    let cost = self.compute_point_distance(series1[i - 1], series2[j - 1]);
293
294                    let candidates = vec![
295                        dtw_matrix[[i - 1, j]] + cost,     // Insertion
296                        dtw_matrix[[i, j - 1]] + cost,     // Deletion
297                        dtw_matrix[[i - 1, j - 1]] + cost, // Match
298                    ];
299
300                    dtw_matrix[[i, j]] = candidates
301                        .into_iter()
302                        .min_by(|a, b| a.partial_cmp(b).unwrap())
303                        .unwrap();
304                }
305            }
306        }
307
308        let distance = dtw_matrix[[n1, n2]];
309
310        // Normalize by path length if requested
311        if self
312            .config
313            .dtw_config
314            .as_ref()
315            .map_or(true, |cfg| cfg.normalize)
316        {
317            Ok(distance / (n1 + n2) as f64)
318        } else {
319            Ok(distance)
320        }
321    }
322
323    /// Get window constraint for DTW
324    fn get_window_constraint(&self, n1: usize, n2: usize) -> Option<Vec<(usize, usize)>> {
325        if let Some(dtw_config) = &self.config.dtw_config {
326            match &dtw_config.window_type {
327                DTWWindowType::SakoeChiba { window_size } => {
328                    let mut constraints = Vec::new();
329                    for i in 0..n1 {
330                        let j_start = (i as i32 - *window_size as i32).max(0) as usize;
331                        let j_end = (i + window_size).min(n2 - 1);
332                        for j in j_start..=j_end {
333                            constraints.push((i, j));
334                        }
335                    }
336                    Some(constraints)
337                }
338                DTWWindowType::Custom { window_func } => Some(window_func.clone()),
339                _ => None,
340            }
341        } else {
342            None
343        }
344    }
345
346    /// Check if point is within window constraint
347    fn is_within_window(
348        &self,
349        i: usize,
350        j: usize,
351        window_constraint: &Option<Vec<(usize, usize)>>,
352    ) -> bool {
353        match window_constraint {
354            Some(constraints) => constraints.contains(&(i, j)),
355            None => true,
356        }
357    }
358
359    /// Compute distance between two points
360    fn compute_point_distance(&self, x1: f64, x2: f64) -> f64 {
361        let metric = self
362            .config
363            .dtw_config
364            .as_ref()
365            .map(|cfg| &cfg.distance_metric)
366            .unwrap_or(&DTWDistanceMetric::Euclidean);
367
368        match metric {
369            DTWDistanceMetric::Euclidean => (x1 - x2).powi(2),
370            DTWDistanceMetric::Manhattan => (x1 - x2).abs(),
371            DTWDistanceMetric::Cosine => 1.0 - (x1 * x2) / ((x1.powi(2) + x2.powi(2)).sqrt()),
372            DTWDistanceMetric::Custom => (x1 - x2).powi(2), // Default to Euclidean
373        }
374    }
375}
376
377/// Autoregressive kernel approximation
378pub struct AutoregressiveKernelApproximation {
379    config: TimeSeriesKernelConfig,
380    ar_coefficients: Option<Array2<f64>>,
381    reference_models: Option<Vec<Array1<f64>>>,
382    random_features: Option<Array2<f64>>,
383}
384
385impl AutoregressiveKernelApproximation {
386    /// Create a new autoregressive kernel approximation
387    pub fn new(n_components: usize, order: usize) -> Self {
388        Self {
389            config: TimeSeriesKernelConfig {
390                n_components,
391                kernel_type: TimeSeriesKernelType::Autoregressive { order, lambda: 0.1 },
392                ..Default::default()
393            },
394            ar_coefficients: None,
395            reference_models: None,
396            random_features: None,
397        }
398    }
399
400    /// Set regularization parameter
401    pub fn lambda(mut self, new_lambda: f64) -> Self {
402        if let TimeSeriesKernelType::Autoregressive { ref mut lambda, .. } = self.config.kernel_type
403        {
404            *lambda = new_lambda;
405        }
406        self
407    }
408
409    /// Fit the autoregressive kernel approximation
410    pub fn fit(&mut self, time_series: &Array3<f64>) -> Result<()> {
411        let (n_series, n_timepoints, n_features) = time_series.dim();
412
413        if let TimeSeriesKernelType::Autoregressive { order, lambda } = &self.config.kernel_type {
414            // Fit AR models to each time series
415            let mut ar_coefficients = Array2::zeros((n_series, order * n_features));
416
417            for i in 0..n_series {
418                let series = time_series.slice(s![i, .., ..]);
419                let coeffs = self.fit_ar_model(&series, *order, *lambda)?;
420                ar_coefficients.row_mut(i).assign(&coeffs);
421            }
422
423            self.ar_coefficients = Some(ar_coefficients);
424
425            // Generate random features based on AR coefficients
426            self.generate_random_features()?;
427        }
428
429        Ok(())
430    }
431
432    /// Transform time series using AR kernel features
433    pub fn transform(&self, time_series: &Array3<f64>) -> Result<Array2<f64>> {
434        let ar_coefficients = self.ar_coefficients.as_ref().ok_or("Model not fitted")?;
435        let random_features = self
436            .random_features
437            .as_ref()
438            .ok_or("Random features not generated")?;
439
440        let (n_series, n_timepoints, n_features) = time_series.dim();
441        let n_components = self.config.n_components;
442
443        let mut features = Array2::zeros((n_series, n_components));
444
445        if let TimeSeriesKernelType::Autoregressive { order, lambda } = &self.config.kernel_type {
446            for i in 0..n_series {
447                let series = time_series.slice(s![i, .., ..]);
448                let coeffs = self.fit_ar_model(&series, *order, *lambda)?;
449
450                // Compute random features
451                for j in 0..n_components {
452                    let random_proj = coeffs.dot(&random_features.row(j));
453                    features[[i, j]] = random_proj.cos();
454                }
455            }
456        }
457
458        Ok(features)
459    }
460
461    /// Fit AR model to a single time series
462    fn fit_ar_model(
463        &self,
464        series: &ArrayView2<f64>,
465        order: usize,
466        lambda: f64,
467    ) -> Result<Array1<f64>> {
468        let (n_timepoints, n_features) = series.dim();
469
470        if n_timepoints <= order {
471            return Err("Time series too short for specified AR order".into());
472        }
473
474        // Create design matrix X and target vector y
475        let n_samples = n_timepoints - order;
476        let mut x_matrix = Array2::zeros((n_samples, order * n_features));
477        let mut y_vector = Array2::zeros((n_samples, n_features));
478
479        for t in order..n_timepoints {
480            let sample_idx = t - order;
481
482            // Fill design matrix with lagged values
483            for lag in 1..=order {
484                let lag_idx = t - lag;
485                for feat in 0..n_features {
486                    x_matrix[[sample_idx, (lag - 1) * n_features + feat]] = series[[lag_idx, feat]];
487                }
488            }
489
490            // Fill target vector
491            for feat in 0..n_features {
492                y_vector[[sample_idx, feat]] = series[[t, feat]];
493            }
494        }
495
496        // Solve least squares with regularization: (X^T X + λI)β = X^T y
497        let xtx = x_matrix.t().dot(&x_matrix);
498        let xtx_reg = xtx + Array2::<f64>::eye(order * n_features) * lambda;
499        let xty = x_matrix.t().dot(&y_vector);
500
501        // Simplified solution (in practice, use proper linear algebra)
502        let coeffs = self.solve_linear_system(&xtx_reg, &xty)?;
503
504        Ok(coeffs)
505    }
506
507    /// Solve linear system (simplified implementation)
508    fn solve_linear_system(&self, a: &Array2<f64>, b: &Array2<f64>) -> Result<Array1<f64>> {
509        // This is a simplified implementation - in practice use proper solvers
510        let n = a.nrows();
511        let n_features = b.ncols();
512        let mut solution = Array1::zeros(n);
513
514        // Use diagonal approximation for simplicity - average across features
515        for i in 0..n {
516            if a[[i, i]].abs() > 1e-12 {
517                let avg_target =
518                    (0..n_features).map(|j| b[[i, j]]).sum::<f64>() / n_features as f64;
519                solution[i] = avg_target / a[[i, i]];
520            }
521        }
522
523        Ok(solution)
524    }
525
526    /// Generate random features for AR kernel approximation
527    fn generate_random_features(&mut self) -> Result<()> {
528        if let TimeSeriesKernelType::Autoregressive { .. } = &self.config.kernel_type {
529            let ar_coefficients = self.ar_coefficients.as_ref().unwrap();
530            let (_, n_ar_features) = ar_coefficients.dim();
531
532            let mut rng = if let Some(seed) = self.config.random_state {
533                RealStdRng::seed_from_u64(seed)
534            } else {
535                RealStdRng::from_seed(thread_rng().gen())
536            };
537
538            let normal = RandNormal::new(0.0, 1.0).unwrap();
539            let random_features =
540                Array2::from_shape_fn((self.config.n_components, n_ar_features), |_| {
541                    rng.sample(normal)
542                });
543
544            self.random_features = Some(random_features);
545        }
546
547        Ok(())
548    }
549}
550
551/// Spectral kernel approximation for time series
552pub struct SpectralKernelApproximation {
553    config: TimeSeriesKernelConfig,
554    frequency_features: Option<Array2<f64>>,
555    reference_spectra: Option<Array2<f64>>,
556}
557
558impl SpectralKernelApproximation {
559    /// Create a new spectral kernel approximation
560    pub fn new(n_components: usize, n_frequencies: usize) -> Self {
561        Self {
562            config: TimeSeriesKernelConfig {
563                n_components,
564                kernel_type: TimeSeriesKernelType::Spectral {
565                    n_frequencies,
566                    magnitude_only: true,
567                },
568                ..Default::default()
569            },
570            frequency_features: None,
571            reference_spectra: None,
572        }
573    }
574
575    /// Set whether to use magnitude only
576    pub fn magnitude_only(mut self, new_magnitude_only: bool) -> Self {
577        if let TimeSeriesKernelType::Spectral {
578            ref mut magnitude_only,
579            ..
580        } = self.config.kernel_type
581        {
582            *magnitude_only = new_magnitude_only;
583        }
584        self
585    }
586
587    /// Fit the spectral kernel approximation
588    pub fn fit(&mut self, time_series: &Array3<f64>) -> Result<()> {
589        let (n_series, n_timepoints, n_features) = time_series.dim();
590
591        if let TimeSeriesKernelType::Spectral {
592            n_frequencies,
593            magnitude_only,
594        } = &self.config.kernel_type
595        {
596            // Compute frequency domain representation
597            let mut spectra = Array2::zeros((n_series, *n_frequencies * n_features));
598
599            for i in 0..n_series {
600                let series = time_series.slice(s![i, .., ..]);
601                let spectrum =
602                    self.compute_frequency_features(&series, *n_frequencies, *magnitude_only)?;
603                spectra.row_mut(i).assign(&spectrum);
604            }
605
606            self.reference_spectra = Some(spectra);
607
608            // Generate random projection features
609            self.generate_spectral_features(*n_frequencies * n_features)?;
610        }
611
612        Ok(())
613    }
614
615    /// Transform time series using spectral features
616    pub fn transform(&self, time_series: &Array3<f64>) -> Result<Array2<f64>> {
617        let frequency_features = self.frequency_features.as_ref().ok_or("Model not fitted")?;
618
619        let (n_series, n_timepoints, n_features) = time_series.dim();
620        let n_components = self.config.n_components;
621
622        let mut features = Array2::zeros((n_series, n_components));
623
624        if let TimeSeriesKernelType::Spectral {
625            n_frequencies,
626            magnitude_only,
627        } = &self.config.kernel_type
628        {
629            for i in 0..n_series {
630                let series = time_series.slice(s![i, .., ..]);
631                let spectrum =
632                    self.compute_frequency_features(&series, *n_frequencies, *magnitude_only)?;
633
634                // Apply random projection
635                for j in 0..n_components {
636                    let projection = spectrum.dot(&frequency_features.row(j));
637                    features[[i, j]] = projection.cos();
638                }
639            }
640        }
641
642        Ok(features)
643    }
644
645    /// Compute frequency domain features using FFT
646    fn compute_frequency_features(
647        &self,
648        series: &ArrayView2<f64>,
649        n_frequencies: usize,
650        magnitude_only: bool,
651    ) -> Result<Array1<f64>> {
652        let (n_timepoints, n_features) = series.dim();
653        let mut features = Vec::new();
654
655        for feat in 0..n_features {
656            let signal = series.column(feat);
657
658            // Simple discrete Fourier transform approximation
659            for k in 0..n_frequencies {
660                let freq = 2.0 * std::f64::consts::PI * k as f64 / n_timepoints as f64;
661
662                let mut real_part = 0.0;
663                let mut imag_part = 0.0;
664
665                for t in 0..n_timepoints {
666                    let angle = freq * t as f64;
667                    real_part += signal[t] * angle.cos();
668                    imag_part += signal[t] * angle.sin();
669                }
670
671                if magnitude_only {
672                    features.push((real_part.powi(2) + imag_part.powi(2)).sqrt());
673                } else {
674                    features.push(real_part);
675                    features.push(imag_part);
676                }
677            }
678        }
679
680        Ok(Array1::from(features))
681    }
682
683    /// Generate random spectral features
684    fn generate_spectral_features(&mut self, n_spectrum_features: usize) -> Result<()> {
685        let mut rng = if let Some(seed) = self.config.random_state {
686            RealStdRng::seed_from_u64(seed)
687        } else {
688            RealStdRng::from_seed(thread_rng().gen())
689        };
690
691        let normal = RandNormal::new(0.0, 1.0).unwrap();
692        let frequency_features =
693            Array2::from_shape_fn((self.config.n_components, n_spectrum_features), |_| {
694                rng.sample(normal)
695            });
696
697        self.frequency_features = Some(frequency_features);
698        Ok(())
699    }
700}
701
702/// Global Alignment Kernel (GAK) approximation
703pub struct GlobalAlignmentKernelApproximation {
704    config: TimeSeriesKernelConfig,
705    reference_series: Option<Array2<f64>>,
706    sigma: f64,
707}
708
709impl GlobalAlignmentKernelApproximation {
710    /// Create a new GAK approximation
711    pub fn new(n_components: usize, sigma: f64) -> Self {
712        Self {
713            config: TimeSeriesKernelConfig {
714                n_components,
715                kernel_type: TimeSeriesKernelType::GlobalAlignment {
716                    sigma,
717                    triangular: false,
718                },
719                ..Default::default()
720            },
721            reference_series: None,
722            sigma,
723        }
724    }
725
726    /// Fit the GAK approximation
727    pub fn fit(&mut self, time_series: &Array3<f64>) -> Result<()> {
728        let (n_series, n_timepoints, n_features) = time_series.dim();
729
730        // Select random reference series
731        let mut rng = if let Some(seed) = self.config.random_state {
732            RealStdRng::seed_from_u64(seed)
733        } else {
734            RealStdRng::from_seed(thread_rng().gen())
735        };
736
737        let n_references = std::cmp::min(self.config.n_components, n_series);
738        let mut indices: Vec<usize> = (0..n_series).collect();
739        indices.sort_by_key(|_| rng.gen::<u32>());
740        indices.truncate(n_references);
741
742        let mut reference_series = Array2::zeros((n_references, n_timepoints * n_features));
743        for (i, &idx) in indices.iter().enumerate() {
744            let series = time_series.slice(s![idx, .., ..]);
745            let flattened = series.into_shape((n_timepoints * n_features,)).unwrap();
746            reference_series.row_mut(i).assign(&flattened);
747        }
748
749        self.reference_series = Some(reference_series);
750        Ok(())
751    }
752
753    /// Transform using GAK features
754    pub fn transform(&self, time_series: &Array3<f64>) -> Result<Array2<f64>> {
755        let reference_series = self.reference_series.as_ref().ok_or("Model not fitted")?;
756
757        let (n_series, n_timepoints, n_features) = time_series.dim();
758        let n_references = reference_series.nrows();
759
760        let mut features = Array2::zeros((n_series, n_references));
761
762        for i in 0..n_series {
763            let series = time_series.slice(s![i, .., ..]);
764            let series_flat = series.into_shape((n_timepoints * n_features,)).unwrap();
765
766            for j in 0..n_references {
767                let reference = reference_series.row(j);
768                let gak_value = self.compute_gak(&series_flat, &reference)?;
769                features[[i, j]] = gak_value;
770            }
771        }
772
773        Ok(features)
774    }
775
776    /// Compute Global Alignment Kernel
777    fn compute_gak(&self, series1: &ArrayView1<f64>, series2: &ArrayView1<f64>) -> Result<f64> {
778        let n1 = series1.len();
779        let n2 = series2.len();
780
781        // Initialize GAK matrix with exponential of negative squared distances
782        let mut gak_matrix = Array2::zeros((n1 + 1, n2 + 1));
783
784        for i in 1..=n1 {
785            for j in 1..=n2 {
786                let dist_sq = (series1[i - 1] - series2[j - 1]).powi(2);
787                let kernel_val = (-dist_sq / (2.0 * self.sigma.powi(2))).exp();
788
789                // GAK recurrence relation
790                let max_alignment = vec![
791                    gak_matrix[[i - 1, j]] * kernel_val,
792                    gak_matrix[[i, j - 1]] * kernel_val,
793                    gak_matrix[[i - 1, j - 1]] * kernel_val,
794                ]
795                .into_iter()
796                .max_by(|a: &f64, b: &f64| a.partial_cmp(b).unwrap())
797                .unwrap_or(0.0);
798
799                gak_matrix[[i, j]] = max_alignment;
800            }
801        }
802
803        Ok(gak_matrix[[n1, n2]])
804    }
805}
806
807#[allow(non_snake_case)]
808#[cfg(test)]
809mod tests {
810    use super::*;
811    use scirs2_core::ndarray::Array3;
812
813    fn create_test_time_series() -> Array3<f64> {
814        // Create simple test time series data
815        let mut ts = Array3::zeros((5, 10, 2));
816
817        for i in 0..5 {
818            for t in 0..10 {
819                ts[[i, t, 0]] = (t as f64 + i as f64).sin();
820                ts[[i, t, 1]] = (t as f64 + i as f64).cos();
821            }
822        }
823
824        ts
825    }
826
827    #[test]
828    fn test_dtw_kernel_approximation() {
829        let time_series = create_test_time_series();
830
831        let mut dtw_kernel = DTWKernelApproximation::new(3)
832            .bandwidth(1.0)
833            .window_size(Some(2));
834
835        dtw_kernel.fit(&time_series).unwrap();
836        let features = dtw_kernel.transform(&time_series).unwrap();
837
838        assert_eq!(features.shape(), &[5, 3]);
839
840        // Features should be positive (RBF kernel values)
841        assert!(features.iter().all(|&x| x >= 0.0 && x <= 1.0));
842    }
843
844    #[test]
845    fn test_autoregressive_kernel_approximation() {
846        let time_series = create_test_time_series();
847
848        let mut ar_kernel = AutoregressiveKernelApproximation::new(4, 2).lambda(0.1);
849
850        ar_kernel.fit(&time_series).unwrap();
851        let features = ar_kernel.transform(&time_series).unwrap();
852
853        assert_eq!(features.shape(), &[5, 4]);
854    }
855
856    #[test]
857    fn test_spectral_kernel_approximation() {
858        let time_series = create_test_time_series();
859
860        let mut spectral_kernel = SpectralKernelApproximation::new(6, 5).magnitude_only(true);
861
862        spectral_kernel.fit(&time_series).unwrap();
863        let features = spectral_kernel.transform(&time_series).unwrap();
864
865        assert_eq!(features.shape(), &[5, 6]);
866    }
867
868    #[test]
869    fn test_global_alignment_kernel() {
870        let time_series = create_test_time_series();
871
872        let mut gak = GlobalAlignmentKernelApproximation::new(3, 1.0);
873
874        gak.fit(&time_series).unwrap();
875        let features = gak.transform(&time_series).unwrap();
876
877        assert_eq!(features.shape(), &[5, 3]);
878        assert!(features.iter().all(|&x| x >= 0.0));
879    }
880
881    #[test]
882    fn test_dtw_distance_computation() {
883        let series1 = Array1::from(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
884        let series2 = Array1::from(vec![1.0, 3.0, 2.0, 1.0]);
885
886        let dtw_kernel = DTWKernelApproximation::new(1);
887        let distance = dtw_kernel
888            .compute_dtw_distance(&series1.view(), &series2.view())
889            .unwrap();
890
891        assert!(distance >= 0.0);
892        assert!(distance.is_finite());
893    }
894
895    #[test]
896    fn test_ar_model_fitting() {
897        let time_series = create_test_time_series();
898        let series = time_series.slice(s![0, .., ..]);
899
900        let ar_kernel = AutoregressiveKernelApproximation::new(10, 2);
901        let coeffs = ar_kernel.fit_ar_model(&series, 2, 0.1).unwrap();
902
903        assert_eq!(coeffs.len(), 4); // 2 lags * 2 features
904    }
905
906    #[test]
907    fn test_frequency_features() {
908        let time_series = create_test_time_series();
909        let series = time_series.slice(s![0, .., ..]);
910
911        let spectral_kernel = SpectralKernelApproximation::new(10, 5);
912        let features = spectral_kernel
913            .compute_frequency_features(&series, 5, true)
914            .unwrap();
915
916        assert_eq!(features.len(), 10); // 5 frequencies * 2 features
917        assert!(features.iter().all(|&x| x >= 0.0)); // Magnitudes are non-negative
918    }
919
920    #[test]
921    fn test_time_series_kernel_config() {
922        let config = TimeSeriesKernelConfig::default();
923
924        assert_eq!(config.n_components, 100);
925        assert!(matches!(
926            config.kernel_type,
927            TimeSeriesKernelType::DTW { .. }
928        ));
929        assert!(config.normalize_series);
930        assert!(config.dtw_config.is_some());
931    }
932
933    #[test]
934    fn test_dtw_window_constraints() {
935        let dtw_kernel = DTWKernelApproximation::new(1);
936        let constraints = dtw_kernel.get_window_constraint(5, 5);
937
938        assert!(constraints.is_none()); // No window constraint by default
939
940        let window_constraint = Some(vec![(0, 0), (1, 1), (2, 2)]);
941        assert!(dtw_kernel.is_within_window(1, 1, &window_constraint));
942        assert!(!dtw_kernel.is_within_window(0, 2, &window_constraint));
943    }
944
945    #[test]
946    fn test_reproducibility_with_random_state() {
947        let time_series = create_test_time_series();
948
949        let config1 = TimeSeriesKernelConfig {
950            random_state: Some(42),
951            ..Default::default()
952        };
953        let config2 = TimeSeriesKernelConfig {
954            random_state: Some(42),
955            ..Default::default()
956        };
957
958        let mut dtw1 = DTWKernelApproximation::new(3).with_config(config1);
959        let mut dtw2 = DTWKernelApproximation::new(3).with_config(config2);
960
961        dtw1.fit(&time_series).unwrap();
962        dtw2.fit(&time_series).unwrap();
963
964        let features1 = dtw1.transform(&time_series).unwrap();
965        let features2 = dtw2.transform(&time_series).unwrap();
966
967        // Should be approximately equal due to same random state
968        for i in 0..features1.len() {
969            assert!(
970                (features1.as_slice().unwrap()[i] - features2.as_slice().unwrap()[i]).abs() < 1e-10
971            );
972        }
973    }
974}