1use super::{SimTime, Spike};
21use rayon::prelude::*;
22use std::collections::VecDeque;
23
24const PARALLEL_THRESHOLD: usize = 2000;
27
28#[derive(Debug, Clone)]
30pub struct NeuronConfig {
31 pub tau_membrane: f64,
33 pub v_rest: f64,
35 pub v_reset: f64,
37 pub threshold: f64,
39 pub t_refrac: f64,
41 pub resistance: f64,
43 pub threshold_adapt: f64,
45 pub tau_threshold: f64,
47 pub homeostatic: bool,
49 pub target_rate: f64,
51 pub tau_homeostatic: f64,
53}
54
55impl Default for NeuronConfig {
56 fn default() -> Self {
57 Self {
58 tau_membrane: 20.0,
59 v_rest: 0.0,
60 v_reset: 0.0,
61 threshold: 1.0,
62 t_refrac: 2.0,
63 resistance: 1.0,
64 threshold_adapt: 0.1,
65 tau_threshold: 100.0,
66 homeostatic: true,
67 target_rate: 0.01,
68 tau_homeostatic: 1000.0,
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct NeuronState {
76 pub v: f64,
78 pub threshold: f64,
80 pub refrac_remaining: f64,
82 pub last_spike_time: f64,
84 pub spike_rate: f64,
86}
87
88impl Default for NeuronState {
89 fn default() -> Self {
90 Self {
91 v: 0.0,
92 threshold: 1.0,
93 refrac_remaining: 0.0,
94 last_spike_time: f64::NEG_INFINITY,
95 spike_rate: 0.0,
96 }
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct LIFNeuron {
103 pub id: usize,
105 pub config: NeuronConfig,
107 pub state: NeuronState,
109}
110
111impl LIFNeuron {
112 pub fn new(id: usize) -> Self {
114 Self {
115 id,
116 config: NeuronConfig::default(),
117 state: NeuronState::default(),
118 }
119 }
120
121 pub fn with_config(id: usize, config: NeuronConfig) -> Self {
123 let mut state = NeuronState::default();
124 state.threshold = config.threshold;
125 Self { id, config, state }
126 }
127
128 pub fn reset(&mut self) {
130 self.state = NeuronState {
131 threshold: self.config.threshold,
132 ..NeuronState::default()
133 };
134 }
135
136 pub fn step(&mut self, current: f64, dt: f64, time: SimTime) -> bool {
139 if self.state.refrac_remaining > 0.0 {
141 self.state.refrac_remaining -= dt;
142 return false;
143 }
144
145 let dv = (-self.state.v + self.config.v_rest + self.config.resistance * current)
147 / self.config.tau_membrane * dt;
148 self.state.v += dv;
149
150 if self.state.threshold > self.config.threshold {
152 let d_thresh = -(self.state.threshold - self.config.threshold)
153 / self.config.tau_threshold * dt;
154 self.state.threshold += d_thresh;
155 }
156
157 if self.state.v >= self.state.threshold {
159 self.state.v = self.config.v_reset;
161 self.state.refrac_remaining = self.config.t_refrac;
162 self.state.last_spike_time = time;
163
164 self.state.threshold += self.config.threshold_adapt;
166
167 let alpha = (dt / self.config.tau_homeostatic).min(1.0);
170 self.state.spike_rate = self.state.spike_rate * (1.0 - alpha) + alpha;
171
172 return true;
173 }
174
175 self.state.spike_rate *= 1.0 - dt / self.config.tau_homeostatic;
177
178 if self.config.homeostatic {
180 let rate_error = self.state.spike_rate - self.config.target_rate;
181 let d_base_thresh = rate_error * dt / self.config.tau_homeostatic;
182 }
185
186 false
187 }
188
189 pub fn inject_spike(&mut self, time: SimTime) {
191 self.state.last_spike_time = time;
192 let alpha = (1.0 / self.config.tau_homeostatic).min(1.0);
194 self.state.spike_rate = self.state.spike_rate * (1.0 - alpha) + alpha;
195 }
196
197 pub fn time_since_spike(&self, current_time: SimTime) -> f64 {
199 current_time - self.state.last_spike_time
200 }
201
202 pub fn is_refractory(&self) -> bool {
204 self.state.refrac_remaining > 0.0
205 }
206
207 pub fn membrane_potential(&self) -> f64 {
209 self.state.v
210 }
211
212 pub fn set_membrane_potential(&mut self, v: f64) {
214 self.state.v = v;
215 }
216
217 pub fn threshold(&self) -> f64 {
219 self.state.threshold
220 }
221}
222
223#[derive(Debug, Clone)]
225pub struct SpikeTrain {
226 pub neuron_id: usize,
228 pub spike_times: Vec<SimTime>,
230 pub max_window: f64,
232}
233
234impl SpikeTrain {
235 pub fn new(neuron_id: usize) -> Self {
237 Self {
238 neuron_id,
239 spike_times: Vec::new(),
240 max_window: 1000.0, }
242 }
243
244 pub fn with_window(neuron_id: usize, max_window: f64) -> Self {
246 Self {
247 neuron_id,
248 spike_times: Vec::new(),
249 max_window,
250 }
251 }
252
253 pub fn record_spike(&mut self, time: SimTime) {
255 self.spike_times.push(time);
256
257 let cutoff = time - self.max_window;
259 self.spike_times.retain(|&t| t >= cutoff);
260 }
261
262 pub fn clear(&mut self) {
264 self.spike_times.clear();
265 }
266
267 pub fn count(&self) -> usize {
269 self.spike_times.len()
270 }
271
272 pub fn spike_rate(&self, window: f64) -> f64 {
274 if self.spike_times.is_empty() {
275 return 0.0;
276 }
277
278 let latest = self.spike_times.last().copied().unwrap_or(0.0);
279 let count = self.spike_times.iter()
280 .filter(|&&t| t >= latest - window)
281 .count();
282
283 count as f64 / window
284 }
285
286 pub fn mean_isi(&self) -> Option<f64> {
288 if self.spike_times.len() < 2 {
289 return None;
290 }
291
292 let mut total_isi = 0.0;
293 for i in 1..self.spike_times.len() {
294 total_isi += self.spike_times[i] - self.spike_times[i - 1];
295 }
296
297 Some(total_isi / (self.spike_times.len() - 1) as f64)
298 }
299
300 pub fn cv_isi(&self) -> Option<f64> {
302 let mean = self.mean_isi()?;
303 if mean == 0.0 {
304 return None;
305 }
306
307 let mut variance = 0.0;
308 for i in 1..self.spike_times.len() {
309 let isi = self.spike_times[i] - self.spike_times[i - 1];
310 variance += (isi - mean).powi(2);
311 }
312 variance /= (self.spike_times.len() - 1) as f64;
313
314 Some(variance.sqrt() / mean)
315 }
316
317 pub fn to_pattern(&self, start: SimTime, bin_size: f64, num_bins: usize) -> Vec<bool> {
321 let mut pattern = vec![false; num_bins];
322
323 if bin_size <= 0.0 || num_bins == 0 {
325 return pattern;
326 }
327
328 let end_time = start + bin_size * num_bins as f64;
329
330 for &spike_time in &self.spike_times {
331 if spike_time >= start && spike_time < end_time {
332 let offset = spike_time - start;
334 let bin_f64 = offset / bin_size;
335
336 if bin_f64 >= 0.0 && bin_f64 < num_bins as f64 {
338 let bin = bin_f64 as usize;
339 if bin < num_bins {
340 pattern[bin] = true;
341 }
342 }
343 }
344 }
345
346 pattern
347 }
348
349 #[inline]
351 fn is_sorted(times: &[f64]) -> bool {
352 times.windows(2).all(|w| w[0] <= w[1])
353 }
354
355 pub fn cross_correlation(&self, other: &SpikeTrain, max_lag: f64, bin_size: f64) -> Vec<f64> {
361 if bin_size <= 0.0 || max_lag <= 0.0 {
363 return vec![0.0];
364 }
365
366 let num_bins_f64 = 2.0 * max_lag / bin_size + 1.0;
368 let num_bins = if num_bins_f64 > 0.0 && num_bins_f64 < usize::MAX as f64 {
369 (num_bins_f64 as usize).min(100_000) } else {
371 return vec![0.0];
372 };
373
374 let mut correlation = vec![0.0; num_bins];
375
376 if self.spike_times.is_empty() || other.spike_times.is_empty() {
378 return correlation;
379 }
380
381 let t1_owned: Vec<f64>;
383 let t2_owned: Vec<f64>;
384
385 let t1: &[f64] = if Self::is_sorted(&self.spike_times) {
386 &self.spike_times
387 } else {
388 t1_owned = {
389 let mut v = self.spike_times.clone();
390 v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
391 v
392 };
393 &t1_owned
394 };
395
396 let t2: &[f64] = if Self::is_sorted(&other.spike_times) {
397 &other.spike_times
398 } else {
399 t2_owned = {
400 let mut v = other.spike_times.clone();
401 v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
402 v
403 };
404 &t2_owned
405 };
406
407 let first_lower = t1[0] - max_lag;
409 let mut window_start = t2.partition_point(|&x| x < first_lower);
410
411 for &t1_spike in t1 {
412 let lower_bound = t1_spike - max_lag;
413 let upper_bound = t1_spike + max_lag;
414
415 while window_start < t2.len() && t2[window_start] < lower_bound {
417 window_start += 1;
418 }
419
420 let mut j = window_start;
422 while j < t2.len() && t2[j] <= upper_bound {
423 let lag = t1_spike - t2[j];
424
425 let bin = ((lag + max_lag) / bin_size) as usize;
427 if bin < num_bins {
428 correlation[bin] += 1.0;
429 }
430 j += 1;
431 }
432 }
433
434 let norm = ((self.count() * other.count()) as f64).sqrt();
436 if norm > 0.0 {
437 let inv_norm = 1.0 / norm;
438 for c in &mut correlation {
439 *c *= inv_norm;
440 }
441 }
442
443 correlation
444 }
445}
446
447#[derive(Debug, Clone)]
449pub struct NeuronPopulation {
450 pub neurons: Vec<LIFNeuron>,
452 pub spike_trains: Vec<SpikeTrain>,
454 pub time: SimTime,
456}
457
458impl NeuronPopulation {
459 pub fn new(n: usize) -> Self {
461 let neurons: Vec<_> = (0..n).map(|i| LIFNeuron::new(i)).collect();
462 let spike_trains: Vec<_> = (0..n).map(|i| SpikeTrain::new(i)).collect();
463
464 Self {
465 neurons,
466 spike_trains,
467 time: 0.0,
468 }
469 }
470
471 pub fn with_config(n: usize, config: NeuronConfig) -> Self {
473 let neurons: Vec<_> = (0..n)
474 .map(|i| LIFNeuron::with_config(i, config.clone()))
475 .collect();
476 let spike_trains: Vec<_> = (0..n).map(|i| SpikeTrain::new(i)).collect();
477
478 Self {
479 neurons,
480 spike_trains,
481 time: 0.0,
482 }
483 }
484
485 pub fn size(&self) -> usize {
487 self.neurons.len()
488 }
489
490 pub fn step(&mut self, currents: &[f64], dt: f64) -> Vec<Spike> {
494 self.time += dt;
495 let time = self.time;
496
497 if self.neurons.len() >= PARALLEL_THRESHOLD {
498 let spike_flags: Vec<bool> = self.neurons
500 .par_iter_mut()
501 .enumerate()
502 .map(|(i, neuron)| {
503 let current = currents.get(i).copied().unwrap_or(0.0);
504 neuron.step(current, dt, time)
505 })
506 .collect();
507
508 let mut spikes = Vec::new();
510 for (i, &spiked) in spike_flags.iter().enumerate() {
511 if spiked {
512 spikes.push(Spike { neuron_id: i, time });
513 self.spike_trains[i].record_spike(time);
514 }
515 }
516 spikes
517 } else {
518 let mut spikes = Vec::new();
520 for (i, neuron) in self.neurons.iter_mut().enumerate() {
521 let current = currents.get(i).copied().unwrap_or(0.0);
522 if neuron.step(current, dt, time) {
523 spikes.push(Spike { neuron_id: i, time });
524 self.spike_trains[i].record_spike(time);
525 }
526 }
527 spikes
528 }
529 }
530
531 pub fn reset(&mut self) {
533 self.time = 0.0;
534 for neuron in &mut self.neurons {
535 neuron.reset();
536 }
537 for train in &mut self.spike_trains {
538 train.clear();
539 }
540 }
541
542 pub fn population_rate(&self, window: f64) -> f64 {
544 let total: f64 = self.spike_trains.iter()
545 .map(|t| t.spike_rate(window))
546 .sum();
547 total / self.neurons.len() as f64
548 }
549
550 pub fn synchrony(&self, window: f64) -> f64 {
552 let mut all_spikes = Vec::new();
554 let cutoff = self.time - window;
555
556 for train in &self.spike_trains {
557 for &t in &train.spike_times {
558 if t >= cutoff {
559 all_spikes.push(Spike { neuron_id: train.neuron_id, time: t });
560 }
561 }
562 }
563
564 super::compute_synchrony(&all_spikes, window / 10.0)
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571
572 #[test]
573 fn test_lif_neuron_creation() {
574 let neuron = LIFNeuron::new(0);
575 assert_eq!(neuron.id, 0);
576 assert_eq!(neuron.state.v, 0.0);
577 }
578
579 #[test]
580 fn test_lif_neuron_spike() {
581 let mut neuron = LIFNeuron::new(0);
582
583 let mut spiked = false;
585 for i in 0..100 {
586 if neuron.step(2.0, 1.0, i as f64) {
587 spiked = true;
588 break;
589 }
590 }
591
592 assert!(spiked);
593 assert!(neuron.is_refractory());
594 }
595
596 #[test]
597 fn test_spike_train() {
598 let mut train = SpikeTrain::new(0);
599 train.record_spike(10.0);
600 train.record_spike(20.0);
601 train.record_spike(30.0);
602
603 assert_eq!(train.count(), 3);
604
605 let mean_isi = train.mean_isi().unwrap();
606 assert!((mean_isi - 10.0).abs() < 0.001);
607 }
608
609 #[test]
610 fn test_neuron_population() {
611 let mut pop = NeuronPopulation::new(100);
612
613 let currents = vec![1.5; 100];
615
616 let mut total_spikes = 0;
617 for _ in 0..100 {
618 let spikes = pop.step(¤ts, 1.0);
619 total_spikes += spikes.len();
620 }
621
622 assert!(total_spikes > 0);
624 }
625
626 #[test]
627 fn test_spike_train_pattern() {
628 let mut train = SpikeTrain::new(0);
629 train.record_spike(1.0);
630 train.record_spike(3.0);
631 train.record_spike(7.0);
632
633 let pattern = train.to_pattern(0.0, 1.0, 10);
634 assert_eq!(pattern.len(), 10);
635 assert!(pattern[1]); assert!(pattern[3]); assert!(pattern[7]); assert!(!pattern[0]); }
640}