1use super::metrics::TrainingMetrics;
6use super::templates::{AgentType, TrainingTemplate};
7use crate::engine::SonaEngine;
8use crate::time_compat::SystemTime;
9use crate::types::SonaConfig;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13
14#[derive(Clone, Debug)]
16pub struct AgentHandle {
17 pub id: String,
19 pub agent_type: AgentType,
21 pub created_at: u64,
23}
24
25pub struct ManagedAgent {
27 pub handle: AgentHandle,
29 pub engine: SonaEngine,
31 pub metrics: TrainingMetrics,
33 pub purpose: String,
35 pub training_count: u64,
37 pub tags: Vec<String>,
39}
40
41impl ManagedAgent {
42 pub fn new(
44 id: impl Into<String>,
45 agent_type: AgentType,
46 config: SonaConfig,
47 purpose: impl Into<String>,
48 ) -> Self {
49 let now = SystemTime::now().duration_since_epoch().as_secs();
50
51 let id = id.into();
52 Self {
53 handle: AgentHandle {
54 id: id.clone(),
55 agent_type,
56 created_at: now,
57 },
58 engine: SonaEngine::with_config(config),
59 metrics: TrainingMetrics::new(&id),
60 purpose: purpose.into(),
61 training_count: 0,
62 tags: Vec::new(),
63 }
64 }
65
66 pub fn stats(&self) -> AgentStats {
68 AgentStats {
69 id: self.handle.id.clone(),
70 agent_type: self.handle.agent_type.clone(),
71 training_count: self.training_count,
72 patterns_learned: self.metrics.patterns_learned,
73 avg_quality: self.metrics.avg_quality(),
74 total_examples: self.metrics.total_examples,
75 }
76 }
77}
78
79#[derive(Clone, Debug, Serialize, Deserialize)]
81pub struct AgentStats {
82 pub id: String,
84 pub agent_type: AgentType,
86 pub training_count: u64,
88 pub patterns_learned: usize,
90 pub avg_quality: f32,
92 pub total_examples: usize,
94}
95
96pub struct AgentFactory {
98 base_config: SonaConfig,
100 agents: HashMap<String, ManagedAgent>,
102 default_hidden_dim: usize,
104}
105
106impl Default for AgentFactory {
107 fn default() -> Self {
108 Self::new(SonaConfig::default())
109 }
110}
111
112impl AgentFactory {
113 pub fn new(base_config: SonaConfig) -> Self {
115 let default_hidden_dim = base_config.hidden_dim;
116 Self {
117 base_config,
118 agents: HashMap::new(),
119 default_hidden_dim,
120 }
121 }
122
123 pub fn with_hidden_dim(hidden_dim: usize) -> Self {
125 let config = SonaConfig {
126 hidden_dim,
127 embedding_dim: hidden_dim,
128 ..SonaConfig::default()
129 };
130 Self::new(config)
131 }
132
133 pub fn create_from_template(
135 &mut self,
136 name: impl Into<String>,
137 template: &TrainingTemplate,
138 ) -> &ManagedAgent {
139 let name = name.into();
140 let agent = ManagedAgent::new(
141 name.clone(),
142 template.agent_type.clone(),
143 template.sona_config.clone(),
144 &template.name,
145 );
146 self.agents.insert(name.clone(), agent);
147 self.agents.get(&name).unwrap()
148 }
149
150 pub fn create_agent(
152 &mut self,
153 name: impl Into<String>,
154 agent_type: AgentType,
155 purpose: impl Into<String>,
156 ) -> &ManagedAgent {
157 let name = name.into();
158 let config = self.config_for_agent_type(&agent_type);
159 let mut agent = ManagedAgent::new(name.clone(), agent_type, config, purpose);
160 agent.tags.push("custom".into());
161 self.agents.insert(name.clone(), agent);
162 self.agents.get(&name).unwrap()
163 }
164
165 pub fn create_code_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
167 let template = TrainingTemplate::code_agent().with_hidden_dim(self.default_hidden_dim);
168 self.create_from_template(name, &template)
169 }
170
171 pub fn create_chat_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
173 let template = TrainingTemplate::chat_agent().with_hidden_dim(self.default_hidden_dim);
174 self.create_from_template(name, &template)
175 }
176
177 pub fn create_rag_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
179 let template = TrainingTemplate::rag_agent().with_hidden_dim(self.default_hidden_dim);
180 self.create_from_template(name, &template)
181 }
182
183 pub fn create_task_planner(&mut self, name: impl Into<String>) -> &ManagedAgent {
185 let template = TrainingTemplate::task_planner().with_hidden_dim(self.default_hidden_dim);
186 self.create_from_template(name, &template)
187 }
188
189 pub fn create_reasoning_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
191 let template = TrainingTemplate::reasoning_agent().with_hidden_dim(self.default_hidden_dim);
192 self.create_from_template(name, &template)
193 }
194
195 pub fn create_codebase_helper(&mut self, name: impl Into<String>) -> &ManagedAgent {
197 let template = TrainingTemplate::codebase_helper().with_hidden_dim(self.default_hidden_dim);
198 self.create_from_template(name, &template)
199 }
200
201 pub fn get_agent(&self, name: &str) -> Option<&ManagedAgent> {
203 self.agents.get(name)
204 }
205
206 pub fn get_agent_mut(&mut self, name: &str) -> Option<&mut ManagedAgent> {
208 self.agents.get_mut(name)
209 }
210
211 pub fn remove_agent(&mut self, name: &str) -> Option<ManagedAgent> {
213 self.agents.remove(name)
214 }
215
216 pub fn list_agents(&self) -> Vec<AgentStats> {
218 self.agents.values().map(|a| a.stats()).collect()
219 }
220
221 pub fn agent_count(&self) -> usize {
223 self.agents.len()
224 }
225
226 pub fn train_agent<E>(
228 &mut self,
229 name: &str,
230 examples: impl Iterator<Item = E>,
231 ) -> Result<usize, String>
232 where
233 E: TrainingExample,
234 {
235 let agent = self
236 .agents
237 .get_mut(name)
238 .ok_or_else(|| format!("Agent '{}' not found", name))?;
239
240 let mut count = 0;
241 for example in examples {
242 let mut builder = agent.engine.begin_trajectory(example.embedding());
244
245 if let Some(route) = example.route() {
247 builder.set_model_route(&route);
248 }
249
250 for ctx in example.context() {
252 builder.add_context(&ctx);
253 }
254
255 builder.add_step(example.activations(), example.attention(), example.reward());
257
258 agent.engine.end_trajectory(builder, example.quality());
260
261 count += 1;
262 agent.metrics.total_examples += 1;
263 agent.metrics.add_quality_sample(example.quality());
264 }
265
266 agent.engine.force_learn();
268 agent.training_count += 1;
269 agent.metrics.training_sessions += 1;
270
271 Ok(count)
272 }
273
274 fn config_for_agent_type(&self, agent_type: &AgentType) -> SonaConfig {
276 let mut config = self.base_config.clone();
277
278 match agent_type {
279 AgentType::CodeAgent | AgentType::CodebaseHelper => {
280 config.base_lora_rank = 16;
281 config.pattern_clusters = 200;
282 config.quality_threshold = 0.2;
283 }
284 AgentType::ChatAgent => {
285 config.base_lora_rank = 8;
286 config.pattern_clusters = 50;
287 config.quality_threshold = 0.4;
288 }
289 AgentType::RagAgent => {
290 config.pattern_clusters = 200;
291 config.trajectory_capacity = 10000;
292 }
293 AgentType::TaskPlanner => {
294 config.base_lora_rank = 16;
295 config.ewc_lambda = 2000.0;
296 }
297 AgentType::ReasoningAgent => {
298 config.base_lora_rank = 16;
299 config.ewc_lambda = 3000.0;
300 config.pattern_clusters = 150;
301 }
302 AgentType::DomainExpert => {
303 config.quality_threshold = 0.1;
304 config.trajectory_capacity = 20000;
305 }
306 AgentType::DataAnalyst => {
307 config.base_lora_rank = 8;
308 config.pattern_clusters = 100;
309 }
310 AgentType::CreativeWriter => {
311 config.base_lora_rank = 8;
312 config.pattern_clusters = 50;
313 config.quality_threshold = 0.5;
314 }
315 _ => {}
316 }
317
318 config
319 }
320}
321
322pub trait TrainingExample {
324 fn embedding(&self) -> Vec<f32>;
326
327 fn activations(&self) -> Vec<f32> {
329 self.embedding()
330 }
331
332 fn attention(&self) -> Vec<f32> {
334 vec![1.0 / 64.0; 64]
335 }
336
337 fn reward(&self) -> f32 {
339 self.quality()
340 }
341
342 fn quality(&self) -> f32;
344
345 fn route(&self) -> Option<String> {
347 None
348 }
349
350 fn context(&self) -> Vec<String> {
352 Vec::new()
353 }
354}
355
356#[derive(Clone, Debug)]
358pub struct SimpleExample {
359 pub embedding: Vec<f32>,
361 pub quality: f32,
363 pub route: Option<String>,
365 pub context: Vec<String>,
367}
368
369impl SimpleExample {
370 pub fn new(embedding: Vec<f32>, quality: f32) -> Self {
372 Self {
373 embedding,
374 quality,
375 route: None,
376 context: Vec::new(),
377 }
378 }
379
380 pub fn with_route(mut self, route: impl Into<String>) -> Self {
382 self.route = Some(route.into());
383 self
384 }
385
386 pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
388 self.context.push(ctx.into());
389 self
390 }
391}
392
393impl TrainingExample for SimpleExample {
394 fn embedding(&self) -> Vec<f32> {
395 self.embedding.clone()
396 }
397
398 fn quality(&self) -> f32 {
399 self.quality
400 }
401
402 fn route(&self) -> Option<String> {
403 self.route.clone()
404 }
405
406 fn context(&self) -> Vec<String> {
407 self.context.clone()
408 }
409}
410
411pub struct SharedAgentFactory {
413 inner: Arc<RwLock<AgentFactory>>,
414}
415
416impl SharedAgentFactory {
417 pub fn new(config: SonaConfig) -> Self {
419 Self {
420 inner: Arc::new(RwLock::new(AgentFactory::new(config))),
421 }
422 }
423
424 pub fn read(&self) -> std::sync::RwLockReadGuard<'_, AgentFactory> {
426 self.inner.read().unwrap()
427 }
428
429 pub fn write(&self) -> std::sync::RwLockWriteGuard<'_, AgentFactory> {
431 self.inner.write().unwrap()
432 }
433
434 pub fn clone_arc(&self) -> Self {
436 Self {
437 inner: Arc::clone(&self.inner),
438 }
439 }
440}
441
442impl Clone for SharedAgentFactory {
443 fn clone(&self) -> Self {
444 self.clone_arc()
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451
452 #[test]
453 fn test_factory_creation() {
454 let factory = AgentFactory::default();
455 assert_eq!(factory.agent_count(), 0);
456 }
457
458 #[test]
459 fn test_create_agents() {
460 let mut factory = AgentFactory::with_hidden_dim(256);
461
462 factory.create_code_agent("code-1");
463 factory.create_chat_agent("chat-1");
464 factory.create_rag_agent("rag-1");
465
466 assert_eq!(factory.agent_count(), 3);
467 assert!(factory.get_agent("code-1").is_some());
468 assert!(factory.get_agent("unknown").is_none());
469 }
470
471 #[test]
472 fn test_agent_from_template() {
473 let mut factory = AgentFactory::with_hidden_dim(256);
474 let template = TrainingTemplate::reasoning_agent().with_hidden_dim(256);
475
476 factory.create_from_template("reasoner", &template);
477
478 let agent = factory.get_agent("reasoner").unwrap();
479 assert_eq!(agent.handle.agent_type, AgentType::ReasoningAgent);
480 }
481
482 #[test]
483 fn test_train_agent() {
484 let mut factory = AgentFactory::with_hidden_dim(256);
485 factory.create_chat_agent("bot");
486
487 let examples = vec![
488 SimpleExample::new(vec![0.1; 256], 0.8).with_route("greeting"),
489 SimpleExample::new(vec![0.2; 256], 0.9).with_route("question"),
490 SimpleExample::new(vec![0.3; 256], 0.7).with_route("farewell"),
491 ];
492
493 let count = factory.train_agent("bot", examples.into_iter()).unwrap();
494 assert_eq!(count, 3);
495
496 let agent = factory.get_agent("bot").unwrap();
497 assert_eq!(agent.training_count, 1);
498 assert_eq!(agent.metrics.total_examples, 3);
499 }
500
501 #[test]
502 fn test_list_agents() {
503 let mut factory = AgentFactory::with_hidden_dim(256);
504 factory.create_code_agent("code");
505 factory.create_chat_agent("chat");
506
507 let agents = factory.list_agents();
508 assert_eq!(agents.len(), 2);
509 }
510}