1use crate::types::SonaConfig;
6use serde::{Deserialize, Serialize};
7
8#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub enum AgentType {
11 CodeAgent,
13 ChatAgent,
15 RagAgent,
17 TaskPlanner,
19 DomainExpert,
21 CodebaseHelper,
23 DataAnalyst,
25 CreativeWriter,
27 ReasoningAgent,
29 MultiModal,
31 Custom(String),
33}
34
35impl std::fmt::Display for AgentType {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 match self {
38 AgentType::CodeAgent => write!(f, "code-agent"),
39 AgentType::ChatAgent => write!(f, "chat-agent"),
40 AgentType::RagAgent => write!(f, "rag-agent"),
41 AgentType::TaskPlanner => write!(f, "task-planner"),
42 AgentType::DomainExpert => write!(f, "domain-expert"),
43 AgentType::CodebaseHelper => write!(f, "codebase-helper"),
44 AgentType::DataAnalyst => write!(f, "data-analyst"),
45 AgentType::CreativeWriter => write!(f, "creative-writer"),
46 AgentType::ReasoningAgent => write!(f, "reasoning-agent"),
47 AgentType::MultiModal => write!(f, "multi-modal"),
48 AgentType::Custom(name) => write!(f, "custom-{}", name),
49 }
50 }
51}
52
53#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
55pub enum TaskDomain {
56 SoftwareDevelopment,
58 CustomerSupport,
60 Healthcare,
62 Finance,
64 Legal,
66 Education,
68 Research,
70 Marketing,
72 General,
74 Custom(String),
76}
77
78#[derive(Clone, Debug, Serialize, Deserialize)]
80pub enum TrainingMethod {
81 Supervised {
83 batch_size: usize,
85 epochs: usize,
87 },
88 RLHF {
90 reward_weight: f32,
92 kl_penalty: f32,
94 },
95 DPO {
97 beta: f32,
99 ref_weight: f32,
101 },
102 Online {
104 lr_decay: f32,
106 window_size: usize,
108 },
109 FewShot {
111 k_shot: usize,
113 meta_lr: f32,
115 },
116}
117
118impl Default for TrainingMethod {
119 fn default() -> Self {
120 TrainingMethod::Online {
121 lr_decay: 0.999,
122 window_size: 1000,
123 }
124 }
125}
126
127#[derive(Clone, Debug, Serialize, Deserialize)]
129pub struct VerticalConfig {
130 pub domain: TaskDomain,
132 pub vocab_boost: usize,
134 pub quality_metrics: Vec<String>,
136 pub compliance_level: ComplianceLevel,
138}
139
140#[derive(Clone, Debug, Default, Serialize, Deserialize)]
142pub enum ComplianceLevel {
143 #[default]
144 None,
145 Basic,
147 Hipaa,
149 Soc2,
151 Gdpr,
153 Custom(String),
155}
156
157#[derive(Clone, Debug, Serialize, Deserialize)]
159pub enum TemplatePreset {
160 Minimal,
162 Balanced,
164 Production,
166 MaxQuality,
168 Edge,
170 Research,
172}
173
174#[derive(Clone, Debug, Serialize, Deserialize)]
176pub struct TrainingTemplate {
177 pub name: String,
179 pub agent_type: AgentType,
181 pub sona_config: SonaConfig,
183 pub training_method: TrainingMethod,
185 pub vertical: Option<VerticalConfig>,
187 pub expected_data_size: DataSizeHint,
189 pub memory_budget_mb: usize,
191 pub target_latency_us: u64,
193 pub continuous_learning: bool,
195 pub auto_export: bool,
197 pub tags: Vec<String>,
199}
200
201#[derive(Clone, Debug, Serialize, Deserialize)]
203pub enum DataSizeHint {
204 Tiny,
206 Small,
208 Medium,
210 Large,
212 Massive,
214}
215
216impl Default for DataSizeHint {
217 fn default() -> Self {
218 DataSizeHint::Medium
219 }
220}
221
222impl TrainingTemplate {
223 pub fn new(name: impl Into<String>, agent_type: AgentType) -> Self {
225 Self {
226 name: name.into(),
227 agent_type,
228 sona_config: SonaConfig::default(),
229 training_method: TrainingMethod::default(),
230 vertical: None,
231 expected_data_size: DataSizeHint::default(),
232 memory_budget_mb: 100,
233 target_latency_us: 1000,
234 continuous_learning: true,
235 auto_export: false,
236 tags: Vec::new(),
237 }
238 }
239
240 pub fn from_preset(preset: TemplatePreset, agent_type: AgentType) -> Self {
242 let mut template = Self::new(format!("{:?}-{}", preset, agent_type), agent_type.clone());
243
244 match preset {
245 TemplatePreset::Minimal => {
246 template.sona_config = SonaConfig::edge_deployment();
247 template.memory_budget_mb = 10;
248 template.expected_data_size = DataSizeHint::Tiny;
249 }
250 TemplatePreset::Balanced => {
251 template.sona_config = SonaConfig::default();
252 template.memory_budget_mb = 100;
253 }
254 TemplatePreset::Production => {
255 template.sona_config = SonaConfig::max_throughput();
256 template.memory_budget_mb = 200;
257 template.auto_export = true;
258 }
259 TemplatePreset::MaxQuality => {
260 template.sona_config = SonaConfig::max_quality();
261 template.memory_budget_mb = 500;
262 template.expected_data_size = DataSizeHint::Large;
263 }
264 TemplatePreset::Edge => {
265 template.sona_config = SonaConfig::edge_deployment();
266 template.memory_budget_mb = 5;
267 template.target_latency_us = 500;
268 }
269 TemplatePreset::Research => {
270 template.sona_config = SonaConfig::max_quality();
271 template.sona_config.trajectory_capacity = 50000;
272 template.memory_budget_mb = 1000;
273 template.expected_data_size = DataSizeHint::Massive;
274 }
275 }
276
277 template.apply_agent_optimizations();
279 template
280 }
281
282 pub fn code_agent() -> Self {
292 let mut template = Self::new("code-agent", AgentType::CodeAgent);
293 template.sona_config.base_lora_rank = 16; template.sona_config.pattern_clusters = 200; template.sona_config.trajectory_capacity = 10000;
296 template.sona_config.quality_threshold = 0.2; template.training_method = TrainingMethod::Online {
298 lr_decay: 0.9995,
299 window_size: 5000,
300 };
301 template.tags = vec!["code".into(), "development".into(), "completion".into()];
302 template
303 }
304
305 pub fn chat_agent() -> Self {
311 let mut template = Self::new("chat-agent", AgentType::ChatAgent);
312 template.sona_config.base_lora_rank = 8;
313 template.sona_config.pattern_clusters = 50;
314 template.sona_config.quality_threshold = 0.4;
315 template.target_latency_us = 500; template.training_method = TrainingMethod::RLHF {
317 reward_weight: 0.5,
318 kl_penalty: 0.1,
319 };
320 template.tags = vec!["chat".into(), "conversation".into(), "support".into()];
321 template
322 }
323
324 pub fn rag_agent() -> Self {
330 let mut template = Self::new("rag-agent", AgentType::RagAgent);
331 template.sona_config.pattern_clusters = 200; template.sona_config.trajectory_capacity = 10000;
333 template.sona_config.embedding_dim = 512; template.sona_config.hidden_dim = 512;
335 template.training_method = TrainingMethod::Supervised {
336 batch_size: 32,
337 epochs: 10,
338 };
339 template.tags = vec!["rag".into(), "retrieval".into(), "documents".into()];
340 template
341 }
342
343 pub fn task_planner() -> Self {
349 let mut template = Self::new("task-planner", AgentType::TaskPlanner);
350 template.sona_config.base_lora_rank = 16;
351 template.sona_config.ewc_lambda = 2000.0; template.sona_config.pattern_clusters = 100;
353 template.training_method = TrainingMethod::DPO {
354 beta: 0.1,
355 ref_weight: 0.5,
356 };
357 template.tags = vec!["planning".into(), "tasks".into(), "decomposition".into()];
358 template
359 }
360
361 pub fn domain_expert(domain: TaskDomain) -> Self {
367 let domain_name = format!("{:?}", domain).to_lowercase();
368 let mut template = Self::new(format!("domain-expert-{}", domain_name), AgentType::DomainExpert);
369 template.sona_config.quality_threshold = 0.1; template.sona_config.trajectory_capacity = 20000;
371 template.sona_config.base_lora_rank = 16;
372 template.vertical = Some(VerticalConfig {
373 domain: domain.clone(),
374 vocab_boost: 10000,
375 quality_metrics: vec!["accuracy".into(), "relevance".into(), "compliance".into()],
376 compliance_level: match domain {
377 TaskDomain::Healthcare => ComplianceLevel::Hipaa,
378 TaskDomain::Finance => ComplianceLevel::Soc2,
379 TaskDomain::Legal => ComplianceLevel::Basic,
380 _ => ComplianceLevel::None,
381 },
382 });
383 template.tags = vec!["domain".into(), "expert".into(), domain_name];
384 template
385 }
386
387 pub fn codebase_helper() -> Self {
393 let mut template = Self::new("codebase-helper", AgentType::CodebaseHelper);
394 template.sona_config.pattern_clusters = 200;
395 template.sona_config.trajectory_capacity = 10000;
396 template.sona_config.quality_threshold = 0.2;
397 template.sona_config.base_lora_rank = 16;
398 template.expected_data_size = DataSizeHint::Large;
399 template.training_method = TrainingMethod::Online {
400 lr_decay: 0.999,
401 window_size: 10000,
402 };
403 template.tags = vec!["codebase".into(), "repository".into(), "navigation".into()];
404 template
405 }
406
407 pub fn data_analyst() -> Self {
412 let mut template = Self::new("data-analyst", AgentType::DataAnalyst);
413 template.sona_config.base_lora_rank = 8;
414 template.sona_config.pattern_clusters = 100;
415 template.vertical = Some(VerticalConfig {
416 domain: TaskDomain::Research,
417 vocab_boost: 5000,
418 quality_metrics: vec!["accuracy".into(), "insight_quality".into()],
419 compliance_level: ComplianceLevel::None,
420 });
421 template.tags = vec!["data".into(), "analysis".into(), "insights".into()];
422 template
423 }
424
425 pub fn creative_writer() -> Self {
430 let mut template = Self::new("creative-writer", AgentType::CreativeWriter);
431 template.sona_config.base_lora_rank = 8;
432 template.sona_config.pattern_clusters = 50; template.sona_config.quality_threshold = 0.5; template.training_method = TrainingMethod::RLHF {
435 reward_weight: 0.7,
436 kl_penalty: 0.05, };
438 template.vertical = Some(VerticalConfig {
439 domain: TaskDomain::Marketing,
440 vocab_boost: 0,
441 quality_metrics: vec!["creativity".into(), "engagement".into(), "clarity".into()],
442 compliance_level: ComplianceLevel::None,
443 });
444 template.tags = vec!["creative".into(), "writing".into(), "content".into()];
445 template
446 }
447
448 pub fn reasoning_agent() -> Self {
453 let mut template = Self::new("reasoning-agent", AgentType::ReasoningAgent);
454 template.sona_config.base_lora_rank = 16;
455 template.sona_config.ewc_lambda = 3000.0; template.sona_config.pattern_clusters = 150;
457 template.sona_config.quality_threshold = 0.3;
458 template.training_method = TrainingMethod::DPO {
459 beta: 0.15,
460 ref_weight: 0.4,
461 };
462 template.tags = vec!["reasoning".into(), "logic".into(), "math".into()];
463 template
464 }
465
466 pub fn with_sona_config(mut self, config: SonaConfig) -> Self {
472 self.sona_config = config;
473 self
474 }
475
476 pub fn with_training_method(mut self, method: TrainingMethod) -> Self {
478 self.training_method = method;
479 self
480 }
481
482 pub fn with_vertical(mut self, vertical: VerticalConfig) -> Self {
484 self.vertical = Some(vertical);
485 self
486 }
487
488 pub fn with_memory_budget(mut self, mb: usize) -> Self {
490 self.memory_budget_mb = mb;
491 self
492 }
493
494 pub fn with_target_latency(mut self, us: u64) -> Self {
496 self.target_latency_us = us;
497 self
498 }
499
500 pub fn with_continuous_learning(mut self, enabled: bool) -> Self {
502 self.continuous_learning = enabled;
503 self
504 }
505
506 pub fn with_auto_export(mut self, enabled: bool) -> Self {
508 self.auto_export = enabled;
509 self
510 }
511
512 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
514 self.tags = tags;
515 self
516 }
517
518 pub fn with_hidden_dim(mut self, dim: usize) -> Self {
520 self.sona_config.hidden_dim = dim;
521 self.sona_config.embedding_dim = dim;
522 self
523 }
524
525 pub fn with_lora_ranks(mut self, micro: usize, base: usize) -> Self {
527 self.sona_config.micro_lora_rank = micro.min(2); self.sona_config.base_lora_rank = base;
529 self
530 }
531
532 fn apply_agent_optimizations(&mut self) {
538 match &self.agent_type {
539 AgentType::CodeAgent | AgentType::CodebaseHelper => {
540 self.sona_config.pattern_clusters = 200;
541 self.sona_config.base_lora_rank = 16;
542 }
543 AgentType::ChatAgent => {
544 self.sona_config.pattern_clusters = 50;
545 self.target_latency_us = 500;
546 }
547 AgentType::RagAgent => {
548 self.sona_config.pattern_clusters = 200;
549 self.sona_config.trajectory_capacity = 10000;
550 }
551 AgentType::ReasoningAgent => {
552 self.sona_config.ewc_lambda = 3000.0;
553 self.sona_config.base_lora_rank = 16;
554 }
555 AgentType::DomainExpert => {
556 self.sona_config.quality_threshold = 0.1;
557 }
558 _ => {}
559 }
560 }
561
562 pub fn validate(&self) -> Result<(), String> {
564 if self.sona_config.micro_lora_rank > 2 {
565 return Err("MicroLoRA rank must be 1 or 2".into());
566 }
567 if self.sona_config.hidden_dim == 0 {
568 return Err("Hidden dimension must be > 0".into());
569 }
570 if self.memory_budget_mb < 1 {
571 return Err("Memory budget must be >= 1 MB".into());
572 }
573 Ok(())
574 }
575
576 pub fn estimated_memory_mb(&self) -> usize {
578 let config = &self.sona_config;
579
580 let engine_mb = 5;
582
583 let lora_bytes = config.hidden_dim * (config.micro_lora_rank + config.base_lora_rank) * 2 * 4 * 2;
585 let lora_mb = lora_bytes / (1024 * 1024);
586
587 let traj_mb = (config.trajectory_capacity * 800) / (1024 * 1024);
589
590 let pattern_mb = (config.pattern_clusters * config.embedding_dim * 4) / (1024 * 1024);
592
593 engine_mb + lora_mb + traj_mb + pattern_mb + 1
594 }
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600
601 #[test]
602 fn test_template_creation() {
603 let template = TrainingTemplate::code_agent();
604 assert_eq!(template.agent_type, AgentType::CodeAgent);
605 assert_eq!(template.sona_config.base_lora_rank, 16);
606 assert_eq!(template.sona_config.pattern_clusters, 200);
607 }
608
609 #[test]
610 fn test_preset_templates() {
611 let production = TrainingTemplate::from_preset(TemplatePreset::Production, AgentType::ChatAgent);
612 assert!(production.auto_export);
613
614 let edge = TrainingTemplate::from_preset(TemplatePreset::Edge, AgentType::ChatAgent);
615 assert_eq!(edge.memory_budget_mb, 5);
616 }
617
618 #[test]
619 fn test_domain_expert() {
620 let medical = TrainingTemplate::domain_expert(TaskDomain::Healthcare);
621 assert!(medical.vertical.is_some());
622 if let Some(v) = &medical.vertical {
623 assert!(matches!(v.compliance_level, ComplianceLevel::Hipaa));
624 }
625 }
626
627 #[test]
628 fn test_builder_pattern() {
629 let template = TrainingTemplate::new("custom", AgentType::Custom("test".into()))
630 .with_hidden_dim(512)
631 .with_lora_ranks(2, 16)
632 .with_memory_budget(200)
633 .with_continuous_learning(true);
634
635 assert_eq!(template.sona_config.hidden_dim, 512);
636 assert_eq!(template.sona_config.micro_lora_rank, 2);
637 assert_eq!(template.sona_config.base_lora_rank, 16);
638 }
639
640 #[test]
641 fn test_validation() {
642 let mut template = TrainingTemplate::code_agent();
643 assert!(template.validate().is_ok());
644
645 template.sona_config.micro_lora_rank = 5;
646 assert!(template.validate().is_err());
647 }
648
649 #[test]
650 fn test_memory_estimation() {
651 let template = TrainingTemplate::code_agent();
652 let mem = template.estimated_memory_mb();
653 assert!(mem > 0);
654 assert!(mem < template.memory_budget_mb * 2);
655 }
656}