1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8
9use super::parts::{BuiltinToolCallPart, FilePart, TextPart, ThinkingPart, ToolCallPart};
10use crate::usage::RequestUsage;
11
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
14pub struct ModelResponse {
15 pub parts: Vec<ModelResponsePart>,
17 #[serde(skip_serializing_if = "Option::is_none")]
19 pub model_name: Option<String>,
20 pub timestamp: DateTime<Utc>,
22 #[serde(skip_serializing_if = "Option::is_none")]
24 pub finish_reason: Option<FinishReason>,
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub usage: Option<RequestUsage>,
28 #[serde(skip_serializing_if = "Option::is_none")]
30 pub vendor_id: Option<String>,
31 #[serde(skip_serializing_if = "Option::is_none")]
33 pub vendor_details: Option<serde_json::Value>,
34 #[serde(default = "default_response_kind")]
36 pub kind: String,
37}
38
39fn default_response_kind() -> String {
40 "response".to_string()
41}
42
43impl ModelResponse {
44 #[must_use]
46 pub fn new() -> Self {
47 Self {
48 parts: Vec::new(),
49 model_name: None,
50 timestamp: Utc::now(),
51 finish_reason: None,
52 usage: None,
53 vendor_id: None,
54 vendor_details: None,
55 kind: "response".to_string(),
56 }
57 }
58
59 #[must_use]
61 pub fn with_parts(parts: Vec<ModelResponsePart>) -> Self {
62 Self {
63 parts,
64 ..Self::new()
65 }
66 }
67
68 #[must_use]
70 pub fn text(content: impl Into<String>) -> Self {
71 Self::with_parts(vec![ModelResponsePart::Text(TextPart::new(content))])
72 }
73
74 pub fn add_part(&mut self, part: ModelResponsePart) {
76 self.parts.push(part);
77 }
78
79 #[must_use]
81 pub fn with_model_name(mut self, name: impl Into<String>) -> Self {
82 self.model_name = Some(name.into());
83 self
84 }
85
86 #[must_use]
88 pub fn with_finish_reason(mut self, reason: FinishReason) -> Self {
89 self.finish_reason = Some(reason);
90 self
91 }
92
93 #[must_use]
95 pub fn with_usage(mut self, usage: RequestUsage) -> Self {
96 self.usage = Some(usage);
97 self
98 }
99
100 #[must_use]
102 pub fn with_vendor_id(mut self, id: impl Into<String>) -> Self {
103 self.vendor_id = Some(id.into());
104 self
105 }
106
107 #[must_use]
109 pub fn with_vendor_details(mut self, details: serde_json::Value) -> Self {
110 self.vendor_details = Some(details);
111 self
112 }
113
114 pub fn text_parts(&self) -> impl Iterator<Item = &TextPart> {
116 self.parts.iter().filter_map(|p| match p {
117 ModelResponsePart::Text(t) => Some(t),
118 _ => None,
119 })
120 }
121
122 pub fn tool_call_parts(&self) -> impl Iterator<Item = &ToolCallPart> {
124 self.parts.iter().filter_map(|p| match p {
125 ModelResponsePart::ToolCall(t) => Some(t),
126 _ => None,
127 })
128 }
129
130 pub fn thinking_parts(&self) -> impl Iterator<Item = &ThinkingPart> {
132 self.parts.iter().filter_map(|p| match p {
133 ModelResponsePart::Thinking(t) => Some(t),
134 _ => None,
135 })
136 }
137
138 pub fn file_parts(&self) -> impl Iterator<Item = &FilePart> {
140 self.parts.iter().filter_map(|p| match p {
141 ModelResponsePart::File(f) => Some(f),
142 _ => None,
143 })
144 }
145
146 #[deprecated(note = "Use text_parts() iterator instead")]
148 pub fn text_parts_vec(&self) -> Vec<&TextPart> {
149 self.text_parts().collect()
150 }
151
152 #[deprecated(note = "Use tool_call_parts() iterator instead")]
154 pub fn tool_call_parts_vec(&self) -> Vec<&ToolCallPart> {
155 self.tool_call_parts().collect()
156 }
157
158 #[deprecated(note = "Use thinking_parts() iterator instead")]
160 pub fn thinking_parts_vec(&self) -> Vec<&ThinkingPart> {
161 self.thinking_parts().collect()
162 }
163
164 #[deprecated(note = "Use file_parts() iterator instead")]
166 pub fn file_parts_vec(&self) -> Vec<&FilePart> {
167 self.file_parts().collect()
168 }
169
170 #[must_use]
172 pub fn has_files(&self) -> bool {
173 self.parts
174 .iter()
175 .any(|p| matches!(p, ModelResponsePart::File(_)))
176 }
177
178 pub fn builtin_tool_call_parts(&self) -> impl Iterator<Item = &BuiltinToolCallPart> {
180 self.parts.iter().filter_map(|p| match p {
181 ModelResponsePart::BuiltinToolCall(b) => Some(b),
182 _ => None,
183 })
184 }
185
186 #[deprecated(note = "Use builtin_tool_call_parts() iterator instead")]
188 pub fn builtin_tool_call_parts_vec(&self) -> Vec<&BuiltinToolCallPart> {
189 self.builtin_tool_call_parts().collect()
190 }
191
192 #[must_use]
194 pub fn has_builtin_tool_calls(&self) -> bool {
195 self.parts
196 .iter()
197 .any(|p| matches!(p, ModelResponsePart::BuiltinToolCall(_)))
198 }
199
200 #[must_use]
202 pub fn text_content(&self) -> String {
203 self.text_parts()
204 .map(|p| p.content.as_str())
205 .collect::<Vec<_>>()
206 .join("")
207 }
208
209 #[must_use]
211 pub fn has_tool_calls(&self) -> bool {
212 self.parts
213 .iter()
214 .any(|p| matches!(p, ModelResponsePart::ToolCall(_)))
215 }
216
217 #[must_use]
219 pub fn is_empty(&self) -> bool {
220 self.parts.is_empty()
221 }
222
223 #[must_use]
225 pub fn len(&self) -> usize {
226 self.parts.len()
227 }
228}
229
230impl Default for ModelResponse {
231 fn default() -> Self {
232 Self::new()
233 }
234}
235
236impl FromIterator<ModelResponsePart> for ModelResponse {
237 fn from_iter<T: IntoIterator<Item = ModelResponsePart>>(iter: T) -> Self {
238 Self::with_parts(iter.into_iter().collect())
239 }
240}
241
242#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
244#[serde(tag = "part_kind", rename_all = "kebab-case")]
245pub enum ModelResponsePart {
246 Text(TextPart),
248 ToolCall(ToolCallPart),
250 Thinking(ThinkingPart),
252 File(FilePart),
254 BuiltinToolCall(BuiltinToolCallPart),
256}
257
258impl ModelResponsePart {
259 #[must_use]
261 pub fn text(content: impl Into<String>) -> Self {
262 Self::Text(TextPart::new(content))
263 }
264
265 #[must_use]
267 pub fn tool_call(
268 tool_name: impl Into<String>,
269 args: impl Into<super::parts::ToolCallArgs>,
270 ) -> Self {
271 Self::ToolCall(ToolCallPart::new(tool_name, args))
272 }
273
274 #[must_use]
276 pub fn thinking(content: impl Into<String>) -> Self {
277 Self::Thinking(ThinkingPart::new(content))
278 }
279
280 #[must_use]
282 pub fn file(data: Vec<u8>, media_type: impl Into<String>) -> Self {
283 Self::File(FilePart::from_bytes(data, media_type))
284 }
285
286 #[must_use]
288 pub fn builtin_tool_call(
289 tool_name: impl Into<String>,
290 args: impl Into<super::parts::ToolCallArgs>,
291 ) -> Self {
292 Self::BuiltinToolCall(BuiltinToolCallPart::new(tool_name, args))
293 }
294
295 #[must_use]
297 pub fn part_kind(&self) -> &'static str {
298 match self {
299 Self::Text(_) => TextPart::PART_KIND,
300 Self::ToolCall(_) => ToolCallPart::PART_KIND,
301 Self::Thinking(_) => ThinkingPart::PART_KIND,
302 Self::File(_) => FilePart::PART_KIND,
303 Self::BuiltinToolCall(_) => BuiltinToolCallPart::PART_KIND,
304 }
305 }
306
307 #[must_use]
309 pub fn is_text(&self) -> bool {
310 matches!(self, Self::Text(_))
311 }
312
313 #[must_use]
315 pub fn is_tool_call(&self) -> bool {
316 matches!(self, Self::ToolCall(_))
317 }
318
319 #[must_use]
321 pub fn is_thinking(&self) -> bool {
322 matches!(self, Self::Thinking(_))
323 }
324
325 #[must_use]
327 pub fn is_file(&self) -> bool {
328 matches!(self, Self::File(_))
329 }
330
331 #[must_use]
333 pub fn is_builtin_tool_call(&self) -> bool {
334 matches!(self, Self::BuiltinToolCall(_))
335 }
336}
337
338impl From<TextPart> for ModelResponsePart {
339 fn from(p: TextPart) -> Self {
340 Self::Text(p)
341 }
342}
343
344impl From<ToolCallPart> for ModelResponsePart {
345 fn from(p: ToolCallPart) -> Self {
346 Self::ToolCall(p)
347 }
348}
349
350impl From<ThinkingPart> for ModelResponsePart {
351 fn from(p: ThinkingPart) -> Self {
352 Self::Thinking(p)
353 }
354}
355
356impl From<FilePart> for ModelResponsePart {
357 fn from(p: FilePart) -> Self {
358 Self::File(p)
359 }
360}
361
362impl From<BuiltinToolCallPart> for ModelResponsePart {
363 fn from(p: BuiltinToolCallPart) -> Self {
364 Self::BuiltinToolCall(p)
365 }
366}
367
368#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
370#[serde(rename_all = "snake_case")]
371pub enum FinishReason {
372 Stop,
374 Length,
376 ContentFilter,
378 ToolCall,
380 Error,
382 EndTurn,
384 StopSequence,
386}
387
388impl FinishReason {
389 #[must_use]
391 pub fn is_complete(&self) -> bool {
392 matches!(self, Self::Stop | Self::EndTurn | Self::StopSequence)
393 }
394
395 #[must_use]
397 pub fn is_truncated(&self) -> bool {
398 matches!(self, Self::Length)
399 }
400
401 #[must_use]
403 pub fn is_tool_call(&self) -> bool {
404 matches!(self, Self::ToolCall)
405 }
406}
407
408impl std::fmt::Display for FinishReason {
409 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410 match self {
411 Self::Stop => write!(f, "stop"),
412 Self::Length => write!(f, "length"),
413 Self::ContentFilter => write!(f, "content_filter"),
414 Self::ToolCall => write!(f, "tool_call"),
415 Self::Error => write!(f, "error"),
416 Self::EndTurn => write!(f, "end_turn"),
417 Self::StopSequence => write!(f, "stop_sequence"),
418 }
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn test_model_response_new() {
428 let response = ModelResponse::new();
429 assert!(response.is_empty());
430 assert!(!response.has_tool_calls());
431 }
432
433 #[test]
434 fn test_model_response_text() {
435 let response = ModelResponse::text("Hello, world!");
436 assert_eq!(response.len(), 1);
437 assert_eq!(response.text_content(), "Hello, world!");
438 }
439
440 #[test]
441 fn test_model_response_with_tool_calls() {
442 let response = ModelResponse::with_parts(vec![
443 ModelResponsePart::text("Let me check the weather."),
444 ModelResponsePart::tool_call("get_weather", serde_json::json!({"city": "NYC"})),
445 ]);
446 assert!(response.has_tool_calls());
447 assert_eq!(response.tool_call_parts().count(), 1);
448 }
449
450 #[test]
451 fn test_finish_reason() {
452 assert!(FinishReason::Stop.is_complete());
453 assert!(FinishReason::Length.is_truncated());
454 assert!(FinishReason::ToolCall.is_tool_call());
455 }
456
457 #[test]
458 fn test_serde_roundtrip() {
459 let response = ModelResponse::with_parts(vec![
460 ModelResponsePart::text("Hello"),
461 ModelResponsePart::thinking("Thinking..."),
462 ])
463 .with_model_name("gpt-4")
464 .with_finish_reason(FinishReason::Stop);
465
466 let json = serde_json::to_string(&response).unwrap();
467 let parsed: ModelResponse = serde_json::from_str(&json).unwrap();
468 assert_eq!(response.len(), parsed.len());
469 assert_eq!(response.model_name, parsed.model_name);
470 }
471}