Skip to main content

ruvector_sona/training/
factory.rs

1//! Agent Factory for SONA
2//!
3//! Create and manage multiple specialized agents.
4
5use super::metrics::TrainingMetrics;
6use super::templates::{AgentType, TrainingTemplate};
7use crate::engine::SonaEngine;
8use crate::time_compat::SystemTime;
9use crate::types::SonaConfig;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13
14/// Handle to a managed agent
15#[derive(Clone, Debug)]
16pub struct AgentHandle {
17    /// Agent identifier
18    pub id: String,
19    /// Agent type
20    pub agent_type: AgentType,
21    /// Creation timestamp
22    pub created_at: u64,
23}
24
25/// Managed agent with engine and metadata
26pub struct ManagedAgent {
27    /// Agent handle
28    pub handle: AgentHandle,
29    /// SONA engine
30    pub engine: SonaEngine,
31    /// Training metrics
32    pub metrics: TrainingMetrics,
33    /// Purpose/description
34    pub purpose: String,
35    /// Training count
36    pub training_count: u64,
37    /// Tags for organization
38    pub tags: Vec<String>,
39}
40
41impl ManagedAgent {
42    /// Create a new managed agent
43    pub fn new(
44        id: impl Into<String>,
45        agent_type: AgentType,
46        config: SonaConfig,
47        purpose: impl Into<String>,
48    ) -> Self {
49        let now = SystemTime::now().duration_since_epoch().as_secs();
50
51        let id = id.into();
52        Self {
53            handle: AgentHandle {
54                id: id.clone(),
55                agent_type,
56                created_at: now,
57            },
58            engine: SonaEngine::with_config(config),
59            metrics: TrainingMetrics::new(&id),
60            purpose: purpose.into(),
61            training_count: 0,
62            tags: Vec::new(),
63        }
64    }
65
66    /// Get agent stats
67    pub fn stats(&self) -> AgentStats {
68        AgentStats {
69            id: self.handle.id.clone(),
70            agent_type: self.handle.agent_type.clone(),
71            training_count: self.training_count,
72            patterns_learned: self.metrics.patterns_learned,
73            avg_quality: self.metrics.avg_quality(),
74            total_examples: self.metrics.total_examples,
75        }
76    }
77}
78
79/// Agent statistics
80#[derive(Clone, Debug, Serialize, Deserialize)]
81pub struct AgentStats {
82    /// Agent ID
83    pub id: String,
84    /// Agent type
85    pub agent_type: AgentType,
86    /// Number of training sessions
87    pub training_count: u64,
88    /// Patterns learned
89    pub patterns_learned: usize,
90    /// Average quality score
91    pub avg_quality: f32,
92    /// Total examples processed
93    pub total_examples: usize,
94}
95
96/// Factory for creating and managing agents
97pub struct AgentFactory {
98    /// Base configuration for all agents
99    base_config: SonaConfig,
100    /// Managed agents
101    agents: HashMap<String, ManagedAgent>,
102    /// Default hidden dimension
103    default_hidden_dim: usize,
104}
105
106impl AgentFactory {
107    /// Create a new agent factory
108    pub fn new(base_config: SonaConfig) -> Self {
109        let default_hidden_dim = base_config.hidden_dim;
110        Self {
111            base_config,
112            agents: HashMap::new(),
113            default_hidden_dim,
114        }
115    }
116
117    /// Create factory with default configuration
118    pub fn default() -> Self {
119        Self::new(SonaConfig::default())
120    }
121
122    /// Create factory with specific hidden dimension
123    pub fn with_hidden_dim(hidden_dim: usize) -> Self {
124        let mut config = SonaConfig::default();
125        config.hidden_dim = hidden_dim;
126        config.embedding_dim = hidden_dim;
127        Self::new(config)
128    }
129
130    /// Create an agent from a template
131    pub fn create_from_template(
132        &mut self,
133        name: impl Into<String>,
134        template: &TrainingTemplate,
135    ) -> &ManagedAgent {
136        let name = name.into();
137        let agent = ManagedAgent::new(
138            name.clone(),
139            template.agent_type.clone(),
140            template.sona_config.clone(),
141            &template.name,
142        );
143        self.agents.insert(name.clone(), agent);
144        self.agents.get(&name).unwrap()
145    }
146
147    /// Create an agent with custom configuration
148    pub fn create_agent(
149        &mut self,
150        name: impl Into<String>,
151        agent_type: AgentType,
152        purpose: impl Into<String>,
153    ) -> &ManagedAgent {
154        let name = name.into();
155        let config = self.config_for_agent_type(&agent_type);
156        let mut agent = ManagedAgent::new(name.clone(), agent_type, config, purpose);
157        agent.tags.push("custom".into());
158        self.agents.insert(name.clone(), agent);
159        self.agents.get(&name).unwrap()
160    }
161
162    /// Create a code agent
163    pub fn create_code_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
164        let template = TrainingTemplate::code_agent().with_hidden_dim(self.default_hidden_dim);
165        self.create_from_template(name, &template)
166    }
167
168    /// Create a chat agent
169    pub fn create_chat_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
170        let template = TrainingTemplate::chat_agent().with_hidden_dim(self.default_hidden_dim);
171        self.create_from_template(name, &template)
172    }
173
174    /// Create a RAG agent
175    pub fn create_rag_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
176        let template = TrainingTemplate::rag_agent().with_hidden_dim(self.default_hidden_dim);
177        self.create_from_template(name, &template)
178    }
179
180    /// Create a task planner agent
181    pub fn create_task_planner(&mut self, name: impl Into<String>) -> &ManagedAgent {
182        let template = TrainingTemplate::task_planner().with_hidden_dim(self.default_hidden_dim);
183        self.create_from_template(name, &template)
184    }
185
186    /// Create a reasoning agent
187    pub fn create_reasoning_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
188        let template = TrainingTemplate::reasoning_agent().with_hidden_dim(self.default_hidden_dim);
189        self.create_from_template(name, &template)
190    }
191
192    /// Create a codebase helper agent
193    pub fn create_codebase_helper(&mut self, name: impl Into<String>) -> &ManagedAgent {
194        let template = TrainingTemplate::codebase_helper().with_hidden_dim(self.default_hidden_dim);
195        self.create_from_template(name, &template)
196    }
197
198    /// Get an agent by name
199    pub fn get_agent(&self, name: &str) -> Option<&ManagedAgent> {
200        self.agents.get(name)
201    }
202
203    /// Get a mutable agent by name
204    pub fn get_agent_mut(&mut self, name: &str) -> Option<&mut ManagedAgent> {
205        self.agents.get_mut(name)
206    }
207
208    /// Remove an agent
209    pub fn remove_agent(&mut self, name: &str) -> Option<ManagedAgent> {
210        self.agents.remove(name)
211    }
212
213    /// List all agents
214    pub fn list_agents(&self) -> Vec<AgentStats> {
215        self.agents.values().map(|a| a.stats()).collect()
216    }
217
218    /// Get agent count
219    pub fn agent_count(&self) -> usize {
220        self.agents.len()
221    }
222
223    /// Train an agent with examples
224    pub fn train_agent<E>(
225        &mut self,
226        name: &str,
227        examples: impl Iterator<Item = E>,
228    ) -> Result<usize, String>
229    where
230        E: TrainingExample,
231    {
232        let agent = self
233            .agents
234            .get_mut(name)
235            .ok_or_else(|| format!("Agent '{}' not found", name))?;
236
237        let mut count = 0;
238        for example in examples {
239            // Use builder-based trajectory API
240            let mut builder = agent.engine.begin_trajectory(example.embedding());
241
242            // Set route if available
243            if let Some(route) = example.route() {
244                builder.set_model_route(&route);
245            }
246
247            // Add context if available
248            for ctx in example.context() {
249                builder.add_context(&ctx);
250            }
251
252            // Add step with activations
253            builder.add_step(example.activations(), example.attention(), example.reward());
254
255            // End trajectory with quality
256            agent.engine.end_trajectory(builder, example.quality());
257
258            count += 1;
259            agent.metrics.total_examples += 1;
260            agent.metrics.add_quality_sample(example.quality());
261        }
262
263        // Force learning after batch
264        agent.engine.force_learn();
265        agent.training_count += 1;
266        agent.metrics.training_sessions += 1;
267
268        Ok(count)
269    }
270
271    /// Get configuration for agent type
272    fn config_for_agent_type(&self, agent_type: &AgentType) -> SonaConfig {
273        let mut config = self.base_config.clone();
274
275        match agent_type {
276            AgentType::CodeAgent | AgentType::CodebaseHelper => {
277                config.base_lora_rank = 16;
278                config.pattern_clusters = 200;
279                config.quality_threshold = 0.2;
280            }
281            AgentType::ChatAgent => {
282                config.base_lora_rank = 8;
283                config.pattern_clusters = 50;
284                config.quality_threshold = 0.4;
285            }
286            AgentType::RagAgent => {
287                config.pattern_clusters = 200;
288                config.trajectory_capacity = 10000;
289            }
290            AgentType::TaskPlanner => {
291                config.base_lora_rank = 16;
292                config.ewc_lambda = 2000.0;
293            }
294            AgentType::ReasoningAgent => {
295                config.base_lora_rank = 16;
296                config.ewc_lambda = 3000.0;
297                config.pattern_clusters = 150;
298            }
299            AgentType::DomainExpert => {
300                config.quality_threshold = 0.1;
301                config.trajectory_capacity = 20000;
302            }
303            AgentType::DataAnalyst => {
304                config.base_lora_rank = 8;
305                config.pattern_clusters = 100;
306            }
307            AgentType::CreativeWriter => {
308                config.base_lora_rank = 8;
309                config.pattern_clusters = 50;
310                config.quality_threshold = 0.5;
311            }
312            _ => {}
313        }
314
315        config
316    }
317}
318
319/// Trait for training examples
320pub trait TrainingExample {
321    /// Get embedding vector
322    fn embedding(&self) -> Vec<f32>;
323
324    /// Get activations (can be same as embedding)
325    fn activations(&self) -> Vec<f32> {
326        self.embedding()
327    }
328
329    /// Get attention weights
330    fn attention(&self) -> Vec<f32> {
331        vec![1.0 / 64.0; 64]
332    }
333
334    /// Get reward signal
335    fn reward(&self) -> f32 {
336        self.quality()
337    }
338
339    /// Get quality score
340    fn quality(&self) -> f32;
341
342    /// Get optional route
343    fn route(&self) -> Option<String> {
344        None
345    }
346
347    /// Get context identifiers
348    fn context(&self) -> Vec<String> {
349        Vec::new()
350    }
351}
352
353/// Simple training example implementation
354#[derive(Clone, Debug)]
355pub struct SimpleExample {
356    /// Embedding vector
357    pub embedding: Vec<f32>,
358    /// Quality score
359    pub quality: f32,
360    /// Optional route
361    pub route: Option<String>,
362    /// Context IDs
363    pub context: Vec<String>,
364}
365
366impl SimpleExample {
367    /// Create a new simple example
368    pub fn new(embedding: Vec<f32>, quality: f32) -> Self {
369        Self {
370            embedding,
371            quality,
372            route: None,
373            context: Vec::new(),
374        }
375    }
376
377    /// Set route
378    pub fn with_route(mut self, route: impl Into<String>) -> Self {
379        self.route = Some(route.into());
380        self
381    }
382
383    /// Add context
384    pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
385        self.context.push(ctx.into());
386        self
387    }
388}
389
390impl TrainingExample for SimpleExample {
391    fn embedding(&self) -> Vec<f32> {
392        self.embedding.clone()
393    }
394
395    fn quality(&self) -> f32 {
396        self.quality
397    }
398
399    fn route(&self) -> Option<String> {
400        self.route.clone()
401    }
402
403    fn context(&self) -> Vec<String> {
404        self.context.clone()
405    }
406}
407
408/// Thread-safe agent factory wrapper
409pub struct SharedAgentFactory {
410    inner: Arc<RwLock<AgentFactory>>,
411}
412
413impl SharedAgentFactory {
414    /// Create a new shared factory
415    pub fn new(config: SonaConfig) -> Self {
416        Self {
417            inner: Arc::new(RwLock::new(AgentFactory::new(config))),
418        }
419    }
420
421    /// Get read access to factory
422    pub fn read(&self) -> std::sync::RwLockReadGuard<AgentFactory> {
423        self.inner.read().unwrap()
424    }
425
426    /// Get write access to factory
427    pub fn write(&self) -> std::sync::RwLockWriteGuard<AgentFactory> {
428        self.inner.write().unwrap()
429    }
430
431    /// Clone the Arc
432    pub fn clone_arc(&self) -> Self {
433        Self {
434            inner: Arc::clone(&self.inner),
435        }
436    }
437}
438
439impl Clone for SharedAgentFactory {
440    fn clone(&self) -> Self {
441        self.clone_arc()
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[test]
450    fn test_factory_creation() {
451        let factory = AgentFactory::default();
452        assert_eq!(factory.agent_count(), 0);
453    }
454
455    #[test]
456    fn test_create_agents() {
457        let mut factory = AgentFactory::with_hidden_dim(256);
458
459        factory.create_code_agent("code-1");
460        factory.create_chat_agent("chat-1");
461        factory.create_rag_agent("rag-1");
462
463        assert_eq!(factory.agent_count(), 3);
464        assert!(factory.get_agent("code-1").is_some());
465        assert!(factory.get_agent("unknown").is_none());
466    }
467
468    #[test]
469    fn test_agent_from_template() {
470        let mut factory = AgentFactory::with_hidden_dim(256);
471        let template = TrainingTemplate::reasoning_agent().with_hidden_dim(256);
472
473        factory.create_from_template("reasoner", &template);
474
475        let agent = factory.get_agent("reasoner").unwrap();
476        assert_eq!(agent.handle.agent_type, AgentType::ReasoningAgent);
477    }
478
479    #[test]
480    fn test_train_agent() {
481        let mut factory = AgentFactory::with_hidden_dim(256);
482        factory.create_chat_agent("bot");
483
484        let examples = vec![
485            SimpleExample::new(vec![0.1; 256], 0.8).with_route("greeting"),
486            SimpleExample::new(vec![0.2; 256], 0.9).with_route("question"),
487            SimpleExample::new(vec![0.3; 256], 0.7).with_route("farewell"),
488        ];
489
490        let count = factory.train_agent("bot", examples.into_iter()).unwrap();
491        assert_eq!(count, 3);
492
493        let agent = factory.get_agent("bot").unwrap();
494        assert_eq!(agent.training_count, 1);
495        assert_eq!(agent.metrics.total_examples, 3);
496    }
497
498    #[test]
499    fn test_list_agents() {
500        let mut factory = AgentFactory::with_hidden_dim(256);
501        factory.create_code_agent("code");
502        factory.create_chat_agent("chat");
503
504        let agents = factory.list_agents();
505        assert_eq!(agents.len(), 2);
506    }
507}