1use crate::types::SonaConfig;
6use serde::{Deserialize, Serialize};
7
8#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub enum AgentType {
11 CodeAgent,
13 ChatAgent,
15 RagAgent,
17 TaskPlanner,
19 DomainExpert,
21 CodebaseHelper,
23 DataAnalyst,
25 CreativeWriter,
27 ReasoningAgent,
29 MultiModal,
31 Custom(String),
33}
34
35impl std::fmt::Display for AgentType {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 match self {
38 AgentType::CodeAgent => write!(f, "code-agent"),
39 AgentType::ChatAgent => write!(f, "chat-agent"),
40 AgentType::RagAgent => write!(f, "rag-agent"),
41 AgentType::TaskPlanner => write!(f, "task-planner"),
42 AgentType::DomainExpert => write!(f, "domain-expert"),
43 AgentType::CodebaseHelper => write!(f, "codebase-helper"),
44 AgentType::DataAnalyst => write!(f, "data-analyst"),
45 AgentType::CreativeWriter => write!(f, "creative-writer"),
46 AgentType::ReasoningAgent => write!(f, "reasoning-agent"),
47 AgentType::MultiModal => write!(f, "multi-modal"),
48 AgentType::Custom(name) => write!(f, "custom-{}", name),
49 }
50 }
51}
52
53#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
55pub enum TaskDomain {
56 SoftwareDevelopment,
58 CustomerSupport,
60 Healthcare,
62 Finance,
64 Legal,
66 Education,
68 Research,
70 Marketing,
72 General,
74 Custom(String),
76}
77
78#[derive(Clone, Debug, Serialize, Deserialize)]
80pub enum TrainingMethod {
81 Supervised {
83 batch_size: usize,
85 epochs: usize,
87 },
88 RLHF {
90 reward_weight: f32,
92 kl_penalty: f32,
94 },
95 DPO {
97 beta: f32,
99 ref_weight: f32,
101 },
102 Online {
104 lr_decay: f32,
106 window_size: usize,
108 },
109 FewShot {
111 k_shot: usize,
113 meta_lr: f32,
115 },
116}
117
118impl Default for TrainingMethod {
119 fn default() -> Self {
120 TrainingMethod::Online {
121 lr_decay: 0.999,
122 window_size: 1000,
123 }
124 }
125}
126
127#[derive(Clone, Debug, Serialize, Deserialize)]
129pub struct VerticalConfig {
130 pub domain: TaskDomain,
132 pub vocab_boost: usize,
134 pub quality_metrics: Vec<String>,
136 pub compliance_level: ComplianceLevel,
138}
139
140#[derive(Clone, Debug, Default, Serialize, Deserialize)]
142pub enum ComplianceLevel {
143 #[default]
144 None,
145 Basic,
147 Hipaa,
149 Soc2,
151 Gdpr,
153 Custom(String),
155}
156
157#[derive(Clone, Debug, Serialize, Deserialize)]
159pub enum TemplatePreset {
160 Minimal,
162 Balanced,
164 Production,
166 MaxQuality,
168 Edge,
170 Research,
172}
173
174#[derive(Clone, Debug, Serialize, Deserialize)]
176pub struct TrainingTemplate {
177 pub name: String,
179 pub agent_type: AgentType,
181 pub sona_config: SonaConfig,
183 pub training_method: TrainingMethod,
185 pub vertical: Option<VerticalConfig>,
187 pub expected_data_size: DataSizeHint,
189 pub memory_budget_mb: usize,
191 pub target_latency_us: u64,
193 pub continuous_learning: bool,
195 pub auto_export: bool,
197 pub tags: Vec<String>,
199}
200
201#[derive(Clone, Debug, Default, Serialize, Deserialize)]
203pub enum DataSizeHint {
204 Tiny,
206 Small,
208 #[default]
210 Medium,
211 Large,
213 Massive,
215}
216
217impl TrainingTemplate {
218 pub fn new(name: impl Into<String>, agent_type: AgentType) -> Self {
220 Self {
221 name: name.into(),
222 agent_type,
223 sona_config: SonaConfig::default(),
224 training_method: TrainingMethod::default(),
225 vertical: None,
226 expected_data_size: DataSizeHint::default(),
227 memory_budget_mb: 100,
228 target_latency_us: 1000,
229 continuous_learning: true,
230 auto_export: false,
231 tags: Vec::new(),
232 }
233 }
234
235 pub fn from_preset(preset: TemplatePreset, agent_type: AgentType) -> Self {
237 let mut template = Self::new(format!("{:?}-{}", preset, agent_type), agent_type.clone());
238
239 match preset {
240 TemplatePreset::Minimal => {
241 template.sona_config = SonaConfig::edge_deployment();
242 template.memory_budget_mb = 10;
243 template.expected_data_size = DataSizeHint::Tiny;
244 }
245 TemplatePreset::Balanced => {
246 template.sona_config = SonaConfig::default();
247 template.memory_budget_mb = 100;
248 }
249 TemplatePreset::Production => {
250 template.sona_config = SonaConfig::max_throughput();
251 template.memory_budget_mb = 200;
252 template.auto_export = true;
253 }
254 TemplatePreset::MaxQuality => {
255 template.sona_config = SonaConfig::max_quality();
256 template.memory_budget_mb = 500;
257 template.expected_data_size = DataSizeHint::Large;
258 }
259 TemplatePreset::Edge => {
260 template.sona_config = SonaConfig::edge_deployment();
261 template.memory_budget_mb = 5;
262 template.target_latency_us = 500;
263 }
264 TemplatePreset::Research => {
265 template.sona_config = SonaConfig::max_quality();
266 template.sona_config.trajectory_capacity = 50000;
267 template.memory_budget_mb = 1000;
268 template.expected_data_size = DataSizeHint::Massive;
269 }
270 }
271
272 template.apply_agent_optimizations();
274 template
275 }
276
277 pub fn code_agent() -> Self {
287 let mut template = Self::new("code-agent", AgentType::CodeAgent);
288 template.sona_config.base_lora_rank = 16; template.sona_config.pattern_clusters = 200; template.sona_config.trajectory_capacity = 10000;
291 template.sona_config.quality_threshold = 0.2; template.training_method = TrainingMethod::Online {
293 lr_decay: 0.9995,
294 window_size: 5000,
295 };
296 template.tags = vec!["code".into(), "development".into(), "completion".into()];
297 template
298 }
299
300 pub fn chat_agent() -> Self {
306 let mut template = Self::new("chat-agent", AgentType::ChatAgent);
307 template.sona_config.base_lora_rank = 8;
308 template.sona_config.pattern_clusters = 50;
309 template.sona_config.quality_threshold = 0.4;
310 template.target_latency_us = 500; template.training_method = TrainingMethod::RLHF {
312 reward_weight: 0.5,
313 kl_penalty: 0.1,
314 };
315 template.tags = vec!["chat".into(), "conversation".into(), "support".into()];
316 template
317 }
318
319 pub fn rag_agent() -> Self {
325 let mut template = Self::new("rag-agent", AgentType::RagAgent);
326 template.sona_config.pattern_clusters = 200; template.sona_config.trajectory_capacity = 10000;
328 template.sona_config.embedding_dim = 512; template.sona_config.hidden_dim = 512;
330 template.training_method = TrainingMethod::Supervised {
331 batch_size: 32,
332 epochs: 10,
333 };
334 template.tags = vec!["rag".into(), "retrieval".into(), "documents".into()];
335 template
336 }
337
338 pub fn task_planner() -> Self {
344 let mut template = Self::new("task-planner", AgentType::TaskPlanner);
345 template.sona_config.base_lora_rank = 16;
346 template.sona_config.ewc_lambda = 2000.0; template.sona_config.pattern_clusters = 100;
348 template.training_method = TrainingMethod::DPO {
349 beta: 0.1,
350 ref_weight: 0.5,
351 };
352 template.tags = vec!["planning".into(), "tasks".into(), "decomposition".into()];
353 template
354 }
355
356 pub fn domain_expert(domain: TaskDomain) -> Self {
362 let domain_name = format!("{:?}", domain).to_lowercase();
363 let mut template = Self::new(
364 format!("domain-expert-{}", domain_name),
365 AgentType::DomainExpert,
366 );
367 template.sona_config.quality_threshold = 0.1; template.sona_config.trajectory_capacity = 20000;
369 template.sona_config.base_lora_rank = 16;
370 template.vertical = Some(VerticalConfig {
371 domain: domain.clone(),
372 vocab_boost: 10000,
373 quality_metrics: vec!["accuracy".into(), "relevance".into(), "compliance".into()],
374 compliance_level: match domain {
375 TaskDomain::Healthcare => ComplianceLevel::Hipaa,
376 TaskDomain::Finance => ComplianceLevel::Soc2,
377 TaskDomain::Legal => ComplianceLevel::Basic,
378 _ => ComplianceLevel::None,
379 },
380 });
381 template.tags = vec!["domain".into(), "expert".into(), domain_name];
382 template
383 }
384
385 pub fn codebase_helper() -> Self {
391 let mut template = Self::new("codebase-helper", AgentType::CodebaseHelper);
392 template.sona_config.pattern_clusters = 200;
393 template.sona_config.trajectory_capacity = 10000;
394 template.sona_config.quality_threshold = 0.2;
395 template.sona_config.base_lora_rank = 16;
396 template.expected_data_size = DataSizeHint::Large;
397 template.training_method = TrainingMethod::Online {
398 lr_decay: 0.999,
399 window_size: 10000,
400 };
401 template.tags = vec!["codebase".into(), "repository".into(), "navigation".into()];
402 template
403 }
404
405 pub fn data_analyst() -> Self {
410 let mut template = Self::new("data-analyst", AgentType::DataAnalyst);
411 template.sona_config.base_lora_rank = 8;
412 template.sona_config.pattern_clusters = 100;
413 template.vertical = Some(VerticalConfig {
414 domain: TaskDomain::Research,
415 vocab_boost: 5000,
416 quality_metrics: vec!["accuracy".into(), "insight_quality".into()],
417 compliance_level: ComplianceLevel::None,
418 });
419 template.tags = vec!["data".into(), "analysis".into(), "insights".into()];
420 template
421 }
422
423 pub fn creative_writer() -> Self {
428 let mut template = Self::new("creative-writer", AgentType::CreativeWriter);
429 template.sona_config.base_lora_rank = 8;
430 template.sona_config.pattern_clusters = 50; template.sona_config.quality_threshold = 0.5; template.training_method = TrainingMethod::RLHF {
433 reward_weight: 0.7,
434 kl_penalty: 0.05, };
436 template.vertical = Some(VerticalConfig {
437 domain: TaskDomain::Marketing,
438 vocab_boost: 0,
439 quality_metrics: vec!["creativity".into(), "engagement".into(), "clarity".into()],
440 compliance_level: ComplianceLevel::None,
441 });
442 template.tags = vec!["creative".into(), "writing".into(), "content".into()];
443 template
444 }
445
446 pub fn reasoning_agent() -> Self {
451 let mut template = Self::new("reasoning-agent", AgentType::ReasoningAgent);
452 template.sona_config.base_lora_rank = 16;
453 template.sona_config.ewc_lambda = 3000.0; template.sona_config.pattern_clusters = 150;
455 template.sona_config.quality_threshold = 0.3;
456 template.training_method = TrainingMethod::DPO {
457 beta: 0.15,
458 ref_weight: 0.4,
459 };
460 template.tags = vec!["reasoning".into(), "logic".into(), "math".into()];
461 template
462 }
463
464 pub fn with_sona_config(mut self, config: SonaConfig) -> Self {
470 self.sona_config = config;
471 self
472 }
473
474 pub fn with_training_method(mut self, method: TrainingMethod) -> Self {
476 self.training_method = method;
477 self
478 }
479
480 pub fn with_vertical(mut self, vertical: VerticalConfig) -> Self {
482 self.vertical = Some(vertical);
483 self
484 }
485
486 pub fn with_memory_budget(mut self, mb: usize) -> Self {
488 self.memory_budget_mb = mb;
489 self
490 }
491
492 pub fn with_target_latency(mut self, us: u64) -> Self {
494 self.target_latency_us = us;
495 self
496 }
497
498 pub fn with_continuous_learning(mut self, enabled: bool) -> Self {
500 self.continuous_learning = enabled;
501 self
502 }
503
504 pub fn with_auto_export(mut self, enabled: bool) -> Self {
506 self.auto_export = enabled;
507 self
508 }
509
510 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
512 self.tags = tags;
513 self
514 }
515
516 pub fn with_hidden_dim(mut self, dim: usize) -> Self {
518 self.sona_config.hidden_dim = dim;
519 self.sona_config.embedding_dim = dim;
520 self
521 }
522
523 pub fn with_lora_ranks(mut self, micro: usize, base: usize) -> Self {
525 self.sona_config.micro_lora_rank = micro.min(2); self.sona_config.base_lora_rank = base;
527 self
528 }
529
530 fn apply_agent_optimizations(&mut self) {
536 match &self.agent_type {
537 AgentType::CodeAgent | AgentType::CodebaseHelper => {
538 self.sona_config.pattern_clusters = 200;
539 self.sona_config.base_lora_rank = 16;
540 }
541 AgentType::ChatAgent => {
542 self.sona_config.pattern_clusters = 50;
543 self.target_latency_us = 500;
544 }
545 AgentType::RagAgent => {
546 self.sona_config.pattern_clusters = 200;
547 self.sona_config.trajectory_capacity = 10000;
548 }
549 AgentType::ReasoningAgent => {
550 self.sona_config.ewc_lambda = 3000.0;
551 self.sona_config.base_lora_rank = 16;
552 }
553 AgentType::DomainExpert => {
554 self.sona_config.quality_threshold = 0.1;
555 }
556 _ => {}
557 }
558 }
559
560 pub fn validate(&self) -> Result<(), String> {
562 if self.sona_config.micro_lora_rank > 2 {
563 return Err("MicroLoRA rank must be 1 or 2".into());
564 }
565 if self.sona_config.hidden_dim == 0 {
566 return Err("Hidden dimension must be > 0".into());
567 }
568 if self.memory_budget_mb < 1 {
569 return Err("Memory budget must be >= 1 MB".into());
570 }
571 Ok(())
572 }
573
574 pub fn estimated_memory_mb(&self) -> usize {
576 let config = &self.sona_config;
577
578 let engine_mb = 5;
580
581 let lora_bytes =
583 config.hidden_dim * (config.micro_lora_rank + config.base_lora_rank) * 2 * 4 * 2;
584 let lora_mb = lora_bytes / (1024 * 1024);
585
586 let traj_mb = (config.trajectory_capacity * 800) / (1024 * 1024);
588
589 let pattern_mb = (config.pattern_clusters * config.embedding_dim * 4) / (1024 * 1024);
591
592 engine_mb + lora_mb + traj_mb + pattern_mb + 1
593 }
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599
600 #[test]
601 fn test_template_creation() {
602 let template = TrainingTemplate::code_agent();
603 assert_eq!(template.agent_type, AgentType::CodeAgent);
604 assert_eq!(template.sona_config.base_lora_rank, 16);
605 assert_eq!(template.sona_config.pattern_clusters, 200);
606 }
607
608 #[test]
609 fn test_preset_templates() {
610 let production =
611 TrainingTemplate::from_preset(TemplatePreset::Production, AgentType::ChatAgent);
612 assert!(production.auto_export);
613
614 let edge = TrainingTemplate::from_preset(TemplatePreset::Edge, AgentType::ChatAgent);
615 assert_eq!(edge.memory_budget_mb, 5);
616 }
617
618 #[test]
619 fn test_domain_expert() {
620 let medical = TrainingTemplate::domain_expert(TaskDomain::Healthcare);
621 assert!(medical.vertical.is_some());
622 if let Some(v) = &medical.vertical {
623 assert!(matches!(v.compliance_level, ComplianceLevel::Hipaa));
624 }
625 }
626
627 #[test]
628 fn test_builder_pattern() {
629 let template = TrainingTemplate::new("custom", AgentType::Custom("test".into()))
630 .with_hidden_dim(512)
631 .with_lora_ranks(2, 16)
632 .with_memory_budget(200)
633 .with_continuous_learning(true);
634
635 assert_eq!(template.sona_config.hidden_dim, 512);
636 assert_eq!(template.sona_config.micro_lora_rank, 2);
637 assert_eq!(template.sona_config.base_lora_rank, 16);
638 }
639
640 #[test]
641 fn test_validation() {
642 let mut template = TrainingTemplate::code_agent();
643 assert!(template.validate().is_ok());
644
645 template.sona_config.micro_lora_rank = 5;
646 assert!(template.validate().is_err());
647 }
648
649 #[test]
650 fn test_memory_estimation() {
651 let template = TrainingTemplate::code_agent();
652 let mem = template.estimated_memory_mb();
653 assert!(mem > 0);
654 assert!(mem < template.memory_budget_mb * 2);
655 }
656}