1use crate::engine::SonaEngine;
6use crate::types::SonaConfig;
7use super::templates::{TrainingTemplate, AgentType};
8use super::metrics::TrainingMetrics;
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use serde::{Deserialize, Serialize};
12
13#[derive(Clone, Debug)]
15pub struct AgentHandle {
16 pub id: String,
18 pub agent_type: AgentType,
20 pub created_at: u64,
22}
23
24pub struct ManagedAgent {
26 pub handle: AgentHandle,
28 pub engine: SonaEngine,
30 pub metrics: TrainingMetrics,
32 pub purpose: String,
34 pub training_count: u64,
36 pub tags: Vec<String>,
38}
39
40impl ManagedAgent {
41 pub fn new(
43 id: impl Into<String>,
44 agent_type: AgentType,
45 config: SonaConfig,
46 purpose: impl Into<String>,
47 ) -> Self {
48 let now = std::time::SystemTime::now()
49 .duration_since(std::time::UNIX_EPOCH)
50 .unwrap_or_default()
51 .as_secs();
52
53 let id = id.into();
54 Self {
55 handle: AgentHandle {
56 id: id.clone(),
57 agent_type,
58 created_at: now,
59 },
60 engine: SonaEngine::with_config(config),
61 metrics: TrainingMetrics::new(&id),
62 purpose: purpose.into(),
63 training_count: 0,
64 tags: Vec::new(),
65 }
66 }
67
68 pub fn stats(&self) -> AgentStats {
70 AgentStats {
71 id: self.handle.id.clone(),
72 agent_type: self.handle.agent_type.clone(),
73 training_count: self.training_count,
74 patterns_learned: self.metrics.patterns_learned,
75 avg_quality: self.metrics.avg_quality(),
76 total_examples: self.metrics.total_examples,
77 }
78 }
79}
80
81#[derive(Clone, Debug, Serialize, Deserialize)]
83pub struct AgentStats {
84 pub id: String,
86 pub agent_type: AgentType,
88 pub training_count: u64,
90 pub patterns_learned: usize,
92 pub avg_quality: f32,
94 pub total_examples: usize,
96}
97
98pub struct AgentFactory {
100 base_config: SonaConfig,
102 agents: HashMap<String, ManagedAgent>,
104 default_hidden_dim: usize,
106}
107
108impl AgentFactory {
109 pub fn new(base_config: SonaConfig) -> Self {
111 let default_hidden_dim = base_config.hidden_dim;
112 Self {
113 base_config,
114 agents: HashMap::new(),
115 default_hidden_dim,
116 }
117 }
118
119 pub fn default() -> Self {
121 Self::new(SonaConfig::default())
122 }
123
124 pub fn with_hidden_dim(hidden_dim: usize) -> Self {
126 let mut config = SonaConfig::default();
127 config.hidden_dim = hidden_dim;
128 config.embedding_dim = hidden_dim;
129 Self::new(config)
130 }
131
132 pub fn create_from_template(&mut self, name: impl Into<String>, template: &TrainingTemplate) -> &ManagedAgent {
134 let name = name.into();
135 let agent = ManagedAgent::new(
136 name.clone(),
137 template.agent_type.clone(),
138 template.sona_config.clone(),
139 &template.name,
140 );
141 self.agents.insert(name.clone(), agent);
142 self.agents.get(&name).unwrap()
143 }
144
145 pub fn create_agent(
147 &mut self,
148 name: impl Into<String>,
149 agent_type: AgentType,
150 purpose: impl Into<String>,
151 ) -> &ManagedAgent {
152 let name = name.into();
153 let config = self.config_for_agent_type(&agent_type);
154 let mut agent = ManagedAgent::new(name.clone(), agent_type, config, purpose);
155 agent.tags.push("custom".into());
156 self.agents.insert(name.clone(), agent);
157 self.agents.get(&name).unwrap()
158 }
159
160 pub fn create_code_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
162 let template = TrainingTemplate::code_agent()
163 .with_hidden_dim(self.default_hidden_dim);
164 self.create_from_template(name, &template)
165 }
166
167 pub fn create_chat_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
169 let template = TrainingTemplate::chat_agent()
170 .with_hidden_dim(self.default_hidden_dim);
171 self.create_from_template(name, &template)
172 }
173
174 pub fn create_rag_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
176 let template = TrainingTemplate::rag_agent()
177 .with_hidden_dim(self.default_hidden_dim);
178 self.create_from_template(name, &template)
179 }
180
181 pub fn create_task_planner(&mut self, name: impl Into<String>) -> &ManagedAgent {
183 let template = TrainingTemplate::task_planner()
184 .with_hidden_dim(self.default_hidden_dim);
185 self.create_from_template(name, &template)
186 }
187
188 pub fn create_reasoning_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
190 let template = TrainingTemplate::reasoning_agent()
191 .with_hidden_dim(self.default_hidden_dim);
192 self.create_from_template(name, &template)
193 }
194
195 pub fn create_codebase_helper(&mut self, name: impl Into<String>) -> &ManagedAgent {
197 let template = TrainingTemplate::codebase_helper()
198 .with_hidden_dim(self.default_hidden_dim);
199 self.create_from_template(name, &template)
200 }
201
202 pub fn get_agent(&self, name: &str) -> Option<&ManagedAgent> {
204 self.agents.get(name)
205 }
206
207 pub fn get_agent_mut(&mut self, name: &str) -> Option<&mut ManagedAgent> {
209 self.agents.get_mut(name)
210 }
211
212 pub fn remove_agent(&mut self, name: &str) -> Option<ManagedAgent> {
214 self.agents.remove(name)
215 }
216
217 pub fn list_agents(&self) -> Vec<AgentStats> {
219 self.agents.values().map(|a| a.stats()).collect()
220 }
221
222 pub fn agent_count(&self) -> usize {
224 self.agents.len()
225 }
226
227 pub fn train_agent<E>(&mut self, name: &str, examples: impl Iterator<Item = E>) -> Result<usize, String>
229 where
230 E: TrainingExample,
231 {
232 let agent = self.agents.get_mut(name)
233 .ok_or_else(|| format!("Agent '{}' not found", name))?;
234
235 let mut count = 0;
236 for example in examples {
237 let mut builder = agent.engine.begin_trajectory(example.embedding());
239
240 if let Some(route) = example.route() {
242 builder.set_model_route(&route);
243 }
244
245 for ctx in example.context() {
247 builder.add_context(&ctx);
248 }
249
250 builder.add_step(
252 example.activations(),
253 example.attention(),
254 example.reward(),
255 );
256
257 agent.engine.end_trajectory(builder, example.quality());
259
260 count += 1;
261 agent.metrics.total_examples += 1;
262 agent.metrics.add_quality_sample(example.quality());
263 }
264
265 agent.engine.force_learn();
267 agent.training_count += 1;
268 agent.metrics.training_sessions += 1;
269
270 Ok(count)
271 }
272
273 fn config_for_agent_type(&self, agent_type: &AgentType) -> SonaConfig {
275 let mut config = self.base_config.clone();
276
277 match agent_type {
278 AgentType::CodeAgent | AgentType::CodebaseHelper => {
279 config.base_lora_rank = 16;
280 config.pattern_clusters = 200;
281 config.quality_threshold = 0.2;
282 }
283 AgentType::ChatAgent => {
284 config.base_lora_rank = 8;
285 config.pattern_clusters = 50;
286 config.quality_threshold = 0.4;
287 }
288 AgentType::RagAgent => {
289 config.pattern_clusters = 200;
290 config.trajectory_capacity = 10000;
291 }
292 AgentType::TaskPlanner => {
293 config.base_lora_rank = 16;
294 config.ewc_lambda = 2000.0;
295 }
296 AgentType::ReasoningAgent => {
297 config.base_lora_rank = 16;
298 config.ewc_lambda = 3000.0;
299 config.pattern_clusters = 150;
300 }
301 AgentType::DomainExpert => {
302 config.quality_threshold = 0.1;
303 config.trajectory_capacity = 20000;
304 }
305 AgentType::DataAnalyst => {
306 config.base_lora_rank = 8;
307 config.pattern_clusters = 100;
308 }
309 AgentType::CreativeWriter => {
310 config.base_lora_rank = 8;
311 config.pattern_clusters = 50;
312 config.quality_threshold = 0.5;
313 }
314 _ => {}
315 }
316
317 config
318 }
319}
320
321pub trait TrainingExample {
323 fn embedding(&self) -> Vec<f32>;
325
326 fn activations(&self) -> Vec<f32> {
328 self.embedding()
329 }
330
331 fn attention(&self) -> Vec<f32> {
333 vec![1.0 / 64.0; 64]
334 }
335
336 fn reward(&self) -> f32 {
338 self.quality()
339 }
340
341 fn quality(&self) -> f32;
343
344 fn route(&self) -> Option<String> {
346 None
347 }
348
349 fn context(&self) -> Vec<String> {
351 Vec::new()
352 }
353}
354
355#[derive(Clone, Debug)]
357pub struct SimpleExample {
358 pub embedding: Vec<f32>,
360 pub quality: f32,
362 pub route: Option<String>,
364 pub context: Vec<String>,
366}
367
368impl SimpleExample {
369 pub fn new(embedding: Vec<f32>, quality: f32) -> Self {
371 Self {
372 embedding,
373 quality,
374 route: None,
375 context: Vec::new(),
376 }
377 }
378
379 pub fn with_route(mut self, route: impl Into<String>) -> Self {
381 self.route = Some(route.into());
382 self
383 }
384
385 pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
387 self.context.push(ctx.into());
388 self
389 }
390}
391
392impl TrainingExample for SimpleExample {
393 fn embedding(&self) -> Vec<f32> {
394 self.embedding.clone()
395 }
396
397 fn quality(&self) -> f32 {
398 self.quality
399 }
400
401 fn route(&self) -> Option<String> {
402 self.route.clone()
403 }
404
405 fn context(&self) -> Vec<String> {
406 self.context.clone()
407 }
408}
409
410pub struct SharedAgentFactory {
412 inner: Arc<RwLock<AgentFactory>>,
413}
414
415impl SharedAgentFactory {
416 pub fn new(config: SonaConfig) -> Self {
418 Self {
419 inner: Arc::new(RwLock::new(AgentFactory::new(config))),
420 }
421 }
422
423 pub fn read(&self) -> std::sync::RwLockReadGuard<AgentFactory> {
425 self.inner.read().unwrap()
426 }
427
428 pub fn write(&self) -> std::sync::RwLockWriteGuard<AgentFactory> {
430 self.inner.write().unwrap()
431 }
432
433 pub fn clone_arc(&self) -> Self {
435 Self {
436 inner: Arc::clone(&self.inner),
437 }
438 }
439}
440
441impl Clone for SharedAgentFactory {
442 fn clone(&self) -> Self {
443 self.clone_arc()
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_factory_creation() {
453 let factory = AgentFactory::default();
454 assert_eq!(factory.agent_count(), 0);
455 }
456
457 #[test]
458 fn test_create_agents() {
459 let mut factory = AgentFactory::with_hidden_dim(256);
460
461 factory.create_code_agent("code-1");
462 factory.create_chat_agent("chat-1");
463 factory.create_rag_agent("rag-1");
464
465 assert_eq!(factory.agent_count(), 3);
466 assert!(factory.get_agent("code-1").is_some());
467 assert!(factory.get_agent("unknown").is_none());
468 }
469
470 #[test]
471 fn test_agent_from_template() {
472 let mut factory = AgentFactory::with_hidden_dim(256);
473 let template = TrainingTemplate::reasoning_agent()
474 .with_hidden_dim(256);
475
476 factory.create_from_template("reasoner", &template);
477
478 let agent = factory.get_agent("reasoner").unwrap();
479 assert_eq!(agent.handle.agent_type, AgentType::ReasoningAgent);
480 }
481
482 #[test]
483 fn test_train_agent() {
484 let mut factory = AgentFactory::with_hidden_dim(256);
485 factory.create_chat_agent("bot");
486
487 let examples = vec![
488 SimpleExample::new(vec![0.1; 256], 0.8).with_route("greeting"),
489 SimpleExample::new(vec![0.2; 256], 0.9).with_route("question"),
490 SimpleExample::new(vec![0.3; 256], 0.7).with_route("farewell"),
491 ];
492
493 let count = factory.train_agent("bot", examples.into_iter()).unwrap();
494 assert_eq!(count, 3);
495
496 let agent = factory.get_agent("bot").unwrap();
497 assert_eq!(agent.training_count, 1);
498 assert_eq!(agent.metrics.total_examples, 3);
499 }
500
501 #[test]
502 fn test_list_agents() {
503 let mut factory = AgentFactory::with_hidden_dim(256);
504 factory.create_code_agent("code");
505 factory.create_chat_agent("chat");
506
507 let agents = factory.list_agents();
508 assert_eq!(agents.len(), 2);
509 }
510}