1use super::{
12 neuron::{LIFNeuron, NeuronConfig, NeuronPopulation, SpikeTrain},
13 synapse::{Synapse, SynapseMatrix, STDPConfig},
14 SimTime, Spike, Vector,
15};
16use crate::graph::DynamicGraph;
17use rayon::prelude::*;
18use std::collections::VecDeque;
19
20#[derive(Debug, Clone)]
22pub struct LayerConfig {
23 pub size: usize,
25 pub neuron_config: NeuronConfig,
27 pub recurrent: bool,
29}
30
31impl LayerConfig {
32 pub fn new(size: usize) -> Self {
34 Self {
35 size,
36 neuron_config: NeuronConfig::default(),
37 recurrent: false,
38 }
39 }
40
41 pub fn with_recurrence(mut self) -> Self {
43 self.recurrent = true;
44 self
45 }
46
47 pub fn with_neuron_config(mut self, config: NeuronConfig) -> Self {
49 self.neuron_config = config;
50 self
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct NetworkConfig {
57 pub layers: Vec<LayerConfig>,
59 pub stdp_config: STDPConfig,
61 pub dt: f64,
63 pub winner_take_all: bool,
65 pub wta_strength: f64,
67}
68
69impl Default for NetworkConfig {
70 fn default() -> Self {
71 Self {
72 layers: vec![
73 LayerConfig::new(100), LayerConfig::new(50), LayerConfig::new(10), ],
77 stdp_config: STDPConfig::default(),
78 dt: 1.0,
79 winner_take_all: false,
80 wta_strength: 0.8,
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct SpikingNetwork {
88 pub config: NetworkConfig,
90 layers: Vec<NeuronPopulation>,
92 feedforward_weights: Vec<SynapseMatrix>,
94 recurrent_weights: Vec<Option<SynapseMatrix>>,
96 time: SimTime,
98 spike_buffer: VecDeque<(Spike, usize, SimTime)>, global_inhibition: f64,
102}
103
104impl SpikingNetwork {
105 pub fn new(config: NetworkConfig) -> Self {
107 let mut layers = Vec::new();
108 let mut feedforward_weights = Vec::new();
109 let mut recurrent_weights = Vec::new();
110
111 for (i, layer_config) in config.layers.iter().enumerate() {
112 let population = NeuronPopulation::with_config(
114 layer_config.size,
115 layer_config.neuron_config.clone(),
116 );
117 layers.push(population);
118
119 if i + 1 < config.layers.len() {
121 let next_size = config.layers[i + 1].size;
122 let mut weights = SynapseMatrix::with_config(
123 layer_config.size,
124 next_size,
125 config.stdp_config.clone(),
126 );
127
128 for pre in 0..layer_config.size {
130 for post in 0..next_size {
131 let weight = rand_weight();
132 weights.add_synapse(pre, post, weight);
133 }
134 }
135
136 feedforward_weights.push(weights);
137 }
138
139 if layer_config.recurrent {
141 let mut weights = SynapseMatrix::with_config(
142 layer_config.size,
143 layer_config.size,
144 config.stdp_config.clone(),
145 );
146
147 for pre in 0..layer_config.size {
149 for post in 0..layer_config.size {
150 if pre != post && rand_bool(0.1) {
151 weights.add_synapse(pre, post, rand_weight() * 0.5);
152 }
153 }
154 }
155
156 recurrent_weights.push(Some(weights));
157 } else {
158 recurrent_weights.push(None);
159 }
160 }
161
162 Self {
163 config,
164 layers,
165 feedforward_weights,
166 recurrent_weights,
167 time: 0.0,
168 spike_buffer: VecDeque::new(),
169 global_inhibition: 0.0,
170 }
171 }
172
173 pub fn from_graph(graph: &DynamicGraph, config: NetworkConfig) -> Self {
175 let n = graph.num_vertices();
176
177 let mut network_config = config.clone();
179 network_config.layers = vec![LayerConfig::new(n).with_recurrence()];
180
181 let mut network = Self::new(network_config);
182
183 if let Some(ref mut recurrent) = network.recurrent_weights[0] {
185 let vertices: Vec<_> = graph.vertices();
186 let vertex_to_idx: std::collections::HashMap<_, _> = vertices
187 .iter()
188 .enumerate()
189 .map(|(i, &v)| (v, i))
190 .collect();
191
192 for edge in graph.edges() {
193 if let (Some(&pre), Some(&post)) = (
194 vertex_to_idx.get(&edge.source),
195 vertex_to_idx.get(&edge.target),
196 ) {
197 recurrent.set_weight(pre, post, edge.weight);
198 recurrent.set_weight(post, pre, edge.weight); }
200 }
201 }
202
203 network
204 }
205
206 pub fn reset(&mut self) {
208 self.time = 0.0;
209 self.spike_buffer.clear();
210 self.global_inhibition = 0.0;
211
212 for layer in &mut self.layers {
213 layer.reset();
214 }
215 }
216
217 pub fn num_layers(&self) -> usize {
219 self.layers.len()
220 }
221
222 pub fn layer_size(&self, layer: usize) -> usize {
224 self.layers.get(layer).map(|l| l.size()).unwrap_or(0)
225 }
226
227 pub fn current_time(&self) -> SimTime {
229 self.time
230 }
231
232 pub fn inject_current(&mut self, currents: &[f64]) {
234 if !self.layers.is_empty() {
235 let input_layer = &mut self.layers[0];
236 let n = currents.len().min(input_layer.size());
237
238 for (i, neuron) in input_layer.neurons.iter_mut().take(n).enumerate() {
239 neuron.set_membrane_potential(
240 neuron.membrane_potential() + currents[i] * 0.1
241 );
242 }
243 }
244 }
245
246 pub fn step(&mut self) -> Vec<Spike> {
249 let dt = self.config.dt;
250 self.time += dt;
251
252 let mut all_spikes: Vec<Vec<Spike>> = Vec::new();
254
255 for layer_idx in 0..self.layers.len() {
257 let mut currents = vec![0.0; self.layers[layer_idx].size()];
259
260 if layer_idx > 0 {
262 let weights = &self.feedforward_weights[layer_idx - 1];
263 let pre_activations: Vec<f64> = self.layers[layer_idx - 1]
265 .neurons
266 .iter()
267 .map(|n| n.membrane_potential().max(0.0))
268 .collect();
269 let ff_currents = weights.compute_weighted_sums(&pre_activations);
271 for (j, &c) in ff_currents.iter().enumerate() {
272 currents[j] += c;
273 }
274 }
275
276 if let Some(ref weights) = self.recurrent_weights[layer_idx] {
278 let activations: Vec<f64> = self.layers[layer_idx]
280 .neurons
281 .iter()
282 .map(|n| n.membrane_potential().max(0.0))
283 .collect();
284 let rec_currents = weights.compute_weighted_sums(&activations);
286 for (j, &c) in rec_currents.iter().enumerate() {
287 currents[j] += c;
288 }
289 }
290
291 if self.config.winner_take_all && layer_idx == self.layers.len() - 1 {
293 let max_v = self.layers[layer_idx]
294 .neurons
295 .iter()
296 .map(|n| n.membrane_potential())
297 .fold(f64::NEG_INFINITY, f64::max);
298
299 for (i, neuron) in self.layers[layer_idx].neurons.iter().enumerate() {
300 if neuron.membrane_potential() < max_v {
301 currents[i] -= self.config.wta_strength * self.global_inhibition;
302 }
303 }
304 }
305
306 let spikes = self.layers[layer_idx].step(¤ts, dt);
308 all_spikes.push(spikes.clone());
309
310 if !spikes.is_empty() {
312 self.global_inhibition = (self.global_inhibition + 0.1).min(1.0);
313 } else {
314 self.global_inhibition *= 0.95;
315 }
316
317 if layer_idx > 0 {
319 for spike in &spikes {
320 self.feedforward_weights[layer_idx - 1].on_post_spike(spike.neuron_id, self.time);
321 }
322 }
323
324 if layer_idx + 1 < self.layers.len() {
325 for spike in &spikes {
326 self.feedforward_weights[layer_idx].on_pre_spike(spike.neuron_id, self.time);
327 }
328 }
329 }
330
331 all_spikes.last().cloned().unwrap_or_default()
333 }
334
335 pub fn run_until_decision(&mut self, max_steps: usize) -> Vec<Spike> {
337 for _ in 0..max_steps {
338 let spikes = self.step();
339 if !spikes.is_empty() {
340 return spikes;
341 }
342 }
343 Vec::new()
344 }
345
346 pub fn layer_rate(&self, layer: usize, window: f64) -> f64 {
348 self.layers
349 .get(layer)
350 .map(|l| l.population_rate(window))
351 .unwrap_or(0.0)
352 }
353
354 pub fn global_synchrony(&self) -> f64 {
356 let mut total_sync = 0.0;
357 let mut count = 0;
358
359 for layer in &self.layers {
360 total_sync += layer.synchrony(10.0);
361 count += 1;
362 }
363
364 if count > 0 {
365 total_sync / count as f64
366 } else {
367 0.0
368 }
369 }
370
371 pub fn synchrony_matrix(&self) -> Vec<Vec<f64>> {
373 let layer = &self.layers[0];
375 let n = layer.size();
376 let mut matrix = vec![vec![0.0; n]; n];
377
378 for i in 0..n {
379 for j in (i + 1)..n {
380 let corr = layer.spike_trains[i].cross_correlation(
381 &layer.spike_trains[j],
382 50.0,
383 5.0,
384 );
385 let sync = corr.iter().sum::<f64>() / corr.len() as f64;
386 matrix[i][j] = sync;
387 matrix[j][i] = sync;
388 }
389 matrix[i][i] = 1.0;
390 }
391
392 matrix
393 }
394
395 pub fn get_output(&self) -> Vec<f64> {
397 self.layers
398 .last()
399 .map(|l| l.neurons.iter().map(|n| n.membrane_potential()).collect())
400 .unwrap_or_default()
401 }
402
403 pub fn apply_reward(&mut self, reward: f64) {
405 for weights in &mut self.feedforward_weights {
406 weights.apply_reward(reward);
407 }
408 for weights in &mut self.recurrent_weights {
409 if let Some(w) = weights {
410 w.apply_reward(reward);
411 }
412 }
413 }
414
415 pub fn low_activity_regions(&self) -> Vec<usize> {
417 let mut low_activity = Vec::new();
418 let threshold = 0.001;
419
420 for (layer_idx, layer) in self.layers.iter().enumerate() {
421 for (neuron_idx, train) in layer.spike_trains.iter().enumerate() {
422 if train.spike_rate(100.0) < threshold {
423 low_activity.push(layer_idx * 1000 + neuron_idx);
424 }
425 }
426 }
427
428 low_activity
429 }
430
431 pub fn sync_to_graph(&self, graph: &mut DynamicGraph) {
433 if let Some(ref recurrent) = self.recurrent_weights.first().and_then(|r| r.as_ref()) {
434 let vertices: Vec<_> = graph.vertices();
435
436 for ((pre, post), synapse) in recurrent.iter() {
437 if *pre < vertices.len() && *post < vertices.len() {
438 let u = vertices[*pre];
439 let v = vertices[*post];
440 if graph.has_edge(u, v) {
441 let _ = graph.update_edge_weight(u, v, synapse.weight);
442 }
443 }
444 }
445 }
446 }
447}
448
449use std::sync::atomic::{AtomicU64, Ordering};
451static RNG_STATE: AtomicU64 = AtomicU64::new(0x853c49e6748fea9b);
452
453fn rand_u64() -> u64 {
454 loop {
456 let current = RNG_STATE.load(Ordering::Relaxed);
457 let next = current.wrapping_mul(0x5851f42d4c957f2d).wrapping_add(0x14057b7ef767814f);
458 match RNG_STATE.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) {
459 Ok(_) => return next,
460 Err(_) => continue, }
462 }
463}
464
465fn rand_weight() -> f64 {
466 (rand_u64() as f64) / (u64::MAX as f64) * 0.5 + 0.25
467}
468
469fn rand_bool(p: f64) -> bool {
470 (rand_u64() as f64) / (u64::MAX as f64) < p
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476
477 #[test]
478 fn test_network_creation() {
479 let config = NetworkConfig::default();
480 let network = SpikingNetwork::new(config);
481
482 assert_eq!(network.num_layers(), 3);
483 assert_eq!(network.layer_size(0), 100);
484 assert_eq!(network.layer_size(1), 50);
485 assert_eq!(network.layer_size(2), 10);
486 }
487
488 #[test]
489 fn test_network_step() {
490 let config = NetworkConfig::default();
491 let mut network = SpikingNetwork::new(config);
492
493 let currents = vec![5.0; 100];
495 network.inject_current(¤ts);
496
497 let mut total_spikes = 0;
499 for _ in 0..100 {
500 let spikes = network.step();
501 total_spikes += spikes.len();
502 }
503
504 assert!(network.current_time() > 0.0);
506 }
507
508 #[test]
509 fn test_graph_network() {
510 use crate::graph::DynamicGraph;
511
512 let graph = DynamicGraph::new();
513 graph.insert_edge(0, 1, 1.0).unwrap();
514 graph.insert_edge(1, 2, 1.0).unwrap();
515 graph.insert_edge(2, 0, 1.0).unwrap();
516
517 let config = NetworkConfig::default();
518 let network = SpikingNetwork::from_graph(&graph, config);
519
520 assert_eq!(network.num_layers(), 1);
521 assert_eq!(network.layer_size(0), 3);
522 }
523
524 #[test]
525 fn test_synchrony_matrix() {
526 let mut config = NetworkConfig::default();
527 config.layers = vec![LayerConfig::new(5)];
528
529 let mut network = SpikingNetwork::new(config);
530
531 let currents = vec![2.0; 5];
533 for _ in 0..50 {
534 network.inject_current(¤ts);
535 network.step();
536 }
537
538 let sync = network.synchrony_matrix();
539 assert_eq!(sync.len(), 5);
540 assert_eq!(sync[0].len(), 5);
541
542 for i in 0..5 {
544 assert!((sync[i][i] - 1.0).abs() < 0.001);
545 }
546 }
547}