1use super::{
18 network::{LayerConfig, NetworkConfig, SpikingNetwork},
19 neuron::{LIFNeuron, NeuronConfig, NeuronPopulation},
20 synapse::{STDPConfig, Synapse, SynapseMatrix},
21 SimTime, Spike,
22};
23use crate::graph::{DynamicGraph, EdgeId, VertexId, Weight};
24use std::collections::VecDeque;
25
26#[derive(Debug, Clone)]
28pub struct OptimizerConfig {
29 pub input_size: usize,
31 pub hidden_size: usize,
33 pub num_actions: usize,
35 pub learning_rate: f64,
37 pub gamma: f64,
39 pub search_weight: f64,
41 pub replay_buffer_size: usize,
43 pub batch_size: usize,
45 pub dt: f64,
47}
48
49impl Default for OptimizerConfig {
50 fn default() -> Self {
51 Self {
52 input_size: 10,
53 hidden_size: 32,
54 num_actions: 5,
55 learning_rate: 0.01,
56 gamma: 0.99,
57 search_weight: 0.1,
58 replay_buffer_size: 10000,
59 batch_size: 32,
60 dt: 1.0,
61 }
62 }
63}
64
65#[derive(Debug, Clone, PartialEq)]
67pub enum GraphAction {
68 AddEdge(VertexId, VertexId, Weight),
70 RemoveEdge(VertexId, VertexId),
72 Strengthen(VertexId, VertexId, f64),
74 Weaken(VertexId, VertexId, f64),
76 NoOp,
78}
79
80impl GraphAction {
81 pub fn to_index(&self) -> usize {
83 match self {
84 GraphAction::AddEdge(..) => 0,
85 GraphAction::RemoveEdge(..) => 1,
86 GraphAction::Strengthen(..) => 2,
87 GraphAction::Weaken(..) => 3,
88 GraphAction::NoOp => 4,
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
95pub struct OptimizationResult {
96 pub action: GraphAction,
98 pub reward: f64,
100 pub new_mincut: f64,
102 pub search_latency: f64,
104}
105
106#[derive(Debug, Clone)]
108struct Experience {
109 state: Vec<f64>,
111 action_idx: usize,
113 reward: f64,
115 next_state: Vec<f64>,
117 done: bool,
119 td_error: f64,
121}
122
123struct PrioritizedReplayBuffer {
125 buffer: VecDeque<Experience>,
127 capacity: usize,
129}
130
131impl PrioritizedReplayBuffer {
132 fn new(capacity: usize) -> Self {
133 Self {
134 buffer: VecDeque::with_capacity(capacity),
135 capacity,
136 }
137 }
138
139 fn push(&mut self, exp: Experience) {
140 if self.buffer.len() >= self.capacity {
141 self.buffer.pop_front();
142 }
143 self.buffer.push_back(exp);
144 }
145
146 fn sample(&self, batch_size: usize) -> Vec<&Experience> {
147 let mut sorted: Vec<_> = self.buffer.iter().collect();
149 sorted.sort_by(|a, b| {
150 b.td_error
151 .abs()
152 .partial_cmp(&a.td_error.abs())
153 .unwrap_or(std::cmp::Ordering::Equal)
154 });
155
156 sorted.into_iter().take(batch_size).collect()
157 }
158
159 fn len(&self) -> usize {
160 self.buffer.len()
161 }
162}
163
164#[derive(Debug, Clone)]
166pub struct ValueNetwork {
167 w_hidden: Vec<Vec<f64>>,
169 b_hidden: Vec<f64>,
171 w_output: Vec<f64>,
173 b_output: f64,
175 last_estimate: f64,
177}
178
179impl ValueNetwork {
180 pub fn new(input_size: usize, hidden_size: usize) -> Self {
182 let scale = (2.0 / (input_size + hidden_size) as f64).sqrt();
184 let w_hidden: Vec<Vec<f64>> = (0..hidden_size)
185 .map(|_| (0..input_size).map(|_| rand_small() * scale).collect())
186 .collect();
187
188 let b_hidden = vec![0.0; hidden_size];
189
190 let output_scale = (1.0 / hidden_size as f64).sqrt();
191 let w_output: Vec<f64> = (0..hidden_size)
192 .map(|_| rand_small() * output_scale)
193 .collect();
194 let b_output = 0.0;
195
196 Self {
197 w_hidden,
198 b_hidden,
199 w_output,
200 b_output,
201 last_estimate: 0.0,
202 }
203 }
204
205 pub fn estimate(&mut self, state: &[f64]) -> f64 {
207 let mut hidden = vec![0.0; self.w_hidden.len()];
209 for (j, weights) in self.w_hidden.iter().enumerate() {
210 let mut sum = self.b_hidden[j];
211 for (i, &w) in weights.iter().enumerate() {
212 if i < state.len() {
213 sum += w * state[i];
214 }
215 }
216 hidden[j] = relu(sum);
217 }
218
219 let mut output = self.b_output;
221 for (j, &w) in self.w_output.iter().enumerate() {
222 output += w * hidden[j];
223 }
224
225 self.last_estimate = output;
226 output
227 }
228
229 pub fn estimate_previous(&self) -> f64 {
231 self.last_estimate
232 }
233
234 pub fn update(&mut self, state: &[f64], td_error: f64, lr: f64) {
241 let hidden_size = self.w_hidden.len();
242 let input_size = if self.w_hidden.is_empty() {
243 0
244 } else {
245 self.w_hidden[0].len()
246 };
247
248 let mut hidden_pre = vec![0.0; hidden_size]; let mut hidden_post = vec![0.0; hidden_size]; for (j, weights) in self.w_hidden.iter().enumerate() {
253 let mut sum = self.b_hidden[j];
254 for (i, &w) in weights.iter().enumerate() {
255 if i < state.len() {
256 sum += w * state[i];
257 }
258 }
259 hidden_pre[j] = sum;
260 hidden_post[j] = relu(sum);
261 }
262
263 for (j, w) in self.w_output.iter_mut().enumerate() {
269 *w += lr * td_error * hidden_post[j];
270 }
271 self.b_output += lr * td_error;
272
273 for (j, weights) in self.w_hidden.iter_mut().enumerate() {
279 let relu_grad = if hidden_pre[j] > 0.0 { 1.0 } else { 0.0 };
281 let delta = td_error * self.w_output[j] * relu_grad;
282
283 for (i, w) in weights.iter_mut().enumerate() {
284 if i < state.len() {
285 *w += lr * delta * state[i];
286 }
287 }
288 self.b_hidden[j] += lr * delta;
289 }
290 }
291}
292
293pub struct PolicySNN {
295 input_layer: NeuronPopulation,
297 hidden_layer: NeuronPopulation,
299 output_layer: NeuronPopulation,
301 w_ih: SynapseMatrix,
303 w_ho: SynapseMatrix,
305 stdp_config: STDPConfig,
307 config: OptimizerConfig,
309}
310
311impl PolicySNN {
312 pub fn new(config: OptimizerConfig) -> Self {
314 let input_config = NeuronConfig {
315 tau_membrane: 10.0,
316 threshold: 0.8,
317 ..NeuronConfig::default()
318 };
319
320 let hidden_config = NeuronConfig {
321 tau_membrane: 20.0,
322 threshold: 1.0,
323 ..NeuronConfig::default()
324 };
325
326 let output_config = NeuronConfig {
327 tau_membrane: 15.0,
328 threshold: 0.6,
329 ..NeuronConfig::default()
330 };
331
332 let input_layer = NeuronPopulation::with_config(config.input_size, input_config);
333 let hidden_layer = NeuronPopulation::with_config(config.hidden_size, hidden_config);
334 let output_layer = NeuronPopulation::with_config(config.num_actions, output_config);
335
336 let mut w_ih = SynapseMatrix::new(config.input_size, config.hidden_size);
338 let mut w_ho = SynapseMatrix::new(config.hidden_size, config.num_actions);
339
340 for i in 0..config.input_size {
342 for j in 0..config.hidden_size {
343 w_ih.add_synapse(i, j, rand_small() + 0.3);
344 }
345 }
346
347 for i in 0..config.hidden_size {
348 for j in 0..config.num_actions {
349 w_ho.add_synapse(i, j, rand_small() + 0.3);
350 }
351 }
352
353 Self {
354 input_layer,
355 hidden_layer,
356 output_layer,
357 w_ih,
358 w_ho,
359 stdp_config: STDPConfig::default(),
360 config,
361 }
362 }
363
364 pub fn inject(&mut self, state: &[f64]) {
366 for (i, neuron) in self.input_layer.neurons.iter_mut().enumerate() {
367 if i < state.len() {
368 neuron.set_membrane_potential(state[i]);
369 }
370 }
371 }
372
373 pub fn run_until_decision(&mut self, max_steps: usize) -> Vec<Spike> {
375 for step in 0..max_steps {
376 let time = step as f64 * self.config.dt;
377
378 let mut hidden_currents = vec![0.0; self.config.hidden_size];
380 for j in 0..self.config.hidden_size {
381 for i in 0..self.config.input_size {
382 hidden_currents[j] += self.w_ih.weight(i, j)
383 * self.input_layer.neurons[i].membrane_potential().max(0.0);
384 }
385 }
386
387 let hidden_spikes = self.hidden_layer.step(&hidden_currents, self.config.dt);
389
390 let mut output_currents = vec![0.0; self.config.num_actions];
392 for j in 0..self.config.num_actions {
393 for i in 0..self.config.hidden_size {
394 output_currents[j] += self.w_ho.weight(i, j)
395 * self.hidden_layer.neurons[i].membrane_potential().max(0.0);
396 }
397 }
398
399 let output_spikes = self.output_layer.step(&output_currents, self.config.dt);
401
402 for spike in &hidden_spikes {
404 self.w_ih.on_post_spike(spike.neuron_id, time);
405 }
406 for spike in &output_spikes {
407 self.w_ho.on_post_spike(spike.neuron_id, time);
408 }
409
410 if !output_spikes.is_empty() {
412 return output_spikes;
413 }
414 }
415
416 Vec::new()
417 }
418
419 pub fn apply_reward_modulated_stdp(&mut self, td_error: f64) {
421 self.w_ih.apply_reward(td_error);
422 self.w_ho.apply_reward(td_error);
423 }
424
425 pub fn low_activity_regions(&self) -> Vec<usize> {
427 self.hidden_layer
428 .spike_trains
429 .iter()
430 .enumerate()
431 .filter(|(_, t)| t.spike_rate(100.0) < 0.001)
432 .map(|(i, _)| i)
433 .collect()
434 }
435
436 pub fn reset(&mut self) {
438 self.input_layer.reset();
439 self.hidden_layer.reset();
440 self.output_layer.reset();
441 }
442}
443
444pub struct NeuralGraphOptimizer {
446 policy_snn: PolicySNN,
448 value_network: ValueNetwork,
450 replay_buffer: PrioritizedReplayBuffer,
452 graph: DynamicGraph,
454 config: OptimizerConfig,
456 time: SimTime,
458 prev_mincut: f64,
460 prev_state: Vec<f64>,
462 search_latencies: VecDeque<f64>,
464}
465
466impl NeuralGraphOptimizer {
467 pub fn new(graph: DynamicGraph, config: OptimizerConfig) -> Self {
469 let prev_state = extract_features(&graph, config.input_size);
470 let prev_mincut = estimate_mincut(&graph);
471
472 Self {
473 policy_snn: PolicySNN::new(config.clone()),
474 value_network: ValueNetwork::new(config.input_size, config.hidden_size),
475 replay_buffer: PrioritizedReplayBuffer::new(config.replay_buffer_size),
476 graph,
477 config,
478 time: 0.0,
479 prev_mincut,
480 prev_state,
481 search_latencies: VecDeque::with_capacity(100),
482 }
483 }
484
485 pub fn optimize_step(&mut self) -> OptimizationResult {
487 let state = extract_features(&self.graph, self.config.input_size);
489
490 self.policy_snn.inject(&state);
492 let action_spikes = self.policy_snn.run_until_decision(50);
493 let action = self.decode_action(&action_spikes);
494
495 let old_mincut = estimate_mincut(&self.graph);
497 self.apply_action(&action);
498 let new_mincut = estimate_mincut(&self.graph);
499
500 let mincut_reward = if old_mincut > 0.0 {
502 (new_mincut - old_mincut) / old_mincut
503 } else {
504 0.0
505 };
506
507 let search_reward = self.measure_search_efficiency();
508 let reward = mincut_reward + self.config.search_weight * search_reward;
509
510 let new_state = extract_features(&self.graph, self.config.input_size);
512 let current_value = self.value_network.estimate(&state);
513 let next_value = self.value_network.estimate(&new_state);
514
515 let td_error = reward + self.config.gamma * next_value - current_value;
516
517 self.policy_snn.apply_reward_modulated_stdp(td_error);
519
520 self.value_network
522 .update(&state, td_error, self.config.learning_rate);
523
524 let exp = Experience {
526 state: self.prev_state.clone(),
527 action_idx: action.to_index(),
528 reward,
529 next_state: new_state.clone(),
530 done: false,
531 td_error,
532 };
533 self.replay_buffer.push(exp);
534
535 self.prev_state = new_state;
537 self.prev_mincut = new_mincut;
538 self.time += self.config.dt;
539
540 OptimizationResult {
541 action,
542 reward,
543 new_mincut,
544 search_latency: search_reward,
545 }
546 }
547
548 fn decode_action(&self, spikes: &[Spike]) -> GraphAction {
550 if spikes.is_empty() {
551 return GraphAction::NoOp;
552 }
553
554 let action_idx = spikes[0].neuron_id;
556
557 let vertices: Vec<_> = self.graph.vertices();
559
560 if vertices.len() < 2 {
561 return GraphAction::NoOp;
562 }
563
564 let v1 = vertices[action_idx % vertices.len()];
565 let v2 = vertices[(action_idx + 1) % vertices.len()];
566
567 match action_idx % 5 {
568 0 => {
569 if !self.graph.has_edge(v1, v2) {
570 GraphAction::AddEdge(v1, v2, 1.0)
571 } else {
572 GraphAction::NoOp
573 }
574 }
575 1 => {
576 if self.graph.has_edge(v1, v2) {
577 GraphAction::RemoveEdge(v1, v2)
578 } else {
579 GraphAction::NoOp
580 }
581 }
582 2 => GraphAction::Strengthen(v1, v2, 0.1),
583 3 => GraphAction::Weaken(v1, v2, 0.1),
584 _ => GraphAction::NoOp,
585 }
586 }
587
588 fn apply_action(&mut self, action: &GraphAction) {
590 match action {
591 GraphAction::AddEdge(u, v, w) => {
592 if !self.graph.has_edge(*u, *v) {
593 let _ = self.graph.insert_edge(*u, *v, *w);
594 }
595 }
596 GraphAction::RemoveEdge(u, v) => {
597 let _ = self.graph.delete_edge(*u, *v);
598 }
599 GraphAction::Strengthen(u, v, delta) => {
600 if let Some(edge) = self.graph.get_edge(*u, *v) {
601 let _ = self.graph.update_edge_weight(*u, *v, edge.weight + delta);
602 }
603 }
604 GraphAction::Weaken(u, v, delta) => {
605 if let Some(edge) = self.graph.get_edge(*u, *v) {
606 let new_weight = (edge.weight - delta).max(0.01);
607 let _ = self.graph.update_edge_weight(*u, *v, new_weight);
608 }
609 }
610 GraphAction::NoOp => {}
611 }
612 }
613
614 fn measure_search_efficiency(&mut self) -> f64 {
616 let n = self.graph.num_vertices() as f64;
618 let m = self.graph.num_edges() as f64;
619
620 if n < 2.0 {
621 return 0.0;
622 }
623
624 let efficiency = m / (n * (n - 1.0) / 2.0);
626
627 self.search_latencies.push_back(efficiency);
628 if self.search_latencies.len() > 100 {
629 self.search_latencies.pop_front();
630 }
631
632 efficiency
633 }
634
635 pub fn search_skip_regions(&self) -> Vec<usize> {
637 self.policy_snn.low_activity_regions()
638 }
639
640 pub fn search(&self, query: &[f64], k: usize) -> Vec<VertexId> {
642 let skip_regions = self.search_skip_regions();
644
645 let vertices: Vec<_> = self.graph.vertices();
647
648 let mut scores: Vec<(VertexId, f64)> = vertices
649 .iter()
650 .enumerate()
651 .filter(|(i, _)| !skip_regions.contains(i))
652 .map(|(i, &v)| {
653 let score = self.graph.degree(v) as f64;
655 (v, score)
656 })
657 .collect();
658
659 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
660
661 scores.into_iter().take(k).map(|(v, _)| v).collect()
662 }
663
664 pub fn graph(&self) -> &DynamicGraph {
666 &self.graph
667 }
668
669 pub fn graph_mut(&mut self) -> &mut DynamicGraph {
671 &mut self.graph
672 }
673
674 pub fn optimize(&mut self, steps: usize) -> Vec<OptimizationResult> {
676 (0..steps).map(|_| self.optimize_step()).collect()
677 }
678
679 pub fn reset(&mut self) {
681 self.policy_snn.reset();
682 self.prev_mincut = estimate_mincut(&self.graph);
683 self.prev_state = extract_features(&self.graph, self.config.input_size);
684 self.time = 0.0;
685 }
686}
687
688fn extract_features(graph: &DynamicGraph, num_features: usize) -> Vec<f64> {
690 let n = graph.num_vertices() as f64;
691 let m = graph.num_edges() as f64;
692
693 let mut features = vec![0.0; num_features];
694
695 if num_features > 0 {
696 features[0] = n / 1000.0; }
698 if num_features > 1 {
699 features[1] = m / 5000.0; }
701 if num_features > 2 {
702 features[2] = if n > 1.0 {
703 m / (n * (n - 1.0) / 2.0)
704 } else {
705 0.0
706 }; }
708 if num_features > 3 {
709 let avg_deg: f64 = graph
711 .vertices()
712 .iter()
713 .map(|&v| graph.degree(v) as f64)
714 .sum::<f64>()
715 / n.max(1.0);
716 features[3] = avg_deg / 10.0;
717 }
718 if num_features > 4 {
719 features[4] = estimate_mincut(graph) / m.max(1.0); }
721
722 for i in 5..num_features {
724 features[i] = features[i % 5] * 0.1;
725 }
726
727 features
728}
729
730fn estimate_mincut(graph: &DynamicGraph) -> f64 {
732 if graph.num_vertices() == 0 {
733 return 0.0;
734 }
735
736 graph
737 .vertices()
738 .iter()
739 .map(|&v| graph.degree(v) as f64)
740 .fold(f64::INFINITY, f64::min)
741}
742
743use std::sync::atomic::{AtomicU64, Ordering};
745static OPTIMIZER_RNG: AtomicU64 = AtomicU64::new(0xdeadbeef12345678);
746
747fn rand_small() -> f64 {
748 let state = loop {
750 let current = OPTIMIZER_RNG.load(Ordering::Relaxed);
751 let next = current.wrapping_mul(0x5851f42d4c957f2d).wrapping_add(1);
752 match OPTIMIZER_RNG.compare_exchange_weak(
753 current,
754 next,
755 Ordering::Relaxed,
756 Ordering::Relaxed,
757 ) {
758 Ok(_) => break next,
759 Err(_) => continue,
760 }
761 };
762 (state as f64) / (u64::MAX as f64) * 0.4 - 0.2
763}
764
765fn relu(x: f64) -> f64 {
766 x.max(0.0)
767}
768
769#[cfg(test)]
770mod tests {
771 use super::*;
772
773 #[test]
774 fn test_value_network() {
775 let mut network = ValueNetwork::new(5, 10);
776
777 let state = vec![0.5, 0.3, 0.8, 0.2, 0.9];
778 let value = network.estimate(&state);
779
780 assert!(value.is_finite());
781 }
782
783 #[test]
784 fn test_policy_snn() {
785 let config = OptimizerConfig::default();
786 let mut policy = PolicySNN::new(config);
787
788 let state = vec![1.0; 10];
789 policy.inject(&state);
790
791 let spikes = policy.run_until_decision(100);
792 assert!(spikes.len() >= 0);
794 }
795
796 #[test]
797 fn test_neural_optimizer() {
798 let graph = DynamicGraph::new();
799 for i in 0..10 {
800 graph.insert_edge(i, (i + 1) % 10, 1.0).unwrap();
801 }
802
803 let config = OptimizerConfig::default();
804 let mut optimizer = NeuralGraphOptimizer::new(graph, config);
805
806 let result = optimizer.optimize_step();
807
808 assert!(result.new_mincut.is_finite());
809 }
810
811 #[test]
812 fn test_optimize_multiple() {
813 let graph = DynamicGraph::new();
814 for i in 0..5 {
815 for j in (i + 1)..5 {
816 graph.insert_edge(i, j, 1.0).unwrap();
817 }
818 }
819
820 let config = OptimizerConfig::default();
821 let mut optimizer = NeuralGraphOptimizer::new(graph, config);
822
823 let results = optimizer.optimize(10);
824 assert_eq!(results.len(), 10);
825 }
826
827 #[test]
828 fn test_search() {
829 let graph = DynamicGraph::new();
830 for i in 0..20 {
831 graph.insert_edge(i, (i + 1) % 20, 1.0).unwrap();
832 }
833
834 let config = OptimizerConfig::default();
835 let optimizer = NeuralGraphOptimizer::new(graph, config);
836
837 let results = optimizer.search(&[0.5; 10], 5);
838 assert!(results.len() <= 5);
839 }
840}