ruvector_sona/training/
factory.rs

1//! Agent Factory for SONA
2//!
3//! Create and manage multiple specialized agents.
4
5use crate::engine::SonaEngine;
6use crate::types::SonaConfig;
7use super::templates::{TrainingTemplate, AgentType};
8use super::metrics::TrainingMetrics;
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use serde::{Deserialize, Serialize};
12
13/// Handle to a managed agent
14#[derive(Clone, Debug)]
15pub struct AgentHandle {
16    /// Agent identifier
17    pub id: String,
18    /// Agent type
19    pub agent_type: AgentType,
20    /// Creation timestamp
21    pub created_at: u64,
22}
23
24/// Managed agent with engine and metadata
25pub struct ManagedAgent {
26    /// Agent handle
27    pub handle: AgentHandle,
28    /// SONA engine
29    pub engine: SonaEngine,
30    /// Training metrics
31    pub metrics: TrainingMetrics,
32    /// Purpose/description
33    pub purpose: String,
34    /// Training count
35    pub training_count: u64,
36    /// Tags for organization
37    pub tags: Vec<String>,
38}
39
40impl ManagedAgent {
41    /// Create a new managed agent
42    pub fn new(
43        id: impl Into<String>,
44        agent_type: AgentType,
45        config: SonaConfig,
46        purpose: impl Into<String>,
47    ) -> Self {
48        let now = std::time::SystemTime::now()
49            .duration_since(std::time::UNIX_EPOCH)
50            .unwrap_or_default()
51            .as_secs();
52
53        let id = id.into();
54        Self {
55            handle: AgentHandle {
56                id: id.clone(),
57                agent_type,
58                created_at: now,
59            },
60            engine: SonaEngine::with_config(config),
61            metrics: TrainingMetrics::new(&id),
62            purpose: purpose.into(),
63            training_count: 0,
64            tags: Vec::new(),
65        }
66    }
67
68    /// Get agent stats
69    pub fn stats(&self) -> AgentStats {
70        AgentStats {
71            id: self.handle.id.clone(),
72            agent_type: self.handle.agent_type.clone(),
73            training_count: self.training_count,
74            patterns_learned: self.metrics.patterns_learned,
75            avg_quality: self.metrics.avg_quality(),
76            total_examples: self.metrics.total_examples,
77        }
78    }
79}
80
81/// Agent statistics
82#[derive(Clone, Debug, Serialize, Deserialize)]
83pub struct AgentStats {
84    /// Agent ID
85    pub id: String,
86    /// Agent type
87    pub agent_type: AgentType,
88    /// Number of training sessions
89    pub training_count: u64,
90    /// Patterns learned
91    pub patterns_learned: usize,
92    /// Average quality score
93    pub avg_quality: f32,
94    /// Total examples processed
95    pub total_examples: usize,
96}
97
98/// Factory for creating and managing agents
99pub struct AgentFactory {
100    /// Base configuration for all agents
101    base_config: SonaConfig,
102    /// Managed agents
103    agents: HashMap<String, ManagedAgent>,
104    /// Default hidden dimension
105    default_hidden_dim: usize,
106}
107
108impl AgentFactory {
109    /// Create a new agent factory
110    pub fn new(base_config: SonaConfig) -> Self {
111        let default_hidden_dim = base_config.hidden_dim;
112        Self {
113            base_config,
114            agents: HashMap::new(),
115            default_hidden_dim,
116        }
117    }
118
119    /// Create factory with default configuration
120    pub fn default() -> Self {
121        Self::new(SonaConfig::default())
122    }
123
124    /// Create factory with specific hidden dimension
125    pub fn with_hidden_dim(hidden_dim: usize) -> Self {
126        let mut config = SonaConfig::default();
127        config.hidden_dim = hidden_dim;
128        config.embedding_dim = hidden_dim;
129        Self::new(config)
130    }
131
132    /// Create an agent from a template
133    pub fn create_from_template(&mut self, name: impl Into<String>, template: &TrainingTemplate) -> &ManagedAgent {
134        let name = name.into();
135        let agent = ManagedAgent::new(
136            name.clone(),
137            template.agent_type.clone(),
138            template.sona_config.clone(),
139            &template.name,
140        );
141        self.agents.insert(name.clone(), agent);
142        self.agents.get(&name).unwrap()
143    }
144
145    /// Create an agent with custom configuration
146    pub fn create_agent(
147        &mut self,
148        name: impl Into<String>,
149        agent_type: AgentType,
150        purpose: impl Into<String>,
151    ) -> &ManagedAgent {
152        let name = name.into();
153        let config = self.config_for_agent_type(&agent_type);
154        let mut agent = ManagedAgent::new(name.clone(), agent_type, config, purpose);
155        agent.tags.push("custom".into());
156        self.agents.insert(name.clone(), agent);
157        self.agents.get(&name).unwrap()
158    }
159
160    /// Create a code agent
161    pub fn create_code_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
162        let template = TrainingTemplate::code_agent()
163            .with_hidden_dim(self.default_hidden_dim);
164        self.create_from_template(name, &template)
165    }
166
167    /// Create a chat agent
168    pub fn create_chat_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
169        let template = TrainingTemplate::chat_agent()
170            .with_hidden_dim(self.default_hidden_dim);
171        self.create_from_template(name, &template)
172    }
173
174    /// Create a RAG agent
175    pub fn create_rag_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
176        let template = TrainingTemplate::rag_agent()
177            .with_hidden_dim(self.default_hidden_dim);
178        self.create_from_template(name, &template)
179    }
180
181    /// Create a task planner agent
182    pub fn create_task_planner(&mut self, name: impl Into<String>) -> &ManagedAgent {
183        let template = TrainingTemplate::task_planner()
184            .with_hidden_dim(self.default_hidden_dim);
185        self.create_from_template(name, &template)
186    }
187
188    /// Create a reasoning agent
189    pub fn create_reasoning_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
190        let template = TrainingTemplate::reasoning_agent()
191            .with_hidden_dim(self.default_hidden_dim);
192        self.create_from_template(name, &template)
193    }
194
195    /// Create a codebase helper agent
196    pub fn create_codebase_helper(&mut self, name: impl Into<String>) -> &ManagedAgent {
197        let template = TrainingTemplate::codebase_helper()
198            .with_hidden_dim(self.default_hidden_dim);
199        self.create_from_template(name, &template)
200    }
201
202    /// Get an agent by name
203    pub fn get_agent(&self, name: &str) -> Option<&ManagedAgent> {
204        self.agents.get(name)
205    }
206
207    /// Get a mutable agent by name
208    pub fn get_agent_mut(&mut self, name: &str) -> Option<&mut ManagedAgent> {
209        self.agents.get_mut(name)
210    }
211
212    /// Remove an agent
213    pub fn remove_agent(&mut self, name: &str) -> Option<ManagedAgent> {
214        self.agents.remove(name)
215    }
216
217    /// List all agents
218    pub fn list_agents(&self) -> Vec<AgentStats> {
219        self.agents.values().map(|a| a.stats()).collect()
220    }
221
222    /// Get agent count
223    pub fn agent_count(&self) -> usize {
224        self.agents.len()
225    }
226
227    /// Train an agent with examples
228    pub fn train_agent<E>(&mut self, name: &str, examples: impl Iterator<Item = E>) -> Result<usize, String>
229    where
230        E: TrainingExample,
231    {
232        let agent = self.agents.get_mut(name)
233            .ok_or_else(|| format!("Agent '{}' not found", name))?;
234
235        let mut count = 0;
236        for example in examples {
237            // Use builder-based trajectory API
238            let mut builder = agent.engine.begin_trajectory(example.embedding());
239
240            // Set route if available
241            if let Some(route) = example.route() {
242                builder.set_model_route(&route);
243            }
244
245            // Add context if available
246            for ctx in example.context() {
247                builder.add_context(&ctx);
248            }
249
250            // Add step with activations
251            builder.add_step(
252                example.activations(),
253                example.attention(),
254                example.reward(),
255            );
256
257            // End trajectory with quality
258            agent.engine.end_trajectory(builder, example.quality());
259
260            count += 1;
261            agent.metrics.total_examples += 1;
262            agent.metrics.add_quality_sample(example.quality());
263        }
264
265        // Force learning after batch
266        agent.engine.force_learn();
267        agent.training_count += 1;
268        agent.metrics.training_sessions += 1;
269
270        Ok(count)
271    }
272
273    /// Get configuration for agent type
274    fn config_for_agent_type(&self, agent_type: &AgentType) -> SonaConfig {
275        let mut config = self.base_config.clone();
276
277        match agent_type {
278            AgentType::CodeAgent | AgentType::CodebaseHelper => {
279                config.base_lora_rank = 16;
280                config.pattern_clusters = 200;
281                config.quality_threshold = 0.2;
282            }
283            AgentType::ChatAgent => {
284                config.base_lora_rank = 8;
285                config.pattern_clusters = 50;
286                config.quality_threshold = 0.4;
287            }
288            AgentType::RagAgent => {
289                config.pattern_clusters = 200;
290                config.trajectory_capacity = 10000;
291            }
292            AgentType::TaskPlanner => {
293                config.base_lora_rank = 16;
294                config.ewc_lambda = 2000.0;
295            }
296            AgentType::ReasoningAgent => {
297                config.base_lora_rank = 16;
298                config.ewc_lambda = 3000.0;
299                config.pattern_clusters = 150;
300            }
301            AgentType::DomainExpert => {
302                config.quality_threshold = 0.1;
303                config.trajectory_capacity = 20000;
304            }
305            AgentType::DataAnalyst => {
306                config.base_lora_rank = 8;
307                config.pattern_clusters = 100;
308            }
309            AgentType::CreativeWriter => {
310                config.base_lora_rank = 8;
311                config.pattern_clusters = 50;
312                config.quality_threshold = 0.5;
313            }
314            _ => {}
315        }
316
317        config
318    }
319}
320
321/// Trait for training examples
322pub trait TrainingExample {
323    /// Get embedding vector
324    fn embedding(&self) -> Vec<f32>;
325
326    /// Get activations (can be same as embedding)
327    fn activations(&self) -> Vec<f32> {
328        self.embedding()
329    }
330
331    /// Get attention weights
332    fn attention(&self) -> Vec<f32> {
333        vec![1.0 / 64.0; 64]
334    }
335
336    /// Get reward signal
337    fn reward(&self) -> f32 {
338        self.quality()
339    }
340
341    /// Get quality score
342    fn quality(&self) -> f32;
343
344    /// Get optional route
345    fn route(&self) -> Option<String> {
346        None
347    }
348
349    /// Get context identifiers
350    fn context(&self) -> Vec<String> {
351        Vec::new()
352    }
353}
354
355/// Simple training example implementation
356#[derive(Clone, Debug)]
357pub struct SimpleExample {
358    /// Embedding vector
359    pub embedding: Vec<f32>,
360    /// Quality score
361    pub quality: f32,
362    /// Optional route
363    pub route: Option<String>,
364    /// Context IDs
365    pub context: Vec<String>,
366}
367
368impl SimpleExample {
369    /// Create a new simple example
370    pub fn new(embedding: Vec<f32>, quality: f32) -> Self {
371        Self {
372            embedding,
373            quality,
374            route: None,
375            context: Vec::new(),
376        }
377    }
378
379    /// Set route
380    pub fn with_route(mut self, route: impl Into<String>) -> Self {
381        self.route = Some(route.into());
382        self
383    }
384
385    /// Add context
386    pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
387        self.context.push(ctx.into());
388        self
389    }
390}
391
392impl TrainingExample for SimpleExample {
393    fn embedding(&self) -> Vec<f32> {
394        self.embedding.clone()
395    }
396
397    fn quality(&self) -> f32 {
398        self.quality
399    }
400
401    fn route(&self) -> Option<String> {
402        self.route.clone()
403    }
404
405    fn context(&self) -> Vec<String> {
406        self.context.clone()
407    }
408}
409
410/// Thread-safe agent factory wrapper
411pub struct SharedAgentFactory {
412    inner: Arc<RwLock<AgentFactory>>,
413}
414
415impl SharedAgentFactory {
416    /// Create a new shared factory
417    pub fn new(config: SonaConfig) -> Self {
418        Self {
419            inner: Arc::new(RwLock::new(AgentFactory::new(config))),
420        }
421    }
422
423    /// Get read access to factory
424    pub fn read(&self) -> std::sync::RwLockReadGuard<AgentFactory> {
425        self.inner.read().unwrap()
426    }
427
428    /// Get write access to factory
429    pub fn write(&self) -> std::sync::RwLockWriteGuard<AgentFactory> {
430        self.inner.write().unwrap()
431    }
432
433    /// Clone the Arc
434    pub fn clone_arc(&self) -> Self {
435        Self {
436            inner: Arc::clone(&self.inner),
437        }
438    }
439}
440
441impl Clone for SharedAgentFactory {
442    fn clone(&self) -> Self {
443        self.clone_arc()
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[test]
452    fn test_factory_creation() {
453        let factory = AgentFactory::default();
454        assert_eq!(factory.agent_count(), 0);
455    }
456
457    #[test]
458    fn test_create_agents() {
459        let mut factory = AgentFactory::with_hidden_dim(256);
460
461        factory.create_code_agent("code-1");
462        factory.create_chat_agent("chat-1");
463        factory.create_rag_agent("rag-1");
464
465        assert_eq!(factory.agent_count(), 3);
466        assert!(factory.get_agent("code-1").is_some());
467        assert!(factory.get_agent("unknown").is_none());
468    }
469
470    #[test]
471    fn test_agent_from_template() {
472        let mut factory = AgentFactory::with_hidden_dim(256);
473        let template = TrainingTemplate::reasoning_agent()
474            .with_hidden_dim(256);
475
476        factory.create_from_template("reasoner", &template);
477
478        let agent = factory.get_agent("reasoner").unwrap();
479        assert_eq!(agent.handle.agent_type, AgentType::ReasoningAgent);
480    }
481
482    #[test]
483    fn test_train_agent() {
484        let mut factory = AgentFactory::with_hidden_dim(256);
485        factory.create_chat_agent("bot");
486
487        let examples = vec![
488            SimpleExample::new(vec![0.1; 256], 0.8).with_route("greeting"),
489            SimpleExample::new(vec![0.2; 256], 0.9).with_route("question"),
490            SimpleExample::new(vec![0.3; 256], 0.7).with_route("farewell"),
491        ];
492
493        let count = factory.train_agent("bot", examples.into_iter()).unwrap();
494        assert_eq!(count, 3);
495
496        let agent = factory.get_agent("bot").unwrap();
497        assert_eq!(agent.training_count, 1);
498        assert_eq!(agent.metrics.total_examples, 3);
499    }
500
501    #[test]
502    fn test_list_agents() {
503        let mut factory = AgentFactory::with_hidden_dim(256);
504        factory.create_code_agent("code");
505        factory.create_chat_agent("chat");
506
507        let agents = factory.list_agents();
508        assert_eq!(agents.len(), 2);
509    }
510}