ruvector_sona/training/
templates.rs

1//! Training Templates for SONA
2//!
3//! Pre-configured training setups optimized for different use cases.
4
5use crate::types::SonaConfig;
6use serde::{Deserialize, Serialize};
7
8/// Agent specialization types
9#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub enum AgentType {
11    /// Code generation and assistance
12    CodeAgent,
13    /// General chat and conversation
14    ChatAgent,
15    /// Document retrieval and Q&A
16    RagAgent,
17    /// Task decomposition and planning
18    TaskPlanner,
19    /// Domain-specific expert
20    DomainExpert,
21    /// Codebase-aware assistant
22    CodebaseHelper,
23    /// Data analysis and insights
24    DataAnalyst,
25    /// Creative writing and content
26    CreativeWriter,
27    /// Reasoning and logic
28    ReasoningAgent,
29    /// Multi-modal understanding
30    MultiModal,
31    /// Custom agent type
32    Custom(String),
33}
34
35impl std::fmt::Display for AgentType {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            AgentType::CodeAgent => write!(f, "code-agent"),
39            AgentType::ChatAgent => write!(f, "chat-agent"),
40            AgentType::RagAgent => write!(f, "rag-agent"),
41            AgentType::TaskPlanner => write!(f, "task-planner"),
42            AgentType::DomainExpert => write!(f, "domain-expert"),
43            AgentType::CodebaseHelper => write!(f, "codebase-helper"),
44            AgentType::DataAnalyst => write!(f, "data-analyst"),
45            AgentType::CreativeWriter => write!(f, "creative-writer"),
46            AgentType::ReasoningAgent => write!(f, "reasoning-agent"),
47            AgentType::MultiModal => write!(f, "multi-modal"),
48            AgentType::Custom(name) => write!(f, "custom-{}", name),
49        }
50    }
51}
52
53/// Task domain for training focus
54#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
55pub enum TaskDomain {
56    /// Software development
57    SoftwareDevelopment,
58    /// Customer support
59    CustomerSupport,
60    /// Healthcare
61    Healthcare,
62    /// Finance
63    Finance,
64    /// Legal
65    Legal,
66    /// Education
67    Education,
68    /// Research
69    Research,
70    /// Marketing
71    Marketing,
72    /// General purpose
73    General,
74    /// Custom domain
75    Custom(String),
76}
77
78/// Training method configuration
79#[derive(Clone, Debug, Serialize, Deserialize)]
80pub enum TrainingMethod {
81    /// Standard supervised learning
82    Supervised {
83        /// Batch size for training
84        batch_size: usize,
85        /// Number of epochs
86        epochs: usize,
87    },
88    /// Reinforcement learning from feedback
89    RLHF {
90        /// Reward model weight
91        reward_weight: f32,
92        /// KL divergence penalty
93        kl_penalty: f32,
94    },
95    /// Direct preference optimization
96    DPO {
97        /// Beta parameter for DPO
98        beta: f32,
99        /// Reference model weight
100        ref_weight: f32,
101    },
102    /// Continuous online learning
103    Online {
104        /// Learning rate decay
105        lr_decay: f32,
106        /// Window size for recent examples
107        window_size: usize,
108    },
109    /// Few-shot adaptation
110    FewShot {
111        /// Number of examples per class
112        k_shot: usize,
113        /// Meta-learning rate
114        meta_lr: f32,
115    },
116}
117
118impl Default for TrainingMethod {
119    fn default() -> Self {
120        TrainingMethod::Online {
121            lr_decay: 0.999,
122            window_size: 1000,
123        }
124    }
125}
126
127/// Vertical-specific configuration
128#[derive(Clone, Debug, Serialize, Deserialize)]
129pub struct VerticalConfig {
130    /// Domain focus
131    pub domain: TaskDomain,
132    /// Specialized vocabulary size
133    pub vocab_boost: usize,
134    /// Domain-specific quality metrics
135    pub quality_metrics: Vec<String>,
136    /// Compliance requirements
137    pub compliance_level: ComplianceLevel,
138}
139
140/// Compliance level for regulated industries
141#[derive(Clone, Debug, Default, Serialize, Deserialize)]
142pub enum ComplianceLevel {
143    #[default]
144    None,
145    /// Basic audit logging
146    Basic,
147    /// HIPAA compliance
148    Hipaa,
149    /// SOC2 compliance
150    Soc2,
151    /// GDPR compliance
152    Gdpr,
153    /// Custom compliance
154    Custom(String),
155}
156
157/// Template preset for quick configuration
158#[derive(Clone, Debug, Serialize, Deserialize)]
159pub enum TemplatePreset {
160    /// Minimal configuration for testing
161    Minimal,
162    /// Balanced for general use
163    Balanced,
164    /// High performance for production
165    Production,
166    /// Maximum quality regardless of speed
167    MaxQuality,
168    /// Edge deployment (<5MB)
169    Edge,
170    /// Research and experimentation
171    Research,
172}
173
174/// Training template with full configuration
175#[derive(Clone, Debug, Serialize, Deserialize)]
176pub struct TrainingTemplate {
177    /// Template name
178    pub name: String,
179    /// Agent type
180    pub agent_type: AgentType,
181    /// SONA configuration
182    pub sona_config: SonaConfig,
183    /// Training method
184    pub training_method: TrainingMethod,
185    /// Vertical configuration
186    pub vertical: Option<VerticalConfig>,
187    /// Expected training data size
188    pub expected_data_size: DataSizeHint,
189    /// Memory budget in MB
190    pub memory_budget_mb: usize,
191    /// Target latency in microseconds
192    pub target_latency_us: u64,
193    /// Enable continuous learning
194    pub continuous_learning: bool,
195    /// Auto-export trained adapters
196    pub auto_export: bool,
197    /// Tags for organization
198    pub tags: Vec<String>,
199}
200
201/// Hint about training data size
202#[derive(Clone, Debug, Serialize, Deserialize)]
203pub enum DataSizeHint {
204    /// <100 examples (few-shot)
205    Tiny,
206    /// 100-1000 examples
207    Small,
208    /// 1000-10000 examples
209    Medium,
210    /// 10000-100000 examples
211    Large,
212    /// >100000 examples
213    Massive,
214}
215
216impl Default for DataSizeHint {
217    fn default() -> Self {
218        DataSizeHint::Medium
219    }
220}
221
222impl TrainingTemplate {
223    /// Create a new training template
224    pub fn new(name: impl Into<String>, agent_type: AgentType) -> Self {
225        Self {
226            name: name.into(),
227            agent_type,
228            sona_config: SonaConfig::default(),
229            training_method: TrainingMethod::default(),
230            vertical: None,
231            expected_data_size: DataSizeHint::default(),
232            memory_budget_mb: 100,
233            target_latency_us: 1000,
234            continuous_learning: true,
235            auto_export: false,
236            tags: Vec::new(),
237        }
238    }
239
240    /// Create from preset
241    pub fn from_preset(preset: TemplatePreset, agent_type: AgentType) -> Self {
242        let mut template = Self::new(format!("{:?}-{}", preset, agent_type), agent_type.clone());
243
244        match preset {
245            TemplatePreset::Minimal => {
246                template.sona_config = SonaConfig::edge_deployment();
247                template.memory_budget_mb = 10;
248                template.expected_data_size = DataSizeHint::Tiny;
249            }
250            TemplatePreset::Balanced => {
251                template.sona_config = SonaConfig::default();
252                template.memory_budget_mb = 100;
253            }
254            TemplatePreset::Production => {
255                template.sona_config = SonaConfig::max_throughput();
256                template.memory_budget_mb = 200;
257                template.auto_export = true;
258            }
259            TemplatePreset::MaxQuality => {
260                template.sona_config = SonaConfig::max_quality();
261                template.memory_budget_mb = 500;
262                template.expected_data_size = DataSizeHint::Large;
263            }
264            TemplatePreset::Edge => {
265                template.sona_config = SonaConfig::edge_deployment();
266                template.memory_budget_mb = 5;
267                template.target_latency_us = 500;
268            }
269            TemplatePreset::Research => {
270                template.sona_config = SonaConfig::max_quality();
271                template.sona_config.trajectory_capacity = 50000;
272                template.memory_budget_mb = 1000;
273                template.expected_data_size = DataSizeHint::Massive;
274            }
275        }
276
277        // Apply agent-specific optimizations
278        template.apply_agent_optimizations();
279        template
280    }
281
282    //------------------------------------------------------------------
283    // Pre-built Templates for Common Use Cases
284    //------------------------------------------------------------------
285
286    /// Code agent template - optimized for code generation
287    ///
288    /// **Best for**: Code completion, bug fixes, refactoring
289    /// **Config**: baseLoraRank=16, clusters=200, capacity=10000
290    /// **Training data**: Code completions, fixes, reviews
291    pub fn code_agent() -> Self {
292        let mut template = Self::new("code-agent", AgentType::CodeAgent);
293        template.sona_config.base_lora_rank = 16;  // Deeper for code patterns
294        template.sona_config.pattern_clusters = 200;  // Many code patterns
295        template.sona_config.trajectory_capacity = 10000;
296        template.sona_config.quality_threshold = 0.2;  // Learn from most examples
297        template.training_method = TrainingMethod::Online {
298            lr_decay: 0.9995,
299            window_size: 5000,
300        };
301        template.tags = vec!["code".into(), "development".into(), "completion".into()];
302        template
303    }
304
305    /// Chat agent template - optimized for conversational AI
306    ///
307    /// **Best for**: Customer support, general chat, assistants
308    /// **Config**: baseLoraRank=8, clusters=50, fast response
309    /// **Training data**: Conversation histories, feedback
310    pub fn chat_agent() -> Self {
311        let mut template = Self::new("chat-agent", AgentType::ChatAgent);
312        template.sona_config.base_lora_rank = 8;
313        template.sona_config.pattern_clusters = 50;
314        template.sona_config.quality_threshold = 0.4;
315        template.target_latency_us = 500;  // Fast responses
316        template.training_method = TrainingMethod::RLHF {
317            reward_weight: 0.5,
318            kl_penalty: 0.1,
319        };
320        template.tags = vec!["chat".into(), "conversation".into(), "support".into()];
321        template
322    }
323
324    /// RAG agent template - optimized for retrieval-augmented generation
325    ///
326    /// **Best for**: Document Q&A, knowledge bases, search
327    /// **Config**: clusters=200, capacity=10000, high pattern storage
328    /// **Training data**: Document chunks, Q&A pairs
329    pub fn rag_agent() -> Self {
330        let mut template = Self::new("rag-agent", AgentType::RagAgent);
331        template.sona_config.pattern_clusters = 200;  // Many document patterns
332        template.sona_config.trajectory_capacity = 10000;
333        template.sona_config.embedding_dim = 512;  // Larger embeddings for retrieval
334        template.sona_config.hidden_dim = 512;
335        template.training_method = TrainingMethod::Supervised {
336            batch_size: 32,
337            epochs: 10,
338        };
339        template.tags = vec!["rag".into(), "retrieval".into(), "documents".into()];
340        template
341    }
342
343    /// Task planner template - optimized for task decomposition
344    ///
345    /// **Best for**: Project planning, task breakdown, scheduling
346    /// **Config**: baseLoraRank=16, ewcLambda=2000, multi-task
347    /// **Training data**: Task decompositions, planning examples
348    pub fn task_planner() -> Self {
349        let mut template = Self::new("task-planner", AgentType::TaskPlanner);
350        template.sona_config.base_lora_rank = 16;
351        template.sona_config.ewc_lambda = 2000.0;  // Important for multi-task
352        template.sona_config.pattern_clusters = 100;
353        template.training_method = TrainingMethod::DPO {
354            beta: 0.1,
355            ref_weight: 0.5,
356        };
357        template.tags = vec!["planning".into(), "tasks".into(), "decomposition".into()];
358        template
359    }
360
361    /// Domain expert template - optimized for specialized knowledge
362    ///
363    /// **Best for**: Legal, medical, financial expertise
364    /// **Config**: qualityThreshold=0.1, high capacity, compliance
365    /// **Training data**: Domain-specific Q&A, expert responses
366    pub fn domain_expert(domain: TaskDomain) -> Self {
367        let domain_name = format!("{:?}", domain).to_lowercase();
368        let mut template = Self::new(format!("domain-expert-{}", domain_name), AgentType::DomainExpert);
369        template.sona_config.quality_threshold = 0.1;  // Learn from all domain examples
370        template.sona_config.trajectory_capacity = 20000;
371        template.sona_config.base_lora_rank = 16;
372        template.vertical = Some(VerticalConfig {
373            domain: domain.clone(),
374            vocab_boost: 10000,
375            quality_metrics: vec!["accuracy".into(), "relevance".into(), "compliance".into()],
376            compliance_level: match domain {
377                TaskDomain::Healthcare => ComplianceLevel::Hipaa,
378                TaskDomain::Finance => ComplianceLevel::Soc2,
379                TaskDomain::Legal => ComplianceLevel::Basic,
380                _ => ComplianceLevel::None,
381            },
382        });
383        template.tags = vec!["domain".into(), "expert".into(), domain_name];
384        template
385    }
386
387    /// Codebase helper template - learns your specific codebase
388    ///
389    /// **Best for**: Repository-specific assistance, code navigation
390    /// **Config**: clusters=200, capacity=10000, high pattern storage
391    /// **Training data**: Your repo's code, documentation
392    pub fn codebase_helper() -> Self {
393        let mut template = Self::new("codebase-helper", AgentType::CodebaseHelper);
394        template.sona_config.pattern_clusters = 200;
395        template.sona_config.trajectory_capacity = 10000;
396        template.sona_config.quality_threshold = 0.2;
397        template.sona_config.base_lora_rank = 16;
398        template.expected_data_size = DataSizeHint::Large;
399        template.training_method = TrainingMethod::Online {
400            lr_decay: 0.999,
401            window_size: 10000,
402        };
403        template.tags = vec!["codebase".into(), "repository".into(), "navigation".into()];
404        template
405    }
406
407    /// Data analyst template - optimized for data insights
408    ///
409    /// **Best for**: Data analysis, visualization, statistics
410    /// **Config**: baseLoraRank=8, clusters=100, reasoning focus
411    pub fn data_analyst() -> Self {
412        let mut template = Self::new("data-analyst", AgentType::DataAnalyst);
413        template.sona_config.base_lora_rank = 8;
414        template.sona_config.pattern_clusters = 100;
415        template.vertical = Some(VerticalConfig {
416            domain: TaskDomain::Research,
417            vocab_boost: 5000,
418            quality_metrics: vec!["accuracy".into(), "insight_quality".into()],
419            compliance_level: ComplianceLevel::None,
420        });
421        template.tags = vec!["data".into(), "analysis".into(), "insights".into()];
422        template
423    }
424
425    /// Creative writer template - optimized for content generation
426    ///
427    /// **Best for**: Marketing copy, blog posts, creative writing
428    /// **Config**: High diversity, quality focus
429    pub fn creative_writer() -> Self {
430        let mut template = Self::new("creative-writer", AgentType::CreativeWriter);
431        template.sona_config.base_lora_rank = 8;
432        template.sona_config.pattern_clusters = 50;  // Fewer clusters for diversity
433        template.sona_config.quality_threshold = 0.5;  // Only learn from high quality
434        template.training_method = TrainingMethod::RLHF {
435            reward_weight: 0.7,
436            kl_penalty: 0.05,  // Less constraint for creativity
437        };
438        template.vertical = Some(VerticalConfig {
439            domain: TaskDomain::Marketing,
440            vocab_boost: 0,
441            quality_metrics: vec!["creativity".into(), "engagement".into(), "clarity".into()],
442            compliance_level: ComplianceLevel::None,
443        });
444        template.tags = vec!["creative".into(), "writing".into(), "content".into()];
445        template
446    }
447
448    /// Reasoning agent template - optimized for logical reasoning
449    ///
450    /// **Best for**: Math, logic, chain-of-thought reasoning
451    /// **Config**: High rank, strong EWC, accuracy focus
452    pub fn reasoning_agent() -> Self {
453        let mut template = Self::new("reasoning-agent", AgentType::ReasoningAgent);
454        template.sona_config.base_lora_rank = 16;
455        template.sona_config.ewc_lambda = 3000.0;  // Strong protection
456        template.sona_config.pattern_clusters = 150;
457        template.sona_config.quality_threshold = 0.3;
458        template.training_method = TrainingMethod::DPO {
459            beta: 0.15,
460            ref_weight: 0.4,
461        };
462        template.tags = vec!["reasoning".into(), "logic".into(), "math".into()];
463        template
464    }
465
466    //------------------------------------------------------------------
467    // Builder Methods
468    //------------------------------------------------------------------
469
470    /// Set SONA configuration
471    pub fn with_sona_config(mut self, config: SonaConfig) -> Self {
472        self.sona_config = config;
473        self
474    }
475
476    /// Set training method
477    pub fn with_training_method(mut self, method: TrainingMethod) -> Self {
478        self.training_method = method;
479        self
480    }
481
482    /// Set vertical configuration
483    pub fn with_vertical(mut self, vertical: VerticalConfig) -> Self {
484        self.vertical = Some(vertical);
485        self
486    }
487
488    /// Set memory budget
489    pub fn with_memory_budget(mut self, mb: usize) -> Self {
490        self.memory_budget_mb = mb;
491        self
492    }
493
494    /// Set target latency
495    pub fn with_target_latency(mut self, us: u64) -> Self {
496        self.target_latency_us = us;
497        self
498    }
499
500    /// Enable continuous learning
501    pub fn with_continuous_learning(mut self, enabled: bool) -> Self {
502        self.continuous_learning = enabled;
503        self
504    }
505
506    /// Enable auto-export
507    pub fn with_auto_export(mut self, enabled: bool) -> Self {
508        self.auto_export = enabled;
509        self
510    }
511
512    /// Add tags
513    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
514        self.tags = tags;
515        self
516    }
517
518    /// Set hidden dimension
519    pub fn with_hidden_dim(mut self, dim: usize) -> Self {
520        self.sona_config.hidden_dim = dim;
521        self.sona_config.embedding_dim = dim;
522        self
523    }
524
525    /// Set LoRA ranks
526    pub fn with_lora_ranks(mut self, micro: usize, base: usize) -> Self {
527        self.sona_config.micro_lora_rank = micro.min(2);  // MicroLoRA max rank is 2
528        self.sona_config.base_lora_rank = base;
529        self
530    }
531
532    //------------------------------------------------------------------
533    // Internal Methods
534    //------------------------------------------------------------------
535
536    /// Apply agent-specific optimizations
537    fn apply_agent_optimizations(&mut self) {
538        match &self.agent_type {
539            AgentType::CodeAgent | AgentType::CodebaseHelper => {
540                self.sona_config.pattern_clusters = 200;
541                self.sona_config.base_lora_rank = 16;
542            }
543            AgentType::ChatAgent => {
544                self.sona_config.pattern_clusters = 50;
545                self.target_latency_us = 500;
546            }
547            AgentType::RagAgent => {
548                self.sona_config.pattern_clusters = 200;
549                self.sona_config.trajectory_capacity = 10000;
550            }
551            AgentType::ReasoningAgent => {
552                self.sona_config.ewc_lambda = 3000.0;
553                self.sona_config.base_lora_rank = 16;
554            }
555            AgentType::DomainExpert => {
556                self.sona_config.quality_threshold = 0.1;
557            }
558            _ => {}
559        }
560    }
561
562    /// Validate template configuration
563    pub fn validate(&self) -> Result<(), String> {
564        if self.sona_config.micro_lora_rank > 2 {
565            return Err("MicroLoRA rank must be 1 or 2".into());
566        }
567        if self.sona_config.hidden_dim == 0 {
568            return Err("Hidden dimension must be > 0".into());
569        }
570        if self.memory_budget_mb < 1 {
571            return Err("Memory budget must be >= 1 MB".into());
572        }
573        Ok(())
574    }
575
576    /// Get estimated memory usage in MB
577    pub fn estimated_memory_mb(&self) -> usize {
578        let config = &self.sona_config;
579
580        // Base engine memory
581        let engine_mb = 5;
582
583        // LoRA weights: hidden_dim * rank * 2 (A and B matrices) * 4 bytes * 2 (micro + base)
584        let lora_bytes = config.hidden_dim * (config.micro_lora_rank + config.base_lora_rank) * 2 * 4 * 2;
585        let lora_mb = lora_bytes / (1024 * 1024);
586
587        // Trajectory buffer: capacity * ~800 bytes per trajectory
588        let traj_mb = (config.trajectory_capacity * 800) / (1024 * 1024);
589
590        // Pattern storage: clusters * embedding_dim * 4 bytes
591        let pattern_mb = (config.pattern_clusters * config.embedding_dim * 4) / (1024 * 1024);
592
593        engine_mb + lora_mb + traj_mb + pattern_mb + 1
594    }
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    #[test]
602    fn test_template_creation() {
603        let template = TrainingTemplate::code_agent();
604        assert_eq!(template.agent_type, AgentType::CodeAgent);
605        assert_eq!(template.sona_config.base_lora_rank, 16);
606        assert_eq!(template.sona_config.pattern_clusters, 200);
607    }
608
609    #[test]
610    fn test_preset_templates() {
611        let production = TrainingTemplate::from_preset(TemplatePreset::Production, AgentType::ChatAgent);
612        assert!(production.auto_export);
613
614        let edge = TrainingTemplate::from_preset(TemplatePreset::Edge, AgentType::ChatAgent);
615        assert_eq!(edge.memory_budget_mb, 5);
616    }
617
618    #[test]
619    fn test_domain_expert() {
620        let medical = TrainingTemplate::domain_expert(TaskDomain::Healthcare);
621        assert!(medical.vertical.is_some());
622        if let Some(v) = &medical.vertical {
623            assert!(matches!(v.compliance_level, ComplianceLevel::Hipaa));
624        }
625    }
626
627    #[test]
628    fn test_builder_pattern() {
629        let template = TrainingTemplate::new("custom", AgentType::Custom("test".into()))
630            .with_hidden_dim(512)
631            .with_lora_ranks(2, 16)
632            .with_memory_budget(200)
633            .with_continuous_learning(true);
634
635        assert_eq!(template.sona_config.hidden_dim, 512);
636        assert_eq!(template.sona_config.micro_lora_rank, 2);
637        assert_eq!(template.sona_config.base_lora_rank, 16);
638    }
639
640    #[test]
641    fn test_validation() {
642        let mut template = TrainingTemplate::code_agent();
643        assert!(template.validate().is_ok());
644
645        template.sona_config.micro_lora_rank = 5;
646        assert!(template.validate().is_err());
647    }
648
649    #[test]
650    fn test_memory_estimation() {
651        let template = TrainingTemplate::code_agent();
652        let mem = template.estimated_memory_mb();
653        assert!(mem > 0);
654        assert!(mem < template.memory_budget_mb * 2);
655    }
656}