Skip to main content

synaptic_bedrock/
chat_model.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use aws_sdk_bedrockruntime::types::{
5    self as bedrock_types, ContentBlock, ConversationRole, InferenceConfiguration,
6    SystemContentBlock, ToolConfiguration, ToolInputSchema, ToolResultBlock,
7    ToolResultContentBlock, ToolSpecification, ToolUseBlock,
8};
9use aws_smithy_types::Document as SmithyDocument;
10use serde_json::Value;
11use synaptic_core::{
12    AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError,
13    TokenUsage, ToolCall, ToolCallChunk, ToolChoice,
14};
15
16// ---------------------------------------------------------------------------
17// Configuration
18// ---------------------------------------------------------------------------
19
20/// Configuration for the AWS Bedrock chat model.
21#[derive(Debug, Clone)]
22pub struct BedrockConfig {
23    /// The model identifier (e.g., `"anthropic.claude-3-5-sonnet-20241022-v2:0"`).
24    pub model_id: String,
25    /// AWS region override. Falls back to `AWS_REGION` env var or `"us-east-1"`.
26    pub region: Option<String>,
27    /// Maximum number of tokens to generate.
28    pub max_tokens: Option<i32>,
29    /// Sampling temperature (0.0 - 1.0).
30    pub temperature: Option<f32>,
31    /// Nucleus sampling parameter.
32    pub top_p: Option<f32>,
33    /// Stop sequences.
34    pub stop: Option<Vec<String>>,
35}
36
37impl BedrockConfig {
38    /// Create a new configuration with the given model ID.
39    pub fn new(model_id: impl Into<String>) -> Self {
40        Self {
41            model_id: model_id.into(),
42            region: None,
43            max_tokens: None,
44            temperature: None,
45            top_p: None,
46            stop: None,
47        }
48    }
49
50    /// Set the AWS region.
51    pub fn with_region(mut self, region: impl Into<String>) -> Self {
52        self.region = Some(region.into());
53        self
54    }
55
56    /// Set the maximum number of output tokens.
57    pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
58        self.max_tokens = Some(max_tokens);
59        self
60    }
61
62    /// Set the sampling temperature.
63    pub fn with_temperature(mut self, temperature: f32) -> Self {
64        self.temperature = Some(temperature);
65        self
66    }
67
68    /// Set the nucleus sampling parameter.
69    pub fn with_top_p(mut self, top_p: f32) -> Self {
70        self.top_p = Some(top_p);
71        self
72    }
73
74    /// Set stop sequences.
75    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
76        self.stop = Some(stop);
77        self
78    }
79}
80
81// ---------------------------------------------------------------------------
82// BedrockChatModel
83// ---------------------------------------------------------------------------
84
85/// A [`ChatModel`] implementation backed by AWS Bedrock's Converse API.
86///
87/// Supports both synchronous and streaming responses, tool calling,
88/// and all Bedrock-supported foundation models.
89pub struct BedrockChatModel {
90    config: BedrockConfig,
91    client: aws_sdk_bedrockruntime::Client,
92}
93
94impl BedrockChatModel {
95    /// Create a new `BedrockChatModel` by loading AWS configuration from the
96    /// environment. Respects `AWS_REGION`, `AWS_ACCESS_KEY_ID`,
97    /// `AWS_SECRET_ACCESS_KEY`, and other standard AWS SDK environment variables.
98    pub async fn new(config: BedrockConfig) -> Self {
99        let mut aws_config_loader = aws_config::from_env();
100
101        if let Some(ref region) = config.region {
102            aws_config_loader = aws_config_loader.region(aws_config::Region::new(region.clone()));
103        }
104
105        let aws_config = aws_config_loader.load().await;
106        let client = aws_sdk_bedrockruntime::Client::new(&aws_config);
107
108        Self { config, client }
109    }
110
111    /// Create a new `BedrockChatModel` with a pre-existing AWS SDK client.
112    pub fn from_client(config: BedrockConfig, client: aws_sdk_bedrockruntime::Client) -> Self {
113        Self { config, client }
114    }
115
116    /// Build the inference configuration from our config.
117    fn build_inference_config(&self) -> Option<InferenceConfiguration> {
118        let has_any = self.config.max_tokens.is_some()
119            || self.config.temperature.is_some()
120            || self.config.top_p.is_some()
121            || self.config.stop.is_some();
122
123        if !has_any {
124            return None;
125        }
126
127        let mut builder = InferenceConfiguration::builder();
128
129        if let Some(max_tokens) = self.config.max_tokens {
130            builder = builder.max_tokens(max_tokens);
131        }
132        if let Some(temperature) = self.config.temperature {
133            builder = builder.temperature(temperature);
134        }
135        if let Some(top_p) = self.config.top_p {
136            builder = builder.top_p(top_p);
137        }
138        if let Some(ref stop) = self.config.stop {
139            for s in stop {
140                builder = builder.stop_sequences(s.clone());
141            }
142        }
143
144        Some(builder.build())
145    }
146
147    /// Build the tool configuration from a ChatRequest.
148    fn build_tool_config(&self, request: &ChatRequest) -> Option<ToolConfiguration> {
149        if request.tools.is_empty() {
150            return None;
151        }
152
153        let tools: Vec<bedrock_types::Tool> = request
154            .tools
155            .iter()
156            .map(|td| {
157                let spec = ToolSpecification::builder()
158                    .name(&td.name)
159                    .description(&td.description)
160                    .input_schema(ToolInputSchema::Json(json_value_to_document(
161                        &td.parameters,
162                    )))
163                    .build()
164                    .expect("tool specification build should not fail");
165
166                bedrock_types::Tool::ToolSpec(spec)
167            })
168            .collect();
169
170        let mut builder = ToolConfiguration::builder();
171        for tool in tools {
172            builder = builder.tools(tool);
173        }
174
175        if let Some(ref choice) = request.tool_choice {
176            let bedrock_choice = match choice {
177                ToolChoice::Auto => bedrock_types::ToolChoice::Auto(
178                    bedrock_types::AutoToolChoice::builder().build(),
179                ),
180                ToolChoice::Required => {
181                    bedrock_types::ToolChoice::Any(bedrock_types::AnyToolChoice::builder().build())
182                }
183                ToolChoice::None => {
184                    // Bedrock does not have a "none" tool choice. We omit tools instead,
185                    // but since we already built them, just default to Auto.
186                    bedrock_types::ToolChoice::Auto(
187                        bedrock_types::AutoToolChoice::builder().build(),
188                    )
189                }
190                ToolChoice::Specific(name) => bedrock_types::ToolChoice::Tool(
191                    bedrock_types::SpecificToolChoice::builder()
192                        .name(name)
193                        .build()
194                        .expect("specific tool choice build should not fail"),
195                ),
196            };
197            builder = builder.tool_choice(bedrock_choice);
198        }
199
200        Some(
201            builder
202                .build()
203                .expect("tool configuration build should not fail"),
204        )
205    }
206}
207
208#[async_trait]
209impl ChatModel for BedrockChatModel {
210    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
211        let (system_blocks, messages) = convert_messages(&request.messages);
212
213        let mut converse = self.client.converse().model_id(&self.config.model_id);
214
215        // Add system prompts.
216        for block in system_blocks {
217            converse = converse.system(block);
218        }
219
220        // Add messages.
221        for msg in messages {
222            converse = converse.messages(msg);
223        }
224
225        // Add inference config.
226        if let Some(inference_config) = self.build_inference_config() {
227            converse = converse.inference_config(inference_config);
228        }
229
230        // Add tool config.
231        if let Some(tool_config) = self.build_tool_config(&request) {
232            converse = converse.tool_config(tool_config);
233        }
234
235        let output = converse
236            .send()
237            .await
238            .map_err(|e| SynapticError::Model(format!("Bedrock Converse API error: {e}")))?;
239
240        // Parse usage.
241        let usage = output.usage().map(|u| TokenUsage {
242            input_tokens: u.input_tokens() as u32,
243            output_tokens: u.output_tokens() as u32,
244            total_tokens: u.total_tokens() as u32,
245            input_details: None,
246            output_details: None,
247        });
248
249        // Parse the output message.
250        let message = match output.output() {
251            Some(bedrock_types::ConverseOutput::Message(msg)) => parse_bedrock_message(msg),
252            _ => Message::ai(""),
253        };
254
255        Ok(ChatResponse { message, usage })
256    }
257
258    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
259        Box::pin(async_stream::stream! {
260            let (system_blocks, messages) = convert_messages(&request.messages);
261
262            let mut converse_stream = self
263                .client
264                .converse_stream()
265                .model_id(&self.config.model_id);
266
267            for block in system_blocks {
268                converse_stream = converse_stream.system(block);
269            }
270
271            for msg in messages {
272                converse_stream = converse_stream.messages(msg);
273            }
274
275            if let Some(inference_config) = self.build_inference_config() {
276                converse_stream = converse_stream.inference_config(inference_config);
277            }
278
279            if let Some(tool_config) = self.build_tool_config(&request) {
280                converse_stream = converse_stream.tool_config(tool_config);
281            }
282
283            let output = match converse_stream.send().await {
284                Ok(o) => o,
285                Err(e) => {
286                    yield Err(SynapticError::Model(format!(
287                        "Bedrock ConverseStream API error: {e}"
288                    )));
289                    return;
290                }
291            };
292
293            let mut stream = output.stream;
294
295            // Track current tool use blocks being built during streaming.
296            let mut current_tool_id: Option<String> = None;
297            let mut current_tool_name: Option<String> = None;
298            let mut current_tool_input: String = String::new();
299
300            loop {
301                match stream.recv().await {
302                    Ok(Some(event)) => {
303                        match event {
304                            bedrock_types::ConverseStreamOutput::ContentBlockStart(start_event) => {
305                                if let Some(bedrock_types::ContentBlockStart::ToolUse(tool_start)) = start_event.start() {
306                                    current_tool_id = Some(tool_start.tool_use_id().to_string());
307                                    current_tool_name = Some(tool_start.name().to_string());
308                                    current_tool_input.clear();
309
310                                    yield Ok(AIMessageChunk {
311                                        tool_call_chunks: vec![ToolCallChunk {
312                                            id: Some(tool_start.tool_use_id().to_string()),
313                                            name: Some(tool_start.name().to_string()),
314                                            arguments: None,
315                                            index: Some(start_event.content_block_index() as usize),
316                                        }],
317                                        ..Default::default()
318                                    });
319                                }
320                            }
321                            bedrock_types::ConverseStreamOutput::ContentBlockDelta(delta_event) => {
322                                if let Some(delta) = delta_event.delta() {
323                                    match delta {
324                                        bedrock_types::ContentBlockDelta::Text(text) => {
325                                            yield Ok(AIMessageChunk {
326                                                content: text.to_string(),
327                                                ..Default::default()
328                                            });
329                                        }
330                                        bedrock_types::ContentBlockDelta::ToolUse(tool_delta) => {
331                                            let input_fragment = tool_delta.input();
332                                            current_tool_input.push_str(input_fragment);
333
334                                            yield Ok(AIMessageChunk {
335                                                tool_call_chunks: vec![ToolCallChunk {
336                                                    id: current_tool_id.clone(),
337                                                    name: current_tool_name.clone(),
338                                                    arguments: Some(input_fragment.to_string()),
339                                                    index: Some(delta_event.content_block_index() as usize),
340                                                }],
341                                                ..Default::default()
342                                            });
343                                        }
344                                        _ => { /* ignore other delta types */ }
345                                    }
346                                }
347                            }
348                            bedrock_types::ConverseStreamOutput::ContentBlockStop(_) => {
349                                // If we were accumulating a tool call, emit the complete ToolCall.
350                                if let (Some(id), Some(name)) = (current_tool_id.take(), current_tool_name.take()) {
351                                    let arguments: Value = serde_json::from_str(&current_tool_input)
352                                        .unwrap_or(Value::Object(Default::default()));
353                                    current_tool_input.clear();
354
355                                    yield Ok(AIMessageChunk {
356                                        tool_calls: vec![ToolCall {
357                                            id,
358                                            name,
359                                            arguments,
360                                        }],
361                                        ..Default::default()
362                                    });
363                                }
364                            }
365                            bedrock_types::ConverseStreamOutput::Metadata(meta) => {
366                                if let Some(u) = meta.usage() {
367                                    yield Ok(AIMessageChunk {
368                                        usage: Some(TokenUsage {
369                                            input_tokens: u.input_tokens() as u32,
370                                            output_tokens: u.output_tokens() as u32,
371                                            total_tokens: u.total_tokens() as u32,
372                                            input_details: None,
373                                            output_details: None,
374                                        }),
375                                        ..Default::default()
376                                    });
377                                }
378                            }
379                            _ => { /* MessageStart, MessageStop, Unknown — skip */ }
380                        }
381                    }
382                    Ok(None) => break,
383                    Err(e) => {
384                        yield Err(SynapticError::Model(format!(
385                            "Bedrock stream error: {e}"
386                        )));
387                        break;
388                    }
389                }
390            }
391        })
392    }
393}
394
395// ---------------------------------------------------------------------------
396// Message conversion helpers
397// ---------------------------------------------------------------------------
398
399/// Convert Synaptic messages into Bedrock system blocks and conversation messages.
400///
401/// System messages are extracted into `SystemContentBlock` entries.
402/// Human, AI, and Tool messages are mapped to Bedrock `Message` types.
403fn convert_messages(
404    messages: &[Message],
405) -> (Vec<SystemContentBlock>, Vec<bedrock_types::Message>) {
406    let mut system_blocks = Vec::new();
407    let mut bedrock_messages: Vec<bedrock_types::Message> = Vec::new();
408
409    for msg in messages {
410        match msg {
411            Message::System { content, .. } => {
412                system_blocks.push(SystemContentBlock::Text(content.clone()));
413            }
414            Message::Human { content, .. } => {
415                let bedrock_msg = bedrock_types::Message::builder()
416                    .role(ConversationRole::User)
417                    .content(ContentBlock::Text(content.clone()))
418                    .build()
419                    .expect("message build should not fail");
420                bedrock_messages.push(bedrock_msg);
421            }
422            Message::AI {
423                content,
424                tool_calls,
425                ..
426            } => {
427                let mut blocks: Vec<ContentBlock> = Vec::new();
428
429                if !content.is_empty() {
430                    blocks.push(ContentBlock::Text(content.clone()));
431                }
432
433                for tc in tool_calls {
434                    let tool_use = ToolUseBlock::builder()
435                        .tool_use_id(&tc.id)
436                        .name(&tc.name)
437                        .input(json_value_to_document(&tc.arguments))
438                        .build()
439                        .expect("tool use block build should not fail");
440                    blocks.push(ContentBlock::ToolUse(tool_use));
441                }
442
443                // Bedrock requires at least one content block.
444                if blocks.is_empty() {
445                    blocks.push(ContentBlock::Text(String::new()));
446                }
447
448                let bedrock_msg = bedrock_types::Message::builder()
449                    .role(ConversationRole::Assistant)
450                    .set_content(Some(blocks))
451                    .build()
452                    .expect("message build should not fail");
453                bedrock_messages.push(bedrock_msg);
454            }
455            Message::Tool {
456                content,
457                tool_call_id,
458                ..
459            } => {
460                let tool_result = ToolResultBlock::builder()
461                    .tool_use_id(tool_call_id)
462                    .content(ToolResultContentBlock::Text(content.clone()))
463                    .build()
464                    .expect("tool result block build should not fail");
465
466                let bedrock_msg = bedrock_types::Message::builder()
467                    .role(ConversationRole::User)
468                    .content(ContentBlock::ToolResult(tool_result))
469                    .build()
470                    .expect("message build should not fail");
471                bedrock_messages.push(bedrock_msg);
472            }
473            Message::Chat { content, .. } => {
474                // Map custom roles to user by default.
475                let bedrock_msg = bedrock_types::Message::builder()
476                    .role(ConversationRole::User)
477                    .content(ContentBlock::Text(content.clone()))
478                    .build()
479                    .expect("message build should not fail");
480                bedrock_messages.push(bedrock_msg);
481            }
482            Message::Remove { .. } => { /* Skip remove messages */ }
483        }
484    }
485
486    (system_blocks, bedrock_messages)
487}
488
489/// Parse a Bedrock response message into a Synaptic `Message`.
490fn parse_bedrock_message(msg: &bedrock_types::Message) -> Message {
491    let mut text_parts: Vec<String> = Vec::new();
492    let mut tool_calls: Vec<ToolCall> = Vec::new();
493
494    for block in msg.content() {
495        match block {
496            ContentBlock::Text(text) => {
497                text_parts.push(text.clone());
498            }
499            ContentBlock::ToolUse(tool_use) => {
500                tool_calls.push(ToolCall {
501                    id: tool_use.tool_use_id().to_string(),
502                    name: tool_use.name().to_string(),
503                    arguments: document_to_json_value(tool_use.input()),
504                });
505            }
506            _ => { /* Ignore other content block types for now */ }
507        }
508    }
509
510    let content = text_parts.join("");
511
512    if tool_calls.is_empty() {
513        Message::ai(content)
514    } else {
515        Message::ai_with_tool_calls(content, tool_calls)
516    }
517}
518
519// ---------------------------------------------------------------------------
520// Document <-> serde_json::Value conversion
521// ---------------------------------------------------------------------------
522
523/// Convert a `serde_json::Value` to an `aws_smithy_types::Document`.
524pub(crate) fn json_value_to_document(value: &Value) -> SmithyDocument {
525    match value {
526        Value::Null => SmithyDocument::Null,
527        Value::Bool(b) => SmithyDocument::Bool(*b),
528        Value::Number(n) => {
529            if let Some(i) = n.as_i64() {
530                SmithyDocument::Number(aws_smithy_types::Number::NegInt(i))
531            } else if let Some(u) = n.as_u64() {
532                SmithyDocument::Number(aws_smithy_types::Number::PosInt(u))
533            } else if let Some(f) = n.as_f64() {
534                SmithyDocument::Number(aws_smithy_types::Number::Float(f))
535            } else {
536                SmithyDocument::Null
537            }
538        }
539        Value::String(s) => SmithyDocument::String(s.clone()),
540        Value::Array(arr) => {
541            SmithyDocument::Array(arr.iter().map(json_value_to_document).collect())
542        }
543        Value::Object(obj) => {
544            let map: HashMap<String, SmithyDocument> = obj
545                .iter()
546                .map(|(k, v)| (k.clone(), json_value_to_document(v)))
547                .collect();
548            SmithyDocument::Object(map)
549        }
550    }
551}
552
553/// Convert an `aws_smithy_types::Document` to a `serde_json::Value`.
554pub(crate) fn document_to_json_value(doc: &SmithyDocument) -> Value {
555    match doc {
556        SmithyDocument::Null => Value::Null,
557        SmithyDocument::Bool(b) => Value::Bool(*b),
558        SmithyDocument::Number(n) => match *n {
559            aws_smithy_types::Number::PosInt(u) => {
560                serde_json::json!(u)
561            }
562            aws_smithy_types::Number::NegInt(i) => {
563                serde_json::json!(i)
564            }
565            aws_smithy_types::Number::Float(f) => {
566                serde_json::json!(f)
567            }
568        },
569        SmithyDocument::String(s) => Value::String(s.clone()),
570        SmithyDocument::Array(arr) => {
571            Value::Array(arr.iter().map(document_to_json_value).collect())
572        }
573        SmithyDocument::Object(obj) => {
574            let map: serde_json::Map<String, Value> = obj
575                .iter()
576                .map(|(k, v)| (k.clone(), document_to_json_value(v)))
577                .collect();
578            Value::Object(map)
579        }
580    }
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586
587    #[test]
588    fn json_value_to_document_round_trip() {
589        let original = serde_json::json!({
590            "type": "object",
591            "properties": {
592                "name": {"type": "string"},
593                "age": {"type": "integer"}
594            },
595            "required": ["name"]
596        });
597
598        let doc = json_value_to_document(&original);
599        let back = document_to_json_value(&doc);
600        assert_eq!(original, back);
601    }
602
603    #[test]
604    fn json_value_to_document_primitives() {
605        assert!(matches!(
606            json_value_to_document(&Value::Null),
607            SmithyDocument::Null
608        ));
609        assert!(matches!(
610            json_value_to_document(&Value::Bool(true)),
611            SmithyDocument::Bool(true)
612        ));
613        assert!(matches!(
614            json_value_to_document(&serde_json::json!("hello")),
615            SmithyDocument::String(_)
616        ));
617    }
618
619    #[test]
620    fn convert_system_messages() {
621        let messages = vec![
622            Message::system("You are a helpful assistant."),
623            Message::human("Hello!"),
624        ];
625
626        let (system_blocks, bedrock_messages) = convert_messages(&messages);
627        assert_eq!(system_blocks.len(), 1);
628        assert_eq!(bedrock_messages.len(), 1);
629    }
630
631    #[test]
632    fn convert_tool_messages() {
633        let messages = vec![
634            Message::human("What is the weather?"),
635            Message::ai_with_tool_calls(
636                "",
637                vec![ToolCall {
638                    id: "tc_1".to_string(),
639                    name: "get_weather".to_string(),
640                    arguments: serde_json::json!({"city": "NYC"}),
641                }],
642            ),
643            Message::tool("Sunny, 72F", "tc_1"),
644        ];
645
646        let (system_blocks, bedrock_messages) = convert_messages(&messages);
647        assert!(system_blocks.is_empty());
648        assert_eq!(bedrock_messages.len(), 3);
649
650        // First message is user.
651        assert_eq!(*bedrock_messages[0].role(), ConversationRole::User);
652        // Second is assistant with tool use.
653        assert_eq!(*bedrock_messages[1].role(), ConversationRole::Assistant);
654        // Third is user with tool result.
655        assert_eq!(*bedrock_messages[2].role(), ConversationRole::User);
656    }
657
658    #[test]
659    fn convert_remove_messages_are_skipped() {
660        let messages = vec![
661            Message::human("Hi"),
662            Message::remove("some-id"),
663            Message::ai("Hello!"),
664        ];
665
666        let (_, bedrock_messages) = convert_messages(&messages);
667        assert_eq!(bedrock_messages.len(), 2);
668    }
669
670    #[test]
671    fn parse_text_only_message() {
672        let msg = bedrock_types::Message::builder()
673            .role(ConversationRole::Assistant)
674            .content(ContentBlock::Text("Hello world".to_string()))
675            .build()
676            .unwrap();
677
678        let parsed = parse_bedrock_message(&msg);
679        assert!(parsed.is_ai());
680        assert_eq!(parsed.content(), "Hello world");
681        assert!(parsed.tool_calls().is_empty());
682    }
683
684    #[test]
685    fn parse_message_with_tool_use() {
686        let tool_use = ToolUseBlock::builder()
687            .tool_use_id("tc_1")
688            .name("calculator")
689            .input(json_value_to_document(&serde_json::json!({"expr": "1+1"})))
690            .build()
691            .unwrap();
692
693        let msg = bedrock_types::Message::builder()
694            .role(ConversationRole::Assistant)
695            .content(ContentBlock::Text("Let me calculate.".to_string()))
696            .content(ContentBlock::ToolUse(tool_use))
697            .build()
698            .unwrap();
699
700        let parsed = parse_bedrock_message(&msg);
701        assert!(parsed.is_ai());
702        assert_eq!(parsed.content(), "Let me calculate.");
703        assert_eq!(parsed.tool_calls().len(), 1);
704        assert_eq!(parsed.tool_calls()[0].id, "tc_1");
705        assert_eq!(parsed.tool_calls()[0].name, "calculator");
706        assert_eq!(
707            parsed.tool_calls()[0].arguments,
708            serde_json::json!({"expr": "1+1"})
709        );
710    }
711}