ruvector_math/tropical/
neural_analysis.rs

1//! Tropical Neural Network Analysis
2//!
3//! Neural networks with ReLU activations are piecewise linear functions,
4//! which can be analyzed using tropical geometry.
5//!
6//! ## Key Insight
7//!
8//! ReLU(x) = max(0, x) = 0 ⊕ x in tropical arithmetic
9//!
10//! A ReLU network is a composition of affine maps and tropical additions,
11//! making it a tropical rational function.
12//!
13//! ## Applications
14//!
15//! - Count linear regions of a neural network
16//! - Analyze decision boundaries
17//! - Bound network complexity
18
19use super::polynomial::TropicalPolynomial;
20
21/// Analyzes ReLU neural networks using tropical geometry
22#[derive(Debug, Clone)]
23pub struct TropicalNeuralAnalysis {
24    /// Network architecture: [input_dim, hidden1, hidden2, ..., output_dim]
25    architecture: Vec<usize>,
26    /// Weights: weights[l] is a (layer_size, prev_layer_size) matrix
27    weights: Vec<Vec<Vec<f64>>>,
28    /// Biases: biases[l] is a vector of length layer_size
29    biases: Vec<Vec<f64>>,
30}
31
32impl TropicalNeuralAnalysis {
33    /// Create analyzer for a ReLU network
34    pub fn new(
35        architecture: Vec<usize>,
36        weights: Vec<Vec<Vec<f64>>>,
37        biases: Vec<Vec<f64>>,
38    ) -> Self {
39        Self { architecture, weights, biases }
40    }
41
42    /// Create a random network for testing
43    pub fn random(architecture: Vec<usize>, seed: u64) -> Self {
44        use std::collections::hash_map::DefaultHasher;
45        use std::hash::{Hash, Hasher};
46
47        let mut weights = Vec::new();
48        let mut biases = Vec::new();
49
50        let mut s = seed;
51        for i in 1..architecture.len() {
52            let input_size = architecture[i - 1];
53            let output_size = architecture[i];
54
55            let mut layer_weights = Vec::new();
56            for _ in 0..output_size {
57                let mut neuron_weights = Vec::new();
58                for _ in 0..input_size {
59                    // Simple PRNG
60                    s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
61                    let w = ((s >> 33) as f64 / (1u64 << 31) as f64) - 1.0;
62                    neuron_weights.push(w);
63                }
64                layer_weights.push(neuron_weights);
65            }
66            weights.push(layer_weights);
67
68            let mut layer_biases = Vec::new();
69            for _ in 0..output_size {
70                s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
71                let b = ((s >> 33) as f64 / (1u64 << 31) as f64) - 1.0;
72                layer_biases.push(b * 0.1);
73            }
74            biases.push(layer_biases);
75        }
76
77        Self { architecture, weights, biases }
78    }
79
80    /// Forward pass of the ReLU network
81    pub fn forward(&self, input: &[f64]) -> Vec<f64> {
82        let mut x = input.to_vec();
83
84        for layer in 0..self.weights.len() {
85            let mut y = Vec::with_capacity(self.weights[layer].len());
86
87            for (neuron_weights, &bias) in self.weights[layer].iter().zip(self.biases[layer].iter()) {
88                let linear: f64 = neuron_weights.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum();
89                let z = linear + bias;
90                // ReLU = max(0, z) = tropical addition
91                y.push(z.max(0.0));
92            }
93
94            x = y;
95        }
96
97        x
98    }
99
100    /// Upper bound on number of linear regions
101    ///
102    /// For a network with widths n_0, n_1, ..., n_L where n_0 is input dimension:
103    /// Upper bound = prod_{i=1}^{L-1} sum_{j=0}^{min(n_0, n_i)} C(n_i, j)
104    ///
105    /// This follows from tropical geometry considerations.
106    pub fn linear_region_upper_bound(&self) -> u128 {
107        if self.architecture.len() < 2 {
108            return 1;
109        }
110
111        let n0 = self.architecture[0] as u128;
112        let mut bound: u128 = 1;
113
114        for i in 1..self.architecture.len() - 1 {
115            let ni = self.architecture[i] as u128;
116
117            // Sum of binomial coefficients C(ni, j) for j = 0 to min(n0, ni)
118            let k_max = n0.min(ni);
119            let mut layer_sum: u128 = 0;
120
121            for j in 0..=k_max {
122                layer_sum = layer_sum.saturating_add(binomial(ni, j));
123            }
124
125            bound = bound.saturating_mul(layer_sum);
126        }
127
128        bound
129    }
130
131    /// Estimate actual linear regions by sampling
132    ///
133    /// Samples random points and counts how many distinct activation patterns occur.
134    pub fn estimate_linear_regions(&self, num_samples: usize, seed: u64) -> usize {
135        use std::collections::HashSet;
136
137        let mut activation_patterns = HashSet::new();
138        let input_dim = self.architecture[0];
139
140        let mut s = seed;
141        for _ in 0..num_samples {
142            // Generate random input
143            let mut input = Vec::with_capacity(input_dim);
144            for _ in 0..input_dim {
145                s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
146                let x = ((s >> 33) as f64 / (1u64 << 31) as f64) * 2.0 - 1.0;
147                input.push(x);
148            }
149
150            // Track activation pattern
151            let pattern = self.get_activation_pattern(&input);
152            activation_patterns.insert(pattern);
153        }
154
155        activation_patterns.len()
156    }
157
158    /// Get activation pattern (which neurons are active) for an input
159    fn get_activation_pattern(&self, input: &[f64]) -> Vec<bool> {
160        let mut x = input.to_vec();
161        let mut pattern = Vec::new();
162
163        for layer in 0..self.weights.len() {
164            let mut y = Vec::with_capacity(self.weights[layer].len());
165
166            for (neuron_weights, &bias) in self.weights[layer].iter().zip(self.biases[layer].iter()) {
167                let linear: f64 = neuron_weights.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum();
168                let z = linear + bias;
169                pattern.push(z > 0.0);
170                y.push(z.max(0.0));
171            }
172
173            x = y;
174        }
175
176        pattern
177    }
178
179    /// Compute the tropical polynomial representation for 1D input
180    /// Returns the piecewise linear function f(x)
181    pub fn as_tropical_polynomial_1d(&self) -> Option<TropicalPolynomial> {
182        if self.architecture[0] != 1 || self.architecture[self.architecture.len() - 1] != 1 {
183            return None;
184        }
185
186        // For 1D input, we can enumerate the breakpoints
187        let breakpoints = self.find_breakpoints_1d(-10.0, 10.0, 1000);
188
189        if breakpoints.is_empty() {
190            return None;
191        }
192
193        // Build tropical polynomial from breakpoints
194        // Each breakpoint corresponds to a change in slope
195        let mut terms = Vec::new();
196        for (i, &x) in breakpoints.iter().enumerate() {
197            let y = self.forward(&[x])[0];
198            terms.push((y - (i as f64) * x, i as i32));
199        }
200
201        Some(TropicalPolynomial::from_monomials(
202            terms.into_iter().map(|(c, e)| super::polynomial::TropicalMonomial::new(c, e)).collect()
203        ))
204    }
205
206    /// Find breakpoints of the 1D piecewise linear function
207    fn find_breakpoints_1d(&self, x_min: f64, x_max: f64, num_samples: usize) -> Vec<f64> {
208        let mut breakpoints = vec![x_min];
209        let dx = (x_max - x_min) / num_samples as f64;
210
211        let mut prev_pattern = self.get_activation_pattern(&[x_min]);
212
213        for i in 1..=num_samples {
214            let x = x_min + i as f64 * dx;
215            let pattern = self.get_activation_pattern(&[x]);
216
217            if pattern != prev_pattern {
218                // Breakpoint somewhere between previous x and current x
219                let breakpoint = self.binary_search_breakpoint(x - dx, x, &prev_pattern);
220                breakpoints.push(breakpoint);
221                prev_pattern = pattern;
222            }
223        }
224
225        breakpoints.push(x_max);
226        breakpoints
227    }
228
229    /// Binary search for exact breakpoint location
230    fn binary_search_breakpoint(&self, mut lo: f64, mut hi: f64, lo_pattern: &[bool]) -> f64 {
231        for _ in 0..50 {
232            let mid = (lo + hi) / 2.0;
233            let mid_pattern = self.get_activation_pattern(&[mid]);
234
235            if mid_pattern == *lo_pattern {
236                lo = mid;
237            } else {
238                hi = mid;
239            }
240        }
241
242        (lo + hi) / 2.0
243    }
244
245    /// Compute decision boundary complexity for binary classification
246    pub fn decision_boundary_complexity(&self, num_samples: usize, seed: u64) -> f64 {
247        // For a binary classifier, count sign changes in output
248        // along random rays through the input space
249
250        let input_dim = self.architecture[0];
251        let mut total_changes = 0;
252        let mut s = seed;
253
254        for _ in 0..num_samples {
255            // Random direction
256            let mut direction = Vec::with_capacity(input_dim);
257            for _ in 0..input_dim {
258                s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
259                let d = ((s >> 33) as f64 / (1u64 << 31) as f64) * 2.0 - 1.0;
260                direction.push(d);
261            }
262
263            // Normalize
264            let norm: f64 = direction.iter().map(|x| x * x).sum::<f64>().sqrt();
265            for d in direction.iter_mut() {
266                *d /= norm.max(1e-10);
267            }
268
269            // Count sign changes along ray
270            let mut prev_sign = None;
271            for t in -100..=100 {
272                let t = t as f64 * 0.1;
273                let input: Vec<f64> = direction.iter().map(|d| t * d).collect();
274                let output = self.forward(&input);
275
276                if !output.is_empty() {
277                    let sign = output[0] > 0.0;
278                    if let Some(prev) = prev_sign {
279                        if prev != sign {
280                            total_changes += 1;
281                        }
282                    }
283                    prev_sign = Some(sign);
284                }
285            }
286        }
287
288        total_changes as f64 / num_samples as f64
289    }
290}
291
292/// Counter for linear regions of piecewise linear functions
293#[derive(Debug, Clone)]
294pub struct LinearRegionCounter {
295    /// Dimension of input space
296    input_dim: usize,
297}
298
299impl LinearRegionCounter {
300    /// Create counter for given input dimension
301    pub fn new(input_dim: usize) -> Self {
302        Self { input_dim }
303    }
304
305    /// Theoretical maximum for n-dimensional input with k hyperplanes
306    /// This is the central zone counting problem
307    pub fn hyperplane_arrangement_max(&self, num_hyperplanes: usize) -> u128 {
308        // Maximum regions = sum_{i=0}^{n} C(k, i)
309        let n = self.input_dim as u128;
310        let k = num_hyperplanes as u128;
311
312        let mut total: u128 = 0;
313        for i in 0..=n.min(k) {
314            total = total.saturating_add(binomial(k, i));
315        }
316
317        total
318    }
319
320    /// Zaslavsky's theorem: count regions of hyperplane arrangement
321    /// For a general position arrangement of k hyperplanes in R^n:
322    /// regions = sum_{i=0}^n C(k, i)
323    pub fn zaslavsky_formula(&self, num_hyperplanes: usize) -> u128 {
324        self.hyperplane_arrangement_max(num_hyperplanes)
325    }
326}
327
328/// Compute binomial coefficient C(n, k) = n! / (k! * (n-k)!)
329fn binomial(n: u128, k: u128) -> u128 {
330    if k > n {
331        return 0;
332    }
333
334    let k = k.min(n - k); // Use symmetry
335
336    let mut result: u128 = 1;
337    for i in 0..k {
338        result = result.saturating_mul(n - i) / (i + 1);
339    }
340
341    result
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_relu_forward() {
350        let analysis = TropicalNeuralAnalysis::new(
351            vec![2, 3, 1],
352            vec![
353                vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]],
354                vec![vec![1.0, 1.0, 1.0]],
355            ],
356            vec![
357                vec![0.0, 0.0, -1.0],
358                vec![0.0],
359            ],
360        );
361
362        let output = analysis.forward(&[1.0, 1.0]);
363        assert!(output[0] > 0.0);
364    }
365
366    #[test]
367    fn test_linear_region_bound() {
368        // Network: 2 -> 4 -> 4 -> 1
369        let analysis = TropicalNeuralAnalysis::random(vec![2, 4, 4, 1], 42);
370        let bound = analysis.linear_region_upper_bound();
371
372        // For 2D input with hidden layers of 4:
373        // Upper bound = C(4,0)+C(4,1)+C(4,2) for each hidden layer
374        // = (1 + 4 + 6)^2 = 121
375        assert!(bound > 0);
376    }
377
378    #[test]
379    fn test_estimate_regions() {
380        let analysis = TropicalNeuralAnalysis::random(vec![2, 4, 1], 42);
381        let estimate = analysis.estimate_linear_regions(1000, 123);
382
383        // Should find multiple regions
384        assert!(estimate >= 1);
385    }
386
387    #[test]
388    fn test_binomial() {
389        assert_eq!(binomial(5, 2), 10);
390        assert_eq!(binomial(10, 0), 1);
391        assert_eq!(binomial(10, 10), 1);
392        assert_eq!(binomial(6, 3), 20);
393    }
394
395    #[test]
396    fn test_hyperplane_max() {
397        let counter = LinearRegionCounter::new(2);
398
399        // 3 lines in R^2 can create at most 1 + 3 + 3 = 7 regions
400        assert_eq!(counter.hyperplane_arrangement_max(3), 7);
401    }
402}