1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use aws_sdk_bedrockruntime::types::{
5 self as bedrock_types, ContentBlock, ConversationRole, InferenceConfiguration,
6 SystemContentBlock, ToolConfiguration, ToolInputSchema, ToolResultBlock,
7 ToolResultContentBlock, ToolSpecification, ToolUseBlock,
8};
9use aws_smithy_types::Document as SmithyDocument;
10use serde_json::Value;
11use synaptic_core::{
12 AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError,
13 TokenUsage, ToolCall, ToolCallChunk, ToolChoice,
14};
15
16#[derive(Debug, Clone)]
22pub struct BedrockConfig {
23 pub model_id: String,
25 pub region: Option<String>,
27 pub max_tokens: Option<i32>,
29 pub temperature: Option<f32>,
31 pub top_p: Option<f32>,
33 pub stop: Option<Vec<String>>,
35}
36
37impl BedrockConfig {
38 pub fn new(model_id: impl Into<String>) -> Self {
40 Self {
41 model_id: model_id.into(),
42 region: None,
43 max_tokens: None,
44 temperature: None,
45 top_p: None,
46 stop: None,
47 }
48 }
49
50 pub fn with_region(mut self, region: impl Into<String>) -> Self {
52 self.region = Some(region.into());
53 self
54 }
55
56 pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
58 self.max_tokens = Some(max_tokens);
59 self
60 }
61
62 pub fn with_temperature(mut self, temperature: f32) -> Self {
64 self.temperature = Some(temperature);
65 self
66 }
67
68 pub fn with_top_p(mut self, top_p: f32) -> Self {
70 self.top_p = Some(top_p);
71 self
72 }
73
74 pub fn with_stop(mut self, stop: Vec<String>) -> Self {
76 self.stop = Some(stop);
77 self
78 }
79}
80
81pub struct BedrockChatModel {
90 config: BedrockConfig,
91 client: aws_sdk_bedrockruntime::Client,
92}
93
94impl BedrockChatModel {
95 pub async fn new(config: BedrockConfig) -> Self {
99 let mut aws_config_loader = aws_config::from_env();
100
101 if let Some(ref region) = config.region {
102 aws_config_loader = aws_config_loader.region(aws_config::Region::new(region.clone()));
103 }
104
105 let aws_config = aws_config_loader.load().await;
106 let client = aws_sdk_bedrockruntime::Client::new(&aws_config);
107
108 Self { config, client }
109 }
110
111 pub fn from_client(config: BedrockConfig, client: aws_sdk_bedrockruntime::Client) -> Self {
113 Self { config, client }
114 }
115
116 fn build_inference_config(&self) -> Option<InferenceConfiguration> {
118 let has_any = self.config.max_tokens.is_some()
119 || self.config.temperature.is_some()
120 || self.config.top_p.is_some()
121 || self.config.stop.is_some();
122
123 if !has_any {
124 return None;
125 }
126
127 let mut builder = InferenceConfiguration::builder();
128
129 if let Some(max_tokens) = self.config.max_tokens {
130 builder = builder.max_tokens(max_tokens);
131 }
132 if let Some(temperature) = self.config.temperature {
133 builder = builder.temperature(temperature);
134 }
135 if let Some(top_p) = self.config.top_p {
136 builder = builder.top_p(top_p);
137 }
138 if let Some(ref stop) = self.config.stop {
139 for s in stop {
140 builder = builder.stop_sequences(s.clone());
141 }
142 }
143
144 Some(builder.build())
145 }
146
147 fn build_tool_config(&self, request: &ChatRequest) -> Option<ToolConfiguration> {
149 if request.tools.is_empty() {
150 return None;
151 }
152
153 let tools: Vec<bedrock_types::Tool> = request
154 .tools
155 .iter()
156 .map(|td| {
157 let spec = ToolSpecification::builder()
158 .name(&td.name)
159 .description(&td.description)
160 .input_schema(ToolInputSchema::Json(json_value_to_document(
161 &td.parameters,
162 )))
163 .build()
164 .expect("tool specification build should not fail");
165
166 bedrock_types::Tool::ToolSpec(spec)
167 })
168 .collect();
169
170 let mut builder = ToolConfiguration::builder();
171 for tool in tools {
172 builder = builder.tools(tool);
173 }
174
175 if let Some(ref choice) = request.tool_choice {
176 let bedrock_choice = match choice {
177 ToolChoice::Auto => bedrock_types::ToolChoice::Auto(
178 bedrock_types::AutoToolChoice::builder().build(),
179 ),
180 ToolChoice::Required => {
181 bedrock_types::ToolChoice::Any(bedrock_types::AnyToolChoice::builder().build())
182 }
183 ToolChoice::None => {
184 bedrock_types::ToolChoice::Auto(
187 bedrock_types::AutoToolChoice::builder().build(),
188 )
189 }
190 ToolChoice::Specific(name) => bedrock_types::ToolChoice::Tool(
191 bedrock_types::SpecificToolChoice::builder()
192 .name(name)
193 .build()
194 .expect("specific tool choice build should not fail"),
195 ),
196 };
197 builder = builder.tool_choice(bedrock_choice);
198 }
199
200 Some(
201 builder
202 .build()
203 .expect("tool configuration build should not fail"),
204 )
205 }
206}
207
208#[async_trait]
209impl ChatModel for BedrockChatModel {
210 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
211 let (system_blocks, messages) = convert_messages(&request.messages);
212
213 let mut converse = self.client.converse().model_id(&self.config.model_id);
214
215 for block in system_blocks {
217 converse = converse.system(block);
218 }
219
220 for msg in messages {
222 converse = converse.messages(msg);
223 }
224
225 if let Some(inference_config) = self.build_inference_config() {
227 converse = converse.inference_config(inference_config);
228 }
229
230 if let Some(tool_config) = self.build_tool_config(&request) {
232 converse = converse.tool_config(tool_config);
233 }
234
235 let output = converse
236 .send()
237 .await
238 .map_err(|e| SynapticError::Model(format!("Bedrock Converse API error: {e}")))?;
239
240 let usage = output.usage().map(|u| TokenUsage {
242 input_tokens: u.input_tokens() as u32,
243 output_tokens: u.output_tokens() as u32,
244 total_tokens: u.total_tokens() as u32,
245 input_details: None,
246 output_details: None,
247 });
248
249 let message = match output.output() {
251 Some(bedrock_types::ConverseOutput::Message(msg)) => parse_bedrock_message(msg),
252 _ => Message::ai(""),
253 };
254
255 Ok(ChatResponse { message, usage })
256 }
257
258 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
259 Box::pin(async_stream::stream! {
260 let (system_blocks, messages) = convert_messages(&request.messages);
261
262 let mut converse_stream = self
263 .client
264 .converse_stream()
265 .model_id(&self.config.model_id);
266
267 for block in system_blocks {
268 converse_stream = converse_stream.system(block);
269 }
270
271 for msg in messages {
272 converse_stream = converse_stream.messages(msg);
273 }
274
275 if let Some(inference_config) = self.build_inference_config() {
276 converse_stream = converse_stream.inference_config(inference_config);
277 }
278
279 if let Some(tool_config) = self.build_tool_config(&request) {
280 converse_stream = converse_stream.tool_config(tool_config);
281 }
282
283 let output = match converse_stream.send().await {
284 Ok(o) => o,
285 Err(e) => {
286 yield Err(SynapticError::Model(format!(
287 "Bedrock ConverseStream API error: {e}"
288 )));
289 return;
290 }
291 };
292
293 let mut stream = output.stream;
294
295 let mut current_tool_id: Option<String> = None;
297 let mut current_tool_name: Option<String> = None;
298 let mut current_tool_input: String = String::new();
299
300 loop {
301 match stream.recv().await {
302 Ok(Some(event)) => {
303 match event {
304 bedrock_types::ConverseStreamOutput::ContentBlockStart(start_event) => {
305 if let Some(bedrock_types::ContentBlockStart::ToolUse(tool_start)) = start_event.start() {
306 current_tool_id = Some(tool_start.tool_use_id().to_string());
307 current_tool_name = Some(tool_start.name().to_string());
308 current_tool_input.clear();
309
310 yield Ok(AIMessageChunk {
311 tool_call_chunks: vec![ToolCallChunk {
312 id: Some(tool_start.tool_use_id().to_string()),
313 name: Some(tool_start.name().to_string()),
314 arguments: None,
315 index: Some(start_event.content_block_index() as usize),
316 }],
317 ..Default::default()
318 });
319 }
320 }
321 bedrock_types::ConverseStreamOutput::ContentBlockDelta(delta_event) => {
322 if let Some(delta) = delta_event.delta() {
323 match delta {
324 bedrock_types::ContentBlockDelta::Text(text) => {
325 yield Ok(AIMessageChunk {
326 content: text.to_string(),
327 ..Default::default()
328 });
329 }
330 bedrock_types::ContentBlockDelta::ToolUse(tool_delta) => {
331 let input_fragment = tool_delta.input();
332 current_tool_input.push_str(input_fragment);
333
334 yield Ok(AIMessageChunk {
335 tool_call_chunks: vec![ToolCallChunk {
336 id: current_tool_id.clone(),
337 name: current_tool_name.clone(),
338 arguments: Some(input_fragment.to_string()),
339 index: Some(delta_event.content_block_index() as usize),
340 }],
341 ..Default::default()
342 });
343 }
344 _ => { }
345 }
346 }
347 }
348 bedrock_types::ConverseStreamOutput::ContentBlockStop(_) => {
349 if let (Some(id), Some(name)) = (current_tool_id.take(), current_tool_name.take()) {
351 let arguments: Value = serde_json::from_str(¤t_tool_input)
352 .unwrap_or(Value::Object(Default::default()));
353 current_tool_input.clear();
354
355 yield Ok(AIMessageChunk {
356 tool_calls: vec![ToolCall {
357 id,
358 name,
359 arguments,
360 }],
361 ..Default::default()
362 });
363 }
364 }
365 bedrock_types::ConverseStreamOutput::Metadata(meta) => {
366 if let Some(u) = meta.usage() {
367 yield Ok(AIMessageChunk {
368 usage: Some(TokenUsage {
369 input_tokens: u.input_tokens() as u32,
370 output_tokens: u.output_tokens() as u32,
371 total_tokens: u.total_tokens() as u32,
372 input_details: None,
373 output_details: None,
374 }),
375 ..Default::default()
376 });
377 }
378 }
379 _ => { }
380 }
381 }
382 Ok(None) => break,
383 Err(e) => {
384 yield Err(SynapticError::Model(format!(
385 "Bedrock stream error: {e}"
386 )));
387 break;
388 }
389 }
390 }
391 })
392 }
393}
394
395fn convert_messages(
404 messages: &[Message],
405) -> (Vec<SystemContentBlock>, Vec<bedrock_types::Message>) {
406 let mut system_blocks = Vec::new();
407 let mut bedrock_messages: Vec<bedrock_types::Message> = Vec::new();
408
409 for msg in messages {
410 match msg {
411 Message::System { content, .. } => {
412 system_blocks.push(SystemContentBlock::Text(content.clone()));
413 }
414 Message::Human { content, .. } => {
415 let bedrock_msg = bedrock_types::Message::builder()
416 .role(ConversationRole::User)
417 .content(ContentBlock::Text(content.clone()))
418 .build()
419 .expect("message build should not fail");
420 bedrock_messages.push(bedrock_msg);
421 }
422 Message::AI {
423 content,
424 tool_calls,
425 ..
426 } => {
427 let mut blocks: Vec<ContentBlock> = Vec::new();
428
429 if !content.is_empty() {
430 blocks.push(ContentBlock::Text(content.clone()));
431 }
432
433 for tc in tool_calls {
434 let tool_use = ToolUseBlock::builder()
435 .tool_use_id(&tc.id)
436 .name(&tc.name)
437 .input(json_value_to_document(&tc.arguments))
438 .build()
439 .expect("tool use block build should not fail");
440 blocks.push(ContentBlock::ToolUse(tool_use));
441 }
442
443 if blocks.is_empty() {
445 blocks.push(ContentBlock::Text(String::new()));
446 }
447
448 let bedrock_msg = bedrock_types::Message::builder()
449 .role(ConversationRole::Assistant)
450 .set_content(Some(blocks))
451 .build()
452 .expect("message build should not fail");
453 bedrock_messages.push(bedrock_msg);
454 }
455 Message::Tool {
456 content,
457 tool_call_id,
458 ..
459 } => {
460 let tool_result = ToolResultBlock::builder()
461 .tool_use_id(tool_call_id)
462 .content(ToolResultContentBlock::Text(content.clone()))
463 .build()
464 .expect("tool result block build should not fail");
465
466 let bedrock_msg = bedrock_types::Message::builder()
467 .role(ConversationRole::User)
468 .content(ContentBlock::ToolResult(tool_result))
469 .build()
470 .expect("message build should not fail");
471 bedrock_messages.push(bedrock_msg);
472 }
473 Message::Chat { content, .. } => {
474 let bedrock_msg = bedrock_types::Message::builder()
476 .role(ConversationRole::User)
477 .content(ContentBlock::Text(content.clone()))
478 .build()
479 .expect("message build should not fail");
480 bedrock_messages.push(bedrock_msg);
481 }
482 Message::Remove { .. } => { }
483 }
484 }
485
486 (system_blocks, bedrock_messages)
487}
488
489fn parse_bedrock_message(msg: &bedrock_types::Message) -> Message {
491 let mut text_parts: Vec<String> = Vec::new();
492 let mut tool_calls: Vec<ToolCall> = Vec::new();
493
494 for block in msg.content() {
495 match block {
496 ContentBlock::Text(text) => {
497 text_parts.push(text.clone());
498 }
499 ContentBlock::ToolUse(tool_use) => {
500 tool_calls.push(ToolCall {
501 id: tool_use.tool_use_id().to_string(),
502 name: tool_use.name().to_string(),
503 arguments: document_to_json_value(tool_use.input()),
504 });
505 }
506 _ => { }
507 }
508 }
509
510 let content = text_parts.join("");
511
512 if tool_calls.is_empty() {
513 Message::ai(content)
514 } else {
515 Message::ai_with_tool_calls(content, tool_calls)
516 }
517}
518
519pub(crate) fn json_value_to_document(value: &Value) -> SmithyDocument {
525 match value {
526 Value::Null => SmithyDocument::Null,
527 Value::Bool(b) => SmithyDocument::Bool(*b),
528 Value::Number(n) => {
529 if let Some(i) = n.as_i64() {
530 SmithyDocument::Number(aws_smithy_types::Number::NegInt(i))
531 } else if let Some(u) = n.as_u64() {
532 SmithyDocument::Number(aws_smithy_types::Number::PosInt(u))
533 } else if let Some(f) = n.as_f64() {
534 SmithyDocument::Number(aws_smithy_types::Number::Float(f))
535 } else {
536 SmithyDocument::Null
537 }
538 }
539 Value::String(s) => SmithyDocument::String(s.clone()),
540 Value::Array(arr) => {
541 SmithyDocument::Array(arr.iter().map(json_value_to_document).collect())
542 }
543 Value::Object(obj) => {
544 let map: HashMap<String, SmithyDocument> = obj
545 .iter()
546 .map(|(k, v)| (k.clone(), json_value_to_document(v)))
547 .collect();
548 SmithyDocument::Object(map)
549 }
550 }
551}
552
553pub(crate) fn document_to_json_value(doc: &SmithyDocument) -> Value {
555 match doc {
556 SmithyDocument::Null => Value::Null,
557 SmithyDocument::Bool(b) => Value::Bool(*b),
558 SmithyDocument::Number(n) => match *n {
559 aws_smithy_types::Number::PosInt(u) => {
560 serde_json::json!(u)
561 }
562 aws_smithy_types::Number::NegInt(i) => {
563 serde_json::json!(i)
564 }
565 aws_smithy_types::Number::Float(f) => {
566 serde_json::json!(f)
567 }
568 },
569 SmithyDocument::String(s) => Value::String(s.clone()),
570 SmithyDocument::Array(arr) => {
571 Value::Array(arr.iter().map(document_to_json_value).collect())
572 }
573 SmithyDocument::Object(obj) => {
574 let map: serde_json::Map<String, Value> = obj
575 .iter()
576 .map(|(k, v)| (k.clone(), document_to_json_value(v)))
577 .collect();
578 Value::Object(map)
579 }
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586
587 #[test]
588 fn json_value_to_document_round_trip() {
589 let original = serde_json::json!({
590 "type": "object",
591 "properties": {
592 "name": {"type": "string"},
593 "age": {"type": "integer"}
594 },
595 "required": ["name"]
596 });
597
598 let doc = json_value_to_document(&original);
599 let back = document_to_json_value(&doc);
600 assert_eq!(original, back);
601 }
602
603 #[test]
604 fn json_value_to_document_primitives() {
605 assert!(matches!(
606 json_value_to_document(&Value::Null),
607 SmithyDocument::Null
608 ));
609 assert!(matches!(
610 json_value_to_document(&Value::Bool(true)),
611 SmithyDocument::Bool(true)
612 ));
613 assert!(matches!(
614 json_value_to_document(&serde_json::json!("hello")),
615 SmithyDocument::String(_)
616 ));
617 }
618
619 #[test]
620 fn convert_system_messages() {
621 let messages = vec![
622 Message::system("You are a helpful assistant."),
623 Message::human("Hello!"),
624 ];
625
626 let (system_blocks, bedrock_messages) = convert_messages(&messages);
627 assert_eq!(system_blocks.len(), 1);
628 assert_eq!(bedrock_messages.len(), 1);
629 }
630
631 #[test]
632 fn convert_tool_messages() {
633 let messages = vec![
634 Message::human("What is the weather?"),
635 Message::ai_with_tool_calls(
636 "",
637 vec![ToolCall {
638 id: "tc_1".to_string(),
639 name: "get_weather".to_string(),
640 arguments: serde_json::json!({"city": "NYC"}),
641 }],
642 ),
643 Message::tool("Sunny, 72F", "tc_1"),
644 ];
645
646 let (system_blocks, bedrock_messages) = convert_messages(&messages);
647 assert!(system_blocks.is_empty());
648 assert_eq!(bedrock_messages.len(), 3);
649
650 assert_eq!(*bedrock_messages[0].role(), ConversationRole::User);
652 assert_eq!(*bedrock_messages[1].role(), ConversationRole::Assistant);
654 assert_eq!(*bedrock_messages[2].role(), ConversationRole::User);
656 }
657
658 #[test]
659 fn convert_remove_messages_are_skipped() {
660 let messages = vec![
661 Message::human("Hi"),
662 Message::remove("some-id"),
663 Message::ai("Hello!"),
664 ];
665
666 let (_, bedrock_messages) = convert_messages(&messages);
667 assert_eq!(bedrock_messages.len(), 2);
668 }
669
670 #[test]
671 fn parse_text_only_message() {
672 let msg = bedrock_types::Message::builder()
673 .role(ConversationRole::Assistant)
674 .content(ContentBlock::Text("Hello world".to_string()))
675 .build()
676 .unwrap();
677
678 let parsed = parse_bedrock_message(&msg);
679 assert!(parsed.is_ai());
680 assert_eq!(parsed.content(), "Hello world");
681 assert!(parsed.tool_calls().is_empty());
682 }
683
684 #[test]
685 fn parse_message_with_tool_use() {
686 let tool_use = ToolUseBlock::builder()
687 .tool_use_id("tc_1")
688 .name("calculator")
689 .input(json_value_to_document(&serde_json::json!({"expr": "1+1"})))
690 .build()
691 .unwrap();
692
693 let msg = bedrock_types::Message::builder()
694 .role(ConversationRole::Assistant)
695 .content(ContentBlock::Text("Let me calculate.".to_string()))
696 .content(ContentBlock::ToolUse(tool_use))
697 .build()
698 .unwrap();
699
700 let parsed = parse_bedrock_message(&msg);
701 assert!(parsed.is_ai());
702 assert_eq!(parsed.content(), "Let me calculate.");
703 assert_eq!(parsed.tool_calls().len(), 1);
704 assert_eq!(parsed.tool_calls()[0].id, "tc_1");
705 assert_eq!(parsed.tool_calls()[0].name, "calculator");
706 assert_eq!(
707 parsed.tool_calls()[0].arguments,
708 serde_json::json!({"expr": "1+1"})
709 );
710 }
711}