ruvector_sona/training/
federated.rs

1//! Federated Learning for SONA
2//!
3//! Enable distributed learning across ephemeral agents that share
4//! trajectories with a central coordinator.
5//!
6//! ## Architecture
7//!
8//! ```text
9//! ┌─────────────┐     ┌─────────────┐     ┌─────────────┐
10//! │  Agent A    │     │  Agent B    │     │  Agent C    │
11//! │ (ephemeral) │     │ (ephemeral) │     │ (ephemeral) │
12//! └──────┬──────┘     └──────┬──────┘     └──────┬──────┘
13//!        │                   │                   │
14//!        │    export()       │    export()       │    export()
15//!        ▼                   ▼                   ▼
16//!   ┌────────────────────────────────────────────────┐
17//!   │            Federated Coordinator               │
18//!   │         (persistent, large capacity)           │
19//!   └────────────────────────────────────────────────┘
20//! ```
21
22use crate::engine::SonaEngine;
23use crate::types::{SonaConfig, LearnedPattern};
24use super::metrics::TrainingMetrics;
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::time::{SystemTime, UNIX_EPOCH};
28
29/// Exported state from an ephemeral agent
30#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct AgentExport {
32    /// Agent identifier
33    pub agent_id: String,
34    /// Exported trajectories (embedding, quality pairs)
35    pub trajectories: Vec<TrajectoryExport>,
36    /// Agent statistics
37    pub stats: AgentExportStats,
38    /// Session duration in milliseconds
39    pub session_duration_ms: u64,
40    /// Export timestamp
41    pub timestamp: u64,
42}
43
44/// Single trajectory export
45#[derive(Clone, Debug, Serialize, Deserialize)]
46pub struct TrajectoryExport {
47    /// Query embedding
48    pub embedding: Vec<f32>,
49    /// Quality score
50    pub quality: f32,
51    /// Model route (if any)
52    pub route: Option<String>,
53    /// Context identifiers
54    pub context: Vec<String>,
55    /// Timestamp
56    pub timestamp: u64,
57}
58
59/// Agent export statistics
60#[derive(Clone, Debug, Default, Serialize, Deserialize)]
61pub struct AgentExportStats {
62    /// Total trajectories processed
63    pub total_trajectories: usize,
64    /// Average quality
65    pub avg_quality: f32,
66    /// Patterns learned locally
67    pub patterns_learned: usize,
68}
69
70/// Ephemeral agent for federated learning
71///
72/// Collects trajectories during its session and exports state before termination.
73pub struct EphemeralAgent {
74    /// Agent identifier
75    agent_id: String,
76    /// SONA engine
77    engine: SonaEngine,
78    /// Collected trajectories
79    trajectories: Vec<TrajectoryExport>,
80    /// Session start time
81    start_time: u64,
82    /// Quality samples
83    quality_samples: Vec<f32>,
84}
85
86impl EphemeralAgent {
87    /// Create a new ephemeral agent
88    pub fn new(agent_id: impl Into<String>, config: SonaConfig) -> Self {
89        let now = SystemTime::now()
90            .duration_since(UNIX_EPOCH)
91            .unwrap_or_default()
92            .as_millis() as u64;
93
94        Self {
95            agent_id: agent_id.into(),
96            engine: SonaEngine::with_config(config),
97            trajectories: Vec::new(),
98            start_time: now,
99            quality_samples: Vec::new(),
100        }
101    }
102
103    /// Create with default config for federated learning
104    pub fn default_federated(agent_id: impl Into<String>, hidden_dim: usize) -> Self {
105        Self::new(agent_id, SonaConfig {
106            hidden_dim,
107            embedding_dim: hidden_dim,
108            micro_lora_rank: 2,
109            base_lora_rank: 8,
110            micro_lora_lr: 0.002,
111            trajectory_capacity: 500,  // Small buffer per agent
112            pattern_clusters: 25,
113            ..Default::default()
114        })
115    }
116
117    /// Get agent ID
118    pub fn agent_id(&self) -> &str {
119        &self.agent_id
120    }
121
122    /// Get engine reference
123    pub fn engine(&self) -> &SonaEngine {
124        &self.engine
125    }
126
127    /// Get mutable engine reference
128    pub fn engine_mut(&mut self) -> &mut SonaEngine {
129        &mut self.engine
130    }
131
132    /// Process a task and record trajectory
133    pub fn process_trajectory(
134        &mut self,
135        embedding: Vec<f32>,
136        activations: Vec<f32>,
137        quality: f32,
138        route: Option<String>,
139        context: Vec<String>,
140    ) {
141        let now = SystemTime::now()
142            .duration_since(UNIX_EPOCH)
143            .unwrap_or_default()
144            .as_millis() as u64;
145
146        // Record in SONA engine
147        let mut builder = self.engine.begin_trajectory(embedding.clone());
148        if let Some(ref r) = route {
149            builder.set_model_route(r);
150        }
151        for ctx in &context {
152            builder.add_context(ctx);
153        }
154        builder.add_step(activations, vec![], quality);
155        self.engine.end_trajectory(builder, quality);
156
157        // Store for export
158        self.trajectories.push(TrajectoryExport {
159            embedding,
160            quality,
161            route,
162            context,
163            timestamp: now,
164        });
165
166        self.quality_samples.push(quality);
167    }
168
169    /// Apply micro-LoRA to hidden states
170    pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
171        self.engine.apply_micro_lora(input, output);
172    }
173
174    /// Get number of collected trajectories
175    pub fn trajectory_count(&self) -> usize {
176        self.trajectories.len()
177    }
178
179    /// Get average quality
180    pub fn avg_quality(&self) -> f32 {
181        if self.quality_samples.is_empty() {
182            0.0
183        } else {
184            self.quality_samples.iter().sum::<f32>() / self.quality_samples.len() as f32
185        }
186    }
187
188    /// Force local learning
189    pub fn force_learn(&self) -> String {
190        self.engine.force_learn()
191    }
192
193    /// Simple process task method
194    pub fn process_task(&mut self, embedding: Vec<f32>, quality: f32) {
195        self.process_trajectory(embedding.clone(), embedding, quality, None, vec![]);
196    }
197
198    /// Process task with route information
199    pub fn process_task_with_route(&mut self, embedding: Vec<f32>, quality: f32, route: &str) {
200        self.process_trajectory(embedding.clone(), embedding, quality, Some(route.to_string()), vec![]);
201    }
202
203    /// Get average quality (alias for avg_quality)
204    pub fn average_quality(&self) -> f32 {
205        self.avg_quality()
206    }
207
208    /// Get uptime in seconds
209    pub fn uptime_seconds(&self) -> u64 {
210        let now = SystemTime::now()
211            .duration_since(UNIX_EPOCH)
212            .unwrap_or_default()
213            .as_millis() as u64;
214        (now - self.start_time) / 1000
215    }
216
217    /// Get agent stats
218    pub fn stats(&self) -> AgentExportStats {
219        let engine_stats = self.engine.stats();
220        AgentExportStats {
221            total_trajectories: self.trajectories.len(),
222            avg_quality: self.avg_quality(),
223            patterns_learned: engine_stats.patterns_stored,
224        }
225    }
226
227    /// Clear trajectories (after export)
228    pub fn clear(&mut self) {
229        self.trajectories.clear();
230        self.quality_samples.clear();
231    }
232
233    /// Get learned patterns from agent
234    pub fn get_patterns(&self) -> Vec<LearnedPattern> {
235        self.engine.find_patterns(&[], 0)
236    }
237
238    /// Export agent state for federation
239    ///
240    /// Call this before terminating the agent.
241    pub fn export_state(&self) -> AgentExport {
242        let now = SystemTime::now()
243            .duration_since(UNIX_EPOCH)
244            .unwrap_or_default()
245            .as_millis() as u64;
246
247        // Force learning before export
248        self.engine.force_learn();
249
250        let stats = self.engine.stats();
251
252        AgentExport {
253            agent_id: self.agent_id.clone(),
254            trajectories: self.trajectories.clone(),
255            stats: AgentExportStats {
256                total_trajectories: self.trajectories.len(),
257                avg_quality: self.avg_quality(),
258                patterns_learned: stats.patterns_stored,
259            },
260            session_duration_ms: now - self.start_time,
261            timestamp: now,
262        }
263    }
264}
265
266/// Agent contribution record
267#[derive(Clone, Debug, Serialize, Deserialize)]
268pub struct AgentContribution {
269    /// Number of trajectories contributed
270    pub trajectory_count: usize,
271    /// Average quality of contributions
272    pub avg_quality: f32,
273    /// Contribution timestamp
274    pub timestamp: u64,
275    /// Session duration
276    pub session_duration_ms: u64,
277}
278
279/// Federated learning coordinator
280///
281/// Aggregates learning from multiple ephemeral agents.
282pub struct FederatedCoordinator {
283    /// Coordinator identifier
284    coordinator_id: String,
285    /// Master SONA engine for aggregation
286    master_engine: SonaEngine,
287    /// Agent contributions
288    contributions: HashMap<String, AgentContribution>,
289    /// Quality threshold for accepting trajectories
290    quality_threshold: f32,
291    /// Total trajectories aggregated
292    total_trajectories: usize,
293    /// Consolidation interval (number of agents)
294    consolidation_interval: usize,
295    /// Metrics
296    metrics: TrainingMetrics,
297}
298
299impl FederatedCoordinator {
300    /// Create a new federated coordinator
301    pub fn new(coordinator_id: impl Into<String>, config: SonaConfig) -> Self {
302        let id = coordinator_id.into();
303        Self {
304            coordinator_id: id.clone(),
305            master_engine: SonaEngine::with_config(config),
306            contributions: HashMap::new(),
307            quality_threshold: 0.4,
308            total_trajectories: 0,
309            consolidation_interval: 50,
310            metrics: TrainingMetrics::new(&id),
311        }
312    }
313
314    /// Create with default config for coordination
315    pub fn default_coordinator(coordinator_id: impl Into<String>, hidden_dim: usize) -> Self {
316        Self::new(coordinator_id, SonaConfig {
317            hidden_dim,
318            embedding_dim: hidden_dim,
319            micro_lora_rank: 2,
320            base_lora_rank: 16,          // Deeper for aggregation
321            trajectory_capacity: 50000,   // Large central buffer
322            pattern_clusters: 200,
323            ewc_lambda: 2000.0,          // Strong regularization
324            ..Default::default()
325        })
326    }
327
328    /// Get coordinator ID
329    pub fn coordinator_id(&self) -> &str {
330        &self.coordinator_id
331    }
332
333    /// Set quality threshold for accepting trajectories
334    pub fn set_quality_threshold(&mut self, threshold: f32) {
335        self.quality_threshold = threshold;
336    }
337
338    /// Set consolidation interval
339    pub fn set_consolidation_interval(&mut self, interval: usize) {
340        self.consolidation_interval = interval;
341    }
342
343    /// Get master engine reference
344    pub fn master_engine(&self) -> &SonaEngine {
345        &self.master_engine
346    }
347
348    /// Aggregate agent export into coordinator
349    pub fn aggregate(&mut self, export: AgentExport) -> AggregationResult {
350        let mut accepted = 0;
351        let mut rejected = 0;
352
353        // Replay trajectories into master engine
354        for traj in &export.trajectories {
355            if traj.quality >= self.quality_threshold {
356                let mut builder = self.master_engine.begin_trajectory(traj.embedding.clone());
357                if let Some(ref route) = traj.route {
358                    builder.set_model_route(route);
359                }
360                for ctx in &traj.context {
361                    builder.add_context(ctx);
362                }
363                self.master_engine.end_trajectory(builder, traj.quality);
364
365                self.metrics.add_quality_sample(traj.quality);
366                accepted += 1;
367            } else {
368                rejected += 1;
369            }
370        }
371
372        self.total_trajectories += accepted;
373
374        // Record contribution
375        let now = SystemTime::now()
376            .duration_since(UNIX_EPOCH)
377            .unwrap_or_default()
378            .as_millis() as u64;
379
380        self.contributions.insert(export.agent_id.clone(), AgentContribution {
381            trajectory_count: export.trajectories.len(),
382            avg_quality: export.stats.avg_quality,
383            timestamp: now,
384            session_duration_ms: export.session_duration_ms,
385        });
386
387        // Auto-consolidate if needed
388        let consolidated = if self.should_consolidate() {
389            self.master_engine.force_learn();
390            true
391        } else {
392            false
393        };
394
395        AggregationResult {
396            agent_id: export.agent_id,
397            trajectories_accepted: accepted,
398            trajectories_rejected: rejected,
399            consolidated,
400            total_agents: self.contributions.len(),
401            total_trajectories: self.total_trajectories,
402        }
403    }
404
405    /// Check if consolidation is needed
406    fn should_consolidate(&self) -> bool {
407        self.contributions.len() % self.consolidation_interval == 0
408    }
409
410    /// Force consolidation
411    pub fn force_consolidate(&self) -> String {
412        self.master_engine.force_learn()
413    }
414
415    /// Get initial state for new agents
416    ///
417    /// Returns learned patterns that new agents can use for warm start.
418    pub fn get_initial_patterns(&self, k: usize) -> Vec<LearnedPattern> {
419        // Find patterns similar to a general query (empty or average)
420        // Since we don't have a specific query, get all patterns
421        self.master_engine.find_patterns(&[], 0)
422            .into_iter()
423            .take(k)
424            .collect()
425    }
426
427    /// Get all learned patterns
428    pub fn get_all_patterns(&self) -> Vec<LearnedPattern> {
429        self.master_engine.find_patterns(&[], 0)
430    }
431
432    /// Get coordinator statistics
433    pub fn stats(&self) -> CoordinatorStats {
434        let engine_stats = self.master_engine.stats();
435
436        CoordinatorStats {
437            coordinator_id: self.coordinator_id.clone(),
438            total_agents: self.contributions.len(),
439            total_trajectories: self.total_trajectories,
440            patterns_learned: engine_stats.patterns_stored,
441            avg_quality: self.metrics.avg_quality(),
442            quality_threshold: self.quality_threshold,
443        }
444    }
445
446    /// Get contribution history
447    pub fn contributions(&self) -> &HashMap<String, AgentContribution> {
448        &self.contributions
449    }
450
451    /// Get metrics
452    pub fn metrics(&self) -> &TrainingMetrics {
453        &self.metrics
454    }
455
456    /// Get total number of contributing agents
457    pub fn agent_count(&self) -> usize {
458        self.contributions.len()
459    }
460
461    /// Get total trajectories aggregated
462    pub fn total_trajectories(&self) -> usize {
463        self.total_trajectories
464    }
465
466    /// Find similar patterns
467    pub fn find_patterns(&self, query: &[f32], k: usize) -> Vec<LearnedPattern> {
468        self.master_engine.find_patterns(query, k)
469    }
470
471    /// Apply coordinator's LoRA to input
472    pub fn apply_lora(&self, input: &[f32]) -> Vec<f32> {
473        let mut output = vec![0.0; input.len()];
474        self.master_engine.apply_micro_lora(input, &mut output);
475        output
476    }
477
478    /// Consolidate learning (alias for force_consolidate)
479    pub fn consolidate(&self) -> String {
480        self.force_consolidate()
481    }
482
483    /// Clear all contributions
484    pub fn clear(&mut self) {
485        self.contributions.clear();
486        self.total_trajectories = 0;
487    }
488}
489
490/// Result of aggregating an agent export
491#[derive(Clone, Debug, Serialize, Deserialize)]
492pub struct AggregationResult {
493    /// Agent ID that was aggregated
494    pub agent_id: String,
495    /// Number of trajectories accepted
496    pub trajectories_accepted: usize,
497    /// Number of trajectories rejected (below quality threshold)
498    pub trajectories_rejected: usize,
499    /// Whether consolidation was triggered
500    pub consolidated: bool,
501    /// Total number of contributing agents
502    pub total_agents: usize,
503    /// Total trajectories in coordinator
504    pub total_trajectories: usize,
505}
506
507/// Coordinator statistics
508#[derive(Clone, Debug, Serialize, Deserialize)]
509pub struct CoordinatorStats {
510    /// Coordinator identifier
511    pub coordinator_id: String,
512    /// Number of contributing agents
513    pub total_agents: usize,
514    /// Total trajectories aggregated
515    pub total_trajectories: usize,
516    /// Patterns learned
517    pub patterns_learned: usize,
518    /// Average quality across all contributions
519    pub avg_quality: f32,
520    /// Quality threshold
521    pub quality_threshold: f32,
522}
523
524impl std::fmt::Display for CoordinatorStats {
525    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526        write!(
527            f,
528            "Coordinator(id={}, agents={}, trajectories={}, patterns={}, avg_quality={:.4})",
529            self.coordinator_id,
530            self.total_agents,
531            self.total_trajectories,
532            self.patterns_learned,
533            self.avg_quality
534        )
535    }
536}
537
538/// Federated learning topology
539#[derive(Clone, Debug, Serialize, Deserialize)]
540pub enum FederatedTopology {
541    /// Agents → Central Coordinator (simple, single aggregation point)
542    Star,
543    /// Agents → Regional → Global (multi-datacenter)
544    Hierarchical {
545        /// Number of regional coordinators
546        regions: usize,
547    },
548    /// Agents share directly (edge deployment)
549    PeerToPeer,
550}
551
552impl Default for FederatedTopology {
553    fn default() -> Self {
554        FederatedTopology::Star
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561
562    #[test]
563    fn test_ephemeral_agent_creation() {
564        let agent = EphemeralAgent::default_federated("agent-1", 256);
565        assert_eq!(agent.agent_id(), "agent-1");
566        assert_eq!(agent.trajectory_count(), 0);
567    }
568
569    #[test]
570    fn test_trajectory_collection() {
571        let mut agent = EphemeralAgent::default_federated("agent-1", 256);
572
573        agent.process_trajectory(
574            vec![0.1; 256],
575            vec![0.5; 256],
576            0.8,
577            Some("code".into()),
578            vec!["file:main.rs".into()],
579        );
580
581        assert_eq!(agent.trajectory_count(), 1);
582        assert!((agent.avg_quality() - 0.8).abs() < 0.01);
583    }
584
585    #[test]
586    fn test_agent_export() {
587        let mut agent = EphemeralAgent::default_federated("agent-1", 256);
588
589        for i in 0..5 {
590            agent.process_trajectory(
591                vec![i as f32 * 0.1; 256],
592                vec![0.5; 256],
593                0.7 + i as f32 * 0.05,
594                None,
595                vec![],
596            );
597        }
598
599        let export = agent.export_state();
600        assert_eq!(export.agent_id, "agent-1");
601        assert_eq!(export.trajectories.len(), 5);
602        assert!(export.stats.avg_quality > 0.7);
603    }
604
605    #[test]
606    fn test_coordinator_creation() {
607        let coord = FederatedCoordinator::default_coordinator("coord-1", 256);
608        assert_eq!(coord.coordinator_id(), "coord-1");
609
610        let stats = coord.stats();
611        assert_eq!(stats.total_agents, 0);
612        assert_eq!(stats.total_trajectories, 0);
613    }
614
615    #[test]
616    fn test_aggregation() {
617        let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
618        coord.set_quality_threshold(0.5);
619
620        // Create agent export
621        let export = AgentExport {
622            agent_id: "agent-1".into(),
623            trajectories: vec![
624                TrajectoryExport {
625                    embedding: vec![0.1; 256],
626                    quality: 0.8,
627                    route: Some("code".into()),
628                    context: vec![],
629                    timestamp: 0,
630                },
631                TrajectoryExport {
632                    embedding: vec![0.2; 256],
633                    quality: 0.3,  // Below threshold
634                    route: None,
635                    context: vec![],
636                    timestamp: 0,
637                },
638            ],
639            stats: AgentExportStats {
640                total_trajectories: 2,
641                avg_quality: 0.55,
642                patterns_learned: 0,
643            },
644            session_duration_ms: 1000,
645            timestamp: 0,
646        };
647
648        let result = coord.aggregate(export);
649        assert_eq!(result.trajectories_accepted, 1);
650        assert_eq!(result.trajectories_rejected, 1);
651        assert_eq!(result.total_agents, 1);
652    }
653
654    #[test]
655    fn test_multi_agent_aggregation() {
656        let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
657        coord.set_consolidation_interval(2);  // Consolidate every 2 agents
658
659        for i in 0..3 {
660            let export = AgentExport {
661                agent_id: format!("agent-{}", i),
662                trajectories: vec![
663                    TrajectoryExport {
664                        embedding: vec![i as f32 * 0.1; 256],
665                        quality: 0.8,
666                        route: None,
667                        context: vec![],
668                        timestamp: 0,
669                    },
670                ],
671                stats: AgentExportStats::default(),
672                session_duration_ms: 1000,
673                timestamp: 0,
674            };
675
676            let result = coord.aggregate(export);
677            // Second agent should trigger consolidation
678            if i == 1 {
679                assert!(result.consolidated);
680            }
681        }
682
683        let stats = coord.stats();
684        assert_eq!(stats.total_agents, 3);
685        assert_eq!(stats.total_trajectories, 3);
686    }
687}