steer_core/api/claude/
client.rs

1use async_trait::async_trait;
2use reqwest::{self, header};
3use serde::{Deserialize, Serialize};
4use std::sync::Arc;
5use strum_macros::Display;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, warn};
8
9use crate::api::{CompletionResponse, Provider, error::ApiError};
10use crate::app::conversation::{
11    AssistantContent, Message as AppMessage, ThoughtContent, ToolResult, UserContent,
12};
13use crate::auth::{
14    AuthFlowWrapper, AuthStorage, DynAuthenticationFlow, InteractiveAuth,
15    anthropic::{AnthropicOAuth, AnthropicOAuthFlow, refresh_if_needed},
16};
17use crate::config::model::{ModelId, ModelParameters};
18use steer_tools::ToolSchema;
19
20const API_URL: &str = "https://api.anthropic.com/v1/messages";
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Display)]
23pub enum ClaudeMessageRole {
24    #[serde(rename = "user")]
25    #[strum(serialize = "user")]
26    User,
27    #[serde(rename = "assistant")]
28    #[strum(serialize = "assistant")]
29    Assistant,
30    #[serde(rename = "tool")]
31    #[strum(serialize = "tool")]
32    Tool,
33}
34
35/// Represents a message to be sent to the Claude API
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37pub struct ClaudeMessage {
38    pub role: ClaudeMessageRole,
39    #[serde(flatten)]
40    pub content: ClaudeMessageContent,
41    #[serde(skip_serializing)]
42    pub id: Option<String>,
43}
44
45/// Content types for Claude API messages
46#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
47#[serde(untagged)]
48pub enum ClaudeMessageContent {
49    /// Simple text content
50    Text { content: String },
51    /// Structured content for tool results or other special content
52    StructuredContent { content: ClaudeStructuredContent },
53}
54
55/// Represents structured content blocks for messages
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57#[serde(transparent)]
58pub struct ClaudeStructuredContent(pub Vec<ClaudeContentBlock>);
59
60#[derive(Clone)]
61pub enum AuthMethod {
62    ApiKey(String),
63    OAuth(Arc<dyn AuthStorage>),
64}
65
66#[derive(Clone)]
67pub struct AnthropicClient {
68    http_client: reqwest::Client,
69    auth: AuthMethod,
70}
71
72#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
73#[serde(rename_all = "lowercase")]
74enum ThinkingType {
75    Enabled,
76}
77
78impl Default for ThinkingType {
79    fn default() -> Self {
80        Self::Enabled
81    }
82}
83
84#[derive(Debug, Serialize, Deserialize, Clone)]
85struct Thinking {
86    #[serde(rename = "type", default)]
87    thinking_type: ThinkingType,
88    budget_tokens: u32,
89}
90
91#[derive(Debug, Serialize, Clone)]
92struct SystemContentBlock {
93    #[serde(rename = "type")]
94    content_type: String,
95    text: String,
96    #[serde(skip_serializing_if = "Option::is_none")]
97    cache_control: Option<CacheControl>,
98}
99
100#[derive(Debug, Serialize, Clone)]
101#[serde(untagged)]
102enum System {
103    // Structured system prompt represented as a list of content blocks
104    Content(Vec<SystemContentBlock>),
105}
106
107#[derive(Debug, Serialize)]
108struct CompletionRequest {
109    model: String,
110    messages: Vec<ClaudeMessage>,
111    max_tokens: usize,
112    #[serde(skip_serializing_if = "Option::is_none")]
113    system: Option<System>,
114    #[serde(skip_serializing_if = "Option::is_none")]
115    tools: Option<Vec<ToolSchema>>,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    temperature: Option<f32>,
118    #[serde(skip_serializing_if = "Option::is_none")]
119    top_p: Option<f32>,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    top_k: Option<usize>,
122    #[serde(skip_serializing_if = "Option::is_none")]
123    stream: Option<bool>,
124    #[serde(skip_serializing_if = "Option::is_none")]
125    thinking: Option<Thinking>,
126}
127
128#[derive(Debug, Serialize, Deserialize, Clone)]
129struct ClaudeCompletionResponse {
130    id: String,
131    content: Vec<ClaudeContentBlock>,
132    model: String,
133    role: String,
134    #[serde(default)]
135    stop_reason: Option<String>,
136    #[serde(default)]
137    stop_sequence: Option<String>,
138    #[serde(default)]
139    usage: ClaudeUsage,
140    // Allow other fields for API flexibility
141    #[serde(flatten)]
142    extra: std::collections::HashMap<String, serde_json::Value>,
143}
144
145fn default_cache_type() -> String {
146    "ephemeral".to_string()
147}
148
149#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
150pub struct CacheControl {
151    #[serde(rename = "type", default = "default_cache_type")]
152    cache_type: String,
153}
154
155#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
156#[serde(tag = "type")]
157pub enum ClaudeContentBlock {
158    #[serde(rename = "text")]
159    Text {
160        text: String,
161        #[serde(skip_serializing_if = "Option::is_none")]
162        cache_control: Option<CacheControl>,
163        #[serde(flatten)]
164        extra: std::collections::HashMap<String, serde_json::Value>,
165    },
166    #[serde(rename = "tool_use")]
167    ToolUse {
168        id: String,
169        name: String,
170        input: serde_json::Value,
171        #[serde(skip_serializing_if = "Option::is_none")]
172        cache_control: Option<CacheControl>,
173        #[serde(flatten)]
174        extra: std::collections::HashMap<String, serde_json::Value>,
175    },
176    #[serde(rename = "tool_result")]
177    ToolResult {
178        tool_use_id: String,
179        content: Vec<ClaudeContentBlock>,
180        #[serde(skip_serializing_if = "Option::is_none")]
181        cache_control: Option<CacheControl>,
182        #[serde(skip_serializing_if = "Option::is_none")]
183        is_error: Option<bool>,
184        #[serde(flatten)]
185        extra: std::collections::HashMap<String, serde_json::Value>,
186    },
187    #[serde(rename = "thinking")]
188    Thinking {
189        thinking: String,
190        signature: String,
191        #[serde(skip_serializing_if = "Option::is_none")]
192        cache_control: Option<CacheControl>,
193        #[serde(flatten)]
194        extra: std::collections::HashMap<String, serde_json::Value>,
195    },
196    #[serde(rename = "redacted_thinking")]
197    RedactedThinking {
198        data: String,
199        #[serde(skip_serializing_if = "Option::is_none")]
200        cache_control: Option<CacheControl>,
201        #[serde(flatten)]
202        extra: std::collections::HashMap<String, serde_json::Value>,
203    },
204    #[serde(other)]
205    Unknown,
206}
207
208#[derive(Debug, Serialize, Deserialize, Default, Clone)]
209struct ClaudeUsage {
210    input_tokens: usize,
211    output_tokens: usize,
212    #[serde(skip_serializing_if = "Option::is_none")]
213    cache_creation_input_tokens: Option<usize>,
214    #[serde(skip_serializing_if = "Option::is_none")]
215    cache_read_input_tokens: Option<usize>,
216}
217
218impl AnthropicClient {
219    pub fn new(api_key: &str) -> Self {
220        Self::with_api_key(api_key)
221    }
222
223    pub fn with_api_key(api_key: &str) -> Self {
224        let mut headers = header::HeaderMap::new();
225        headers.insert("x-api-key", header::HeaderValue::from_str(api_key).unwrap());
226        headers.insert(
227            "anthropic-version",
228            header::HeaderValue::from_static("2023-06-01"),
229        );
230        headers.insert(
231            header::CONTENT_TYPE,
232            header::HeaderValue::from_static("application/json"),
233        );
234
235        let client = reqwest::Client::builder()
236            .default_headers(headers)
237            .build()
238            .expect("Failed to build HTTP client");
239
240        Self {
241            http_client: client,
242            auth: AuthMethod::ApiKey(api_key.to_string()),
243        }
244    }
245
246    pub fn with_oauth(storage: Arc<dyn AuthStorage>) -> Self {
247        // For OAuth, we don't set default headers since they're dynamic
248        let mut headers = header::HeaderMap::new();
249        headers.insert(
250            "anthropic-version",
251            header::HeaderValue::from_static("2023-06-01"),
252        );
253        headers.insert(
254            header::CONTENT_TYPE,
255            header::HeaderValue::from_static("application/json"),
256        );
257
258        let client = reqwest::Client::builder()
259            .default_headers(headers)
260            .build()
261            .expect("Failed to build HTTP client");
262
263        Self {
264            http_client: client,
265            auth: AuthMethod::OAuth(storage),
266        }
267    }
268
269    async fn get_auth_headers(&self) -> Result<Vec<(String, String)>, ApiError> {
270        match &self.auth {
271            AuthMethod::ApiKey(key) => Ok(vec![("x-api-key".to_string(), key.clone())]),
272            AuthMethod::OAuth(storage) => {
273                let oauth_client = AnthropicOAuth::new();
274                let tokens = refresh_if_needed(storage, &oauth_client)
275                    .await
276                    .map_err(|e| ApiError::AuthError(e.to_string()))?;
277                Ok(crate::auth::anthropic::get_oauth_headers(
278                    &tokens.access_token,
279                ))
280            }
281        }
282    }
283}
284
285// Conversion functions start
286fn convert_messages(messages: Vec<AppMessage>) -> Result<Vec<ClaudeMessage>, ApiError> {
287    let claude_messages: Result<Vec<ClaudeMessage>, ApiError> =
288        messages.into_iter().map(convert_single_message).collect();
289
290    // Filter out any User messages that have empty content after removing app commands
291    claude_messages.map(|messages| {
292        messages
293            .into_iter()
294            .filter(|msg| {
295                match &msg.content {
296                    ClaudeMessageContent::Text { content } => !content.trim().is_empty(),
297                    _ => true, // Keep all non-text messages
298                }
299            })
300            .collect()
301    })
302}
303
304fn convert_single_message(msg: AppMessage) -> Result<ClaudeMessage, ApiError> {
305    match &msg.data {
306        crate::app::conversation::MessageData::User { content, .. } => {
307            // Convert UserContent to Claude format
308            let combined_text = content
309                .iter()
310                .filter_map(|user_content| match user_content {
311                    UserContent::Text { text } => Some(text.clone()),
312                    UserContent::CommandExecution {
313                        command,
314                        stdout,
315                        stderr,
316                        exit_code,
317                    } => Some(UserContent::format_command_execution_as_xml(
318                        command, stdout, stderr, *exit_code,
319                    )),
320                    UserContent::AppCommand { .. } => {
321                        // Don't send app commands to the model - they're for local execution only
322                        None
323                    }
324                })
325                .collect::<Vec<_>>()
326                .join("\n");
327
328            Ok(ClaudeMessage {
329                role: ClaudeMessageRole::User,
330                content: ClaudeMessageContent::Text {
331                    content: combined_text,
332                },
333                id: Some(msg.id.clone()),
334            })
335        }
336        crate::app::conversation::MessageData::Assistant { content, .. } => {
337            // Convert AssistantContent to Claude blocks
338            let claude_blocks: Vec<ClaudeContentBlock> = content
339                .iter()
340                .filter_map(|assistant_content| match assistant_content {
341                    AssistantContent::Text { text } => {
342                        if text.trim().is_empty() {
343                            None
344                        } else {
345                            Some(ClaudeContentBlock::Text {
346                                text: text.clone(),
347                                cache_control: None,
348                                extra: Default::default(),
349                            })
350                        }
351                    }
352                    AssistantContent::ToolCall { tool_call } => Some(ClaudeContentBlock::ToolUse {
353                        id: tool_call.id.clone(),
354                        name: tool_call.name.clone(),
355                        input: tool_call.parameters.clone(),
356                        cache_control: None,
357                        extra: Default::default(),
358                    }),
359                    AssistantContent::Thought { thought } => {
360                        match thought {
361                            ThoughtContent::Signed { text, signature } => {
362                                Some(ClaudeContentBlock::Thinking {
363                                    thinking: text.clone(),
364                                    signature: signature.clone(),
365                                    cache_control: None,
366                                    extra: Default::default(),
367                                })
368                            }
369                            ThoughtContent::Redacted { data } => {
370                                Some(ClaudeContentBlock::RedactedThinking {
371                                    data: data.clone(),
372                                    cache_control: None,
373                                    extra: Default::default(),
374                                })
375                            }
376                            ThoughtContent::Simple { text } => {
377                                // Claude doesn't have a simple thought type, convert to text
378                                Some(ClaudeContentBlock::Text {
379                                    text: format!("[Thought: {text}]"),
380                                    cache_control: None,
381                                    extra: Default::default(),
382                                })
383                            }
384                        }
385                    }
386                })
387                .collect();
388
389            if !claude_blocks.is_empty() {
390                let claude_content = if claude_blocks.len() == 1 {
391                    if let Some(ClaudeContentBlock::Text { text, .. }) = claude_blocks.first() {
392                        ClaudeMessageContent::Text {
393                            content: text.clone(),
394                        }
395                    } else {
396                        ClaudeMessageContent::StructuredContent {
397                            content: ClaudeStructuredContent(claude_blocks),
398                        }
399                    }
400                } else {
401                    ClaudeMessageContent::StructuredContent {
402                        content: ClaudeStructuredContent(claude_blocks),
403                    }
404                };
405
406                Ok(ClaudeMessage {
407                    role: ClaudeMessageRole::Assistant,
408                    content: claude_content,
409                    id: Some(msg.id.clone()),
410                })
411            } else {
412                debug!("No content blocks found: {:?}", content);
413                Err(ApiError::InvalidRequest {
414                    provider: "anthropic".to_string(),
415                    details: format!(
416                        "Assistant message ID {} resulted in no valid content blocks",
417                        msg.id
418                    ),
419                })
420            }
421        }
422        crate::app::conversation::MessageData::Tool {
423            tool_use_id,
424            result,
425            ..
426        } => {
427            // Convert ToolResult to Claude format
428            // Claude expects tool results as User messages
429            let (result_text, is_error) = match result {
430                ToolResult::Error(e) => (e.to_string(), Some(true)),
431                _ => {
432                    // For all other variants, use llm_format
433                    let text = result.llm_format();
434                    let text = if text.trim().is_empty() {
435                        "(No output)".to_string()
436                    } else {
437                        text
438                    };
439                    (text, None)
440                }
441            };
442
443            let claude_blocks = vec![ClaudeContentBlock::ToolResult {
444                tool_use_id: tool_use_id.clone(),
445                content: vec![ClaudeContentBlock::Text {
446                    text: result_text,
447                    cache_control: None,
448                    extra: Default::default(),
449                }],
450                is_error,
451                cache_control: None,
452                extra: Default::default(),
453            }];
454
455            Ok(ClaudeMessage {
456                role: ClaudeMessageRole::User, // Tool results are sent as User messages in Claude
457                content: ClaudeMessageContent::StructuredContent {
458                    content: ClaudeStructuredContent(claude_blocks),
459                },
460                id: Some(msg.id.clone()),
461            })
462        }
463    }
464}
465// Conversion functions end
466
467// Convert Claude's content blocks to our provider-agnostic format
468fn convert_claude_content(claude_blocks: Vec<ClaudeContentBlock>) -> Vec<AssistantContent> {
469    claude_blocks
470        .into_iter()
471        .filter_map(|block| match block {
472            ClaudeContentBlock::Text { text, .. } => Some(AssistantContent::Text { text }),
473            ClaudeContentBlock::ToolUse {
474                id, name, input, ..
475            } => Some(AssistantContent::ToolCall {
476                tool_call: steer_tools::ToolCall {
477                    id,
478                    name,
479                    parameters: input,
480                },
481            }),
482            ClaudeContentBlock::ToolResult { .. } => {
483                warn!("Unexpected ToolResult block received in Claude response content");
484                None
485            }
486            ClaudeContentBlock::Thinking {
487                thinking,
488                signature,
489                ..
490            } => Some(AssistantContent::Thought {
491                thought: ThoughtContent::Signed {
492                    text: thinking,
493                    signature,
494                },
495            }),
496            ClaudeContentBlock::RedactedThinking { data, .. } => Some(AssistantContent::Thought {
497                thought: ThoughtContent::Redacted { data },
498            }),
499            ClaudeContentBlock::Unknown => {
500                warn!("Unknown content block received in Claude response content");
501                None
502            }
503        })
504        .collect()
505}
506
507#[async_trait]
508impl Provider for AnthropicClient {
509    fn name(&self) -> &'static str {
510        "anthropic"
511    }
512
513    async fn complete(
514        &self,
515        model_id: &ModelId,
516        messages: Vec<AppMessage>,
517        system: Option<String>,
518        tools: Option<Vec<ToolSchema>>,
519        call_options: Option<ModelParameters>,
520        token: CancellationToken,
521    ) -> Result<CompletionResponse, ApiError> {
522        let mut claude_messages = convert_messages(messages)?;
523
524        if claude_messages.is_empty() {
525            return Err(ApiError::InvalidRequest {
526                provider: self.name().to_string(),
527                details: "No messages provided".to_string(),
528            });
529        }
530
531        let last_message = claude_messages.last_mut().unwrap();
532        let cache_setting = Some(CacheControl {
533            cache_type: "ephemeral".to_string(),
534        });
535
536        let system_content = match (system, &self.auth) {
537            (Some(sys), AuthMethod::ApiKey(_)) => Some(System::Content(vec![SystemContentBlock {
538                content_type: "text".to_string(),
539                text: sys,
540                cache_control: cache_setting.clone(),
541            }])),
542            (Some(sys), AuthMethod::OAuth(_)) => Some(System::Content(vec![
543                SystemContentBlock {
544                    content_type: "text".to_string(),
545                    text: "You are Claude Code, Anthropic's official CLI for Claude.".to_string(),
546                    cache_control: cache_setting.clone(),
547                },
548                SystemContentBlock {
549                    content_type: "text".to_string(),
550                    text: sys,
551                    cache_control: cache_setting.clone(),
552                },
553            ])),
554            (None, AuthMethod::ApiKey(_)) => None,
555            (None, AuthMethod::OAuth(_)) => Some(System::Content(vec![SystemContentBlock {
556                content_type: "text".to_string(),
557                text: "You are Claude Code, Anthropic's official CLI for Claude.".to_string(),
558                cache_control: cache_setting.clone(),
559            }])),
560        };
561
562        match &mut last_message.content {
563            ClaudeMessageContent::StructuredContent { content } => {
564                for block in content.0.iter_mut() {
565                    if let ClaudeContentBlock::ToolResult { cache_control, .. } = block {
566                        *cache_control = cache_setting.clone();
567                    }
568                }
569            }
570            ClaudeMessageContent::Text { content } => {
571                let text_content = content.clone();
572                last_message.content = ClaudeMessageContent::StructuredContent {
573                    content: ClaudeStructuredContent(vec![ClaudeContentBlock::Text {
574                        text: text_content,
575                        cache_control: cache_setting,
576                        extra: Default::default(),
577                    }]),
578                };
579            }
580        }
581
582        // Extract model-specific logic using ModelId
583        let supports_thinking = call_options
584            .as_ref()
585            .and_then(|opts| opts.thinking_config.as_ref())
586            .map(|tc| tc.enabled)
587            .unwrap_or(false);
588
589        let request = if supports_thinking {
590            // Use catalog/call options to configure thinking budget when provided
591            let budget = call_options
592                .as_ref()
593                .and_then(|o| o.thinking_config)
594                .and_then(|tc| tc.budget_tokens)
595                .unwrap_or(4000);
596            let thinking = Some(Thinking {
597                thinking_type: ThinkingType::Enabled,
598                budget_tokens: budget,
599            });
600            CompletionRequest {
601                model: model_id.1.clone(), // Use the model ID string
602                messages: claude_messages,
603                max_tokens: call_options
604                    .as_ref()
605                    .and_then(|o| o.max_tokens)
606                    .map(|v| v as usize)
607                    .unwrap_or(32_000),
608                system: system_content.clone(),
609                tools,
610                temperature: call_options
611                    .as_ref()
612                    .and_then(|o| o.temperature)
613                    .or(Some(1.0)),
614                top_p: call_options.as_ref().and_then(|o| o.top_p),
615                top_k: None,
616                stream: None,
617                thinking,
618            }
619        } else {
620            CompletionRequest {
621                model: model_id.1.clone(), // Use the model ID string
622                messages: claude_messages,
623                max_tokens: call_options
624                    .as_ref()
625                    .and_then(|o| o.max_tokens)
626                    .map(|v| v as usize)
627                    .unwrap_or(8000),
628                system: system_content,
629                tools,
630                temperature: call_options
631                    .as_ref()
632                    .and_then(|o| o.temperature)
633                    .or(Some(0.7)),
634                top_p: call_options.as_ref().and_then(|o| o.top_p),
635                top_k: None,
636                stream: None,
637                thinking: None,
638            }
639        };
640
641        let auth_headers = self.get_auth_headers().await?;
642        let mut request_builder = self.http_client.post(API_URL).json(&request);
643
644        // Add dynamic auth headers
645        for (name, value) in auth_headers {
646            request_builder = request_builder.header(&name, &value);
647        }
648
649        // Check for thinking beta header based on model ID
650        if supports_thinking && matches!(&self.auth, AuthMethod::ApiKey(_)) {
651            request_builder =
652                request_builder.header("anthropic-beta", "interleaved-thinking-2025-05-14");
653        }
654
655        let response = tokio::select! {
656            biased;
657            _ = token.cancelled() => {
658                debug!(target: "claude::complete", "Cancellation token triggered before sending request.");
659                return Err(ApiError::Cancelled{ provider: self.name().to_string()});
660            }
661            res = request_builder.send() => {
662                res?
663            }
664        };
665
666        if token.is_cancelled() {
667            debug!(target: "claude::complete", "Cancellation token triggered after sending request, before status check.");
668            return Err(ApiError::Cancelled {
669                provider: self.name().to_string(),
670            });
671        }
672
673        let status = response.status();
674        if !status.is_success() {
675            let error_text = tokio::select! {
676                biased;
677                _ = token.cancelled() => {
678                    debug!(target: "claude::complete", "Cancellation token triggered while reading error response body.");
679                    return Err(ApiError::Cancelled{ provider: self.name().to_string()});
680                }
681                text_res = response.text() => {
682                    text_res?
683                }
684            };
685            return Err(match status.as_u16() {
686                401 => ApiError::AuthenticationFailed {
687                    provider: self.name().to_string(),
688                    details: error_text,
689                },
690                403 => ApiError::AuthenticationFailed {
691                    provider: self.name().to_string(),
692                    details: error_text,
693                },
694                429 => ApiError::RateLimited {
695                    provider: self.name().to_string(),
696                    details: error_text,
697                },
698                400..=499 => ApiError::InvalidRequest {
699                    provider: self.name().to_string(),
700                    details: error_text,
701                },
702                500..=599 => ApiError::ServerError {
703                    provider: self.name().to_string(),
704                    status_code: status.as_u16(),
705                    details: error_text,
706                },
707                _ => ApiError::Unknown {
708                    provider: self.name().to_string(),
709                    details: error_text,
710                },
711            });
712        }
713
714        let response_text = tokio::select! {
715            biased;
716            _ = token.cancelled() => {
717                debug!(target: "claude::complete", "Cancellation token triggered while reading successful response body.");
718                return Err(ApiError::Cancelled { provider: self.name().to_string() });
719            }
720            text_res = response.text() => {
721                text_res?
722            }
723        };
724
725        let claude_completion: ClaudeCompletionResponse = serde_json::from_str(&response_text)
726            .map_err(|e| ApiError::ResponseParsingError {
727                provider: self.name().to_string(),
728                details: format!("Error: {e}, Body: {response_text}"),
729            })?;
730        let completion = CompletionResponse {
731            content: convert_claude_content(claude_completion.content),
732        };
733
734        Ok(completion)
735    }
736}
737
738impl InteractiveAuth for AnthropicClient {
739    fn create_auth_flow(
740        &self,
741        storage: Arc<dyn AuthStorage>,
742    ) -> Option<Box<dyn DynAuthenticationFlow>> {
743        Some(Box::new(AuthFlowWrapper::new(AnthropicOAuthFlow::new(
744            storage,
745        ))))
746    }
747}