1use crate::models::error::{AgentError, BadRequestErrorMessage};
2use crate::models::llm::{
3 GenerationDelta, GenerationDeltaToolUse, LLMChoice, LLMCompletionResponse, LLMMessage,
4 LLMMessageContent, LLMMessageTypedContent, LLMTool,
5};
6use crate::models::llm::{LLMTokenUsage, PromptTokensDetails};
7use futures_util::StreamExt;
8use itertools::Itertools;
9use reqwest::Response;
10use reqwest_middleware::ClientBuilder;
11use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use std::collections::HashMap;
15
16#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
17pub enum AnthropicModel {
18 #[serde(rename = "claude-haiku-4-5-20251001")]
19 Claude45Haiku,
20 #[serde(rename = "claude-sonnet-4-5-20250929")]
21 Claude45Sonnet,
22 #[serde(rename = "claude-opus-4-5-20251101")]
23 Claude45Opus,
24}
25impl std::fmt::Display for AnthropicModel {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 match self {
28 AnthropicModel::Claude45Haiku => write!(f, "claude-haiku-4-5-20251001"),
29 AnthropicModel::Claude45Sonnet => write!(f, "claude-sonnet-4-5-20250929"),
30 AnthropicModel::Claude45Opus => write!(f, "claude-opus-4-5-20251101"),
31 }
32 }
33}
34
35impl AnthropicModel {
36 pub fn from_string(s: &str) -> Result<Self, String> {
37 serde_json::from_value(serde_json::Value::String(s.to_string()))
38 .map_err(|_| "Failed to deserialize Anthropic model".to_string())
39 }
40
41 pub const DEFAULT_SMART_MODEL: AnthropicModel = AnthropicModel::Claude45Opus;
43
44 pub const DEFAULT_ECO_MODEL: AnthropicModel = AnthropicModel::Claude45Haiku;
46
47 pub const DEFAULT_RECOVERY_MODEL: AnthropicModel = AnthropicModel::Claude45Haiku;
49
50 pub fn default_smart_model() -> String {
52 Self::DEFAULT_SMART_MODEL.to_string()
53 }
54
55 pub fn default_eco_model() -> String {
57 Self::DEFAULT_ECO_MODEL.to_string()
58 }
59
60 pub fn default_recovery_model() -> String {
62 Self::DEFAULT_RECOVERY_MODEL.to_string()
63 }
64}
65
66#[derive(Serialize, Deserialize, Debug)]
67pub struct AnthropicInput {
68 pub model: AnthropicModel,
69 pub messages: Vec<LLMMessage>,
70 pub grammar: Option<String>,
71 pub max_tokens: u32,
72 pub stop_sequences: Option<Vec<String>>,
73 pub tools: Option<Vec<LLMTool>>,
74 pub thinking: ThinkingInput,
75}
76
77#[derive(Serialize, Deserialize, Debug)]
78pub struct ThinkingInput {
79 pub r#type: ThinkingType,
80 pub budget_tokens: u32,
82}
83
84impl Default for ThinkingInput {
85 fn default() -> Self {
86 Self {
87 r#type: ThinkingType::default(),
88 budget_tokens: 1024,
89 }
90 }
91}
92
93#[derive(Serialize, Deserialize, Debug, Default)]
94#[serde(rename_all = "lowercase")]
95pub enum ThinkingType {
96 Enabled,
97 #[default]
98 Disabled,
99}
100
101#[derive(Serialize, Deserialize, Debug)]
102pub struct AnthropicOutputUsage {
103 pub input_tokens: u32,
104 pub output_tokens: u32,
105 #[serde(default)]
106 pub cache_creation_input_tokens: Option<u32>,
107 #[serde(default)]
108 pub cache_read_input_tokens: Option<u32>,
109}
110
111#[derive(Serialize, Deserialize, Debug)]
112pub struct AnthropicOutput {
113 pub id: String,
114 pub r#type: String,
115 pub role: String,
116 pub content: LLMMessageContent,
117 pub model: String,
118 pub stop_reason: String,
119 pub usage: AnthropicOutputUsage,
120}
121
122#[derive(Serialize, Deserialize, Debug)]
123pub struct AnthropicErrorOutput {
124 pub r#type: String,
125 pub error: AnthropicError,
126}
127
128#[derive(Serialize, Deserialize, Debug)]
129pub struct AnthropicError {
130 pub message: String,
131 pub r#type: String,
132}
133
134impl From<AnthropicOutput> for LLMCompletionResponse {
135 fn from(val: AnthropicOutput) -> Self {
136 let choices = vec![LLMChoice {
137 finish_reason: Some(val.stop_reason.clone()),
138 index: 0,
139 message: LLMMessage {
140 role: val.role.clone(),
141 content: val.content,
142 },
143 }];
144
145 LLMCompletionResponse {
146 id: val.id,
147 model: val.model,
148 object: val.r#type,
149 choices,
150 created: chrono::Utc::now().timestamp_millis() as u64,
151 usage: Some(val.usage.into()),
152 }
153 }
154}
155
156#[derive(Serialize, Deserialize, Debug)]
157pub struct AnthropicStreamEvent {
158 #[serde(rename = "type")]
159 pub event: String,
160 #[serde(flatten)]
161 pub data: AnthropicStreamEventData,
162}
163
164impl From<AnthropicOutputUsage> for LLMTokenUsage {
165 fn from(usage: AnthropicOutputUsage) -> Self {
166 let input_tokens = usage.input_tokens
167 + usage.cache_creation_input_tokens.unwrap_or(0)
168 + usage.cache_read_input_tokens.unwrap_or(0);
169 let output_tokens = usage.output_tokens;
170 Self {
171 completion_tokens: output_tokens,
172 prompt_tokens: input_tokens,
173 total_tokens: input_tokens + output_tokens,
174 prompt_tokens_details: Some(PromptTokensDetails {
175 input_tokens: Some(input_tokens),
176 output_tokens: Some(output_tokens),
177 cache_read_input_tokens: usage.cache_read_input_tokens,
178 cache_write_input_tokens: usage.cache_creation_input_tokens,
179 }),
180 }
181 }
182}
183
184#[derive(Serialize, Deserialize, Debug)]
185pub struct AnthropicStreamOutput {
186 pub id: String,
187 pub r#type: String,
188 pub role: String,
189 pub content: LLMMessageContent,
190 pub model: String,
191 pub stop_reason: Option<String>,
192 pub usage: AnthropicOutputUsage,
193}
194
195#[derive(Serialize, Deserialize, Debug)]
196#[serde(rename_all = "snake_case", tag = "type")]
197pub enum AnthropicStreamEventData {
198 MessageStart {
199 message: AnthropicStreamOutput,
200 },
201 ContentBlockStart {
202 index: usize,
203 content_block: ContentBlock,
204 },
205 ContentBlockDelta {
206 index: usize,
207 delta: ContentDelta,
208 },
209 ContentBlockStop {
210 index: usize,
211 },
212 MessageDelta {
213 delta: MessageDelta,
214 usage: Option<AnthropicOutputUsage>,
215 },
216 MessageStop {},
217 Ping {},
218}
219
220#[derive(Serialize, Deserialize, Debug)]
221#[serde(tag = "type")]
222pub enum ContentBlock {
223 #[serde(rename = "text")]
224 Text { text: String },
225 #[serde(rename = "thinking")]
226 Thinking { thinking: String },
227 #[serde(rename = "tool_use")]
228 ToolUse {
229 id: String,
230 name: String,
231 input: serde_json::Value,
232 },
233}
234
235#[derive(Serialize, Deserialize, Debug)]
236#[serde(tag = "type")]
237pub enum ContentDelta {
238 #[serde(rename = "text_delta")]
239 TextDelta { text: String },
240 #[serde(rename = "thinking_delta")]
241 ThinkingDelta { thinking: String },
242 #[serde(rename = "input_json_delta")]
243 InputJsonDelta { partial_json: String },
244}
245
246#[derive(Serialize, Deserialize, Debug)]
247pub struct MessageDelta {
248 pub stop_reason: Option<String>,
249 pub stop_sequence: Option<String>,
250}
251
252#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
253pub struct AnthropicConfig {
254 pub api_endpoint: Option<String>,
255 pub api_key: Option<String>,
256}
257
258pub struct Anthropic {}
259
260impl Anthropic {
261 pub async fn chat(
262 config: &AnthropicConfig,
263 input: AnthropicInput,
264 ) -> Result<LLMCompletionResponse, AgentError> {
265 let mut payload = json!({
266 "model": input.model.to_string(),
267 "system": input.messages.iter().find(|mess| mess.role == "system").map(|mess| mess.content.clone()),
268 "messages": input.messages.into_iter().filter(|message| message.role!= "system").collect::<Vec<LLMMessage>>(),
269 "max_tokens": input.max_tokens,
270 "temperature": 0,
271 "stream": false,
272 });
273
274 if let Some(tools) = input.tools {
275 payload["tools"] = json!(tools);
276 }
277
278 if let Some(stop_sequences) = input.stop_sequences {
279 payload["stop_sequences"] = json!(stop_sequences);
280 }
281
282 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
284 let client = ClientBuilder::new(reqwest::Client::new())
285 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
286 .build();
287
288 let api_endpoint = config
289 .api_endpoint
290 .as_ref()
291 .map_or("https://api.anthropic.com/v1/messages", |v| v);
292 let api_key = config.api_key.as_ref().map_or("", |v| v);
293
294 let response = client
296 .post(api_endpoint)
297 .header("x-api-key", api_key)
298 .header("anthropic-version", "2023-06-01")
299 .header("accept", "application/json")
300 .header("content-type", "application/json")
301 .json(&payload)
302 .send()
303 .await;
304
305 let response = match response {
306 Ok(resp) => resp,
307 Err(e) => {
308 let error_message = format!("Anthropic API request error: {e}");
309 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
310 error_message,
311 )));
312 }
313 };
314
315 if !response.status().is_success() {
317 let status = response.status();
318 let error_body = match response.text().await {
319 Ok(body) => body,
320 Err(_) => "Unable to read error response".to_string(),
321 };
322
323 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
324 format!(
325 "Anthropic API returned error status: {}, body: {}",
326 status, error_body
327 ),
328 )));
329 }
330
331 match response.json::<Value>().await {
332 Ok(json) => {
333 let pretty_json = serde_json::to_string_pretty(&json).unwrap_or_default();
335 match serde_json::from_value::<AnthropicOutput>(json) {
336 Ok(json_response) => Ok(json_response.into()),
337 Err(e) => Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
338 format!(
339 "Error deserializing JSON: {:?}\nOriginal JSON: {}",
340 e, pretty_json
341 ),
342 ))),
343 }
344 }
345 Err(e) => Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
346 format!("Failed to decode Anthropic JSON response:: {:?}", e),
347 ))),
348 }
349 }
350
351 pub async fn chat_stream(
352 config: &AnthropicConfig,
353 stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
354 input: AnthropicInput,
355 ) -> Result<LLMCompletionResponse, AgentError> {
356 let mut payload = json!({
357 "model": input.model.to_string(),
358 "system": input.messages.iter().find(|mess| mess.role == "system").map(|mess| json!([
359 {
360 "type": "text",
361 "text": mess.content.clone(),
362 "cache_control": {"type": "ephemeral", "ttl": "5m"}
363 }
364 ])),
365 "messages": input.messages.into_iter().filter(|message| message.role != "system").collect::<Vec<LLMMessage>>(),
366 "max_tokens": input.max_tokens,
367 "temperature": 0,
368 "stream": true,
369 });
370
371 if let Some(tools) = input.tools {
372 payload["tools"] = json!(
373 tools
374 .iter()
375 .map(|tool| {
376 let mut tool_json = json!(tool);
377 if let Some(last_tool) = tools.last()
378 && tool == last_tool
379 {
380 tool_json["cache_control"] = json!({"type": "ephemeral", "ttl": "1h"});
381 }
382 tool_json
383 })
384 .collect::<Vec<serde_json::Value>>()
385 );
386 }
387
388 if let Some(stop_sequences) = input.stop_sequences {
389 payload["stop_sequences"] = json!(stop_sequences);
390 }
391
392 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
394 let client = ClientBuilder::new(reqwest::Client::new())
395 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
396 .build();
397
398 let api_endpoint = config
399 .api_endpoint
400 .as_deref()
401 .unwrap_or("https://api.anthropic.com/v1/messages");
402
403 let api_key = config.api_key.as_ref().map_or("", |v| v);
404
405 let response = client
407 .post(api_endpoint)
408 .header("x-api-key", api_key)
409 .header("anthropic-version", "2023-06-01")
410 .header(
411 "anthropic-beta",
412 "extended-cache-ttl-2025-04-11,context-1m-2025-08-07",
413 )
414 .header("accept", "application/json")
415 .header("content-type", "application/json")
416 .json(&payload)
417 .send()
418 .await;
419
420 let response = match response {
421 Ok(resp) => resp,
422 Err(e) => {
423 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
424 e.to_string(),
425 )));
426 }
427 };
428
429 if !response.status().is_success() {
430 let error_body = match response.json::<AnthropicErrorOutput>().await {
431 Ok(body) => body,
432 Err(_) => AnthropicErrorOutput {
433 r#type: "error".to_string(),
434 error: AnthropicError {
435 message: "Unable to read error response".to_string(),
436 r#type: "error".to_string(),
437 },
438 },
439 };
440
441 match error_body.error.r#type.as_str() {
442 "invalid_request_error" => {
443 return Err(AgentError::BadRequest(
444 BadRequestErrorMessage::InvalidAgentInput(error_body.error.message),
445 ));
446 }
447 _ => {
448 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
449 error_body.error.message,
450 )));
451 }
452 }
453 }
454
455 let completion_response =
456 process_stream(response, input.model.to_string(), stream_channel_tx).await?;
457
458 Ok(completion_response)
459 }
460}
461
462pub async fn process_stream(
463 response: Response,
464 model: String,
465 stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
466) -> Result<LLMCompletionResponse, AgentError> {
467 let mut completion_response = LLMCompletionResponse {
468 id: "".to_string(),
469 model: model.clone(),
470 object: "chat.completion".to_string(),
471 choices: vec![],
472 created: chrono::Utc::now().timestamp_millis() as u64,
473 usage: None,
474 };
475
476 let mut choices: HashMap<usize, LLMChoice> = HashMap::from([(
477 0,
478 LLMChoice {
479 finish_reason: None,
480 index: 0,
481 message: LLMMessage {
482 role: "assistant".to_string(),
483 content: LLMMessageContent::List(vec![]),
484 },
485 },
486 )]);
487 let mut contents: Vec<LLMMessageTypedContent> = vec![];
488 let mut stream = response.bytes_stream();
489 let mut unparsed_data = String::new();
490
491 while let Some(chunk) = stream.next().await {
492 let chunk = chunk.map_err(|e| {
493 let error_message = format!("Failed to read stream chunk from Anthropic API: {e}");
494 AgentError::BadRequest(BadRequestErrorMessage::ApiError(error_message))
495 })?;
496
497 let text = std::str::from_utf8(&chunk).map_err(|e| {
498 let error_message = format!("Failed to parse UTF-8 from Anthropic response: {e}");
499 AgentError::BadRequest(BadRequestErrorMessage::ApiError(error_message))
500 })?;
501
502 unparsed_data.push_str(text);
503
504 while let Some(event_end) = unparsed_data.find("\n\n") {
505 let event_str = unparsed_data[..event_end].to_string();
506 unparsed_data = unparsed_data[event_end + 2..].to_string();
507
508 if !event_str.starts_with("event: ") {
509 continue;
510 }
511
512 let json_str = &event_str[event_str.find("data: ").map(|i| i + 6).unwrap_or(6)..];
513 if json_str == "[DONE]" {
514 continue;
515 }
516
517 match serde_json::from_str::<AnthropicStreamEventData>(json_str) {
518 Ok(data) => {
519 match data {
520 AnthropicStreamEventData::MessageStart { message } => {
521 completion_response.id = message.id;
522 completion_response.model = message.model;
523 completion_response.object = message.r#type;
524 completion_response.usage = Some(message.usage.into());
525 }
526 AnthropicStreamEventData::ContentBlockStart {
527 content_block,
528 index,
529 } => match content_block {
530 ContentBlock::Text { text } => {
531 stream_channel_tx
532 .send(GenerationDelta::Content {
533 content: text.clone(), })
536 .await
537 .map_err(|e| {
538 AgentError::BadRequest(BadRequestErrorMessage::ApiError(
539 e.to_string(),
540 ))
541 })?;
542 contents.push(LLMMessageTypedContent::Text { text: text.clone() });
543 }
544 ContentBlock::Thinking { thinking } => {
545 stream_channel_tx
546 .send(GenerationDelta::Thinking {
547 thinking: thinking.clone(),
548 })
549 .await
550 .map_err(|e| {
551 AgentError::BadRequest(BadRequestErrorMessage::ApiError(
552 e.to_string(),
553 ))
554 })?;
555 contents.push(LLMMessageTypedContent::Text {
556 text: thinking.clone(),
557 });
558 }
559 ContentBlock::ToolUse { id, name, input: _ } => {
560 stream_channel_tx
561 .send(GenerationDelta::ToolUse {
562 tool_use: GenerationDeltaToolUse {
563 id: Some(id.clone()),
564 name: Some(name.clone()),
565 input: Some(String::new()),
566 index,
567 },
568 })
569 .await
570 .map_err(|e| {
571 AgentError::BadRequest(BadRequestErrorMessage::ApiError(
572 e.to_string(),
573 ))
574 })?;
575 contents.push(LLMMessageTypedContent::ToolCall {
577 id: id.clone(),
578 name: name.clone(),
579 args: serde_json::Value::String(String::new()),
580 });
581 }
582 },
583 AnthropicStreamEventData::ContentBlockDelta { delta, index } => {
584 if let Some(content) = contents.get_mut(index) {
585 match delta {
586 ContentDelta::TextDelta { text } => {
587 stream_channel_tx
588 .send(GenerationDelta::Content {
589 content: text.clone(), })
592 .await
593 .map_err(|e| {
594 AgentError::BadRequest(
595 BadRequestErrorMessage::ApiError(e.to_string()),
596 )
597 })?;
598 let delta_text = text.clone();
599 if let LLMMessageTypedContent::Text { text } = content {
600 text.push_str(&delta_text);
601 }
602 }
603 ContentDelta::ThinkingDelta { thinking } => {
604 stream_channel_tx
605 .send(GenerationDelta::Thinking {
606 thinking: thinking.clone(),
607 })
608 .await
609 .map_err(|e| {
610 AgentError::BadRequest(
611 BadRequestErrorMessage::ApiError(e.to_string()),
612 )
613 })?;
614 if let LLMMessageTypedContent::Text { text } = content {
615 text.push_str(&thinking);
616 }
617 }
618 ContentDelta::InputJsonDelta { partial_json } => {
619 stream_channel_tx
620 .send(GenerationDelta::ToolUse {
621 tool_use: GenerationDeltaToolUse {
622 id: None,
623 name: None,
624 input: Some(partial_json.clone()),
625 index,
626 },
627 })
628 .await
629 .map_err(|e| {
630 AgentError::BadRequest(
631 BadRequestErrorMessage::ApiError(e.to_string()),
632 )
633 })?;
634 if let Some(LLMMessageTypedContent::ToolCall {
635 args: serde_json::Value::String(accumulated_json),
636 ..
637 }) = contents.get_mut(index)
638 {
639 accumulated_json.push_str(&partial_json);
640 }
641 }
642 }
643 }
644 }
645 AnthropicStreamEventData::ContentBlockStop { index } => {
646 if let Some(LLMMessageTypedContent::ToolCall { args, .. }) =
647 contents.get_mut(index)
648 && let serde_json::Value::String(json_str) = args
649 {
650 *args = serde_json::from_str(json_str).unwrap_or_else(|_| {
652 serde_json::Value::String(json_str.clone())
654 });
655 }
656 }
657 AnthropicStreamEventData::MessageDelta { delta, usage } => {
658 if let Some(stop_reason) = delta.stop_reason {
661 for choice in choices.values_mut() {
662 if choice.finish_reason.is_none() {
663 choice.finish_reason = Some(stop_reason.clone());
664 }
665 }
666 }
667 if let Some(usage) = usage {
668 let usage = LLMTokenUsage {
669 prompt_tokens: usage.input_tokens,
670 completion_tokens: usage.output_tokens,
671 total_tokens: usage.input_tokens + usage.output_tokens,
672 prompt_tokens_details: Some(PromptTokensDetails {
673 input_tokens: Some(usage.input_tokens),
674 output_tokens: Some(usage.output_tokens),
675 cache_read_input_tokens: usage.cache_read_input_tokens,
676 cache_write_input_tokens: usage.cache_creation_input_tokens,
677 }),
678 };
679
680 stream_channel_tx
681 .send(GenerationDelta::Usage {
682 usage: usage.clone(),
683 })
684 .await
685 .map_err(|e| {
686 AgentError::BadRequest(BadRequestErrorMessage::ApiError(
687 e.to_string(),
688 ))
689 })?;
690 completion_response.usage = Some(usage);
691 }
692 }
693
694 _ => {}
695 }
696 }
697 Err(_) => {
698 }
701 }
702 }
703 }
704
705 if let Some(choice) = choices.get_mut(&0) {
706 choice.message.content = LLMMessageContent::List(contents);
707 }
708
709 completion_response.choices = choices
710 .into_iter()
711 .sorted_by(|(index, _), (other_index, _)| index.cmp(other_index))
712 .map(|(_, choice)| choice)
713 .collect();
714
715 Ok(completion_response)
716}