stakpak_shared/models/integrations/
anthropic.rs

1use crate::models::error::{AgentError, BadRequestErrorMessage};
2use crate::models::llm::{
3    GenerationDelta, GenerationDeltaToolUse, LLMChoice, LLMCompletionResponse, LLMMessage,
4    LLMMessageContent, LLMMessageTypedContent, LLMTool,
5};
6use crate::models::llm::{LLMTokenUsage, PromptTokensDetails};
7use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
8use futures_util::StreamExt;
9use itertools::Itertools;
10use reqwest::Response;
11use reqwest_middleware::ClientBuilder;
12use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
13use serde::{Deserialize, Serialize};
14use serde_json::{Value, json};
15use std::collections::HashMap;
16
17#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
18pub enum AnthropicModel {
19    #[serde(rename = "claude-haiku-4-5-20251001")]
20    Claude45Haiku,
21    #[serde(rename = "claude-sonnet-4-5-20250929")]
22    Claude45Sonnet,
23    #[serde(rename = "claude-opus-4-5-20251101")]
24    Claude45Opus,
25}
26impl std::fmt::Display for AnthropicModel {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            AnthropicModel::Claude45Haiku => write!(f, "claude-haiku-4-5-20251001"),
30            AnthropicModel::Claude45Sonnet => write!(f, "claude-sonnet-4-5-20250929"),
31            AnthropicModel::Claude45Opus => write!(f, "claude-opus-4-5-20251101"),
32        }
33    }
34}
35
36impl AnthropicModel {
37    pub fn from_string(s: &str) -> Result<Self, String> {
38        serde_json::from_value(serde_json::Value::String(s.to_string()))
39            .map_err(|_| "Failed to deserialize Anthropic model".to_string())
40    }
41
42    /// Default smart model for Anthropic
43    pub const DEFAULT_SMART_MODEL: AnthropicModel = AnthropicModel::Claude45Opus;
44
45    /// Default eco model for Anthropic
46    pub const DEFAULT_ECO_MODEL: AnthropicModel = AnthropicModel::Claude45Haiku;
47
48    /// Default recovery model for Anthropic
49    pub const DEFAULT_RECOVERY_MODEL: AnthropicModel = AnthropicModel::Claude45Haiku;
50
51    /// Get default smart model as string
52    pub fn default_smart_model() -> String {
53        Self::DEFAULT_SMART_MODEL.to_string()
54    }
55
56    /// Get default eco model as string
57    pub fn default_eco_model() -> String {
58        Self::DEFAULT_ECO_MODEL.to_string()
59    }
60
61    /// Get default recovery model as string
62    pub fn default_recovery_model() -> String {
63        Self::DEFAULT_RECOVERY_MODEL.to_string()
64    }
65}
66
67impl ContextAware for AnthropicModel {
68    fn context_info(&self) -> ModelContextInfo {
69        let model_name = self.to_string();
70
71        if model_name.starts_with("claude-haiku") {
72            return ModelContextInfo {
73                max_tokens: 200_000,
74                pricing_tiers: vec![ContextPricingTier {
75                    label: "Standard".to_string(),
76                    input_cost_per_million: 1.0,
77                    output_cost_per_million: 5.0,
78                    upper_bound: None,
79                }],
80                approach_warning_threshold: 0.8,
81            };
82        }
83
84        if model_name.starts_with("claude-sonnet") {
85            return ModelContextInfo {
86                max_tokens: 1_000_000,
87                pricing_tiers: vec![
88                    ContextPricingTier {
89                        label: "<200K tokens".to_string(),
90                        input_cost_per_million: 3.0,
91                        output_cost_per_million: 15.0,
92                        upper_bound: Some(200_000),
93                    },
94                    ContextPricingTier {
95                        label: ">200K tokens".to_string(),
96                        input_cost_per_million: 6.0,
97                        output_cost_per_million: 22.5,
98                        upper_bound: None,
99                    },
100                ],
101                approach_warning_threshold: 0.8,
102            };
103        }
104
105        if model_name.starts_with("claude-opus") {
106            return ModelContextInfo {
107                max_tokens: 200_000,
108                pricing_tiers: vec![ContextPricingTier {
109                    label: "Standard".to_string(),
110                    input_cost_per_million: 5.0,
111                    output_cost_per_million: 25.0,
112                    upper_bound: None,
113                }],
114                approach_warning_threshold: 0.8,
115            };
116        }
117
118        panic!("Unknown model: {}", model_name);
119    }
120
121    fn model_name(&self) -> String {
122        match self {
123            AnthropicModel::Claude45Sonnet => "Claude Sonnet 4.5".to_string(),
124            AnthropicModel::Claude45Haiku => "Claude Haiku 4.5".to_string(),
125            AnthropicModel::Claude45Opus => "Claude Opus 4.5".to_string(),
126        }
127    }
128}
129
130#[derive(Serialize, Deserialize, Debug)]
131pub struct AnthropicInput {
132    pub model: AnthropicModel,
133    pub messages: Vec<LLMMessage>,
134    pub grammar: Option<String>,
135    pub max_tokens: u32,
136    pub stop_sequences: Option<Vec<String>>,
137    pub tools: Option<Vec<LLMTool>>,
138    pub thinking: ThinkingInput,
139}
140
141#[derive(Serialize, Deserialize, Debug)]
142pub struct ThinkingInput {
143    pub r#type: ThinkingType,
144    // Must be ≥1024 and less than max_tokens
145    pub budget_tokens: u32,
146}
147
148impl Default for ThinkingInput {
149    fn default() -> Self {
150        Self {
151            r#type: ThinkingType::default(),
152            budget_tokens: 1024,
153        }
154    }
155}
156
157#[derive(Serialize, Deserialize, Debug, Default)]
158#[serde(rename_all = "lowercase")]
159pub enum ThinkingType {
160    Enabled,
161    #[default]
162    Disabled,
163}
164
165#[derive(Serialize, Deserialize, Debug)]
166pub struct AnthropicOutputUsage {
167    pub input_tokens: u32,
168    pub output_tokens: u32,
169    #[serde(default)]
170    pub cache_creation_input_tokens: Option<u32>,
171    #[serde(default)]
172    pub cache_read_input_tokens: Option<u32>,
173}
174
175#[derive(Serialize, Deserialize, Debug)]
176pub struct AnthropicOutput {
177    pub id: String,
178    pub r#type: String,
179    pub role: String,
180    pub content: LLMMessageContent,
181    pub model: String,
182    pub stop_reason: String,
183    pub usage: AnthropicOutputUsage,
184}
185
186#[derive(Serialize, Deserialize, Debug)]
187pub struct AnthropicErrorOutput {
188    pub r#type: String,
189    pub error: AnthropicError,
190}
191
192#[derive(Serialize, Deserialize, Debug)]
193pub struct AnthropicError {
194    pub message: String,
195    pub r#type: String,
196}
197
198impl From<AnthropicOutput> for LLMCompletionResponse {
199    fn from(val: AnthropicOutput) -> Self {
200        let choices = vec![LLMChoice {
201            finish_reason: Some(val.stop_reason.clone()),
202            index: 0,
203            message: LLMMessage {
204                role: val.role.clone(),
205                content: val.content,
206            },
207        }];
208
209        LLMCompletionResponse {
210            id: val.id,
211            model: val.model,
212            object: val.r#type,
213            choices,
214            created: chrono::Utc::now().timestamp_millis() as u64,
215            usage: Some(val.usage.into()),
216        }
217    }
218}
219
220#[derive(Serialize, Deserialize, Debug)]
221pub struct AnthropicStreamEvent {
222    #[serde(rename = "type")]
223    pub event: String,
224    #[serde(flatten)]
225    pub data: AnthropicStreamEventData,
226}
227
228impl From<AnthropicOutputUsage> for LLMTokenUsage {
229    fn from(usage: AnthropicOutputUsage) -> Self {
230        let input_tokens = usage.input_tokens
231            + usage.cache_creation_input_tokens.unwrap_or(0)
232            + usage.cache_read_input_tokens.unwrap_or(0);
233        let output_tokens = usage.output_tokens;
234        Self {
235            completion_tokens: output_tokens,
236            prompt_tokens: input_tokens,
237            total_tokens: input_tokens + output_tokens,
238            prompt_tokens_details: Some(PromptTokensDetails {
239                input_tokens: Some(input_tokens),
240                output_tokens: Some(output_tokens),
241                cache_read_input_tokens: usage.cache_read_input_tokens,
242                cache_write_input_tokens: usage.cache_creation_input_tokens,
243            }),
244        }
245    }
246}
247
248#[derive(Serialize, Deserialize, Debug)]
249pub struct AnthropicStreamOutput {
250    pub id: String,
251    pub r#type: String,
252    pub role: String,
253    pub content: LLMMessageContent,
254    pub model: String,
255    pub stop_reason: Option<String>,
256    pub usage: AnthropicOutputUsage,
257}
258
259#[derive(Serialize, Deserialize, Debug)]
260#[serde(rename_all = "snake_case", tag = "type")]
261pub enum AnthropicStreamEventData {
262    MessageStart {
263        message: AnthropicStreamOutput,
264    },
265    ContentBlockStart {
266        index: usize,
267        content_block: ContentBlock,
268    },
269    ContentBlockDelta {
270        index: usize,
271        delta: ContentDelta,
272    },
273    ContentBlockStop {
274        index: usize,
275    },
276    MessageDelta {
277        delta: MessageDelta,
278        usage: Option<AnthropicOutputUsage>,
279    },
280    MessageStop {},
281    Ping {},
282}
283
284#[derive(Serialize, Deserialize, Debug)]
285#[serde(tag = "type")]
286pub enum ContentBlock {
287    #[serde(rename = "text")]
288    Text { text: String },
289    #[serde(rename = "thinking")]
290    Thinking { thinking: String },
291    #[serde(rename = "tool_use")]
292    ToolUse {
293        id: String,
294        name: String,
295        input: serde_json::Value,
296    },
297}
298
299#[derive(Serialize, Deserialize, Debug)]
300#[serde(tag = "type")]
301pub enum ContentDelta {
302    #[serde(rename = "text_delta")]
303    TextDelta { text: String },
304    #[serde(rename = "thinking_delta")]
305    ThinkingDelta { thinking: String },
306    #[serde(rename = "input_json_delta")]
307    InputJsonDelta { partial_json: String },
308}
309
310#[derive(Serialize, Deserialize, Debug)]
311pub struct MessageDelta {
312    pub stop_reason: Option<String>,
313    pub stop_sequence: Option<String>,
314}
315
316#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
317pub struct AnthropicConfig {
318    pub api_endpoint: Option<String>,
319    pub api_key: Option<String>,
320}
321
322pub struct Anthropic {}
323
324impl Anthropic {
325    pub async fn chat(
326        config: &AnthropicConfig,
327        input: AnthropicInput,
328    ) -> Result<LLMCompletionResponse, AgentError> {
329        let mut payload = json!({
330            "model": input.model.to_string(),
331            "system": input.messages.iter().find(|mess| mess.role == "system").map(|mess| mess.content.clone()),
332            "messages": input.messages.into_iter().filter(|message| message.role!= "system").collect::<Vec<LLMMessage>>(),
333            "max_tokens": input.max_tokens,
334            "temperature": 0,
335            "stream": false,
336        });
337
338        if let Some(tools) = input.tools {
339            payload["tools"] = json!(tools);
340        }
341
342        if let Some(stop_sequences) = input.stop_sequences {
343            payload["stop_sequences"] = json!(stop_sequences);
344        }
345
346        // Setup retry with exponential backoff
347        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
348        let client = ClientBuilder::new(reqwest::Client::new())
349            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
350            .build();
351
352        let api_endpoint = config
353            .api_endpoint
354            .as_ref()
355            .map_or("https://api.anthropic.com/v1/messages", |v| v);
356        let api_key = config.api_key.as_ref().map_or("", |v| v);
357
358        // Send the POST request
359        let response = client
360            .post(api_endpoint)
361            .header("x-api-key", api_key)
362            .header("anthropic-version", "2023-06-01")
363            .header("accept", "application/json")
364            .header("content-type", "application/json")
365            .json(&payload)
366            .send()
367            .await;
368
369        let response = match response {
370            Ok(resp) => resp,
371            Err(e) => {
372                let error_message = format!("Anthropic API request error: {e}");
373                return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
374                    error_message,
375                )));
376            }
377        };
378
379        // Check for HTTP status errors and extract error details if present
380        if !response.status().is_success() {
381            let status = response.status();
382            let error_body = match response.text().await {
383                Ok(body) => body,
384                Err(_) => "Unable to read error response".to_string(),
385            };
386
387            return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
388                format!(
389                    "Anthropic API returned error status: {}, body: {}",
390                    status, error_body
391                ),
392            )));
393        }
394
395        match response.json::<Value>().await {
396            Ok(json) => {
397                // Check if the response contains an error field
398                if let Some(error_obj) = json.get("error") {
399                    let error_message = if let Some(message) =
400                        error_obj.get("message").and_then(|m| m.as_str())
401                    {
402                        message.to_string()
403                    } else if let Some(error_type) = error_obj.get("type").and_then(|t| t.as_str())
404                    {
405                        format!("API error: {}", error_type)
406                    } else {
407                        serde_json::to_string(error_obj)
408                            .unwrap_or_else(|_| "Unknown API error".to_string())
409                    };
410                    return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
411                        error_message,
412                    )));
413                }
414
415                // I have to copy this here to print the original response in case we find an error
416                let pretty_json = serde_json::to_string_pretty(&json).unwrap_or_default();
417                match serde_json::from_value::<AnthropicOutput>(json) {
418                    Ok(json_response) => Ok(json_response.into()),
419                    Err(e) => Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
420                        format!(
421                            "Error deserializing JSON: {:?}\nOriginal JSON: {}",
422                            e, pretty_json
423                        ),
424                    ))),
425                }
426            }
427            Err(e) => Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
428                format!("Failed to decode Anthropic JSON response:: {:?}", e),
429            ))),
430        }
431    }
432
433    pub async fn chat_stream(
434        config: &AnthropicConfig,
435        stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
436        input: AnthropicInput,
437    ) -> Result<LLMCompletionResponse, AgentError> {
438        let mut payload = json!({
439            "model": input.model.to_string(),
440            "system": input.messages.iter().find(|mess| mess.role == "system").map(|mess| json!([
441                {
442                    "type": "text",
443                    "text": mess.content.clone(),
444                    "cache_control": {"type": "ephemeral", "ttl": "5m"}
445                }
446            ])),
447            "messages": input.messages.into_iter().filter(|message| message.role != "system").collect::<Vec<LLMMessage>>(),
448            "max_tokens": input.max_tokens,
449            "temperature": 0,
450            "stream": true,
451        });
452
453        if let Some(tools) = input.tools {
454            payload["tools"] = json!(
455                tools
456                    .iter()
457                    .map(|tool| {
458                        let mut tool_json = json!(tool);
459                        if let Some(last_tool) = tools.last()
460                            && tool == last_tool
461                        {
462                            tool_json["cache_control"] = json!({"type": "ephemeral", "ttl": "1h"});
463                        }
464                        tool_json
465                    })
466                    .collect::<Vec<serde_json::Value>>()
467            );
468        }
469
470        if let Some(stop_sequences) = input.stop_sequences {
471            payload["stop_sequences"] = json!(stop_sequences);
472        }
473
474        // Setup retry with exponential backoff
475        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
476        let client = ClientBuilder::new(reqwest::Client::new())
477            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
478            .build();
479
480        let api_endpoint = config
481            .api_endpoint
482            .as_deref()
483            .unwrap_or("https://api.anthropic.com/v1/messages");
484
485        let api_key = config.api_key.as_ref().map_or("", |v| v);
486
487        // Send the POST request
488        let response = client
489            .post(api_endpoint)
490            .header("x-api-key", api_key)
491            .header("anthropic-version", "2023-06-01")
492            .header(
493                "anthropic-beta",
494                "extended-cache-ttl-2025-04-11,context-1m-2025-08-07",
495            )
496            .header("accept", "application/json")
497            .header("content-type", "application/json")
498            .json(&payload)
499            .send()
500            .await;
501
502        let response = match response {
503            Ok(resp) => resp,
504            Err(e) => {
505                return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
506                    e.to_string(),
507                )));
508            }
509        };
510
511        if !response.status().is_success() {
512            let error_body = match response.json::<AnthropicErrorOutput>().await {
513                Ok(body) => body,
514                Err(_) => AnthropicErrorOutput {
515                    r#type: "error".to_string(),
516                    error: AnthropicError {
517                        message: "Unable to read error response".to_string(),
518                        r#type: "error".to_string(),
519                    },
520                },
521            };
522
523            match error_body.error.r#type.as_str() {
524                "invalid_request_error" => {
525                    return Err(AgentError::BadRequest(
526                        BadRequestErrorMessage::InvalidAgentInput(error_body.error.message),
527                    ));
528                }
529                _ => {
530                    return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
531                        error_body.error.message,
532                    )));
533                }
534            }
535        }
536
537        let completion_response =
538            process_stream(response, input.model.to_string(), stream_channel_tx).await?;
539
540        Ok(completion_response)
541    }
542}
543
544pub async fn process_stream(
545    response: Response,
546    model: String,
547    stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
548) -> Result<LLMCompletionResponse, AgentError> {
549    let mut completion_response = LLMCompletionResponse {
550        id: "".to_string(),
551        model: model.clone(),
552        object: "chat.completion".to_string(),
553        choices: vec![],
554        created: chrono::Utc::now().timestamp_millis() as u64,
555        usage: None,
556    };
557
558    let mut choices: HashMap<usize, LLMChoice> = HashMap::from([(
559        0,
560        LLMChoice {
561            finish_reason: None,
562            index: 0,
563            message: LLMMessage {
564                role: "assistant".to_string(),
565                content: LLMMessageContent::List(vec![]),
566            },
567        },
568    )]);
569    let mut contents: Vec<LLMMessageTypedContent> = vec![];
570    let mut stream = response.bytes_stream();
571    let mut unparsed_data = String::new();
572
573    while let Some(chunk) = stream.next().await {
574        let chunk = chunk.map_err(|e| {
575            let error_message = format!("Failed to read stream chunk from Anthropic API: {e}");
576            AgentError::BadRequest(BadRequestErrorMessage::ApiError(error_message))
577        })?;
578
579        let text = std::str::from_utf8(&chunk).map_err(|e| {
580            let error_message = format!("Failed to parse UTF-8 from Anthropic response: {e}");
581            AgentError::BadRequest(BadRequestErrorMessage::ApiError(error_message))
582        })?;
583
584        unparsed_data.push_str(text);
585
586        while let Some(event_end) = unparsed_data.find("\n\n") {
587            let event_str = unparsed_data[..event_end].to_string();
588            unparsed_data = unparsed_data[event_end + 2..].to_string();
589
590            if !event_str.starts_with("event: ") {
591                continue;
592            }
593
594            let json_str = &event_str[event_str.find("data: ").map(|i| i + 6).unwrap_or(6)..];
595            if json_str == "[DONE]" {
596                continue;
597            }
598
599            match serde_json::from_str::<AnthropicStreamEventData>(json_str) {
600                Ok(data) => {
601                    match data {
602                        AnthropicStreamEventData::MessageStart { message } => {
603                            completion_response.id = message.id;
604                            completion_response.model = message.model;
605                            completion_response.object = message.r#type;
606                            completion_response.usage = Some(message.usage.into());
607                        }
608                        AnthropicStreamEventData::ContentBlockStart {
609                            content_block,
610                            index,
611                        } => match content_block {
612                            ContentBlock::Text { text } => {
613                                stream_channel_tx
614                                    .send(GenerationDelta::Content {
615                                        // if this will be rendered as markdown, we need to escape the < and >
616                                        content: text.clone(), //.replace("<", "\\<").replace(">", "\\>"),
617                                    })
618                                    .await
619                                    .map_err(|e| {
620                                        AgentError::BadRequest(BadRequestErrorMessage::ApiError(
621                                            e.to_string(),
622                                        ))
623                                    })?;
624                                contents.push(LLMMessageTypedContent::Text { text: text.clone() });
625                            }
626                            ContentBlock::Thinking { thinking } => {
627                                stream_channel_tx
628                                    .send(GenerationDelta::Thinking {
629                                        thinking: thinking.clone(),
630                                    })
631                                    .await
632                                    .map_err(|e| {
633                                        AgentError::BadRequest(BadRequestErrorMessage::ApiError(
634                                            e.to_string(),
635                                        ))
636                                    })?;
637                                contents.push(LLMMessageTypedContent::Text {
638                                    text: thinking.clone(),
639                                });
640                            }
641                            ContentBlock::ToolUse { id, name, input: _ } => {
642                                stream_channel_tx
643                                    .send(GenerationDelta::ToolUse {
644                                        tool_use: GenerationDeltaToolUse {
645                                            id: Some(id.clone()),
646                                            name: Some(name.clone()),
647                                            input: Some(String::new()),
648                                            index,
649                                        },
650                                    })
651                                    .await
652                                    .map_err(|e| {
653                                        AgentError::BadRequest(BadRequestErrorMessage::ApiError(
654                                            e.to_string(),
655                                        ))
656                                    })?;
657                                // Initialize with empty string since we'll accumulate via InputJsonDelta events
658                                contents.push(LLMMessageTypedContent::ToolCall {
659                                    id: id.clone(),
660                                    name: name.clone(),
661                                    args: serde_json::Value::String(String::new()),
662                                });
663                            }
664                        },
665                        AnthropicStreamEventData::ContentBlockDelta { delta, index } => {
666                            if let Some(content) = contents.get_mut(index) {
667                                match delta {
668                                    ContentDelta::TextDelta { text } => {
669                                        stream_channel_tx
670                                            .send(GenerationDelta::Content {
671                                                // if this will be rendered as markdown, we need to escape the < and >
672                                                content: text.clone(), //.replace("<", "\\<").replace(">", "\\>"),
673                                            })
674                                            .await
675                                            .map_err(|e| {
676                                                AgentError::BadRequest(
677                                                    BadRequestErrorMessage::ApiError(e.to_string()),
678                                                )
679                                            })?;
680                                        let delta_text = text.clone();
681                                        if let LLMMessageTypedContent::Text { text } = content {
682                                            text.push_str(&delta_text);
683                                        }
684                                    }
685                                    ContentDelta::ThinkingDelta { thinking } => {
686                                        stream_channel_tx
687                                            .send(GenerationDelta::Thinking {
688                                                thinking: thinking.clone(),
689                                            })
690                                            .await
691                                            .map_err(|e| {
692                                                AgentError::BadRequest(
693                                                    BadRequestErrorMessage::ApiError(e.to_string()),
694                                                )
695                                            })?;
696                                        if let LLMMessageTypedContent::Text { text } = content {
697                                            text.push_str(&thinking);
698                                        }
699                                    }
700                                    ContentDelta::InputJsonDelta { partial_json } => {
701                                        stream_channel_tx
702                                            .send(GenerationDelta::ToolUse {
703                                                tool_use: GenerationDeltaToolUse {
704                                                    id: None,
705                                                    name: None,
706                                                    input: Some(partial_json.clone()),
707                                                    index,
708                                                },
709                                            })
710                                            .await
711                                            .map_err(|e| {
712                                                AgentError::BadRequest(
713                                                    BadRequestErrorMessage::ApiError(e.to_string()),
714                                                )
715                                            })?;
716                                        if let Some(LLMMessageTypedContent::ToolCall {
717                                            args: serde_json::Value::String(accumulated_json),
718                                            ..
719                                        }) = contents.get_mut(index)
720                                        {
721                                            accumulated_json.push_str(&partial_json);
722                                        }
723                                    }
724                                }
725                            }
726                        }
727                        AnthropicStreamEventData::ContentBlockStop { index } => {
728                            if let Some(LLMMessageTypedContent::ToolCall { args, .. }) =
729                                contents.get_mut(index)
730                                && let serde_json::Value::String(json_str) = args
731                            {
732                                // Try to parse the accumulated JSON string
733                                *args = serde_json::from_str(json_str).unwrap_or_else(|_| {
734                                    // If parsing fails, keep as string
735                                    serde_json::Value::String(json_str.clone())
736                                });
737                            }
738                        }
739                        AnthropicStreamEventData::MessageDelta { delta, usage } => {
740                            //write message delta to file as json
741
742                            if let Some(stop_reason) = delta.stop_reason {
743                                for choice in choices.values_mut() {
744                                    if choice.finish_reason.is_none() {
745                                        choice.finish_reason = Some(stop_reason.clone());
746                                    }
747                                }
748                            }
749                            if let Some(usage) = usage {
750                                let usage = LLMTokenUsage {
751                                    prompt_tokens: usage.input_tokens,
752                                    completion_tokens: usage.output_tokens,
753                                    total_tokens: usage.input_tokens
754                                        + usage.cache_creation_input_tokens.unwrap_or(0)
755                                        + usage.cache_read_input_tokens.unwrap_or(0)
756                                        + usage.output_tokens,
757                                    prompt_tokens_details: Some(PromptTokensDetails {
758                                        input_tokens: Some(usage.input_tokens),
759                                        output_tokens: Some(usage.output_tokens),
760                                        cache_read_input_tokens: usage.cache_read_input_tokens,
761                                        cache_write_input_tokens: usage.cache_creation_input_tokens,
762                                    }),
763                                };
764
765                                stream_channel_tx
766                                    .send(GenerationDelta::Usage {
767                                        usage: usage.clone(),
768                                    })
769                                    .await
770                                    .map_err(|e| {
771                                        AgentError::BadRequest(BadRequestErrorMessage::ApiError(
772                                            e.to_string(),
773                                        ))
774                                    })?;
775                                completion_response.usage = Some(usage);
776                            }
777                        }
778
779                        _ => {}
780                    }
781                }
782                Err(_) => {
783                    // We don't want to fail the entire stream if we can't parse one message
784                    // Just log the error and continue
785                }
786            }
787        }
788    }
789
790    if let Some(choice) = choices.get_mut(&0) {
791        choice.message.content = LLMMessageContent::List(contents);
792    }
793
794    completion_response.choices = choices
795        .into_iter()
796        .sorted_by(|(index, _), (other_index, _)| index.cmp(other_index))
797        .map(|(_, choice)| choice)
798        .collect();
799
800    Ok(completion_response)
801}