stakpak_shared/models/integrations/
anthropic.rs

1use crate::models::error::{AgentError, BadRequestErrorMessage};
2use crate::models::llm::{
3    GenerationDelta, GenerationDeltaToolUse, LLMChoice, LLMCompletionResponse, LLMMessage,
4    LLMMessageContent, LLMMessageTypedContent, LLMTool,
5};
6use crate::models::llm::{LLMTokenUsage, PromptTokensDetails};
7use futures_util::StreamExt;
8use itertools::Itertools;
9use reqwest::Response;
10use reqwest_middleware::ClientBuilder;
11use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use std::collections::HashMap;
15
16#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
17pub enum AnthropicModel {
18    #[serde(rename = "claude-haiku-4-5-20251001")]
19    Claude45Haiku,
20    #[serde(rename = "claude-sonnet-4-5-20250929")]
21    Claude45Sonnet,
22    #[serde(rename = "claude-opus-4-5-20251101")]
23    Claude45Opus,
24}
25impl std::fmt::Display for AnthropicModel {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        match self {
28            AnthropicModel::Claude45Haiku => write!(f, "claude-haiku-4-5-20251001"),
29            AnthropicModel::Claude45Sonnet => write!(f, "claude-sonnet-4-5-20250929"),
30            AnthropicModel::Claude45Opus => write!(f, "claude-opus-4-5-20251101"),
31        }
32    }
33}
34
35impl AnthropicModel {
36    pub fn from_string(s: &str) -> Result<Self, String> {
37        serde_json::from_value(serde_json::Value::String(s.to_string()))
38            .map_err(|_| "Failed to deserialize Anthropic model".to_string())
39    }
40
41    /// Default smart model for Anthropic
42    pub const DEFAULT_SMART_MODEL: AnthropicModel = AnthropicModel::Claude45Opus;
43
44    /// Default eco model for Anthropic
45    pub const DEFAULT_ECO_MODEL: AnthropicModel = AnthropicModel::Claude45Haiku;
46
47    /// Default recovery model for Anthropic
48    pub const DEFAULT_RECOVERY_MODEL: AnthropicModel = AnthropicModel::Claude45Haiku;
49
50    /// Get default smart model as string
51    pub fn default_smart_model() -> String {
52        Self::DEFAULT_SMART_MODEL.to_string()
53    }
54
55    /// Get default eco model as string
56    pub fn default_eco_model() -> String {
57        Self::DEFAULT_ECO_MODEL.to_string()
58    }
59
60    /// Get default recovery model as string
61    pub fn default_recovery_model() -> String {
62        Self::DEFAULT_RECOVERY_MODEL.to_string()
63    }
64}
65
66#[derive(Serialize, Deserialize, Debug)]
67pub struct AnthropicInput {
68    pub model: AnthropicModel,
69    pub messages: Vec<LLMMessage>,
70    pub grammar: Option<String>,
71    pub max_tokens: u32,
72    pub stop_sequences: Option<Vec<String>>,
73    pub tools: Option<Vec<LLMTool>>,
74    pub thinking: ThinkingInput,
75}
76
77#[derive(Serialize, Deserialize, Debug)]
78pub struct ThinkingInput {
79    pub r#type: ThinkingType,
80    // Must be ≥1024 and less than max_tokens
81    pub budget_tokens: u32,
82}
83
84impl Default for ThinkingInput {
85    fn default() -> Self {
86        Self {
87            r#type: ThinkingType::default(),
88            budget_tokens: 1024,
89        }
90    }
91}
92
93#[derive(Serialize, Deserialize, Debug, Default)]
94#[serde(rename_all = "lowercase")]
95pub enum ThinkingType {
96    Enabled,
97    #[default]
98    Disabled,
99}
100
101#[derive(Serialize, Deserialize, Debug)]
102pub struct AnthropicOutputUsage {
103    pub input_tokens: u32,
104    pub output_tokens: u32,
105    #[serde(default)]
106    pub cache_creation_input_tokens: Option<u32>,
107    #[serde(default)]
108    pub cache_read_input_tokens: Option<u32>,
109}
110
111#[derive(Serialize, Deserialize, Debug)]
112pub struct AnthropicOutput {
113    pub id: String,
114    pub r#type: String,
115    pub role: String,
116    pub content: LLMMessageContent,
117    pub model: String,
118    pub stop_reason: String,
119    pub usage: AnthropicOutputUsage,
120}
121
122#[derive(Serialize, Deserialize, Debug)]
123pub struct AnthropicErrorOutput {
124    pub r#type: String,
125    pub error: AnthropicError,
126}
127
128#[derive(Serialize, Deserialize, Debug)]
129pub struct AnthropicError {
130    pub message: String,
131    pub r#type: String,
132}
133
134impl From<AnthropicOutput> for LLMCompletionResponse {
135    fn from(val: AnthropicOutput) -> Self {
136        let choices = vec![LLMChoice {
137            finish_reason: Some(val.stop_reason.clone()),
138            index: 0,
139            message: LLMMessage {
140                role: val.role.clone(),
141                content: val.content,
142            },
143        }];
144
145        LLMCompletionResponse {
146            id: val.id,
147            model: val.model,
148            object: val.r#type,
149            choices,
150            created: chrono::Utc::now().timestamp_millis() as u64,
151            usage: Some(val.usage.into()),
152        }
153    }
154}
155
156#[derive(Serialize, Deserialize, Debug)]
157pub struct AnthropicStreamEvent {
158    #[serde(rename = "type")]
159    pub event: String,
160    #[serde(flatten)]
161    pub data: AnthropicStreamEventData,
162}
163
164impl From<AnthropicOutputUsage> for LLMTokenUsage {
165    fn from(usage: AnthropicOutputUsage) -> Self {
166        let input_tokens = usage.input_tokens
167            + usage.cache_creation_input_tokens.unwrap_or(0)
168            + usage.cache_read_input_tokens.unwrap_or(0);
169        let output_tokens = usage.output_tokens;
170        Self {
171            completion_tokens: output_tokens,
172            prompt_tokens: input_tokens,
173            total_tokens: input_tokens + output_tokens,
174            prompt_tokens_details: Some(PromptTokensDetails {
175                input_tokens: Some(input_tokens),
176                output_tokens: Some(output_tokens),
177                cache_read_input_tokens: usage.cache_read_input_tokens,
178                cache_write_input_tokens: usage.cache_creation_input_tokens,
179            }),
180        }
181    }
182}
183
184#[derive(Serialize, Deserialize, Debug)]
185pub struct AnthropicStreamOutput {
186    pub id: String,
187    pub r#type: String,
188    pub role: String,
189    pub content: LLMMessageContent,
190    pub model: String,
191    pub stop_reason: Option<String>,
192    pub usage: AnthropicOutputUsage,
193}
194
195#[derive(Serialize, Deserialize, Debug)]
196#[serde(rename_all = "snake_case", tag = "type")]
197pub enum AnthropicStreamEventData {
198    MessageStart {
199        message: AnthropicStreamOutput,
200    },
201    ContentBlockStart {
202        index: usize,
203        content_block: ContentBlock,
204    },
205    ContentBlockDelta {
206        index: usize,
207        delta: ContentDelta,
208    },
209    ContentBlockStop {
210        index: usize,
211    },
212    MessageDelta {
213        delta: MessageDelta,
214        usage: Option<AnthropicOutputUsage>,
215    },
216    MessageStop {},
217    Ping {},
218}
219
220#[derive(Serialize, Deserialize, Debug)]
221#[serde(tag = "type")]
222pub enum ContentBlock {
223    #[serde(rename = "text")]
224    Text { text: String },
225    #[serde(rename = "thinking")]
226    Thinking { thinking: String },
227    #[serde(rename = "tool_use")]
228    ToolUse {
229        id: String,
230        name: String,
231        input: serde_json::Value,
232    },
233}
234
235#[derive(Serialize, Deserialize, Debug)]
236#[serde(tag = "type")]
237pub enum ContentDelta {
238    #[serde(rename = "text_delta")]
239    TextDelta { text: String },
240    #[serde(rename = "thinking_delta")]
241    ThinkingDelta { thinking: String },
242    #[serde(rename = "input_json_delta")]
243    InputJsonDelta { partial_json: String },
244}
245
246#[derive(Serialize, Deserialize, Debug)]
247pub struct MessageDelta {
248    pub stop_reason: Option<String>,
249    pub stop_sequence: Option<String>,
250}
251
252#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
253pub struct AnthropicConfig {
254    pub api_endpoint: Option<String>,
255    pub api_key: Option<String>,
256}
257
258pub struct Anthropic {}
259
260impl Anthropic {
261    pub async fn chat(
262        config: &AnthropicConfig,
263        input: AnthropicInput,
264    ) -> Result<LLMCompletionResponse, AgentError> {
265        let mut payload = json!({
266            "model": input.model.to_string(),
267            "system": input.messages.iter().find(|mess| mess.role == "system").map(|mess| mess.content.clone()),
268            "messages": input.messages.into_iter().filter(|message| message.role!= "system").collect::<Vec<LLMMessage>>(),
269            "max_tokens": input.max_tokens,
270            "temperature": 0,
271            "stream": false,
272        });
273
274        if let Some(tools) = input.tools {
275            payload["tools"] = json!(tools);
276        }
277
278        if let Some(stop_sequences) = input.stop_sequences {
279            payload["stop_sequences"] = json!(stop_sequences);
280        }
281
282        // Setup retry with exponential backoff
283        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
284        let client = ClientBuilder::new(reqwest::Client::new())
285            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
286            .build();
287
288        let api_endpoint = config
289            .api_endpoint
290            .as_ref()
291            .map_or("https://api.anthropic.com/v1/messages", |v| v);
292        let api_key = config.api_key.as_ref().map_or("", |v| v);
293
294        // Send the POST request
295        let response = client
296            .post(api_endpoint)
297            .header("x-api-key", api_key)
298            .header("anthropic-version", "2023-06-01")
299            .header("accept", "application/json")
300            .header("content-type", "application/json")
301            .json(&payload)
302            .send()
303            .await;
304
305        let response = match response {
306            Ok(resp) => resp,
307            Err(e) => {
308                let error_message = format!("Anthropic API request error: {e}");
309                return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
310                    error_message,
311                )));
312            }
313        };
314
315        // Check for HTTP status errors and extract error details if present
316        if !response.status().is_success() {
317            let status = response.status();
318            let error_body = match response.text().await {
319                Ok(body) => body,
320                Err(_) => "Unable to read error response".to_string(),
321            };
322
323            return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
324                format!(
325                    "Anthropic API returned error status: {}, body: {}",
326                    status, error_body
327                ),
328            )));
329        }
330
331        match response.json::<Value>().await {
332            Ok(json) => {
333                // I have to copy this here to print the original response in case we find an error
334                let pretty_json = serde_json::to_string_pretty(&json).unwrap_or_default();
335                match serde_json::from_value::<AnthropicOutput>(json) {
336                    Ok(json_response) => Ok(json_response.into()),
337                    Err(e) => Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
338                        format!(
339                            "Error deserializing JSON: {:?}\nOriginal JSON: {}",
340                            e, pretty_json
341                        ),
342                    ))),
343                }
344            }
345            Err(e) => Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
346                format!("Failed to decode Anthropic JSON response:: {:?}", e),
347            ))),
348        }
349    }
350
351    pub async fn chat_stream(
352        config: &AnthropicConfig,
353        stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
354        input: AnthropicInput,
355    ) -> Result<LLMCompletionResponse, AgentError> {
356        let mut payload = json!({
357            "model": input.model.to_string(),
358            "system": input.messages.iter().find(|mess| mess.role == "system").map(|mess| json!([
359                {
360                    "type": "text",
361                    "text": mess.content.clone(),
362                    "cache_control": {"type": "ephemeral", "ttl": "5m"}
363                }
364            ])),
365            "messages": input.messages.into_iter().filter(|message| message.role != "system").collect::<Vec<LLMMessage>>(),
366            "max_tokens": input.max_tokens,
367            "temperature": 0,
368            "stream": true,
369        });
370
371        if let Some(tools) = input.tools {
372            payload["tools"] = json!(
373                tools
374                    .iter()
375                    .map(|tool| {
376                        let mut tool_json = json!(tool);
377                        if let Some(last_tool) = tools.last()
378                            && tool == last_tool
379                        {
380                            tool_json["cache_control"] = json!({"type": "ephemeral", "ttl": "1h"});
381                        }
382                        tool_json
383                    })
384                    .collect::<Vec<serde_json::Value>>()
385            );
386        }
387
388        if let Some(stop_sequences) = input.stop_sequences {
389            payload["stop_sequences"] = json!(stop_sequences);
390        }
391
392        // Setup retry with exponential backoff
393        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
394        let client = ClientBuilder::new(reqwest::Client::new())
395            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
396            .build();
397
398        let api_endpoint = config
399            .api_endpoint
400            .as_deref()
401            .unwrap_or("https://api.anthropic.com/v1/messages");
402
403        let api_key = config.api_key.as_ref().map_or("", |v| v);
404
405        // Send the POST request
406        let response = client
407            .post(api_endpoint)
408            .header("x-api-key", api_key)
409            .header("anthropic-version", "2023-06-01")
410            .header(
411                "anthropic-beta",
412                "extended-cache-ttl-2025-04-11,context-1m-2025-08-07",
413            )
414            .header("accept", "application/json")
415            .header("content-type", "application/json")
416            .json(&payload)
417            .send()
418            .await;
419
420        let response = match response {
421            Ok(resp) => resp,
422            Err(e) => {
423                return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
424                    e.to_string(),
425                )));
426            }
427        };
428
429        if !response.status().is_success() {
430            let error_body = match response.json::<AnthropicErrorOutput>().await {
431                Ok(body) => body,
432                Err(_) => AnthropicErrorOutput {
433                    r#type: "error".to_string(),
434                    error: AnthropicError {
435                        message: "Unable to read error response".to_string(),
436                        r#type: "error".to_string(),
437                    },
438                },
439            };
440
441            match error_body.error.r#type.as_str() {
442                "invalid_request_error" => {
443                    return Err(AgentError::BadRequest(
444                        BadRequestErrorMessage::InvalidAgentInput(error_body.error.message),
445                    ));
446                }
447                _ => {
448                    return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
449                        error_body.error.message,
450                    )));
451                }
452            }
453        }
454
455        let completion_response =
456            process_stream(response, input.model.to_string(), stream_channel_tx).await?;
457
458        Ok(completion_response)
459    }
460}
461
462pub async fn process_stream(
463    response: Response,
464    model: String,
465    stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
466) -> Result<LLMCompletionResponse, AgentError> {
467    let mut completion_response = LLMCompletionResponse {
468        id: "".to_string(),
469        model: model.clone(),
470        object: "chat.completion".to_string(),
471        choices: vec![],
472        created: chrono::Utc::now().timestamp_millis() as u64,
473        usage: None,
474    };
475
476    let mut choices: HashMap<usize, LLMChoice> = HashMap::from([(
477        0,
478        LLMChoice {
479            finish_reason: None,
480            index: 0,
481            message: LLMMessage {
482                role: "assistant".to_string(),
483                content: LLMMessageContent::List(vec![]),
484            },
485        },
486    )]);
487    let mut contents: Vec<LLMMessageTypedContent> = vec![];
488    let mut stream = response.bytes_stream();
489    let mut unparsed_data = String::new();
490
491    while let Some(chunk) = stream.next().await {
492        let chunk = chunk.map_err(|e| {
493            let error_message = format!("Failed to read stream chunk from Anthropic API: {e}");
494            AgentError::BadRequest(BadRequestErrorMessage::ApiError(error_message))
495        })?;
496
497        let text = std::str::from_utf8(&chunk).map_err(|e| {
498            let error_message = format!("Failed to parse UTF-8 from Anthropic response: {e}");
499            AgentError::BadRequest(BadRequestErrorMessage::ApiError(error_message))
500        })?;
501
502        unparsed_data.push_str(text);
503
504        while let Some(event_end) = unparsed_data.find("\n\n") {
505            let event_str = unparsed_data[..event_end].to_string();
506            unparsed_data = unparsed_data[event_end + 2..].to_string();
507
508            if !event_str.starts_with("event: ") {
509                continue;
510            }
511
512            let json_str = &event_str[event_str.find("data: ").map(|i| i + 6).unwrap_or(6)..];
513            if json_str == "[DONE]" {
514                continue;
515            }
516
517            match serde_json::from_str::<AnthropicStreamEventData>(json_str) {
518                Ok(data) => {
519                    match data {
520                        AnthropicStreamEventData::MessageStart { message } => {
521                            completion_response.id = message.id;
522                            completion_response.model = message.model;
523                            completion_response.object = message.r#type;
524                            completion_response.usage = Some(message.usage.into());
525                        }
526                        AnthropicStreamEventData::ContentBlockStart {
527                            content_block,
528                            index,
529                        } => match content_block {
530                            ContentBlock::Text { text } => {
531                                stream_channel_tx
532                                    .send(GenerationDelta::Content {
533                                        // if this will be rendered as markdown, we need to escape the < and >
534                                        content: text.clone(), //.replace("<", "\\<").replace(">", "\\>"),
535                                    })
536                                    .await
537                                    .map_err(|e| {
538                                        AgentError::BadRequest(BadRequestErrorMessage::ApiError(
539                                            e.to_string(),
540                                        ))
541                                    })?;
542                                contents.push(LLMMessageTypedContent::Text { text: text.clone() });
543                            }
544                            ContentBlock::Thinking { thinking } => {
545                                stream_channel_tx
546                                    .send(GenerationDelta::Thinking {
547                                        thinking: thinking.clone(),
548                                    })
549                                    .await
550                                    .map_err(|e| {
551                                        AgentError::BadRequest(BadRequestErrorMessage::ApiError(
552                                            e.to_string(),
553                                        ))
554                                    })?;
555                                contents.push(LLMMessageTypedContent::Text {
556                                    text: thinking.clone(),
557                                });
558                            }
559                            ContentBlock::ToolUse { id, name, input: _ } => {
560                                stream_channel_tx
561                                    .send(GenerationDelta::ToolUse {
562                                        tool_use: GenerationDeltaToolUse {
563                                            id: Some(id.clone()),
564                                            name: Some(name.clone()),
565                                            input: Some(String::new()),
566                                            index,
567                                        },
568                                    })
569                                    .await
570                                    .map_err(|e| {
571                                        AgentError::BadRequest(BadRequestErrorMessage::ApiError(
572                                            e.to_string(),
573                                        ))
574                                    })?;
575                                // Initialize with empty string since we'll accumulate via InputJsonDelta events
576                                contents.push(LLMMessageTypedContent::ToolCall {
577                                    id: id.clone(),
578                                    name: name.clone(),
579                                    args: serde_json::Value::String(String::new()),
580                                });
581                            }
582                        },
583                        AnthropicStreamEventData::ContentBlockDelta { delta, index } => {
584                            if let Some(content) = contents.get_mut(index) {
585                                match delta {
586                                    ContentDelta::TextDelta { text } => {
587                                        stream_channel_tx
588                                            .send(GenerationDelta::Content {
589                                                // if this will be rendered as markdown, we need to escape the < and >
590                                                content: text.clone(), //.replace("<", "\\<").replace(">", "\\>"),
591                                            })
592                                            .await
593                                            .map_err(|e| {
594                                                AgentError::BadRequest(
595                                                    BadRequestErrorMessage::ApiError(e.to_string()),
596                                                )
597                                            })?;
598                                        let delta_text = text.clone();
599                                        if let LLMMessageTypedContent::Text { text } = content {
600                                            text.push_str(&delta_text);
601                                        }
602                                    }
603                                    ContentDelta::ThinkingDelta { thinking } => {
604                                        stream_channel_tx
605                                            .send(GenerationDelta::Thinking {
606                                                thinking: thinking.clone(),
607                                            })
608                                            .await
609                                            .map_err(|e| {
610                                                AgentError::BadRequest(
611                                                    BadRequestErrorMessage::ApiError(e.to_string()),
612                                                )
613                                            })?;
614                                        if let LLMMessageTypedContent::Text { text } = content {
615                                            text.push_str(&thinking);
616                                        }
617                                    }
618                                    ContentDelta::InputJsonDelta { partial_json } => {
619                                        stream_channel_tx
620                                            .send(GenerationDelta::ToolUse {
621                                                tool_use: GenerationDeltaToolUse {
622                                                    id: None,
623                                                    name: None,
624                                                    input: Some(partial_json.clone()),
625                                                    index,
626                                                },
627                                            })
628                                            .await
629                                            .map_err(|e| {
630                                                AgentError::BadRequest(
631                                                    BadRequestErrorMessage::ApiError(e.to_string()),
632                                                )
633                                            })?;
634                                        if let Some(LLMMessageTypedContent::ToolCall {
635                                            args: serde_json::Value::String(accumulated_json),
636                                            ..
637                                        }) = contents.get_mut(index)
638                                        {
639                                            accumulated_json.push_str(&partial_json);
640                                        }
641                                    }
642                                }
643                            }
644                        }
645                        AnthropicStreamEventData::ContentBlockStop { index } => {
646                            if let Some(LLMMessageTypedContent::ToolCall { args, .. }) =
647                                contents.get_mut(index)
648                                && let serde_json::Value::String(json_str) = args
649                            {
650                                // Try to parse the accumulated JSON string
651                                *args = serde_json::from_str(json_str).unwrap_or_else(|_| {
652                                    // If parsing fails, keep as string
653                                    serde_json::Value::String(json_str.clone())
654                                });
655                            }
656                        }
657                        AnthropicStreamEventData::MessageDelta { delta, usage } => {
658                            //write message delta to file as json
659
660                            if let Some(stop_reason) = delta.stop_reason {
661                                for choice in choices.values_mut() {
662                                    if choice.finish_reason.is_none() {
663                                        choice.finish_reason = Some(stop_reason.clone());
664                                    }
665                                }
666                            }
667                            if let Some(usage) = usage {
668                                let usage = LLMTokenUsage {
669                                    prompt_tokens: usage.input_tokens,
670                                    completion_tokens: usage.output_tokens,
671                                    total_tokens: usage.input_tokens + usage.output_tokens,
672                                    prompt_tokens_details: Some(PromptTokensDetails {
673                                        input_tokens: Some(usage.input_tokens),
674                                        output_tokens: Some(usage.output_tokens),
675                                        cache_read_input_tokens: usage.cache_read_input_tokens,
676                                        cache_write_input_tokens: usage.cache_creation_input_tokens,
677                                    }),
678                                };
679
680                                stream_channel_tx
681                                    .send(GenerationDelta::Usage {
682                                        usage: usage.clone(),
683                                    })
684                                    .await
685                                    .map_err(|e| {
686                                        AgentError::BadRequest(BadRequestErrorMessage::ApiError(
687                                            e.to_string(),
688                                        ))
689                                    })?;
690                                completion_response.usage = Some(usage);
691                            }
692                        }
693
694                        _ => {}
695                    }
696                }
697                Err(_) => {
698                    // We don't want to fail the entire stream if we can't parse one message
699                    // Just log the error and continue
700                }
701            }
702        }
703    }
704
705    if let Some(choice) = choices.get_mut(&0) {
706        choice.message.content = LLMMessageContent::List(contents);
707    }
708
709    completion_response.choices = choices
710        .into_iter()
711        .sorted_by(|(index, _), (other_index, _)| index.cmp(other_index))
712        .map(|(_, choice)| choice)
713        .collect();
714
715    Ok(completion_response)
716}