1use super::types::completion_request::AwsCompletionRequest;
2use super::{completion::CompletionModel, types::errors::AwsSdkConverseStreamError};
3use async_stream::stream;
4use aws_sdk_bedrockruntime::types as aws_bedrock;
5use rig::completion::GetTokenUsage;
6use rig::streaming::StreamingCompletionResponse;
7use rig::{
8 completion::CompletionError,
9 streaming::{RawStreamingChoice, RawStreamingToolCall},
10};
11use serde::{Deserialize, Serialize};
12
13#[derive(Clone, Deserialize, Serialize)]
14pub struct BedrockStreamingResponse {
15 pub usage: Option<BedrockUsage>,
16}
17
18#[derive(Clone, Deserialize, Serialize)]
19pub struct BedrockUsage {
20 pub input_tokens: i32,
21 pub output_tokens: i32,
22 pub total_tokens: i32,
23}
24
25impl GetTokenUsage for BedrockStreamingResponse {
26 fn token_usage(&self) -> Option<rig::completion::Usage> {
27 self.usage.as_ref().map(|u| rig::completion::Usage {
28 input_tokens: u.input_tokens as u64,
29 output_tokens: u.output_tokens as u64,
30 total_tokens: u.total_tokens as u64,
31 })
32 }
33}
34
35#[derive(Default)]
36struct ToolCallState {
37 name: String,
38 id: String,
39 input_json: String,
40}
41
42#[derive(Default)]
43struct ReasoningState {
44 content: String,
45 signature: Option<String>,
46}
47
48impl CompletionModel {
49 pub(crate) async fn stream(
50 &self,
51 completion_request: rig::completion::CompletionRequest,
52 ) -> Result<StreamingCompletionResponse<BedrockStreamingResponse>, CompletionError> {
53 let request = AwsCompletionRequest(completion_request);
54
55 let mut converse_builder = self
56 .client
57 .get_inner()
58 .await
59 .converse_stream()
60 .model_id(self.model.as_str());
61
62 let tool_config = request.tools_config()?;
63 let prompt_with_history = request.messages()?;
64 converse_builder = converse_builder
65 .set_additional_model_request_fields(request.additional_params())
66 .set_inference_config(request.inference_config())
67 .set_tool_config(tool_config)
68 .set_system(request.system_prompt())
69 .set_messages(Some(prompt_with_history));
70
71 let response = converse_builder.send().await.map_err(|sdk_error| {
72 Into::<CompletionError>::into(AwsSdkConverseStreamError(sdk_error))
73 })?;
74
75 let stream = Box::pin(stream! {
76 let mut current_tool_call: Option<ToolCallState> = None;
77 let mut current_reasoning: Option<ReasoningState> = None;
78 let mut stream = response.stream;
79 while let Ok(Some(output)) = stream.recv().await {
80 match output {
81 aws_bedrock::ConverseStreamOutput::ContentBlockDelta(event) => {
82 let delta = event.delta.ok_or(CompletionError::ProviderError("The delta for a content block is missing".into()))?;
83 match delta {
84 aws_bedrock::ContentBlockDelta::Text(text) => {
85 if current_tool_call.is_none() {
86 yield Ok(RawStreamingChoice::Message(text))
87 }
88 },
89 aws_bedrock::ContentBlockDelta::ToolUse(tool) => {
90 if let Some(ref mut tool_call) = current_tool_call {
91 let delta = tool.input().to_string();
92 tool_call.input_json.push_str(&delta);
93
94 yield Ok(RawStreamingChoice::ToolCallDelta {
96 id: tool_call.id.clone(),
97 delta,
98 });
99 }
100 },
101 aws_bedrock::ContentBlockDelta::ReasoningContent(reasoning) => {
102 match reasoning {
103 aws_bedrock::ReasoningContentBlockDelta::Text(text) => {
104 if current_reasoning.is_none() {
105 current_reasoning = Some(ReasoningState::default());
106 }
107
108 if let Some(ref mut state) = current_reasoning {
109 state.content.push_str(text.as_str());
110 }
111
112 if !text.is_empty() {
113 yield Ok(RawStreamingChoice::ReasoningDelta {
114 reasoning: text.clone(),
115 id: None,
116 })
117 }
118 },
119 aws_bedrock::ReasoningContentBlockDelta::Signature(signature) => {
120 if current_reasoning.is_none() {
121 current_reasoning = Some(ReasoningState::default());
122 }
123
124 if let Some(ref mut state) = current_reasoning {
125 state.signature = Some(signature.clone());
126 }
127 },
128 _ => {}
129 }
130 },
131 _ => {}
132 }
133 },
134 aws_bedrock::ConverseStreamOutput::ContentBlockStart(event) => {
135 match event.start.ok_or(CompletionError::ProviderError("ContentBlockStart has no data".into()))? {
136 aws_bedrock::ContentBlockStart::ToolUse(tool_use) => {
137 current_tool_call = Some(ToolCallState {
138 name: tool_use.name,
139 id: tool_use.tool_use_id,
140 input_json: String::new(),
141 });
142 },
143 _ => yield Err(CompletionError::ProviderError("Stream is empty".into()))
144 }
145 },
146 aws_bedrock::ConverseStreamOutput::ContentBlockStop(_event) => {
147 if let Some(reasoning_state) = current_reasoning.take()
148 && !reasoning_state.content.is_empty() {
149 yield Ok(RawStreamingChoice::Reasoning {
150 reasoning: reasoning_state.content,
151 id: None,
152 signature: reasoning_state.signature,
153 })
154 }
155 },
156 aws_bedrock::ConverseStreamOutput::MessageStop(message_stop_event) => {
157 match message_stop_event.stop_reason {
158 aws_bedrock::StopReason::ToolUse => {
159 if let Some(tool_call) = current_tool_call.take() {
160 let tool_input = if tool_call.input_json.is_empty() {
162 serde_json::json!({})
163 } else {
164 serde_json::from_str(tool_call.input_json.as_str())?
165 };
166 yield Ok(RawStreamingChoice::ToolCall(RawStreamingToolCall::new(tool_call.id, tool_call.name, tool_input)));
167 } else {
168 yield Err(CompletionError::ProviderError("Failed to call tool".into()))
169 }
170 }
171 aws_bedrock::StopReason::MaxTokens => {
172 yield Err(CompletionError::ProviderError("Exceeded max tokens".into()))
173 }
174 _ => {}
175 }
176 },
177 aws_bedrock::ConverseStreamOutput::Metadata(metadata_event) => {
178 if let Some(usage) = metadata_event.usage {
180 yield Ok(RawStreamingChoice::FinalResponse(BedrockStreamingResponse {
181 usage: Some(BedrockUsage {
182 input_tokens: usage.input_tokens,
183 output_tokens: usage.output_tokens,
184 total_tokens: usage.total_tokens,
185 }),
186 }));
187 }
188 },
189 _ => {}
190 }
191 }
192 });
193
194 Ok(StreamingCompletionResponse::stream(stream))
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn test_bedrock_usage_creation() {
204 let usage = BedrockUsage {
205 input_tokens: 100,
206 output_tokens: 50,
207 total_tokens: 150,
208 };
209
210 assert_eq!(usage.input_tokens, 100);
211 assert_eq!(usage.output_tokens, 50);
212 assert_eq!(usage.total_tokens, 150);
213 }
214
215 #[test]
216 fn test_bedrock_streaming_response_with_usage() {
217 let response = BedrockStreamingResponse {
218 usage: Some(BedrockUsage {
219 input_tokens: 200,
220 output_tokens: 75,
221 total_tokens: 275,
222 }),
223 };
224
225 let rig_usage = response.token_usage();
226 assert!(rig_usage.is_some());
227
228 let usage = rig_usage.unwrap();
229 assert_eq!(usage.input_tokens, 200);
230 assert_eq!(usage.output_tokens, 75);
231 assert_eq!(usage.total_tokens, 275);
232 }
233
234 #[test]
235 fn test_bedrock_streaming_response_without_usage() {
236 let response = BedrockStreamingResponse { usage: None };
237
238 let rig_usage = response.token_usage();
239 assert!(rig_usage.is_none());
240 }
241
242 #[test]
243 fn test_get_token_usage_trait() {
244 let response = BedrockStreamingResponse {
245 usage: Some(BedrockUsage {
246 input_tokens: 448,
247 output_tokens: 68,
248 total_tokens: 516,
249 }),
250 };
251
252 let usage = response.token_usage().expect("Usage should be present");
254 assert_eq!(usage.input_tokens, 448);
255 assert_eq!(usage.output_tokens, 68);
256 assert_eq!(usage.total_tokens, 516);
257 }
258
259 #[test]
260 fn test_bedrock_usage_serde() {
261 let usage = BedrockUsage {
262 input_tokens: 100,
263 output_tokens: 50,
264 total_tokens: 150,
265 };
266
267 let json = serde_json::to_string(&usage).expect("Should serialize");
269 assert!(json.contains("\"input_tokens\":100"));
270 assert!(json.contains("\"output_tokens\":50"));
271 assert!(json.contains("\"total_tokens\":150"));
272
273 let deserialized: BedrockUsage = serde_json::from_str(&json).expect("Should deserialize");
275 assert_eq!(deserialized.input_tokens, usage.input_tokens);
276 assert_eq!(deserialized.output_tokens, usage.output_tokens);
277 assert_eq!(deserialized.total_tokens, usage.total_tokens);
278 }
279
280 #[test]
281 fn test_bedrock_streaming_response_serde() {
282 let response = BedrockStreamingResponse {
283 usage: Some(BedrockUsage {
284 input_tokens: 200,
285 output_tokens: 75,
286 total_tokens: 275,
287 }),
288 };
289
290 let json = serde_json::to_string(&response).expect("Should serialize");
292 assert!(json.contains("\"input_tokens\":200"));
293
294 let deserialized: BedrockStreamingResponse =
296 serde_json::from_str(&json).expect("Should deserialize");
297 assert!(deserialized.usage.is_some());
298 let usage = deserialized.usage.unwrap();
299 assert_eq!(usage.input_tokens, 200);
300 assert_eq!(usage.output_tokens, 75);
301 assert_eq!(usage.total_tokens, 275);
302 }
303
304 #[test]
305 fn test_reasoning_state_default() {
306 let state = ReasoningState::default();
308 assert_eq!(state.content, "");
309 assert_eq!(state.signature, None);
310 }
311
312 #[test]
313 fn test_reasoning_state_accumulate_content() {
314 let mut state = ReasoningState::default();
316 state.content.push_str("First chunk");
317 state.content.push_str(" Second chunk");
318 state.content.push_str(" Third chunk");
319
320 assert_eq!(state.content, "First chunk Second chunk Third chunk");
321 assert_eq!(state.signature, None);
322 }
323
324 #[test]
325 fn test_reasoning_state_with_signature() {
326 let mut state = ReasoningState::default();
328 state.content.push_str("Reasoning content");
329 state.signature = Some("test_signature_456".to_string());
330
331 assert_eq!(state.content, "Reasoning content");
332 assert_eq!(state.signature, Some("test_signature_456".to_string()));
333 }
334
335 #[test]
336 fn test_reasoning_state_empty_content() {
337 let state = ReasoningState {
339 signature: Some("signature_only".to_string()),
340 ..Default::default()
341 };
342
343 assert_eq!(state.content, "");
344 assert!(state.signature.is_some());
345 }
346
347 #[test]
348 fn test_tool_call_state_default() {
349 let state = ToolCallState::default();
351 assert_eq!(state.name, "");
352 assert_eq!(state.id, "");
353 assert_eq!(state.input_json, "");
354 }
355
356 #[test]
357 fn test_tool_call_state_accumulate_json() {
358 let mut state = ToolCallState {
360 name: "my_tool".to_string(),
361 id: "tool_123".to_string(),
362 input_json: String::new(),
363 };
364
365 state.input_json.push_str("{\"arg1\":");
366 state.input_json.push_str("\"value1\"");
367 state.input_json.push('}');
368
369 assert_eq!(state.name, "my_tool");
370 assert_eq!(state.id, "tool_123");
371 assert_eq!(state.input_json, "{\"arg1\":\"value1\"}");
372 }
373
374 #[test]
375 fn test_tool_call_state_empty_accumulation() {
376 let state = ToolCallState {
377 name: "test_tool".to_string(),
378 id: "tool_abc".to_string(),
379 input_json: String::new(),
380 };
381
382 assert_eq!(state.name, "test_tool");
383 assert_eq!(state.id, "tool_abc");
384 assert!(state.input_json.is_empty());
385 }
386
387 #[test]
388 fn test_tool_call_state_single_chunk() {
389 let mut state = ToolCallState {
390 name: "get_weather".to_string(),
391 id: "call_123".to_string(),
392 input_json: String::new(),
393 };
394
395 state.input_json.push_str("{\"location\":\"Paris\"}");
396
397 assert_eq!(state.input_json, "{\"location\":\"Paris\"}");
398 }
399
400 #[test]
401 fn test_tool_call_state_multiple_small_chunks() {
402 let mut state = ToolCallState {
403 name: "search".to_string(),
404 id: "call_xyz".to_string(),
405 input_json: String::new(),
406 };
407
408 let chunks = vec!["{", "\"q", "uery", "\":", "\"R", "ust", "\"}"];
410
411 for chunk in chunks {
412 state.input_json.push_str(chunk);
413 }
414
415 assert_eq!(state.input_json, "{\"query\":\"Rust\"}");
416 }
417
418 #[test]
419 fn test_tool_call_state_complex_json_accumulation() {
420 let mut state = ToolCallState {
421 name: "analyze_data".to_string(),
422 id: "call_456".to_string(),
423 input_json: String::new(),
424 };
425
426 state.input_json.push_str("{\"data\":{");
428 state.input_json.push_str("\"values\":[1,2,3],");
429 state
430 .input_json
431 .push_str("\"metadata\":{\"source\":\"api\"}");
432 state.input_json.push_str("}}");
433
434 assert_eq!(
435 state.input_json,
436 "{\"data\":{\"values\":[1,2,3],\"metadata\":{\"source\":\"api\"}}}"
437 );
438
439 let parsed: serde_json::Value =
441 serde_json::from_str(&state.input_json).expect("Should parse as valid JSON");
442 assert!(parsed.is_object());
443 }
444
445 #[test]
446 fn test_reasoning_state_accumulation() {
447 let mut state = ReasoningState::default();
448
449 state.content.push_str("First, ");
450 state.content.push_str("I need to ");
451 state.content.push_str("analyze the problem.");
452
453 assert_eq!(state.content, "First, I need to analyze the problem.");
454 assert!(state.signature.is_none());
455 }
456
457 #[test]
458 fn test_reasoning_state_with_signature_accumulation() {
459 let mut state = ReasoningState::default();
460
461 state.content.push_str("Reasoning content here");
462 state.signature = Some("sig_part1".to_string());
463
464 if let Some(ref mut sig) = state.signature {
466 sig.push_str("_part2");
467 }
468
469 assert_eq!(state.content, "Reasoning content here");
470 assert_eq!(state.signature, Some("sig_part1_part2".to_string()));
471 }
472}