Skip to main content

serdes_ai_streaming/
partial_response.rs

1//! Partial response accumulation.
2//!
3//! This module provides types for accumulating streaming deltas into
4//! complete responses.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use serde_json::Value as JsonValue;
9use serdes_ai_core::messages::{TextPart, ThinkingPart, ToolCallArgs, ToolCallPart};
10use serdes_ai_core::{FinishReason, ModelResponse, ModelResponsePart, RequestUsage};
11
12/// Partial part being accumulated.
13#[derive(Debug, Clone)]
14enum PartialPart {
15    /// Text content.
16    Text { content: String },
17    /// Tool call.
18    ToolCall {
19        name: Option<String>,
20        args: String,
21        id: Option<String>,
22    },
23    /// Thinking content.
24    Thinking {
25        content: String,
26        signature: Option<String>,
27    },
28}
29
30impl PartialPart {
31    /// Create new text part.
32    fn text() -> Self {
33        Self::Text {
34            content: String::new(),
35        }
36    }
37
38    /// Create new tool call part.
39    fn tool_call() -> Self {
40        Self::ToolCall {
41            name: None,
42            args: String::new(),
43            id: None,
44        }
45    }
46
47    /// Create new thinking part.
48    fn thinking() -> Self {
49        Self::Thinking {
50            content: String::new(),
51            signature: None,
52        }
53    }
54
55    /// Check if this part has any content.
56    fn has_content(&self) -> bool {
57        match self {
58            Self::Text { content } => !content.is_empty(),
59            Self::ToolCall { name, args, .. } => name.is_some() || !args.is_empty(),
60            Self::Thinking { content, .. } => !content.is_empty(),
61        }
62    }
63}
64
65/// Accumulates streaming deltas into a complete response.
66#[derive(Debug, Clone)]
67pub struct PartialResponse {
68    parts: Vec<PartialPart>,
69    model_name: Option<String>,
70    usage: Option<RequestUsage>,
71    finish_reason: Option<FinishReason>,
72    timestamp: DateTime<Utc>,
73    vendor_id: Option<String>,
74}
75
76impl Default for PartialResponse {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl PartialResponse {
83    /// Create a new partial response.
84    #[must_use]
85    pub fn new() -> Self {
86        Self {
87            parts: Vec::new(),
88            model_name: None,
89            usage: None,
90            finish_reason: None,
91            timestamp: Utc::now(),
92            vendor_id: None,
93        }
94    }
95
96    /// Ensure we have at least `n` parts, expanding with default type.
97    fn ensure_parts(&mut self, n: usize, default_fn: impl Fn() -> PartialPart) {
98        while self.parts.len() <= n {
99            self.parts.push(default_fn());
100        }
101    }
102
103    /// Apply a text delta.
104    pub fn apply_text_delta(&mut self, index: usize, content: &str) {
105        self.ensure_parts(index, PartialPart::text);
106
107        // Ensure it's a text part
108        if !matches!(self.parts[index], PartialPart::Text { .. }) {
109            self.parts[index] = PartialPart::text();
110        }
111
112        if let PartialPart::Text {
113            content: existing, ..
114        } = &mut self.parts[index]
115        {
116            existing.push_str(content);
117        }
118    }
119
120    /// Apply a tool call delta.
121    pub fn apply_tool_delta(
122        &mut self,
123        index: usize,
124        name: Option<&str>,
125        args_delta: Option<&str>,
126        id: Option<&str>,
127    ) {
128        self.ensure_parts(index, PartialPart::tool_call);
129
130        // Ensure it's a tool call part
131        if !matches!(self.parts[index], PartialPart::ToolCall { .. }) {
132            self.parts[index] = PartialPart::tool_call();
133        }
134
135        if let PartialPart::ToolCall {
136            name: existing_name,
137            args,
138            id: existing_id,
139        } = &mut self.parts[index]
140        {
141            if let Some(n) = name {
142                *existing_name = Some(n.to_string());
143            }
144            if let Some(a) = args_delta {
145                args.push_str(a);
146            }
147            if let Some(i) = id {
148                *existing_id = Some(i.to_string());
149            }
150        }
151    }
152
153    /// Apply a thinking delta.
154    pub fn apply_thinking_delta(&mut self, index: usize, content: &str, signature: Option<&str>) {
155        self.ensure_parts(index, PartialPart::thinking);
156
157        // Ensure it's a thinking part
158        if !matches!(self.parts[index], PartialPart::Thinking { .. }) {
159            self.parts[index] = PartialPart::thinking();
160        }
161
162        if let PartialPart::Thinking {
163            content: existing,
164            signature: existing_sig,
165        } = &mut self.parts[index]
166        {
167            existing.push_str(content);
168            if let Some(s) = signature {
169                *existing_sig = Some(s.to_string());
170            }
171        }
172    }
173
174    /// Set the model name.
175    pub fn set_model_name(&mut self, name: impl Into<String>) {
176        self.model_name = Some(name.into());
177    }
178
179    /// Set the usage.
180    pub fn set_usage(&mut self, usage: RequestUsage) {
181        self.usage = Some(usage);
182    }
183
184    /// Set the finish reason.
185    pub fn set_finish_reason(&mut self, reason: FinishReason) {
186        self.finish_reason = Some(reason);
187    }
188
189    /// Set the vendor ID.
190    pub fn set_vendor_id(&mut self, id: impl Into<String>) {
191        self.vendor_id = Some(id.into());
192    }
193
194    /// Get accumulated text content.
195    #[must_use]
196    pub fn text_content(&self) -> String {
197        self.parts
198            .iter()
199            .filter_map(|p| match p {
200                PartialPart::Text { content } => Some(content.as_str()),
201                _ => None,
202            })
203            .collect::<Vec<_>>()
204            .join("")
205    }
206
207    /// Get the number of parts.
208    #[must_use]
209    pub fn num_parts(&self) -> usize {
210        self.parts.len()
211    }
212
213    /// Check if the response is empty.
214    #[must_use]
215    pub fn is_empty(&self) -> bool {
216        self.parts.iter().all(|p| !p.has_content())
217    }
218
219    /// Finalize into a complete response (consumes self).
220    #[must_use]
221    pub fn finalize(self) -> ModelResponse {
222        let parts = self
223            .parts
224            .into_iter()
225            .filter(|p| p.has_content())
226            .filter_map(|p| match p {
227                PartialPart::Text { content } => {
228                    Some(ModelResponsePart::Text(TextPart::new(content)))
229                }
230                PartialPart::ToolCall {
231                    name: Some(name),
232                    args,
233                    id,
234                } => {
235                    let parsed_args: JsonValue =
236                        serde_json::from_str(&args).unwrap_or(JsonValue::Null);
237                    let mut tc = ToolCallPart::new(name, ToolCallArgs::Json(parsed_args));
238                    if let Some(id) = id {
239                        tc.tool_call_id = Some(id);
240                    }
241                    Some(ModelResponsePart::ToolCall(tc))
242                }
243                PartialPart::Thinking { content, signature } => {
244                    let mut thinking = ThinkingPart::new(content);
245                    thinking.signature = signature;
246                    Some(ModelResponsePart::Thinking(thinking))
247                }
248                _ => None,
249            })
250            .collect();
251
252        ModelResponse {
253            parts,
254            model_name: self.model_name,
255            timestamp: self.timestamp,
256            finish_reason: self.finish_reason,
257            usage: self.usage,
258            vendor_id: self.vendor_id,
259            vendor_details: None,
260            kind: "response".to_string(),
261        }
262    }
263
264    /// Get a snapshot as a response (clones data).
265    #[must_use]
266    pub fn as_response(&self) -> ModelResponse {
267        self.clone().finalize()
268    }
269}
270
271/// Delta types for applying to partial response.
272#[derive(Debug, Clone, Serialize, Deserialize)]
273#[serde(tag = "type", rename_all = "snake_case")]
274pub enum ResponseDelta {
275    /// Text content delta.
276    Text {
277        /// Part index.
278        index: usize,
279        /// Text content.
280        content: String,
281    },
282    /// Tool call delta.
283    ToolCall {
284        /// Part index.
285        index: usize,
286        /// Tool name (first delta only).
287        name: Option<String>,
288        /// Arguments delta.
289        args: Option<String>,
290        /// Tool call ID (first delta only).
291        id: Option<String>,
292    },
293    /// Thinking delta.
294    Thinking {
295        /// Part index.
296        index: usize,
297        /// Thinking content.
298        content: String,
299        /// Signature (last delta only).
300        signature: Option<String>,
301    },
302    /// Finish signal.
303    Finish {
304        /// Finish reason.
305        reason: FinishReason,
306    },
307    /// Usage update.
308    Usage {
309        /// Current usage.
310        usage: RequestUsage,
311    },
312}
313
314impl PartialResponse {
315    /// Apply a delta to the response.
316    pub fn apply_delta(&mut self, delta: &ResponseDelta) {
317        match delta {
318            ResponseDelta::Text { index, content } => {
319                self.apply_text_delta(*index, content);
320            }
321            ResponseDelta::ToolCall {
322                index,
323                name,
324                args,
325                id,
326            } => {
327                self.apply_tool_delta(*index, name.as_deref(), args.as_deref(), id.as_deref());
328            }
329            ResponseDelta::Thinking {
330                index,
331                content,
332                signature,
333            } => {
334                self.apply_thinking_delta(*index, content, signature.as_deref());
335            }
336            ResponseDelta::Finish { reason } => {
337                self.set_finish_reason(*reason);
338            }
339            ResponseDelta::Usage { usage } => {
340                self.set_usage(usage.clone());
341            }
342        }
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn test_new_partial_response() {
352        let pr = PartialResponse::new();
353        assert!(pr.is_empty());
354        assert_eq!(pr.num_parts(), 0);
355    }
356
357    #[test]
358    fn test_text_accumulation() {
359        let mut pr = PartialResponse::new();
360        pr.apply_text_delta(0, "Hello, ");
361        pr.apply_text_delta(0, "world!");
362
363        assert_eq!(pr.text_content(), "Hello, world!");
364        assert!(!pr.is_empty());
365    }
366
367    #[test]
368    fn test_tool_call_accumulation() {
369        let mut pr = PartialResponse::new();
370        pr.apply_tool_delta(0, Some("search"), None, Some("call-1"));
371        pr.apply_tool_delta(0, None, Some("{\"query\": "), None);
372        pr.apply_tool_delta(0, None, Some("\"rust\"}"), None);
373
374        let response = pr.finalize();
375        assert_eq!(response.parts.len(), 1);
376
377        if let ModelResponsePart::ToolCall(tc) = &response.parts[0] {
378            assert_eq!(tc.tool_name, "search");
379            assert_eq!(tc.tool_call_id, Some("call-1".to_string()));
380        } else {
381            panic!("Expected tool call part");
382        }
383    }
384
385    #[test]
386    fn test_thinking_accumulation() {
387        let mut pr = PartialResponse::new();
388        pr.apply_thinking_delta(0, "Let me think...", None);
389        pr.apply_thinking_delta(0, " I need to", None);
390        pr.apply_thinking_delta(0, " consider options.", Some("sig-123"));
391
392        let response = pr.finalize();
393        assert_eq!(response.parts.len(), 1);
394
395        if let ModelResponsePart::Thinking(t) = &response.parts[0] {
396            assert_eq!(t.content, "Let me think... I need to consider options.");
397            assert_eq!(t.signature, Some("sig-123".to_string()));
398        } else {
399            panic!("Expected thinking part");
400        }
401    }
402
403    #[test]
404    fn test_multiple_parts() {
405        let mut pr = PartialResponse::new();
406        pr.apply_text_delta(0, "Hello");
407        pr.apply_tool_delta(1, Some("search"), Some("{}"), None);
408        pr.apply_text_delta(2, "World");
409
410        let response = pr.finalize();
411        assert_eq!(response.parts.len(), 3);
412    }
413
414    #[test]
415    fn test_apply_delta() {
416        let mut pr = PartialResponse::new();
417
418        pr.apply_delta(&ResponseDelta::Text {
419            index: 0,
420            content: "Hello".to_string(),
421        });
422
423        pr.apply_delta(&ResponseDelta::Finish {
424            reason: FinishReason::Stop,
425        });
426
427        let response = pr.finalize();
428        assert_eq!(response.text_content(), "Hello");
429        assert_eq!(response.finish_reason, Some(FinishReason::Stop));
430    }
431
432    #[test]
433    fn test_as_response_clones() {
434        let mut pr = PartialResponse::new();
435        pr.apply_text_delta(0, "Test");
436
437        let snap1 = pr.as_response();
438        pr.apply_text_delta(0, " more");
439        let snap2 = pr.as_response();
440
441        assert_eq!(snap1.text_content(), "Test");
442        assert_eq!(snap2.text_content(), "Test more");
443    }
444}