Skip to main content

tycode_core/ai/
bedrock.rs

1use std::collections::HashSet;
2use std::pin::Pin;
3
4use tokio_stream::Stream;
5
6use base64::Engine;
7
8use aws_sdk_bedrockruntime::{
9    operation::converse::{builders::ConverseFluentBuilder, ConverseError},
10    operation::converse_stream::{builders::ConverseStreamFluentBuilder, ConverseStreamError},
11    types::ConverseStreamOutput as BedrockStreamEvent,
12    types::{
13        CachePointBlock, ContentBlock as BedrockContentBlock, ImageBlock, ImageFormat, ImageSource,
14        Message as BedrockMessage, ReasoningContentBlock, ReasoningTextBlock, SystemContentBlock,
15        Tool, ToolConfiguration, ToolInputSchema, ToolResultBlock, ToolResultContentBlock,
16        ToolSpecification, ToolUseBlock,
17    },
18    Client as BedrockClient,
19};
20use aws_smithy_types::Blob;
21use serde_json::json;
22
23use crate::ai::{error::AiError, provider::AiProvider, types::*};
24use crate::ai::{
25    json::{from_doc, to_doc},
26    model::Model,
27};
28
29#[derive(Clone)]
30pub struct BedrockProvider {
31    client: BedrockClient,
32}
33
34impl BedrockProvider {
35    pub fn new(client: BedrockClient) -> Self {
36        Self { client }
37    }
38
39    fn get_bedrock_model_id(&self, model: &Model) -> Result<String, AiError> {
40        let model_id = match model {
41            Model::ClaudeSonnet45 => "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
42            Model::ClaudeHaiku45 => "us.anthropic.claude-haiku-4-5-20251001-v1:0",
43            Model::ClaudeOpus46 => "global.anthropic.claude-opus-4-6-v1",
44            Model::ClaudeOpus45 => "global.anthropic.claude-opus-4-5-20251101-v1:0",
45            Model::GptOss120b => "openai.gpt-oss-120b-1:0",
46            _ => {
47                return Err(AiError::Terminal(anyhow::anyhow!(
48                    "Model {} is not supported in bedrock",
49                    model.name()
50                )))
51            }
52        };
53        Ok(model_id.to_string())
54    }
55
56    fn convert_to_bedrock_messages(
57        &self,
58        messages: &[Message],
59        model: Model,
60    ) -> Result<Vec<BedrockMessage>, AiError> {
61        let mut bedrock_messages = Vec::new();
62
63        for (msg_index, msg) in messages.iter().enumerate() {
64            let role = match msg.role {
65                MessageRole::User => aws_sdk_bedrockruntime::types::ConversationRole::User,
66                MessageRole::Assistant => {
67                    aws_sdk_bedrockruntime::types::ConversationRole::Assistant
68                }
69            };
70
71            let mut content_blocks = Vec::new();
72            for block in msg.content.blocks() {
73                match block {
74                    ContentBlock::Text(text) => {
75                        if !text.trim().is_empty() {
76                            content_blocks.push(BedrockContentBlock::Text(text.trim().to_string()));
77                        }
78                    }
79                    ContentBlock::ReasoningContent(reasoning) => {
80                        let reasoning_content = if let Some(blob) = &reasoning.blob {
81                            ReasoningContentBlock::RedactedContent(Blob::new(blob.clone()))
82                        } else {
83                            let mut text_block_builder =
84                                ReasoningTextBlock::builder().text(&reasoning.text);
85
86                            if let Some(signature) = &reasoning.signature {
87                                text_block_builder = text_block_builder.signature(signature);
88                            }
89
90                            let text_block = text_block_builder.build().map_err(|e| {
91                                AiError::Terminal(anyhow::anyhow!(
92                                    "Failed to build reasoning text block: {:?}",
93                                    e
94                                ))
95                            })?;
96
97                            ReasoningContentBlock::ReasoningText(text_block)
98                        };
99
100                        content_blocks
101                            .push(BedrockContentBlock::ReasoningContent(reasoning_content));
102                    }
103                    ContentBlock::ToolUse(tool_use) => {
104                        let args = if tool_use.arguments.is_null() {
105                            tracing::warn!(
106                                tool_name = %tool_use.name,
107                                "Null tool arguments in conversation history, substituting empty object"
108                            );
109                            serde_json::Value::Object(Default::default())
110                        } else {
111                            tool_use.arguments.clone()
112                        };
113                        let tool_use_block = ToolUseBlock::builder()
114                            .tool_use_id(&tool_use.id)
115                            .name(&tool_use.name)
116                            .input(to_doc(args))
117                            .build()
118                            .map_err(|e| {
119                                AiError::Terminal(anyhow::anyhow!(
120                                    "Failed to build tool use block: {:?}",
121                                    e
122                                ))
123                            })?;
124                        content_blocks.push(BedrockContentBlock::ToolUse(tool_use_block));
125                    }
126                    ContentBlock::ToolResult(tool_result) => {
127                        let tool_result_block = ToolResultBlock::builder()
128                            .tool_use_id(&tool_result.tool_use_id)
129                            .content(ToolResultContentBlock::Text(tool_result.content.clone()))
130                            .build()
131                            .map_err(|e| {
132                                AiError::Terminal(anyhow::anyhow!(
133                                    "Failed to build tool result block: {:?}",
134                                    e
135                                ))
136                            })?;
137                        content_blocks.push(BedrockContentBlock::ToolResult(tool_result_block));
138                    }
139                    ContentBlock::Image(image) => {
140                        content_blocks.push(BedrockContentBlock::Image(build_bedrock_image_block(
141                            image,
142                        )?));
143                    }
144                }
145            }
146
147            if content_blocks.is_empty() {
148                content_blocks.push(BedrockContentBlock::Text("...".to_string()));
149            }
150
151            // Reorder: reasoning blocks first for deterministic ordering
152            // and cache point compatibility (cache point cannot follow reasoning)
153            let (reasoning, non_reasoning): (Vec<_>, Vec<_>) = content_blocks
154                .into_iter()
155                .partition(|b| matches!(b, BedrockContentBlock::ReasoningContent(_)));
156            content_blocks = reasoning;
157            content_blocks.extend(non_reasoning);
158
159            let last_is_reasoning = content_blocks
160                .last()
161                .is_some_and(|b| matches!(b, BedrockContentBlock::ReasoningContent(_)));
162            if model.supports_prompt_caching()
163                && messages.len() >= 2
164                && msg_index == messages.len() - 2
165                && !last_is_reasoning
166            {
167                content_blocks.push(BedrockContentBlock::CachePoint(Self::build_cache_point()?));
168            }
169
170            bedrock_messages.push(
171                BedrockMessage::builder()
172                    .role(role)
173                    .set_content(Some(content_blocks))
174                    .build()
175                    .map_err(|e| {
176                        AiError::Terminal(anyhow::anyhow!("Failed to build message: {:?}", e))
177                    })?,
178            );
179        }
180
181        Ok(bedrock_messages)
182    }
183}
184
185fn map_image_format(media_type: &str) -> Result<ImageFormat, AiError> {
186    match media_type {
187        "image/png" => Ok(ImageFormat::Png),
188        "image/jpeg" => Ok(ImageFormat::Jpeg),
189        "image/gif" => Ok(ImageFormat::Gif),
190        "image/webp" => Ok(ImageFormat::Webp),
191        other => Err(AiError::Terminal(anyhow::anyhow!(
192            "Unsupported image format: {other}"
193        ))),
194    }
195}
196
197fn build_bedrock_image_block(image: &ImageData) -> Result<ImageBlock, AiError> {
198    let bytes = base64::engine::general_purpose::STANDARD
199        .decode(&image.data)
200        .map_err(|e| AiError::Terminal(anyhow::anyhow!("Failed to decode image base64: {e:?}")))?;
201
202    let format = map_image_format(&image.media_type)?;
203
204    ImageBlock::builder()
205        .format(format)
206        .source(ImageSource::Bytes(Blob::new(bytes)))
207        .build()
208        .map_err(|e| AiError::Terminal(anyhow::anyhow!("Failed to build image block: {e:?}")))
209}
210
211impl BedrockProvider {
212    fn extract_content_blocks(&self, message: BedrockMessage) -> Content {
213        let mut content_blocks = Vec::new();
214
215        tracing::debug!("Processing {} content blocks", message.content().len());
216
217        for (i, content) in message.content().iter().enumerate() {
218            tracing::debug!("Content block {}: {:?}", i, content);
219
220            match content {
221                BedrockContentBlock::Text(text) => {
222                    tracing::debug!("Text block: {}", text);
223                    content_blocks.push(ContentBlock::Text(text.clone()));
224                }
225                BedrockContentBlock::ReasoningContent(block) => {
226                    let reasoning_data = if block.is_reasoning_text() {
227                        let block = block.as_reasoning_text().unwrap();
228                        ReasoningData {
229                            text: block.text.clone(),
230                            signature: block.signature.clone(),
231                            blob: None,
232                            raw_json: None,
233                        }
234                    } else {
235                        let block = block.as_redacted_content().unwrap();
236                        ReasoningData {
237                            text: "** Redacted reasoning content **".to_string(),
238                            signature: None,
239                            blob: Some(block.clone().into_inner()),
240                            raw_json: None,
241                        }
242                    };
243                    content_blocks.push(ContentBlock::ReasoningContent(reasoning_data));
244                }
245                BedrockContentBlock::ToolUse(tool_use) => {
246                    let tool_use_data = ToolUseData {
247                        id: tool_use.tool_use_id().to_string(),
248                        name: tool_use.name().to_string(),
249                        arguments: from_doc(tool_use.input().clone()),
250                    };
251                    content_blocks.push(ContentBlock::ToolUse(tool_use_data));
252                }
253                _ => (),
254            }
255        }
256
257        Content::from(content_blocks)
258    }
259
260    fn build_cache_point() -> Result<CachePointBlock, AiError> {
261        CachePointBlock::builder()
262            .r#type(aws_sdk_bedrockruntime::types::CachePointType::Default)
263            .build()
264            .map_err(|e| {
265                AiError::Terminal(anyhow::anyhow!(
266                    "Failed to build cache point block: {:?}",
267                    e
268                ))
269            })
270    }
271
272    fn effective_reasoning_budget_tokens(model: &ModelSettings) -> Option<u32> {
273        let requested_budget = model.reasoning_budget.get_max_tokens()?;
274
275        let Some(max_tokens) = model.max_tokens else {
276            return Some(requested_budget);
277        };
278
279        // Bedrock requires max_tokens > thinking.budget_tokens.
280        if max_tokens <= 1 {
281            tracing::warn!(
282                max_tokens,
283                requested_budget,
284                "Skipping reasoning budget because max_tokens is too low"
285            );
286            return None;
287        }
288
289        let capped_budget = max_tokens.saturating_sub(1);
290        if requested_budget > capped_budget {
291            tracing::warn!(
292                requested_budget,
293                max_tokens,
294                capped_budget,
295                "Capping reasoning budget so it remains below max_tokens"
296            );
297            Some(capped_budget)
298        } else {
299            Some(requested_budget)
300        }
301    }
302
303    fn apply_additional_model_fields(
304        &self,
305        model: &ModelSettings,
306        request: ConverseFluentBuilder,
307    ) -> ConverseFluentBuilder {
308        let mut additional_fields = serde_json::Map::new();
309
310        match model.model {
311            Model::ClaudeOpus46 => {
312                if let Some(effort) = model.reasoning_budget.get_effort_level() {
313                    tracing::info!("Enabling adaptive reasoning with effort '{effort}'");
314                    additional_fields.insert("thinking".to_string(), json!({"type": "adaptive"}));
315                    additional_fields
316                        .insert("output_config".to_string(), json!({"effort": effort}));
317                }
318            }
319            Model::ClaudeOpus45 | Model::ClaudeSonnet45 => {
320                if let Some(reasoning_budget) = Self::effective_reasoning_budget_tokens(model) {
321                    tracing::info!("Enabling reasoning with budget {} tokens", reasoning_budget);
322                    additional_fields.insert(
323                        "thinking".to_string(),
324                        json!({
325                            "type": "enabled",
326                            "budget_tokens": reasoning_budget
327                        }),
328                    );
329                }
330            }
331            _ => {}
332        }
333
334        if matches!(model.model, Model::ClaudeSonnet45) {
335            tracing::info!("Enabling 1M context beta for Claude Sonnet 4.5");
336            additional_fields.insert(
337                "anthropic_beta".to_string(),
338                json!(["context-1m-2025-08-07"]),
339            );
340        }
341
342        if additional_fields.is_empty() {
343            return request;
344        }
345
346        let additional_params = serde_json::Value::Object(additional_fields);
347        tracing::debug!("Additional model request fields: {:?}", additional_params);
348        request.additional_model_request_fields(to_doc(additional_params))
349    }
350
351    fn apply_additional_model_fields_stream(
352        &self,
353        model: &ModelSettings,
354        request: ConverseStreamFluentBuilder,
355    ) -> ConverseStreamFluentBuilder {
356        let mut additional_fields = serde_json::Map::new();
357
358        match model.model {
359            Model::ClaudeOpus46 => {
360                if let Some(effort) = model.reasoning_budget.get_effort_level() {
361                    tracing::info!("Enabling adaptive reasoning with effort '{effort}'");
362                    additional_fields.insert("thinking".to_string(), json!({"type": "adaptive"}));
363                    additional_fields
364                        .insert("output_config".to_string(), json!({"effort": effort}));
365                }
366            }
367            Model::ClaudeOpus45 | Model::ClaudeSonnet45 => {
368                if let Some(reasoning_budget) = Self::effective_reasoning_budget_tokens(model) {
369                    tracing::info!("Enabling reasoning with budget {} tokens", reasoning_budget);
370                    additional_fields.insert(
371                        "thinking".to_string(),
372                        json!({
373                            "type": "enabled",
374                            "budget_tokens": reasoning_budget
375                        }),
376                    );
377                }
378            }
379            _ => {}
380        }
381
382        if matches!(model.model, Model::ClaudeSonnet45) {
383            tracing::info!("Enabling 1M context beta for Claude Sonnet 4.5");
384            additional_fields.insert(
385                "anthropic_beta".to_string(),
386                json!(["context-1m-2025-08-07"]),
387            );
388        }
389
390        if additional_fields.is_empty() {
391            return request;
392        }
393
394        let additional_params = serde_json::Value::Object(additional_fields);
395        tracing::debug!("Additional model request fields: {:?}", additional_params);
396        request.additional_model_request_fields(to_doc(additional_params))
397    }
398}
399
400struct BedrockStreamAccumulator {
401    content_blocks: Vec<ContentBlock>,
402    pending_text: String,
403    pending_reasoning: String,
404    pending_tool_id: String,
405    pending_tool_name: String,
406    pending_tool_input: String,
407    in_text_block: bool,
408    in_reasoning_block: bool,
409    in_tool_block: bool,
410    pending_reasoning_signature: Option<String>,
411    usage: TokenUsage,
412    stop_reason: StopReason,
413}
414
415impl BedrockStreamAccumulator {
416    fn new() -> Self {
417        Self {
418            content_blocks: Vec::new(),
419            pending_text: String::new(),
420            pending_reasoning: String::new(),
421            pending_tool_id: String::new(),
422            pending_tool_name: String::new(),
423            pending_tool_input: String::new(),
424            in_text_block: false,
425            in_reasoning_block: false,
426            in_tool_block: false,
427            pending_reasoning_signature: None,
428            usage: TokenUsage::empty(),
429            stop_reason: StopReason::EndTurn,
430        }
431    }
432
433    fn process_event(&mut self, event: BedrockStreamEvent) -> Vec<StreamEvent> {
434        match event {
435            BedrockStreamEvent::ContentBlockStart(start) => self.handle_block_start(start),
436            BedrockStreamEvent::ContentBlockDelta(delta) => self.handle_block_delta(delta),
437            BedrockStreamEvent::ContentBlockStop(_) => self.handle_block_stop(),
438            BedrockStreamEvent::MessageStop(stop) => {
439                self.handle_message_stop(stop);
440                vec![]
441            }
442            BedrockStreamEvent::Metadata(metadata) => {
443                self.handle_metadata(metadata);
444                vec![]
445            }
446            _ => vec![],
447        }
448    }
449
450    fn handle_block_start(
451        &mut self,
452        start: aws_sdk_bedrockruntime::types::ContentBlockStartEvent,
453    ) -> Vec<StreamEvent> {
454        let content_start = match start.start() {
455            Some(s) => s,
456            None => return vec![StreamEvent::ContentBlockStart],
457        };
458
459        if content_start.is_tool_use() {
460            let tool_use = content_start.as_tool_use().unwrap();
461            self.in_tool_block = true;
462            self.pending_tool_id = tool_use.tool_use_id().to_string();
463            self.pending_tool_name = tool_use.name().to_string();
464            self.pending_tool_input.clear();
465        }
466
467        vec![StreamEvent::ContentBlockStart]
468    }
469
470    fn handle_block_delta(
471        &mut self,
472        delta_event: aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent,
473    ) -> Vec<StreamEvent> {
474        let delta = match delta_event.delta() {
475            Some(d) => d,
476            None => return vec![],
477        };
478
479        if let Ok(text) = delta.as_text() {
480            self.in_text_block = true;
481            self.pending_text.push_str(text);
482            return vec![StreamEvent::TextDelta {
483                text: text.to_string(),
484            }];
485        }
486
487        if let Ok(reasoning) = delta.as_reasoning_content() {
488            if let Ok(text) = reasoning.as_text() {
489                self.pending_reasoning.push_str(text);
490                self.in_reasoning_block = true;
491                return vec![StreamEvent::ReasoningDelta {
492                    text: text.to_string(),
493                }];
494            }
495            if let Ok(sig) = reasoning.as_signature() {
496                self.pending_reasoning_signature = Some(sig.to_string());
497            }
498        }
499
500        if let Ok(tool_delta) = delta.as_tool_use() {
501            self.pending_tool_input.push_str(tool_delta.input());
502        }
503
504        vec![]
505    }
506
507    fn handle_block_stop(&mut self) -> Vec<StreamEvent> {
508        if self.in_tool_block {
509            self.finalize_tool_block();
510        } else if self.in_reasoning_block {
511            self.finalize_reasoning_block();
512        } else if self.in_text_block {
513            self.finalize_text_block();
514        }
515        vec![StreamEvent::ContentBlockStop]
516    }
517
518    fn finalize_tool_block(&mut self) {
519        let arguments = if self.pending_tool_input.trim().is_empty() {
520            tracing::warn!(
521                tool_name = %self.pending_tool_name,
522                tool_id = %self.pending_tool_id,
523                "Streamed tool use block had no input deltas, defaulting to empty object"
524            );
525            serde_json::Value::Object(Default::default())
526        } else {
527            serde_json::from_str(&self.pending_tool_input).unwrap_or_else(|e| {
528                tracing::warn!(
529                    tool_name = %self.pending_tool_name,
530                    input = %self.pending_tool_input,
531                    error = ?e,
532                    "Failed to parse streamed tool input as JSON"
533                );
534                serde_json::Value::Object(Default::default())
535            })
536        };
537        self.content_blocks.push(ContentBlock::ToolUse(ToolUseData {
538            id: std::mem::take(&mut self.pending_tool_id),
539            name: std::mem::take(&mut self.pending_tool_name),
540            arguments,
541        }));
542        self.pending_tool_input.clear();
543        self.in_tool_block = false;
544    }
545
546    fn finalize_reasoning_block(&mut self) {
547        if !self.pending_reasoning.trim().is_empty() {
548            self.content_blocks
549                .push(ContentBlock::ReasoningContent(ReasoningData {
550                    text: std::mem::take(&mut self.pending_reasoning),
551                    signature: self.pending_reasoning_signature.take(),
552                    blob: None,
553                    raw_json: None,
554                }));
555        }
556        self.in_reasoning_block = false;
557    }
558
559    fn finalize_text_block(&mut self) {
560        if !self.pending_text.trim().is_empty() {
561            self.content_blocks.push(ContentBlock::Text(
562                std::mem::take(&mut self.pending_text).trim().to_string(),
563            ));
564        }
565        self.in_text_block = false;
566    }
567
568    fn handle_message_stop(&mut self, stop: aws_sdk_bedrockruntime::types::MessageStopEvent) {
569        self.stop_reason = match stop.stop_reason() {
570            aws_sdk_bedrockruntime::types::StopReason::EndTurn => StopReason::EndTurn,
571            aws_sdk_bedrockruntime::types::StopReason::MaxTokens => StopReason::MaxTokens,
572            aws_sdk_bedrockruntime::types::StopReason::StopSequence => {
573                StopReason::StopSequence("unknown".to_string())
574            }
575            aws_sdk_bedrockruntime::types::StopReason::ToolUse => StopReason::ToolUse,
576            _ => StopReason::EndTurn,
577        };
578    }
579
580    fn handle_metadata(
581        &mut self,
582        metadata: aws_sdk_bedrockruntime::types::ConverseStreamMetadataEvent,
583    ) {
584        let Some(u) = metadata.usage() else { return };
585        self.usage = TokenUsage {
586            input_tokens: u.input_tokens() as u32,
587            output_tokens: u.output_tokens() as u32,
588            total_tokens: (u.input_tokens() + u.output_tokens()) as u32,
589            cached_prompt_tokens: u.cache_read_input_tokens().map(|v| v as u32),
590            cache_creation_input_tokens: u.cache_write_input_tokens().map(|v| v as u32),
591            reasoning_tokens: None,
592        };
593    }
594
595    fn into_response(self) -> ConversationResponse {
596        ConversationResponse {
597            content: Content::from(self.content_blocks),
598            usage: self.usage,
599            stop_reason: self.stop_reason,
600        }
601    }
602}
603
604#[async_trait::async_trait]
605impl AiProvider for BedrockProvider {
606    fn name(&self) -> &'static str {
607        "AWS Bedrock"
608    }
609
610    fn supported_models(&self) -> HashSet<Model> {
611        HashSet::from([
612            Model::ClaudeOpus46,
613            Model::ClaudeOpus45,
614            Model::ClaudeSonnet45,
615            Model::ClaudeHaiku45,
616            Model::GptOss120b,
617        ])
618    }
619
620    async fn converse(
621        &self,
622        request: ConversationRequest,
623    ) -> Result<ConversationResponse, AiError> {
624        let model_id = self.get_bedrock_model_id(&request.model.model)?;
625        let bedrock_messages =
626            self.convert_to_bedrock_messages(&request.messages, request.model.model)?;
627
628        tracing::debug!(?model_id, "Using Bedrock Converse API");
629
630        let mut converse_request = self
631            .client
632            .converse()
633            .model_id(&model_id)
634            .system(SystemContentBlock::Text(request.system_prompt));
635
636        if request.model.model.supports_prompt_caching() {
637            converse_request =
638                converse_request.system(SystemContentBlock::CachePoint(Self::build_cache_point()?));
639        }
640
641        converse_request = converse_request.set_messages(Some(bedrock_messages));
642
643        let mut inference_config_builder =
644            aws_sdk_bedrockruntime::types::InferenceConfiguration::builder();
645
646        if let Some(max_tokens) = request.model.max_tokens {
647            inference_config_builder = inference_config_builder.max_tokens(max_tokens as i32);
648        }
649
650        if let Some(temperature) = request.model.temperature {
651            inference_config_builder = inference_config_builder.temperature(temperature);
652        }
653
654        if let Some(top_p) = request.model.top_p {
655            inference_config_builder = inference_config_builder.top_p(top_p);
656        }
657
658        if !request.stop_sequences.is_empty() {
659            inference_config_builder =
660                inference_config_builder.set_stop_sequences(Some(request.stop_sequences));
661        }
662
663        converse_request = converse_request.inference_config(inference_config_builder.build());
664        converse_request = self.apply_additional_model_fields(&request.model, converse_request);
665
666        if !request.tools.is_empty() {
667            let bedrock_tools: Vec<Tool> = request
668                .tools
669                .iter()
670                .map(|tool| {
671                    Tool::ToolSpec(
672                        ToolSpecification::builder()
673                            .name(&tool.name)
674                            .description(&tool.description)
675                            .input_schema(ToolInputSchema::Json(to_doc(tool.input_schema.clone())))
676                            .build()
677                            .expect("Failed to build tool spec"),
678                    )
679                })
680                .collect();
681
682            let mut tool_config_builder =
683                ToolConfiguration::builder().set_tools(Some(bedrock_tools));
684
685            if request.model.model.supports_prompt_caching() {
686                tool_config_builder =
687                    tool_config_builder.tools(Tool::CachePoint(Self::build_cache_point()?));
688            }
689
690            let tool_config = tool_config_builder
691                .build()
692                .expect("Failed to build tool config");
693            converse_request = converse_request.tool_config(tool_config);
694        }
695
696        tracing::debug!(?converse_request, "Sending bedrock request");
697        let response = converse_request.send().await.map_err(|e| {
698            tracing::warn!(?e, "Bedrock converse failed");
699
700            let e = e.into_service_error();
701            match e {
702                ConverseError::ThrottlingException(e) => AiError::Retryable(anyhow::anyhow!(e)),
703                ConverseError::ServiceUnavailableException(e) => {
704                    AiError::Retryable(anyhow::anyhow!(e))
705                }
706                ConverseError::InternalServerException(e) => AiError::Retryable(anyhow::anyhow!(e)),
707                ConverseError::ModelTimeoutException(e) => AiError::Retryable(anyhow::anyhow!(e)),
708
709                ConverseError::ResourceNotFoundException(e) => {
710                    AiError::Terminal(anyhow::anyhow!(e))
711                }
712                ConverseError::AccessDeniedException(e) => AiError::Terminal(anyhow::anyhow!(e)),
713                ConverseError::ModelErrorException(e) => AiError::Terminal(anyhow::anyhow!(e)),
714                ConverseError::ModelNotReadyException(e) => AiError::Terminal(anyhow::anyhow!(e)),
715                ConverseError::ValidationException(e) => {
716                    let error_message = format!("{}", e).to_lowercase();
717                    let is_input_too_long = ["too long"]
718                        .iter()
719                        .any(|keyword| error_message.contains(keyword));
720
721                    if is_input_too_long {
722                        AiError::InputTooLong(anyhow::anyhow!(e))
723                    } else {
724                        AiError::Terminal(anyhow::anyhow!(e))
725                    }
726                }
727                _ => AiError::Terminal(anyhow::anyhow!("Unknown error from bedrock: {e:?}")),
728            }
729        })?;
730
731        tracing::debug!("Full response: {:?}", response);
732
733        let usage = if let Some(usage) = response.usage {
734            TokenUsage {
735                input_tokens: usage.input_tokens() as u32,
736                output_tokens: usage.output_tokens() as u32,
737                total_tokens: (usage.input_tokens() + usage.output_tokens()) as u32,
738                cached_prompt_tokens: usage.cache_read_input_tokens().map(|v| v as u32),
739                cache_creation_input_tokens: usage.cache_write_input_tokens().map(|v| v as u32),
740                reasoning_tokens: None,
741            }
742        } else {
743            TokenUsage::empty()
744        };
745
746        let stop_reason = match response.stop_reason {
747            aws_sdk_bedrockruntime::types::StopReason::EndTurn => StopReason::EndTurn,
748            aws_sdk_bedrockruntime::types::StopReason::MaxTokens => StopReason::MaxTokens,
749            aws_sdk_bedrockruntime::types::StopReason::StopSequence => {
750                StopReason::StopSequence("unknown".to_string())
751            }
752            aws_sdk_bedrockruntime::types::StopReason::ToolUse => StopReason::ToolUse,
753            _ => StopReason::EndTurn,
754        };
755
756        let message = response
757            .output
758            .ok_or_else(|| AiError::Terminal(anyhow::anyhow!("No output in response")))?
759            .as_message()
760            .map_err(|_| AiError::Terminal(anyhow::anyhow!("Output is not a message")))?
761            .clone();
762
763        tracing::debug!("Message content blocks: {:?}", message.content());
764
765        let content = self.extract_content_blocks(message.clone());
766
767        Ok(ConversationResponse {
768            content,
769            usage,
770            stop_reason,
771        })
772    }
773
774    async fn converse_stream(
775        &self,
776        request: ConversationRequest,
777    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, AiError>> + Send>>, AiError> {
778        let model_id = self.get_bedrock_model_id(&request.model.model)?;
779        let bedrock_messages =
780            self.convert_to_bedrock_messages(&request.messages, request.model.model)?;
781
782        tracing::debug!(?model_id, "Using Bedrock Converse Stream API");
783
784        let mut stream_request = self
785            .client
786            .converse_stream()
787            .model_id(&model_id)
788            .system(SystemContentBlock::Text(request.system_prompt));
789
790        if request.model.model.supports_prompt_caching() {
791            stream_request =
792                stream_request.system(SystemContentBlock::CachePoint(Self::build_cache_point()?));
793        }
794
795        stream_request = stream_request.set_messages(Some(bedrock_messages));
796
797        let mut inference_config_builder =
798            aws_sdk_bedrockruntime::types::InferenceConfiguration::builder();
799
800        if let Some(max_tokens) = request.model.max_tokens {
801            inference_config_builder = inference_config_builder.max_tokens(max_tokens as i32);
802        }
803
804        if let Some(temperature) = request.model.temperature {
805            inference_config_builder = inference_config_builder.temperature(temperature);
806        }
807
808        if let Some(top_p) = request.model.top_p {
809            inference_config_builder = inference_config_builder.top_p(top_p);
810        }
811
812        if !request.stop_sequences.is_empty() {
813            inference_config_builder =
814                inference_config_builder.set_stop_sequences(Some(request.stop_sequences));
815        }
816
817        stream_request = stream_request.inference_config(inference_config_builder.build());
818        stream_request = self.apply_additional_model_fields_stream(&request.model, stream_request);
819
820        if !request.tools.is_empty() {
821            let bedrock_tools: Vec<Tool> = request
822                .tools
823                .iter()
824                .map(|tool| {
825                    Tool::ToolSpec(
826                        ToolSpecification::builder()
827                            .name(&tool.name)
828                            .description(&tool.description)
829                            .input_schema(ToolInputSchema::Json(to_doc(tool.input_schema.clone())))
830                            .build()
831                            .expect("Failed to build tool spec"),
832                    )
833                })
834                .collect();
835
836            let mut tool_config_builder =
837                ToolConfiguration::builder().set_tools(Some(bedrock_tools));
838
839            if request.model.model.supports_prompt_caching() {
840                tool_config_builder =
841                    tool_config_builder.tools(Tool::CachePoint(Self::build_cache_point()?));
842            }
843
844            let tool_config = tool_config_builder
845                .build()
846                .expect("Failed to build tool config");
847            stream_request = stream_request.tool_config(tool_config);
848        }
849
850        let response = stream_request.send().await.map_err(|e| {
851            tracing::warn!(?e, "Bedrock converse_stream failed");
852            let e = e.into_service_error();
853            match e {
854                ConverseStreamError::ThrottlingException(e) => {
855                    AiError::Retryable(anyhow::anyhow!(e))
856                }
857                ConverseStreamError::ServiceUnavailableException(e) => {
858                    AiError::Retryable(anyhow::anyhow!(e))
859                }
860                ConverseStreamError::InternalServerException(e) => {
861                    AiError::Retryable(anyhow::anyhow!(e))
862                }
863                ConverseStreamError::ModelTimeoutException(e) => {
864                    AiError::Retryable(anyhow::anyhow!(e))
865                }
866                ConverseStreamError::ResourceNotFoundException(e) => {
867                    AiError::Terminal(anyhow::anyhow!(e))
868                }
869                ConverseStreamError::AccessDeniedException(e) => {
870                    AiError::Terminal(anyhow::anyhow!(e))
871                }
872                ConverseStreamError::ModelErrorException(e) => {
873                    AiError::Terminal(anyhow::anyhow!(e))
874                }
875                ConverseStreamError::ModelNotReadyException(e) => {
876                    AiError::Terminal(anyhow::anyhow!(e))
877                }
878                ConverseStreamError::ValidationException(e) => {
879                    let error_message = format!("{}", e).to_lowercase();
880                    if error_message.contains("too long") {
881                        AiError::InputTooLong(anyhow::anyhow!(e))
882                    } else {
883                        AiError::Terminal(anyhow::anyhow!(e))
884                    }
885                }
886                _ => AiError::Terminal(anyhow::anyhow!("Unknown error from bedrock stream: {e:?}")),
887            }
888        })?;
889
890        let mut event_stream = response.stream;
891
892        let stream = async_stream::stream! {
893            let mut state = BedrockStreamAccumulator::new();
894
895            loop {
896                let recv_result = event_stream.recv().await;
897                let Ok(maybe_event) = recv_result else {
898                    tracing::warn!("Error in bedrock stream");
899                    yield Err(AiError::Retryable(anyhow::anyhow!("Bedrock stream error")));
900                    return;
901                };
902                let Some(event) = maybe_event else { break };
903                for stream_event in state.process_event(event) {
904                    yield Ok(stream_event);
905                }
906            }
907
908            yield Ok(StreamEvent::MessageComplete { response: state.into_response() });
909        };
910
911        Ok(Box::pin(stream))
912    }
913
914    fn get_cost(&self, model: &Model) -> Cost {
915        match model {
916            Model::ClaudeSonnet45 => Cost::new(3.0, 15.0, 3.75, 0.3),
917            Model::ClaudeHaiku45 => Cost::new(1.0, 5.0, 1.25, 0.1),
918            Model::ClaudeOpus46 => Cost::new(5.0, 25.0, 6.25, 0.5),
919            Model::ClaudeOpus45 => Cost::new(5.0, 25.0, 6.25, 0.5),
920            Model::GptOss120b => Cost::new(0.15, 0.6, 0.0, 0.0),
921            _ => Cost::new(0.0, 0.0, 0.0, 0.0),
922        }
923    }
924}
925
926#[cfg(test)]
927mod tests {
928    use super::*;
929    use crate::ai::tests::{
930        test_hello_world, test_multiple_tool_calls, test_reasoning_conversation,
931        test_reasoning_with_tools, test_tool_usage,
932    };
933
934    async fn create_bedrock_provider() -> anyhow::Result<BedrockProvider> {
935        let bedrock_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
936            .region(aws_config::Region::new("us-west-2"))
937            .load()
938            .await;
939        let bedrock_client = aws_sdk_bedrockruntime::Client::new(&bedrock_config);
940        Ok(BedrockProvider::new(bedrock_client))
941    }
942
943    #[tokio::test]
944    #[ignore = "requires AWS credentials"]
945    async fn test_bedrock_hello_world() {
946        let provider = match create_bedrock_provider().await {
947            Ok(provider) => provider,
948            Err(e) => {
949                tracing::error!(?e, "Failed to create Bedrock provider");
950                panic!("Failed to create Bedrock provider: {e:?}");
951            }
952        };
953
954        if let Err(e) = test_hello_world(provider).await {
955            tracing::error!(?e, "Bedrock hello world test failed");
956            panic!("Bedrock hello world test failed: {e:?}");
957        }
958    }
959
960    #[tokio::test]
961    #[ignore = "requires AWS credentials"]
962    async fn test_bedrock_reasoning_conversation() {
963        let provider = match create_bedrock_provider().await {
964            Ok(provider) => provider,
965            Err(e) => {
966                tracing::error!(?e, "Failed to create Bedrock provider");
967                panic!("Failed to create Bedrock provider: {e:?}");
968            }
969        };
970
971        if let Err(e) = test_reasoning_conversation(provider).await {
972            tracing::error!(?e, "Bedrock reasoning conversation test failed");
973            panic!("Bedrock reasoning conversation test failed: {e:?}");
974        }
975    }
976
977    #[tokio::test]
978    #[ignore = "requires AWS credentials"]
979    async fn test_bedrock_tool_usage() {
980        let provider = match create_bedrock_provider().await {
981            Ok(provider) => provider,
982            Err(e) => {
983                tracing::error!(?e, "Failed to create Bedrock provider");
984                panic!("Failed to create Bedrock provider: {e:?}");
985            }
986        };
987
988        if let Err(e) = test_tool_usage(provider).await {
989            tracing::error!(?e, "Bedrock tool usage test failed");
990            panic!("Bedrock tool usage test failed: {e:?}");
991        }
992    }
993
994    #[tokio::test]
995    #[ignore = "requires AWS credentials"]
996    async fn test_bedrock_reasoning_with_tools() {
997        let provider = match create_bedrock_provider().await {
998            Ok(provider) => provider,
999            Err(e) => {
1000                tracing::error!(?e, "Failed to create Bedrock provider");
1001                panic!("Failed to create Bedrock provider: {e:?}");
1002            }
1003        };
1004
1005        if let Err(e) = test_reasoning_with_tools(provider).await {
1006            tracing::error!(?e, "Bedrock reasoning with tools test failed");
1007            panic!("Bedrock reasoning with tools test failed: {e:?}");
1008        }
1009    }
1010
1011    #[tokio::test]
1012    #[ignore = "requires AWS credentials"]
1013    async fn test_bedrock_multiple_tool_calls() {
1014        let provider = match create_bedrock_provider().await {
1015            Ok(provider) => provider,
1016            Err(e) => {
1017                tracing::error!(?e, "Failed to create Bedrock provider");
1018                panic!("Failed to create Bedrock provider: {e:?}");
1019            }
1020        };
1021
1022        if let Err(e) = test_multiple_tool_calls(provider).await {
1023            tracing::error!(?e, "Bedrock reasoning with tools test failed");
1024            panic!("Bedrock reasoning with tools test failed: {e:?}");
1025        }
1026    }
1027}