strands_agents/types/
streaming.rs

1//! Streaming types for model responses.
2
3use serde::{Deserialize, Serialize};
4
5use super::content::Role;
6
7/// Reason why model generation stopped.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9#[serde(rename_all = "snake_case")]
10pub enum StopReason {
11    EndTurn,
12    ToolUse,
13    MaxTokens,
14    StopSequence,
15    ContentFiltered,
16    GuardrailIntervention,
17    Interrupt,
18}
19
20impl Default for StopReason {
21    fn default() -> Self { Self::EndTurn }
22}
23
24impl StopReason {
25    /// Returns the string representation of the stop reason.
26    pub fn as_str(&self) -> &'static str {
27        match self {
28            StopReason::EndTurn => "end_turn",
29            StopReason::ToolUse => "tool_use",
30            StopReason::MaxTokens => "max_tokens",
31            StopReason::StopSequence => "stop_sequence",
32            StopReason::ContentFiltered => "content_filtered",
33            StopReason::GuardrailIntervention => "guardrail_intervention",
34            StopReason::Interrupt => "interrupt",
35        }
36    }
37}
38
39impl std::fmt::Display for StopReason {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "{}", self.as_str())
42    }
43}
44
45/// Token usage statistics.
46#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
47#[serde(rename_all = "camelCase")]
48pub struct Usage {
49    pub input_tokens: u32,
50    pub output_tokens: u32,
51    pub total_tokens: u32,
52
53    #[serde(default)]
54    pub cache_read_input_tokens: u32,
55
56    #[serde(default)]
57    pub cache_write_input_tokens: u32,
58}
59
60impl Usage {
61    pub fn new(input_tokens: u32, output_tokens: u32) -> Self {
62        Self {
63            input_tokens,
64            output_tokens,
65            total_tokens: input_tokens + output_tokens,
66            cache_read_input_tokens: 0,
67            cache_write_input_tokens: 0,
68        }
69    }
70
71    pub fn add(&mut self, other: &Usage) {
72        self.input_tokens += other.input_tokens;
73        self.output_tokens += other.output_tokens;
74        self.total_tokens += other.total_tokens;
75        self.cache_read_input_tokens += other.cache_read_input_tokens;
76        self.cache_write_input_tokens += other.cache_write_input_tokens;
77    }
78}
79
80/// Performance metrics for a model call.
81#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
82#[serde(rename_all = "camelCase")]
83pub struct Metrics {
84    pub latency_ms: u64,
85
86    #[serde(default)]
87    pub time_to_first_byte_ms: u64,
88}
89
90/// Event indicating message generation has started.
91#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
92#[serde(rename_all = "camelCase")]
93pub struct MessageStartEvent {
94    pub role: Role,
95}
96
97/// Tool use information at content block start.
98#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
99#[serde(rename_all = "camelCase")]
100pub struct ContentBlockStartToolUse {
101    pub name: String,
102    pub tool_use_id: String,
103}
104
105/// Content block start data.
106#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
107#[serde(rename_all = "camelCase")]
108pub struct ContentBlockStart {
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub tool_use: Option<ContentBlockStartToolUse>,
111}
112
113/// Event indicating a content block has started.
114#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
115#[serde(rename_all = "camelCase")]
116pub struct ContentBlockStartEvent {
117    #[serde(skip_serializing_if = "Option::is_none")]
118    pub content_block_index: Option<u32>,
119
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub start: Option<ContentBlockStart>,
122}
123
124/// Tool use delta within a content block.
125#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
126#[serde(rename_all = "camelCase")]
127pub struct ContentBlockDeltaToolUse {
128    pub input: String,
129}
130
131/// Reasoning content delta.
132#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
133#[serde(rename_all = "camelCase")]
134pub struct ReasoningContentBlockDelta {
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub text: Option<String>,
137
138    #[serde(skip_serializing_if = "Option::is_none")]
139    pub signature: Option<String>,
140
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub redacted_content: Option<Vec<u8>>,
143}
144
145/// Citation delta in streaming response.
146#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
147#[serde(rename_all = "camelCase")]
148pub struct CitationsDelta {
149    #[serde(skip_serializing_if = "Option::is_none")]
150    pub location: Option<serde_json::Value>,
151
152    #[serde(skip_serializing_if = "Option::is_none")]
153    pub source_content: Option<Vec<CitationSourceContentDelta>>,
154
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub title: Option<String>,
157}
158
159/// Source content delta for citations.
160#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
161#[serde(rename_all = "camelCase")]
162pub struct CitationSourceContentDelta {
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub text: Option<String>,
165}
166
167/// Incremental content block update.
168#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
169#[serde(rename_all = "camelCase")]
170pub struct ContentBlockDelta {
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub text: Option<String>,
173
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub tool_use: Option<ContentBlockDeltaToolUse>,
176
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub reasoning_content: Option<ReasoningContentBlockDelta>,
179
180    #[serde(skip_serializing_if = "Option::is_none")]
181    pub citation: Option<CitationsDelta>,
182}
183
184/// Event containing a content block delta.
185#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
186#[serde(rename_all = "camelCase")]
187pub struct ContentBlockDeltaEvent {
188    #[serde(skip_serializing_if = "Option::is_none")]
189    pub content_block_index: Option<u32>,
190
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub delta: Option<ContentBlockDelta>,
193}
194
195/// Event indicating a content block has stopped.
196#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
197#[serde(rename_all = "camelCase")]
198pub struct ContentBlockStopEvent {
199    #[serde(skip_serializing_if = "Option::is_none")]
200    pub content_block_index: Option<u32>,
201}
202
203/// Event indicating message generation has stopped.
204#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
205#[serde(rename_all = "camelCase")]
206pub struct MessageStopEvent {
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub stop_reason: Option<StopReason>,
209
210    #[serde(skip_serializing_if = "Option::is_none")]
211    pub additional_model_response_fields: Option<serde_json::Value>,
212}
213
214/// Event containing usage and metrics metadata.
215#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
216#[serde(rename_all = "camelCase")]
217pub struct MetadataEvent {
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub usage: Option<Usage>,
220
221    #[serde(skip_serializing_if = "Option::is_none")]
222    pub metrics: Option<Metrics>,
223
224    #[serde(skip_serializing_if = "Option::is_none")]
225    pub trace: Option<serde_json::Value>,
226}
227
228/// An exception event from the model.
229#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
230#[serde(rename_all = "camelCase")]
231pub struct ExceptionEvent {
232    pub message: String,
233}
234
235/// A stream error event from the model.
236#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
237#[serde(rename_all = "camelCase")]
238pub struct ModelStreamErrorEvent {
239    pub message: String,
240    pub original_message: String,
241    pub original_status_code: i32,
242}
243
244/// Event for content redaction.
245#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
246#[serde(rename_all = "camelCase")]
247pub struct RedactContentEvent {
248    #[serde(skip_serializing_if = "Option::is_none")]
249    pub redact_user_content_message: Option<String>,
250
251    #[serde(skip_serializing_if = "Option::is_none")]
252    pub redact_assistant_content_message: Option<String>,
253}
254
255/// A streaming event from the model.
256#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
257#[serde(rename_all = "camelCase")]
258pub struct StreamEvent {
259    #[serde(skip_serializing_if = "Option::is_none")]
260    pub message_start: Option<MessageStartEvent>,
261
262    #[serde(skip_serializing_if = "Option::is_none")]
263    pub content_block_start: Option<ContentBlockStartEvent>,
264
265    #[serde(skip_serializing_if = "Option::is_none")]
266    pub content_block_delta: Option<ContentBlockDeltaEvent>,
267
268    #[serde(skip_serializing_if = "Option::is_none")]
269    pub content_block_stop: Option<ContentBlockStopEvent>,
270
271    #[serde(skip_serializing_if = "Option::is_none")]
272    pub message_stop: Option<MessageStopEvent>,
273
274    #[serde(skip_serializing_if = "Option::is_none")]
275    pub metadata: Option<MetadataEvent>,
276
277    #[serde(skip_serializing_if = "Option::is_none")]
278    pub redact_content: Option<RedactContentEvent>,
279
280    #[serde(skip_serializing_if = "Option::is_none")]
281    pub internal_server_exception: Option<ExceptionEvent>,
282
283    #[serde(skip_serializing_if = "Option::is_none")]
284    pub model_stream_error_exception: Option<ModelStreamErrorEvent>,
285
286    #[serde(skip_serializing_if = "Option::is_none")]
287    pub throttling_exception: Option<ExceptionEvent>,
288
289    #[serde(skip_serializing_if = "Option::is_none")]
290    pub validation_exception: Option<ExceptionEvent>,
291
292    #[serde(skip_serializing_if = "Option::is_none")]
293    pub service_unavailable_exception: Option<ExceptionEvent>,
294}
295
296impl StreamEvent {
297    pub fn message_start(role: Role) -> Self {
298        Self { message_start: Some(MessageStartEvent { role }), ..Default::default() }
299    }
300
301    pub fn content_block_start(index: u32, start: Option<ContentBlockStart>) -> Self {
302        Self {
303            content_block_start: Some(ContentBlockStartEvent {
304                content_block_index: Some(index),
305                start,
306            }),
307            ..Default::default()
308        }
309    }
310
311    pub fn content_block_delta(index: u32, delta: ContentBlockDelta) -> Self {
312        Self {
313            content_block_delta: Some(ContentBlockDeltaEvent {
314                content_block_index: Some(index),
315                delta: Some(delta),
316            }),
317            ..Default::default()
318        }
319    }
320
321    pub fn text_delta(index: u32, text: impl Into<String>) -> Self {
322        Self::content_block_delta(index, ContentBlockDelta { text: Some(text.into()), ..Default::default() })
323    }
324
325    pub fn tool_use_delta(index: u32, input: impl Into<String>) -> Self {
326        Self::content_block_delta(index, ContentBlockDelta {
327            tool_use: Some(ContentBlockDeltaToolUse { input: input.into() }),
328            ..Default::default()
329        })
330    }
331
332    pub fn tool_use_start(index: u32, name: impl Into<String>, tool_use_id: impl Into<String>) -> Self {
333        Self {
334            content_block_start: Some(ContentBlockStartEvent {
335                content_block_index: Some(index),
336                start: Some(ContentBlockStart {
337                    tool_use: Some(ContentBlockStartToolUse {
338                        name: name.into(),
339                        tool_use_id: tool_use_id.into(),
340                    }),
341                }),
342            }),
343            ..Default::default()
344        }
345    }
346
347    pub fn reasoning_delta(index: u32, text: impl Into<String>) -> Self {
348        Self::content_block_delta(index, ContentBlockDelta {
349            reasoning_content: Some(ReasoningContentBlockDelta {
350                text: Some(text.into()),
351                ..Default::default()
352            }),
353            ..Default::default()
354        })
355    }
356
357    pub fn content_block_stop(index: u32) -> Self {
358        Self {
359            content_block_stop: Some(ContentBlockStopEvent { content_block_index: Some(index) }),
360            ..Default::default()
361        }
362    }
363
364    pub fn message_stop(stop_reason: StopReason) -> Self {
365        Self {
366            message_stop: Some(MessageStopEvent { stop_reason: Some(stop_reason), additional_model_response_fields: None }),
367            ..Default::default()
368        }
369    }
370
371    pub fn metadata(usage: Usage, metrics: Metrics) -> Self {
372        Self {
373            metadata: Some(MetadataEvent { usage: Some(usage), metrics: Some(metrics), trace: None }),
374            ..Default::default()
375        }
376    }
377
378    pub fn is_text_delta(&self) -> bool {
379        self.content_block_delta.as_ref().and_then(|e| e.delta.as_ref()).map(|d| d.text.is_some()).unwrap_or(false)
380    }
381
382    pub fn as_text_delta(&self) -> Option<&str> {
383        self.content_block_delta.as_ref().and_then(|e| e.delta.as_ref()).and_then(|d| d.text.as_deref())
384    }
385
386    pub fn is_message_stop(&self) -> bool { self.message_stop.is_some() }
387    pub fn stop_reason(&self) -> Option<StopReason> { self.message_stop.as_ref().and_then(|e| e.stop_reason) }
388
389    pub fn is_error(&self) -> bool {
390        self.internal_server_exception.is_some()
391            || self.model_stream_error_exception.is_some()
392            || self.throttling_exception.is_some()
393            || self.validation_exception.is_some()
394            || self.service_unavailable_exception.is_some()
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401
402    #[test]
403    fn test_usage_add() {
404        let mut usage1 = Usage::new(100, 50);
405        let usage2 = Usage::new(200, 100);
406        usage1.add(&usage2);
407        assert_eq!(usage1.input_tokens, 300);
408        assert_eq!(usage1.output_tokens, 150);
409        assert_eq!(usage1.total_tokens, 450);
410    }
411
412    #[test]
413    fn test_stop_reason_serialization() {
414        assert_eq!(serde_json::to_string(&StopReason::EndTurn).unwrap(), "\"end_turn\"");
415        assert_eq!(serde_json::to_string(&StopReason::ToolUse).unwrap(), "\"tool_use\"");
416    }
417
418    #[test]
419    fn test_stream_event_text_delta() {
420        let event = StreamEvent::text_delta(0, "Hello");
421        assert!(event.is_text_delta());
422        assert_eq!(event.as_text_delta(), Some("Hello"));
423    }
424
425    #[test]
426    fn test_stream_event_serialization() {
427        let event = StreamEvent::text_delta(0, "hi");
428        let json = serde_json::to_string(&event).unwrap();
429        assert!(json.contains("contentBlockDelta"));
430    }
431}