Skip to main content

tensor_vault/
temporal_analysis.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Temporal analysis of access patterns via TT decomposition and drift detection.
3//!
4//! Extracts seasonal patterns from the access tensor using Tensor Train
5//! decomposition, and detects behavioral drift by comparing recent access
6//! patterns against historical baselines.
7
8use serde::{Deserialize, Serialize};
9use tensor_compress::tensor_train::{tt_decompose, tt_reconstruct, TTConfig};
10
11use crate::access_tensor::AccessTensor;
12
13/// Configuration for temporal pattern analysis.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TemporalAnalysisConfig {
16    /// TT decomposition config. If `None`, uses default settings.
17    pub tt_config: Option<TTConfig>,
18    /// Number of recent buckets for drift comparison (default: 24).
19    pub drift_window: usize,
20    /// Cosine distance threshold to flag drift (default: 0.3).
21    pub drift_threshold: f64,
22    /// Minimum total accesses to analyze an entity (default: 5).
23    pub min_accesses: u64,
24}
25
26impl Default for TemporalAnalysisConfig {
27    fn default() -> Self {
28        Self {
29            tt_config: None,
30            drift_window: 24,
31            drift_threshold: 0.3,
32            min_accesses: 5,
33        }
34    }
35}
36
37/// A seasonal pattern extracted via TT decomposition.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct SeasonalPattern {
40    /// Entity identifier.
41    pub entity: String,
42    /// Compressed (reconstructed) pattern.
43    pub compressed_pattern: Vec<f32>,
44    /// Dominant periodicity in buckets (from autocorrelation).
45    pub dominant_period: usize,
46    /// Ratio of compressed to original size.
47    pub compression_ratio: f32,
48    /// Reconstruction error (L2 norm of difference / L2 norm of original).
49    pub reconstruction_error: f32,
50}
51
52/// Drift detection result for a single entity.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct DriftDetection {
55    /// Entity identifier.
56    pub entity: String,
57    /// Cosine distance between historical and recent windows.
58    pub drift_score: f64,
59    /// Whether the drift exceeds the threshold.
60    pub is_drifting: bool,
61    /// Secrets whose access pattern changed the most.
62    pub changed_secrets: Vec<String>,
63}
64
65/// Combined temporal analysis report.
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct TemporalAnalysisReport {
68    /// Seasonal patterns per entity.
69    pub seasonal_patterns: Vec<SeasonalPattern>,
70    /// Drift detections per entity.
71    pub drift_detections: Vec<DriftDetection>,
72    /// Total entities analyzed.
73    pub total_entities_analyzed: usize,
74    /// Mean compression ratio across all entities.
75    pub mean_compression_ratio: f32,
76}
77
78/// Run full temporal analysis on an access tensor.
79pub fn analyze_temporal_patterns(
80    tensor: &AccessTensor,
81    config: TemporalAnalysisConfig,
82) -> TemporalAnalysisReport {
83    let tt_config = config.tt_config.clone().unwrap_or(TTConfig {
84        shape: vec![],
85        max_rank: 4,
86        tolerance: 1e-4,
87    });
88
89    let seasonal = extract_seasonal_patterns(tensor, &tt_config, config.min_accesses);
90    let drift = detect_drift(tensor, config.drift_window, config.drift_threshold);
91
92    let mean_compression = if seasonal.is_empty() {
93        0.0
94    } else {
95        #[allow(clippy::cast_precision_loss)] // pattern count never exceeds 2^23
96        let count = seasonal.len() as f32;
97        seasonal.iter().map(|s| s.compression_ratio).sum::<f32>() / count
98    };
99
100    let total = seasonal.len().max(drift.len());
101
102    TemporalAnalysisReport {
103        seasonal_patterns: seasonal,
104        drift_detections: drift,
105        total_entities_analyzed: total,
106        mean_compression_ratio: mean_compression,
107    }
108}
109
110/// Try to find a shape whose product equals `n` for TT decomposition.
111fn factorize_for_tt(n: usize) -> Option<Vec<usize>> {
112    if n < 4 {
113        return None;
114    }
115    #[allow(clippy::cast_precision_loss, clippy::cast_sign_loss)] // safe: n is small
116    let sqrt_n = (n as f64).sqrt() as usize;
117    for f in (2..=sqrt_n).rev() {
118        if n.is_multiple_of(f) {
119            let other = n / f;
120            if other >= 2 && f >= 2 {
121                return Some(vec![f, other]);
122            }
123        }
124    }
125    None
126}
127
128fn extract_seasonal_patterns(
129    tensor: &AccessTensor,
130    tt_config: &TTConfig,
131    min_accesses: u64,
132) -> Vec<SeasonalPattern> {
133    let mut patterns = Vec::new();
134
135    for entity in tensor.entities() {
136        let vec = tensor.entity_vector(&entity);
137        if vec.is_empty() {
138            continue;
139        }
140
141        let total: f32 = vec.iter().sum();
142        #[allow(clippy::cast_sign_loss)]
143        let total_u64 = total as u64;
144        if total_u64 < min_accesses {
145            continue;
146        }
147
148        // Try TT decomposition
149        let len = vec.len();
150        let shape = if tt_config.shape.is_empty() {
151            match factorize_for_tt(len) {
152                Some(s) => s,
153                None => continue,
154            }
155        } else if tt_config.shape.iter().product::<usize>() == len {
156            tt_config.shape.clone()
157        } else {
158            match factorize_for_tt(len) {
159                Some(s) => s,
160                None => continue,
161            }
162        };
163
164        let config = TTConfig {
165            shape,
166            max_rank: tt_config.max_rank,
167            tolerance: tt_config.tolerance,
168        };
169
170        let Ok(tt_vec) = tt_decompose(&vec, &config) else {
171            continue;
172        };
173
174        let reconstructed = tt_reconstruct(&tt_vec);
175
176        // Reconstruction error (relative L2)
177        let orig_norm: f32 = vec.iter().map(|v| v * v).sum::<f32>().sqrt();
178        let error_norm: f32 = vec
179            .iter()
180            .zip(reconstructed.iter())
181            .map(|(a, b)| (a - b).powi(2))
182            .sum::<f32>()
183            .sqrt();
184        let reconstruction_error = if orig_norm > f32::EPSILON {
185            error_norm / orig_norm
186        } else {
187            0.0
188        };
189
190        // Compression ratio: compressed storage / original
191        let compressed_size: usize = tt_vec.cores.iter().map(|c| c.data.len()).sum();
192        #[allow(clippy::cast_precision_loss)] // tensor sizes never exceed 2^23
193        let compression_ratio = compressed_size as f32 / len as f32;
194
195        // Find dominant period from the entity's time-bucket pattern
196        let (_, _, n_buckets) = tensor.dimensions();
197        let bucket_pattern = aggregate_entity_buckets(&vec, n_buckets);
198        let dominant_period = find_dominant_period(&bucket_pattern);
199
200        patterns.push(SeasonalPattern {
201            entity,
202            compressed_pattern: reconstructed,
203            dominant_period,
204            compression_ratio,
205            reconstruction_error,
206        });
207    }
208
209    patterns
210}
211
212/// Aggregate entity vector into per-bucket totals.
213fn aggregate_entity_buckets(entity_vec: &[f32], n_buckets: usize) -> Vec<f32> {
214    if n_buckets == 0 {
215        return Vec::new();
216    }
217    let n_secrets = entity_vec.len() / n_buckets;
218    let mut buckets = vec![0.0_f32; n_buckets];
219    for s in 0..n_secrets {
220        for b in 0..n_buckets {
221            buckets[b] += entity_vec[s * n_buckets + b];
222        }
223    }
224    buckets
225}
226
227fn detect_drift(tensor: &AccessTensor, window: usize, threshold: f64) -> Vec<DriftDetection> {
228    let mut detections = Vec::new();
229    let (_, _, n_buckets) = tensor.dimensions();
230    if n_buckets < window * 2 {
231        return detections;
232    }
233
234    let historical_end = n_buckets - window;
235
236    for entity in tensor.entities() {
237        let vec = tensor.entity_vector(&entity);
238        if vec.is_empty() {
239            continue;
240        }
241
242        let total: f32 = vec.iter().sum();
243        if total < 1.0 {
244            continue;
245        }
246
247        // Build per-secret mean-rate vectors for historical vs recent
248        let secrets = tensor.secrets();
249        let mut changed = Vec::new();
250        let mut hist_means = Vec::new();
251        let mut recent_means = Vec::new();
252        let hist_len = historical_end as f32;
253        let recent_len = window as f32;
254
255        for secret in &secrets {
256            let ts = tensor.time_series(&entity, secret);
257            if ts.len() < n_buckets {
258                continue;
259            }
260
261            let hist = &ts[..historical_end];
262            let recent = &ts[historical_end..];
263
264            let hist_mean = hist.iter().sum::<f32>() / hist_len.max(1.0);
265            let recent_mean = recent.iter().sum::<f32>() / recent_len.max(1.0);
266            hist_means.push(hist_mean);
267            recent_means.push(recent_mean);
268
269            // Per-secret drift
270            let hist_sum: f32 = hist.iter().sum();
271            let recent_sum: f32 = recent.iter().sum();
272            if (hist_sum - recent_sum).abs() > hist_sum.max(1.0) * 0.5 {
273                changed.push(secret.clone());
274            }
275        }
276
277        // Combined drift: cosine distance (directional) + magnitude shift
278        let cos_dist = cosine_distance(&hist_means, &recent_means);
279        let hist_norm: f64 = hist_means.iter().map(|x| f64::from(*x)).sum();
280        let recent_norm: f64 = recent_means.iter().map(|x| f64::from(*x)).sum();
281        let denom = hist_norm.max(recent_norm).max(f64::EPSILON);
282        let magnitude_shift = (recent_norm - hist_norm).abs() / denom;
283        let drift_score = cos_dist.max(magnitude_shift);
284        let is_drifting = drift_score > threshold;
285
286        detections.push(DriftDetection {
287            entity,
288            drift_score,
289            is_drifting,
290            changed_secrets: changed,
291        });
292    }
293
294    detections
295}
296
297/// Find dominant period via autocorrelation.
298pub fn find_dominant_period(time_series: &[f32]) -> usize {
299    let n = time_series.len();
300    if n < 4 {
301        return 0;
302    }
303
304    #[allow(clippy::cast_precision_loss)] // series length never exceeds 2^23
305    let mean: f32 = time_series.iter().sum::<f32>() / n as f32;
306    let centered: Vec<f32> = time_series.iter().map(|v| v - mean).collect();
307    let variance: f32 = centered.iter().map(|v| v * v).sum();
308
309    if variance < f32::EPSILON {
310        return 0;
311    }
312
313    let mut best_lag = 0;
314    let mut best_corr = f32::NEG_INFINITY;
315
316    // Check lags from 2 to n/2
317    let max_lag = n / 2;
318    for lag in 2..=max_lag {
319        let mut corr = 0.0_f32;
320        for i in 0..n - lag {
321            corr += centered[i] * centered[i + lag];
322        }
323        corr /= variance;
324
325        if corr > best_corr {
326            best_corr = corr;
327            best_lag = lag;
328        }
329    }
330
331    best_lag
332}
333
334/// Cosine distance between two vectors: 1 - cos(a, b).
335fn cosine_distance(a: &[f32], b: &[f32]) -> f64 {
336    if a.len() != b.len() || a.is_empty() {
337        return 1.0;
338    }
339
340    let dot: f64 = a
341        .iter()
342        .zip(b.iter())
343        .map(|(x, y)| f64::from(*x) * f64::from(*y))
344        .sum();
345    let norm_a: f64 = a
346        .iter()
347        .map(|x| f64::from(*x) * f64::from(*x))
348        .sum::<f64>()
349        .sqrt();
350    let norm_b: f64 = b
351        .iter()
352        .map(|x| f64::from(*x) * f64::from(*x))
353        .sum::<f64>()
354        .sqrt();
355
356    if norm_a < f64::EPSILON || norm_b < f64::EPSILON {
357        return 1.0;
358    }
359
360    1.0 - (dot / (norm_a * norm_b))
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn test_seasonal_empty() {
369        // Manually construct an empty tensor-like structure for testing
370        let config = TemporalAnalysisConfig::default();
371        let report = analyze_temporal_patterns(&empty_tensor(), config);
372        assert!(report.seasonal_patterns.is_empty());
373        assert_eq!(report.total_entities_analyzed, 0);
374    }
375
376    #[test]
377    fn test_seasonal_periodic_signal() {
378        // A signal with clear periodicity should have low reconstruction error
379        let period = 6;
380        let n_buckets = 24;
381        // 24 = 4 * 6, factorable
382        let mut data = vec![0.0_f32; n_buckets];
383        for i in 0..n_buckets {
384            data[i] = ((i % period) as f32 * std::f32::consts::PI / period as f32).sin() + 1.0;
385        }
386
387        let tensor = make_single_entity_tensor("user:alice", "secret1", &data);
388        let config = TemporalAnalysisConfig {
389            min_accesses: 1,
390            ..TemporalAnalysisConfig::default()
391        };
392        let report = analyze_temporal_patterns(&tensor, config);
393        // Should find at least one seasonal pattern
394        if !report.seasonal_patterns.is_empty() {
395            assert!(
396                report.seasonal_patterns[0].reconstruction_error < 1.0,
397                "Periodic signal should compress well"
398            );
399        }
400    }
401
402    #[test]
403    fn test_seasonal_random_high_error() {
404        // Random data should compress poorly
405        let n_buckets = 12; // 12 = 3 * 4, factorable
406        let data: Vec<f32> = (0..n_buckets).map(|i| ((i * 7 + 3) % 11) as f32).collect();
407
408        let tensor = make_single_entity_tensor("user:alice", "secret1", &data);
409        let config = TemporalAnalysisConfig {
410            min_accesses: 1,
411            ..TemporalAnalysisConfig::default()
412        };
413        let report = analyze_temporal_patterns(&tensor, config);
414        // Random data may or may not compress, but report should succeed
415        assert!(report.total_entities_analyzed <= 1);
416    }
417
418    #[test]
419    fn test_drift_stable_entity() {
420        // Uniform access pattern: no drift
421        let n_buckets = 48;
422        let data = vec![1.0_f32; n_buckets];
423        let tensor = make_single_entity_tensor("user:alice", "s1", &data);
424        let detections = detect_drift(&tensor, 12, 0.3);
425        for d in &detections {
426            assert!(!d.is_drifting, "Uniform pattern should not drift");
427        }
428    }
429
430    #[test]
431    fn test_drift_changed_entity() {
432        // Access pattern that changes dramatically
433        let n_buckets = 48;
434        let mut data = vec![0.0_f32; n_buckets];
435        // First 36 buckets: low access
436        for d in data.iter_mut().take(36) {
437            *d = 1.0;
438        }
439        // Last 12 buckets: high access
440        for d in data.iter_mut().skip(36) {
441            *d = 10.0;
442        }
443
444        let tensor = make_single_entity_tensor("user:alice", "s1", &data);
445        let detections = detect_drift(&tensor, 12, 0.01);
446        assert!(!detections.is_empty());
447        // With a low threshold, a big change should be detected
448        let alice = detections.iter().find(|d| d.entity == "user:alice");
449        assert!(alice.is_some());
450        if let Some(d) = alice {
451            assert!(d.drift_score > 0.0);
452        }
453    }
454
455    #[test]
456    fn test_drift_threshold_boundary() {
457        let n_buckets = 48;
458        let data = vec![1.0_f32; n_buckets];
459        let tensor = make_single_entity_tensor("user:alice", "s1", &data);
460
461        // With threshold 0.0, even tiny drift is flagged
462        let det_strict = detect_drift(&tensor, 12, 0.0);
463        // With threshold 2.0, nothing is flagged
464        let det_lax = detect_drift(&tensor, 12, 2.0);
465        for d in &det_lax {
466            assert!(!d.is_drifting);
467        }
468        // Strict may or may not flag depending on numerical precision
469        let _ = det_strict;
470    }
471
472    #[test]
473    fn test_dominant_period() {
474        // Generate a signal with period 6
475        let period = 6;
476        let n = 48;
477        let signal: Vec<f32> = (0..n)
478            .map(|i| ((i % period) as f32 * std::f32::consts::PI * 2.0 / period as f32).sin())
479            .collect();
480
481        let result = find_dominant_period(&signal);
482        // Should detect period near 6
483        assert!(
484            result >= 4 && result <= 8,
485            "Expected period near 6, got {result}"
486        );
487    }
488
489    #[test]
490    fn test_temporal_min_accesses_filter() {
491        let n_buckets = 12;
492        // Very few accesses
493        let mut data = vec![0.0_f32; n_buckets];
494        data[0] = 1.0;
495
496        let tensor = make_single_entity_tensor("user:alice", "s1", &data);
497        let config = TemporalAnalysisConfig {
498            min_accesses: 10, // require at least 10
499            ..TemporalAnalysisConfig::default()
500        };
501        let report = analyze_temporal_patterns(&tensor, config);
502        assert!(
503            report.seasonal_patterns.is_empty(),
504            "Entity with 1 access should be filtered"
505        );
506    }
507
508    // ===== Test helpers =====
509
510    fn empty_tensor() -> AccessTensor {
511        AccessTensor {
512            entity_index: std::collections::HashMap::new(),
513            secret_index: std::collections::HashMap::new(),
514            data: Vec::new(),
515            dimensions: (0, 0, 0),
516            config: crate::access_tensor::AccessTensorConfig::default(),
517        }
518    }
519
520    fn make_single_entity_tensor(entity: &str, secret: &str, data: &[f32]) -> AccessTensor {
521        let n_buckets = data.len();
522        let mut entity_index = std::collections::HashMap::new();
523        entity_index.insert(entity.to_string(), 0);
524        let mut secret_index = std::collections::HashMap::new();
525        secret_index.insert(secret.to_string(), 0);
526
527        AccessTensor {
528            entity_index,
529            secret_index,
530            data: data.to_vec(),
531            dimensions: (1, 1, n_buckets),
532            config: crate::access_tensor::AccessTensorConfig {
533                num_buckets: n_buckets,
534                ..crate::access_tensor::AccessTensorConfig::default()
535            },
536        }
537    }
538}