Skip to main content

scirs2_spatial/neuromorphic/algorithms/
spiking_clustering.rs

1//! Spiking Neural Network Clustering
2//!
3//! This module implements clustering algorithms based on spiking neural networks (SNNs).
4//! These algorithms use spike-timing dependent plasticity (STDP) and competitive learning
5//! to discover patterns in spatial data through biologically-inspired neural dynamics.
6
7use crate::error::{SpatialError, SpatialResult};
8use scirs2_core::ndarray::{Array1, ArrayView2};
9use scirs2_core::random::{Rng, RngExt};
10use std::collections::HashMap;
11
12// Import core neuromorphic components
13use super::super::core::{SpikeEvent, SpikingNeuron, Synapse};
14
15/// Spiking neural network clusterer
16///
17/// This clusterer uses a network of spiking neurons with STDP learning to perform
18/// unsupervised clustering of spatial data. Input points are encoded as spike trains
19/// and presented to the network, which learns to respond selectively to different
20/// input patterns through competitive dynamics.
21///
22/// # Features
23/// - Rate coding for spatial data encoding
24/// - STDP learning for adaptive weights
25/// - Lateral inhibition for competitive dynamics
26/// - Configurable network architecture
27/// - Spike timing analysis
28///
29/// # Example
30/// ```rust
31/// use scirs2_core::ndarray::Array2;
32/// use scirs2_spatial::neuromorphic::algorithms::SpikingNeuralClusterer;
33///
34/// let points = Array2::from_shape_vec((4, 2), vec![
35///     0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0
36/// ]).expect("Operation failed");
37///
38/// let mut clusterer = SpikingNeuralClusterer::new(2)
39///     .with_spike_threshold(0.8)
40///     .with_stdp_learning(true)
41///     .with_lateral_inhibition(true);
42///
43/// let (assignments, spike_events) = clusterer.fit(&points.view()).expect("Operation failed");
44/// println!("Cluster assignments: {:?}", assignments);
45/// ```
46#[derive(Debug, Clone)]
47pub struct SpikingNeuralClusterer {
48    /// Network of spiking neurons
49    neurons: Vec<SpikingNeuron>,
50    /// Synaptic connections
51    synapses: Vec<Synapse>,
52    /// Number of clusters (output neurons)
53    num_clusters: usize,
54    /// Spike threshold
55    spike_threshold: f64,
56    /// Enable STDP learning
57    stdp_learning: bool,
58    /// Enable lateral inhibition
59    lateral_inhibition: bool,
60    /// Simulation time step
61    dt: f64,
62    /// Current simulation time
63    current_time: f64,
64    /// Spike history
65    spike_history: Vec<SpikeEvent>,
66    /// Number of training epochs
67    max_epochs: usize,
68    /// Simulation duration per data point
69    simulation_duration: f64,
70}
71
72impl SpikingNeuralClusterer {
73    /// Create new spiking neural clusterer
74    ///
75    /// # Arguments
76    /// * `num_clusters` - Number of clusters to discover
77    ///
78    /// # Returns
79    /// A new `SpikingNeuralClusterer` with default parameters
80    pub fn new(num_clusters: usize) -> Self {
81        Self {
82            neurons: Vec::new(),
83            synapses: Vec::new(),
84            num_clusters,
85            spike_threshold: 1.0,
86            stdp_learning: true,
87            lateral_inhibition: true,
88            dt: 0.1,
89            current_time: 0.0,
90            spike_history: Vec::new(),
91            max_epochs: 100,
92            simulation_duration: 10.0,
93        }
94    }
95
96    /// Configure spike threshold
97    ///
98    /// # Arguments
99    /// * `threshold` - Spike threshold for neurons
100    pub fn with_spike_threshold(mut self, threshold: f64) -> Self {
101        self.spike_threshold = threshold;
102        self
103    }
104
105    /// Enable/disable STDP learning
106    ///
107    /// # Arguments
108    /// * `enabled` - Whether to enable STDP learning
109    pub fn with_stdp_learning(mut self, enabled: bool) -> Self {
110        self.stdp_learning = enabled;
111        self
112    }
113
114    /// Enable/disable lateral inhibition
115    ///
116    /// # Arguments
117    /// * `enabled` - Whether to enable lateral inhibition
118    pub fn with_lateral_inhibition(mut self, enabled: bool) -> Self {
119        self.lateral_inhibition = enabled;
120        self
121    }
122
123    /// Configure training parameters
124    ///
125    /// # Arguments
126    /// * `max_epochs` - Maximum number of training epochs
127    /// * `simulation_duration` - Duration to simulate per data point
128    pub fn with_training_params(mut self, max_epochs: usize, simulation_duration: f64) -> Self {
129        self.max_epochs = max_epochs;
130        self.simulation_duration = simulation_duration;
131        self
132    }
133
134    /// Configure simulation time step
135    ///
136    /// # Arguments
137    /// * `dt` - Time step for simulation
138    pub fn with_time_step(mut self, dt: f64) -> Self {
139        self.dt = dt;
140        self
141    }
142
143    /// Fit clustering to spatial data
144    ///
145    /// Trains the spiking neural network on the provided spatial data using
146    /// STDP learning and competitive dynamics to discover cluster structure.
147    ///
148    /// # Arguments
149    /// * `points` - Input points to cluster (n_points × n_dims)
150    ///
151    /// # Returns
152    /// Tuple of (cluster assignments, spike events) where assignments
153    /// maps each point to its cluster and spike_events contains the
154    /// complete spike timing history.
155    pub fn fit(
156        &mut self,
157        points: &ArrayView2<'_, f64>,
158    ) -> SpatialResult<(Array1<usize>, Vec<SpikeEvent>)> {
159        let (n_points, n_dims) = points.dim();
160
161        if n_points == 0 || n_dims == 0 {
162            return Err(SpatialError::InvalidInput(
163                "Input data cannot be empty".to_string(),
164            ));
165        }
166
167        // Initialize neural network
168        self.initialize_network(n_dims)?;
169
170        // Present data points as spike trains
171        let mut assignments = Array1::zeros(n_points);
172
173        for epoch in 0..self.max_epochs {
174            self.current_time = epoch as f64 * 100.0;
175
176            for (point_idx, point) in points.outer_iter().enumerate() {
177                // Encode spatial point as spike train
178                let spike_train = self.encode_point_as_spikes(&point.to_owned())?;
179
180                // Process spike train through network
181                let winning_neuron = self.process_spike_train(&spike_train)?;
182                assignments[point_idx] = winning_neuron;
183
184                // Apply learning if enabled
185                if self.stdp_learning {
186                    self.apply_stdp_learning(&spike_train)?;
187                }
188            }
189
190            // Apply lateral inhibition
191            if self.lateral_inhibition {
192                self.apply_lateral_inhibition()?;
193            }
194        }
195
196        Ok((assignments, self.spike_history.clone()))
197    }
198
199    /// Initialize spiking neural network
200    ///
201    /// Creates the network topology with input neurons, output neurons,
202    /// and synaptic connections between them.
203    fn initialize_network(&mut self, input_dims: usize) -> SpatialResult<()> {
204        self.neurons.clear();
205        self.synapses.clear();
206        self.spike_history.clear();
207
208        // Create input neurons (one per dimension)
209        for i in 0..input_dims {
210            let position = vec![i as f64];
211            let mut neuron = SpikingNeuron::new(position);
212            neuron.set_threshold(self.spike_threshold);
213            self.neurons.push(neuron);
214        }
215
216        // Create output neurons (cluster centers)
217        let mut rng = scirs2_core::random::rng();
218        for _i in 0..self.num_clusters {
219            let position = (0..input_dims)
220                .map(|_| rng.random_range(0.0..1.0))
221                .collect();
222            let mut neuron = SpikingNeuron::new(position);
223            neuron.set_threshold(self.spike_threshold);
224            self.neurons.push(neuron);
225        }
226
227        // Create synaptic connections (input to output)
228        for i in 0..input_dims {
229            for j in 0..self.num_clusters {
230                let output_idx = input_dims + j;
231                let weight = rng.random_range(0.0..0.5);
232                let synapse = Synapse::new(i, output_idx, weight);
233                self.synapses.push(synapse);
234            }
235        }
236
237        // Create lateral inhibitory connections between output neurons
238        if self.lateral_inhibition {
239            for i in 0..self.num_clusters {
240                for j in 0..self.num_clusters {
241                    if i != j {
242                        let neuron_i = input_dims + i;
243                        let neuron_j = input_dims + j;
244                        let synapse = Synapse::new(neuron_i, neuron_j, -0.5);
245                        self.synapses.push(synapse);
246                    }
247                }
248            }
249        }
250
251        Ok(())
252    }
253
254    /// Encode spatial point as spike train
255    ///
256    /// Converts a spatial data point into a spike train using rate coding,
257    /// where the firing rate of each input neuron is proportional to the
258    /// corresponding coordinate value.
259    fn encode_point_as_spikes(&self, point: &Array1<f64>) -> SpatialResult<Vec<SpikeEvent>> {
260        let mut spike_train = Vec::new();
261
262        // Rate coding: spike frequency proportional to coordinate value
263        for (dim, &coord) in point.iter().enumerate() {
264            // Normalize coordinate to [0, 1] and scale to spike rate
265            let normalized_coord = (coord + 10.0) / 20.0; // Assume data in [-10, 10]
266            let spike_rate = normalized_coord.clamp(0.0, 1.0) * 50.0; // Max 50 Hz
267
268            // Generate Poisson spike train
269            let num_spikes = (spike_rate * 1.0) as usize; // 1 second duration
270            for spike_idx in 0..num_spikes {
271                let timestamp =
272                    self.current_time + (spike_idx as f64) * (1.0 / spike_rate.max(1.0));
273                let spike = SpikeEvent::new(dim, timestamp, 1.0, point.to_vec());
274                spike_train.push(spike);
275            }
276        }
277
278        // Sort spikes by timestamp
279        spike_train.sort_by(|a, b| {
280            a.timestamp()
281                .partial_cmp(&b.timestamp())
282                .expect("Operation failed")
283        });
284
285        Ok(spike_train)
286    }
287
288    /// Process spike train through network
289    ///
290    /// Simulates the network dynamics when presented with a spike train,
291    /// determining which output neuron responds most strongly.
292    fn process_spike_train(&mut self, spike_train: &[SpikeEvent]) -> SpatialResult<usize> {
293        let input_dims = self.neurons.len() - self.num_clusters;
294        let mut neuron_spike_counts = vec![0; self.num_clusters];
295
296        // Simulate network for duration of spike train
297        let mut t = self.current_time;
298        let mut spike_idx = 0;
299
300        while t < self.current_time + self.simulation_duration {
301            // Apply input spikes
302            let mut input_currents = vec![0.0; self.neurons.len()];
303
304            while spike_idx < spike_train.len() && spike_train[spike_idx].timestamp() <= t {
305                let spike = &spike_train[spike_idx];
306                if spike.neuron_id() < input_dims {
307                    input_currents[spike.neuron_id()] += spike.amplitude();
308                }
309                spike_idx += 1;
310            }
311
312            // Calculate synaptic currents
313            for synapse in &self.synapses {
314                if synapse.pre_neuron() < self.neurons.len()
315                    && synapse.post_neuron() < self.neurons.len()
316                {
317                    let pre_current = input_currents[synapse.pre_neuron()];
318                    let synaptic_current = synapse.synaptic_current(pre_current);
319                    input_currents[synapse.post_neuron()] += synaptic_current;
320                }
321            }
322
323            // Update neurons and check for spikes
324            for (neuron_idx, neuron) in self.neurons.iter_mut().enumerate() {
325                let spiked = neuron.update(self.dt, input_currents[neuron_idx]);
326
327                if spiked && neuron_idx >= input_dims {
328                    let cluster_idx = neuron_idx - input_dims;
329                    neuron_spike_counts[cluster_idx] += 1;
330
331                    // Record spike event
332                    let spike_event =
333                        SpikeEvent::new(neuron_idx, t, 1.0, neuron.position().to_vec());
334                    self.spike_history.push(spike_event);
335                }
336            }
337
338            t += self.dt;
339        }
340
341        // Find winning neuron (cluster with most spikes)
342        let winning_cluster = neuron_spike_counts
343            .iter()
344            .enumerate()
345            .max_by(|(_, a), (_, b)| a.cmp(b))
346            .map(|(idx, _)| idx)
347            .unwrap_or(0);
348
349        Ok(winning_cluster)
350    }
351
352    /// Apply STDP learning to synapses
353    ///
354    /// Updates synaptic weights based on the relative timing of pre- and
355    /// post-synaptic spikes using the STDP learning rule.
356    fn apply_stdp_learning(&mut self, spike_train: &[SpikeEvent]) -> SpatialResult<()> {
357        // Create spike timing map
358        let mut spike_times: HashMap<usize, Vec<f64>> = HashMap::new();
359        for spike in spike_train {
360            spike_times
361                .entry(spike.neuron_id())
362                .or_default()
363                .push(spike.timestamp());
364        }
365
366        // Add output neuron spikes from history
367        for spike in &self.spike_history {
368            spike_times
369                .entry(spike.neuron_id())
370                .or_default()
371                .push(spike.timestamp());
372        }
373
374        // Update synaptic weights using STDP
375        let empty_spikes = Vec::new();
376        for synapse in &mut self.synapses {
377            let pre_spikes = spike_times
378                .get(&synapse.pre_neuron())
379                .unwrap_or(&empty_spikes);
380            let post_spikes = spike_times
381                .get(&synapse.post_neuron())
382                .unwrap_or(&empty_spikes);
383
384            // Check for coincident spikes
385            for &pre_time in pre_spikes {
386                for &post_time in post_spikes {
387                    let dt = post_time - pre_time;
388                    if dt.abs() < 50.0 {
389                        // Within STDP window
390                        let current_weight = synapse.weight();
391                        if dt > 0.0 {
392                            // Potentiation
393                            let delta_w = synapse.stdp_rate() * (-dt / synapse.stdp_tau()).exp();
394                            synapse.set_weight(current_weight + delta_w);
395                        } else {
396                            // Depression
397                            let delta_w = synapse.stdp_rate() * (dt / synapse.stdp_tau()).exp();
398                            synapse.set_weight(current_weight - delta_w);
399                        }
400                    }
401                }
402            }
403        }
404
405        Ok(())
406    }
407
408    /// Apply lateral inhibition between output neurons
409    ///
410    /// Strengthens inhibitory connections between neurons based on their
411    /// relative activity levels to promote competition.
412    fn apply_lateral_inhibition(&mut self) -> SpatialResult<()> {
413        let input_dims = self.neurons.len() - self.num_clusters;
414
415        // Strengthen inhibitory connections between active neurons
416        for i in 0..self.num_clusters {
417            for j in 0..self.num_clusters {
418                if i != j {
419                    let neuron_i_idx = input_dims + i;
420                    let neuron_j_idx = input_dims + j;
421
422                    // Find inhibitory synapse
423                    for synapse in &mut self.synapses {
424                        if synapse.pre_neuron() == neuron_i_idx
425                            && synapse.post_neuron() == neuron_j_idx
426                        {
427                            // Strengthen inhibition based on activity
428                            let activity_i = self.neurons[neuron_i_idx].membrane_potential();
429                            let activity_j = self.neurons[neuron_j_idx].membrane_potential();
430
431                            if activity_i > activity_j {
432                                let current_weight = synapse.weight();
433                                synapse.set_weight(current_weight - 0.01); // Strengthen inhibition
434                            }
435                        }
436                    }
437                }
438            }
439        }
440
441        Ok(())
442    }
443
444    /// Get number of clusters
445    pub fn num_clusters(&self) -> usize {
446        self.num_clusters
447    }
448
449    /// Get spike threshold
450    pub fn spike_threshold(&self) -> f64 {
451        self.spike_threshold
452    }
453
454    /// Check if STDP learning is enabled
455    pub fn is_stdp_enabled(&self) -> bool {
456        self.stdp_learning
457    }
458
459    /// Check if lateral inhibition is enabled
460    pub fn is_lateral_inhibition_enabled(&self) -> bool {
461        self.lateral_inhibition
462    }
463
464    /// Get current spike history
465    pub fn spike_history(&self) -> &[SpikeEvent] {
466        &self.spike_history
467    }
468
469    /// Get network statistics
470    pub fn network_stats(&self) -> NetworkStats {
471        NetworkStats {
472            num_neurons: self.neurons.len(),
473            num_synapses: self.synapses.len(),
474            num_spikes: self.spike_history.len(),
475            average_weight: if self.synapses.is_empty() {
476                0.0
477            } else {
478                self.synapses.iter().map(|s| s.weight()).sum::<f64>() / self.synapses.len() as f64
479            },
480        }
481    }
482
483    /// Reset the network to initial state
484    pub fn reset(&mut self) {
485        for neuron in &mut self.neurons {
486            neuron.reset();
487        }
488        for synapse in &mut self.synapses {
489            synapse.reset_spike_history();
490        }
491        self.spike_history.clear();
492        self.current_time = 0.0;
493    }
494}
495
496/// Network statistics for analysis
497#[derive(Debug, Clone)]
498pub struct NetworkStats {
499    /// Total number of neurons
500    pub num_neurons: usize,
501    /// Total number of synapses
502    pub num_synapses: usize,
503    /// Total number of spikes recorded
504    pub num_spikes: usize,
505    /// Average synaptic weight
506    pub average_weight: f64,
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use scirs2_core::ndarray::Array2;
513
514    #[test]
515    fn test_spiking_clusterer_creation() {
516        let clusterer = SpikingNeuralClusterer::new(3);
517        assert_eq!(clusterer.num_clusters(), 3);
518        assert_eq!(clusterer.spike_threshold(), 1.0);
519        assert!(clusterer.is_stdp_enabled());
520        assert!(clusterer.is_lateral_inhibition_enabled());
521    }
522
523    #[test]
524    fn test_clusterer_configuration() {
525        let clusterer = SpikingNeuralClusterer::new(2)
526            .with_spike_threshold(0.8)
527            .with_stdp_learning(false)
528            .with_lateral_inhibition(false)
529            .with_training_params(50, 5.0);
530
531        assert_eq!(clusterer.spike_threshold(), 0.8);
532        assert!(!clusterer.is_stdp_enabled());
533        assert!(!clusterer.is_lateral_inhibition_enabled());
534        assert_eq!(clusterer.max_epochs, 50);
535        assert_eq!(clusterer.simulation_duration, 5.0);
536    }
537
538    #[test]
539    fn test_simple_clustering() {
540        let points = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
541            .expect("Operation failed");
542
543        let mut clusterer = SpikingNeuralClusterer::new(2).with_training_params(5, 1.0); // Reduced for test speed
544
545        let result = clusterer.fit(&points.view());
546        assert!(result.is_ok());
547
548        let (assignments, spike_events) = result.expect("Operation failed");
549        assert_eq!(assignments.len(), 4);
550
551        // Should have recorded some spike events
552        assert!(!spike_events.is_empty());
553    }
554
555    #[test]
556    fn test_empty_input() {
557        let points = Array2::zeros((0, 2));
558        let mut clusterer = SpikingNeuralClusterer::new(2);
559
560        let result = clusterer.fit(&points.view());
561        assert!(result.is_err());
562    }
563
564    #[test]
565    fn test_network_initialization() {
566        let mut clusterer = SpikingNeuralClusterer::new(2);
567        clusterer.initialize_network(3).expect("Operation failed");
568
569        let stats = clusterer.network_stats();
570        assert_eq!(stats.num_neurons, 5); // 3 input + 2 output
571
572        // Should have input-to-output connections
573        let expected_connections = 3 * 2; // input_dims * num_clusters
574                                          // Plus lateral inhibition connections: num_clusters * (num_clusters - 1)
575        let lateral_connections = 2;
576        assert_eq!(
577            stats.num_synapses,
578            expected_connections + lateral_connections
579        );
580    }
581
582    #[test]
583    fn test_spike_encoding() {
584        let clusterer = SpikingNeuralClusterer::new(2);
585        let point = Array1::from_vec(vec![1.0, -1.0]);
586
587        let spike_train = clusterer
588            .encode_point_as_spikes(&point)
589            .expect("Operation failed");
590
591        // Should generate spikes for each dimension
592        assert!(!spike_train.is_empty());
593
594        // Spikes should be sorted by timestamp
595        for i in 1..spike_train.len() {
596            assert!(spike_train[i - 1].timestamp() <= spike_train[i].timestamp());
597        }
598    }
599
600    #[test]
601    fn test_network_reset() {
602        let mut clusterer = SpikingNeuralClusterer::new(2);
603        clusterer.initialize_network(2).expect("Operation failed");
604
605        // Add some activity
606        clusterer
607            .spike_history
608            .push(SpikeEvent::new(0, 1.0, 1.0, vec![0.0, 0.0]));
609        clusterer.current_time = 100.0;
610
611        // Reset should clear history and time
612        clusterer.reset();
613        assert!(clusterer.spike_history().is_empty());
614        assert_eq!(clusterer.current_time, 0.0);
615    }
616
617    #[test]
618    fn test_network_stats() {
619        let mut clusterer = SpikingNeuralClusterer::new(2);
620        clusterer.initialize_network(3).expect("Operation failed");
621
622        let stats = clusterer.network_stats();
623        assert_eq!(stats.num_neurons, 5);
624        assert!(stats.num_synapses > 0);
625        assert_eq!(stats.num_spikes, 0); // No activity yet
626        assert!(stats.average_weight.is_finite()); // Allow negative weights for inhibitory synapses
627    }
628}