1use super::{
12 neuron::{LIFNeuron, NeuronConfig, NeuronPopulation, SpikeTrain},
13 synapse::{STDPConfig, Synapse, SynapseMatrix},
14 SimTime, Spike, Vector,
15};
16use crate::graph::DynamicGraph;
17use rayon::prelude::*;
18use std::collections::VecDeque;
19
20#[derive(Debug, Clone)]
22pub struct LayerConfig {
23 pub size: usize,
25 pub neuron_config: NeuronConfig,
27 pub recurrent: bool,
29}
30
31impl LayerConfig {
32 pub fn new(size: usize) -> Self {
34 Self {
35 size,
36 neuron_config: NeuronConfig::default(),
37 recurrent: false,
38 }
39 }
40
41 pub fn with_recurrence(mut self) -> Self {
43 self.recurrent = true;
44 self
45 }
46
47 pub fn with_neuron_config(mut self, config: NeuronConfig) -> Self {
49 self.neuron_config = config;
50 self
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct NetworkConfig {
57 pub layers: Vec<LayerConfig>,
59 pub stdp_config: STDPConfig,
61 pub dt: f64,
63 pub winner_take_all: bool,
65 pub wta_strength: f64,
67}
68
69impl Default for NetworkConfig {
70 fn default() -> Self {
71 Self {
72 layers: vec![
73 LayerConfig::new(100), LayerConfig::new(50), LayerConfig::new(10), ],
77 stdp_config: STDPConfig::default(),
78 dt: 1.0,
79 winner_take_all: false,
80 wta_strength: 0.8,
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct SpikingNetwork {
88 pub config: NetworkConfig,
90 layers: Vec<NeuronPopulation>,
92 feedforward_weights: Vec<SynapseMatrix>,
94 recurrent_weights: Vec<Option<SynapseMatrix>>,
96 time: SimTime,
98 spike_buffer: VecDeque<(Spike, usize, SimTime)>, global_inhibition: f64,
102}
103
104impl SpikingNetwork {
105 pub fn new(config: NetworkConfig) -> Self {
107 let mut layers = Vec::new();
108 let mut feedforward_weights = Vec::new();
109 let mut recurrent_weights = Vec::new();
110
111 for (i, layer_config) in config.layers.iter().enumerate() {
112 let population = NeuronPopulation::with_config(
114 layer_config.size,
115 layer_config.neuron_config.clone(),
116 );
117 layers.push(population);
118
119 if i + 1 < config.layers.len() {
121 let next_size = config.layers[i + 1].size;
122 let mut weights = SynapseMatrix::with_config(
123 layer_config.size,
124 next_size,
125 config.stdp_config.clone(),
126 );
127
128 for pre in 0..layer_config.size {
130 for post in 0..next_size {
131 let weight = rand_weight();
132 weights.add_synapse(pre, post, weight);
133 }
134 }
135
136 feedforward_weights.push(weights);
137 }
138
139 if layer_config.recurrent {
141 let mut weights = SynapseMatrix::with_config(
142 layer_config.size,
143 layer_config.size,
144 config.stdp_config.clone(),
145 );
146
147 for pre in 0..layer_config.size {
149 for post in 0..layer_config.size {
150 if pre != post && rand_bool(0.1) {
151 weights.add_synapse(pre, post, rand_weight() * 0.5);
152 }
153 }
154 }
155
156 recurrent_weights.push(Some(weights));
157 } else {
158 recurrent_weights.push(None);
159 }
160 }
161
162 Self {
163 config,
164 layers,
165 feedforward_weights,
166 recurrent_weights,
167 time: 0.0,
168 spike_buffer: VecDeque::new(),
169 global_inhibition: 0.0,
170 }
171 }
172
173 pub fn from_graph(graph: &DynamicGraph, config: NetworkConfig) -> Self {
175 let n = graph.num_vertices();
176
177 let mut network_config = config.clone();
179 network_config.layers = vec![LayerConfig::new(n).with_recurrence()];
180
181 let mut network = Self::new(network_config);
182
183 if let Some(ref mut recurrent) = network.recurrent_weights[0] {
185 let vertices: Vec<_> = graph.vertices();
186 let vertex_to_idx: std::collections::HashMap<_, _> =
187 vertices.iter().enumerate().map(|(i, &v)| (v, i)).collect();
188
189 for edge in graph.edges() {
190 if let (Some(&pre), Some(&post)) = (
191 vertex_to_idx.get(&edge.source),
192 vertex_to_idx.get(&edge.target),
193 ) {
194 recurrent.set_weight(pre, post, edge.weight);
195 recurrent.set_weight(post, pre, edge.weight); }
197 }
198 }
199
200 network
201 }
202
203 pub fn reset(&mut self) {
205 self.time = 0.0;
206 self.spike_buffer.clear();
207 self.global_inhibition = 0.0;
208
209 for layer in &mut self.layers {
210 layer.reset();
211 }
212 }
213
214 pub fn num_layers(&self) -> usize {
216 self.layers.len()
217 }
218
219 pub fn layer_size(&self, layer: usize) -> usize {
221 self.layers.get(layer).map(|l| l.size()).unwrap_or(0)
222 }
223
224 pub fn current_time(&self) -> SimTime {
226 self.time
227 }
228
229 pub fn inject_current(&mut self, currents: &[f64]) {
231 if !self.layers.is_empty() {
232 let input_layer = &mut self.layers[0];
233 let n = currents.len().min(input_layer.size());
234
235 for (i, neuron) in input_layer.neurons.iter_mut().take(n).enumerate() {
236 neuron.set_membrane_potential(neuron.membrane_potential() + currents[i] * 0.1);
237 }
238 }
239 }
240
241 pub fn step(&mut self) -> Vec<Spike> {
244 let dt = self.config.dt;
245 self.time += dt;
246
247 let mut all_spikes: Vec<Vec<Spike>> = Vec::new();
249
250 for layer_idx in 0..self.layers.len() {
252 let mut currents = vec![0.0; self.layers[layer_idx].size()];
254
255 if layer_idx > 0 {
257 let weights = &self.feedforward_weights[layer_idx - 1];
258 let pre_activations: Vec<f64> = self.layers[layer_idx - 1]
260 .neurons
261 .iter()
262 .map(|n| n.membrane_potential().max(0.0))
263 .collect();
264 let ff_currents = weights.compute_weighted_sums(&pre_activations);
266 for (j, &c) in ff_currents.iter().enumerate() {
267 currents[j] += c;
268 }
269 }
270
271 if let Some(ref weights) = self.recurrent_weights[layer_idx] {
273 let activations: Vec<f64> = self.layers[layer_idx]
275 .neurons
276 .iter()
277 .map(|n| n.membrane_potential().max(0.0))
278 .collect();
279 let rec_currents = weights.compute_weighted_sums(&activations);
281 for (j, &c) in rec_currents.iter().enumerate() {
282 currents[j] += c;
283 }
284 }
285
286 if self.config.winner_take_all && layer_idx == self.layers.len() - 1 {
288 let max_v = self.layers[layer_idx]
289 .neurons
290 .iter()
291 .map(|n| n.membrane_potential())
292 .fold(f64::NEG_INFINITY, f64::max);
293
294 for (i, neuron) in self.layers[layer_idx].neurons.iter().enumerate() {
295 if neuron.membrane_potential() < max_v {
296 currents[i] -= self.config.wta_strength * self.global_inhibition;
297 }
298 }
299 }
300
301 let spikes = self.layers[layer_idx].step(¤ts, dt);
303 all_spikes.push(spikes.clone());
304
305 if !spikes.is_empty() {
307 self.global_inhibition = (self.global_inhibition + 0.1).min(1.0);
308 } else {
309 self.global_inhibition *= 0.95;
310 }
311
312 if layer_idx > 0 {
314 for spike in &spikes {
315 self.feedforward_weights[layer_idx - 1]
316 .on_post_spike(spike.neuron_id, self.time);
317 }
318 }
319
320 if layer_idx + 1 < self.layers.len() {
321 for spike in &spikes {
322 self.feedforward_weights[layer_idx].on_pre_spike(spike.neuron_id, self.time);
323 }
324 }
325 }
326
327 all_spikes.last().cloned().unwrap_or_default()
329 }
330
331 pub fn run_until_decision(&mut self, max_steps: usize) -> Vec<Spike> {
333 for _ in 0..max_steps {
334 let spikes = self.step();
335 if !spikes.is_empty() {
336 return spikes;
337 }
338 }
339 Vec::new()
340 }
341
342 pub fn layer_rate(&self, layer: usize, window: f64) -> f64 {
344 self.layers
345 .get(layer)
346 .map(|l| l.population_rate(window))
347 .unwrap_or(0.0)
348 }
349
350 pub fn global_synchrony(&self) -> f64 {
352 let mut total_sync = 0.0;
353 let mut count = 0;
354
355 for layer in &self.layers {
356 total_sync += layer.synchrony(10.0);
357 count += 1;
358 }
359
360 if count > 0 {
361 total_sync / count as f64
362 } else {
363 0.0
364 }
365 }
366
367 pub fn synchrony_matrix(&self) -> Vec<Vec<f64>> {
369 let layer = &self.layers[0];
371 let n = layer.size();
372 let mut matrix = vec![vec![0.0; n]; n];
373
374 for i in 0..n {
375 for j in (i + 1)..n {
376 let corr =
377 layer.spike_trains[i].cross_correlation(&layer.spike_trains[j], 50.0, 5.0);
378 let sync = corr.iter().sum::<f64>() / corr.len() as f64;
379 matrix[i][j] = sync;
380 matrix[j][i] = sync;
381 }
382 matrix[i][i] = 1.0;
383 }
384
385 matrix
386 }
387
388 pub fn get_output(&self) -> Vec<f64> {
390 self.layers
391 .last()
392 .map(|l| l.neurons.iter().map(|n| n.membrane_potential()).collect())
393 .unwrap_or_default()
394 }
395
396 pub fn apply_reward(&mut self, reward: f64) {
398 for weights in &mut self.feedforward_weights {
399 weights.apply_reward(reward);
400 }
401 for weights in &mut self.recurrent_weights {
402 if let Some(w) = weights {
403 w.apply_reward(reward);
404 }
405 }
406 }
407
408 pub fn low_activity_regions(&self) -> Vec<usize> {
410 let mut low_activity = Vec::new();
411 let threshold = 0.001;
412
413 for (layer_idx, layer) in self.layers.iter().enumerate() {
414 for (neuron_idx, train) in layer.spike_trains.iter().enumerate() {
415 if train.spike_rate(100.0) < threshold {
416 low_activity.push(layer_idx * 1000 + neuron_idx);
417 }
418 }
419 }
420
421 low_activity
422 }
423
424 pub fn sync_to_graph(&self, graph: &mut DynamicGraph) {
426 if let Some(ref recurrent) = self.recurrent_weights.first().and_then(|r| r.as_ref()) {
427 let vertices: Vec<_> = graph.vertices();
428
429 for ((pre, post), synapse) in recurrent.iter() {
430 if *pre < vertices.len() && *post < vertices.len() {
431 let u = vertices[*pre];
432 let v = vertices[*post];
433 if graph.has_edge(u, v) {
434 let _ = graph.update_edge_weight(u, v, synapse.weight);
435 }
436 }
437 }
438 }
439 }
440}
441
442use std::sync::atomic::{AtomicU64, Ordering};
444static RNG_STATE: AtomicU64 = AtomicU64::new(0x853c49e6748fea9b);
445
446fn rand_u64() -> u64 {
447 loop {
449 let current = RNG_STATE.load(Ordering::Relaxed);
450 let next = current
451 .wrapping_mul(0x5851f42d4c957f2d)
452 .wrapping_add(0x14057b7ef767814f);
453 match RNG_STATE.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) {
454 Ok(_) => return next,
455 Err(_) => continue, }
457 }
458}
459
460fn rand_weight() -> f64 {
461 (rand_u64() as f64) / (u64::MAX as f64) * 0.5 + 0.25
462}
463
464fn rand_bool(p: f64) -> bool {
465 (rand_u64() as f64) / (u64::MAX as f64) < p
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471
472 #[test]
473 fn test_network_creation() {
474 let config = NetworkConfig::default();
475 let network = SpikingNetwork::new(config);
476
477 assert_eq!(network.num_layers(), 3);
478 assert_eq!(network.layer_size(0), 100);
479 assert_eq!(network.layer_size(1), 50);
480 assert_eq!(network.layer_size(2), 10);
481 }
482
483 #[test]
484 fn test_network_step() {
485 let config = NetworkConfig::default();
486 let mut network = SpikingNetwork::new(config);
487
488 let currents = vec![5.0; 100];
490 network.inject_current(¤ts);
491
492 let mut total_spikes = 0;
494 for _ in 0..100 {
495 let spikes = network.step();
496 total_spikes += spikes.len();
497 }
498
499 assert!(network.current_time() > 0.0);
501 }
502
503 #[test]
504 fn test_graph_network() {
505 use crate::graph::DynamicGraph;
506
507 let graph = DynamicGraph::new();
508 graph.insert_edge(0, 1, 1.0).unwrap();
509 graph.insert_edge(1, 2, 1.0).unwrap();
510 graph.insert_edge(2, 0, 1.0).unwrap();
511
512 let config = NetworkConfig::default();
513 let network = SpikingNetwork::from_graph(&graph, config);
514
515 assert_eq!(network.num_layers(), 1);
516 assert_eq!(network.layer_size(0), 3);
517 }
518
519 #[test]
520 fn test_synchrony_matrix() {
521 let mut config = NetworkConfig::default();
522 config.layers = vec![LayerConfig::new(5)];
523
524 let mut network = SpikingNetwork::new(config);
525
526 let currents = vec![2.0; 5];
528 for _ in 0..50 {
529 network.inject_current(¤ts);
530 network.step();
531 }
532
533 let sync = network.synchrony_matrix();
534 assert_eq!(sync.len(), 5);
535 assert_eq!(sync[0].len(), 5);
536
537 for i in 0..5 {
539 assert!((sync[i][i] - 1.0).abs() < 0.001);
540 }
541 }
542}