ruvector_mincut/snn/
optimizer.rs

1//! # Layer 6: Neural Optimizer as Reinforcement Learning on Graphs
2//!
3//! Implements neural network-based graph optimization using policy gradients.
4//!
5//! ## Architecture
6//!
7//! - **Policy SNN**: Outputs graph modification actions via spike rates
8//! - **Value Network**: Estimates mincut improvement
9//! - **Experience Replay**: Prioritized sampling for stable learning
10//!
11//! ## Key Features
12//!
13//! - STDP-based policy gradient for spike-driven learning
14//! - TD learning via spike-timing for value estimation
15//! - Subpolynomial search exploiting learned graph structure
16
17use super::{
18    neuron::{LIFNeuron, NeuronConfig, NeuronPopulation},
19    synapse::{Synapse, SynapseMatrix, STDPConfig},
20    network::{SpikingNetwork, NetworkConfig, LayerConfig},
21    SimTime, Spike,
22};
23use crate::graph::{DynamicGraph, VertexId, EdgeId, Weight};
24use std::collections::VecDeque;
25
26/// Configuration for neural graph optimizer
27#[derive(Debug, Clone)]
28pub struct OptimizerConfig {
29    /// Number of input features
30    pub input_size: usize,
31    /// Hidden layer size
32    pub hidden_size: usize,
33    /// Number of possible actions
34    pub num_actions: usize,
35    /// Learning rate
36    pub learning_rate: f64,
37    /// Discount factor (gamma)
38    pub gamma: f64,
39    /// Reward weight for search efficiency
40    pub search_weight: f64,
41    /// Experience replay buffer size
42    pub replay_buffer_size: usize,
43    /// Batch size for training
44    pub batch_size: usize,
45    /// Time step
46    pub dt: f64,
47}
48
49impl Default for OptimizerConfig {
50    fn default() -> Self {
51        Self {
52            input_size: 10,
53            hidden_size: 32,
54            num_actions: 5,
55            learning_rate: 0.01,
56            gamma: 0.99,
57            search_weight: 0.1,
58            replay_buffer_size: 10000,
59            batch_size: 32,
60            dt: 1.0,
61        }
62    }
63}
64
65/// Graph modification action
66#[derive(Debug, Clone, PartialEq)]
67pub enum GraphAction {
68    /// Add an edge between vertices
69    AddEdge(VertexId, VertexId, Weight),
70    /// Remove an edge
71    RemoveEdge(VertexId, VertexId),
72    /// Strengthen an edge (increase weight)
73    Strengthen(VertexId, VertexId, f64),
74    /// Weaken an edge (decrease weight)
75    Weaken(VertexId, VertexId, f64),
76    /// No action
77    NoOp,
78}
79
80impl GraphAction {
81    /// Get action index
82    pub fn to_index(&self) -> usize {
83        match self {
84            GraphAction::AddEdge(..) => 0,
85            GraphAction::RemoveEdge(..) => 1,
86            GraphAction::Strengthen(..) => 2,
87            GraphAction::Weaken(..) => 3,
88            GraphAction::NoOp => 4,
89        }
90    }
91}
92
93/// Result of an optimization step
94#[derive(Debug, Clone)]
95pub struct OptimizationResult {
96    /// Action taken
97    pub action: GraphAction,
98    /// Reward received
99    pub reward: f64,
100    /// New mincut value
101    pub new_mincut: f64,
102    /// Average search latency
103    pub search_latency: f64,
104}
105
106/// Experience for replay buffer
107#[derive(Debug, Clone)]
108struct Experience {
109    /// State features
110    state: Vec<f64>,
111    /// Action taken
112    action_idx: usize,
113    /// Reward received
114    reward: f64,
115    /// Next state features
116    next_state: Vec<f64>,
117    /// Is terminal state
118    done: bool,
119    /// TD error for prioritization
120    td_error: f64,
121}
122
123/// Prioritized experience replay buffer
124struct PrioritizedReplayBuffer {
125    /// Stored experiences
126    buffer: VecDeque<Experience>,
127    /// Maximum capacity
128    capacity: usize,
129}
130
131impl PrioritizedReplayBuffer {
132    fn new(capacity: usize) -> Self {
133        Self {
134            buffer: VecDeque::with_capacity(capacity),
135            capacity,
136        }
137    }
138
139    fn push(&mut self, exp: Experience) {
140        if self.buffer.len() >= self.capacity {
141            self.buffer.pop_front();
142        }
143        self.buffer.push_back(exp);
144    }
145
146    fn sample(&self, batch_size: usize) -> Vec<&Experience> {
147        // Prioritized by TD error (simplified: just take recent high-error samples)
148        let mut sorted: Vec<_> = self.buffer.iter().collect();
149        sorted.sort_by(|a, b| {
150            b.td_error.abs().partial_cmp(&a.td_error.abs()).unwrap_or(std::cmp::Ordering::Equal)
151        });
152
153        sorted.into_iter().take(batch_size).collect()
154    }
155
156    fn len(&self) -> usize {
157        self.buffer.len()
158    }
159}
160
161/// Simple value network for state value estimation
162#[derive(Debug, Clone)]
163pub struct ValueNetwork {
164    /// Weights from input to hidden
165    w_hidden: Vec<Vec<f64>>,
166    /// Hidden biases
167    b_hidden: Vec<f64>,
168    /// Weights from hidden to output
169    w_output: Vec<f64>,
170    /// Output bias
171    b_output: f64,
172    /// Last estimate (for TD error)
173    last_estimate: f64,
174}
175
176impl ValueNetwork {
177    /// Create a new value network
178    pub fn new(input_size: usize, hidden_size: usize) -> Self {
179        // Initialize with small random weights (Xavier initialization)
180        let scale = (2.0 / (input_size + hidden_size) as f64).sqrt();
181        let w_hidden: Vec<Vec<f64>> = (0..hidden_size)
182            .map(|_| (0..input_size).map(|_| rand_small() * scale).collect())
183            .collect();
184
185        let b_hidden = vec![0.0; hidden_size];
186
187        let output_scale = (1.0 / hidden_size as f64).sqrt();
188        let w_output: Vec<f64> = (0..hidden_size).map(|_| rand_small() * output_scale).collect();
189        let b_output = 0.0;
190
191        Self {
192            w_hidden,
193            b_hidden,
194            w_output,
195            b_output,
196            last_estimate: 0.0,
197        }
198    }
199
200    /// Estimate value of a state
201    pub fn estimate(&mut self, state: &[f64]) -> f64 {
202        // Hidden layer
203        let mut hidden = vec![0.0; self.w_hidden.len()];
204        for (j, weights) in self.w_hidden.iter().enumerate() {
205            let mut sum = self.b_hidden[j];
206            for (i, &w) in weights.iter().enumerate() {
207                if i < state.len() {
208                    sum += w * state[i];
209                }
210            }
211            hidden[j] = relu(sum);
212        }
213
214        // Output layer
215        let mut output = self.b_output;
216        for (j, &w) in self.w_output.iter().enumerate() {
217            output += w * hidden[j];
218        }
219
220        self.last_estimate = output;
221        output
222    }
223
224    /// Get previous estimate
225    pub fn estimate_previous(&self) -> f64 {
226        self.last_estimate
227    }
228
229    /// Update weights with TD error using proper backpropagation
230    ///
231    /// Implements gradient descent with:
232    /// - Forward pass to compute activations
233    /// - Backward pass to compute ∂V/∂w
234    /// - Weight update: w += lr * td_error * ∂V/∂w
235    pub fn update(&mut self, state: &[f64], td_error: f64, lr: f64) {
236        let hidden_size = self.w_hidden.len();
237        let input_size = if self.w_hidden.is_empty() { 0 } else { self.w_hidden[0].len() };
238
239        // Forward pass: compute hidden activations and pre-activations
240        let mut hidden_pre = vec![0.0; hidden_size]; // Before ReLU
241        let mut hidden_post = vec![0.0; hidden_size]; // After ReLU
242
243        for (j, weights) in self.w_hidden.iter().enumerate() {
244            let mut sum = self.b_hidden[j];
245            for (i, &w) in weights.iter().enumerate() {
246                if i < state.len() {
247                    sum += w * state[i];
248                }
249            }
250            hidden_pre[j] = sum;
251            hidden_post[j] = relu(sum);
252        }
253
254        // Backward pass: compute gradients
255        // Output layer gradient: ∂L/∂w_output = td_error * hidden_post
256        // (since L = 0.5 * td_error², ∂L/∂V = td_error)
257
258        // Update output weights: ∂V/∂w_output[j] = hidden_post[j]
259        for (j, w) in self.w_output.iter_mut().enumerate() {
260            *w += lr * td_error * hidden_post[j];
261        }
262        self.b_output += lr * td_error;
263
264        // Backpropagate to hidden layer
265        // ∂V/∂hidden_post[j] = w_output[j]
266        // ∂hidden_post/∂hidden_pre = relu'(hidden_pre) = 1 if hidden_pre > 0 else 0
267        // ∂V/∂w_hidden[j][i] = ∂V/∂hidden_post[j] * relu'(hidden_pre[j]) * state[i]
268
269        for (j, weights) in self.w_hidden.iter_mut().enumerate() {
270            // ReLU derivative: 1 if pre-activation > 0, else 0
271            let relu_grad = if hidden_pre[j] > 0.0 { 1.0 } else { 0.0 };
272            let delta = td_error * self.w_output[j] * relu_grad;
273
274            for (i, w) in weights.iter_mut().enumerate() {
275                if i < state.len() {
276                    *w += lr * delta * state[i];
277                }
278            }
279            self.b_hidden[j] += lr * delta;
280        }
281    }
282}
283
284/// Policy SNN for action selection
285pub struct PolicySNN {
286    /// Input layer
287    input_layer: NeuronPopulation,
288    /// Hidden recurrent layer
289    hidden_layer: NeuronPopulation,
290    /// Output layer (one neuron per action)
291    output_layer: NeuronPopulation,
292    /// Input → Hidden weights
293    w_ih: SynapseMatrix,
294    /// Hidden → Output weights
295    w_ho: SynapseMatrix,
296    /// Reward-modulated STDP configuration
297    stdp_config: STDPConfig,
298    /// Configuration
299    config: OptimizerConfig,
300}
301
302impl PolicySNN {
303    /// Create a new policy SNN
304    pub fn new(config: OptimizerConfig) -> Self {
305        let input_config = NeuronConfig {
306            tau_membrane: 10.0,
307            threshold: 0.8,
308            ..NeuronConfig::default()
309        };
310
311        let hidden_config = NeuronConfig {
312            tau_membrane: 20.0,
313            threshold: 1.0,
314            ..NeuronConfig::default()
315        };
316
317        let output_config = NeuronConfig {
318            tau_membrane: 15.0,
319            threshold: 0.6,
320            ..NeuronConfig::default()
321        };
322
323        let input_layer = NeuronPopulation::with_config(config.input_size, input_config);
324        let hidden_layer = NeuronPopulation::with_config(config.hidden_size, hidden_config);
325        let output_layer = NeuronPopulation::with_config(config.num_actions, output_config);
326
327        // Initialize weights
328        let mut w_ih = SynapseMatrix::new(config.input_size, config.hidden_size);
329        let mut w_ho = SynapseMatrix::new(config.hidden_size, config.num_actions);
330
331        // Random initialization
332        for i in 0..config.input_size {
333            for j in 0..config.hidden_size {
334                w_ih.add_synapse(i, j, rand_small() + 0.3);
335            }
336        }
337
338        for i in 0..config.hidden_size {
339            for j in 0..config.num_actions {
340                w_ho.add_synapse(i, j, rand_small() + 0.3);
341            }
342        }
343
344        Self {
345            input_layer,
346            hidden_layer,
347            output_layer,
348            w_ih,
349            w_ho,
350            stdp_config: STDPConfig::default(),
351            config,
352        }
353    }
354
355    /// Inject state as current to input layer
356    pub fn inject(&mut self, state: &[f64]) {
357        for (i, neuron) in self.input_layer.neurons.iter_mut().enumerate() {
358            if i < state.len() {
359                neuron.set_membrane_potential(state[i]);
360            }
361        }
362    }
363
364    /// Run until decision (output spike)
365    pub fn run_until_decision(&mut self, max_steps: usize) -> Vec<Spike> {
366        for step in 0..max_steps {
367            let time = step as f64 * self.config.dt;
368
369            // Compute hidden currents from input
370            let mut hidden_currents = vec![0.0; self.config.hidden_size];
371            for j in 0..self.config.hidden_size {
372                for i in 0..self.config.input_size {
373                    hidden_currents[j] += self.w_ih.weight(i, j) *
374                        self.input_layer.neurons[i].membrane_potential().max(0.0);
375                }
376            }
377
378            // Update hidden layer
379            let hidden_spikes = self.hidden_layer.step(&hidden_currents, self.config.dt);
380
381            // Compute output currents from hidden
382            let mut output_currents = vec![0.0; self.config.num_actions];
383            for j in 0..self.config.num_actions {
384                for i in 0..self.config.hidden_size {
385                    output_currents[j] += self.w_ho.weight(i, j) *
386                        self.hidden_layer.neurons[i].membrane_potential().max(0.0);
387                }
388            }
389
390            // Update output layer
391            let output_spikes = self.output_layer.step(&output_currents, self.config.dt);
392
393            // STDP updates
394            for spike in &hidden_spikes {
395                self.w_ih.on_post_spike(spike.neuron_id, time);
396            }
397            for spike in &output_spikes {
398                self.w_ho.on_post_spike(spike.neuron_id, time);
399            }
400
401            // Return if we have output spikes
402            if !output_spikes.is_empty() {
403                return output_spikes;
404            }
405        }
406
407        Vec::new()
408    }
409
410    /// Apply reward-modulated STDP
411    pub fn apply_reward_modulated_stdp(&mut self, td_error: f64) {
412        self.w_ih.apply_reward(td_error);
413        self.w_ho.apply_reward(td_error);
414    }
415
416    /// Get regions with low activity (for search skip)
417    pub fn low_activity_regions(&self) -> Vec<usize> {
418        self.hidden_layer.spike_trains
419            .iter()
420            .enumerate()
421            .filter(|(_, t)| t.spike_rate(100.0) < 0.001)
422            .map(|(i, _)| i)
423            .collect()
424    }
425
426    /// Reset SNN state
427    pub fn reset(&mut self) {
428        self.input_layer.reset();
429        self.hidden_layer.reset();
430        self.output_layer.reset();
431    }
432}
433
434/// Neural Graph Optimizer combining policy and value networks
435pub struct NeuralGraphOptimizer {
436    /// Policy network: SNN that outputs graph modification actions
437    policy_snn: PolicySNN,
438    /// Value network: estimates mincut improvement
439    value_network: ValueNetwork,
440    /// Experience replay buffer
441    replay_buffer: PrioritizedReplayBuffer,
442    /// Current graph state
443    graph: DynamicGraph,
444    /// Configuration
445    config: OptimizerConfig,
446    /// Current simulation time
447    time: SimTime,
448    /// Previous mincut for reward computation
449    prev_mincut: f64,
450    /// Previous state for experience storage
451    prev_state: Vec<f64>,
452    /// Search statistics
453    search_latencies: VecDeque<f64>,
454}
455
456impl NeuralGraphOptimizer {
457    /// Create a new neural graph optimizer
458    pub fn new(graph: DynamicGraph, config: OptimizerConfig) -> Self {
459        let prev_state = extract_features(&graph, config.input_size);
460        let prev_mincut = estimate_mincut(&graph);
461
462        Self {
463            policy_snn: PolicySNN::new(config.clone()),
464            value_network: ValueNetwork::new(config.input_size, config.hidden_size),
465            replay_buffer: PrioritizedReplayBuffer::new(config.replay_buffer_size),
466            graph,
467            config,
468            time: 0.0,
469            prev_mincut,
470            prev_state,
471            search_latencies: VecDeque::with_capacity(100),
472        }
473    }
474
475    /// Run one optimization step
476    pub fn optimize_step(&mut self) -> OptimizationResult {
477        // 1. Encode current state as spike pattern
478        let state = extract_features(&self.graph, self.config.input_size);
479
480        // 2. Policy SNN outputs action distribution via spike rates
481        self.policy_snn.inject(&state);
482        let action_spikes = self.policy_snn.run_until_decision(50);
483        let action = self.decode_action(&action_spikes);
484
485        // 3. Execute action on graph
486        let old_mincut = estimate_mincut(&self.graph);
487        self.apply_action(&action);
488        let new_mincut = estimate_mincut(&self.graph);
489
490        // 4. Compute reward: mincut improvement + search efficiency
491        let mincut_reward = if old_mincut > 0.0 {
492            (new_mincut - old_mincut) / old_mincut
493        } else {
494            0.0
495        };
496
497        let search_reward = self.measure_search_efficiency();
498        let reward = mincut_reward + self.config.search_weight * search_reward;
499
500        // 5. TD learning update via spike-timing
501        let new_state = extract_features(&self.graph, self.config.input_size);
502        let current_value = self.value_network.estimate(&state);
503        let next_value = self.value_network.estimate(&new_state);
504
505        let td_error = reward + self.config.gamma * next_value - current_value;
506
507        // 6. STDP-based policy gradient
508        self.policy_snn.apply_reward_modulated_stdp(td_error);
509
510        // 7. Update value network
511        self.value_network.update(&state, td_error, self.config.learning_rate);
512
513        // 8. Store experience
514        let exp = Experience {
515            state: self.prev_state.clone(),
516            action_idx: action.to_index(),
517            reward,
518            next_state: new_state.clone(),
519            done: false,
520            td_error,
521        };
522        self.replay_buffer.push(exp);
523
524        // Update state
525        self.prev_state = new_state;
526        self.prev_mincut = new_mincut;
527        self.time += self.config.dt;
528
529        OptimizationResult {
530            action,
531            reward,
532            new_mincut,
533            search_latency: search_reward,
534        }
535    }
536
537    /// Decode action from spikes
538    fn decode_action(&self, spikes: &[Spike]) -> GraphAction {
539        if spikes.is_empty() {
540            return GraphAction::NoOp;
541        }
542
543        // Use first spike's neuron as action
544        let action_idx = spikes[0].neuron_id;
545
546        // Get random vertices for action
547        let vertices: Vec<_> = self.graph.vertices();
548
549        if vertices.len() < 2 {
550            return GraphAction::NoOp;
551        }
552
553        let v1 = vertices[action_idx % vertices.len()];
554        let v2 = vertices[(action_idx + 1) % vertices.len()];
555
556        match action_idx % 5 {
557            0 => {
558                if !self.graph.has_edge(v1, v2) {
559                    GraphAction::AddEdge(v1, v2, 1.0)
560                } else {
561                    GraphAction::NoOp
562                }
563            }
564            1 => {
565                if self.graph.has_edge(v1, v2) {
566                    GraphAction::RemoveEdge(v1, v2)
567                } else {
568                    GraphAction::NoOp
569                }
570            }
571            2 => GraphAction::Strengthen(v1, v2, 0.1),
572            3 => GraphAction::Weaken(v1, v2, 0.1),
573            _ => GraphAction::NoOp,
574        }
575    }
576
577    /// Apply action to graph
578    fn apply_action(&mut self, action: &GraphAction) {
579        match action {
580            GraphAction::AddEdge(u, v, w) => {
581                if !self.graph.has_edge(*u, *v) {
582                    let _ = self.graph.insert_edge(*u, *v, *w);
583                }
584            }
585            GraphAction::RemoveEdge(u, v) => {
586                let _ = self.graph.delete_edge(*u, *v);
587            }
588            GraphAction::Strengthen(u, v, delta) => {
589                if let Some(edge) = self.graph.get_edge(*u, *v) {
590                    let _ = self.graph.update_edge_weight(*u, *v, edge.weight + delta);
591                }
592            }
593            GraphAction::Weaken(u, v, delta) => {
594                if let Some(edge) = self.graph.get_edge(*u, *v) {
595                    let new_weight = (edge.weight - delta).max(0.01);
596                    let _ = self.graph.update_edge_weight(*u, *v, new_weight);
597                }
598            }
599            GraphAction::NoOp => {}
600        }
601    }
602
603    /// Measure search efficiency
604    fn measure_search_efficiency(&mut self) -> f64 {
605        // Simplified: based on graph connectivity
606        let n = self.graph.num_vertices() as f64;
607        let m = self.graph.num_edges() as f64;
608
609        if n < 2.0 {
610            return 0.0;
611        }
612
613        // Higher connectivity relative to vertices = better search
614        let efficiency = m / (n * (n - 1.0) / 2.0);
615
616        self.search_latencies.push_back(efficiency);
617        if self.search_latencies.len() > 100 {
618            self.search_latencies.pop_front();
619        }
620
621        efficiency
622    }
623
624    /// Get learned skip regions for subpolynomial search
625    pub fn search_skip_regions(&self) -> Vec<usize> {
626        self.policy_snn.low_activity_regions()
627    }
628
629    /// Search with learned structure
630    pub fn search(&self, query: &[f64], k: usize) -> Vec<VertexId> {
631        // Use skip regions to guide search
632        let skip_regions = self.search_skip_regions();
633
634        // Simple nearest neighbor in graph space
635        let vertices: Vec<_> = self.graph.vertices();
636
637        let mut scores: Vec<(VertexId, f64)> = vertices.iter()
638            .enumerate()
639            .filter(|(i, _)| !skip_regions.contains(i))
640            .map(|(i, &v)| {
641                // Score based on degree (proxy for centrality)
642                let score = self.graph.degree(v) as f64;
643                (v, score)
644            })
645            .collect();
646
647        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
648
649        scores.into_iter().take(k).map(|(v, _)| v).collect()
650    }
651
652    /// Get underlying graph
653    pub fn graph(&self) -> &DynamicGraph {
654        &self.graph
655    }
656
657    /// Get mutable graph
658    pub fn graph_mut(&mut self) -> &mut DynamicGraph {
659        &mut self.graph
660    }
661
662    /// Run multiple optimization steps
663    pub fn optimize(&mut self, steps: usize) -> Vec<OptimizationResult> {
664        (0..steps).map(|_| self.optimize_step()).collect()
665    }
666
667    /// Reset optimizer state
668    pub fn reset(&mut self) {
669        self.policy_snn.reset();
670        self.prev_mincut = estimate_mincut(&self.graph);
671        self.prev_state = extract_features(&self.graph, self.config.input_size);
672        self.time = 0.0;
673    }
674}
675
676/// Extract features from graph
677fn extract_features(graph: &DynamicGraph, num_features: usize) -> Vec<f64> {
678    let n = graph.num_vertices() as f64;
679    let m = graph.num_edges() as f64;
680
681    let mut features = vec![0.0; num_features];
682
683    if num_features > 0 {
684        features[0] = n / 1000.0;  // Normalized vertex count
685    }
686    if num_features > 1 {
687        features[1] = m / 5000.0;  // Normalized edge count
688    }
689    if num_features > 2 {
690        features[2] = if n > 1.0 { m / (n * (n - 1.0) / 2.0) } else { 0.0 };  // Density
691    }
692    if num_features > 3 {
693        // Average degree
694        let avg_deg: f64 = graph.vertices().iter()
695            .map(|&v| graph.degree(v) as f64)
696            .sum::<f64>() / n.max(1.0);
697        features[3] = avg_deg / 10.0;
698    }
699    if num_features > 4 {
700        features[4] = estimate_mincut(graph) / m.max(1.0);  // Normalized mincut
701    }
702
703    // Fill rest with zeros or derived features
704    for i in 5..num_features {
705        features[i] = features[i % 5] * 0.1;
706    }
707
708    features
709}
710
711/// Estimate mincut (simplified)
712fn estimate_mincut(graph: &DynamicGraph) -> f64 {
713    if graph.num_vertices() == 0 {
714        return 0.0;
715    }
716
717    graph.vertices()
718        .iter()
719        .map(|&v| graph.degree(v) as f64)
720        .fold(f64::INFINITY, f64::min)
721}
722
723// Thread-safe PRNG helpers using atomic CAS
724use std::sync::atomic::{AtomicU64, Ordering};
725static OPTIMIZER_RNG: AtomicU64 = AtomicU64::new(0xdeadbeef12345678);
726
727fn rand_small() -> f64 {
728    // Use compare_exchange loop to ensure atomicity
729    let state = loop {
730        let current = OPTIMIZER_RNG.load(Ordering::Relaxed);
731        let next = current.wrapping_mul(0x5851f42d4c957f2d).wrapping_add(1);
732        match OPTIMIZER_RNG.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) {
733            Ok(_) => break next,
734            Err(_) => continue,
735        }
736    };
737    (state as f64) / (u64::MAX as f64) * 0.4 - 0.2
738}
739
740fn relu(x: f64) -> f64 {
741    x.max(0.0)
742}
743
744#[cfg(test)]
745mod tests {
746    use super::*;
747
748    #[test]
749    fn test_value_network() {
750        let mut network = ValueNetwork::new(5, 10);
751
752        let state = vec![0.5, 0.3, 0.8, 0.2, 0.9];
753        let value = network.estimate(&state);
754
755        assert!(value.is_finite());
756    }
757
758    #[test]
759    fn test_policy_snn() {
760        let config = OptimizerConfig::default();
761        let mut policy = PolicySNN::new(config);
762
763        let state = vec![1.0; 10];
764        policy.inject(&state);
765
766        let spikes = policy.run_until_decision(100);
767        // May or may not spike
768        assert!(spikes.len() >= 0);
769    }
770
771    #[test]
772    fn test_neural_optimizer() {
773        let graph = DynamicGraph::new();
774        for i in 0..10 {
775            graph.insert_edge(i, (i + 1) % 10, 1.0).unwrap();
776        }
777
778        let config = OptimizerConfig::default();
779        let mut optimizer = NeuralGraphOptimizer::new(graph, config);
780
781        let result = optimizer.optimize_step();
782
783        assert!(result.new_mincut.is_finite());
784    }
785
786    #[test]
787    fn test_optimize_multiple() {
788        let graph = DynamicGraph::new();
789        for i in 0..5 {
790            for j in (i + 1)..5 {
791                graph.insert_edge(i, j, 1.0).unwrap();
792            }
793        }
794
795        let config = OptimizerConfig::default();
796        let mut optimizer = NeuralGraphOptimizer::new(graph, config);
797
798        let results = optimizer.optimize(10);
799        assert_eq!(results.len(), 10);
800    }
801
802    #[test]
803    fn test_search() {
804        let graph = DynamicGraph::new();
805        for i in 0..20 {
806            graph.insert_edge(i, (i + 1) % 20, 1.0).unwrap();
807        }
808
809        let config = OptimizerConfig::default();
810        let optimizer = NeuralGraphOptimizer::new(graph, config);
811
812        let results = optimizer.search(&[0.5; 10], 5);
813        assert!(results.len() <= 5);
814    }
815}