Skip to main content

ruvector_consciousness/
simd.rs

1//! SIMD-accelerated operations for consciousness computation.
2//!
3//! Provides vectorized KL-divergence, entropy, and matrix operations
4//! critical for Φ computation hot paths.
5
6// ---------------------------------------------------------------------------
7// KL Divergence (the core operation in Φ computation)
8// ---------------------------------------------------------------------------
9
10/// Compute KL divergence D_KL(P || Q) = Σ p_i * ln(p_i / q_i).
11///
12/// Dispatches to AVX2 when available, falls back to scalar.
13pub fn kl_divergence(p: &[f64], q: &[f64]) -> f64 {
14    assert_eq!(p.len(), q.len(), "KL divergence: mismatched lengths");
15
16    #[cfg(all(feature = "simd", target_arch = "x86_64"))]
17    {
18        if is_x86_feature_detected!("avx2") {
19            return kl_divergence_scalar(p, q); // AVX2 log is complex; use scalar with prefetch
20        }
21    }
22
23    kl_divergence_scalar(p, q)
24}
25
26/// Scalar KL divergence with branch-free clamping.
27pub fn kl_divergence_scalar(p: &[f64], q: &[f64]) -> f64 {
28    let mut sum = 0.0f64;
29    for i in 0..p.len() {
30        let pi = p[i];
31        let qi = q[i];
32        if pi > 1e-15 && qi > 1e-15 {
33            sum += pi * (pi / qi).ln();
34        }
35    }
36    sum
37}
38
39/// Earth Mover's Distance (EMD) approximation for distribution comparison.
40/// Used in IIT 4.0 for comparing cause-effect structures.
41pub fn emd_l1(p: &[f64], q: &[f64]) -> f64 {
42    assert_eq!(p.len(), q.len());
43    let mut cumsum = 0.0f64;
44    let mut dist = 0.0f64;
45    for i in 0..p.len() {
46        cumsum += p[i] - q[i];
47        dist += cumsum.abs();
48    }
49    dist
50}
51
52// ---------------------------------------------------------------------------
53// Entropy
54// ---------------------------------------------------------------------------
55
56/// Shannon entropy H(P) = -Σ p_i * ln(p_i).
57pub fn entropy(p: &[f64]) -> f64 {
58    #[cfg(all(feature = "simd", target_arch = "x86_64"))]
59    {
60        if is_x86_feature_detected!("avx2") {
61            return entropy_scalar(p);
62        }
63    }
64    entropy_scalar(p)
65}
66
67pub fn entropy_scalar(p: &[f64]) -> f64 {
68    let mut h = 0.0f64;
69    for &pi in p {
70        if pi > 1e-15 {
71            h -= pi * pi.ln();
72        }
73    }
74    h
75}
76
77// ---------------------------------------------------------------------------
78// SIMD matrix-vector multiply (dense, f64)
79// ---------------------------------------------------------------------------
80
81/// Dense matrix-vector multiply y = A * x (row-major A).
82/// Used for TPM operations in Φ computation.
83pub fn dense_matvec(a: &[f64], x: &[f64], y: &mut [f64], n: usize) {
84    assert_eq!(a.len(), n * n);
85    assert_eq!(x.len(), n);
86    assert_eq!(y.len(), n);
87
88    #[cfg(all(feature = "simd", target_arch = "x86_64"))]
89    {
90        if is_x86_feature_detected!("avx2") {
91            unsafe {
92                dense_matvec_avx2(a, x, y, n);
93            }
94            return;
95        }
96    }
97
98    dense_matvec_scalar(a, x, y, n);
99}
100
101fn dense_matvec_scalar(a: &[f64], x: &[f64], y: &mut [f64], n: usize) {
102    for i in 0..n {
103        let mut sum = 0.0f64;
104        let row_start = i * n;
105        for j in 0..n {
106            sum += a[row_start + j] * x[j];
107        }
108        y[i] = sum;
109    }
110}
111
112#[cfg(all(feature = "simd", target_arch = "x86_64"))]
113#[target_feature(enable = "avx2")]
114unsafe fn dense_matvec_avx2(a: &[f64], x: &[f64], y: &mut [f64], n: usize) {
115    use std::arch::x86_64::*;
116
117    for i in 0..n {
118        let row_start = i * n;
119        let mut accum = _mm256_setzero_pd();
120        let chunks = n / 4;
121        let remainder = n % 4;
122
123        for chunk in 0..chunks {
124            let base = row_start + chunk * 4;
125            // SAFETY: base + 3 < row_start + n = a.len() / n * (i+1), in bounds.
126            let av = _mm256_loadu_pd(a.as_ptr().add(base));
127            let xv = _mm256_loadu_pd(x.as_ptr().add(chunk * 4));
128            accum = _mm256_add_pd(accum, _mm256_mul_pd(av, xv));
129        }
130
131        let mut sum = horizontal_sum_f64x4(accum);
132
133        let tail_start = chunks * 4;
134        for j in tail_start..(tail_start + remainder) {
135            sum += *a.get_unchecked(row_start + j) * *x.get_unchecked(j);
136        }
137
138        *y.get_unchecked_mut(i) = sum;
139    }
140}
141
142#[cfg(all(feature = "simd", target_arch = "x86_64"))]
143#[target_feature(enable = "avx2")]
144unsafe fn horizontal_sum_f64x4(v: std::arch::x86_64::__m256d) -> f64 {
145    use std::arch::x86_64::*;
146    let hi = _mm256_extractf128_pd(v, 1);
147    let lo = _mm256_castpd256_pd128(v);
148    let sum128 = _mm_add_pd(lo, hi);
149    let hi64 = _mm_unpackhi_pd(sum128, sum128);
150    let result = _mm_add_sd(sum128, hi64);
151    _mm_cvtsd_f64(result)
152}
153
154// ---------------------------------------------------------------------------
155// Conditional distribution extraction
156// ---------------------------------------------------------------------------
157
158/// Extract conditional distribution P(future | state) from TPM row.
159#[inline]
160pub fn conditional_distribution(tpm: &[f64], n: usize, state: usize) -> &[f64] {
161    &tpm[state * n..(state + 1) * n]
162}
163
164/// Compute marginal distribution by averaging over all rows.
165pub fn marginal_distribution(tpm: &[f64], n: usize) -> Vec<f64> {
166    let mut marginal = vec![0.0; n];
167    for i in 0..n {
168        for j in 0..n {
169            marginal[j] += tpm[i * n + j];
170        }
171    }
172    let inv_n = 1.0 / n as f64;
173    for m in &mut marginal {
174        *m *= inv_n;
175    }
176    marginal
177}
178
179// ---------------------------------------------------------------------------
180// Shared pairwise MI computation (used by all spectral engines)
181// ---------------------------------------------------------------------------
182
183/// Pairwise mutual information between elements i and j given marginals.
184///
185/// MI(i,j) = p(i,j) · ln(p(i,j) / (p(i)·p(j)))
186/// where p(i,j) = (1/n) Σ_s TPM[s,i]·TPM[s,j].
187#[inline]
188pub fn pairwise_mi(tpm: &[f64], n: usize, i: usize, j: usize, marginal: &[f64]) -> f64 {
189    let pi = marginal[i].max(1e-15);
190    let pj = marginal[j].max(1e-15);
191    let mut pij = 0.0;
192    for state in 0..n {
193        // Column-major access: tpm[state][i] and tpm[state][j]
194        unsafe {
195            pij += *tpm.get_unchecked(state * n + i) * *tpm.get_unchecked(state * n + j);
196        }
197    }
198    pij /= n as f64;
199    pij = pij.max(1e-15);
200    (pij * (pij / (pi * pj)).ln()).max(0.0)
201}
202
203/// Build full pairwise MI matrix (symmetric, zero diagonal).
204/// Returns flat n×n row-major matrix.
205pub fn build_mi_matrix(tpm: &[f64], n: usize) -> Vec<f64> {
206    let marginal = marginal_distribution(tpm, n);
207    let mut mi = vec![0.0f64; n * n];
208    for i in 0..n {
209        for j in (i + 1)..n {
210            let val = pairwise_mi(tpm, n, i, j, &marginal);
211            mi[i * n + j] = val;
212            mi[j * n + i] = val;
213        }
214    }
215    mi
216}
217
218/// Build MI edge list (i, j, weight) with threshold pruning.
219pub fn build_mi_edges(tpm: &[f64], n: usize, threshold: f64) -> (Vec<(usize, usize, f64)>, Vec<f64>) {
220    let marginal = marginal_distribution(tpm, n);
221    let mut edges = Vec::new();
222    for i in 0..n {
223        for j in (i + 1)..n {
224            let mi = pairwise_mi(tpm, n, i, j, &marginal);
225            if mi > threshold {
226                edges.push((i, j, mi));
227            }
228        }
229    }
230    (edges, marginal)
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn kl_divergence_identical() {
239        let p = vec![0.25, 0.25, 0.25, 0.25];
240        assert!((kl_divergence(&p, &p)).abs() < 1e-12);
241    }
242
243    #[test]
244    fn entropy_uniform() {
245        let p = vec![0.25, 0.25, 0.25, 0.25];
246        let h = entropy(&p);
247        let expected = (4.0f64).ln();
248        assert!((h - expected).abs() < 1e-10);
249    }
250
251    #[test]
252    fn dense_matvec_correctness() {
253        let a = vec![1.0, 2.0, 3.0, 4.0];
254        let x = vec![1.0, 1.0];
255        let mut y = vec![0.0; 2];
256        dense_matvec(&a, &x, &mut y, 2);
257        assert!((y[0] - 3.0).abs() < 1e-10);
258        assert!((y[1] - 7.0).abs() < 1e-10);
259    }
260
261    #[test]
262    fn emd_identical() {
263        let p = vec![0.5, 0.3, 0.2];
264        assert!((emd_l1(&p, &p)).abs() < 1e-12);
265    }
266
267    #[test]
268    fn marginal_identity() {
269        let tpm = vec![1.0, 0.0, 0.0, 1.0];
270        let m = marginal_distribution(&tpm, 2);
271        assert!((m[0] - 0.5).abs() < 1e-10);
272        assert!((m[1] - 0.5).abs() < 1e-10);
273    }
274}