Skip to main content

sklears_kernel_approximation/
time_series_kernels.rs

1//! Time series kernel approximations
2//!
3//! This module provides kernel approximation methods specifically designed for
4//! time series data, including Dynamic Time Warping (DTW), autoregressive kernels,
5//! spectral kernels, and other time-series specific kernel methods.
6
7use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView2};
8use scirs2_core::random::essentials::Normal as RandNormal;
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::RngExt;
11use scirs2_core::random::{thread_rng, SeedableRng};
12use sklears_core::error::Result;
13
14/// Time series kernel type
15#[derive(Clone, Debug, PartialEq)]
16/// TimeSeriesKernelType
17pub enum TimeSeriesKernelType {
18    /// Dynamic Time Warping kernel
19    DTW {
20        /// Window size for DTW alignment
21        window_size: Option<usize>,
22        /// Penalty for insertions/deletions
23        penalty: f64,
24    },
25    /// Autoregressive kernel
26    Autoregressive {
27        /// Order of the autoregressive model
28        order: usize,
29        /// Regularization parameter
30        lambda: f64,
31    },
32    /// Spectral kernel using Fourier transform
33    Spectral {
34        /// Number of frequency components
35        n_frequencies: usize,
36        /// Whether to use magnitude only or include phase
37        magnitude_only: bool,
38    },
39    /// Global Alignment Kernel (GAK)
40    GlobalAlignment { sigma: f64, triangular: bool },
41    /// Time Warp Edit Distance (TWED)
42    TimeWarpEdit {
43        /// Penalty for elastic transformation
44        nu: f64,
45        /// Penalty for time warping
46        lambda: f64,
47    },
48    /// Subsequence Time Series Kernel
49    Subsequence {
50        /// Length of subsequences
51        subsequence_length: usize,
52        /// Step size for sliding window
53        step_size: usize,
54    },
55    /// Shapelet-based kernel
56    Shapelet {
57        /// Number of shapelets to extract
58        n_shapelets: usize,
59        /// Minimum shapelet length
60        min_length: usize,
61        /// Maximum shapelet length
62        max_length: usize,
63    },
64}
65
66/// DTW distance computation configuration
67#[derive(Clone, Debug)]
68/// DTWConfig
69pub struct DTWConfig {
70    /// Window constraint for DTW alignment
71    pub window_type: DTWWindowType,
72    /// Distance metric for individual points
73    pub distance_metric: DTWDistanceMetric,
74    /// Step pattern for DTW
75    pub step_pattern: DTWStepPattern,
76    /// Whether to normalize by path length
77    pub normalize: bool,
78}
79
80/// DTW window constraint types
81#[derive(Clone, Debug, PartialEq)]
82/// DTWWindowType
83pub enum DTWWindowType {
84    /// No window constraint
85    None,
86    /// Sakoe-Chiba band
87    SakoeChiba { window_size: usize },
88    /// Itakura parallelogram
89    Itakura,
90    /// Custom window function
91    Custom { window_func: Vec<(usize, usize)> },
92}
93
94/// Distance metrics for DTW
95#[derive(Clone, Debug, PartialEq)]
96/// DTWDistanceMetric
97pub enum DTWDistanceMetric {
98    /// Euclidean distance
99    Euclidean,
100    /// Manhattan distance
101    Manhattan,
102    /// Cosine distance
103    Cosine,
104    /// Custom distance function
105    Custom,
106}
107
108/// DTW step patterns
109#[derive(Clone, Debug, PartialEq)]
110/// DTWStepPattern
111pub enum DTWStepPattern {
112    /// Symmetric step pattern
113    Symmetric,
114    /// Asymmetric step pattern
115    Asymmetric,
116    /// Custom step pattern
117    Custom { steps: Vec<(i32, i32, f64)> },
118}
119
120impl Default for DTWConfig {
121    fn default() -> Self {
122        Self {
123            window_type: DTWWindowType::None,
124            distance_metric: DTWDistanceMetric::Euclidean,
125            step_pattern: DTWStepPattern::Symmetric,
126            normalize: true,
127        }
128    }
129}
130
131/// Time series kernel configuration
132#[derive(Clone, Debug)]
133/// TimeSeriesKernelConfig
134pub struct TimeSeriesKernelConfig {
135    /// Type of time series kernel
136    pub kernel_type: TimeSeriesKernelType,
137    /// Number of random features for approximation
138    pub n_components: usize,
139    /// Random state for reproducibility
140    pub random_state: Option<u64>,
141    /// DTW-specific configuration
142    pub dtw_config: Option<DTWConfig>,
143    /// Whether to normalize time series
144    pub normalize_series: bool,
145    /// Number of parallel workers
146    pub n_workers: usize,
147}
148
149impl Default for TimeSeriesKernelConfig {
150    fn default() -> Self {
151        Self {
152            kernel_type: TimeSeriesKernelType::DTW {
153                window_size: None,
154                penalty: 0.0,
155            },
156            n_components: 100,
157            random_state: None,
158            dtw_config: Some(DTWConfig::default()),
159            normalize_series: true,
160            n_workers: num_cpus::get(),
161        }
162    }
163}
164
165/// Dynamic Time Warping kernel approximation
166pub struct DTWKernelApproximation {
167    config: TimeSeriesKernelConfig,
168    reference_series: Option<Array2<f64>>,
169    random_indices: Option<Vec<usize>>,
170    dtw_distances: Option<Array2<f64>>,
171    kernel_bandwidth: f64,
172}
173
174impl DTWKernelApproximation {
175    /// Create a new DTW kernel approximation
176    pub fn new(n_components: usize) -> Self {
177        Self {
178            config: TimeSeriesKernelConfig {
179                n_components,
180                kernel_type: TimeSeriesKernelType::DTW {
181                    window_size: None,
182                    penalty: 0.0,
183                },
184                ..Default::default()
185            },
186            reference_series: None,
187            random_indices: None,
188            dtw_distances: None,
189            kernel_bandwidth: 1.0,
190        }
191    }
192
193    /// Set kernel bandwidth
194    pub fn bandwidth(mut self, bandwidth: f64) -> Self {
195        self.kernel_bandwidth = bandwidth;
196        self
197    }
198
199    /// Set DTW window size
200    pub fn window_size(mut self, new_window_size: Option<usize>) -> Self {
201        if let TimeSeriesKernelType::DTW {
202            ref mut window_size,
203            ..
204        } = self.config.kernel_type
205        {
206            *window_size = new_window_size;
207        }
208        self
209    }
210
211    /// Set configuration
212    pub fn with_config(mut self, config: TimeSeriesKernelConfig) -> Self {
213        self.config = config;
214        self
215    }
216
217    /// Fit the DTW kernel approximation
218    pub fn fit(&mut self, time_series: &Array3<f64>) -> Result<()> {
219        let (n_series, n_timepoints, n_features) = time_series.dim();
220
221        // Select random reference series for approximation
222        let mut rng = if let Some(seed) = self.config.random_state {
223            RealStdRng::seed_from_u64(seed)
224        } else {
225            RealStdRng::from_seed(thread_rng().random())
226        };
227
228        let n_references = std::cmp::min(self.config.n_components, n_series);
229        let mut indices: Vec<usize> = (0..n_series).collect();
230        indices.sort_by_key(|_| rng.random::<u32>());
231        indices.truncate(n_references);
232        self.random_indices = Some(indices.clone());
233
234        // Extract reference series
235        let mut reference_series = Array2::zeros((n_references, n_timepoints * n_features));
236        for (i, &idx) in indices.iter().enumerate() {
237            let series = time_series.slice(s![idx, .., ..]);
238            let flattened = series
239                .into_shape((n_timepoints * n_features,))
240                .expect("operation should succeed");
241            reference_series.row_mut(i).assign(&flattened);
242        }
243        self.reference_series = Some(reference_series);
244
245        Ok(())
246    }
247
248    /// Transform time series using DTW kernel features
249    pub fn transform(&self, time_series: &Array3<f64>) -> Result<Array2<f64>> {
250        let reference_series = self.reference_series.as_ref().ok_or("Model not fitted")?;
251        let (n_series, n_timepoints, n_features) = time_series.dim();
252        let n_references = reference_series.nrows();
253
254        let mut features = Array2::zeros((n_series, n_references));
255
256        // Compute DTW distances to reference series
257        for i in 0..n_series {
258            let series = time_series.slice(s![i, .., ..]);
259            let series_flat = series
260                .into_shape((n_timepoints * n_features,))
261                .expect("operation should succeed");
262
263            for j in 0..n_references {
264                let reference = reference_series.row(j);
265                let distance = self.compute_dtw_distance(&series_flat, &reference)?;
266
267                // Convert distance to kernel value using RBF kernel
268                let kernel_value = (-distance / (2.0 * self.kernel_bandwidth.powi(2))).exp();
269                features[[i, j]] = kernel_value;
270            }
271        }
272
273        Ok(features)
274    }
275
276    /// Compute DTW distance between two time series
277    fn compute_dtw_distance(
278        &self,
279        series1: &ArrayView1<f64>,
280        series2: &ArrayView1<f64>,
281    ) -> Result<f64> {
282        let n1 = series1.len();
283        let n2 = series2.len();
284
285        // Initialize DTW matrix
286        let mut dtw_matrix = Array2::from_elem((n1 + 1, n2 + 1), f64::INFINITY);
287        dtw_matrix[[0, 0]] = 0.0;
288
289        // Apply window constraint if specified
290        let window_constraint = self.get_window_constraint(n1, n2);
291
292        for i in 1..=n1 {
293            for j in 1..=n2 {
294                if self.is_within_window(i - 1, j - 1, &window_constraint) {
295                    let cost = self.compute_point_distance(series1[i - 1], series2[j - 1]);
296
297                    let candidates = vec![
298                        dtw_matrix[[i - 1, j]] + cost,     // Insertion
299                        dtw_matrix[[i, j - 1]] + cost,     // Deletion
300                        dtw_matrix[[i - 1, j - 1]] + cost, // Match
301                    ];
302
303                    dtw_matrix[[i, j]] = candidates
304                        .into_iter()
305                        .min_by(|a, b| a.partial_cmp(b).expect("operation should succeed"))
306                        .expect("operation should succeed");
307                }
308            }
309        }
310
311        let distance = dtw_matrix[[n1, n2]];
312
313        // Normalize by path length if requested
314        if self
315            .config
316            .dtw_config
317            .as_ref()
318            .map_or(true, |cfg| cfg.normalize)
319        {
320            Ok(distance / (n1 + n2) as f64)
321        } else {
322            Ok(distance)
323        }
324    }
325
326    /// Get window constraint for DTW
327    fn get_window_constraint(&self, n1: usize, n2: usize) -> Option<Vec<(usize, usize)>> {
328        if let Some(dtw_config) = &self.config.dtw_config {
329            match &dtw_config.window_type {
330                DTWWindowType::SakoeChiba { window_size } => {
331                    let mut constraints = Vec::new();
332                    for i in 0..n1 {
333                        let j_start = (i as i32 - *window_size as i32).max(0) as usize;
334                        let j_end = (i + window_size).min(n2 - 1);
335                        for j in j_start..=j_end {
336                            constraints.push((i, j));
337                        }
338                    }
339                    Some(constraints)
340                }
341                DTWWindowType::Custom { window_func } => Some(window_func.clone()),
342                _ => None,
343            }
344        } else {
345            None
346        }
347    }
348
349    /// Check if point is within window constraint
350    fn is_within_window(
351        &self,
352        i: usize,
353        j: usize,
354        window_constraint: &Option<Vec<(usize, usize)>>,
355    ) -> bool {
356        match window_constraint {
357            Some(constraints) => constraints.contains(&(i, j)),
358            None => true,
359        }
360    }
361
362    /// Compute distance between two points
363    fn compute_point_distance(&self, x1: f64, x2: f64) -> f64 {
364        let metric = self
365            .config
366            .dtw_config
367            .as_ref()
368            .map(|cfg| &cfg.distance_metric)
369            .unwrap_or(&DTWDistanceMetric::Euclidean);
370
371        match metric {
372            DTWDistanceMetric::Euclidean => (x1 - x2).powi(2),
373            DTWDistanceMetric::Manhattan => (x1 - x2).abs(),
374            DTWDistanceMetric::Cosine => 1.0 - (x1 * x2) / ((x1.powi(2) + x2.powi(2)).sqrt()),
375            DTWDistanceMetric::Custom => (x1 - x2).powi(2), // Default to Euclidean
376        }
377    }
378}
379
380/// Autoregressive kernel approximation
381pub struct AutoregressiveKernelApproximation {
382    config: TimeSeriesKernelConfig,
383    ar_coefficients: Option<Array2<f64>>,
384    reference_models: Option<Vec<Array1<f64>>>,
385    random_features: Option<Array2<f64>>,
386}
387
388impl AutoregressiveKernelApproximation {
389    /// Create a new autoregressive kernel approximation
390    pub fn new(n_components: usize, order: usize) -> Self {
391        Self {
392            config: TimeSeriesKernelConfig {
393                n_components,
394                kernel_type: TimeSeriesKernelType::Autoregressive { order, lambda: 0.1 },
395                ..Default::default()
396            },
397            ar_coefficients: None,
398            reference_models: None,
399            random_features: None,
400        }
401    }
402
403    /// Set regularization parameter
404    pub fn lambda(mut self, new_lambda: f64) -> Self {
405        if let TimeSeriesKernelType::Autoregressive { ref mut lambda, .. } = self.config.kernel_type
406        {
407            *lambda = new_lambda;
408        }
409        self
410    }
411
412    /// Fit the autoregressive kernel approximation
413    pub fn fit(&mut self, time_series: &Array3<f64>) -> Result<()> {
414        let (n_series, _n_timepoints, n_features) = time_series.dim();
415
416        if let TimeSeriesKernelType::Autoregressive { order, lambda } = &self.config.kernel_type {
417            // Fit AR models to each time series
418            let mut ar_coefficients = Array2::zeros((n_series, order * n_features));
419
420            for i in 0..n_series {
421                let series = time_series.slice(s![i, .., ..]);
422                let coeffs = self.fit_ar_model(&series, *order, *lambda)?;
423                ar_coefficients.row_mut(i).assign(&coeffs);
424            }
425
426            self.ar_coefficients = Some(ar_coefficients);
427
428            // Generate random features based on AR coefficients
429            self.generate_random_features()?;
430        }
431
432        Ok(())
433    }
434
435    /// Transform time series using AR kernel features
436    pub fn transform(&self, time_series: &Array3<f64>) -> Result<Array2<f64>> {
437        let _ar_coefficients = self.ar_coefficients.as_ref().ok_or("Model not fitted")?;
438        let random_features = self
439            .random_features
440            .as_ref()
441            .ok_or("Random features not generated")?;
442
443        let (n_series, _n_timepoints, _n_features) = time_series.dim();
444        let n_components = self.config.n_components;
445
446        let mut features = Array2::zeros((n_series, n_components));
447
448        if let TimeSeriesKernelType::Autoregressive { order, lambda } = &self.config.kernel_type {
449            for i in 0..n_series {
450                let series = time_series.slice(s![i, .., ..]);
451                let coeffs = self.fit_ar_model(&series, *order, *lambda)?;
452
453                // Compute random features
454                for j in 0..n_components {
455                    let random_proj = coeffs.dot(&random_features.row(j));
456                    features[[i, j]] = random_proj.cos();
457                }
458            }
459        }
460
461        Ok(features)
462    }
463
464    /// Fit AR model to a single time series
465    fn fit_ar_model(
466        &self,
467        series: &ArrayView2<f64>,
468        order: usize,
469        lambda: f64,
470    ) -> Result<Array1<f64>> {
471        let (n_timepoints, n_features) = series.dim();
472
473        if n_timepoints <= order {
474            return Err("Time series too short for specified AR order".into());
475        }
476
477        // Create design matrix X and target vector y
478        let n_samples = n_timepoints - order;
479        let mut x_matrix = Array2::zeros((n_samples, order * n_features));
480        let mut y_vector = Array2::zeros((n_samples, n_features));
481
482        for t in order..n_timepoints {
483            let sample_idx = t - order;
484
485            // Fill design matrix with lagged values
486            for lag in 1..=order {
487                let lag_idx = t - lag;
488                for feat in 0..n_features {
489                    x_matrix[[sample_idx, (lag - 1) * n_features + feat]] = series[[lag_idx, feat]];
490                }
491            }
492
493            // Fill target vector
494            for feat in 0..n_features {
495                y_vector[[sample_idx, feat]] = series[[t, feat]];
496            }
497        }
498
499        // Solve least squares with regularization: (X^T X + λI)β = X^T y
500        let xtx = x_matrix.t().dot(&x_matrix);
501        let xtx_reg = xtx + Array2::<f64>::eye(order * n_features) * lambda;
502        let xty = x_matrix.t().dot(&y_vector);
503
504        // Simplified solution (in practice, use proper linear algebra)
505        let coeffs = self.solve_linear_system(&xtx_reg, &xty)?;
506
507        Ok(coeffs)
508    }
509
510    /// Solve linear system (simplified implementation)
511    fn solve_linear_system(&self, a: &Array2<f64>, b: &Array2<f64>) -> Result<Array1<f64>> {
512        // This is a simplified implementation - in practice use proper solvers
513        let n = a.nrows();
514        let n_features = b.ncols();
515        let mut solution = Array1::zeros(n);
516
517        // Use diagonal approximation for simplicity - average across features
518        for i in 0..n {
519            if a[[i, i]].abs() > 1e-12 {
520                let avg_target =
521                    (0..n_features).map(|j| b[[i, j]]).sum::<f64>() / n_features as f64;
522                solution[i] = avg_target / a[[i, i]];
523            }
524        }
525
526        Ok(solution)
527    }
528
529    /// Generate random features for AR kernel approximation
530    fn generate_random_features(&mut self) -> Result<()> {
531        if let TimeSeriesKernelType::Autoregressive { .. } = &self.config.kernel_type {
532            let ar_coefficients = self
533                .ar_coefficients
534                .as_ref()
535                .expect("operation should succeed");
536            let (_, n_ar_features) = ar_coefficients.dim();
537
538            let mut rng = if let Some(seed) = self.config.random_state {
539                RealStdRng::seed_from_u64(seed)
540            } else {
541                RealStdRng::from_seed(thread_rng().random())
542            };
543
544            let normal = RandNormal::new(0.0, 1.0).expect("operation should succeed");
545            let random_features =
546                Array2::from_shape_fn((self.config.n_components, n_ar_features), |_| {
547                    rng.sample(normal)
548                });
549
550            self.random_features = Some(random_features);
551        }
552
553        Ok(())
554    }
555}
556
557/// Spectral kernel approximation for time series
558pub struct SpectralKernelApproximation {
559    config: TimeSeriesKernelConfig,
560    frequency_features: Option<Array2<f64>>,
561    reference_spectra: Option<Array2<f64>>,
562}
563
564impl SpectralKernelApproximation {
565    /// Create a new spectral kernel approximation
566    pub fn new(n_components: usize, n_frequencies: usize) -> Self {
567        Self {
568            config: TimeSeriesKernelConfig {
569                n_components,
570                kernel_type: TimeSeriesKernelType::Spectral {
571                    n_frequencies,
572                    magnitude_only: true,
573                },
574                ..Default::default()
575            },
576            frequency_features: None,
577            reference_spectra: None,
578        }
579    }
580
581    /// Set whether to use magnitude only
582    pub fn magnitude_only(mut self, new_magnitude_only: bool) -> Self {
583        if let TimeSeriesKernelType::Spectral {
584            ref mut magnitude_only,
585            ..
586        } = self.config.kernel_type
587        {
588            *magnitude_only = new_magnitude_only;
589        }
590        self
591    }
592
593    /// Fit the spectral kernel approximation
594    pub fn fit(&mut self, time_series: &Array3<f64>) -> Result<()> {
595        let (n_series, _n_timepoints, n_features) = time_series.dim();
596
597        if let TimeSeriesKernelType::Spectral {
598            n_frequencies,
599            magnitude_only,
600        } = &self.config.kernel_type
601        {
602            // Compute frequency domain representation
603            let mut spectra = Array2::zeros((n_series, *n_frequencies * n_features));
604
605            for i in 0..n_series {
606                let series = time_series.slice(s![i, .., ..]);
607                let spectrum =
608                    self.compute_frequency_features(&series, *n_frequencies, *magnitude_only)?;
609                spectra.row_mut(i).assign(&spectrum);
610            }
611
612            self.reference_spectra = Some(spectra);
613
614            // Generate random projection features
615            self.generate_spectral_features(*n_frequencies * n_features)?;
616        }
617
618        Ok(())
619    }
620
621    /// Transform time series using spectral features
622    pub fn transform(&self, time_series: &Array3<f64>) -> Result<Array2<f64>> {
623        let frequency_features = self.frequency_features.as_ref().ok_or("Model not fitted")?;
624
625        let (n_series, _n_timepoints, _n_features) = time_series.dim();
626        let n_components = self.config.n_components;
627
628        let mut features = Array2::zeros((n_series, n_components));
629
630        if let TimeSeriesKernelType::Spectral {
631            n_frequencies,
632            magnitude_only,
633        } = &self.config.kernel_type
634        {
635            for i in 0..n_series {
636                let series = time_series.slice(s![i, .., ..]);
637                let spectrum =
638                    self.compute_frequency_features(&series, *n_frequencies, *magnitude_only)?;
639
640                // Apply random projection
641                for j in 0..n_components {
642                    let projection = spectrum.dot(&frequency_features.row(j));
643                    features[[i, j]] = projection.cos();
644                }
645            }
646        }
647
648        Ok(features)
649    }
650
651    /// Compute frequency domain features using FFT
652    fn compute_frequency_features(
653        &self,
654        series: &ArrayView2<f64>,
655        n_frequencies: usize,
656        magnitude_only: bool,
657    ) -> Result<Array1<f64>> {
658        let (n_timepoints, n_features) = series.dim();
659        let mut features = Vec::new();
660
661        for feat in 0..n_features {
662            let signal = series.column(feat);
663
664            // Simple discrete Fourier transform approximation
665            for k in 0..n_frequencies {
666                let freq = 2.0 * std::f64::consts::PI * k as f64 / n_timepoints as f64;
667
668                let mut real_part = 0.0;
669                let mut imag_part = 0.0;
670
671                for t in 0..n_timepoints {
672                    let angle = freq * t as f64;
673                    real_part += signal[t] * angle.cos();
674                    imag_part += signal[t] * angle.sin();
675                }
676
677                if magnitude_only {
678                    features.push((real_part.powi(2) + imag_part.powi(2)).sqrt());
679                } else {
680                    features.push(real_part);
681                    features.push(imag_part);
682                }
683            }
684        }
685
686        Ok(Array1::from(features))
687    }
688
689    /// Generate random spectral features
690    fn generate_spectral_features(&mut self, n_spectrum_features: usize) -> Result<()> {
691        let mut rng = if let Some(seed) = self.config.random_state {
692            RealStdRng::seed_from_u64(seed)
693        } else {
694            RealStdRng::from_seed(thread_rng().random())
695        };
696
697        let normal = RandNormal::new(0.0, 1.0).expect("operation should succeed");
698        let frequency_features =
699            Array2::from_shape_fn((self.config.n_components, n_spectrum_features), |_| {
700                rng.sample(normal)
701            });
702
703        self.frequency_features = Some(frequency_features);
704        Ok(())
705    }
706}
707
708/// Global Alignment Kernel (GAK) approximation
709pub struct GlobalAlignmentKernelApproximation {
710    config: TimeSeriesKernelConfig,
711    reference_series: Option<Array2<f64>>,
712    sigma: f64,
713}
714
715impl GlobalAlignmentKernelApproximation {
716    /// Create a new GAK approximation
717    pub fn new(n_components: usize, sigma: f64) -> Self {
718        Self {
719            config: TimeSeriesKernelConfig {
720                n_components,
721                kernel_type: TimeSeriesKernelType::GlobalAlignment {
722                    sigma,
723                    triangular: false,
724                },
725                ..Default::default()
726            },
727            reference_series: None,
728            sigma,
729        }
730    }
731
732    /// Fit the GAK approximation
733    pub fn fit(&mut self, time_series: &Array3<f64>) -> Result<()> {
734        let (n_series, n_timepoints, n_features) = time_series.dim();
735
736        // Select random reference series
737        let mut rng = if let Some(seed) = self.config.random_state {
738            RealStdRng::seed_from_u64(seed)
739        } else {
740            RealStdRng::from_seed(thread_rng().random())
741        };
742
743        let n_references = std::cmp::min(self.config.n_components, n_series);
744        let mut indices: Vec<usize> = (0..n_series).collect();
745        indices.sort_by_key(|_| rng.random::<u32>());
746        indices.truncate(n_references);
747
748        let mut reference_series = Array2::zeros((n_references, n_timepoints * n_features));
749        for (i, &idx) in indices.iter().enumerate() {
750            let series = time_series.slice(s![idx, .., ..]);
751            let flattened = series
752                .into_shape((n_timepoints * n_features,))
753                .expect("operation should succeed");
754            reference_series.row_mut(i).assign(&flattened);
755        }
756
757        self.reference_series = Some(reference_series);
758        Ok(())
759    }
760
761    /// Transform using GAK features
762    pub fn transform(&self, time_series: &Array3<f64>) -> Result<Array2<f64>> {
763        let reference_series = self.reference_series.as_ref().ok_or("Model not fitted")?;
764
765        let (n_series, n_timepoints, n_features) = time_series.dim();
766        let n_references = reference_series.nrows();
767
768        let mut features = Array2::zeros((n_series, n_references));
769
770        for i in 0..n_series {
771            let series = time_series.slice(s![i, .., ..]);
772            let series_flat = series
773                .into_shape((n_timepoints * n_features,))
774                .expect("operation should succeed");
775
776            for j in 0..n_references {
777                let reference = reference_series.row(j);
778                let gak_value = self.compute_gak(&series_flat, &reference)?;
779                features[[i, j]] = gak_value;
780            }
781        }
782
783        Ok(features)
784    }
785
786    /// Compute Global Alignment Kernel
787    fn compute_gak(&self, series1: &ArrayView1<f64>, series2: &ArrayView1<f64>) -> Result<f64> {
788        let n1 = series1.len();
789        let n2 = series2.len();
790
791        // Initialize GAK matrix with exponential of negative squared distances
792        let mut gak_matrix = Array2::zeros((n1 + 1, n2 + 1));
793
794        for i in 1..=n1 {
795            for j in 1..=n2 {
796                let dist_sq = (series1[i - 1] - series2[j - 1]).powi(2);
797                let kernel_val = (-dist_sq / (2.0 * self.sigma.powi(2))).exp();
798
799                // GAK recurrence relation
800                let max_alignment = vec![
801                    gak_matrix[[i - 1, j]] * kernel_val,
802                    gak_matrix[[i, j - 1]] * kernel_val,
803                    gak_matrix[[i - 1, j - 1]] * kernel_val,
804                ]
805                .into_iter()
806                .max_by(|a: &f64, b: &f64| a.partial_cmp(b).expect("operation should succeed"))
807                .unwrap_or(0.0);
808
809                gak_matrix[[i, j]] = max_alignment;
810            }
811        }
812
813        Ok(gak_matrix[[n1, n2]])
814    }
815}
816
817#[allow(non_snake_case)]
818#[cfg(test)]
819mod tests {
820    use super::*;
821    use scirs2_core::ndarray::Array3;
822
823    fn create_test_time_series() -> Array3<f64> {
824        // Create simple test time series data
825        let mut ts = Array3::zeros((5, 10, 2));
826
827        for i in 0..5 {
828            for t in 0..10 {
829                ts[[i, t, 0]] = (t as f64 + i as f64).sin();
830                ts[[i, t, 1]] = (t as f64 + i as f64).cos();
831            }
832        }
833
834        ts
835    }
836
837    #[test]
838    fn test_dtw_kernel_approximation() {
839        let time_series = create_test_time_series();
840
841        let mut dtw_kernel = DTWKernelApproximation::new(3)
842            .bandwidth(1.0)
843            .window_size(Some(2));
844
845        dtw_kernel
846            .fit(&time_series)
847            .expect("operation should succeed");
848        let features = dtw_kernel
849            .transform(&time_series)
850            .expect("operation should succeed");
851
852        assert_eq!(features.shape(), &[5, 3]);
853
854        // Features should be positive (RBF kernel values)
855        assert!(features.iter().all(|&x| x >= 0.0 && x <= 1.0));
856    }
857
858    #[test]
859    fn test_autoregressive_kernel_approximation() {
860        let time_series = create_test_time_series();
861
862        let mut ar_kernel = AutoregressiveKernelApproximation::new(4, 2).lambda(0.1);
863
864        ar_kernel
865            .fit(&time_series)
866            .expect("operation should succeed");
867        let features = ar_kernel
868            .transform(&time_series)
869            .expect("operation should succeed");
870
871        assert_eq!(features.shape(), &[5, 4]);
872    }
873
874    #[test]
875    fn test_spectral_kernel_approximation() {
876        let time_series = create_test_time_series();
877
878        let mut spectral_kernel = SpectralKernelApproximation::new(6, 5).magnitude_only(true);
879
880        spectral_kernel
881            .fit(&time_series)
882            .expect("operation should succeed");
883        let features = spectral_kernel
884            .transform(&time_series)
885            .expect("operation should succeed");
886
887        assert_eq!(features.shape(), &[5, 6]);
888    }
889
890    #[test]
891    fn test_global_alignment_kernel() {
892        let time_series = create_test_time_series();
893
894        let mut gak = GlobalAlignmentKernelApproximation::new(3, 1.0);
895
896        gak.fit(&time_series).expect("operation should succeed");
897        let features = gak
898            .transform(&time_series)
899            .expect("operation should succeed");
900
901        assert_eq!(features.shape(), &[5, 3]);
902        assert!(features.iter().all(|&x| x >= 0.0));
903    }
904
905    #[test]
906    fn test_dtw_distance_computation() {
907        let series1 = Array1::from(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
908        let series2 = Array1::from(vec![1.0, 3.0, 2.0, 1.0]);
909
910        let dtw_kernel = DTWKernelApproximation::new(1);
911        let distance = dtw_kernel
912            .compute_dtw_distance(&series1.view(), &series2.view())
913            .expect("operation should succeed");
914
915        assert!(distance >= 0.0);
916        assert!(distance.is_finite());
917    }
918
919    #[test]
920    fn test_ar_model_fitting() {
921        let time_series = create_test_time_series();
922        let series = time_series.slice(s![0, .., ..]);
923
924        let ar_kernel = AutoregressiveKernelApproximation::new(10, 2);
925        let coeffs = ar_kernel
926            .fit_ar_model(&series, 2, 0.1)
927            .expect("operation should succeed");
928
929        assert_eq!(coeffs.len(), 4); // 2 lags * 2 features
930    }
931
932    #[test]
933    fn test_frequency_features() {
934        let time_series = create_test_time_series();
935        let series = time_series.slice(s![0, .., ..]);
936
937        let spectral_kernel = SpectralKernelApproximation::new(10, 5);
938        let features = spectral_kernel
939            .compute_frequency_features(&series, 5, true)
940            .expect("operation should succeed");
941
942        assert_eq!(features.len(), 10); // 5 frequencies * 2 features
943        assert!(features.iter().all(|&x| x >= 0.0)); // Magnitudes are non-negative
944    }
945
946    #[test]
947    fn test_time_series_kernel_config() {
948        let config = TimeSeriesKernelConfig::default();
949
950        assert_eq!(config.n_components, 100);
951        assert!(matches!(
952            config.kernel_type,
953            TimeSeriesKernelType::DTW { .. }
954        ));
955        assert!(config.normalize_series);
956        assert!(config.dtw_config.is_some());
957    }
958
959    #[test]
960    fn test_dtw_window_constraints() {
961        let dtw_kernel = DTWKernelApproximation::new(1);
962        let constraints = dtw_kernel.get_window_constraint(5, 5);
963
964        assert!(constraints.is_none()); // No window constraint by default
965
966        let window_constraint = Some(vec![(0, 0), (1, 1), (2, 2)]);
967        assert!(dtw_kernel.is_within_window(1, 1, &window_constraint));
968        assert!(!dtw_kernel.is_within_window(0, 2, &window_constraint));
969    }
970
971    #[test]
972    fn test_reproducibility_with_random_state() {
973        let time_series = create_test_time_series();
974
975        let config1 = TimeSeriesKernelConfig {
976            random_state: Some(42),
977            ..Default::default()
978        };
979        let config2 = TimeSeriesKernelConfig {
980            random_state: Some(42),
981            ..Default::default()
982        };
983
984        let mut dtw1 = DTWKernelApproximation::new(3).with_config(config1);
985        let mut dtw2 = DTWKernelApproximation::new(3).with_config(config2);
986
987        dtw1.fit(&time_series).expect("operation should succeed");
988        dtw2.fit(&time_series).expect("operation should succeed");
989
990        let features1 = dtw1
991            .transform(&time_series)
992            .expect("operation should succeed");
993        let features2 = dtw2
994            .transform(&time_series)
995            .expect("operation should succeed");
996
997        // Should be approximately equal due to same random state
998        for i in 0..features1.len() {
999            assert!(
1000                (features1.as_slice().expect("operation should succeed")[i]
1001                    - features2.as_slice().expect("operation should succeed")[i])
1002                    .abs()
1003                    < 1e-10
1004            );
1005        }
1006    }
1007}