ruvector_math/tropical/
neural_analysis.rs1use super::polynomial::TropicalPolynomial;
20
21#[derive(Debug, Clone)]
23pub struct TropicalNeuralAnalysis {
24 architecture: Vec<usize>,
26 weights: Vec<Vec<Vec<f64>>>,
28 biases: Vec<Vec<f64>>,
30}
31
32impl TropicalNeuralAnalysis {
33 pub fn new(
35 architecture: Vec<usize>,
36 weights: Vec<Vec<Vec<f64>>>,
37 biases: Vec<Vec<f64>>,
38 ) -> Self {
39 Self { architecture, weights, biases }
40 }
41
42 pub fn random(architecture: Vec<usize>, seed: u64) -> Self {
44 use std::collections::hash_map::DefaultHasher;
45 use std::hash::{Hash, Hasher};
46
47 let mut weights = Vec::new();
48 let mut biases = Vec::new();
49
50 let mut s = seed;
51 for i in 1..architecture.len() {
52 let input_size = architecture[i - 1];
53 let output_size = architecture[i];
54
55 let mut layer_weights = Vec::new();
56 for _ in 0..output_size {
57 let mut neuron_weights = Vec::new();
58 for _ in 0..input_size {
59 s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
61 let w = ((s >> 33) as f64 / (1u64 << 31) as f64) - 1.0;
62 neuron_weights.push(w);
63 }
64 layer_weights.push(neuron_weights);
65 }
66 weights.push(layer_weights);
67
68 let mut layer_biases = Vec::new();
69 for _ in 0..output_size {
70 s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
71 let b = ((s >> 33) as f64 / (1u64 << 31) as f64) - 1.0;
72 layer_biases.push(b * 0.1);
73 }
74 biases.push(layer_biases);
75 }
76
77 Self { architecture, weights, biases }
78 }
79
80 pub fn forward(&self, input: &[f64]) -> Vec<f64> {
82 let mut x = input.to_vec();
83
84 for layer in 0..self.weights.len() {
85 let mut y = Vec::with_capacity(self.weights[layer].len());
86
87 for (neuron_weights, &bias) in self.weights[layer].iter().zip(self.biases[layer].iter()) {
88 let linear: f64 = neuron_weights.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum();
89 let z = linear + bias;
90 y.push(z.max(0.0));
92 }
93
94 x = y;
95 }
96
97 x
98 }
99
100 pub fn linear_region_upper_bound(&self) -> u128 {
107 if self.architecture.len() < 2 {
108 return 1;
109 }
110
111 let n0 = self.architecture[0] as u128;
112 let mut bound: u128 = 1;
113
114 for i in 1..self.architecture.len() - 1 {
115 let ni = self.architecture[i] as u128;
116
117 let k_max = n0.min(ni);
119 let mut layer_sum: u128 = 0;
120
121 for j in 0..=k_max {
122 layer_sum = layer_sum.saturating_add(binomial(ni, j));
123 }
124
125 bound = bound.saturating_mul(layer_sum);
126 }
127
128 bound
129 }
130
131 pub fn estimate_linear_regions(&self, num_samples: usize, seed: u64) -> usize {
135 use std::collections::HashSet;
136
137 let mut activation_patterns = HashSet::new();
138 let input_dim = self.architecture[0];
139
140 let mut s = seed;
141 for _ in 0..num_samples {
142 let mut input = Vec::with_capacity(input_dim);
144 for _ in 0..input_dim {
145 s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
146 let x = ((s >> 33) as f64 / (1u64 << 31) as f64) * 2.0 - 1.0;
147 input.push(x);
148 }
149
150 let pattern = self.get_activation_pattern(&input);
152 activation_patterns.insert(pattern);
153 }
154
155 activation_patterns.len()
156 }
157
158 fn get_activation_pattern(&self, input: &[f64]) -> Vec<bool> {
160 let mut x = input.to_vec();
161 let mut pattern = Vec::new();
162
163 for layer in 0..self.weights.len() {
164 let mut y = Vec::with_capacity(self.weights[layer].len());
165
166 for (neuron_weights, &bias) in self.weights[layer].iter().zip(self.biases[layer].iter()) {
167 let linear: f64 = neuron_weights.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum();
168 let z = linear + bias;
169 pattern.push(z > 0.0);
170 y.push(z.max(0.0));
171 }
172
173 x = y;
174 }
175
176 pattern
177 }
178
179 pub fn as_tropical_polynomial_1d(&self) -> Option<TropicalPolynomial> {
182 if self.architecture[0] != 1 || self.architecture[self.architecture.len() - 1] != 1 {
183 return None;
184 }
185
186 let breakpoints = self.find_breakpoints_1d(-10.0, 10.0, 1000);
188
189 if breakpoints.is_empty() {
190 return None;
191 }
192
193 let mut terms = Vec::new();
196 for (i, &x) in breakpoints.iter().enumerate() {
197 let y = self.forward(&[x])[0];
198 terms.push((y - (i as f64) * x, i as i32));
199 }
200
201 Some(TropicalPolynomial::from_monomials(
202 terms.into_iter().map(|(c, e)| super::polynomial::TropicalMonomial::new(c, e)).collect()
203 ))
204 }
205
206 fn find_breakpoints_1d(&self, x_min: f64, x_max: f64, num_samples: usize) -> Vec<f64> {
208 let mut breakpoints = vec![x_min];
209 let dx = (x_max - x_min) / num_samples as f64;
210
211 let mut prev_pattern = self.get_activation_pattern(&[x_min]);
212
213 for i in 1..=num_samples {
214 let x = x_min + i as f64 * dx;
215 let pattern = self.get_activation_pattern(&[x]);
216
217 if pattern != prev_pattern {
218 let breakpoint = self.binary_search_breakpoint(x - dx, x, &prev_pattern);
220 breakpoints.push(breakpoint);
221 prev_pattern = pattern;
222 }
223 }
224
225 breakpoints.push(x_max);
226 breakpoints
227 }
228
229 fn binary_search_breakpoint(&self, mut lo: f64, mut hi: f64, lo_pattern: &[bool]) -> f64 {
231 for _ in 0..50 {
232 let mid = (lo + hi) / 2.0;
233 let mid_pattern = self.get_activation_pattern(&[mid]);
234
235 if mid_pattern == *lo_pattern {
236 lo = mid;
237 } else {
238 hi = mid;
239 }
240 }
241
242 (lo + hi) / 2.0
243 }
244
245 pub fn decision_boundary_complexity(&self, num_samples: usize, seed: u64) -> f64 {
247 let input_dim = self.architecture[0];
251 let mut total_changes = 0;
252 let mut s = seed;
253
254 for _ in 0..num_samples {
255 let mut direction = Vec::with_capacity(input_dim);
257 for _ in 0..input_dim {
258 s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
259 let d = ((s >> 33) as f64 / (1u64 << 31) as f64) * 2.0 - 1.0;
260 direction.push(d);
261 }
262
263 let norm: f64 = direction.iter().map(|x| x * x).sum::<f64>().sqrt();
265 for d in direction.iter_mut() {
266 *d /= norm.max(1e-10);
267 }
268
269 let mut prev_sign = None;
271 for t in -100..=100 {
272 let t = t as f64 * 0.1;
273 let input: Vec<f64> = direction.iter().map(|d| t * d).collect();
274 let output = self.forward(&input);
275
276 if !output.is_empty() {
277 let sign = output[0] > 0.0;
278 if let Some(prev) = prev_sign {
279 if prev != sign {
280 total_changes += 1;
281 }
282 }
283 prev_sign = Some(sign);
284 }
285 }
286 }
287
288 total_changes as f64 / num_samples as f64
289 }
290}
291
292#[derive(Debug, Clone)]
294pub struct LinearRegionCounter {
295 input_dim: usize,
297}
298
299impl LinearRegionCounter {
300 pub fn new(input_dim: usize) -> Self {
302 Self { input_dim }
303 }
304
305 pub fn hyperplane_arrangement_max(&self, num_hyperplanes: usize) -> u128 {
308 let n = self.input_dim as u128;
310 let k = num_hyperplanes as u128;
311
312 let mut total: u128 = 0;
313 for i in 0..=n.min(k) {
314 total = total.saturating_add(binomial(k, i));
315 }
316
317 total
318 }
319
320 pub fn zaslavsky_formula(&self, num_hyperplanes: usize) -> u128 {
324 self.hyperplane_arrangement_max(num_hyperplanes)
325 }
326}
327
328fn binomial(n: u128, k: u128) -> u128 {
330 if k > n {
331 return 0;
332 }
333
334 let k = k.min(n - k); let mut result: u128 = 1;
337 for i in 0..k {
338 result = result.saturating_mul(n - i) / (i + 1);
339 }
340
341 result
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_relu_forward() {
350 let analysis = TropicalNeuralAnalysis::new(
351 vec![2, 3, 1],
352 vec![
353 vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]],
354 vec![vec![1.0, 1.0, 1.0]],
355 ],
356 vec![
357 vec![0.0, 0.0, -1.0],
358 vec![0.0],
359 ],
360 );
361
362 let output = analysis.forward(&[1.0, 1.0]);
363 assert!(output[0] > 0.0);
364 }
365
366 #[test]
367 fn test_linear_region_bound() {
368 let analysis = TropicalNeuralAnalysis::random(vec![2, 4, 4, 1], 42);
370 let bound = analysis.linear_region_upper_bound();
371
372 assert!(bound > 0);
376 }
377
378 #[test]
379 fn test_estimate_regions() {
380 let analysis = TropicalNeuralAnalysis::random(vec![2, 4, 1], 42);
381 let estimate = analysis.estimate_linear_regions(1000, 123);
382
383 assert!(estimate >= 1);
385 }
386
387 #[test]
388 fn test_binomial() {
389 assert_eq!(binomial(5, 2), 10);
390 assert_eq!(binomial(10, 0), 1);
391 assert_eq!(binomial(10, 10), 1);
392 assert_eq!(binomial(6, 3), 20);
393 }
394
395 #[test]
396 fn test_hyperplane_max() {
397 let counter = LinearRegionCounter::new(2);
398
399 assert_eq!(counter.hyperplane_arrangement_max(3), 7);
401 }
402}