Skip to main content

ruvector_sona/training/
templates.rs

1//! Training Templates for SONA
2//!
3//! Pre-configured training setups optimized for different use cases.
4
5use crate::types::SonaConfig;
6use serde::{Deserialize, Serialize};
7
8/// Agent specialization types
9#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub enum AgentType {
11    /// Code generation and assistance
12    CodeAgent,
13    /// General chat and conversation
14    ChatAgent,
15    /// Document retrieval and Q&A
16    RagAgent,
17    /// Task decomposition and planning
18    TaskPlanner,
19    /// Domain-specific expert
20    DomainExpert,
21    /// Codebase-aware assistant
22    CodebaseHelper,
23    /// Data analysis and insights
24    DataAnalyst,
25    /// Creative writing and content
26    CreativeWriter,
27    /// Reasoning and logic
28    ReasoningAgent,
29    /// Multi-modal understanding
30    MultiModal,
31    /// Custom agent type
32    Custom(String),
33}
34
35impl std::fmt::Display for AgentType {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            AgentType::CodeAgent => write!(f, "code-agent"),
39            AgentType::ChatAgent => write!(f, "chat-agent"),
40            AgentType::RagAgent => write!(f, "rag-agent"),
41            AgentType::TaskPlanner => write!(f, "task-planner"),
42            AgentType::DomainExpert => write!(f, "domain-expert"),
43            AgentType::CodebaseHelper => write!(f, "codebase-helper"),
44            AgentType::DataAnalyst => write!(f, "data-analyst"),
45            AgentType::CreativeWriter => write!(f, "creative-writer"),
46            AgentType::ReasoningAgent => write!(f, "reasoning-agent"),
47            AgentType::MultiModal => write!(f, "multi-modal"),
48            AgentType::Custom(name) => write!(f, "custom-{}", name),
49        }
50    }
51}
52
53/// Task domain for training focus
54#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
55pub enum TaskDomain {
56    /// Software development
57    SoftwareDevelopment,
58    /// Customer support
59    CustomerSupport,
60    /// Healthcare
61    Healthcare,
62    /// Finance
63    Finance,
64    /// Legal
65    Legal,
66    /// Education
67    Education,
68    /// Research
69    Research,
70    /// Marketing
71    Marketing,
72    /// General purpose
73    General,
74    /// Custom domain
75    Custom(String),
76}
77
78/// Training method configuration
79#[derive(Clone, Debug, Serialize, Deserialize)]
80pub enum TrainingMethod {
81    /// Standard supervised learning
82    Supervised {
83        /// Batch size for training
84        batch_size: usize,
85        /// Number of epochs
86        epochs: usize,
87    },
88    /// Reinforcement learning from feedback
89    RLHF {
90        /// Reward model weight
91        reward_weight: f32,
92        /// KL divergence penalty
93        kl_penalty: f32,
94    },
95    /// Direct preference optimization
96    DPO {
97        /// Beta parameter for DPO
98        beta: f32,
99        /// Reference model weight
100        ref_weight: f32,
101    },
102    /// Continuous online learning
103    Online {
104        /// Learning rate decay
105        lr_decay: f32,
106        /// Window size for recent examples
107        window_size: usize,
108    },
109    /// Few-shot adaptation
110    FewShot {
111        /// Number of examples per class
112        k_shot: usize,
113        /// Meta-learning rate
114        meta_lr: f32,
115    },
116}
117
118impl Default for TrainingMethod {
119    fn default() -> Self {
120        TrainingMethod::Online {
121            lr_decay: 0.999,
122            window_size: 1000,
123        }
124    }
125}
126
127/// Vertical-specific configuration
128#[derive(Clone, Debug, Serialize, Deserialize)]
129pub struct VerticalConfig {
130    /// Domain focus
131    pub domain: TaskDomain,
132    /// Specialized vocabulary size
133    pub vocab_boost: usize,
134    /// Domain-specific quality metrics
135    pub quality_metrics: Vec<String>,
136    /// Compliance requirements
137    pub compliance_level: ComplianceLevel,
138}
139
140/// Compliance level for regulated industries
141#[derive(Clone, Debug, Default, Serialize, Deserialize)]
142pub enum ComplianceLevel {
143    #[default]
144    None,
145    /// Basic audit logging
146    Basic,
147    /// HIPAA compliance
148    Hipaa,
149    /// SOC2 compliance
150    Soc2,
151    /// GDPR compliance
152    Gdpr,
153    /// Custom compliance
154    Custom(String),
155}
156
157/// Template preset for quick configuration
158#[derive(Clone, Debug, Serialize, Deserialize)]
159pub enum TemplatePreset {
160    /// Minimal configuration for testing
161    Minimal,
162    /// Balanced for general use
163    Balanced,
164    /// High performance for production
165    Production,
166    /// Maximum quality regardless of speed
167    MaxQuality,
168    /// Edge deployment (<5MB)
169    Edge,
170    /// Research and experimentation
171    Research,
172}
173
174/// Training template with full configuration
175#[derive(Clone, Debug, Serialize, Deserialize)]
176pub struct TrainingTemplate {
177    /// Template name
178    pub name: String,
179    /// Agent type
180    pub agent_type: AgentType,
181    /// SONA configuration
182    pub sona_config: SonaConfig,
183    /// Training method
184    pub training_method: TrainingMethod,
185    /// Vertical configuration
186    pub vertical: Option<VerticalConfig>,
187    /// Expected training data size
188    pub expected_data_size: DataSizeHint,
189    /// Memory budget in MB
190    pub memory_budget_mb: usize,
191    /// Target latency in microseconds
192    pub target_latency_us: u64,
193    /// Enable continuous learning
194    pub continuous_learning: bool,
195    /// Auto-export trained adapters
196    pub auto_export: bool,
197    /// Tags for organization
198    pub tags: Vec<String>,
199}
200
201/// Hint about training data size
202#[derive(Clone, Debug, Serialize, Deserialize)]
203pub enum DataSizeHint {
204    /// <100 examples (few-shot)
205    Tiny,
206    /// 100-1000 examples
207    Small,
208    /// 1000-10000 examples
209    Medium,
210    /// 10000-100000 examples
211    Large,
212    /// >100000 examples
213    Massive,
214}
215
216impl Default for DataSizeHint {
217    fn default() -> Self {
218        DataSizeHint::Medium
219    }
220}
221
222impl TrainingTemplate {
223    /// Create a new training template
224    pub fn new(name: impl Into<String>, agent_type: AgentType) -> Self {
225        Self {
226            name: name.into(),
227            agent_type,
228            sona_config: SonaConfig::default(),
229            training_method: TrainingMethod::default(),
230            vertical: None,
231            expected_data_size: DataSizeHint::default(),
232            memory_budget_mb: 100,
233            target_latency_us: 1000,
234            continuous_learning: true,
235            auto_export: false,
236            tags: Vec::new(),
237        }
238    }
239
240    /// Create from preset
241    pub fn from_preset(preset: TemplatePreset, agent_type: AgentType) -> Self {
242        let mut template = Self::new(format!("{:?}-{}", preset, agent_type), agent_type.clone());
243
244        match preset {
245            TemplatePreset::Minimal => {
246                template.sona_config = SonaConfig::edge_deployment();
247                template.memory_budget_mb = 10;
248                template.expected_data_size = DataSizeHint::Tiny;
249            }
250            TemplatePreset::Balanced => {
251                template.sona_config = SonaConfig::default();
252                template.memory_budget_mb = 100;
253            }
254            TemplatePreset::Production => {
255                template.sona_config = SonaConfig::max_throughput();
256                template.memory_budget_mb = 200;
257                template.auto_export = true;
258            }
259            TemplatePreset::MaxQuality => {
260                template.sona_config = SonaConfig::max_quality();
261                template.memory_budget_mb = 500;
262                template.expected_data_size = DataSizeHint::Large;
263            }
264            TemplatePreset::Edge => {
265                template.sona_config = SonaConfig::edge_deployment();
266                template.memory_budget_mb = 5;
267                template.target_latency_us = 500;
268            }
269            TemplatePreset::Research => {
270                template.sona_config = SonaConfig::max_quality();
271                template.sona_config.trajectory_capacity = 50000;
272                template.memory_budget_mb = 1000;
273                template.expected_data_size = DataSizeHint::Massive;
274            }
275        }
276
277        // Apply agent-specific optimizations
278        template.apply_agent_optimizations();
279        template
280    }
281
282    //------------------------------------------------------------------
283    // Pre-built Templates for Common Use Cases
284    //------------------------------------------------------------------
285
286    /// Code agent template - optimized for code generation
287    ///
288    /// **Best for**: Code completion, bug fixes, refactoring
289    /// **Config**: baseLoraRank=16, clusters=200, capacity=10000
290    /// **Training data**: Code completions, fixes, reviews
291    pub fn code_agent() -> Self {
292        let mut template = Self::new("code-agent", AgentType::CodeAgent);
293        template.sona_config.base_lora_rank = 16; // Deeper for code patterns
294        template.sona_config.pattern_clusters = 200; // Many code patterns
295        template.sona_config.trajectory_capacity = 10000;
296        template.sona_config.quality_threshold = 0.2; // Learn from most examples
297        template.training_method = TrainingMethod::Online {
298            lr_decay: 0.9995,
299            window_size: 5000,
300        };
301        template.tags = vec!["code".into(), "development".into(), "completion".into()];
302        template
303    }
304
305    /// Chat agent template - optimized for conversational AI
306    ///
307    /// **Best for**: Customer support, general chat, assistants
308    /// **Config**: baseLoraRank=8, clusters=50, fast response
309    /// **Training data**: Conversation histories, feedback
310    pub fn chat_agent() -> Self {
311        let mut template = Self::new("chat-agent", AgentType::ChatAgent);
312        template.sona_config.base_lora_rank = 8;
313        template.sona_config.pattern_clusters = 50;
314        template.sona_config.quality_threshold = 0.4;
315        template.target_latency_us = 500; // Fast responses
316        template.training_method = TrainingMethod::RLHF {
317            reward_weight: 0.5,
318            kl_penalty: 0.1,
319        };
320        template.tags = vec!["chat".into(), "conversation".into(), "support".into()];
321        template
322    }
323
324    /// RAG agent template - optimized for retrieval-augmented generation
325    ///
326    /// **Best for**: Document Q&A, knowledge bases, search
327    /// **Config**: clusters=200, capacity=10000, high pattern storage
328    /// **Training data**: Document chunks, Q&A pairs
329    pub fn rag_agent() -> Self {
330        let mut template = Self::new("rag-agent", AgentType::RagAgent);
331        template.sona_config.pattern_clusters = 200; // Many document patterns
332        template.sona_config.trajectory_capacity = 10000;
333        template.sona_config.embedding_dim = 512; // Larger embeddings for retrieval
334        template.sona_config.hidden_dim = 512;
335        template.training_method = TrainingMethod::Supervised {
336            batch_size: 32,
337            epochs: 10,
338        };
339        template.tags = vec!["rag".into(), "retrieval".into(), "documents".into()];
340        template
341    }
342
343    /// Task planner template - optimized for task decomposition
344    ///
345    /// **Best for**: Project planning, task breakdown, scheduling
346    /// **Config**: baseLoraRank=16, ewcLambda=2000, multi-task
347    /// **Training data**: Task decompositions, planning examples
348    pub fn task_planner() -> Self {
349        let mut template = Self::new("task-planner", AgentType::TaskPlanner);
350        template.sona_config.base_lora_rank = 16;
351        template.sona_config.ewc_lambda = 2000.0; // Important for multi-task
352        template.sona_config.pattern_clusters = 100;
353        template.training_method = TrainingMethod::DPO {
354            beta: 0.1,
355            ref_weight: 0.5,
356        };
357        template.tags = vec!["planning".into(), "tasks".into(), "decomposition".into()];
358        template
359    }
360
361    /// Domain expert template - optimized for specialized knowledge
362    ///
363    /// **Best for**: Legal, medical, financial expertise
364    /// **Config**: qualityThreshold=0.1, high capacity, compliance
365    /// **Training data**: Domain-specific Q&A, expert responses
366    pub fn domain_expert(domain: TaskDomain) -> Self {
367        let domain_name = format!("{:?}", domain).to_lowercase();
368        let mut template = Self::new(
369            format!("domain-expert-{}", domain_name),
370            AgentType::DomainExpert,
371        );
372        template.sona_config.quality_threshold = 0.1; // Learn from all domain examples
373        template.sona_config.trajectory_capacity = 20000;
374        template.sona_config.base_lora_rank = 16;
375        template.vertical = Some(VerticalConfig {
376            domain: domain.clone(),
377            vocab_boost: 10000,
378            quality_metrics: vec!["accuracy".into(), "relevance".into(), "compliance".into()],
379            compliance_level: match domain {
380                TaskDomain::Healthcare => ComplianceLevel::Hipaa,
381                TaskDomain::Finance => ComplianceLevel::Soc2,
382                TaskDomain::Legal => ComplianceLevel::Basic,
383                _ => ComplianceLevel::None,
384            },
385        });
386        template.tags = vec!["domain".into(), "expert".into(), domain_name];
387        template
388    }
389
390    /// Codebase helper template - learns your specific codebase
391    ///
392    /// **Best for**: Repository-specific assistance, code navigation
393    /// **Config**: clusters=200, capacity=10000, high pattern storage
394    /// **Training data**: Your repo's code, documentation
395    pub fn codebase_helper() -> Self {
396        let mut template = Self::new("codebase-helper", AgentType::CodebaseHelper);
397        template.sona_config.pattern_clusters = 200;
398        template.sona_config.trajectory_capacity = 10000;
399        template.sona_config.quality_threshold = 0.2;
400        template.sona_config.base_lora_rank = 16;
401        template.expected_data_size = DataSizeHint::Large;
402        template.training_method = TrainingMethod::Online {
403            lr_decay: 0.999,
404            window_size: 10000,
405        };
406        template.tags = vec!["codebase".into(), "repository".into(), "navigation".into()];
407        template
408    }
409
410    /// Data analyst template - optimized for data insights
411    ///
412    /// **Best for**: Data analysis, visualization, statistics
413    /// **Config**: baseLoraRank=8, clusters=100, reasoning focus
414    pub fn data_analyst() -> Self {
415        let mut template = Self::new("data-analyst", AgentType::DataAnalyst);
416        template.sona_config.base_lora_rank = 8;
417        template.sona_config.pattern_clusters = 100;
418        template.vertical = Some(VerticalConfig {
419            domain: TaskDomain::Research,
420            vocab_boost: 5000,
421            quality_metrics: vec!["accuracy".into(), "insight_quality".into()],
422            compliance_level: ComplianceLevel::None,
423        });
424        template.tags = vec!["data".into(), "analysis".into(), "insights".into()];
425        template
426    }
427
428    /// Creative writer template - optimized for content generation
429    ///
430    /// **Best for**: Marketing copy, blog posts, creative writing
431    /// **Config**: High diversity, quality focus
432    pub fn creative_writer() -> Self {
433        let mut template = Self::new("creative-writer", AgentType::CreativeWriter);
434        template.sona_config.base_lora_rank = 8;
435        template.sona_config.pattern_clusters = 50; // Fewer clusters for diversity
436        template.sona_config.quality_threshold = 0.5; // Only learn from high quality
437        template.training_method = TrainingMethod::RLHF {
438            reward_weight: 0.7,
439            kl_penalty: 0.05, // Less constraint for creativity
440        };
441        template.vertical = Some(VerticalConfig {
442            domain: TaskDomain::Marketing,
443            vocab_boost: 0,
444            quality_metrics: vec!["creativity".into(), "engagement".into(), "clarity".into()],
445            compliance_level: ComplianceLevel::None,
446        });
447        template.tags = vec!["creative".into(), "writing".into(), "content".into()];
448        template
449    }
450
451    /// Reasoning agent template - optimized for logical reasoning
452    ///
453    /// **Best for**: Math, logic, chain-of-thought reasoning
454    /// **Config**: High rank, strong EWC, accuracy focus
455    pub fn reasoning_agent() -> Self {
456        let mut template = Self::new("reasoning-agent", AgentType::ReasoningAgent);
457        template.sona_config.base_lora_rank = 16;
458        template.sona_config.ewc_lambda = 3000.0; // Strong protection
459        template.sona_config.pattern_clusters = 150;
460        template.sona_config.quality_threshold = 0.3;
461        template.training_method = TrainingMethod::DPO {
462            beta: 0.15,
463            ref_weight: 0.4,
464        };
465        template.tags = vec!["reasoning".into(), "logic".into(), "math".into()];
466        template
467    }
468
469    //------------------------------------------------------------------
470    // Builder Methods
471    //------------------------------------------------------------------
472
473    /// Set SONA configuration
474    pub fn with_sona_config(mut self, config: SonaConfig) -> Self {
475        self.sona_config = config;
476        self
477    }
478
479    /// Set training method
480    pub fn with_training_method(mut self, method: TrainingMethod) -> Self {
481        self.training_method = method;
482        self
483    }
484
485    /// Set vertical configuration
486    pub fn with_vertical(mut self, vertical: VerticalConfig) -> Self {
487        self.vertical = Some(vertical);
488        self
489    }
490
491    /// Set memory budget
492    pub fn with_memory_budget(mut self, mb: usize) -> Self {
493        self.memory_budget_mb = mb;
494        self
495    }
496
497    /// Set target latency
498    pub fn with_target_latency(mut self, us: u64) -> Self {
499        self.target_latency_us = us;
500        self
501    }
502
503    /// Enable continuous learning
504    pub fn with_continuous_learning(mut self, enabled: bool) -> Self {
505        self.continuous_learning = enabled;
506        self
507    }
508
509    /// Enable auto-export
510    pub fn with_auto_export(mut self, enabled: bool) -> Self {
511        self.auto_export = enabled;
512        self
513    }
514
515    /// Add tags
516    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
517        self.tags = tags;
518        self
519    }
520
521    /// Set hidden dimension
522    pub fn with_hidden_dim(mut self, dim: usize) -> Self {
523        self.sona_config.hidden_dim = dim;
524        self.sona_config.embedding_dim = dim;
525        self
526    }
527
528    /// Set LoRA ranks
529    pub fn with_lora_ranks(mut self, micro: usize, base: usize) -> Self {
530        self.sona_config.micro_lora_rank = micro.min(2); // MicroLoRA max rank is 2
531        self.sona_config.base_lora_rank = base;
532        self
533    }
534
535    //------------------------------------------------------------------
536    // Internal Methods
537    //------------------------------------------------------------------
538
539    /// Apply agent-specific optimizations
540    fn apply_agent_optimizations(&mut self) {
541        match &self.agent_type {
542            AgentType::CodeAgent | AgentType::CodebaseHelper => {
543                self.sona_config.pattern_clusters = 200;
544                self.sona_config.base_lora_rank = 16;
545            }
546            AgentType::ChatAgent => {
547                self.sona_config.pattern_clusters = 50;
548                self.target_latency_us = 500;
549            }
550            AgentType::RagAgent => {
551                self.sona_config.pattern_clusters = 200;
552                self.sona_config.trajectory_capacity = 10000;
553            }
554            AgentType::ReasoningAgent => {
555                self.sona_config.ewc_lambda = 3000.0;
556                self.sona_config.base_lora_rank = 16;
557            }
558            AgentType::DomainExpert => {
559                self.sona_config.quality_threshold = 0.1;
560            }
561            _ => {}
562        }
563    }
564
565    /// Validate template configuration
566    pub fn validate(&self) -> Result<(), String> {
567        if self.sona_config.micro_lora_rank > 2 {
568            return Err("MicroLoRA rank must be 1 or 2".into());
569        }
570        if self.sona_config.hidden_dim == 0 {
571            return Err("Hidden dimension must be > 0".into());
572        }
573        if self.memory_budget_mb < 1 {
574            return Err("Memory budget must be >= 1 MB".into());
575        }
576        Ok(())
577    }
578
579    /// Get estimated memory usage in MB
580    pub fn estimated_memory_mb(&self) -> usize {
581        let config = &self.sona_config;
582
583        // Base engine memory
584        let engine_mb = 5;
585
586        // LoRA weights: hidden_dim * rank * 2 (A and B matrices) * 4 bytes * 2 (micro + base)
587        let lora_bytes =
588            config.hidden_dim * (config.micro_lora_rank + config.base_lora_rank) * 2 * 4 * 2;
589        let lora_mb = lora_bytes / (1024 * 1024);
590
591        // Trajectory buffer: capacity * ~800 bytes per trajectory
592        let traj_mb = (config.trajectory_capacity * 800) / (1024 * 1024);
593
594        // Pattern storage: clusters * embedding_dim * 4 bytes
595        let pattern_mb = (config.pattern_clusters * config.embedding_dim * 4) / (1024 * 1024);
596
597        engine_mb + lora_mb + traj_mb + pattern_mb + 1
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604
605    #[test]
606    fn test_template_creation() {
607        let template = TrainingTemplate::code_agent();
608        assert_eq!(template.agent_type, AgentType::CodeAgent);
609        assert_eq!(template.sona_config.base_lora_rank, 16);
610        assert_eq!(template.sona_config.pattern_clusters, 200);
611    }
612
613    #[test]
614    fn test_preset_templates() {
615        let production =
616            TrainingTemplate::from_preset(TemplatePreset::Production, AgentType::ChatAgent);
617        assert!(production.auto_export);
618
619        let edge = TrainingTemplate::from_preset(TemplatePreset::Edge, AgentType::ChatAgent);
620        assert_eq!(edge.memory_budget_mb, 5);
621    }
622
623    #[test]
624    fn test_domain_expert() {
625        let medical = TrainingTemplate::domain_expert(TaskDomain::Healthcare);
626        assert!(medical.vertical.is_some());
627        if let Some(v) = &medical.vertical {
628            assert!(matches!(v.compliance_level, ComplianceLevel::Hipaa));
629        }
630    }
631
632    #[test]
633    fn test_builder_pattern() {
634        let template = TrainingTemplate::new("custom", AgentType::Custom("test".into()))
635            .with_hidden_dim(512)
636            .with_lora_ranks(2, 16)
637            .with_memory_budget(200)
638            .with_continuous_learning(true);
639
640        assert_eq!(template.sona_config.hidden_dim, 512);
641        assert_eq!(template.sona_config.micro_lora_rank, 2);
642        assert_eq!(template.sona_config.base_lora_rank, 16);
643    }
644
645    #[test]
646    fn test_validation() {
647        let mut template = TrainingTemplate::code_agent();
648        assert!(template.validate().is_ok());
649
650        template.sona_config.micro_lora_rank = 5;
651        assert!(template.validate().is_err());
652    }
653
654    #[test]
655    fn test_memory_estimation() {
656        let template = TrainingTemplate::code_agent();
657        let mem = template.estimated_memory_mb();
658        assert!(mem > 0);
659        assert!(mem < template.memory_budget_mb * 2);
660    }
661}