1use super::metrics::TrainingMetrics;
6use super::templates::{AgentType, TrainingTemplate};
7use crate::engine::SonaEngine;
8use crate::time_compat::SystemTime;
9use crate::types::SonaConfig;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13
14#[derive(Clone, Debug)]
16pub struct AgentHandle {
17 pub id: String,
19 pub agent_type: AgentType,
21 pub created_at: u64,
23}
24
25pub struct ManagedAgent {
27 pub handle: AgentHandle,
29 pub engine: SonaEngine,
31 pub metrics: TrainingMetrics,
33 pub purpose: String,
35 pub training_count: u64,
37 pub tags: Vec<String>,
39}
40
41impl ManagedAgent {
42 pub fn new(
44 id: impl Into<String>,
45 agent_type: AgentType,
46 config: SonaConfig,
47 purpose: impl Into<String>,
48 ) -> Self {
49 let now = SystemTime::now().duration_since_epoch().as_secs();
50
51 let id = id.into();
52 Self {
53 handle: AgentHandle {
54 id: id.clone(),
55 agent_type,
56 created_at: now,
57 },
58 engine: SonaEngine::with_config(config),
59 metrics: TrainingMetrics::new(&id),
60 purpose: purpose.into(),
61 training_count: 0,
62 tags: Vec::new(),
63 }
64 }
65
66 pub fn stats(&self) -> AgentStats {
68 AgentStats {
69 id: self.handle.id.clone(),
70 agent_type: self.handle.agent_type.clone(),
71 training_count: self.training_count,
72 patterns_learned: self.metrics.patterns_learned,
73 avg_quality: self.metrics.avg_quality(),
74 total_examples: self.metrics.total_examples,
75 }
76 }
77}
78
79#[derive(Clone, Debug, Serialize, Deserialize)]
81pub struct AgentStats {
82 pub id: String,
84 pub agent_type: AgentType,
86 pub training_count: u64,
88 pub patterns_learned: usize,
90 pub avg_quality: f32,
92 pub total_examples: usize,
94}
95
96pub struct AgentFactory {
98 base_config: SonaConfig,
100 agents: HashMap<String, ManagedAgent>,
102 default_hidden_dim: usize,
104}
105
106impl AgentFactory {
107 pub fn new(base_config: SonaConfig) -> Self {
109 let default_hidden_dim = base_config.hidden_dim;
110 Self {
111 base_config,
112 agents: HashMap::new(),
113 default_hidden_dim,
114 }
115 }
116
117 pub fn default() -> Self {
119 Self::new(SonaConfig::default())
120 }
121
122 pub fn with_hidden_dim(hidden_dim: usize) -> Self {
124 let mut config = SonaConfig::default();
125 config.hidden_dim = hidden_dim;
126 config.embedding_dim = hidden_dim;
127 Self::new(config)
128 }
129
130 pub fn create_from_template(
132 &mut self,
133 name: impl Into<String>,
134 template: &TrainingTemplate,
135 ) -> &ManagedAgent {
136 let name = name.into();
137 let agent = ManagedAgent::new(
138 name.clone(),
139 template.agent_type.clone(),
140 template.sona_config.clone(),
141 &template.name,
142 );
143 self.agents.insert(name.clone(), agent);
144 self.agents.get(&name).unwrap()
145 }
146
147 pub fn create_agent(
149 &mut self,
150 name: impl Into<String>,
151 agent_type: AgentType,
152 purpose: impl Into<String>,
153 ) -> &ManagedAgent {
154 let name = name.into();
155 let config = self.config_for_agent_type(&agent_type);
156 let mut agent = ManagedAgent::new(name.clone(), agent_type, config, purpose);
157 agent.tags.push("custom".into());
158 self.agents.insert(name.clone(), agent);
159 self.agents.get(&name).unwrap()
160 }
161
162 pub fn create_code_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
164 let template = TrainingTemplate::code_agent().with_hidden_dim(self.default_hidden_dim);
165 self.create_from_template(name, &template)
166 }
167
168 pub fn create_chat_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
170 let template = TrainingTemplate::chat_agent().with_hidden_dim(self.default_hidden_dim);
171 self.create_from_template(name, &template)
172 }
173
174 pub fn create_rag_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
176 let template = TrainingTemplate::rag_agent().with_hidden_dim(self.default_hidden_dim);
177 self.create_from_template(name, &template)
178 }
179
180 pub fn create_task_planner(&mut self, name: impl Into<String>) -> &ManagedAgent {
182 let template = TrainingTemplate::task_planner().with_hidden_dim(self.default_hidden_dim);
183 self.create_from_template(name, &template)
184 }
185
186 pub fn create_reasoning_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
188 let template = TrainingTemplate::reasoning_agent().with_hidden_dim(self.default_hidden_dim);
189 self.create_from_template(name, &template)
190 }
191
192 pub fn create_codebase_helper(&mut self, name: impl Into<String>) -> &ManagedAgent {
194 let template = TrainingTemplate::codebase_helper().with_hidden_dim(self.default_hidden_dim);
195 self.create_from_template(name, &template)
196 }
197
198 pub fn get_agent(&self, name: &str) -> Option<&ManagedAgent> {
200 self.agents.get(name)
201 }
202
203 pub fn get_agent_mut(&mut self, name: &str) -> Option<&mut ManagedAgent> {
205 self.agents.get_mut(name)
206 }
207
208 pub fn remove_agent(&mut self, name: &str) -> Option<ManagedAgent> {
210 self.agents.remove(name)
211 }
212
213 pub fn list_agents(&self) -> Vec<AgentStats> {
215 self.agents.values().map(|a| a.stats()).collect()
216 }
217
218 pub fn agent_count(&self) -> usize {
220 self.agents.len()
221 }
222
223 pub fn train_agent<E>(
225 &mut self,
226 name: &str,
227 examples: impl Iterator<Item = E>,
228 ) -> Result<usize, String>
229 where
230 E: TrainingExample,
231 {
232 let agent = self
233 .agents
234 .get_mut(name)
235 .ok_or_else(|| format!("Agent '{}' not found", name))?;
236
237 let mut count = 0;
238 for example in examples {
239 let mut builder = agent.engine.begin_trajectory(example.embedding());
241
242 if let Some(route) = example.route() {
244 builder.set_model_route(&route);
245 }
246
247 for ctx in example.context() {
249 builder.add_context(&ctx);
250 }
251
252 builder.add_step(example.activations(), example.attention(), example.reward());
254
255 agent.engine.end_trajectory(builder, example.quality());
257
258 count += 1;
259 agent.metrics.total_examples += 1;
260 agent.metrics.add_quality_sample(example.quality());
261 }
262
263 agent.engine.force_learn();
265 agent.training_count += 1;
266 agent.metrics.training_sessions += 1;
267
268 Ok(count)
269 }
270
271 fn config_for_agent_type(&self, agent_type: &AgentType) -> SonaConfig {
273 let mut config = self.base_config.clone();
274
275 match agent_type {
276 AgentType::CodeAgent | AgentType::CodebaseHelper => {
277 config.base_lora_rank = 16;
278 config.pattern_clusters = 200;
279 config.quality_threshold = 0.2;
280 }
281 AgentType::ChatAgent => {
282 config.base_lora_rank = 8;
283 config.pattern_clusters = 50;
284 config.quality_threshold = 0.4;
285 }
286 AgentType::RagAgent => {
287 config.pattern_clusters = 200;
288 config.trajectory_capacity = 10000;
289 }
290 AgentType::TaskPlanner => {
291 config.base_lora_rank = 16;
292 config.ewc_lambda = 2000.0;
293 }
294 AgentType::ReasoningAgent => {
295 config.base_lora_rank = 16;
296 config.ewc_lambda = 3000.0;
297 config.pattern_clusters = 150;
298 }
299 AgentType::DomainExpert => {
300 config.quality_threshold = 0.1;
301 config.trajectory_capacity = 20000;
302 }
303 AgentType::DataAnalyst => {
304 config.base_lora_rank = 8;
305 config.pattern_clusters = 100;
306 }
307 AgentType::CreativeWriter => {
308 config.base_lora_rank = 8;
309 config.pattern_clusters = 50;
310 config.quality_threshold = 0.5;
311 }
312 _ => {}
313 }
314
315 config
316 }
317}
318
319pub trait TrainingExample {
321 fn embedding(&self) -> Vec<f32>;
323
324 fn activations(&self) -> Vec<f32> {
326 self.embedding()
327 }
328
329 fn attention(&self) -> Vec<f32> {
331 vec![1.0 / 64.0; 64]
332 }
333
334 fn reward(&self) -> f32 {
336 self.quality()
337 }
338
339 fn quality(&self) -> f32;
341
342 fn route(&self) -> Option<String> {
344 None
345 }
346
347 fn context(&self) -> Vec<String> {
349 Vec::new()
350 }
351}
352
353#[derive(Clone, Debug)]
355pub struct SimpleExample {
356 pub embedding: Vec<f32>,
358 pub quality: f32,
360 pub route: Option<String>,
362 pub context: Vec<String>,
364}
365
366impl SimpleExample {
367 pub fn new(embedding: Vec<f32>, quality: f32) -> Self {
369 Self {
370 embedding,
371 quality,
372 route: None,
373 context: Vec::new(),
374 }
375 }
376
377 pub fn with_route(mut self, route: impl Into<String>) -> Self {
379 self.route = Some(route.into());
380 self
381 }
382
383 pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
385 self.context.push(ctx.into());
386 self
387 }
388}
389
390impl TrainingExample for SimpleExample {
391 fn embedding(&self) -> Vec<f32> {
392 self.embedding.clone()
393 }
394
395 fn quality(&self) -> f32 {
396 self.quality
397 }
398
399 fn route(&self) -> Option<String> {
400 self.route.clone()
401 }
402
403 fn context(&self) -> Vec<String> {
404 self.context.clone()
405 }
406}
407
408pub struct SharedAgentFactory {
410 inner: Arc<RwLock<AgentFactory>>,
411}
412
413impl SharedAgentFactory {
414 pub fn new(config: SonaConfig) -> Self {
416 Self {
417 inner: Arc::new(RwLock::new(AgentFactory::new(config))),
418 }
419 }
420
421 pub fn read(&self) -> std::sync::RwLockReadGuard<AgentFactory> {
423 self.inner.read().unwrap()
424 }
425
426 pub fn write(&self) -> std::sync::RwLockWriteGuard<AgentFactory> {
428 self.inner.write().unwrap()
429 }
430
431 pub fn clone_arc(&self) -> Self {
433 Self {
434 inner: Arc::clone(&self.inner),
435 }
436 }
437}
438
439impl Clone for SharedAgentFactory {
440 fn clone(&self) -> Self {
441 self.clone_arc()
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
450 fn test_factory_creation() {
451 let factory = AgentFactory::default();
452 assert_eq!(factory.agent_count(), 0);
453 }
454
455 #[test]
456 fn test_create_agents() {
457 let mut factory = AgentFactory::with_hidden_dim(256);
458
459 factory.create_code_agent("code-1");
460 factory.create_chat_agent("chat-1");
461 factory.create_rag_agent("rag-1");
462
463 assert_eq!(factory.agent_count(), 3);
464 assert!(factory.get_agent("code-1").is_some());
465 assert!(factory.get_agent("unknown").is_none());
466 }
467
468 #[test]
469 fn test_agent_from_template() {
470 let mut factory = AgentFactory::with_hidden_dim(256);
471 let template = TrainingTemplate::reasoning_agent().with_hidden_dim(256);
472
473 factory.create_from_template("reasoner", &template);
474
475 let agent = factory.get_agent("reasoner").unwrap();
476 assert_eq!(agent.handle.agent_type, AgentType::ReasoningAgent);
477 }
478
479 #[test]
480 fn test_train_agent() {
481 let mut factory = AgentFactory::with_hidden_dim(256);
482 factory.create_chat_agent("bot");
483
484 let examples = vec![
485 SimpleExample::new(vec![0.1; 256], 0.8).with_route("greeting"),
486 SimpleExample::new(vec![0.2; 256], 0.9).with_route("question"),
487 SimpleExample::new(vec![0.3; 256], 0.7).with_route("farewell"),
488 ];
489
490 let count = factory.train_agent("bot", examples.into_iter()).unwrap();
491 assert_eq!(count, 3);
492
493 let agent = factory.get_agent("bot").unwrap();
494 assert_eq!(agent.training_count, 1);
495 assert_eq!(agent.metrics.total_examples, 3);
496 }
497
498 #[test]
499 fn test_list_agents() {
500 let mut factory = AgentFactory::with_hidden_dim(256);
501 factory.create_code_agent("code");
502 factory.create_chat_agent("chat");
503
504 let agents = factory.list_agents();
505 assert_eq!(agents.len(), 2);
506 }
507}