1use super::metrics::{MetricType, MetricValue};
4use crate::info_bottleneck::KLDivergence;
5use crate::pde_attention::GraphLaplacian;
6use crate::topology::WindowCoherence;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ReportConfig {
12 pub ot_projections: usize,
14 pub knn_k: usize,
16 pub diffusion_sigma: f32,
18 pub compute_persistence: bool,
20 pub seed: u64,
22}
23
24impl Default for ReportConfig {
25 fn default() -> Self {
26 Self {
27 ot_projections: 8,
28 knn_k: 8,
29 diffusion_sigma: 1.0,
30 compute_persistence: false,
31 seed: 42,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct GeometryReport {
39 pub ot_mean_distance: f32,
41 pub topology_coherence: f32,
43 pub h0_death_sum: Option<f32>,
45 pub ib_kl: f32,
47 pub diffusion_energy: f32,
49 pub attention_entropy: f32,
51 pub metrics: Vec<MetricValue>,
53 pub health_score: f32,
55 pub recommendation: AttentionRecommendation,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
61pub enum AttentionRecommendation {
62 Stable,
64 Cautious,
66 Freeze,
68 IncreaseTemperature,
70 DecreaseTemperature,
72 AddRegularization,
74}
75
76pub struct ReportBuilder {
78 config: ReportConfig,
79}
80
81impl ReportBuilder {
82 pub fn new(config: ReportConfig) -> Self {
84 Self { config }
85 }
86
87 pub fn build(
89 &self,
90 query: &[f32],
91 keys: &[&[f32]],
92 attention_weights: Option<&[f32]>,
93 ib_mean: Option<&[f32]>,
94 ib_log_var: Option<&[f32]>,
95 ) -> GeometryReport {
96 let n = keys.len();
97 if n == 0 {
98 return GeometryReport::empty();
99 }
100
101 let _dim = keys[0].len();
102
103 let ot_mean = self.compute_ot_distance(query, keys);
105
106 let coherence = self.compute_coherence(keys);
108
109 let h0_sum = if self.config.compute_persistence {
111 Some(self.compute_h0_persistence(keys))
112 } else {
113 None
114 };
115
116 let ib_kl = match (ib_mean, ib_log_var) {
118 (Some(m), Some(v)) => KLDivergence::gaussian_to_unit_arrays(m, v),
119 _ => 0.0,
120 };
121
122 let diffusion_energy = self.compute_diffusion_energy(query, keys);
124
125 let entropy = match attention_weights {
127 Some(w) => self.compute_entropy(w),
128 None => (n as f32).ln(), };
130
131 let mut metrics = vec![
133 MetricValue::new(MetricType::OTDistance, ot_mean, 0.0, 10.0, 5.0, 8.0),
134 MetricValue::new(MetricType::TopologyCoherence, coherence, 0.0, 1.0, 0.3, 0.1),
135 MetricValue::new(MetricType::IBKL, ib_kl, 0.0, 100.0, 50.0, 80.0),
136 MetricValue::new(
137 MetricType::DiffusionEnergy,
138 diffusion_energy,
139 0.0,
140 100.0,
141 50.0,
142 80.0,
143 ),
144 MetricValue::new(
145 MetricType::AttentionEntropy,
146 entropy,
147 0.0,
148 (n as f32).ln().max(1.0),
149 0.5,
150 0.2,
151 ),
152 ];
153
154 if let Some(h0) = h0_sum {
155 metrics.push(MetricValue::new(
156 MetricType::H0Persistence,
157 h0,
158 0.0,
159 100.0,
160 50.0,
161 80.0,
162 ));
163 }
164
165 let health_score = self.compute_health_score(&metrics);
167
168 let recommendation = self.determine_recommendation(&metrics, coherence, entropy, n);
170
171 GeometryReport {
172 ot_mean_distance: ot_mean,
173 topology_coherence: coherence,
174 h0_death_sum: h0_sum,
175 ib_kl,
176 diffusion_energy,
177 attention_entropy: entropy,
178 metrics,
179 health_score,
180 recommendation,
181 }
182 }
183
184 fn compute_ot_distance(&self, query: &[f32], keys: &[&[f32]]) -> f32 {
186 let dim = query.len();
187 let n = keys.len();
188 if n == 0 {
189 return 0.0;
190 }
191
192 let mut rng_state = self.config.seed;
194 let projections: Vec<Vec<f32>> = (0..self.config.ot_projections)
195 .map(|_| self.random_unit_vector(dim, &mut rng_state))
196 .collect();
197
198 let q_projs: Vec<f32> = projections.iter().map(|p| Self::dot(query, p)).collect();
200
201 let mut total = 0.0f32;
203 for key in keys {
204 let mut dist = 0.0f32;
205 for (i, proj) in projections.iter().enumerate() {
206 let k_proj = Self::dot(key, proj);
207 dist += (q_projs[i] - k_proj).abs();
208 }
209 total += dist / self.config.ot_projections as f32;
210 }
211
212 total / n as f32
213 }
214
215 fn compute_coherence(&self, keys: &[&[f32]]) -> f32 {
217 use crate::topology::CoherenceMetric;
218
219 let coherence = WindowCoherence::compute(
220 keys,
221 self.config.knn_k,
222 &[
223 CoherenceMetric::BoundaryMass,
224 CoherenceMetric::SimilarityVariance,
225 ],
226 );
227
228 coherence.score
229 }
230
231 fn compute_h0_persistence(&self, keys: &[&[f32]]) -> f32 {
233 let n = keys.len();
234 if n <= 1 {
235 return 0.0;
236 }
237
238 let mut edges: Vec<(f32, usize, usize)> = Vec::new();
240 for i in 0..n {
241 for j in (i + 1)..n {
242 let dist = Self::l2_distance(keys[i], keys[j]);
243 edges.push((dist, i, j));
244 }
245 }
246
247 edges.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
248
249 let mut parent: Vec<usize> = (0..n).collect();
251 let mut rank = vec![0u8; n];
252 let mut deaths = Vec::new();
253
254 fn find(parent: &mut [usize], x: usize) -> usize {
255 if parent[x] != x {
256 parent[x] = find(parent, parent[x]);
257 }
258 parent[x]
259 }
260
261 fn union(parent: &mut [usize], rank: &mut [u8], a: usize, b: usize) -> bool {
262 let mut ra = find(parent, a);
263 let mut rb = find(parent, b);
264 if ra == rb {
265 return false;
266 }
267 if rank[ra] < rank[rb] {
268 std::mem::swap(&mut ra, &mut rb);
269 }
270 parent[rb] = ra;
271 if rank[ra] == rank[rb] {
272 rank[ra] += 1;
273 }
274 true
275 }
276
277 for (w, i, j) in edges {
278 if union(&mut parent, &mut rank, i, j) {
279 deaths.push(w);
280 if deaths.len() == n - 1 {
281 break;
282 }
283 }
284 }
285
286 if !deaths.is_empty() {
288 deaths.pop();
289 }
290
291 deaths.iter().sum()
292 }
293
294 fn compute_diffusion_energy(&self, query: &[f32], keys: &[&[f32]]) -> f32 {
296 use crate::pde_attention::LaplacianType;
297
298 let n = keys.len();
299 if n == 0 {
300 return 0.0;
301 }
302
303 let x: Vec<f32> = keys.iter().map(|k| Self::dot(query, k)).collect();
305
306 let lap = GraphLaplacian::from_keys(
308 keys,
309 self.config.diffusion_sigma,
310 LaplacianType::Unnormalized,
311 );
312
313 let lx = lap.apply(&x);
315 Self::dot(&x, &lx)
316 }
317
318 fn compute_entropy(&self, weights: &[f32]) -> f32 {
320 let eps = 1e-10;
321 let mut entropy = 0.0f32;
322
323 for &w in weights {
324 if w > eps {
325 entropy -= w * w.ln();
326 }
327 }
328
329 entropy.max(0.0)
330 }
331
332 fn compute_health_score(&self, metrics: &[MetricValue]) -> f32 {
334 if metrics.is_empty() {
335 return 1.0;
336 }
337
338 let healthy_count = metrics.iter().filter(|m| m.is_healthy).count();
339 healthy_count as f32 / metrics.len() as f32
340 }
341
342 fn determine_recommendation(
344 &self,
345 metrics: &[MetricValue],
346 coherence: f32,
347 entropy: f32,
348 n: usize,
349 ) -> AttentionRecommendation {
350 let max_entropy = (n as f32).ln().max(1.0);
351 let entropy_ratio = entropy / max_entropy;
352
353 let has_critical = metrics.iter().any(|m| m.is_critical());
355 if has_critical {
356 return AttentionRecommendation::Freeze;
357 }
358
359 if coherence < 0.3 {
361 return AttentionRecommendation::Cautious;
362 }
363
364 if entropy_ratio < 0.2 {
366 return AttentionRecommendation::IncreaseTemperature;
367 }
368
369 if entropy_ratio > 0.9 {
371 return AttentionRecommendation::DecreaseTemperature;
372 }
373
374 let has_warning = metrics.iter().any(|m| m.is_warning());
376 if has_warning {
377 return AttentionRecommendation::AddRegularization;
378 }
379
380 AttentionRecommendation::Stable
381 }
382
383 fn random_unit_vector(&self, dim: usize, state: &mut u64) -> Vec<f32> {
385 let mut v = vec![0.0f32; dim];
386 for i in 0..dim {
387 *state ^= *state << 13;
389 *state ^= *state >> 7;
390 *state ^= *state << 17;
391 let u = (*state & 0x00FF_FFFF) as f32 / 16_777_216.0;
392 v[i] = u * 2.0 - 1.0;
393 }
394
395 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
396 if norm > 0.0 {
397 for x in v.iter_mut() {
398 *x /= norm;
399 }
400 }
401
402 v
403 }
404
405 #[inline]
407 fn dot(a: &[f32], b: &[f32]) -> f32 {
408 a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
409 }
410
411 #[inline]
413 fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
414 a.iter()
415 .zip(b.iter())
416 .map(|(&ai, &bi)| (ai - bi) * (ai - bi))
417 .sum::<f32>()
418 .sqrt()
419 }
420}
421
422impl GeometryReport {
423 pub fn empty() -> Self {
425 Self {
426 ot_mean_distance: 0.0,
427 topology_coherence: 1.0,
428 h0_death_sum: None,
429 ib_kl: 0.0,
430 diffusion_energy: 0.0,
431 attention_entropy: 0.0,
432 metrics: vec![],
433 health_score: 1.0,
434 recommendation: AttentionRecommendation::Stable,
435 }
436 }
437
438 pub fn is_healthy(&self) -> bool {
440 self.health_score > 0.7
441 }
442
443 pub fn warnings(&self) -> Vec<&MetricValue> {
445 self.metrics.iter().filter(|m| m.is_warning()).collect()
446 }
447
448 pub fn criticals(&self) -> Vec<&MetricValue> {
450 self.metrics.iter().filter(|m| m.is_critical()).collect()
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457
458 #[test]
459 fn test_report_builder() {
460 let builder = ReportBuilder::new(ReportConfig::default());
461
462 let query = vec![1.0f32; 16];
463 let keys: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 * 0.1; 16]).collect();
464 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
465
466 let report = builder.build(&query, &keys_refs, None, None, None);
467
468 assert!(report.topology_coherence >= 0.0);
469 assert!(report.topology_coherence <= 1.0);
470 assert!(report.health_score >= 0.0);
471 assert!(report.health_score <= 1.0);
472 }
473
474 #[test]
475 fn test_empty_report() {
476 let report = GeometryReport::empty();
477 assert!(report.is_healthy());
478 assert_eq!(report.recommendation, AttentionRecommendation::Stable);
479 }
480
481 #[test]
482 fn test_with_attention_weights() {
483 let builder = ReportBuilder::new(ReportConfig::default());
484
485 let query = vec![1.0f32; 8];
486 let keys: Vec<Vec<f32>> = vec![vec![1.0; 8], vec![0.9; 8], vec![0.1; 8]];
487 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
488 let weights = vec![0.6, 0.3, 0.1];
489
490 let report = builder.build(&query, &keys_refs, Some(&weights), None, None);
491
492 assert!(report.attention_entropy > 0.0);
493 }
494}