ruvector_mincut/snn/
neuron.rs

1//! # Leaky Integrate-and-Fire Neuron Model
2//!
3//! Implements LIF neurons with adaptive thresholds and refractory periods.
4//!
5//! ## Membrane Dynamics
6//!
7//! ```text
8//! τ_m * dV/dt = -(V - V_rest) + R * I(t)
9//! ```
10//!
11//! When V >= θ: emit spike, V → V_reset, enter refractory period
12//!
13//! ## Features
14//!
15//! - Exponential leak with configurable time constant
16//! - Adaptive threshold (increases after spike, decays back)
17//! - Absolute refractory period
18//! - Homeostatic plasticity for stable firing rates
19
20use super::{SimTime, Spike};
21use rayon::prelude::*;
22use std::collections::VecDeque;
23
24/// Threshold for using parallel neuron updates (overhead not worth it for small populations)
25/// Set high because neuron.step() is very fast, parallel overhead dominates for smaller sizes.
26const PARALLEL_THRESHOLD: usize = 2000;
27
28/// Configuration for LIF neuron
29#[derive(Debug, Clone)]
30pub struct NeuronConfig {
31    /// Membrane time constant (ms)
32    pub tau_membrane: f64,
33    /// Resting potential (mV)
34    pub v_rest: f64,
35    /// Reset potential after spike (mV)
36    pub v_reset: f64,
37    /// Initial threshold (mV)
38    pub threshold: f64,
39    /// Absolute refractory period (ms)
40    pub t_refrac: f64,
41    /// Membrane resistance (MΩ)
42    pub resistance: f64,
43    /// Threshold adaptation increment
44    pub threshold_adapt: f64,
45    /// Threshold adaptation time constant (ms)
46    pub tau_threshold: f64,
47    /// Enable homeostatic plasticity
48    pub homeostatic: bool,
49    /// Target spike rate (spikes/ms) for homeostasis
50    pub target_rate: f64,
51    /// Homeostatic time constant (ms)
52    pub tau_homeostatic: f64,
53}
54
55impl Default for NeuronConfig {
56    fn default() -> Self {
57        Self {
58            tau_membrane: 20.0,
59            v_rest: 0.0,
60            v_reset: 0.0,
61            threshold: 1.0,
62            t_refrac: 2.0,
63            resistance: 1.0,
64            threshold_adapt: 0.1,
65            tau_threshold: 100.0,
66            homeostatic: true,
67            target_rate: 0.01,
68            tau_homeostatic: 1000.0,
69        }
70    }
71}
72
73/// State of a single LIF neuron
74#[derive(Debug, Clone)]
75pub struct NeuronState {
76    /// Membrane potential (mV)
77    pub v: f64,
78    /// Current threshold (may be adapted)
79    pub threshold: f64,
80    /// Time remaining in refractory period (ms)
81    pub refrac_remaining: f64,
82    /// Last spike time (-∞ if never spiked)
83    pub last_spike_time: f64,
84    /// Running average spike rate (for homeostasis)
85    pub spike_rate: f64,
86}
87
88impl Default for NeuronState {
89    fn default() -> Self {
90        Self {
91            v: 0.0,
92            threshold: 1.0,
93            refrac_remaining: 0.0,
94            last_spike_time: f64::NEG_INFINITY,
95            spike_rate: 0.0,
96        }
97    }
98}
99
100/// Leaky Integrate-and-Fire neuron
101#[derive(Debug, Clone)]
102pub struct LIFNeuron {
103    /// Neuron ID
104    pub id: usize,
105    /// Configuration parameters
106    pub config: NeuronConfig,
107    /// Current state
108    pub state: NeuronState,
109}
110
111impl LIFNeuron {
112    /// Create a new LIF neuron with given ID
113    pub fn new(id: usize) -> Self {
114        Self {
115            id,
116            config: NeuronConfig::default(),
117            state: NeuronState::default(),
118        }
119    }
120
121    /// Create a new LIF neuron with custom configuration
122    pub fn with_config(id: usize, config: NeuronConfig) -> Self {
123        let mut state = NeuronState::default();
124        state.threshold = config.threshold;
125        Self { id, config, state }
126    }
127
128    /// Reset neuron state to initial conditions
129    pub fn reset(&mut self) {
130        self.state = NeuronState {
131            threshold: self.config.threshold,
132            ..NeuronState::default()
133        };
134    }
135
136    /// Integrate input current for one timestep
137    /// Returns true if neuron spiked
138    pub fn step(&mut self, current: f64, dt: f64, time: SimTime) -> bool {
139        // Handle refractory period
140        if self.state.refrac_remaining > 0.0 {
141            self.state.refrac_remaining -= dt;
142            return false;
143        }
144
145        // Membrane dynamics: τ dV/dt = -(V - V_rest) + R*I
146        let dv = (-self.state.v + self.config.v_rest + self.config.resistance * current)
147            / self.config.tau_membrane
148            * dt;
149        self.state.v += dv;
150
151        // Threshold adaptation decay
152        if self.state.threshold > self.config.threshold {
153            let d_thresh =
154                -(self.state.threshold - self.config.threshold) / self.config.tau_threshold * dt;
155            self.state.threshold += d_thresh;
156        }
157
158        // Check for spike
159        if self.state.v >= self.state.threshold {
160            // Fire!
161            self.state.v = self.config.v_reset;
162            self.state.refrac_remaining = self.config.t_refrac;
163            self.state.last_spike_time = time;
164
165            // Threshold adaptation
166            self.state.threshold += self.config.threshold_adapt;
167
168            // Update running spike rate for homeostasis using proper exponential
169            // decay based on tau_homeostatic: rate += (1 - rate) * dt / tau
170            let alpha = (dt / self.config.tau_homeostatic).min(1.0);
171            self.state.spike_rate = self.state.spike_rate * (1.0 - alpha) + alpha;
172
173            return true;
174        }
175
176        // Update spike rate (decay toward 0)
177        self.state.spike_rate *= 1.0 - dt / self.config.tau_homeostatic;
178
179        // Homeostatic plasticity: adjust threshold based on firing rate
180        if self.config.homeostatic {
181            let rate_error = self.state.spike_rate - self.config.target_rate;
182            let d_base_thresh = rate_error * dt / self.config.tau_homeostatic;
183            // Only apply to base threshold, not adapted part
184            // This is a simplification - full implementation would track separately
185        }
186
187        false
188    }
189
190    /// Inject a direct spike (for input neurons)
191    pub fn inject_spike(&mut self, time: SimTime) {
192        self.state.last_spike_time = time;
193        // Use same homeostatic update as regular spikes
194        let alpha = (1.0 / self.config.tau_homeostatic).min(1.0);
195        self.state.spike_rate = self.state.spike_rate * (1.0 - alpha) + alpha;
196    }
197
198    /// Get time since last spike
199    pub fn time_since_spike(&self, current_time: SimTime) -> f64 {
200        current_time - self.state.last_spike_time
201    }
202
203    /// Check if neuron is in refractory period
204    pub fn is_refractory(&self) -> bool {
205        self.state.refrac_remaining > 0.0
206    }
207
208    /// Get membrane potential
209    pub fn membrane_potential(&self) -> f64 {
210        self.state.v
211    }
212
213    /// Set membrane potential directly
214    pub fn set_membrane_potential(&mut self, v: f64) {
215        self.state.v = v;
216    }
217
218    /// Get current threshold
219    pub fn threshold(&self) -> f64 {
220        self.state.threshold
221    }
222}
223
224/// A collection of spikes over time for one neuron
225#[derive(Debug, Clone)]
226pub struct SpikeTrain {
227    /// Neuron ID
228    pub neuron_id: usize,
229    /// Spike times (sorted)
230    pub spike_times: Vec<SimTime>,
231    /// Maximum time window to keep
232    pub max_window: f64,
233}
234
235impl SpikeTrain {
236    /// Create a new empty spike train
237    pub fn new(neuron_id: usize) -> Self {
238        Self {
239            neuron_id,
240            spike_times: Vec::new(),
241            max_window: 1000.0, // 1 second default
242        }
243    }
244
245    /// Create with custom window size
246    pub fn with_window(neuron_id: usize, max_window: f64) -> Self {
247        Self {
248            neuron_id,
249            spike_times: Vec::new(),
250            max_window,
251        }
252    }
253
254    /// Record a spike at given time
255    pub fn record_spike(&mut self, time: SimTime) {
256        self.spike_times.push(time);
257
258        // Prune old spikes
259        let cutoff = time - self.max_window;
260        self.spike_times.retain(|&t| t >= cutoff);
261    }
262
263    /// Clear all recorded spikes
264    pub fn clear(&mut self) {
265        self.spike_times.clear();
266    }
267
268    /// Get number of spikes in the train
269    pub fn count(&self) -> usize {
270        self.spike_times.len()
271    }
272
273    /// Compute instantaneous spike rate (spikes/ms)
274    pub fn spike_rate(&self, window: f64) -> f64 {
275        if self.spike_times.is_empty() {
276            return 0.0;
277        }
278
279        let latest = self.spike_times.last().copied().unwrap_or(0.0);
280        let count = self
281            .spike_times
282            .iter()
283            .filter(|&&t| t >= latest - window)
284            .count();
285
286        count as f64 / window
287    }
288
289    /// Compute inter-spike interval statistics
290    pub fn mean_isi(&self) -> Option<f64> {
291        if self.spike_times.len() < 2 {
292            return None;
293        }
294
295        let mut total_isi = 0.0;
296        for i in 1..self.spike_times.len() {
297            total_isi += self.spike_times[i] - self.spike_times[i - 1];
298        }
299
300        Some(total_isi / (self.spike_times.len() - 1) as f64)
301    }
302
303    /// Get coefficient of variation of ISI
304    pub fn cv_isi(&self) -> Option<f64> {
305        let mean = self.mean_isi()?;
306        if mean == 0.0 {
307            return None;
308        }
309
310        let mut variance = 0.0;
311        for i in 1..self.spike_times.len() {
312            let isi = self.spike_times[i] - self.spike_times[i - 1];
313            variance += (isi - mean).powi(2);
314        }
315        variance /= (self.spike_times.len() - 1) as f64;
316
317        Some(variance.sqrt() / mean)
318    }
319
320    /// Convert spike train to binary pattern (temporal encoding)
321    ///
322    /// Safely handles potential overflow in bin calculation.
323    pub fn to_pattern(&self, start: SimTime, bin_size: f64, num_bins: usize) -> Vec<bool> {
324        let mut pattern = vec![false; num_bins];
325
326        // Guard against zero/negative bin_size
327        if bin_size <= 0.0 || num_bins == 0 {
328            return pattern;
329        }
330
331        let end_time = start + bin_size * num_bins as f64;
332
333        for &spike_time in &self.spike_times {
334            if spike_time >= start && spike_time < end_time {
335                // Safe bin calculation with overflow protection
336                let offset = spike_time - start;
337                let bin_f64 = offset / bin_size;
338
339                // Check for overflow before casting
340                if bin_f64 >= 0.0 && bin_f64 < num_bins as f64 {
341                    let bin = bin_f64 as usize;
342                    if bin < num_bins {
343                        pattern[bin] = true;
344                    }
345                }
346            }
347        }
348
349        pattern
350    }
351
352    /// Check if spike times are sorted (for optimization)
353    #[inline]
354    fn is_sorted(times: &[f64]) -> bool {
355        times.windows(2).all(|w| w[0] <= w[1])
356    }
357
358    /// Compute cross-correlation with another spike train
359    ///
360    /// Uses O(n log n) sliding window algorithm instead of O(n²) pairwise comparison.
361    /// Optimized to skip sorting when spike trains are already sorted (typical case).
362    /// Uses binary search for initial window position.
363    pub fn cross_correlation(&self, other: &SpikeTrain, max_lag: f64, bin_size: f64) -> Vec<f64> {
364        // Guard against invalid parameters
365        if bin_size <= 0.0 || max_lag <= 0.0 {
366            return vec![0.0];
367        }
368
369        // Safe num_bins calculation with overflow protection
370        let num_bins_f64 = 2.0 * max_lag / bin_size + 1.0;
371        let num_bins = if num_bins_f64 > 0.0 && num_bins_f64 < usize::MAX as f64 {
372            (num_bins_f64 as usize).min(100_000) // Cap at 100K bins to prevent DoS
373        } else {
374            return vec![0.0];
375        };
376
377        let mut correlation = vec![0.0; num_bins];
378
379        // Empty train optimization
380        if self.spike_times.is_empty() || other.spike_times.is_empty() {
381            return correlation;
382        }
383
384        // Avoid cloning and sorting if already sorted (typical case for spike trains)
385        let t1_owned: Vec<f64>;
386        let t2_owned: Vec<f64>;
387
388        let t1: &[f64] = if Self::is_sorted(&self.spike_times) {
389            &self.spike_times
390        } else {
391            t1_owned = {
392                let mut v = self.spike_times.clone();
393                v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
394                v
395            };
396            &t1_owned
397        };
398
399        let t2: &[f64] = if Self::is_sorted(&other.spike_times) {
400            &other.spike_times
401        } else {
402            t2_owned = {
403                let mut v = other.spike_times.clone();
404                v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
405                v
406            };
407            &t2_owned
408        };
409
410        // Use binary search for first spike's window start
411        let first_lower = t1[0] - max_lag;
412        let mut window_start = t2.partition_point(|&x| x < first_lower);
413
414        for &t1_spike in t1 {
415            let lower_bound = t1_spike - max_lag;
416            let upper_bound = t1_spike + max_lag;
417
418            // Advance window_start past spikes too early
419            while window_start < t2.len() && t2[window_start] < lower_bound {
420                window_start += 1;
421            }
422
423            // Count spikes within window
424            let mut j = window_start;
425            while j < t2.len() && t2[j] <= upper_bound {
426                let lag = t1_spike - t2[j];
427
428                // Safe bin calculation (inlined for performance)
429                let bin = ((lag + max_lag) / bin_size) as usize;
430                if bin < num_bins {
431                    correlation[bin] += 1.0;
432                }
433                j += 1;
434            }
435        }
436
437        // Normalize by geometric mean of spike counts
438        let norm = ((self.count() * other.count()) as f64).sqrt();
439        if norm > 0.0 {
440            let inv_norm = 1.0 / norm;
441            for c in &mut correlation {
442                *c *= inv_norm;
443            }
444        }
445
446        correlation
447    }
448}
449
450/// Population of LIF neurons
451#[derive(Debug, Clone)]
452pub struct NeuronPopulation {
453    /// All neurons in the population
454    pub neurons: Vec<LIFNeuron>,
455    /// Spike trains for each neuron
456    pub spike_trains: Vec<SpikeTrain>,
457    /// Current simulation time
458    pub time: SimTime,
459}
460
461impl NeuronPopulation {
462    /// Create a new population with n neurons
463    pub fn new(n: usize) -> Self {
464        let neurons: Vec<_> = (0..n).map(|i| LIFNeuron::new(i)).collect();
465        let spike_trains: Vec<_> = (0..n).map(|i| SpikeTrain::new(i)).collect();
466
467        Self {
468            neurons,
469            spike_trains,
470            time: 0.0,
471        }
472    }
473
474    /// Create population with custom configuration
475    pub fn with_config(n: usize, config: NeuronConfig) -> Self {
476        let neurons: Vec<_> = (0..n)
477            .map(|i| LIFNeuron::with_config(i, config.clone()))
478            .collect();
479        let spike_trains: Vec<_> = (0..n).map(|i| SpikeTrain::new(i)).collect();
480
481        Self {
482            neurons,
483            spike_trains,
484            time: 0.0,
485        }
486    }
487
488    /// Get number of neurons
489    pub fn size(&self) -> usize {
490        self.neurons.len()
491    }
492
493    /// Step all neurons with given currents
494    ///
495    /// Uses parallel processing for large populations (>200 neurons).
496    pub fn step(&mut self, currents: &[f64], dt: f64) -> Vec<Spike> {
497        self.time += dt;
498        let time = self.time;
499
500        if self.neurons.len() >= PARALLEL_THRESHOLD {
501            // Parallel path: compute neuron updates in parallel
502            let spike_flags: Vec<bool> = self
503                .neurons
504                .par_iter_mut()
505                .enumerate()
506                .map(|(i, neuron)| {
507                    let current = currents.get(i).copied().unwrap_or(0.0);
508                    neuron.step(current, dt, time)
509                })
510                .collect();
511
512            // Sequential: collect spikes and record to trains
513            let mut spikes = Vec::new();
514            for (i, &spiked) in spike_flags.iter().enumerate() {
515                if spiked {
516                    spikes.push(Spike { neuron_id: i, time });
517                    self.spike_trains[i].record_spike(time);
518                }
519            }
520            spikes
521        } else {
522            // Sequential path for small populations (avoid parallel overhead)
523            let mut spikes = Vec::new();
524            for (i, neuron) in self.neurons.iter_mut().enumerate() {
525                let current = currents.get(i).copied().unwrap_or(0.0);
526                if neuron.step(current, dt, time) {
527                    spikes.push(Spike { neuron_id: i, time });
528                    self.spike_trains[i].record_spike(time);
529                }
530            }
531            spikes
532        }
533    }
534
535    /// Reset all neurons
536    pub fn reset(&mut self) {
537        self.time = 0.0;
538        for neuron in &mut self.neurons {
539            neuron.reset();
540        }
541        for train in &mut self.spike_trains {
542            train.clear();
543        }
544    }
545
546    /// Get population spike rate
547    pub fn population_rate(&self, window: f64) -> f64 {
548        let total: f64 = self.spike_trains.iter().map(|t| t.spike_rate(window)).sum();
549        total / self.neurons.len() as f64
550    }
551
552    /// Compute population synchrony
553    pub fn synchrony(&self, window: f64) -> f64 {
554        // Collect recent spikes
555        let mut all_spikes = Vec::new();
556        let cutoff = self.time - window;
557
558        for train in &self.spike_trains {
559            for &t in &train.spike_times {
560                if t >= cutoff {
561                    all_spikes.push(Spike {
562                        neuron_id: train.neuron_id,
563                        time: t,
564                    });
565                }
566            }
567        }
568
569        super::compute_synchrony(&all_spikes, window / 10.0)
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576
577    #[test]
578    fn test_lif_neuron_creation() {
579        let neuron = LIFNeuron::new(0);
580        assert_eq!(neuron.id, 0);
581        assert_eq!(neuron.state.v, 0.0);
582    }
583
584    #[test]
585    fn test_lif_neuron_spike() {
586        let mut neuron = LIFNeuron::new(0);
587
588        // Apply strong current until it spikes
589        let mut spiked = false;
590        for i in 0..100 {
591            if neuron.step(2.0, 1.0, i as f64) {
592                spiked = true;
593                break;
594            }
595        }
596
597        assert!(spiked);
598        assert!(neuron.is_refractory());
599    }
600
601    #[test]
602    fn test_spike_train() {
603        let mut train = SpikeTrain::new(0);
604        train.record_spike(10.0);
605        train.record_spike(20.0);
606        train.record_spike(30.0);
607
608        assert_eq!(train.count(), 3);
609
610        let mean_isi = train.mean_isi().unwrap();
611        assert!((mean_isi - 10.0).abs() < 0.001);
612    }
613
614    #[test]
615    fn test_neuron_population() {
616        let mut pop = NeuronPopulation::new(100);
617
618        // Apply uniform current
619        let currents = vec![1.5; 100];
620
621        let mut total_spikes = 0;
622        for _ in 0..100 {
623            let spikes = pop.step(&currents, 1.0);
624            total_spikes += spikes.len();
625        }
626
627        // Should have some spikes after 100ms with current of 1.5
628        assert!(total_spikes > 0);
629    }
630
631    #[test]
632    fn test_spike_train_pattern() {
633        let mut train = SpikeTrain::new(0);
634        train.record_spike(1.0);
635        train.record_spike(3.0);
636        train.record_spike(7.0);
637
638        let pattern = train.to_pattern(0.0, 1.0, 10);
639        assert_eq!(pattern.len(), 10);
640        assert!(pattern[1]); // Spike at t=1
641        assert!(pattern[3]); // Spike at t=3
642        assert!(pattern[7]); // Spike at t=7
643        assert!(!pattern[0]); // No spike at t=0
644    }
645}