Skip to main content

synaptic_bedrock/
chat_model.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use aws_sdk_bedrockruntime::types::{
5    self as bedrock_types, ContentBlock, ConversationRole, InferenceConfiguration,
6    SystemContentBlock, ToolConfiguration, ToolInputSchema, ToolResultBlock,
7    ToolResultContentBlock, ToolSpecification, ToolUseBlock,
8};
9use aws_smithy_types::Document as SmithyDocument;
10use serde_json::Value;
11use synaptic_core::{
12    AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError,
13    TokenUsage, ToolCall, ToolCallChunk, ToolChoice,
14};
15
16// ---------------------------------------------------------------------------
17// Configuration
18// ---------------------------------------------------------------------------
19
20/// Configuration for the AWS Bedrock chat model.
21#[derive(Debug, Clone)]
22pub struct BedrockConfig {
23    /// The model identifier (e.g., `"anthropic.claude-3-5-sonnet-20241022-v2:0"`).
24    pub model_id: String,
25    /// AWS region override. Falls back to `AWS_REGION` env var or `"us-east-1"`.
26    pub region: Option<String>,
27    /// Maximum number of tokens to generate.
28    pub max_tokens: Option<i32>,
29    /// Sampling temperature (0.0 - 1.0).
30    pub temperature: Option<f32>,
31    /// Nucleus sampling parameter.
32    pub top_p: Option<f32>,
33    /// Stop sequences.
34    pub stop: Option<Vec<String>>,
35}
36
37impl BedrockConfig {
38    /// Create a new configuration with the given model ID.
39    pub fn new(model_id: impl Into<String>) -> Self {
40        Self {
41            model_id: model_id.into(),
42            region: None,
43            max_tokens: None,
44            temperature: None,
45            top_p: None,
46            stop: None,
47        }
48    }
49
50    /// Set the AWS region.
51    pub fn with_region(mut self, region: impl Into<String>) -> Self {
52        self.region = Some(region.into());
53        self
54    }
55
56    /// Set the maximum number of output tokens.
57    pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
58        self.max_tokens = Some(max_tokens);
59        self
60    }
61
62    /// Set the sampling temperature.
63    pub fn with_temperature(mut self, temperature: f32) -> Self {
64        self.temperature = Some(temperature);
65        self
66    }
67
68    /// Set the nucleus sampling parameter.
69    pub fn with_top_p(mut self, top_p: f32) -> Self {
70        self.top_p = Some(top_p);
71        self
72    }
73
74    /// Set stop sequences.
75    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
76        self.stop = Some(stop);
77        self
78    }
79}
80
81// ---------------------------------------------------------------------------
82// BedrockChatModel
83// ---------------------------------------------------------------------------
84
85/// A [`ChatModel`] implementation backed by AWS Bedrock's Converse API.
86///
87/// Supports both synchronous and streaming responses, tool calling,
88/// and all Bedrock-supported foundation models.
89pub struct BedrockChatModel {
90    config: BedrockConfig,
91    client: aws_sdk_bedrockruntime::Client,
92}
93
94impl BedrockChatModel {
95    /// Create a new `BedrockChatModel` by loading AWS configuration from the
96    /// environment. Respects `AWS_REGION`, `AWS_ACCESS_KEY_ID`,
97    /// `AWS_SECRET_ACCESS_KEY`, and other standard AWS SDK environment variables.
98    pub async fn new(config: BedrockConfig) -> Self {
99        let mut aws_config_loader = aws_config::from_env();
100
101        if let Some(ref region) = config.region {
102            aws_config_loader =
103                aws_config_loader.region(aws_config::Region::new(region.clone()));
104        }
105
106        let aws_config = aws_config_loader.load().await;
107        let client = aws_sdk_bedrockruntime::Client::new(&aws_config);
108
109        Self { config, client }
110    }
111
112    /// Create a new `BedrockChatModel` with a pre-existing AWS SDK client.
113    pub fn from_client(config: BedrockConfig, client: aws_sdk_bedrockruntime::Client) -> Self {
114        Self { config, client }
115    }
116
117    /// Build the inference configuration from our config.
118    fn build_inference_config(&self) -> Option<InferenceConfiguration> {
119        let has_any = self.config.max_tokens.is_some()
120            || self.config.temperature.is_some()
121            || self.config.top_p.is_some()
122            || self.config.stop.is_some();
123
124        if !has_any {
125            return None;
126        }
127
128        let mut builder = InferenceConfiguration::builder();
129
130        if let Some(max_tokens) = self.config.max_tokens {
131            builder = builder.max_tokens(max_tokens);
132        }
133        if let Some(temperature) = self.config.temperature {
134            builder = builder.temperature(temperature);
135        }
136        if let Some(top_p) = self.config.top_p {
137            builder = builder.top_p(top_p);
138        }
139        if let Some(ref stop) = self.config.stop {
140            for s in stop {
141                builder = builder.stop_sequences(s.clone());
142            }
143        }
144
145        Some(builder.build())
146    }
147
148    /// Build the tool configuration from a ChatRequest.
149    fn build_tool_config(
150        &self,
151        request: &ChatRequest,
152    ) -> Option<ToolConfiguration> {
153        if request.tools.is_empty() {
154            return None;
155        }
156
157        let tools: Vec<bedrock_types::Tool> = request
158            .tools
159            .iter()
160            .map(|td| {
161                let spec = ToolSpecification::builder()
162                    .name(&td.name)
163                    .description(&td.description)
164                    .input_schema(ToolInputSchema::Json(json_value_to_document(&td.parameters)))
165                    .build()
166                    .expect("tool specification build should not fail");
167
168                bedrock_types::Tool::ToolSpec(spec)
169            })
170            .collect();
171
172        let mut builder = ToolConfiguration::builder();
173        for tool in tools {
174            builder = builder.tools(tool);
175        }
176
177        if let Some(ref choice) = request.tool_choice {
178            let bedrock_choice = match choice {
179                ToolChoice::Auto => {
180                    bedrock_types::ToolChoice::Auto(
181                        bedrock_types::AutoToolChoice::builder().build(),
182                    )
183                }
184                ToolChoice::Required => {
185                    bedrock_types::ToolChoice::Any(
186                        bedrock_types::AnyToolChoice::builder().build(),
187                    )
188                }
189                ToolChoice::None => {
190                    // Bedrock does not have a "none" tool choice. We omit tools instead,
191                    // but since we already built them, just default to Auto.
192                    bedrock_types::ToolChoice::Auto(
193                        bedrock_types::AutoToolChoice::builder().build(),
194                    )
195                }
196                ToolChoice::Specific(name) => {
197                    bedrock_types::ToolChoice::Tool(
198                        bedrock_types::SpecificToolChoice::builder()
199                            .name(name)
200                            .build()
201                            .expect("specific tool choice build should not fail"),
202                    )
203                }
204            };
205            builder = builder.tool_choice(bedrock_choice);
206        }
207
208        Some(builder.build().expect("tool configuration build should not fail"))
209    }
210}
211
212#[async_trait]
213impl ChatModel for BedrockChatModel {
214    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
215        let (system_blocks, messages) = convert_messages(&request.messages);
216
217        let mut converse = self
218            .client
219            .converse()
220            .model_id(&self.config.model_id);
221
222        // Add system prompts.
223        for block in system_blocks {
224            converse = converse.system(block);
225        }
226
227        // Add messages.
228        for msg in messages {
229            converse = converse.messages(msg);
230        }
231
232        // Add inference config.
233        if let Some(inference_config) = self.build_inference_config() {
234            converse = converse.inference_config(inference_config);
235        }
236
237        // Add tool config.
238        if let Some(tool_config) = self.build_tool_config(&request) {
239            converse = converse.tool_config(tool_config);
240        }
241
242        let output = converse
243            .send()
244            .await
245            .map_err(|e| SynapticError::Model(format!("Bedrock Converse API error: {e}")))?;
246
247        // Parse usage.
248        let usage = output.usage().map(|u| TokenUsage {
249            input_tokens: u.input_tokens() as u32,
250            output_tokens: u.output_tokens() as u32,
251            total_tokens: u.total_tokens() as u32,
252            input_details: None,
253            output_details: None,
254        });
255
256        // Parse the output message.
257        let message = match output.output() {
258            Some(bedrock_types::ConverseOutput::Message(msg)) => {
259                parse_bedrock_message(msg)
260            }
261            _ => Message::ai(""),
262        };
263
264        Ok(ChatResponse { message, usage })
265    }
266
267    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
268        Box::pin(async_stream::stream! {
269            let (system_blocks, messages) = convert_messages(&request.messages);
270
271            let mut converse_stream = self
272                .client
273                .converse_stream()
274                .model_id(&self.config.model_id);
275
276            for block in system_blocks {
277                converse_stream = converse_stream.system(block);
278            }
279
280            for msg in messages {
281                converse_stream = converse_stream.messages(msg);
282            }
283
284            if let Some(inference_config) = self.build_inference_config() {
285                converse_stream = converse_stream.inference_config(inference_config);
286            }
287
288            if let Some(tool_config) = self.build_tool_config(&request) {
289                converse_stream = converse_stream.tool_config(tool_config);
290            }
291
292            let output = match converse_stream.send().await {
293                Ok(o) => o,
294                Err(e) => {
295                    yield Err(SynapticError::Model(format!(
296                        "Bedrock ConverseStream API error: {e}"
297                    )));
298                    return;
299                }
300            };
301
302            let mut stream = output.stream;
303
304            // Track current tool use blocks being built during streaming.
305            let mut current_tool_id: Option<String> = None;
306            let mut current_tool_name: Option<String> = None;
307            let mut current_tool_input: String = String::new();
308
309            loop {
310                match stream.recv().await {
311                    Ok(Some(event)) => {
312                        match event {
313                            bedrock_types::ConverseStreamOutput::ContentBlockStart(start_event) => {
314                                if let Some(bedrock_types::ContentBlockStart::ToolUse(tool_start)) = start_event.start() {
315                                    current_tool_id = Some(tool_start.tool_use_id().to_string());
316                                    current_tool_name = Some(tool_start.name().to_string());
317                                    current_tool_input.clear();
318
319                                    yield Ok(AIMessageChunk {
320                                        tool_call_chunks: vec![ToolCallChunk {
321                                            id: Some(tool_start.tool_use_id().to_string()),
322                                            name: Some(tool_start.name().to_string()),
323                                            arguments: None,
324                                            index: Some(start_event.content_block_index() as usize),
325                                        }],
326                                        ..Default::default()
327                                    });
328                                }
329                            }
330                            bedrock_types::ConverseStreamOutput::ContentBlockDelta(delta_event) => {
331                                if let Some(delta) = delta_event.delta() {
332                                    match delta {
333                                        bedrock_types::ContentBlockDelta::Text(text) => {
334                                            yield Ok(AIMessageChunk {
335                                                content: text.to_string(),
336                                                ..Default::default()
337                                            });
338                                        }
339                                        bedrock_types::ContentBlockDelta::ToolUse(tool_delta) => {
340                                            let input_fragment = tool_delta.input();
341                                            current_tool_input.push_str(input_fragment);
342
343                                            yield Ok(AIMessageChunk {
344                                                tool_call_chunks: vec![ToolCallChunk {
345                                                    id: current_tool_id.clone(),
346                                                    name: current_tool_name.clone(),
347                                                    arguments: Some(input_fragment.to_string()),
348                                                    index: Some(delta_event.content_block_index() as usize),
349                                                }],
350                                                ..Default::default()
351                                            });
352                                        }
353                                        _ => { /* ignore other delta types */ }
354                                    }
355                                }
356                            }
357                            bedrock_types::ConverseStreamOutput::ContentBlockStop(_) => {
358                                // If we were accumulating a tool call, emit the complete ToolCall.
359                                if let (Some(id), Some(name)) = (current_tool_id.take(), current_tool_name.take()) {
360                                    let arguments: Value = serde_json::from_str(&current_tool_input)
361                                        .unwrap_or(Value::Object(Default::default()));
362                                    current_tool_input.clear();
363
364                                    yield Ok(AIMessageChunk {
365                                        tool_calls: vec![ToolCall {
366                                            id,
367                                            name,
368                                            arguments,
369                                        }],
370                                        ..Default::default()
371                                    });
372                                }
373                            }
374                            bedrock_types::ConverseStreamOutput::Metadata(meta) => {
375                                if let Some(u) = meta.usage() {
376                                    yield Ok(AIMessageChunk {
377                                        usage: Some(TokenUsage {
378                                            input_tokens: u.input_tokens() as u32,
379                                            output_tokens: u.output_tokens() as u32,
380                                            total_tokens: u.total_tokens() as u32,
381                                            input_details: None,
382                                            output_details: None,
383                                        }),
384                                        ..Default::default()
385                                    });
386                                }
387                            }
388                            _ => { /* MessageStart, MessageStop, Unknown — skip */ }
389                        }
390                    }
391                    Ok(None) => break,
392                    Err(e) => {
393                        yield Err(SynapticError::Model(format!(
394                            "Bedrock stream error: {e}"
395                        )));
396                        break;
397                    }
398                }
399            }
400        })
401    }
402}
403
404// ---------------------------------------------------------------------------
405// Message conversion helpers
406// ---------------------------------------------------------------------------
407
408/// Convert Synaptic messages into Bedrock system blocks and conversation messages.
409///
410/// System messages are extracted into `SystemContentBlock` entries.
411/// Human, AI, and Tool messages are mapped to Bedrock `Message` types.
412fn convert_messages(
413    messages: &[Message],
414) -> (Vec<SystemContentBlock>, Vec<bedrock_types::Message>) {
415    let mut system_blocks = Vec::new();
416    let mut bedrock_messages: Vec<bedrock_types::Message> = Vec::new();
417
418    for msg in messages {
419        match msg {
420            Message::System { content, .. } => {
421                system_blocks.push(SystemContentBlock::Text(content.clone()));
422            }
423            Message::Human { content, .. } => {
424                let bedrock_msg = bedrock_types::Message::builder()
425                    .role(ConversationRole::User)
426                    .content(ContentBlock::Text(content.clone()))
427                    .build()
428                    .expect("message build should not fail");
429                bedrock_messages.push(bedrock_msg);
430            }
431            Message::AI {
432                content,
433                tool_calls,
434                ..
435            } => {
436                let mut blocks: Vec<ContentBlock> = Vec::new();
437
438                if !content.is_empty() {
439                    blocks.push(ContentBlock::Text(content.clone()));
440                }
441
442                for tc in tool_calls {
443                    let tool_use = ToolUseBlock::builder()
444                        .tool_use_id(&tc.id)
445                        .name(&tc.name)
446                        .input(json_value_to_document(&tc.arguments))
447                        .build()
448                        .expect("tool use block build should not fail");
449                    blocks.push(ContentBlock::ToolUse(tool_use));
450                }
451
452                // Bedrock requires at least one content block.
453                if blocks.is_empty() {
454                    blocks.push(ContentBlock::Text(String::new()));
455                }
456
457                let bedrock_msg = bedrock_types::Message::builder()
458                    .role(ConversationRole::Assistant)
459                    .set_content(Some(blocks))
460                    .build()
461                    .expect("message build should not fail");
462                bedrock_messages.push(bedrock_msg);
463            }
464            Message::Tool {
465                content,
466                tool_call_id,
467                ..
468            } => {
469                let tool_result = ToolResultBlock::builder()
470                    .tool_use_id(tool_call_id)
471                    .content(ToolResultContentBlock::Text(content.clone()))
472                    .build()
473                    .expect("tool result block build should not fail");
474
475                let bedrock_msg = bedrock_types::Message::builder()
476                    .role(ConversationRole::User)
477                    .content(ContentBlock::ToolResult(tool_result))
478                    .build()
479                    .expect("message build should not fail");
480                bedrock_messages.push(bedrock_msg);
481            }
482            Message::Chat { content, .. } => {
483                // Map custom roles to user by default.
484                let bedrock_msg = bedrock_types::Message::builder()
485                    .role(ConversationRole::User)
486                    .content(ContentBlock::Text(content.clone()))
487                    .build()
488                    .expect("message build should not fail");
489                bedrock_messages.push(bedrock_msg);
490            }
491            Message::Remove { .. } => { /* Skip remove messages */ }
492        }
493    }
494
495    (system_blocks, bedrock_messages)
496}
497
498/// Parse a Bedrock response message into a Synaptic `Message`.
499fn parse_bedrock_message(msg: &bedrock_types::Message) -> Message {
500    let mut text_parts: Vec<String> = Vec::new();
501    let mut tool_calls: Vec<ToolCall> = Vec::new();
502
503    for block in msg.content() {
504        match block {
505            ContentBlock::Text(text) => {
506                text_parts.push(text.clone());
507            }
508            ContentBlock::ToolUse(tool_use) => {
509                tool_calls.push(ToolCall {
510                    id: tool_use.tool_use_id().to_string(),
511                    name: tool_use.name().to_string(),
512                    arguments: document_to_json_value(tool_use.input()),
513                });
514            }
515            _ => { /* Ignore other content block types for now */ }
516        }
517    }
518
519    let content = text_parts.join("");
520
521    if tool_calls.is_empty() {
522        Message::ai(content)
523    } else {
524        Message::ai_with_tool_calls(content, tool_calls)
525    }
526}
527
528// ---------------------------------------------------------------------------
529// Document <-> serde_json::Value conversion
530// ---------------------------------------------------------------------------
531
532/// Convert a `serde_json::Value` to an `aws_smithy_types::Document`.
533pub(crate) fn json_value_to_document(value: &Value) -> SmithyDocument {
534    match value {
535        Value::Null => SmithyDocument::Null,
536        Value::Bool(b) => SmithyDocument::Bool(*b),
537        Value::Number(n) => {
538            if let Some(i) = n.as_i64() {
539                SmithyDocument::Number(aws_smithy_types::Number::NegInt(i))
540            } else if let Some(u) = n.as_u64() {
541                SmithyDocument::Number(aws_smithy_types::Number::PosInt(u))
542            } else if let Some(f) = n.as_f64() {
543                SmithyDocument::Number(aws_smithy_types::Number::Float(f))
544            } else {
545                SmithyDocument::Null
546            }
547        }
548        Value::String(s) => SmithyDocument::String(s.clone()),
549        Value::Array(arr) => {
550            SmithyDocument::Array(arr.iter().map(json_value_to_document).collect())
551        }
552        Value::Object(obj) => {
553            let map: HashMap<String, SmithyDocument> = obj
554                .iter()
555                .map(|(k, v)| (k.clone(), json_value_to_document(v)))
556                .collect();
557            SmithyDocument::Object(map)
558        }
559    }
560}
561
562/// Convert an `aws_smithy_types::Document` to a `serde_json::Value`.
563pub(crate) fn document_to_json_value(doc: &SmithyDocument) -> Value {
564    match doc {
565        SmithyDocument::Null => Value::Null,
566        SmithyDocument::Bool(b) => Value::Bool(*b),
567        SmithyDocument::Number(n) => match *n {
568            aws_smithy_types::Number::PosInt(u) => {
569                serde_json::json!(u)
570            }
571            aws_smithy_types::Number::NegInt(i) => {
572                serde_json::json!(i)
573            }
574            aws_smithy_types::Number::Float(f) => {
575                serde_json::json!(f)
576            }
577        },
578        SmithyDocument::String(s) => Value::String(s.clone()),
579        SmithyDocument::Array(arr) => {
580            Value::Array(arr.iter().map(document_to_json_value).collect())
581        }
582        SmithyDocument::Object(obj) => {
583            let map: serde_json::Map<String, Value> = obj
584                .iter()
585                .map(|(k, v)| (k.clone(), document_to_json_value(v)))
586                .collect();
587            Value::Object(map)
588        }
589    }
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595
596    #[test]
597    fn json_value_to_document_round_trip() {
598        let original = serde_json::json!({
599            "type": "object",
600            "properties": {
601                "name": {"type": "string"},
602                "age": {"type": "integer"}
603            },
604            "required": ["name"]
605        });
606
607        let doc = json_value_to_document(&original);
608        let back = document_to_json_value(&doc);
609        assert_eq!(original, back);
610    }
611
612    #[test]
613    fn json_value_to_document_primitives() {
614        assert!(matches!(
615            json_value_to_document(&Value::Null),
616            SmithyDocument::Null
617        ));
618        assert!(matches!(
619            json_value_to_document(&Value::Bool(true)),
620            SmithyDocument::Bool(true)
621        ));
622        assert!(matches!(
623            json_value_to_document(&serde_json::json!("hello")),
624            SmithyDocument::String(_)
625        ));
626    }
627
628    #[test]
629    fn convert_system_messages() {
630        let messages = vec![
631            Message::system("You are a helpful assistant."),
632            Message::human("Hello!"),
633        ];
634
635        let (system_blocks, bedrock_messages) = convert_messages(&messages);
636        assert_eq!(system_blocks.len(), 1);
637        assert_eq!(bedrock_messages.len(), 1);
638    }
639
640    #[test]
641    fn convert_tool_messages() {
642        let messages = vec![
643            Message::human("What is the weather?"),
644            Message::ai_with_tool_calls(
645                "",
646                vec![ToolCall {
647                    id: "tc_1".to_string(),
648                    name: "get_weather".to_string(),
649                    arguments: serde_json::json!({"city": "NYC"}),
650                }],
651            ),
652            Message::tool("Sunny, 72F", "tc_1"),
653        ];
654
655        let (system_blocks, bedrock_messages) = convert_messages(&messages);
656        assert!(system_blocks.is_empty());
657        assert_eq!(bedrock_messages.len(), 3);
658
659        // First message is user.
660        assert_eq!(*bedrock_messages[0].role(), ConversationRole::User);
661        // Second is assistant with tool use.
662        assert_eq!(*bedrock_messages[1].role(), ConversationRole::Assistant);
663        // Third is user with tool result.
664        assert_eq!(*bedrock_messages[2].role(), ConversationRole::User);
665    }
666
667    #[test]
668    fn convert_remove_messages_are_skipped() {
669        let messages = vec![
670            Message::human("Hi"),
671            Message::remove("some-id"),
672            Message::ai("Hello!"),
673        ];
674
675        let (_, bedrock_messages) = convert_messages(&messages);
676        assert_eq!(bedrock_messages.len(), 2);
677    }
678
679    #[test]
680    fn parse_text_only_message() {
681        let msg = bedrock_types::Message::builder()
682            .role(ConversationRole::Assistant)
683            .content(ContentBlock::Text("Hello world".to_string()))
684            .build()
685            .unwrap();
686
687        let parsed = parse_bedrock_message(&msg);
688        assert!(parsed.is_ai());
689        assert_eq!(parsed.content(), "Hello world");
690        assert!(parsed.tool_calls().is_empty());
691    }
692
693    #[test]
694    fn parse_message_with_tool_use() {
695        let tool_use = ToolUseBlock::builder()
696            .tool_use_id("tc_1")
697            .name("calculator")
698            .input(json_value_to_document(&serde_json::json!({"expr": "1+1"})))
699            .build()
700            .unwrap();
701
702        let msg = bedrock_types::Message::builder()
703            .role(ConversationRole::Assistant)
704            .content(ContentBlock::Text("Let me calculate.".to_string()))
705            .content(ContentBlock::ToolUse(tool_use))
706            .build()
707            .unwrap();
708
709        let parsed = parse_bedrock_message(&msg);
710        assert!(parsed.is_ai());
711        assert_eq!(parsed.content(), "Let me calculate.");
712        assert_eq!(parsed.tool_calls().len(), 1);
713        assert_eq!(parsed.tool_calls()[0].id, "tc_1");
714        assert_eq!(parsed.tool_calls()[0].name, "calculator");
715        assert_eq!(parsed.tool_calls()[0].arguments, serde_json::json!({"expr": "1+1"}));
716    }
717}