ruvector_mincut/snn/
synapse.rs

1//! # Synaptic Connections with STDP Learning
2//!
3//! Implements spike-timing dependent plasticity (STDP) for synaptic weight updates.
4//!
5//! ## STDP Learning Rule
6//!
7//! ```text
8//! ΔW = A+ * exp(-Δt/τ+)  if Δt > 0 (pre before post → LTP)
9//! ΔW = A- * exp(Δt/τ-)   if Δt < 0 (post before pre → LTD)
10//! ```
11//!
12//! Where Δt = t_post - t_pre
13//!
14//! ## Integration with MinCut
15//!
16//! Synaptic weights directly map to graph edge weights:
17//! - Strong synapse → strong edge → less likely in mincut
18//! - STDP learning → edge weight evolution → dynamic mincut
19
20use super::{SimTime, Spike};
21use crate::graph::{DynamicGraph, VertexId, Weight};
22use std::collections::HashMap;
23
24/// Configuration for STDP learning
25#[derive(Debug, Clone)]
26pub struct STDPConfig {
27    /// LTP amplitude (potentiation)
28    pub a_plus: f64,
29    /// LTD amplitude (depression)
30    pub a_minus: f64,
31    /// LTP time constant (ms)
32    pub tau_plus: f64,
33    /// LTD time constant (ms)
34    pub tau_minus: f64,
35    /// Minimum weight
36    pub w_min: f64,
37    /// Maximum weight
38    pub w_max: f64,
39    /// Learning rate
40    pub learning_rate: f64,
41    /// Eligibility trace time constant
42    pub tau_eligibility: f64,
43}
44
45impl Default for STDPConfig {
46    fn default() -> Self {
47        Self {
48            a_plus: 0.01,
49            a_minus: 0.012,
50            tau_plus: 20.0,
51            tau_minus: 20.0,
52            w_min: 0.0,
53            w_max: 1.0,
54            learning_rate: 1.0,
55            tau_eligibility: 1000.0,
56        }
57    }
58}
59
60/// A single synapse between two neurons
61#[derive(Debug, Clone)]
62pub struct Synapse {
63    /// Pre-synaptic neuron ID
64    pub pre: usize,
65    /// Post-synaptic neuron ID
66    pub post: usize,
67    /// Synaptic weight
68    pub weight: f64,
69    /// Transmission delay (ms)
70    pub delay: f64,
71    /// Eligibility trace for reward-modulated STDP
72    pub eligibility: f64,
73    /// Last update time
74    pub last_update: SimTime,
75}
76
77impl Synapse {
78    /// Create a new synapse
79    pub fn new(pre: usize, post: usize, weight: f64) -> Self {
80        Self {
81            pre,
82            post,
83            weight,
84            delay: 1.0,
85            eligibility: 0.0,
86            last_update: 0.0,
87        }
88    }
89
90    /// Create synapse with delay
91    pub fn with_delay(pre: usize, post: usize, weight: f64, delay: f64) -> Self {
92        Self {
93            pre,
94            post,
95            weight,
96            delay,
97            eligibility: 0.0,
98            last_update: 0.0,
99        }
100    }
101
102    /// Compute STDP weight change
103    pub fn stdp_update(
104        &mut self,
105        t_pre: SimTime,
106        t_post: SimTime,
107        config: &STDPConfig,
108    ) -> f64 {
109        let dt = t_post - t_pre;
110
111        let dw = if dt > 0.0 {
112            // Pre before post → LTP
113            config.a_plus * (-dt / config.tau_plus).exp()
114        } else {
115            // Post before pre → LTD
116            -config.a_minus * (dt / config.tau_minus).exp()
117        };
118
119        // Apply learning rate and clip
120        let delta = config.learning_rate * dw;
121        self.weight = (self.weight + delta).clamp(config.w_min, config.w_max);
122
123        // Update eligibility trace
124        self.eligibility += dw;
125
126        delta
127    }
128
129    /// Decay eligibility trace
130    pub fn decay_eligibility(&mut self, dt: f64, tau: f64) {
131        self.eligibility *= (-dt / tau).exp();
132    }
133
134    /// Apply reward-modulated update (R-STDP)
135    pub fn reward_modulated_update(&mut self, reward: f64, config: &STDPConfig) {
136        let delta = reward * self.eligibility * config.learning_rate;
137        self.weight = (self.weight + delta).clamp(config.w_min, config.w_max);
138        // Reset eligibility after reward
139        self.eligibility *= 0.5;
140    }
141}
142
143/// Matrix of synaptic connections
144#[derive(Debug, Clone)]
145pub struct SynapseMatrix {
146    /// Number of pre-synaptic neurons
147    pub n_pre: usize,
148    /// Number of post-synaptic neurons
149    pub n_post: usize,
150    /// Synapses indexed by (pre, post)
151    synapses: HashMap<(usize, usize), Synapse>,
152    /// STDP configuration
153    pub config: STDPConfig,
154    /// Track last spike times for pre-synaptic neurons
155    pre_spike_times: Vec<SimTime>,
156    /// Track last spike times for post-synaptic neurons
157    post_spike_times: Vec<SimTime>,
158}
159
160impl SynapseMatrix {
161    /// Create a new synapse matrix
162    pub fn new(n_pre: usize, n_post: usize) -> Self {
163        Self {
164            n_pre,
165            n_post,
166            synapses: HashMap::new(),
167            config: STDPConfig::default(),
168            pre_spike_times: vec![f64::NEG_INFINITY; n_pre],
169            post_spike_times: vec![f64::NEG_INFINITY; n_post],
170        }
171    }
172
173    /// Create with custom STDP config
174    pub fn with_config(n_pre: usize, n_post: usize, config: STDPConfig) -> Self {
175        Self {
176            n_pre,
177            n_post,
178            synapses: HashMap::new(),
179            config,
180            pre_spike_times: vec![f64::NEG_INFINITY; n_pre],
181            post_spike_times: vec![f64::NEG_INFINITY; n_post],
182        }
183    }
184
185    /// Add a synapse
186    pub fn add_synapse(&mut self, pre: usize, post: usize, weight: f64) {
187        if pre < self.n_pre && post < self.n_post {
188            self.synapses.insert((pre, post), Synapse::new(pre, post, weight));
189        }
190    }
191
192    /// Get synapse if it exists
193    pub fn get_synapse(&self, pre: usize, post: usize) -> Option<&Synapse> {
194        self.synapses.get(&(pre, post))
195    }
196
197    /// Get mutable synapse if it exists
198    pub fn get_synapse_mut(&mut self, pre: usize, post: usize) -> Option<&mut Synapse> {
199        self.synapses.get_mut(&(pre, post))
200    }
201
202    /// Get weight of a synapse (0 if doesn't exist)
203    pub fn weight(&self, pre: usize, post: usize) -> f64 {
204        self.get_synapse(pre, post).map(|s| s.weight).unwrap_or(0.0)
205    }
206
207    /// Compute weighted sum for all post-synaptic neurons given pre-synaptic activations
208    ///
209    /// This is optimized to iterate only over existing synapses, avoiding O(n²) lookups.
210    /// pre_activations[i] is the activation of pre-synaptic neuron i.
211    /// Returns vector of weighted sums for each post-synaptic neuron.
212    #[inline]
213    pub fn compute_weighted_sums(&self, pre_activations: &[f64]) -> Vec<f64> {
214        let mut sums = vec![0.0; self.n_post];
215
216        // Iterate only over existing synapses (sparse operation)
217        for (&(pre, post), synapse) in &self.synapses {
218            if pre < pre_activations.len() {
219                sums[post] += synapse.weight * pre_activations[pre];
220            }
221        }
222
223        sums
224    }
225
226    /// Compute weighted sum for a single post-synaptic neuron
227    #[inline]
228    pub fn weighted_sum_for_post(&self, post: usize, pre_activations: &[f64]) -> f64 {
229        let mut sum = 0.0;
230        for pre in 0..self.n_pre.min(pre_activations.len()) {
231            if let Some(synapse) = self.synapses.get(&(pre, post)) {
232                sum += synapse.weight * pre_activations[pre];
233            }
234        }
235        sum
236    }
237
238    /// Set weight of a synapse (creates if doesn't exist)
239    pub fn set_weight(&mut self, pre: usize, post: usize, weight: f64) {
240        if let Some(synapse) = self.get_synapse_mut(pre, post) {
241            synapse.weight = weight;
242        } else {
243            self.add_synapse(pre, post, weight);
244        }
245    }
246
247    /// Record a pre-synaptic spike and perform STDP updates
248    pub fn on_pre_spike(&mut self, pre: usize, time: SimTime) {
249        if pre >= self.n_pre {
250            return;
251        }
252
253        self.pre_spike_times[pre] = time;
254
255        // LTD: pre spike after recent post spikes
256        for post in 0..self.n_post {
257            if let Some(synapse) = self.synapses.get_mut(&(pre, post)) {
258                let t_post = self.post_spike_times[post];
259                if t_post > f64::NEG_INFINITY {
260                    synapse.stdp_update(time, t_post, &self.config);
261                }
262            }
263        }
264    }
265
266    /// Record a post-synaptic spike and perform STDP updates
267    pub fn on_post_spike(&mut self, post: usize, time: SimTime) {
268        if post >= self.n_post {
269            return;
270        }
271
272        self.post_spike_times[post] = time;
273
274        // LTP: post spike after recent pre spikes
275        for pre in 0..self.n_pre {
276            if let Some(synapse) = self.synapses.get_mut(&(pre, post)) {
277                let t_pre = self.pre_spike_times[pre];
278                if t_pre > f64::NEG_INFINITY {
279                    synapse.stdp_update(t_pre, time, &self.config);
280                }
281            }
282        }
283    }
284
285    /// Process multiple spikes with STDP
286    pub fn process_spikes(&mut self, spikes: &[Spike]) {
287        for spike in spikes {
288            // Assume neuron IDs map directly
289            // Pre-synaptic: lower half, Post-synaptic: upper half (example mapping)
290            if spike.neuron_id < self.n_pre {
291                self.on_pre_spike(spike.neuron_id, spike.time);
292            }
293            if spike.neuron_id < self.n_post {
294                self.on_post_spike(spike.neuron_id, spike.time);
295            }
296        }
297    }
298
299    /// Decay all eligibility traces
300    pub fn decay_eligibility(&mut self, dt: f64) {
301        for synapse in self.synapses.values_mut() {
302            synapse.decay_eligibility(dt, self.config.tau_eligibility);
303        }
304    }
305
306    /// Apply reward-modulated learning to all synapses
307    pub fn apply_reward(&mut self, reward: f64) {
308        for synapse in self.synapses.values_mut() {
309            synapse.reward_modulated_update(reward, &self.config);
310        }
311    }
312
313    /// Get all synapses as an iterator
314    pub fn iter(&self) -> impl Iterator<Item = (&(usize, usize), &Synapse)> {
315        self.synapses.iter()
316    }
317
318    /// Get number of synapses
319    pub fn num_synapses(&self) -> usize {
320        self.synapses.len()
321    }
322
323    /// Compute total synaptic input to a post-synaptic neuron
324    pub fn input_to(&self, post: usize, pre_activities: &[f64]) -> f64 {
325        let mut total = 0.0;
326        for pre in 0..self.n_pre.min(pre_activities.len()) {
327            total += self.weight(pre, post) * pre_activities[pre];
328        }
329        total
330    }
331
332    /// Create dense weight matrix
333    pub fn to_dense(&self) -> Vec<Vec<f64>> {
334        let mut matrix = vec![vec![0.0; self.n_post]; self.n_pre];
335        for ((pre, post), synapse) in &self.synapses {
336            matrix[*pre][*post] = synapse.weight;
337        }
338        matrix
339    }
340
341    /// Initialize from dense matrix
342    pub fn from_dense(matrix: &[Vec<f64>]) -> Self {
343        let n_pre = matrix.len();
344        let n_post = matrix.first().map(|r| r.len()).unwrap_or(0);
345
346        let mut sm = Self::new(n_pre, n_post);
347
348        for (pre, row) in matrix.iter().enumerate() {
349            for (post, &weight) in row.iter().enumerate() {
350                if weight != 0.0 {
351                    sm.add_synapse(pre, post, weight);
352                }
353            }
354        }
355
356        sm
357    }
358
359    /// Synchronize weights with a DynamicGraph
360    /// Maps neurons to vertices via a mapping function
361    pub fn sync_to_graph<F>(&self, graph: &mut DynamicGraph, neuron_to_vertex: F)
362    where
363        F: Fn(usize) -> VertexId,
364    {
365        for ((pre, post), synapse) in &self.synapses {
366            let u = neuron_to_vertex(*pre);
367            let v = neuron_to_vertex(*post);
368
369            if graph.has_edge(u, v) {
370                let _ = graph.update_edge_weight(u, v, synapse.weight);
371            }
372        }
373    }
374
375    /// Load weights from a DynamicGraph
376    pub fn sync_from_graph<F>(&mut self, graph: &DynamicGraph, vertex_to_neuron: F)
377    where
378        F: Fn(VertexId) -> usize,
379    {
380        for edge in graph.edges() {
381            let pre = vertex_to_neuron(edge.source);
382            let post = vertex_to_neuron(edge.target);
383
384            if pre < self.n_pre && post < self.n_post {
385                self.set_weight(pre, post, edge.weight);
386            }
387        }
388    }
389
390    /// Get high-correlation pairs (synapses with weight above threshold)
391    pub fn high_correlation_pairs(&self, threshold: f64) -> Vec<(usize, usize)> {
392        self.synapses
393            .iter()
394            .filter(|(_, s)| s.weight >= threshold)
395            .map(|((pre, post), _)| (*pre, *post))
396            .collect()
397    }
398}
399
400/// Asymmetric STDP for causal relationship encoding
401#[derive(Debug, Clone)]
402pub struct AsymmetricSTDP {
403    /// Forward (causal) time constant
404    pub tau_forward: f64,
405    /// Backward time constant
406    pub tau_backward: f64,
407    /// Forward amplitude (typically larger for causality)
408    pub a_forward: f64,
409    /// Backward amplitude
410    pub a_backward: f64,
411}
412
413impl Default for AsymmetricSTDP {
414    fn default() -> Self {
415        Self {
416            tau_forward: 15.0,
417            tau_backward: 30.0,  // Longer backward window
418            a_forward: 0.015,   // Stronger forward (causal)
419            a_backward: 0.008,  // Weaker backward
420        }
421    }
422}
423
424impl AsymmetricSTDP {
425    /// Compute weight change for causal relationship encoding
426    /// Positive Δt (pre→post) is weighted more heavily
427    pub fn compute_dw(&self, dt: f64) -> f64 {
428        if dt > 0.0 {
429            // Pre before post → causal relationship
430            self.a_forward * (-dt / self.tau_forward).exp()
431        } else {
432            // Post before pre → anti-causal
433            -self.a_backward * (dt / self.tau_backward).exp()
434        }
435    }
436
437    /// Update weight matrix for causal discovery
438    pub fn update_weights(
439        &self,
440        matrix: &mut SynapseMatrix,
441        neuron_id: usize,
442        time: SimTime,
443    ) {
444        let w_min = matrix.config.w_min;
445        let w_max = matrix.config.w_max;
446        let n_pre = matrix.n_pre;
447        let n_post = matrix.n_post;
448
449        // Collect pre-spike times first to avoid borrow conflicts
450        let pre_times: Vec<_> = (0..n_pre)
451            .map(|pre| matrix.pre_spike_times.get(pre).copied().unwrap_or(f64::NEG_INFINITY))
452            .collect();
453
454        // This neuron just spiked - update all synapses involving it (incoming)
455        for pre in 0..n_pre {
456            let t_pre = pre_times[pre];
457            if t_pre > f64::NEG_INFINITY {
458                let dt = time - t_pre;
459                let dw = self.compute_dw(dt);
460                if let Some(synapse) = matrix.get_synapse_mut(pre, neuron_id) {
461                    synapse.weight = (synapse.weight + dw).clamp(w_min, w_max);
462                }
463            }
464        }
465
466        // Collect post-spike times
467        let post_times: Vec<_> = (0..n_post)
468            .map(|post| matrix.post_spike_times.get(post).copied().unwrap_or(f64::NEG_INFINITY))
469            .collect();
470
471        for post in 0..n_post {
472            let t_post = post_times[post];
473            if t_post > f64::NEG_INFINITY {
474                let dt = t_post - time;  // Reversed for outgoing
475                let dw = self.compute_dw(dt);
476                if let Some(synapse) = matrix.get_synapse_mut(neuron_id, post) {
477                    synapse.weight = (synapse.weight + dw).clamp(w_min, w_max);
478                }
479            }
480        }
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn test_synapse_creation() {
490        let synapse = Synapse::new(0, 1, 0.5);
491        assert_eq!(synapse.pre, 0);
492        assert_eq!(synapse.post, 1);
493        assert_eq!(synapse.weight, 0.5);
494    }
495
496    #[test]
497    fn test_stdp_ltp() {
498        let mut synapse = Synapse::new(0, 1, 0.5);
499        let config = STDPConfig::default();
500
501        // Pre before post → LTP
502        let dw = synapse.stdp_update(10.0, 15.0, &config);
503        assert!(dw > 0.0);
504        assert!(synapse.weight > 0.5);
505    }
506
507    #[test]
508    fn test_stdp_ltd() {
509        let mut synapse = Synapse::new(0, 1, 0.5);
510        let config = STDPConfig::default();
511
512        // Post before pre → LTD
513        let dw = synapse.stdp_update(15.0, 10.0, &config);
514        assert!(dw < 0.0);
515        assert!(synapse.weight < 0.5);
516    }
517
518    #[test]
519    fn test_synapse_matrix() {
520        let mut matrix = SynapseMatrix::new(10, 10);
521        matrix.add_synapse(0, 1, 0.5);
522        matrix.add_synapse(1, 2, 0.3);
523
524        assert_eq!(matrix.num_synapses(), 2);
525        assert!((matrix.weight(0, 1) - 0.5).abs() < 0.001);
526        assert!((matrix.weight(1, 2) - 0.3).abs() < 0.001);
527        assert_eq!(matrix.weight(2, 3), 0.0);
528    }
529
530    #[test]
531    fn test_spike_processing() {
532        let mut matrix = SynapseMatrix::new(5, 5);
533
534        // Fully connected
535        for i in 0..5 {
536            for j in 0..5 {
537                if i != j {
538                    matrix.add_synapse(i, j, 0.5);
539                }
540            }
541        }
542
543        // Pre spike then post spike → LTP
544        matrix.on_pre_spike(0, 10.0);
545        matrix.on_post_spike(1, 15.0);
546
547        // Should have strengthened 0→1 connection
548        assert!(matrix.weight(0, 1) > 0.5);
549    }
550
551    #[test]
552    fn test_asymmetric_stdp() {
553        let stdp = AsymmetricSTDP::default();
554
555        // Causal (dt > 0) should have larger effect
556        let dw_causal = stdp.compute_dw(5.0);
557        let dw_anticausal = stdp.compute_dw(-5.0);
558
559        assert!(dw_causal > 0.0);
560        assert!(dw_anticausal < 0.0);
561        assert!(dw_causal.abs() > dw_anticausal.abs());
562    }
563
564    #[test]
565    fn test_dense_conversion() {
566        let mut matrix = SynapseMatrix::new(3, 3);
567        matrix.add_synapse(0, 1, 0.5);
568        matrix.add_synapse(1, 2, 0.7);
569
570        let dense = matrix.to_dense();
571        assert_eq!(dense.len(), 3);
572        assert!((dense[0][1] - 0.5).abs() < 0.001);
573        assert!((dense[1][2] - 0.7).abs() < 0.001);
574
575        let recovered = SynapseMatrix::from_dense(&dense);
576        assert_eq!(recovered.num_synapses(), 2);
577    }
578}