Skip to main content

ruvector_sona/training/
templates.rs

1//! Training Templates for SONA
2//!
3//! Pre-configured training setups optimized for different use cases.
4
5use crate::types::SonaConfig;
6use serde::{Deserialize, Serialize};
7
8/// Agent specialization types
9#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub enum AgentType {
11    /// Code generation and assistance
12    CodeAgent,
13    /// General chat and conversation
14    ChatAgent,
15    /// Document retrieval and Q&A
16    RagAgent,
17    /// Task decomposition and planning
18    TaskPlanner,
19    /// Domain-specific expert
20    DomainExpert,
21    /// Codebase-aware assistant
22    CodebaseHelper,
23    /// Data analysis and insights
24    DataAnalyst,
25    /// Creative writing and content
26    CreativeWriter,
27    /// Reasoning and logic
28    ReasoningAgent,
29    /// Multi-modal understanding
30    MultiModal,
31    /// Custom agent type
32    Custom(String),
33}
34
35impl std::fmt::Display for AgentType {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            AgentType::CodeAgent => write!(f, "code-agent"),
39            AgentType::ChatAgent => write!(f, "chat-agent"),
40            AgentType::RagAgent => write!(f, "rag-agent"),
41            AgentType::TaskPlanner => write!(f, "task-planner"),
42            AgentType::DomainExpert => write!(f, "domain-expert"),
43            AgentType::CodebaseHelper => write!(f, "codebase-helper"),
44            AgentType::DataAnalyst => write!(f, "data-analyst"),
45            AgentType::CreativeWriter => write!(f, "creative-writer"),
46            AgentType::ReasoningAgent => write!(f, "reasoning-agent"),
47            AgentType::MultiModal => write!(f, "multi-modal"),
48            AgentType::Custom(name) => write!(f, "custom-{}", name),
49        }
50    }
51}
52
53/// Task domain for training focus
54#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
55pub enum TaskDomain {
56    /// Software development
57    SoftwareDevelopment,
58    /// Customer support
59    CustomerSupport,
60    /// Healthcare
61    Healthcare,
62    /// Finance
63    Finance,
64    /// Legal
65    Legal,
66    /// Education
67    Education,
68    /// Research
69    Research,
70    /// Marketing
71    Marketing,
72    /// General purpose
73    General,
74    /// Custom domain
75    Custom(String),
76}
77
78/// Training method configuration
79#[derive(Clone, Debug, Serialize, Deserialize)]
80pub enum TrainingMethod {
81    /// Standard supervised learning
82    Supervised {
83        /// Batch size for training
84        batch_size: usize,
85        /// Number of epochs
86        epochs: usize,
87    },
88    /// Reinforcement learning from feedback
89    RLHF {
90        /// Reward model weight
91        reward_weight: f32,
92        /// KL divergence penalty
93        kl_penalty: f32,
94    },
95    /// Direct preference optimization
96    DPO {
97        /// Beta parameter for DPO
98        beta: f32,
99        /// Reference model weight
100        ref_weight: f32,
101    },
102    /// Continuous online learning
103    Online {
104        /// Learning rate decay
105        lr_decay: f32,
106        /// Window size for recent examples
107        window_size: usize,
108    },
109    /// Few-shot adaptation
110    FewShot {
111        /// Number of examples per class
112        k_shot: usize,
113        /// Meta-learning rate
114        meta_lr: f32,
115    },
116}
117
118impl Default for TrainingMethod {
119    fn default() -> Self {
120        TrainingMethod::Online {
121            lr_decay: 0.999,
122            window_size: 1000,
123        }
124    }
125}
126
127/// Vertical-specific configuration
128#[derive(Clone, Debug, Serialize, Deserialize)]
129pub struct VerticalConfig {
130    /// Domain focus
131    pub domain: TaskDomain,
132    /// Specialized vocabulary size
133    pub vocab_boost: usize,
134    /// Domain-specific quality metrics
135    pub quality_metrics: Vec<String>,
136    /// Compliance requirements
137    pub compliance_level: ComplianceLevel,
138}
139
140/// Compliance level for regulated industries
141#[derive(Clone, Debug, Default, Serialize, Deserialize)]
142pub enum ComplianceLevel {
143    #[default]
144    None,
145    /// Basic audit logging
146    Basic,
147    /// HIPAA compliance
148    Hipaa,
149    /// SOC2 compliance
150    Soc2,
151    /// GDPR compliance
152    Gdpr,
153    /// Custom compliance
154    Custom(String),
155}
156
157/// Template preset for quick configuration
158#[derive(Clone, Debug, Serialize, Deserialize)]
159pub enum TemplatePreset {
160    /// Minimal configuration for testing
161    Minimal,
162    /// Balanced for general use
163    Balanced,
164    /// High performance for production
165    Production,
166    /// Maximum quality regardless of speed
167    MaxQuality,
168    /// Edge deployment (<5MB)
169    Edge,
170    /// Research and experimentation
171    Research,
172}
173
174/// Training template with full configuration
175#[derive(Clone, Debug, Serialize, Deserialize)]
176pub struct TrainingTemplate {
177    /// Template name
178    pub name: String,
179    /// Agent type
180    pub agent_type: AgentType,
181    /// SONA configuration
182    pub sona_config: SonaConfig,
183    /// Training method
184    pub training_method: TrainingMethod,
185    /// Vertical configuration
186    pub vertical: Option<VerticalConfig>,
187    /// Expected training data size
188    pub expected_data_size: DataSizeHint,
189    /// Memory budget in MB
190    pub memory_budget_mb: usize,
191    /// Target latency in microseconds
192    pub target_latency_us: u64,
193    /// Enable continuous learning
194    pub continuous_learning: bool,
195    /// Auto-export trained adapters
196    pub auto_export: bool,
197    /// Tags for organization
198    pub tags: Vec<String>,
199}
200
201/// Hint about training data size
202#[derive(Clone, Debug, Default, Serialize, Deserialize)]
203pub enum DataSizeHint {
204    /// <100 examples (few-shot)
205    Tiny,
206    /// 100-1000 examples
207    Small,
208    /// 1000-10000 examples
209    #[default]
210    Medium,
211    /// 10000-100000 examples
212    Large,
213    /// >100000 examples
214    Massive,
215}
216
217impl TrainingTemplate {
218    /// Create a new training template
219    pub fn new(name: impl Into<String>, agent_type: AgentType) -> Self {
220        Self {
221            name: name.into(),
222            agent_type,
223            sona_config: SonaConfig::default(),
224            training_method: TrainingMethod::default(),
225            vertical: None,
226            expected_data_size: DataSizeHint::default(),
227            memory_budget_mb: 100,
228            target_latency_us: 1000,
229            continuous_learning: true,
230            auto_export: false,
231            tags: Vec::new(),
232        }
233    }
234
235    /// Create from preset
236    pub fn from_preset(preset: TemplatePreset, agent_type: AgentType) -> Self {
237        let mut template = Self::new(format!("{:?}-{}", preset, agent_type), agent_type.clone());
238
239        match preset {
240            TemplatePreset::Minimal => {
241                template.sona_config = SonaConfig::edge_deployment();
242                template.memory_budget_mb = 10;
243                template.expected_data_size = DataSizeHint::Tiny;
244            }
245            TemplatePreset::Balanced => {
246                template.sona_config = SonaConfig::default();
247                template.memory_budget_mb = 100;
248            }
249            TemplatePreset::Production => {
250                template.sona_config = SonaConfig::max_throughput();
251                template.memory_budget_mb = 200;
252                template.auto_export = true;
253            }
254            TemplatePreset::MaxQuality => {
255                template.sona_config = SonaConfig::max_quality();
256                template.memory_budget_mb = 500;
257                template.expected_data_size = DataSizeHint::Large;
258            }
259            TemplatePreset::Edge => {
260                template.sona_config = SonaConfig::edge_deployment();
261                template.memory_budget_mb = 5;
262                template.target_latency_us = 500;
263            }
264            TemplatePreset::Research => {
265                template.sona_config = SonaConfig::max_quality();
266                template.sona_config.trajectory_capacity = 50000;
267                template.memory_budget_mb = 1000;
268                template.expected_data_size = DataSizeHint::Massive;
269            }
270        }
271
272        // Apply agent-specific optimizations
273        template.apply_agent_optimizations();
274        template
275    }
276
277    //------------------------------------------------------------------
278    // Pre-built Templates for Common Use Cases
279    //------------------------------------------------------------------
280
281    /// Code agent template - optimized for code generation
282    ///
283    /// **Best for**: Code completion, bug fixes, refactoring
284    /// **Config**: baseLoraRank=16, clusters=200, capacity=10000
285    /// **Training data**: Code completions, fixes, reviews
286    pub fn code_agent() -> Self {
287        let mut template = Self::new("code-agent", AgentType::CodeAgent);
288        template.sona_config.base_lora_rank = 16; // Deeper for code patterns
289        template.sona_config.pattern_clusters = 200; // Many code patterns
290        template.sona_config.trajectory_capacity = 10000;
291        template.sona_config.quality_threshold = 0.2; // Learn from most examples
292        template.training_method = TrainingMethod::Online {
293            lr_decay: 0.9995,
294            window_size: 5000,
295        };
296        template.tags = vec!["code".into(), "development".into(), "completion".into()];
297        template
298    }
299
300    /// Chat agent template - optimized for conversational AI
301    ///
302    /// **Best for**: Customer support, general chat, assistants
303    /// **Config**: baseLoraRank=8, clusters=50, fast response
304    /// **Training data**: Conversation histories, feedback
305    pub fn chat_agent() -> Self {
306        let mut template = Self::new("chat-agent", AgentType::ChatAgent);
307        template.sona_config.base_lora_rank = 8;
308        template.sona_config.pattern_clusters = 50;
309        template.sona_config.quality_threshold = 0.4;
310        template.target_latency_us = 500; // Fast responses
311        template.training_method = TrainingMethod::RLHF {
312            reward_weight: 0.5,
313            kl_penalty: 0.1,
314        };
315        template.tags = vec!["chat".into(), "conversation".into(), "support".into()];
316        template
317    }
318
319    /// RAG agent template - optimized for retrieval-augmented generation
320    ///
321    /// **Best for**: Document Q&A, knowledge bases, search
322    /// **Config**: clusters=200, capacity=10000, high pattern storage
323    /// **Training data**: Document chunks, Q&A pairs
324    pub fn rag_agent() -> Self {
325        let mut template = Self::new("rag-agent", AgentType::RagAgent);
326        template.sona_config.pattern_clusters = 200; // Many document patterns
327        template.sona_config.trajectory_capacity = 10000;
328        template.sona_config.embedding_dim = 512; // Larger embeddings for retrieval
329        template.sona_config.hidden_dim = 512;
330        template.training_method = TrainingMethod::Supervised {
331            batch_size: 32,
332            epochs: 10,
333        };
334        template.tags = vec!["rag".into(), "retrieval".into(), "documents".into()];
335        template
336    }
337
338    /// Task planner template - optimized for task decomposition
339    ///
340    /// **Best for**: Project planning, task breakdown, scheduling
341    /// **Config**: baseLoraRank=16, ewcLambda=2000, multi-task
342    /// **Training data**: Task decompositions, planning examples
343    pub fn task_planner() -> Self {
344        let mut template = Self::new("task-planner", AgentType::TaskPlanner);
345        template.sona_config.base_lora_rank = 16;
346        template.sona_config.ewc_lambda = 2000.0; // Important for multi-task
347        template.sona_config.pattern_clusters = 100;
348        template.training_method = TrainingMethod::DPO {
349            beta: 0.1,
350            ref_weight: 0.5,
351        };
352        template.tags = vec!["planning".into(), "tasks".into(), "decomposition".into()];
353        template
354    }
355
356    /// Domain expert template - optimized for specialized knowledge
357    ///
358    /// **Best for**: Legal, medical, financial expertise
359    /// **Config**: qualityThreshold=0.1, high capacity, compliance
360    /// **Training data**: Domain-specific Q&A, expert responses
361    pub fn domain_expert(domain: TaskDomain) -> Self {
362        let domain_name = format!("{:?}", domain).to_lowercase();
363        let mut template = Self::new(
364            format!("domain-expert-{}", domain_name),
365            AgentType::DomainExpert,
366        );
367        template.sona_config.quality_threshold = 0.1; // Learn from all domain examples
368        template.sona_config.trajectory_capacity = 20000;
369        template.sona_config.base_lora_rank = 16;
370        template.vertical = Some(VerticalConfig {
371            domain: domain.clone(),
372            vocab_boost: 10000,
373            quality_metrics: vec!["accuracy".into(), "relevance".into(), "compliance".into()],
374            compliance_level: match domain {
375                TaskDomain::Healthcare => ComplianceLevel::Hipaa,
376                TaskDomain::Finance => ComplianceLevel::Soc2,
377                TaskDomain::Legal => ComplianceLevel::Basic,
378                _ => ComplianceLevel::None,
379            },
380        });
381        template.tags = vec!["domain".into(), "expert".into(), domain_name];
382        template
383    }
384
385    /// Codebase helper template - learns your specific codebase
386    ///
387    /// **Best for**: Repository-specific assistance, code navigation
388    /// **Config**: clusters=200, capacity=10000, high pattern storage
389    /// **Training data**: Your repo's code, documentation
390    pub fn codebase_helper() -> Self {
391        let mut template = Self::new("codebase-helper", AgentType::CodebaseHelper);
392        template.sona_config.pattern_clusters = 200;
393        template.sona_config.trajectory_capacity = 10000;
394        template.sona_config.quality_threshold = 0.2;
395        template.sona_config.base_lora_rank = 16;
396        template.expected_data_size = DataSizeHint::Large;
397        template.training_method = TrainingMethod::Online {
398            lr_decay: 0.999,
399            window_size: 10000,
400        };
401        template.tags = vec!["codebase".into(), "repository".into(), "navigation".into()];
402        template
403    }
404
405    /// Data analyst template - optimized for data insights
406    ///
407    /// **Best for**: Data analysis, visualization, statistics
408    /// **Config**: baseLoraRank=8, clusters=100, reasoning focus
409    pub fn data_analyst() -> Self {
410        let mut template = Self::new("data-analyst", AgentType::DataAnalyst);
411        template.sona_config.base_lora_rank = 8;
412        template.sona_config.pattern_clusters = 100;
413        template.vertical = Some(VerticalConfig {
414            domain: TaskDomain::Research,
415            vocab_boost: 5000,
416            quality_metrics: vec!["accuracy".into(), "insight_quality".into()],
417            compliance_level: ComplianceLevel::None,
418        });
419        template.tags = vec!["data".into(), "analysis".into(), "insights".into()];
420        template
421    }
422
423    /// Creative writer template - optimized for content generation
424    ///
425    /// **Best for**: Marketing copy, blog posts, creative writing
426    /// **Config**: High diversity, quality focus
427    pub fn creative_writer() -> Self {
428        let mut template = Self::new("creative-writer", AgentType::CreativeWriter);
429        template.sona_config.base_lora_rank = 8;
430        template.sona_config.pattern_clusters = 50; // Fewer clusters for diversity
431        template.sona_config.quality_threshold = 0.5; // Only learn from high quality
432        template.training_method = TrainingMethod::RLHF {
433            reward_weight: 0.7,
434            kl_penalty: 0.05, // Less constraint for creativity
435        };
436        template.vertical = Some(VerticalConfig {
437            domain: TaskDomain::Marketing,
438            vocab_boost: 0,
439            quality_metrics: vec!["creativity".into(), "engagement".into(), "clarity".into()],
440            compliance_level: ComplianceLevel::None,
441        });
442        template.tags = vec!["creative".into(), "writing".into(), "content".into()];
443        template
444    }
445
446    /// Reasoning agent template - optimized for logical reasoning
447    ///
448    /// **Best for**: Math, logic, chain-of-thought reasoning
449    /// **Config**: High rank, strong EWC, accuracy focus
450    pub fn reasoning_agent() -> Self {
451        let mut template = Self::new("reasoning-agent", AgentType::ReasoningAgent);
452        template.sona_config.base_lora_rank = 16;
453        template.sona_config.ewc_lambda = 3000.0; // Strong protection
454        template.sona_config.pattern_clusters = 150;
455        template.sona_config.quality_threshold = 0.3;
456        template.training_method = TrainingMethod::DPO {
457            beta: 0.15,
458            ref_weight: 0.4,
459        };
460        template.tags = vec!["reasoning".into(), "logic".into(), "math".into()];
461        template
462    }
463
464    //------------------------------------------------------------------
465    // Builder Methods
466    //------------------------------------------------------------------
467
468    /// Set SONA configuration
469    pub fn with_sona_config(mut self, config: SonaConfig) -> Self {
470        self.sona_config = config;
471        self
472    }
473
474    /// Set training method
475    pub fn with_training_method(mut self, method: TrainingMethod) -> Self {
476        self.training_method = method;
477        self
478    }
479
480    /// Set vertical configuration
481    pub fn with_vertical(mut self, vertical: VerticalConfig) -> Self {
482        self.vertical = Some(vertical);
483        self
484    }
485
486    /// Set memory budget
487    pub fn with_memory_budget(mut self, mb: usize) -> Self {
488        self.memory_budget_mb = mb;
489        self
490    }
491
492    /// Set target latency
493    pub fn with_target_latency(mut self, us: u64) -> Self {
494        self.target_latency_us = us;
495        self
496    }
497
498    /// Enable continuous learning
499    pub fn with_continuous_learning(mut self, enabled: bool) -> Self {
500        self.continuous_learning = enabled;
501        self
502    }
503
504    /// Enable auto-export
505    pub fn with_auto_export(mut self, enabled: bool) -> Self {
506        self.auto_export = enabled;
507        self
508    }
509
510    /// Add tags
511    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
512        self.tags = tags;
513        self
514    }
515
516    /// Set hidden dimension
517    pub fn with_hidden_dim(mut self, dim: usize) -> Self {
518        self.sona_config.hidden_dim = dim;
519        self.sona_config.embedding_dim = dim;
520        self
521    }
522
523    /// Set LoRA ranks
524    pub fn with_lora_ranks(mut self, micro: usize, base: usize) -> Self {
525        self.sona_config.micro_lora_rank = micro.min(2); // MicroLoRA max rank is 2
526        self.sona_config.base_lora_rank = base;
527        self
528    }
529
530    //------------------------------------------------------------------
531    // Internal Methods
532    //------------------------------------------------------------------
533
534    /// Apply agent-specific optimizations
535    fn apply_agent_optimizations(&mut self) {
536        match &self.agent_type {
537            AgentType::CodeAgent | AgentType::CodebaseHelper => {
538                self.sona_config.pattern_clusters = 200;
539                self.sona_config.base_lora_rank = 16;
540            }
541            AgentType::ChatAgent => {
542                self.sona_config.pattern_clusters = 50;
543                self.target_latency_us = 500;
544            }
545            AgentType::RagAgent => {
546                self.sona_config.pattern_clusters = 200;
547                self.sona_config.trajectory_capacity = 10000;
548            }
549            AgentType::ReasoningAgent => {
550                self.sona_config.ewc_lambda = 3000.0;
551                self.sona_config.base_lora_rank = 16;
552            }
553            AgentType::DomainExpert => {
554                self.sona_config.quality_threshold = 0.1;
555            }
556            _ => {}
557        }
558    }
559
560    /// Validate template configuration
561    pub fn validate(&self) -> Result<(), String> {
562        if self.sona_config.micro_lora_rank > 2 {
563            return Err("MicroLoRA rank must be 1 or 2".into());
564        }
565        if self.sona_config.hidden_dim == 0 {
566            return Err("Hidden dimension must be > 0".into());
567        }
568        if self.memory_budget_mb < 1 {
569            return Err("Memory budget must be >= 1 MB".into());
570        }
571        Ok(())
572    }
573
574    /// Get estimated memory usage in MB
575    pub fn estimated_memory_mb(&self) -> usize {
576        let config = &self.sona_config;
577
578        // Base engine memory
579        let engine_mb = 5;
580
581        // LoRA weights: hidden_dim * rank * 2 (A and B matrices) * 4 bytes * 2 (micro + base)
582        let lora_bytes =
583            config.hidden_dim * (config.micro_lora_rank + config.base_lora_rank) * 2 * 4 * 2;
584        let lora_mb = lora_bytes / (1024 * 1024);
585
586        // Trajectory buffer: capacity * ~800 bytes per trajectory
587        let traj_mb = (config.trajectory_capacity * 800) / (1024 * 1024);
588
589        // Pattern storage: clusters * embedding_dim * 4 bytes
590        let pattern_mb = (config.pattern_clusters * config.embedding_dim * 4) / (1024 * 1024);
591
592        engine_mb + lora_mb + traj_mb + pattern_mb + 1
593    }
594}
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599
600    #[test]
601    fn test_template_creation() {
602        let template = TrainingTemplate::code_agent();
603        assert_eq!(template.agent_type, AgentType::CodeAgent);
604        assert_eq!(template.sona_config.base_lora_rank, 16);
605        assert_eq!(template.sona_config.pattern_clusters, 200);
606    }
607
608    #[test]
609    fn test_preset_templates() {
610        let production =
611            TrainingTemplate::from_preset(TemplatePreset::Production, AgentType::ChatAgent);
612        assert!(production.auto_export);
613
614        let edge = TrainingTemplate::from_preset(TemplatePreset::Edge, AgentType::ChatAgent);
615        assert_eq!(edge.memory_budget_mb, 5);
616    }
617
618    #[test]
619    fn test_domain_expert() {
620        let medical = TrainingTemplate::domain_expert(TaskDomain::Healthcare);
621        assert!(medical.vertical.is_some());
622        if let Some(v) = &medical.vertical {
623            assert!(matches!(v.compliance_level, ComplianceLevel::Hipaa));
624        }
625    }
626
627    #[test]
628    fn test_builder_pattern() {
629        let template = TrainingTemplate::new("custom", AgentType::Custom("test".into()))
630            .with_hidden_dim(512)
631            .with_lora_ranks(2, 16)
632            .with_memory_budget(200)
633            .with_continuous_learning(true);
634
635        assert_eq!(template.sona_config.hidden_dim, 512);
636        assert_eq!(template.sona_config.micro_lora_rank, 2);
637        assert_eq!(template.sona_config.base_lora_rank, 16);
638    }
639
640    #[test]
641    fn test_validation() {
642        let mut template = TrainingTemplate::code_agent();
643        assert!(template.validate().is_ok());
644
645        template.sona_config.micro_lora_rank = 5;
646        assert!(template.validate().is_err());
647    }
648
649    #[test]
650    fn test_memory_estimation() {
651        let template = TrainingTemplate::code_agent();
652        let mem = template.estimated_memory_mb();
653        assert!(mem > 0);
654        assert!(mem < template.memory_budget_mb * 2);
655    }
656}