ruvector_sona/training/
federated.rs

1//! Federated Learning for SONA
2//!
3//! Enable distributed learning across ephemeral agents that share
4//! trajectories with a central coordinator.
5//!
6//! ## Architecture
7//!
8//! ```text
9//! ┌─────────────┐     ┌─────────────┐     ┌─────────────┐
10//! │  Agent A    │     │  Agent B    │     │  Agent C    │
11//! │ (ephemeral) │     │ (ephemeral) │     │ (ephemeral) │
12//! └──────┬──────┘     └──────┬──────┘     └──────┬──────┘
13//!        │                   │                   │
14//!        │    export()       │    export()       │    export()
15//!        ▼                   ▼                   ▼
16//!   ┌────────────────────────────────────────────────┐
17//!   │            Federated Coordinator               │
18//!   │         (persistent, large capacity)           │
19//!   └────────────────────────────────────────────────┘
20//! ```
21
22use crate::engine::SonaEngine;
23use crate::types::{SonaConfig, LearnedPattern};
24use super::metrics::TrainingMetrics;
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::time::{SystemTime, UNIX_EPOCH};
28
29/// Exported state from an ephemeral agent
30#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct AgentExport {
32    /// Agent identifier
33    pub agent_id: String,
34    /// Exported trajectories (embedding, quality pairs)
35    pub trajectories: Vec<TrajectoryExport>,
36    /// Agent statistics
37    pub stats: AgentExportStats,
38    /// Session duration in milliseconds
39    pub session_duration_ms: u64,
40    /// Export timestamp
41    pub timestamp: u64,
42}
43
44/// Single trajectory export
45#[derive(Clone, Debug, Serialize, Deserialize)]
46pub struct TrajectoryExport {
47    /// Query embedding
48    pub embedding: Vec<f32>,
49    /// Quality score
50    pub quality: f32,
51    /// Model route (if any)
52    pub route: Option<String>,
53    /// Context identifiers
54    pub context: Vec<String>,
55    /// Timestamp
56    pub timestamp: u64,
57}
58
59/// Agent export statistics
60#[derive(Clone, Debug, Default, Serialize, Deserialize)]
61pub struct AgentExportStats {
62    /// Total trajectories processed
63    pub total_trajectories: usize,
64    /// Average quality
65    pub avg_quality: f32,
66    /// Patterns learned locally
67    pub patterns_learned: usize,
68}
69
70/// Ephemeral agent for federated learning
71///
72/// Collects trajectories during its session and exports state before termination.
73pub struct EphemeralAgent {
74    /// Agent identifier
75    agent_id: String,
76    /// SONA engine
77    engine: SonaEngine,
78    /// Collected trajectories
79    trajectories: Vec<TrajectoryExport>,
80    /// Session start time
81    start_time: u64,
82    /// Quality samples
83    quality_samples: Vec<f32>,
84}
85
86impl EphemeralAgent {
87    /// Create a new ephemeral agent
88    pub fn new(agent_id: impl Into<String>, config: SonaConfig) -> Self {
89        let now = SystemTime::now()
90            .duration_since(UNIX_EPOCH)
91            .unwrap_or_default()
92            .as_millis() as u64;
93
94        Self {
95            agent_id: agent_id.into(),
96            engine: SonaEngine::with_config(config),
97            trajectories: Vec::new(),
98            start_time: now,
99            quality_samples: Vec::new(),
100        }
101    }
102
103    /// Create with default config for federated learning
104    pub fn default_federated(agent_id: impl Into<String>, hidden_dim: usize) -> Self {
105        Self::new(agent_id, SonaConfig {
106            hidden_dim,
107            embedding_dim: hidden_dim,
108            micro_lora_rank: 2,
109            base_lora_rank: 8,
110            micro_lora_lr: 0.002,
111            trajectory_capacity: 500,  // Small buffer per agent
112            pattern_clusters: 25,
113            ..Default::default()
114        })
115    }
116
117    /// Get agent ID
118    pub fn agent_id(&self) -> &str {
119        &self.agent_id
120    }
121
122    /// Get engine reference
123    pub fn engine(&self) -> &SonaEngine {
124        &self.engine
125    }
126
127    /// Get mutable engine reference
128    pub fn engine_mut(&mut self) -> &mut SonaEngine {
129        &mut self.engine
130    }
131
132    /// Process a task and record trajectory
133    pub fn process_trajectory(
134        &mut self,
135        embedding: Vec<f32>,
136        activations: Vec<f32>,
137        quality: f32,
138        route: Option<String>,
139        context: Vec<String>,
140    ) {
141        let now = SystemTime::now()
142            .duration_since(UNIX_EPOCH)
143            .unwrap_or_default()
144            .as_millis() as u64;
145
146        // Record in SONA engine
147        let mut builder = self.engine.begin_trajectory(embedding.clone());
148        if let Some(ref r) = route {
149            builder.set_model_route(r);
150        }
151        for ctx in &context {
152            builder.add_context(ctx);
153        }
154        builder.add_step(activations, vec![], quality);
155        self.engine.end_trajectory(builder, quality);
156
157        // Store for export
158        self.trajectories.push(TrajectoryExport {
159            embedding,
160            quality,
161            route,
162            context,
163            timestamp: now,
164        });
165
166        self.quality_samples.push(quality);
167    }
168
169    /// Apply micro-LoRA to hidden states
170    pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
171        self.engine.apply_micro_lora(input, output);
172    }
173
174    /// Get number of collected trajectories
175    pub fn trajectory_count(&self) -> usize {
176        self.trajectories.len()
177    }
178
179    /// Get average quality
180    pub fn avg_quality(&self) -> f32 {
181        if self.quality_samples.is_empty() {
182            0.0
183        } else {
184            self.quality_samples.iter().sum::<f32>() / self.quality_samples.len() as f32
185        }
186    }
187
188    /// Force local learning
189    pub fn force_learn(&self) {
190        self.engine.force_learn();
191    }
192
193    /// Export agent state for federation
194    ///
195    /// Call this before terminating the agent.
196    pub fn export_state(&self) -> AgentExport {
197        let now = SystemTime::now()
198            .duration_since(UNIX_EPOCH)
199            .unwrap_or_default()
200            .as_millis() as u64;
201
202        // Force learning before export
203        self.engine.force_learn();
204
205        let stats = self.engine.stats();
206
207        AgentExport {
208            agent_id: self.agent_id.clone(),
209            trajectories: self.trajectories.clone(),
210            stats: AgentExportStats {
211                total_trajectories: self.trajectories.len(),
212                avg_quality: self.avg_quality(),
213                patterns_learned: stats.patterns_stored,
214            },
215            session_duration_ms: now - self.start_time,
216            timestamp: now,
217        }
218    }
219}
220
221/// Agent contribution record
222#[derive(Clone, Debug, Serialize, Deserialize)]
223pub struct AgentContribution {
224    /// Number of trajectories contributed
225    pub trajectory_count: usize,
226    /// Average quality of contributions
227    pub avg_quality: f32,
228    /// Contribution timestamp
229    pub timestamp: u64,
230    /// Session duration
231    pub session_duration_ms: u64,
232}
233
234/// Federated learning coordinator
235///
236/// Aggregates learning from multiple ephemeral agents.
237pub struct FederatedCoordinator {
238    /// Coordinator identifier
239    coordinator_id: String,
240    /// Master SONA engine for aggregation
241    master_engine: SonaEngine,
242    /// Agent contributions
243    contributions: HashMap<String, AgentContribution>,
244    /// Quality threshold for accepting trajectories
245    quality_threshold: f32,
246    /// Total trajectories aggregated
247    total_trajectories: usize,
248    /// Consolidation interval (number of agents)
249    consolidation_interval: usize,
250    /// Metrics
251    metrics: TrainingMetrics,
252}
253
254impl FederatedCoordinator {
255    /// Create a new federated coordinator
256    pub fn new(coordinator_id: impl Into<String>, config: SonaConfig) -> Self {
257        let id = coordinator_id.into();
258        Self {
259            coordinator_id: id.clone(),
260            master_engine: SonaEngine::with_config(config),
261            contributions: HashMap::new(),
262            quality_threshold: 0.4,
263            total_trajectories: 0,
264            consolidation_interval: 50,
265            metrics: TrainingMetrics::new(&id),
266        }
267    }
268
269    /// Create with default config for coordination
270    pub fn default_coordinator(coordinator_id: impl Into<String>, hidden_dim: usize) -> Self {
271        Self::new(coordinator_id, SonaConfig {
272            hidden_dim,
273            embedding_dim: hidden_dim,
274            micro_lora_rank: 2,
275            base_lora_rank: 16,          // Deeper for aggregation
276            trajectory_capacity: 50000,   // Large central buffer
277            pattern_clusters: 200,
278            ewc_lambda: 2000.0,          // Strong regularization
279            ..Default::default()
280        })
281    }
282
283    /// Get coordinator ID
284    pub fn coordinator_id(&self) -> &str {
285        &self.coordinator_id
286    }
287
288    /// Set quality threshold for accepting trajectories
289    pub fn set_quality_threshold(&mut self, threshold: f32) {
290        self.quality_threshold = threshold;
291    }
292
293    /// Set consolidation interval
294    pub fn set_consolidation_interval(&mut self, interval: usize) {
295        self.consolidation_interval = interval;
296    }
297
298    /// Get master engine reference
299    pub fn master_engine(&self) -> &SonaEngine {
300        &self.master_engine
301    }
302
303    /// Aggregate agent export into coordinator
304    pub fn aggregate(&mut self, export: AgentExport) -> AggregationResult {
305        let mut accepted = 0;
306        let mut rejected = 0;
307
308        // Replay trajectories into master engine
309        for traj in &export.trajectories {
310            if traj.quality >= self.quality_threshold {
311                let mut builder = self.master_engine.begin_trajectory(traj.embedding.clone());
312                if let Some(ref route) = traj.route {
313                    builder.set_model_route(route);
314                }
315                for ctx in &traj.context {
316                    builder.add_context(ctx);
317                }
318                self.master_engine.end_trajectory(builder, traj.quality);
319
320                self.metrics.add_quality_sample(traj.quality);
321                accepted += 1;
322            } else {
323                rejected += 1;
324            }
325        }
326
327        self.total_trajectories += accepted;
328
329        // Record contribution
330        let now = SystemTime::now()
331            .duration_since(UNIX_EPOCH)
332            .unwrap_or_default()
333            .as_millis() as u64;
334
335        self.contributions.insert(export.agent_id.clone(), AgentContribution {
336            trajectory_count: export.trajectories.len(),
337            avg_quality: export.stats.avg_quality,
338            timestamp: now,
339            session_duration_ms: export.session_duration_ms,
340        });
341
342        // Auto-consolidate if needed
343        let consolidated = if self.should_consolidate() {
344            self.master_engine.force_learn();
345            true
346        } else {
347            false
348        };
349
350        AggregationResult {
351            agent_id: export.agent_id,
352            trajectories_accepted: accepted,
353            trajectories_rejected: rejected,
354            consolidated,
355            total_agents: self.contributions.len(),
356            total_trajectories: self.total_trajectories,
357        }
358    }
359
360    /// Check if consolidation is needed
361    fn should_consolidate(&self) -> bool {
362        self.contributions.len() % self.consolidation_interval == 0
363    }
364
365    /// Force consolidation
366    pub fn force_consolidate(&self) -> String {
367        self.master_engine.force_learn()
368    }
369
370    /// Get initial state for new agents
371    ///
372    /// Returns learned patterns that new agents can use for warm start.
373    pub fn get_initial_patterns(&self, k: usize) -> Vec<LearnedPattern> {
374        // Find patterns similar to a general query (empty or average)
375        // Since we don't have a specific query, get all patterns
376        self.master_engine.find_patterns(&[], 0)
377            .into_iter()
378            .take(k)
379            .collect()
380    }
381
382    /// Get all learned patterns
383    pub fn get_all_patterns(&self) -> Vec<LearnedPattern> {
384        self.master_engine.find_patterns(&[], 0)
385    }
386
387    /// Get coordinator statistics
388    pub fn stats(&self) -> CoordinatorStats {
389        let engine_stats = self.master_engine.stats();
390
391        CoordinatorStats {
392            coordinator_id: self.coordinator_id.clone(),
393            total_agents: self.contributions.len(),
394            total_trajectories: self.total_trajectories,
395            patterns_learned: engine_stats.patterns_stored,
396            avg_quality: self.metrics.avg_quality(),
397            quality_threshold: self.quality_threshold,
398        }
399    }
400
401    /// Get contribution history
402    pub fn contributions(&self) -> &HashMap<String, AgentContribution> {
403        &self.contributions
404    }
405
406    /// Get metrics
407    pub fn metrics(&self) -> &TrainingMetrics {
408        &self.metrics
409    }
410}
411
412/// Result of aggregating an agent export
413#[derive(Clone, Debug, Serialize, Deserialize)]
414pub struct AggregationResult {
415    /// Agent ID that was aggregated
416    pub agent_id: String,
417    /// Number of trajectories accepted
418    pub trajectories_accepted: usize,
419    /// Number of trajectories rejected (below quality threshold)
420    pub trajectories_rejected: usize,
421    /// Whether consolidation was triggered
422    pub consolidated: bool,
423    /// Total number of contributing agents
424    pub total_agents: usize,
425    /// Total trajectories in coordinator
426    pub total_trajectories: usize,
427}
428
429/// Coordinator statistics
430#[derive(Clone, Debug, Serialize, Deserialize)]
431pub struct CoordinatorStats {
432    /// Coordinator identifier
433    pub coordinator_id: String,
434    /// Number of contributing agents
435    pub total_agents: usize,
436    /// Total trajectories aggregated
437    pub total_trajectories: usize,
438    /// Patterns learned
439    pub patterns_learned: usize,
440    /// Average quality across all contributions
441    pub avg_quality: f32,
442    /// Quality threshold
443    pub quality_threshold: f32,
444}
445
446impl std::fmt::Display for CoordinatorStats {
447    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448        write!(
449            f,
450            "Coordinator(id={}, agents={}, trajectories={}, patterns={}, avg_quality={:.4})",
451            self.coordinator_id,
452            self.total_agents,
453            self.total_trajectories,
454            self.patterns_learned,
455            self.avg_quality
456        )
457    }
458}
459
460/// Federated learning topology
461#[derive(Clone, Debug, Serialize, Deserialize)]
462pub enum FederatedTopology {
463    /// Agents → Central Coordinator (simple, single aggregation point)
464    Star,
465    /// Agents → Regional → Global (multi-datacenter)
466    Hierarchical {
467        /// Number of regional coordinators
468        regions: usize,
469    },
470    /// Agents share directly (edge deployment)
471    PeerToPeer,
472}
473
474impl Default for FederatedTopology {
475    fn default() -> Self {
476        FederatedTopology::Star
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483
484    #[test]
485    fn test_ephemeral_agent_creation() {
486        let agent = EphemeralAgent::default_federated("agent-1", 256);
487        assert_eq!(agent.agent_id(), "agent-1");
488        assert_eq!(agent.trajectory_count(), 0);
489    }
490
491    #[test]
492    fn test_trajectory_collection() {
493        let mut agent = EphemeralAgent::default_federated("agent-1", 256);
494
495        agent.process_trajectory(
496            vec![0.1; 256],
497            vec![0.5; 256],
498            0.8,
499            Some("code".into()),
500            vec!["file:main.rs".into()],
501        );
502
503        assert_eq!(agent.trajectory_count(), 1);
504        assert!((agent.avg_quality() - 0.8).abs() < 0.01);
505    }
506
507    #[test]
508    fn test_agent_export() {
509        let mut agent = EphemeralAgent::default_federated("agent-1", 256);
510
511        for i in 0..5 {
512            agent.process_trajectory(
513                vec![i as f32 * 0.1; 256],
514                vec![0.5; 256],
515                0.7 + i as f32 * 0.05,
516                None,
517                vec![],
518            );
519        }
520
521        let export = agent.export_state();
522        assert_eq!(export.agent_id, "agent-1");
523        assert_eq!(export.trajectories.len(), 5);
524        assert!(export.stats.avg_quality > 0.7);
525    }
526
527    #[test]
528    fn test_coordinator_creation() {
529        let coord = FederatedCoordinator::default_coordinator("coord-1", 256);
530        assert_eq!(coord.coordinator_id(), "coord-1");
531
532        let stats = coord.stats();
533        assert_eq!(stats.total_agents, 0);
534        assert_eq!(stats.total_trajectories, 0);
535    }
536
537    #[test]
538    fn test_aggregation() {
539        let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
540        coord.set_quality_threshold(0.5);
541
542        // Create agent export
543        let export = AgentExport {
544            agent_id: "agent-1".into(),
545            trajectories: vec![
546                TrajectoryExport {
547                    embedding: vec![0.1; 256],
548                    quality: 0.8,
549                    route: Some("code".into()),
550                    context: vec![],
551                    timestamp: 0,
552                },
553                TrajectoryExport {
554                    embedding: vec![0.2; 256],
555                    quality: 0.3,  // Below threshold
556                    route: None,
557                    context: vec![],
558                    timestamp: 0,
559                },
560            ],
561            stats: AgentExportStats {
562                total_trajectories: 2,
563                avg_quality: 0.55,
564                patterns_learned: 0,
565            },
566            session_duration_ms: 1000,
567            timestamp: 0,
568        };
569
570        let result = coord.aggregate(export);
571        assert_eq!(result.trajectories_accepted, 1);
572        assert_eq!(result.trajectories_rejected, 1);
573        assert_eq!(result.total_agents, 1);
574    }
575
576    #[test]
577    fn test_multi_agent_aggregation() {
578        let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
579        coord.set_consolidation_interval(2);  // Consolidate every 2 agents
580
581        for i in 0..3 {
582            let export = AgentExport {
583                agent_id: format!("agent-{}", i),
584                trajectories: vec![
585                    TrajectoryExport {
586                        embedding: vec![i as f32 * 0.1; 256],
587                        quality: 0.8,
588                        route: None,
589                        context: vec![],
590                        timestamp: 0,
591                    },
592                ],
593                stats: AgentExportStats::default(),
594                session_duration_ms: 1000,
595                timestamp: 0,
596            };
597
598            let result = coord.aggregate(export);
599            // Second agent should trigger consolidation
600            if i == 1 {
601                assert!(result.consolidated);
602            }
603        }
604
605        let stats = coord.stats();
606        assert_eq!(stats.total_agents, 3);
607        assert_eq!(stats.total_trajectories, 3);
608    }
609}