1use super::{SimTime, Spike};
21use crate::graph::{DynamicGraph, VertexId, Weight};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone)]
26pub struct STDPConfig {
27 pub a_plus: f64,
29 pub a_minus: f64,
31 pub tau_plus: f64,
33 pub tau_minus: f64,
35 pub w_min: f64,
37 pub w_max: f64,
39 pub learning_rate: f64,
41 pub tau_eligibility: f64,
43}
44
45impl Default for STDPConfig {
46 fn default() -> Self {
47 Self {
48 a_plus: 0.01,
49 a_minus: 0.012,
50 tau_plus: 20.0,
51 tau_minus: 20.0,
52 w_min: 0.0,
53 w_max: 1.0,
54 learning_rate: 1.0,
55 tau_eligibility: 1000.0,
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct Synapse {
63 pub pre: usize,
65 pub post: usize,
67 pub weight: f64,
69 pub delay: f64,
71 pub eligibility: f64,
73 pub last_update: SimTime,
75}
76
77impl Synapse {
78 pub fn new(pre: usize, post: usize, weight: f64) -> Self {
80 Self {
81 pre,
82 post,
83 weight,
84 delay: 1.0,
85 eligibility: 0.0,
86 last_update: 0.0,
87 }
88 }
89
90 pub fn with_delay(pre: usize, post: usize, weight: f64, delay: f64) -> Self {
92 Self {
93 pre,
94 post,
95 weight,
96 delay,
97 eligibility: 0.0,
98 last_update: 0.0,
99 }
100 }
101
102 pub fn stdp_update(&mut self, t_pre: SimTime, t_post: SimTime, config: &STDPConfig) -> f64 {
104 let dt = t_post - t_pre;
105
106 let dw = if dt > 0.0 {
107 config.a_plus * (-dt / config.tau_plus).exp()
109 } else {
110 -config.a_minus * (dt / config.tau_minus).exp()
112 };
113
114 let delta = config.learning_rate * dw;
116 self.weight = (self.weight + delta).clamp(config.w_min, config.w_max);
117
118 self.eligibility += dw;
120
121 delta
122 }
123
124 pub fn decay_eligibility(&mut self, dt: f64, tau: f64) {
126 self.eligibility *= (-dt / tau).exp();
127 }
128
129 pub fn reward_modulated_update(&mut self, reward: f64, config: &STDPConfig) {
131 let delta = reward * self.eligibility * config.learning_rate;
132 self.weight = (self.weight + delta).clamp(config.w_min, config.w_max);
133 self.eligibility *= 0.5;
135 }
136}
137
138#[derive(Debug, Clone)]
140pub struct SynapseMatrix {
141 pub n_pre: usize,
143 pub n_post: usize,
145 synapses: HashMap<(usize, usize), Synapse>,
147 pub config: STDPConfig,
149 pre_spike_times: Vec<SimTime>,
151 post_spike_times: Vec<SimTime>,
153}
154
155impl SynapseMatrix {
156 pub fn new(n_pre: usize, n_post: usize) -> Self {
158 Self {
159 n_pre,
160 n_post,
161 synapses: HashMap::new(),
162 config: STDPConfig::default(),
163 pre_spike_times: vec![f64::NEG_INFINITY; n_pre],
164 post_spike_times: vec![f64::NEG_INFINITY; n_post],
165 }
166 }
167
168 pub fn with_config(n_pre: usize, n_post: usize, config: STDPConfig) -> Self {
170 Self {
171 n_pre,
172 n_post,
173 synapses: HashMap::new(),
174 config,
175 pre_spike_times: vec![f64::NEG_INFINITY; n_pre],
176 post_spike_times: vec![f64::NEG_INFINITY; n_post],
177 }
178 }
179
180 pub fn add_synapse(&mut self, pre: usize, post: usize, weight: f64) {
182 if pre < self.n_pre && post < self.n_post {
183 self.synapses
184 .insert((pre, post), Synapse::new(pre, post, weight));
185 }
186 }
187
188 pub fn get_synapse(&self, pre: usize, post: usize) -> Option<&Synapse> {
190 self.synapses.get(&(pre, post))
191 }
192
193 pub fn get_synapse_mut(&mut self, pre: usize, post: usize) -> Option<&mut Synapse> {
195 self.synapses.get_mut(&(pre, post))
196 }
197
198 pub fn weight(&self, pre: usize, post: usize) -> f64 {
200 self.get_synapse(pre, post).map(|s| s.weight).unwrap_or(0.0)
201 }
202
203 #[inline]
209 pub fn compute_weighted_sums(&self, pre_activations: &[f64]) -> Vec<f64> {
210 let mut sums = vec![0.0; self.n_post];
211
212 for (&(pre, post), synapse) in &self.synapses {
214 if pre < pre_activations.len() {
215 sums[post] += synapse.weight * pre_activations[pre];
216 }
217 }
218
219 sums
220 }
221
222 #[inline]
224 pub fn weighted_sum_for_post(&self, post: usize, pre_activations: &[f64]) -> f64 {
225 let mut sum = 0.0;
226 for pre in 0..self.n_pre.min(pre_activations.len()) {
227 if let Some(synapse) = self.synapses.get(&(pre, post)) {
228 sum += synapse.weight * pre_activations[pre];
229 }
230 }
231 sum
232 }
233
234 pub fn set_weight(&mut self, pre: usize, post: usize, weight: f64) {
236 if let Some(synapse) = self.get_synapse_mut(pre, post) {
237 synapse.weight = weight;
238 } else {
239 self.add_synapse(pre, post, weight);
240 }
241 }
242
243 pub fn on_pre_spike(&mut self, pre: usize, time: SimTime) {
245 if pre >= self.n_pre {
246 return;
247 }
248
249 self.pre_spike_times[pre] = time;
250
251 for post in 0..self.n_post {
253 if let Some(synapse) = self.synapses.get_mut(&(pre, post)) {
254 let t_post = self.post_spike_times[post];
255 if t_post > f64::NEG_INFINITY {
256 synapse.stdp_update(time, t_post, &self.config);
257 }
258 }
259 }
260 }
261
262 pub fn on_post_spike(&mut self, post: usize, time: SimTime) {
264 if post >= self.n_post {
265 return;
266 }
267
268 self.post_spike_times[post] = time;
269
270 for pre in 0..self.n_pre {
272 if let Some(synapse) = self.synapses.get_mut(&(pre, post)) {
273 let t_pre = self.pre_spike_times[pre];
274 if t_pre > f64::NEG_INFINITY {
275 synapse.stdp_update(t_pre, time, &self.config);
276 }
277 }
278 }
279 }
280
281 pub fn process_spikes(&mut self, spikes: &[Spike]) {
283 for spike in spikes {
284 if spike.neuron_id < self.n_pre {
287 self.on_pre_spike(spike.neuron_id, spike.time);
288 }
289 if spike.neuron_id < self.n_post {
290 self.on_post_spike(spike.neuron_id, spike.time);
291 }
292 }
293 }
294
295 pub fn decay_eligibility(&mut self, dt: f64) {
297 for synapse in self.synapses.values_mut() {
298 synapse.decay_eligibility(dt, self.config.tau_eligibility);
299 }
300 }
301
302 pub fn apply_reward(&mut self, reward: f64) {
304 for synapse in self.synapses.values_mut() {
305 synapse.reward_modulated_update(reward, &self.config);
306 }
307 }
308
309 pub fn iter(&self) -> impl Iterator<Item = (&(usize, usize), &Synapse)> {
311 self.synapses.iter()
312 }
313
314 pub fn num_synapses(&self) -> usize {
316 self.synapses.len()
317 }
318
319 pub fn input_to(&self, post: usize, pre_activities: &[f64]) -> f64 {
321 let mut total = 0.0;
322 for pre in 0..self.n_pre.min(pre_activities.len()) {
323 total += self.weight(pre, post) * pre_activities[pre];
324 }
325 total
326 }
327
328 pub fn to_dense(&self) -> Vec<Vec<f64>> {
330 let mut matrix = vec![vec![0.0; self.n_post]; self.n_pre];
331 for ((pre, post), synapse) in &self.synapses {
332 matrix[*pre][*post] = synapse.weight;
333 }
334 matrix
335 }
336
337 pub fn from_dense(matrix: &[Vec<f64>]) -> Self {
339 let n_pre = matrix.len();
340 let n_post = matrix.first().map(|r| r.len()).unwrap_or(0);
341
342 let mut sm = Self::new(n_pre, n_post);
343
344 for (pre, row) in matrix.iter().enumerate() {
345 for (post, &weight) in row.iter().enumerate() {
346 if weight != 0.0 {
347 sm.add_synapse(pre, post, weight);
348 }
349 }
350 }
351
352 sm
353 }
354
355 pub fn sync_to_graph<F>(&self, graph: &mut DynamicGraph, neuron_to_vertex: F)
358 where
359 F: Fn(usize) -> VertexId,
360 {
361 for ((pre, post), synapse) in &self.synapses {
362 let u = neuron_to_vertex(*pre);
363 let v = neuron_to_vertex(*post);
364
365 if graph.has_edge(u, v) {
366 let _ = graph.update_edge_weight(u, v, synapse.weight);
367 }
368 }
369 }
370
371 pub fn sync_from_graph<F>(&mut self, graph: &DynamicGraph, vertex_to_neuron: F)
373 where
374 F: Fn(VertexId) -> usize,
375 {
376 for edge in graph.edges() {
377 let pre = vertex_to_neuron(edge.source);
378 let post = vertex_to_neuron(edge.target);
379
380 if pre < self.n_pre && post < self.n_post {
381 self.set_weight(pre, post, edge.weight);
382 }
383 }
384 }
385
386 pub fn high_correlation_pairs(&self, threshold: f64) -> Vec<(usize, usize)> {
388 self.synapses
389 .iter()
390 .filter(|(_, s)| s.weight >= threshold)
391 .map(|((pre, post), _)| (*pre, *post))
392 .collect()
393 }
394}
395
396#[derive(Debug, Clone)]
398pub struct AsymmetricSTDP {
399 pub tau_forward: f64,
401 pub tau_backward: f64,
403 pub a_forward: f64,
405 pub a_backward: f64,
407}
408
409impl Default for AsymmetricSTDP {
410 fn default() -> Self {
411 Self {
412 tau_forward: 15.0,
413 tau_backward: 30.0, a_forward: 0.015, a_backward: 0.008, }
417 }
418}
419
420impl AsymmetricSTDP {
421 pub fn compute_dw(&self, dt: f64) -> f64 {
424 if dt > 0.0 {
425 self.a_forward * (-dt / self.tau_forward).exp()
427 } else {
428 -self.a_backward * (dt / self.tau_backward).exp()
430 }
431 }
432
433 pub fn update_weights(&self, matrix: &mut SynapseMatrix, neuron_id: usize, time: SimTime) {
435 let w_min = matrix.config.w_min;
436 let w_max = matrix.config.w_max;
437 let n_pre = matrix.n_pre;
438 let n_post = matrix.n_post;
439
440 let pre_times: Vec<_> = (0..n_pre)
442 .map(|pre| {
443 matrix
444 .pre_spike_times
445 .get(pre)
446 .copied()
447 .unwrap_or(f64::NEG_INFINITY)
448 })
449 .collect();
450
451 for pre in 0..n_pre {
453 let t_pre = pre_times[pre];
454 if t_pre > f64::NEG_INFINITY {
455 let dt = time - t_pre;
456 let dw = self.compute_dw(dt);
457 if let Some(synapse) = matrix.get_synapse_mut(pre, neuron_id) {
458 synapse.weight = (synapse.weight + dw).clamp(w_min, w_max);
459 }
460 }
461 }
462
463 let post_times: Vec<_> = (0..n_post)
465 .map(|post| {
466 matrix
467 .post_spike_times
468 .get(post)
469 .copied()
470 .unwrap_or(f64::NEG_INFINITY)
471 })
472 .collect();
473
474 for post in 0..n_post {
475 let t_post = post_times[post];
476 if t_post > f64::NEG_INFINITY {
477 let dt = t_post - time; let dw = self.compute_dw(dt);
479 if let Some(synapse) = matrix.get_synapse_mut(neuron_id, post) {
480 synapse.weight = (synapse.weight + dw).clamp(w_min, w_max);
481 }
482 }
483 }
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 #[test]
492 fn test_synapse_creation() {
493 let synapse = Synapse::new(0, 1, 0.5);
494 assert_eq!(synapse.pre, 0);
495 assert_eq!(synapse.post, 1);
496 assert_eq!(synapse.weight, 0.5);
497 }
498
499 #[test]
500 fn test_stdp_ltp() {
501 let mut synapse = Synapse::new(0, 1, 0.5);
502 let config = STDPConfig::default();
503
504 let dw = synapse.stdp_update(10.0, 15.0, &config);
506 assert!(dw > 0.0);
507 assert!(synapse.weight > 0.5);
508 }
509
510 #[test]
511 fn test_stdp_ltd() {
512 let mut synapse = Synapse::new(0, 1, 0.5);
513 let config = STDPConfig::default();
514
515 let dw = synapse.stdp_update(15.0, 10.0, &config);
517 assert!(dw < 0.0);
518 assert!(synapse.weight < 0.5);
519 }
520
521 #[test]
522 fn test_synapse_matrix() {
523 let mut matrix = SynapseMatrix::new(10, 10);
524 matrix.add_synapse(0, 1, 0.5);
525 matrix.add_synapse(1, 2, 0.3);
526
527 assert_eq!(matrix.num_synapses(), 2);
528 assert!((matrix.weight(0, 1) - 0.5).abs() < 0.001);
529 assert!((matrix.weight(1, 2) - 0.3).abs() < 0.001);
530 assert_eq!(matrix.weight(2, 3), 0.0);
531 }
532
533 #[test]
534 fn test_spike_processing() {
535 let mut matrix = SynapseMatrix::new(5, 5);
536
537 for i in 0..5 {
539 for j in 0..5 {
540 if i != j {
541 matrix.add_synapse(i, j, 0.5);
542 }
543 }
544 }
545
546 matrix.on_pre_spike(0, 10.0);
548 matrix.on_post_spike(1, 15.0);
549
550 assert!(matrix.weight(0, 1) > 0.5);
552 }
553
554 #[test]
555 fn test_asymmetric_stdp() {
556 let stdp = AsymmetricSTDP::default();
557
558 let dw_causal = stdp.compute_dw(5.0);
560 let dw_anticausal = stdp.compute_dw(-5.0);
561
562 assert!(dw_causal > 0.0);
563 assert!(dw_anticausal < 0.0);
564 assert!(dw_causal.abs() > dw_anticausal.abs());
565 }
566
567 #[test]
568 fn test_dense_conversion() {
569 let mut matrix = SynapseMatrix::new(3, 3);
570 matrix.add_synapse(0, 1, 0.5);
571 matrix.add_synapse(1, 2, 0.7);
572
573 let dense = matrix.to_dense();
574 assert_eq!(dense.len(), 3);
575 assert!((dense[0][1] - 0.5).abs() < 0.001);
576 assert!((dense[1][2] - 0.7).abs() < 0.001);
577
578 let recovered = SynapseMatrix::from_dense(&dense);
579 assert_eq!(recovered.num_synapses(), 2);
580 }
581}