Skip to main content

ruvector_math/tropical/
neural_analysis.rs

1//! Tropical Neural Network Analysis
2//!
3//! Neural networks with ReLU activations are piecewise linear functions,
4//! which can be analyzed using tropical geometry.
5//!
6//! ## Key Insight
7//!
8//! ReLU(x) = max(0, x) = 0 ⊕ x in tropical arithmetic
9//!
10//! A ReLU network is a composition of affine maps and tropical additions,
11//! making it a tropical rational function.
12//!
13//! ## Applications
14//!
15//! - Count linear regions of a neural network
16//! - Analyze decision boundaries
17//! - Bound network complexity
18
19use super::polynomial::TropicalPolynomial;
20
21/// Analyzes ReLU neural networks using tropical geometry
22#[derive(Debug, Clone)]
23pub struct TropicalNeuralAnalysis {
24    /// Network architecture: [input_dim, hidden1, hidden2, ..., output_dim]
25    architecture: Vec<usize>,
26    /// Weights: weights[l] is a (layer_size, prev_layer_size) matrix
27    weights: Vec<Vec<Vec<f64>>>,
28    /// Biases: biases[l] is a vector of length layer_size
29    biases: Vec<Vec<f64>>,
30}
31
32impl TropicalNeuralAnalysis {
33    /// Create analyzer for a ReLU network
34    pub fn new(
35        architecture: Vec<usize>,
36        weights: Vec<Vec<Vec<f64>>>,
37        biases: Vec<Vec<f64>>,
38    ) -> Self {
39        Self {
40            architecture,
41            weights,
42            biases,
43        }
44    }
45
46    /// Create a random network for testing
47    pub fn random(architecture: Vec<usize>, seed: u64) -> Self {
48        use std::collections::hash_map::DefaultHasher;
49        use std::hash::{Hash, Hasher};
50
51        let mut weights = Vec::new();
52        let mut biases = Vec::new();
53
54        let mut s = seed;
55        for i in 1..architecture.len() {
56            let input_size = architecture[i - 1];
57            let output_size = architecture[i];
58
59            let mut layer_weights = Vec::new();
60            for _ in 0..output_size {
61                let mut neuron_weights = Vec::new();
62                for _ in 0..input_size {
63                    // Simple PRNG
64                    s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
65                    let w = ((s >> 33) as f64 / (1u64 << 31) as f64) - 1.0;
66                    neuron_weights.push(w);
67                }
68                layer_weights.push(neuron_weights);
69            }
70            weights.push(layer_weights);
71
72            let mut layer_biases = Vec::new();
73            for _ in 0..output_size {
74                s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
75                let b = ((s >> 33) as f64 / (1u64 << 31) as f64) - 1.0;
76                layer_biases.push(b * 0.1);
77            }
78            biases.push(layer_biases);
79        }
80
81        Self {
82            architecture,
83            weights,
84            biases,
85        }
86    }
87
88    /// Forward pass of the ReLU network
89    pub fn forward(&self, input: &[f64]) -> Vec<f64> {
90        let mut x = input.to_vec();
91
92        for layer in 0..self.weights.len() {
93            let mut y = Vec::with_capacity(self.weights[layer].len());
94
95            for (neuron_weights, &bias) in self.weights[layer].iter().zip(self.biases[layer].iter())
96            {
97                let linear: f64 = neuron_weights
98                    .iter()
99                    .zip(x.iter())
100                    .map(|(w, xi)| w * xi)
101                    .sum();
102                let z = linear + bias;
103                // ReLU = max(0, z) = tropical addition
104                y.push(z.max(0.0));
105            }
106
107            x = y;
108        }
109
110        x
111    }
112
113    /// Upper bound on number of linear regions
114    ///
115    /// For a network with widths n_0, n_1, ..., n_L where n_0 is input dimension:
116    /// Upper bound = prod_{i=1}^{L-1} sum_{j=0}^{min(n_0, n_i)} C(n_i, j)
117    ///
118    /// This follows from tropical geometry considerations.
119    pub fn linear_region_upper_bound(&self) -> u128 {
120        if self.architecture.len() < 2 {
121            return 1;
122        }
123
124        let n0 = self.architecture[0] as u128;
125        let mut bound: u128 = 1;
126
127        for i in 1..self.architecture.len() - 1 {
128            let ni = self.architecture[i] as u128;
129
130            // Sum of binomial coefficients C(ni, j) for j = 0 to min(n0, ni)
131            let k_max = n0.min(ni);
132            let mut layer_sum: u128 = 0;
133
134            for j in 0..=k_max {
135                layer_sum = layer_sum.saturating_add(binomial(ni, j));
136            }
137
138            bound = bound.saturating_mul(layer_sum);
139        }
140
141        bound
142    }
143
144    /// Estimate actual linear regions by sampling
145    ///
146    /// Samples random points and counts how many distinct activation patterns occur.
147    pub fn estimate_linear_regions(&self, num_samples: usize, seed: u64) -> usize {
148        use std::collections::HashSet;
149
150        let mut activation_patterns = HashSet::new();
151        let input_dim = self.architecture[0];
152
153        let mut s = seed;
154        for _ in 0..num_samples {
155            // Generate random input
156            let mut input = Vec::with_capacity(input_dim);
157            for _ in 0..input_dim {
158                s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
159                let x = ((s >> 33) as f64 / (1u64 << 31) as f64) * 2.0 - 1.0;
160                input.push(x);
161            }
162
163            // Track activation pattern
164            let pattern = self.get_activation_pattern(&input);
165            activation_patterns.insert(pattern);
166        }
167
168        activation_patterns.len()
169    }
170
171    /// Get activation pattern (which neurons are active) for an input
172    fn get_activation_pattern(&self, input: &[f64]) -> Vec<bool> {
173        let mut x = input.to_vec();
174        let mut pattern = Vec::new();
175
176        for layer in 0..self.weights.len() {
177            let mut y = Vec::with_capacity(self.weights[layer].len());
178
179            for (neuron_weights, &bias) in self.weights[layer].iter().zip(self.biases[layer].iter())
180            {
181                let linear: f64 = neuron_weights
182                    .iter()
183                    .zip(x.iter())
184                    .map(|(w, xi)| w * xi)
185                    .sum();
186                let z = linear + bias;
187                pattern.push(z > 0.0);
188                y.push(z.max(0.0));
189            }
190
191            x = y;
192        }
193
194        pattern
195    }
196
197    /// Compute the tropical polynomial representation for 1D input
198    /// Returns the piecewise linear function f(x)
199    pub fn as_tropical_polynomial_1d(&self) -> Option<TropicalPolynomial> {
200        if self.architecture[0] != 1 || self.architecture[self.architecture.len() - 1] != 1 {
201            return None;
202        }
203
204        // For 1D input, we can enumerate the breakpoints
205        let breakpoints = self.find_breakpoints_1d(-10.0, 10.0, 1000);
206
207        if breakpoints.is_empty() {
208            return None;
209        }
210
211        // Build tropical polynomial from breakpoints
212        // Each breakpoint corresponds to a change in slope
213        let mut terms = Vec::new();
214        for (i, &x) in breakpoints.iter().enumerate() {
215            let y = self.forward(&[x])[0];
216            terms.push((y - (i as f64) * x, i as i32));
217        }
218
219        Some(TropicalPolynomial::from_monomials(
220            terms
221                .into_iter()
222                .map(|(c, e)| super::polynomial::TropicalMonomial::new(c, e))
223                .collect(),
224        ))
225    }
226
227    /// Find breakpoints of the 1D piecewise linear function
228    fn find_breakpoints_1d(&self, x_min: f64, x_max: f64, num_samples: usize) -> Vec<f64> {
229        let mut breakpoints = vec![x_min];
230        let dx = (x_max - x_min) / num_samples as f64;
231
232        let mut prev_pattern = self.get_activation_pattern(&[x_min]);
233
234        for i in 1..=num_samples {
235            let x = x_min + i as f64 * dx;
236            let pattern = self.get_activation_pattern(&[x]);
237
238            if pattern != prev_pattern {
239                // Breakpoint somewhere between previous x and current x
240                let breakpoint = self.binary_search_breakpoint(x - dx, x, &prev_pattern);
241                breakpoints.push(breakpoint);
242                prev_pattern = pattern;
243            }
244        }
245
246        breakpoints.push(x_max);
247        breakpoints
248    }
249
250    /// Binary search for exact breakpoint location
251    fn binary_search_breakpoint(&self, mut lo: f64, mut hi: f64, lo_pattern: &[bool]) -> f64 {
252        for _ in 0..50 {
253            let mid = (lo + hi) / 2.0;
254            let mid_pattern = self.get_activation_pattern(&[mid]);
255
256            if mid_pattern == *lo_pattern {
257                lo = mid;
258            } else {
259                hi = mid;
260            }
261        }
262
263        (lo + hi) / 2.0
264    }
265
266    /// Compute decision boundary complexity for binary classification
267    pub fn decision_boundary_complexity(&self, num_samples: usize, seed: u64) -> f64 {
268        // For a binary classifier, count sign changes in output
269        // along random rays through the input space
270
271        let input_dim = self.architecture[0];
272        let mut total_changes = 0;
273        let mut s = seed;
274
275        for _ in 0..num_samples {
276            // Random direction
277            let mut direction = Vec::with_capacity(input_dim);
278            for _ in 0..input_dim {
279                s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
280                let d = ((s >> 33) as f64 / (1u64 << 31) as f64) * 2.0 - 1.0;
281                direction.push(d);
282            }
283
284            // Normalize
285            let norm: f64 = direction.iter().map(|x| x * x).sum::<f64>().sqrt();
286            for d in direction.iter_mut() {
287                *d /= norm.max(1e-10);
288            }
289
290            // Count sign changes along ray
291            let mut prev_sign = None;
292            for t in -100..=100 {
293                let t = t as f64 * 0.1;
294                let input: Vec<f64> = direction.iter().map(|d| t * d).collect();
295                let output = self.forward(&input);
296
297                if !output.is_empty() {
298                    let sign = output[0] > 0.0;
299                    if let Some(prev) = prev_sign {
300                        if prev != sign {
301                            total_changes += 1;
302                        }
303                    }
304                    prev_sign = Some(sign);
305                }
306            }
307        }
308
309        total_changes as f64 / num_samples as f64
310    }
311}
312
313/// Counter for linear regions of piecewise linear functions
314#[derive(Debug, Clone)]
315pub struct LinearRegionCounter {
316    /// Dimension of input space
317    input_dim: usize,
318}
319
320impl LinearRegionCounter {
321    /// Create counter for given input dimension
322    pub fn new(input_dim: usize) -> Self {
323        Self { input_dim }
324    }
325
326    /// Theoretical maximum for n-dimensional input with k hyperplanes
327    /// This is the central zone counting problem
328    pub fn hyperplane_arrangement_max(&self, num_hyperplanes: usize) -> u128 {
329        // Maximum regions = sum_{i=0}^{n} C(k, i)
330        let n = self.input_dim as u128;
331        let k = num_hyperplanes as u128;
332
333        let mut total: u128 = 0;
334        for i in 0..=n.min(k) {
335            total = total.saturating_add(binomial(k, i));
336        }
337
338        total
339    }
340
341    /// Zaslavsky's theorem: count regions of hyperplane arrangement
342    /// For a general position arrangement of k hyperplanes in R^n:
343    /// regions = sum_{i=0}^n C(k, i)
344    pub fn zaslavsky_formula(&self, num_hyperplanes: usize) -> u128 {
345        self.hyperplane_arrangement_max(num_hyperplanes)
346    }
347}
348
349/// Compute binomial coefficient C(n, k) = n! / (k! * (n-k)!)
350fn binomial(n: u128, k: u128) -> u128 {
351    if k > n {
352        return 0;
353    }
354
355    let k = k.min(n - k); // Use symmetry
356
357    let mut result: u128 = 1;
358    for i in 0..k {
359        result = result.saturating_mul(n - i) / (i + 1);
360    }
361
362    result
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn test_relu_forward() {
371        let analysis = TropicalNeuralAnalysis::new(
372            vec![2, 3, 1],
373            vec![
374                vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]],
375                vec![vec![1.0, 1.0, 1.0]],
376            ],
377            vec![vec![0.0, 0.0, -1.0], vec![0.0]],
378        );
379
380        let output = analysis.forward(&[1.0, 1.0]);
381        assert!(output[0] > 0.0);
382    }
383
384    #[test]
385    fn test_linear_region_bound() {
386        // Network: 2 -> 4 -> 4 -> 1
387        let analysis = TropicalNeuralAnalysis::random(vec![2, 4, 4, 1], 42);
388        let bound = analysis.linear_region_upper_bound();
389
390        // For 2D input with hidden layers of 4:
391        // Upper bound = C(4,0)+C(4,1)+C(4,2) for each hidden layer
392        // = (1 + 4 + 6)^2 = 121
393        assert!(bound > 0);
394    }
395
396    #[test]
397    fn test_estimate_regions() {
398        let analysis = TropicalNeuralAnalysis::random(vec![2, 4, 1], 42);
399        let estimate = analysis.estimate_linear_regions(1000, 123);
400
401        // Should find multiple regions
402        assert!(estimate >= 1);
403    }
404
405    #[test]
406    fn test_binomial() {
407        assert_eq!(binomial(5, 2), 10);
408        assert_eq!(binomial(10, 0), 1);
409        assert_eq!(binomial(10, 10), 1);
410        assert_eq!(binomial(6, 3), 20);
411    }
412
413    #[test]
414    fn test_hyperplane_max() {
415        let counter = LinearRegionCounter::new(2);
416
417        // 3 lines in R^2 can create at most 1 + 3 + 3 = 7 regions
418        assert_eq!(counter.hyperplane_arrangement_max(3), 7);
419    }
420}