1use super::{SimTime, Spike};
21use rayon::prelude::*;
22use std::collections::VecDeque;
23
24const PARALLEL_THRESHOLD: usize = 2000;
27
28#[derive(Debug, Clone)]
30pub struct NeuronConfig {
31 pub tau_membrane: f64,
33 pub v_rest: f64,
35 pub v_reset: f64,
37 pub threshold: f64,
39 pub t_refrac: f64,
41 pub resistance: f64,
43 pub threshold_adapt: f64,
45 pub tau_threshold: f64,
47 pub homeostatic: bool,
49 pub target_rate: f64,
51 pub tau_homeostatic: f64,
53}
54
55impl Default for NeuronConfig {
56 fn default() -> Self {
57 Self {
58 tau_membrane: 20.0,
59 v_rest: 0.0,
60 v_reset: 0.0,
61 threshold: 1.0,
62 t_refrac: 2.0,
63 resistance: 1.0,
64 threshold_adapt: 0.1,
65 tau_threshold: 100.0,
66 homeostatic: true,
67 target_rate: 0.01,
68 tau_homeostatic: 1000.0,
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct NeuronState {
76 pub v: f64,
78 pub threshold: f64,
80 pub refrac_remaining: f64,
82 pub last_spike_time: f64,
84 pub spike_rate: f64,
86}
87
88impl Default for NeuronState {
89 fn default() -> Self {
90 Self {
91 v: 0.0,
92 threshold: 1.0,
93 refrac_remaining: 0.0,
94 last_spike_time: f64::NEG_INFINITY,
95 spike_rate: 0.0,
96 }
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct LIFNeuron {
103 pub id: usize,
105 pub config: NeuronConfig,
107 pub state: NeuronState,
109}
110
111impl LIFNeuron {
112 pub fn new(id: usize) -> Self {
114 Self {
115 id,
116 config: NeuronConfig::default(),
117 state: NeuronState::default(),
118 }
119 }
120
121 pub fn with_config(id: usize, config: NeuronConfig) -> Self {
123 let mut state = NeuronState::default();
124 state.threshold = config.threshold;
125 Self { id, config, state }
126 }
127
128 pub fn reset(&mut self) {
130 self.state = NeuronState {
131 threshold: self.config.threshold,
132 ..NeuronState::default()
133 };
134 }
135
136 pub fn step(&mut self, current: f64, dt: f64, time: SimTime) -> bool {
139 if self.state.refrac_remaining > 0.0 {
141 self.state.refrac_remaining -= dt;
142 return false;
143 }
144
145 let dv = (-self.state.v + self.config.v_rest + self.config.resistance * current)
147 / self.config.tau_membrane
148 * dt;
149 self.state.v += dv;
150
151 if self.state.threshold > self.config.threshold {
153 let d_thresh =
154 -(self.state.threshold - self.config.threshold) / self.config.tau_threshold * dt;
155 self.state.threshold += d_thresh;
156 }
157
158 if self.state.v >= self.state.threshold {
160 self.state.v = self.config.v_reset;
162 self.state.refrac_remaining = self.config.t_refrac;
163 self.state.last_spike_time = time;
164
165 self.state.threshold += self.config.threshold_adapt;
167
168 let alpha = (dt / self.config.tau_homeostatic).min(1.0);
171 self.state.spike_rate = self.state.spike_rate * (1.0 - alpha) + alpha;
172
173 return true;
174 }
175
176 self.state.spike_rate *= 1.0 - dt / self.config.tau_homeostatic;
178
179 if self.config.homeostatic {
181 let rate_error = self.state.spike_rate - self.config.target_rate;
182 let d_base_thresh = rate_error * dt / self.config.tau_homeostatic;
183 }
186
187 false
188 }
189
190 pub fn inject_spike(&mut self, time: SimTime) {
192 self.state.last_spike_time = time;
193 let alpha = (1.0 / self.config.tau_homeostatic).min(1.0);
195 self.state.spike_rate = self.state.spike_rate * (1.0 - alpha) + alpha;
196 }
197
198 pub fn time_since_spike(&self, current_time: SimTime) -> f64 {
200 current_time - self.state.last_spike_time
201 }
202
203 pub fn is_refractory(&self) -> bool {
205 self.state.refrac_remaining > 0.0
206 }
207
208 pub fn membrane_potential(&self) -> f64 {
210 self.state.v
211 }
212
213 pub fn set_membrane_potential(&mut self, v: f64) {
215 self.state.v = v;
216 }
217
218 pub fn threshold(&self) -> f64 {
220 self.state.threshold
221 }
222}
223
224#[derive(Debug, Clone)]
226pub struct SpikeTrain {
227 pub neuron_id: usize,
229 pub spike_times: Vec<SimTime>,
231 pub max_window: f64,
233}
234
235impl SpikeTrain {
236 pub fn new(neuron_id: usize) -> Self {
238 Self {
239 neuron_id,
240 spike_times: Vec::new(),
241 max_window: 1000.0, }
243 }
244
245 pub fn with_window(neuron_id: usize, max_window: f64) -> Self {
247 Self {
248 neuron_id,
249 spike_times: Vec::new(),
250 max_window,
251 }
252 }
253
254 pub fn record_spike(&mut self, time: SimTime) {
256 self.spike_times.push(time);
257
258 let cutoff = time - self.max_window;
260 self.spike_times.retain(|&t| t >= cutoff);
261 }
262
263 pub fn clear(&mut self) {
265 self.spike_times.clear();
266 }
267
268 pub fn count(&self) -> usize {
270 self.spike_times.len()
271 }
272
273 pub fn spike_rate(&self, window: f64) -> f64 {
275 if self.spike_times.is_empty() {
276 return 0.0;
277 }
278
279 let latest = self.spike_times.last().copied().unwrap_or(0.0);
280 let count = self
281 .spike_times
282 .iter()
283 .filter(|&&t| t >= latest - window)
284 .count();
285
286 count as f64 / window
287 }
288
289 pub fn mean_isi(&self) -> Option<f64> {
291 if self.spike_times.len() < 2 {
292 return None;
293 }
294
295 let mut total_isi = 0.0;
296 for i in 1..self.spike_times.len() {
297 total_isi += self.spike_times[i] - self.spike_times[i - 1];
298 }
299
300 Some(total_isi / (self.spike_times.len() - 1) as f64)
301 }
302
303 pub fn cv_isi(&self) -> Option<f64> {
305 let mean = self.mean_isi()?;
306 if mean == 0.0 {
307 return None;
308 }
309
310 let mut variance = 0.0;
311 for i in 1..self.spike_times.len() {
312 let isi = self.spike_times[i] - self.spike_times[i - 1];
313 variance += (isi - mean).powi(2);
314 }
315 variance /= (self.spike_times.len() - 1) as f64;
316
317 Some(variance.sqrt() / mean)
318 }
319
320 pub fn to_pattern(&self, start: SimTime, bin_size: f64, num_bins: usize) -> Vec<bool> {
324 let mut pattern = vec![false; num_bins];
325
326 if bin_size <= 0.0 || num_bins == 0 {
328 return pattern;
329 }
330
331 let end_time = start + bin_size * num_bins as f64;
332
333 for &spike_time in &self.spike_times {
334 if spike_time >= start && spike_time < end_time {
335 let offset = spike_time - start;
337 let bin_f64 = offset / bin_size;
338
339 if bin_f64 >= 0.0 && bin_f64 < num_bins as f64 {
341 let bin = bin_f64 as usize;
342 if bin < num_bins {
343 pattern[bin] = true;
344 }
345 }
346 }
347 }
348
349 pattern
350 }
351
352 #[inline]
354 fn is_sorted(times: &[f64]) -> bool {
355 times.windows(2).all(|w| w[0] <= w[1])
356 }
357
358 pub fn cross_correlation(&self, other: &SpikeTrain, max_lag: f64, bin_size: f64) -> Vec<f64> {
364 if bin_size <= 0.0 || max_lag <= 0.0 {
366 return vec![0.0];
367 }
368
369 let num_bins_f64 = 2.0 * max_lag / bin_size + 1.0;
371 let num_bins = if num_bins_f64 > 0.0 && num_bins_f64 < usize::MAX as f64 {
372 (num_bins_f64 as usize).min(100_000) } else {
374 return vec![0.0];
375 };
376
377 let mut correlation = vec![0.0; num_bins];
378
379 if self.spike_times.is_empty() || other.spike_times.is_empty() {
381 return correlation;
382 }
383
384 let t1_owned: Vec<f64>;
386 let t2_owned: Vec<f64>;
387
388 let t1: &[f64] = if Self::is_sorted(&self.spike_times) {
389 &self.spike_times
390 } else {
391 t1_owned = {
392 let mut v = self.spike_times.clone();
393 v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
394 v
395 };
396 &t1_owned
397 };
398
399 let t2: &[f64] = if Self::is_sorted(&other.spike_times) {
400 &other.spike_times
401 } else {
402 t2_owned = {
403 let mut v = other.spike_times.clone();
404 v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
405 v
406 };
407 &t2_owned
408 };
409
410 let first_lower = t1[0] - max_lag;
412 let mut window_start = t2.partition_point(|&x| x < first_lower);
413
414 for &t1_spike in t1 {
415 let lower_bound = t1_spike - max_lag;
416 let upper_bound = t1_spike + max_lag;
417
418 while window_start < t2.len() && t2[window_start] < lower_bound {
420 window_start += 1;
421 }
422
423 let mut j = window_start;
425 while j < t2.len() && t2[j] <= upper_bound {
426 let lag = t1_spike - t2[j];
427
428 let bin = ((lag + max_lag) / bin_size) as usize;
430 if bin < num_bins {
431 correlation[bin] += 1.0;
432 }
433 j += 1;
434 }
435 }
436
437 let norm = ((self.count() * other.count()) as f64).sqrt();
439 if norm > 0.0 {
440 let inv_norm = 1.0 / norm;
441 for c in &mut correlation {
442 *c *= inv_norm;
443 }
444 }
445
446 correlation
447 }
448}
449
450#[derive(Debug, Clone)]
452pub struct NeuronPopulation {
453 pub neurons: Vec<LIFNeuron>,
455 pub spike_trains: Vec<SpikeTrain>,
457 pub time: SimTime,
459}
460
461impl NeuronPopulation {
462 pub fn new(n: usize) -> Self {
464 let neurons: Vec<_> = (0..n).map(|i| LIFNeuron::new(i)).collect();
465 let spike_trains: Vec<_> = (0..n).map(|i| SpikeTrain::new(i)).collect();
466
467 Self {
468 neurons,
469 spike_trains,
470 time: 0.0,
471 }
472 }
473
474 pub fn with_config(n: usize, config: NeuronConfig) -> Self {
476 let neurons: Vec<_> = (0..n)
477 .map(|i| LIFNeuron::with_config(i, config.clone()))
478 .collect();
479 let spike_trains: Vec<_> = (0..n).map(|i| SpikeTrain::new(i)).collect();
480
481 Self {
482 neurons,
483 spike_trains,
484 time: 0.0,
485 }
486 }
487
488 pub fn size(&self) -> usize {
490 self.neurons.len()
491 }
492
493 pub fn step(&mut self, currents: &[f64], dt: f64) -> Vec<Spike> {
497 self.time += dt;
498 let time = self.time;
499
500 if self.neurons.len() >= PARALLEL_THRESHOLD {
501 let spike_flags: Vec<bool> = self
503 .neurons
504 .par_iter_mut()
505 .enumerate()
506 .map(|(i, neuron)| {
507 let current = currents.get(i).copied().unwrap_or(0.0);
508 neuron.step(current, dt, time)
509 })
510 .collect();
511
512 let mut spikes = Vec::new();
514 for (i, &spiked) in spike_flags.iter().enumerate() {
515 if spiked {
516 spikes.push(Spike { neuron_id: i, time });
517 self.spike_trains[i].record_spike(time);
518 }
519 }
520 spikes
521 } else {
522 let mut spikes = Vec::new();
524 for (i, neuron) in self.neurons.iter_mut().enumerate() {
525 let current = currents.get(i).copied().unwrap_or(0.0);
526 if neuron.step(current, dt, time) {
527 spikes.push(Spike { neuron_id: i, time });
528 self.spike_trains[i].record_spike(time);
529 }
530 }
531 spikes
532 }
533 }
534
535 pub fn reset(&mut self) {
537 self.time = 0.0;
538 for neuron in &mut self.neurons {
539 neuron.reset();
540 }
541 for train in &mut self.spike_trains {
542 train.clear();
543 }
544 }
545
546 pub fn population_rate(&self, window: f64) -> f64 {
548 let total: f64 = self.spike_trains.iter().map(|t| t.spike_rate(window)).sum();
549 total / self.neurons.len() as f64
550 }
551
552 pub fn synchrony(&self, window: f64) -> f64 {
554 let mut all_spikes = Vec::new();
556 let cutoff = self.time - window;
557
558 for train in &self.spike_trains {
559 for &t in &train.spike_times {
560 if t >= cutoff {
561 all_spikes.push(Spike {
562 neuron_id: train.neuron_id,
563 time: t,
564 });
565 }
566 }
567 }
568
569 super::compute_synchrony(&all_spikes, window / 10.0)
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use super::*;
576
577 #[test]
578 fn test_lif_neuron_creation() {
579 let neuron = LIFNeuron::new(0);
580 assert_eq!(neuron.id, 0);
581 assert_eq!(neuron.state.v, 0.0);
582 }
583
584 #[test]
585 fn test_lif_neuron_spike() {
586 let mut neuron = LIFNeuron::new(0);
587
588 let mut spiked = false;
590 for i in 0..100 {
591 if neuron.step(2.0, 1.0, i as f64) {
592 spiked = true;
593 break;
594 }
595 }
596
597 assert!(spiked);
598 assert!(neuron.is_refractory());
599 }
600
601 #[test]
602 fn test_spike_train() {
603 let mut train = SpikeTrain::new(0);
604 train.record_spike(10.0);
605 train.record_spike(20.0);
606 train.record_spike(30.0);
607
608 assert_eq!(train.count(), 3);
609
610 let mean_isi = train.mean_isi().unwrap();
611 assert!((mean_isi - 10.0).abs() < 0.001);
612 }
613
614 #[test]
615 fn test_neuron_population() {
616 let mut pop = NeuronPopulation::new(100);
617
618 let currents = vec![1.5; 100];
620
621 let mut total_spikes = 0;
622 for _ in 0..100 {
623 let spikes = pop.step(¤ts, 1.0);
624 total_spikes += spikes.len();
625 }
626
627 assert!(total_spikes > 0);
629 }
630
631 #[test]
632 fn test_spike_train_pattern() {
633 let mut train = SpikeTrain::new(0);
634 train.record_spike(1.0);
635 train.record_spike(3.0);
636 train.record_spike(7.0);
637
638 let pattern = train.to_pattern(0.0, 1.0, 10);
639 assert_eq!(pattern.len(), 10);
640 assert!(pattern[1]); assert!(pattern[3]); assert!(pattern[7]); assert!(!pattern[0]); }
645}