Skip to main content

xai_rust/models/
response.rs

1//! Response types for the API.
2
3use serde::{Deserialize, Serialize};
4
5use super::message::Role;
6use super::tool::ToolCall;
7use super::usage::{ServerSideToolUsage, Usage};
8
9/// Response from the Responses API.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Response {
12    /// Unique identifier for this response.
13    pub id: String,
14    /// The model used for generation.
15    pub model: String,
16    /// Output items from the response.
17    pub output: Vec<OutputItem>,
18    /// Token usage statistics.
19    #[serde(default)]
20    pub usage: Usage,
21    /// Citations from search tools.
22    #[serde(default)]
23    pub citations: Option<Vec<String>>,
24    /// Inline citations in the response text.
25    #[serde(default)]
26    pub inline_citations: Option<Vec<InlineCitation>>,
27    /// Server-side tool usage statistics.
28    #[serde(default)]
29    pub server_side_tool_usage: Option<ServerSideToolUsage>,
30    /// Tool calls made during generation.
31    #[serde(default)]
32    pub tool_calls: Option<Vec<ToolCall>>,
33    /// System fingerprint for tracking.
34    #[serde(default)]
35    pub system_fingerprint: Option<String>,
36}
37
38impl Response {
39    /// Get the text content from the first message output.
40    pub fn output_text(&self) -> Option<String> {
41        self.output.iter().find_map(|item| {
42            if let OutputItem::Message { content, .. } = item {
43                content.iter().find_map(|c| {
44                    if let TextContent::Text { text } = c {
45                        Some(text.clone())
46                    } else {
47                        None
48                    }
49                })
50            } else {
51                None
52            }
53        })
54    }
55
56    /// Get all text content from the response.
57    pub fn all_text(&self) -> String {
58        self.output
59            .iter()
60            .filter_map(|item| {
61                if let OutputItem::Message { content, .. } = item {
62                    Some(
63                        content
64                            .iter()
65                            .filter_map(|c| {
66                                if let TextContent::Text { text } = c {
67                                    Some(text.as_str())
68                                } else {
69                                    None
70                                }
71                            })
72                            .collect::<Vec<_>>()
73                            .join(""),
74                    )
75                } else {
76                    None
77                }
78            })
79            .collect::<Vec<_>>()
80            .join("")
81    }
82}
83
84/// Output item in a response.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86#[serde(tag = "type", rename_all = "snake_case")]
87pub enum OutputItem {
88    /// A message from the assistant.
89    Message {
90        /// The role (always assistant for output).
91        role: Role,
92        /// The content of the message.
93        content: Vec<TextContent>,
94    },
95    /// A tool call.
96    #[serde(rename = "function_call")]
97    FunctionCall {
98        /// The function call details.
99        #[serde(flatten)]
100        call: ToolCall,
101    },
102    /// Code interpreter call.
103    CodeInterpreterCall {
104        /// The code interpreter call ID.
105        id: String,
106        /// The code that was executed.
107        code: Option<String>,
108        /// The outputs from code execution.
109        outputs: Option<Vec<CodeInterpreterOutput>>,
110    },
111    /// Web search call.
112    WebSearchCall {
113        /// The search call ID.
114        id: String,
115        /// Search results.
116        results: Option<Vec<SearchResult>>,
117    },
118    /// X search call.
119    XSearchCall {
120        /// The search call ID.
121        id: String,
122        /// Search results.
123        results: Option<Vec<SearchResult>>,
124    },
125}
126
127/// Text content in a message.
128#[derive(Debug, Clone, Serialize, Deserialize)]
129#[serde(tag = "type", rename_all = "snake_case")]
130pub enum TextContent {
131    /// Plain text.
132    Text {
133        /// The text content.
134        text: String,
135    },
136    /// Refusal message.
137    Refusal {
138        /// The refusal reason.
139        refusal: String,
140    },
141}
142
143/// Code interpreter output.
144#[derive(Debug, Clone, Serialize, Deserialize)]
145#[serde(tag = "type", rename_all = "snake_case")]
146pub enum CodeInterpreterOutput {
147    /// Log output.
148    Logs {
149        /// The log content.
150        logs: String,
151    },
152    /// Image output.
153    Image {
154        /// The image data (base64).
155        image: String,
156    },
157}
158
159/// Search result from web/X search.
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct SearchResult {
162    /// The title of the result.
163    pub title: Option<String>,
164    /// The URL of the result.
165    pub url: Option<String>,
166    /// A snippet from the result.
167    pub snippet: Option<String>,
168}
169
170/// Inline citation in response text.
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct InlineCitation {
173    /// Citation ID (e.g., "1", "2").
174    pub id: String,
175    /// Start position in the text.
176    #[serde(default)]
177    pub start_index: Option<u32>,
178    /// End position in the text.
179    #[serde(default)]
180    pub end_index: Option<u32>,
181    /// Web citation details.
182    #[serde(default)]
183    pub web_citation: Option<WebCitation>,
184    /// X citation details.
185    #[serde(default)]
186    pub x_citation: Option<XCitation>,
187}
188
189/// Web citation details.
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct WebCitation {
192    /// The URL of the source.
193    pub url: String,
194    /// The title of the source.
195    #[serde(default)]
196    pub title: Option<String>,
197}
198
199/// X/Twitter citation details.
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct XCitation {
202    /// The URL of the X post.
203    pub url: String,
204    /// The author's handle.
205    #[serde(default)]
206    pub author_handle: Option<String>,
207}
208
209/// Response format specification.
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ResponseFormat {
212    /// The type of response format.
213    #[serde(rename = "type")]
214    pub format_type: ResponseFormatType,
215    /// JSON schema for structured output.
216    #[serde(skip_serializing_if = "Option::is_none")]
217    pub json_schema: Option<JsonSchema>,
218}
219
220impl ResponseFormat {
221    /// Create a text response format.
222    pub fn text() -> Self {
223        Self {
224            format_type: ResponseFormatType::Text,
225            json_schema: None,
226        }
227    }
228
229    /// Create a JSON object response format.
230    pub fn json_object() -> Self {
231        Self {
232            format_type: ResponseFormatType::JsonObject,
233            json_schema: None,
234        }
235    }
236
237    /// Create a JSON schema response format.
238    pub fn json_schema(name: impl Into<String>, schema: serde_json::Value) -> Self {
239        Self {
240            format_type: ResponseFormatType::JsonSchema,
241            json_schema: Some(JsonSchema {
242                name: name.into(),
243                schema,
244                strict: Some(true),
245            }),
246        }
247    }
248}
249
250/// Response format type.
251#[derive(Debug, Clone, Serialize, Deserialize)]
252#[serde(rename_all = "snake_case")]
253pub enum ResponseFormatType {
254    /// Plain text response.
255    Text,
256    /// JSON object response.
257    JsonObject,
258    /// Structured JSON with schema.
259    JsonSchema,
260}
261
262/// JSON schema specification for structured output.
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct JsonSchema {
265    /// Name of the schema.
266    pub name: String,
267    /// The JSON schema definition.
268    pub schema: serde_json::Value,
269    /// Whether to strictly enforce the schema.
270    #[serde(skip_serializing_if = "Option::is_none")]
271    pub strict: Option<bool>,
272}
273
274/// A chunk from a streaming response.
275#[derive(Debug, Clone)]
276pub struct StreamChunk {
277    /// The delta text content.
278    pub delta: Option<String>,
279    /// Reasoning content delta (for reasoning models).
280    pub reasoning_delta: Option<String>,
281    /// Tool calls in this chunk.
282    pub tool_calls: Vec<ToolCall>,
283    /// Whether this is the final chunk.
284    pub done: bool,
285    /// The full response (only present in the final chunk).
286    pub response: Option<Response>,
287}
288
289impl StreamChunk {
290    /// Get the delta text.
291    pub fn delta(&self) -> &str {
292        self.delta.as_deref().unwrap_or("")
293    }
294
295    /// Check if there is content in this chunk.
296    pub fn has_content(&self) -> bool {
297        self.delta.is_some() || self.reasoning_delta.is_some() || !self.tool_calls.is_empty()
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use serde_json::json;
305
306    // ── OutputItem serde roundtrips ─────────────────────────────────────
307
308    #[test]
309    fn output_item_message_roundtrip() {
310        let item = OutputItem::Message {
311            role: Role::Assistant,
312            content: vec![TextContent::Text {
313                text: "Hello!".to_string(),
314            }],
315        };
316        let json_val = serde_json::to_value(&item).unwrap();
317        assert_eq!(json_val["type"], "message");
318        assert_eq!(json_val["role"], "assistant");
319        assert_eq!(json_val["content"][0]["type"], "text");
320        assert_eq!(json_val["content"][0]["text"], "Hello!");
321
322        let back: OutputItem = serde_json::from_value(json_val).unwrap();
323        if let OutputItem::Message { role, content } = back {
324            assert_eq!(role, Role::Assistant);
325            assert_eq!(content.len(), 1);
326        } else {
327            panic!("Expected Message variant");
328        }
329    }
330
331    #[test]
332    fn output_item_function_call_roundtrip() {
333        // When call_type is None, the tag "type" is not overwritten by ToolCall
334        let item = OutputItem::FunctionCall {
335            call: ToolCall {
336                id: "call_1".to_string(),
337                call_type: None,
338                function: Some(crate::models::tool::FunctionCall {
339                    name: "test_fn".to_string(),
340                    arguments: "{}".to_string(),
341                }),
342            },
343        };
344        let json_val = serde_json::to_value(&item).unwrap();
345        assert_eq!(json_val["type"], "function_call");
346        assert_eq!(json_val["id"], "call_1");
347        assert_eq!(json_val["function"]["name"], "test_fn");
348
349        let back: OutputItem = serde_json::from_value(json_val).unwrap();
350        assert!(matches!(back, OutputItem::FunctionCall { .. }));
351    }
352
353    #[test]
354    fn output_item_function_call_deserialize_from_api() {
355        // Typical API response shape for a function call output item
356        let api_json = json!({
357            "type": "function_call",
358            "id": "call_abc",
359            "function": {
360                "name": "get_weather",
361                "arguments": "{\"city\":\"NYC\"}"
362            }
363        });
364        let item: OutputItem = serde_json::from_value(api_json).unwrap();
365        if let OutputItem::FunctionCall { call } = item {
366            assert_eq!(call.id, "call_abc");
367            assert_eq!(call.function.unwrap().name, "get_weather");
368        } else {
369            panic!("Expected FunctionCall variant");
370        }
371    }
372
373    #[test]
374    fn output_item_code_interpreter_call_roundtrip() {
375        let item = OutputItem::CodeInterpreterCall {
376            id: "ci_1".to_string(),
377            code: Some("print('hi')".to_string()),
378            outputs: Some(vec![CodeInterpreterOutput::Logs {
379                logs: "hi".to_string(),
380            }]),
381        };
382        let json_val = serde_json::to_value(&item).unwrap();
383        assert_eq!(json_val["type"], "code_interpreter_call");
384        assert_eq!(json_val["code"], "print('hi')");
385
386        let back: OutputItem = serde_json::from_value(json_val).unwrap();
387        assert!(matches!(back, OutputItem::CodeInterpreterCall { .. }));
388    }
389
390    #[test]
391    fn output_item_web_search_call_roundtrip() {
392        let item = OutputItem::WebSearchCall {
393            id: "ws_1".to_string(),
394            results: Some(vec![SearchResult {
395                title: Some("Example".to_string()),
396                url: Some("https://example.com".to_string()),
397                snippet: Some("A snippet".to_string()),
398            }]),
399        };
400        let json_val = serde_json::to_value(&item).unwrap();
401        assert_eq!(json_val["type"], "web_search_call");
402
403        let back: OutputItem = serde_json::from_value(json_val).unwrap();
404        if let OutputItem::WebSearchCall { results, .. } = back {
405            assert_eq!(results.unwrap()[0].title.as_deref(), Some("Example"));
406        } else {
407            panic!("Expected WebSearchCall variant");
408        }
409    }
410
411    #[test]
412    fn output_item_x_search_call_roundtrip() {
413        let item = OutputItem::XSearchCall {
414            id: "xs_1".to_string(),
415            results: None,
416        };
417        let json_val = serde_json::to_value(&item).unwrap();
418        assert_eq!(json_val["type"], "x_search_call");
419
420        let back: OutputItem = serde_json::from_value(json_val).unwrap();
421        assert!(matches!(back, OutputItem::XSearchCall { .. }));
422    }
423
424    // ── TextContent serde roundtrips ───────────────────────────────────
425
426    #[test]
427    fn text_content_text_roundtrip() {
428        let tc = TextContent::Text {
429            text: "Some text".to_string(),
430        };
431        let json_val = serde_json::to_value(&tc).unwrap();
432        assert_eq!(json_val["type"], "text");
433        assert_eq!(json_val["text"], "Some text");
434
435        let back: TextContent = serde_json::from_value(json_val).unwrap();
436        assert!(matches!(back, TextContent::Text { .. }));
437    }
438
439    #[test]
440    fn text_content_refusal_roundtrip() {
441        let tc = TextContent::Refusal {
442            refusal: "I cannot help with that".to_string(),
443        };
444        let json_val = serde_json::to_value(&tc).unwrap();
445        assert_eq!(json_val["type"], "refusal");
446        assert_eq!(json_val["refusal"], "I cannot help with that");
447
448        let back: TextContent = serde_json::from_value(json_val).unwrap();
449        if let TextContent::Refusal { refusal } = back {
450            assert_eq!(refusal, "I cannot help with that");
451        } else {
452            panic!("Expected Refusal variant");
453        }
454    }
455
456    // ── ResponseFormat serde roundtrips ─────────────────────────────────
457
458    #[test]
459    fn response_format_text_roundtrip() {
460        let rf = ResponseFormat::text();
461        let json_val = serde_json::to_value(&rf).unwrap();
462        assert_eq!(json_val["type"], "text");
463        assert!(json_val.get("json_schema").is_none());
464
465        let back: ResponseFormat = serde_json::from_value(json_val).unwrap();
466        assert!(matches!(back.format_type, ResponseFormatType::Text));
467    }
468
469    #[test]
470    fn response_format_json_object_roundtrip() {
471        let rf = ResponseFormat::json_object();
472        let json_val = serde_json::to_value(&rf).unwrap();
473        assert_eq!(json_val["type"], "json_object");
474
475        let back: ResponseFormat = serde_json::from_value(json_val).unwrap();
476        assert!(matches!(back.format_type, ResponseFormatType::JsonObject));
477    }
478
479    #[test]
480    fn response_format_json_schema_roundtrip() {
481        let rf = ResponseFormat::json_schema(
482            "my_schema",
483            json!({"type": "object", "properties": {"x": {"type": "integer"}}}),
484        );
485        let json_val = serde_json::to_value(&rf).unwrap();
486        assert_eq!(json_val["type"], "json_schema");
487        assert_eq!(json_val["json_schema"]["name"], "my_schema");
488        assert_eq!(json_val["json_schema"]["strict"], true);
489
490        let back: ResponseFormat = serde_json::from_value(json_val).unwrap();
491        assert!(matches!(back.format_type, ResponseFormatType::JsonSchema));
492        let schema = back.json_schema.unwrap();
493        assert_eq!(schema.name, "my_schema");
494        assert_eq!(schema.strict, Some(true));
495    }
496
497    // ── ResponseFormatType serde roundtrip ──────────────────────────────
498
499    #[test]
500    fn response_format_type_roundtrip_all() {
501        for (variant, expected_str) in [
502            (ResponseFormatType::Text, "text"),
503            (ResponseFormatType::JsonObject, "json_object"),
504            (ResponseFormatType::JsonSchema, "json_schema"),
505        ] {
506            let json_val = serde_json::to_value(&variant).unwrap();
507            assert_eq!(json_val, json!(expected_str));
508
509            let back: ResponseFormatType = serde_json::from_value(json_val).unwrap();
510            // Compare serialised form since no PartialEq
511            let back_str = serde_json::to_value(&back).unwrap();
512            assert_eq!(back_str, json!(expected_str));
513        }
514    }
515
516    // ── Response serde roundtrip (now derives Serialize) ────────────────
517
518    #[test]
519    fn response_minimal_roundtrip() {
520        let resp = Response {
521            id: "resp_1".to_string(),
522            model: "grok-4".to_string(),
523            output: vec![OutputItem::Message {
524                role: Role::Assistant,
525                content: vec![TextContent::Text {
526                    text: "Hi".to_string(),
527                }],
528            }],
529            usage: Usage::default(),
530            citations: None,
531            inline_citations: None,
532            server_side_tool_usage: None,
533            tool_calls: None,
534            system_fingerprint: None,
535        };
536
537        let json_val = serde_json::to_value(&resp).unwrap();
538        assert_eq!(json_val["id"], "resp_1");
539        assert_eq!(json_val["model"], "grok-4");
540
541        let back: Response = serde_json::from_value(json_val).unwrap();
542        assert_eq!(back.id, "resp_1");
543        assert_eq!(back.model, "grok-4");
544        assert_eq!(back.output_text().unwrap(), "Hi");
545    }
546
547    #[test]
548    fn response_output_text_returns_first_text() {
549        let resp = Response {
550            id: "r".to_string(),
551            model: "m".to_string(),
552            output: vec![OutputItem::Message {
553                role: Role::Assistant,
554                content: vec![
555                    TextContent::Refusal {
556                        refusal: "no".to_string(),
557                    },
558                    TextContent::Text {
559                        text: "yes".to_string(),
560                    },
561                ],
562            }],
563            usage: Usage::default(),
564            citations: None,
565            inline_citations: None,
566            server_side_tool_usage: None,
567            tool_calls: None,
568            system_fingerprint: None,
569        };
570        // output_text returns the first TextContent::Text, skipping refusals
571        assert_eq!(resp.output_text(), Some("yes".to_string()));
572    }
573
574    #[test]
575    fn response_all_text_joins() {
576        let resp = Response {
577            id: "r".to_string(),
578            model: "m".to_string(),
579            output: vec![
580                OutputItem::Message {
581                    role: Role::Assistant,
582                    content: vec![TextContent::Text {
583                        text: "Hello ".to_string(),
584                    }],
585                },
586                OutputItem::Message {
587                    role: Role::Assistant,
588                    content: vec![TextContent::Text {
589                        text: "World".to_string(),
590                    }],
591                },
592            ],
593            usage: Usage::default(),
594            citations: None,
595            inline_citations: None,
596            server_side_tool_usage: None,
597            tool_calls: None,
598            system_fingerprint: None,
599        };
600        assert_eq!(resp.all_text(), "Hello World");
601    }
602
603    #[test]
604    fn response_deserialize_from_api_like_json() {
605        let json_val = json!({
606            "id": "resp_abc",
607            "model": "grok-4",
608            "output": [{
609                "type": "message",
610                "role": "assistant",
611                "content": [{
612                    "type": "text",
613                    "text": "The answer is 42."
614                }]
615            }],
616            "usage": {
617                "prompt_tokens": 10,
618                "completion_tokens": 20,
619                "total_tokens": 30
620            },
621            "system_fingerprint": "fp_abc123"
622        });
623
624        let resp: Response = serde_json::from_value(json_val).unwrap();
625        assert_eq!(resp.id, "resp_abc");
626        assert_eq!(resp.usage.prompt_tokens, 10);
627        assert_eq!(resp.usage.completion_tokens, 20);
628        assert_eq!(resp.system_fingerprint.as_deref(), Some("fp_abc123"));
629        assert_eq!(resp.output_text().unwrap(), "The answer is 42.");
630    }
631
632    // ── StreamChunk ────────────────────────────────────────────────────
633
634    #[test]
635    fn stream_chunk_delta_returns_empty_when_none() {
636        let chunk = StreamChunk {
637            delta: None,
638            reasoning_delta: None,
639            tool_calls: vec![],
640            done: false,
641            response: None,
642        };
643        assert_eq!(chunk.delta(), "");
644        assert!(!chunk.has_content());
645    }
646
647    #[test]
648    fn stream_chunk_has_content_with_delta() {
649        let chunk = StreamChunk {
650            delta: Some("hello".to_string()),
651            reasoning_delta: None,
652            tool_calls: vec![],
653            done: false,
654            response: None,
655        };
656        assert!(chunk.has_content());
657        assert_eq!(chunk.delta(), "hello");
658    }
659
660    #[test]
661    fn stream_chunk_has_content_with_reasoning() {
662        let chunk = StreamChunk {
663            delta: None,
664            reasoning_delta: Some("thinking...".to_string()),
665            tool_calls: vec![],
666            done: false,
667            response: None,
668        };
669        assert!(chunk.has_content());
670    }
671
672    #[test]
673    fn stream_chunk_has_content_with_tool_calls() {
674        let chunk = StreamChunk {
675            delta: None,
676            reasoning_delta: None,
677            tool_calls: vec![ToolCall {
678                id: "c1".to_string(),
679                call_type: None,
680                function: None,
681            }],
682            done: false,
683            response: None,
684        };
685        assert!(chunk.has_content());
686    }
687}