Skip to main content

ruvector_attention/unified_report/
report.rs

1//! Unified Geometry Report Builder
2
3use super::metrics::{MetricType, MetricValue};
4use crate::info_bottleneck::KLDivergence;
5use crate::pde_attention::GraphLaplacian;
6use crate::topology::WindowCoherence;
7use serde::{Deserialize, Serialize};
8
9/// Report configuration
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ReportConfig {
12    /// Number of OT projections
13    pub ot_projections: usize,
14    /// k for k-NN coherence
15    pub knn_k: usize,
16    /// Sigma for diffusion
17    pub diffusion_sigma: f32,
18    /// Whether to compute H0 persistence (expensive)
19    pub compute_persistence: bool,
20    /// Random seed
21    pub seed: u64,
22}
23
24impl Default for ReportConfig {
25    fn default() -> Self {
26        Self {
27            ot_projections: 8,
28            knn_k: 8,
29            diffusion_sigma: 1.0,
30            compute_persistence: false,
31            seed: 42,
32        }
33    }
34}
35
36/// Unified geometry report
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct GeometryReport {
39    /// OT sliced Wasserstein mean distance
40    pub ot_mean_distance: f32,
41    /// Topology coherence score
42    pub topology_coherence: f32,
43    /// H0 persistence death sum (if computed)
44    pub h0_death_sum: Option<f32>,
45    /// Information bottleneck KL
46    pub ib_kl: f32,
47    /// Diffusion energy
48    pub diffusion_energy: f32,
49    /// Attention entropy
50    pub attention_entropy: f32,
51    /// All metrics with thresholds
52    pub metrics: Vec<MetricValue>,
53    /// Overall health score (0-1)
54    pub health_score: f32,
55    /// Recommended action
56    pub recommendation: AttentionRecommendation,
57}
58
59/// Recommended action based on report
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
61pub enum AttentionRecommendation {
62    /// Full attention, normal operation
63    Stable,
64    /// Reduce attention width
65    Cautious,
66    /// Retrieval only, no updates
67    Freeze,
68    /// Increase temperature
69    IncreaseTemperature,
70    /// Decrease temperature
71    DecreaseTemperature,
72    /// Add regularization
73    AddRegularization,
74}
75
76/// Report builder
77pub struct ReportBuilder {
78    config: ReportConfig,
79}
80
81impl ReportBuilder {
82    /// Create new report builder
83    pub fn new(config: ReportConfig) -> Self {
84        Self { config }
85    }
86
87    /// Build report from query and keys
88    pub fn build(
89        &self,
90        query: &[f32],
91        keys: &[&[f32]],
92        attention_weights: Option<&[f32]>,
93        ib_mean: Option<&[f32]>,
94        ib_log_var: Option<&[f32]>,
95    ) -> GeometryReport {
96        let n = keys.len();
97        if n == 0 {
98            return GeometryReport::empty();
99        }
100
101        let _dim = keys[0].len();
102
103        // 1. OT distance (simplified sliced Wasserstein)
104        let ot_mean = self.compute_ot_distance(query, keys);
105
106        // 2. Topology coherence
107        let coherence = self.compute_coherence(keys);
108
109        // 3. H0 persistence (optional)
110        let h0_sum = if self.config.compute_persistence {
111            Some(self.compute_h0_persistence(keys))
112        } else {
113            None
114        };
115
116        // 4. IB KL
117        let ib_kl = match (ib_mean, ib_log_var) {
118            (Some(m), Some(v)) => KLDivergence::gaussian_to_unit_arrays(m, v),
119            _ => 0.0,
120        };
121
122        // 5. Diffusion energy
123        let diffusion_energy = self.compute_diffusion_energy(query, keys);
124
125        // 6. Attention entropy
126        let entropy = match attention_weights {
127            Some(w) => self.compute_entropy(w),
128            None => (n as f32).ln(), // Max entropy
129        };
130
131        // Build metrics
132        let mut metrics = vec![
133            MetricValue::new(MetricType::OTDistance, ot_mean, 0.0, 10.0, 5.0, 8.0),
134            MetricValue::new(MetricType::TopologyCoherence, coherence, 0.0, 1.0, 0.3, 0.1),
135            MetricValue::new(MetricType::IBKL, ib_kl, 0.0, 100.0, 50.0, 80.0),
136            MetricValue::new(
137                MetricType::DiffusionEnergy,
138                diffusion_energy,
139                0.0,
140                100.0,
141                50.0,
142                80.0,
143            ),
144            MetricValue::new(
145                MetricType::AttentionEntropy,
146                entropy,
147                0.0,
148                (n as f32).ln().max(1.0),
149                0.5,
150                0.2,
151            ),
152        ];
153
154        if let Some(h0) = h0_sum {
155            metrics.push(MetricValue::new(
156                MetricType::H0Persistence,
157                h0,
158                0.0,
159                100.0,
160                50.0,
161                80.0,
162            ));
163        }
164
165        // Compute health score
166        let health_score = self.compute_health_score(&metrics);
167
168        // Determine recommendation
169        let recommendation = self.determine_recommendation(&metrics, coherence, entropy, n);
170
171        GeometryReport {
172            ot_mean_distance: ot_mean,
173            topology_coherence: coherence,
174            h0_death_sum: h0_sum,
175            ib_kl,
176            diffusion_energy,
177            attention_entropy: entropy,
178            metrics,
179            health_score,
180            recommendation,
181        }
182    }
183
184    /// Simplified sliced Wasserstein distance
185    fn compute_ot_distance(&self, query: &[f32], keys: &[&[f32]]) -> f32 {
186        let dim = query.len();
187        let n = keys.len();
188        if n == 0 {
189            return 0.0;
190        }
191
192        // Generate random projections
193        let mut rng_state = self.config.seed;
194        let projections: Vec<Vec<f32>> = (0..self.config.ot_projections)
195            .map(|_| self.random_unit_vector(dim, &mut rng_state))
196            .collect();
197
198        // Project query
199        let q_projs: Vec<f32> = projections.iter().map(|p| Self::dot(query, p)).collect();
200
201        // Mean absolute distance over keys
202        let mut total = 0.0f32;
203        for key in keys {
204            let mut dist = 0.0f32;
205            for (i, proj) in projections.iter().enumerate() {
206                let k_proj = Self::dot(key, proj);
207                dist += (q_projs[i] - k_proj).abs();
208            }
209            total += dist / self.config.ot_projections as f32;
210        }
211
212        total / n as f32
213    }
214
215    /// Compute coherence using WindowCoherence
216    fn compute_coherence(&self, keys: &[&[f32]]) -> f32 {
217        use crate::topology::CoherenceMetric;
218
219        let coherence = WindowCoherence::compute(
220            keys,
221            self.config.knn_k,
222            &[
223                CoherenceMetric::BoundaryMass,
224                CoherenceMetric::SimilarityVariance,
225            ],
226        );
227
228        coherence.score
229    }
230
231    /// Compute H0 persistence (expensive)
232    fn compute_h0_persistence(&self, keys: &[&[f32]]) -> f32 {
233        let n = keys.len();
234        if n <= 1 {
235            return 0.0;
236        }
237
238        // Build distance matrix
239        let mut edges: Vec<(f32, usize, usize)> = Vec::new();
240        for i in 0..n {
241            for j in (i + 1)..n {
242                let dist = Self::l2_distance(keys[i], keys[j]);
243                edges.push((dist, i, j));
244            }
245        }
246
247        edges.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
248
249        // Union-Find for Kruskal's algorithm
250        let mut parent: Vec<usize> = (0..n).collect();
251        let mut rank = vec![0u8; n];
252        let mut deaths = Vec::new();
253
254        fn find(parent: &mut [usize], x: usize) -> usize {
255            if parent[x] != x {
256                parent[x] = find(parent, parent[x]);
257            }
258            parent[x]
259        }
260
261        fn union(parent: &mut [usize], rank: &mut [u8], a: usize, b: usize) -> bool {
262            let mut ra = find(parent, a);
263            let mut rb = find(parent, b);
264            if ra == rb {
265                return false;
266            }
267            if rank[ra] < rank[rb] {
268                std::mem::swap(&mut ra, &mut rb);
269            }
270            parent[rb] = ra;
271            if rank[ra] == rank[rb] {
272                rank[ra] += 1;
273            }
274            true
275        }
276
277        for (w, i, j) in edges {
278            if union(&mut parent, &mut rank, i, j) {
279                deaths.push(w);
280                if deaths.len() == n - 1 {
281                    break;
282                }
283            }
284        }
285
286        // Remove last (infinite lifetime component)
287        if !deaths.is_empty() {
288            deaths.pop();
289        }
290
291        deaths.iter().sum()
292    }
293
294    /// Compute diffusion energy
295    fn compute_diffusion_energy(&self, query: &[f32], keys: &[&[f32]]) -> f32 {
296        use crate::pde_attention::LaplacianType;
297
298        let n = keys.len();
299        if n == 0 {
300            return 0.0;
301        }
302
303        // Initial logits
304        let x: Vec<f32> = keys.iter().map(|k| Self::dot(query, k)).collect();
305
306        // Build Laplacian
307        let lap = GraphLaplacian::from_keys(
308            keys,
309            self.config.diffusion_sigma,
310            LaplacianType::Unnormalized,
311        );
312
313        // Energy = x^T L x
314        let lx = lap.apply(&x);
315        Self::dot(&x, &lx)
316    }
317
318    /// Compute entropy
319    fn compute_entropy(&self, weights: &[f32]) -> f32 {
320        let eps = 1e-10;
321        let mut entropy = 0.0f32;
322
323        for &w in weights {
324            if w > eps {
325                entropy -= w * w.ln();
326            }
327        }
328
329        entropy.max(0.0)
330    }
331
332    /// Compute overall health score
333    fn compute_health_score(&self, metrics: &[MetricValue]) -> f32 {
334        if metrics.is_empty() {
335            return 1.0;
336        }
337
338        let healthy_count = metrics.iter().filter(|m| m.is_healthy).count();
339        healthy_count as f32 / metrics.len() as f32
340    }
341
342    /// Determine recommendation
343    fn determine_recommendation(
344        &self,
345        metrics: &[MetricValue],
346        coherence: f32,
347        entropy: f32,
348        n: usize,
349    ) -> AttentionRecommendation {
350        let max_entropy = (n as f32).ln().max(1.0);
351        let entropy_ratio = entropy / max_entropy;
352
353        // Check for critical conditions
354        let has_critical = metrics.iter().any(|m| m.is_critical());
355        if has_critical {
356            return AttentionRecommendation::Freeze;
357        }
358
359        // Low coherence = cautious mode
360        if coherence < 0.3 {
361            return AttentionRecommendation::Cautious;
362        }
363
364        // Very low entropy = temperature too low
365        if entropy_ratio < 0.2 {
366            return AttentionRecommendation::IncreaseTemperature;
367        }
368
369        // Very high entropy = temperature too high
370        if entropy_ratio > 0.9 {
371            return AttentionRecommendation::DecreaseTemperature;
372        }
373
374        // Check for warnings
375        let has_warning = metrics.iter().any(|m| m.is_warning());
376        if has_warning {
377            return AttentionRecommendation::AddRegularization;
378        }
379
380        AttentionRecommendation::Stable
381    }
382
383    /// Generate random unit vector
384    fn random_unit_vector(&self, dim: usize, state: &mut u64) -> Vec<f32> {
385        let mut v = vec![0.0f32; dim];
386        for i in 0..dim {
387            // XorShift
388            *state ^= *state << 13;
389            *state ^= *state >> 7;
390            *state ^= *state << 17;
391            let u = (*state & 0x00FF_FFFF) as f32 / 16_777_216.0;
392            v[i] = u * 2.0 - 1.0;
393        }
394
395        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
396        if norm > 0.0 {
397            for x in v.iter_mut() {
398                *x /= norm;
399            }
400        }
401
402        v
403    }
404
405    /// Dot product
406    #[inline]
407    fn dot(a: &[f32], b: &[f32]) -> f32 {
408        a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
409    }
410
411    /// L2 distance
412    #[inline]
413    fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
414        a.iter()
415            .zip(b.iter())
416            .map(|(&ai, &bi)| (ai - bi) * (ai - bi))
417            .sum::<f32>()
418            .sqrt()
419    }
420}
421
422impl GeometryReport {
423    /// Create empty report
424    pub fn empty() -> Self {
425        Self {
426            ot_mean_distance: 0.0,
427            topology_coherence: 1.0,
428            h0_death_sum: None,
429            ib_kl: 0.0,
430            diffusion_energy: 0.0,
431            attention_entropy: 0.0,
432            metrics: vec![],
433            health_score: 1.0,
434            recommendation: AttentionRecommendation::Stable,
435        }
436    }
437
438    /// Check if attention is healthy
439    pub fn is_healthy(&self) -> bool {
440        self.health_score > 0.7
441    }
442
443    /// Get all warning metrics
444    pub fn warnings(&self) -> Vec<&MetricValue> {
445        self.metrics.iter().filter(|m| m.is_warning()).collect()
446    }
447
448    /// Get all critical metrics
449    pub fn criticals(&self) -> Vec<&MetricValue> {
450        self.metrics.iter().filter(|m| m.is_critical()).collect()
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457
458    #[test]
459    fn test_report_builder() {
460        let builder = ReportBuilder::new(ReportConfig::default());
461
462        let query = vec![1.0f32; 16];
463        let keys: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 * 0.1; 16]).collect();
464        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
465
466        let report = builder.build(&query, &keys_refs, None, None, None);
467
468        assert!(report.topology_coherence >= 0.0);
469        assert!(report.topology_coherence <= 1.0);
470        assert!(report.health_score >= 0.0);
471        assert!(report.health_score <= 1.0);
472    }
473
474    #[test]
475    fn test_empty_report() {
476        let report = GeometryReport::empty();
477        assert!(report.is_healthy());
478        assert_eq!(report.recommendation, AttentionRecommendation::Stable);
479    }
480
481    #[test]
482    fn test_with_attention_weights() {
483        let builder = ReportBuilder::new(ReportConfig::default());
484
485        let query = vec![1.0f32; 8];
486        let keys: Vec<Vec<f32>> = vec![vec![1.0; 8], vec![0.9; 8], vec![0.1; 8]];
487        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
488        let weights = vec![0.6, 0.3, 0.1];
489
490        let report = builder.build(&query, &keys_refs, Some(&weights), None, None);
491
492        assert!(report.attention_entropy > 0.0);
493    }
494}