1use crate::engine::SonaEngine;
23use crate::types::{SonaConfig, LearnedPattern};
24use super::metrics::TrainingMetrics;
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::time::{SystemTime, UNIX_EPOCH};
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct AgentExport {
32 pub agent_id: String,
34 pub trajectories: Vec<TrajectoryExport>,
36 pub stats: AgentExportStats,
38 pub session_duration_ms: u64,
40 pub timestamp: u64,
42}
43
44#[derive(Clone, Debug, Serialize, Deserialize)]
46pub struct TrajectoryExport {
47 pub embedding: Vec<f32>,
49 pub quality: f32,
51 pub route: Option<String>,
53 pub context: Vec<String>,
55 pub timestamp: u64,
57}
58
59#[derive(Clone, Debug, Default, Serialize, Deserialize)]
61pub struct AgentExportStats {
62 pub total_trajectories: usize,
64 pub avg_quality: f32,
66 pub patterns_learned: usize,
68}
69
70pub struct EphemeralAgent {
74 agent_id: String,
76 engine: SonaEngine,
78 trajectories: Vec<TrajectoryExport>,
80 start_time: u64,
82 quality_samples: Vec<f32>,
84}
85
86impl EphemeralAgent {
87 pub fn new(agent_id: impl Into<String>, config: SonaConfig) -> Self {
89 let now = SystemTime::now()
90 .duration_since(UNIX_EPOCH)
91 .unwrap_or_default()
92 .as_millis() as u64;
93
94 Self {
95 agent_id: agent_id.into(),
96 engine: SonaEngine::with_config(config),
97 trajectories: Vec::new(),
98 start_time: now,
99 quality_samples: Vec::new(),
100 }
101 }
102
103 pub fn default_federated(agent_id: impl Into<String>, hidden_dim: usize) -> Self {
105 Self::new(agent_id, SonaConfig {
106 hidden_dim,
107 embedding_dim: hidden_dim,
108 micro_lora_rank: 2,
109 base_lora_rank: 8,
110 micro_lora_lr: 0.002,
111 trajectory_capacity: 500, pattern_clusters: 25,
113 ..Default::default()
114 })
115 }
116
117 pub fn agent_id(&self) -> &str {
119 &self.agent_id
120 }
121
122 pub fn engine(&self) -> &SonaEngine {
124 &self.engine
125 }
126
127 pub fn engine_mut(&mut self) -> &mut SonaEngine {
129 &mut self.engine
130 }
131
132 pub fn process_trajectory(
134 &mut self,
135 embedding: Vec<f32>,
136 activations: Vec<f32>,
137 quality: f32,
138 route: Option<String>,
139 context: Vec<String>,
140 ) {
141 let now = SystemTime::now()
142 .duration_since(UNIX_EPOCH)
143 .unwrap_or_default()
144 .as_millis() as u64;
145
146 let mut builder = self.engine.begin_trajectory(embedding.clone());
148 if let Some(ref r) = route {
149 builder.set_model_route(r);
150 }
151 for ctx in &context {
152 builder.add_context(ctx);
153 }
154 builder.add_step(activations, vec![], quality);
155 self.engine.end_trajectory(builder, quality);
156
157 self.trajectories.push(TrajectoryExport {
159 embedding,
160 quality,
161 route,
162 context,
163 timestamp: now,
164 });
165
166 self.quality_samples.push(quality);
167 }
168
169 pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
171 self.engine.apply_micro_lora(input, output);
172 }
173
174 pub fn trajectory_count(&self) -> usize {
176 self.trajectories.len()
177 }
178
179 pub fn avg_quality(&self) -> f32 {
181 if self.quality_samples.is_empty() {
182 0.0
183 } else {
184 self.quality_samples.iter().sum::<f32>() / self.quality_samples.len() as f32
185 }
186 }
187
188 pub fn force_learn(&self) {
190 self.engine.force_learn();
191 }
192
193 pub fn export_state(&self) -> AgentExport {
197 let now = SystemTime::now()
198 .duration_since(UNIX_EPOCH)
199 .unwrap_or_default()
200 .as_millis() as u64;
201
202 self.engine.force_learn();
204
205 let stats = self.engine.stats();
206
207 AgentExport {
208 agent_id: self.agent_id.clone(),
209 trajectories: self.trajectories.clone(),
210 stats: AgentExportStats {
211 total_trajectories: self.trajectories.len(),
212 avg_quality: self.avg_quality(),
213 patterns_learned: stats.patterns_stored,
214 },
215 session_duration_ms: now - self.start_time,
216 timestamp: now,
217 }
218 }
219}
220
221#[derive(Clone, Debug, Serialize, Deserialize)]
223pub struct AgentContribution {
224 pub trajectory_count: usize,
226 pub avg_quality: f32,
228 pub timestamp: u64,
230 pub session_duration_ms: u64,
232}
233
234pub struct FederatedCoordinator {
238 coordinator_id: String,
240 master_engine: SonaEngine,
242 contributions: HashMap<String, AgentContribution>,
244 quality_threshold: f32,
246 total_trajectories: usize,
248 consolidation_interval: usize,
250 metrics: TrainingMetrics,
252}
253
254impl FederatedCoordinator {
255 pub fn new(coordinator_id: impl Into<String>, config: SonaConfig) -> Self {
257 let id = coordinator_id.into();
258 Self {
259 coordinator_id: id.clone(),
260 master_engine: SonaEngine::with_config(config),
261 contributions: HashMap::new(),
262 quality_threshold: 0.4,
263 total_trajectories: 0,
264 consolidation_interval: 50,
265 metrics: TrainingMetrics::new(&id),
266 }
267 }
268
269 pub fn default_coordinator(coordinator_id: impl Into<String>, hidden_dim: usize) -> Self {
271 Self::new(coordinator_id, SonaConfig {
272 hidden_dim,
273 embedding_dim: hidden_dim,
274 micro_lora_rank: 2,
275 base_lora_rank: 16, trajectory_capacity: 50000, pattern_clusters: 200,
278 ewc_lambda: 2000.0, ..Default::default()
280 })
281 }
282
283 pub fn coordinator_id(&self) -> &str {
285 &self.coordinator_id
286 }
287
288 pub fn set_quality_threshold(&mut self, threshold: f32) {
290 self.quality_threshold = threshold;
291 }
292
293 pub fn set_consolidation_interval(&mut self, interval: usize) {
295 self.consolidation_interval = interval;
296 }
297
298 pub fn master_engine(&self) -> &SonaEngine {
300 &self.master_engine
301 }
302
303 pub fn aggregate(&mut self, export: AgentExport) -> AggregationResult {
305 let mut accepted = 0;
306 let mut rejected = 0;
307
308 for traj in &export.trajectories {
310 if traj.quality >= self.quality_threshold {
311 let mut builder = self.master_engine.begin_trajectory(traj.embedding.clone());
312 if let Some(ref route) = traj.route {
313 builder.set_model_route(route);
314 }
315 for ctx in &traj.context {
316 builder.add_context(ctx);
317 }
318 self.master_engine.end_trajectory(builder, traj.quality);
319
320 self.metrics.add_quality_sample(traj.quality);
321 accepted += 1;
322 } else {
323 rejected += 1;
324 }
325 }
326
327 self.total_trajectories += accepted;
328
329 let now = SystemTime::now()
331 .duration_since(UNIX_EPOCH)
332 .unwrap_or_default()
333 .as_millis() as u64;
334
335 self.contributions.insert(export.agent_id.clone(), AgentContribution {
336 trajectory_count: export.trajectories.len(),
337 avg_quality: export.stats.avg_quality,
338 timestamp: now,
339 session_duration_ms: export.session_duration_ms,
340 });
341
342 let consolidated = if self.should_consolidate() {
344 self.master_engine.force_learn();
345 true
346 } else {
347 false
348 };
349
350 AggregationResult {
351 agent_id: export.agent_id,
352 trajectories_accepted: accepted,
353 trajectories_rejected: rejected,
354 consolidated,
355 total_agents: self.contributions.len(),
356 total_trajectories: self.total_trajectories,
357 }
358 }
359
360 fn should_consolidate(&self) -> bool {
362 self.contributions.len() % self.consolidation_interval == 0
363 }
364
365 pub fn force_consolidate(&self) -> String {
367 self.master_engine.force_learn()
368 }
369
370 pub fn get_initial_patterns(&self, k: usize) -> Vec<LearnedPattern> {
374 self.master_engine.find_patterns(&[], 0)
377 .into_iter()
378 .take(k)
379 .collect()
380 }
381
382 pub fn get_all_patterns(&self) -> Vec<LearnedPattern> {
384 self.master_engine.find_patterns(&[], 0)
385 }
386
387 pub fn stats(&self) -> CoordinatorStats {
389 let engine_stats = self.master_engine.stats();
390
391 CoordinatorStats {
392 coordinator_id: self.coordinator_id.clone(),
393 total_agents: self.contributions.len(),
394 total_trajectories: self.total_trajectories,
395 patterns_learned: engine_stats.patterns_stored,
396 avg_quality: self.metrics.avg_quality(),
397 quality_threshold: self.quality_threshold,
398 }
399 }
400
401 pub fn contributions(&self) -> &HashMap<String, AgentContribution> {
403 &self.contributions
404 }
405
406 pub fn metrics(&self) -> &TrainingMetrics {
408 &self.metrics
409 }
410}
411
412#[derive(Clone, Debug, Serialize, Deserialize)]
414pub struct AggregationResult {
415 pub agent_id: String,
417 pub trajectories_accepted: usize,
419 pub trajectories_rejected: usize,
421 pub consolidated: bool,
423 pub total_agents: usize,
425 pub total_trajectories: usize,
427}
428
429#[derive(Clone, Debug, Serialize, Deserialize)]
431pub struct CoordinatorStats {
432 pub coordinator_id: String,
434 pub total_agents: usize,
436 pub total_trajectories: usize,
438 pub patterns_learned: usize,
440 pub avg_quality: f32,
442 pub quality_threshold: f32,
444}
445
446impl std::fmt::Display for CoordinatorStats {
447 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448 write!(
449 f,
450 "Coordinator(id={}, agents={}, trajectories={}, patterns={}, avg_quality={:.4})",
451 self.coordinator_id,
452 self.total_agents,
453 self.total_trajectories,
454 self.patterns_learned,
455 self.avg_quality
456 )
457 }
458}
459
460#[derive(Clone, Debug, Serialize, Deserialize)]
462pub enum FederatedTopology {
463 Star,
465 Hierarchical {
467 regions: usize,
469 },
470 PeerToPeer,
472}
473
474impl Default for FederatedTopology {
475 fn default() -> Self {
476 FederatedTopology::Star
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483
484 #[test]
485 fn test_ephemeral_agent_creation() {
486 let agent = EphemeralAgent::default_federated("agent-1", 256);
487 assert_eq!(agent.agent_id(), "agent-1");
488 assert_eq!(agent.trajectory_count(), 0);
489 }
490
491 #[test]
492 fn test_trajectory_collection() {
493 let mut agent = EphemeralAgent::default_federated("agent-1", 256);
494
495 agent.process_trajectory(
496 vec![0.1; 256],
497 vec![0.5; 256],
498 0.8,
499 Some("code".into()),
500 vec!["file:main.rs".into()],
501 );
502
503 assert_eq!(agent.trajectory_count(), 1);
504 assert!((agent.avg_quality() - 0.8).abs() < 0.01);
505 }
506
507 #[test]
508 fn test_agent_export() {
509 let mut agent = EphemeralAgent::default_federated("agent-1", 256);
510
511 for i in 0..5 {
512 agent.process_trajectory(
513 vec![i as f32 * 0.1; 256],
514 vec![0.5; 256],
515 0.7 + i as f32 * 0.05,
516 None,
517 vec![],
518 );
519 }
520
521 let export = agent.export_state();
522 assert_eq!(export.agent_id, "agent-1");
523 assert_eq!(export.trajectories.len(), 5);
524 assert!(export.stats.avg_quality > 0.7);
525 }
526
527 #[test]
528 fn test_coordinator_creation() {
529 let coord = FederatedCoordinator::default_coordinator("coord-1", 256);
530 assert_eq!(coord.coordinator_id(), "coord-1");
531
532 let stats = coord.stats();
533 assert_eq!(stats.total_agents, 0);
534 assert_eq!(stats.total_trajectories, 0);
535 }
536
537 #[test]
538 fn test_aggregation() {
539 let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
540 coord.set_quality_threshold(0.5);
541
542 let export = AgentExport {
544 agent_id: "agent-1".into(),
545 trajectories: vec![
546 TrajectoryExport {
547 embedding: vec![0.1; 256],
548 quality: 0.8,
549 route: Some("code".into()),
550 context: vec![],
551 timestamp: 0,
552 },
553 TrajectoryExport {
554 embedding: vec![0.2; 256],
555 quality: 0.3, route: None,
557 context: vec![],
558 timestamp: 0,
559 },
560 ],
561 stats: AgentExportStats {
562 total_trajectories: 2,
563 avg_quality: 0.55,
564 patterns_learned: 0,
565 },
566 session_duration_ms: 1000,
567 timestamp: 0,
568 };
569
570 let result = coord.aggregate(export);
571 assert_eq!(result.trajectories_accepted, 1);
572 assert_eq!(result.trajectories_rejected, 1);
573 assert_eq!(result.total_agents, 1);
574 }
575
576 #[test]
577 fn test_multi_agent_aggregation() {
578 let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
579 coord.set_consolidation_interval(2); for i in 0..3 {
582 let export = AgentExport {
583 agent_id: format!("agent-{}", i),
584 trajectories: vec![
585 TrajectoryExport {
586 embedding: vec![i as f32 * 0.1; 256],
587 quality: 0.8,
588 route: None,
589 context: vec![],
590 timestamp: 0,
591 },
592 ],
593 stats: AgentExportStats::default(),
594 session_duration_ms: 1000,
595 timestamp: 0,
596 };
597
598 let result = coord.aggregate(export);
599 if i == 1 {
601 assert!(result.consolidated);
602 }
603 }
604
605 let stats = coord.stats();
606 assert_eq!(stats.total_agents, 3);
607 assert_eq!(stats.total_trajectories, 3);
608 }
609}