1use std::collections::HashMap;
30use std::sync::Arc;
31use std::time::{Duration, Instant};
32use parking_lot::RwLock;
33use serde::{Deserialize, Serialize};
34use tokio::sync::mpsc;
35
36use super::{AgentType, ClaudeFlowAgent, ClaudeFlowTask};
37use crate::error::{Result, RuvLLMError};
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
45pub enum ClaudeModel {
46 Haiku,
48 Sonnet,
50 Opus,
52}
53
54impl ClaudeModel {
55 pub fn name(&self) -> &'static str {
57 match self {
58 Self::Haiku => "haiku",
59 Self::Sonnet => "sonnet",
60 Self::Opus => "opus",
61 }
62 }
63
64 pub fn model_id(&self) -> &'static str {
66 match self {
67 Self::Haiku => "claude-3-5-haiku-20241022",
68 Self::Sonnet => "claude-sonnet-4-20250514",
69 Self::Opus => "claude-opus-4-20250514",
70 }
71 }
72
73 pub fn input_cost_per_1k(&self) -> f64 {
75 match self {
76 Self::Haiku => 0.00025,
77 Self::Sonnet => 0.003,
78 Self::Opus => 0.015,
79 }
80 }
81
82 pub fn output_cost_per_1k(&self) -> f64 {
84 match self {
85 Self::Haiku => 0.00125,
86 Self::Sonnet => 0.015,
87 Self::Opus => 0.075,
88 }
89 }
90
91 pub fn typical_ttft_ms(&self) -> u64 {
93 match self {
94 Self::Haiku => 200,
95 Self::Sonnet => 500,
96 Self::Opus => 1500,
97 }
98 }
99
100 pub fn max_context_tokens(&self) -> usize {
102 match self {
103 Self::Haiku => 200_000,
104 Self::Sonnet => 200_000,
105 Self::Opus => 200_000,
106 }
107 }
108}
109
110impl Default for ClaudeModel {
111 fn default() -> Self {
112 Self::Sonnet
113 }
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
118#[serde(rename_all = "lowercase")]
119pub enum MessageRole {
120 User,
122 Assistant,
124 System,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130#[serde(tag = "type", rename_all = "snake_case")]
131pub enum ContentBlock {
132 Text { text: String },
134 ToolUse {
136 id: String,
137 name: String,
138 input: serde_json::Value,
139 },
140 ToolResult {
142 tool_use_id: String,
143 content: String,
144 },
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct Message {
150 pub role: MessageRole,
152 pub content: Vec<ContentBlock>,
154}
155
156impl Message {
157 pub fn text(role: MessageRole, text: impl Into<String>) -> Self {
159 Self {
160 role,
161 content: vec![ContentBlock::Text { text: text.into() }],
162 }
163 }
164
165 pub fn user(text: impl Into<String>) -> Self {
167 Self::text(MessageRole::User, text)
168 }
169
170 pub fn assistant(text: impl Into<String>) -> Self {
172 Self::text(MessageRole::Assistant, text)
173 }
174
175 pub fn estimate_tokens(&self) -> usize {
177 self.content.iter().map(|block| {
178 match block {
179 ContentBlock::Text { text } => text.len() / 4, ContentBlock::ToolUse { input, .. } => {
181 input.to_string().len() / 4 + 50 }
183 ContentBlock::ToolResult { content, .. } => content.len() / 4 + 20,
184 }
185 }).sum()
186 }
187}
188
189#[derive(Debug, Clone, Serialize)]
191pub struct ClaudeRequest {
192 pub model: String,
194 pub messages: Vec<Message>,
196 pub max_tokens: usize,
198 #[serde(skip_serializing_if = "Option::is_none")]
200 pub system: Option<String>,
201 #[serde(skip_serializing_if = "Option::is_none")]
203 pub temperature: Option<f32>,
204 #[serde(skip_serializing_if = "Option::is_none")]
206 pub stream: Option<bool>,
207}
208
209#[derive(Debug, Clone, Deserialize)]
211pub struct ClaudeResponse {
212 pub id: String,
214 pub model: String,
216 pub content: Vec<ContentBlock>,
218 pub stop_reason: Option<String>,
220 pub usage: UsageStats,
222}
223
224#[derive(Debug, Clone, Default, Deserialize, Serialize)]
226pub struct UsageStats {
227 pub input_tokens: usize,
229 pub output_tokens: usize,
231}
232
233impl UsageStats {
234 pub fn calculate_cost(&self, model: ClaudeModel) -> f64 {
236 let input_cost = (self.input_tokens as f64 / 1000.0) * model.input_cost_per_1k();
237 let output_cost = (self.output_tokens as f64 / 1000.0) * model.output_cost_per_1k();
238 input_cost + output_cost
239 }
240}
241
242#[derive(Debug, Clone)]
248pub struct StreamToken {
249 pub text: String,
251 pub index: usize,
253 pub latency_ms: u64,
255 pub quality_score: Option<f32>,
257}
258
259#[derive(Debug, Clone)]
261pub enum StreamEvent {
262 Start {
264 request_id: String,
265 model: ClaudeModel,
266 },
267 Token(StreamToken),
269 ContentBlockComplete {
271 index: usize,
272 content: ContentBlock,
273 },
274 Complete {
276 usage: UsageStats,
277 stop_reason: String,
278 total_latency_ms: u64,
279 },
280 Error {
282 message: String,
283 is_retryable: bool,
284 },
285}
286
287#[derive(Debug, Clone)]
289pub struct QualityMonitor {
290 pub min_quality: f32,
292 pub check_interval: usize,
294 scores: Vec<f32>,
296 tokens_since_check: usize,
298}
299
300impl QualityMonitor {
301 pub fn new(min_quality: f32, check_interval: usize) -> Self {
303 Self {
304 min_quality,
305 check_interval,
306 scores: Vec::new(),
307 tokens_since_check: 0,
308 }
309 }
310
311 pub fn record(&mut self, score: f32) {
313 self.scores.push(score);
314 self.tokens_since_check += 1;
315 }
316
317 pub fn should_continue(&self) -> bool {
319 if self.scores.is_empty() {
320 return true;
321 }
322 let avg = self.scores.iter().sum::<f32>() / self.scores.len() as f32;
323 avg >= self.min_quality
324 }
325
326 pub fn should_check(&self) -> bool {
328 self.tokens_since_check >= self.check_interval
329 }
330
331 pub fn reset_check(&mut self) {
333 self.tokens_since_check = 0;
334 }
335
336 pub fn average_quality(&self) -> f32 {
338 if self.scores.is_empty() {
339 1.0
340 } else {
341 self.scores.iter().sum::<f32>() / self.scores.len() as f32
342 }
343 }
344}
345
346pub struct ResponseStreamer {
348 pub request_id: String,
350 pub model: ClaudeModel,
352 start_time: Instant,
354 token_count: usize,
356 quality_monitor: QualityMonitor,
358 sender: mpsc::Sender<StreamEvent>,
360 accumulated_text: String,
362 is_complete: bool,
364}
365
366impl ResponseStreamer {
367 pub fn new(
369 request_id: String,
370 model: ClaudeModel,
371 sender: mpsc::Sender<StreamEvent>,
372 ) -> Self {
373 Self {
374 request_id: request_id.clone(),
375 model,
376 start_time: Instant::now(),
377 token_count: 0,
378 quality_monitor: QualityMonitor::new(0.6, 20),
379 sender,
380 accumulated_text: String::new(),
381 is_complete: false,
382 }
383 }
384
385 pub async fn process_token(&mut self, text: String, quality_score: Option<f32>) -> Result<()> {
387 if self.is_complete {
388 return Err(RuvLLMError::InvalidOperation("Stream already complete".to_string()));
389 }
390
391 let token = StreamToken {
392 text: text.clone(),
393 index: self.token_count,
394 latency_ms: self.start_time.elapsed().as_millis() as u64,
395 quality_score,
396 };
397
398 if let Some(score) = quality_score {
400 self.quality_monitor.record(score);
401 }
402
403 self.accumulated_text.push_str(&text);
405 self.token_count += 1;
406
407 self.sender
409 .send(StreamEvent::Token(token))
410 .await
411 .map_err(|e| RuvLLMError::InvalidOperation(format!("Failed to send token: {}", e)))?;
412
413 Ok(())
414 }
415
416 pub async fn complete(&mut self, usage: UsageStats, stop_reason: String) -> Result<()> {
418 self.is_complete = true;
419
420 self.sender
421 .send(StreamEvent::Complete {
422 usage,
423 stop_reason,
424 total_latency_ms: self.start_time.elapsed().as_millis() as u64,
425 })
426 .await
427 .map_err(|e| RuvLLMError::InvalidOperation(format!("Failed to send complete: {}", e)))?;
428
429 Ok(())
430 }
431
432 pub fn stats(&self) -> StreamStats {
434 let elapsed = self.start_time.elapsed();
435 StreamStats {
436 token_count: self.token_count,
437 elapsed_ms: elapsed.as_millis() as u64,
438 tokens_per_second: if elapsed.as_secs_f64() > 0.0 {
439 self.token_count as f64 / elapsed.as_secs_f64()
440 } else {
441 0.0
442 },
443 average_quality: self.quality_monitor.average_quality(),
444 is_complete: self.is_complete,
445 }
446 }
447
448 pub fn accumulated_text(&self) -> &str {
450 &self.accumulated_text
451 }
452
453 pub fn quality_acceptable(&self) -> bool {
455 self.quality_monitor.should_continue()
456 }
457}
458
459#[derive(Debug, Clone)]
461pub struct StreamStats {
462 pub token_count: usize,
464 pub elapsed_ms: u64,
466 pub tokens_per_second: f64,
468 pub average_quality: f32,
470 pub is_complete: bool,
472}
473
474#[derive(Debug, Clone)]
480pub struct ContextWindow {
481 messages: Vec<Message>,
483 system_prompt: Option<String>,
485 max_tokens: usize,
487 current_tokens: usize,
489 compression_threshold: f32,
491}
492
493impl ContextWindow {
494 pub fn new(max_tokens: usize) -> Self {
496 Self {
497 messages: Vec::new(),
498 system_prompt: None,
499 max_tokens,
500 current_tokens: 0,
501 compression_threshold: 0.8,
502 }
503 }
504
505 pub fn set_system(&mut self, prompt: impl Into<String>) {
507 let prompt = prompt.into();
508 self.current_tokens -= self.system_prompt.as_ref().map_or(0, |p| p.len() / 4);
509 self.current_tokens += prompt.len() / 4;
510 self.system_prompt = Some(prompt);
511 }
512
513 pub fn add_message(&mut self, message: Message) {
515 let tokens = message.estimate_tokens();
516 self.current_tokens += tokens;
517 self.messages.push(message);
518
519 if self.needs_compression() {
521 self.compress();
522 }
523 }
524
525 pub fn needs_compression(&self) -> bool {
527 self.current_tokens as f32 > self.max_tokens as f32 * self.compression_threshold
528 }
529
530 pub fn utilization(&self) -> f32 {
532 self.current_tokens as f32 / self.max_tokens as f32
533 }
534
535 pub fn compress(&mut self) {
537 if self.messages.len() <= 4 {
539 return;
540 }
541
542 let target_tokens = (self.max_tokens as f32 * 0.6) as usize;
543
544 let keep_first = 1;
546 let mut keep_last = 3;
547
548 while self.current_tokens > target_tokens && keep_last > 1 {
549 let to_remove = self.messages.len() - keep_first - keep_last;
550 if to_remove > 0 {
551 let removed: Vec<_> = self.messages.drain(keep_first..keep_first + 1).collect();
553 for msg in removed {
554 self.current_tokens -= msg.estimate_tokens();
555 }
556 } else {
557 keep_last -= 1;
558 }
559 }
560 }
561
562 pub fn expand_for_task(&mut self, task_complexity: f32, model: ClaudeModel) {
564 let base_max = model.max_context_tokens();
566 let expansion_factor = 0.5 + (task_complexity * 0.5); self.max_tokens = (base_max as f32 * expansion_factor) as usize;
568 }
569
570 pub fn get_messages(&self) -> &[Message] {
572 &self.messages
573 }
574
575 pub fn get_system(&self) -> Option<&str> {
577 self.system_prompt.as_deref()
578 }
579
580 pub fn token_count(&self) -> usize {
582 self.current_tokens
583 }
584
585 pub fn remaining_capacity(&self) -> usize {
587 self.max_tokens.saturating_sub(self.current_tokens)
588 }
589
590 pub fn clear(&mut self) {
592 self.messages.clear();
593 self.current_tokens = self.system_prompt.as_ref().map_or(0, |p| p.len() / 4);
594 }
595}
596
597pub struct ContextManager {
599 windows: HashMap<String, ContextWindow>,
601 default_max_tokens: usize,
603}
604
605impl ContextManager {
606 pub fn new(default_max_tokens: usize) -> Self {
608 Self {
609 windows: HashMap::new(),
610 default_max_tokens,
611 }
612 }
613
614 pub fn get_window(&mut self, agent_id: &str) -> &mut ContextWindow {
616 if !self.windows.contains_key(agent_id) {
617 self.windows.insert(
618 agent_id.to_string(),
619 ContextWindow::new(self.default_max_tokens),
620 );
621 }
622 self.windows.get_mut(agent_id).unwrap()
623 }
624
625 pub fn remove_window(&mut self, agent_id: &str) {
627 self.windows.remove(agent_id);
628 }
629
630 pub fn total_tokens(&self) -> usize {
632 self.windows.values().map(|w| w.token_count()).sum()
633 }
634
635 pub fn window_count(&self) -> usize {
637 self.windows.len()
638 }
639}
640
641#[derive(Debug, Clone, PartialEq, Eq)]
647pub enum AgentState {
648 Idle,
650 Running,
652 Blocked,
654 Completed,
656 Failed,
658}
659
660#[derive(Debug, Clone)]
662pub struct AgentContext {
663 pub agent_id: String,
665 pub agent_type: AgentType,
667 pub model: ClaudeModel,
669 pub state: AgentState,
671 pub context_tokens: usize,
673 pub total_tokens_used: usize,
675 pub total_cost: f64,
677 pub started_at: Option<Instant>,
679 pub completed_at: Option<Instant>,
681 pub error: Option<String>,
683}
684
685impl AgentContext {
686 pub fn new(agent_id: String, agent_type: AgentType, model: ClaudeModel) -> Self {
688 Self {
689 agent_id,
690 agent_type,
691 model,
692 state: AgentState::Idle,
693 context_tokens: 0,
694 total_tokens_used: 0,
695 total_cost: 0.0,
696 started_at: None,
697 completed_at: None,
698 error: None,
699 }
700 }
701
702 pub fn start(&mut self) {
704 self.state = AgentState::Running;
705 self.started_at = Some(Instant::now());
706 }
707
708 pub fn block(&mut self) {
710 self.state = AgentState::Blocked;
711 }
712
713 pub fn complete(&mut self, usage: &UsageStats) {
715 self.state = AgentState::Completed;
716 self.completed_at = Some(Instant::now());
717 self.total_tokens_used += usage.input_tokens + usage.output_tokens;
718 self.total_cost += usage.calculate_cost(self.model);
719 }
720
721 pub fn fail(&mut self, error: String) {
723 self.state = AgentState::Failed;
724 self.completed_at = Some(Instant::now());
725 self.error = Some(error);
726 }
727
728 pub fn duration(&self) -> Option<Duration> {
730 match (self.started_at, self.completed_at) {
731 (Some(start), Some(end)) => Some(end.duration_since(start)),
732 (Some(start), None) => Some(start.elapsed()),
733 _ => None,
734 }
735 }
736}
737
738#[derive(Debug, Clone)]
740pub struct WorkflowStep {
741 pub step_id: String,
743 pub agent_type: AgentType,
745 pub task: String,
747 pub dependencies: Vec<String>,
749 pub required_model: Option<ClaudeModel>,
751 pub max_retries: u32,
753}
754
755#[derive(Debug, Clone)]
757pub struct WorkflowResult {
758 pub workflow_id: String,
760 pub step_results: HashMap<String, StepResult>,
762 pub total_duration: Duration,
764 pub total_tokens: usize,
766 pub total_cost: f64,
768 pub success: bool,
770 pub error: Option<String>,
772}
773
774#[derive(Debug, Clone)]
776pub struct StepResult {
777 pub step_id: String,
779 pub agent_id: String,
781 pub model: ClaudeModel,
783 pub response: Option<String>,
785 pub duration: Duration,
787 pub tokens_used: usize,
789 pub cost: f64,
791 pub success: bool,
793 pub error: Option<String>,
795}
796
797pub struct AgentCoordinator {
799 agents: Arc<RwLock<HashMap<String, AgentContext>>>,
801 context_manager: Arc<RwLock<ContextManager>>,
803 default_model: ClaudeModel,
805 max_concurrent: usize,
807 workflows_executed: u64,
809 total_cost: f64,
811}
812
813impl AgentCoordinator {
814 pub fn new(default_model: ClaudeModel, max_concurrent: usize) -> Self {
816 Self {
817 agents: Arc::new(RwLock::new(HashMap::new())),
818 context_manager: Arc::new(RwLock::new(ContextManager::new(100_000))),
819 default_model,
820 max_concurrent,
821 workflows_executed: 0,
822 total_cost: 0.0,
823 }
824 }
825
826 pub fn spawn_agent(&self, agent_id: String, agent_type: AgentType) -> Result<()> {
828 let mut agents = self.agents.write();
829
830 if agents.len() >= self.max_concurrent {
831 return Err(RuvLLMError::OutOfMemory(format!(
832 "Maximum concurrent agents ({}) reached",
833 self.max_concurrent
834 )));
835 }
836
837 if agents.contains_key(&agent_id) {
838 return Err(RuvLLMError::InvalidOperation(format!(
839 "Agent {} already exists",
840 agent_id
841 )));
842 }
843
844 let context = AgentContext::new(agent_id.clone(), agent_type, self.default_model);
845 agents.insert(agent_id, context);
846
847 Ok(())
848 }
849
850 pub fn get_agent(&self, agent_id: &str) -> Option<AgentContext> {
852 self.agents.read().get(agent_id).cloned()
853 }
854
855 pub fn update_agent<F>(&self, agent_id: &str, f: F) -> Result<()>
857 where
858 F: FnOnce(&mut AgentContext),
859 {
860 let mut agents = self.agents.write();
861 let agent = agents
862 .get_mut(agent_id)
863 .ok_or_else(|| RuvLLMError::NotFound(format!("Agent {} not found", agent_id)))?;
864 f(agent);
865 Ok(())
866 }
867
868 pub fn terminate_agent(&self, agent_id: &str) -> Result<()> {
870 let mut agents = self.agents.write();
871 agents
872 .remove(agent_id)
873 .ok_or_else(|| RuvLLMError::NotFound(format!("Agent {} not found", agent_id)))?;
874
875 self.context_manager.write().remove_window(agent_id);
877
878 Ok(())
879 }
880
881 pub fn active_agent_count(&self) -> usize {
883 self.agents
884 .read()
885 .values()
886 .filter(|a| a.state == AgentState::Running)
887 .count()
888 }
889
890 pub fn total_agent_count(&self) -> usize {
892 self.agents.read().len()
893 }
894
895 pub async fn execute_workflow(
897 &mut self,
898 workflow_id: String,
899 steps: Vec<WorkflowStep>,
900 ) -> Result<WorkflowResult> {
901 let start_time = Instant::now();
902 let mut step_results: HashMap<String, StepResult> = HashMap::new();
903 let mut completed_steps: std::collections::HashSet<String> = std::collections::HashSet::new();
904
905 let mut pending_steps: Vec<&WorkflowStep> = steps.iter().collect();
907
908 while !pending_steps.is_empty() {
909 let ready_steps: Vec<_> = pending_steps
911 .iter()
912 .filter(|step| {
913 step.dependencies
914 .iter()
915 .all(|dep| completed_steps.contains(dep))
916 })
917 .cloned()
918 .collect();
919
920 if ready_steps.is_empty() && !pending_steps.is_empty() {
921 return Err(RuvLLMError::InvalidOperation(
922 "Workflow has circular dependencies".to_string(),
923 ));
924 }
925
926 for step in ready_steps {
928 let agent_id = format!("{}-{}", workflow_id, step.step_id);
929 let model = step.required_model.unwrap_or(self.default_model);
930
931 self.spawn_agent(agent_id.clone(), step.agent_type)?;
933 self.update_agent(&agent_id, |a| a.start())?;
934
935 let step_start = Instant::now();
937
938 let result = StepResult {
940 step_id: step.step_id.clone(),
941 agent_id: agent_id.clone(),
942 model,
943 response: Some(format!("Completed: {}", step.task)),
944 duration: step_start.elapsed(),
945 tokens_used: 500, cost: 0.001, success: true,
948 error: None,
949 };
950
951 self.update_agent(&agent_id, |a| {
952 let usage = UsageStats {
953 input_tokens: 250,
954 output_tokens: 250,
955 };
956 a.complete(&usage);
957 })?;
958
959 step_results.insert(step.step_id.clone(), result);
960 completed_steps.insert(step.step_id.clone());
961
962 self.terminate_agent(&agent_id)?;
964 }
965
966 pending_steps.retain(|step| !completed_steps.contains(&step.step_id));
968 }
969
970 let total_tokens: usize = step_results.values().map(|r| r.tokens_used).sum();
972 let total_cost: f64 = step_results.values().map(|r| r.cost).sum();
973
974 self.workflows_executed += 1;
975 self.total_cost += total_cost;
976
977 Ok(WorkflowResult {
978 workflow_id,
979 step_results,
980 total_duration: start_time.elapsed(),
981 total_tokens,
982 total_cost,
983 success: true,
984 error: None,
985 })
986 }
987
988 pub fn stats(&self) -> CoordinatorStats {
990 let agents = self.agents.read();
991 let active_count = agents
992 .values()
993 .filter(|a| a.state == AgentState::Running)
994 .count();
995 let total_tokens: usize = agents.values().map(|a| a.total_tokens_used).sum();
996
997 CoordinatorStats {
998 total_agents: agents.len(),
999 active_agents: active_count,
1000 blocked_agents: agents
1001 .values()
1002 .filter(|a| a.state == AgentState::Blocked)
1003 .count(),
1004 completed_agents: agents
1005 .values()
1006 .filter(|a| a.state == AgentState::Completed)
1007 .count(),
1008 failed_agents: agents
1009 .values()
1010 .filter(|a| a.state == AgentState::Failed)
1011 .count(),
1012 workflows_executed: self.workflows_executed,
1013 total_tokens_used: total_tokens,
1014 total_cost: self.total_cost,
1015 }
1016 }
1017}
1018
1019#[derive(Debug, Clone)]
1021pub struct CoordinatorStats {
1022 pub total_agents: usize,
1024 pub active_agents: usize,
1026 pub blocked_agents: usize,
1028 pub completed_agents: usize,
1030 pub failed_agents: usize,
1032 pub workflows_executed: u64,
1034 pub total_tokens_used: usize,
1036 pub total_cost: f64,
1038}
1039
1040pub struct CostEstimator {
1046 usage_by_model: HashMap<ClaudeModel, UsageStats>,
1048}
1049
1050impl CostEstimator {
1051 pub fn new() -> Self {
1053 Self {
1054 usage_by_model: HashMap::new(),
1055 }
1056 }
1057
1058 pub fn estimate_request_cost(
1060 &self,
1061 model: ClaudeModel,
1062 input_tokens: usize,
1063 expected_output_tokens: usize,
1064 ) -> f64 {
1065 let input_cost = (input_tokens as f64 / 1000.0) * model.input_cost_per_1k();
1066 let output_cost = (expected_output_tokens as f64 / 1000.0) * model.output_cost_per_1k();
1067 input_cost + output_cost
1068 }
1069
1070 pub fn record_usage(&mut self, model: ClaudeModel, usage: &UsageStats) {
1072 let entry = self.usage_by_model.entry(model).or_insert(UsageStats::default());
1073 entry.input_tokens += usage.input_tokens;
1074 entry.output_tokens += usage.output_tokens;
1075 }
1076
1077 pub fn total_cost(&self) -> f64 {
1079 self.usage_by_model
1080 .iter()
1081 .map(|(model, usage)| usage.calculate_cost(*model))
1082 .sum()
1083 }
1084
1085 pub fn cost_breakdown(&self) -> HashMap<ClaudeModel, f64> {
1087 self.usage_by_model
1088 .iter()
1089 .map(|(model, usage)| (*model, usage.calculate_cost(*model)))
1090 .collect()
1091 }
1092
1093 pub fn usage_by_model(&self) -> &HashMap<ClaudeModel, UsageStats> {
1095 &self.usage_by_model
1096 }
1097}
1098
1099impl Default for CostEstimator {
1100 fn default() -> Self {
1101 Self::new()
1102 }
1103}
1104
1105pub struct LatencyTracker {
1111 samples: HashMap<ClaudeModel, Vec<LatencySample>>,
1113 max_samples: usize,
1115}
1116
1117#[derive(Debug, Clone)]
1119pub struct LatencySample {
1120 pub ttft_ms: u64,
1122 pub total_ms: u64,
1124 pub input_tokens: usize,
1126 pub output_tokens: usize,
1128 pub timestamp: Instant,
1130}
1131
1132impl LatencyTracker {
1133 pub fn new(max_samples: usize) -> Self {
1135 Self {
1136 samples: HashMap::new(),
1137 max_samples,
1138 }
1139 }
1140
1141 pub fn record(&mut self, model: ClaudeModel, sample: LatencySample) {
1143 let samples = self.samples.entry(model).or_insert_with(Vec::new);
1144 samples.push(sample);
1145
1146 if samples.len() > self.max_samples {
1148 samples.remove(0);
1149 }
1150 }
1151
1152 pub fn average_ttft(&self, model: ClaudeModel) -> Option<f64> {
1154 self.samples.get(&model).map(|samples| {
1155 if samples.is_empty() {
1156 return 0.0;
1157 }
1158 let sum: u64 = samples.iter().map(|s| s.ttft_ms).sum();
1159 sum as f64 / samples.len() as f64
1160 })
1161 }
1162
1163 pub fn p95_ttft(&self, model: ClaudeModel) -> Option<u64> {
1165 self.samples.get(&model).and_then(|samples| {
1166 if samples.is_empty() {
1167 return None;
1168 }
1169 let mut ttfts: Vec<u64> = samples.iter().map(|s| s.ttft_ms).collect();
1170 ttfts.sort();
1171 let idx = (ttfts.len() as f64 * 0.95) as usize;
1172 ttfts.get(idx.min(ttfts.len() - 1)).copied()
1173 })
1174 }
1175
1176 pub fn average_tokens_per_second(&self, model: ClaudeModel) -> Option<f64> {
1178 self.samples.get(&model).map(|samples| {
1179 if samples.is_empty() {
1180 return 0.0;
1181 }
1182 let total_tokens: usize = samples.iter().map(|s| s.output_tokens).sum();
1183 let total_time_ms: u64 = samples.iter().map(|s| s.total_ms - s.ttft_ms).sum();
1184 if total_time_ms == 0 {
1185 return 0.0;
1186 }
1187 total_tokens as f64 / (total_time_ms as f64 / 1000.0)
1188 })
1189 }
1190
1191 pub fn get_stats(&self, model: ClaudeModel) -> Option<LatencyStats> {
1193 self.samples.get(&model).map(|samples| LatencyStats {
1194 sample_count: samples.len(),
1195 avg_ttft_ms: self.average_ttft(model).unwrap_or(0.0),
1196 p95_ttft_ms: self.p95_ttft(model).unwrap_or(0),
1197 avg_tokens_per_second: self.average_tokens_per_second(model).unwrap_or(0.0),
1198 })
1199 }
1200}
1201
1202#[derive(Debug, Clone)]
1204pub struct LatencyStats {
1205 pub sample_count: usize,
1207 pub avg_ttft_ms: f64,
1209 pub p95_ttft_ms: u64,
1211 pub avg_tokens_per_second: f64,
1213}
1214
1215#[cfg(test)]
1220mod tests {
1221 use super::*;
1222
1223 #[test]
1224 fn test_claude_model_costs() {
1225 let usage = UsageStats {
1226 input_tokens: 1000,
1227 output_tokens: 500,
1228 };
1229
1230 let haiku_cost = usage.calculate_cost(ClaudeModel::Haiku);
1231 let sonnet_cost = usage.calculate_cost(ClaudeModel::Sonnet);
1232 let opus_cost = usage.calculate_cost(ClaudeModel::Opus);
1233
1234 assert!(haiku_cost < sonnet_cost);
1235 assert!(sonnet_cost < opus_cost);
1236 }
1237
1238 #[test]
1239 fn test_context_window_compression() {
1240 let mut window = ContextWindow::new(1000);
1241
1242 for i in 0..20 {
1244 window.add_message(Message::user(format!("Message {} with some content to add tokens", i)));
1245 }
1246
1247 assert!(window.token_count() <= 1000);
1249 }
1250
1251 #[test]
1252 fn test_message_token_estimation() {
1253 let msg = Message::user("Hello, this is a test message with some content.");
1254 let tokens = msg.estimate_tokens();
1255 assert!(tokens > 0);
1256 assert!(tokens < 100); }
1258
1259 #[test]
1260 fn test_quality_monitor() {
1261 let mut monitor = QualityMonitor::new(0.6, 10);
1262
1263 for _ in 0..5 {
1265 monitor.record(0.8);
1266 }
1267 assert!(monitor.should_continue());
1268
1269 let mut bad_monitor = QualityMonitor::new(0.6, 10);
1271 for _ in 0..5 {
1272 bad_monitor.record(0.3);
1273 }
1274 assert!(!bad_monitor.should_continue());
1275 }
1276
1277 #[test]
1278 fn test_agent_coordinator() {
1279 let coordinator = AgentCoordinator::new(ClaudeModel::Sonnet, 10);
1280
1281 coordinator.spawn_agent("agent-1".to_string(), AgentType::Coder).unwrap();
1282 coordinator.spawn_agent("agent-2".to_string(), AgentType::Researcher).unwrap();
1283
1284 assert_eq!(coordinator.total_agent_count(), 2);
1285
1286 coordinator.update_agent("agent-1", |a| a.start()).unwrap();
1287 assert_eq!(coordinator.active_agent_count(), 1);
1288
1289 coordinator.terminate_agent("agent-1").unwrap();
1290 assert_eq!(coordinator.total_agent_count(), 1);
1291 }
1292
1293 #[test]
1294 fn test_cost_estimator() {
1295 let mut estimator = CostEstimator::new();
1296
1297 let usage = UsageStats {
1298 input_tokens: 1000,
1299 output_tokens: 500,
1300 };
1301
1302 estimator.record_usage(ClaudeModel::Sonnet, &usage);
1303 estimator.record_usage(ClaudeModel::Haiku, &usage);
1304
1305 let total = estimator.total_cost();
1306 assert!(total > 0.0);
1307
1308 let breakdown = estimator.cost_breakdown();
1309 assert!(breakdown.contains_key(&ClaudeModel::Sonnet));
1310 assert!(breakdown.contains_key(&ClaudeModel::Haiku));
1311 }
1312
1313 #[test]
1314 fn test_latency_tracker() {
1315 let mut tracker = LatencyTracker::new(100);
1316
1317 for i in 0..10 {
1318 tracker.record(
1319 ClaudeModel::Sonnet,
1320 LatencySample {
1321 ttft_ms: 400 + i * 10,
1322 total_ms: 1000 + i * 100,
1323 input_tokens: 500,
1324 output_tokens: 200,
1325 timestamp: Instant::now(),
1326 },
1327 );
1328 }
1329
1330 let stats = tracker.get_stats(ClaudeModel::Sonnet).unwrap();
1331 assert_eq!(stats.sample_count, 10);
1332 assert!(stats.avg_ttft_ms > 400.0);
1333 assert!(stats.avg_tokens_per_second > 0.0);
1334 }
1335}