1use crate::arena::PhiArena;
17use crate::error::ConsciousnessError;
18use crate::phi::{partition_information_loss_pub, validate_tpm};
19use crate::simd::emd_l1;
20use crate::traits::PhiEngine;
21use crate::types::{Bipartition, ComputeBudget, PhiAlgorithm, PhiResult, TransitionMatrix};
22
23use std::time::Instant;
24
25pub struct GrayCodePartitionIter {
34 current: u64,
36 counter: u64,
38 max: u64,
40 n: usize,
42}
43
44impl GrayCodePartitionIter {
45 pub fn new(n: usize) -> Self {
50 assert!((2..=63).contains(&n));
51 Self {
52 current: 0,
53 counter: 1, max: 1u64 << (n - 1),
55 n,
56 }
57 }
58
59 #[inline]
61 pub fn changed_bit(prev_gray: u64, curr_gray: u64) -> u32 {
62 (prev_gray ^ curr_gray).trailing_zeros()
63 }
64}
65
66impl Iterator for GrayCodePartitionIter {
67 type Item = (Bipartition, u32); fn next(&mut self) -> Option<Self::Item> {
70 if self.counter >= self.max {
71 return None;
72 }
73
74 let prev_gray = self.current;
75 let gray = self.counter ^ (self.counter >> 1);
77 self.current = gray;
78 self.counter += 1;
79
80 let mask = 1u64 | (gray << 1);
82
83 let full = (1u64 << self.n) - 1;
85 if mask == 0 || mask == full {
86 return self.next(); }
88
89 let changed = if prev_gray == 0 {
90 0 } else {
92 Self::changed_bit(prev_gray, gray) + 1 };
94
95 Some((Bipartition { mask, n: self.n }, changed))
96 }
97
98 fn size_hint(&self) -> (usize, Option<usize>) {
99 let remaining = (self.max - self.counter) as usize;
100 (remaining, Some(remaining))
101 }
102}
103
104fn canonical_partition(mask: u64, n: usize) -> u64 {
115 let popcount = mask.count_ones();
116 let complement_popcount = n as u32 - popcount;
117
118 if popcount > complement_popcount {
120 let full = (1u64 << n) - 1;
121 full & !mask
122 } else if popcount == complement_popcount {
123 let full = (1u64 << n) - 1;
125 mask.min(full & !mask)
126 } else {
127 mask
128 }
129}
130
131#[inline]
138fn balance_score(mask: u64, n: usize) -> f64 {
139 let k = mask.count_ones() as f64;
140 let half = n as f64 / 2.0;
141 1.0 - ((k - half).abs() / half)
142}
143
144pub struct GeoMipPhiEngine {
156 pub prune_automorphisms: bool,
158 pub max_evaluations: u64,
160}
161
162impl GeoMipPhiEngine {
163 pub fn new(prune_automorphisms: bool, max_evaluations: u64) -> Self {
164 Self {
165 prune_automorphisms,
166 max_evaluations,
167 }
168 }
169}
170
171impl Default for GeoMipPhiEngine {
172 fn default() -> Self {
173 Self {
174 prune_automorphisms: true,
175 max_evaluations: 0,
176 }
177 }
178}
179
180impl PhiEngine for GeoMipPhiEngine {
181 fn compute_phi(
182 &self,
183 tpm: &TransitionMatrix,
184 state: Option<usize>,
185 budget: &ComputeBudget,
186 ) -> Result<PhiResult, ConsciousnessError> {
187 validate_tpm(tpm)?;
188 let n = tpm.n;
189
190 if n > 25 {
191 return Err(ConsciousnessError::SystemTooLarge { n, max: 25 });
192 }
193
194 let state_idx = state.unwrap_or(0);
195 let start = Instant::now();
196 let arena = PhiArena::with_capacity(n * n * 16);
197
198 let total_partitions = (1u64 << n) - 2;
199 let mut min_phi = f64::MAX;
200 let mut best_partition = Bipartition { mask: 1, n };
201 let mut evaluated = 0u64;
202 let mut convergence = Vec::new();
203
204 let mut balanced_partitions: Vec<Bipartition> = Vec::new();
206 let half = n / 2;
207 for mask in 1..((1u64 << n) - 1) {
208 let popcount = mask.count_ones() as usize;
209 if (popcount == half || popcount == half + 1)
210 && (!self.prune_automorphisms
211 || canonical_partition(mask, n) == mask)
212 {
213 balanced_partitions.push(Bipartition { mask, n });
214 }
215 }
216
217 balanced_partitions
219 .sort_by(|a, b| balance_score(b.mask, n).partial_cmp(&balance_score(a.mask, n)).unwrap());
220
221 for partition in &balanced_partitions {
222 if self.max_evaluations > 0 && evaluated >= self.max_evaluations {
223 break;
224 }
225 if budget.max_partitions > 0 && evaluated >= budget.max_partitions {
226 break;
227 }
228 if start.elapsed() > budget.max_time {
229 break;
230 }
231
232 let loss = partition_information_loss_pub(tpm, state_idx, partition, &arena);
233 arena.reset();
234
235 if loss < min_phi {
236 min_phi = loss;
237 best_partition = partition.clone();
238 }
239
240 if min_phi < 1e-12 {
242 evaluated += 1;
243 break;
244 }
245
246 evaluated += 1;
247 if evaluated % 500 == 0 {
248 convergence.push(min_phi);
249 }
250 }
251
252 if min_phi > 1e-12 {
254 let mut seen = std::collections::HashSet::new();
255 for bp in &balanced_partitions {
256 seen.insert(bp.mask);
257 }
258
259 for (partition, _changed_bit) in GrayCodePartitionIter::new(n) {
260 if self.max_evaluations > 0 && evaluated >= self.max_evaluations {
261 break;
262 }
263 if budget.max_partitions > 0 && evaluated >= budget.max_partitions {
264 break;
265 }
266 if start.elapsed() > budget.max_time {
267 break;
268 }
269
270 if seen.contains(&partition.mask) {
272 continue;
273 }
274
275 if self.prune_automorphisms {
277 let canon = canonical_partition(partition.mask, n);
278 if canon != partition.mask && seen.contains(&canon) {
279 continue;
280 }
281 seen.insert(partition.mask);
282 }
283
284 let loss = partition_information_loss_pub(tpm, state_idx, &partition, &arena);
285 arena.reset();
286
287 if loss < min_phi {
288 min_phi = loss;
289 best_partition = partition;
290 }
291
292 if min_phi < 1e-12 {
293 evaluated += 1;
294 break;
295 }
296
297 evaluated += 1;
298 if evaluated % 500 == 0 {
299 convergence.push(min_phi);
300 }
301 }
302 }
303
304 convergence.push(min_phi);
305
306 Ok(PhiResult {
307 phi: if min_phi == f64::MAX { 0.0 } else { min_phi },
308 mip: best_partition,
309 partitions_evaluated: evaluated,
310 total_partitions,
311 algorithm: PhiAlgorithm::GeoMIP,
312 elapsed: start.elapsed(),
313 convergence,
314 })
315 }
316
317 fn algorithm(&self) -> PhiAlgorithm {
318 PhiAlgorithm::GeoMIP
319 }
320
321 fn estimate_cost(&self, n: usize) -> u64 {
322 ((1u64 << n) - 2) / 2
324 }
325}
326
327pub fn partition_information_loss_emd(
338 tpm: &TransitionMatrix,
339 state: usize,
340 partition: &Bipartition,
341 arena: &PhiArena,
342) -> f64 {
343 let n = tpm.n;
344 let set_a = partition.set_a();
345 let set_b = partition.set_b();
346
347 let whole_dist = &tpm.data[state * n..(state + 1) * n];
348
349 let tpm_a = tpm.marginalize(&set_a);
350 let tpm_b = tpm.marginalize(&set_b);
351
352 let state_a = map_state_to_subsystem_local(state, &set_a);
353 let state_b = map_state_to_subsystem_local(state, &set_b);
354
355 let dist_a = &tpm_a.data[state_a * tpm_a.n..(state_a + 1) * tpm_a.n];
356 let dist_b = &tpm_b.data[state_b * tpm_b.n..(state_b + 1) * tpm_b.n];
357
358 let product = arena.alloc_slice::<f64>(n);
359 compute_product_local(dist_a, &set_a, dist_b, &set_b, product, n);
360
361 let sum: f64 = product.iter().sum();
362 if sum > 1e-15 {
363 let inv_sum = 1.0 / sum;
364 for p in product.iter_mut() {
365 *p *= inv_sum;
366 }
367 }
368
369 let loss = emd_l1(whole_dist, product).max(0.0);
370 arena.reset();
371 loss
372}
373
374fn map_state_to_subsystem_local(state: usize, indices: &[usize]) -> usize {
375 let mut sub_state = 0;
376 for (bit, &idx) in indices.iter().enumerate() {
377 if state & (1 << idx) != 0 {
378 sub_state |= 1 << bit;
379 }
380 }
381 sub_state % indices.len().max(1)
382}
383
384fn compute_product_local(
385 dist_a: &[f64],
386 set_a: &[usize],
387 dist_b: &[f64],
388 set_b: &[usize],
389 output: &mut [f64],
390 n: usize,
391) {
392 let ka = set_a.len();
393 let kb = set_b.len();
394
395 for global_state in 0..n {
396 let mut sa = 0usize;
397 for (bit, &idx) in set_a.iter().enumerate() {
398 if global_state & (1 << idx) != 0 {
399 sa |= 1 << bit;
400 }
401 }
402 let mut sb = 0usize;
403 for (bit, &idx) in set_b.iter().enumerate() {
404 if global_state & (1 << idx) != 0 {
405 sb |= 1 << bit;
406 }
407 }
408 let pa = if sa < ka { dist_a[sa] } else { 0.0 };
409 let pb = if sb < kb { dist_b[sb] } else { 0.0 };
410 output[global_state] = pa * pb;
411 }
412}
413
414#[cfg(test)]
419mod tests {
420 use super::*;
421
422 fn and_gate_tpm() -> TransitionMatrix {
423 #[rustfmt::skip]
424 let data = vec![
425 0.5, 0.25, 0.25, 0.0,
426 0.5, 0.25, 0.25, 0.0,
427 0.5, 0.25, 0.25, 0.0,
428 0.0, 0.0, 0.0, 1.0,
429 ];
430 TransitionMatrix::new(4, data)
431 }
432
433 fn disconnected_tpm() -> TransitionMatrix {
434 #[rustfmt::skip]
435 let data = vec![
436 0.5, 0.5, 0.0, 0.0,
437 0.5, 0.5, 0.0, 0.0,
438 0.0, 0.0, 0.5, 0.5,
439 0.0, 0.0, 0.5, 0.5,
440 ];
441 TransitionMatrix::new(4, data)
442 }
443
444 #[test]
445 fn gray_code_iter_count() {
446 let count = GrayCodePartitionIter::new(4).count();
447 assert_eq!(count, 6);
450 }
451
452 #[test]
453 fn gray_code_consecutive_differ_by_one() {
454 let partitions: Vec<(Bipartition, u32)> = GrayCodePartitionIter::new(5).collect();
455 for i in 1..partitions.len() {
456 let diff = partitions[i].0.mask ^ partitions[i - 1].0.mask;
457 assert!(
459 diff.count_ones() <= 2,
460 "Gray code partitions at {i} differ by {} bits",
461 diff.count_ones()
462 );
463 }
464 }
465
466 #[test]
467 fn canonical_partition_symmetric() {
468 let c1 = canonical_partition(0b0011, 4);
470 let c2 = canonical_partition(0b1100, 4);
471 assert_eq!(c1, c2);
472 }
473
474 #[test]
475 fn geomip_disconnected_is_zero() {
476 let tpm = disconnected_tpm();
477 let budget = ComputeBudget::exact();
478 let engine = GeoMipPhiEngine::default();
479 let result = engine.compute_phi(&tpm, Some(0), &budget).unwrap();
480 assert!(
481 result.phi < 1e-6,
482 "disconnected should have Φ ≈ 0, got {}",
483 result.phi
484 );
485 }
486
487 #[test]
488 fn geomip_and_gate() {
489 let tpm = and_gate_tpm();
490 let budget = ComputeBudget::exact();
491 let engine = GeoMipPhiEngine::default();
492 let result = engine.compute_phi(&tpm, Some(3), &budget).unwrap();
493 assert!(result.phi >= 0.0);
494 }
495
496 #[test]
497 fn geomip_fewer_evaluations_than_exact() {
498 let tpm = and_gate_tpm();
499 let budget = ComputeBudget::exact();
500
501 let exact_result =
502 crate::phi::ExactPhiEngine.compute_phi(&tpm, Some(0), &budget).unwrap();
503 let geomip_result =
504 GeoMipPhiEngine::default().compute_phi(&tpm, Some(0), &budget).unwrap();
505
506 assert!(
508 geomip_result.partitions_evaluated <= exact_result.partitions_evaluated,
509 "GeoMIP evaluated {} vs exact {}",
510 geomip_result.partitions_evaluated,
511 exact_result.partitions_evaluated
512 );
513 }
514
515 #[test]
516 fn emd_loss_nonnegative() {
517 let tpm = and_gate_tpm();
518 let partition = Bipartition { mask: 0b0011, n: 4 };
519 let arena = PhiArena::with_capacity(1024);
520 let loss = partition_information_loss_emd(&tpm, 0, &partition, &arena);
521 assert!(loss >= 0.0, "EMD loss should be ≥ 0, got {loss}");
522 }
523
524 #[test]
525 fn emd_loss_disconnected_zero() {
526 let tpm = disconnected_tpm();
527 let partition = Bipartition {
528 mask: 0b0011,
529 n: 4,
530 };
531 let arena = PhiArena::with_capacity(1024);
532 let loss = partition_information_loss_emd(&tpm, 0, &partition, &arena);
533 assert!(loss < 1e-6, "disconnected EMD loss should be ≈ 0, got {loss}");
534 }
535}