ruvector_mincut/snn/
synapse.rs

1//! # Synaptic Connections with STDP Learning
2//!
3//! Implements spike-timing dependent plasticity (STDP) for synaptic weight updates.
4//!
5//! ## STDP Learning Rule
6//!
7//! ```text
8//! ΔW = A+ * exp(-Δt/τ+)  if Δt > 0 (pre before post → LTP)
9//! ΔW = A- * exp(Δt/τ-)   if Δt < 0 (post before pre → LTD)
10//! ```
11//!
12//! Where Δt = t_post - t_pre
13//!
14//! ## Integration with MinCut
15//!
16//! Synaptic weights directly map to graph edge weights:
17//! - Strong synapse → strong edge → less likely in mincut
18//! - STDP learning → edge weight evolution → dynamic mincut
19
20use super::{SimTime, Spike};
21use crate::graph::{DynamicGraph, VertexId, Weight};
22use std::collections::HashMap;
23
24/// Configuration for STDP learning
25#[derive(Debug, Clone)]
26pub struct STDPConfig {
27    /// LTP amplitude (potentiation)
28    pub a_plus: f64,
29    /// LTD amplitude (depression)
30    pub a_minus: f64,
31    /// LTP time constant (ms)
32    pub tau_plus: f64,
33    /// LTD time constant (ms)
34    pub tau_minus: f64,
35    /// Minimum weight
36    pub w_min: f64,
37    /// Maximum weight
38    pub w_max: f64,
39    /// Learning rate
40    pub learning_rate: f64,
41    /// Eligibility trace time constant
42    pub tau_eligibility: f64,
43}
44
45impl Default for STDPConfig {
46    fn default() -> Self {
47        Self {
48            a_plus: 0.01,
49            a_minus: 0.012,
50            tau_plus: 20.0,
51            tau_minus: 20.0,
52            w_min: 0.0,
53            w_max: 1.0,
54            learning_rate: 1.0,
55            tau_eligibility: 1000.0,
56        }
57    }
58}
59
60/// A single synapse between two neurons
61#[derive(Debug, Clone)]
62pub struct Synapse {
63    /// Pre-synaptic neuron ID
64    pub pre: usize,
65    /// Post-synaptic neuron ID
66    pub post: usize,
67    /// Synaptic weight
68    pub weight: f64,
69    /// Transmission delay (ms)
70    pub delay: f64,
71    /// Eligibility trace for reward-modulated STDP
72    pub eligibility: f64,
73    /// Last update time
74    pub last_update: SimTime,
75}
76
77impl Synapse {
78    /// Create a new synapse
79    pub fn new(pre: usize, post: usize, weight: f64) -> Self {
80        Self {
81            pre,
82            post,
83            weight,
84            delay: 1.0,
85            eligibility: 0.0,
86            last_update: 0.0,
87        }
88    }
89
90    /// Create synapse with delay
91    pub fn with_delay(pre: usize, post: usize, weight: f64, delay: f64) -> Self {
92        Self {
93            pre,
94            post,
95            weight,
96            delay,
97            eligibility: 0.0,
98            last_update: 0.0,
99        }
100    }
101
102    /// Compute STDP weight change
103    pub fn stdp_update(&mut self, t_pre: SimTime, t_post: SimTime, config: &STDPConfig) -> f64 {
104        let dt = t_post - t_pre;
105
106        let dw = if dt > 0.0 {
107            // Pre before post → LTP
108            config.a_plus * (-dt / config.tau_plus).exp()
109        } else {
110            // Post before pre → LTD
111            -config.a_minus * (dt / config.tau_minus).exp()
112        };
113
114        // Apply learning rate and clip
115        let delta = config.learning_rate * dw;
116        self.weight = (self.weight + delta).clamp(config.w_min, config.w_max);
117
118        // Update eligibility trace
119        self.eligibility += dw;
120
121        delta
122    }
123
124    /// Decay eligibility trace
125    pub fn decay_eligibility(&mut self, dt: f64, tau: f64) {
126        self.eligibility *= (-dt / tau).exp();
127    }
128
129    /// Apply reward-modulated update (R-STDP)
130    pub fn reward_modulated_update(&mut self, reward: f64, config: &STDPConfig) {
131        let delta = reward * self.eligibility * config.learning_rate;
132        self.weight = (self.weight + delta).clamp(config.w_min, config.w_max);
133        // Reset eligibility after reward
134        self.eligibility *= 0.5;
135    }
136}
137
138/// Matrix of synaptic connections
139#[derive(Debug, Clone)]
140pub struct SynapseMatrix {
141    /// Number of pre-synaptic neurons
142    pub n_pre: usize,
143    /// Number of post-synaptic neurons
144    pub n_post: usize,
145    /// Synapses indexed by (pre, post)
146    synapses: HashMap<(usize, usize), Synapse>,
147    /// STDP configuration
148    pub config: STDPConfig,
149    /// Track last spike times for pre-synaptic neurons
150    pre_spike_times: Vec<SimTime>,
151    /// Track last spike times for post-synaptic neurons
152    post_spike_times: Vec<SimTime>,
153}
154
155impl SynapseMatrix {
156    /// Create a new synapse matrix
157    pub fn new(n_pre: usize, n_post: usize) -> Self {
158        Self {
159            n_pre,
160            n_post,
161            synapses: HashMap::new(),
162            config: STDPConfig::default(),
163            pre_spike_times: vec![f64::NEG_INFINITY; n_pre],
164            post_spike_times: vec![f64::NEG_INFINITY; n_post],
165        }
166    }
167
168    /// Create with custom STDP config
169    pub fn with_config(n_pre: usize, n_post: usize, config: STDPConfig) -> Self {
170        Self {
171            n_pre,
172            n_post,
173            synapses: HashMap::new(),
174            config,
175            pre_spike_times: vec![f64::NEG_INFINITY; n_pre],
176            post_spike_times: vec![f64::NEG_INFINITY; n_post],
177        }
178    }
179
180    /// Add a synapse
181    pub fn add_synapse(&mut self, pre: usize, post: usize, weight: f64) {
182        if pre < self.n_pre && post < self.n_post {
183            self.synapses
184                .insert((pre, post), Synapse::new(pre, post, weight));
185        }
186    }
187
188    /// Get synapse if it exists
189    pub fn get_synapse(&self, pre: usize, post: usize) -> Option<&Synapse> {
190        self.synapses.get(&(pre, post))
191    }
192
193    /// Get mutable synapse if it exists
194    pub fn get_synapse_mut(&mut self, pre: usize, post: usize) -> Option<&mut Synapse> {
195        self.synapses.get_mut(&(pre, post))
196    }
197
198    /// Get weight of a synapse (0 if doesn't exist)
199    pub fn weight(&self, pre: usize, post: usize) -> f64 {
200        self.get_synapse(pre, post).map(|s| s.weight).unwrap_or(0.0)
201    }
202
203    /// Compute weighted sum for all post-synaptic neurons given pre-synaptic activations
204    ///
205    /// This is optimized to iterate only over existing synapses, avoiding O(n²) lookups.
206    /// pre_activations[i] is the activation of pre-synaptic neuron i.
207    /// Returns vector of weighted sums for each post-synaptic neuron.
208    #[inline]
209    pub fn compute_weighted_sums(&self, pre_activations: &[f64]) -> Vec<f64> {
210        let mut sums = vec![0.0; self.n_post];
211
212        // Iterate only over existing synapses (sparse operation)
213        for (&(pre, post), synapse) in &self.synapses {
214            if pre < pre_activations.len() {
215                sums[post] += synapse.weight * pre_activations[pre];
216            }
217        }
218
219        sums
220    }
221
222    /// Compute weighted sum for a single post-synaptic neuron
223    #[inline]
224    pub fn weighted_sum_for_post(&self, post: usize, pre_activations: &[f64]) -> f64 {
225        let mut sum = 0.0;
226        for pre in 0..self.n_pre.min(pre_activations.len()) {
227            if let Some(synapse) = self.synapses.get(&(pre, post)) {
228                sum += synapse.weight * pre_activations[pre];
229            }
230        }
231        sum
232    }
233
234    /// Set weight of a synapse (creates if doesn't exist)
235    pub fn set_weight(&mut self, pre: usize, post: usize, weight: f64) {
236        if let Some(synapse) = self.get_synapse_mut(pre, post) {
237            synapse.weight = weight;
238        } else {
239            self.add_synapse(pre, post, weight);
240        }
241    }
242
243    /// Record a pre-synaptic spike and perform STDP updates
244    pub fn on_pre_spike(&mut self, pre: usize, time: SimTime) {
245        if pre >= self.n_pre {
246            return;
247        }
248
249        self.pre_spike_times[pre] = time;
250
251        // LTD: pre spike after recent post spikes
252        for post in 0..self.n_post {
253            if let Some(synapse) = self.synapses.get_mut(&(pre, post)) {
254                let t_post = self.post_spike_times[post];
255                if t_post > f64::NEG_INFINITY {
256                    synapse.stdp_update(time, t_post, &self.config);
257                }
258            }
259        }
260    }
261
262    /// Record a post-synaptic spike and perform STDP updates
263    pub fn on_post_spike(&mut self, post: usize, time: SimTime) {
264        if post >= self.n_post {
265            return;
266        }
267
268        self.post_spike_times[post] = time;
269
270        // LTP: post spike after recent pre spikes
271        for pre in 0..self.n_pre {
272            if let Some(synapse) = self.synapses.get_mut(&(pre, post)) {
273                let t_pre = self.pre_spike_times[pre];
274                if t_pre > f64::NEG_INFINITY {
275                    synapse.stdp_update(t_pre, time, &self.config);
276                }
277            }
278        }
279    }
280
281    /// Process multiple spikes with STDP
282    pub fn process_spikes(&mut self, spikes: &[Spike]) {
283        for spike in spikes {
284            // Assume neuron IDs map directly
285            // Pre-synaptic: lower half, Post-synaptic: upper half (example mapping)
286            if spike.neuron_id < self.n_pre {
287                self.on_pre_spike(spike.neuron_id, spike.time);
288            }
289            if spike.neuron_id < self.n_post {
290                self.on_post_spike(spike.neuron_id, spike.time);
291            }
292        }
293    }
294
295    /// Decay all eligibility traces
296    pub fn decay_eligibility(&mut self, dt: f64) {
297        for synapse in self.synapses.values_mut() {
298            synapse.decay_eligibility(dt, self.config.tau_eligibility);
299        }
300    }
301
302    /// Apply reward-modulated learning to all synapses
303    pub fn apply_reward(&mut self, reward: f64) {
304        for synapse in self.synapses.values_mut() {
305            synapse.reward_modulated_update(reward, &self.config);
306        }
307    }
308
309    /// Get all synapses as an iterator
310    pub fn iter(&self) -> impl Iterator<Item = (&(usize, usize), &Synapse)> {
311        self.synapses.iter()
312    }
313
314    /// Get number of synapses
315    pub fn num_synapses(&self) -> usize {
316        self.synapses.len()
317    }
318
319    /// Compute total synaptic input to a post-synaptic neuron
320    pub fn input_to(&self, post: usize, pre_activities: &[f64]) -> f64 {
321        let mut total = 0.0;
322        for pre in 0..self.n_pre.min(pre_activities.len()) {
323            total += self.weight(pre, post) * pre_activities[pre];
324        }
325        total
326    }
327
328    /// Create dense weight matrix
329    pub fn to_dense(&self) -> Vec<Vec<f64>> {
330        let mut matrix = vec![vec![0.0; self.n_post]; self.n_pre];
331        for ((pre, post), synapse) in &self.synapses {
332            matrix[*pre][*post] = synapse.weight;
333        }
334        matrix
335    }
336
337    /// Initialize from dense matrix
338    pub fn from_dense(matrix: &[Vec<f64>]) -> Self {
339        let n_pre = matrix.len();
340        let n_post = matrix.first().map(|r| r.len()).unwrap_or(0);
341
342        let mut sm = Self::new(n_pre, n_post);
343
344        for (pre, row) in matrix.iter().enumerate() {
345            for (post, &weight) in row.iter().enumerate() {
346                if weight != 0.0 {
347                    sm.add_synapse(pre, post, weight);
348                }
349            }
350        }
351
352        sm
353    }
354
355    /// Synchronize weights with a DynamicGraph
356    /// Maps neurons to vertices via a mapping function
357    pub fn sync_to_graph<F>(&self, graph: &mut DynamicGraph, neuron_to_vertex: F)
358    where
359        F: Fn(usize) -> VertexId,
360    {
361        for ((pre, post), synapse) in &self.synapses {
362            let u = neuron_to_vertex(*pre);
363            let v = neuron_to_vertex(*post);
364
365            if graph.has_edge(u, v) {
366                let _ = graph.update_edge_weight(u, v, synapse.weight);
367            }
368        }
369    }
370
371    /// Load weights from a DynamicGraph
372    pub fn sync_from_graph<F>(&mut self, graph: &DynamicGraph, vertex_to_neuron: F)
373    where
374        F: Fn(VertexId) -> usize,
375    {
376        for edge in graph.edges() {
377            let pre = vertex_to_neuron(edge.source);
378            let post = vertex_to_neuron(edge.target);
379
380            if pre < self.n_pre && post < self.n_post {
381                self.set_weight(pre, post, edge.weight);
382            }
383        }
384    }
385
386    /// Get high-correlation pairs (synapses with weight above threshold)
387    pub fn high_correlation_pairs(&self, threshold: f64) -> Vec<(usize, usize)> {
388        self.synapses
389            .iter()
390            .filter(|(_, s)| s.weight >= threshold)
391            .map(|((pre, post), _)| (*pre, *post))
392            .collect()
393    }
394}
395
396/// Asymmetric STDP for causal relationship encoding
397#[derive(Debug, Clone)]
398pub struct AsymmetricSTDP {
399    /// Forward (causal) time constant
400    pub tau_forward: f64,
401    /// Backward time constant
402    pub tau_backward: f64,
403    /// Forward amplitude (typically larger for causality)
404    pub a_forward: f64,
405    /// Backward amplitude
406    pub a_backward: f64,
407}
408
409impl Default for AsymmetricSTDP {
410    fn default() -> Self {
411        Self {
412            tau_forward: 15.0,
413            tau_backward: 30.0, // Longer backward window
414            a_forward: 0.015,   // Stronger forward (causal)
415            a_backward: 0.008,  // Weaker backward
416        }
417    }
418}
419
420impl AsymmetricSTDP {
421    /// Compute weight change for causal relationship encoding
422    /// Positive Δt (pre→post) is weighted more heavily
423    pub fn compute_dw(&self, dt: f64) -> f64 {
424        if dt > 0.0 {
425            // Pre before post → causal relationship
426            self.a_forward * (-dt / self.tau_forward).exp()
427        } else {
428            // Post before pre → anti-causal
429            -self.a_backward * (dt / self.tau_backward).exp()
430        }
431    }
432
433    /// Update weight matrix for causal discovery
434    pub fn update_weights(&self, matrix: &mut SynapseMatrix, neuron_id: usize, time: SimTime) {
435        let w_min = matrix.config.w_min;
436        let w_max = matrix.config.w_max;
437        let n_pre = matrix.n_pre;
438        let n_post = matrix.n_post;
439
440        // Collect pre-spike times first to avoid borrow conflicts
441        let pre_times: Vec<_> = (0..n_pre)
442            .map(|pre| {
443                matrix
444                    .pre_spike_times
445                    .get(pre)
446                    .copied()
447                    .unwrap_or(f64::NEG_INFINITY)
448            })
449            .collect();
450
451        // This neuron just spiked - update all synapses involving it (incoming)
452        for pre in 0..n_pre {
453            let t_pre = pre_times[pre];
454            if t_pre > f64::NEG_INFINITY {
455                let dt = time - t_pre;
456                let dw = self.compute_dw(dt);
457                if let Some(synapse) = matrix.get_synapse_mut(pre, neuron_id) {
458                    synapse.weight = (synapse.weight + dw).clamp(w_min, w_max);
459                }
460            }
461        }
462
463        // Collect post-spike times
464        let post_times: Vec<_> = (0..n_post)
465            .map(|post| {
466                matrix
467                    .post_spike_times
468                    .get(post)
469                    .copied()
470                    .unwrap_or(f64::NEG_INFINITY)
471            })
472            .collect();
473
474        for post in 0..n_post {
475            let t_post = post_times[post];
476            if t_post > f64::NEG_INFINITY {
477                let dt = t_post - time; // Reversed for outgoing
478                let dw = self.compute_dw(dt);
479                if let Some(synapse) = matrix.get_synapse_mut(neuron_id, post) {
480                    synapse.weight = (synapse.weight + dw).clamp(w_min, w_max);
481                }
482            }
483        }
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    #[test]
492    fn test_synapse_creation() {
493        let synapse = Synapse::new(0, 1, 0.5);
494        assert_eq!(synapse.pre, 0);
495        assert_eq!(synapse.post, 1);
496        assert_eq!(synapse.weight, 0.5);
497    }
498
499    #[test]
500    fn test_stdp_ltp() {
501        let mut synapse = Synapse::new(0, 1, 0.5);
502        let config = STDPConfig::default();
503
504        // Pre before post → LTP
505        let dw = synapse.stdp_update(10.0, 15.0, &config);
506        assert!(dw > 0.0);
507        assert!(synapse.weight > 0.5);
508    }
509
510    #[test]
511    fn test_stdp_ltd() {
512        let mut synapse = Synapse::new(0, 1, 0.5);
513        let config = STDPConfig::default();
514
515        // Post before pre → LTD
516        let dw = synapse.stdp_update(15.0, 10.0, &config);
517        assert!(dw < 0.0);
518        assert!(synapse.weight < 0.5);
519    }
520
521    #[test]
522    fn test_synapse_matrix() {
523        let mut matrix = SynapseMatrix::new(10, 10);
524        matrix.add_synapse(0, 1, 0.5);
525        matrix.add_synapse(1, 2, 0.3);
526
527        assert_eq!(matrix.num_synapses(), 2);
528        assert!((matrix.weight(0, 1) - 0.5).abs() < 0.001);
529        assert!((matrix.weight(1, 2) - 0.3).abs() < 0.001);
530        assert_eq!(matrix.weight(2, 3), 0.0);
531    }
532
533    #[test]
534    fn test_spike_processing() {
535        let mut matrix = SynapseMatrix::new(5, 5);
536
537        // Fully connected
538        for i in 0..5 {
539            for j in 0..5 {
540                if i != j {
541                    matrix.add_synapse(i, j, 0.5);
542                }
543            }
544        }
545
546        // Pre spike then post spike → LTP
547        matrix.on_pre_spike(0, 10.0);
548        matrix.on_post_spike(1, 15.0);
549
550        // Should have strengthened 0→1 connection
551        assert!(matrix.weight(0, 1) > 0.5);
552    }
553
554    #[test]
555    fn test_asymmetric_stdp() {
556        let stdp = AsymmetricSTDP::default();
557
558        // Causal (dt > 0) should have larger effect
559        let dw_causal = stdp.compute_dw(5.0);
560        let dw_anticausal = stdp.compute_dw(-5.0);
561
562        assert!(dw_causal > 0.0);
563        assert!(dw_anticausal < 0.0);
564        assert!(dw_causal.abs() > dw_anticausal.abs());
565    }
566
567    #[test]
568    fn test_dense_conversion() {
569        let mut matrix = SynapseMatrix::new(3, 3);
570        matrix.add_synapse(0, 1, 0.5);
571        matrix.add_synapse(1, 2, 0.7);
572
573        let dense = matrix.to_dense();
574        assert_eq!(dense.len(), 3);
575        assert!((dense[0][1] - 0.5).abs() < 0.001);
576        assert!((dense[1][2] - 0.7).abs() < 0.001);
577
578        let recovered = SynapseMatrix::from_dense(&dense);
579        assert_eq!(recovered.num_synapses(), 2);
580    }
581}