scirs2_spatial/neuromorphic/algorithms/
spiking_clustering.rs1use crate::error::{SpatialError, SpatialResult};
8use scirs2_core::ndarray::{Array1, ArrayView2};
9use scirs2_core::random::{Rng, RngExt};
10use std::collections::HashMap;
11
12use super::super::core::{SpikeEvent, SpikingNeuron, Synapse};
14
15#[derive(Debug, Clone)]
47pub struct SpikingNeuralClusterer {
48 neurons: Vec<SpikingNeuron>,
50 synapses: Vec<Synapse>,
52 num_clusters: usize,
54 spike_threshold: f64,
56 stdp_learning: bool,
58 lateral_inhibition: bool,
60 dt: f64,
62 current_time: f64,
64 spike_history: Vec<SpikeEvent>,
66 max_epochs: usize,
68 simulation_duration: f64,
70}
71
72impl SpikingNeuralClusterer {
73 pub fn new(num_clusters: usize) -> Self {
81 Self {
82 neurons: Vec::new(),
83 synapses: Vec::new(),
84 num_clusters,
85 spike_threshold: 1.0,
86 stdp_learning: true,
87 lateral_inhibition: true,
88 dt: 0.1,
89 current_time: 0.0,
90 spike_history: Vec::new(),
91 max_epochs: 100,
92 simulation_duration: 10.0,
93 }
94 }
95
96 pub fn with_spike_threshold(mut self, threshold: f64) -> Self {
101 self.spike_threshold = threshold;
102 self
103 }
104
105 pub fn with_stdp_learning(mut self, enabled: bool) -> Self {
110 self.stdp_learning = enabled;
111 self
112 }
113
114 pub fn with_lateral_inhibition(mut self, enabled: bool) -> Self {
119 self.lateral_inhibition = enabled;
120 self
121 }
122
123 pub fn with_training_params(mut self, max_epochs: usize, simulation_duration: f64) -> Self {
129 self.max_epochs = max_epochs;
130 self.simulation_duration = simulation_duration;
131 self
132 }
133
134 pub fn with_time_step(mut self, dt: f64) -> Self {
139 self.dt = dt;
140 self
141 }
142
143 pub fn fit(
156 &mut self,
157 points: &ArrayView2<'_, f64>,
158 ) -> SpatialResult<(Array1<usize>, Vec<SpikeEvent>)> {
159 let (n_points, n_dims) = points.dim();
160
161 if n_points == 0 || n_dims == 0 {
162 return Err(SpatialError::InvalidInput(
163 "Input data cannot be empty".to_string(),
164 ));
165 }
166
167 self.initialize_network(n_dims)?;
169
170 let mut assignments = Array1::zeros(n_points);
172
173 for epoch in 0..self.max_epochs {
174 self.current_time = epoch as f64 * 100.0;
175
176 for (point_idx, point) in points.outer_iter().enumerate() {
177 let spike_train = self.encode_point_as_spikes(&point.to_owned())?;
179
180 let winning_neuron = self.process_spike_train(&spike_train)?;
182 assignments[point_idx] = winning_neuron;
183
184 if self.stdp_learning {
186 self.apply_stdp_learning(&spike_train)?;
187 }
188 }
189
190 if self.lateral_inhibition {
192 self.apply_lateral_inhibition()?;
193 }
194 }
195
196 Ok((assignments, self.spike_history.clone()))
197 }
198
199 fn initialize_network(&mut self, input_dims: usize) -> SpatialResult<()> {
204 self.neurons.clear();
205 self.synapses.clear();
206 self.spike_history.clear();
207
208 for i in 0..input_dims {
210 let position = vec![i as f64];
211 let mut neuron = SpikingNeuron::new(position);
212 neuron.set_threshold(self.spike_threshold);
213 self.neurons.push(neuron);
214 }
215
216 let mut rng = scirs2_core::random::rng();
218 for _i in 0..self.num_clusters {
219 let position = (0..input_dims)
220 .map(|_| rng.random_range(0.0..1.0))
221 .collect();
222 let mut neuron = SpikingNeuron::new(position);
223 neuron.set_threshold(self.spike_threshold);
224 self.neurons.push(neuron);
225 }
226
227 for i in 0..input_dims {
229 for j in 0..self.num_clusters {
230 let output_idx = input_dims + j;
231 let weight = rng.random_range(0.0..0.5);
232 let synapse = Synapse::new(i, output_idx, weight);
233 self.synapses.push(synapse);
234 }
235 }
236
237 if self.lateral_inhibition {
239 for i in 0..self.num_clusters {
240 for j in 0..self.num_clusters {
241 if i != j {
242 let neuron_i = input_dims + i;
243 let neuron_j = input_dims + j;
244 let synapse = Synapse::new(neuron_i, neuron_j, -0.5);
245 self.synapses.push(synapse);
246 }
247 }
248 }
249 }
250
251 Ok(())
252 }
253
254 fn encode_point_as_spikes(&self, point: &Array1<f64>) -> SpatialResult<Vec<SpikeEvent>> {
260 let mut spike_train = Vec::new();
261
262 for (dim, &coord) in point.iter().enumerate() {
264 let normalized_coord = (coord + 10.0) / 20.0; let spike_rate = normalized_coord.clamp(0.0, 1.0) * 50.0; let num_spikes = (spike_rate * 1.0) as usize; for spike_idx in 0..num_spikes {
271 let timestamp =
272 self.current_time + (spike_idx as f64) * (1.0 / spike_rate.max(1.0));
273 let spike = SpikeEvent::new(dim, timestamp, 1.0, point.to_vec());
274 spike_train.push(spike);
275 }
276 }
277
278 spike_train.sort_by(|a, b| {
280 a.timestamp()
281 .partial_cmp(&b.timestamp())
282 .expect("Operation failed")
283 });
284
285 Ok(spike_train)
286 }
287
288 fn process_spike_train(&mut self, spike_train: &[SpikeEvent]) -> SpatialResult<usize> {
293 let input_dims = self.neurons.len() - self.num_clusters;
294 let mut neuron_spike_counts = vec![0; self.num_clusters];
295
296 let mut t = self.current_time;
298 let mut spike_idx = 0;
299
300 while t < self.current_time + self.simulation_duration {
301 let mut input_currents = vec![0.0; self.neurons.len()];
303
304 while spike_idx < spike_train.len() && spike_train[spike_idx].timestamp() <= t {
305 let spike = &spike_train[spike_idx];
306 if spike.neuron_id() < input_dims {
307 input_currents[spike.neuron_id()] += spike.amplitude();
308 }
309 spike_idx += 1;
310 }
311
312 for synapse in &self.synapses {
314 if synapse.pre_neuron() < self.neurons.len()
315 && synapse.post_neuron() < self.neurons.len()
316 {
317 let pre_current = input_currents[synapse.pre_neuron()];
318 let synaptic_current = synapse.synaptic_current(pre_current);
319 input_currents[synapse.post_neuron()] += synaptic_current;
320 }
321 }
322
323 for (neuron_idx, neuron) in self.neurons.iter_mut().enumerate() {
325 let spiked = neuron.update(self.dt, input_currents[neuron_idx]);
326
327 if spiked && neuron_idx >= input_dims {
328 let cluster_idx = neuron_idx - input_dims;
329 neuron_spike_counts[cluster_idx] += 1;
330
331 let spike_event =
333 SpikeEvent::new(neuron_idx, t, 1.0, neuron.position().to_vec());
334 self.spike_history.push(spike_event);
335 }
336 }
337
338 t += self.dt;
339 }
340
341 let winning_cluster = neuron_spike_counts
343 .iter()
344 .enumerate()
345 .max_by(|(_, a), (_, b)| a.cmp(b))
346 .map(|(idx, _)| idx)
347 .unwrap_or(0);
348
349 Ok(winning_cluster)
350 }
351
352 fn apply_stdp_learning(&mut self, spike_train: &[SpikeEvent]) -> SpatialResult<()> {
357 let mut spike_times: HashMap<usize, Vec<f64>> = HashMap::new();
359 for spike in spike_train {
360 spike_times
361 .entry(spike.neuron_id())
362 .or_default()
363 .push(spike.timestamp());
364 }
365
366 for spike in &self.spike_history {
368 spike_times
369 .entry(spike.neuron_id())
370 .or_default()
371 .push(spike.timestamp());
372 }
373
374 let empty_spikes = Vec::new();
376 for synapse in &mut self.synapses {
377 let pre_spikes = spike_times
378 .get(&synapse.pre_neuron())
379 .unwrap_or(&empty_spikes);
380 let post_spikes = spike_times
381 .get(&synapse.post_neuron())
382 .unwrap_or(&empty_spikes);
383
384 for &pre_time in pre_spikes {
386 for &post_time in post_spikes {
387 let dt = post_time - pre_time;
388 if dt.abs() < 50.0 {
389 let current_weight = synapse.weight();
391 if dt > 0.0 {
392 let delta_w = synapse.stdp_rate() * (-dt / synapse.stdp_tau()).exp();
394 synapse.set_weight(current_weight + delta_w);
395 } else {
396 let delta_w = synapse.stdp_rate() * (dt / synapse.stdp_tau()).exp();
398 synapse.set_weight(current_weight - delta_w);
399 }
400 }
401 }
402 }
403 }
404
405 Ok(())
406 }
407
408 fn apply_lateral_inhibition(&mut self) -> SpatialResult<()> {
413 let input_dims = self.neurons.len() - self.num_clusters;
414
415 for i in 0..self.num_clusters {
417 for j in 0..self.num_clusters {
418 if i != j {
419 let neuron_i_idx = input_dims + i;
420 let neuron_j_idx = input_dims + j;
421
422 for synapse in &mut self.synapses {
424 if synapse.pre_neuron() == neuron_i_idx
425 && synapse.post_neuron() == neuron_j_idx
426 {
427 let activity_i = self.neurons[neuron_i_idx].membrane_potential();
429 let activity_j = self.neurons[neuron_j_idx].membrane_potential();
430
431 if activity_i > activity_j {
432 let current_weight = synapse.weight();
433 synapse.set_weight(current_weight - 0.01); }
435 }
436 }
437 }
438 }
439 }
440
441 Ok(())
442 }
443
444 pub fn num_clusters(&self) -> usize {
446 self.num_clusters
447 }
448
449 pub fn spike_threshold(&self) -> f64 {
451 self.spike_threshold
452 }
453
454 pub fn is_stdp_enabled(&self) -> bool {
456 self.stdp_learning
457 }
458
459 pub fn is_lateral_inhibition_enabled(&self) -> bool {
461 self.lateral_inhibition
462 }
463
464 pub fn spike_history(&self) -> &[SpikeEvent] {
466 &self.spike_history
467 }
468
469 pub fn network_stats(&self) -> NetworkStats {
471 NetworkStats {
472 num_neurons: self.neurons.len(),
473 num_synapses: self.synapses.len(),
474 num_spikes: self.spike_history.len(),
475 average_weight: if self.synapses.is_empty() {
476 0.0
477 } else {
478 self.synapses.iter().map(|s| s.weight()).sum::<f64>() / self.synapses.len() as f64
479 },
480 }
481 }
482
483 pub fn reset(&mut self) {
485 for neuron in &mut self.neurons {
486 neuron.reset();
487 }
488 for synapse in &mut self.synapses {
489 synapse.reset_spike_history();
490 }
491 self.spike_history.clear();
492 self.current_time = 0.0;
493 }
494}
495
496#[derive(Debug, Clone)]
498pub struct NetworkStats {
499 pub num_neurons: usize,
501 pub num_synapses: usize,
503 pub num_spikes: usize,
505 pub average_weight: f64,
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512 use scirs2_core::ndarray::Array2;
513
514 #[test]
515 fn test_spiking_clusterer_creation() {
516 let clusterer = SpikingNeuralClusterer::new(3);
517 assert_eq!(clusterer.num_clusters(), 3);
518 assert_eq!(clusterer.spike_threshold(), 1.0);
519 assert!(clusterer.is_stdp_enabled());
520 assert!(clusterer.is_lateral_inhibition_enabled());
521 }
522
523 #[test]
524 fn test_clusterer_configuration() {
525 let clusterer = SpikingNeuralClusterer::new(2)
526 .with_spike_threshold(0.8)
527 .with_stdp_learning(false)
528 .with_lateral_inhibition(false)
529 .with_training_params(50, 5.0);
530
531 assert_eq!(clusterer.spike_threshold(), 0.8);
532 assert!(!clusterer.is_stdp_enabled());
533 assert!(!clusterer.is_lateral_inhibition_enabled());
534 assert_eq!(clusterer.max_epochs, 50);
535 assert_eq!(clusterer.simulation_duration, 5.0);
536 }
537
538 #[test]
539 fn test_simple_clustering() {
540 let points = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
541 .expect("Operation failed");
542
543 let mut clusterer = SpikingNeuralClusterer::new(2).with_training_params(5, 1.0); let result = clusterer.fit(&points.view());
546 assert!(result.is_ok());
547
548 let (assignments, spike_events) = result.expect("Operation failed");
549 assert_eq!(assignments.len(), 4);
550
551 assert!(!spike_events.is_empty());
553 }
554
555 #[test]
556 fn test_empty_input() {
557 let points = Array2::zeros((0, 2));
558 let mut clusterer = SpikingNeuralClusterer::new(2);
559
560 let result = clusterer.fit(&points.view());
561 assert!(result.is_err());
562 }
563
564 #[test]
565 fn test_network_initialization() {
566 let mut clusterer = SpikingNeuralClusterer::new(2);
567 clusterer.initialize_network(3).expect("Operation failed");
568
569 let stats = clusterer.network_stats();
570 assert_eq!(stats.num_neurons, 5); let expected_connections = 3 * 2; let lateral_connections = 2;
576 assert_eq!(
577 stats.num_synapses,
578 expected_connections + lateral_connections
579 );
580 }
581
582 #[test]
583 fn test_spike_encoding() {
584 let clusterer = SpikingNeuralClusterer::new(2);
585 let point = Array1::from_vec(vec![1.0, -1.0]);
586
587 let spike_train = clusterer
588 .encode_point_as_spikes(&point)
589 .expect("Operation failed");
590
591 assert!(!spike_train.is_empty());
593
594 for i in 1..spike_train.len() {
596 assert!(spike_train[i - 1].timestamp() <= spike_train[i].timestamp());
597 }
598 }
599
600 #[test]
601 fn test_network_reset() {
602 let mut clusterer = SpikingNeuralClusterer::new(2);
603 clusterer.initialize_network(2).expect("Operation failed");
604
605 clusterer
607 .spike_history
608 .push(SpikeEvent::new(0, 1.0, 1.0, vec![0.0, 0.0]));
609 clusterer.current_time = 100.0;
610
611 clusterer.reset();
613 assert!(clusterer.spike_history().is_empty());
614 assert_eq!(clusterer.current_time, 0.0);
615 }
616
617 #[test]
618 fn test_network_stats() {
619 let mut clusterer = SpikingNeuralClusterer::new(2);
620 clusterer.initialize_network(3).expect("Operation failed");
621
622 let stats = clusterer.network_stats();
623 assert_eq!(stats.num_neurons, 5);
624 assert!(stats.num_synapses > 0);
625 assert_eq!(stats.num_spikes, 0); assert!(stats.average_weight.is_finite()); }
628}