Skip to main content

ruvector_sona/training/
factory.rs

1//! Agent Factory for SONA
2//!
3//! Create and manage multiple specialized agents.
4
5use super::metrics::TrainingMetrics;
6use super::templates::{AgentType, TrainingTemplate};
7use crate::engine::SonaEngine;
8use crate::time_compat::SystemTime;
9use crate::types::SonaConfig;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13
14/// Handle to a managed agent
15#[derive(Clone, Debug)]
16pub struct AgentHandle {
17    /// Agent identifier
18    pub id: String,
19    /// Agent type
20    pub agent_type: AgentType,
21    /// Creation timestamp
22    pub created_at: u64,
23}
24
25/// Managed agent with engine and metadata
26pub struct ManagedAgent {
27    /// Agent handle
28    pub handle: AgentHandle,
29    /// SONA engine
30    pub engine: SonaEngine,
31    /// Training metrics
32    pub metrics: TrainingMetrics,
33    /// Purpose/description
34    pub purpose: String,
35    /// Training count
36    pub training_count: u64,
37    /// Tags for organization
38    pub tags: Vec<String>,
39}
40
41impl ManagedAgent {
42    /// Create a new managed agent
43    pub fn new(
44        id: impl Into<String>,
45        agent_type: AgentType,
46        config: SonaConfig,
47        purpose: impl Into<String>,
48    ) -> Self {
49        let now = SystemTime::now().duration_since_epoch().as_secs();
50
51        let id = id.into();
52        Self {
53            handle: AgentHandle {
54                id: id.clone(),
55                agent_type,
56                created_at: now,
57            },
58            engine: SonaEngine::with_config(config),
59            metrics: TrainingMetrics::new(&id),
60            purpose: purpose.into(),
61            training_count: 0,
62            tags: Vec::new(),
63        }
64    }
65
66    /// Get agent stats
67    pub fn stats(&self) -> AgentStats {
68        AgentStats {
69            id: self.handle.id.clone(),
70            agent_type: self.handle.agent_type.clone(),
71            training_count: self.training_count,
72            patterns_learned: self.metrics.patterns_learned,
73            avg_quality: self.metrics.avg_quality(),
74            total_examples: self.metrics.total_examples,
75        }
76    }
77}
78
79/// Agent statistics
80#[derive(Clone, Debug, Serialize, Deserialize)]
81pub struct AgentStats {
82    /// Agent ID
83    pub id: String,
84    /// Agent type
85    pub agent_type: AgentType,
86    /// Number of training sessions
87    pub training_count: u64,
88    /// Patterns learned
89    pub patterns_learned: usize,
90    /// Average quality score
91    pub avg_quality: f32,
92    /// Total examples processed
93    pub total_examples: usize,
94}
95
96/// Factory for creating and managing agents
97pub struct AgentFactory {
98    /// Base configuration for all agents
99    base_config: SonaConfig,
100    /// Managed agents
101    agents: HashMap<String, ManagedAgent>,
102    /// Default hidden dimension
103    default_hidden_dim: usize,
104}
105
106impl Default for AgentFactory {
107    fn default() -> Self {
108        Self::new(SonaConfig::default())
109    }
110}
111
112impl AgentFactory {
113    /// Create a new agent factory
114    pub fn new(base_config: SonaConfig) -> Self {
115        let default_hidden_dim = base_config.hidden_dim;
116        Self {
117            base_config,
118            agents: HashMap::new(),
119            default_hidden_dim,
120        }
121    }
122
123    /// Create factory with specific hidden dimension
124    pub fn with_hidden_dim(hidden_dim: usize) -> Self {
125        let config = SonaConfig {
126            hidden_dim,
127            embedding_dim: hidden_dim,
128            ..SonaConfig::default()
129        };
130        Self::new(config)
131    }
132
133    /// Create an agent from a template
134    pub fn create_from_template(
135        &mut self,
136        name: impl Into<String>,
137        template: &TrainingTemplate,
138    ) -> &ManagedAgent {
139        let name = name.into();
140        let agent = ManagedAgent::new(
141            name.clone(),
142            template.agent_type.clone(),
143            template.sona_config.clone(),
144            &template.name,
145        );
146        self.agents.insert(name.clone(), agent);
147        self.agents.get(&name).unwrap()
148    }
149
150    /// Create an agent with custom configuration
151    pub fn create_agent(
152        &mut self,
153        name: impl Into<String>,
154        agent_type: AgentType,
155        purpose: impl Into<String>,
156    ) -> &ManagedAgent {
157        let name = name.into();
158        let config = self.config_for_agent_type(&agent_type);
159        let mut agent = ManagedAgent::new(name.clone(), agent_type, config, purpose);
160        agent.tags.push("custom".into());
161        self.agents.insert(name.clone(), agent);
162        self.agents.get(&name).unwrap()
163    }
164
165    /// Create a code agent
166    pub fn create_code_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
167        let template = TrainingTemplate::code_agent().with_hidden_dim(self.default_hidden_dim);
168        self.create_from_template(name, &template)
169    }
170
171    /// Create a chat agent
172    pub fn create_chat_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
173        let template = TrainingTemplate::chat_agent().with_hidden_dim(self.default_hidden_dim);
174        self.create_from_template(name, &template)
175    }
176
177    /// Create a RAG agent
178    pub fn create_rag_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
179        let template = TrainingTemplate::rag_agent().with_hidden_dim(self.default_hidden_dim);
180        self.create_from_template(name, &template)
181    }
182
183    /// Create a task planner agent
184    pub fn create_task_planner(&mut self, name: impl Into<String>) -> &ManagedAgent {
185        let template = TrainingTemplate::task_planner().with_hidden_dim(self.default_hidden_dim);
186        self.create_from_template(name, &template)
187    }
188
189    /// Create a reasoning agent
190    pub fn create_reasoning_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
191        let template = TrainingTemplate::reasoning_agent().with_hidden_dim(self.default_hidden_dim);
192        self.create_from_template(name, &template)
193    }
194
195    /// Create a codebase helper agent
196    pub fn create_codebase_helper(&mut self, name: impl Into<String>) -> &ManagedAgent {
197        let template = TrainingTemplate::codebase_helper().with_hidden_dim(self.default_hidden_dim);
198        self.create_from_template(name, &template)
199    }
200
201    /// Get an agent by name
202    pub fn get_agent(&self, name: &str) -> Option<&ManagedAgent> {
203        self.agents.get(name)
204    }
205
206    /// Get a mutable agent by name
207    pub fn get_agent_mut(&mut self, name: &str) -> Option<&mut ManagedAgent> {
208        self.agents.get_mut(name)
209    }
210
211    /// Remove an agent
212    pub fn remove_agent(&mut self, name: &str) -> Option<ManagedAgent> {
213        self.agents.remove(name)
214    }
215
216    /// List all agents
217    pub fn list_agents(&self) -> Vec<AgentStats> {
218        self.agents.values().map(|a| a.stats()).collect()
219    }
220
221    /// Get agent count
222    pub fn agent_count(&self) -> usize {
223        self.agents.len()
224    }
225
226    /// Train an agent with examples
227    pub fn train_agent<E>(
228        &mut self,
229        name: &str,
230        examples: impl Iterator<Item = E>,
231    ) -> Result<usize, String>
232    where
233        E: TrainingExample,
234    {
235        let agent = self
236            .agents
237            .get_mut(name)
238            .ok_or_else(|| format!("Agent '{}' not found", name))?;
239
240        let mut count = 0;
241        for example in examples {
242            // Use builder-based trajectory API
243            let mut builder = agent.engine.begin_trajectory(example.embedding());
244
245            // Set route if available
246            if let Some(route) = example.route() {
247                builder.set_model_route(&route);
248            }
249
250            // Add context if available
251            for ctx in example.context() {
252                builder.add_context(&ctx);
253            }
254
255            // Add step with activations
256            builder.add_step(example.activations(), example.attention(), example.reward());
257
258            // End trajectory with quality
259            agent.engine.end_trajectory(builder, example.quality());
260
261            count += 1;
262            agent.metrics.total_examples += 1;
263            agent.metrics.add_quality_sample(example.quality());
264        }
265
266        // Force learning after batch
267        agent.engine.force_learn();
268        agent.training_count += 1;
269        agent.metrics.training_sessions += 1;
270
271        Ok(count)
272    }
273
274    /// Get configuration for agent type
275    fn config_for_agent_type(&self, agent_type: &AgentType) -> SonaConfig {
276        let mut config = self.base_config.clone();
277
278        match agent_type {
279            AgentType::CodeAgent | AgentType::CodebaseHelper => {
280                config.base_lora_rank = 16;
281                config.pattern_clusters = 200;
282                config.quality_threshold = 0.2;
283            }
284            AgentType::ChatAgent => {
285                config.base_lora_rank = 8;
286                config.pattern_clusters = 50;
287                config.quality_threshold = 0.4;
288            }
289            AgentType::RagAgent => {
290                config.pattern_clusters = 200;
291                config.trajectory_capacity = 10000;
292            }
293            AgentType::TaskPlanner => {
294                config.base_lora_rank = 16;
295                config.ewc_lambda = 2000.0;
296            }
297            AgentType::ReasoningAgent => {
298                config.base_lora_rank = 16;
299                config.ewc_lambda = 3000.0;
300                config.pattern_clusters = 150;
301            }
302            AgentType::DomainExpert => {
303                config.quality_threshold = 0.1;
304                config.trajectory_capacity = 20000;
305            }
306            AgentType::DataAnalyst => {
307                config.base_lora_rank = 8;
308                config.pattern_clusters = 100;
309            }
310            AgentType::CreativeWriter => {
311                config.base_lora_rank = 8;
312                config.pattern_clusters = 50;
313                config.quality_threshold = 0.5;
314            }
315            _ => {}
316        }
317
318        config
319    }
320}
321
322/// Trait for training examples
323pub trait TrainingExample {
324    /// Get embedding vector
325    fn embedding(&self) -> Vec<f32>;
326
327    /// Get activations (can be same as embedding)
328    fn activations(&self) -> Vec<f32> {
329        self.embedding()
330    }
331
332    /// Get attention weights
333    fn attention(&self) -> Vec<f32> {
334        vec![1.0 / 64.0; 64]
335    }
336
337    /// Get reward signal
338    fn reward(&self) -> f32 {
339        self.quality()
340    }
341
342    /// Get quality score
343    fn quality(&self) -> f32;
344
345    /// Get optional route
346    fn route(&self) -> Option<String> {
347        None
348    }
349
350    /// Get context identifiers
351    fn context(&self) -> Vec<String> {
352        Vec::new()
353    }
354}
355
356/// Simple training example implementation
357#[derive(Clone, Debug)]
358pub struct SimpleExample {
359    /// Embedding vector
360    pub embedding: Vec<f32>,
361    /// Quality score
362    pub quality: f32,
363    /// Optional route
364    pub route: Option<String>,
365    /// Context IDs
366    pub context: Vec<String>,
367}
368
369impl SimpleExample {
370    /// Create a new simple example
371    pub fn new(embedding: Vec<f32>, quality: f32) -> Self {
372        Self {
373            embedding,
374            quality,
375            route: None,
376            context: Vec::new(),
377        }
378    }
379
380    /// Set route
381    pub fn with_route(mut self, route: impl Into<String>) -> Self {
382        self.route = Some(route.into());
383        self
384    }
385
386    /// Add context
387    pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
388        self.context.push(ctx.into());
389        self
390    }
391}
392
393impl TrainingExample for SimpleExample {
394    fn embedding(&self) -> Vec<f32> {
395        self.embedding.clone()
396    }
397
398    fn quality(&self) -> f32 {
399        self.quality
400    }
401
402    fn route(&self) -> Option<String> {
403        self.route.clone()
404    }
405
406    fn context(&self) -> Vec<String> {
407        self.context.clone()
408    }
409}
410
411/// Thread-safe agent factory wrapper
412pub struct SharedAgentFactory {
413    inner: Arc<RwLock<AgentFactory>>,
414}
415
416impl SharedAgentFactory {
417    /// Create a new shared factory
418    pub fn new(config: SonaConfig) -> Self {
419        Self {
420            inner: Arc::new(RwLock::new(AgentFactory::new(config))),
421        }
422    }
423
424    /// Get read access to factory
425    pub fn read(&self) -> std::sync::RwLockReadGuard<'_, AgentFactory> {
426        self.inner.read().unwrap()
427    }
428
429    /// Get write access to factory
430    pub fn write(&self) -> std::sync::RwLockWriteGuard<'_, AgentFactory> {
431        self.inner.write().unwrap()
432    }
433
434    /// Clone the Arc
435    pub fn clone_arc(&self) -> Self {
436        Self {
437            inner: Arc::clone(&self.inner),
438        }
439    }
440}
441
442impl Clone for SharedAgentFactory {
443    fn clone(&self) -> Self {
444        self.clone_arc()
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    #[test]
453    fn test_factory_creation() {
454        let factory = AgentFactory::default();
455        assert_eq!(factory.agent_count(), 0);
456    }
457
458    #[test]
459    fn test_create_agents() {
460        let mut factory = AgentFactory::with_hidden_dim(256);
461
462        factory.create_code_agent("code-1");
463        factory.create_chat_agent("chat-1");
464        factory.create_rag_agent("rag-1");
465
466        assert_eq!(factory.agent_count(), 3);
467        assert!(factory.get_agent("code-1").is_some());
468        assert!(factory.get_agent("unknown").is_none());
469    }
470
471    #[test]
472    fn test_agent_from_template() {
473        let mut factory = AgentFactory::with_hidden_dim(256);
474        let template = TrainingTemplate::reasoning_agent().with_hidden_dim(256);
475
476        factory.create_from_template("reasoner", &template);
477
478        let agent = factory.get_agent("reasoner").unwrap();
479        assert_eq!(agent.handle.agent_type, AgentType::ReasoningAgent);
480    }
481
482    #[test]
483    fn test_train_agent() {
484        let mut factory = AgentFactory::with_hidden_dim(256);
485        factory.create_chat_agent("bot");
486
487        let examples = vec![
488            SimpleExample::new(vec![0.1; 256], 0.8).with_route("greeting"),
489            SimpleExample::new(vec![0.2; 256], 0.9).with_route("question"),
490            SimpleExample::new(vec![0.3; 256], 0.7).with_route("farewell"),
491        ];
492
493        let count = factory.train_agent("bot", examples.into_iter()).unwrap();
494        assert_eq!(count, 3);
495
496        let agent = factory.get_agent("bot").unwrap();
497        assert_eq!(agent.training_count, 1);
498        assert_eq!(agent.metrics.total_examples, 3);
499    }
500
501    #[test]
502    fn test_list_agents() {
503        let mut factory = AgentFactory::with_hidden_dim(256);
504        factory.create_code_agent("code");
505        factory.create_chat_agent("chat");
506
507        let agents = factory.list_agents();
508        assert_eq!(agents.len(), 2);
509    }
510}