1use async_trait::async_trait;
4use aws_config::BehaviorVersion;
5use aws_sdk_bedrockruntime::types::{
6 ContentBlock as BedrockContentBlock, ConversationRole, InferenceConfiguration, Message as BedrockMessage,
7 SystemContentBlock as BedrockSystemContentBlock, Tool, ToolConfiguration, ToolInputSchema, ToolSpecification,
8};
9use aws_sdk_bedrockruntime::Client;
10use aws_smithy_types::Document;
11
12use super::{Model, ModelConfig, StreamEventStream};
13use crate::types::{
14 content::{ContentBlock, Message, Role, SystemContentBlock},
15 errors::StrandsError,
16 streaming::{
17 ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockDeltaToolUse, ContentBlockStart,
18 ContentBlockStartEvent, ContentBlockStartToolUse, ContentBlockStopEvent, MessageStartEvent,
19 MessageStopEvent, MetadataEvent, Metrics, StopReason, StreamEvent, Usage,
20 },
21 tools::{ToolChoice, ToolSpec},
22};
23
24const DEFAULT_MODEL_ID: &str = "us.anthropic.claude-sonnet-4-20250514-v1:0";
25
26#[derive(Debug, Clone)]
28pub struct BedrockModel {
29 config: ModelConfig,
30 region: Option<String>,
31}
32
33impl BedrockModel {
34 pub fn new(model_id: impl Into<String>) -> Self {
35 Self {
36 config: ModelConfig::new(model_id),
37 region: None,
38 }
39 }
40
41 pub fn with_config(config: ModelConfig) -> Self {
42 Self { config, region: None }
43 }
44
45 pub fn with_region(mut self, region: impl Into<String>) -> Self {
46 self.region = Some(region.into());
47 self
48 }
49
50 fn format_messages(&self, messages: &[Message]) -> Vec<BedrockMessage> {
51 messages
52 .iter()
53 .map(|msg| {
54 let role = match msg.role {
55 Role::User => ConversationRole::User,
56 Role::Assistant => ConversationRole::Assistant,
57 };
58
59 let content_blocks: Vec<BedrockContentBlock> = msg
60 .content
61 .iter()
62 .filter_map(|block| self.format_content_block(block))
63 .collect();
64
65 BedrockMessage::builder()
66 .role(role)
67 .set_content(Some(content_blocks))
68 .build()
69 .expect("valid message")
70 })
71 .collect()
72 }
73
74 fn json_to_document(value: &serde_json::Value) -> Document {
75 match value {
76 serde_json::Value::Null => Document::Null,
77 serde_json::Value::Bool(b) => Document::Bool(*b),
78 serde_json::Value::Number(n) => {
79 if let Some(i) = n.as_i64() {
80 Document::Number(aws_smithy_types::Number::NegInt(i))
81 } else if let Some(f) = n.as_f64() {
82 Document::Number(aws_smithy_types::Number::Float(f))
83 } else {
84 Document::Null
85 }
86 }
87 serde_json::Value::String(s) => Document::String(s.clone()),
88 serde_json::Value::Array(arr) => {
89 Document::Array(arr.iter().map(Self::json_to_document).collect())
90 }
91 serde_json::Value::Object(obj) => {
92 Document::Object(obj.iter().map(|(k, v)| (k.clone(), Self::json_to_document(v))).collect())
93 }
94 }
95 }
96
97 fn format_content_block(&self, block: &ContentBlock) -> Option<BedrockContentBlock> {
98 if let Some(ref text) = block.text {
99 return Some(BedrockContentBlock::Text(text.clone()));
100 }
101
102 if let Some(ref tool_use) = block.tool_use {
103 let input_doc = Self::json_to_document(&tool_use.input);
104
105 return Some(BedrockContentBlock::ToolUse(
106 aws_sdk_bedrockruntime::types::ToolUseBlock::builder()
107 .tool_use_id(&tool_use.tool_use_id)
108 .name(&tool_use.name)
109 .input(input_doc)
110 .build()
111 .expect("valid tool use"),
112 ));
113 }
114
115 if let Some(ref tool_result) = block.tool_result {
116 let content: Vec<aws_sdk_bedrockruntime::types::ToolResultContentBlock> = tool_result
117 .content
118 .iter()
119 .filter_map(|c| {
120 if let Some(ref text) = c.text {
121 Some(aws_sdk_bedrockruntime::types::ToolResultContentBlock::Text(text.clone()))
122 } else if let Some(ref json_val) = c.json {
123 Some(aws_sdk_bedrockruntime::types::ToolResultContentBlock::Text(
124 serde_json::to_string(json_val).unwrap_or_default(),
125 ))
126 } else {
127 None
128 }
129 })
130 .collect();
131
132 let status = match tool_result.status {
133 crate::types::tools::ToolResultStatus::Success => {
134 aws_sdk_bedrockruntime::types::ToolResultStatus::Success
135 }
136 crate::types::tools::ToolResultStatus::Error => {
137 aws_sdk_bedrockruntime::types::ToolResultStatus::Error
138 }
139 };
140
141 return Some(BedrockContentBlock::ToolResult(
142 aws_sdk_bedrockruntime::types::ToolResultBlock::builder()
143 .tool_use_id(&tool_result.tool_use_id)
144 .set_content(Some(content))
145 .status(status)
146 .build()
147 .expect("valid tool result"),
148 ));
149 }
150
151 None
152 }
153
154 fn format_tool_specs(&self, tool_specs: &[ToolSpec]) -> Vec<Tool> {
155 tool_specs
156 .iter()
157 .map(|spec| {
158 let input_schema_doc = Self::json_to_document(&spec.input_schema.json);
159
160 Tool::ToolSpec(
161 ToolSpecification::builder()
162 .name(&spec.name)
163 .description(&spec.description)
164 .input_schema(ToolInputSchema::Json(input_schema_doc))
165 .build()
166 .expect("valid tool spec"),
167 )
168 })
169 .collect()
170 }
171
172 fn format_system_prompt(&self, system_prompt: Option<&str>) -> Option<Vec<BedrockSystemContentBlock>> {
173 system_prompt.map(|s| vec![BedrockSystemContentBlock::Text(s.to_string())])
174 }
175
176 fn map_stop_reason(reason: &aws_sdk_bedrockruntime::types::StopReason) -> StopReason {
177 match reason {
178 aws_sdk_bedrockruntime::types::StopReason::EndTurn => StopReason::EndTurn,
179 aws_sdk_bedrockruntime::types::StopReason::ToolUse => StopReason::ToolUse,
180 aws_sdk_bedrockruntime::types::StopReason::MaxTokens => StopReason::MaxTokens,
181 aws_sdk_bedrockruntime::types::StopReason::StopSequence => StopReason::StopSequence,
182 aws_sdk_bedrockruntime::types::StopReason::ContentFiltered => StopReason::ContentFiltered,
183 aws_sdk_bedrockruntime::types::StopReason::GuardrailIntervened => StopReason::GuardrailIntervention,
184 _ => StopReason::EndTurn,
185 }
186 }
187}
188
189impl Default for BedrockModel {
190 fn default() -> Self {
191 Self::new(DEFAULT_MODEL_ID)
192 }
193}
194
195#[async_trait]
196impl Model for BedrockModel {
197 fn config(&self) -> &ModelConfig {
198 &self.config
199 }
200
201 fn update_config(&mut self, config: ModelConfig) {
202 self.config = config;
203 }
204
205 fn stream<'a>(
206 &'a self,
207 messages: &'a [Message],
208 tool_specs: Option<&'a [ToolSpec]>,
209 system_prompt: Option<&'a str>,
210 _tool_choice: Option<ToolChoice>,
211 _system_prompt_content: Option<&'a [SystemContentBlock]>,
212 ) -> StreamEventStream<'a> {
213 let model_id = self.config.model_id.clone();
214 let max_tokens = self.config.max_tokens.unwrap_or(4096);
215 let temperature = self.config.temperature;
216 let top_p = self.config.top_p;
217 let stop_sequences = self.config.stop_sequences.clone();
218
219 let formatted_messages = self.format_messages(messages);
220 let formatted_tools = tool_specs.map(|specs| self.format_tool_specs(specs));
221 let formatted_system = self.format_system_prompt(system_prompt);
222 let region = self.region.clone();
223
224 Box::pin(async_stream::stream! {
225 let mut config_loader = aws_config::defaults(BehaviorVersion::latest());
226 if let Some(ref r) = region {
227 config_loader = config_loader.region(aws_config::Region::new(r.clone()));
228 }
229 let sdk_config = config_loader.load().await;
230 let client = Client::new(&sdk_config);
231
232 let mut inference_config = InferenceConfiguration::builder().max_tokens(max_tokens as i32);
233
234 if let Some(temp) = temperature {
235 inference_config = inference_config.temperature(temp);
236 }
237 if let Some(p) = top_p {
238 inference_config = inference_config.top_p(p);
239 }
240 if let Some(ref seqs) = stop_sequences {
241 inference_config = inference_config.set_stop_sequences(Some(seqs.clone()));
242 }
243
244 let mut request = client
245 .converse_stream()
246 .model_id(&model_id)
247 .set_messages(Some(formatted_messages))
248 .inference_config(inference_config.build());
249
250 if let Some(system) = formatted_system {
251 request = request.set_system(Some(system));
252 }
253
254 if let Some(tools) = formatted_tools {
255 request = request.tool_config(
256 ToolConfiguration::builder()
257 .set_tools(Some(tools))
258 .build()
259 .expect("valid tool config"),
260 );
261 }
262
263 let response = match request.send().await {
264 Ok(resp) => resp,
265 Err(e) => {
266 let err_msg = e.to_string();
267 if err_msg.contains("ThrottlingException") || err_msg.contains("throttlingException") {
268 yield Err(StrandsError::ModelThrottled { message: err_msg });
269 } else if err_msg.contains("Input is too long") || err_msg.contains("context limit") {
270 yield Err(StrandsError::ContextWindowOverflow { message: err_msg });
271 } else {
272 yield Err(StrandsError::model_error(err_msg));
273 }
274 return;
275 }
276 };
277
278 let mut stream = response.stream;
279 let mut has_tool_use = false;
280
281 while let Ok(Some(event)) = stream.recv().await {
282 match event {
283 aws_sdk_bedrockruntime::types::ConverseStreamOutput::MessageStart(msg) => {
284 let _role = msg.role;
285 yield Ok(StreamEvent {
286 message_start: Some(MessageStartEvent { role: Role::Assistant }),
287 ..Default::default()
288 });
289 }
290
291 aws_sdk_bedrockruntime::types::ConverseStreamOutput::ContentBlockStart(start) => {
292 let content_block_index = start.content_block_index as u32;
293 let block_start = if let Some(ref s) = start.start {
294 match s {
295 aws_sdk_bedrockruntime::types::ContentBlockStart::ToolUse(tu) => {
296 has_tool_use = true;
297 Some(ContentBlockStart {
298 tool_use: Some(ContentBlockStartToolUse {
299 name: tu.name.clone(),
300 tool_use_id: tu.tool_use_id.clone(),
301 }),
302 })
303 }
304 _ => None,
305 }
306 } else {
307 None
308 };
309
310 yield Ok(StreamEvent {
311 content_block_start: Some(ContentBlockStartEvent {
312 content_block_index: Some(content_block_index),
313 start: block_start,
314 }),
315 ..Default::default()
316 });
317 }
318
319 aws_sdk_bedrockruntime::types::ConverseStreamOutput::ContentBlockDelta(delta) => {
320 if let Some(ref d) = delta.delta {
321 let block_delta = match d {
322 aws_sdk_bedrockruntime::types::ContentBlockDelta::Text(text) => {
323 ContentBlockDelta {
324 text: Some(text.clone()),
325 ..Default::default()
326 }
327 }
328 aws_sdk_bedrockruntime::types::ContentBlockDelta::ToolUse(tu) => {
329 ContentBlockDelta {
330 tool_use: Some(ContentBlockDeltaToolUse {
331 input: tu.input.clone(),
332 }),
333 ..Default::default()
334 }
335 }
336 _ => ContentBlockDelta::default(),
337 };
338
339 yield Ok(StreamEvent {
340 content_block_delta: Some(ContentBlockDeltaEvent {
341 content_block_index: Some(delta.content_block_index as u32),
342 delta: Some(block_delta),
343 }),
344 ..Default::default()
345 });
346 }
347 }
348
349 aws_sdk_bedrockruntime::types::ConverseStreamOutput::ContentBlockStop(stop) => {
350 yield Ok(StreamEvent {
351 content_block_stop: Some(ContentBlockStopEvent {
352 content_block_index: Some(stop.content_block_index as u32),
353 }),
354 ..Default::default()
355 });
356 }
357
358 aws_sdk_bedrockruntime::types::ConverseStreamOutput::MessageStop(stop) => {
359 let mut stop_reason = Self::map_stop_reason(&stop.stop_reason);
360
361 if has_tool_use && stop_reason == StopReason::EndTurn {
362 stop_reason = StopReason::ToolUse;
363 }
364
365 yield Ok(StreamEvent {
366 message_stop: Some(MessageStopEvent {
367 stop_reason: Some(stop_reason),
368 additional_model_response_fields: None,
369 }),
370 ..Default::default()
371 });
372 }
373
374 aws_sdk_bedrockruntime::types::ConverseStreamOutput::Metadata(meta) => {
375 let usage = meta.usage.map(|u| Usage {
376 input_tokens: u.input_tokens as u32,
377 output_tokens: u.output_tokens as u32,
378 total_tokens: (u.input_tokens + u.output_tokens) as u32,
379 cache_read_input_tokens: 0,
380 cache_write_input_tokens: 0,
381 });
382
383 let metrics = meta.metrics.map(|m| Metrics {
384 latency_ms: m.latency_ms as u64,
385 time_to_first_byte_ms: 0,
386 });
387
388 yield Ok(StreamEvent {
389 metadata: Some(MetadataEvent {
390 usage,
391 metrics,
392 trace: None,
393 }),
394 ..Default::default()
395 });
396 }
397
398 _ => {}
399 }
400 }
401 })
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_bedrock_model_creation() {
411 let model = BedrockModel::new("anthropic.claude-3-sonnet-20240229-v1:0");
412 assert_eq!(model.config().model_id, "anthropic.claude-3-sonnet-20240229-v1:0");
413 }
414
415 #[test]
416 fn test_bedrock_model_default() {
417 let model = BedrockModel::default();
418 assert!(model.config().model_id.contains("claude"));
419 }
420
421 #[test]
422 fn test_bedrock_with_region() {
423 let model = BedrockModel::default().with_region("us-east-1");
424 assert_eq!(model.region, Some("us-east-1".to_string()));
425 }
426
427 #[test]
428 fn test_json_to_document() {
429 let json = serde_json::json!({"key": "value", "num": 42});
430 let doc = BedrockModel::json_to_document(&json);
431 match doc {
432 Document::Object(map) => {
433 assert!(map.contains_key("key"));
434 assert!(map.contains_key("num"));
435 }
436 _ => panic!("expected object"),
437 }
438 }
439}