1use serde::{Deserialize, Serialize};
9use tensor_compress::tensor_train::{tt_decompose, tt_reconstruct, TTConfig};
10
11use crate::access_tensor::AccessTensor;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TemporalAnalysisConfig {
16 pub tt_config: Option<TTConfig>,
18 pub drift_window: usize,
20 pub drift_threshold: f64,
22 pub min_accesses: u64,
24}
25
26impl Default for TemporalAnalysisConfig {
27 fn default() -> Self {
28 Self {
29 tt_config: None,
30 drift_window: 24,
31 drift_threshold: 0.3,
32 min_accesses: 5,
33 }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct SeasonalPattern {
40 pub entity: String,
42 pub compressed_pattern: Vec<f32>,
44 pub dominant_period: usize,
46 pub compression_ratio: f32,
48 pub reconstruction_error: f32,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct DriftDetection {
55 pub entity: String,
57 pub drift_score: f64,
59 pub is_drifting: bool,
61 pub changed_secrets: Vec<String>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct TemporalAnalysisReport {
68 pub seasonal_patterns: Vec<SeasonalPattern>,
70 pub drift_detections: Vec<DriftDetection>,
72 pub total_entities_analyzed: usize,
74 pub mean_compression_ratio: f32,
76}
77
78pub fn analyze_temporal_patterns(
80 tensor: &AccessTensor,
81 config: TemporalAnalysisConfig,
82) -> TemporalAnalysisReport {
83 let tt_config = config.tt_config.clone().unwrap_or(TTConfig {
84 shape: vec![],
85 max_rank: 4,
86 tolerance: 1e-4,
87 });
88
89 let seasonal = extract_seasonal_patterns(tensor, &tt_config, config.min_accesses);
90 let drift = detect_drift(tensor, config.drift_window, config.drift_threshold);
91
92 let mean_compression = if seasonal.is_empty() {
93 0.0
94 } else {
95 #[allow(clippy::cast_precision_loss)] let count = seasonal.len() as f32;
97 seasonal.iter().map(|s| s.compression_ratio).sum::<f32>() / count
98 };
99
100 let total = seasonal.len().max(drift.len());
101
102 TemporalAnalysisReport {
103 seasonal_patterns: seasonal,
104 drift_detections: drift,
105 total_entities_analyzed: total,
106 mean_compression_ratio: mean_compression,
107 }
108}
109
110fn factorize_for_tt(n: usize) -> Option<Vec<usize>> {
112 if n < 4 {
113 return None;
114 }
115 #[allow(clippy::cast_precision_loss, clippy::cast_sign_loss)] let sqrt_n = (n as f64).sqrt() as usize;
117 for f in (2..=sqrt_n).rev() {
118 if n.is_multiple_of(f) {
119 let other = n / f;
120 if other >= 2 && f >= 2 {
121 return Some(vec![f, other]);
122 }
123 }
124 }
125 None
126}
127
128fn extract_seasonal_patterns(
129 tensor: &AccessTensor,
130 tt_config: &TTConfig,
131 min_accesses: u64,
132) -> Vec<SeasonalPattern> {
133 let mut patterns = Vec::new();
134
135 for entity in tensor.entities() {
136 let vec = tensor.entity_vector(&entity);
137 if vec.is_empty() {
138 continue;
139 }
140
141 let total: f32 = vec.iter().sum();
142 #[allow(clippy::cast_sign_loss)]
143 let total_u64 = total as u64;
144 if total_u64 < min_accesses {
145 continue;
146 }
147
148 let len = vec.len();
150 let shape = if tt_config.shape.is_empty() {
151 match factorize_for_tt(len) {
152 Some(s) => s,
153 None => continue,
154 }
155 } else if tt_config.shape.iter().product::<usize>() == len {
156 tt_config.shape.clone()
157 } else {
158 match factorize_for_tt(len) {
159 Some(s) => s,
160 None => continue,
161 }
162 };
163
164 let config = TTConfig {
165 shape,
166 max_rank: tt_config.max_rank,
167 tolerance: tt_config.tolerance,
168 };
169
170 let Ok(tt_vec) = tt_decompose(&vec, &config) else {
171 continue;
172 };
173
174 let reconstructed = tt_reconstruct(&tt_vec);
175
176 let orig_norm: f32 = vec.iter().map(|v| v * v).sum::<f32>().sqrt();
178 let error_norm: f32 = vec
179 .iter()
180 .zip(reconstructed.iter())
181 .map(|(a, b)| (a - b).powi(2))
182 .sum::<f32>()
183 .sqrt();
184 let reconstruction_error = if orig_norm > f32::EPSILON {
185 error_norm / orig_norm
186 } else {
187 0.0
188 };
189
190 let compressed_size: usize = tt_vec.cores.iter().map(|c| c.data.len()).sum();
192 #[allow(clippy::cast_precision_loss)] let compression_ratio = compressed_size as f32 / len as f32;
194
195 let (_, _, n_buckets) = tensor.dimensions();
197 let bucket_pattern = aggregate_entity_buckets(&vec, n_buckets);
198 let dominant_period = find_dominant_period(&bucket_pattern);
199
200 patterns.push(SeasonalPattern {
201 entity,
202 compressed_pattern: reconstructed,
203 dominant_period,
204 compression_ratio,
205 reconstruction_error,
206 });
207 }
208
209 patterns
210}
211
212fn aggregate_entity_buckets(entity_vec: &[f32], n_buckets: usize) -> Vec<f32> {
214 if n_buckets == 0 {
215 return Vec::new();
216 }
217 let n_secrets = entity_vec.len() / n_buckets;
218 let mut buckets = vec![0.0_f32; n_buckets];
219 for s in 0..n_secrets {
220 for b in 0..n_buckets {
221 buckets[b] += entity_vec[s * n_buckets + b];
222 }
223 }
224 buckets
225}
226
227fn detect_drift(tensor: &AccessTensor, window: usize, threshold: f64) -> Vec<DriftDetection> {
228 let mut detections = Vec::new();
229 let (_, _, n_buckets) = tensor.dimensions();
230 if n_buckets < window * 2 {
231 return detections;
232 }
233
234 let historical_end = n_buckets - window;
235
236 for entity in tensor.entities() {
237 let vec = tensor.entity_vector(&entity);
238 if vec.is_empty() {
239 continue;
240 }
241
242 let total: f32 = vec.iter().sum();
243 if total < 1.0 {
244 continue;
245 }
246
247 let secrets = tensor.secrets();
249 let mut changed = Vec::new();
250 let mut hist_means = Vec::new();
251 let mut recent_means = Vec::new();
252 let hist_len = historical_end as f32;
253 let recent_len = window as f32;
254
255 for secret in &secrets {
256 let ts = tensor.time_series(&entity, secret);
257 if ts.len() < n_buckets {
258 continue;
259 }
260
261 let hist = &ts[..historical_end];
262 let recent = &ts[historical_end..];
263
264 let hist_mean = hist.iter().sum::<f32>() / hist_len.max(1.0);
265 let recent_mean = recent.iter().sum::<f32>() / recent_len.max(1.0);
266 hist_means.push(hist_mean);
267 recent_means.push(recent_mean);
268
269 let hist_sum: f32 = hist.iter().sum();
271 let recent_sum: f32 = recent.iter().sum();
272 if (hist_sum - recent_sum).abs() > hist_sum.max(1.0) * 0.5 {
273 changed.push(secret.clone());
274 }
275 }
276
277 let cos_dist = cosine_distance(&hist_means, &recent_means);
279 let hist_norm: f64 = hist_means.iter().map(|x| f64::from(*x)).sum();
280 let recent_norm: f64 = recent_means.iter().map(|x| f64::from(*x)).sum();
281 let denom = hist_norm.max(recent_norm).max(f64::EPSILON);
282 let magnitude_shift = (recent_norm - hist_norm).abs() / denom;
283 let drift_score = cos_dist.max(magnitude_shift);
284 let is_drifting = drift_score > threshold;
285
286 detections.push(DriftDetection {
287 entity,
288 drift_score,
289 is_drifting,
290 changed_secrets: changed,
291 });
292 }
293
294 detections
295}
296
297pub fn find_dominant_period(time_series: &[f32]) -> usize {
299 let n = time_series.len();
300 if n < 4 {
301 return 0;
302 }
303
304 #[allow(clippy::cast_precision_loss)] let mean: f32 = time_series.iter().sum::<f32>() / n as f32;
306 let centered: Vec<f32> = time_series.iter().map(|v| v - mean).collect();
307 let variance: f32 = centered.iter().map(|v| v * v).sum();
308
309 if variance < f32::EPSILON {
310 return 0;
311 }
312
313 let mut best_lag = 0;
314 let mut best_corr = f32::NEG_INFINITY;
315
316 let max_lag = n / 2;
318 for lag in 2..=max_lag {
319 let mut corr = 0.0_f32;
320 for i in 0..n - lag {
321 corr += centered[i] * centered[i + lag];
322 }
323 corr /= variance;
324
325 if corr > best_corr {
326 best_corr = corr;
327 best_lag = lag;
328 }
329 }
330
331 best_lag
332}
333
334fn cosine_distance(a: &[f32], b: &[f32]) -> f64 {
336 if a.len() != b.len() || a.is_empty() {
337 return 1.0;
338 }
339
340 let dot: f64 = a
341 .iter()
342 .zip(b.iter())
343 .map(|(x, y)| f64::from(*x) * f64::from(*y))
344 .sum();
345 let norm_a: f64 = a
346 .iter()
347 .map(|x| f64::from(*x) * f64::from(*x))
348 .sum::<f64>()
349 .sqrt();
350 let norm_b: f64 = b
351 .iter()
352 .map(|x| f64::from(*x) * f64::from(*x))
353 .sum::<f64>()
354 .sqrt();
355
356 if norm_a < f64::EPSILON || norm_b < f64::EPSILON {
357 return 1.0;
358 }
359
360 1.0 - (dot / (norm_a * norm_b))
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn test_seasonal_empty() {
369 let config = TemporalAnalysisConfig::default();
371 let report = analyze_temporal_patterns(&empty_tensor(), config);
372 assert!(report.seasonal_patterns.is_empty());
373 assert_eq!(report.total_entities_analyzed, 0);
374 }
375
376 #[test]
377 fn test_seasonal_periodic_signal() {
378 let period = 6;
380 let n_buckets = 24;
381 let mut data = vec![0.0_f32; n_buckets];
383 for i in 0..n_buckets {
384 data[i] = ((i % period) as f32 * std::f32::consts::PI / period as f32).sin() + 1.0;
385 }
386
387 let tensor = make_single_entity_tensor("user:alice", "secret1", &data);
388 let config = TemporalAnalysisConfig {
389 min_accesses: 1,
390 ..TemporalAnalysisConfig::default()
391 };
392 let report = analyze_temporal_patterns(&tensor, config);
393 if !report.seasonal_patterns.is_empty() {
395 assert!(
396 report.seasonal_patterns[0].reconstruction_error < 1.0,
397 "Periodic signal should compress well"
398 );
399 }
400 }
401
402 #[test]
403 fn test_seasonal_random_high_error() {
404 let n_buckets = 12; let data: Vec<f32> = (0..n_buckets).map(|i| ((i * 7 + 3) % 11) as f32).collect();
407
408 let tensor = make_single_entity_tensor("user:alice", "secret1", &data);
409 let config = TemporalAnalysisConfig {
410 min_accesses: 1,
411 ..TemporalAnalysisConfig::default()
412 };
413 let report = analyze_temporal_patterns(&tensor, config);
414 assert!(report.total_entities_analyzed <= 1);
416 }
417
418 #[test]
419 fn test_drift_stable_entity() {
420 let n_buckets = 48;
422 let data = vec![1.0_f32; n_buckets];
423 let tensor = make_single_entity_tensor("user:alice", "s1", &data);
424 let detections = detect_drift(&tensor, 12, 0.3);
425 for d in &detections {
426 assert!(!d.is_drifting, "Uniform pattern should not drift");
427 }
428 }
429
430 #[test]
431 fn test_drift_changed_entity() {
432 let n_buckets = 48;
434 let mut data = vec![0.0_f32; n_buckets];
435 for d in data.iter_mut().take(36) {
437 *d = 1.0;
438 }
439 for d in data.iter_mut().skip(36) {
441 *d = 10.0;
442 }
443
444 let tensor = make_single_entity_tensor("user:alice", "s1", &data);
445 let detections = detect_drift(&tensor, 12, 0.01);
446 assert!(!detections.is_empty());
447 let alice = detections.iter().find(|d| d.entity == "user:alice");
449 assert!(alice.is_some());
450 if let Some(d) = alice {
451 assert!(d.drift_score > 0.0);
452 }
453 }
454
455 #[test]
456 fn test_drift_threshold_boundary() {
457 let n_buckets = 48;
458 let data = vec![1.0_f32; n_buckets];
459 let tensor = make_single_entity_tensor("user:alice", "s1", &data);
460
461 let det_strict = detect_drift(&tensor, 12, 0.0);
463 let det_lax = detect_drift(&tensor, 12, 2.0);
465 for d in &det_lax {
466 assert!(!d.is_drifting);
467 }
468 let _ = det_strict;
470 }
471
472 #[test]
473 fn test_dominant_period() {
474 let period = 6;
476 let n = 48;
477 let signal: Vec<f32> = (0..n)
478 .map(|i| ((i % period) as f32 * std::f32::consts::PI * 2.0 / period as f32).sin())
479 .collect();
480
481 let result = find_dominant_period(&signal);
482 assert!(
484 result >= 4 && result <= 8,
485 "Expected period near 6, got {result}"
486 );
487 }
488
489 #[test]
490 fn test_temporal_min_accesses_filter() {
491 let n_buckets = 12;
492 let mut data = vec![0.0_f32; n_buckets];
494 data[0] = 1.0;
495
496 let tensor = make_single_entity_tensor("user:alice", "s1", &data);
497 let config = TemporalAnalysisConfig {
498 min_accesses: 10, ..TemporalAnalysisConfig::default()
500 };
501 let report = analyze_temporal_patterns(&tensor, config);
502 assert!(
503 report.seasonal_patterns.is_empty(),
504 "Entity with 1 access should be filtered"
505 );
506 }
507
508 fn empty_tensor() -> AccessTensor {
511 AccessTensor {
512 entity_index: std::collections::HashMap::new(),
513 secret_index: std::collections::HashMap::new(),
514 data: Vec::new(),
515 dimensions: (0, 0, 0),
516 config: crate::access_tensor::AccessTensorConfig::default(),
517 }
518 }
519
520 fn make_single_entity_tensor(entity: &str, secret: &str, data: &[f32]) -> AccessTensor {
521 let n_buckets = data.len();
522 let mut entity_index = std::collections::HashMap::new();
523 entity_index.insert(entity.to_string(), 0);
524 let mut secret_index = std::collections::HashMap::new();
525 secret_index.insert(secret.to_string(), 0);
526
527 AccessTensor {
528 entity_index,
529 secret_index,
530 data: data.to_vec(),
531 dimensions: (1, 1, n_buckets),
532 config: crate::access_tensor::AccessTensorConfig {
533 num_buckets: n_buckets,
534 ..crate::access_tensor::AccessTensorConfig::default()
535 },
536 }
537 }
538}