Skip to main content

ruvllm/claude_flow/
claude_integration.rs

1//! Claude API Integration for Agent Communication
2//!
3//! Provides full Claude API compatibility for multi-agent coordination,
4//! including streaming response handling, context window management,
5//! and workflow orchestration.
6//!
7//! ## Key Features
8//!
9//! - **Full Claude API Compatibility**: Messages, streaming, tool use
10//! - **Streaming Response Handling**: Real-time token generation with quality monitoring
11//! - **Context Window Management**: Dynamic compression/expansion based on task complexity
12//! - **Multi-Agent Coordination**: Workflow orchestration with dependency resolution
13//!
14//! ## Architecture
15//!
16//! ```text
17//! +-------------------+     +-------------------+
18//! | AgentCoordinator  |---->| ClaudeClient      |
19//! | (workflow mgmt)   |     | (API interface)   |
20//! +--------+----------+     +--------+----------+
21//!          |                         |
22//!          v                         v
23//! +--------+----------+     +--------+----------+
24//! | ResponseStreamer  |<----| ContextManager    |
25//! | (token handling)  |     | (window mgmt)     |
26//! +-------------------+     +-------------------+
27//! ```
28
29use parking_lot::RwLock;
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::sync::Arc;
33use std::time::{Duration, Instant};
34use tokio::sync::mpsc;
35
36use super::{AgentType, ClaudeFlowAgent, ClaudeFlowTask};
37use crate::error::{Result, RuvLLMError};
38
39// ============================================================================
40// Claude API Types
41// ============================================================================
42
43/// Claude model variants for intelligent routing
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
45pub enum ClaudeModel {
46    /// Fast, cost-effective for simple tasks
47    Haiku,
48    /// Balanced performance and capability
49    Sonnet,
50    /// Most capable for complex reasoning
51    Opus,
52}
53
54impl ClaudeModel {
55    /// Get short name for the model
56    pub fn name(&self) -> &'static str {
57        match self {
58            Self::Haiku => "haiku",
59            Self::Sonnet => "sonnet",
60            Self::Opus => "opus",
61        }
62    }
63
64    /// Get model identifier string
65    pub fn model_id(&self) -> &'static str {
66        match self {
67            Self::Haiku => "claude-3-5-haiku-20241022",
68            Self::Sonnet => "claude-sonnet-4-20250514",
69            Self::Opus => "claude-opus-4-20250514",
70        }
71    }
72
73    /// Get cost per 1K input tokens (USD)
74    pub fn input_cost_per_1k(&self) -> f64 {
75        match self {
76            Self::Haiku => 0.00025,
77            Self::Sonnet => 0.003,
78            Self::Opus => 0.015,
79        }
80    }
81
82    /// Get cost per 1K output tokens (USD)
83    pub fn output_cost_per_1k(&self) -> f64 {
84        match self {
85            Self::Haiku => 0.00125,
86            Self::Sonnet => 0.015,
87            Self::Opus => 0.075,
88        }
89    }
90
91    /// Get typical latency for first token (ms)
92    pub fn typical_ttft_ms(&self) -> u64 {
93        match self {
94            Self::Haiku => 200,
95            Self::Sonnet => 500,
96            Self::Opus => 1500,
97        }
98    }
99
100    /// Get maximum context window size
101    pub fn max_context_tokens(&self) -> usize {
102        match self {
103            Self::Haiku => 200_000,
104            Self::Sonnet => 200_000,
105            Self::Opus => 200_000,
106        }
107    }
108}
109
110impl Default for ClaudeModel {
111    fn default() -> Self {
112        Self::Sonnet
113    }
114}
115
116/// Message role in conversation
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
118#[serde(rename_all = "lowercase")]
119pub enum MessageRole {
120    /// User message
121    User,
122    /// Assistant response
123    Assistant,
124    /// System instructions
125    System,
126}
127
128/// Content block types
129#[derive(Debug, Clone, Serialize, Deserialize)]
130#[serde(tag = "type", rename_all = "snake_case")]
131pub enum ContentBlock {
132    /// Text content
133    Text { text: String },
134    /// Tool use request
135    ToolUse {
136        id: String,
137        name: String,
138        input: serde_json::Value,
139    },
140    /// Tool result
141    ToolResult {
142        tool_use_id: String,
143        content: String,
144    },
145}
146
147/// Message in conversation
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct Message {
150    /// Message role
151    pub role: MessageRole,
152    /// Message content blocks
153    pub content: Vec<ContentBlock>,
154}
155
156impl Message {
157    /// Create a simple text message
158    pub fn text(role: MessageRole, text: impl Into<String>) -> Self {
159        Self {
160            role,
161            content: vec![ContentBlock::Text { text: text.into() }],
162        }
163    }
164
165    /// Create a user message
166    pub fn user(text: impl Into<String>) -> Self {
167        Self::text(MessageRole::User, text)
168    }
169
170    /// Create an assistant message
171    pub fn assistant(text: impl Into<String>) -> Self {
172        Self::text(MessageRole::Assistant, text)
173    }
174
175    /// Estimate token count for this message
176    pub fn estimate_tokens(&self) -> usize {
177        self.content
178            .iter()
179            .map(|block| {
180                match block {
181                    ContentBlock::Text { text } => text.len() / 4, // ~4 chars per token
182                    ContentBlock::ToolUse { input, .. } => {
183                        input.to_string().len() / 4 + 50 // overhead for tool structure
184                    }
185                    ContentBlock::ToolResult { content, .. } => content.len() / 4 + 20,
186                }
187            })
188            .sum()
189    }
190}
191
192/// Request to Claude API
193#[derive(Debug, Clone, Serialize)]
194pub struct ClaudeRequest {
195    /// Model to use
196    pub model: String,
197    /// Conversation messages
198    pub messages: Vec<Message>,
199    /// Maximum tokens to generate
200    pub max_tokens: usize,
201    /// System prompt
202    #[serde(skip_serializing_if = "Option::is_none")]
203    pub system: Option<String>,
204    /// Temperature for sampling
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub temperature: Option<f32>,
207    /// Enable streaming
208    #[serde(skip_serializing_if = "Option::is_none")]
209    pub stream: Option<bool>,
210}
211
212/// Response from Claude API
213#[derive(Debug, Clone, Deserialize)]
214pub struct ClaudeResponse {
215    /// Response ID
216    pub id: String,
217    /// Model used
218    pub model: String,
219    /// Content blocks
220    pub content: Vec<ContentBlock>,
221    /// Stop reason
222    pub stop_reason: Option<String>,
223    /// Usage statistics
224    pub usage: UsageStats,
225}
226
227/// Token usage statistics
228#[derive(Debug, Clone, Default, Deserialize, Serialize)]
229pub struct UsageStats {
230    /// Input tokens used
231    pub input_tokens: usize,
232    /// Output tokens generated
233    pub output_tokens: usize,
234}
235
236impl UsageStats {
237    /// Calculate cost for given model
238    pub fn calculate_cost(&self, model: ClaudeModel) -> f64 {
239        let input_cost = (self.input_tokens as f64 / 1000.0) * model.input_cost_per_1k();
240        let output_cost = (self.output_tokens as f64 / 1000.0) * model.output_cost_per_1k();
241        input_cost + output_cost
242    }
243}
244
245// ============================================================================
246// Streaming Types
247// ============================================================================
248
249/// Streaming token with metadata
250#[derive(Debug, Clone)]
251pub struct StreamToken {
252    /// Token text
253    pub text: String,
254    /// Token index in sequence
255    pub index: usize,
256    /// Cumulative latency from stream start
257    pub latency_ms: u64,
258    /// Quality score (0.0 - 1.0) if available
259    pub quality_score: Option<f32>,
260}
261
262/// Stream event types
263#[derive(Debug, Clone)]
264pub enum StreamEvent {
265    /// Stream started
266    Start {
267        request_id: String,
268        model: ClaudeModel,
269    },
270    /// Token generated
271    Token(StreamToken),
272    /// Content block completed
273    ContentBlockComplete { index: usize, content: ContentBlock },
274    /// Stream completed
275    Complete {
276        usage: UsageStats,
277        stop_reason: String,
278        total_latency_ms: u64,
279    },
280    /// Error occurred
281    Error { message: String, is_retryable: bool },
282}
283
284/// Quality monitoring for streaming responses
285#[derive(Debug, Clone)]
286pub struct QualityMonitor {
287    /// Minimum acceptable quality score
288    pub min_quality: f32,
289    /// Check interval (tokens)
290    pub check_interval: usize,
291    /// Accumulated quality scores
292    scores: Vec<f32>,
293    /// Tokens since last check
294    tokens_since_check: usize,
295}
296
297impl QualityMonitor {
298    /// Create new quality monitor
299    pub fn new(min_quality: f32, check_interval: usize) -> Self {
300        Self {
301            min_quality,
302            check_interval,
303            scores: Vec::new(),
304            tokens_since_check: 0,
305        }
306    }
307
308    /// Record a quality observation
309    pub fn record(&mut self, score: f32) {
310        self.scores.push(score);
311        self.tokens_since_check += 1;
312    }
313
314    /// Check if quality is acceptable
315    pub fn should_continue(&self) -> bool {
316        if self.scores.is_empty() {
317            return true;
318        }
319        let avg = self.scores.iter().sum::<f32>() / self.scores.len() as f32;
320        avg >= self.min_quality
321    }
322
323    /// Check if it's time to evaluate quality
324    pub fn should_check(&self) -> bool {
325        self.tokens_since_check >= self.check_interval
326    }
327
328    /// Reset check counter
329    pub fn reset_check(&mut self) {
330        self.tokens_since_check = 0;
331    }
332
333    /// Get average quality score
334    pub fn average_quality(&self) -> f32 {
335        if self.scores.is_empty() {
336            1.0
337        } else {
338            self.scores.iter().sum::<f32>() / self.scores.len() as f32
339        }
340    }
341}
342
343/// Response streamer for real-time token handling
344pub struct ResponseStreamer {
345    /// Request ID
346    pub request_id: String,
347    /// Model being used
348    pub model: ClaudeModel,
349    /// Stream start time
350    start_time: Instant,
351    /// Token count
352    token_count: usize,
353    /// Quality monitor
354    quality_monitor: QualityMonitor,
355    /// Event sender
356    sender: mpsc::Sender<StreamEvent>,
357    /// Accumulated text
358    accumulated_text: String,
359    /// Is stream complete
360    is_complete: bool,
361}
362
363impl ResponseStreamer {
364    /// Create new response streamer
365    pub fn new(request_id: String, model: ClaudeModel, sender: mpsc::Sender<StreamEvent>) -> Self {
366        Self {
367            request_id: request_id.clone(),
368            model,
369            start_time: Instant::now(),
370            token_count: 0,
371            quality_monitor: QualityMonitor::new(0.6, 20),
372            sender,
373            accumulated_text: String::new(),
374            is_complete: false,
375        }
376    }
377
378    /// Process incoming token
379    pub async fn process_token(&mut self, text: String, quality_score: Option<f32>) -> Result<()> {
380        if self.is_complete {
381            return Err(RuvLLMError::InvalidOperation(
382                "Stream already complete".to_string(),
383            ));
384        }
385
386        let token = StreamToken {
387            text: text.clone(),
388            index: self.token_count,
389            latency_ms: self.start_time.elapsed().as_millis() as u64,
390            quality_score,
391        };
392
393        // Update quality monitor
394        if let Some(score) = quality_score {
395            self.quality_monitor.record(score);
396        }
397
398        // Accumulate text
399        self.accumulated_text.push_str(&text);
400        self.token_count += 1;
401
402        // Send token event
403        self.sender
404            .send(StreamEvent::Token(token))
405            .await
406            .map_err(|e| RuvLLMError::InvalidOperation(format!("Failed to send token: {}", e)))?;
407
408        Ok(())
409    }
410
411    /// Complete the stream
412    pub async fn complete(&mut self, usage: UsageStats, stop_reason: String) -> Result<()> {
413        self.is_complete = true;
414
415        self.sender
416            .send(StreamEvent::Complete {
417                usage,
418                stop_reason,
419                total_latency_ms: self.start_time.elapsed().as_millis() as u64,
420            })
421            .await
422            .map_err(|e| {
423                RuvLLMError::InvalidOperation(format!("Failed to send complete: {}", e))
424            })?;
425
426        Ok(())
427    }
428
429    /// Get current statistics
430    pub fn stats(&self) -> StreamStats {
431        let elapsed = self.start_time.elapsed();
432        StreamStats {
433            token_count: self.token_count,
434            elapsed_ms: elapsed.as_millis() as u64,
435            tokens_per_second: if elapsed.as_secs_f64() > 0.0 {
436                self.token_count as f64 / elapsed.as_secs_f64()
437            } else {
438                0.0
439            },
440            average_quality: self.quality_monitor.average_quality(),
441            is_complete: self.is_complete,
442        }
443    }
444
445    /// Get accumulated text
446    pub fn accumulated_text(&self) -> &str {
447        &self.accumulated_text
448    }
449
450    /// Check if quality is acceptable
451    pub fn quality_acceptable(&self) -> bool {
452        self.quality_monitor.should_continue()
453    }
454}
455
456/// Stream statistics
457#[derive(Debug, Clone)]
458pub struct StreamStats {
459    /// Total tokens processed
460    pub token_count: usize,
461    /// Elapsed time in milliseconds
462    pub elapsed_ms: u64,
463    /// Tokens per second
464    pub tokens_per_second: f64,
465    /// Average quality score
466    pub average_quality: f32,
467    /// Is stream complete
468    pub is_complete: bool,
469}
470
471// ============================================================================
472// Context Window Management
473// ============================================================================
474
475/// Context window state
476#[derive(Debug, Clone)]
477pub struct ContextWindow {
478    /// Current messages
479    messages: Vec<Message>,
480    /// System prompt
481    system_prompt: Option<String>,
482    /// Maximum tokens for context
483    max_tokens: usize,
484    /// Current estimated token count
485    current_tokens: usize,
486    /// Compression threshold (0.0 - 1.0)
487    compression_threshold: f32,
488}
489
490impl ContextWindow {
491    /// Create new context window
492    pub fn new(max_tokens: usize) -> Self {
493        Self {
494            messages: Vec::new(),
495            system_prompt: None,
496            max_tokens,
497            current_tokens: 0,
498            compression_threshold: 0.8,
499        }
500    }
501
502    /// Set system prompt
503    pub fn set_system(&mut self, prompt: impl Into<String>) {
504        let prompt = prompt.into();
505        self.current_tokens -= self.system_prompt.as_ref().map_or(0, |p| p.len() / 4);
506        self.current_tokens += prompt.len() / 4;
507        self.system_prompt = Some(prompt);
508    }
509
510    /// Add message to context
511    pub fn add_message(&mut self, message: Message) {
512        let tokens = message.estimate_tokens();
513        self.current_tokens += tokens;
514        self.messages.push(message);
515
516        // Check if compression needed
517        if self.needs_compression() {
518            self.compress();
519        }
520    }
521
522    /// Check if context needs compression
523    pub fn needs_compression(&self) -> bool {
524        self.current_tokens as f32 > self.max_tokens as f32 * self.compression_threshold
525    }
526
527    /// Get utilization ratio
528    pub fn utilization(&self) -> f32 {
529        self.current_tokens as f32 / self.max_tokens as f32
530    }
531
532    /// Compress context to fit within limits
533    pub fn compress(&mut self) {
534        // Strategy: Keep system, first user message, and recent messages
535        if self.messages.len() <= 4 {
536            return;
537        }
538
539        let target_tokens = (self.max_tokens as f32 * 0.6) as usize;
540
541        // Keep first and last N messages
542        let keep_first = 1;
543        let mut keep_last = 3;
544
545        while self.current_tokens > target_tokens && keep_last > 1 {
546            let to_remove = self.messages.len() - keep_first - keep_last;
547            if to_remove > 0 {
548                // Remove middle messages
549                let removed: Vec<_> = self.messages.drain(keep_first..keep_first + 1).collect();
550                for msg in removed {
551                    self.current_tokens -= msg.estimate_tokens();
552                }
553            } else {
554                keep_last -= 1;
555            }
556        }
557    }
558
559    /// Expand context window for complex task
560    pub fn expand_for_task(&mut self, task_complexity: f32, model: ClaudeModel) {
561        // Higher complexity = larger context window needed
562        let base_max = model.max_context_tokens();
563        let expansion_factor = 0.5 + (task_complexity * 0.5); // 0.5 to 1.0
564        self.max_tokens = (base_max as f32 * expansion_factor) as usize;
565    }
566
567    /// Get messages for request
568    pub fn get_messages(&self) -> &[Message] {
569        &self.messages
570    }
571
572    /// Get system prompt
573    pub fn get_system(&self) -> Option<&str> {
574        self.system_prompt.as_deref()
575    }
576
577    /// Get current token estimate
578    pub fn token_count(&self) -> usize {
579        self.current_tokens
580    }
581
582    /// Get remaining capacity
583    pub fn remaining_capacity(&self) -> usize {
584        self.max_tokens.saturating_sub(self.current_tokens)
585    }
586
587    /// Clear context
588    pub fn clear(&mut self) {
589        self.messages.clear();
590        self.current_tokens = self.system_prompt.as_ref().map_or(0, |p| p.len() / 4);
591    }
592}
593
594/// Context manager for dynamic window management
595pub struct ContextManager {
596    /// Windows by agent ID
597    windows: HashMap<String, ContextWindow>,
598    /// Default max tokens
599    default_max_tokens: usize,
600}
601
602impl ContextManager {
603    /// Create new context manager
604    pub fn new(default_max_tokens: usize) -> Self {
605        Self {
606            windows: HashMap::new(),
607            default_max_tokens,
608        }
609    }
610
611    /// Get or create context window for agent
612    pub fn get_window(&mut self, agent_id: &str) -> &mut ContextWindow {
613        if !self.windows.contains_key(agent_id) {
614            self.windows.insert(
615                agent_id.to_string(),
616                ContextWindow::new(self.default_max_tokens),
617            );
618        }
619        self.windows.get_mut(agent_id).unwrap()
620    }
621
622    /// Remove context window
623    pub fn remove_window(&mut self, agent_id: &str) {
624        self.windows.remove(agent_id);
625    }
626
627    /// Get total token usage across all windows
628    pub fn total_tokens(&self) -> usize {
629        self.windows.values().map(|w| w.token_count()).sum()
630    }
631
632    /// Get window count
633    pub fn window_count(&self) -> usize {
634        self.windows.len()
635    }
636}
637
638// ============================================================================
639// Multi-Agent Coordination
640// ============================================================================
641
642/// Agent state in workflow
643#[derive(Debug, Clone, PartialEq, Eq)]
644pub enum AgentState {
645    /// Agent is idle
646    Idle,
647    /// Agent is executing task
648    Running,
649    /// Agent is waiting for dependencies
650    Blocked,
651    /// Agent completed successfully
652    Completed,
653    /// Agent failed
654    Failed,
655}
656
657/// Agent execution context
658#[derive(Debug, Clone)]
659pub struct AgentContext {
660    /// Agent identifier
661    pub agent_id: String,
662    /// Agent type
663    pub agent_type: AgentType,
664    /// Assigned model
665    pub model: ClaudeModel,
666    /// Current state
667    pub state: AgentState,
668    /// Context window
669    pub context_tokens: usize,
670    /// Total tokens used
671    pub total_tokens_used: usize,
672    /// Total cost incurred
673    pub total_cost: f64,
674    /// Task start time
675    pub started_at: Option<Instant>,
676    /// Task completion time
677    pub completed_at: Option<Instant>,
678    /// Error message if failed
679    pub error: Option<String>,
680}
681
682impl AgentContext {
683    /// Create new agent context
684    pub fn new(agent_id: String, agent_type: AgentType, model: ClaudeModel) -> Self {
685        Self {
686            agent_id,
687            agent_type,
688            model,
689            state: AgentState::Idle,
690            context_tokens: 0,
691            total_tokens_used: 0,
692            total_cost: 0.0,
693            started_at: None,
694            completed_at: None,
695            error: None,
696        }
697    }
698
699    /// Start execution
700    pub fn start(&mut self) {
701        self.state = AgentState::Running;
702        self.started_at = Some(Instant::now());
703    }
704
705    /// Mark as blocked
706    pub fn block(&mut self) {
707        self.state = AgentState::Blocked;
708    }
709
710    /// Complete execution
711    pub fn complete(&mut self, usage: &UsageStats) {
712        self.state = AgentState::Completed;
713        self.completed_at = Some(Instant::now());
714        self.total_tokens_used += usage.input_tokens + usage.output_tokens;
715        self.total_cost += usage.calculate_cost(self.model);
716    }
717
718    /// Fail execution
719    pub fn fail(&mut self, error: String) {
720        self.state = AgentState::Failed;
721        self.completed_at = Some(Instant::now());
722        self.error = Some(error);
723    }
724
725    /// Get execution duration
726    pub fn duration(&self) -> Option<Duration> {
727        match (self.started_at, self.completed_at) {
728            (Some(start), Some(end)) => Some(end.duration_since(start)),
729            (Some(start), None) => Some(start.elapsed()),
730            _ => None,
731        }
732    }
733}
734
735/// Workflow step definition
736#[derive(Debug, Clone)]
737pub struct WorkflowStep {
738    /// Step identifier
739    pub step_id: String,
740    /// Agent type to execute step
741    pub agent_type: AgentType,
742    /// Task description
743    pub task: String,
744    /// Dependencies (step IDs that must complete first)
745    pub dependencies: Vec<String>,
746    /// Required model (or None for auto-selection)
747    pub required_model: Option<ClaudeModel>,
748    /// Maximum retries
749    pub max_retries: u32,
750}
751
752/// Workflow execution result
753#[derive(Debug, Clone)]
754pub struct WorkflowResult {
755    /// Workflow identifier
756    pub workflow_id: String,
757    /// Step results
758    pub step_results: HashMap<String, StepResult>,
759    /// Total execution time
760    pub total_duration: Duration,
761    /// Total tokens used
762    pub total_tokens: usize,
763    /// Total cost
764    pub total_cost: f64,
765    /// Success status
766    pub success: bool,
767    /// Error message if failed
768    pub error: Option<String>,
769}
770
771/// Individual step result
772#[derive(Debug, Clone)]
773pub struct StepResult {
774    /// Step identifier
775    pub step_id: String,
776    /// Agent that executed step
777    pub agent_id: String,
778    /// Model used
779    pub model: ClaudeModel,
780    /// Response content
781    pub response: Option<String>,
782    /// Execution duration
783    pub duration: Duration,
784    /// Tokens used
785    pub tokens_used: usize,
786    /// Cost incurred
787    pub cost: f64,
788    /// Success status
789    pub success: bool,
790    /// Error message if failed
791    pub error: Option<String>,
792}
793
794/// Multi-agent coordinator
795pub struct AgentCoordinator {
796    /// Agent contexts
797    agents: Arc<RwLock<HashMap<String, AgentContext>>>,
798    /// Context manager
799    context_manager: Arc<RwLock<ContextManager>>,
800    /// Default model for agents
801    default_model: ClaudeModel,
802    /// Maximum concurrent agents
803    max_concurrent: usize,
804    /// Total workflows executed
805    workflows_executed: u64,
806    /// Total cost incurred
807    total_cost: f64,
808}
809
810impl AgentCoordinator {
811    /// Create new agent coordinator
812    pub fn new(default_model: ClaudeModel, max_concurrent: usize) -> Self {
813        Self {
814            agents: Arc::new(RwLock::new(HashMap::new())),
815            context_manager: Arc::new(RwLock::new(ContextManager::new(100_000))),
816            default_model,
817            max_concurrent,
818            workflows_executed: 0,
819            total_cost: 0.0,
820        }
821    }
822
823    /// Spawn a new agent
824    pub fn spawn_agent(&self, agent_id: String, agent_type: AgentType) -> Result<()> {
825        let mut agents = self.agents.write();
826
827        if agents.len() >= self.max_concurrent {
828            return Err(RuvLLMError::OutOfMemory(format!(
829                "Maximum concurrent agents ({}) reached",
830                self.max_concurrent
831            )));
832        }
833
834        if agents.contains_key(&agent_id) {
835            return Err(RuvLLMError::InvalidOperation(format!(
836                "Agent {} already exists",
837                agent_id
838            )));
839        }
840
841        let context = AgentContext::new(agent_id.clone(), agent_type, self.default_model);
842        agents.insert(agent_id, context);
843
844        Ok(())
845    }
846
847    /// Get agent context
848    pub fn get_agent(&self, agent_id: &str) -> Option<AgentContext> {
849        self.agents.read().get(agent_id).cloned()
850    }
851
852    /// Update agent state
853    pub fn update_agent<F>(&self, agent_id: &str, f: F) -> Result<()>
854    where
855        F: FnOnce(&mut AgentContext),
856    {
857        let mut agents = self.agents.write();
858        let agent = agents
859            .get_mut(agent_id)
860            .ok_or_else(|| RuvLLMError::NotFound(format!("Agent {} not found", agent_id)))?;
861        f(agent);
862        Ok(())
863    }
864
865    /// Terminate agent
866    pub fn terminate_agent(&self, agent_id: &str) -> Result<()> {
867        let mut agents = self.agents.write();
868        agents
869            .remove(agent_id)
870            .ok_or_else(|| RuvLLMError::NotFound(format!("Agent {} not found", agent_id)))?;
871
872        // Clean up context window
873        self.context_manager.write().remove_window(agent_id);
874
875        Ok(())
876    }
877
878    /// Get active agent count
879    pub fn active_agent_count(&self) -> usize {
880        self.agents
881            .read()
882            .values()
883            .filter(|a| a.state == AgentState::Running)
884            .count()
885    }
886
887    /// Get total agent count
888    pub fn total_agent_count(&self) -> usize {
889        self.agents.read().len()
890    }
891
892    /// Execute workflow with dependency resolution
893    pub async fn execute_workflow(
894        &mut self,
895        workflow_id: String,
896        steps: Vec<WorkflowStep>,
897    ) -> Result<WorkflowResult> {
898        let start_time = Instant::now();
899        let mut step_results: HashMap<String, StepResult> = HashMap::new();
900        let mut completed_steps: std::collections::HashSet<String> =
901            std::collections::HashSet::new();
902
903        // Build dependency graph
904        let mut pending_steps: Vec<&WorkflowStep> = steps.iter().collect();
905
906        while !pending_steps.is_empty() {
907            // Find steps with satisfied dependencies
908            let ready_steps: Vec<_> = pending_steps
909                .iter()
910                .filter(|step| {
911                    step.dependencies
912                        .iter()
913                        .all(|dep| completed_steps.contains(dep))
914                })
915                .cloned()
916                .collect();
917
918            if ready_steps.is_empty() && !pending_steps.is_empty() {
919                return Err(RuvLLMError::InvalidOperation(
920                    "Workflow has circular dependencies".to_string(),
921                ));
922            }
923
924            // Execute ready steps in parallel
925            for step in ready_steps {
926                let agent_id = format!("{}-{}", workflow_id, step.step_id);
927                let model = step.required_model.unwrap_or(self.default_model);
928
929                // Spawn agent for step
930                self.spawn_agent(agent_id.clone(), step.agent_type)?;
931                self.update_agent(&agent_id, |a| a.start())?;
932
933                // Simulate execution (in production, would call Claude API)
934                let step_start = Instant::now();
935
936                // Create mock result
937                let result = StepResult {
938                    step_id: step.step_id.clone(),
939                    agent_id: agent_id.clone(),
940                    model,
941                    response: Some(format!("Completed: {}", step.task)),
942                    duration: step_start.elapsed(),
943                    tokens_used: 500, // Mock value
944                    cost: 0.001,      // Mock value
945                    success: true,
946                    error: None,
947                };
948
949                self.update_agent(&agent_id, |a| {
950                    let usage = UsageStats {
951                        input_tokens: 250,
952                        output_tokens: 250,
953                    };
954                    a.complete(&usage);
955                })?;
956
957                step_results.insert(step.step_id.clone(), result);
958                completed_steps.insert(step.step_id.clone());
959
960                // Clean up agent
961                self.terminate_agent(&agent_id)?;
962            }
963
964            // Remove completed steps from pending
965            pending_steps.retain(|step| !completed_steps.contains(&step.step_id));
966        }
967
968        // Calculate totals
969        let total_tokens: usize = step_results.values().map(|r| r.tokens_used).sum();
970        let total_cost: f64 = step_results.values().map(|r| r.cost).sum();
971
972        self.workflows_executed += 1;
973        self.total_cost += total_cost;
974
975        Ok(WorkflowResult {
976            workflow_id,
977            step_results,
978            total_duration: start_time.elapsed(),
979            total_tokens,
980            total_cost,
981            success: true,
982            error: None,
983        })
984    }
985
986    /// Get coordinator statistics
987    pub fn stats(&self) -> CoordinatorStats {
988        let agents = self.agents.read();
989        let active_count = agents
990            .values()
991            .filter(|a| a.state == AgentState::Running)
992            .count();
993        let total_tokens: usize = agents.values().map(|a| a.total_tokens_used).sum();
994
995        CoordinatorStats {
996            total_agents: agents.len(),
997            active_agents: active_count,
998            blocked_agents: agents
999                .values()
1000                .filter(|a| a.state == AgentState::Blocked)
1001                .count(),
1002            completed_agents: agents
1003                .values()
1004                .filter(|a| a.state == AgentState::Completed)
1005                .count(),
1006            failed_agents: agents
1007                .values()
1008                .filter(|a| a.state == AgentState::Failed)
1009                .count(),
1010            workflows_executed: self.workflows_executed,
1011            total_tokens_used: total_tokens,
1012            total_cost: self.total_cost,
1013        }
1014    }
1015}
1016
1017/// Coordinator statistics
1018#[derive(Debug, Clone)]
1019pub struct CoordinatorStats {
1020    /// Total agents created
1021    pub total_agents: usize,
1022    /// Currently active agents
1023    pub active_agents: usize,
1024    /// Blocked agents
1025    pub blocked_agents: usize,
1026    /// Completed agents
1027    pub completed_agents: usize,
1028    /// Failed agents
1029    pub failed_agents: usize,
1030    /// Total workflows executed
1031    pub workflows_executed: u64,
1032    /// Total tokens used
1033    pub total_tokens_used: usize,
1034    /// Total cost incurred
1035    pub total_cost: f64,
1036}
1037
1038// ============================================================================
1039// Cost Estimation
1040// ============================================================================
1041
1042/// Cost estimator for Claude API usage
1043pub struct CostEstimator {
1044    /// Usage by model
1045    usage_by_model: HashMap<ClaudeModel, UsageStats>,
1046}
1047
1048impl CostEstimator {
1049    /// Create new cost estimator
1050    pub fn new() -> Self {
1051        Self {
1052            usage_by_model: HashMap::new(),
1053        }
1054    }
1055
1056    /// Estimate cost for a request
1057    pub fn estimate_request_cost(
1058        &self,
1059        model: ClaudeModel,
1060        input_tokens: usize,
1061        expected_output_tokens: usize,
1062    ) -> f64 {
1063        let input_cost = (input_tokens as f64 / 1000.0) * model.input_cost_per_1k();
1064        let output_cost = (expected_output_tokens as f64 / 1000.0) * model.output_cost_per_1k();
1065        input_cost + output_cost
1066    }
1067
1068    /// Record actual usage
1069    pub fn record_usage(&mut self, model: ClaudeModel, usage: &UsageStats) {
1070        let entry = self
1071            .usage_by_model
1072            .entry(model)
1073            .or_insert(UsageStats::default());
1074        entry.input_tokens += usage.input_tokens;
1075        entry.output_tokens += usage.output_tokens;
1076    }
1077
1078    /// Get total cost to date
1079    pub fn total_cost(&self) -> f64 {
1080        self.usage_by_model
1081            .iter()
1082            .map(|(model, usage)| usage.calculate_cost(*model))
1083            .sum()
1084    }
1085
1086    /// Get cost breakdown by model
1087    pub fn cost_breakdown(&self) -> HashMap<ClaudeModel, f64> {
1088        self.usage_by_model
1089            .iter()
1090            .map(|(model, usage)| (*model, usage.calculate_cost(*model)))
1091            .collect()
1092    }
1093
1094    /// Get total usage by model
1095    pub fn usage_by_model(&self) -> &HashMap<ClaudeModel, UsageStats> {
1096        &self.usage_by_model
1097    }
1098}
1099
1100impl Default for CostEstimator {
1101    fn default() -> Self {
1102        Self::new()
1103    }
1104}
1105
1106// ============================================================================
1107// Latency Tracking
1108// ============================================================================
1109
1110/// Latency tracker for performance monitoring
1111pub struct LatencyTracker {
1112    /// Samples by model
1113    samples: HashMap<ClaudeModel, Vec<LatencySample>>,
1114    /// Maximum samples to keep per model
1115    max_samples: usize,
1116}
1117
1118/// Single latency sample
1119#[derive(Debug, Clone)]
1120pub struct LatencySample {
1121    /// Time to first token (ms)
1122    pub ttft_ms: u64,
1123    /// Total response time (ms)
1124    pub total_ms: u64,
1125    /// Input tokens
1126    pub input_tokens: usize,
1127    /// Output tokens
1128    pub output_tokens: usize,
1129    /// Timestamp
1130    pub timestamp: Instant,
1131}
1132
1133impl LatencyTracker {
1134    /// Create new latency tracker
1135    pub fn new(max_samples: usize) -> Self {
1136        Self {
1137            samples: HashMap::new(),
1138            max_samples,
1139        }
1140    }
1141
1142    /// Record latency sample
1143    pub fn record(&mut self, model: ClaudeModel, sample: LatencySample) {
1144        let samples = self.samples.entry(model).or_insert_with(Vec::new);
1145        samples.push(sample);
1146
1147        // Trim old samples
1148        if samples.len() > self.max_samples {
1149            samples.remove(0);
1150        }
1151    }
1152
1153    /// Get average TTFT for model
1154    pub fn average_ttft(&self, model: ClaudeModel) -> Option<f64> {
1155        self.samples.get(&model).map(|samples| {
1156            if samples.is_empty() {
1157                return 0.0;
1158            }
1159            let sum: u64 = samples.iter().map(|s| s.ttft_ms).sum();
1160            sum as f64 / samples.len() as f64
1161        })
1162    }
1163
1164    /// Get p95 TTFT for model
1165    pub fn p95_ttft(&self, model: ClaudeModel) -> Option<u64> {
1166        self.samples.get(&model).and_then(|samples| {
1167            if samples.is_empty() {
1168                return None;
1169            }
1170            let mut ttfts: Vec<u64> = samples.iter().map(|s| s.ttft_ms).collect();
1171            ttfts.sort();
1172            let idx = (ttfts.len() as f64 * 0.95) as usize;
1173            ttfts.get(idx.min(ttfts.len() - 1)).copied()
1174        })
1175    }
1176
1177    /// Get average tokens per second for model
1178    pub fn average_tokens_per_second(&self, model: ClaudeModel) -> Option<f64> {
1179        self.samples.get(&model).map(|samples| {
1180            if samples.is_empty() {
1181                return 0.0;
1182            }
1183            let total_tokens: usize = samples.iter().map(|s| s.output_tokens).sum();
1184            let total_time_ms: u64 = samples.iter().map(|s| s.total_ms - s.ttft_ms).sum();
1185            if total_time_ms == 0 {
1186                return 0.0;
1187            }
1188            total_tokens as f64 / (total_time_ms as f64 / 1000.0)
1189        })
1190    }
1191
1192    /// Get statistics for model
1193    pub fn get_stats(&self, model: ClaudeModel) -> Option<LatencyStats> {
1194        self.samples.get(&model).map(|samples| LatencyStats {
1195            sample_count: samples.len(),
1196            avg_ttft_ms: self.average_ttft(model).unwrap_or(0.0),
1197            p95_ttft_ms: self.p95_ttft(model).unwrap_or(0),
1198            avg_tokens_per_second: self.average_tokens_per_second(model).unwrap_or(0.0),
1199        })
1200    }
1201}
1202
1203/// Latency statistics
1204#[derive(Debug, Clone)]
1205pub struct LatencyStats {
1206    /// Number of samples
1207    pub sample_count: usize,
1208    /// Average time to first token
1209    pub avg_ttft_ms: f64,
1210    /// P95 time to first token
1211    pub p95_ttft_ms: u64,
1212    /// Average tokens per second
1213    pub avg_tokens_per_second: f64,
1214}
1215
1216// ============================================================================
1217// Tests
1218// ============================================================================
1219
1220#[cfg(test)]
1221mod tests {
1222    use super::*;
1223
1224    #[test]
1225    fn test_claude_model_costs() {
1226        let usage = UsageStats {
1227            input_tokens: 1000,
1228            output_tokens: 500,
1229        };
1230
1231        let haiku_cost = usage.calculate_cost(ClaudeModel::Haiku);
1232        let sonnet_cost = usage.calculate_cost(ClaudeModel::Sonnet);
1233        let opus_cost = usage.calculate_cost(ClaudeModel::Opus);
1234
1235        assert!(haiku_cost < sonnet_cost);
1236        assert!(sonnet_cost < opus_cost);
1237    }
1238
1239    #[test]
1240    fn test_context_window_compression() {
1241        let mut window = ContextWindow::new(1000);
1242
1243        // Add many messages
1244        for i in 0..20 {
1245            window.add_message(Message::user(format!(
1246                "Message {} with some content to add tokens",
1247                i
1248            )));
1249        }
1250
1251        // Window should have compressed
1252        assert!(window.token_count() <= 1000);
1253    }
1254
1255    #[test]
1256    fn test_message_token_estimation() {
1257        let msg = Message::user("Hello, this is a test message with some content.");
1258        let tokens = msg.estimate_tokens();
1259        assert!(tokens > 0);
1260        assert!(tokens < 100); // Should be reasonable estimate
1261    }
1262
1263    #[test]
1264    fn test_quality_monitor() {
1265        let mut monitor = QualityMonitor::new(0.6, 10);
1266
1267        // Add good quality scores
1268        for _ in 0..5 {
1269            monitor.record(0.8);
1270        }
1271        assert!(monitor.should_continue());
1272
1273        // Add bad quality scores
1274        let mut bad_monitor = QualityMonitor::new(0.6, 10);
1275        for _ in 0..5 {
1276            bad_monitor.record(0.3);
1277        }
1278        assert!(!bad_monitor.should_continue());
1279    }
1280
1281    #[test]
1282    fn test_agent_coordinator() {
1283        let coordinator = AgentCoordinator::new(ClaudeModel::Sonnet, 10);
1284
1285        coordinator
1286            .spawn_agent("agent-1".to_string(), AgentType::Coder)
1287            .unwrap();
1288        coordinator
1289            .spawn_agent("agent-2".to_string(), AgentType::Researcher)
1290            .unwrap();
1291
1292        assert_eq!(coordinator.total_agent_count(), 2);
1293
1294        coordinator.update_agent("agent-1", |a| a.start()).unwrap();
1295        assert_eq!(coordinator.active_agent_count(), 1);
1296
1297        coordinator.terminate_agent("agent-1").unwrap();
1298        assert_eq!(coordinator.total_agent_count(), 1);
1299    }
1300
1301    #[test]
1302    fn test_cost_estimator() {
1303        let mut estimator = CostEstimator::new();
1304
1305        let usage = UsageStats {
1306            input_tokens: 1000,
1307            output_tokens: 500,
1308        };
1309
1310        estimator.record_usage(ClaudeModel::Sonnet, &usage);
1311        estimator.record_usage(ClaudeModel::Haiku, &usage);
1312
1313        let total = estimator.total_cost();
1314        assert!(total > 0.0);
1315
1316        let breakdown = estimator.cost_breakdown();
1317        assert!(breakdown.contains_key(&ClaudeModel::Sonnet));
1318        assert!(breakdown.contains_key(&ClaudeModel::Haiku));
1319    }
1320
1321    #[test]
1322    fn test_latency_tracker() {
1323        let mut tracker = LatencyTracker::new(100);
1324
1325        for i in 0..10 {
1326            tracker.record(
1327                ClaudeModel::Sonnet,
1328                LatencySample {
1329                    ttft_ms: 400 + i * 10,
1330                    total_ms: 1000 + i * 100,
1331                    input_tokens: 500,
1332                    output_tokens: 200,
1333                    timestamp: Instant::now(),
1334                },
1335            );
1336        }
1337
1338        let stats = tracker.get_stats(ClaudeModel::Sonnet).unwrap();
1339        assert_eq!(stats.sample_count, 10);
1340        assert!(stats.avg_ttft_ms > 400.0);
1341        assert!(stats.avg_tokens_per_second > 0.0);
1342    }
1343}