ruvector_mincut/snn/
optimizer.rs

1//! # Layer 6: Neural Optimizer as Reinforcement Learning on Graphs
2//!
3//! Implements neural network-based graph optimization using policy gradients.
4//!
5//! ## Architecture
6//!
7//! - **Policy SNN**: Outputs graph modification actions via spike rates
8//! - **Value Network**: Estimates mincut improvement
9//! - **Experience Replay**: Prioritized sampling for stable learning
10//!
11//! ## Key Features
12//!
13//! - STDP-based policy gradient for spike-driven learning
14//! - TD learning via spike-timing for value estimation
15//! - Subpolynomial search exploiting learned graph structure
16
17use super::{
18    network::{LayerConfig, NetworkConfig, SpikingNetwork},
19    neuron::{LIFNeuron, NeuronConfig, NeuronPopulation},
20    synapse::{STDPConfig, Synapse, SynapseMatrix},
21    SimTime, Spike,
22};
23use crate::graph::{DynamicGraph, EdgeId, VertexId, Weight};
24use std::collections::VecDeque;
25
26/// Configuration for neural graph optimizer
27#[derive(Debug, Clone)]
28pub struct OptimizerConfig {
29    /// Number of input features
30    pub input_size: usize,
31    /// Hidden layer size
32    pub hidden_size: usize,
33    /// Number of possible actions
34    pub num_actions: usize,
35    /// Learning rate
36    pub learning_rate: f64,
37    /// Discount factor (gamma)
38    pub gamma: f64,
39    /// Reward weight for search efficiency
40    pub search_weight: f64,
41    /// Experience replay buffer size
42    pub replay_buffer_size: usize,
43    /// Batch size for training
44    pub batch_size: usize,
45    /// Time step
46    pub dt: f64,
47}
48
49impl Default for OptimizerConfig {
50    fn default() -> Self {
51        Self {
52            input_size: 10,
53            hidden_size: 32,
54            num_actions: 5,
55            learning_rate: 0.01,
56            gamma: 0.99,
57            search_weight: 0.1,
58            replay_buffer_size: 10000,
59            batch_size: 32,
60            dt: 1.0,
61        }
62    }
63}
64
65/// Graph modification action
66#[derive(Debug, Clone, PartialEq)]
67pub enum GraphAction {
68    /// Add an edge between vertices
69    AddEdge(VertexId, VertexId, Weight),
70    /// Remove an edge
71    RemoveEdge(VertexId, VertexId),
72    /// Strengthen an edge (increase weight)
73    Strengthen(VertexId, VertexId, f64),
74    /// Weaken an edge (decrease weight)
75    Weaken(VertexId, VertexId, f64),
76    /// No action
77    NoOp,
78}
79
80impl GraphAction {
81    /// Get action index
82    pub fn to_index(&self) -> usize {
83        match self {
84            GraphAction::AddEdge(..) => 0,
85            GraphAction::RemoveEdge(..) => 1,
86            GraphAction::Strengthen(..) => 2,
87            GraphAction::Weaken(..) => 3,
88            GraphAction::NoOp => 4,
89        }
90    }
91}
92
93/// Result of an optimization step
94#[derive(Debug, Clone)]
95pub struct OptimizationResult {
96    /// Action taken
97    pub action: GraphAction,
98    /// Reward received
99    pub reward: f64,
100    /// New mincut value
101    pub new_mincut: f64,
102    /// Average search latency
103    pub search_latency: f64,
104}
105
106/// Experience for replay buffer
107#[derive(Debug, Clone)]
108struct Experience {
109    /// State features
110    state: Vec<f64>,
111    /// Action taken
112    action_idx: usize,
113    /// Reward received
114    reward: f64,
115    /// Next state features
116    next_state: Vec<f64>,
117    /// Is terminal state
118    done: bool,
119    /// TD error for prioritization
120    td_error: f64,
121}
122
123/// Prioritized experience replay buffer
124struct PrioritizedReplayBuffer {
125    /// Stored experiences
126    buffer: VecDeque<Experience>,
127    /// Maximum capacity
128    capacity: usize,
129}
130
131impl PrioritizedReplayBuffer {
132    fn new(capacity: usize) -> Self {
133        Self {
134            buffer: VecDeque::with_capacity(capacity),
135            capacity,
136        }
137    }
138
139    fn push(&mut self, exp: Experience) {
140        if self.buffer.len() >= self.capacity {
141            self.buffer.pop_front();
142        }
143        self.buffer.push_back(exp);
144    }
145
146    fn sample(&self, batch_size: usize) -> Vec<&Experience> {
147        // Prioritized by TD error (simplified: just take recent high-error samples)
148        let mut sorted: Vec<_> = self.buffer.iter().collect();
149        sorted.sort_by(|a, b| {
150            b.td_error
151                .abs()
152                .partial_cmp(&a.td_error.abs())
153                .unwrap_or(std::cmp::Ordering::Equal)
154        });
155
156        sorted.into_iter().take(batch_size).collect()
157    }
158
159    fn len(&self) -> usize {
160        self.buffer.len()
161    }
162}
163
164/// Simple value network for state value estimation
165#[derive(Debug, Clone)]
166pub struct ValueNetwork {
167    /// Weights from input to hidden
168    w_hidden: Vec<Vec<f64>>,
169    /// Hidden biases
170    b_hidden: Vec<f64>,
171    /// Weights from hidden to output
172    w_output: Vec<f64>,
173    /// Output bias
174    b_output: f64,
175    /// Last estimate (for TD error)
176    last_estimate: f64,
177}
178
179impl ValueNetwork {
180    /// Create a new value network
181    pub fn new(input_size: usize, hidden_size: usize) -> Self {
182        // Initialize with small random weights (Xavier initialization)
183        let scale = (2.0 / (input_size + hidden_size) as f64).sqrt();
184        let w_hidden: Vec<Vec<f64>> = (0..hidden_size)
185            .map(|_| (0..input_size).map(|_| rand_small() * scale).collect())
186            .collect();
187
188        let b_hidden = vec![0.0; hidden_size];
189
190        let output_scale = (1.0 / hidden_size as f64).sqrt();
191        let w_output: Vec<f64> = (0..hidden_size)
192            .map(|_| rand_small() * output_scale)
193            .collect();
194        let b_output = 0.0;
195
196        Self {
197            w_hidden,
198            b_hidden,
199            w_output,
200            b_output,
201            last_estimate: 0.0,
202        }
203    }
204
205    /// Estimate value of a state
206    pub fn estimate(&mut self, state: &[f64]) -> f64 {
207        // Hidden layer
208        let mut hidden = vec![0.0; self.w_hidden.len()];
209        for (j, weights) in self.w_hidden.iter().enumerate() {
210            let mut sum = self.b_hidden[j];
211            for (i, &w) in weights.iter().enumerate() {
212                if i < state.len() {
213                    sum += w * state[i];
214                }
215            }
216            hidden[j] = relu(sum);
217        }
218
219        // Output layer
220        let mut output = self.b_output;
221        for (j, &w) in self.w_output.iter().enumerate() {
222            output += w * hidden[j];
223        }
224
225        self.last_estimate = output;
226        output
227    }
228
229    /// Get previous estimate
230    pub fn estimate_previous(&self) -> f64 {
231        self.last_estimate
232    }
233
234    /// Update weights with TD error using proper backpropagation
235    ///
236    /// Implements gradient descent with:
237    /// - Forward pass to compute activations
238    /// - Backward pass to compute ∂V/∂w
239    /// - Weight update: w += lr * td_error * ∂V/∂w
240    pub fn update(&mut self, state: &[f64], td_error: f64, lr: f64) {
241        let hidden_size = self.w_hidden.len();
242        let input_size = if self.w_hidden.is_empty() {
243            0
244        } else {
245            self.w_hidden[0].len()
246        };
247
248        // Forward pass: compute hidden activations and pre-activations
249        let mut hidden_pre = vec![0.0; hidden_size]; // Before ReLU
250        let mut hidden_post = vec![0.0; hidden_size]; // After ReLU
251
252        for (j, weights) in self.w_hidden.iter().enumerate() {
253            let mut sum = self.b_hidden[j];
254            for (i, &w) in weights.iter().enumerate() {
255                if i < state.len() {
256                    sum += w * state[i];
257                }
258            }
259            hidden_pre[j] = sum;
260            hidden_post[j] = relu(sum);
261        }
262
263        // Backward pass: compute gradients
264        // Output layer gradient: ∂L/∂w_output = td_error * hidden_post
265        // (since L = 0.5 * td_error², ∂L/∂V = td_error)
266
267        // Update output weights: ∂V/∂w_output[j] = hidden_post[j]
268        for (j, w) in self.w_output.iter_mut().enumerate() {
269            *w += lr * td_error * hidden_post[j];
270        }
271        self.b_output += lr * td_error;
272
273        // Backpropagate to hidden layer
274        // ∂V/∂hidden_post[j] = w_output[j]
275        // ∂hidden_post/∂hidden_pre = relu'(hidden_pre) = 1 if hidden_pre > 0 else 0
276        // ∂V/∂w_hidden[j][i] = ∂V/∂hidden_post[j] * relu'(hidden_pre[j]) * state[i]
277
278        for (j, weights) in self.w_hidden.iter_mut().enumerate() {
279            // ReLU derivative: 1 if pre-activation > 0, else 0
280            let relu_grad = if hidden_pre[j] > 0.0 { 1.0 } else { 0.0 };
281            let delta = td_error * self.w_output[j] * relu_grad;
282
283            for (i, w) in weights.iter_mut().enumerate() {
284                if i < state.len() {
285                    *w += lr * delta * state[i];
286                }
287            }
288            self.b_hidden[j] += lr * delta;
289        }
290    }
291}
292
293/// Policy SNN for action selection
294pub struct PolicySNN {
295    /// Input layer
296    input_layer: NeuronPopulation,
297    /// Hidden recurrent layer
298    hidden_layer: NeuronPopulation,
299    /// Output layer (one neuron per action)
300    output_layer: NeuronPopulation,
301    /// Input → Hidden weights
302    w_ih: SynapseMatrix,
303    /// Hidden → Output weights
304    w_ho: SynapseMatrix,
305    /// Reward-modulated STDP configuration
306    stdp_config: STDPConfig,
307    /// Configuration
308    config: OptimizerConfig,
309}
310
311impl PolicySNN {
312    /// Create a new policy SNN
313    pub fn new(config: OptimizerConfig) -> Self {
314        let input_config = NeuronConfig {
315            tau_membrane: 10.0,
316            threshold: 0.8,
317            ..NeuronConfig::default()
318        };
319
320        let hidden_config = NeuronConfig {
321            tau_membrane: 20.0,
322            threshold: 1.0,
323            ..NeuronConfig::default()
324        };
325
326        let output_config = NeuronConfig {
327            tau_membrane: 15.0,
328            threshold: 0.6,
329            ..NeuronConfig::default()
330        };
331
332        let input_layer = NeuronPopulation::with_config(config.input_size, input_config);
333        let hidden_layer = NeuronPopulation::with_config(config.hidden_size, hidden_config);
334        let output_layer = NeuronPopulation::with_config(config.num_actions, output_config);
335
336        // Initialize weights
337        let mut w_ih = SynapseMatrix::new(config.input_size, config.hidden_size);
338        let mut w_ho = SynapseMatrix::new(config.hidden_size, config.num_actions);
339
340        // Random initialization
341        for i in 0..config.input_size {
342            for j in 0..config.hidden_size {
343                w_ih.add_synapse(i, j, rand_small() + 0.3);
344            }
345        }
346
347        for i in 0..config.hidden_size {
348            for j in 0..config.num_actions {
349                w_ho.add_synapse(i, j, rand_small() + 0.3);
350            }
351        }
352
353        Self {
354            input_layer,
355            hidden_layer,
356            output_layer,
357            w_ih,
358            w_ho,
359            stdp_config: STDPConfig::default(),
360            config,
361        }
362    }
363
364    /// Inject state as current to input layer
365    pub fn inject(&mut self, state: &[f64]) {
366        for (i, neuron) in self.input_layer.neurons.iter_mut().enumerate() {
367            if i < state.len() {
368                neuron.set_membrane_potential(state[i]);
369            }
370        }
371    }
372
373    /// Run until decision (output spike)
374    pub fn run_until_decision(&mut self, max_steps: usize) -> Vec<Spike> {
375        for step in 0..max_steps {
376            let time = step as f64 * self.config.dt;
377
378            // Compute hidden currents from input
379            let mut hidden_currents = vec![0.0; self.config.hidden_size];
380            for j in 0..self.config.hidden_size {
381                for i in 0..self.config.input_size {
382                    hidden_currents[j] += self.w_ih.weight(i, j)
383                        * self.input_layer.neurons[i].membrane_potential().max(0.0);
384                }
385            }
386
387            // Update hidden layer
388            let hidden_spikes = self.hidden_layer.step(&hidden_currents, self.config.dt);
389
390            // Compute output currents from hidden
391            let mut output_currents = vec![0.0; self.config.num_actions];
392            for j in 0..self.config.num_actions {
393                for i in 0..self.config.hidden_size {
394                    output_currents[j] += self.w_ho.weight(i, j)
395                        * self.hidden_layer.neurons[i].membrane_potential().max(0.0);
396                }
397            }
398
399            // Update output layer
400            let output_spikes = self.output_layer.step(&output_currents, self.config.dt);
401
402            // STDP updates
403            for spike in &hidden_spikes {
404                self.w_ih.on_post_spike(spike.neuron_id, time);
405            }
406            for spike in &output_spikes {
407                self.w_ho.on_post_spike(spike.neuron_id, time);
408            }
409
410            // Return if we have output spikes
411            if !output_spikes.is_empty() {
412                return output_spikes;
413            }
414        }
415
416        Vec::new()
417    }
418
419    /// Apply reward-modulated STDP
420    pub fn apply_reward_modulated_stdp(&mut self, td_error: f64) {
421        self.w_ih.apply_reward(td_error);
422        self.w_ho.apply_reward(td_error);
423    }
424
425    /// Get regions with low activity (for search skip)
426    pub fn low_activity_regions(&self) -> Vec<usize> {
427        self.hidden_layer
428            .spike_trains
429            .iter()
430            .enumerate()
431            .filter(|(_, t)| t.spike_rate(100.0) < 0.001)
432            .map(|(i, _)| i)
433            .collect()
434    }
435
436    /// Reset SNN state
437    pub fn reset(&mut self) {
438        self.input_layer.reset();
439        self.hidden_layer.reset();
440        self.output_layer.reset();
441    }
442}
443
444/// Neural Graph Optimizer combining policy and value networks
445pub struct NeuralGraphOptimizer {
446    /// Policy network: SNN that outputs graph modification actions
447    policy_snn: PolicySNN,
448    /// Value network: estimates mincut improvement
449    value_network: ValueNetwork,
450    /// Experience replay buffer
451    replay_buffer: PrioritizedReplayBuffer,
452    /// Current graph state
453    graph: DynamicGraph,
454    /// Configuration
455    config: OptimizerConfig,
456    /// Current simulation time
457    time: SimTime,
458    /// Previous mincut for reward computation
459    prev_mincut: f64,
460    /// Previous state for experience storage
461    prev_state: Vec<f64>,
462    /// Search statistics
463    search_latencies: VecDeque<f64>,
464}
465
466impl NeuralGraphOptimizer {
467    /// Create a new neural graph optimizer
468    pub fn new(graph: DynamicGraph, config: OptimizerConfig) -> Self {
469        let prev_state = extract_features(&graph, config.input_size);
470        let prev_mincut = estimate_mincut(&graph);
471
472        Self {
473            policy_snn: PolicySNN::new(config.clone()),
474            value_network: ValueNetwork::new(config.input_size, config.hidden_size),
475            replay_buffer: PrioritizedReplayBuffer::new(config.replay_buffer_size),
476            graph,
477            config,
478            time: 0.0,
479            prev_mincut,
480            prev_state,
481            search_latencies: VecDeque::with_capacity(100),
482        }
483    }
484
485    /// Run one optimization step
486    pub fn optimize_step(&mut self) -> OptimizationResult {
487        // 1. Encode current state as spike pattern
488        let state = extract_features(&self.graph, self.config.input_size);
489
490        // 2. Policy SNN outputs action distribution via spike rates
491        self.policy_snn.inject(&state);
492        let action_spikes = self.policy_snn.run_until_decision(50);
493        let action = self.decode_action(&action_spikes);
494
495        // 3. Execute action on graph
496        let old_mincut = estimate_mincut(&self.graph);
497        self.apply_action(&action);
498        let new_mincut = estimate_mincut(&self.graph);
499
500        // 4. Compute reward: mincut improvement + search efficiency
501        let mincut_reward = if old_mincut > 0.0 {
502            (new_mincut - old_mincut) / old_mincut
503        } else {
504            0.0
505        };
506
507        let search_reward = self.measure_search_efficiency();
508        let reward = mincut_reward + self.config.search_weight * search_reward;
509
510        // 5. TD learning update via spike-timing
511        let new_state = extract_features(&self.graph, self.config.input_size);
512        let current_value = self.value_network.estimate(&state);
513        let next_value = self.value_network.estimate(&new_state);
514
515        let td_error = reward + self.config.gamma * next_value - current_value;
516
517        // 6. STDP-based policy gradient
518        self.policy_snn.apply_reward_modulated_stdp(td_error);
519
520        // 7. Update value network
521        self.value_network
522            .update(&state, td_error, self.config.learning_rate);
523
524        // 8. Store experience
525        let exp = Experience {
526            state: self.prev_state.clone(),
527            action_idx: action.to_index(),
528            reward,
529            next_state: new_state.clone(),
530            done: false,
531            td_error,
532        };
533        self.replay_buffer.push(exp);
534
535        // Update state
536        self.prev_state = new_state;
537        self.prev_mincut = new_mincut;
538        self.time += self.config.dt;
539
540        OptimizationResult {
541            action,
542            reward,
543            new_mincut,
544            search_latency: search_reward,
545        }
546    }
547
548    /// Decode action from spikes
549    fn decode_action(&self, spikes: &[Spike]) -> GraphAction {
550        if spikes.is_empty() {
551            return GraphAction::NoOp;
552        }
553
554        // Use first spike's neuron as action
555        let action_idx = spikes[0].neuron_id;
556
557        // Get random vertices for action
558        let vertices: Vec<_> = self.graph.vertices();
559
560        if vertices.len() < 2 {
561            return GraphAction::NoOp;
562        }
563
564        let v1 = vertices[action_idx % vertices.len()];
565        let v2 = vertices[(action_idx + 1) % vertices.len()];
566
567        match action_idx % 5 {
568            0 => {
569                if !self.graph.has_edge(v1, v2) {
570                    GraphAction::AddEdge(v1, v2, 1.0)
571                } else {
572                    GraphAction::NoOp
573                }
574            }
575            1 => {
576                if self.graph.has_edge(v1, v2) {
577                    GraphAction::RemoveEdge(v1, v2)
578                } else {
579                    GraphAction::NoOp
580                }
581            }
582            2 => GraphAction::Strengthen(v1, v2, 0.1),
583            3 => GraphAction::Weaken(v1, v2, 0.1),
584            _ => GraphAction::NoOp,
585        }
586    }
587
588    /// Apply action to graph
589    fn apply_action(&mut self, action: &GraphAction) {
590        match action {
591            GraphAction::AddEdge(u, v, w) => {
592                if !self.graph.has_edge(*u, *v) {
593                    let _ = self.graph.insert_edge(*u, *v, *w);
594                }
595            }
596            GraphAction::RemoveEdge(u, v) => {
597                let _ = self.graph.delete_edge(*u, *v);
598            }
599            GraphAction::Strengthen(u, v, delta) => {
600                if let Some(edge) = self.graph.get_edge(*u, *v) {
601                    let _ = self.graph.update_edge_weight(*u, *v, edge.weight + delta);
602                }
603            }
604            GraphAction::Weaken(u, v, delta) => {
605                if let Some(edge) = self.graph.get_edge(*u, *v) {
606                    let new_weight = (edge.weight - delta).max(0.01);
607                    let _ = self.graph.update_edge_weight(*u, *v, new_weight);
608                }
609            }
610            GraphAction::NoOp => {}
611        }
612    }
613
614    /// Measure search efficiency
615    fn measure_search_efficiency(&mut self) -> f64 {
616        // Simplified: based on graph connectivity
617        let n = self.graph.num_vertices() as f64;
618        let m = self.graph.num_edges() as f64;
619
620        if n < 2.0 {
621            return 0.0;
622        }
623
624        // Higher connectivity relative to vertices = better search
625        let efficiency = m / (n * (n - 1.0) / 2.0);
626
627        self.search_latencies.push_back(efficiency);
628        if self.search_latencies.len() > 100 {
629            self.search_latencies.pop_front();
630        }
631
632        efficiency
633    }
634
635    /// Get learned skip regions for subpolynomial search
636    pub fn search_skip_regions(&self) -> Vec<usize> {
637        self.policy_snn.low_activity_regions()
638    }
639
640    /// Search with learned structure
641    pub fn search(&self, query: &[f64], k: usize) -> Vec<VertexId> {
642        // Use skip regions to guide search
643        let skip_regions = self.search_skip_regions();
644
645        // Simple nearest neighbor in graph space
646        let vertices: Vec<_> = self.graph.vertices();
647
648        let mut scores: Vec<(VertexId, f64)> = vertices
649            .iter()
650            .enumerate()
651            .filter(|(i, _)| !skip_regions.contains(i))
652            .map(|(i, &v)| {
653                // Score based on degree (proxy for centrality)
654                let score = self.graph.degree(v) as f64;
655                (v, score)
656            })
657            .collect();
658
659        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
660
661        scores.into_iter().take(k).map(|(v, _)| v).collect()
662    }
663
664    /// Get underlying graph
665    pub fn graph(&self) -> &DynamicGraph {
666        &self.graph
667    }
668
669    /// Get mutable graph
670    pub fn graph_mut(&mut self) -> &mut DynamicGraph {
671        &mut self.graph
672    }
673
674    /// Run multiple optimization steps
675    pub fn optimize(&mut self, steps: usize) -> Vec<OptimizationResult> {
676        (0..steps).map(|_| self.optimize_step()).collect()
677    }
678
679    /// Reset optimizer state
680    pub fn reset(&mut self) {
681        self.policy_snn.reset();
682        self.prev_mincut = estimate_mincut(&self.graph);
683        self.prev_state = extract_features(&self.graph, self.config.input_size);
684        self.time = 0.0;
685    }
686}
687
688/// Extract features from graph
689fn extract_features(graph: &DynamicGraph, num_features: usize) -> Vec<f64> {
690    let n = graph.num_vertices() as f64;
691    let m = graph.num_edges() as f64;
692
693    let mut features = vec![0.0; num_features];
694
695    if num_features > 0 {
696        features[0] = n / 1000.0; // Normalized vertex count
697    }
698    if num_features > 1 {
699        features[1] = m / 5000.0; // Normalized edge count
700    }
701    if num_features > 2 {
702        features[2] = if n > 1.0 {
703            m / (n * (n - 1.0) / 2.0)
704        } else {
705            0.0
706        }; // Density
707    }
708    if num_features > 3 {
709        // Average degree
710        let avg_deg: f64 = graph
711            .vertices()
712            .iter()
713            .map(|&v| graph.degree(v) as f64)
714            .sum::<f64>()
715            / n.max(1.0);
716        features[3] = avg_deg / 10.0;
717    }
718    if num_features > 4 {
719        features[4] = estimate_mincut(graph) / m.max(1.0); // Normalized mincut
720    }
721
722    // Fill rest with zeros or derived features
723    for i in 5..num_features {
724        features[i] = features[i % 5] * 0.1;
725    }
726
727    features
728}
729
730/// Estimate mincut (simplified)
731fn estimate_mincut(graph: &DynamicGraph) -> f64 {
732    if graph.num_vertices() == 0 {
733        return 0.0;
734    }
735
736    graph
737        .vertices()
738        .iter()
739        .map(|&v| graph.degree(v) as f64)
740        .fold(f64::INFINITY, f64::min)
741}
742
743// Thread-safe PRNG helpers using atomic CAS
744use std::sync::atomic::{AtomicU64, Ordering};
745static OPTIMIZER_RNG: AtomicU64 = AtomicU64::new(0xdeadbeef12345678);
746
747fn rand_small() -> f64 {
748    // Use compare_exchange loop to ensure atomicity
749    let state = loop {
750        let current = OPTIMIZER_RNG.load(Ordering::Relaxed);
751        let next = current.wrapping_mul(0x5851f42d4c957f2d).wrapping_add(1);
752        match OPTIMIZER_RNG.compare_exchange_weak(
753            current,
754            next,
755            Ordering::Relaxed,
756            Ordering::Relaxed,
757        ) {
758            Ok(_) => break next,
759            Err(_) => continue,
760        }
761    };
762    (state as f64) / (u64::MAX as f64) * 0.4 - 0.2
763}
764
765fn relu(x: f64) -> f64 {
766    x.max(0.0)
767}
768
769#[cfg(test)]
770mod tests {
771    use super::*;
772
773    #[test]
774    fn test_value_network() {
775        let mut network = ValueNetwork::new(5, 10);
776
777        let state = vec![0.5, 0.3, 0.8, 0.2, 0.9];
778        let value = network.estimate(&state);
779
780        assert!(value.is_finite());
781    }
782
783    #[test]
784    fn test_policy_snn() {
785        let config = OptimizerConfig::default();
786        let mut policy = PolicySNN::new(config);
787
788        let state = vec![1.0; 10];
789        policy.inject(&state);
790
791        let spikes = policy.run_until_decision(100);
792        // May or may not spike
793        assert!(spikes.len() >= 0);
794    }
795
796    #[test]
797    fn test_neural_optimizer() {
798        let graph = DynamicGraph::new();
799        for i in 0..10 {
800            graph.insert_edge(i, (i + 1) % 10, 1.0).unwrap();
801        }
802
803        let config = OptimizerConfig::default();
804        let mut optimizer = NeuralGraphOptimizer::new(graph, config);
805
806        let result = optimizer.optimize_step();
807
808        assert!(result.new_mincut.is_finite());
809    }
810
811    #[test]
812    fn test_optimize_multiple() {
813        let graph = DynamicGraph::new();
814        for i in 0..5 {
815            for j in (i + 1)..5 {
816                graph.insert_edge(i, j, 1.0).unwrap();
817            }
818        }
819
820        let config = OptimizerConfig::default();
821        let mut optimizer = NeuralGraphOptimizer::new(graph, config);
822
823        let results = optimizer.optimize(10);
824        assert_eq!(results.len(), 10);
825    }
826
827    #[test]
828    fn test_search() {
829        let graph = DynamicGraph::new();
830        for i in 0..20 {
831            graph.insert_edge(i, (i + 1) % 20, 1.0).unwrap();
832        }
833
834        let config = OptimizerConfig::default();
835        let optimizer = NeuralGraphOptimizer::new(graph, config);
836
837        let results = optimizer.search(&[0.5; 10], 5);
838        assert!(results.len() <= 5);
839    }
840}