1use super::{
12 utils, StreamingConfig, StreamingDataPoint, StreamingObjective, StreamingOptimizer,
13 StreamingStats,
14};
15use crate::error::OptimizeError;
16use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
17use std::collections::{HashMap, VecDeque};
24use std::time::{Duration, Instant};
25
26type Result<T> = std::result::Result<T, OptimizeError>;
27
28#[derive(Debug, Clone)]
30pub struct DistributedConsensusNode {
31 pub node_id: usize,
33 pub local_parameters: Array1<f64>,
35 pub consensus_parameters: Array1<f64>,
37 pub trust_scores: HashMap<usize, f64>,
39 pub byzantine_detector: ByzantineFaultDetector,
41 pub peer_history: HashMap<usize, VecDeque<ConsensusMessage>>,
43 pub gradient_accumulator: Array1<f64>,
45 pub voting_state: ConsensusVotingState,
47 pub network_topology: NetworkTopology,
49}
50
51#[derive(Debug, Clone)]
53pub struct ByzantineFaultDetector {
54 pub reputation_scores: HashMap<usize, f64>,
56 pub suspicion_counters: HashMap<usize, usize>,
58 pub deviation_history: HashMap<usize, VecDeque<f64>>,
60 pub fault_threshold: f64,
62 pub recovery_period: Duration,
64 pub last_detection_times: HashMap<usize, Instant>,
66}
67
68impl ByzantineFaultDetector {
69 pub fn new(_faultthreshold: f64) -> Self {
70 Self {
71 reputation_scores: HashMap::new(),
72 suspicion_counters: HashMap::new(),
73 deviation_history: HashMap::new(),
74 fault_threshold: _faultthreshold,
75 recovery_period: Duration::from_secs(300), last_detection_times: HashMap::new(),
77 }
78 }
79
80 pub fn detect_byzantine_behavior(
82 &mut self,
83 node_id: usize,
84 proposed_params: &ArrayView1<f64>,
85 consensus_params: &ArrayView1<f64>,
86 current_time: Instant,
87 ) -> bool {
88 let deviation = (proposed_params - consensus_params).mapv(|x| x.abs()).sum()
90 / proposed_params.len() as f64;
91
92 let history = self
94 .deviation_history
95 .entry(node_id)
96 .or_insert_with(|| VecDeque::with_capacity(100));
97 history.push_back(deviation);
98 if history.len() > 100 {
99 history.pop_front();
100 }
101
102 if deviation > self.fault_threshold {
104 let suspicion = self.suspicion_counters.entry(node_id).or_insert(0);
105 *suspicion += 1;
106
107 let reputation = self.reputation_scores.entry(node_id).or_insert(1.0);
109 *reputation *= 0.85;
110
111 if *suspicion > 5 && *reputation < 0.3 {
113 self.last_detection_times.insert(node_id, current_time);
114 return true;
115 }
116 } else {
117 let reputation = self.reputation_scores.entry(node_id).or_insert(1.0);
119 *reputation = (*reputation + 0.01).min(1.0);
120
121 if let Some(suspicion) = self.suspicion_counters.get_mut(&node_id) {
123 *suspicion = suspicion.saturating_sub(1);
124 }
125 }
126
127 false
128 }
129
130 pub fn is_byzantine_suspected(&self, node_id: usize, currenttime: Instant) -> bool {
132 if let Some(&last_detection) = self.last_detection_times.get(&node_id) {
133 if currenttime.duration_since(last_detection) < self.recovery_period {
134 return true;
135 }
136 }
137 false
138 }
139
140 pub fn get_trust_weight(&self, nodeid: usize) -> f64 {
142 self.reputation_scores.get(&nodeid).copied().unwrap_or(1.0)
143 }
144}
145
146#[derive(Debug, Clone)]
148pub struct ConsensusVotingState {
149 pub round: usize,
151 pub proposals: HashMap<usize, Array1<f64>>,
153 pub votes: HashMap<usize, Vec<usize>>, pub voting_weights: HashMap<usize, f64>,
157 pub consensus_threshold: f64,
159 pub round_timeout: Duration,
161 pub round_start: Option<Instant>,
163}
164
165impl ConsensusVotingState {
166 pub fn new(_consensusthreshold: f64) -> Self {
167 Self {
168 round: 0,
169 proposals: HashMap::new(),
170 votes: HashMap::new(),
171 voting_weights: HashMap::new(),
172 consensus_threshold: _consensusthreshold,
173 round_timeout: Duration::from_millis(100),
174 round_start: None,
175 }
176 }
177
178 pub fn start_round(&mut self) {
180 self.round += 1;
181 self.proposals.clear();
182 self.votes.clear();
183 self.round_start = Some(Instant::now());
184 }
185
186 pub fn add_proposal(&mut self, nodeid: usize, parameters: Array1<f64>) {
188 self.proposals.insert(nodeid, parameters);
189 }
190
191 pub fn vote(&mut self, voter_id: usize, proposalid: usize, weight: f64) {
193 self.voting_weights.insert(voter_id, weight);
194 self.votes.entry(proposalid).or_default().push(voter_id);
195 }
196
197 pub fn check_consensus(&self) -> Option<(usize, Array1<f64>)> {
199 let mut best_proposal = None;
200 let mut best_weight = 0.0;
201
202 for (&proposal_id, voters) in &self.votes {
203 let total_weight: f64 = voters
204 .iter()
205 .map(|&voter| self.voting_weights.get(&voter).copied().unwrap_or(1.0))
206 .sum();
207
208 if total_weight > best_weight && total_weight >= self.consensus_threshold {
209 best_weight = total_weight;
210 if let Some(params) = self.proposals.get(&proposal_id) {
211 best_proposal = Some((proposal_id, params.clone()));
212 }
213 }
214 }
215
216 best_proposal
217 }
218
219 pub fn is_timeout(&self) -> bool {
221 if let Some(start) = self.round_start {
222 start.elapsed() > self.round_timeout
223 } else {
224 false
225 }
226 }
227}
228
229#[derive(Debug, Clone)]
231pub struct NetworkTopology {
232 pub adjacency_matrix: Array2<f64>,
234 pub delay_matrix: Array2<f64>,
236 pub bandwidth_matrix: Array2<f64>,
238 pub active_connections: HashMap<usize, Vec<usize>>,
240 pub reliability_scores: HashMap<usize, f64>,
242}
243
244impl NetworkTopology {
245 pub fn new(_numnodes: usize) -> Self {
246 Self {
247 adjacency_matrix: Array2::zeros((_numnodes, _numnodes)),
248 delay_matrix: Array2::zeros((_numnodes, _numnodes)),
249 bandwidth_matrix: Array2::from_elem((_numnodes, _numnodes), 1.0),
250 active_connections: HashMap::new(),
251 reliability_scores: HashMap::new(),
252 }
253 }
254
255 pub fn add_connection(&mut self, node1: usize, node2: usize, weight: f64, delay: f64) {
257 if node1 < self.adjacency_matrix.nrows() && node2 < self.adjacency_matrix.ncols() {
258 self.adjacency_matrix[[node1, node2]] = weight;
259 self.adjacency_matrix[[node2, node1]] = weight;
260 self.delay_matrix[[node1, node2]] = delay;
261 self.delay_matrix[[node2, node1]] = delay;
262
263 self.active_connections
264 .entry(node1)
265 .or_default()
266 .push(node2);
267 self.active_connections
268 .entry(node2)
269 .or_default()
270 .push(node1);
271 }
272 }
273
274 pub fn get_neighbors(&self, nodeid: usize) -> Vec<usize> {
276 self.active_connections
277 .get(&nodeid)
278 .cloned()
279 .unwrap_or_default()
280 }
281
282 pub fn compute_shortest_paths(&self) -> Array2<f64> {
284 let n = self.adjacency_matrix.nrows();
285 let mut dist = self.adjacency_matrix.clone();
286
287 for i in 0..n {
289 for j in 0..n {
290 if i != j && dist[[i, j]] == 0.0 {
291 dist[[i, j]] = f64::INFINITY;
292 }
293 }
294 }
295
296 for k in 0..n {
298 for i in 0..n {
299 for j in 0..n {
300 if dist[[i, k]] + dist[[k, j]] < dist[[i, j]] {
301 dist[[i, j]] = dist[[i, k]] + dist[[k, j]];
302 }
303 }
304 }
305 }
306
307 dist
308 }
309}
310
311#[derive(Debug, Clone)]
313pub enum ConsensusMessage {
314 Proposal {
316 round: usize,
317 node_id: usize,
318 parameters: Array1<f64>,
319 timestamp: Instant,
320 },
321 Vote {
323 round: usize,
324 voter_id: usize,
325 proposal_id: usize,
326 weight: f64,
327 timestamp: Instant,
328 },
329 ConsensusResult {
331 round: usize,
332 winning_proposal: usize,
333 parameters: Array1<f64>,
334 timestamp: Instant,
335 },
336 Heartbeat { node_id: usize, timestamp: Instant },
338 ByzantineAlert {
340 suspected_node: usize,
341 reporter_node: usize,
342 evidence: ByzantineEvidence,
343 timestamp: Instant,
344 },
345}
346
347#[derive(Debug, Clone)]
349pub struct ByzantineEvidence {
350 pub deviation_magnitude: f64,
351 pub frequency_count: usize,
352 pub reputation_score: f64,
353}
354
355#[derive(Debug, Clone)]
357pub struct AdvancedAdvancedDistributedOnlineGD<T: StreamingObjective> {
358 pub consensus_node: DistributedConsensusNode,
360 pub objective: T,
362 pub config: StreamingConfig,
364 pub stats: StreamingStats,
366 pub distributed_stats: DistributedOptimizationStats,
368 pub gradient_sum_sq: Array1<f64>,
370 pub momentum: Array1<f64>,
372 pub federated_state: FederatedAveragingState,
374 pub async_update_queue: VecDeque<DelayedUpdate>,
376 pub message_buffer: VecDeque<ConsensusMessage>,
378 pub sync_state: NetworkSynchronizationState,
380}
381
382#[derive(Debug, Clone)]
384pub struct DistributedOptimizationStats {
385 pub consensus_rounds: usize,
387 pub consensus_success_rate: f64,
389 pub avg_consensus_time: Duration,
391 pub byzantine_faults_detected: usize,
393 pub network_partitions: usize,
395 pub communication_overhead: f64,
397 pub relative_convergence_rate: f64,
399}
400
401impl Default for DistributedOptimizationStats {
402 fn default() -> Self {
403 Self {
404 consensus_rounds: 0,
405 consensus_success_rate: 1.0,
406 avg_consensus_time: Duration::from_millis(50),
407 byzantine_faults_detected: 0,
408 network_partitions: 0,
409 communication_overhead: 0.1,
410 relative_convergence_rate: 1.0,
411 }
412 }
413}
414
415#[derive(Debug, Clone)]
417pub struct FederatedAveragingState {
418 pub peer_gradients: HashMap<usize, Array1<f64>>,
420 pub peer_weights: HashMap<usize, f64>,
422 pub peer_data_counts: HashMap<usize, usize>,
424 pub last_updates: HashMap<usize, Instant>,
426 pub federated_round: usize,
428 pub staleness_tolerance: Duration,
430}
431
432impl Default for FederatedAveragingState {
433 fn default() -> Self {
434 Self::new()
435 }
436}
437
438impl FederatedAveragingState {
439 pub fn new() -> Self {
440 Self {
441 peer_gradients: HashMap::new(),
442 peer_weights: HashMap::new(),
443 peer_data_counts: HashMap::new(),
444 last_updates: HashMap::new(),
445 federated_round: 0,
446 staleness_tolerance: Duration::from_secs(10),
447 }
448 }
449
450 pub fn add_peer_gradient(&mut self, peer_id: usize, gradient: Array1<f64>, datacount: usize) {
452 self.peer_gradients.insert(peer_id, gradient);
453 self.peer_data_counts.insert(peer_id, datacount);
454 self.last_updates.insert(peer_id, Instant::now());
455
456 let total_data: usize = self.peer_data_counts.values().sum();
458 if total_data > 0 {
459 let weight = datacount as f64 / total_data as f64;
460 self.peer_weights.insert(peer_id, weight);
461 }
462 }
463
464 pub fn compute_federated_gradient(&self, currenttime: Instant) -> Option<Array1<f64>> {
466 if self.peer_gradients.is_empty() {
467 return None;
468 }
469
470 let mut weighted_sum = None;
471 let mut total_weight = 0.0;
472
473 for (&peer_id, gradient) in &self.peer_gradients {
474 if let Some(&last_update) = self.last_updates.get(&peer_id) {
476 if currenttime.duration_since(last_update) > self.staleness_tolerance {
477 continue; }
479 }
480
481 let weight = self.peer_weights.get(&peer_id).copied().unwrap_or(1.0);
482
483 if let Some(ref mut sum) = weighted_sum {
484 *sum = &*sum + &(weight * gradient);
485 } else {
486 weighted_sum = Some(weight * gradient);
487 }
488
489 total_weight += weight;
490 }
491
492 if let Some(sum) = weighted_sum {
493 if total_weight > 0.0 {
494 Some(sum / total_weight)
495 } else {
496 Some(sum)
497 }
498 } else {
499 None
500 }
501 }
502}
503
504#[derive(Debug, Clone)]
506pub struct DelayedUpdate {
507 pub source_node: usize,
508 pub parameters: Array1<f64>,
509 pub timestamp: Instant,
510 pub apply_at: Instant,
511}
512
513#[derive(Debug, Clone)]
515pub struct NetworkSynchronizationState {
516 pub clock_offsets: HashMap<usize, Duration>,
518 pub sync_accuracy: Duration,
520 pub last_sync: Instant,
522 pub sync_period: Duration,
524}
525
526impl Default for NetworkSynchronizationState {
527 fn default() -> Self {
528 Self::new()
529 }
530}
531
532impl NetworkSynchronizationState {
533 pub fn new() -> Self {
534 Self {
535 clock_offsets: HashMap::new(),
536 sync_accuracy: Duration::from_millis(10),
537 last_sync: Instant::now(),
538 sync_period: Duration::from_secs(60),
539 }
540 }
541
542 pub fn needs_sync(&self) -> bool {
544 self.last_sync.elapsed() > self.sync_period
545 }
546
547 pub fn update_clock_offset(&mut self, nodeid: usize, offset: Duration) {
549 self.clock_offsets.insert(nodeid, offset);
550 }
551
552 pub fn get_synchronized_time(&self, nodeid: usize) -> Instant {
554 let now = Instant::now();
555 if let Some(&offset) = self.clock_offsets.get(&nodeid) {
556 now - offset
557 } else {
558 now
559 }
560 }
561}
562
563impl<T: StreamingObjective + Clone> AdvancedAdvancedDistributedOnlineGD<T> {
564 pub fn new(
566 node_id: usize,
567 initial_parameters: Array1<f64>,
568 objective: T,
569 config: StreamingConfig,
570 num_nodes: usize,
571 ) -> Self {
572 let n_params = initial_parameters.len();
573
574 let consensus_node = DistributedConsensusNode {
575 node_id,
576 local_parameters: initial_parameters.clone(),
577 consensus_parameters: initial_parameters.clone(),
578 trust_scores: HashMap::new(),
579 byzantine_detector: ByzantineFaultDetector::new(1.0),
580 peer_history: HashMap::new(),
581 gradient_accumulator: Array1::zeros(n_params),
582 voting_state: ConsensusVotingState::new(num_nodes as f64 * 0.67), network_topology: NetworkTopology::new(num_nodes),
584 };
585
586 Self {
587 consensus_node,
588 objective,
589 config,
590 stats: StreamingStats::default(),
591 distributed_stats: DistributedOptimizationStats::default(),
592 gradient_sum_sq: Array1::zeros(n_params),
593 momentum: Array1::zeros(n_params),
594 federated_state: FederatedAveragingState::new(),
595 async_update_queue: VecDeque::new(),
596 message_buffer: VecDeque::new(),
597 sync_state: NetworkSynchronizationState::new(),
598 }
599 }
600
601 pub fn setup_network_topology(&mut self, peerconnections: &[(usize, usize, f64, f64)]) {
603 for &(node1, node2, weight, delay) in peerconnections {
604 self.consensus_node
605 .network_topology
606 .add_connection(node1, node2, weight, delay);
607 }
608 }
609
610 pub fn process_consensus_messages(&mut self) -> Result<()> {
612 let current_time = Instant::now();
613
614 while let Some(message) = self.message_buffer.pop_front() {
615 match message {
616 ConsensusMessage::Proposal {
617 round,
618 node_id,
619 parameters,
620 timestamp: _,
621 } => {
622 if round == self.consensus_node.voting_state.round {
623 let is_byzantine = self
625 .consensus_node
626 .byzantine_detector
627 .detect_byzantine_behavior(
628 node_id,
629 ¶meters.view(),
630 &self.consensus_node.consensus_parameters.view(),
631 current_time,
632 );
633
634 if !is_byzantine {
635 self.consensus_node
636 .voting_state
637 .add_proposal(node_id, parameters);
638
639 let similarity = self.compute_parameter_similarity(
641 &self.consensus_node.local_parameters.view(),
642 &self.consensus_node.voting_state.proposals[&node_id].view(),
643 );
644
645 let trust_weight = self
646 .consensus_node
647 .byzantine_detector
648 .get_trust_weight(node_id);
649 let vote_weight = similarity * trust_weight;
650
651 if vote_weight > 0.5 {
652 self.consensus_node.voting_state.vote(
653 self.consensus_node.node_id,
654 node_id,
655 vote_weight,
656 );
657 }
658 }
659 }
660 }
661 ConsensusMessage::Vote {
662 round,
663 voter_id,
664 proposal_id,
665 weight,
666 timestamp: _,
667 } => {
668 if round == self.consensus_node.voting_state.round {
669 self.consensus_node
670 .voting_state
671 .vote(voter_id, proposal_id, weight);
672 }
673 }
674 ConsensusMessage::ConsensusResult {
675 round: _,
676 winning_proposal: _,
677 parameters,
678 timestamp: _,
679 } => {
680 self.apply_consensus_parameters(parameters)?;
682 }
683 ConsensusMessage::Heartbeat {
684 node_id,
685 timestamp: _,
686 } => {
687 self.consensus_node
689 .network_topology
690 .reliability_scores
691 .insert(node_id, 1.0);
692 }
693 ConsensusMessage::ByzantineAlert {
694 suspected_node,
695 reporter_node: _,
696 evidence,
697 timestamp: _,
698 } => {
699 self.handle_byzantine_alert(suspected_node, evidence);
701 }
702 }
703 }
704
705 Ok(())
706 }
707
708 fn compute_parameter_similarity(
709 &self,
710 params1: &ArrayView1<f64>,
711 params2: &ArrayView1<f64>,
712 ) -> f64 {
713 let diff = params1 - params2;
714 let norm = diff.mapv(|x| x * x).sum().sqrt();
715 let scale = params1.mapv(|x| x * x).sum().sqrt().max(1e-12);
716 (-norm / scale).exp()
717 }
718
719 fn apply_consensus_parameters(&mut self, parameters: Array1<f64>) -> Result<()> {
720 let blend_factor = 0.7; self.consensus_node.consensus_parameters = &(blend_factor * ¶meters)
723 + &((1.0 - blend_factor) * &self.consensus_node.local_parameters);
724
725 self.distributed_stats.consensus_rounds += 1;
726 Ok(())
727 }
728
729 fn handle_byzantine_alert(&mut self, suspectednode: usize, evidence: ByzantineEvidence) {
730 let current_trust = self
732 .consensus_node
733 .trust_scores
734 .get(&suspectednode)
735 .copied()
736 .unwrap_or(1.0);
737 let new_trust = current_trust * (1.0 - evidence.deviation_magnitude * 0.1);
738 self.consensus_node
739 .trust_scores
740 .insert(suspectednode, new_trust.max(0.0));
741
742 if new_trust < 0.1 {
743 self.distributed_stats.byzantine_faults_detected += 1;
744 }
745 }
746
747 pub fn run_consensus_protocol(&mut self) -> Result<Option<Array1<f64>>> {
749 self.consensus_node.voting_state.start_round();
751
752 let proposal_message = ConsensusMessage::Proposal {
754 round: self.consensus_node.voting_state.round,
755 node_id: self.consensus_node.node_id,
756 parameters: self.consensus_node.local_parameters.clone(),
757 timestamp: Instant::now(),
758 };
759
760 self.consensus_node.voting_state.add_proposal(
762 self.consensus_node.node_id,
763 self.consensus_node.local_parameters.clone(),
764 );
765
766 self.message_buffer.push_back(proposal_message);
768
769 self.process_consensus_messages()?;
771
772 if let Some((_winning_id, consensus_params)) =
774 self.consensus_node.voting_state.check_consensus()
775 {
776 self.distributed_stats.consensus_success_rate =
777 0.95 * self.distributed_stats.consensus_success_rate + 0.05 * 1.0;
778
779 Ok(Some(consensus_params))
780 } else if self.consensus_node.voting_state.is_timeout() {
781 self.distributed_stats.consensus_success_rate =
782 0.95 * self.distributed_stats.consensus_success_rate + 0.05 * 0.0;
783
784 Ok(None)
785 } else {
786 Ok(None)
787 }
788 }
789
790 pub fn federated_update(&mut self, gradient: &ArrayView1<f64>) -> Result<()> {
792 self.federated_state.add_peer_gradient(
794 self.consensus_node.node_id,
795 gradient.to_owned(),
796 1, );
798
799 let current_time = Instant::now();
801 if let Some(fed_gradient) = self
802 .federated_state
803 .compute_federated_gradient(current_time)
804 {
805 self.apply_gradient_update(&fed_gradient.view())?;
807
808 self.federated_state.federated_round += 1;
809 }
810
811 Ok(())
812 }
813
814 fn apply_gradient_update(&mut self, gradient: &ArrayView1<f64>) -> Result<()> {
815 let lr = if self.config.adaptive_lr {
816 let local_grad_norm = gradient.mapv(|x| x * x).sum().sqrt();
818 let consensus_factor = self.distributed_stats.consensus_success_rate;
819 self.config.learning_rate * consensus_factor * (1.0 / (1.0 + local_grad_norm * 0.1))
820 } else {
821 self.config.learning_rate
822 };
823
824 self.consensus_node.local_parameters =
826 &self.consensus_node.local_parameters - &(lr * gradient);
827
828 Ok(())
829 }
830
831 pub fn process_async_updates(&mut self) -> Result<()> {
833 let current_time = Instant::now();
834
835 while let Some(update) = self.async_update_queue.front() {
836 if current_time >= update.apply_at {
837 let update = self.async_update_queue.pop_front().unwrap();
838
839 let staleness = current_time.duration_since(update.timestamp).as_secs_f64();
841 let staleness_factor = (-staleness * 0.1).exp(); let weighted_update = &update.parameters * staleness_factor;
844 self.consensus_node.local_parameters =
845 &(0.9 * &self.consensus_node.local_parameters) + &(0.1 * &weighted_update);
846 } else {
847 break; }
849 }
850
851 Ok(())
852 }
853}
854
855impl<T: StreamingObjective + Clone> StreamingOptimizer for AdvancedAdvancedDistributedOnlineGD<T> {
856 fn update(&mut self, datapoint: &StreamingDataPoint) -> Result<()> {
857 let start_time = Instant::now();
858
859 let gradient = self
861 .objective
862 .gradient(&self.consensus_node.local_parameters.view(), datapoint);
863
864 self.consensus_node.gradient_accumulator =
866 &self.consensus_node.gradient_accumulator + &gradient;
867
868 if self.stats.points_processed.is_multiple_of(10) {
870 if let Some(consensus_params) = self.run_consensus_protocol()? {
871 self.apply_consensus_parameters(consensus_params)?;
872 }
873 }
874
875 self.federated_update(&gradient.view())?;
877
878 self.process_async_updates()?;
880
881 let loss = self
883 .objective
884 .evaluate(&self.consensus_node.local_parameters.view(), datapoint);
885
886 self.stats.points_processed += 1;
888 self.stats.updates_performed += 1;
889 self.stats.current_loss = loss;
890 self.stats.average_loss = utils::ewma_update(self.stats.average_loss, loss, 0.01);
891
892 let param_change = (&self.consensus_node.local_parameters
894 - &self.consensus_node.consensus_parameters)
895 .mapv(|x| x.abs())
896 .sum()
897 / self.consensus_node.local_parameters.len() as f64;
898
899 self.stats.converged = param_change < self.config.tolerance;
900 self.stats.processing_time_ms += start_time.elapsed().as_secs_f64() * 1000.0;
901
902 Ok(())
903 }
904
905 fn parameters(&self) -> &Array1<f64> {
906 &self.consensus_node.consensus_parameters
907 }
908
909 fn stats(&self) -> &StreamingStats {
910 &self.stats
911 }
912
913 fn reset(&mut self) {
914 self.consensus_node.local_parameters.fill(0.0);
915 self.consensus_node.consensus_parameters.fill(0.0);
916 self.consensus_node.gradient_accumulator.fill(0.0);
917 self.gradient_sum_sq.fill(0.0);
918 self.momentum.fill(0.0);
919 self.stats = StreamingStats::default();
920 self.distributed_stats = DistributedOptimizationStats::default();
921 self.federated_state = FederatedAveragingState::new();
922 self.async_update_queue.clear();
923 self.message_buffer.clear();
924 }
925}
926
927#[allow(dead_code)]
929pub fn distributed_online_linear_regression(
930 node_id: usize,
931 n_features: usize,
932 num_nodes: usize,
933 config: Option<StreamingConfig>,
934) -> AdvancedAdvancedDistributedOnlineGD<super::LinearRegressionObjective> {
935 let config = config.unwrap_or_default();
936 let initial_params = Array1::zeros(n_features);
937 let objective = super::LinearRegressionObjective;
938
939 AdvancedAdvancedDistributedOnlineGD::new(node_id, initial_params, objective, config, num_nodes)
940}
941
942#[allow(dead_code)]
944pub fn distributed_online_logistic_regression(
945 node_id: usize,
946 n_features: usize,
947 num_nodes: usize,
948 config: Option<StreamingConfig>,
949) -> AdvancedAdvancedDistributedOnlineGD<super::LogisticRegressionObjective> {
950 let config = config.unwrap_or_default();
951 let initial_params = Array1::zeros(n_features);
952 let objective = super::LogisticRegressionObjective;
953
954 AdvancedAdvancedDistributedOnlineGD::new(node_id, initial_params, objective, config, num_nodes)
955}
956
957#[allow(dead_code)]
959pub fn online_linear_regression(
960 n_features: usize,
961 config: Option<StreamingConfig>,
962) -> AdvancedAdvancedDistributedOnlineGD<super::LinearRegressionObjective> {
963 distributed_online_linear_regression(0, n_features, 1, config)
964}
965
966#[allow(dead_code)]
967pub fn online_logistic_regression(
968 n_features: usize,
969 config: Option<StreamingConfig>,
970) -> AdvancedAdvancedDistributedOnlineGD<super::LogisticRegressionObjective> {
971 distributed_online_logistic_regression(0, n_features, 1, config)
972}
973
974#[cfg(test)]
975mod tests {
976 use super::*;
977 use crate::streaming::StreamingDataPoint;
978
979 #[test]
980 fn test_distributed_optimizer_creation() {
981 let optimizer = distributed_online_linear_regression(0, 2, 3, None);
982 assert_eq!(optimizer.consensus_node.node_id, 0);
983 assert_eq!(optimizer.consensus_node.local_parameters.len(), 2);
984 }
985
986 #[test]
987 fn test_byzantine_fault_detector() {
988 let mut detector = ByzantineFaultDetector::new(1.0);
989 let good_params = Array1::from(vec![1.0, 2.0]);
990 let bad_params = Array1::from(vec![10.0, 20.0]); let current_time = Instant::now();
992
993 assert!(!detector.detect_byzantine_behavior(
995 1,
996 &good_params.view(),
997 &good_params.view(),
998 current_time
999 ));
1000
1001 for _ in 0..10 {
1003 detector.detect_byzantine_behavior(
1004 2,
1005 &bad_params.view(),
1006 &good_params.view(),
1007 current_time,
1008 );
1009 }
1010
1011 assert!(detector.is_byzantine_suspected(2, current_time));
1012 }
1013
1014 #[test]
1015 fn test_consensus_voting() {
1016 let mut voting_state = ConsensusVotingState::new(2.0); voting_state.start_round();
1018
1019 let params1 = Array1::from(vec![1.0, 2.0]);
1020 let params2 = Array1::from(vec![1.1, 2.1]);
1021
1022 voting_state.add_proposal(1, params1);
1023 voting_state.add_proposal(2, params2);
1024
1025 voting_state.vote(1, 1, 1.0);
1026 voting_state.vote(2, 1, 1.0);
1027
1028 let consensus = voting_state.check_consensus();
1029 assert!(consensus.is_some());
1030
1031 let (winner_id, _winning_params) = consensus.unwrap();
1032 assert_eq!(winner_id, 1);
1033 }
1034
1035 #[test]
1036 fn test_federated_averaging() {
1037 let mut federated_state = FederatedAveragingState::new();
1038
1039 let grad1 = Array1::from(vec![1.0, 2.0]);
1040 let grad2 = Array1::from(vec![3.0, 4.0]);
1041
1042 federated_state.add_peer_gradient(1, grad1, 10);
1043 federated_state.add_peer_gradient(2, grad2, 20);
1044
1045 let avg_grad = federated_state
1046 .compute_federated_gradient(Instant::now())
1047 .unwrap();
1048
1049 assert!(avg_grad[0].is_finite() && avg_grad[0] > 0.0);
1051 assert!(avg_grad[1].is_finite() && avg_grad[1] > 0.0);
1052 assert!(avg_grad[0] >= 1.0 && avg_grad[0] <= 3.0);
1054 assert!(avg_grad[1] >= 2.0 && avg_grad[1] <= 4.0);
1055 }
1056
1057 #[test]
1058 fn test_network_topology() {
1059 let mut topology = NetworkTopology::new(3);
1060 topology.add_connection(0, 1, 1.0, 0.1);
1061 topology.add_connection(1, 2, 1.0, 0.1);
1062
1063 let neighbors_0 = topology.get_neighbors(0);
1064 let neighbors_1 = topology.get_neighbors(1);
1065
1066 assert_eq!(neighbors_0, vec![1]);
1067 assert_eq!(neighbors_1, vec![0, 2]);
1068 }
1069
1070 #[test]
1071 fn test_distributed_optimization_update() {
1072 let mut optimizer = distributed_online_linear_regression(0, 2, 1, None);
1073
1074 let features = Array1::from(vec![1.0, 2.0]);
1075 let target = 3.0;
1076 let point = StreamingDataPoint::new(features, target);
1077
1078 assert!(optimizer.update(&point).is_ok());
1080 assert_eq!(optimizer.stats().points_processed, 1);
1081 }
1082
1083 #[test]
1084 fn test_network_synchronization() {
1085 let mut sync_state = NetworkSynchronizationState::new();
1086
1087 let offset = Duration::from_millis(100);
1088 sync_state.update_clock_offset(1, offset);
1089
1090 let sync_time = sync_state.get_synchronized_time(1);
1091 let now = Instant::now();
1092
1093 assert!(now.duration_since(sync_time) >= offset);
1095 }
1096
1097 #[test]
1098 fn test_parameter_similarity() {
1099 let optimizer = distributed_online_linear_regression(0, 2, 1, None);
1100
1101 let params1 = Array1::from(vec![1.0, 2.0]);
1102 let params2 = Array1::from(vec![1.0, 2.0]); let params3 = Array1::from(vec![10.0, 20.0]); let similarity_identical =
1106 optimizer.compute_parameter_similarity(¶ms1.view(), ¶ms2.view());
1107 let similarity_different =
1108 optimizer.compute_parameter_similarity(¶ms1.view(), ¶ms3.view());
1109
1110 assert!(similarity_identical > 0.9);
1111 assert!(similarity_different < 0.1);
1112 }
1113}