1#![allow(dead_code)]
8use crate::{GraphData, GraphLayer};
9use std::collections::{HashMap, VecDeque};
10use torsh_tensor::{
11 creation::{randn, zeros},
12 Tensor,
13};
14
15#[derive(Debug, Clone)]
17pub struct SpikingGraphNetwork {
18 pub num_nodes: usize,
20 pub input_dim: usize,
22 pub hidden_dim: usize,
24 pub membrane_potentials: Tensor,
26 pub synaptic_weights: Tensor,
28 pub spike_threshold: f32,
30 pub tau_membrane: f32,
32 pub tau_synapse: f32,
34 pub refractory_period: usize,
36 pub spike_history: HashMap<usize, VecDeque<f32>>,
38 pub last_spike_times: Vec<Option<usize>>,
40 pub current_time: usize,
42 pub learning_rate: f32,
44 pub stdp_params: STDPParameters,
46}
47
48#[derive(Debug, Clone)]
50pub struct STDPParameters {
51 pub tau_pre: f32,
53 pub tau_post: f32,
55 pub a_plus: f32,
57 pub a_minus: f32,
59 pub learning_rate: f32,
61}
62
63impl STDPParameters {
64 pub fn new() -> Self {
65 Self {
66 tau_pre: 20.0,
67 tau_post: 20.0,
68 a_plus: 0.1,
69 a_minus: 0.12,
70 learning_rate: 0.01,
71 }
72 }
73}
74
75impl SpikingGraphNetwork {
76 pub fn new(
78 num_nodes: usize,
79 input_dim: usize,
80 hidden_dim: usize,
81 ) -> Result<Self, Box<dyn std::error::Error>> {
82 let membrane_potentials = zeros(&[num_nodes, hidden_dim])?;
83 let synaptic_weights = randn(&[num_nodes, num_nodes])?.mul_scalar(0.1)?;
84
85 let mut spike_history = HashMap::new();
86 for i in 0..num_nodes {
87 spike_history.insert(i, VecDeque::new());
88 }
89
90 Ok(Self {
91 num_nodes,
92 input_dim,
93 hidden_dim,
94 membrane_potentials,
95 synaptic_weights,
96 spike_threshold: 1.0,
97 tau_membrane: 20.0,
98 tau_synapse: 5.0,
99 refractory_period: 2,
100 spike_history,
101 last_spike_times: vec![None; num_nodes],
102 current_time: 0,
103 learning_rate: 0.01,
104 stdp_params: STDPParameters::new(),
105 })
106 }
107
108 pub fn forward_spike(
110 &mut self,
111 graph: &GraphData,
112 input_spikes: &Tensor,
113 ) -> Result<SpikingOutput, Box<dyn std::error::Error>> {
114 let _output_spikes = zeros::<f32>(&[self.num_nodes])?;
115 let spike_times = Vec::new();
116
117 self.update_membrane_potentials(input_spikes)?;
119
120 let spikes = self.generate_spikes()?;
122
123 let propagated_spikes = self.propagate_spikes(&spikes, graph)?;
125
126 self.apply_stdp_learning(&spikes)?;
128
129 self.update_spike_history(&spikes)?;
131
132 self.apply_refractory_period()?;
134
135 self.current_time += 1;
136
137 Ok(SpikingOutput {
138 spikes: propagated_spikes,
139 membrane_potentials: self.membrane_potentials.clone(),
140 spike_times,
141 firing_rates: self.compute_firing_rates()?,
142 })
143 }
144
145 fn update_membrane_potentials(
147 &mut self,
148 input_spikes: &Tensor,
149 ) -> Result<(), Box<dyn std::error::Error>> {
150 let decay_factor = (-1.0 / self.tau_membrane).exp();
152
153 self.membrane_potentials = self.membrane_potentials.mul_scalar(decay_factor)?;
155
156 let input_current = self.compute_input_current(input_spikes)?;
158 self.membrane_potentials = self.membrane_potentials.add(&input_current)?;
159
160 Ok(())
161 }
162
163 fn compute_input_current(
165 &self,
166 input_spikes: &Tensor,
167 ) -> Result<Tensor, Box<dyn std::error::Error>> {
168 let input_weights = randn(&[self.input_dim, self.hidden_dim])?.mul_scalar(0.5)?;
170
171 input_spikes
173 .matmul(&input_weights)
174 .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
175 }
176
177 fn generate_spikes(&mut self) -> Result<Tensor, Box<dyn std::error::Error>> {
179 let mut spikes = zeros(&[self.num_nodes])?;
180 let membrane_data = self.membrane_potentials.to_vec()?;
181
182 for node in 0..self.num_nodes {
183 if let Some(last_spike_time) = self.last_spike_times[node] {
185 if self.current_time - last_spike_time < self.refractory_period {
186 continue;
187 }
188 }
189
190 let membrane_potential = membrane_data[node * self.hidden_dim]; if membrane_potential > self.spike_threshold {
193 spikes = self.set_spike(spikes, node, 1.0)?;
195 self.last_spike_times[node] = Some(self.current_time);
196
197 self.reset_membrane_potential(node)?;
199 }
200 }
201
202 Ok(spikes)
203 }
204
205 fn propagate_spikes(
207 &self,
208 spikes: &Tensor,
209 graph: &GraphData,
210 ) -> Result<Tensor, Box<dyn std::error::Error>> {
211 let edge_data = graph.edge_index.to_vec()?;
213 let num_edges = edge_data.len() / 2;
214
215 let mut propagated = spikes.clone();
216
217 for edge_idx in 0..num_edges {
219 let src_node = edge_data[edge_idx] as usize;
220 let dst_node = edge_data[edge_idx + num_edges] as usize;
221
222 if src_node < self.num_nodes && dst_node < self.num_nodes {
223 let weight = self.get_synaptic_weight(src_node, dst_node)?;
225
226 let src_spike = self.get_spike_value(spikes, src_node)?;
228 if src_spike > 0.0 {
229 let propagated_value = src_spike * weight;
230 propagated =
231 self.add_spike_contribution(propagated, dst_node, propagated_value)?;
232 }
233 }
234 }
235
236 Ok(propagated)
237 }
238
239 fn apply_stdp_learning(&mut self, spikes: &Tensor) -> Result<(), Box<dyn std::error::Error>> {
241 let spike_data = spikes.to_vec()?;
242
243 for pre_node in 0..self.num_nodes {
244 for post_node in 0..self.num_nodes {
245 if pre_node == post_node {
246 continue;
247 }
248
249 if let (Some(pre_history), Some(post_history)) = (
251 self.spike_history.get(&pre_node),
252 self.spike_history.get(&post_node),
253 ) {
254 let weight_update = self.calculate_stdp_update(
256 pre_history,
257 post_history,
258 spike_data[pre_node],
259 spike_data[post_node],
260 );
261
262 self.update_synaptic_weight(pre_node, post_node, weight_update)?;
264 }
265 }
266 }
267
268 Ok(())
269 }
270
271 fn calculate_stdp_update(
273 &self,
274 pre_history: &VecDeque<f32>,
275 post_history: &VecDeque<f32>,
276 current_pre_spike: f32,
277 current_post_spike: f32,
278 ) -> f32 {
279 let mut weight_update = 0.0;
280
281 if current_pre_spike > 0.0 && current_post_spike > 0.0 {
283 weight_update += self.stdp_params.a_plus * 0.1;
285 }
286
287 for (i, &pre_spike) in pre_history.iter().rev().enumerate() {
289 for (j, &post_spike) in post_history.iter().rev().enumerate() {
290 if pre_spike > 0.0 && post_spike > 0.0 {
291 let dt = (i as f32) - (j as f32);
292
293 if dt > 0.0 {
294 let strength =
296 self.stdp_params.a_plus * (-dt / self.stdp_params.tau_pre).exp();
297 weight_update += strength;
298 } else if dt < 0.0 {
299 let strength =
301 self.stdp_params.a_minus * (dt / self.stdp_params.tau_post).exp();
302 weight_update -= strength;
303 }
304 }
305 }
306 }
307
308 weight_update * self.stdp_params.learning_rate
309 }
310
311 fn update_spike_history(&mut self, spikes: &Tensor) -> Result<(), Box<dyn std::error::Error>> {
313 let spike_data = spikes.to_vec()?;
314
315 for node in 0..self.num_nodes {
316 if let Some(history) = self.spike_history.get_mut(&node) {
317 history.push_back(spike_data[node]);
318
319 if history.len() > 100 {
321 history.pop_front();
322 }
323 }
324 }
325
326 Ok(())
327 }
328
329 fn apply_refractory_period(&mut self) -> Result<(), Box<dyn std::error::Error>> {
331 for node in 0..self.num_nodes {
333 if let Some(last_spike_time) = self.last_spike_times[node] {
334 if self.current_time - last_spike_time < self.refractory_period {
335 self.set_membrane_potential(node, 0.0)?;
336 }
337 }
338 }
339
340 Ok(())
341 }
342
343 fn compute_firing_rates(&self) -> Result<Tensor, Box<dyn std::error::Error>> {
345 let mut firing_rates = zeros(&[self.num_nodes])?;
346 let window_size = 100; for node in 0..self.num_nodes {
349 if let Some(history) = self.spike_history.get(&node) {
350 let recent_spikes: f32 = history.iter().rev().take(window_size).sum();
351 let rate = recent_spikes / window_size as f32;
352 firing_rates = self.set_firing_rate(firing_rates, node, rate)?;
353 }
354 }
355
356 Ok(firing_rates)
357 }
358
359 fn set_spike(
362 &self,
363 spikes: Tensor,
364 _node: usize,
365 _value: f32,
366 ) -> Result<Tensor, Box<dyn std::error::Error>> {
367 Ok(spikes)
369 }
370
371 fn reset_membrane_potential(&mut self, node: usize) -> Result<(), Box<dyn std::error::Error>> {
372 self.set_membrane_potential(node, -0.7)?;
374 Ok(())
375 }
376
377 fn set_membrane_potential(
378 &mut self,
379 _node: usize,
380 _value: f32,
381 ) -> Result<(), Box<dyn std::error::Error>> {
382 Ok(())
384 }
385
386 fn get_synaptic_weight(
387 &self,
388 _src: usize,
389 _dst: usize,
390 ) -> Result<f32, Box<dyn std::error::Error>> {
391 Ok(0.1)
393 }
394
395 fn update_synaptic_weight(
396 &mut self,
397 _src: usize,
398 _dst: usize,
399 _update: f32,
400 ) -> Result<(), Box<dyn std::error::Error>> {
401 Ok(())
403 }
404
405 fn get_spike_value(
406 &self,
407 _spikes: &Tensor,
408 _node: usize,
409 ) -> Result<f32, Box<dyn std::error::Error>> {
410 Ok(0.0)
412 }
413
414 fn add_spike_contribution(
415 &self,
416 spikes: Tensor,
417 _node: usize,
418 _value: f32,
419 ) -> Result<Tensor, Box<dyn std::error::Error>> {
420 Ok(spikes)
422 }
423
424 fn set_firing_rate(
425 &self,
426 rates: Tensor,
427 _node: usize,
428 _rate: f32,
429 ) -> Result<Tensor, Box<dyn std::error::Error>> {
430 Ok(rates)
432 }
433}
434
435#[derive(Debug, Clone)]
437pub struct SpikingOutput {
438 pub spikes: Tensor,
440 pub membrane_potentials: Tensor,
442 pub spike_times: Vec<f32>,
444 pub firing_rates: Tensor,
446}
447
448#[derive(Debug)]
450pub struct EventDrivenGraphProcessor {
451 pub event_queue: VecDeque<GraphEvent>,
453 pub node_states: HashMap<usize, NodeState>,
455 pub processing_stats: EventProcessingStats,
457 pub energy_tracker: EnergyTracker,
459}
460
461#[derive(Debug, Clone)]
463pub struct GraphEvent {
464 pub timestamp: f64,
466 pub source_node: usize,
468 pub target_node: usize,
470 pub event_type: EventType,
472 pub data: f32,
474 pub priority: u8,
476}
477
478#[derive(Debug, Clone)]
479pub enum EventType {
480 Spike,
482 FeatureUpdate,
484 WeightUpdate,
486 ThresholdUpdate,
488 TopologyChange,
490}
491
492#[derive(Debug, Clone)]
494pub struct NodeState {
495 pub membrane_potential: f32,
497 pub last_update: f64,
499 pub charge: f32,
501 pub threshold: f32,
503 pub refractory_until: f64,
505 pub energy_consumed: f32,
507}
508
509impl EventDrivenGraphProcessor {
510 pub fn new(num_nodes: usize) -> Self {
512 let mut node_states = HashMap::new();
513 for i in 0..num_nodes {
514 node_states.insert(
515 i,
516 NodeState {
517 membrane_potential: -0.7,
518 last_update: 0.0,
519 charge: 0.0,
520 threshold: 1.0,
521 refractory_until: 0.0,
522 energy_consumed: 0.0,
523 },
524 );
525 }
526
527 Self {
528 event_queue: VecDeque::new(),
529 node_states,
530 processing_stats: EventProcessingStats::new(),
531 energy_tracker: EnergyTracker::new(),
532 }
533 }
534
535 pub fn process_events(&mut self, current_time: f64) -> Vec<GraphEvent> {
537 let mut generated_events = Vec::new();
538 let mut events_processed = 0;
539
540 while let Some(event) = self.event_queue.pop_front() {
541 if event.timestamp > current_time {
542 self.event_queue.push_front(event);
544 break;
545 }
546
547 let new_events = self.process_single_event(&event, current_time);
549 generated_events.extend(new_events);
550 events_processed += 1;
551
552 self.energy_tracker.record_event_processing();
554 }
555
556 self.processing_stats.events_processed += events_processed;
557 generated_events
558 }
559
560 fn process_single_event(&mut self, event: &GraphEvent, current_time: f64) -> Vec<GraphEvent> {
562 let mut new_events = Vec::new();
563
564 match event.event_type {
565 EventType::Spike => {
566 new_events.extend(self.process_spike_event(event, current_time));
567 }
568 EventType::FeatureUpdate => {
569 self.process_feature_update(event, current_time);
570 }
571 EventType::WeightUpdate => {
572 self.process_weight_update(event, current_time);
573 }
574 EventType::ThresholdUpdate => {
575 self.process_threshold_update(event, current_time);
576 }
577 EventType::TopologyChange => {
578 new_events.extend(self.process_topology_change(event, current_time));
579 }
580 }
581
582 new_events
583 }
584
585 fn process_spike_event(&mut self, event: &GraphEvent, current_time: f64) -> Vec<GraphEvent> {
587 let mut new_events = Vec::new();
588
589 if let Some(target_state) = self.node_states.get_mut(&event.target_node) {
590 if current_time < target_state.refractory_until {
592 return new_events;
593 }
594
595 target_state.membrane_potential += event.data;
597 target_state.last_update = current_time;
598
599 if target_state.membrane_potential >= target_state.threshold {
601 target_state.membrane_potential = -0.7; target_state.refractory_until = current_time + 0.002; let spike_event = GraphEvent {
607 timestamp: current_time + 0.001, source_node: event.target_node,
609 target_node: 0, event_type: EventType::Spike,
611 data: 1.0,
612 priority: 1,
613 };
614
615 new_events.push(spike_event);
616
617 self.energy_tracker.record_spike();
619 }
620 }
621
622 new_events
623 }
624
625 fn process_feature_update(&mut self, event: &GraphEvent, current_time: f64) {
626 if let Some(node_state) = self.node_states.get_mut(&event.target_node) {
627 node_state.charge += event.data;
629 node_state.last_update = current_time;
630 }
631 }
632
633 fn process_weight_update(&mut self, _event: &GraphEvent, _current_time: f64) {
634 self.energy_tracker.record_weight_update();
636 }
637
638 fn process_threshold_update(&mut self, event: &GraphEvent, current_time: f64) {
639 if let Some(node_state) = self.node_states.get_mut(&event.target_node) {
640 node_state.threshold = event.data;
641 node_state.last_update = current_time;
642 }
643 }
644
645 fn process_topology_change(
646 &mut self,
647 _event: &GraphEvent,
648 _current_time: f64,
649 ) -> Vec<GraphEvent> {
650 vec![]
652 }
653
654 pub fn add_event(&mut self, event: GraphEvent) {
656 let insert_pos = self
658 .event_queue
659 .iter()
660 .position(|e| e.timestamp > event.timestamp)
661 .unwrap_or(self.event_queue.len());
662
663 self.event_queue.insert(insert_pos, event);
664 }
665}
666
667#[derive(Debug, Clone)]
669pub struct EventProcessingStats {
670 pub events_processed: usize,
671 pub spikes_generated: usize,
672 pub average_processing_time: f64,
673 pub queue_length_max: usize,
674}
675
676impl EventProcessingStats {
677 pub fn new() -> Self {
678 Self {
679 events_processed: 0,
680 spikes_generated: 0,
681 average_processing_time: 0.0,
682 queue_length_max: 0,
683 }
684 }
685}
686
687#[derive(Debug, Clone)]
689pub struct EnergyTracker {
690 pub total_energy: f32,
692 pub energy_per_spike: f32,
694 pub energy_per_weight_update: f32,
696 pub energy_per_event: f32,
698 pub spike_count: usize,
700 pub weight_update_count: usize,
701 pub event_count: usize,
702}
703
704impl EnergyTracker {
705 pub fn new() -> Self {
706 Self {
707 total_energy: 0.0,
708 energy_per_spike: 1e-12, energy_per_weight_update: 1e-15, energy_per_event: 1e-15,
711 spike_count: 0,
712 weight_update_count: 0,
713 event_count: 0,
714 }
715 }
716
717 pub fn record_spike(&mut self) {
718 self.total_energy += self.energy_per_spike;
719 self.spike_count += 1;
720 }
721
722 pub fn record_weight_update(&mut self) {
723 self.total_energy += self.energy_per_weight_update;
724 self.weight_update_count += 1;
725 }
726
727 pub fn record_event_processing(&mut self) {
728 self.total_energy += self.energy_per_event;
729 self.event_count += 1;
730 }
731
732 pub fn get_energy_efficiency(&self) -> f32 {
733 if self.event_count > 0 {
734 self.total_energy / self.event_count as f32
735 } else {
736 0.0
737 }
738 }
739}
740
741#[derive(Debug, Clone)]
743pub struct LiquidStateMachine {
744 pub reservoir_size: usize,
746 pub connection_prob: f32,
748 pub spectral_radius: f32,
750 pub input_scaling: f32,
752 pub leak_rate: f32,
754 pub state: Tensor,
756 pub input_weights: Tensor,
758 pub reservoir_weights: Tensor,
760 pub memory_capacity: usize,
762 pub state_history: VecDeque<Tensor>,
764}
765
766impl LiquidStateMachine {
767 pub fn new(
769 input_dim: usize,
770 reservoir_size: usize,
771 connection_prob: f32,
772 ) -> Result<Self, Box<dyn std::error::Error>> {
773 let input_weights = randn(&[input_dim, reservoir_size])?.mul_scalar(0.1)?;
774 let reservoir_weights = Self::create_sparse_reservoir(reservoir_size, connection_prob)?;
775 let state = zeros(&[reservoir_size])?;
776
777 Ok(Self {
778 reservoir_size,
779 connection_prob,
780 spectral_radius: 0.9,
781 input_scaling: 1.0,
782 leak_rate: 0.3,
783 state,
784 input_weights,
785 reservoir_weights,
786 memory_capacity: 100,
787 state_history: VecDeque::new(),
788 })
789 }
790
791 pub fn process(&mut self, input: &Tensor) -> Result<Tensor, Box<dyn std::error::Error>> {
793 let reservoir_input = input.matmul(&self.input_weights)?;
795
796 let reservoir_activation = self.state.matmul(&self.reservoir_weights)?;
798 let total_input = reservoir_input.add(&reservoir_activation)?;
799
800 let activated = self.apply_tanh(&total_input)?;
802
803 let leak_complement = 1.0 - self.leak_rate;
805 self.state = self
806 .state
807 .mul_scalar(leak_complement)?
808 .add(&activated.mul_scalar(self.leak_rate)?)?;
809
810 self.state_history.push_back(self.state.clone());
812 if self.state_history.len() > self.memory_capacity {
813 self.state_history.pop_front();
814 }
815
816 Ok(self.state.clone())
817 }
818
819 fn create_sparse_reservoir(
820 size: usize,
821 prob: f32,
822 ) -> Result<Tensor, Box<dyn std::error::Error>> {
823 let mut weights = randn(&[size, size])?;
825
826 weights = weights.mul_scalar(prob)?;
828
829 Ok(weights)
830 }
831
832 fn apply_tanh(&self, tensor: &Tensor) -> Result<Tensor, Box<dyn std::error::Error>> {
833 Ok(tensor.clone())
835 }
836}
837
838#[derive(Debug)]
840pub struct NeuromorphicGraphLayer {
841 pub spiking_network: SpikingGraphNetwork,
843 pub event_processor: EventDrivenGraphProcessor,
845 pub liquid_state_machine: LiquidStateMachine,
847 pub processing_mode: NeuromorphicMode,
849}
850
851#[derive(Debug, Clone)]
852pub enum NeuromorphicMode {
853 Spiking,
855 EventDriven,
857 LiquidState,
859 Hybrid,
861}
862
863impl NeuromorphicGraphLayer {
864 pub fn new(
865 num_nodes: usize,
866 input_dim: usize,
867 hidden_dim: usize,
868 ) -> Result<Self, Box<dyn std::error::Error>> {
869 let spiking_network = SpikingGraphNetwork::new(num_nodes, input_dim, hidden_dim)?;
870 let event_processor = EventDrivenGraphProcessor::new(num_nodes);
871 let liquid_state_machine = LiquidStateMachine::new(input_dim, hidden_dim, 0.1)?;
872
873 Ok(Self {
874 spiking_network,
875 event_processor,
876 liquid_state_machine,
877 processing_mode: NeuromorphicMode::Hybrid,
878 })
879 }
880
881 pub fn set_mode(&mut self, mode: NeuromorphicMode) {
883 self.processing_mode = mode;
884 }
885}
886
887impl GraphLayer for NeuromorphicGraphLayer {
888 fn forward(&self, graph: &GraphData) -> GraphData {
889 graph.clone()
892 }
893
894 fn parameters(&self) -> Vec<Tensor> {
895 vec![
896 self.spiking_network.synaptic_weights.clone(),
897 self.liquid_state_machine.input_weights.clone(),
898 self.liquid_state_machine.reservoir_weights.clone(),
899 ]
900 }
901}
902
903#[cfg(test)]
904mod tests {
905 use super::*;
906
907 #[test]
908 fn test_spiking_network_creation() {
909 let network = SpikingGraphNetwork::new(10, 5, 8);
910 assert!(network.is_ok());
911
912 let net = network.unwrap();
913 assert_eq!(net.num_nodes, 10);
914 assert_eq!(net.input_dim, 5);
915 assert_eq!(net.hidden_dim, 8);
916 assert_eq!(net.spike_threshold, 1.0);
917 }
918
919 #[test]
920 fn test_stdp_parameters() {
921 let stdp = STDPParameters::new();
922 assert_eq!(stdp.tau_pre, 20.0);
923 assert_eq!(stdp.tau_post, 20.0);
924 assert_eq!(stdp.a_plus, 0.1);
925 assert_eq!(stdp.a_minus, 0.12);
926 }
927
928 #[test]
929 fn test_event_driven_processor() {
930 let processor = EventDrivenGraphProcessor::new(5);
931 assert_eq!(processor.node_states.len(), 5);
932 assert_eq!(processor.event_queue.len(), 0);
933 }
934
935 #[test]
936 fn test_graph_event_creation() {
937 let event = GraphEvent {
938 timestamp: 1.0,
939 source_node: 0,
940 target_node: 1,
941 event_type: EventType::Spike,
942 data: 1.0,
943 priority: 1,
944 };
945
946 assert_eq!(event.timestamp, 1.0);
947 assert_eq!(event.source_node, 0);
948 assert_eq!(event.target_node, 1);
949 }
950
951 #[test]
952 fn test_energy_tracker() {
953 let mut tracker = EnergyTracker::new();
954 tracker.record_spike();
955 tracker.record_weight_update();
956
957 assert_eq!(tracker.spike_count, 1);
958 assert_eq!(tracker.weight_update_count, 1);
959 assert!(tracker.total_energy > 0.0);
960 }
961
962 #[test]
963 fn test_liquid_state_machine() {
964 let lsm = LiquidStateMachine::new(3, 10, 0.1);
965 assert!(lsm.is_ok());
966
967 let machine = lsm.unwrap();
968 assert_eq!(machine.reservoir_size, 10);
969 assert_eq!(machine.connection_prob, 0.1);
970 assert_eq!(machine.spectral_radius, 0.9);
971 }
972
973 #[test]
974 fn test_neuromorphic_layer_creation() {
975 let layer = NeuromorphicGraphLayer::new(5, 3, 8);
976 assert!(layer.is_ok());
977
978 let neuromorphic_layer = layer.unwrap();
979 assert_eq!(neuromorphic_layer.spiking_network.num_nodes, 5);
980 }
981
982 #[test]
983 fn test_node_state() {
984 let state = NodeState {
985 membrane_potential: -0.7,
986 last_update: 0.0,
987 charge: 0.0,
988 threshold: 1.0,
989 refractory_until: 0.0,
990 energy_consumed: 0.0,
991 };
992
993 assert_eq!(state.membrane_potential, -0.7);
994 assert_eq!(state.threshold, 1.0);
995 }
996}