1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use serde_json::Value as JsonValue;
9use serdes_ai_core::messages::{TextPart, ThinkingPart, ToolCallArgs, ToolCallPart};
10use serdes_ai_core::{FinishReason, ModelResponse, ModelResponsePart, RequestUsage};
11
12#[derive(Debug, Clone)]
14enum PartialPart {
15 Text { content: String },
17 ToolCall {
19 name: Option<String>,
20 args: String,
21 id: Option<String>,
22 },
23 Thinking {
25 content: String,
26 signature: Option<String>,
27 },
28}
29
30impl PartialPart {
31 fn text() -> Self {
33 Self::Text {
34 content: String::new(),
35 }
36 }
37
38 fn tool_call() -> Self {
40 Self::ToolCall {
41 name: None,
42 args: String::new(),
43 id: None,
44 }
45 }
46
47 fn thinking() -> Self {
49 Self::Thinking {
50 content: String::new(),
51 signature: None,
52 }
53 }
54
55 fn has_content(&self) -> bool {
57 match self {
58 Self::Text { content } => !content.is_empty(),
59 Self::ToolCall { name, args, .. } => name.is_some() || !args.is_empty(),
60 Self::Thinking { content, .. } => !content.is_empty(),
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct PartialResponse {
68 parts: Vec<PartialPart>,
69 model_name: Option<String>,
70 usage: Option<RequestUsage>,
71 finish_reason: Option<FinishReason>,
72 timestamp: DateTime<Utc>,
73 vendor_id: Option<String>,
74}
75
76impl Default for PartialResponse {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl PartialResponse {
83 #[must_use]
85 pub fn new() -> Self {
86 Self {
87 parts: Vec::new(),
88 model_name: None,
89 usage: None,
90 finish_reason: None,
91 timestamp: Utc::now(),
92 vendor_id: None,
93 }
94 }
95
96 fn ensure_parts(&mut self, n: usize, default_fn: impl Fn() -> PartialPart) {
98 while self.parts.len() <= n {
99 self.parts.push(default_fn());
100 }
101 }
102
103 pub fn apply_text_delta(&mut self, index: usize, content: &str) {
105 self.ensure_parts(index, PartialPart::text);
106
107 if !matches!(self.parts[index], PartialPart::Text { .. }) {
109 self.parts[index] = PartialPart::text();
110 }
111
112 if let PartialPart::Text {
113 content: existing, ..
114 } = &mut self.parts[index]
115 {
116 existing.push_str(content);
117 }
118 }
119
120 pub fn apply_tool_delta(
122 &mut self,
123 index: usize,
124 name: Option<&str>,
125 args_delta: Option<&str>,
126 id: Option<&str>,
127 ) {
128 self.ensure_parts(index, PartialPart::tool_call);
129
130 if !matches!(self.parts[index], PartialPart::ToolCall { .. }) {
132 self.parts[index] = PartialPart::tool_call();
133 }
134
135 if let PartialPart::ToolCall {
136 name: existing_name,
137 args,
138 id: existing_id,
139 } = &mut self.parts[index]
140 {
141 if let Some(n) = name {
142 *existing_name = Some(n.to_string());
143 }
144 if let Some(a) = args_delta {
145 args.push_str(a);
146 }
147 if let Some(i) = id {
148 *existing_id = Some(i.to_string());
149 }
150 }
151 }
152
153 pub fn apply_thinking_delta(&mut self, index: usize, content: &str, signature: Option<&str>) {
155 self.ensure_parts(index, PartialPart::thinking);
156
157 if !matches!(self.parts[index], PartialPart::Thinking { .. }) {
159 self.parts[index] = PartialPart::thinking();
160 }
161
162 if let PartialPart::Thinking {
163 content: existing,
164 signature: existing_sig,
165 } = &mut self.parts[index]
166 {
167 existing.push_str(content);
168 if let Some(s) = signature {
169 *existing_sig = Some(s.to_string());
170 }
171 }
172 }
173
174 pub fn set_model_name(&mut self, name: impl Into<String>) {
176 self.model_name = Some(name.into());
177 }
178
179 pub fn set_usage(&mut self, usage: RequestUsage) {
181 self.usage = Some(usage);
182 }
183
184 pub fn set_finish_reason(&mut self, reason: FinishReason) {
186 self.finish_reason = Some(reason);
187 }
188
189 pub fn set_vendor_id(&mut self, id: impl Into<String>) {
191 self.vendor_id = Some(id.into());
192 }
193
194 #[must_use]
196 pub fn text_content(&self) -> String {
197 self.parts
198 .iter()
199 .filter_map(|p| match p {
200 PartialPart::Text { content } => Some(content.as_str()),
201 _ => None,
202 })
203 .collect::<Vec<_>>()
204 .join("")
205 }
206
207 #[must_use]
209 pub fn num_parts(&self) -> usize {
210 self.parts.len()
211 }
212
213 #[must_use]
215 pub fn is_empty(&self) -> bool {
216 self.parts.iter().all(|p| !p.has_content())
217 }
218
219 #[must_use]
221 pub fn finalize(self) -> ModelResponse {
222 let parts = self
223 .parts
224 .into_iter()
225 .filter(|p| p.has_content())
226 .filter_map(|p| match p {
227 PartialPart::Text { content } => {
228 Some(ModelResponsePart::Text(TextPart::new(content)))
229 }
230 PartialPart::ToolCall {
231 name: Some(name),
232 args,
233 id,
234 } => {
235 let parsed_args: JsonValue =
236 serde_json::from_str(&args).unwrap_or(JsonValue::Null);
237 let mut tc = ToolCallPart::new(name, ToolCallArgs::Json(parsed_args));
238 if let Some(id) = id {
239 tc.tool_call_id = Some(id);
240 }
241 Some(ModelResponsePart::ToolCall(tc))
242 }
243 PartialPart::Thinking { content, signature } => {
244 let mut thinking = ThinkingPart::new(content);
245 thinking.signature = signature;
246 Some(ModelResponsePart::Thinking(thinking))
247 }
248 _ => None,
249 })
250 .collect();
251
252 ModelResponse {
253 parts,
254 model_name: self.model_name,
255 timestamp: self.timestamp,
256 finish_reason: self.finish_reason,
257 usage: self.usage,
258 vendor_id: self.vendor_id,
259 vendor_details: None,
260 kind: "response".to_string(),
261 }
262 }
263
264 #[must_use]
266 pub fn as_response(&self) -> ModelResponse {
267 self.clone().finalize()
268 }
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
273#[serde(tag = "type", rename_all = "snake_case")]
274pub enum ResponseDelta {
275 Text {
277 index: usize,
279 content: String,
281 },
282 ToolCall {
284 index: usize,
286 name: Option<String>,
288 args: Option<String>,
290 id: Option<String>,
292 },
293 Thinking {
295 index: usize,
297 content: String,
299 signature: Option<String>,
301 },
302 Finish {
304 reason: FinishReason,
306 },
307 Usage {
309 usage: RequestUsage,
311 },
312}
313
314impl PartialResponse {
315 pub fn apply_delta(&mut self, delta: &ResponseDelta) {
317 match delta {
318 ResponseDelta::Text { index, content } => {
319 self.apply_text_delta(*index, content);
320 }
321 ResponseDelta::ToolCall {
322 index,
323 name,
324 args,
325 id,
326 } => {
327 self.apply_tool_delta(*index, name.as_deref(), args.as_deref(), id.as_deref());
328 }
329 ResponseDelta::Thinking {
330 index,
331 content,
332 signature,
333 } => {
334 self.apply_thinking_delta(*index, content, signature.as_deref());
335 }
336 ResponseDelta::Finish { reason } => {
337 self.set_finish_reason(*reason);
338 }
339 ResponseDelta::Usage { usage } => {
340 self.set_usage(usage.clone());
341 }
342 }
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn test_new_partial_response() {
352 let pr = PartialResponse::new();
353 assert!(pr.is_empty());
354 assert_eq!(pr.num_parts(), 0);
355 }
356
357 #[test]
358 fn test_text_accumulation() {
359 let mut pr = PartialResponse::new();
360 pr.apply_text_delta(0, "Hello, ");
361 pr.apply_text_delta(0, "world!");
362
363 assert_eq!(pr.text_content(), "Hello, world!");
364 assert!(!pr.is_empty());
365 }
366
367 #[test]
368 fn test_tool_call_accumulation() {
369 let mut pr = PartialResponse::new();
370 pr.apply_tool_delta(0, Some("search"), None, Some("call-1"));
371 pr.apply_tool_delta(0, None, Some("{\"query\": "), None);
372 pr.apply_tool_delta(0, None, Some("\"rust\"}"), None);
373
374 let response = pr.finalize();
375 assert_eq!(response.parts.len(), 1);
376
377 if let ModelResponsePart::ToolCall(tc) = &response.parts[0] {
378 assert_eq!(tc.tool_name, "search");
379 assert_eq!(tc.tool_call_id, Some("call-1".to_string()));
380 } else {
381 panic!("Expected tool call part");
382 }
383 }
384
385 #[test]
386 fn test_thinking_accumulation() {
387 let mut pr = PartialResponse::new();
388 pr.apply_thinking_delta(0, "Let me think...", None);
389 pr.apply_thinking_delta(0, " I need to", None);
390 pr.apply_thinking_delta(0, " consider options.", Some("sig-123"));
391
392 let response = pr.finalize();
393 assert_eq!(response.parts.len(), 1);
394
395 if let ModelResponsePart::Thinking(t) = &response.parts[0] {
396 assert_eq!(t.content, "Let me think... I need to consider options.");
397 assert_eq!(t.signature, Some("sig-123".to_string()));
398 } else {
399 panic!("Expected thinking part");
400 }
401 }
402
403 #[test]
404 fn test_multiple_parts() {
405 let mut pr = PartialResponse::new();
406 pr.apply_text_delta(0, "Hello");
407 pr.apply_tool_delta(1, Some("search"), Some("{}"), None);
408 pr.apply_text_delta(2, "World");
409
410 let response = pr.finalize();
411 assert_eq!(response.parts.len(), 3);
412 }
413
414 #[test]
415 fn test_apply_delta() {
416 let mut pr = PartialResponse::new();
417
418 pr.apply_delta(&ResponseDelta::Text {
419 index: 0,
420 content: "Hello".to_string(),
421 });
422
423 pr.apply_delta(&ResponseDelta::Finish {
424 reason: FinishReason::Stop,
425 });
426
427 let response = pr.finalize();
428 assert_eq!(response.text_content(), "Hello");
429 assert_eq!(response.finish_reason, Some(FinishReason::Stop));
430 }
431
432 #[test]
433 fn test_as_response_clones() {
434 let mut pr = PartialResponse::new();
435 pr.apply_text_delta(0, "Test");
436
437 let snap1 = pr.as_response();
438 pr.apply_text_delta(0, " more");
439 let snap2 = pr.as_response();
440
441 assert_eq!(snap1.text_content(), "Test");
442 assert_eq!(snap2.text_content(), "Test more");
443 }
444}