sklears_kernel_approximation/
time_series_kernels.rs

1//! Time series kernel approximations
2//!
3//! This module provides kernel approximation methods specifically designed for
4//! time series data, including Dynamic Time Warping (DTW), autoregressive kernels,
5//! spectral kernels, and other time-series specific kernel methods.
6
7use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView2};
8use scirs2_core::random::essentials::Normal as RandNormal;
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::Rng;
11use scirs2_core::random::{thread_rng, SeedableRng};
12use sklears_core::error::Result;
13
14/// Time series kernel type
15#[derive(Clone, Debug, PartialEq)]
16/// TimeSeriesKernelType
17pub enum TimeSeriesKernelType {
18    /// Dynamic Time Warping kernel
19    DTW {
20        /// Window size for DTW alignment
21        window_size: Option<usize>,
22        /// Penalty for insertions/deletions
23        penalty: f64,
24    },
25    /// Autoregressive kernel
26    Autoregressive {
27        /// Order of the autoregressive model
28        order: usize,
29        /// Regularization parameter
30        lambda: f64,
31    },
32    /// Spectral kernel using Fourier transform
33    Spectral {
34        /// Number of frequency components
35        n_frequencies: usize,
36        /// Whether to use magnitude only or include phase
37        magnitude_only: bool,
38    },
39    /// Global Alignment Kernel (GAK)
40    GlobalAlignment { sigma: f64, triangular: bool },
41    /// Time Warp Edit Distance (TWED)
42    TimeWarpEdit {
43        /// Penalty for elastic transformation
44        nu: f64,
45        /// Penalty for time warping
46        lambda: f64,
47    },
48    /// Subsequence Time Series Kernel
49    Subsequence {
50        /// Length of subsequences
51        subsequence_length: usize,
52        /// Step size for sliding window
53        step_size: usize,
54    },
55    /// Shapelet-based kernel
56    Shapelet {
57        /// Number of shapelets to extract
58        n_shapelets: usize,
59        /// Minimum shapelet length
60        min_length: usize,
61        /// Maximum shapelet length
62        max_length: usize,
63    },
64}
65
66/// DTW distance computation configuration
67#[derive(Clone, Debug)]
68/// DTWConfig
69pub struct DTWConfig {
70    /// Window constraint for DTW alignment
71    pub window_type: DTWWindowType,
72    /// Distance metric for individual points
73    pub distance_metric: DTWDistanceMetric,
74    /// Step pattern for DTW
75    pub step_pattern: DTWStepPattern,
76    /// Whether to normalize by path length
77    pub normalize: bool,
78}
79
80/// DTW window constraint types
81#[derive(Clone, Debug, PartialEq)]
82/// DTWWindowType
83pub enum DTWWindowType {
84    /// No window constraint
85    None,
86    /// Sakoe-Chiba band
87    SakoeChiba { window_size: usize },
88    /// Itakura parallelogram
89    Itakura,
90    /// Custom window function
91    Custom { window_func: Vec<(usize, usize)> },
92}
93
94/// Distance metrics for DTW
95#[derive(Clone, Debug, PartialEq)]
96/// DTWDistanceMetric
97pub enum DTWDistanceMetric {
98    /// Euclidean distance
99    Euclidean,
100    /// Manhattan distance
101    Manhattan,
102    /// Cosine distance
103    Cosine,
104    /// Custom distance function
105    Custom,
106}
107
108/// DTW step patterns
109#[derive(Clone, Debug, PartialEq)]
110/// DTWStepPattern
111pub enum DTWStepPattern {
112    /// Symmetric step pattern
113    Symmetric,
114    /// Asymmetric step pattern
115    Asymmetric,
116    /// Custom step pattern
117    Custom { steps: Vec<(i32, i32, f64)> },
118}
119
120impl Default for DTWConfig {
121    fn default() -> Self {
122        Self {
123            window_type: DTWWindowType::None,
124            distance_metric: DTWDistanceMetric::Euclidean,
125            step_pattern: DTWStepPattern::Symmetric,
126            normalize: true,
127        }
128    }
129}
130
131/// Time series kernel configuration
132#[derive(Clone, Debug)]
133/// TimeSeriesKernelConfig
134pub struct TimeSeriesKernelConfig {
135    /// Type of time series kernel
136    pub kernel_type: TimeSeriesKernelType,
137    /// Number of random features for approximation
138    pub n_components: usize,
139    /// Random state for reproducibility
140    pub random_state: Option<u64>,
141    /// DTW-specific configuration
142    pub dtw_config: Option<DTWConfig>,
143    /// Whether to normalize time series
144    pub normalize_series: bool,
145    /// Number of parallel workers
146    pub n_workers: usize,
147}
148
149impl Default for TimeSeriesKernelConfig {
150    fn default() -> Self {
151        Self {
152            kernel_type: TimeSeriesKernelType::DTW {
153                window_size: None,
154                penalty: 0.0,
155            },
156            n_components: 100,
157            random_state: None,
158            dtw_config: Some(DTWConfig::default()),
159            normalize_series: true,
160            n_workers: num_cpus::get(),
161        }
162    }
163}
164
165/// Dynamic Time Warping kernel approximation
166pub struct DTWKernelApproximation {
167    config: TimeSeriesKernelConfig,
168    reference_series: Option<Array2<f64>>,
169    random_indices: Option<Vec<usize>>,
170    dtw_distances: Option<Array2<f64>>,
171    kernel_bandwidth: f64,
172}
173
174impl DTWKernelApproximation {
175    /// Create a new DTW kernel approximation
176    pub fn new(n_components: usize) -> Self {
177        Self {
178            config: TimeSeriesKernelConfig {
179                n_components,
180                kernel_type: TimeSeriesKernelType::DTW {
181                    window_size: None,
182                    penalty: 0.0,
183                },
184                ..Default::default()
185            },
186            reference_series: None,
187            random_indices: None,
188            dtw_distances: None,
189            kernel_bandwidth: 1.0,
190        }
191    }
192
193    /// Set kernel bandwidth
194    pub fn bandwidth(mut self, bandwidth: f64) -> Self {
195        self.kernel_bandwidth = bandwidth;
196        self
197    }
198
199    /// Set DTW window size
200    pub fn window_size(mut self, new_window_size: Option<usize>) -> Self {
201        if let TimeSeriesKernelType::DTW {
202            ref mut window_size,
203            ..
204        } = self.config.kernel_type
205        {
206            *window_size = new_window_size;
207        }
208        self
209    }
210
211    /// Set configuration
212    pub fn with_config(mut self, config: TimeSeriesKernelConfig) -> Self {
213        self.config = config;
214        self
215    }
216
217    /// Fit the DTW kernel approximation
218    pub fn fit(&mut self, time_series: &Array3<f64>) -> Result<()> {
219        let (n_series, n_timepoints, n_features) = time_series.dim();
220
221        // Select random reference series for approximation
222        let mut rng = if let Some(seed) = self.config.random_state {
223            RealStdRng::seed_from_u64(seed)
224        } else {
225            RealStdRng::from_seed(thread_rng().gen())
226        };
227
228        let n_references = std::cmp::min(self.config.n_components, n_series);
229        let mut indices: Vec<usize> = (0..n_series).collect();
230        indices.sort_by_key(|_| rng.gen::<u32>());
231        indices.truncate(n_references);
232        self.random_indices = Some(indices.clone());
233
234        // Extract reference series
235        let mut reference_series = Array2::zeros((n_references, n_timepoints * n_features));
236        for (i, &idx) in indices.iter().enumerate() {
237            let series = time_series.slice(s![idx, .., ..]);
238            let flattened = series.into_shape((n_timepoints * n_features,)).unwrap();
239            reference_series.row_mut(i).assign(&flattened);
240        }
241        self.reference_series = Some(reference_series);
242
243        Ok(())
244    }
245
246    /// Transform time series using DTW kernel features
247    pub fn transform(&self, time_series: &Array3<f64>) -> Result<Array2<f64>> {
248        let reference_series = self.reference_series.as_ref().ok_or("Model not fitted")?;
249        let (n_series, n_timepoints, n_features) = time_series.dim();
250        let n_references = reference_series.nrows();
251
252        let mut features = Array2::zeros((n_series, n_references));
253
254        // Compute DTW distances to reference series
255        for i in 0..n_series {
256            let series = time_series.slice(s![i, .., ..]);
257            let series_flat = series.into_shape((n_timepoints * n_features,)).unwrap();
258
259            for j in 0..n_references {
260                let reference = reference_series.row(j);
261                let distance = self.compute_dtw_distance(&series_flat, &reference)?;
262
263                // Convert distance to kernel value using RBF kernel
264                let kernel_value = (-distance / (2.0 * self.kernel_bandwidth.powi(2))).exp();
265                features[[i, j]] = kernel_value;
266            }
267        }
268
269        Ok(features)
270    }
271
272    /// Compute DTW distance between two time series
273    fn compute_dtw_distance(
274        &self,
275        series1: &ArrayView1<f64>,
276        series2: &ArrayView1<f64>,
277    ) -> Result<f64> {
278        let n1 = series1.len();
279        let n2 = series2.len();
280
281        // Initialize DTW matrix
282        let mut dtw_matrix = Array2::from_elem((n1 + 1, n2 + 1), f64::INFINITY);
283        dtw_matrix[[0, 0]] = 0.0;
284
285        // Apply window constraint if specified
286        let window_constraint = self.get_window_constraint(n1, n2);
287
288        for i in 1..=n1 {
289            for j in 1..=n2 {
290                if self.is_within_window(i - 1, j - 1, &window_constraint) {
291                    let cost = self.compute_point_distance(series1[i - 1], series2[j - 1]);
292
293                    let candidates = vec![
294                        dtw_matrix[[i - 1, j]] + cost,     // Insertion
295                        dtw_matrix[[i, j - 1]] + cost,     // Deletion
296                        dtw_matrix[[i - 1, j - 1]] + cost, // Match
297                    ];
298
299                    dtw_matrix[[i, j]] = candidates
300                        .into_iter()
301                        .min_by(|a, b| a.partial_cmp(b).unwrap())
302                        .unwrap();
303                }
304            }
305        }
306
307        let distance = dtw_matrix[[n1, n2]];
308
309        // Normalize by path length if requested
310        if self
311            .config
312            .dtw_config
313            .as_ref()
314            .map_or(true, |cfg| cfg.normalize)
315        {
316            Ok(distance / (n1 + n2) as f64)
317        } else {
318            Ok(distance)
319        }
320    }
321
322    /// Get window constraint for DTW
323    fn get_window_constraint(&self, n1: usize, n2: usize) -> Option<Vec<(usize, usize)>> {
324        if let Some(dtw_config) = &self.config.dtw_config {
325            match &dtw_config.window_type {
326                DTWWindowType::SakoeChiba { window_size } => {
327                    let mut constraints = Vec::new();
328                    for i in 0..n1 {
329                        let j_start = (i as i32 - *window_size as i32).max(0) as usize;
330                        let j_end = (i + window_size).min(n2 - 1);
331                        for j in j_start..=j_end {
332                            constraints.push((i, j));
333                        }
334                    }
335                    Some(constraints)
336                }
337                DTWWindowType::Custom { window_func } => Some(window_func.clone()),
338                _ => None,
339            }
340        } else {
341            None
342        }
343    }
344
345    /// Check if point is within window constraint
346    fn is_within_window(
347        &self,
348        i: usize,
349        j: usize,
350        window_constraint: &Option<Vec<(usize, usize)>>,
351    ) -> bool {
352        match window_constraint {
353            Some(constraints) => constraints.contains(&(i, j)),
354            None => true,
355        }
356    }
357
358    /// Compute distance between two points
359    fn compute_point_distance(&self, x1: f64, x2: f64) -> f64 {
360        let metric = self
361            .config
362            .dtw_config
363            .as_ref()
364            .map(|cfg| &cfg.distance_metric)
365            .unwrap_or(&DTWDistanceMetric::Euclidean);
366
367        match metric {
368            DTWDistanceMetric::Euclidean => (x1 - x2).powi(2),
369            DTWDistanceMetric::Manhattan => (x1 - x2).abs(),
370            DTWDistanceMetric::Cosine => 1.0 - (x1 * x2) / ((x1.powi(2) + x2.powi(2)).sqrt()),
371            DTWDistanceMetric::Custom => (x1 - x2).powi(2), // Default to Euclidean
372        }
373    }
374}
375
376/// Autoregressive kernel approximation
377pub struct AutoregressiveKernelApproximation {
378    config: TimeSeriesKernelConfig,
379    ar_coefficients: Option<Array2<f64>>,
380    reference_models: Option<Vec<Array1<f64>>>,
381    random_features: Option<Array2<f64>>,
382}
383
384impl AutoregressiveKernelApproximation {
385    /// Create a new autoregressive kernel approximation
386    pub fn new(n_components: usize, order: usize) -> Self {
387        Self {
388            config: TimeSeriesKernelConfig {
389                n_components,
390                kernel_type: TimeSeriesKernelType::Autoregressive { order, lambda: 0.1 },
391                ..Default::default()
392            },
393            ar_coefficients: None,
394            reference_models: None,
395            random_features: None,
396        }
397    }
398
399    /// Set regularization parameter
400    pub fn lambda(mut self, new_lambda: f64) -> Self {
401        if let TimeSeriesKernelType::Autoregressive { ref mut lambda, .. } = self.config.kernel_type
402        {
403            *lambda = new_lambda;
404        }
405        self
406    }
407
408    /// Fit the autoregressive kernel approximation
409    pub fn fit(&mut self, time_series: &Array3<f64>) -> Result<()> {
410        let (n_series, _n_timepoints, n_features) = time_series.dim();
411
412        if let TimeSeriesKernelType::Autoregressive { order, lambda } = &self.config.kernel_type {
413            // Fit AR models to each time series
414            let mut ar_coefficients = Array2::zeros((n_series, order * n_features));
415
416            for i in 0..n_series {
417                let series = time_series.slice(s![i, .., ..]);
418                let coeffs = self.fit_ar_model(&series, *order, *lambda)?;
419                ar_coefficients.row_mut(i).assign(&coeffs);
420            }
421
422            self.ar_coefficients = Some(ar_coefficients);
423
424            // Generate random features based on AR coefficients
425            self.generate_random_features()?;
426        }
427
428        Ok(())
429    }
430
431    /// Transform time series using AR kernel features
432    pub fn transform(&self, time_series: &Array3<f64>) -> Result<Array2<f64>> {
433        let _ar_coefficients = self.ar_coefficients.as_ref().ok_or("Model not fitted")?;
434        let random_features = self
435            .random_features
436            .as_ref()
437            .ok_or("Random features not generated")?;
438
439        let (n_series, _n_timepoints, _n_features) = time_series.dim();
440        let n_components = self.config.n_components;
441
442        let mut features = Array2::zeros((n_series, n_components));
443
444        if let TimeSeriesKernelType::Autoregressive { order, lambda } = &self.config.kernel_type {
445            for i in 0..n_series {
446                let series = time_series.slice(s![i, .., ..]);
447                let coeffs = self.fit_ar_model(&series, *order, *lambda)?;
448
449                // Compute random features
450                for j in 0..n_components {
451                    let random_proj = coeffs.dot(&random_features.row(j));
452                    features[[i, j]] = random_proj.cos();
453                }
454            }
455        }
456
457        Ok(features)
458    }
459
460    /// Fit AR model to a single time series
461    fn fit_ar_model(
462        &self,
463        series: &ArrayView2<f64>,
464        order: usize,
465        lambda: f64,
466    ) -> Result<Array1<f64>> {
467        let (n_timepoints, n_features) = series.dim();
468
469        if n_timepoints <= order {
470            return Err("Time series too short for specified AR order".into());
471        }
472
473        // Create design matrix X and target vector y
474        let n_samples = n_timepoints - order;
475        let mut x_matrix = Array2::zeros((n_samples, order * n_features));
476        let mut y_vector = Array2::zeros((n_samples, n_features));
477
478        for t in order..n_timepoints {
479            let sample_idx = t - order;
480
481            // Fill design matrix with lagged values
482            for lag in 1..=order {
483                let lag_idx = t - lag;
484                for feat in 0..n_features {
485                    x_matrix[[sample_idx, (lag - 1) * n_features + feat]] = series[[lag_idx, feat]];
486                }
487            }
488
489            // Fill target vector
490            for feat in 0..n_features {
491                y_vector[[sample_idx, feat]] = series[[t, feat]];
492            }
493        }
494
495        // Solve least squares with regularization: (X^T X + λI)β = X^T y
496        let xtx = x_matrix.t().dot(&x_matrix);
497        let xtx_reg = xtx + Array2::<f64>::eye(order * n_features) * lambda;
498        let xty = x_matrix.t().dot(&y_vector);
499
500        // Simplified solution (in practice, use proper linear algebra)
501        let coeffs = self.solve_linear_system(&xtx_reg, &xty)?;
502
503        Ok(coeffs)
504    }
505
506    /// Solve linear system (simplified implementation)
507    fn solve_linear_system(&self, a: &Array2<f64>, b: &Array2<f64>) -> Result<Array1<f64>> {
508        // This is a simplified implementation - in practice use proper solvers
509        let n = a.nrows();
510        let n_features = b.ncols();
511        let mut solution = Array1::zeros(n);
512
513        // Use diagonal approximation for simplicity - average across features
514        for i in 0..n {
515            if a[[i, i]].abs() > 1e-12 {
516                let avg_target =
517                    (0..n_features).map(|j| b[[i, j]]).sum::<f64>() / n_features as f64;
518                solution[i] = avg_target / a[[i, i]];
519            }
520        }
521
522        Ok(solution)
523    }
524
525    /// Generate random features for AR kernel approximation
526    fn generate_random_features(&mut self) -> Result<()> {
527        if let TimeSeriesKernelType::Autoregressive { .. } = &self.config.kernel_type {
528            let ar_coefficients = self.ar_coefficients.as_ref().unwrap();
529            let (_, n_ar_features) = ar_coefficients.dim();
530
531            let mut rng = if let Some(seed) = self.config.random_state {
532                RealStdRng::seed_from_u64(seed)
533            } else {
534                RealStdRng::from_seed(thread_rng().gen())
535            };
536
537            let normal = RandNormal::new(0.0, 1.0).unwrap();
538            let random_features =
539                Array2::from_shape_fn((self.config.n_components, n_ar_features), |_| {
540                    rng.sample(normal)
541                });
542
543            self.random_features = Some(random_features);
544        }
545
546        Ok(())
547    }
548}
549
550/// Spectral kernel approximation for time series
551pub struct SpectralKernelApproximation {
552    config: TimeSeriesKernelConfig,
553    frequency_features: Option<Array2<f64>>,
554    reference_spectra: Option<Array2<f64>>,
555}
556
557impl SpectralKernelApproximation {
558    /// Create a new spectral kernel approximation
559    pub fn new(n_components: usize, n_frequencies: usize) -> Self {
560        Self {
561            config: TimeSeriesKernelConfig {
562                n_components,
563                kernel_type: TimeSeriesKernelType::Spectral {
564                    n_frequencies,
565                    magnitude_only: true,
566                },
567                ..Default::default()
568            },
569            frequency_features: None,
570            reference_spectra: None,
571        }
572    }
573
574    /// Set whether to use magnitude only
575    pub fn magnitude_only(mut self, new_magnitude_only: bool) -> Self {
576        if let TimeSeriesKernelType::Spectral {
577            ref mut magnitude_only,
578            ..
579        } = self.config.kernel_type
580        {
581            *magnitude_only = new_magnitude_only;
582        }
583        self
584    }
585
586    /// Fit the spectral kernel approximation
587    pub fn fit(&mut self, time_series: &Array3<f64>) -> Result<()> {
588        let (n_series, _n_timepoints, n_features) = time_series.dim();
589
590        if let TimeSeriesKernelType::Spectral {
591            n_frequencies,
592            magnitude_only,
593        } = &self.config.kernel_type
594        {
595            // Compute frequency domain representation
596            let mut spectra = Array2::zeros((n_series, *n_frequencies * n_features));
597
598            for i in 0..n_series {
599                let series = time_series.slice(s![i, .., ..]);
600                let spectrum =
601                    self.compute_frequency_features(&series, *n_frequencies, *magnitude_only)?;
602                spectra.row_mut(i).assign(&spectrum);
603            }
604
605            self.reference_spectra = Some(spectra);
606
607            // Generate random projection features
608            self.generate_spectral_features(*n_frequencies * n_features)?;
609        }
610
611        Ok(())
612    }
613
614    /// Transform time series using spectral features
615    pub fn transform(&self, time_series: &Array3<f64>) -> Result<Array2<f64>> {
616        let frequency_features = self.frequency_features.as_ref().ok_or("Model not fitted")?;
617
618        let (n_series, _n_timepoints, _n_features) = time_series.dim();
619        let n_components = self.config.n_components;
620
621        let mut features = Array2::zeros((n_series, n_components));
622
623        if let TimeSeriesKernelType::Spectral {
624            n_frequencies,
625            magnitude_only,
626        } = &self.config.kernel_type
627        {
628            for i in 0..n_series {
629                let series = time_series.slice(s![i, .., ..]);
630                let spectrum =
631                    self.compute_frequency_features(&series, *n_frequencies, *magnitude_only)?;
632
633                // Apply random projection
634                for j in 0..n_components {
635                    let projection = spectrum.dot(&frequency_features.row(j));
636                    features[[i, j]] = projection.cos();
637                }
638            }
639        }
640
641        Ok(features)
642    }
643
644    /// Compute frequency domain features using FFT
645    fn compute_frequency_features(
646        &self,
647        series: &ArrayView2<f64>,
648        n_frequencies: usize,
649        magnitude_only: bool,
650    ) -> Result<Array1<f64>> {
651        let (n_timepoints, n_features) = series.dim();
652        let mut features = Vec::new();
653
654        for feat in 0..n_features {
655            let signal = series.column(feat);
656
657            // Simple discrete Fourier transform approximation
658            for k in 0..n_frequencies {
659                let freq = 2.0 * std::f64::consts::PI * k as f64 / n_timepoints as f64;
660
661                let mut real_part = 0.0;
662                let mut imag_part = 0.0;
663
664                for t in 0..n_timepoints {
665                    let angle = freq * t as f64;
666                    real_part += signal[t] * angle.cos();
667                    imag_part += signal[t] * angle.sin();
668                }
669
670                if magnitude_only {
671                    features.push((real_part.powi(2) + imag_part.powi(2)).sqrt());
672                } else {
673                    features.push(real_part);
674                    features.push(imag_part);
675                }
676            }
677        }
678
679        Ok(Array1::from(features))
680    }
681
682    /// Generate random spectral features
683    fn generate_spectral_features(&mut self, n_spectrum_features: usize) -> Result<()> {
684        let mut rng = if let Some(seed) = self.config.random_state {
685            RealStdRng::seed_from_u64(seed)
686        } else {
687            RealStdRng::from_seed(thread_rng().gen())
688        };
689
690        let normal = RandNormal::new(0.0, 1.0).unwrap();
691        let frequency_features =
692            Array2::from_shape_fn((self.config.n_components, n_spectrum_features), |_| {
693                rng.sample(normal)
694            });
695
696        self.frequency_features = Some(frequency_features);
697        Ok(())
698    }
699}
700
701/// Global Alignment Kernel (GAK) approximation
702pub struct GlobalAlignmentKernelApproximation {
703    config: TimeSeriesKernelConfig,
704    reference_series: Option<Array2<f64>>,
705    sigma: f64,
706}
707
708impl GlobalAlignmentKernelApproximation {
709    /// Create a new GAK approximation
710    pub fn new(n_components: usize, sigma: f64) -> Self {
711        Self {
712            config: TimeSeriesKernelConfig {
713                n_components,
714                kernel_type: TimeSeriesKernelType::GlobalAlignment {
715                    sigma,
716                    triangular: false,
717                },
718                ..Default::default()
719            },
720            reference_series: None,
721            sigma,
722        }
723    }
724
725    /// Fit the GAK approximation
726    pub fn fit(&mut self, time_series: &Array3<f64>) -> Result<()> {
727        let (n_series, n_timepoints, n_features) = time_series.dim();
728
729        // Select random reference series
730        let mut rng = if let Some(seed) = self.config.random_state {
731            RealStdRng::seed_from_u64(seed)
732        } else {
733            RealStdRng::from_seed(thread_rng().gen())
734        };
735
736        let n_references = std::cmp::min(self.config.n_components, n_series);
737        let mut indices: Vec<usize> = (0..n_series).collect();
738        indices.sort_by_key(|_| rng.gen::<u32>());
739        indices.truncate(n_references);
740
741        let mut reference_series = Array2::zeros((n_references, n_timepoints * n_features));
742        for (i, &idx) in indices.iter().enumerate() {
743            let series = time_series.slice(s![idx, .., ..]);
744            let flattened = series.into_shape((n_timepoints * n_features,)).unwrap();
745            reference_series.row_mut(i).assign(&flattened);
746        }
747
748        self.reference_series = Some(reference_series);
749        Ok(())
750    }
751
752    /// Transform using GAK features
753    pub fn transform(&self, time_series: &Array3<f64>) -> Result<Array2<f64>> {
754        let reference_series = self.reference_series.as_ref().ok_or("Model not fitted")?;
755
756        let (n_series, n_timepoints, n_features) = time_series.dim();
757        let n_references = reference_series.nrows();
758
759        let mut features = Array2::zeros((n_series, n_references));
760
761        for i in 0..n_series {
762            let series = time_series.slice(s![i, .., ..]);
763            let series_flat = series.into_shape((n_timepoints * n_features,)).unwrap();
764
765            for j in 0..n_references {
766                let reference = reference_series.row(j);
767                let gak_value = self.compute_gak(&series_flat, &reference)?;
768                features[[i, j]] = gak_value;
769            }
770        }
771
772        Ok(features)
773    }
774
775    /// Compute Global Alignment Kernel
776    fn compute_gak(&self, series1: &ArrayView1<f64>, series2: &ArrayView1<f64>) -> Result<f64> {
777        let n1 = series1.len();
778        let n2 = series2.len();
779
780        // Initialize GAK matrix with exponential of negative squared distances
781        let mut gak_matrix = Array2::zeros((n1 + 1, n2 + 1));
782
783        for i in 1..=n1 {
784            for j in 1..=n2 {
785                let dist_sq = (series1[i - 1] - series2[j - 1]).powi(2);
786                let kernel_val = (-dist_sq / (2.0 * self.sigma.powi(2))).exp();
787
788                // GAK recurrence relation
789                let max_alignment = vec![
790                    gak_matrix[[i - 1, j]] * kernel_val,
791                    gak_matrix[[i, j - 1]] * kernel_val,
792                    gak_matrix[[i - 1, j - 1]] * kernel_val,
793                ]
794                .into_iter()
795                .max_by(|a: &f64, b: &f64| a.partial_cmp(b).unwrap())
796                .unwrap_or(0.0);
797
798                gak_matrix[[i, j]] = max_alignment;
799            }
800        }
801
802        Ok(gak_matrix[[n1, n2]])
803    }
804}
805
806#[allow(non_snake_case)]
807#[cfg(test)]
808mod tests {
809    use super::*;
810    use scirs2_core::ndarray::Array3;
811
812    fn create_test_time_series() -> Array3<f64> {
813        // Create simple test time series data
814        let mut ts = Array3::zeros((5, 10, 2));
815
816        for i in 0..5 {
817            for t in 0..10 {
818                ts[[i, t, 0]] = (t as f64 + i as f64).sin();
819                ts[[i, t, 1]] = (t as f64 + i as f64).cos();
820            }
821        }
822
823        ts
824    }
825
826    #[test]
827    fn test_dtw_kernel_approximation() {
828        let time_series = create_test_time_series();
829
830        let mut dtw_kernel = DTWKernelApproximation::new(3)
831            .bandwidth(1.0)
832            .window_size(Some(2));
833
834        dtw_kernel.fit(&time_series).unwrap();
835        let features = dtw_kernel.transform(&time_series).unwrap();
836
837        assert_eq!(features.shape(), &[5, 3]);
838
839        // Features should be positive (RBF kernel values)
840        assert!(features.iter().all(|&x| x >= 0.0 && x <= 1.0));
841    }
842
843    #[test]
844    fn test_autoregressive_kernel_approximation() {
845        let time_series = create_test_time_series();
846
847        let mut ar_kernel = AutoregressiveKernelApproximation::new(4, 2).lambda(0.1);
848
849        ar_kernel.fit(&time_series).unwrap();
850        let features = ar_kernel.transform(&time_series).unwrap();
851
852        assert_eq!(features.shape(), &[5, 4]);
853    }
854
855    #[test]
856    fn test_spectral_kernel_approximation() {
857        let time_series = create_test_time_series();
858
859        let mut spectral_kernel = SpectralKernelApproximation::new(6, 5).magnitude_only(true);
860
861        spectral_kernel.fit(&time_series).unwrap();
862        let features = spectral_kernel.transform(&time_series).unwrap();
863
864        assert_eq!(features.shape(), &[5, 6]);
865    }
866
867    #[test]
868    fn test_global_alignment_kernel() {
869        let time_series = create_test_time_series();
870
871        let mut gak = GlobalAlignmentKernelApproximation::new(3, 1.0);
872
873        gak.fit(&time_series).unwrap();
874        let features = gak.transform(&time_series).unwrap();
875
876        assert_eq!(features.shape(), &[5, 3]);
877        assert!(features.iter().all(|&x| x >= 0.0));
878    }
879
880    #[test]
881    fn test_dtw_distance_computation() {
882        let series1 = Array1::from(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
883        let series2 = Array1::from(vec![1.0, 3.0, 2.0, 1.0]);
884
885        let dtw_kernel = DTWKernelApproximation::new(1);
886        let distance = dtw_kernel
887            .compute_dtw_distance(&series1.view(), &series2.view())
888            .unwrap();
889
890        assert!(distance >= 0.0);
891        assert!(distance.is_finite());
892    }
893
894    #[test]
895    fn test_ar_model_fitting() {
896        let time_series = create_test_time_series();
897        let series = time_series.slice(s![0, .., ..]);
898
899        let ar_kernel = AutoregressiveKernelApproximation::new(10, 2);
900        let coeffs = ar_kernel.fit_ar_model(&series, 2, 0.1).unwrap();
901
902        assert_eq!(coeffs.len(), 4); // 2 lags * 2 features
903    }
904
905    #[test]
906    fn test_frequency_features() {
907        let time_series = create_test_time_series();
908        let series = time_series.slice(s![0, .., ..]);
909
910        let spectral_kernel = SpectralKernelApproximation::new(10, 5);
911        let features = spectral_kernel
912            .compute_frequency_features(&series, 5, true)
913            .unwrap();
914
915        assert_eq!(features.len(), 10); // 5 frequencies * 2 features
916        assert!(features.iter().all(|&x| x >= 0.0)); // Magnitudes are non-negative
917    }
918
919    #[test]
920    fn test_time_series_kernel_config() {
921        let config = TimeSeriesKernelConfig::default();
922
923        assert_eq!(config.n_components, 100);
924        assert!(matches!(
925            config.kernel_type,
926            TimeSeriesKernelType::DTW { .. }
927        ));
928        assert!(config.normalize_series);
929        assert!(config.dtw_config.is_some());
930    }
931
932    #[test]
933    fn test_dtw_window_constraints() {
934        let dtw_kernel = DTWKernelApproximation::new(1);
935        let constraints = dtw_kernel.get_window_constraint(5, 5);
936
937        assert!(constraints.is_none()); // No window constraint by default
938
939        let window_constraint = Some(vec![(0, 0), (1, 1), (2, 2)]);
940        assert!(dtw_kernel.is_within_window(1, 1, &window_constraint));
941        assert!(!dtw_kernel.is_within_window(0, 2, &window_constraint));
942    }
943
944    #[test]
945    fn test_reproducibility_with_random_state() {
946        let time_series = create_test_time_series();
947
948        let config1 = TimeSeriesKernelConfig {
949            random_state: Some(42),
950            ..Default::default()
951        };
952        let config2 = TimeSeriesKernelConfig {
953            random_state: Some(42),
954            ..Default::default()
955        };
956
957        let mut dtw1 = DTWKernelApproximation::new(3).with_config(config1);
958        let mut dtw2 = DTWKernelApproximation::new(3).with_config(config2);
959
960        dtw1.fit(&time_series).unwrap();
961        dtw2.fit(&time_series).unwrap();
962
963        let features1 = dtw1.transform(&time_series).unwrap();
964        let features2 = dtw2.transform(&time_series).unwrap();
965
966        // Should be approximately equal due to same random state
967        for i in 0..features1.len() {
968            assert!(
969                (features1.as_slice().unwrap()[i] - features2.as_slice().unwrap()[i]).abs() < 1e-10
970            );
971        }
972    }
973}