1use parking_lot::RwLock;
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::sync::Arc;
33use std::time::{Duration, Instant};
34use tokio::sync::mpsc;
35
36use super::{AgentType, ClaudeFlowAgent, ClaudeFlowTask};
37use crate::error::{Result, RuvLLMError};
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
45pub enum ClaudeModel {
46 Haiku,
48 Sonnet,
50 Opus,
52}
53
54impl ClaudeModel {
55 pub fn name(&self) -> &'static str {
57 match self {
58 Self::Haiku => "haiku",
59 Self::Sonnet => "sonnet",
60 Self::Opus => "opus",
61 }
62 }
63
64 pub fn model_id(&self) -> &'static str {
66 match self {
67 Self::Haiku => "claude-3-5-haiku-20241022",
68 Self::Sonnet => "claude-sonnet-4-20250514",
69 Self::Opus => "claude-opus-4-20250514",
70 }
71 }
72
73 pub fn input_cost_per_1k(&self) -> f64 {
75 match self {
76 Self::Haiku => 0.00025,
77 Self::Sonnet => 0.003,
78 Self::Opus => 0.015,
79 }
80 }
81
82 pub fn output_cost_per_1k(&self) -> f64 {
84 match self {
85 Self::Haiku => 0.00125,
86 Self::Sonnet => 0.015,
87 Self::Opus => 0.075,
88 }
89 }
90
91 pub fn typical_ttft_ms(&self) -> u64 {
93 match self {
94 Self::Haiku => 200,
95 Self::Sonnet => 500,
96 Self::Opus => 1500,
97 }
98 }
99
100 pub fn max_context_tokens(&self) -> usize {
102 match self {
103 Self::Haiku => 200_000,
104 Self::Sonnet => 200_000,
105 Self::Opus => 200_000,
106 }
107 }
108}
109
110impl Default for ClaudeModel {
111 fn default() -> Self {
112 Self::Sonnet
113 }
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
118#[serde(rename_all = "lowercase")]
119pub enum MessageRole {
120 User,
122 Assistant,
124 System,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130#[serde(tag = "type", rename_all = "snake_case")]
131pub enum ContentBlock {
132 Text { text: String },
134 ToolUse {
136 id: String,
137 name: String,
138 input: serde_json::Value,
139 },
140 ToolResult {
142 tool_use_id: String,
143 content: String,
144 },
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct Message {
150 pub role: MessageRole,
152 pub content: Vec<ContentBlock>,
154}
155
156impl Message {
157 pub fn text(role: MessageRole, text: impl Into<String>) -> Self {
159 Self {
160 role,
161 content: vec![ContentBlock::Text { text: text.into() }],
162 }
163 }
164
165 pub fn user(text: impl Into<String>) -> Self {
167 Self::text(MessageRole::User, text)
168 }
169
170 pub fn assistant(text: impl Into<String>) -> Self {
172 Self::text(MessageRole::Assistant, text)
173 }
174
175 pub fn estimate_tokens(&self) -> usize {
177 self.content
178 .iter()
179 .map(|block| {
180 match block {
181 ContentBlock::Text { text } => text.len() / 4, ContentBlock::ToolUse { input, .. } => {
183 input.to_string().len() / 4 + 50 }
185 ContentBlock::ToolResult { content, .. } => content.len() / 4 + 20,
186 }
187 })
188 .sum()
189 }
190}
191
192#[derive(Debug, Clone, Serialize)]
194pub struct ClaudeRequest {
195 pub model: String,
197 pub messages: Vec<Message>,
199 pub max_tokens: usize,
201 #[serde(skip_serializing_if = "Option::is_none")]
203 pub system: Option<String>,
204 #[serde(skip_serializing_if = "Option::is_none")]
206 pub temperature: Option<f32>,
207 #[serde(skip_serializing_if = "Option::is_none")]
209 pub stream: Option<bool>,
210}
211
212#[derive(Debug, Clone, Deserialize)]
214pub struct ClaudeResponse {
215 pub id: String,
217 pub model: String,
219 pub content: Vec<ContentBlock>,
221 pub stop_reason: Option<String>,
223 pub usage: UsageStats,
225}
226
227#[derive(Debug, Clone, Default, Deserialize, Serialize)]
229pub struct UsageStats {
230 pub input_tokens: usize,
232 pub output_tokens: usize,
234}
235
236impl UsageStats {
237 pub fn calculate_cost(&self, model: ClaudeModel) -> f64 {
239 let input_cost = (self.input_tokens as f64 / 1000.0) * model.input_cost_per_1k();
240 let output_cost = (self.output_tokens as f64 / 1000.0) * model.output_cost_per_1k();
241 input_cost + output_cost
242 }
243}
244
245#[derive(Debug, Clone)]
251pub struct StreamToken {
252 pub text: String,
254 pub index: usize,
256 pub latency_ms: u64,
258 pub quality_score: Option<f32>,
260}
261
262#[derive(Debug, Clone)]
264pub enum StreamEvent {
265 Start {
267 request_id: String,
268 model: ClaudeModel,
269 },
270 Token(StreamToken),
272 ContentBlockComplete { index: usize, content: ContentBlock },
274 Complete {
276 usage: UsageStats,
277 stop_reason: String,
278 total_latency_ms: u64,
279 },
280 Error { message: String, is_retryable: bool },
282}
283
284#[derive(Debug, Clone)]
286pub struct QualityMonitor {
287 pub min_quality: f32,
289 pub check_interval: usize,
291 scores: Vec<f32>,
293 tokens_since_check: usize,
295}
296
297impl QualityMonitor {
298 pub fn new(min_quality: f32, check_interval: usize) -> Self {
300 Self {
301 min_quality,
302 check_interval,
303 scores: Vec::new(),
304 tokens_since_check: 0,
305 }
306 }
307
308 pub fn record(&mut self, score: f32) {
310 self.scores.push(score);
311 self.tokens_since_check += 1;
312 }
313
314 pub fn should_continue(&self) -> bool {
316 if self.scores.is_empty() {
317 return true;
318 }
319 let avg = self.scores.iter().sum::<f32>() / self.scores.len() as f32;
320 avg >= self.min_quality
321 }
322
323 pub fn should_check(&self) -> bool {
325 self.tokens_since_check >= self.check_interval
326 }
327
328 pub fn reset_check(&mut self) {
330 self.tokens_since_check = 0;
331 }
332
333 pub fn average_quality(&self) -> f32 {
335 if self.scores.is_empty() {
336 1.0
337 } else {
338 self.scores.iter().sum::<f32>() / self.scores.len() as f32
339 }
340 }
341}
342
343pub struct ResponseStreamer {
345 pub request_id: String,
347 pub model: ClaudeModel,
349 start_time: Instant,
351 token_count: usize,
353 quality_monitor: QualityMonitor,
355 sender: mpsc::Sender<StreamEvent>,
357 accumulated_text: String,
359 is_complete: bool,
361}
362
363impl ResponseStreamer {
364 pub fn new(request_id: String, model: ClaudeModel, sender: mpsc::Sender<StreamEvent>) -> Self {
366 Self {
367 request_id: request_id.clone(),
368 model,
369 start_time: Instant::now(),
370 token_count: 0,
371 quality_monitor: QualityMonitor::new(0.6, 20),
372 sender,
373 accumulated_text: String::new(),
374 is_complete: false,
375 }
376 }
377
378 pub async fn process_token(&mut self, text: String, quality_score: Option<f32>) -> Result<()> {
380 if self.is_complete {
381 return Err(RuvLLMError::InvalidOperation(
382 "Stream already complete".to_string(),
383 ));
384 }
385
386 let token = StreamToken {
387 text: text.clone(),
388 index: self.token_count,
389 latency_ms: self.start_time.elapsed().as_millis() as u64,
390 quality_score,
391 };
392
393 if let Some(score) = quality_score {
395 self.quality_monitor.record(score);
396 }
397
398 self.accumulated_text.push_str(&text);
400 self.token_count += 1;
401
402 self.sender
404 .send(StreamEvent::Token(token))
405 .await
406 .map_err(|e| RuvLLMError::InvalidOperation(format!("Failed to send token: {}", e)))?;
407
408 Ok(())
409 }
410
411 pub async fn complete(&mut self, usage: UsageStats, stop_reason: String) -> Result<()> {
413 self.is_complete = true;
414
415 self.sender
416 .send(StreamEvent::Complete {
417 usage,
418 stop_reason,
419 total_latency_ms: self.start_time.elapsed().as_millis() as u64,
420 })
421 .await
422 .map_err(|e| {
423 RuvLLMError::InvalidOperation(format!("Failed to send complete: {}", e))
424 })?;
425
426 Ok(())
427 }
428
429 pub fn stats(&self) -> StreamStats {
431 let elapsed = self.start_time.elapsed();
432 StreamStats {
433 token_count: self.token_count,
434 elapsed_ms: elapsed.as_millis() as u64,
435 tokens_per_second: if elapsed.as_secs_f64() > 0.0 {
436 self.token_count as f64 / elapsed.as_secs_f64()
437 } else {
438 0.0
439 },
440 average_quality: self.quality_monitor.average_quality(),
441 is_complete: self.is_complete,
442 }
443 }
444
445 pub fn accumulated_text(&self) -> &str {
447 &self.accumulated_text
448 }
449
450 pub fn quality_acceptable(&self) -> bool {
452 self.quality_monitor.should_continue()
453 }
454}
455
456#[derive(Debug, Clone)]
458pub struct StreamStats {
459 pub token_count: usize,
461 pub elapsed_ms: u64,
463 pub tokens_per_second: f64,
465 pub average_quality: f32,
467 pub is_complete: bool,
469}
470
471#[derive(Debug, Clone)]
477pub struct ContextWindow {
478 messages: Vec<Message>,
480 system_prompt: Option<String>,
482 max_tokens: usize,
484 current_tokens: usize,
486 compression_threshold: f32,
488}
489
490impl ContextWindow {
491 pub fn new(max_tokens: usize) -> Self {
493 Self {
494 messages: Vec::new(),
495 system_prompt: None,
496 max_tokens,
497 current_tokens: 0,
498 compression_threshold: 0.8,
499 }
500 }
501
502 pub fn set_system(&mut self, prompt: impl Into<String>) {
504 let prompt = prompt.into();
505 self.current_tokens -= self.system_prompt.as_ref().map_or(0, |p| p.len() / 4);
506 self.current_tokens += prompt.len() / 4;
507 self.system_prompt = Some(prompt);
508 }
509
510 pub fn add_message(&mut self, message: Message) {
512 let tokens = message.estimate_tokens();
513 self.current_tokens += tokens;
514 self.messages.push(message);
515
516 if self.needs_compression() {
518 self.compress();
519 }
520 }
521
522 pub fn needs_compression(&self) -> bool {
524 self.current_tokens as f32 > self.max_tokens as f32 * self.compression_threshold
525 }
526
527 pub fn utilization(&self) -> f32 {
529 self.current_tokens as f32 / self.max_tokens as f32
530 }
531
532 pub fn compress(&mut self) {
534 if self.messages.len() <= 4 {
536 return;
537 }
538
539 let target_tokens = (self.max_tokens as f32 * 0.6) as usize;
540
541 let keep_first = 1;
543 let mut keep_last = 3;
544
545 while self.current_tokens > target_tokens && keep_last > 1 {
546 let to_remove = self.messages.len() - keep_first - keep_last;
547 if to_remove > 0 {
548 let removed: Vec<_> = self.messages.drain(keep_first..keep_first + 1).collect();
550 for msg in removed {
551 self.current_tokens -= msg.estimate_tokens();
552 }
553 } else {
554 keep_last -= 1;
555 }
556 }
557 }
558
559 pub fn expand_for_task(&mut self, task_complexity: f32, model: ClaudeModel) {
561 let base_max = model.max_context_tokens();
563 let expansion_factor = 0.5 + (task_complexity * 0.5); self.max_tokens = (base_max as f32 * expansion_factor) as usize;
565 }
566
567 pub fn get_messages(&self) -> &[Message] {
569 &self.messages
570 }
571
572 pub fn get_system(&self) -> Option<&str> {
574 self.system_prompt.as_deref()
575 }
576
577 pub fn token_count(&self) -> usize {
579 self.current_tokens
580 }
581
582 pub fn remaining_capacity(&self) -> usize {
584 self.max_tokens.saturating_sub(self.current_tokens)
585 }
586
587 pub fn clear(&mut self) {
589 self.messages.clear();
590 self.current_tokens = self.system_prompt.as_ref().map_or(0, |p| p.len() / 4);
591 }
592}
593
594pub struct ContextManager {
596 windows: HashMap<String, ContextWindow>,
598 default_max_tokens: usize,
600}
601
602impl ContextManager {
603 pub fn new(default_max_tokens: usize) -> Self {
605 Self {
606 windows: HashMap::new(),
607 default_max_tokens,
608 }
609 }
610
611 pub fn get_window(&mut self, agent_id: &str) -> &mut ContextWindow {
613 if !self.windows.contains_key(agent_id) {
614 self.windows.insert(
615 agent_id.to_string(),
616 ContextWindow::new(self.default_max_tokens),
617 );
618 }
619 self.windows.get_mut(agent_id).unwrap()
620 }
621
622 pub fn remove_window(&mut self, agent_id: &str) {
624 self.windows.remove(agent_id);
625 }
626
627 pub fn total_tokens(&self) -> usize {
629 self.windows.values().map(|w| w.token_count()).sum()
630 }
631
632 pub fn window_count(&self) -> usize {
634 self.windows.len()
635 }
636}
637
638#[derive(Debug, Clone, PartialEq, Eq)]
644pub enum AgentState {
645 Idle,
647 Running,
649 Blocked,
651 Completed,
653 Failed,
655}
656
657#[derive(Debug, Clone)]
659pub struct AgentContext {
660 pub agent_id: String,
662 pub agent_type: AgentType,
664 pub model: ClaudeModel,
666 pub state: AgentState,
668 pub context_tokens: usize,
670 pub total_tokens_used: usize,
672 pub total_cost: f64,
674 pub started_at: Option<Instant>,
676 pub completed_at: Option<Instant>,
678 pub error: Option<String>,
680}
681
682impl AgentContext {
683 pub fn new(agent_id: String, agent_type: AgentType, model: ClaudeModel) -> Self {
685 Self {
686 agent_id,
687 agent_type,
688 model,
689 state: AgentState::Idle,
690 context_tokens: 0,
691 total_tokens_used: 0,
692 total_cost: 0.0,
693 started_at: None,
694 completed_at: None,
695 error: None,
696 }
697 }
698
699 pub fn start(&mut self) {
701 self.state = AgentState::Running;
702 self.started_at = Some(Instant::now());
703 }
704
705 pub fn block(&mut self) {
707 self.state = AgentState::Blocked;
708 }
709
710 pub fn complete(&mut self, usage: &UsageStats) {
712 self.state = AgentState::Completed;
713 self.completed_at = Some(Instant::now());
714 self.total_tokens_used += usage.input_tokens + usage.output_tokens;
715 self.total_cost += usage.calculate_cost(self.model);
716 }
717
718 pub fn fail(&mut self, error: String) {
720 self.state = AgentState::Failed;
721 self.completed_at = Some(Instant::now());
722 self.error = Some(error);
723 }
724
725 pub fn duration(&self) -> Option<Duration> {
727 match (self.started_at, self.completed_at) {
728 (Some(start), Some(end)) => Some(end.duration_since(start)),
729 (Some(start), None) => Some(start.elapsed()),
730 _ => None,
731 }
732 }
733}
734
735#[derive(Debug, Clone)]
737pub struct WorkflowStep {
738 pub step_id: String,
740 pub agent_type: AgentType,
742 pub task: String,
744 pub dependencies: Vec<String>,
746 pub required_model: Option<ClaudeModel>,
748 pub max_retries: u32,
750}
751
752#[derive(Debug, Clone)]
754pub struct WorkflowResult {
755 pub workflow_id: String,
757 pub step_results: HashMap<String, StepResult>,
759 pub total_duration: Duration,
761 pub total_tokens: usize,
763 pub total_cost: f64,
765 pub success: bool,
767 pub error: Option<String>,
769}
770
771#[derive(Debug, Clone)]
773pub struct StepResult {
774 pub step_id: String,
776 pub agent_id: String,
778 pub model: ClaudeModel,
780 pub response: Option<String>,
782 pub duration: Duration,
784 pub tokens_used: usize,
786 pub cost: f64,
788 pub success: bool,
790 pub error: Option<String>,
792}
793
794pub struct AgentCoordinator {
796 agents: Arc<RwLock<HashMap<String, AgentContext>>>,
798 context_manager: Arc<RwLock<ContextManager>>,
800 default_model: ClaudeModel,
802 max_concurrent: usize,
804 workflows_executed: u64,
806 total_cost: f64,
808}
809
810impl AgentCoordinator {
811 pub fn new(default_model: ClaudeModel, max_concurrent: usize) -> Self {
813 Self {
814 agents: Arc::new(RwLock::new(HashMap::new())),
815 context_manager: Arc::new(RwLock::new(ContextManager::new(100_000))),
816 default_model,
817 max_concurrent,
818 workflows_executed: 0,
819 total_cost: 0.0,
820 }
821 }
822
823 pub fn spawn_agent(&self, agent_id: String, agent_type: AgentType) -> Result<()> {
825 let mut agents = self.agents.write();
826
827 if agents.len() >= self.max_concurrent {
828 return Err(RuvLLMError::OutOfMemory(format!(
829 "Maximum concurrent agents ({}) reached",
830 self.max_concurrent
831 )));
832 }
833
834 if agents.contains_key(&agent_id) {
835 return Err(RuvLLMError::InvalidOperation(format!(
836 "Agent {} already exists",
837 agent_id
838 )));
839 }
840
841 let context = AgentContext::new(agent_id.clone(), agent_type, self.default_model);
842 agents.insert(agent_id, context);
843
844 Ok(())
845 }
846
847 pub fn get_agent(&self, agent_id: &str) -> Option<AgentContext> {
849 self.agents.read().get(agent_id).cloned()
850 }
851
852 pub fn update_agent<F>(&self, agent_id: &str, f: F) -> Result<()>
854 where
855 F: FnOnce(&mut AgentContext),
856 {
857 let mut agents = self.agents.write();
858 let agent = agents
859 .get_mut(agent_id)
860 .ok_or_else(|| RuvLLMError::NotFound(format!("Agent {} not found", agent_id)))?;
861 f(agent);
862 Ok(())
863 }
864
865 pub fn terminate_agent(&self, agent_id: &str) -> Result<()> {
867 let mut agents = self.agents.write();
868 agents
869 .remove(agent_id)
870 .ok_or_else(|| RuvLLMError::NotFound(format!("Agent {} not found", agent_id)))?;
871
872 self.context_manager.write().remove_window(agent_id);
874
875 Ok(())
876 }
877
878 pub fn active_agent_count(&self) -> usize {
880 self.agents
881 .read()
882 .values()
883 .filter(|a| a.state == AgentState::Running)
884 .count()
885 }
886
887 pub fn total_agent_count(&self) -> usize {
889 self.agents.read().len()
890 }
891
892 pub async fn execute_workflow(
894 &mut self,
895 workflow_id: String,
896 steps: Vec<WorkflowStep>,
897 ) -> Result<WorkflowResult> {
898 let start_time = Instant::now();
899 let mut step_results: HashMap<String, StepResult> = HashMap::new();
900 let mut completed_steps: std::collections::HashSet<String> =
901 std::collections::HashSet::new();
902
903 let mut pending_steps: Vec<&WorkflowStep> = steps.iter().collect();
905
906 while !pending_steps.is_empty() {
907 let ready_steps: Vec<_> = pending_steps
909 .iter()
910 .filter(|step| {
911 step.dependencies
912 .iter()
913 .all(|dep| completed_steps.contains(dep))
914 })
915 .cloned()
916 .collect();
917
918 if ready_steps.is_empty() && !pending_steps.is_empty() {
919 return Err(RuvLLMError::InvalidOperation(
920 "Workflow has circular dependencies".to_string(),
921 ));
922 }
923
924 for step in ready_steps {
926 let agent_id = format!("{}-{}", workflow_id, step.step_id);
927 let model = step.required_model.unwrap_or(self.default_model);
928
929 self.spawn_agent(agent_id.clone(), step.agent_type)?;
931 self.update_agent(&agent_id, |a| a.start())?;
932
933 let step_start = Instant::now();
935
936 let result = StepResult {
938 step_id: step.step_id.clone(),
939 agent_id: agent_id.clone(),
940 model,
941 response: Some(format!("Completed: {}", step.task)),
942 duration: step_start.elapsed(),
943 tokens_used: 500, cost: 0.001, success: true,
946 error: None,
947 };
948
949 self.update_agent(&agent_id, |a| {
950 let usage = UsageStats {
951 input_tokens: 250,
952 output_tokens: 250,
953 };
954 a.complete(&usage);
955 })?;
956
957 step_results.insert(step.step_id.clone(), result);
958 completed_steps.insert(step.step_id.clone());
959
960 self.terminate_agent(&agent_id)?;
962 }
963
964 pending_steps.retain(|step| !completed_steps.contains(&step.step_id));
966 }
967
968 let total_tokens: usize = step_results.values().map(|r| r.tokens_used).sum();
970 let total_cost: f64 = step_results.values().map(|r| r.cost).sum();
971
972 self.workflows_executed += 1;
973 self.total_cost += total_cost;
974
975 Ok(WorkflowResult {
976 workflow_id,
977 step_results,
978 total_duration: start_time.elapsed(),
979 total_tokens,
980 total_cost,
981 success: true,
982 error: None,
983 })
984 }
985
986 pub fn stats(&self) -> CoordinatorStats {
988 let agents = self.agents.read();
989 let active_count = agents
990 .values()
991 .filter(|a| a.state == AgentState::Running)
992 .count();
993 let total_tokens: usize = agents.values().map(|a| a.total_tokens_used).sum();
994
995 CoordinatorStats {
996 total_agents: agents.len(),
997 active_agents: active_count,
998 blocked_agents: agents
999 .values()
1000 .filter(|a| a.state == AgentState::Blocked)
1001 .count(),
1002 completed_agents: agents
1003 .values()
1004 .filter(|a| a.state == AgentState::Completed)
1005 .count(),
1006 failed_agents: agents
1007 .values()
1008 .filter(|a| a.state == AgentState::Failed)
1009 .count(),
1010 workflows_executed: self.workflows_executed,
1011 total_tokens_used: total_tokens,
1012 total_cost: self.total_cost,
1013 }
1014 }
1015}
1016
1017#[derive(Debug, Clone)]
1019pub struct CoordinatorStats {
1020 pub total_agents: usize,
1022 pub active_agents: usize,
1024 pub blocked_agents: usize,
1026 pub completed_agents: usize,
1028 pub failed_agents: usize,
1030 pub workflows_executed: u64,
1032 pub total_tokens_used: usize,
1034 pub total_cost: f64,
1036}
1037
1038pub struct CostEstimator {
1044 usage_by_model: HashMap<ClaudeModel, UsageStats>,
1046}
1047
1048impl CostEstimator {
1049 pub fn new() -> Self {
1051 Self {
1052 usage_by_model: HashMap::new(),
1053 }
1054 }
1055
1056 pub fn estimate_request_cost(
1058 &self,
1059 model: ClaudeModel,
1060 input_tokens: usize,
1061 expected_output_tokens: usize,
1062 ) -> f64 {
1063 let input_cost = (input_tokens as f64 / 1000.0) * model.input_cost_per_1k();
1064 let output_cost = (expected_output_tokens as f64 / 1000.0) * model.output_cost_per_1k();
1065 input_cost + output_cost
1066 }
1067
1068 pub fn record_usage(&mut self, model: ClaudeModel, usage: &UsageStats) {
1070 let entry = self
1071 .usage_by_model
1072 .entry(model)
1073 .or_insert(UsageStats::default());
1074 entry.input_tokens += usage.input_tokens;
1075 entry.output_tokens += usage.output_tokens;
1076 }
1077
1078 pub fn total_cost(&self) -> f64 {
1080 self.usage_by_model
1081 .iter()
1082 .map(|(model, usage)| usage.calculate_cost(*model))
1083 .sum()
1084 }
1085
1086 pub fn cost_breakdown(&self) -> HashMap<ClaudeModel, f64> {
1088 self.usage_by_model
1089 .iter()
1090 .map(|(model, usage)| (*model, usage.calculate_cost(*model)))
1091 .collect()
1092 }
1093
1094 pub fn usage_by_model(&self) -> &HashMap<ClaudeModel, UsageStats> {
1096 &self.usage_by_model
1097 }
1098}
1099
1100impl Default for CostEstimator {
1101 fn default() -> Self {
1102 Self::new()
1103 }
1104}
1105
1106pub struct LatencyTracker {
1112 samples: HashMap<ClaudeModel, Vec<LatencySample>>,
1114 max_samples: usize,
1116}
1117
1118#[derive(Debug, Clone)]
1120pub struct LatencySample {
1121 pub ttft_ms: u64,
1123 pub total_ms: u64,
1125 pub input_tokens: usize,
1127 pub output_tokens: usize,
1129 pub timestamp: Instant,
1131}
1132
1133impl LatencyTracker {
1134 pub fn new(max_samples: usize) -> Self {
1136 Self {
1137 samples: HashMap::new(),
1138 max_samples,
1139 }
1140 }
1141
1142 pub fn record(&mut self, model: ClaudeModel, sample: LatencySample) {
1144 let samples = self.samples.entry(model).or_insert_with(Vec::new);
1145 samples.push(sample);
1146
1147 if samples.len() > self.max_samples {
1149 samples.remove(0);
1150 }
1151 }
1152
1153 pub fn average_ttft(&self, model: ClaudeModel) -> Option<f64> {
1155 self.samples.get(&model).map(|samples| {
1156 if samples.is_empty() {
1157 return 0.0;
1158 }
1159 let sum: u64 = samples.iter().map(|s| s.ttft_ms).sum();
1160 sum as f64 / samples.len() as f64
1161 })
1162 }
1163
1164 pub fn p95_ttft(&self, model: ClaudeModel) -> Option<u64> {
1166 self.samples.get(&model).and_then(|samples| {
1167 if samples.is_empty() {
1168 return None;
1169 }
1170 let mut ttfts: Vec<u64> = samples.iter().map(|s| s.ttft_ms).collect();
1171 ttfts.sort();
1172 let idx = (ttfts.len() as f64 * 0.95) as usize;
1173 ttfts.get(idx.min(ttfts.len() - 1)).copied()
1174 })
1175 }
1176
1177 pub fn average_tokens_per_second(&self, model: ClaudeModel) -> Option<f64> {
1179 self.samples.get(&model).map(|samples| {
1180 if samples.is_empty() {
1181 return 0.0;
1182 }
1183 let total_tokens: usize = samples.iter().map(|s| s.output_tokens).sum();
1184 let total_time_ms: u64 = samples.iter().map(|s| s.total_ms - s.ttft_ms).sum();
1185 if total_time_ms == 0 {
1186 return 0.0;
1187 }
1188 total_tokens as f64 / (total_time_ms as f64 / 1000.0)
1189 })
1190 }
1191
1192 pub fn get_stats(&self, model: ClaudeModel) -> Option<LatencyStats> {
1194 self.samples.get(&model).map(|samples| LatencyStats {
1195 sample_count: samples.len(),
1196 avg_ttft_ms: self.average_ttft(model).unwrap_or(0.0),
1197 p95_ttft_ms: self.p95_ttft(model).unwrap_or(0),
1198 avg_tokens_per_second: self.average_tokens_per_second(model).unwrap_or(0.0),
1199 })
1200 }
1201}
1202
1203#[derive(Debug, Clone)]
1205pub struct LatencyStats {
1206 pub sample_count: usize,
1208 pub avg_ttft_ms: f64,
1210 pub p95_ttft_ms: u64,
1212 pub avg_tokens_per_second: f64,
1214}
1215
1216#[cfg(test)]
1221mod tests {
1222 use super::*;
1223
1224 #[test]
1225 fn test_claude_model_costs() {
1226 let usage = UsageStats {
1227 input_tokens: 1000,
1228 output_tokens: 500,
1229 };
1230
1231 let haiku_cost = usage.calculate_cost(ClaudeModel::Haiku);
1232 let sonnet_cost = usage.calculate_cost(ClaudeModel::Sonnet);
1233 let opus_cost = usage.calculate_cost(ClaudeModel::Opus);
1234
1235 assert!(haiku_cost < sonnet_cost);
1236 assert!(sonnet_cost < opus_cost);
1237 }
1238
1239 #[test]
1240 fn test_context_window_compression() {
1241 let mut window = ContextWindow::new(1000);
1242
1243 for i in 0..20 {
1245 window.add_message(Message::user(format!(
1246 "Message {} with some content to add tokens",
1247 i
1248 )));
1249 }
1250
1251 assert!(window.token_count() <= 1000);
1253 }
1254
1255 #[test]
1256 fn test_message_token_estimation() {
1257 let msg = Message::user("Hello, this is a test message with some content.");
1258 let tokens = msg.estimate_tokens();
1259 assert!(tokens > 0);
1260 assert!(tokens < 100); }
1262
1263 #[test]
1264 fn test_quality_monitor() {
1265 let mut monitor = QualityMonitor::new(0.6, 10);
1266
1267 for _ in 0..5 {
1269 monitor.record(0.8);
1270 }
1271 assert!(monitor.should_continue());
1272
1273 let mut bad_monitor = QualityMonitor::new(0.6, 10);
1275 for _ in 0..5 {
1276 bad_monitor.record(0.3);
1277 }
1278 assert!(!bad_monitor.should_continue());
1279 }
1280
1281 #[test]
1282 fn test_agent_coordinator() {
1283 let coordinator = AgentCoordinator::new(ClaudeModel::Sonnet, 10);
1284
1285 coordinator
1286 .spawn_agent("agent-1".to_string(), AgentType::Coder)
1287 .unwrap();
1288 coordinator
1289 .spawn_agent("agent-2".to_string(), AgentType::Researcher)
1290 .unwrap();
1291
1292 assert_eq!(coordinator.total_agent_count(), 2);
1293
1294 coordinator.update_agent("agent-1", |a| a.start()).unwrap();
1295 assert_eq!(coordinator.active_agent_count(), 1);
1296
1297 coordinator.terminate_agent("agent-1").unwrap();
1298 assert_eq!(coordinator.total_agent_count(), 1);
1299 }
1300
1301 #[test]
1302 fn test_cost_estimator() {
1303 let mut estimator = CostEstimator::new();
1304
1305 let usage = UsageStats {
1306 input_tokens: 1000,
1307 output_tokens: 500,
1308 };
1309
1310 estimator.record_usage(ClaudeModel::Sonnet, &usage);
1311 estimator.record_usage(ClaudeModel::Haiku, &usage);
1312
1313 let total = estimator.total_cost();
1314 assert!(total > 0.0);
1315
1316 let breakdown = estimator.cost_breakdown();
1317 assert!(breakdown.contains_key(&ClaudeModel::Sonnet));
1318 assert!(breakdown.contains_key(&ClaudeModel::Haiku));
1319 }
1320
1321 #[test]
1322 fn test_latency_tracker() {
1323 let mut tracker = LatencyTracker::new(100);
1324
1325 for i in 0..10 {
1326 tracker.record(
1327 ClaudeModel::Sonnet,
1328 LatencySample {
1329 ttft_ms: 400 + i * 10,
1330 total_ms: 1000 + i * 100,
1331 input_tokens: 500,
1332 output_tokens: 200,
1333 timestamp: Instant::now(),
1334 },
1335 );
1336 }
1337
1338 let stats = tracker.get_stats(ClaudeModel::Sonnet).unwrap();
1339 assert_eq!(stats.sample_count, 10);
1340 assert!(stats.avg_ttft_ms > 400.0);
1341 assert!(stats.avg_tokens_per_second > 0.0);
1342 }
1343}