1use crate::engine::SonaEngine;
23use crate::types::{SonaConfig, LearnedPattern};
24use super::metrics::TrainingMetrics;
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::time::{SystemTime, UNIX_EPOCH};
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct AgentExport {
32 pub agent_id: String,
34 pub trajectories: Vec<TrajectoryExport>,
36 pub stats: AgentExportStats,
38 pub session_duration_ms: u64,
40 pub timestamp: u64,
42}
43
44#[derive(Clone, Debug, Serialize, Deserialize)]
46pub struct TrajectoryExport {
47 pub embedding: Vec<f32>,
49 pub quality: f32,
51 pub route: Option<String>,
53 pub context: Vec<String>,
55 pub timestamp: u64,
57}
58
59#[derive(Clone, Debug, Default, Serialize, Deserialize)]
61pub struct AgentExportStats {
62 pub total_trajectories: usize,
64 pub avg_quality: f32,
66 pub patterns_learned: usize,
68}
69
70pub struct EphemeralAgent {
74 agent_id: String,
76 engine: SonaEngine,
78 trajectories: Vec<TrajectoryExport>,
80 start_time: u64,
82 quality_samples: Vec<f32>,
84}
85
86impl EphemeralAgent {
87 pub fn new(agent_id: impl Into<String>, config: SonaConfig) -> Self {
89 let now = SystemTime::now()
90 .duration_since(UNIX_EPOCH)
91 .unwrap_or_default()
92 .as_millis() as u64;
93
94 Self {
95 agent_id: agent_id.into(),
96 engine: SonaEngine::with_config(config),
97 trajectories: Vec::new(),
98 start_time: now,
99 quality_samples: Vec::new(),
100 }
101 }
102
103 pub fn default_federated(agent_id: impl Into<String>, hidden_dim: usize) -> Self {
105 Self::new(agent_id, SonaConfig {
106 hidden_dim,
107 embedding_dim: hidden_dim,
108 micro_lora_rank: 2,
109 base_lora_rank: 8,
110 micro_lora_lr: 0.002,
111 trajectory_capacity: 500, pattern_clusters: 25,
113 ..Default::default()
114 })
115 }
116
117 pub fn agent_id(&self) -> &str {
119 &self.agent_id
120 }
121
122 pub fn engine(&self) -> &SonaEngine {
124 &self.engine
125 }
126
127 pub fn engine_mut(&mut self) -> &mut SonaEngine {
129 &mut self.engine
130 }
131
132 pub fn process_trajectory(
134 &mut self,
135 embedding: Vec<f32>,
136 activations: Vec<f32>,
137 quality: f32,
138 route: Option<String>,
139 context: Vec<String>,
140 ) {
141 let now = SystemTime::now()
142 .duration_since(UNIX_EPOCH)
143 .unwrap_or_default()
144 .as_millis() as u64;
145
146 let mut builder = self.engine.begin_trajectory(embedding.clone());
148 if let Some(ref r) = route {
149 builder.set_model_route(r);
150 }
151 for ctx in &context {
152 builder.add_context(ctx);
153 }
154 builder.add_step(activations, vec![], quality);
155 self.engine.end_trajectory(builder, quality);
156
157 self.trajectories.push(TrajectoryExport {
159 embedding,
160 quality,
161 route,
162 context,
163 timestamp: now,
164 });
165
166 self.quality_samples.push(quality);
167 }
168
169 pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
171 self.engine.apply_micro_lora(input, output);
172 }
173
174 pub fn trajectory_count(&self) -> usize {
176 self.trajectories.len()
177 }
178
179 pub fn avg_quality(&self) -> f32 {
181 if self.quality_samples.is_empty() {
182 0.0
183 } else {
184 self.quality_samples.iter().sum::<f32>() / self.quality_samples.len() as f32
185 }
186 }
187
188 pub fn force_learn(&self) -> String {
190 self.engine.force_learn()
191 }
192
193 pub fn process_task(&mut self, embedding: Vec<f32>, quality: f32) {
195 self.process_trajectory(embedding.clone(), embedding, quality, None, vec![]);
196 }
197
198 pub fn process_task_with_route(&mut self, embedding: Vec<f32>, quality: f32, route: &str) {
200 self.process_trajectory(embedding.clone(), embedding, quality, Some(route.to_string()), vec![]);
201 }
202
203 pub fn average_quality(&self) -> f32 {
205 self.avg_quality()
206 }
207
208 pub fn uptime_seconds(&self) -> u64 {
210 let now = SystemTime::now()
211 .duration_since(UNIX_EPOCH)
212 .unwrap_or_default()
213 .as_millis() as u64;
214 (now - self.start_time) / 1000
215 }
216
217 pub fn stats(&self) -> AgentExportStats {
219 let engine_stats = self.engine.stats();
220 AgentExportStats {
221 total_trajectories: self.trajectories.len(),
222 avg_quality: self.avg_quality(),
223 patterns_learned: engine_stats.patterns_stored,
224 }
225 }
226
227 pub fn clear(&mut self) {
229 self.trajectories.clear();
230 self.quality_samples.clear();
231 }
232
233 pub fn get_patterns(&self) -> Vec<LearnedPattern> {
235 self.engine.find_patterns(&[], 0)
236 }
237
238 pub fn export_state(&self) -> AgentExport {
242 let now = SystemTime::now()
243 .duration_since(UNIX_EPOCH)
244 .unwrap_or_default()
245 .as_millis() as u64;
246
247 self.engine.force_learn();
249
250 let stats = self.engine.stats();
251
252 AgentExport {
253 agent_id: self.agent_id.clone(),
254 trajectories: self.trajectories.clone(),
255 stats: AgentExportStats {
256 total_trajectories: self.trajectories.len(),
257 avg_quality: self.avg_quality(),
258 patterns_learned: stats.patterns_stored,
259 },
260 session_duration_ms: now - self.start_time,
261 timestamp: now,
262 }
263 }
264}
265
266#[derive(Clone, Debug, Serialize, Deserialize)]
268pub struct AgentContribution {
269 pub trajectory_count: usize,
271 pub avg_quality: f32,
273 pub timestamp: u64,
275 pub session_duration_ms: u64,
277}
278
279pub struct FederatedCoordinator {
283 coordinator_id: String,
285 master_engine: SonaEngine,
287 contributions: HashMap<String, AgentContribution>,
289 quality_threshold: f32,
291 total_trajectories: usize,
293 consolidation_interval: usize,
295 metrics: TrainingMetrics,
297}
298
299impl FederatedCoordinator {
300 pub fn new(coordinator_id: impl Into<String>, config: SonaConfig) -> Self {
302 let id = coordinator_id.into();
303 Self {
304 coordinator_id: id.clone(),
305 master_engine: SonaEngine::with_config(config),
306 contributions: HashMap::new(),
307 quality_threshold: 0.4,
308 total_trajectories: 0,
309 consolidation_interval: 50,
310 metrics: TrainingMetrics::new(&id),
311 }
312 }
313
314 pub fn default_coordinator(coordinator_id: impl Into<String>, hidden_dim: usize) -> Self {
316 Self::new(coordinator_id, SonaConfig {
317 hidden_dim,
318 embedding_dim: hidden_dim,
319 micro_lora_rank: 2,
320 base_lora_rank: 16, trajectory_capacity: 50000, pattern_clusters: 200,
323 ewc_lambda: 2000.0, ..Default::default()
325 })
326 }
327
328 pub fn coordinator_id(&self) -> &str {
330 &self.coordinator_id
331 }
332
333 pub fn set_quality_threshold(&mut self, threshold: f32) {
335 self.quality_threshold = threshold;
336 }
337
338 pub fn set_consolidation_interval(&mut self, interval: usize) {
340 self.consolidation_interval = interval;
341 }
342
343 pub fn master_engine(&self) -> &SonaEngine {
345 &self.master_engine
346 }
347
348 pub fn aggregate(&mut self, export: AgentExport) -> AggregationResult {
350 let mut accepted = 0;
351 let mut rejected = 0;
352
353 for traj in &export.trajectories {
355 if traj.quality >= self.quality_threshold {
356 let mut builder = self.master_engine.begin_trajectory(traj.embedding.clone());
357 if let Some(ref route) = traj.route {
358 builder.set_model_route(route);
359 }
360 for ctx in &traj.context {
361 builder.add_context(ctx);
362 }
363 self.master_engine.end_trajectory(builder, traj.quality);
364
365 self.metrics.add_quality_sample(traj.quality);
366 accepted += 1;
367 } else {
368 rejected += 1;
369 }
370 }
371
372 self.total_trajectories += accepted;
373
374 let now = SystemTime::now()
376 .duration_since(UNIX_EPOCH)
377 .unwrap_or_default()
378 .as_millis() as u64;
379
380 self.contributions.insert(export.agent_id.clone(), AgentContribution {
381 trajectory_count: export.trajectories.len(),
382 avg_quality: export.stats.avg_quality,
383 timestamp: now,
384 session_duration_ms: export.session_duration_ms,
385 });
386
387 let consolidated = if self.should_consolidate() {
389 self.master_engine.force_learn();
390 true
391 } else {
392 false
393 };
394
395 AggregationResult {
396 agent_id: export.agent_id,
397 trajectories_accepted: accepted,
398 trajectories_rejected: rejected,
399 consolidated,
400 total_agents: self.contributions.len(),
401 total_trajectories: self.total_trajectories,
402 }
403 }
404
405 fn should_consolidate(&self) -> bool {
407 self.contributions.len() % self.consolidation_interval == 0
408 }
409
410 pub fn force_consolidate(&self) -> String {
412 self.master_engine.force_learn()
413 }
414
415 pub fn get_initial_patterns(&self, k: usize) -> Vec<LearnedPattern> {
419 self.master_engine.find_patterns(&[], 0)
422 .into_iter()
423 .take(k)
424 .collect()
425 }
426
427 pub fn get_all_patterns(&self) -> Vec<LearnedPattern> {
429 self.master_engine.find_patterns(&[], 0)
430 }
431
432 pub fn stats(&self) -> CoordinatorStats {
434 let engine_stats = self.master_engine.stats();
435
436 CoordinatorStats {
437 coordinator_id: self.coordinator_id.clone(),
438 total_agents: self.contributions.len(),
439 total_trajectories: self.total_trajectories,
440 patterns_learned: engine_stats.patterns_stored,
441 avg_quality: self.metrics.avg_quality(),
442 quality_threshold: self.quality_threshold,
443 }
444 }
445
446 pub fn contributions(&self) -> &HashMap<String, AgentContribution> {
448 &self.contributions
449 }
450
451 pub fn metrics(&self) -> &TrainingMetrics {
453 &self.metrics
454 }
455
456 pub fn agent_count(&self) -> usize {
458 self.contributions.len()
459 }
460
461 pub fn total_trajectories(&self) -> usize {
463 self.total_trajectories
464 }
465
466 pub fn find_patterns(&self, query: &[f32], k: usize) -> Vec<LearnedPattern> {
468 self.master_engine.find_patterns(query, k)
469 }
470
471 pub fn apply_lora(&self, input: &[f32]) -> Vec<f32> {
473 let mut output = vec![0.0; input.len()];
474 self.master_engine.apply_micro_lora(input, &mut output);
475 output
476 }
477
478 pub fn consolidate(&self) -> String {
480 self.force_consolidate()
481 }
482
483 pub fn clear(&mut self) {
485 self.contributions.clear();
486 self.total_trajectories = 0;
487 }
488}
489
490#[derive(Clone, Debug, Serialize, Deserialize)]
492pub struct AggregationResult {
493 pub agent_id: String,
495 pub trajectories_accepted: usize,
497 pub trajectories_rejected: usize,
499 pub consolidated: bool,
501 pub total_agents: usize,
503 pub total_trajectories: usize,
505}
506
507#[derive(Clone, Debug, Serialize, Deserialize)]
509pub struct CoordinatorStats {
510 pub coordinator_id: String,
512 pub total_agents: usize,
514 pub total_trajectories: usize,
516 pub patterns_learned: usize,
518 pub avg_quality: f32,
520 pub quality_threshold: f32,
522}
523
524impl std::fmt::Display for CoordinatorStats {
525 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526 write!(
527 f,
528 "Coordinator(id={}, agents={}, trajectories={}, patterns={}, avg_quality={:.4})",
529 self.coordinator_id,
530 self.total_agents,
531 self.total_trajectories,
532 self.patterns_learned,
533 self.avg_quality
534 )
535 }
536}
537
538#[derive(Clone, Debug, Serialize, Deserialize)]
540pub enum FederatedTopology {
541 Star,
543 Hierarchical {
545 regions: usize,
547 },
548 PeerToPeer,
550}
551
552impl Default for FederatedTopology {
553 fn default() -> Self {
554 FederatedTopology::Star
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561
562 #[test]
563 fn test_ephemeral_agent_creation() {
564 let agent = EphemeralAgent::default_federated("agent-1", 256);
565 assert_eq!(agent.agent_id(), "agent-1");
566 assert_eq!(agent.trajectory_count(), 0);
567 }
568
569 #[test]
570 fn test_trajectory_collection() {
571 let mut agent = EphemeralAgent::default_federated("agent-1", 256);
572
573 agent.process_trajectory(
574 vec![0.1; 256],
575 vec![0.5; 256],
576 0.8,
577 Some("code".into()),
578 vec!["file:main.rs".into()],
579 );
580
581 assert_eq!(agent.trajectory_count(), 1);
582 assert!((agent.avg_quality() - 0.8).abs() < 0.01);
583 }
584
585 #[test]
586 fn test_agent_export() {
587 let mut agent = EphemeralAgent::default_federated("agent-1", 256);
588
589 for i in 0..5 {
590 agent.process_trajectory(
591 vec![i as f32 * 0.1; 256],
592 vec![0.5; 256],
593 0.7 + i as f32 * 0.05,
594 None,
595 vec![],
596 );
597 }
598
599 let export = agent.export_state();
600 assert_eq!(export.agent_id, "agent-1");
601 assert_eq!(export.trajectories.len(), 5);
602 assert!(export.stats.avg_quality > 0.7);
603 }
604
605 #[test]
606 fn test_coordinator_creation() {
607 let coord = FederatedCoordinator::default_coordinator("coord-1", 256);
608 assert_eq!(coord.coordinator_id(), "coord-1");
609
610 let stats = coord.stats();
611 assert_eq!(stats.total_agents, 0);
612 assert_eq!(stats.total_trajectories, 0);
613 }
614
615 #[test]
616 fn test_aggregation() {
617 let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
618 coord.set_quality_threshold(0.5);
619
620 let export = AgentExport {
622 agent_id: "agent-1".into(),
623 trajectories: vec![
624 TrajectoryExport {
625 embedding: vec![0.1; 256],
626 quality: 0.8,
627 route: Some("code".into()),
628 context: vec![],
629 timestamp: 0,
630 },
631 TrajectoryExport {
632 embedding: vec![0.2; 256],
633 quality: 0.3, route: None,
635 context: vec![],
636 timestamp: 0,
637 },
638 ],
639 stats: AgentExportStats {
640 total_trajectories: 2,
641 avg_quality: 0.55,
642 patterns_learned: 0,
643 },
644 session_duration_ms: 1000,
645 timestamp: 0,
646 };
647
648 let result = coord.aggregate(export);
649 assert_eq!(result.trajectories_accepted, 1);
650 assert_eq!(result.trajectories_rejected, 1);
651 assert_eq!(result.total_agents, 1);
652 }
653
654 #[test]
655 fn test_multi_agent_aggregation() {
656 let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
657 coord.set_consolidation_interval(2); for i in 0..3 {
660 let export = AgentExport {
661 agent_id: format!("agent-{}", i),
662 trajectories: vec![
663 TrajectoryExport {
664 embedding: vec![i as f32 * 0.1; 256],
665 quality: 0.8,
666 route: None,
667 context: vec![],
668 timestamp: 0,
669 },
670 ],
671 stats: AgentExportStats::default(),
672 session_duration_ms: 1000,
673 timestamp: 0,
674 };
675
676 let result = coord.aggregate(export);
677 if i == 1 {
679 assert!(result.consolidated);
680 }
681 }
682
683 let stats = coord.stats();
684 assert_eq!(stats.total_agents, 3);
685 assert_eq!(stats.total_trajectories, 3);
686 }
687}