ruvector_mincut/snn/
neuron.rs

1//! # Leaky Integrate-and-Fire Neuron Model
2//!
3//! Implements LIF neurons with adaptive thresholds and refractory periods.
4//!
5//! ## Membrane Dynamics
6//!
7//! ```text
8//! τ_m * dV/dt = -(V - V_rest) + R * I(t)
9//! ```
10//!
11//! When V >= θ: emit spike, V → V_reset, enter refractory period
12//!
13//! ## Features
14//!
15//! - Exponential leak with configurable time constant
16//! - Adaptive threshold (increases after spike, decays back)
17//! - Absolute refractory period
18//! - Homeostatic plasticity for stable firing rates
19
20use super::{SimTime, Spike};
21use rayon::prelude::*;
22use std::collections::VecDeque;
23
24/// Threshold for using parallel neuron updates (overhead not worth it for small populations)
25/// Set high because neuron.step() is very fast, parallel overhead dominates for smaller sizes.
26const PARALLEL_THRESHOLD: usize = 2000;
27
28/// Configuration for LIF neuron
29#[derive(Debug, Clone)]
30pub struct NeuronConfig {
31    /// Membrane time constant (ms)
32    pub tau_membrane: f64,
33    /// Resting potential (mV)
34    pub v_rest: f64,
35    /// Reset potential after spike (mV)
36    pub v_reset: f64,
37    /// Initial threshold (mV)
38    pub threshold: f64,
39    /// Absolute refractory period (ms)
40    pub t_refrac: f64,
41    /// Membrane resistance (MΩ)
42    pub resistance: f64,
43    /// Threshold adaptation increment
44    pub threshold_adapt: f64,
45    /// Threshold adaptation time constant (ms)
46    pub tau_threshold: f64,
47    /// Enable homeostatic plasticity
48    pub homeostatic: bool,
49    /// Target spike rate (spikes/ms) for homeostasis
50    pub target_rate: f64,
51    /// Homeostatic time constant (ms)
52    pub tau_homeostatic: f64,
53}
54
55impl Default for NeuronConfig {
56    fn default() -> Self {
57        Self {
58            tau_membrane: 20.0,
59            v_rest: 0.0,
60            v_reset: 0.0,
61            threshold: 1.0,
62            t_refrac: 2.0,
63            resistance: 1.0,
64            threshold_adapt: 0.1,
65            tau_threshold: 100.0,
66            homeostatic: true,
67            target_rate: 0.01,
68            tau_homeostatic: 1000.0,
69        }
70    }
71}
72
73/// State of a single LIF neuron
74#[derive(Debug, Clone)]
75pub struct NeuronState {
76    /// Membrane potential (mV)
77    pub v: f64,
78    /// Current threshold (may be adapted)
79    pub threshold: f64,
80    /// Time remaining in refractory period (ms)
81    pub refrac_remaining: f64,
82    /// Last spike time (-∞ if never spiked)
83    pub last_spike_time: f64,
84    /// Running average spike rate (for homeostasis)
85    pub spike_rate: f64,
86}
87
88impl Default for NeuronState {
89    fn default() -> Self {
90        Self {
91            v: 0.0,
92            threshold: 1.0,
93            refrac_remaining: 0.0,
94            last_spike_time: f64::NEG_INFINITY,
95            spike_rate: 0.0,
96        }
97    }
98}
99
100/// Leaky Integrate-and-Fire neuron
101#[derive(Debug, Clone)]
102pub struct LIFNeuron {
103    /// Neuron ID
104    pub id: usize,
105    /// Configuration parameters
106    pub config: NeuronConfig,
107    /// Current state
108    pub state: NeuronState,
109}
110
111impl LIFNeuron {
112    /// Create a new LIF neuron with given ID
113    pub fn new(id: usize) -> Self {
114        Self {
115            id,
116            config: NeuronConfig::default(),
117            state: NeuronState::default(),
118        }
119    }
120
121    /// Create a new LIF neuron with custom configuration
122    pub fn with_config(id: usize, config: NeuronConfig) -> Self {
123        let mut state = NeuronState::default();
124        state.threshold = config.threshold;
125        Self { id, config, state }
126    }
127
128    /// Reset neuron state to initial conditions
129    pub fn reset(&mut self) {
130        self.state = NeuronState {
131            threshold: self.config.threshold,
132            ..NeuronState::default()
133        };
134    }
135
136    /// Integrate input current for one timestep
137    /// Returns true if neuron spiked
138    pub fn step(&mut self, current: f64, dt: f64, time: SimTime) -> bool {
139        // Handle refractory period
140        if self.state.refrac_remaining > 0.0 {
141            self.state.refrac_remaining -= dt;
142            return false;
143        }
144
145        // Membrane dynamics: τ dV/dt = -(V - V_rest) + R*I
146        let dv = (-self.state.v + self.config.v_rest + self.config.resistance * current)
147                 / self.config.tau_membrane * dt;
148        self.state.v += dv;
149
150        // Threshold adaptation decay
151        if self.state.threshold > self.config.threshold {
152            let d_thresh = -(self.state.threshold - self.config.threshold)
153                           / self.config.tau_threshold * dt;
154            self.state.threshold += d_thresh;
155        }
156
157        // Check for spike
158        if self.state.v >= self.state.threshold {
159            // Fire!
160            self.state.v = self.config.v_reset;
161            self.state.refrac_remaining = self.config.t_refrac;
162            self.state.last_spike_time = time;
163
164            // Threshold adaptation
165            self.state.threshold += self.config.threshold_adapt;
166
167            // Update running spike rate for homeostasis using proper exponential
168            // decay based on tau_homeostatic: rate += (1 - rate) * dt / tau
169            let alpha = (dt / self.config.tau_homeostatic).min(1.0);
170            self.state.spike_rate = self.state.spike_rate * (1.0 - alpha) + alpha;
171
172            return true;
173        }
174
175        // Update spike rate (decay toward 0)
176        self.state.spike_rate *= 1.0 - dt / self.config.tau_homeostatic;
177
178        // Homeostatic plasticity: adjust threshold based on firing rate
179        if self.config.homeostatic {
180            let rate_error = self.state.spike_rate - self.config.target_rate;
181            let d_base_thresh = rate_error * dt / self.config.tau_homeostatic;
182            // Only apply to base threshold, not adapted part
183            // This is a simplification - full implementation would track separately
184        }
185
186        false
187    }
188
189    /// Inject a direct spike (for input neurons)
190    pub fn inject_spike(&mut self, time: SimTime) {
191        self.state.last_spike_time = time;
192        // Use same homeostatic update as regular spikes
193        let alpha = (1.0 / self.config.tau_homeostatic).min(1.0);
194        self.state.spike_rate = self.state.spike_rate * (1.0 - alpha) + alpha;
195    }
196
197    /// Get time since last spike
198    pub fn time_since_spike(&self, current_time: SimTime) -> f64 {
199        current_time - self.state.last_spike_time
200    }
201
202    /// Check if neuron is in refractory period
203    pub fn is_refractory(&self) -> bool {
204        self.state.refrac_remaining > 0.0
205    }
206
207    /// Get membrane potential
208    pub fn membrane_potential(&self) -> f64 {
209        self.state.v
210    }
211
212    /// Set membrane potential directly
213    pub fn set_membrane_potential(&mut self, v: f64) {
214        self.state.v = v;
215    }
216
217    /// Get current threshold
218    pub fn threshold(&self) -> f64 {
219        self.state.threshold
220    }
221}
222
223/// A collection of spikes over time for one neuron
224#[derive(Debug, Clone)]
225pub struct SpikeTrain {
226    /// Neuron ID
227    pub neuron_id: usize,
228    /// Spike times (sorted)
229    pub spike_times: Vec<SimTime>,
230    /// Maximum time window to keep
231    pub max_window: f64,
232}
233
234impl SpikeTrain {
235    /// Create a new empty spike train
236    pub fn new(neuron_id: usize) -> Self {
237        Self {
238            neuron_id,
239            spike_times: Vec::new(),
240            max_window: 1000.0, // 1 second default
241        }
242    }
243
244    /// Create with custom window size
245    pub fn with_window(neuron_id: usize, max_window: f64) -> Self {
246        Self {
247            neuron_id,
248            spike_times: Vec::new(),
249            max_window,
250        }
251    }
252
253    /// Record a spike at given time
254    pub fn record_spike(&mut self, time: SimTime) {
255        self.spike_times.push(time);
256
257        // Prune old spikes
258        let cutoff = time - self.max_window;
259        self.spike_times.retain(|&t| t >= cutoff);
260    }
261
262    /// Clear all recorded spikes
263    pub fn clear(&mut self) {
264        self.spike_times.clear();
265    }
266
267    /// Get number of spikes in the train
268    pub fn count(&self) -> usize {
269        self.spike_times.len()
270    }
271
272    /// Compute instantaneous spike rate (spikes/ms)
273    pub fn spike_rate(&self, window: f64) -> f64 {
274        if self.spike_times.is_empty() {
275            return 0.0;
276        }
277
278        let latest = self.spike_times.last().copied().unwrap_or(0.0);
279        let count = self.spike_times.iter()
280            .filter(|&&t| t >= latest - window)
281            .count();
282
283        count as f64 / window
284    }
285
286    /// Compute inter-spike interval statistics
287    pub fn mean_isi(&self) -> Option<f64> {
288        if self.spike_times.len() < 2 {
289            return None;
290        }
291
292        let mut total_isi = 0.0;
293        for i in 1..self.spike_times.len() {
294            total_isi += self.spike_times[i] - self.spike_times[i - 1];
295        }
296
297        Some(total_isi / (self.spike_times.len() - 1) as f64)
298    }
299
300    /// Get coefficient of variation of ISI
301    pub fn cv_isi(&self) -> Option<f64> {
302        let mean = self.mean_isi()?;
303        if mean == 0.0 {
304            return None;
305        }
306
307        let mut variance = 0.0;
308        for i in 1..self.spike_times.len() {
309            let isi = self.spike_times[i] - self.spike_times[i - 1];
310            variance += (isi - mean).powi(2);
311        }
312        variance /= (self.spike_times.len() - 1) as f64;
313
314        Some(variance.sqrt() / mean)
315    }
316
317    /// Convert spike train to binary pattern (temporal encoding)
318    ///
319    /// Safely handles potential overflow in bin calculation.
320    pub fn to_pattern(&self, start: SimTime, bin_size: f64, num_bins: usize) -> Vec<bool> {
321        let mut pattern = vec![false; num_bins];
322
323        // Guard against zero/negative bin_size
324        if bin_size <= 0.0 || num_bins == 0 {
325            return pattern;
326        }
327
328        let end_time = start + bin_size * num_bins as f64;
329
330        for &spike_time in &self.spike_times {
331            if spike_time >= start && spike_time < end_time {
332                // Safe bin calculation with overflow protection
333                let offset = spike_time - start;
334                let bin_f64 = offset / bin_size;
335
336                // Check for overflow before casting
337                if bin_f64 >= 0.0 && bin_f64 < num_bins as f64 {
338                    let bin = bin_f64 as usize;
339                    if bin < num_bins {
340                        pattern[bin] = true;
341                    }
342                }
343            }
344        }
345
346        pattern
347    }
348
349    /// Check if spike times are sorted (for optimization)
350    #[inline]
351    fn is_sorted(times: &[f64]) -> bool {
352        times.windows(2).all(|w| w[0] <= w[1])
353    }
354
355    /// Compute cross-correlation with another spike train
356    ///
357    /// Uses O(n log n) sliding window algorithm instead of O(n²) pairwise comparison.
358    /// Optimized to skip sorting when spike trains are already sorted (typical case).
359    /// Uses binary search for initial window position.
360    pub fn cross_correlation(&self, other: &SpikeTrain, max_lag: f64, bin_size: f64) -> Vec<f64> {
361        // Guard against invalid parameters
362        if bin_size <= 0.0 || max_lag <= 0.0 {
363            return vec![0.0];
364        }
365
366        // Safe num_bins calculation with overflow protection
367        let num_bins_f64 = 2.0 * max_lag / bin_size + 1.0;
368        let num_bins = if num_bins_f64 > 0.0 && num_bins_f64 < usize::MAX as f64 {
369            (num_bins_f64 as usize).min(100_000) // Cap at 100K bins to prevent DoS
370        } else {
371            return vec![0.0];
372        };
373
374        let mut correlation = vec![0.0; num_bins];
375
376        // Empty train optimization
377        if self.spike_times.is_empty() || other.spike_times.is_empty() {
378            return correlation;
379        }
380
381        // Avoid cloning and sorting if already sorted (typical case for spike trains)
382        let t1_owned: Vec<f64>;
383        let t2_owned: Vec<f64>;
384
385        let t1: &[f64] = if Self::is_sorted(&self.spike_times) {
386            &self.spike_times
387        } else {
388            t1_owned = {
389                let mut v = self.spike_times.clone();
390                v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
391                v
392            };
393            &t1_owned
394        };
395
396        let t2: &[f64] = if Self::is_sorted(&other.spike_times) {
397            &other.spike_times
398        } else {
399            t2_owned = {
400                let mut v = other.spike_times.clone();
401                v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
402                v
403            };
404            &t2_owned
405        };
406
407        // Use binary search for first spike's window start
408        let first_lower = t1[0] - max_lag;
409        let mut window_start = t2.partition_point(|&x| x < first_lower);
410
411        for &t1_spike in t1 {
412            let lower_bound = t1_spike - max_lag;
413            let upper_bound = t1_spike + max_lag;
414
415            // Advance window_start past spikes too early
416            while window_start < t2.len() && t2[window_start] < lower_bound {
417                window_start += 1;
418            }
419
420            // Count spikes within window
421            let mut j = window_start;
422            while j < t2.len() && t2[j] <= upper_bound {
423                let lag = t1_spike - t2[j];
424
425                // Safe bin calculation (inlined for performance)
426                let bin = ((lag + max_lag) / bin_size) as usize;
427                if bin < num_bins {
428                    correlation[bin] += 1.0;
429                }
430                j += 1;
431            }
432        }
433
434        // Normalize by geometric mean of spike counts
435        let norm = ((self.count() * other.count()) as f64).sqrt();
436        if norm > 0.0 {
437            let inv_norm = 1.0 / norm;
438            for c in &mut correlation {
439                *c *= inv_norm;
440            }
441        }
442
443        correlation
444    }
445}
446
447/// Population of LIF neurons
448#[derive(Debug, Clone)]
449pub struct NeuronPopulation {
450    /// All neurons in the population
451    pub neurons: Vec<LIFNeuron>,
452    /// Spike trains for each neuron
453    pub spike_trains: Vec<SpikeTrain>,
454    /// Current simulation time
455    pub time: SimTime,
456}
457
458impl NeuronPopulation {
459    /// Create a new population with n neurons
460    pub fn new(n: usize) -> Self {
461        let neurons: Vec<_> = (0..n).map(|i| LIFNeuron::new(i)).collect();
462        let spike_trains: Vec<_> = (0..n).map(|i| SpikeTrain::new(i)).collect();
463
464        Self {
465            neurons,
466            spike_trains,
467            time: 0.0,
468        }
469    }
470
471    /// Create population with custom configuration
472    pub fn with_config(n: usize, config: NeuronConfig) -> Self {
473        let neurons: Vec<_> = (0..n)
474            .map(|i| LIFNeuron::with_config(i, config.clone()))
475            .collect();
476        let spike_trains: Vec<_> = (0..n).map(|i| SpikeTrain::new(i)).collect();
477
478        Self {
479            neurons,
480            spike_trains,
481            time: 0.0,
482        }
483    }
484
485    /// Get number of neurons
486    pub fn size(&self) -> usize {
487        self.neurons.len()
488    }
489
490    /// Step all neurons with given currents
491    ///
492    /// Uses parallel processing for large populations (>200 neurons).
493    pub fn step(&mut self, currents: &[f64], dt: f64) -> Vec<Spike> {
494        self.time += dt;
495        let time = self.time;
496
497        if self.neurons.len() >= PARALLEL_THRESHOLD {
498            // Parallel path: compute neuron updates in parallel
499            let spike_flags: Vec<bool> = self.neurons
500                .par_iter_mut()
501                .enumerate()
502                .map(|(i, neuron)| {
503                    let current = currents.get(i).copied().unwrap_or(0.0);
504                    neuron.step(current, dt, time)
505                })
506                .collect();
507
508            // Sequential: collect spikes and record to trains
509            let mut spikes = Vec::new();
510            for (i, &spiked) in spike_flags.iter().enumerate() {
511                if spiked {
512                    spikes.push(Spike { neuron_id: i, time });
513                    self.spike_trains[i].record_spike(time);
514                }
515            }
516            spikes
517        } else {
518            // Sequential path for small populations (avoid parallel overhead)
519            let mut spikes = Vec::new();
520            for (i, neuron) in self.neurons.iter_mut().enumerate() {
521                let current = currents.get(i).copied().unwrap_or(0.0);
522                if neuron.step(current, dt, time) {
523                    spikes.push(Spike { neuron_id: i, time });
524                    self.spike_trains[i].record_spike(time);
525                }
526            }
527            spikes
528        }
529    }
530
531    /// Reset all neurons
532    pub fn reset(&mut self) {
533        self.time = 0.0;
534        for neuron in &mut self.neurons {
535            neuron.reset();
536        }
537        for train in &mut self.spike_trains {
538            train.clear();
539        }
540    }
541
542    /// Get population spike rate
543    pub fn population_rate(&self, window: f64) -> f64 {
544        let total: f64 = self.spike_trains.iter()
545            .map(|t| t.spike_rate(window))
546            .sum();
547        total / self.neurons.len() as f64
548    }
549
550    /// Compute population synchrony
551    pub fn synchrony(&self, window: f64) -> f64 {
552        // Collect recent spikes
553        let mut all_spikes = Vec::new();
554        let cutoff = self.time - window;
555
556        for train in &self.spike_trains {
557            for &t in &train.spike_times {
558                if t >= cutoff {
559                    all_spikes.push(Spike { neuron_id: train.neuron_id, time: t });
560                }
561            }
562        }
563
564        super::compute_synchrony(&all_spikes, window / 10.0)
565    }
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571
572    #[test]
573    fn test_lif_neuron_creation() {
574        let neuron = LIFNeuron::new(0);
575        assert_eq!(neuron.id, 0);
576        assert_eq!(neuron.state.v, 0.0);
577    }
578
579    #[test]
580    fn test_lif_neuron_spike() {
581        let mut neuron = LIFNeuron::new(0);
582
583        // Apply strong current until it spikes
584        let mut spiked = false;
585        for i in 0..100 {
586            if neuron.step(2.0, 1.0, i as f64) {
587                spiked = true;
588                break;
589            }
590        }
591
592        assert!(spiked);
593        assert!(neuron.is_refractory());
594    }
595
596    #[test]
597    fn test_spike_train() {
598        let mut train = SpikeTrain::new(0);
599        train.record_spike(10.0);
600        train.record_spike(20.0);
601        train.record_spike(30.0);
602
603        assert_eq!(train.count(), 3);
604
605        let mean_isi = train.mean_isi().unwrap();
606        assert!((mean_isi - 10.0).abs() < 0.001);
607    }
608
609    #[test]
610    fn test_neuron_population() {
611        let mut pop = NeuronPopulation::new(100);
612
613        // Apply uniform current
614        let currents = vec![1.5; 100];
615
616        let mut total_spikes = 0;
617        for _ in 0..100 {
618            let spikes = pop.step(&currents, 1.0);
619            total_spikes += spikes.len();
620        }
621
622        // Should have some spikes after 100ms with current of 1.5
623        assert!(total_spikes > 0);
624    }
625
626    #[test]
627    fn test_spike_train_pattern() {
628        let mut train = SpikeTrain::new(0);
629        train.record_spike(1.0);
630        train.record_spike(3.0);
631        train.record_spike(7.0);
632
633        let pattern = train.to_pattern(0.0, 1.0, 10);
634        assert_eq!(pattern.len(), 10);
635        assert!(pattern[1]);  // Spike at t=1
636        assert!(pattern[3]);  // Spike at t=3
637        assert!(pattern[7]);  // Spike at t=7
638        assert!(!pattern[0]); // No spike at t=0
639    }
640}