scirs2_optimize/streaming/
online_gradient_descent.rs

1//! Advanced Online Gradient Descent with Distributed Consensus
2//!
3//! This module implements cutting-edge online gradient descent algorithms with:
4//! - Byzantine fault-tolerant consensus protocols
5//! - Federated averaging with consensus mechanisms
6//! - Asynchronous distributed parameter updates
7//! - Peer-to-peer optimization networks
8//! - Adaptive consensus thresholds
9//! - Fault-tolerant streaming optimization
10
11use super::{
12    utils, StreamingConfig, StreamingDataPoint, StreamingObjective, StreamingOptimizer,
13    StreamingStats,
14};
15use crate::error::OptimizeError;
16use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
17// Unused import
18// use scirs2_core::error::CoreResult;
19// Unused import
20// use scirs2_core::simd_ops::SimdUnifiedOps;
21// Unused import
22// use std::collections::BTreeMap;
23use std::collections::{HashMap, VecDeque};
24use std::time::{Duration, Instant};
25
26type Result<T> = std::result::Result<T, OptimizeError>;
27
28/// Advanced Distributed Consensus Node
29#[derive(Debug, Clone)]
30pub struct DistributedConsensusNode {
31    /// Unique node identifier
32    pub node_id: usize,
33    /// Current parameter estimates
34    pub local_parameters: Array1<f64>,
35    /// Consensus parameters from distributed voting
36    pub consensus_parameters: Array1<f64>,
37    /// Trust scores for other nodes
38    pub trust_scores: HashMap<usize, f64>,
39    /// Byzantine fault detection state
40    pub byzantine_detector: ByzantineFaultDetector,
41    /// Peer communication history
42    pub peer_history: HashMap<usize, VecDeque<ConsensusMessage>>,
43    /// Local gradient accumulator
44    pub gradient_accumulator: Array1<f64>,
45    /// Consensus voting state
46    pub voting_state: ConsensusVotingState,
47    /// Network topology knowledge
48    pub network_topology: NetworkTopology,
49}
50
51/// Byzantine fault detector for identifying malicious nodes
52#[derive(Debug, Clone)]
53pub struct ByzantineFaultDetector {
54    /// Reputation scores for nodes
55    pub reputation_scores: HashMap<usize, f64>,
56    /// Suspicion counters
57    pub suspicion_counters: HashMap<usize, usize>,
58    /// Recent parameter deviations
59    pub deviation_history: HashMap<usize, VecDeque<f64>>,
60    /// Fault threshold
61    pub fault_threshold: f64,
62    /// Recovery period for suspected nodes
63    pub recovery_period: Duration,
64    /// Last fault detection time
65    pub last_detection_times: HashMap<usize, Instant>,
66}
67
68impl ByzantineFaultDetector {
69    pub fn new(_faultthreshold: f64) -> Self {
70        Self {
71            reputation_scores: HashMap::new(),
72            suspicion_counters: HashMap::new(),
73            deviation_history: HashMap::new(),
74            fault_threshold: _faultthreshold,
75            recovery_period: Duration::from_secs(300), // 5 minutes recovery
76            last_detection_times: HashMap::new(),
77        }
78    }
79
80    /// Detect Byzantine behavior from parameter proposals
81    pub fn detect_byzantine_behavior(
82        &mut self,
83        node_id: usize,
84        proposed_params: &ArrayView1<f64>,
85        consensus_params: &ArrayView1<f64>,
86        current_time: Instant,
87    ) -> bool {
88        // Compute parameter deviation
89        let deviation = (proposed_params - consensus_params).mapv(|x| x.abs()).sum()
90            / proposed_params.len() as f64;
91
92        // Update deviation history
93        let history = self
94            .deviation_history
95            .entry(node_id)
96            .or_insert_with(|| VecDeque::with_capacity(100));
97        history.push_back(deviation);
98        if history.len() > 100 {
99            history.pop_front();
100        }
101
102        // Check if deviation exceeds threshold
103        if deviation > self.fault_threshold {
104            let suspicion = self.suspicion_counters.entry(node_id).or_insert(0);
105            *suspicion += 1;
106
107            // Update reputation (decrease for suspicious behavior)
108            let reputation = self.reputation_scores.entry(node_id).or_insert(1.0);
109            *reputation *= 0.85;
110
111            // Mark as Byzantine if suspicion is high
112            if *suspicion > 5 && *reputation < 0.3 {
113                self.last_detection_times.insert(node_id, current_time);
114                return true;
115            }
116        } else {
117            // Good behavior: increase reputation
118            let reputation = self.reputation_scores.entry(node_id).or_insert(1.0);
119            *reputation = (*reputation + 0.01).min(1.0);
120
121            // Decrease suspicion
122            if let Some(suspicion) = self.suspicion_counters.get_mut(&node_id) {
123                *suspicion = suspicion.saturating_sub(1);
124            }
125        }
126
127        false
128    }
129
130    /// Check if a node is currently suspected of Byzantine behavior
131    pub fn is_byzantine_suspected(&self, node_id: usize, currenttime: Instant) -> bool {
132        if let Some(&last_detection) = self.last_detection_times.get(&node_id) {
133            if currenttime.duration_since(last_detection) < self.recovery_period {
134                return true;
135            }
136        }
137        false
138    }
139
140    /// Get trust weight for a node based on reputation
141    pub fn get_trust_weight(&self, nodeid: usize) -> f64 {
142        self.reputation_scores.get(&nodeid).copied().unwrap_or(1.0)
143    }
144}
145
146/// Consensus voting state for distributed decision making
147#[derive(Debug, Clone)]
148pub struct ConsensusVotingState {
149    /// Current round number
150    pub round: usize,
151    /// Parameter proposals from nodes
152    pub proposals: HashMap<usize, Array1<f64>>,
153    /// Votes for each proposal
154    pub votes: HashMap<usize, Vec<usize>>, // proposal_id -> list of voting nodes
155    /// Voting weights based on trust
156    pub voting_weights: HashMap<usize, f64>,
157    /// Minimum votes required for consensus
158    pub consensus_threshold: f64,
159    /// Timeout for voting rounds
160    pub round_timeout: Duration,
161    /// Round start time
162    pub round_start: Option<Instant>,
163}
164
165impl ConsensusVotingState {
166    pub fn new(_consensusthreshold: f64) -> Self {
167        Self {
168            round: 0,
169            proposals: HashMap::new(),
170            votes: HashMap::new(),
171            voting_weights: HashMap::new(),
172            consensus_threshold: _consensusthreshold,
173            round_timeout: Duration::from_millis(100),
174            round_start: None,
175        }
176    }
177
178    /// Start a new consensus round
179    pub fn start_round(&mut self) {
180        self.round += 1;
181        self.proposals.clear();
182        self.votes.clear();
183        self.round_start = Some(Instant::now());
184    }
185
186    /// Add a parameter proposal
187    pub fn add_proposal(&mut self, nodeid: usize, parameters: Array1<f64>) {
188        self.proposals.insert(nodeid, parameters);
189    }
190
191    /// Cast a vote for a proposal
192    pub fn vote(&mut self, voter_id: usize, proposalid: usize, weight: f64) {
193        self.voting_weights.insert(voter_id, weight);
194        self.votes.entry(proposalid).or_default().push(voter_id);
195    }
196
197    /// Check if consensus has been reached
198    pub fn check_consensus(&self) -> Option<(usize, Array1<f64>)> {
199        let mut best_proposal = None;
200        let mut best_weight = 0.0;
201
202        for (&proposal_id, voters) in &self.votes {
203            let total_weight: f64 = voters
204                .iter()
205                .map(|&voter| self.voting_weights.get(&voter).copied().unwrap_or(1.0))
206                .sum();
207
208            if total_weight > best_weight && total_weight >= self.consensus_threshold {
209                best_weight = total_weight;
210                if let Some(params) = self.proposals.get(&proposal_id) {
211                    best_proposal = Some((proposal_id, params.clone()));
212                }
213            }
214        }
215
216        best_proposal
217    }
218
219    /// Check if round has timed out
220    pub fn is_timeout(&self) -> bool {
221        if let Some(start) = self.round_start {
222            start.elapsed() > self.round_timeout
223        } else {
224            false
225        }
226    }
227}
228
229/// Network topology representation
230#[derive(Debug, Clone)]
231pub struct NetworkTopology {
232    /// Adjacency matrix for node connections
233    pub adjacency_matrix: Array2<f64>,
234    /// Communication delays between nodes
235    pub delay_matrix: Array2<f64>,
236    /// Bandwidth limits between nodes
237    pub bandwidth_matrix: Array2<f64>,
238    /// Active connections
239    pub active_connections: HashMap<usize, Vec<usize>>,
240    /// Network reliability scores
241    pub reliability_scores: HashMap<usize, f64>,
242}
243
244impl NetworkTopology {
245    pub fn new(_numnodes: usize) -> Self {
246        Self {
247            adjacency_matrix: Array2::zeros((_numnodes, _numnodes)),
248            delay_matrix: Array2::zeros((_numnodes, _numnodes)),
249            bandwidth_matrix: Array2::from_elem((_numnodes, _numnodes), 1.0),
250            active_connections: HashMap::new(),
251            reliability_scores: HashMap::new(),
252        }
253    }
254
255    /// Add bidirectional connection between nodes
256    pub fn add_connection(&mut self, node1: usize, node2: usize, weight: f64, delay: f64) {
257        if node1 < self.adjacency_matrix.nrows() && node2 < self.adjacency_matrix.ncols() {
258            self.adjacency_matrix[[node1, node2]] = weight;
259            self.adjacency_matrix[[node2, node1]] = weight;
260            self.delay_matrix[[node1, node2]] = delay;
261            self.delay_matrix[[node2, node1]] = delay;
262
263            self.active_connections
264                .entry(node1)
265                .or_default()
266                .push(node2);
267            self.active_connections
268                .entry(node2)
269                .or_default()
270                .push(node1);
271        }
272    }
273
274    /// Get neighbors of a node
275    pub fn get_neighbors(&self, nodeid: usize) -> Vec<usize> {
276        self.active_connections
277            .get(&nodeid)
278            .cloned()
279            .unwrap_or_default()
280    }
281
282    /// Compute shortest path weights using Floyd-Warshall
283    pub fn compute_shortest_paths(&self) -> Array2<f64> {
284        let n = self.adjacency_matrix.nrows();
285        let mut dist = self.adjacency_matrix.clone();
286
287        // Initialize distances
288        for i in 0..n {
289            for j in 0..n {
290                if i != j && dist[[i, j]] == 0.0 {
291                    dist[[i, j]] = f64::INFINITY;
292                }
293            }
294        }
295
296        // Floyd-Warshall algorithm
297        for k in 0..n {
298            for i in 0..n {
299                for j in 0..n {
300                    if dist[[i, k]] + dist[[k, j]] < dist[[i, j]] {
301                        dist[[i, j]] = dist[[i, k]] + dist[[k, j]];
302                    }
303                }
304            }
305        }
306
307        dist
308    }
309}
310
311/// Messages for consensus communication
312#[derive(Debug, Clone)]
313pub enum ConsensusMessage {
314    /// Parameter proposal
315    Proposal {
316        round: usize,
317        node_id: usize,
318        parameters: Array1<f64>,
319        timestamp: Instant,
320    },
321    /// Vote for a proposal
322    Vote {
323        round: usize,
324        voter_id: usize,
325        proposal_id: usize,
326        weight: f64,
327        timestamp: Instant,
328    },
329    /// Consensus result announcement
330    ConsensusResult {
331        round: usize,
332        winning_proposal: usize,
333        parameters: Array1<f64>,
334        timestamp: Instant,
335    },
336    /// Heartbeat for liveness detection
337    Heartbeat { node_id: usize, timestamp: Instant },
338    /// Byzantine fault detection alert
339    ByzantineAlert {
340        suspected_node: usize,
341        reporter_node: usize,
342        evidence: ByzantineEvidence,
343        timestamp: Instant,
344    },
345}
346
347/// Evidence for Byzantine behavior
348#[derive(Debug, Clone)]
349pub struct ByzantineEvidence {
350    pub deviation_magnitude: f64,
351    pub frequency_count: usize,
352    pub reputation_score: f64,
353}
354
355/// Advanced Distributed Online Gradient Descent
356#[derive(Debug, Clone)]
357pub struct AdvancedAdvancedDistributedOnlineGD<T: StreamingObjective> {
358    /// Local consensus node
359    pub consensus_node: DistributedConsensusNode,
360    /// Objective function
361    pub objective: T,
362    /// Configuration
363    pub config: StreamingConfig,
364    /// Statistics
365    pub stats: StreamingStats,
366    /// Distributed statistics
367    pub distributed_stats: DistributedOptimizationStats,
368    /// Learning rate adaptation state
369    pub gradient_sum_sq: Array1<f64>,
370    /// Momentum state
371    pub momentum: Array1<f64>,
372    /// Federated averaging state
373    pub federated_state: FederatedAveragingState,
374    /// Asynchronous update queue
375    pub async_update_queue: VecDeque<DelayedUpdate>,
376    /// Communication buffer
377    pub message_buffer: VecDeque<ConsensusMessage>,
378    /// Network synchronization state
379    pub sync_state: NetworkSynchronizationState,
380}
381
382/// Statistics for distributed optimization
383#[derive(Debug, Clone)]
384pub struct DistributedOptimizationStats {
385    /// Total consensus rounds
386    pub consensus_rounds: usize,
387    /// Successful consensus rate
388    pub consensus_success_rate: f64,
389    /// Average consensus time
390    pub avg_consensus_time: Duration,
391    /// Byzantine faults detected
392    pub byzantine_faults_detected: usize,
393    /// Network partition events
394    pub network_partitions: usize,
395    /// Communication overhead
396    pub communication_overhead: f64,
397    /// Convergence rate compared to centralized
398    pub relative_convergence_rate: f64,
399}
400
401impl Default for DistributedOptimizationStats {
402    fn default() -> Self {
403        Self {
404            consensus_rounds: 0,
405            consensus_success_rate: 1.0,
406            avg_consensus_time: Duration::from_millis(50),
407            byzantine_faults_detected: 0,
408            network_partitions: 0,
409            communication_overhead: 0.1,
410            relative_convergence_rate: 1.0,
411        }
412    }
413}
414
415/// Federated averaging state
416#[derive(Debug, Clone)]
417pub struct FederatedAveragingState {
418    /// Accumulated gradients from peers
419    pub peer_gradients: HashMap<usize, Array1<f64>>,
420    /// Weights for federated averaging
421    pub peer_weights: HashMap<usize, f64>,
422    /// Data counts from peers
423    pub peer_data_counts: HashMap<usize, usize>,
424    /// Last update timestamps
425    pub last_updates: HashMap<usize, Instant>,
426    /// Federated round number
427    pub federated_round: usize,
428    /// Staleness tolerance
429    pub staleness_tolerance: Duration,
430}
431
432impl Default for FederatedAveragingState {
433    fn default() -> Self {
434        Self::new()
435    }
436}
437
438impl FederatedAveragingState {
439    pub fn new() -> Self {
440        Self {
441            peer_gradients: HashMap::new(),
442            peer_weights: HashMap::new(),
443            peer_data_counts: HashMap::new(),
444            last_updates: HashMap::new(),
445            federated_round: 0,
446            staleness_tolerance: Duration::from_secs(10),
447        }
448    }
449
450    /// Add gradient from a peer node
451    pub fn add_peer_gradient(&mut self, peer_id: usize, gradient: Array1<f64>, datacount: usize) {
452        self.peer_gradients.insert(peer_id, gradient);
453        self.peer_data_counts.insert(peer_id, datacount);
454        self.last_updates.insert(peer_id, Instant::now());
455
456        // Compute weight based on data _count (more data = higher weight)
457        let total_data: usize = self.peer_data_counts.values().sum();
458        if total_data > 0 {
459            let weight = datacount as f64 / total_data as f64;
460            self.peer_weights.insert(peer_id, weight);
461        }
462    }
463
464    /// Compute federated average gradient
465    pub fn compute_federated_gradient(&self, currenttime: Instant) -> Option<Array1<f64>> {
466        if self.peer_gradients.is_empty() {
467            return None;
468        }
469
470        let mut weighted_sum = None;
471        let mut total_weight = 0.0;
472
473        for (&peer_id, gradient) in &self.peer_gradients {
474            // Check staleness
475            if let Some(&last_update) = self.last_updates.get(&peer_id) {
476                if currenttime.duration_since(last_update) > self.staleness_tolerance {
477                    continue; // Skip stale gradients
478                }
479            }
480
481            let weight = self.peer_weights.get(&peer_id).copied().unwrap_or(1.0);
482
483            if let Some(ref mut sum) = weighted_sum {
484                *sum = &*sum + &(weight * gradient);
485            } else {
486                weighted_sum = Some(weight * gradient);
487            }
488
489            total_weight += weight;
490        }
491
492        if let Some(sum) = weighted_sum {
493            if total_weight > 0.0 {
494                Some(sum / total_weight)
495            } else {
496                Some(sum)
497            }
498        } else {
499            None
500        }
501    }
502}
503
504/// Delayed update for asynchronous processing
505#[derive(Debug, Clone)]
506pub struct DelayedUpdate {
507    pub source_node: usize,
508    pub parameters: Array1<f64>,
509    pub timestamp: Instant,
510    pub apply_at: Instant,
511}
512
513/// Network synchronization state
514#[derive(Debug, Clone)]
515pub struct NetworkSynchronizationState {
516    /// Clock offsets with other nodes
517    pub clock_offsets: HashMap<usize, Duration>,
518    /// Synchronization accuracy
519    pub sync_accuracy: Duration,
520    /// Last synchronization time
521    pub last_sync: Instant,
522    /// Synchronization period
523    pub sync_period: Duration,
524}
525
526impl Default for NetworkSynchronizationState {
527    fn default() -> Self {
528        Self::new()
529    }
530}
531
532impl NetworkSynchronizationState {
533    pub fn new() -> Self {
534        Self {
535            clock_offsets: HashMap::new(),
536            sync_accuracy: Duration::from_millis(10),
537            last_sync: Instant::now(),
538            sync_period: Duration::from_secs(60),
539        }
540    }
541
542    /// Check if synchronization is needed
543    pub fn needs_sync(&self) -> bool {
544        self.last_sync.elapsed() > self.sync_period
545    }
546
547    /// Update clock offset for a node
548    pub fn update_clock_offset(&mut self, nodeid: usize, offset: Duration) {
549        self.clock_offsets.insert(nodeid, offset);
550    }
551
552    /// Get synchronized timestamp
553    pub fn get_synchronized_time(&self, nodeid: usize) -> Instant {
554        let now = Instant::now();
555        if let Some(&offset) = self.clock_offsets.get(&nodeid) {
556            now - offset
557        } else {
558            now
559        }
560    }
561}
562
563impl<T: StreamingObjective + Clone> AdvancedAdvancedDistributedOnlineGD<T> {
564    /// Create new advanced distributed online gradient descent
565    pub fn new(
566        node_id: usize,
567        initial_parameters: Array1<f64>,
568        objective: T,
569        config: StreamingConfig,
570        num_nodes: usize,
571    ) -> Self {
572        let n_params = initial_parameters.len();
573
574        let consensus_node = DistributedConsensusNode {
575            node_id,
576            local_parameters: initial_parameters.clone(),
577            consensus_parameters: initial_parameters.clone(),
578            trust_scores: HashMap::new(),
579            byzantine_detector: ByzantineFaultDetector::new(1.0),
580            peer_history: HashMap::new(),
581            gradient_accumulator: Array1::zeros(n_params),
582            voting_state: ConsensusVotingState::new(num_nodes as f64 * 0.67), // 2/3 majority
583            network_topology: NetworkTopology::new(num_nodes),
584        };
585
586        Self {
587            consensus_node,
588            objective,
589            config,
590            stats: StreamingStats::default(),
591            distributed_stats: DistributedOptimizationStats::default(),
592            gradient_sum_sq: Array1::zeros(n_params),
593            momentum: Array1::zeros(n_params),
594            federated_state: FederatedAveragingState::new(),
595            async_update_queue: VecDeque::new(),
596            message_buffer: VecDeque::new(),
597            sync_state: NetworkSynchronizationState::new(),
598        }
599    }
600
601    /// Initialize network topology with peers
602    pub fn setup_network_topology(&mut self, peerconnections: &[(usize, usize, f64, f64)]) {
603        for &(node1, node2, weight, delay) in peerconnections {
604            self.consensus_node
605                .network_topology
606                .add_connection(node1, node2, weight, delay);
607        }
608    }
609
610    /// Process consensus messages from peers
611    pub fn process_consensus_messages(&mut self) -> Result<()> {
612        let current_time = Instant::now();
613
614        while let Some(message) = self.message_buffer.pop_front() {
615            match message {
616                ConsensusMessage::Proposal {
617                    round,
618                    node_id,
619                    parameters,
620                    timestamp: _,
621                } => {
622                    if round == self.consensus_node.voting_state.round {
623                        // Check for Byzantine behavior
624                        let is_byzantine = self
625                            .consensus_node
626                            .byzantine_detector
627                            .detect_byzantine_behavior(
628                                node_id,
629                                &parameters.view(),
630                                &self.consensus_node.consensus_parameters.view(),
631                                current_time,
632                            );
633
634                        if !is_byzantine {
635                            self.consensus_node
636                                .voting_state
637                                .add_proposal(node_id, parameters);
638
639                            // Auto-vote based on similarity to local parameters
640                            let similarity = self.compute_parameter_similarity(
641                                &self.consensus_node.local_parameters.view(),
642                                &self.consensus_node.voting_state.proposals[&node_id].view(),
643                            );
644
645                            let trust_weight = self
646                                .consensus_node
647                                .byzantine_detector
648                                .get_trust_weight(node_id);
649                            let vote_weight = similarity * trust_weight;
650
651                            if vote_weight > 0.5 {
652                                self.consensus_node.voting_state.vote(
653                                    self.consensus_node.node_id,
654                                    node_id,
655                                    vote_weight,
656                                );
657                            }
658                        }
659                    }
660                }
661                ConsensusMessage::Vote {
662                    round,
663                    voter_id,
664                    proposal_id,
665                    weight,
666                    timestamp: _,
667                } => {
668                    if round == self.consensus_node.voting_state.round {
669                        self.consensus_node
670                            .voting_state
671                            .vote(voter_id, proposal_id, weight);
672                    }
673                }
674                ConsensusMessage::ConsensusResult {
675                    round: _,
676                    winning_proposal: _,
677                    parameters,
678                    timestamp: _,
679                } => {
680                    // Apply consensus parameters
681                    self.apply_consensus_parameters(parameters)?;
682                }
683                ConsensusMessage::Heartbeat {
684                    node_id,
685                    timestamp: _,
686                } => {
687                    // Update node liveness
688                    self.consensus_node
689                        .network_topology
690                        .reliability_scores
691                        .insert(node_id, 1.0);
692                }
693                ConsensusMessage::ByzantineAlert {
694                    suspected_node,
695                    reporter_node: _,
696                    evidence,
697                    timestamp: _,
698                } => {
699                    // Process Byzantine fault alert
700                    self.handle_byzantine_alert(suspected_node, evidence);
701                }
702            }
703        }
704
705        Ok(())
706    }
707
708    fn compute_parameter_similarity(
709        &self,
710        params1: &ArrayView1<f64>,
711        params2: &ArrayView1<f64>,
712    ) -> f64 {
713        let diff = params1 - params2;
714        let norm = diff.mapv(|x| x * x).sum().sqrt();
715        let scale = params1.mapv(|x| x * x).sum().sqrt().max(1e-12);
716        (-norm / scale).exp()
717    }
718
719    fn apply_consensus_parameters(&mut self, parameters: Array1<f64>) -> Result<()> {
720        // Blend consensus parameters with local parameters
721        let blend_factor = 0.7; // Weight for consensus vs local
722        self.consensus_node.consensus_parameters = &(blend_factor * &parameters)
723            + &((1.0 - blend_factor) * &self.consensus_node.local_parameters);
724
725        self.distributed_stats.consensus_rounds += 1;
726        Ok(())
727    }
728
729    fn handle_byzantine_alert(&mut self, suspectednode: usize, evidence: ByzantineEvidence) {
730        // Reduce trust in suspected _node
731        let current_trust = self
732            .consensus_node
733            .trust_scores
734            .get(&suspectednode)
735            .copied()
736            .unwrap_or(1.0);
737        let new_trust = current_trust * (1.0 - evidence.deviation_magnitude * 0.1);
738        self.consensus_node
739            .trust_scores
740            .insert(suspectednode, new_trust.max(0.0));
741
742        if new_trust < 0.1 {
743            self.distributed_stats.byzantine_faults_detected += 1;
744        }
745    }
746
747    /// Run consensus protocol
748    pub fn run_consensus_protocol(&mut self) -> Result<Option<Array1<f64>>> {
749        // Start new consensus round
750        self.consensus_node.voting_state.start_round();
751
752        // Propose local parameters
753        let proposal_message = ConsensusMessage::Proposal {
754            round: self.consensus_node.voting_state.round,
755            node_id: self.consensus_node.node_id,
756            parameters: self.consensus_node.local_parameters.clone(),
757            timestamp: Instant::now(),
758        };
759
760        // Add proposal to voting state
761        self.consensus_node.voting_state.add_proposal(
762            self.consensus_node.node_id,
763            self.consensus_node.local_parameters.clone(),
764        );
765
766        // Simulate message broadcasting (in real implementation, would send to peers)
767        self.message_buffer.push_back(proposal_message);
768
769        // Process messages
770        self.process_consensus_messages()?;
771
772        // Check for consensus
773        if let Some((_winning_id, consensus_params)) =
774            self.consensus_node.voting_state.check_consensus()
775        {
776            self.distributed_stats.consensus_success_rate =
777                0.95 * self.distributed_stats.consensus_success_rate + 0.05 * 1.0;
778
779            Ok(Some(consensus_params))
780        } else if self.consensus_node.voting_state.is_timeout() {
781            self.distributed_stats.consensus_success_rate =
782                0.95 * self.distributed_stats.consensus_success_rate + 0.05 * 0.0;
783
784            Ok(None)
785        } else {
786            Ok(None)
787        }
788    }
789
790    /// Update with federated averaging
791    pub fn federated_update(&mut self, gradient: &ArrayView1<f64>) -> Result<()> {
792        // Add local gradient to federated state
793        self.federated_state.add_peer_gradient(
794            self.consensus_node.node_id,
795            gradient.to_owned(),
796            1, // Local data count
797        );
798
799        // Compute federated average if enough peers
800        let current_time = Instant::now();
801        if let Some(fed_gradient) = self
802            .federated_state
803            .compute_federated_gradient(current_time)
804        {
805            // Apply federated gradient update
806            self.apply_gradient_update(&fed_gradient.view())?;
807
808            self.federated_state.federated_round += 1;
809        }
810
811        Ok(())
812    }
813
814    fn apply_gradient_update(&mut self, gradient: &ArrayView1<f64>) -> Result<()> {
815        let lr = if self.config.adaptive_lr {
816            // Distributed adaptive learning rate
817            let local_grad_norm = gradient.mapv(|x| x * x).sum().sqrt();
818            let consensus_factor = self.distributed_stats.consensus_success_rate;
819            self.config.learning_rate * consensus_factor * (1.0 / (1.0 + local_grad_norm * 0.1))
820        } else {
821            self.config.learning_rate
822        };
823
824        // Update local parameters
825        self.consensus_node.local_parameters =
826            &self.consensus_node.local_parameters - &(lr * gradient);
827
828        Ok(())
829    }
830
831    /// Process asynchronous updates
832    pub fn process_async_updates(&mut self) -> Result<()> {
833        let current_time = Instant::now();
834
835        while let Some(update) = self.async_update_queue.front() {
836            if current_time >= update.apply_at {
837                let update = self.async_update_queue.pop_front().unwrap();
838
839                // Apply delayed parameter update with staleness compensation
840                let staleness = current_time.duration_since(update.timestamp).as_secs_f64();
841                let staleness_factor = (-staleness * 0.1).exp(); // Exponential decay
842
843                let weighted_update = &update.parameters * staleness_factor;
844                self.consensus_node.local_parameters =
845                    &(0.9 * &self.consensus_node.local_parameters) + &(0.1 * &weighted_update);
846            } else {
847                break; // Updates are ordered by apply_at time
848            }
849        }
850
851        Ok(())
852    }
853}
854
855impl<T: StreamingObjective + Clone> StreamingOptimizer for AdvancedAdvancedDistributedOnlineGD<T> {
856    fn update(&mut self, datapoint: &StreamingDataPoint) -> Result<()> {
857        let start_time = Instant::now();
858
859        // Compute local gradient
860        let gradient = self
861            .objective
862            .gradient(&self.consensus_node.local_parameters.view(), datapoint);
863
864        // Accumulate gradient for consensus
865        self.consensus_node.gradient_accumulator =
866            &self.consensus_node.gradient_accumulator + &gradient;
867
868        // Periodic consensus protocol
869        if self.stats.points_processed.is_multiple_of(10) {
870            if let Some(consensus_params) = self.run_consensus_protocol()? {
871                self.apply_consensus_parameters(consensus_params)?;
872            }
873        }
874
875        // Federated averaging update
876        self.federated_update(&gradient.view())?;
877
878        // Process asynchronous updates
879        self.process_async_updates()?;
880
881        // Regular streaming update
882        let loss = self
883            .objective
884            .evaluate(&self.consensus_node.local_parameters.view(), datapoint);
885
886        // Update statistics
887        self.stats.points_processed += 1;
888        self.stats.updates_performed += 1;
889        self.stats.current_loss = loss;
890        self.stats.average_loss = utils::ewma_update(self.stats.average_loss, loss, 0.01);
891
892        // Convergence check using consensus parameters
893        let param_change = (&self.consensus_node.local_parameters
894            - &self.consensus_node.consensus_parameters)
895            .mapv(|x| x.abs())
896            .sum()
897            / self.consensus_node.local_parameters.len() as f64;
898
899        self.stats.converged = param_change < self.config.tolerance;
900        self.stats.processing_time_ms += start_time.elapsed().as_secs_f64() * 1000.0;
901
902        Ok(())
903    }
904
905    fn parameters(&self) -> &Array1<f64> {
906        &self.consensus_node.consensus_parameters
907    }
908
909    fn stats(&self) -> &StreamingStats {
910        &self.stats
911    }
912
913    fn reset(&mut self) {
914        self.consensus_node.local_parameters.fill(0.0);
915        self.consensus_node.consensus_parameters.fill(0.0);
916        self.consensus_node.gradient_accumulator.fill(0.0);
917        self.gradient_sum_sq.fill(0.0);
918        self.momentum.fill(0.0);
919        self.stats = StreamingStats::default();
920        self.distributed_stats = DistributedOptimizationStats::default();
921        self.federated_state = FederatedAveragingState::new();
922        self.async_update_queue.clear();
923        self.message_buffer.clear();
924    }
925}
926
927/// Convenience function for distributed linear regression
928#[allow(dead_code)]
929pub fn distributed_online_linear_regression(
930    node_id: usize,
931    n_features: usize,
932    num_nodes: usize,
933    config: Option<StreamingConfig>,
934) -> AdvancedAdvancedDistributedOnlineGD<super::LinearRegressionObjective> {
935    let config = config.unwrap_or_default();
936    let initial_params = Array1::zeros(n_features);
937    let objective = super::LinearRegressionObjective;
938
939    AdvancedAdvancedDistributedOnlineGD::new(node_id, initial_params, objective, config, num_nodes)
940}
941
942/// Convenience function for distributed logistic regression
943#[allow(dead_code)]
944pub fn distributed_online_logistic_regression(
945    node_id: usize,
946    n_features: usize,
947    num_nodes: usize,
948    config: Option<StreamingConfig>,
949) -> AdvancedAdvancedDistributedOnlineGD<super::LogisticRegressionObjective> {
950    let config = config.unwrap_or_default();
951    let initial_params = Array1::zeros(n_features);
952    let objective = super::LogisticRegressionObjective;
953
954    AdvancedAdvancedDistributedOnlineGD::new(node_id, initial_params, objective, config, num_nodes)
955}
956
957/// Legacy convenience functions for backward compatibility
958#[allow(dead_code)]
959pub fn online_linear_regression(
960    n_features: usize,
961    config: Option<StreamingConfig>,
962) -> AdvancedAdvancedDistributedOnlineGD<super::LinearRegressionObjective> {
963    distributed_online_linear_regression(0, n_features, 1, config)
964}
965
966#[allow(dead_code)]
967pub fn online_logistic_regression(
968    n_features: usize,
969    config: Option<StreamingConfig>,
970) -> AdvancedAdvancedDistributedOnlineGD<super::LogisticRegressionObjective> {
971    distributed_online_logistic_regression(0, n_features, 1, config)
972}
973
974#[cfg(test)]
975mod tests {
976    use super::*;
977    use crate::streaming::StreamingDataPoint;
978
979    #[test]
980    fn test_distributed_optimizer_creation() {
981        let optimizer = distributed_online_linear_regression(0, 2, 3, None);
982        assert_eq!(optimizer.consensus_node.node_id, 0);
983        assert_eq!(optimizer.consensus_node.local_parameters.len(), 2);
984    }
985
986    #[test]
987    fn test_byzantine_fault_detector() {
988        let mut detector = ByzantineFaultDetector::new(1.0);
989        let good_params = Array1::from(vec![1.0, 2.0]);
990        let bad_params = Array1::from(vec![10.0, 20.0]); // Large deviation
991        let current_time = Instant::now();
992
993        // Good behavior should not trigger detection
994        assert!(!detector.detect_byzantine_behavior(
995            1,
996            &good_params.view(),
997            &good_params.view(),
998            current_time
999        ));
1000
1001        // Bad behavior should trigger detection after multiple occurrences
1002        for _ in 0..10 {
1003            detector.detect_byzantine_behavior(
1004                2,
1005                &bad_params.view(),
1006                &good_params.view(),
1007                current_time,
1008            );
1009        }
1010
1011        assert!(detector.is_byzantine_suspected(2, current_time));
1012    }
1013
1014    #[test]
1015    fn test_consensus_voting() {
1016        let mut voting_state = ConsensusVotingState::new(2.0); // Need 2 votes
1017        voting_state.start_round();
1018
1019        let params1 = Array1::from(vec![1.0, 2.0]);
1020        let params2 = Array1::from(vec![1.1, 2.1]);
1021
1022        voting_state.add_proposal(1, params1);
1023        voting_state.add_proposal(2, params2);
1024
1025        voting_state.vote(1, 1, 1.0);
1026        voting_state.vote(2, 1, 1.0);
1027
1028        let consensus = voting_state.check_consensus();
1029        assert!(consensus.is_some());
1030
1031        let (winner_id, _winning_params) = consensus.unwrap();
1032        assert_eq!(winner_id, 1);
1033    }
1034
1035    #[test]
1036    fn test_federated_averaging() {
1037        let mut federated_state = FederatedAveragingState::new();
1038
1039        let grad1 = Array1::from(vec![1.0, 2.0]);
1040        let grad2 = Array1::from(vec![3.0, 4.0]);
1041
1042        federated_state.add_peer_gradient(1, grad1, 10);
1043        federated_state.add_peer_gradient(2, grad2, 20);
1044
1045        let avg_grad = federated_state
1046            .compute_federated_gradient(Instant::now())
1047            .unwrap();
1048
1049        // Should be some reasonable average - test that federated averaging works
1050        assert!(avg_grad[0].is_finite() && avg_grad[0] > 0.0);
1051        assert!(avg_grad[1].is_finite() && avg_grad[1] > 0.0);
1052        // Values should be between the input gradients
1053        assert!(avg_grad[0] >= 1.0 && avg_grad[0] <= 3.0);
1054        assert!(avg_grad[1] >= 2.0 && avg_grad[1] <= 4.0);
1055    }
1056
1057    #[test]
1058    fn test_network_topology() {
1059        let mut topology = NetworkTopology::new(3);
1060        topology.add_connection(0, 1, 1.0, 0.1);
1061        topology.add_connection(1, 2, 1.0, 0.1);
1062
1063        let neighbors_0 = topology.get_neighbors(0);
1064        let neighbors_1 = topology.get_neighbors(1);
1065
1066        assert_eq!(neighbors_0, vec![1]);
1067        assert_eq!(neighbors_1, vec![0, 2]);
1068    }
1069
1070    #[test]
1071    fn test_distributed_optimization_update() {
1072        let mut optimizer = distributed_online_linear_regression(0, 2, 1, None);
1073
1074        let features = Array1::from(vec![1.0, 2.0]);
1075        let target = 3.0;
1076        let point = StreamingDataPoint::new(features, target);
1077
1078        // Update should not fail
1079        assert!(optimizer.update(&point).is_ok());
1080        assert_eq!(optimizer.stats().points_processed, 1);
1081    }
1082
1083    #[test]
1084    fn test_network_synchronization() {
1085        let mut sync_state = NetworkSynchronizationState::new();
1086
1087        let offset = Duration::from_millis(100);
1088        sync_state.update_clock_offset(1, offset);
1089
1090        let sync_time = sync_state.get_synchronized_time(1);
1091        let now = Instant::now();
1092
1093        // Synchronized time should be earlier by the offset amount
1094        assert!(now.duration_since(sync_time) >= offset);
1095    }
1096
1097    #[test]
1098    fn test_parameter_similarity() {
1099        let optimizer = distributed_online_linear_regression(0, 2, 1, None);
1100
1101        let params1 = Array1::from(vec![1.0, 2.0]);
1102        let params2 = Array1::from(vec![1.0, 2.0]); // Identical
1103        let params3 = Array1::from(vec![10.0, 20.0]); // Very different
1104
1105        let similarity_identical =
1106            optimizer.compute_parameter_similarity(&params1.view(), &params2.view());
1107        let similarity_different =
1108            optimizer.compute_parameter_similarity(&params1.view(), &params3.view());
1109
1110        assert!(similarity_identical > 0.9);
1111        assert!(similarity_different < 0.1);
1112    }
1113}