Skip to main content

serdes_ai_core/messages/
request.rs

1//! Request message types for model interactions.
2//!
3//! This module defines the message types that are sent TO the model,
4//! including system prompts, user prompts, tool returns, and retry prompts.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8
9use super::content::UserContent;
10use super::parts::BuiltinToolReturnPart;
11use super::tool_return::ToolReturnContent;
12
13/// A complete model request containing multiple parts.
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
15pub struct ModelRequest {
16    /// The request parts.
17    pub parts: Vec<ModelRequestPart>,
18    /// Kind identifier.
19    #[serde(default = "default_request_kind")]
20    pub kind: String,
21}
22
23fn default_request_kind() -> String {
24    "request".to_string()
25}
26
27impl ModelRequest {
28    /// Create a new empty request.
29    #[must_use]
30    pub fn new() -> Self {
31        Self {
32            parts: Vec::new(),
33            kind: "request".to_string(),
34        }
35    }
36
37    /// Create a request with the given parts.
38    #[must_use]
39    pub fn with_parts(parts: Vec<ModelRequestPart>) -> Self {
40        Self {
41            parts,
42            kind: "request".to_string(),
43        }
44    }
45
46    /// Add a part.
47    pub fn add_part(&mut self, part: ModelRequestPart) {
48        self.parts.push(part);
49    }
50
51    /// Add a system prompt.
52    pub fn add_system_prompt(&mut self, content: impl Into<String>) {
53        self.parts
54            .push(ModelRequestPart::SystemPrompt(SystemPromptPart::new(
55                content,
56            )));
57    }
58
59    /// Add a user prompt.
60    pub fn add_user_prompt(&mut self, content: impl Into<UserContent>) {
61        self.parts
62            .push(ModelRequestPart::UserPrompt(UserPromptPart::new(content)));
63    }
64
65    /// Get all system prompts.
66    pub fn system_prompts(&self) -> impl Iterator<Item = &SystemPromptPart> {
67        self.parts.iter().filter_map(|p| match p {
68            ModelRequestPart::SystemPrompt(s) => Some(s),
69            _ => None,
70        })
71    }
72
73    /// Get all user prompts.
74    pub fn user_prompts(&self) -> impl Iterator<Item = &UserPromptPart> {
75        self.parts.iter().filter_map(|p| match p {
76            ModelRequestPart::UserPrompt(u) => Some(u),
77            _ => None,
78        })
79    }
80
81    /// Get all tool returns.
82    pub fn tool_returns(&self) -> impl Iterator<Item = &ToolReturnPart> {
83        self.parts.iter().filter_map(|p| match p {
84            ModelRequestPart::ToolReturn(t) => Some(t),
85            _ => None,
86        })
87    }
88
89    /// Get all builtin tool returns.
90    pub fn builtin_tool_returns(&self) -> impl Iterator<Item = &BuiltinToolReturnPart> {
91        self.parts.iter().filter_map(|p| match p {
92            ModelRequestPart::BuiltinToolReturn(b) => Some(b),
93            _ => None,
94        })
95    }
96
97    /// Get all system prompts as a vector.
98    #[deprecated(note = "Use system_prompts() iterator instead")]
99    pub fn system_prompts_vec(&self) -> Vec<&SystemPromptPart> {
100        self.system_prompts().collect()
101    }
102
103    /// Get all user prompts as a vector.
104    #[deprecated(note = "Use user_prompts() iterator instead")]
105    pub fn user_prompts_vec(&self) -> Vec<&UserPromptPart> {
106        self.user_prompts().collect()
107    }
108
109    /// Get all tool returns as a vector.
110    #[deprecated(note = "Use tool_returns() iterator instead")]
111    pub fn tool_returns_vec(&self) -> Vec<&ToolReturnPart> {
112        self.tool_returns().collect()
113    }
114
115    /// Get all builtin tool returns as a vector.
116    #[deprecated(note = "Use builtin_tool_returns() iterator instead")]
117    pub fn builtin_tool_returns_vec(&self) -> Vec<&BuiltinToolReturnPart> {
118        self.builtin_tool_returns().collect()
119    }
120
121    /// Add a builtin tool return.
122    pub fn add_builtin_tool_return(&mut self, part: BuiltinToolReturnPart) {
123        self.parts.push(ModelRequestPart::BuiltinToolReturn(part));
124    }
125
126    /// Check if the request is empty.
127    #[must_use]
128    pub fn is_empty(&self) -> bool {
129        self.parts.is_empty()
130    }
131
132    /// Get the number of parts.
133    #[must_use]
134    pub fn len(&self) -> usize {
135        self.parts.len()
136    }
137}
138
139impl Default for ModelRequest {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145impl FromIterator<ModelRequestPart> for ModelRequest {
146    fn from_iter<T: IntoIterator<Item = ModelRequestPart>>(iter: T) -> Self {
147        Self::with_parts(iter.into_iter().collect())
148    }
149}
150
151/// Individual parts of a model request.
152#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
153#[serde(tag = "part_kind", rename_all = "kebab-case")]
154pub enum ModelRequestPart {
155    /// System prompt.
156    SystemPrompt(SystemPromptPart),
157    /// User prompt.
158    UserPrompt(UserPromptPart),
159    /// Tool return.
160    ToolReturn(ToolReturnPart),
161    /// Retry prompt.
162    RetryPrompt(RetryPromptPart),
163    /// Builtin tool return (web search results, code execution output, etc.).
164    BuiltinToolReturn(BuiltinToolReturnPart),
165    /// Model response (for multi-turn conversations).
166    /// This represents the assistant's previous response, which MUST be included
167    /// when sending tool results to ensure proper user/assistant message alternation.
168    ModelResponse(Box<super::response::ModelResponse>),
169}
170
171impl ModelRequestPart {
172    /// Get the timestamp of this part.
173    #[must_use]
174    pub fn timestamp(&self) -> DateTime<Utc> {
175        match self {
176            Self::SystemPrompt(p) => p.timestamp,
177            Self::UserPrompt(p) => p.timestamp,
178            Self::ToolReturn(p) => p.timestamp,
179            Self::RetryPrompt(p) => p.timestamp,
180            Self::BuiltinToolReturn(p) => p.timestamp,
181            Self::ModelResponse(r) => r.timestamp,
182        }
183    }
184
185    /// Get the part kind string.
186    #[must_use]
187    pub fn part_kind(&self) -> &'static str {
188        match self {
189            Self::SystemPrompt(_) => SystemPromptPart::PART_KIND,
190            Self::UserPrompt(_) => UserPromptPart::PART_KIND,
191            Self::ToolReturn(_) => ToolReturnPart::PART_KIND,
192            Self::RetryPrompt(_) => RetryPromptPart::PART_KIND,
193            Self::BuiltinToolReturn(_) => BuiltinToolReturnPart::PART_KIND,
194            Self::ModelResponse(_) => "model-response",
195        }
196    }
197
198    /// Check if this is a builtin tool return.
199    #[must_use]
200    pub fn is_builtin_tool_return(&self) -> bool {
201        matches!(self, Self::BuiltinToolReturn(_))
202    }
203
204    /// Check if this is a model response.
205    #[must_use]
206    pub fn is_model_response(&self) -> bool {
207        matches!(self, Self::ModelResponse(_))
208    }
209}
210
211/// System prompt part.
212#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
213pub struct SystemPromptPart {
214    /// The system prompt content.
215    pub content: String,
216    /// When this part was created.
217    pub timestamp: DateTime<Utc>,
218    /// Reference to a dynamic prompt source.
219    #[serde(skip_serializing_if = "Option::is_none")]
220    pub dynamic_ref: Option<String>,
221}
222
223impl SystemPromptPart {
224    /// Part kind identifier.
225    pub const PART_KIND: &'static str = "system-prompt";
226
227    /// Create a new system prompt part.
228    #[must_use]
229    pub fn new(content: impl Into<String>) -> Self {
230        Self {
231            content: content.into(),
232            timestamp: Utc::now(),
233            dynamic_ref: None,
234        }
235    }
236
237    /// Get the part kind.
238    #[must_use]
239    pub fn part_kind(&self) -> &'static str {
240        Self::PART_KIND
241    }
242
243    /// Set the dynamic reference.
244    #[must_use]
245    pub fn with_dynamic_ref(mut self, ref_name: impl Into<String>) -> Self {
246        self.dynamic_ref = Some(ref_name.into());
247        self
248    }
249
250    /// Set the timestamp.
251    #[must_use]
252    pub fn with_timestamp(mut self, timestamp: DateTime<Utc>) -> Self {
253        self.timestamp = timestamp;
254        self
255    }
256}
257
258impl From<String> for SystemPromptPart {
259    fn from(s: String) -> Self {
260        Self::new(s)
261    }
262}
263
264impl From<&str> for SystemPromptPart {
265    fn from(s: &str) -> Self {
266        Self::new(s)
267    }
268}
269
270/// User prompt part.
271#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
272pub struct UserPromptPart {
273    /// The user prompt content.
274    pub content: UserContent,
275    /// When this part was created.
276    pub timestamp: DateTime<Utc>,
277}
278
279impl UserPromptPart {
280    /// Part kind identifier.
281    pub const PART_KIND: &'static str = "user-prompt";
282
283    /// Create a new user prompt part.
284    #[must_use]
285    pub fn new(content: impl Into<UserContent>) -> Self {
286        Self {
287            content: content.into(),
288            timestamp: Utc::now(),
289        }
290    }
291
292    /// Get the part kind.
293    #[must_use]
294    pub fn part_kind(&self) -> &'static str {
295        Self::PART_KIND
296    }
297
298    /// Set the timestamp.
299    #[must_use]
300    pub fn with_timestamp(mut self, timestamp: DateTime<Utc>) -> Self {
301        self.timestamp = timestamp;
302        self
303    }
304
305    /// Get content as text if it's text content.
306    #[must_use]
307    pub fn as_text(&self) -> Option<&str> {
308        self.content.as_text()
309    }
310}
311
312impl From<String> for UserPromptPart {
313    fn from(s: String) -> Self {
314        Self::new(UserContent::text(s))
315    }
316}
317
318impl From<&str> for UserPromptPart {
319    fn from(s: &str) -> Self {
320        Self::new(UserContent::text(s))
321    }
322}
323
324/// Tool return part.
325#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
326pub struct ToolReturnPart {
327    /// Name of the tool.
328    pub tool_name: String,
329    /// The return content.
330    pub content: ToolReturnContent,
331    /// Tool call ID this is responding to.
332    #[serde(skip_serializing_if = "Option::is_none")]
333    pub tool_call_id: Option<String>,
334    /// When this part was created.
335    pub timestamp: DateTime<Utc>,
336}
337
338impl ToolReturnPart {
339    /// Part kind identifier.
340    pub const PART_KIND: &'static str = "tool-return";
341
342    /// Create a new tool return part.
343    #[must_use]
344    pub fn new(tool_name: impl Into<String>, content: impl Into<ToolReturnContent>) -> Self {
345        Self {
346            tool_name: tool_name.into(),
347            content: content.into(),
348            tool_call_id: None,
349            timestamp: Utc::now(),
350        }
351    }
352
353    /// Get the part kind.
354    #[must_use]
355    pub fn part_kind(&self) -> &'static str {
356        Self::PART_KIND
357    }
358
359    /// Set the tool call ID.
360    #[must_use]
361    pub fn with_tool_call_id(mut self, id: impl Into<String>) -> Self {
362        self.tool_call_id = Some(id.into());
363        self
364    }
365
366    /// Set the timestamp.
367    #[must_use]
368    pub fn with_timestamp(mut self, timestamp: DateTime<Utc>) -> Self {
369        self.timestamp = timestamp;
370        self
371    }
372
373    /// Create a success return.
374    #[must_use]
375    pub fn success(tool_name: impl Into<String>, content: impl Into<String>) -> Self {
376        Self::new(tool_name, ToolReturnContent::text(content))
377    }
378
379    /// Create an error return.
380    #[must_use]
381    pub fn error(tool_name: impl Into<String>, message: impl Into<String>) -> Self {
382        Self::new(tool_name, ToolReturnContent::error(message))
383    }
384}
385
386/// Retry content - either text or structured error info.
387#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
388#[serde(untagged)]
389pub enum RetryContent {
390    /// Plain text retry message.
391    Text(String),
392    /// Structured retry info.
393    Structured {
394        /// The error message.
395        message: String,
396        /// Optional validation errors.
397        #[serde(skip_serializing_if = "Option::is_none")]
398        errors: Option<Vec<String>>,
399    },
400}
401
402impl RetryContent {
403    /// Create text retry content.
404    #[must_use]
405    pub fn text(s: impl Into<String>) -> Self {
406        Self::Text(s.into())
407    }
408
409    /// Create structured retry content.
410    #[must_use]
411    pub fn structured(message: impl Into<String>, errors: Option<Vec<String>>) -> Self {
412        Self::Structured {
413            message: message.into(),
414            errors,
415        }
416    }
417
418    /// Get the message.
419    #[must_use]
420    pub fn message(&self) -> &str {
421        match self {
422            Self::Text(s) => s,
423            Self::Structured { message, .. } => message,
424        }
425    }
426}
427
428impl Default for RetryContent {
429    fn default() -> Self {
430        Self::Text(String::new())
431    }
432}
433
434impl From<String> for RetryContent {
435    fn from(s: String) -> Self {
436        Self::Text(s)
437    }
438}
439
440impl From<&str> for RetryContent {
441    fn from(s: &str) -> Self {
442        Self::Text(s.to_string())
443    }
444}
445
446/// Retry prompt part.
447#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
448pub struct RetryPromptPart {
449    /// The retry content.
450    pub content: RetryContent,
451    /// Tool name if this is a tool retry.
452    #[serde(skip_serializing_if = "Option::is_none")]
453    pub tool_name: Option<String>,
454    /// Tool call ID if this is a tool retry.
455    #[serde(skip_serializing_if = "Option::is_none")]
456    pub tool_call_id: Option<String>,
457    /// When this part was created.
458    pub timestamp: DateTime<Utc>,
459}
460
461impl RetryPromptPart {
462    /// Part kind identifier.
463    pub const PART_KIND: &'static str = "retry-prompt";
464
465    /// Create a new retry prompt part.
466    #[must_use]
467    pub fn new(content: impl Into<RetryContent>) -> Self {
468        Self {
469            content: content.into(),
470            tool_name: None,
471            tool_call_id: None,
472            timestamp: Utc::now(),
473        }
474    }
475
476    /// Get the part kind.
477    #[must_use]
478    pub fn part_kind(&self) -> &'static str {
479        Self::PART_KIND
480    }
481
482    /// Set the tool name.
483    #[must_use]
484    pub fn with_tool_name(mut self, name: impl Into<String>) -> Self {
485        self.tool_name = Some(name.into());
486        self
487    }
488
489    /// Set the tool call ID.
490    #[must_use]
491    pub fn with_tool_call_id(mut self, id: impl Into<String>) -> Self {
492        self.tool_call_id = Some(id.into());
493        self
494    }
495
496    /// Set the timestamp.
497    #[must_use]
498    pub fn with_timestamp(mut self, timestamp: DateTime<Utc>) -> Self {
499        self.timestamp = timestamp;
500        self
501    }
502
503    /// Create a tool retry.
504    #[must_use]
505    pub fn tool_retry(tool_name: impl Into<String>, message: impl Into<String>) -> Self {
506        Self::new(message.into()).with_tool_name(tool_name)
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    #[test]
515    fn test_model_request_new() {
516        let mut req = ModelRequest::new();
517        assert!(req.is_empty());
518
519        req.add_system_prompt("You are a helpful assistant.");
520        req.add_user_prompt("Hello!");
521
522        assert_eq!(req.len(), 2);
523        assert_eq!(req.system_prompts().count(), 1);
524        assert_eq!(req.user_prompts().count(), 1);
525    }
526
527    #[test]
528    fn test_system_prompt_part() {
529        let part = SystemPromptPart::new("Be helpful").with_dynamic_ref("main_prompt");
530        assert_eq!(part.content, "Be helpful");
531        assert_eq!(part.dynamic_ref, Some("main_prompt".to_string()));
532        assert_eq!(part.part_kind(), "system-prompt");
533    }
534
535    #[test]
536    fn test_tool_return_part() {
537        let part =
538            ToolReturnPart::success("get_weather", "72°F, sunny").with_tool_call_id("call_123");
539        assert_eq!(part.tool_name, "get_weather");
540        assert_eq!(part.tool_call_id, Some("call_123".to_string()));
541    }
542
543    #[test]
544    fn test_retry_prompt_part() {
545        let part = RetryPromptPart::tool_retry("my_tool", "Invalid JSON").with_tool_call_id("id1");
546        assert_eq!(part.tool_name, Some("my_tool".to_string()));
547        assert_eq!(part.content.message(), "Invalid JSON");
548    }
549
550    #[test]
551    fn test_serde_roundtrip() {
552        let req = ModelRequest::with_parts(vec![
553            ModelRequestPart::SystemPrompt(SystemPromptPart::new("System")),
554            ModelRequestPart::UserPrompt(UserPromptPart::new("User")),
555        ]);
556        let json = serde_json::to_string(&req).unwrap();
557        let parsed: ModelRequest = serde_json::from_str(&json).unwrap();
558        assert_eq!(req.len(), parsed.len());
559    }
560
561    #[test]
562    fn test_builtin_tool_return() {
563        use crate::messages::parts::{BuiltinToolReturnContent, WebSearchResult, WebSearchResults};
564
565        let results = WebSearchResults::new(
566            "rust programming",
567            vec![WebSearchResult::new("Rust", "https://rust-lang.org")],
568        );
569        let content = BuiltinToolReturnContent::web_search(results);
570        let part = BuiltinToolReturnPart::new("web_search", content, "call_123");
571
572        let mut req = ModelRequest::new();
573        req.add_builtin_tool_return(part);
574
575        assert_eq!(req.len(), 1);
576        assert_eq!(req.builtin_tool_returns().count(), 1);
577
578        let returns: Vec<_> = req.builtin_tool_returns().collect();
579        assert_eq!(returns[0].tool_name, "web_search");
580        assert_eq!(returns[0].tool_call_id, "call_123");
581    }
582
583    #[test]
584    fn test_model_request_part_is_builtin_tool_return() {
585        use crate::messages::parts::{BuiltinToolReturnContent, CodeExecutionResult};
586
587        let result = CodeExecutionResult::new("print(1)").with_stdout("1\n");
588        let content = BuiltinToolReturnContent::code_execution(result);
589        let part = BuiltinToolReturnPart::new("code_execution", content, "call_456");
590        let request_part = ModelRequestPart::BuiltinToolReturn(part);
591
592        assert!(request_part.is_builtin_tool_return());
593        assert_eq!(request_part.part_kind(), "builtin-tool-return");
594    }
595
596    #[test]
597    fn test_serde_roundtrip_with_builtin_tool_return() {
598        use crate::messages::parts::{
599            BuiltinToolReturnContent, FileSearchResult, FileSearchResults,
600        };
601
602        let results = FileSearchResults::new(
603            "main function",
604            vec![FileSearchResult::new("main.rs", "fn main() {}")],
605        );
606        let content = BuiltinToolReturnContent::file_search(results);
607        let part = BuiltinToolReturnPart::new("file_search", content, "call_789");
608
609        let req = ModelRequest::with_parts(vec![
610            ModelRequestPart::UserPrompt(UserPromptPart::new("Search files")),
611            ModelRequestPart::BuiltinToolReturn(part),
612        ]);
613
614        let json = serde_json::to_string(&req).unwrap();
615        let parsed: ModelRequest = serde_json::from_str(&json).unwrap();
616
617        assert_eq!(req.len(), parsed.len());
618        assert_eq!(parsed.builtin_tool_returns().count(), 1);
619    }
620}