1use super::metrics::TrainingMetrics;
23use crate::engine::SonaEngine;
24use crate::time_compat::SystemTime;
25use crate::types::{LearnedPattern, SonaConfig};
26use serde::{Deserialize, Serialize};
27use std::collections::HashMap;
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct AgentExport {
32 pub agent_id: String,
34 pub trajectories: Vec<TrajectoryExport>,
36 pub stats: AgentExportStats,
38 pub session_duration_ms: u64,
40 pub timestamp: u64,
42}
43
44#[derive(Clone, Debug, Serialize, Deserialize)]
46pub struct TrajectoryExport {
47 pub embedding: Vec<f32>,
49 pub quality: f32,
51 pub route: Option<String>,
53 pub context: Vec<String>,
55 pub timestamp: u64,
57}
58
59#[derive(Clone, Debug, Default, Serialize, Deserialize)]
61pub struct AgentExportStats {
62 pub total_trajectories: usize,
64 pub avg_quality: f32,
66 pub patterns_learned: usize,
68}
69
70pub struct EphemeralAgent {
74 agent_id: String,
76 engine: SonaEngine,
78 trajectories: Vec<TrajectoryExport>,
80 start_time: u64,
82 quality_samples: Vec<f32>,
84}
85
86impl EphemeralAgent {
87 pub fn new(agent_id: impl Into<String>, config: SonaConfig) -> Self {
89 let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
90
91 Self {
92 agent_id: agent_id.into(),
93 engine: SonaEngine::with_config(config),
94 trajectories: Vec::new(),
95 start_time: now,
96 quality_samples: Vec::new(),
97 }
98 }
99
100 pub fn default_federated(agent_id: impl Into<String>, hidden_dim: usize) -> Self {
102 Self::new(
103 agent_id,
104 SonaConfig {
105 hidden_dim,
106 embedding_dim: hidden_dim,
107 micro_lora_rank: 2,
108 base_lora_rank: 8,
109 micro_lora_lr: 0.002,
110 trajectory_capacity: 500, pattern_clusters: 25,
112 ..Default::default()
113 },
114 )
115 }
116
117 pub fn agent_id(&self) -> &str {
119 &self.agent_id
120 }
121
122 pub fn engine(&self) -> &SonaEngine {
124 &self.engine
125 }
126
127 pub fn engine_mut(&mut self) -> &mut SonaEngine {
129 &mut self.engine
130 }
131
132 pub fn process_trajectory(
134 &mut self,
135 embedding: Vec<f32>,
136 activations: Vec<f32>,
137 quality: f32,
138 route: Option<String>,
139 context: Vec<String>,
140 ) {
141 let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
142
143 let mut builder = self.engine.begin_trajectory(embedding.clone());
145 if let Some(ref r) = route {
146 builder.set_model_route(r);
147 }
148 for ctx in &context {
149 builder.add_context(ctx);
150 }
151 builder.add_step(activations, vec![], quality);
152 self.engine.end_trajectory(builder, quality);
153
154 self.trajectories.push(TrajectoryExport {
156 embedding,
157 quality,
158 route,
159 context,
160 timestamp: now,
161 });
162
163 self.quality_samples.push(quality);
164 }
165
166 pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
168 self.engine.apply_micro_lora(input, output);
169 }
170
171 pub fn trajectory_count(&self) -> usize {
173 self.trajectories.len()
174 }
175
176 pub fn avg_quality(&self) -> f32 {
178 if self.quality_samples.is_empty() {
179 0.0
180 } else {
181 self.quality_samples.iter().sum::<f32>() / self.quality_samples.len() as f32
182 }
183 }
184
185 pub fn force_learn(&self) -> String {
187 self.engine.force_learn()
188 }
189
190 pub fn process_task(&mut self, embedding: Vec<f32>, quality: f32) {
192 self.process_trajectory(embedding.clone(), embedding, quality, None, vec![]);
193 }
194
195 pub fn process_task_with_route(&mut self, embedding: Vec<f32>, quality: f32, route: &str) {
197 self.process_trajectory(
198 embedding.clone(),
199 embedding,
200 quality,
201 Some(route.to_string()),
202 vec![],
203 );
204 }
205
206 pub fn average_quality(&self) -> f32 {
208 self.avg_quality()
209 }
210
211 pub fn uptime_seconds(&self) -> u64 {
213 let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
214 (now - self.start_time) / 1000
215 }
216
217 pub fn stats(&self) -> AgentExportStats {
219 let engine_stats = self.engine.stats();
220 AgentExportStats {
221 total_trajectories: self.trajectories.len(),
222 avg_quality: self.avg_quality(),
223 patterns_learned: engine_stats.patterns_stored,
224 }
225 }
226
227 pub fn clear(&mut self) {
229 self.trajectories.clear();
230 self.quality_samples.clear();
231 }
232
233 pub fn get_patterns(&self) -> Vec<LearnedPattern> {
235 self.engine.find_patterns(&[], 0)
236 }
237
238 pub fn export_state(&self) -> AgentExport {
242 let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
243
244 self.engine.force_learn();
246
247 let stats = self.engine.stats();
248
249 AgentExport {
250 agent_id: self.agent_id.clone(),
251 trajectories: self.trajectories.clone(),
252 stats: AgentExportStats {
253 total_trajectories: self.trajectories.len(),
254 avg_quality: self.avg_quality(),
255 patterns_learned: stats.patterns_stored,
256 },
257 session_duration_ms: now - self.start_time,
258 timestamp: now,
259 }
260 }
261}
262
263#[derive(Clone, Debug, Serialize, Deserialize)]
265pub struct AgentContribution {
266 pub trajectory_count: usize,
268 pub avg_quality: f32,
270 pub timestamp: u64,
272 pub session_duration_ms: u64,
274}
275
276pub struct FederatedCoordinator {
280 coordinator_id: String,
282 master_engine: SonaEngine,
284 contributions: HashMap<String, AgentContribution>,
286 quality_threshold: f32,
288 total_trajectories: usize,
290 consolidation_interval: usize,
292 metrics: TrainingMetrics,
294}
295
296impl FederatedCoordinator {
297 pub fn new(coordinator_id: impl Into<String>, config: SonaConfig) -> Self {
299 let id = coordinator_id.into();
300 Self {
301 coordinator_id: id.clone(),
302 master_engine: SonaEngine::with_config(config),
303 contributions: HashMap::new(),
304 quality_threshold: 0.4,
305 total_trajectories: 0,
306 consolidation_interval: 50,
307 metrics: TrainingMetrics::new(&id),
308 }
309 }
310
311 pub fn default_coordinator(coordinator_id: impl Into<String>, hidden_dim: usize) -> Self {
313 Self::new(
314 coordinator_id,
315 SonaConfig {
316 hidden_dim,
317 embedding_dim: hidden_dim,
318 micro_lora_rank: 2,
319 base_lora_rank: 16, trajectory_capacity: 50000, pattern_clusters: 200,
322 ewc_lambda: 2000.0, ..Default::default()
324 },
325 )
326 }
327
328 pub fn coordinator_id(&self) -> &str {
330 &self.coordinator_id
331 }
332
333 pub fn set_quality_threshold(&mut self, threshold: f32) {
335 self.quality_threshold = threshold;
336 }
337
338 pub fn set_consolidation_interval(&mut self, interval: usize) {
340 self.consolidation_interval = interval;
341 }
342
343 pub fn master_engine(&self) -> &SonaEngine {
345 &self.master_engine
346 }
347
348 pub fn aggregate(&mut self, export: AgentExport) -> AggregationResult {
350 let mut accepted = 0;
351 let mut rejected = 0;
352
353 for traj in &export.trajectories {
355 if traj.quality >= self.quality_threshold {
356 let mut builder = self.master_engine.begin_trajectory(traj.embedding.clone());
357 if let Some(ref route) = traj.route {
358 builder.set_model_route(route);
359 }
360 for ctx in &traj.context {
361 builder.add_context(ctx);
362 }
363 self.master_engine.end_trajectory(builder, traj.quality);
364
365 self.metrics.add_quality_sample(traj.quality);
366 accepted += 1;
367 } else {
368 rejected += 1;
369 }
370 }
371
372 self.total_trajectories += accepted;
373
374 let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
376
377 self.contributions.insert(
378 export.agent_id.clone(),
379 AgentContribution {
380 trajectory_count: export.trajectories.len(),
381 avg_quality: export.stats.avg_quality,
382 timestamp: now,
383 session_duration_ms: export.session_duration_ms,
384 },
385 );
386
387 let consolidated = if self.should_consolidate() {
389 self.master_engine.force_learn();
390 true
391 } else {
392 false
393 };
394
395 AggregationResult {
396 agent_id: export.agent_id,
397 trajectories_accepted: accepted,
398 trajectories_rejected: rejected,
399 consolidated,
400 total_agents: self.contributions.len(),
401 total_trajectories: self.total_trajectories,
402 }
403 }
404
405 fn should_consolidate(&self) -> bool {
407 self.contributions.len() % self.consolidation_interval == 0
408 }
409
410 pub fn force_consolidate(&self) -> String {
412 self.master_engine.force_learn()
413 }
414
415 pub fn get_initial_patterns(&self, k: usize) -> Vec<LearnedPattern> {
419 self.master_engine
422 .find_patterns(&[], 0)
423 .into_iter()
424 .take(k)
425 .collect()
426 }
427
428 pub fn get_all_patterns(&self) -> Vec<LearnedPattern> {
430 self.master_engine.find_patterns(&[], 0)
431 }
432
433 pub fn stats(&self) -> CoordinatorStats {
435 let engine_stats = self.master_engine.stats();
436
437 CoordinatorStats {
438 coordinator_id: self.coordinator_id.clone(),
439 total_agents: self.contributions.len(),
440 total_trajectories: self.total_trajectories,
441 patterns_learned: engine_stats.patterns_stored,
442 avg_quality: self.metrics.avg_quality(),
443 quality_threshold: self.quality_threshold,
444 }
445 }
446
447 pub fn contributions(&self) -> &HashMap<String, AgentContribution> {
449 &self.contributions
450 }
451
452 pub fn metrics(&self) -> &TrainingMetrics {
454 &self.metrics
455 }
456
457 pub fn agent_count(&self) -> usize {
459 self.contributions.len()
460 }
461
462 pub fn total_trajectories(&self) -> usize {
464 self.total_trajectories
465 }
466
467 pub fn find_patterns(&self, query: &[f32], k: usize) -> Vec<LearnedPattern> {
469 self.master_engine.find_patterns(query, k)
470 }
471
472 pub fn apply_lora(&self, input: &[f32]) -> Vec<f32> {
474 let mut output = vec![0.0; input.len()];
475 self.master_engine.apply_micro_lora(input, &mut output);
476 output
477 }
478
479 pub fn consolidate(&self) -> String {
481 self.force_consolidate()
482 }
483
484 pub fn clear(&mut self) {
486 self.contributions.clear();
487 self.total_trajectories = 0;
488 }
489}
490
491#[derive(Clone, Debug, Serialize, Deserialize)]
493pub struct AggregationResult {
494 pub agent_id: String,
496 pub trajectories_accepted: usize,
498 pub trajectories_rejected: usize,
500 pub consolidated: bool,
502 pub total_agents: usize,
504 pub total_trajectories: usize,
506}
507
508#[derive(Clone, Debug, Serialize, Deserialize)]
510pub struct CoordinatorStats {
511 pub coordinator_id: String,
513 pub total_agents: usize,
515 pub total_trajectories: usize,
517 pub patterns_learned: usize,
519 pub avg_quality: f32,
521 pub quality_threshold: f32,
523}
524
525impl std::fmt::Display for CoordinatorStats {
526 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527 write!(
528 f,
529 "Coordinator(id={}, agents={}, trajectories={}, patterns={}, avg_quality={:.4})",
530 self.coordinator_id,
531 self.total_agents,
532 self.total_trajectories,
533 self.patterns_learned,
534 self.avg_quality
535 )
536 }
537}
538
539#[derive(Clone, Debug, Default, Serialize, Deserialize)]
541pub enum FederatedTopology {
542 #[default]
544 Star,
545 Hierarchical {
547 regions: usize,
549 },
550 PeerToPeer,
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557
558 #[test]
559 fn test_ephemeral_agent_creation() {
560 let agent = EphemeralAgent::default_federated("agent-1", 256);
561 assert_eq!(agent.agent_id(), "agent-1");
562 assert_eq!(agent.trajectory_count(), 0);
563 }
564
565 #[test]
566 fn test_trajectory_collection() {
567 let mut agent = EphemeralAgent::default_federated("agent-1", 256);
568
569 agent.process_trajectory(
570 vec![0.1; 256],
571 vec![0.5; 256],
572 0.8,
573 Some("code".into()),
574 vec!["file:main.rs".into()],
575 );
576
577 assert_eq!(agent.trajectory_count(), 1);
578 assert!((agent.avg_quality() - 0.8).abs() < 0.01);
579 }
580
581 #[test]
582 fn test_agent_export() {
583 let mut agent = EphemeralAgent::default_federated("agent-1", 256);
584
585 for i in 0..5 {
586 agent.process_trajectory(
587 vec![i as f32 * 0.1; 256],
588 vec![0.5; 256],
589 0.7 + i as f32 * 0.05,
590 None,
591 vec![],
592 );
593 }
594
595 let export = agent.export_state();
596 assert_eq!(export.agent_id, "agent-1");
597 assert_eq!(export.trajectories.len(), 5);
598 assert!(export.stats.avg_quality > 0.7);
599 }
600
601 #[test]
602 fn test_coordinator_creation() {
603 let coord = FederatedCoordinator::default_coordinator("coord-1", 256);
604 assert_eq!(coord.coordinator_id(), "coord-1");
605
606 let stats = coord.stats();
607 assert_eq!(stats.total_agents, 0);
608 assert_eq!(stats.total_trajectories, 0);
609 }
610
611 #[test]
612 fn test_aggregation() {
613 let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
614 coord.set_quality_threshold(0.5);
615
616 let export = AgentExport {
618 agent_id: "agent-1".into(),
619 trajectories: vec![
620 TrajectoryExport {
621 embedding: vec![0.1; 256],
622 quality: 0.8,
623 route: Some("code".into()),
624 context: vec![],
625 timestamp: 0,
626 },
627 TrajectoryExport {
628 embedding: vec![0.2; 256],
629 quality: 0.3, route: None,
631 context: vec![],
632 timestamp: 0,
633 },
634 ],
635 stats: AgentExportStats {
636 total_trajectories: 2,
637 avg_quality: 0.55,
638 patterns_learned: 0,
639 },
640 session_duration_ms: 1000,
641 timestamp: 0,
642 };
643
644 let result = coord.aggregate(export);
645 assert_eq!(result.trajectories_accepted, 1);
646 assert_eq!(result.trajectories_rejected, 1);
647 assert_eq!(result.total_agents, 1);
648 }
649
650 #[test]
651 fn test_multi_agent_aggregation() {
652 let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
653 coord.set_consolidation_interval(2); for i in 0..3 {
656 let export = AgentExport {
657 agent_id: format!("agent-{}", i),
658 trajectories: vec![TrajectoryExport {
659 embedding: vec![i as f32 * 0.1; 256],
660 quality: 0.8,
661 route: None,
662 context: vec![],
663 timestamp: 0,
664 }],
665 stats: AgentExportStats::default(),
666 session_duration_ms: 1000,
667 timestamp: 0,
668 };
669
670 let result = coord.aggregate(export);
671 if i == 1 {
673 assert!(result.consolidated);
674 }
675 }
676
677 let stats = coord.stats();
678 assert_eq!(stats.total_agents, 3);
679 assert_eq!(stats.total_trajectories, 3);
680 }
681}