ruvllm/claude_flow/
claude_integration.rs

1//! Claude API Integration for Agent Communication
2//!
3//! Provides full Claude API compatibility for multi-agent coordination,
4//! including streaming response handling, context window management,
5//! and workflow orchestration.
6//!
7//! ## Key Features
8//!
9//! - **Full Claude API Compatibility**: Messages, streaming, tool use
10//! - **Streaming Response Handling**: Real-time token generation with quality monitoring
11//! - **Context Window Management**: Dynamic compression/expansion based on task complexity
12//! - **Multi-Agent Coordination**: Workflow orchestration with dependency resolution
13//!
14//! ## Architecture
15//!
16//! ```text
17//! +-------------------+     +-------------------+
18//! | AgentCoordinator  |---->| ClaudeClient      |
19//! | (workflow mgmt)   |     | (API interface)   |
20//! +--------+----------+     +--------+----------+
21//!          |                         |
22//!          v                         v
23//! +--------+----------+     +--------+----------+
24//! | ResponseStreamer  |<----| ContextManager    |
25//! | (token handling)  |     | (window mgmt)     |
26//! +-------------------+     +-------------------+
27//! ```
28
29use std::collections::HashMap;
30use std::sync::Arc;
31use std::time::{Duration, Instant};
32use parking_lot::RwLock;
33use serde::{Deserialize, Serialize};
34use tokio::sync::mpsc;
35
36use super::{AgentType, ClaudeFlowAgent, ClaudeFlowTask};
37use crate::error::{Result, RuvLLMError};
38
39// ============================================================================
40// Claude API Types
41// ============================================================================
42
43/// Claude model variants for intelligent routing
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
45pub enum ClaudeModel {
46    /// Fast, cost-effective for simple tasks
47    Haiku,
48    /// Balanced performance and capability
49    Sonnet,
50    /// Most capable for complex reasoning
51    Opus,
52}
53
54impl ClaudeModel {
55    /// Get short name for the model
56    pub fn name(&self) -> &'static str {
57        match self {
58            Self::Haiku => "haiku",
59            Self::Sonnet => "sonnet",
60            Self::Opus => "opus",
61        }
62    }
63
64    /// Get model identifier string
65    pub fn model_id(&self) -> &'static str {
66        match self {
67            Self::Haiku => "claude-3-5-haiku-20241022",
68            Self::Sonnet => "claude-sonnet-4-20250514",
69            Self::Opus => "claude-opus-4-20250514",
70        }
71    }
72
73    /// Get cost per 1K input tokens (USD)
74    pub fn input_cost_per_1k(&self) -> f64 {
75        match self {
76            Self::Haiku => 0.00025,
77            Self::Sonnet => 0.003,
78            Self::Opus => 0.015,
79        }
80    }
81
82    /// Get cost per 1K output tokens (USD)
83    pub fn output_cost_per_1k(&self) -> f64 {
84        match self {
85            Self::Haiku => 0.00125,
86            Self::Sonnet => 0.015,
87            Self::Opus => 0.075,
88        }
89    }
90
91    /// Get typical latency for first token (ms)
92    pub fn typical_ttft_ms(&self) -> u64 {
93        match self {
94            Self::Haiku => 200,
95            Self::Sonnet => 500,
96            Self::Opus => 1500,
97        }
98    }
99
100    /// Get maximum context window size
101    pub fn max_context_tokens(&self) -> usize {
102        match self {
103            Self::Haiku => 200_000,
104            Self::Sonnet => 200_000,
105            Self::Opus => 200_000,
106        }
107    }
108}
109
110impl Default for ClaudeModel {
111    fn default() -> Self {
112        Self::Sonnet
113    }
114}
115
116/// Message role in conversation
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
118#[serde(rename_all = "lowercase")]
119pub enum MessageRole {
120    /// User message
121    User,
122    /// Assistant response
123    Assistant,
124    /// System instructions
125    System,
126}
127
128/// Content block types
129#[derive(Debug, Clone, Serialize, Deserialize)]
130#[serde(tag = "type", rename_all = "snake_case")]
131pub enum ContentBlock {
132    /// Text content
133    Text { text: String },
134    /// Tool use request
135    ToolUse {
136        id: String,
137        name: String,
138        input: serde_json::Value,
139    },
140    /// Tool result
141    ToolResult {
142        tool_use_id: String,
143        content: String,
144    },
145}
146
147/// Message in conversation
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct Message {
150    /// Message role
151    pub role: MessageRole,
152    /// Message content blocks
153    pub content: Vec<ContentBlock>,
154}
155
156impl Message {
157    /// Create a simple text message
158    pub fn text(role: MessageRole, text: impl Into<String>) -> Self {
159        Self {
160            role,
161            content: vec![ContentBlock::Text { text: text.into() }],
162        }
163    }
164
165    /// Create a user message
166    pub fn user(text: impl Into<String>) -> Self {
167        Self::text(MessageRole::User, text)
168    }
169
170    /// Create an assistant message
171    pub fn assistant(text: impl Into<String>) -> Self {
172        Self::text(MessageRole::Assistant, text)
173    }
174
175    /// Estimate token count for this message
176    pub fn estimate_tokens(&self) -> usize {
177        self.content.iter().map(|block| {
178            match block {
179                ContentBlock::Text { text } => text.len() / 4, // ~4 chars per token
180                ContentBlock::ToolUse { input, .. } => {
181                    input.to_string().len() / 4 + 50 // overhead for tool structure
182                }
183                ContentBlock::ToolResult { content, .. } => content.len() / 4 + 20,
184            }
185        }).sum()
186    }
187}
188
189/// Request to Claude API
190#[derive(Debug, Clone, Serialize)]
191pub struct ClaudeRequest {
192    /// Model to use
193    pub model: String,
194    /// Conversation messages
195    pub messages: Vec<Message>,
196    /// Maximum tokens to generate
197    pub max_tokens: usize,
198    /// System prompt
199    #[serde(skip_serializing_if = "Option::is_none")]
200    pub system: Option<String>,
201    /// Temperature for sampling
202    #[serde(skip_serializing_if = "Option::is_none")]
203    pub temperature: Option<f32>,
204    /// Enable streaming
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub stream: Option<bool>,
207}
208
209/// Response from Claude API
210#[derive(Debug, Clone, Deserialize)]
211pub struct ClaudeResponse {
212    /// Response ID
213    pub id: String,
214    /// Model used
215    pub model: String,
216    /// Content blocks
217    pub content: Vec<ContentBlock>,
218    /// Stop reason
219    pub stop_reason: Option<String>,
220    /// Usage statistics
221    pub usage: UsageStats,
222}
223
224/// Token usage statistics
225#[derive(Debug, Clone, Default, Deserialize, Serialize)]
226pub struct UsageStats {
227    /// Input tokens used
228    pub input_tokens: usize,
229    /// Output tokens generated
230    pub output_tokens: usize,
231}
232
233impl UsageStats {
234    /// Calculate cost for given model
235    pub fn calculate_cost(&self, model: ClaudeModel) -> f64 {
236        let input_cost = (self.input_tokens as f64 / 1000.0) * model.input_cost_per_1k();
237        let output_cost = (self.output_tokens as f64 / 1000.0) * model.output_cost_per_1k();
238        input_cost + output_cost
239    }
240}
241
242// ============================================================================
243// Streaming Types
244// ============================================================================
245
246/// Streaming token with metadata
247#[derive(Debug, Clone)]
248pub struct StreamToken {
249    /// Token text
250    pub text: String,
251    /// Token index in sequence
252    pub index: usize,
253    /// Cumulative latency from stream start
254    pub latency_ms: u64,
255    /// Quality score (0.0 - 1.0) if available
256    pub quality_score: Option<f32>,
257}
258
259/// Stream event types
260#[derive(Debug, Clone)]
261pub enum StreamEvent {
262    /// Stream started
263    Start {
264        request_id: String,
265        model: ClaudeModel,
266    },
267    /// Token generated
268    Token(StreamToken),
269    /// Content block completed
270    ContentBlockComplete {
271        index: usize,
272        content: ContentBlock,
273    },
274    /// Stream completed
275    Complete {
276        usage: UsageStats,
277        stop_reason: String,
278        total_latency_ms: u64,
279    },
280    /// Error occurred
281    Error {
282        message: String,
283        is_retryable: bool,
284    },
285}
286
287/// Quality monitoring for streaming responses
288#[derive(Debug, Clone)]
289pub struct QualityMonitor {
290    /// Minimum acceptable quality score
291    pub min_quality: f32,
292    /// Check interval (tokens)
293    pub check_interval: usize,
294    /// Accumulated quality scores
295    scores: Vec<f32>,
296    /// Tokens since last check
297    tokens_since_check: usize,
298}
299
300impl QualityMonitor {
301    /// Create new quality monitor
302    pub fn new(min_quality: f32, check_interval: usize) -> Self {
303        Self {
304            min_quality,
305            check_interval,
306            scores: Vec::new(),
307            tokens_since_check: 0,
308        }
309    }
310
311    /// Record a quality observation
312    pub fn record(&mut self, score: f32) {
313        self.scores.push(score);
314        self.tokens_since_check += 1;
315    }
316
317    /// Check if quality is acceptable
318    pub fn should_continue(&self) -> bool {
319        if self.scores.is_empty() {
320            return true;
321        }
322        let avg = self.scores.iter().sum::<f32>() / self.scores.len() as f32;
323        avg >= self.min_quality
324    }
325
326    /// Check if it's time to evaluate quality
327    pub fn should_check(&self) -> bool {
328        self.tokens_since_check >= self.check_interval
329    }
330
331    /// Reset check counter
332    pub fn reset_check(&mut self) {
333        self.tokens_since_check = 0;
334    }
335
336    /// Get average quality score
337    pub fn average_quality(&self) -> f32 {
338        if self.scores.is_empty() {
339            1.0
340        } else {
341            self.scores.iter().sum::<f32>() / self.scores.len() as f32
342        }
343    }
344}
345
346/// Response streamer for real-time token handling
347pub struct ResponseStreamer {
348    /// Request ID
349    pub request_id: String,
350    /// Model being used
351    pub model: ClaudeModel,
352    /// Stream start time
353    start_time: Instant,
354    /// Token count
355    token_count: usize,
356    /// Quality monitor
357    quality_monitor: QualityMonitor,
358    /// Event sender
359    sender: mpsc::Sender<StreamEvent>,
360    /// Accumulated text
361    accumulated_text: String,
362    /// Is stream complete
363    is_complete: bool,
364}
365
366impl ResponseStreamer {
367    /// Create new response streamer
368    pub fn new(
369        request_id: String,
370        model: ClaudeModel,
371        sender: mpsc::Sender<StreamEvent>,
372    ) -> Self {
373        Self {
374            request_id: request_id.clone(),
375            model,
376            start_time: Instant::now(),
377            token_count: 0,
378            quality_monitor: QualityMonitor::new(0.6, 20),
379            sender,
380            accumulated_text: String::new(),
381            is_complete: false,
382        }
383    }
384
385    /// Process incoming token
386    pub async fn process_token(&mut self, text: String, quality_score: Option<f32>) -> Result<()> {
387        if self.is_complete {
388            return Err(RuvLLMError::InvalidOperation("Stream already complete".to_string()));
389        }
390
391        let token = StreamToken {
392            text: text.clone(),
393            index: self.token_count,
394            latency_ms: self.start_time.elapsed().as_millis() as u64,
395            quality_score,
396        };
397
398        // Update quality monitor
399        if let Some(score) = quality_score {
400            self.quality_monitor.record(score);
401        }
402
403        // Accumulate text
404        self.accumulated_text.push_str(&text);
405        self.token_count += 1;
406
407        // Send token event
408        self.sender
409            .send(StreamEvent::Token(token))
410            .await
411            .map_err(|e| RuvLLMError::InvalidOperation(format!("Failed to send token: {}", e)))?;
412
413        Ok(())
414    }
415
416    /// Complete the stream
417    pub async fn complete(&mut self, usage: UsageStats, stop_reason: String) -> Result<()> {
418        self.is_complete = true;
419
420        self.sender
421            .send(StreamEvent::Complete {
422                usage,
423                stop_reason,
424                total_latency_ms: self.start_time.elapsed().as_millis() as u64,
425            })
426            .await
427            .map_err(|e| RuvLLMError::InvalidOperation(format!("Failed to send complete: {}", e)))?;
428
429        Ok(())
430    }
431
432    /// Get current statistics
433    pub fn stats(&self) -> StreamStats {
434        let elapsed = self.start_time.elapsed();
435        StreamStats {
436            token_count: self.token_count,
437            elapsed_ms: elapsed.as_millis() as u64,
438            tokens_per_second: if elapsed.as_secs_f64() > 0.0 {
439                self.token_count as f64 / elapsed.as_secs_f64()
440            } else {
441                0.0
442            },
443            average_quality: self.quality_monitor.average_quality(),
444            is_complete: self.is_complete,
445        }
446    }
447
448    /// Get accumulated text
449    pub fn accumulated_text(&self) -> &str {
450        &self.accumulated_text
451    }
452
453    /// Check if quality is acceptable
454    pub fn quality_acceptable(&self) -> bool {
455        self.quality_monitor.should_continue()
456    }
457}
458
459/// Stream statistics
460#[derive(Debug, Clone)]
461pub struct StreamStats {
462    /// Total tokens processed
463    pub token_count: usize,
464    /// Elapsed time in milliseconds
465    pub elapsed_ms: u64,
466    /// Tokens per second
467    pub tokens_per_second: f64,
468    /// Average quality score
469    pub average_quality: f32,
470    /// Is stream complete
471    pub is_complete: bool,
472}
473
474// ============================================================================
475// Context Window Management
476// ============================================================================
477
478/// Context window state
479#[derive(Debug, Clone)]
480pub struct ContextWindow {
481    /// Current messages
482    messages: Vec<Message>,
483    /// System prompt
484    system_prompt: Option<String>,
485    /// Maximum tokens for context
486    max_tokens: usize,
487    /// Current estimated token count
488    current_tokens: usize,
489    /// Compression threshold (0.0 - 1.0)
490    compression_threshold: f32,
491}
492
493impl ContextWindow {
494    /// Create new context window
495    pub fn new(max_tokens: usize) -> Self {
496        Self {
497            messages: Vec::new(),
498            system_prompt: None,
499            max_tokens,
500            current_tokens: 0,
501            compression_threshold: 0.8,
502        }
503    }
504
505    /// Set system prompt
506    pub fn set_system(&mut self, prompt: impl Into<String>) {
507        let prompt = prompt.into();
508        self.current_tokens -= self.system_prompt.as_ref().map_or(0, |p| p.len() / 4);
509        self.current_tokens += prompt.len() / 4;
510        self.system_prompt = Some(prompt);
511    }
512
513    /// Add message to context
514    pub fn add_message(&mut self, message: Message) {
515        let tokens = message.estimate_tokens();
516        self.current_tokens += tokens;
517        self.messages.push(message);
518
519        // Check if compression needed
520        if self.needs_compression() {
521            self.compress();
522        }
523    }
524
525    /// Check if context needs compression
526    pub fn needs_compression(&self) -> bool {
527        self.current_tokens as f32 > self.max_tokens as f32 * self.compression_threshold
528    }
529
530    /// Get utilization ratio
531    pub fn utilization(&self) -> f32 {
532        self.current_tokens as f32 / self.max_tokens as f32
533    }
534
535    /// Compress context to fit within limits
536    pub fn compress(&mut self) {
537        // Strategy: Keep system, first user message, and recent messages
538        if self.messages.len() <= 4 {
539            return;
540        }
541
542        let target_tokens = (self.max_tokens as f32 * 0.6) as usize;
543
544        // Keep first and last N messages
545        let keep_first = 1;
546        let mut keep_last = 3;
547
548        while self.current_tokens > target_tokens && keep_last > 1 {
549            let to_remove = self.messages.len() - keep_first - keep_last;
550            if to_remove > 0 {
551                // Remove middle messages
552                let removed: Vec<_> = self.messages.drain(keep_first..keep_first + 1).collect();
553                for msg in removed {
554                    self.current_tokens -= msg.estimate_tokens();
555                }
556            } else {
557                keep_last -= 1;
558            }
559        }
560    }
561
562    /// Expand context window for complex task
563    pub fn expand_for_task(&mut self, task_complexity: f32, model: ClaudeModel) {
564        // Higher complexity = larger context window needed
565        let base_max = model.max_context_tokens();
566        let expansion_factor = 0.5 + (task_complexity * 0.5); // 0.5 to 1.0
567        self.max_tokens = (base_max as f32 * expansion_factor) as usize;
568    }
569
570    /// Get messages for request
571    pub fn get_messages(&self) -> &[Message] {
572        &self.messages
573    }
574
575    /// Get system prompt
576    pub fn get_system(&self) -> Option<&str> {
577        self.system_prompt.as_deref()
578    }
579
580    /// Get current token estimate
581    pub fn token_count(&self) -> usize {
582        self.current_tokens
583    }
584
585    /// Get remaining capacity
586    pub fn remaining_capacity(&self) -> usize {
587        self.max_tokens.saturating_sub(self.current_tokens)
588    }
589
590    /// Clear context
591    pub fn clear(&mut self) {
592        self.messages.clear();
593        self.current_tokens = self.system_prompt.as_ref().map_or(0, |p| p.len() / 4);
594    }
595}
596
597/// Context manager for dynamic window management
598pub struct ContextManager {
599    /// Windows by agent ID
600    windows: HashMap<String, ContextWindow>,
601    /// Default max tokens
602    default_max_tokens: usize,
603}
604
605impl ContextManager {
606    /// Create new context manager
607    pub fn new(default_max_tokens: usize) -> Self {
608        Self {
609            windows: HashMap::new(),
610            default_max_tokens,
611        }
612    }
613
614    /// Get or create context window for agent
615    pub fn get_window(&mut self, agent_id: &str) -> &mut ContextWindow {
616        if !self.windows.contains_key(agent_id) {
617            self.windows.insert(
618                agent_id.to_string(),
619                ContextWindow::new(self.default_max_tokens),
620            );
621        }
622        self.windows.get_mut(agent_id).unwrap()
623    }
624
625    /// Remove context window
626    pub fn remove_window(&mut self, agent_id: &str) {
627        self.windows.remove(agent_id);
628    }
629
630    /// Get total token usage across all windows
631    pub fn total_tokens(&self) -> usize {
632        self.windows.values().map(|w| w.token_count()).sum()
633    }
634
635    /// Get window count
636    pub fn window_count(&self) -> usize {
637        self.windows.len()
638    }
639}
640
641// ============================================================================
642// Multi-Agent Coordination
643// ============================================================================
644
645/// Agent state in workflow
646#[derive(Debug, Clone, PartialEq, Eq)]
647pub enum AgentState {
648    /// Agent is idle
649    Idle,
650    /// Agent is executing task
651    Running,
652    /// Agent is waiting for dependencies
653    Blocked,
654    /// Agent completed successfully
655    Completed,
656    /// Agent failed
657    Failed,
658}
659
660/// Agent execution context
661#[derive(Debug, Clone)]
662pub struct AgentContext {
663    /// Agent identifier
664    pub agent_id: String,
665    /// Agent type
666    pub agent_type: AgentType,
667    /// Assigned model
668    pub model: ClaudeModel,
669    /// Current state
670    pub state: AgentState,
671    /// Context window
672    pub context_tokens: usize,
673    /// Total tokens used
674    pub total_tokens_used: usize,
675    /// Total cost incurred
676    pub total_cost: f64,
677    /// Task start time
678    pub started_at: Option<Instant>,
679    /// Task completion time
680    pub completed_at: Option<Instant>,
681    /// Error message if failed
682    pub error: Option<String>,
683}
684
685impl AgentContext {
686    /// Create new agent context
687    pub fn new(agent_id: String, agent_type: AgentType, model: ClaudeModel) -> Self {
688        Self {
689            agent_id,
690            agent_type,
691            model,
692            state: AgentState::Idle,
693            context_tokens: 0,
694            total_tokens_used: 0,
695            total_cost: 0.0,
696            started_at: None,
697            completed_at: None,
698            error: None,
699        }
700    }
701
702    /// Start execution
703    pub fn start(&mut self) {
704        self.state = AgentState::Running;
705        self.started_at = Some(Instant::now());
706    }
707
708    /// Mark as blocked
709    pub fn block(&mut self) {
710        self.state = AgentState::Blocked;
711    }
712
713    /// Complete execution
714    pub fn complete(&mut self, usage: &UsageStats) {
715        self.state = AgentState::Completed;
716        self.completed_at = Some(Instant::now());
717        self.total_tokens_used += usage.input_tokens + usage.output_tokens;
718        self.total_cost += usage.calculate_cost(self.model);
719    }
720
721    /// Fail execution
722    pub fn fail(&mut self, error: String) {
723        self.state = AgentState::Failed;
724        self.completed_at = Some(Instant::now());
725        self.error = Some(error);
726    }
727
728    /// Get execution duration
729    pub fn duration(&self) -> Option<Duration> {
730        match (self.started_at, self.completed_at) {
731            (Some(start), Some(end)) => Some(end.duration_since(start)),
732            (Some(start), None) => Some(start.elapsed()),
733            _ => None,
734        }
735    }
736}
737
738/// Workflow step definition
739#[derive(Debug, Clone)]
740pub struct WorkflowStep {
741    /// Step identifier
742    pub step_id: String,
743    /// Agent type to execute step
744    pub agent_type: AgentType,
745    /// Task description
746    pub task: String,
747    /// Dependencies (step IDs that must complete first)
748    pub dependencies: Vec<String>,
749    /// Required model (or None for auto-selection)
750    pub required_model: Option<ClaudeModel>,
751    /// Maximum retries
752    pub max_retries: u32,
753}
754
755/// Workflow execution result
756#[derive(Debug, Clone)]
757pub struct WorkflowResult {
758    /// Workflow identifier
759    pub workflow_id: String,
760    /// Step results
761    pub step_results: HashMap<String, StepResult>,
762    /// Total execution time
763    pub total_duration: Duration,
764    /// Total tokens used
765    pub total_tokens: usize,
766    /// Total cost
767    pub total_cost: f64,
768    /// Success status
769    pub success: bool,
770    /// Error message if failed
771    pub error: Option<String>,
772}
773
774/// Individual step result
775#[derive(Debug, Clone)]
776pub struct StepResult {
777    /// Step identifier
778    pub step_id: String,
779    /// Agent that executed step
780    pub agent_id: String,
781    /// Model used
782    pub model: ClaudeModel,
783    /// Response content
784    pub response: Option<String>,
785    /// Execution duration
786    pub duration: Duration,
787    /// Tokens used
788    pub tokens_used: usize,
789    /// Cost incurred
790    pub cost: f64,
791    /// Success status
792    pub success: bool,
793    /// Error message if failed
794    pub error: Option<String>,
795}
796
797/// Multi-agent coordinator
798pub struct AgentCoordinator {
799    /// Agent contexts
800    agents: Arc<RwLock<HashMap<String, AgentContext>>>,
801    /// Context manager
802    context_manager: Arc<RwLock<ContextManager>>,
803    /// Default model for agents
804    default_model: ClaudeModel,
805    /// Maximum concurrent agents
806    max_concurrent: usize,
807    /// Total workflows executed
808    workflows_executed: u64,
809    /// Total cost incurred
810    total_cost: f64,
811}
812
813impl AgentCoordinator {
814    /// Create new agent coordinator
815    pub fn new(default_model: ClaudeModel, max_concurrent: usize) -> Self {
816        Self {
817            agents: Arc::new(RwLock::new(HashMap::new())),
818            context_manager: Arc::new(RwLock::new(ContextManager::new(100_000))),
819            default_model,
820            max_concurrent,
821            workflows_executed: 0,
822            total_cost: 0.0,
823        }
824    }
825
826    /// Spawn a new agent
827    pub fn spawn_agent(&self, agent_id: String, agent_type: AgentType) -> Result<()> {
828        let mut agents = self.agents.write();
829
830        if agents.len() >= self.max_concurrent {
831            return Err(RuvLLMError::OutOfMemory(format!(
832                "Maximum concurrent agents ({}) reached",
833                self.max_concurrent
834            )));
835        }
836
837        if agents.contains_key(&agent_id) {
838            return Err(RuvLLMError::InvalidOperation(format!(
839                "Agent {} already exists",
840                agent_id
841            )));
842        }
843
844        let context = AgentContext::new(agent_id.clone(), agent_type, self.default_model);
845        agents.insert(agent_id, context);
846
847        Ok(())
848    }
849
850    /// Get agent context
851    pub fn get_agent(&self, agent_id: &str) -> Option<AgentContext> {
852        self.agents.read().get(agent_id).cloned()
853    }
854
855    /// Update agent state
856    pub fn update_agent<F>(&self, agent_id: &str, f: F) -> Result<()>
857    where
858        F: FnOnce(&mut AgentContext),
859    {
860        let mut agents = self.agents.write();
861        let agent = agents
862            .get_mut(agent_id)
863            .ok_or_else(|| RuvLLMError::NotFound(format!("Agent {} not found", agent_id)))?;
864        f(agent);
865        Ok(())
866    }
867
868    /// Terminate agent
869    pub fn terminate_agent(&self, agent_id: &str) -> Result<()> {
870        let mut agents = self.agents.write();
871        agents
872            .remove(agent_id)
873            .ok_or_else(|| RuvLLMError::NotFound(format!("Agent {} not found", agent_id)))?;
874
875        // Clean up context window
876        self.context_manager.write().remove_window(agent_id);
877
878        Ok(())
879    }
880
881    /// Get active agent count
882    pub fn active_agent_count(&self) -> usize {
883        self.agents
884            .read()
885            .values()
886            .filter(|a| a.state == AgentState::Running)
887            .count()
888    }
889
890    /// Get total agent count
891    pub fn total_agent_count(&self) -> usize {
892        self.agents.read().len()
893    }
894
895    /// Execute workflow with dependency resolution
896    pub async fn execute_workflow(
897        &mut self,
898        workflow_id: String,
899        steps: Vec<WorkflowStep>,
900    ) -> Result<WorkflowResult> {
901        let start_time = Instant::now();
902        let mut step_results: HashMap<String, StepResult> = HashMap::new();
903        let mut completed_steps: std::collections::HashSet<String> = std::collections::HashSet::new();
904
905        // Build dependency graph
906        let mut pending_steps: Vec<&WorkflowStep> = steps.iter().collect();
907
908        while !pending_steps.is_empty() {
909            // Find steps with satisfied dependencies
910            let ready_steps: Vec<_> = pending_steps
911                .iter()
912                .filter(|step| {
913                    step.dependencies
914                        .iter()
915                        .all(|dep| completed_steps.contains(dep))
916                })
917                .cloned()
918                .collect();
919
920            if ready_steps.is_empty() && !pending_steps.is_empty() {
921                return Err(RuvLLMError::InvalidOperation(
922                    "Workflow has circular dependencies".to_string(),
923                ));
924            }
925
926            // Execute ready steps in parallel
927            for step in ready_steps {
928                let agent_id = format!("{}-{}", workflow_id, step.step_id);
929                let model = step.required_model.unwrap_or(self.default_model);
930
931                // Spawn agent for step
932                self.spawn_agent(agent_id.clone(), step.agent_type)?;
933                self.update_agent(&agent_id, |a| a.start())?;
934
935                // Simulate execution (in production, would call Claude API)
936                let step_start = Instant::now();
937
938                // Create mock result
939                let result = StepResult {
940                    step_id: step.step_id.clone(),
941                    agent_id: agent_id.clone(),
942                    model,
943                    response: Some(format!("Completed: {}", step.task)),
944                    duration: step_start.elapsed(),
945                    tokens_used: 500, // Mock value
946                    cost: 0.001, // Mock value
947                    success: true,
948                    error: None,
949                };
950
951                self.update_agent(&agent_id, |a| {
952                    let usage = UsageStats {
953                        input_tokens: 250,
954                        output_tokens: 250,
955                    };
956                    a.complete(&usage);
957                })?;
958
959                step_results.insert(step.step_id.clone(), result);
960                completed_steps.insert(step.step_id.clone());
961
962                // Clean up agent
963                self.terminate_agent(&agent_id)?;
964            }
965
966            // Remove completed steps from pending
967            pending_steps.retain(|step| !completed_steps.contains(&step.step_id));
968        }
969
970        // Calculate totals
971        let total_tokens: usize = step_results.values().map(|r| r.tokens_used).sum();
972        let total_cost: f64 = step_results.values().map(|r| r.cost).sum();
973
974        self.workflows_executed += 1;
975        self.total_cost += total_cost;
976
977        Ok(WorkflowResult {
978            workflow_id,
979            step_results,
980            total_duration: start_time.elapsed(),
981            total_tokens,
982            total_cost,
983            success: true,
984            error: None,
985        })
986    }
987
988    /// Get coordinator statistics
989    pub fn stats(&self) -> CoordinatorStats {
990        let agents = self.agents.read();
991        let active_count = agents
992            .values()
993            .filter(|a| a.state == AgentState::Running)
994            .count();
995        let total_tokens: usize = agents.values().map(|a| a.total_tokens_used).sum();
996
997        CoordinatorStats {
998            total_agents: agents.len(),
999            active_agents: active_count,
1000            blocked_agents: agents
1001                .values()
1002                .filter(|a| a.state == AgentState::Blocked)
1003                .count(),
1004            completed_agents: agents
1005                .values()
1006                .filter(|a| a.state == AgentState::Completed)
1007                .count(),
1008            failed_agents: agents
1009                .values()
1010                .filter(|a| a.state == AgentState::Failed)
1011                .count(),
1012            workflows_executed: self.workflows_executed,
1013            total_tokens_used: total_tokens,
1014            total_cost: self.total_cost,
1015        }
1016    }
1017}
1018
1019/// Coordinator statistics
1020#[derive(Debug, Clone)]
1021pub struct CoordinatorStats {
1022    /// Total agents created
1023    pub total_agents: usize,
1024    /// Currently active agents
1025    pub active_agents: usize,
1026    /// Blocked agents
1027    pub blocked_agents: usize,
1028    /// Completed agents
1029    pub completed_agents: usize,
1030    /// Failed agents
1031    pub failed_agents: usize,
1032    /// Total workflows executed
1033    pub workflows_executed: u64,
1034    /// Total tokens used
1035    pub total_tokens_used: usize,
1036    /// Total cost incurred
1037    pub total_cost: f64,
1038}
1039
1040// ============================================================================
1041// Cost Estimation
1042// ============================================================================
1043
1044/// Cost estimator for Claude API usage
1045pub struct CostEstimator {
1046    /// Usage by model
1047    usage_by_model: HashMap<ClaudeModel, UsageStats>,
1048}
1049
1050impl CostEstimator {
1051    /// Create new cost estimator
1052    pub fn new() -> Self {
1053        Self {
1054            usage_by_model: HashMap::new(),
1055        }
1056    }
1057
1058    /// Estimate cost for a request
1059    pub fn estimate_request_cost(
1060        &self,
1061        model: ClaudeModel,
1062        input_tokens: usize,
1063        expected_output_tokens: usize,
1064    ) -> f64 {
1065        let input_cost = (input_tokens as f64 / 1000.0) * model.input_cost_per_1k();
1066        let output_cost = (expected_output_tokens as f64 / 1000.0) * model.output_cost_per_1k();
1067        input_cost + output_cost
1068    }
1069
1070    /// Record actual usage
1071    pub fn record_usage(&mut self, model: ClaudeModel, usage: &UsageStats) {
1072        let entry = self.usage_by_model.entry(model).or_insert(UsageStats::default());
1073        entry.input_tokens += usage.input_tokens;
1074        entry.output_tokens += usage.output_tokens;
1075    }
1076
1077    /// Get total cost to date
1078    pub fn total_cost(&self) -> f64 {
1079        self.usage_by_model
1080            .iter()
1081            .map(|(model, usage)| usage.calculate_cost(*model))
1082            .sum()
1083    }
1084
1085    /// Get cost breakdown by model
1086    pub fn cost_breakdown(&self) -> HashMap<ClaudeModel, f64> {
1087        self.usage_by_model
1088            .iter()
1089            .map(|(model, usage)| (*model, usage.calculate_cost(*model)))
1090            .collect()
1091    }
1092
1093    /// Get total usage by model
1094    pub fn usage_by_model(&self) -> &HashMap<ClaudeModel, UsageStats> {
1095        &self.usage_by_model
1096    }
1097}
1098
1099impl Default for CostEstimator {
1100    fn default() -> Self {
1101        Self::new()
1102    }
1103}
1104
1105// ============================================================================
1106// Latency Tracking
1107// ============================================================================
1108
1109/// Latency tracker for performance monitoring
1110pub struct LatencyTracker {
1111    /// Samples by model
1112    samples: HashMap<ClaudeModel, Vec<LatencySample>>,
1113    /// Maximum samples to keep per model
1114    max_samples: usize,
1115}
1116
1117/// Single latency sample
1118#[derive(Debug, Clone)]
1119pub struct LatencySample {
1120    /// Time to first token (ms)
1121    pub ttft_ms: u64,
1122    /// Total response time (ms)
1123    pub total_ms: u64,
1124    /// Input tokens
1125    pub input_tokens: usize,
1126    /// Output tokens
1127    pub output_tokens: usize,
1128    /// Timestamp
1129    pub timestamp: Instant,
1130}
1131
1132impl LatencyTracker {
1133    /// Create new latency tracker
1134    pub fn new(max_samples: usize) -> Self {
1135        Self {
1136            samples: HashMap::new(),
1137            max_samples,
1138        }
1139    }
1140
1141    /// Record latency sample
1142    pub fn record(&mut self, model: ClaudeModel, sample: LatencySample) {
1143        let samples = self.samples.entry(model).or_insert_with(Vec::new);
1144        samples.push(sample);
1145
1146        // Trim old samples
1147        if samples.len() > self.max_samples {
1148            samples.remove(0);
1149        }
1150    }
1151
1152    /// Get average TTFT for model
1153    pub fn average_ttft(&self, model: ClaudeModel) -> Option<f64> {
1154        self.samples.get(&model).map(|samples| {
1155            if samples.is_empty() {
1156                return 0.0;
1157            }
1158            let sum: u64 = samples.iter().map(|s| s.ttft_ms).sum();
1159            sum as f64 / samples.len() as f64
1160        })
1161    }
1162
1163    /// Get p95 TTFT for model
1164    pub fn p95_ttft(&self, model: ClaudeModel) -> Option<u64> {
1165        self.samples.get(&model).and_then(|samples| {
1166            if samples.is_empty() {
1167                return None;
1168            }
1169            let mut ttfts: Vec<u64> = samples.iter().map(|s| s.ttft_ms).collect();
1170            ttfts.sort();
1171            let idx = (ttfts.len() as f64 * 0.95) as usize;
1172            ttfts.get(idx.min(ttfts.len() - 1)).copied()
1173        })
1174    }
1175
1176    /// Get average tokens per second for model
1177    pub fn average_tokens_per_second(&self, model: ClaudeModel) -> Option<f64> {
1178        self.samples.get(&model).map(|samples| {
1179            if samples.is_empty() {
1180                return 0.0;
1181            }
1182            let total_tokens: usize = samples.iter().map(|s| s.output_tokens).sum();
1183            let total_time_ms: u64 = samples.iter().map(|s| s.total_ms - s.ttft_ms).sum();
1184            if total_time_ms == 0 {
1185                return 0.0;
1186            }
1187            total_tokens as f64 / (total_time_ms as f64 / 1000.0)
1188        })
1189    }
1190
1191    /// Get statistics for model
1192    pub fn get_stats(&self, model: ClaudeModel) -> Option<LatencyStats> {
1193        self.samples.get(&model).map(|samples| LatencyStats {
1194            sample_count: samples.len(),
1195            avg_ttft_ms: self.average_ttft(model).unwrap_or(0.0),
1196            p95_ttft_ms: self.p95_ttft(model).unwrap_or(0),
1197            avg_tokens_per_second: self.average_tokens_per_second(model).unwrap_or(0.0),
1198        })
1199    }
1200}
1201
1202/// Latency statistics
1203#[derive(Debug, Clone)]
1204pub struct LatencyStats {
1205    /// Number of samples
1206    pub sample_count: usize,
1207    /// Average time to first token
1208    pub avg_ttft_ms: f64,
1209    /// P95 time to first token
1210    pub p95_ttft_ms: u64,
1211    /// Average tokens per second
1212    pub avg_tokens_per_second: f64,
1213}
1214
1215// ============================================================================
1216// Tests
1217// ============================================================================
1218
1219#[cfg(test)]
1220mod tests {
1221    use super::*;
1222
1223    #[test]
1224    fn test_claude_model_costs() {
1225        let usage = UsageStats {
1226            input_tokens: 1000,
1227            output_tokens: 500,
1228        };
1229
1230        let haiku_cost = usage.calculate_cost(ClaudeModel::Haiku);
1231        let sonnet_cost = usage.calculate_cost(ClaudeModel::Sonnet);
1232        let opus_cost = usage.calculate_cost(ClaudeModel::Opus);
1233
1234        assert!(haiku_cost < sonnet_cost);
1235        assert!(sonnet_cost < opus_cost);
1236    }
1237
1238    #[test]
1239    fn test_context_window_compression() {
1240        let mut window = ContextWindow::new(1000);
1241
1242        // Add many messages
1243        for i in 0..20 {
1244            window.add_message(Message::user(format!("Message {} with some content to add tokens", i)));
1245        }
1246
1247        // Window should have compressed
1248        assert!(window.token_count() <= 1000);
1249    }
1250
1251    #[test]
1252    fn test_message_token_estimation() {
1253        let msg = Message::user("Hello, this is a test message with some content.");
1254        let tokens = msg.estimate_tokens();
1255        assert!(tokens > 0);
1256        assert!(tokens < 100); // Should be reasonable estimate
1257    }
1258
1259    #[test]
1260    fn test_quality_monitor() {
1261        let mut monitor = QualityMonitor::new(0.6, 10);
1262
1263        // Add good quality scores
1264        for _ in 0..5 {
1265            monitor.record(0.8);
1266        }
1267        assert!(monitor.should_continue());
1268
1269        // Add bad quality scores
1270        let mut bad_monitor = QualityMonitor::new(0.6, 10);
1271        for _ in 0..5 {
1272            bad_monitor.record(0.3);
1273        }
1274        assert!(!bad_monitor.should_continue());
1275    }
1276
1277    #[test]
1278    fn test_agent_coordinator() {
1279        let coordinator = AgentCoordinator::new(ClaudeModel::Sonnet, 10);
1280
1281        coordinator.spawn_agent("agent-1".to_string(), AgentType::Coder).unwrap();
1282        coordinator.spawn_agent("agent-2".to_string(), AgentType::Researcher).unwrap();
1283
1284        assert_eq!(coordinator.total_agent_count(), 2);
1285
1286        coordinator.update_agent("agent-1", |a| a.start()).unwrap();
1287        assert_eq!(coordinator.active_agent_count(), 1);
1288
1289        coordinator.terminate_agent("agent-1").unwrap();
1290        assert_eq!(coordinator.total_agent_count(), 1);
1291    }
1292
1293    #[test]
1294    fn test_cost_estimator() {
1295        let mut estimator = CostEstimator::new();
1296
1297        let usage = UsageStats {
1298            input_tokens: 1000,
1299            output_tokens: 500,
1300        };
1301
1302        estimator.record_usage(ClaudeModel::Sonnet, &usage);
1303        estimator.record_usage(ClaudeModel::Haiku, &usage);
1304
1305        let total = estimator.total_cost();
1306        assert!(total > 0.0);
1307
1308        let breakdown = estimator.cost_breakdown();
1309        assert!(breakdown.contains_key(&ClaudeModel::Sonnet));
1310        assert!(breakdown.contains_key(&ClaudeModel::Haiku));
1311    }
1312
1313    #[test]
1314    fn test_latency_tracker() {
1315        let mut tracker = LatencyTracker::new(100);
1316
1317        for i in 0..10 {
1318            tracker.record(
1319                ClaudeModel::Sonnet,
1320                LatencySample {
1321                    ttft_ms: 400 + i * 10,
1322                    total_ms: 1000 + i * 100,
1323                    input_tokens: 500,
1324                    output_tokens: 200,
1325                    timestamp: Instant::now(),
1326                },
1327            );
1328        }
1329
1330        let stats = tracker.get_stats(ClaudeModel::Sonnet).unwrap();
1331        assert_eq!(stats.sample_count, 10);
1332        assert!(stats.avg_ttft_ms > 400.0);
1333        assert!(stats.avg_tokens_per_second > 0.0);
1334    }
1335}