1use std::collections::HashSet;
2use std::pin::Pin;
3
4use tokio_stream::Stream;
5
6use base64::Engine;
7
8use aws_sdk_bedrockruntime::{
9 operation::converse::{builders::ConverseFluentBuilder, ConverseError},
10 operation::converse_stream::{builders::ConverseStreamFluentBuilder, ConverseStreamError},
11 types::ConverseStreamOutput as BedrockStreamEvent,
12 types::{
13 CachePointBlock, ContentBlock as BedrockContentBlock, ImageBlock, ImageFormat, ImageSource,
14 Message as BedrockMessage, ReasoningContentBlock, ReasoningTextBlock, SystemContentBlock,
15 Tool, ToolConfiguration, ToolInputSchema, ToolResultBlock, ToolResultContentBlock,
16 ToolSpecification, ToolUseBlock,
17 },
18 Client as BedrockClient,
19};
20use aws_smithy_types::Blob;
21use serde_json::json;
22
23use crate::ai::{error::AiError, provider::AiProvider, types::*};
24use crate::ai::{
25 json::{from_doc, to_doc},
26 model::Model,
27};
28
29#[derive(Clone)]
30pub struct BedrockProvider {
31 client: BedrockClient,
32}
33
34impl BedrockProvider {
35 pub fn new(client: BedrockClient) -> Self {
36 Self { client }
37 }
38
39 fn get_bedrock_model_id(&self, model: &Model) -> Result<String, AiError> {
40 let model_id = match model {
41 Model::ClaudeSonnet45 => "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
42 Model::ClaudeHaiku45 => "us.anthropic.claude-haiku-4-5-20251001-v1:0",
43 Model::ClaudeOpus46 => "global.anthropic.claude-opus-4-6-v1",
44 Model::ClaudeOpus45 => "global.anthropic.claude-opus-4-5-20251101-v1:0",
45 Model::GptOss120b => "openai.gpt-oss-120b-1:0",
46 _ => {
47 return Err(AiError::Terminal(anyhow::anyhow!(
48 "Model {} is not supported in bedrock",
49 model.name()
50 )))
51 }
52 };
53 Ok(model_id.to_string())
54 }
55
56 fn convert_to_bedrock_messages(
57 &self,
58 messages: &[Message],
59 model: Model,
60 ) -> Result<Vec<BedrockMessage>, AiError> {
61 let mut bedrock_messages = Vec::new();
62
63 for (msg_index, msg) in messages.iter().enumerate() {
64 let role = match msg.role {
65 MessageRole::User => aws_sdk_bedrockruntime::types::ConversationRole::User,
66 MessageRole::Assistant => {
67 aws_sdk_bedrockruntime::types::ConversationRole::Assistant
68 }
69 };
70
71 let mut content_blocks = Vec::new();
72 for block in msg.content.blocks() {
73 match block {
74 ContentBlock::Text(text) => {
75 if !text.trim().is_empty() {
76 content_blocks.push(BedrockContentBlock::Text(text.trim().to_string()));
77 }
78 }
79 ContentBlock::ReasoningContent(reasoning) => {
80 let reasoning_content = if let Some(blob) = &reasoning.blob {
81 ReasoningContentBlock::RedactedContent(Blob::new(blob.clone()))
82 } else {
83 let mut text_block_builder =
84 ReasoningTextBlock::builder().text(&reasoning.text);
85
86 if let Some(signature) = &reasoning.signature {
87 text_block_builder = text_block_builder.signature(signature);
88 }
89
90 let text_block = text_block_builder.build().map_err(|e| {
91 AiError::Terminal(anyhow::anyhow!(
92 "Failed to build reasoning text block: {:?}",
93 e
94 ))
95 })?;
96
97 ReasoningContentBlock::ReasoningText(text_block)
98 };
99
100 content_blocks
101 .push(BedrockContentBlock::ReasoningContent(reasoning_content));
102 }
103 ContentBlock::ToolUse(tool_use) => {
104 let args = if tool_use.arguments.is_null() {
105 tracing::warn!(
106 tool_name = %tool_use.name,
107 "Null tool arguments in conversation history, substituting empty object"
108 );
109 serde_json::Value::Object(Default::default())
110 } else {
111 tool_use.arguments.clone()
112 };
113 let tool_use_block = ToolUseBlock::builder()
114 .tool_use_id(&tool_use.id)
115 .name(&tool_use.name)
116 .input(to_doc(args))
117 .build()
118 .map_err(|e| {
119 AiError::Terminal(anyhow::anyhow!(
120 "Failed to build tool use block: {:?}",
121 e
122 ))
123 })?;
124 content_blocks.push(BedrockContentBlock::ToolUse(tool_use_block));
125 }
126 ContentBlock::ToolResult(tool_result) => {
127 let tool_result_block = ToolResultBlock::builder()
128 .tool_use_id(&tool_result.tool_use_id)
129 .content(ToolResultContentBlock::Text(tool_result.content.clone()))
130 .build()
131 .map_err(|e| {
132 AiError::Terminal(anyhow::anyhow!(
133 "Failed to build tool result block: {:?}",
134 e
135 ))
136 })?;
137 content_blocks.push(BedrockContentBlock::ToolResult(tool_result_block));
138 }
139 ContentBlock::Image(image) => {
140 content_blocks.push(BedrockContentBlock::Image(build_bedrock_image_block(
141 image,
142 )?));
143 }
144 }
145 }
146
147 if content_blocks.is_empty() {
148 content_blocks.push(BedrockContentBlock::Text("...".to_string()));
149 }
150
151 let (reasoning, non_reasoning): (Vec<_>, Vec<_>) = content_blocks
154 .into_iter()
155 .partition(|b| matches!(b, BedrockContentBlock::ReasoningContent(_)));
156 content_blocks = reasoning;
157 content_blocks.extend(non_reasoning);
158
159 let last_is_reasoning = content_blocks
160 .last()
161 .is_some_and(|b| matches!(b, BedrockContentBlock::ReasoningContent(_)));
162 if model.supports_prompt_caching()
163 && messages.len() >= 2
164 && msg_index == messages.len() - 2
165 && !last_is_reasoning
166 {
167 content_blocks.push(BedrockContentBlock::CachePoint(Self::build_cache_point()?));
168 }
169
170 bedrock_messages.push(
171 BedrockMessage::builder()
172 .role(role)
173 .set_content(Some(content_blocks))
174 .build()
175 .map_err(|e| {
176 AiError::Terminal(anyhow::anyhow!("Failed to build message: {:?}", e))
177 })?,
178 );
179 }
180
181 Ok(bedrock_messages)
182 }
183}
184
185fn map_image_format(media_type: &str) -> Result<ImageFormat, AiError> {
186 match media_type {
187 "image/png" => Ok(ImageFormat::Png),
188 "image/jpeg" => Ok(ImageFormat::Jpeg),
189 "image/gif" => Ok(ImageFormat::Gif),
190 "image/webp" => Ok(ImageFormat::Webp),
191 other => Err(AiError::Terminal(anyhow::anyhow!(
192 "Unsupported image format: {other}"
193 ))),
194 }
195}
196
197fn build_bedrock_image_block(image: &ImageData) -> Result<ImageBlock, AiError> {
198 let bytes = base64::engine::general_purpose::STANDARD
199 .decode(&image.data)
200 .map_err(|e| AiError::Terminal(anyhow::anyhow!("Failed to decode image base64: {e:?}")))?;
201
202 let format = map_image_format(&image.media_type)?;
203
204 ImageBlock::builder()
205 .format(format)
206 .source(ImageSource::Bytes(Blob::new(bytes)))
207 .build()
208 .map_err(|e| AiError::Terminal(anyhow::anyhow!("Failed to build image block: {e:?}")))
209}
210
211impl BedrockProvider {
212 fn extract_content_blocks(&self, message: BedrockMessage) -> Content {
213 let mut content_blocks = Vec::new();
214
215 tracing::debug!("Processing {} content blocks", message.content().len());
216
217 for (i, content) in message.content().iter().enumerate() {
218 tracing::debug!("Content block {}: {:?}", i, content);
219
220 match content {
221 BedrockContentBlock::Text(text) => {
222 tracing::debug!("Text block: {}", text);
223 content_blocks.push(ContentBlock::Text(text.clone()));
224 }
225 BedrockContentBlock::ReasoningContent(block) => {
226 let reasoning_data = if block.is_reasoning_text() {
227 let block = block.as_reasoning_text().unwrap();
228 ReasoningData {
229 text: block.text.clone(),
230 signature: block.signature.clone(),
231 blob: None,
232 raw_json: None,
233 }
234 } else {
235 let block = block.as_redacted_content().unwrap();
236 ReasoningData {
237 text: "** Redacted reasoning content **".to_string(),
238 signature: None,
239 blob: Some(block.clone().into_inner()),
240 raw_json: None,
241 }
242 };
243 content_blocks.push(ContentBlock::ReasoningContent(reasoning_data));
244 }
245 BedrockContentBlock::ToolUse(tool_use) => {
246 let tool_use_data = ToolUseData {
247 id: tool_use.tool_use_id().to_string(),
248 name: tool_use.name().to_string(),
249 arguments: from_doc(tool_use.input().clone()),
250 };
251 content_blocks.push(ContentBlock::ToolUse(tool_use_data));
252 }
253 _ => (),
254 }
255 }
256
257 Content::from(content_blocks)
258 }
259
260 fn build_cache_point() -> Result<CachePointBlock, AiError> {
261 CachePointBlock::builder()
262 .r#type(aws_sdk_bedrockruntime::types::CachePointType::Default)
263 .build()
264 .map_err(|e| {
265 AiError::Terminal(anyhow::anyhow!(
266 "Failed to build cache point block: {:?}",
267 e
268 ))
269 })
270 }
271
272 fn effective_reasoning_budget_tokens(model: &ModelSettings) -> Option<u32> {
273 let requested_budget = model.reasoning_budget.get_max_tokens()?;
274
275 let Some(max_tokens) = model.max_tokens else {
276 return Some(requested_budget);
277 };
278
279 if max_tokens <= 1 {
281 tracing::warn!(
282 max_tokens,
283 requested_budget,
284 "Skipping reasoning budget because max_tokens is too low"
285 );
286 return None;
287 }
288
289 let capped_budget = max_tokens.saturating_sub(1);
290 if requested_budget > capped_budget {
291 tracing::warn!(
292 requested_budget,
293 max_tokens,
294 capped_budget,
295 "Capping reasoning budget so it remains below max_tokens"
296 );
297 Some(capped_budget)
298 } else {
299 Some(requested_budget)
300 }
301 }
302
303 fn apply_additional_model_fields(
304 &self,
305 model: &ModelSettings,
306 request: ConverseFluentBuilder,
307 ) -> ConverseFluentBuilder {
308 let mut additional_fields = serde_json::Map::new();
309
310 match model.model {
311 Model::ClaudeOpus46 => {
312 if let Some(effort) = model.reasoning_budget.get_effort_level() {
313 tracing::info!("Enabling adaptive reasoning with effort '{effort}'");
314 additional_fields.insert("thinking".to_string(), json!({"type": "adaptive"}));
315 additional_fields
316 .insert("output_config".to_string(), json!({"effort": effort}));
317 }
318 }
319 Model::ClaudeOpus45 | Model::ClaudeSonnet45 => {
320 if let Some(reasoning_budget) = Self::effective_reasoning_budget_tokens(model) {
321 tracing::info!("Enabling reasoning with budget {} tokens", reasoning_budget);
322 additional_fields.insert(
323 "thinking".to_string(),
324 json!({
325 "type": "enabled",
326 "budget_tokens": reasoning_budget
327 }),
328 );
329 }
330 }
331 _ => {}
332 }
333
334 if matches!(model.model, Model::ClaudeSonnet45) {
335 tracing::info!("Enabling 1M context beta for Claude Sonnet 4.5");
336 additional_fields.insert(
337 "anthropic_beta".to_string(),
338 json!(["context-1m-2025-08-07"]),
339 );
340 }
341
342 if additional_fields.is_empty() {
343 return request;
344 }
345
346 let additional_params = serde_json::Value::Object(additional_fields);
347 tracing::debug!("Additional model request fields: {:?}", additional_params);
348 request.additional_model_request_fields(to_doc(additional_params))
349 }
350
351 fn apply_additional_model_fields_stream(
352 &self,
353 model: &ModelSettings,
354 request: ConverseStreamFluentBuilder,
355 ) -> ConverseStreamFluentBuilder {
356 let mut additional_fields = serde_json::Map::new();
357
358 match model.model {
359 Model::ClaudeOpus46 => {
360 if let Some(effort) = model.reasoning_budget.get_effort_level() {
361 tracing::info!("Enabling adaptive reasoning with effort '{effort}'");
362 additional_fields.insert("thinking".to_string(), json!({"type": "adaptive"}));
363 additional_fields
364 .insert("output_config".to_string(), json!({"effort": effort}));
365 }
366 }
367 Model::ClaudeOpus45 | Model::ClaudeSonnet45 => {
368 if let Some(reasoning_budget) = Self::effective_reasoning_budget_tokens(model) {
369 tracing::info!("Enabling reasoning with budget {} tokens", reasoning_budget);
370 additional_fields.insert(
371 "thinking".to_string(),
372 json!({
373 "type": "enabled",
374 "budget_tokens": reasoning_budget
375 }),
376 );
377 }
378 }
379 _ => {}
380 }
381
382 if matches!(model.model, Model::ClaudeSonnet45) {
383 tracing::info!("Enabling 1M context beta for Claude Sonnet 4.5");
384 additional_fields.insert(
385 "anthropic_beta".to_string(),
386 json!(["context-1m-2025-08-07"]),
387 );
388 }
389
390 if additional_fields.is_empty() {
391 return request;
392 }
393
394 let additional_params = serde_json::Value::Object(additional_fields);
395 tracing::debug!("Additional model request fields: {:?}", additional_params);
396 request.additional_model_request_fields(to_doc(additional_params))
397 }
398}
399
400struct BedrockStreamAccumulator {
401 content_blocks: Vec<ContentBlock>,
402 pending_text: String,
403 pending_reasoning: String,
404 pending_tool_id: String,
405 pending_tool_name: String,
406 pending_tool_input: String,
407 in_text_block: bool,
408 in_reasoning_block: bool,
409 in_tool_block: bool,
410 pending_reasoning_signature: Option<String>,
411 usage: TokenUsage,
412 stop_reason: StopReason,
413}
414
415impl BedrockStreamAccumulator {
416 fn new() -> Self {
417 Self {
418 content_blocks: Vec::new(),
419 pending_text: String::new(),
420 pending_reasoning: String::new(),
421 pending_tool_id: String::new(),
422 pending_tool_name: String::new(),
423 pending_tool_input: String::new(),
424 in_text_block: false,
425 in_reasoning_block: false,
426 in_tool_block: false,
427 pending_reasoning_signature: None,
428 usage: TokenUsage::empty(),
429 stop_reason: StopReason::EndTurn,
430 }
431 }
432
433 fn process_event(&mut self, event: BedrockStreamEvent) -> Vec<StreamEvent> {
434 match event {
435 BedrockStreamEvent::ContentBlockStart(start) => self.handle_block_start(start),
436 BedrockStreamEvent::ContentBlockDelta(delta) => self.handle_block_delta(delta),
437 BedrockStreamEvent::ContentBlockStop(_) => self.handle_block_stop(),
438 BedrockStreamEvent::MessageStop(stop) => {
439 self.handle_message_stop(stop);
440 vec![]
441 }
442 BedrockStreamEvent::Metadata(metadata) => {
443 self.handle_metadata(metadata);
444 vec![]
445 }
446 _ => vec![],
447 }
448 }
449
450 fn handle_block_start(
451 &mut self,
452 start: aws_sdk_bedrockruntime::types::ContentBlockStartEvent,
453 ) -> Vec<StreamEvent> {
454 let content_start = match start.start() {
455 Some(s) => s,
456 None => return vec![StreamEvent::ContentBlockStart],
457 };
458
459 if content_start.is_tool_use() {
460 let tool_use = content_start.as_tool_use().unwrap();
461 self.in_tool_block = true;
462 self.pending_tool_id = tool_use.tool_use_id().to_string();
463 self.pending_tool_name = tool_use.name().to_string();
464 self.pending_tool_input.clear();
465 }
466
467 vec![StreamEvent::ContentBlockStart]
468 }
469
470 fn handle_block_delta(
471 &mut self,
472 delta_event: aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent,
473 ) -> Vec<StreamEvent> {
474 let delta = match delta_event.delta() {
475 Some(d) => d,
476 None => return vec![],
477 };
478
479 if let Ok(text) = delta.as_text() {
480 self.in_text_block = true;
481 self.pending_text.push_str(text);
482 return vec![StreamEvent::TextDelta {
483 text: text.to_string(),
484 }];
485 }
486
487 if let Ok(reasoning) = delta.as_reasoning_content() {
488 if let Ok(text) = reasoning.as_text() {
489 self.pending_reasoning.push_str(text);
490 self.in_reasoning_block = true;
491 return vec![StreamEvent::ReasoningDelta {
492 text: text.to_string(),
493 }];
494 }
495 if let Ok(sig) = reasoning.as_signature() {
496 self.pending_reasoning_signature = Some(sig.to_string());
497 }
498 }
499
500 if let Ok(tool_delta) = delta.as_tool_use() {
501 self.pending_tool_input.push_str(tool_delta.input());
502 }
503
504 vec![]
505 }
506
507 fn handle_block_stop(&mut self) -> Vec<StreamEvent> {
508 if self.in_tool_block {
509 self.finalize_tool_block();
510 } else if self.in_reasoning_block {
511 self.finalize_reasoning_block();
512 } else if self.in_text_block {
513 self.finalize_text_block();
514 }
515 vec![StreamEvent::ContentBlockStop]
516 }
517
518 fn finalize_tool_block(&mut self) {
519 let arguments = if self.pending_tool_input.trim().is_empty() {
520 tracing::warn!(
521 tool_name = %self.pending_tool_name,
522 tool_id = %self.pending_tool_id,
523 "Streamed tool use block had no input deltas, defaulting to empty object"
524 );
525 serde_json::Value::Object(Default::default())
526 } else {
527 serde_json::from_str(&self.pending_tool_input).unwrap_or_else(|e| {
528 tracing::warn!(
529 tool_name = %self.pending_tool_name,
530 input = %self.pending_tool_input,
531 error = ?e,
532 "Failed to parse streamed tool input as JSON"
533 );
534 serde_json::Value::Object(Default::default())
535 })
536 };
537 self.content_blocks.push(ContentBlock::ToolUse(ToolUseData {
538 id: std::mem::take(&mut self.pending_tool_id),
539 name: std::mem::take(&mut self.pending_tool_name),
540 arguments,
541 }));
542 self.pending_tool_input.clear();
543 self.in_tool_block = false;
544 }
545
546 fn finalize_reasoning_block(&mut self) {
547 if !self.pending_reasoning.trim().is_empty() {
548 self.content_blocks
549 .push(ContentBlock::ReasoningContent(ReasoningData {
550 text: std::mem::take(&mut self.pending_reasoning),
551 signature: self.pending_reasoning_signature.take(),
552 blob: None,
553 raw_json: None,
554 }));
555 }
556 self.in_reasoning_block = false;
557 }
558
559 fn finalize_text_block(&mut self) {
560 if !self.pending_text.trim().is_empty() {
561 self.content_blocks.push(ContentBlock::Text(
562 std::mem::take(&mut self.pending_text).trim().to_string(),
563 ));
564 }
565 self.in_text_block = false;
566 }
567
568 fn handle_message_stop(&mut self, stop: aws_sdk_bedrockruntime::types::MessageStopEvent) {
569 self.stop_reason = match stop.stop_reason() {
570 aws_sdk_bedrockruntime::types::StopReason::EndTurn => StopReason::EndTurn,
571 aws_sdk_bedrockruntime::types::StopReason::MaxTokens => StopReason::MaxTokens,
572 aws_sdk_bedrockruntime::types::StopReason::StopSequence => {
573 StopReason::StopSequence("unknown".to_string())
574 }
575 aws_sdk_bedrockruntime::types::StopReason::ToolUse => StopReason::ToolUse,
576 _ => StopReason::EndTurn,
577 };
578 }
579
580 fn handle_metadata(
581 &mut self,
582 metadata: aws_sdk_bedrockruntime::types::ConverseStreamMetadataEvent,
583 ) {
584 let Some(u) = metadata.usage() else { return };
585 self.usage = TokenUsage {
586 input_tokens: u.input_tokens() as u32,
587 output_tokens: u.output_tokens() as u32,
588 total_tokens: (u.input_tokens() + u.output_tokens()) as u32,
589 cached_prompt_tokens: u.cache_read_input_tokens().map(|v| v as u32),
590 cache_creation_input_tokens: u.cache_write_input_tokens().map(|v| v as u32),
591 reasoning_tokens: None,
592 };
593 }
594
595 fn into_response(self) -> ConversationResponse {
596 ConversationResponse {
597 content: Content::from(self.content_blocks),
598 usage: self.usage,
599 stop_reason: self.stop_reason,
600 }
601 }
602}
603
604#[async_trait::async_trait]
605impl AiProvider for BedrockProvider {
606 fn name(&self) -> &'static str {
607 "AWS Bedrock"
608 }
609
610 fn supported_models(&self) -> HashSet<Model> {
611 HashSet::from([
612 Model::ClaudeOpus46,
613 Model::ClaudeOpus45,
614 Model::ClaudeSonnet45,
615 Model::ClaudeHaiku45,
616 Model::GptOss120b,
617 ])
618 }
619
620 async fn converse(
621 &self,
622 request: ConversationRequest,
623 ) -> Result<ConversationResponse, AiError> {
624 let model_id = self.get_bedrock_model_id(&request.model.model)?;
625 let bedrock_messages =
626 self.convert_to_bedrock_messages(&request.messages, request.model.model)?;
627
628 tracing::debug!(?model_id, "Using Bedrock Converse API");
629
630 let mut converse_request = self
631 .client
632 .converse()
633 .model_id(&model_id)
634 .system(SystemContentBlock::Text(request.system_prompt));
635
636 if request.model.model.supports_prompt_caching() {
637 converse_request =
638 converse_request.system(SystemContentBlock::CachePoint(Self::build_cache_point()?));
639 }
640
641 converse_request = converse_request.set_messages(Some(bedrock_messages));
642
643 let mut inference_config_builder =
644 aws_sdk_bedrockruntime::types::InferenceConfiguration::builder();
645
646 if let Some(max_tokens) = request.model.max_tokens {
647 inference_config_builder = inference_config_builder.max_tokens(max_tokens as i32);
648 }
649
650 if let Some(temperature) = request.model.temperature {
651 inference_config_builder = inference_config_builder.temperature(temperature);
652 }
653
654 if let Some(top_p) = request.model.top_p {
655 inference_config_builder = inference_config_builder.top_p(top_p);
656 }
657
658 if !request.stop_sequences.is_empty() {
659 inference_config_builder =
660 inference_config_builder.set_stop_sequences(Some(request.stop_sequences));
661 }
662
663 converse_request = converse_request.inference_config(inference_config_builder.build());
664 converse_request = self.apply_additional_model_fields(&request.model, converse_request);
665
666 if !request.tools.is_empty() {
667 let bedrock_tools: Vec<Tool> = request
668 .tools
669 .iter()
670 .map(|tool| {
671 Tool::ToolSpec(
672 ToolSpecification::builder()
673 .name(&tool.name)
674 .description(&tool.description)
675 .input_schema(ToolInputSchema::Json(to_doc(tool.input_schema.clone())))
676 .build()
677 .expect("Failed to build tool spec"),
678 )
679 })
680 .collect();
681
682 let mut tool_config_builder =
683 ToolConfiguration::builder().set_tools(Some(bedrock_tools));
684
685 if request.model.model.supports_prompt_caching() {
686 tool_config_builder =
687 tool_config_builder.tools(Tool::CachePoint(Self::build_cache_point()?));
688 }
689
690 let tool_config = tool_config_builder
691 .build()
692 .expect("Failed to build tool config");
693 converse_request = converse_request.tool_config(tool_config);
694 }
695
696 tracing::debug!(?converse_request, "Sending bedrock request");
697 let response = converse_request.send().await.map_err(|e| {
698 tracing::warn!(?e, "Bedrock converse failed");
699
700 let e = e.into_service_error();
701 match e {
702 ConverseError::ThrottlingException(e) => AiError::Retryable(anyhow::anyhow!(e)),
703 ConverseError::ServiceUnavailableException(e) => {
704 AiError::Retryable(anyhow::anyhow!(e))
705 }
706 ConverseError::InternalServerException(e) => AiError::Retryable(anyhow::anyhow!(e)),
707 ConverseError::ModelTimeoutException(e) => AiError::Retryable(anyhow::anyhow!(e)),
708
709 ConverseError::ResourceNotFoundException(e) => {
710 AiError::Terminal(anyhow::anyhow!(e))
711 }
712 ConverseError::AccessDeniedException(e) => AiError::Terminal(anyhow::anyhow!(e)),
713 ConverseError::ModelErrorException(e) => AiError::Terminal(anyhow::anyhow!(e)),
714 ConverseError::ModelNotReadyException(e) => AiError::Terminal(anyhow::anyhow!(e)),
715 ConverseError::ValidationException(e) => {
716 let error_message = format!("{}", e).to_lowercase();
717 let is_input_too_long = ["too long"]
718 .iter()
719 .any(|keyword| error_message.contains(keyword));
720
721 if is_input_too_long {
722 AiError::InputTooLong(anyhow::anyhow!(e))
723 } else {
724 AiError::Terminal(anyhow::anyhow!(e))
725 }
726 }
727 _ => AiError::Terminal(anyhow::anyhow!("Unknown error from bedrock: {e:?}")),
728 }
729 })?;
730
731 tracing::debug!("Full response: {:?}", response);
732
733 let usage = if let Some(usage) = response.usage {
734 TokenUsage {
735 input_tokens: usage.input_tokens() as u32,
736 output_tokens: usage.output_tokens() as u32,
737 total_tokens: (usage.input_tokens() + usage.output_tokens()) as u32,
738 cached_prompt_tokens: usage.cache_read_input_tokens().map(|v| v as u32),
739 cache_creation_input_tokens: usage.cache_write_input_tokens().map(|v| v as u32),
740 reasoning_tokens: None,
741 }
742 } else {
743 TokenUsage::empty()
744 };
745
746 let stop_reason = match response.stop_reason {
747 aws_sdk_bedrockruntime::types::StopReason::EndTurn => StopReason::EndTurn,
748 aws_sdk_bedrockruntime::types::StopReason::MaxTokens => StopReason::MaxTokens,
749 aws_sdk_bedrockruntime::types::StopReason::StopSequence => {
750 StopReason::StopSequence("unknown".to_string())
751 }
752 aws_sdk_bedrockruntime::types::StopReason::ToolUse => StopReason::ToolUse,
753 _ => StopReason::EndTurn,
754 };
755
756 let message = response
757 .output
758 .ok_or_else(|| AiError::Terminal(anyhow::anyhow!("No output in response")))?
759 .as_message()
760 .map_err(|_| AiError::Terminal(anyhow::anyhow!("Output is not a message")))?
761 .clone();
762
763 tracing::debug!("Message content blocks: {:?}", message.content());
764
765 let content = self.extract_content_blocks(message.clone());
766
767 Ok(ConversationResponse {
768 content,
769 usage,
770 stop_reason,
771 })
772 }
773
774 async fn converse_stream(
775 &self,
776 request: ConversationRequest,
777 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, AiError>> + Send>>, AiError> {
778 let model_id = self.get_bedrock_model_id(&request.model.model)?;
779 let bedrock_messages =
780 self.convert_to_bedrock_messages(&request.messages, request.model.model)?;
781
782 tracing::debug!(?model_id, "Using Bedrock Converse Stream API");
783
784 let mut stream_request = self
785 .client
786 .converse_stream()
787 .model_id(&model_id)
788 .system(SystemContentBlock::Text(request.system_prompt));
789
790 if request.model.model.supports_prompt_caching() {
791 stream_request =
792 stream_request.system(SystemContentBlock::CachePoint(Self::build_cache_point()?));
793 }
794
795 stream_request = stream_request.set_messages(Some(bedrock_messages));
796
797 let mut inference_config_builder =
798 aws_sdk_bedrockruntime::types::InferenceConfiguration::builder();
799
800 if let Some(max_tokens) = request.model.max_tokens {
801 inference_config_builder = inference_config_builder.max_tokens(max_tokens as i32);
802 }
803
804 if let Some(temperature) = request.model.temperature {
805 inference_config_builder = inference_config_builder.temperature(temperature);
806 }
807
808 if let Some(top_p) = request.model.top_p {
809 inference_config_builder = inference_config_builder.top_p(top_p);
810 }
811
812 if !request.stop_sequences.is_empty() {
813 inference_config_builder =
814 inference_config_builder.set_stop_sequences(Some(request.stop_sequences));
815 }
816
817 stream_request = stream_request.inference_config(inference_config_builder.build());
818 stream_request = self.apply_additional_model_fields_stream(&request.model, stream_request);
819
820 if !request.tools.is_empty() {
821 let bedrock_tools: Vec<Tool> = request
822 .tools
823 .iter()
824 .map(|tool| {
825 Tool::ToolSpec(
826 ToolSpecification::builder()
827 .name(&tool.name)
828 .description(&tool.description)
829 .input_schema(ToolInputSchema::Json(to_doc(tool.input_schema.clone())))
830 .build()
831 .expect("Failed to build tool spec"),
832 )
833 })
834 .collect();
835
836 let mut tool_config_builder =
837 ToolConfiguration::builder().set_tools(Some(bedrock_tools));
838
839 if request.model.model.supports_prompt_caching() {
840 tool_config_builder =
841 tool_config_builder.tools(Tool::CachePoint(Self::build_cache_point()?));
842 }
843
844 let tool_config = tool_config_builder
845 .build()
846 .expect("Failed to build tool config");
847 stream_request = stream_request.tool_config(tool_config);
848 }
849
850 let response = stream_request.send().await.map_err(|e| {
851 tracing::warn!(?e, "Bedrock converse_stream failed");
852 let e = e.into_service_error();
853 match e {
854 ConverseStreamError::ThrottlingException(e) => {
855 AiError::Retryable(anyhow::anyhow!(e))
856 }
857 ConverseStreamError::ServiceUnavailableException(e) => {
858 AiError::Retryable(anyhow::anyhow!(e))
859 }
860 ConverseStreamError::InternalServerException(e) => {
861 AiError::Retryable(anyhow::anyhow!(e))
862 }
863 ConverseStreamError::ModelTimeoutException(e) => {
864 AiError::Retryable(anyhow::anyhow!(e))
865 }
866 ConverseStreamError::ResourceNotFoundException(e) => {
867 AiError::Terminal(anyhow::anyhow!(e))
868 }
869 ConverseStreamError::AccessDeniedException(e) => {
870 AiError::Terminal(anyhow::anyhow!(e))
871 }
872 ConverseStreamError::ModelErrorException(e) => {
873 AiError::Terminal(anyhow::anyhow!(e))
874 }
875 ConverseStreamError::ModelNotReadyException(e) => {
876 AiError::Terminal(anyhow::anyhow!(e))
877 }
878 ConverseStreamError::ValidationException(e) => {
879 let error_message = format!("{}", e).to_lowercase();
880 if error_message.contains("too long") {
881 AiError::InputTooLong(anyhow::anyhow!(e))
882 } else {
883 AiError::Terminal(anyhow::anyhow!(e))
884 }
885 }
886 _ => AiError::Terminal(anyhow::anyhow!("Unknown error from bedrock stream: {e:?}")),
887 }
888 })?;
889
890 let mut event_stream = response.stream;
891
892 let stream = async_stream::stream! {
893 let mut state = BedrockStreamAccumulator::new();
894
895 loop {
896 let recv_result = event_stream.recv().await;
897 let Ok(maybe_event) = recv_result else {
898 tracing::warn!("Error in bedrock stream");
899 yield Err(AiError::Retryable(anyhow::anyhow!("Bedrock stream error")));
900 return;
901 };
902 let Some(event) = maybe_event else { break };
903 for stream_event in state.process_event(event) {
904 yield Ok(stream_event);
905 }
906 }
907
908 yield Ok(StreamEvent::MessageComplete { response: state.into_response() });
909 };
910
911 Ok(Box::pin(stream))
912 }
913
914 fn get_cost(&self, model: &Model) -> Cost {
915 match model {
916 Model::ClaudeSonnet45 => Cost::new(3.0, 15.0, 3.75, 0.3),
917 Model::ClaudeHaiku45 => Cost::new(1.0, 5.0, 1.25, 0.1),
918 Model::ClaudeOpus46 => Cost::new(5.0, 25.0, 6.25, 0.5),
919 Model::ClaudeOpus45 => Cost::new(5.0, 25.0, 6.25, 0.5),
920 Model::GptOss120b => Cost::new(0.15, 0.6, 0.0, 0.0),
921 _ => Cost::new(0.0, 0.0, 0.0, 0.0),
922 }
923 }
924}
925
926#[cfg(test)]
927mod tests {
928 use super::*;
929 use crate::ai::tests::{
930 test_hello_world, test_multiple_tool_calls, test_reasoning_conversation,
931 test_reasoning_with_tools, test_tool_usage,
932 };
933
934 async fn create_bedrock_provider() -> anyhow::Result<BedrockProvider> {
935 let bedrock_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
936 .region(aws_config::Region::new("us-west-2"))
937 .load()
938 .await;
939 let bedrock_client = aws_sdk_bedrockruntime::Client::new(&bedrock_config);
940 Ok(BedrockProvider::new(bedrock_client))
941 }
942
943 #[tokio::test]
944 #[ignore = "requires AWS credentials"]
945 async fn test_bedrock_hello_world() {
946 let provider = match create_bedrock_provider().await {
947 Ok(provider) => provider,
948 Err(e) => {
949 tracing::error!(?e, "Failed to create Bedrock provider");
950 panic!("Failed to create Bedrock provider: {e:?}");
951 }
952 };
953
954 if let Err(e) = test_hello_world(provider).await {
955 tracing::error!(?e, "Bedrock hello world test failed");
956 panic!("Bedrock hello world test failed: {e:?}");
957 }
958 }
959
960 #[tokio::test]
961 #[ignore = "requires AWS credentials"]
962 async fn test_bedrock_reasoning_conversation() {
963 let provider = match create_bedrock_provider().await {
964 Ok(provider) => provider,
965 Err(e) => {
966 tracing::error!(?e, "Failed to create Bedrock provider");
967 panic!("Failed to create Bedrock provider: {e:?}");
968 }
969 };
970
971 if let Err(e) = test_reasoning_conversation(provider).await {
972 tracing::error!(?e, "Bedrock reasoning conversation test failed");
973 panic!("Bedrock reasoning conversation test failed: {e:?}");
974 }
975 }
976
977 #[tokio::test]
978 #[ignore = "requires AWS credentials"]
979 async fn test_bedrock_tool_usage() {
980 let provider = match create_bedrock_provider().await {
981 Ok(provider) => provider,
982 Err(e) => {
983 tracing::error!(?e, "Failed to create Bedrock provider");
984 panic!("Failed to create Bedrock provider: {e:?}");
985 }
986 };
987
988 if let Err(e) = test_tool_usage(provider).await {
989 tracing::error!(?e, "Bedrock tool usage test failed");
990 panic!("Bedrock tool usage test failed: {e:?}");
991 }
992 }
993
994 #[tokio::test]
995 #[ignore = "requires AWS credentials"]
996 async fn test_bedrock_reasoning_with_tools() {
997 let provider = match create_bedrock_provider().await {
998 Ok(provider) => provider,
999 Err(e) => {
1000 tracing::error!(?e, "Failed to create Bedrock provider");
1001 panic!("Failed to create Bedrock provider: {e:?}");
1002 }
1003 };
1004
1005 if let Err(e) = test_reasoning_with_tools(provider).await {
1006 tracing::error!(?e, "Bedrock reasoning with tools test failed");
1007 panic!("Bedrock reasoning with tools test failed: {e:?}");
1008 }
1009 }
1010
1011 #[tokio::test]
1012 #[ignore = "requires AWS credentials"]
1013 async fn test_bedrock_multiple_tool_calls() {
1014 let provider = match create_bedrock_provider().await {
1015 Ok(provider) => provider,
1016 Err(e) => {
1017 tracing::error!(?e, "Failed to create Bedrock provider");
1018 panic!("Failed to create Bedrock provider: {e:?}");
1019 }
1020 };
1021
1022 if let Err(e) = test_multiple_tool_calls(provider).await {
1023 tracing::error!(?e, "Bedrock reasoning with tools test failed");
1024 panic!("Bedrock reasoning with tools test failed: {e:?}");
1025 }
1026 }
1027}