1use super::{SimTime, Spike};
21use crate::graph::{DynamicGraph, VertexId, Weight};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone)]
26pub struct STDPConfig {
27 pub a_plus: f64,
29 pub a_minus: f64,
31 pub tau_plus: f64,
33 pub tau_minus: f64,
35 pub w_min: f64,
37 pub w_max: f64,
39 pub learning_rate: f64,
41 pub tau_eligibility: f64,
43}
44
45impl Default for STDPConfig {
46 fn default() -> Self {
47 Self {
48 a_plus: 0.01,
49 a_minus: 0.012,
50 tau_plus: 20.0,
51 tau_minus: 20.0,
52 w_min: 0.0,
53 w_max: 1.0,
54 learning_rate: 1.0,
55 tau_eligibility: 1000.0,
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct Synapse {
63 pub pre: usize,
65 pub post: usize,
67 pub weight: f64,
69 pub delay: f64,
71 pub eligibility: f64,
73 pub last_update: SimTime,
75}
76
77impl Synapse {
78 pub fn new(pre: usize, post: usize, weight: f64) -> Self {
80 Self {
81 pre,
82 post,
83 weight,
84 delay: 1.0,
85 eligibility: 0.0,
86 last_update: 0.0,
87 }
88 }
89
90 pub fn with_delay(pre: usize, post: usize, weight: f64, delay: f64) -> Self {
92 Self {
93 pre,
94 post,
95 weight,
96 delay,
97 eligibility: 0.0,
98 last_update: 0.0,
99 }
100 }
101
102 pub fn stdp_update(
104 &mut self,
105 t_pre: SimTime,
106 t_post: SimTime,
107 config: &STDPConfig,
108 ) -> f64 {
109 let dt = t_post - t_pre;
110
111 let dw = if dt > 0.0 {
112 config.a_plus * (-dt / config.tau_plus).exp()
114 } else {
115 -config.a_minus * (dt / config.tau_minus).exp()
117 };
118
119 let delta = config.learning_rate * dw;
121 self.weight = (self.weight + delta).clamp(config.w_min, config.w_max);
122
123 self.eligibility += dw;
125
126 delta
127 }
128
129 pub fn decay_eligibility(&mut self, dt: f64, tau: f64) {
131 self.eligibility *= (-dt / tau).exp();
132 }
133
134 pub fn reward_modulated_update(&mut self, reward: f64, config: &STDPConfig) {
136 let delta = reward * self.eligibility * config.learning_rate;
137 self.weight = (self.weight + delta).clamp(config.w_min, config.w_max);
138 self.eligibility *= 0.5;
140 }
141}
142
143#[derive(Debug, Clone)]
145pub struct SynapseMatrix {
146 pub n_pre: usize,
148 pub n_post: usize,
150 synapses: HashMap<(usize, usize), Synapse>,
152 pub config: STDPConfig,
154 pre_spike_times: Vec<SimTime>,
156 post_spike_times: Vec<SimTime>,
158}
159
160impl SynapseMatrix {
161 pub fn new(n_pre: usize, n_post: usize) -> Self {
163 Self {
164 n_pre,
165 n_post,
166 synapses: HashMap::new(),
167 config: STDPConfig::default(),
168 pre_spike_times: vec![f64::NEG_INFINITY; n_pre],
169 post_spike_times: vec![f64::NEG_INFINITY; n_post],
170 }
171 }
172
173 pub fn with_config(n_pre: usize, n_post: usize, config: STDPConfig) -> Self {
175 Self {
176 n_pre,
177 n_post,
178 synapses: HashMap::new(),
179 config,
180 pre_spike_times: vec![f64::NEG_INFINITY; n_pre],
181 post_spike_times: vec![f64::NEG_INFINITY; n_post],
182 }
183 }
184
185 pub fn add_synapse(&mut self, pre: usize, post: usize, weight: f64) {
187 if pre < self.n_pre && post < self.n_post {
188 self.synapses.insert((pre, post), Synapse::new(pre, post, weight));
189 }
190 }
191
192 pub fn get_synapse(&self, pre: usize, post: usize) -> Option<&Synapse> {
194 self.synapses.get(&(pre, post))
195 }
196
197 pub fn get_synapse_mut(&mut self, pre: usize, post: usize) -> Option<&mut Synapse> {
199 self.synapses.get_mut(&(pre, post))
200 }
201
202 pub fn weight(&self, pre: usize, post: usize) -> f64 {
204 self.get_synapse(pre, post).map(|s| s.weight).unwrap_or(0.0)
205 }
206
207 #[inline]
213 pub fn compute_weighted_sums(&self, pre_activations: &[f64]) -> Vec<f64> {
214 let mut sums = vec![0.0; self.n_post];
215
216 for (&(pre, post), synapse) in &self.synapses {
218 if pre < pre_activations.len() {
219 sums[post] += synapse.weight * pre_activations[pre];
220 }
221 }
222
223 sums
224 }
225
226 #[inline]
228 pub fn weighted_sum_for_post(&self, post: usize, pre_activations: &[f64]) -> f64 {
229 let mut sum = 0.0;
230 for pre in 0..self.n_pre.min(pre_activations.len()) {
231 if let Some(synapse) = self.synapses.get(&(pre, post)) {
232 sum += synapse.weight * pre_activations[pre];
233 }
234 }
235 sum
236 }
237
238 pub fn set_weight(&mut self, pre: usize, post: usize, weight: f64) {
240 if let Some(synapse) = self.get_synapse_mut(pre, post) {
241 synapse.weight = weight;
242 } else {
243 self.add_synapse(pre, post, weight);
244 }
245 }
246
247 pub fn on_pre_spike(&mut self, pre: usize, time: SimTime) {
249 if pre >= self.n_pre {
250 return;
251 }
252
253 self.pre_spike_times[pre] = time;
254
255 for post in 0..self.n_post {
257 if let Some(synapse) = self.synapses.get_mut(&(pre, post)) {
258 let t_post = self.post_spike_times[post];
259 if t_post > f64::NEG_INFINITY {
260 synapse.stdp_update(time, t_post, &self.config);
261 }
262 }
263 }
264 }
265
266 pub fn on_post_spike(&mut self, post: usize, time: SimTime) {
268 if post >= self.n_post {
269 return;
270 }
271
272 self.post_spike_times[post] = time;
273
274 for pre in 0..self.n_pre {
276 if let Some(synapse) = self.synapses.get_mut(&(pre, post)) {
277 let t_pre = self.pre_spike_times[pre];
278 if t_pre > f64::NEG_INFINITY {
279 synapse.stdp_update(t_pre, time, &self.config);
280 }
281 }
282 }
283 }
284
285 pub fn process_spikes(&mut self, spikes: &[Spike]) {
287 for spike in spikes {
288 if spike.neuron_id < self.n_pre {
291 self.on_pre_spike(spike.neuron_id, spike.time);
292 }
293 if spike.neuron_id < self.n_post {
294 self.on_post_spike(spike.neuron_id, spike.time);
295 }
296 }
297 }
298
299 pub fn decay_eligibility(&mut self, dt: f64) {
301 for synapse in self.synapses.values_mut() {
302 synapse.decay_eligibility(dt, self.config.tau_eligibility);
303 }
304 }
305
306 pub fn apply_reward(&mut self, reward: f64) {
308 for synapse in self.synapses.values_mut() {
309 synapse.reward_modulated_update(reward, &self.config);
310 }
311 }
312
313 pub fn iter(&self) -> impl Iterator<Item = (&(usize, usize), &Synapse)> {
315 self.synapses.iter()
316 }
317
318 pub fn num_synapses(&self) -> usize {
320 self.synapses.len()
321 }
322
323 pub fn input_to(&self, post: usize, pre_activities: &[f64]) -> f64 {
325 let mut total = 0.0;
326 for pre in 0..self.n_pre.min(pre_activities.len()) {
327 total += self.weight(pre, post) * pre_activities[pre];
328 }
329 total
330 }
331
332 pub fn to_dense(&self) -> Vec<Vec<f64>> {
334 let mut matrix = vec![vec![0.0; self.n_post]; self.n_pre];
335 for ((pre, post), synapse) in &self.synapses {
336 matrix[*pre][*post] = synapse.weight;
337 }
338 matrix
339 }
340
341 pub fn from_dense(matrix: &[Vec<f64>]) -> Self {
343 let n_pre = matrix.len();
344 let n_post = matrix.first().map(|r| r.len()).unwrap_or(0);
345
346 let mut sm = Self::new(n_pre, n_post);
347
348 for (pre, row) in matrix.iter().enumerate() {
349 for (post, &weight) in row.iter().enumerate() {
350 if weight != 0.0 {
351 sm.add_synapse(pre, post, weight);
352 }
353 }
354 }
355
356 sm
357 }
358
359 pub fn sync_to_graph<F>(&self, graph: &mut DynamicGraph, neuron_to_vertex: F)
362 where
363 F: Fn(usize) -> VertexId,
364 {
365 for ((pre, post), synapse) in &self.synapses {
366 let u = neuron_to_vertex(*pre);
367 let v = neuron_to_vertex(*post);
368
369 if graph.has_edge(u, v) {
370 let _ = graph.update_edge_weight(u, v, synapse.weight);
371 }
372 }
373 }
374
375 pub fn sync_from_graph<F>(&mut self, graph: &DynamicGraph, vertex_to_neuron: F)
377 where
378 F: Fn(VertexId) -> usize,
379 {
380 for edge in graph.edges() {
381 let pre = vertex_to_neuron(edge.source);
382 let post = vertex_to_neuron(edge.target);
383
384 if pre < self.n_pre && post < self.n_post {
385 self.set_weight(pre, post, edge.weight);
386 }
387 }
388 }
389
390 pub fn high_correlation_pairs(&self, threshold: f64) -> Vec<(usize, usize)> {
392 self.synapses
393 .iter()
394 .filter(|(_, s)| s.weight >= threshold)
395 .map(|((pre, post), _)| (*pre, *post))
396 .collect()
397 }
398}
399
400#[derive(Debug, Clone)]
402pub struct AsymmetricSTDP {
403 pub tau_forward: f64,
405 pub tau_backward: f64,
407 pub a_forward: f64,
409 pub a_backward: f64,
411}
412
413impl Default for AsymmetricSTDP {
414 fn default() -> Self {
415 Self {
416 tau_forward: 15.0,
417 tau_backward: 30.0, a_forward: 0.015, a_backward: 0.008, }
421 }
422}
423
424impl AsymmetricSTDP {
425 pub fn compute_dw(&self, dt: f64) -> f64 {
428 if dt > 0.0 {
429 self.a_forward * (-dt / self.tau_forward).exp()
431 } else {
432 -self.a_backward * (dt / self.tau_backward).exp()
434 }
435 }
436
437 pub fn update_weights(
439 &self,
440 matrix: &mut SynapseMatrix,
441 neuron_id: usize,
442 time: SimTime,
443 ) {
444 let w_min = matrix.config.w_min;
445 let w_max = matrix.config.w_max;
446 let n_pre = matrix.n_pre;
447 let n_post = matrix.n_post;
448
449 let pre_times: Vec<_> = (0..n_pre)
451 .map(|pre| matrix.pre_spike_times.get(pre).copied().unwrap_or(f64::NEG_INFINITY))
452 .collect();
453
454 for pre in 0..n_pre {
456 let t_pre = pre_times[pre];
457 if t_pre > f64::NEG_INFINITY {
458 let dt = time - t_pre;
459 let dw = self.compute_dw(dt);
460 if let Some(synapse) = matrix.get_synapse_mut(pre, neuron_id) {
461 synapse.weight = (synapse.weight + dw).clamp(w_min, w_max);
462 }
463 }
464 }
465
466 let post_times: Vec<_> = (0..n_post)
468 .map(|post| matrix.post_spike_times.get(post).copied().unwrap_or(f64::NEG_INFINITY))
469 .collect();
470
471 for post in 0..n_post {
472 let t_post = post_times[post];
473 if t_post > f64::NEG_INFINITY {
474 let dt = t_post - time; let dw = self.compute_dw(dt);
476 if let Some(synapse) = matrix.get_synapse_mut(neuron_id, post) {
477 synapse.weight = (synapse.weight + dw).clamp(w_min, w_max);
478 }
479 }
480 }
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn test_synapse_creation() {
490 let synapse = Synapse::new(0, 1, 0.5);
491 assert_eq!(synapse.pre, 0);
492 assert_eq!(synapse.post, 1);
493 assert_eq!(synapse.weight, 0.5);
494 }
495
496 #[test]
497 fn test_stdp_ltp() {
498 let mut synapse = Synapse::new(0, 1, 0.5);
499 let config = STDPConfig::default();
500
501 let dw = synapse.stdp_update(10.0, 15.0, &config);
503 assert!(dw > 0.0);
504 assert!(synapse.weight > 0.5);
505 }
506
507 #[test]
508 fn test_stdp_ltd() {
509 let mut synapse = Synapse::new(0, 1, 0.5);
510 let config = STDPConfig::default();
511
512 let dw = synapse.stdp_update(15.0, 10.0, &config);
514 assert!(dw < 0.0);
515 assert!(synapse.weight < 0.5);
516 }
517
518 #[test]
519 fn test_synapse_matrix() {
520 let mut matrix = SynapseMatrix::new(10, 10);
521 matrix.add_synapse(0, 1, 0.5);
522 matrix.add_synapse(1, 2, 0.3);
523
524 assert_eq!(matrix.num_synapses(), 2);
525 assert!((matrix.weight(0, 1) - 0.5).abs() < 0.001);
526 assert!((matrix.weight(1, 2) - 0.3).abs() < 0.001);
527 assert_eq!(matrix.weight(2, 3), 0.0);
528 }
529
530 #[test]
531 fn test_spike_processing() {
532 let mut matrix = SynapseMatrix::new(5, 5);
533
534 for i in 0..5 {
536 for j in 0..5 {
537 if i != j {
538 matrix.add_synapse(i, j, 0.5);
539 }
540 }
541 }
542
543 matrix.on_pre_spike(0, 10.0);
545 matrix.on_post_spike(1, 15.0);
546
547 assert!(matrix.weight(0, 1) > 0.5);
549 }
550
551 #[test]
552 fn test_asymmetric_stdp() {
553 let stdp = AsymmetricSTDP::default();
554
555 let dw_causal = stdp.compute_dw(5.0);
557 let dw_anticausal = stdp.compute_dw(-5.0);
558
559 assert!(dw_causal > 0.0);
560 assert!(dw_anticausal < 0.0);
561 assert!(dw_causal.abs() > dw_anticausal.abs());
562 }
563
564 #[test]
565 fn test_dense_conversion() {
566 let mut matrix = SynapseMatrix::new(3, 3);
567 matrix.add_synapse(0, 1, 0.5);
568 matrix.add_synapse(1, 2, 0.7);
569
570 let dense = matrix.to_dense();
571 assert_eq!(dense.len(), 3);
572 assert!((dense[0][1] - 0.5).abs() < 0.001);
573 assert!((dense[1][2] - 0.7).abs() < 0.001);
574
575 let recovered = SynapseMatrix::from_dense(&dense);
576 assert_eq!(recovered.num_synapses(), 2);
577 }
578}