1use super::metrics::{MetricType, MetricValue};
4use crate::topology::WindowCoherence;
5use crate::info_bottleneck::KLDivergence;
6use crate::pde_attention::GraphLaplacian;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ReportConfig {
12 pub ot_projections: usize,
14 pub knn_k: usize,
16 pub diffusion_sigma: f32,
18 pub compute_persistence: bool,
20 pub seed: u64,
22}
23
24impl Default for ReportConfig {
25 fn default() -> Self {
26 Self {
27 ot_projections: 8,
28 knn_k: 8,
29 diffusion_sigma: 1.0,
30 compute_persistence: false,
31 seed: 42,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct GeometryReport {
39 pub ot_mean_distance: f32,
41 pub topology_coherence: f32,
43 pub h0_death_sum: Option<f32>,
45 pub ib_kl: f32,
47 pub diffusion_energy: f32,
49 pub attention_entropy: f32,
51 pub metrics: Vec<MetricValue>,
53 pub health_score: f32,
55 pub recommendation: AttentionRecommendation,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
61pub enum AttentionRecommendation {
62 Stable,
64 Cautious,
66 Freeze,
68 IncreaseTemperature,
70 DecreaseTemperature,
72 AddRegularization,
74}
75
76pub struct ReportBuilder {
78 config: ReportConfig,
79}
80
81impl ReportBuilder {
82 pub fn new(config: ReportConfig) -> Self {
84 Self { config }
85 }
86
87 pub fn build(
89 &self,
90 query: &[f32],
91 keys: &[&[f32]],
92 attention_weights: Option<&[f32]>,
93 ib_mean: Option<&[f32]>,
94 ib_log_var: Option<&[f32]>,
95 ) -> GeometryReport {
96 let n = keys.len();
97 if n == 0 {
98 return GeometryReport::empty();
99 }
100
101 let _dim = keys[0].len();
102
103 let ot_mean = self.compute_ot_distance(query, keys);
105
106 let coherence = self.compute_coherence(keys);
108
109 let h0_sum = if self.config.compute_persistence {
111 Some(self.compute_h0_persistence(keys))
112 } else {
113 None
114 };
115
116 let ib_kl = match (ib_mean, ib_log_var) {
118 (Some(m), Some(v)) => KLDivergence::gaussian_to_unit_arrays(m, v),
119 _ => 0.0,
120 };
121
122 let diffusion_energy = self.compute_diffusion_energy(query, keys);
124
125 let entropy = match attention_weights {
127 Some(w) => self.compute_entropy(w),
128 None => (n as f32).ln(), };
130
131 let mut metrics = vec![
133 MetricValue::new(MetricType::OTDistance, ot_mean, 0.0, 10.0, 5.0, 8.0),
134 MetricValue::new(MetricType::TopologyCoherence, coherence, 0.0, 1.0, 0.3, 0.1),
135 MetricValue::new(MetricType::IBKL, ib_kl, 0.0, 100.0, 50.0, 80.0),
136 MetricValue::new(MetricType::DiffusionEnergy, diffusion_energy, 0.0, 100.0, 50.0, 80.0),
137 MetricValue::new(MetricType::AttentionEntropy, entropy, 0.0, (n as f32).ln().max(1.0), 0.5, 0.2),
138 ];
139
140 if let Some(h0) = h0_sum {
141 metrics.push(MetricValue::new(MetricType::H0Persistence, h0, 0.0, 100.0, 50.0, 80.0));
142 }
143
144 let health_score = self.compute_health_score(&metrics);
146
147 let recommendation = self.determine_recommendation(&metrics, coherence, entropy, n);
149
150 GeometryReport {
151 ot_mean_distance: ot_mean,
152 topology_coherence: coherence,
153 h0_death_sum: h0_sum,
154 ib_kl,
155 diffusion_energy,
156 attention_entropy: entropy,
157 metrics,
158 health_score,
159 recommendation,
160 }
161 }
162
163 fn compute_ot_distance(&self, query: &[f32], keys: &[&[f32]]) -> f32 {
165 let dim = query.len();
166 let n = keys.len();
167 if n == 0 {
168 return 0.0;
169 }
170
171 let mut rng_state = self.config.seed;
173 let projections: Vec<Vec<f32>> = (0..self.config.ot_projections)
174 .map(|_| self.random_unit_vector(dim, &mut rng_state))
175 .collect();
176
177 let q_projs: Vec<f32> = projections.iter()
179 .map(|p| Self::dot(query, p))
180 .collect();
181
182 let mut total = 0.0f32;
184 for key in keys {
185 let mut dist = 0.0f32;
186 for (i, proj) in projections.iter().enumerate() {
187 let k_proj = Self::dot(key, proj);
188 dist += (q_projs[i] - k_proj).abs();
189 }
190 total += dist / self.config.ot_projections as f32;
191 }
192
193 total / n as f32
194 }
195
196 fn compute_coherence(&self, keys: &[&[f32]]) -> f32 {
198 use crate::topology::CoherenceMetric;
199
200 let coherence = WindowCoherence::compute(
201 keys,
202 self.config.knn_k,
203 &[CoherenceMetric::BoundaryMass, CoherenceMetric::SimilarityVariance],
204 );
205
206 coherence.score
207 }
208
209 fn compute_h0_persistence(&self, keys: &[&[f32]]) -> f32 {
211 let n = keys.len();
212 if n <= 1 {
213 return 0.0;
214 }
215
216 let mut edges: Vec<(f32, usize, usize)> = Vec::new();
218 for i in 0..n {
219 for j in (i + 1)..n {
220 let dist = Self::l2_distance(keys[i], keys[j]);
221 edges.push((dist, i, j));
222 }
223 }
224
225 edges.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
226
227 let mut parent: Vec<usize> = (0..n).collect();
229 let mut rank = vec![0u8; n];
230 let mut deaths = Vec::new();
231
232 fn find(parent: &mut [usize], x: usize) -> usize {
233 if parent[x] != x {
234 parent[x] = find(parent, parent[x]);
235 }
236 parent[x]
237 }
238
239 fn union(parent: &mut [usize], rank: &mut [u8], a: usize, b: usize) -> bool {
240 let mut ra = find(parent, a);
241 let mut rb = find(parent, b);
242 if ra == rb {
243 return false;
244 }
245 if rank[ra] < rank[rb] {
246 std::mem::swap(&mut ra, &mut rb);
247 }
248 parent[rb] = ra;
249 if rank[ra] == rank[rb] {
250 rank[ra] += 1;
251 }
252 true
253 }
254
255 for (w, i, j) in edges {
256 if union(&mut parent, &mut rank, i, j) {
257 deaths.push(w);
258 if deaths.len() == n - 1 {
259 break;
260 }
261 }
262 }
263
264 if !deaths.is_empty() {
266 deaths.pop();
267 }
268
269 deaths.iter().sum()
270 }
271
272 fn compute_diffusion_energy(&self, query: &[f32], keys: &[&[f32]]) -> f32 {
274 use crate::pde_attention::LaplacianType;
275
276 let n = keys.len();
277 if n == 0 {
278 return 0.0;
279 }
280
281 let x: Vec<f32> = keys.iter()
283 .map(|k| Self::dot(query, k))
284 .collect();
285
286 let lap = GraphLaplacian::from_keys(keys, self.config.diffusion_sigma, LaplacianType::Unnormalized);
288
289 let lx = lap.apply(&x);
291 Self::dot(&x, &lx)
292 }
293
294 fn compute_entropy(&self, weights: &[f32]) -> f32 {
296 let eps = 1e-10;
297 let mut entropy = 0.0f32;
298
299 for &w in weights {
300 if w > eps {
301 entropy -= w * w.ln();
302 }
303 }
304
305 entropy.max(0.0)
306 }
307
308 fn compute_health_score(&self, metrics: &[MetricValue]) -> f32 {
310 if metrics.is_empty() {
311 return 1.0;
312 }
313
314 let healthy_count = metrics.iter().filter(|m| m.is_healthy).count();
315 healthy_count as f32 / metrics.len() as f32
316 }
317
318 fn determine_recommendation(
320 &self,
321 metrics: &[MetricValue],
322 coherence: f32,
323 entropy: f32,
324 n: usize,
325 ) -> AttentionRecommendation {
326 let max_entropy = (n as f32).ln().max(1.0);
327 let entropy_ratio = entropy / max_entropy;
328
329 let has_critical = metrics.iter().any(|m| m.is_critical());
331 if has_critical {
332 return AttentionRecommendation::Freeze;
333 }
334
335 if coherence < 0.3 {
337 return AttentionRecommendation::Cautious;
338 }
339
340 if entropy_ratio < 0.2 {
342 return AttentionRecommendation::IncreaseTemperature;
343 }
344
345 if entropy_ratio > 0.9 {
347 return AttentionRecommendation::DecreaseTemperature;
348 }
349
350 let has_warning = metrics.iter().any(|m| m.is_warning());
352 if has_warning {
353 return AttentionRecommendation::AddRegularization;
354 }
355
356 AttentionRecommendation::Stable
357 }
358
359 fn random_unit_vector(&self, dim: usize, state: &mut u64) -> Vec<f32> {
361 let mut v = vec![0.0f32; dim];
362 for i in 0..dim {
363 *state ^= *state << 13;
365 *state ^= *state >> 7;
366 *state ^= *state << 17;
367 let u = (*state & 0x00FF_FFFF) as f32 / 16_777_216.0;
368 v[i] = u * 2.0 - 1.0;
369 }
370
371 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
372 if norm > 0.0 {
373 for x in v.iter_mut() {
374 *x /= norm;
375 }
376 }
377
378 v
379 }
380
381 #[inline]
383 fn dot(a: &[f32], b: &[f32]) -> f32 {
384 a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
385 }
386
387 #[inline]
389 fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
390 a.iter()
391 .zip(b.iter())
392 .map(|(&ai, &bi)| (ai - bi) * (ai - bi))
393 .sum::<f32>()
394 .sqrt()
395 }
396}
397
398impl GeometryReport {
399 pub fn empty() -> Self {
401 Self {
402 ot_mean_distance: 0.0,
403 topology_coherence: 1.0,
404 h0_death_sum: None,
405 ib_kl: 0.0,
406 diffusion_energy: 0.0,
407 attention_entropy: 0.0,
408 metrics: vec![],
409 health_score: 1.0,
410 recommendation: AttentionRecommendation::Stable,
411 }
412 }
413
414 pub fn is_healthy(&self) -> bool {
416 self.health_score > 0.7
417 }
418
419 pub fn warnings(&self) -> Vec<&MetricValue> {
421 self.metrics.iter().filter(|m| m.is_warning()).collect()
422 }
423
424 pub fn criticals(&self) -> Vec<&MetricValue> {
426 self.metrics.iter().filter(|m| m.is_critical()).collect()
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[test]
435 fn test_report_builder() {
436 let builder = ReportBuilder::new(ReportConfig::default());
437
438 let query = vec![1.0f32; 16];
439 let keys: Vec<Vec<f32>> = (0..10)
440 .map(|i| vec![i as f32 * 0.1; 16])
441 .collect();
442 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
443
444 let report = builder.build(&query, &keys_refs, None, None, None);
445
446 assert!(report.topology_coherence >= 0.0);
447 assert!(report.topology_coherence <= 1.0);
448 assert!(report.health_score >= 0.0);
449 assert!(report.health_score <= 1.0);
450 }
451
452 #[test]
453 fn test_empty_report() {
454 let report = GeometryReport::empty();
455 assert!(report.is_healthy());
456 assert_eq!(report.recommendation, AttentionRecommendation::Stable);
457 }
458
459 #[test]
460 fn test_with_attention_weights() {
461 let builder = ReportBuilder::new(ReportConfig::default());
462
463 let query = vec![1.0f32; 8];
464 let keys: Vec<Vec<f32>> = vec![
465 vec![1.0; 8],
466 vec![0.9; 8],
467 vec![0.1; 8],
468 ];
469 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
470 let weights = vec![0.6, 0.3, 0.1];
471
472 let report = builder.build(&query, &keys_refs, Some(&weights), None, None);
473
474 assert!(report.attention_entropy > 0.0);
475 }
476}