Skip to main content

ruvector_sona/training/
federated.rs

1//! Federated Learning for SONA
2//!
3//! Enable distributed learning across ephemeral agents that share
4//! trajectories with a central coordinator.
5//!
6//! ## Architecture
7//!
8//! ```text
9//! ┌─────────────┐     ┌─────────────┐     ┌─────────────┐
10//! │  Agent A    │     │  Agent B    │     │  Agent C    │
11//! │ (ephemeral) │     │ (ephemeral) │     │ (ephemeral) │
12//! └──────┬──────┘     └──────┬──────┘     └──────┬──────┘
13//!        │                   │                   │
14//!        │    export()       │    export()       │    export()
15//!        ▼                   ▼                   ▼
16//!   ┌────────────────────────────────────────────────┐
17//!   │            Federated Coordinator               │
18//!   │         (persistent, large capacity)           │
19//!   └────────────────────────────────────────────────┘
20//! ```
21
22use super::metrics::TrainingMetrics;
23use crate::engine::SonaEngine;
24use crate::time_compat::SystemTime;
25use crate::types::{LearnedPattern, SonaConfig};
26use serde::{Deserialize, Serialize};
27use std::collections::HashMap;
28
29/// Exported state from an ephemeral agent
30#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct AgentExport {
32    /// Agent identifier
33    pub agent_id: String,
34    /// Exported trajectories (embedding, quality pairs)
35    pub trajectories: Vec<TrajectoryExport>,
36    /// Agent statistics
37    pub stats: AgentExportStats,
38    /// Session duration in milliseconds
39    pub session_duration_ms: u64,
40    /// Export timestamp
41    pub timestamp: u64,
42}
43
44/// Single trajectory export
45#[derive(Clone, Debug, Serialize, Deserialize)]
46pub struct TrajectoryExport {
47    /// Query embedding
48    pub embedding: Vec<f32>,
49    /// Quality score
50    pub quality: f32,
51    /// Model route (if any)
52    pub route: Option<String>,
53    /// Context identifiers
54    pub context: Vec<String>,
55    /// Timestamp
56    pub timestamp: u64,
57}
58
59/// Agent export statistics
60#[derive(Clone, Debug, Default, Serialize, Deserialize)]
61pub struct AgentExportStats {
62    /// Total trajectories processed
63    pub total_trajectories: usize,
64    /// Average quality
65    pub avg_quality: f32,
66    /// Patterns learned locally
67    pub patterns_learned: usize,
68}
69
70/// Ephemeral agent for federated learning
71///
72/// Collects trajectories during its session and exports state before termination.
73pub struct EphemeralAgent {
74    /// Agent identifier
75    agent_id: String,
76    /// SONA engine
77    engine: SonaEngine,
78    /// Collected trajectories
79    trajectories: Vec<TrajectoryExport>,
80    /// Session start time
81    start_time: u64,
82    /// Quality samples
83    quality_samples: Vec<f32>,
84}
85
86impl EphemeralAgent {
87    /// Create a new ephemeral agent
88    pub fn new(agent_id: impl Into<String>, config: SonaConfig) -> Self {
89        let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
90
91        Self {
92            agent_id: agent_id.into(),
93            engine: SonaEngine::with_config(config),
94            trajectories: Vec::new(),
95            start_time: now,
96            quality_samples: Vec::new(),
97        }
98    }
99
100    /// Create with default config for federated learning
101    pub fn default_federated(agent_id: impl Into<String>, hidden_dim: usize) -> Self {
102        Self::new(
103            agent_id,
104            SonaConfig {
105                hidden_dim,
106                embedding_dim: hidden_dim,
107                micro_lora_rank: 2,
108                base_lora_rank: 8,
109                micro_lora_lr: 0.002,
110                trajectory_capacity: 500, // Small buffer per agent
111                pattern_clusters: 25,
112                ..Default::default()
113            },
114        )
115    }
116
117    /// Get agent ID
118    pub fn agent_id(&self) -> &str {
119        &self.agent_id
120    }
121
122    /// Get engine reference
123    pub fn engine(&self) -> &SonaEngine {
124        &self.engine
125    }
126
127    /// Get mutable engine reference
128    pub fn engine_mut(&mut self) -> &mut SonaEngine {
129        &mut self.engine
130    }
131
132    /// Process a task and record trajectory
133    pub fn process_trajectory(
134        &mut self,
135        embedding: Vec<f32>,
136        activations: Vec<f32>,
137        quality: f32,
138        route: Option<String>,
139        context: Vec<String>,
140    ) {
141        let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
142
143        // Record in SONA engine
144        let mut builder = self.engine.begin_trajectory(embedding.clone());
145        if let Some(ref r) = route {
146            builder.set_model_route(r);
147        }
148        for ctx in &context {
149            builder.add_context(ctx);
150        }
151        builder.add_step(activations, vec![], quality);
152        self.engine.end_trajectory(builder, quality);
153
154        // Store for export
155        self.trajectories.push(TrajectoryExport {
156            embedding,
157            quality,
158            route,
159            context,
160            timestamp: now,
161        });
162
163        self.quality_samples.push(quality);
164    }
165
166    /// Apply micro-LoRA to hidden states
167    pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
168        self.engine.apply_micro_lora(input, output);
169    }
170
171    /// Get number of collected trajectories
172    pub fn trajectory_count(&self) -> usize {
173        self.trajectories.len()
174    }
175
176    /// Get average quality
177    pub fn avg_quality(&self) -> f32 {
178        if self.quality_samples.is_empty() {
179            0.0
180        } else {
181            self.quality_samples.iter().sum::<f32>() / self.quality_samples.len() as f32
182        }
183    }
184
185    /// Force local learning
186    pub fn force_learn(&self) -> String {
187        self.engine.force_learn()
188    }
189
190    /// Simple process task method
191    pub fn process_task(&mut self, embedding: Vec<f32>, quality: f32) {
192        self.process_trajectory(embedding.clone(), embedding, quality, None, vec![]);
193    }
194
195    /// Process task with route information
196    pub fn process_task_with_route(&mut self, embedding: Vec<f32>, quality: f32, route: &str) {
197        self.process_trajectory(
198            embedding.clone(),
199            embedding,
200            quality,
201            Some(route.to_string()),
202            vec![],
203        );
204    }
205
206    /// Get average quality (alias for avg_quality)
207    pub fn average_quality(&self) -> f32 {
208        self.avg_quality()
209    }
210
211    /// Get uptime in seconds
212    pub fn uptime_seconds(&self) -> u64 {
213        let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
214        (now - self.start_time) / 1000
215    }
216
217    /// Get agent stats
218    pub fn stats(&self) -> AgentExportStats {
219        let engine_stats = self.engine.stats();
220        AgentExportStats {
221            total_trajectories: self.trajectories.len(),
222            avg_quality: self.avg_quality(),
223            patterns_learned: engine_stats.patterns_stored,
224        }
225    }
226
227    /// Clear trajectories (after export)
228    pub fn clear(&mut self) {
229        self.trajectories.clear();
230        self.quality_samples.clear();
231    }
232
233    /// Get learned patterns from agent
234    pub fn get_patterns(&self) -> Vec<LearnedPattern> {
235        self.engine.find_patterns(&[], 0)
236    }
237
238    /// Export agent state for federation
239    ///
240    /// Call this before terminating the agent.
241    pub fn export_state(&self) -> AgentExport {
242        let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
243
244        // Force learning before export
245        self.engine.force_learn();
246
247        let stats = self.engine.stats();
248
249        AgentExport {
250            agent_id: self.agent_id.clone(),
251            trajectories: self.trajectories.clone(),
252            stats: AgentExportStats {
253                total_trajectories: self.trajectories.len(),
254                avg_quality: self.avg_quality(),
255                patterns_learned: stats.patterns_stored,
256            },
257            session_duration_ms: now - self.start_time,
258            timestamp: now,
259        }
260    }
261}
262
263/// Agent contribution record
264#[derive(Clone, Debug, Serialize, Deserialize)]
265pub struct AgentContribution {
266    /// Number of trajectories contributed
267    pub trajectory_count: usize,
268    /// Average quality of contributions
269    pub avg_quality: f32,
270    /// Contribution timestamp
271    pub timestamp: u64,
272    /// Session duration
273    pub session_duration_ms: u64,
274}
275
276/// Federated learning coordinator
277///
278/// Aggregates learning from multiple ephemeral agents.
279pub struct FederatedCoordinator {
280    /// Coordinator identifier
281    coordinator_id: String,
282    /// Master SONA engine for aggregation
283    master_engine: SonaEngine,
284    /// Agent contributions
285    contributions: HashMap<String, AgentContribution>,
286    /// Quality threshold for accepting trajectories
287    quality_threshold: f32,
288    /// Total trajectories aggregated
289    total_trajectories: usize,
290    /// Consolidation interval (number of agents)
291    consolidation_interval: usize,
292    /// Metrics
293    metrics: TrainingMetrics,
294}
295
296impl FederatedCoordinator {
297    /// Create a new federated coordinator
298    pub fn new(coordinator_id: impl Into<String>, config: SonaConfig) -> Self {
299        let id = coordinator_id.into();
300        Self {
301            coordinator_id: id.clone(),
302            master_engine: SonaEngine::with_config(config),
303            contributions: HashMap::new(),
304            quality_threshold: 0.4,
305            total_trajectories: 0,
306            consolidation_interval: 50,
307            metrics: TrainingMetrics::new(&id),
308        }
309    }
310
311    /// Create with default config for coordination
312    pub fn default_coordinator(coordinator_id: impl Into<String>, hidden_dim: usize) -> Self {
313        Self::new(
314            coordinator_id,
315            SonaConfig {
316                hidden_dim,
317                embedding_dim: hidden_dim,
318                micro_lora_rank: 2,
319                base_lora_rank: 16,         // Deeper for aggregation
320                trajectory_capacity: 50000, // Large central buffer
321                pattern_clusters: 200,
322                ewc_lambda: 2000.0, // Strong regularization
323                ..Default::default()
324            },
325        )
326    }
327
328    /// Get coordinator ID
329    pub fn coordinator_id(&self) -> &str {
330        &self.coordinator_id
331    }
332
333    /// Set quality threshold for accepting trajectories
334    pub fn set_quality_threshold(&mut self, threshold: f32) {
335        self.quality_threshold = threshold;
336    }
337
338    /// Set consolidation interval
339    pub fn set_consolidation_interval(&mut self, interval: usize) {
340        self.consolidation_interval = interval;
341    }
342
343    /// Get master engine reference
344    pub fn master_engine(&self) -> &SonaEngine {
345        &self.master_engine
346    }
347
348    /// Aggregate agent export into coordinator
349    pub fn aggregate(&mut self, export: AgentExport) -> AggregationResult {
350        let mut accepted = 0;
351        let mut rejected = 0;
352
353        // Replay trajectories into master engine
354        for traj in &export.trajectories {
355            if traj.quality >= self.quality_threshold {
356                let mut builder = self.master_engine.begin_trajectory(traj.embedding.clone());
357                if let Some(ref route) = traj.route {
358                    builder.set_model_route(route);
359                }
360                for ctx in &traj.context {
361                    builder.add_context(ctx);
362                }
363                self.master_engine.end_trajectory(builder, traj.quality);
364
365                self.metrics.add_quality_sample(traj.quality);
366                accepted += 1;
367            } else {
368                rejected += 1;
369            }
370        }
371
372        self.total_trajectories += accepted;
373
374        // Record contribution
375        let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
376
377        self.contributions.insert(
378            export.agent_id.clone(),
379            AgentContribution {
380                trajectory_count: export.trajectories.len(),
381                avg_quality: export.stats.avg_quality,
382                timestamp: now,
383                session_duration_ms: export.session_duration_ms,
384            },
385        );
386
387        // Auto-consolidate if needed
388        let consolidated = if self.should_consolidate() {
389            self.master_engine.force_learn();
390            true
391        } else {
392            false
393        };
394
395        AggregationResult {
396            agent_id: export.agent_id,
397            trajectories_accepted: accepted,
398            trajectories_rejected: rejected,
399            consolidated,
400            total_agents: self.contributions.len(),
401            total_trajectories: self.total_trajectories,
402        }
403    }
404
405    /// Check if consolidation is needed
406    fn should_consolidate(&self) -> bool {
407        self.contributions.len() % self.consolidation_interval == 0
408    }
409
410    /// Force consolidation
411    pub fn force_consolidate(&self) -> String {
412        self.master_engine.force_learn()
413    }
414
415    /// Get initial state for new agents
416    ///
417    /// Returns learned patterns that new agents can use for warm start.
418    pub fn get_initial_patterns(&self, k: usize) -> Vec<LearnedPattern> {
419        // Find patterns similar to a general query (empty or average)
420        // Since we don't have a specific query, get all patterns
421        self.master_engine
422            .find_patterns(&[], 0)
423            .into_iter()
424            .take(k)
425            .collect()
426    }
427
428    /// Get all learned patterns
429    pub fn get_all_patterns(&self) -> Vec<LearnedPattern> {
430        self.master_engine.find_patterns(&[], 0)
431    }
432
433    /// Get coordinator statistics
434    pub fn stats(&self) -> CoordinatorStats {
435        let engine_stats = self.master_engine.stats();
436
437        CoordinatorStats {
438            coordinator_id: self.coordinator_id.clone(),
439            total_agents: self.contributions.len(),
440            total_trajectories: self.total_trajectories,
441            patterns_learned: engine_stats.patterns_stored,
442            avg_quality: self.metrics.avg_quality(),
443            quality_threshold: self.quality_threshold,
444        }
445    }
446
447    /// Get contribution history
448    pub fn contributions(&self) -> &HashMap<String, AgentContribution> {
449        &self.contributions
450    }
451
452    /// Get metrics
453    pub fn metrics(&self) -> &TrainingMetrics {
454        &self.metrics
455    }
456
457    /// Get total number of contributing agents
458    pub fn agent_count(&self) -> usize {
459        self.contributions.len()
460    }
461
462    /// Get total trajectories aggregated
463    pub fn total_trajectories(&self) -> usize {
464        self.total_trajectories
465    }
466
467    /// Find similar patterns
468    pub fn find_patterns(&self, query: &[f32], k: usize) -> Vec<LearnedPattern> {
469        self.master_engine.find_patterns(query, k)
470    }
471
472    /// Apply coordinator's LoRA to input
473    pub fn apply_lora(&self, input: &[f32]) -> Vec<f32> {
474        let mut output = vec![0.0; input.len()];
475        self.master_engine.apply_micro_lora(input, &mut output);
476        output
477    }
478
479    /// Consolidate learning (alias for force_consolidate)
480    pub fn consolidate(&self) -> String {
481        self.force_consolidate()
482    }
483
484    /// Clear all contributions
485    pub fn clear(&mut self) {
486        self.contributions.clear();
487        self.total_trajectories = 0;
488    }
489}
490
491/// Result of aggregating an agent export
492#[derive(Clone, Debug, Serialize, Deserialize)]
493pub struct AggregationResult {
494    /// Agent ID that was aggregated
495    pub agent_id: String,
496    /// Number of trajectories accepted
497    pub trajectories_accepted: usize,
498    /// Number of trajectories rejected (below quality threshold)
499    pub trajectories_rejected: usize,
500    /// Whether consolidation was triggered
501    pub consolidated: bool,
502    /// Total number of contributing agents
503    pub total_agents: usize,
504    /// Total trajectories in coordinator
505    pub total_trajectories: usize,
506}
507
508/// Coordinator statistics
509#[derive(Clone, Debug, Serialize, Deserialize)]
510pub struct CoordinatorStats {
511    /// Coordinator identifier
512    pub coordinator_id: String,
513    /// Number of contributing agents
514    pub total_agents: usize,
515    /// Total trajectories aggregated
516    pub total_trajectories: usize,
517    /// Patterns learned
518    pub patterns_learned: usize,
519    /// Average quality across all contributions
520    pub avg_quality: f32,
521    /// Quality threshold
522    pub quality_threshold: f32,
523}
524
525impl std::fmt::Display for CoordinatorStats {
526    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527        write!(
528            f,
529            "Coordinator(id={}, agents={}, trajectories={}, patterns={}, avg_quality={:.4})",
530            self.coordinator_id,
531            self.total_agents,
532            self.total_trajectories,
533            self.patterns_learned,
534            self.avg_quality
535        )
536    }
537}
538
539/// Federated learning topology
540#[derive(Clone, Debug, Serialize, Deserialize)]
541pub enum FederatedTopology {
542    /// Agents → Central Coordinator (simple, single aggregation point)
543    Star,
544    /// Agents → Regional → Global (multi-datacenter)
545    Hierarchical {
546        /// Number of regional coordinators
547        regions: usize,
548    },
549    /// Agents share directly (edge deployment)
550    PeerToPeer,
551}
552
553impl Default for FederatedTopology {
554    fn default() -> Self {
555        FederatedTopology::Star
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    #[test]
564    fn test_ephemeral_agent_creation() {
565        let agent = EphemeralAgent::default_federated("agent-1", 256);
566        assert_eq!(agent.agent_id(), "agent-1");
567        assert_eq!(agent.trajectory_count(), 0);
568    }
569
570    #[test]
571    fn test_trajectory_collection() {
572        let mut agent = EphemeralAgent::default_federated("agent-1", 256);
573
574        agent.process_trajectory(
575            vec![0.1; 256],
576            vec![0.5; 256],
577            0.8,
578            Some("code".into()),
579            vec!["file:main.rs".into()],
580        );
581
582        assert_eq!(agent.trajectory_count(), 1);
583        assert!((agent.avg_quality() - 0.8).abs() < 0.01);
584    }
585
586    #[test]
587    fn test_agent_export() {
588        let mut agent = EphemeralAgent::default_federated("agent-1", 256);
589
590        for i in 0..5 {
591            agent.process_trajectory(
592                vec![i as f32 * 0.1; 256],
593                vec![0.5; 256],
594                0.7 + i as f32 * 0.05,
595                None,
596                vec![],
597            );
598        }
599
600        let export = agent.export_state();
601        assert_eq!(export.agent_id, "agent-1");
602        assert_eq!(export.trajectories.len(), 5);
603        assert!(export.stats.avg_quality > 0.7);
604    }
605
606    #[test]
607    fn test_coordinator_creation() {
608        let coord = FederatedCoordinator::default_coordinator("coord-1", 256);
609        assert_eq!(coord.coordinator_id(), "coord-1");
610
611        let stats = coord.stats();
612        assert_eq!(stats.total_agents, 0);
613        assert_eq!(stats.total_trajectories, 0);
614    }
615
616    #[test]
617    fn test_aggregation() {
618        let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
619        coord.set_quality_threshold(0.5);
620
621        // Create agent export
622        let export = AgentExport {
623            agent_id: "agent-1".into(),
624            trajectories: vec![
625                TrajectoryExport {
626                    embedding: vec![0.1; 256],
627                    quality: 0.8,
628                    route: Some("code".into()),
629                    context: vec![],
630                    timestamp: 0,
631                },
632                TrajectoryExport {
633                    embedding: vec![0.2; 256],
634                    quality: 0.3, // Below threshold
635                    route: None,
636                    context: vec![],
637                    timestamp: 0,
638                },
639            ],
640            stats: AgentExportStats {
641                total_trajectories: 2,
642                avg_quality: 0.55,
643                patterns_learned: 0,
644            },
645            session_duration_ms: 1000,
646            timestamp: 0,
647        };
648
649        let result = coord.aggregate(export);
650        assert_eq!(result.trajectories_accepted, 1);
651        assert_eq!(result.trajectories_rejected, 1);
652        assert_eq!(result.total_agents, 1);
653    }
654
655    #[test]
656    fn test_multi_agent_aggregation() {
657        let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
658        coord.set_consolidation_interval(2); // Consolidate every 2 agents
659
660        for i in 0..3 {
661            let export = AgentExport {
662                agent_id: format!("agent-{}", i),
663                trajectories: vec![TrajectoryExport {
664                    embedding: vec![i as f32 * 0.1; 256],
665                    quality: 0.8,
666                    route: None,
667                    context: vec![],
668                    timestamp: 0,
669                }],
670                stats: AgentExportStats::default(),
671                session_duration_ms: 1000,
672                timestamp: 0,
673            };
674
675            let result = coord.aggregate(export);
676            // Second agent should trigger consolidation
677            if i == 1 {
678                assert!(result.consolidated);
679            }
680        }
681
682        let stats = coord.stats();
683        assert_eq!(stats.total_agents, 3);
684        assert_eq!(stats.total_trajectories, 3);
685    }
686}