ruvector_attention/unified_report/
report.rs

1//! Unified Geometry Report Builder
2
3use super::metrics::{MetricType, MetricValue};
4use crate::topology::WindowCoherence;
5use crate::info_bottleneck::KLDivergence;
6use crate::pde_attention::GraphLaplacian;
7use serde::{Deserialize, Serialize};
8
9/// Report configuration
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ReportConfig {
12    /// Number of OT projections
13    pub ot_projections: usize,
14    /// k for k-NN coherence
15    pub knn_k: usize,
16    /// Sigma for diffusion
17    pub diffusion_sigma: f32,
18    /// Whether to compute H0 persistence (expensive)
19    pub compute_persistence: bool,
20    /// Random seed
21    pub seed: u64,
22}
23
24impl Default for ReportConfig {
25    fn default() -> Self {
26        Self {
27            ot_projections: 8,
28            knn_k: 8,
29            diffusion_sigma: 1.0,
30            compute_persistence: false,
31            seed: 42,
32        }
33    }
34}
35
36/// Unified geometry report
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct GeometryReport {
39    /// OT sliced Wasserstein mean distance
40    pub ot_mean_distance: f32,
41    /// Topology coherence score
42    pub topology_coherence: f32,
43    /// H0 persistence death sum (if computed)
44    pub h0_death_sum: Option<f32>,
45    /// Information bottleneck KL
46    pub ib_kl: f32,
47    /// Diffusion energy
48    pub diffusion_energy: f32,
49    /// Attention entropy
50    pub attention_entropy: f32,
51    /// All metrics with thresholds
52    pub metrics: Vec<MetricValue>,
53    /// Overall health score (0-1)
54    pub health_score: f32,
55    /// Recommended action
56    pub recommendation: AttentionRecommendation,
57}
58
59/// Recommended action based on report
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
61pub enum AttentionRecommendation {
62    /// Full attention, normal operation
63    Stable,
64    /// Reduce attention width
65    Cautious,
66    /// Retrieval only, no updates
67    Freeze,
68    /// Increase temperature
69    IncreaseTemperature,
70    /// Decrease temperature
71    DecreaseTemperature,
72    /// Add regularization
73    AddRegularization,
74}
75
76/// Report builder
77pub struct ReportBuilder {
78    config: ReportConfig,
79}
80
81impl ReportBuilder {
82    /// Create new report builder
83    pub fn new(config: ReportConfig) -> Self {
84        Self { config }
85    }
86
87    /// Build report from query and keys
88    pub fn build(
89        &self,
90        query: &[f32],
91        keys: &[&[f32]],
92        attention_weights: Option<&[f32]>,
93        ib_mean: Option<&[f32]>,
94        ib_log_var: Option<&[f32]>,
95    ) -> GeometryReport {
96        let n = keys.len();
97        if n == 0 {
98            return GeometryReport::empty();
99        }
100
101        let _dim = keys[0].len();
102
103        // 1. OT distance (simplified sliced Wasserstein)
104        let ot_mean = self.compute_ot_distance(query, keys);
105
106        // 2. Topology coherence
107        let coherence = self.compute_coherence(keys);
108
109        // 3. H0 persistence (optional)
110        let h0_sum = if self.config.compute_persistence {
111            Some(self.compute_h0_persistence(keys))
112        } else {
113            None
114        };
115
116        // 4. IB KL
117        let ib_kl = match (ib_mean, ib_log_var) {
118            (Some(m), Some(v)) => KLDivergence::gaussian_to_unit_arrays(m, v),
119            _ => 0.0,
120        };
121
122        // 5. Diffusion energy
123        let diffusion_energy = self.compute_diffusion_energy(query, keys);
124
125        // 6. Attention entropy
126        let entropy = match attention_weights {
127            Some(w) => self.compute_entropy(w),
128            None => (n as f32).ln(), // Max entropy
129        };
130
131        // Build metrics
132        let mut metrics = vec![
133            MetricValue::new(MetricType::OTDistance, ot_mean, 0.0, 10.0, 5.0, 8.0),
134            MetricValue::new(MetricType::TopologyCoherence, coherence, 0.0, 1.0, 0.3, 0.1),
135            MetricValue::new(MetricType::IBKL, ib_kl, 0.0, 100.0, 50.0, 80.0),
136            MetricValue::new(MetricType::DiffusionEnergy, diffusion_energy, 0.0, 100.0, 50.0, 80.0),
137            MetricValue::new(MetricType::AttentionEntropy, entropy, 0.0, (n as f32).ln().max(1.0), 0.5, 0.2),
138        ];
139
140        if let Some(h0) = h0_sum {
141            metrics.push(MetricValue::new(MetricType::H0Persistence, h0, 0.0, 100.0, 50.0, 80.0));
142        }
143
144        // Compute health score
145        let health_score = self.compute_health_score(&metrics);
146
147        // Determine recommendation
148        let recommendation = self.determine_recommendation(&metrics, coherence, entropy, n);
149
150        GeometryReport {
151            ot_mean_distance: ot_mean,
152            topology_coherence: coherence,
153            h0_death_sum: h0_sum,
154            ib_kl,
155            diffusion_energy,
156            attention_entropy: entropy,
157            metrics,
158            health_score,
159            recommendation,
160        }
161    }
162
163    /// Simplified sliced Wasserstein distance
164    fn compute_ot_distance(&self, query: &[f32], keys: &[&[f32]]) -> f32 {
165        let dim = query.len();
166        let n = keys.len();
167        if n == 0 {
168            return 0.0;
169        }
170
171        // Generate random projections
172        let mut rng_state = self.config.seed;
173        let projections: Vec<Vec<f32>> = (0..self.config.ot_projections)
174            .map(|_| self.random_unit_vector(dim, &mut rng_state))
175            .collect();
176
177        // Project query
178        let q_projs: Vec<f32> = projections.iter()
179            .map(|p| Self::dot(query, p))
180            .collect();
181
182        // Mean absolute distance over keys
183        let mut total = 0.0f32;
184        for key in keys {
185            let mut dist = 0.0f32;
186            for (i, proj) in projections.iter().enumerate() {
187                let k_proj = Self::dot(key, proj);
188                dist += (q_projs[i] - k_proj).abs();
189            }
190            total += dist / self.config.ot_projections as f32;
191        }
192
193        total / n as f32
194    }
195
196    /// Compute coherence using WindowCoherence
197    fn compute_coherence(&self, keys: &[&[f32]]) -> f32 {
198        use crate::topology::CoherenceMetric;
199
200        let coherence = WindowCoherence::compute(
201            keys,
202            self.config.knn_k,
203            &[CoherenceMetric::BoundaryMass, CoherenceMetric::SimilarityVariance],
204        );
205
206        coherence.score
207    }
208
209    /// Compute H0 persistence (expensive)
210    fn compute_h0_persistence(&self, keys: &[&[f32]]) -> f32 {
211        let n = keys.len();
212        if n <= 1 {
213            return 0.0;
214        }
215
216        // Build distance matrix
217        let mut edges: Vec<(f32, usize, usize)> = Vec::new();
218        for i in 0..n {
219            for j in (i + 1)..n {
220                let dist = Self::l2_distance(keys[i], keys[j]);
221                edges.push((dist, i, j));
222            }
223        }
224
225        edges.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
226
227        // Union-Find for Kruskal's algorithm
228        let mut parent: Vec<usize> = (0..n).collect();
229        let mut rank = vec![0u8; n];
230        let mut deaths = Vec::new();
231
232        fn find(parent: &mut [usize], x: usize) -> usize {
233            if parent[x] != x {
234                parent[x] = find(parent, parent[x]);
235            }
236            parent[x]
237        }
238
239        fn union(parent: &mut [usize], rank: &mut [u8], a: usize, b: usize) -> bool {
240            let mut ra = find(parent, a);
241            let mut rb = find(parent, b);
242            if ra == rb {
243                return false;
244            }
245            if rank[ra] < rank[rb] {
246                std::mem::swap(&mut ra, &mut rb);
247            }
248            parent[rb] = ra;
249            if rank[ra] == rank[rb] {
250                rank[ra] += 1;
251            }
252            true
253        }
254
255        for (w, i, j) in edges {
256            if union(&mut parent, &mut rank, i, j) {
257                deaths.push(w);
258                if deaths.len() == n - 1 {
259                    break;
260                }
261            }
262        }
263
264        // Remove last (infinite lifetime component)
265        if !deaths.is_empty() {
266            deaths.pop();
267        }
268
269        deaths.iter().sum()
270    }
271
272    /// Compute diffusion energy
273    fn compute_diffusion_energy(&self, query: &[f32], keys: &[&[f32]]) -> f32 {
274        use crate::pde_attention::LaplacianType;
275
276        let n = keys.len();
277        if n == 0 {
278            return 0.0;
279        }
280
281        // Initial logits
282        let x: Vec<f32> = keys.iter()
283            .map(|k| Self::dot(query, k))
284            .collect();
285
286        // Build Laplacian
287        let lap = GraphLaplacian::from_keys(keys, self.config.diffusion_sigma, LaplacianType::Unnormalized);
288
289        // Energy = x^T L x
290        let lx = lap.apply(&x);
291        Self::dot(&x, &lx)
292    }
293
294    /// Compute entropy
295    fn compute_entropy(&self, weights: &[f32]) -> f32 {
296        let eps = 1e-10;
297        let mut entropy = 0.0f32;
298
299        for &w in weights {
300            if w > eps {
301                entropy -= w * w.ln();
302            }
303        }
304
305        entropy.max(0.0)
306    }
307
308    /// Compute overall health score
309    fn compute_health_score(&self, metrics: &[MetricValue]) -> f32 {
310        if metrics.is_empty() {
311            return 1.0;
312        }
313
314        let healthy_count = metrics.iter().filter(|m| m.is_healthy).count();
315        healthy_count as f32 / metrics.len() as f32
316    }
317
318    /// Determine recommendation
319    fn determine_recommendation(
320        &self,
321        metrics: &[MetricValue],
322        coherence: f32,
323        entropy: f32,
324        n: usize,
325    ) -> AttentionRecommendation {
326        let max_entropy = (n as f32).ln().max(1.0);
327        let entropy_ratio = entropy / max_entropy;
328
329        // Check for critical conditions
330        let has_critical = metrics.iter().any(|m| m.is_critical());
331        if has_critical {
332            return AttentionRecommendation::Freeze;
333        }
334
335        // Low coherence = cautious mode
336        if coherence < 0.3 {
337            return AttentionRecommendation::Cautious;
338        }
339
340        // Very low entropy = temperature too low
341        if entropy_ratio < 0.2 {
342            return AttentionRecommendation::IncreaseTemperature;
343        }
344
345        // Very high entropy = temperature too high
346        if entropy_ratio > 0.9 {
347            return AttentionRecommendation::DecreaseTemperature;
348        }
349
350        // Check for warnings
351        let has_warning = metrics.iter().any(|m| m.is_warning());
352        if has_warning {
353            return AttentionRecommendation::AddRegularization;
354        }
355
356        AttentionRecommendation::Stable
357    }
358
359    /// Generate random unit vector
360    fn random_unit_vector(&self, dim: usize, state: &mut u64) -> Vec<f32> {
361        let mut v = vec![0.0f32; dim];
362        for i in 0..dim {
363            // XorShift
364            *state ^= *state << 13;
365            *state ^= *state >> 7;
366            *state ^= *state << 17;
367            let u = (*state & 0x00FF_FFFF) as f32 / 16_777_216.0;
368            v[i] = u * 2.0 - 1.0;
369        }
370
371        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
372        if norm > 0.0 {
373            for x in v.iter_mut() {
374                *x /= norm;
375            }
376        }
377
378        v
379    }
380
381    /// Dot product
382    #[inline]
383    fn dot(a: &[f32], b: &[f32]) -> f32 {
384        a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
385    }
386
387    /// L2 distance
388    #[inline]
389    fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
390        a.iter()
391            .zip(b.iter())
392            .map(|(&ai, &bi)| (ai - bi) * (ai - bi))
393            .sum::<f32>()
394            .sqrt()
395    }
396}
397
398impl GeometryReport {
399    /// Create empty report
400    pub fn empty() -> Self {
401        Self {
402            ot_mean_distance: 0.0,
403            topology_coherence: 1.0,
404            h0_death_sum: None,
405            ib_kl: 0.0,
406            diffusion_energy: 0.0,
407            attention_entropy: 0.0,
408            metrics: vec![],
409            health_score: 1.0,
410            recommendation: AttentionRecommendation::Stable,
411        }
412    }
413
414    /// Check if attention is healthy
415    pub fn is_healthy(&self) -> bool {
416        self.health_score > 0.7
417    }
418
419    /// Get all warning metrics
420    pub fn warnings(&self) -> Vec<&MetricValue> {
421        self.metrics.iter().filter(|m| m.is_warning()).collect()
422    }
423
424    /// Get all critical metrics
425    pub fn criticals(&self) -> Vec<&MetricValue> {
426        self.metrics.iter().filter(|m| m.is_critical()).collect()
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    #[test]
435    fn test_report_builder() {
436        let builder = ReportBuilder::new(ReportConfig::default());
437
438        let query = vec![1.0f32; 16];
439        let keys: Vec<Vec<f32>> = (0..10)
440            .map(|i| vec![i as f32 * 0.1; 16])
441            .collect();
442        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
443
444        let report = builder.build(&query, &keys_refs, None, None, None);
445
446        assert!(report.topology_coherence >= 0.0);
447        assert!(report.topology_coherence <= 1.0);
448        assert!(report.health_score >= 0.0);
449        assert!(report.health_score <= 1.0);
450    }
451
452    #[test]
453    fn test_empty_report() {
454        let report = GeometryReport::empty();
455        assert!(report.is_healthy());
456        assert_eq!(report.recommendation, AttentionRecommendation::Stable);
457    }
458
459    #[test]
460    fn test_with_attention_weights() {
461        let builder = ReportBuilder::new(ReportConfig::default());
462
463        let query = vec![1.0f32; 8];
464        let keys: Vec<Vec<f32>> = vec![
465            vec![1.0; 8],
466            vec![0.9; 8],
467            vec![0.1; 8],
468        ];
469        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
470        let weights = vec![0.6, 0.3, 0.1];
471
472        let report = builder.build(&query, &keys_refs, Some(&weights), None, None);
473
474        assert!(report.attention_entropy > 0.0);
475    }
476}