steer_core/api/claude/
client.rs

1use async_trait::async_trait;
2use reqwest::{self, header};
3use serde::{Deserialize, Serialize};
4use std::sync::Arc;
5use strum_macros::Display;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, warn};
8
9use crate::api::{CompletionResponse, Model, Provider, error::ApiError};
10use crate::app::conversation::{
11    AssistantContent, Message as AppMessage, ThoughtContent, ToolResult, UserContent,
12};
13use crate::auth::{
14    AuthFlowWrapper, AuthStorage, DynAuthenticationFlow, InteractiveAuth,
15    anthropic::{AnthropicOAuth, AnthropicOAuthFlow, refresh_if_needed},
16};
17use steer_tools::ToolSchema;
18const API_URL: &str = "https://api.anthropic.com/v1/messages";
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Display)]
21pub enum ClaudeMessageRole {
22    #[serde(rename = "user")]
23    #[strum(serialize = "user")]
24    User,
25    #[serde(rename = "assistant")]
26    #[strum(serialize = "assistant")]
27    Assistant,
28    #[serde(rename = "tool")]
29    #[strum(serialize = "tool")]
30    Tool,
31}
32
33/// Represents a message to be sent to the Claude API
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
35pub struct ClaudeMessage {
36    pub role: ClaudeMessageRole,
37    #[serde(flatten)]
38    pub content: ClaudeMessageContent,
39    #[serde(skip_serializing)]
40    pub id: Option<String>,
41}
42
43/// Content types for Claude API messages
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
45#[serde(untagged)]
46pub enum ClaudeMessageContent {
47    /// Simple text content
48    Text { content: String },
49    /// Structured content for tool results or other special content
50    StructuredContent { content: ClaudeStructuredContent },
51}
52
53/// Represents structured content blocks for messages
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
55#[serde(transparent)]
56pub struct ClaudeStructuredContent(pub Vec<ClaudeContentBlock>);
57
58#[derive(Clone)]
59pub enum AuthMethod {
60    ApiKey(String),
61    OAuth(Arc<dyn AuthStorage>),
62}
63
64#[derive(Clone)]
65pub struct AnthropicClient {
66    http_client: reqwest::Client,
67    auth: AuthMethod,
68}
69
70#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
71#[serde(rename_all = "lowercase")]
72enum ThinkingType {
73    Enabled,
74}
75
76impl Default for ThinkingType {
77    fn default() -> Self {
78        Self::Enabled
79    }
80}
81
82#[derive(Debug, Serialize, Deserialize, Clone)]
83struct Thinking {
84    #[serde(rename = "type", default)]
85    thinking_type: ThinkingType,
86    budget_tokens: u32,
87}
88
89#[derive(Debug, Serialize, Clone)]
90struct SystemContentBlock {
91    #[serde(rename = "type")]
92    content_type: String,
93    text: String,
94    #[serde(skip_serializing_if = "Option::is_none")]
95    cache_control: Option<CacheControl>,
96}
97
98#[derive(Debug, Serialize, Clone)]
99#[serde(untagged)]
100enum System {
101    // Structured system prompt represented as a list of content blocks
102    Content(Vec<SystemContentBlock>),
103}
104
105#[derive(Debug, Serialize)]
106struct CompletionRequest {
107    model: String,
108    messages: Vec<ClaudeMessage>,
109    max_tokens: usize,
110    #[serde(skip_serializing_if = "Option::is_none")]
111    system: Option<System>,
112    #[serde(skip_serializing_if = "Option::is_none")]
113    tools: Option<Vec<ToolSchema>>,
114    #[serde(skip_serializing_if = "Option::is_none")]
115    temperature: Option<f32>,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    top_p: Option<f32>,
118    #[serde(skip_serializing_if = "Option::is_none")]
119    top_k: Option<usize>,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    stream: Option<bool>,
122    #[serde(skip_serializing_if = "Option::is_none")]
123    thinking: Option<Thinking>,
124}
125
126#[derive(Debug, Serialize, Deserialize, Clone)]
127struct ClaudeCompletionResponse {
128    id: String,
129    content: Vec<ClaudeContentBlock>,
130    model: String,
131    role: String,
132    #[serde(default)]
133    stop_reason: Option<String>,
134    #[serde(default)]
135    stop_sequence: Option<String>,
136    #[serde(default)]
137    usage: ClaudeUsage,
138    // Allow other fields for API flexibility
139    #[serde(flatten)]
140    extra: std::collections::HashMap<String, serde_json::Value>,
141}
142
143fn default_cache_type() -> String {
144    "ephemeral".to_string()
145}
146
147#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
148pub struct CacheControl {
149    #[serde(rename = "type", default = "default_cache_type")]
150    cache_type: String,
151}
152
153#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
154#[serde(tag = "type")]
155pub enum ClaudeContentBlock {
156    #[serde(rename = "text")]
157    Text {
158        text: String,
159        #[serde(skip_serializing_if = "Option::is_none")]
160        cache_control: Option<CacheControl>,
161        #[serde(flatten)]
162        extra: std::collections::HashMap<String, serde_json::Value>,
163    },
164    #[serde(rename = "tool_use")]
165    ToolUse {
166        id: String,
167        name: String,
168        input: serde_json::Value,
169        #[serde(skip_serializing_if = "Option::is_none")]
170        cache_control: Option<CacheControl>,
171        #[serde(flatten)]
172        extra: std::collections::HashMap<String, serde_json::Value>,
173    },
174    #[serde(rename = "tool_result")]
175    ToolResult {
176        tool_use_id: String,
177        content: Vec<ClaudeContentBlock>,
178        #[serde(skip_serializing_if = "Option::is_none")]
179        cache_control: Option<CacheControl>,
180        #[serde(skip_serializing_if = "Option::is_none")]
181        is_error: Option<bool>,
182        #[serde(flatten)]
183        extra: std::collections::HashMap<String, serde_json::Value>,
184    },
185    #[serde(rename = "thinking")]
186    Thinking {
187        thinking: String,
188        signature: String,
189        #[serde(skip_serializing_if = "Option::is_none")]
190        cache_control: Option<CacheControl>,
191        #[serde(flatten)]
192        extra: std::collections::HashMap<String, serde_json::Value>,
193    },
194    #[serde(rename = "redacted_thinking")]
195    RedactedThinking {
196        data: String,
197        #[serde(skip_serializing_if = "Option::is_none")]
198        cache_control: Option<CacheControl>,
199        #[serde(flatten)]
200        extra: std::collections::HashMap<String, serde_json::Value>,
201    },
202    #[serde(other)]
203    Unknown,
204}
205
206#[derive(Debug, Serialize, Deserialize, Default, Clone)]
207struct ClaudeUsage {
208    input_tokens: usize,
209    output_tokens: usize,
210    #[serde(skip_serializing_if = "Option::is_none")]
211    cache_creation_input_tokens: Option<usize>,
212    #[serde(skip_serializing_if = "Option::is_none")]
213    cache_read_input_tokens: Option<usize>,
214}
215
216impl AnthropicClient {
217    pub fn new(api_key: &str) -> Self {
218        Self::with_api_key(api_key)
219    }
220
221    pub fn with_api_key(api_key: &str) -> Self {
222        let mut headers = header::HeaderMap::new();
223        headers.insert("x-api-key", header::HeaderValue::from_str(api_key).unwrap());
224        headers.insert(
225            "anthropic-version",
226            header::HeaderValue::from_static("2023-06-01"),
227        );
228        headers.insert(
229            header::CONTENT_TYPE,
230            header::HeaderValue::from_static("application/json"),
231        );
232
233        let client = reqwest::Client::builder()
234            .default_headers(headers)
235            .build()
236            .expect("Failed to build HTTP client");
237
238        Self {
239            http_client: client,
240            auth: AuthMethod::ApiKey(api_key.to_string()),
241        }
242    }
243
244    pub fn with_oauth(storage: Arc<dyn AuthStorage>) -> Self {
245        // For OAuth, we don't set default headers since they're dynamic
246        let mut headers = header::HeaderMap::new();
247        headers.insert(
248            "anthropic-version",
249            header::HeaderValue::from_static("2023-06-01"),
250        );
251        headers.insert(
252            header::CONTENT_TYPE,
253            header::HeaderValue::from_static("application/json"),
254        );
255
256        let client = reqwest::Client::builder()
257            .default_headers(headers)
258            .build()
259            .expect("Failed to build HTTP client");
260
261        Self {
262            http_client: client,
263            auth: AuthMethod::OAuth(storage),
264        }
265    }
266
267    async fn get_auth_headers(&self) -> Result<Vec<(String, String)>, ApiError> {
268        match &self.auth {
269            AuthMethod::ApiKey(key) => Ok(vec![("x-api-key".to_string(), key.clone())]),
270            AuthMethod::OAuth(storage) => {
271                let oauth_client = AnthropicOAuth::new();
272                let tokens = refresh_if_needed(storage, &oauth_client)
273                    .await
274                    .map_err(|e| ApiError::AuthError(e.to_string()))?;
275                Ok(crate::auth::anthropic::get_oauth_headers(
276                    &tokens.access_token,
277                ))
278            }
279        }
280    }
281}
282
283// Conversion functions start
284fn convert_messages(messages: Vec<AppMessage>) -> Result<Vec<ClaudeMessage>, ApiError> {
285    let claude_messages: Result<Vec<ClaudeMessage>, ApiError> =
286        messages.into_iter().map(convert_single_message).collect();
287
288    // Filter out any User messages that have empty content after removing app commands
289    claude_messages.map(|messages| {
290        messages
291            .into_iter()
292            .filter(|msg| {
293                match &msg.content {
294                    ClaudeMessageContent::Text { content } => !content.trim().is_empty(),
295                    _ => true, // Keep all non-text messages
296                }
297            })
298            .collect()
299    })
300}
301
302fn convert_single_message(msg: AppMessage) -> Result<ClaudeMessage, ApiError> {
303    match &msg.data {
304        crate::app::conversation::MessageData::User { content, .. } => {
305            // Convert UserContent to Claude format
306            let combined_text = content
307                .iter()
308                .filter_map(|user_content| match user_content {
309                    UserContent::Text { text } => Some(text.clone()),
310                    UserContent::CommandExecution {
311                        command,
312                        stdout,
313                        stderr,
314                        exit_code,
315                    } => Some(UserContent::format_command_execution_as_xml(
316                        command, stdout, stderr, *exit_code,
317                    )),
318                    UserContent::AppCommand { .. } => {
319                        // Don't send app commands to the model - they're for local execution only
320                        None
321                    }
322                })
323                .collect::<Vec<_>>()
324                .join("\n");
325
326            Ok(ClaudeMessage {
327                role: ClaudeMessageRole::User,
328                content: ClaudeMessageContent::Text {
329                    content: combined_text,
330                },
331                id: Some(msg.id.clone()),
332            })
333        }
334        crate::app::conversation::MessageData::Assistant { content, .. } => {
335            // Convert AssistantContent to Claude blocks
336            let claude_blocks: Vec<ClaudeContentBlock> = content
337                .iter()
338                .filter_map(|assistant_content| match assistant_content {
339                    AssistantContent::Text { text } => {
340                        if text.trim().is_empty() {
341                            None
342                        } else {
343                            Some(ClaudeContentBlock::Text {
344                                text: text.clone(),
345                                cache_control: None,
346                                extra: Default::default(),
347                            })
348                        }
349                    }
350                    AssistantContent::ToolCall { tool_call } => Some(ClaudeContentBlock::ToolUse {
351                        id: tool_call.id.clone(),
352                        name: tool_call.name.clone(),
353                        input: tool_call.parameters.clone(),
354                        cache_control: None,
355                        extra: Default::default(),
356                    }),
357                    AssistantContent::Thought { thought } => {
358                        match thought {
359                            ThoughtContent::Signed { text, signature } => {
360                                Some(ClaudeContentBlock::Thinking {
361                                    thinking: text.clone(),
362                                    signature: signature.clone(),
363                                    cache_control: None,
364                                    extra: Default::default(),
365                                })
366                            }
367                            ThoughtContent::Redacted { data } => {
368                                Some(ClaudeContentBlock::RedactedThinking {
369                                    data: data.clone(),
370                                    cache_control: None,
371                                    extra: Default::default(),
372                                })
373                            }
374                            ThoughtContent::Simple { text } => {
375                                // Claude doesn't have a simple thought type, convert to text
376                                Some(ClaudeContentBlock::Text {
377                                    text: format!("[Thought: {text}]"),
378                                    cache_control: None,
379                                    extra: Default::default(),
380                                })
381                            }
382                        }
383                    }
384                })
385                .collect();
386
387            if !claude_blocks.is_empty() {
388                let claude_content = if claude_blocks.len() == 1 {
389                    if let Some(ClaudeContentBlock::Text { text, .. }) = claude_blocks.first() {
390                        ClaudeMessageContent::Text {
391                            content: text.clone(),
392                        }
393                    } else {
394                        ClaudeMessageContent::StructuredContent {
395                            content: ClaudeStructuredContent(claude_blocks),
396                        }
397                    }
398                } else {
399                    ClaudeMessageContent::StructuredContent {
400                        content: ClaudeStructuredContent(claude_blocks),
401                    }
402                };
403
404                Ok(ClaudeMessage {
405                    role: ClaudeMessageRole::Assistant,
406                    content: claude_content,
407                    id: Some(msg.id.clone()),
408                })
409            } else {
410                debug!("No content blocks found: {:?}", content);
411                Err(ApiError::InvalidRequest {
412                    provider: "anthropic".to_string(),
413                    details: format!(
414                        "Assistant message ID {} resulted in no valid content blocks",
415                        msg.id
416                    ),
417                })
418            }
419        }
420        crate::app::conversation::MessageData::Tool {
421            tool_use_id,
422            result,
423            ..
424        } => {
425            // Convert ToolResult to Claude format
426            // Claude expects tool results as User messages
427            let (result_text, is_error) = match result {
428                ToolResult::Error(e) => (e.to_string(), Some(true)),
429                _ => {
430                    // For all other variants, use llm_format
431                    let text = result.llm_format();
432                    let text = if text.trim().is_empty() {
433                        "(No output)".to_string()
434                    } else {
435                        text
436                    };
437                    (text, None)
438                }
439            };
440
441            let claude_blocks = vec![ClaudeContentBlock::ToolResult {
442                tool_use_id: tool_use_id.clone(),
443                content: vec![ClaudeContentBlock::Text {
444                    text: result_text,
445                    cache_control: None,
446                    extra: Default::default(),
447                }],
448                is_error,
449                cache_control: None,
450                extra: Default::default(),
451            }];
452
453            Ok(ClaudeMessage {
454                role: ClaudeMessageRole::User, // Tool results are sent as User messages in Claude
455                content: ClaudeMessageContent::StructuredContent {
456                    content: ClaudeStructuredContent(claude_blocks),
457                },
458                id: Some(msg.id.clone()),
459            })
460        }
461    }
462}
463// Conversion functions end
464
465// Convert Claude's content blocks to our provider-agnostic format
466fn convert_claude_content(claude_blocks: Vec<ClaudeContentBlock>) -> Vec<AssistantContent> {
467    claude_blocks
468        .into_iter()
469        .filter_map(|block| match block {
470            ClaudeContentBlock::Text { text, .. } => Some(AssistantContent::Text { text }),
471            ClaudeContentBlock::ToolUse {
472                id, name, input, ..
473            } => Some(AssistantContent::ToolCall {
474                tool_call: steer_tools::ToolCall {
475                    id,
476                    name,
477                    parameters: input,
478                },
479            }),
480            ClaudeContentBlock::ToolResult { .. } => {
481                warn!("Unexpected ToolResult block received in Claude response content");
482                None
483            }
484            ClaudeContentBlock::Thinking {
485                thinking,
486                signature,
487                ..
488            } => Some(AssistantContent::Thought {
489                thought: ThoughtContent::Signed {
490                    text: thinking,
491                    signature,
492                },
493            }),
494            ClaudeContentBlock::RedactedThinking { data, .. } => Some(AssistantContent::Thought {
495                thought: ThoughtContent::Redacted { data },
496            }),
497            ClaudeContentBlock::Unknown => {
498                warn!("Unknown content block received in Claude response content");
499                None
500            }
501        })
502        .collect()
503}
504
505#[async_trait]
506impl Provider for AnthropicClient {
507    fn name(&self) -> &'static str {
508        "anthropic"
509    }
510
511    async fn complete(
512        &self,
513        model: Model,
514        messages: Vec<AppMessage>,
515        system: Option<String>,
516        tools: Option<Vec<ToolSchema>>,
517        token: CancellationToken,
518    ) -> Result<CompletionResponse, ApiError> {
519        let mut claude_messages = convert_messages(messages)?;
520
521        if claude_messages.is_empty() {
522            return Err(ApiError::InvalidRequest {
523                provider: self.name().to_string(),
524                details: "No messages provided".to_string(),
525            });
526        }
527
528        let last_message = claude_messages.last_mut().unwrap();
529        let cache_setting = Some(CacheControl {
530            cache_type: "ephemeral".to_string(),
531        });
532
533        let system_content = match (system, &self.auth) {
534            (Some(sys), AuthMethod::ApiKey(_)) => Some(System::Content(vec![SystemContentBlock {
535                content_type: "text".to_string(),
536                text: sys,
537                cache_control: cache_setting.clone(),
538            }])),
539            (Some(sys), AuthMethod::OAuth(_)) => Some(System::Content(vec![
540                SystemContentBlock {
541                    content_type: "text".to_string(),
542                    text: "You are Claude Code, Anthropic's official CLI for Claude.".to_string(),
543                    cache_control: cache_setting.clone(),
544                },
545                SystemContentBlock {
546                    content_type: "text".to_string(),
547                    text: sys,
548                    cache_control: cache_setting.clone(),
549                },
550            ])),
551            (None, AuthMethod::ApiKey(_)) => None,
552            (None, AuthMethod::OAuth(_)) => Some(System::Content(vec![SystemContentBlock {
553                content_type: "text".to_string(),
554                text: "You are Claude Code, Anthropic's official CLI for Claude.".to_string(),
555                cache_control: cache_setting.clone(),
556            }])),
557        };
558
559        match &mut last_message.content {
560            ClaudeMessageContent::StructuredContent { content } => {
561                for block in content.0.iter_mut() {
562                    if let ClaudeContentBlock::ToolResult { cache_control, .. } = block {
563                        *cache_control = cache_setting.clone();
564                    }
565                }
566            }
567            ClaudeMessageContent::Text { content } => {
568                let text_content = content.clone();
569                last_message.content = ClaudeMessageContent::StructuredContent {
570                    content: ClaudeStructuredContent(vec![ClaudeContentBlock::Text {
571                        text: text_content,
572                        cache_control: cache_setting,
573                        extra: Default::default(),
574                    }]),
575                };
576            }
577        }
578
579        let request = if model.supports_thinking() {
580            let thinking = Some(Thinking {
581                thinking_type: ThinkingType::Enabled,
582                budget_tokens: 4000,
583            });
584            CompletionRequest {
585                model: model.as_ref().to_string(),
586                messages: claude_messages,
587                max_tokens: 32_000,
588                system: system_content.clone(),
589                tools,
590                temperature: Some(1.0),
591                top_p: None,
592                top_k: None,
593                stream: None,
594                thinking,
595            }
596        } else {
597            CompletionRequest {
598                model: model.as_ref().to_string(),
599                messages: claude_messages,
600                max_tokens: 8000,
601                system: system_content,
602                tools,
603                temperature: Some(0.7),
604                top_p: None,
605                top_k: None,
606                stream: None,
607                thinking: None,
608            }
609        };
610
611        let auth_headers = self.get_auth_headers().await?;
612        let mut request_builder = self.http_client.post(API_URL).json(&request);
613
614        // Add dynamic auth headers
615        for (name, value) in auth_headers {
616            request_builder = request_builder.header(&name, &value);
617        }
618
619        if let (
620            Model::ClaudeSonnet4_20250514 | Model::ClaudeOpus4_20250514,
621            AuthMethod::ApiKey(_),
622        ) = (&model, &self.auth)
623        {
624            request_builder =
625                request_builder.header("anthropic-beta", "interleaved-thinking-2025-05-14");
626        }
627
628        let response = tokio::select! {
629            biased;
630            _ = token.cancelled() => {
631                debug!(target: "claude::complete", "Cancellation token triggered before sending request.");
632                return Err(ApiError::Cancelled{ provider: self.name().to_string()});
633            }
634            res = request_builder.send() => {
635                res?
636            }
637        };
638
639        if token.is_cancelled() {
640            debug!(target: "claude::complete", "Cancellation token triggered after sending request, before status check.");
641            return Err(ApiError::Cancelled {
642                provider: self.name().to_string(),
643            });
644        }
645
646        let status = response.status();
647        if !status.is_success() {
648            let error_text = tokio::select! {
649                biased;
650                _ = token.cancelled() => {
651                    debug!(target: "claude::complete", "Cancellation token triggered while reading error response body.");
652                    return Err(ApiError::Cancelled{ provider: self.name().to_string()});
653                }
654                text_res = response.text() => {
655                    text_res?
656                }
657            };
658            return Err(match status.as_u16() {
659                401 => ApiError::AuthenticationFailed {
660                    provider: self.name().to_string(),
661                    details: error_text,
662                },
663                403 => ApiError::AuthenticationFailed {
664                    provider: self.name().to_string(),
665                    details: error_text,
666                },
667                429 => ApiError::RateLimited {
668                    provider: self.name().to_string(),
669                    details: error_text,
670                },
671                400..=499 => ApiError::InvalidRequest {
672                    provider: self.name().to_string(),
673                    details: error_text,
674                },
675                500..=599 => ApiError::ServerError {
676                    provider: self.name().to_string(),
677                    status_code: status.as_u16(),
678                    details: error_text,
679                },
680                _ => ApiError::Unknown {
681                    provider: self.name().to_string(),
682                    details: error_text,
683                },
684            });
685        }
686
687        let response_text = tokio::select! {
688            biased;
689            _ = token.cancelled() => {
690                debug!(target: "claude::complete", "Cancellation token triggered while reading successful response body.");
691                return Err(ApiError::Cancelled { provider: self.name().to_string() });
692            }
693            text_res = response.text() => {
694                text_res?
695            }
696        };
697
698        let claude_completion: ClaudeCompletionResponse = serde_json::from_str(&response_text)
699            .map_err(|e| ApiError::ResponseParsingError {
700                provider: self.name().to_string(),
701                details: format!("Error: {e}, Body: {response_text}"),
702            })?;
703        let completion = CompletionResponse {
704            content: convert_claude_content(claude_completion.content),
705        };
706
707        Ok(completion)
708    }
709}
710
711impl InteractiveAuth for AnthropicClient {
712    fn create_auth_flow(
713        &self,
714        storage: Arc<dyn AuthStorage>,
715    ) -> Option<Box<dyn DynAuthenticationFlow>> {
716        Some(Box::new(AuthFlowWrapper::new(AnthropicOAuthFlow::new(
717            storage,
718        ))))
719    }
720}