1use super::{
18 neuron::{LIFNeuron, NeuronConfig, NeuronPopulation},
19 synapse::{Synapse, SynapseMatrix, STDPConfig},
20 network::{SpikingNetwork, NetworkConfig, LayerConfig},
21 SimTime, Spike,
22};
23use crate::graph::{DynamicGraph, VertexId, EdgeId, Weight};
24use std::collections::VecDeque;
25
26#[derive(Debug, Clone)]
28pub struct OptimizerConfig {
29 pub input_size: usize,
31 pub hidden_size: usize,
33 pub num_actions: usize,
35 pub learning_rate: f64,
37 pub gamma: f64,
39 pub search_weight: f64,
41 pub replay_buffer_size: usize,
43 pub batch_size: usize,
45 pub dt: f64,
47}
48
49impl Default for OptimizerConfig {
50 fn default() -> Self {
51 Self {
52 input_size: 10,
53 hidden_size: 32,
54 num_actions: 5,
55 learning_rate: 0.01,
56 gamma: 0.99,
57 search_weight: 0.1,
58 replay_buffer_size: 10000,
59 batch_size: 32,
60 dt: 1.0,
61 }
62 }
63}
64
65#[derive(Debug, Clone, PartialEq)]
67pub enum GraphAction {
68 AddEdge(VertexId, VertexId, Weight),
70 RemoveEdge(VertexId, VertexId),
72 Strengthen(VertexId, VertexId, f64),
74 Weaken(VertexId, VertexId, f64),
76 NoOp,
78}
79
80impl GraphAction {
81 pub fn to_index(&self) -> usize {
83 match self {
84 GraphAction::AddEdge(..) => 0,
85 GraphAction::RemoveEdge(..) => 1,
86 GraphAction::Strengthen(..) => 2,
87 GraphAction::Weaken(..) => 3,
88 GraphAction::NoOp => 4,
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
95pub struct OptimizationResult {
96 pub action: GraphAction,
98 pub reward: f64,
100 pub new_mincut: f64,
102 pub search_latency: f64,
104}
105
106#[derive(Debug, Clone)]
108struct Experience {
109 state: Vec<f64>,
111 action_idx: usize,
113 reward: f64,
115 next_state: Vec<f64>,
117 done: bool,
119 td_error: f64,
121}
122
123struct PrioritizedReplayBuffer {
125 buffer: VecDeque<Experience>,
127 capacity: usize,
129}
130
131impl PrioritizedReplayBuffer {
132 fn new(capacity: usize) -> Self {
133 Self {
134 buffer: VecDeque::with_capacity(capacity),
135 capacity,
136 }
137 }
138
139 fn push(&mut self, exp: Experience) {
140 if self.buffer.len() >= self.capacity {
141 self.buffer.pop_front();
142 }
143 self.buffer.push_back(exp);
144 }
145
146 fn sample(&self, batch_size: usize) -> Vec<&Experience> {
147 let mut sorted: Vec<_> = self.buffer.iter().collect();
149 sorted.sort_by(|a, b| {
150 b.td_error.abs().partial_cmp(&a.td_error.abs()).unwrap_or(std::cmp::Ordering::Equal)
151 });
152
153 sorted.into_iter().take(batch_size).collect()
154 }
155
156 fn len(&self) -> usize {
157 self.buffer.len()
158 }
159}
160
161#[derive(Debug, Clone)]
163pub struct ValueNetwork {
164 w_hidden: Vec<Vec<f64>>,
166 b_hidden: Vec<f64>,
168 w_output: Vec<f64>,
170 b_output: f64,
172 last_estimate: f64,
174}
175
176impl ValueNetwork {
177 pub fn new(input_size: usize, hidden_size: usize) -> Self {
179 let scale = (2.0 / (input_size + hidden_size) as f64).sqrt();
181 let w_hidden: Vec<Vec<f64>> = (0..hidden_size)
182 .map(|_| (0..input_size).map(|_| rand_small() * scale).collect())
183 .collect();
184
185 let b_hidden = vec![0.0; hidden_size];
186
187 let output_scale = (1.0 / hidden_size as f64).sqrt();
188 let w_output: Vec<f64> = (0..hidden_size).map(|_| rand_small() * output_scale).collect();
189 let b_output = 0.0;
190
191 Self {
192 w_hidden,
193 b_hidden,
194 w_output,
195 b_output,
196 last_estimate: 0.0,
197 }
198 }
199
200 pub fn estimate(&mut self, state: &[f64]) -> f64 {
202 let mut hidden = vec![0.0; self.w_hidden.len()];
204 for (j, weights) in self.w_hidden.iter().enumerate() {
205 let mut sum = self.b_hidden[j];
206 for (i, &w) in weights.iter().enumerate() {
207 if i < state.len() {
208 sum += w * state[i];
209 }
210 }
211 hidden[j] = relu(sum);
212 }
213
214 let mut output = self.b_output;
216 for (j, &w) in self.w_output.iter().enumerate() {
217 output += w * hidden[j];
218 }
219
220 self.last_estimate = output;
221 output
222 }
223
224 pub fn estimate_previous(&self) -> f64 {
226 self.last_estimate
227 }
228
229 pub fn update(&mut self, state: &[f64], td_error: f64, lr: f64) {
236 let hidden_size = self.w_hidden.len();
237 let input_size = if self.w_hidden.is_empty() { 0 } else { self.w_hidden[0].len() };
238
239 let mut hidden_pre = vec![0.0; hidden_size]; let mut hidden_post = vec![0.0; hidden_size]; for (j, weights) in self.w_hidden.iter().enumerate() {
244 let mut sum = self.b_hidden[j];
245 for (i, &w) in weights.iter().enumerate() {
246 if i < state.len() {
247 sum += w * state[i];
248 }
249 }
250 hidden_pre[j] = sum;
251 hidden_post[j] = relu(sum);
252 }
253
254 for (j, w) in self.w_output.iter_mut().enumerate() {
260 *w += lr * td_error * hidden_post[j];
261 }
262 self.b_output += lr * td_error;
263
264 for (j, weights) in self.w_hidden.iter_mut().enumerate() {
270 let relu_grad = if hidden_pre[j] > 0.0 { 1.0 } else { 0.0 };
272 let delta = td_error * self.w_output[j] * relu_grad;
273
274 for (i, w) in weights.iter_mut().enumerate() {
275 if i < state.len() {
276 *w += lr * delta * state[i];
277 }
278 }
279 self.b_hidden[j] += lr * delta;
280 }
281 }
282}
283
284pub struct PolicySNN {
286 input_layer: NeuronPopulation,
288 hidden_layer: NeuronPopulation,
290 output_layer: NeuronPopulation,
292 w_ih: SynapseMatrix,
294 w_ho: SynapseMatrix,
296 stdp_config: STDPConfig,
298 config: OptimizerConfig,
300}
301
302impl PolicySNN {
303 pub fn new(config: OptimizerConfig) -> Self {
305 let input_config = NeuronConfig {
306 tau_membrane: 10.0,
307 threshold: 0.8,
308 ..NeuronConfig::default()
309 };
310
311 let hidden_config = NeuronConfig {
312 tau_membrane: 20.0,
313 threshold: 1.0,
314 ..NeuronConfig::default()
315 };
316
317 let output_config = NeuronConfig {
318 tau_membrane: 15.0,
319 threshold: 0.6,
320 ..NeuronConfig::default()
321 };
322
323 let input_layer = NeuronPopulation::with_config(config.input_size, input_config);
324 let hidden_layer = NeuronPopulation::with_config(config.hidden_size, hidden_config);
325 let output_layer = NeuronPopulation::with_config(config.num_actions, output_config);
326
327 let mut w_ih = SynapseMatrix::new(config.input_size, config.hidden_size);
329 let mut w_ho = SynapseMatrix::new(config.hidden_size, config.num_actions);
330
331 for i in 0..config.input_size {
333 for j in 0..config.hidden_size {
334 w_ih.add_synapse(i, j, rand_small() + 0.3);
335 }
336 }
337
338 for i in 0..config.hidden_size {
339 for j in 0..config.num_actions {
340 w_ho.add_synapse(i, j, rand_small() + 0.3);
341 }
342 }
343
344 Self {
345 input_layer,
346 hidden_layer,
347 output_layer,
348 w_ih,
349 w_ho,
350 stdp_config: STDPConfig::default(),
351 config,
352 }
353 }
354
355 pub fn inject(&mut self, state: &[f64]) {
357 for (i, neuron) in self.input_layer.neurons.iter_mut().enumerate() {
358 if i < state.len() {
359 neuron.set_membrane_potential(state[i]);
360 }
361 }
362 }
363
364 pub fn run_until_decision(&mut self, max_steps: usize) -> Vec<Spike> {
366 for step in 0..max_steps {
367 let time = step as f64 * self.config.dt;
368
369 let mut hidden_currents = vec![0.0; self.config.hidden_size];
371 for j in 0..self.config.hidden_size {
372 for i in 0..self.config.input_size {
373 hidden_currents[j] += self.w_ih.weight(i, j) *
374 self.input_layer.neurons[i].membrane_potential().max(0.0);
375 }
376 }
377
378 let hidden_spikes = self.hidden_layer.step(&hidden_currents, self.config.dt);
380
381 let mut output_currents = vec![0.0; self.config.num_actions];
383 for j in 0..self.config.num_actions {
384 for i in 0..self.config.hidden_size {
385 output_currents[j] += self.w_ho.weight(i, j) *
386 self.hidden_layer.neurons[i].membrane_potential().max(0.0);
387 }
388 }
389
390 let output_spikes = self.output_layer.step(&output_currents, self.config.dt);
392
393 for spike in &hidden_spikes {
395 self.w_ih.on_post_spike(spike.neuron_id, time);
396 }
397 for spike in &output_spikes {
398 self.w_ho.on_post_spike(spike.neuron_id, time);
399 }
400
401 if !output_spikes.is_empty() {
403 return output_spikes;
404 }
405 }
406
407 Vec::new()
408 }
409
410 pub fn apply_reward_modulated_stdp(&mut self, td_error: f64) {
412 self.w_ih.apply_reward(td_error);
413 self.w_ho.apply_reward(td_error);
414 }
415
416 pub fn low_activity_regions(&self) -> Vec<usize> {
418 self.hidden_layer.spike_trains
419 .iter()
420 .enumerate()
421 .filter(|(_, t)| t.spike_rate(100.0) < 0.001)
422 .map(|(i, _)| i)
423 .collect()
424 }
425
426 pub fn reset(&mut self) {
428 self.input_layer.reset();
429 self.hidden_layer.reset();
430 self.output_layer.reset();
431 }
432}
433
434pub struct NeuralGraphOptimizer {
436 policy_snn: PolicySNN,
438 value_network: ValueNetwork,
440 replay_buffer: PrioritizedReplayBuffer,
442 graph: DynamicGraph,
444 config: OptimizerConfig,
446 time: SimTime,
448 prev_mincut: f64,
450 prev_state: Vec<f64>,
452 search_latencies: VecDeque<f64>,
454}
455
456impl NeuralGraphOptimizer {
457 pub fn new(graph: DynamicGraph, config: OptimizerConfig) -> Self {
459 let prev_state = extract_features(&graph, config.input_size);
460 let prev_mincut = estimate_mincut(&graph);
461
462 Self {
463 policy_snn: PolicySNN::new(config.clone()),
464 value_network: ValueNetwork::new(config.input_size, config.hidden_size),
465 replay_buffer: PrioritizedReplayBuffer::new(config.replay_buffer_size),
466 graph,
467 config,
468 time: 0.0,
469 prev_mincut,
470 prev_state,
471 search_latencies: VecDeque::with_capacity(100),
472 }
473 }
474
475 pub fn optimize_step(&mut self) -> OptimizationResult {
477 let state = extract_features(&self.graph, self.config.input_size);
479
480 self.policy_snn.inject(&state);
482 let action_spikes = self.policy_snn.run_until_decision(50);
483 let action = self.decode_action(&action_spikes);
484
485 let old_mincut = estimate_mincut(&self.graph);
487 self.apply_action(&action);
488 let new_mincut = estimate_mincut(&self.graph);
489
490 let mincut_reward = if old_mincut > 0.0 {
492 (new_mincut - old_mincut) / old_mincut
493 } else {
494 0.0
495 };
496
497 let search_reward = self.measure_search_efficiency();
498 let reward = mincut_reward + self.config.search_weight * search_reward;
499
500 let new_state = extract_features(&self.graph, self.config.input_size);
502 let current_value = self.value_network.estimate(&state);
503 let next_value = self.value_network.estimate(&new_state);
504
505 let td_error = reward + self.config.gamma * next_value - current_value;
506
507 self.policy_snn.apply_reward_modulated_stdp(td_error);
509
510 self.value_network.update(&state, td_error, self.config.learning_rate);
512
513 let exp = Experience {
515 state: self.prev_state.clone(),
516 action_idx: action.to_index(),
517 reward,
518 next_state: new_state.clone(),
519 done: false,
520 td_error,
521 };
522 self.replay_buffer.push(exp);
523
524 self.prev_state = new_state;
526 self.prev_mincut = new_mincut;
527 self.time += self.config.dt;
528
529 OptimizationResult {
530 action,
531 reward,
532 new_mincut,
533 search_latency: search_reward,
534 }
535 }
536
537 fn decode_action(&self, spikes: &[Spike]) -> GraphAction {
539 if spikes.is_empty() {
540 return GraphAction::NoOp;
541 }
542
543 let action_idx = spikes[0].neuron_id;
545
546 let vertices: Vec<_> = self.graph.vertices();
548
549 if vertices.len() < 2 {
550 return GraphAction::NoOp;
551 }
552
553 let v1 = vertices[action_idx % vertices.len()];
554 let v2 = vertices[(action_idx + 1) % vertices.len()];
555
556 match action_idx % 5 {
557 0 => {
558 if !self.graph.has_edge(v1, v2) {
559 GraphAction::AddEdge(v1, v2, 1.0)
560 } else {
561 GraphAction::NoOp
562 }
563 }
564 1 => {
565 if self.graph.has_edge(v1, v2) {
566 GraphAction::RemoveEdge(v1, v2)
567 } else {
568 GraphAction::NoOp
569 }
570 }
571 2 => GraphAction::Strengthen(v1, v2, 0.1),
572 3 => GraphAction::Weaken(v1, v2, 0.1),
573 _ => GraphAction::NoOp,
574 }
575 }
576
577 fn apply_action(&mut self, action: &GraphAction) {
579 match action {
580 GraphAction::AddEdge(u, v, w) => {
581 if !self.graph.has_edge(*u, *v) {
582 let _ = self.graph.insert_edge(*u, *v, *w);
583 }
584 }
585 GraphAction::RemoveEdge(u, v) => {
586 let _ = self.graph.delete_edge(*u, *v);
587 }
588 GraphAction::Strengthen(u, v, delta) => {
589 if let Some(edge) = self.graph.get_edge(*u, *v) {
590 let _ = self.graph.update_edge_weight(*u, *v, edge.weight + delta);
591 }
592 }
593 GraphAction::Weaken(u, v, delta) => {
594 if let Some(edge) = self.graph.get_edge(*u, *v) {
595 let new_weight = (edge.weight - delta).max(0.01);
596 let _ = self.graph.update_edge_weight(*u, *v, new_weight);
597 }
598 }
599 GraphAction::NoOp => {}
600 }
601 }
602
603 fn measure_search_efficiency(&mut self) -> f64 {
605 let n = self.graph.num_vertices() as f64;
607 let m = self.graph.num_edges() as f64;
608
609 if n < 2.0 {
610 return 0.0;
611 }
612
613 let efficiency = m / (n * (n - 1.0) / 2.0);
615
616 self.search_latencies.push_back(efficiency);
617 if self.search_latencies.len() > 100 {
618 self.search_latencies.pop_front();
619 }
620
621 efficiency
622 }
623
624 pub fn search_skip_regions(&self) -> Vec<usize> {
626 self.policy_snn.low_activity_regions()
627 }
628
629 pub fn search(&self, query: &[f64], k: usize) -> Vec<VertexId> {
631 let skip_regions = self.search_skip_regions();
633
634 let vertices: Vec<_> = self.graph.vertices();
636
637 let mut scores: Vec<(VertexId, f64)> = vertices.iter()
638 .enumerate()
639 .filter(|(i, _)| !skip_regions.contains(i))
640 .map(|(i, &v)| {
641 let score = self.graph.degree(v) as f64;
643 (v, score)
644 })
645 .collect();
646
647 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
648
649 scores.into_iter().take(k).map(|(v, _)| v).collect()
650 }
651
652 pub fn graph(&self) -> &DynamicGraph {
654 &self.graph
655 }
656
657 pub fn graph_mut(&mut self) -> &mut DynamicGraph {
659 &mut self.graph
660 }
661
662 pub fn optimize(&mut self, steps: usize) -> Vec<OptimizationResult> {
664 (0..steps).map(|_| self.optimize_step()).collect()
665 }
666
667 pub fn reset(&mut self) {
669 self.policy_snn.reset();
670 self.prev_mincut = estimate_mincut(&self.graph);
671 self.prev_state = extract_features(&self.graph, self.config.input_size);
672 self.time = 0.0;
673 }
674}
675
676fn extract_features(graph: &DynamicGraph, num_features: usize) -> Vec<f64> {
678 let n = graph.num_vertices() as f64;
679 let m = graph.num_edges() as f64;
680
681 let mut features = vec![0.0; num_features];
682
683 if num_features > 0 {
684 features[0] = n / 1000.0; }
686 if num_features > 1 {
687 features[1] = m / 5000.0; }
689 if num_features > 2 {
690 features[2] = if n > 1.0 { m / (n * (n - 1.0) / 2.0) } else { 0.0 }; }
692 if num_features > 3 {
693 let avg_deg: f64 = graph.vertices().iter()
695 .map(|&v| graph.degree(v) as f64)
696 .sum::<f64>() / n.max(1.0);
697 features[3] = avg_deg / 10.0;
698 }
699 if num_features > 4 {
700 features[4] = estimate_mincut(graph) / m.max(1.0); }
702
703 for i in 5..num_features {
705 features[i] = features[i % 5] * 0.1;
706 }
707
708 features
709}
710
711fn estimate_mincut(graph: &DynamicGraph) -> f64 {
713 if graph.num_vertices() == 0 {
714 return 0.0;
715 }
716
717 graph.vertices()
718 .iter()
719 .map(|&v| graph.degree(v) as f64)
720 .fold(f64::INFINITY, f64::min)
721}
722
723use std::sync::atomic::{AtomicU64, Ordering};
725static OPTIMIZER_RNG: AtomicU64 = AtomicU64::new(0xdeadbeef12345678);
726
727fn rand_small() -> f64 {
728 let state = loop {
730 let current = OPTIMIZER_RNG.load(Ordering::Relaxed);
731 let next = current.wrapping_mul(0x5851f42d4c957f2d).wrapping_add(1);
732 match OPTIMIZER_RNG.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) {
733 Ok(_) => break next,
734 Err(_) => continue,
735 }
736 };
737 (state as f64) / (u64::MAX as f64) * 0.4 - 0.2
738}
739
740fn relu(x: f64) -> f64 {
741 x.max(0.0)
742}
743
744#[cfg(test)]
745mod tests {
746 use super::*;
747
748 #[test]
749 fn test_value_network() {
750 let mut network = ValueNetwork::new(5, 10);
751
752 let state = vec![0.5, 0.3, 0.8, 0.2, 0.9];
753 let value = network.estimate(&state);
754
755 assert!(value.is_finite());
756 }
757
758 #[test]
759 fn test_policy_snn() {
760 let config = OptimizerConfig::default();
761 let mut policy = PolicySNN::new(config);
762
763 let state = vec![1.0; 10];
764 policy.inject(&state);
765
766 let spikes = policy.run_until_decision(100);
767 assert!(spikes.len() >= 0);
769 }
770
771 #[test]
772 fn test_neural_optimizer() {
773 let graph = DynamicGraph::new();
774 for i in 0..10 {
775 graph.insert_edge(i, (i + 1) % 10, 1.0).unwrap();
776 }
777
778 let config = OptimizerConfig::default();
779 let mut optimizer = NeuralGraphOptimizer::new(graph, config);
780
781 let result = optimizer.optimize_step();
782
783 assert!(result.new_mincut.is_finite());
784 }
785
786 #[test]
787 fn test_optimize_multiple() {
788 let graph = DynamicGraph::new();
789 for i in 0..5 {
790 for j in (i + 1)..5 {
791 graph.insert_edge(i, j, 1.0).unwrap();
792 }
793 }
794
795 let config = OptimizerConfig::default();
796 let mut optimizer = NeuralGraphOptimizer::new(graph, config);
797
798 let results = optimizer.optimize(10);
799 assert_eq!(results.len(), 10);
800 }
801
802 #[test]
803 fn test_search() {
804 let graph = DynamicGraph::new();
805 for i in 0..20 {
806 graph.insert_edge(i, (i + 1) % 20, 1.0).unwrap();
807 }
808
809 let config = OptimizerConfig::default();
810 let optimizer = NeuralGraphOptimizer::new(graph, config);
811
812 let results = optimizer.search(&[0.5; 10], 5);
813 assert!(results.len() <= 5);
814 }
815}