1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8
9use super::content::UserContent;
10use super::parts::BuiltinToolReturnPart;
11use super::tool_return::ToolReturnContent;
12
13#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
15pub struct ModelRequest {
16 pub parts: Vec<ModelRequestPart>,
18 #[serde(default = "default_request_kind")]
20 pub kind: String,
21}
22
23fn default_request_kind() -> String {
24 "request".to_string()
25}
26
27impl ModelRequest {
28 #[must_use]
30 pub fn new() -> Self {
31 Self {
32 parts: Vec::new(),
33 kind: "request".to_string(),
34 }
35 }
36
37 #[must_use]
39 pub fn with_parts(parts: Vec<ModelRequestPart>) -> Self {
40 Self {
41 parts,
42 kind: "request".to_string(),
43 }
44 }
45
46 pub fn add_part(&mut self, part: ModelRequestPart) {
48 self.parts.push(part);
49 }
50
51 pub fn add_system_prompt(&mut self, content: impl Into<String>) {
53 self.parts
54 .push(ModelRequestPart::SystemPrompt(SystemPromptPart::new(
55 content,
56 )));
57 }
58
59 pub fn add_user_prompt(&mut self, content: impl Into<UserContent>) {
61 self.parts
62 .push(ModelRequestPart::UserPrompt(UserPromptPart::new(content)));
63 }
64
65 pub fn system_prompts(&self) -> impl Iterator<Item = &SystemPromptPart> {
67 self.parts.iter().filter_map(|p| match p {
68 ModelRequestPart::SystemPrompt(s) => Some(s),
69 _ => None,
70 })
71 }
72
73 pub fn user_prompts(&self) -> impl Iterator<Item = &UserPromptPart> {
75 self.parts.iter().filter_map(|p| match p {
76 ModelRequestPart::UserPrompt(u) => Some(u),
77 _ => None,
78 })
79 }
80
81 pub fn tool_returns(&self) -> impl Iterator<Item = &ToolReturnPart> {
83 self.parts.iter().filter_map(|p| match p {
84 ModelRequestPart::ToolReturn(t) => Some(t),
85 _ => None,
86 })
87 }
88
89 pub fn builtin_tool_returns(&self) -> impl Iterator<Item = &BuiltinToolReturnPart> {
91 self.parts.iter().filter_map(|p| match p {
92 ModelRequestPart::BuiltinToolReturn(b) => Some(b),
93 _ => None,
94 })
95 }
96
97 #[deprecated(note = "Use system_prompts() iterator instead")]
99 pub fn system_prompts_vec(&self) -> Vec<&SystemPromptPart> {
100 self.system_prompts().collect()
101 }
102
103 #[deprecated(note = "Use user_prompts() iterator instead")]
105 pub fn user_prompts_vec(&self) -> Vec<&UserPromptPart> {
106 self.user_prompts().collect()
107 }
108
109 #[deprecated(note = "Use tool_returns() iterator instead")]
111 pub fn tool_returns_vec(&self) -> Vec<&ToolReturnPart> {
112 self.tool_returns().collect()
113 }
114
115 #[deprecated(note = "Use builtin_tool_returns() iterator instead")]
117 pub fn builtin_tool_returns_vec(&self) -> Vec<&BuiltinToolReturnPart> {
118 self.builtin_tool_returns().collect()
119 }
120
121 pub fn add_builtin_tool_return(&mut self, part: BuiltinToolReturnPart) {
123 self.parts.push(ModelRequestPart::BuiltinToolReturn(part));
124 }
125
126 #[must_use]
128 pub fn is_empty(&self) -> bool {
129 self.parts.is_empty()
130 }
131
132 #[must_use]
134 pub fn len(&self) -> usize {
135 self.parts.len()
136 }
137}
138
139impl Default for ModelRequest {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145impl FromIterator<ModelRequestPart> for ModelRequest {
146 fn from_iter<T: IntoIterator<Item = ModelRequestPart>>(iter: T) -> Self {
147 Self::with_parts(iter.into_iter().collect())
148 }
149}
150
151#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
153#[serde(tag = "part_kind", rename_all = "kebab-case")]
154pub enum ModelRequestPart {
155 SystemPrompt(SystemPromptPart),
157 UserPrompt(UserPromptPart),
159 ToolReturn(ToolReturnPart),
161 RetryPrompt(RetryPromptPart),
163 BuiltinToolReturn(BuiltinToolReturnPart),
165 ModelResponse(Box<super::response::ModelResponse>),
169}
170
171impl ModelRequestPart {
172 #[must_use]
174 pub fn timestamp(&self) -> DateTime<Utc> {
175 match self {
176 Self::SystemPrompt(p) => p.timestamp,
177 Self::UserPrompt(p) => p.timestamp,
178 Self::ToolReturn(p) => p.timestamp,
179 Self::RetryPrompt(p) => p.timestamp,
180 Self::BuiltinToolReturn(p) => p.timestamp,
181 Self::ModelResponse(r) => r.timestamp,
182 }
183 }
184
185 #[must_use]
187 pub fn part_kind(&self) -> &'static str {
188 match self {
189 Self::SystemPrompt(_) => SystemPromptPart::PART_KIND,
190 Self::UserPrompt(_) => UserPromptPart::PART_KIND,
191 Self::ToolReturn(_) => ToolReturnPart::PART_KIND,
192 Self::RetryPrompt(_) => RetryPromptPart::PART_KIND,
193 Self::BuiltinToolReturn(_) => BuiltinToolReturnPart::PART_KIND,
194 Self::ModelResponse(_) => "model-response",
195 }
196 }
197
198 #[must_use]
200 pub fn is_builtin_tool_return(&self) -> bool {
201 matches!(self, Self::BuiltinToolReturn(_))
202 }
203
204 #[must_use]
206 pub fn is_model_response(&self) -> bool {
207 matches!(self, Self::ModelResponse(_))
208 }
209}
210
211#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
213pub struct SystemPromptPart {
214 pub content: String,
216 pub timestamp: DateTime<Utc>,
218 #[serde(skip_serializing_if = "Option::is_none")]
220 pub dynamic_ref: Option<String>,
221}
222
223impl SystemPromptPart {
224 pub const PART_KIND: &'static str = "system-prompt";
226
227 #[must_use]
229 pub fn new(content: impl Into<String>) -> Self {
230 Self {
231 content: content.into(),
232 timestamp: Utc::now(),
233 dynamic_ref: None,
234 }
235 }
236
237 #[must_use]
239 pub fn part_kind(&self) -> &'static str {
240 Self::PART_KIND
241 }
242
243 #[must_use]
245 pub fn with_dynamic_ref(mut self, ref_name: impl Into<String>) -> Self {
246 self.dynamic_ref = Some(ref_name.into());
247 self
248 }
249
250 #[must_use]
252 pub fn with_timestamp(mut self, timestamp: DateTime<Utc>) -> Self {
253 self.timestamp = timestamp;
254 self
255 }
256}
257
258impl From<String> for SystemPromptPart {
259 fn from(s: String) -> Self {
260 Self::new(s)
261 }
262}
263
264impl From<&str> for SystemPromptPart {
265 fn from(s: &str) -> Self {
266 Self::new(s)
267 }
268}
269
270#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
272pub struct UserPromptPart {
273 pub content: UserContent,
275 pub timestamp: DateTime<Utc>,
277}
278
279impl UserPromptPart {
280 pub const PART_KIND: &'static str = "user-prompt";
282
283 #[must_use]
285 pub fn new(content: impl Into<UserContent>) -> Self {
286 Self {
287 content: content.into(),
288 timestamp: Utc::now(),
289 }
290 }
291
292 #[must_use]
294 pub fn part_kind(&self) -> &'static str {
295 Self::PART_KIND
296 }
297
298 #[must_use]
300 pub fn with_timestamp(mut self, timestamp: DateTime<Utc>) -> Self {
301 self.timestamp = timestamp;
302 self
303 }
304
305 #[must_use]
307 pub fn as_text(&self) -> Option<&str> {
308 self.content.as_text()
309 }
310}
311
312impl From<String> for UserPromptPart {
313 fn from(s: String) -> Self {
314 Self::new(UserContent::text(s))
315 }
316}
317
318impl From<&str> for UserPromptPart {
319 fn from(s: &str) -> Self {
320 Self::new(UserContent::text(s))
321 }
322}
323
324#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
326pub struct ToolReturnPart {
327 pub tool_name: String,
329 pub content: ToolReturnContent,
331 #[serde(skip_serializing_if = "Option::is_none")]
333 pub tool_call_id: Option<String>,
334 pub timestamp: DateTime<Utc>,
336}
337
338impl ToolReturnPart {
339 pub const PART_KIND: &'static str = "tool-return";
341
342 #[must_use]
344 pub fn new(tool_name: impl Into<String>, content: impl Into<ToolReturnContent>) -> Self {
345 Self {
346 tool_name: tool_name.into(),
347 content: content.into(),
348 tool_call_id: None,
349 timestamp: Utc::now(),
350 }
351 }
352
353 #[must_use]
355 pub fn part_kind(&self) -> &'static str {
356 Self::PART_KIND
357 }
358
359 #[must_use]
361 pub fn with_tool_call_id(mut self, id: impl Into<String>) -> Self {
362 self.tool_call_id = Some(id.into());
363 self
364 }
365
366 #[must_use]
368 pub fn with_timestamp(mut self, timestamp: DateTime<Utc>) -> Self {
369 self.timestamp = timestamp;
370 self
371 }
372
373 #[must_use]
375 pub fn success(tool_name: impl Into<String>, content: impl Into<String>) -> Self {
376 Self::new(tool_name, ToolReturnContent::text(content))
377 }
378
379 #[must_use]
381 pub fn error(tool_name: impl Into<String>, message: impl Into<String>) -> Self {
382 Self::new(tool_name, ToolReturnContent::error(message))
383 }
384}
385
386#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
388#[serde(untagged)]
389pub enum RetryContent {
390 Text(String),
392 Structured {
394 message: String,
396 #[serde(skip_serializing_if = "Option::is_none")]
398 errors: Option<Vec<String>>,
399 },
400}
401
402impl RetryContent {
403 #[must_use]
405 pub fn text(s: impl Into<String>) -> Self {
406 Self::Text(s.into())
407 }
408
409 #[must_use]
411 pub fn structured(message: impl Into<String>, errors: Option<Vec<String>>) -> Self {
412 Self::Structured {
413 message: message.into(),
414 errors,
415 }
416 }
417
418 #[must_use]
420 pub fn message(&self) -> &str {
421 match self {
422 Self::Text(s) => s,
423 Self::Structured { message, .. } => message,
424 }
425 }
426}
427
428impl Default for RetryContent {
429 fn default() -> Self {
430 Self::Text(String::new())
431 }
432}
433
434impl From<String> for RetryContent {
435 fn from(s: String) -> Self {
436 Self::Text(s)
437 }
438}
439
440impl From<&str> for RetryContent {
441 fn from(s: &str) -> Self {
442 Self::Text(s.to_string())
443 }
444}
445
446#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
448pub struct RetryPromptPart {
449 pub content: RetryContent,
451 #[serde(skip_serializing_if = "Option::is_none")]
453 pub tool_name: Option<String>,
454 #[serde(skip_serializing_if = "Option::is_none")]
456 pub tool_call_id: Option<String>,
457 pub timestamp: DateTime<Utc>,
459}
460
461impl RetryPromptPart {
462 pub const PART_KIND: &'static str = "retry-prompt";
464
465 #[must_use]
467 pub fn new(content: impl Into<RetryContent>) -> Self {
468 Self {
469 content: content.into(),
470 tool_name: None,
471 tool_call_id: None,
472 timestamp: Utc::now(),
473 }
474 }
475
476 #[must_use]
478 pub fn part_kind(&self) -> &'static str {
479 Self::PART_KIND
480 }
481
482 #[must_use]
484 pub fn with_tool_name(mut self, name: impl Into<String>) -> Self {
485 self.tool_name = Some(name.into());
486 self
487 }
488
489 #[must_use]
491 pub fn with_tool_call_id(mut self, id: impl Into<String>) -> Self {
492 self.tool_call_id = Some(id.into());
493 self
494 }
495
496 #[must_use]
498 pub fn with_timestamp(mut self, timestamp: DateTime<Utc>) -> Self {
499 self.timestamp = timestamp;
500 self
501 }
502
503 #[must_use]
505 pub fn tool_retry(tool_name: impl Into<String>, message: impl Into<String>) -> Self {
506 Self::new(message.into()).with_tool_name(tool_name)
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 fn test_model_request_new() {
516 let mut req = ModelRequest::new();
517 assert!(req.is_empty());
518
519 req.add_system_prompt("You are a helpful assistant.");
520 req.add_user_prompt("Hello!");
521
522 assert_eq!(req.len(), 2);
523 assert_eq!(req.system_prompts().count(), 1);
524 assert_eq!(req.user_prompts().count(), 1);
525 }
526
527 #[test]
528 fn test_system_prompt_part() {
529 let part = SystemPromptPart::new("Be helpful").with_dynamic_ref("main_prompt");
530 assert_eq!(part.content, "Be helpful");
531 assert_eq!(part.dynamic_ref, Some("main_prompt".to_string()));
532 assert_eq!(part.part_kind(), "system-prompt");
533 }
534
535 #[test]
536 fn test_tool_return_part() {
537 let part =
538 ToolReturnPart::success("get_weather", "72°F, sunny").with_tool_call_id("call_123");
539 assert_eq!(part.tool_name, "get_weather");
540 assert_eq!(part.tool_call_id, Some("call_123".to_string()));
541 }
542
543 #[test]
544 fn test_retry_prompt_part() {
545 let part = RetryPromptPart::tool_retry("my_tool", "Invalid JSON").with_tool_call_id("id1");
546 assert_eq!(part.tool_name, Some("my_tool".to_string()));
547 assert_eq!(part.content.message(), "Invalid JSON");
548 }
549
550 #[test]
551 fn test_serde_roundtrip() {
552 let req = ModelRequest::with_parts(vec![
553 ModelRequestPart::SystemPrompt(SystemPromptPart::new("System")),
554 ModelRequestPart::UserPrompt(UserPromptPart::new("User")),
555 ]);
556 let json = serde_json::to_string(&req).unwrap();
557 let parsed: ModelRequest = serde_json::from_str(&json).unwrap();
558 assert_eq!(req.len(), parsed.len());
559 }
560
561 #[test]
562 fn test_builtin_tool_return() {
563 use crate::messages::parts::{BuiltinToolReturnContent, WebSearchResult, WebSearchResults};
564
565 let results = WebSearchResults::new(
566 "rust programming",
567 vec![WebSearchResult::new("Rust", "https://rust-lang.org")],
568 );
569 let content = BuiltinToolReturnContent::web_search(results);
570 let part = BuiltinToolReturnPart::new("web_search", content, "call_123");
571
572 let mut req = ModelRequest::new();
573 req.add_builtin_tool_return(part);
574
575 assert_eq!(req.len(), 1);
576 assert_eq!(req.builtin_tool_returns().count(), 1);
577
578 let returns: Vec<_> = req.builtin_tool_returns().collect();
579 assert_eq!(returns[0].tool_name, "web_search");
580 assert_eq!(returns[0].tool_call_id, "call_123");
581 }
582
583 #[test]
584 fn test_model_request_part_is_builtin_tool_return() {
585 use crate::messages::parts::{BuiltinToolReturnContent, CodeExecutionResult};
586
587 let result = CodeExecutionResult::new("print(1)").with_stdout("1\n");
588 let content = BuiltinToolReturnContent::code_execution(result);
589 let part = BuiltinToolReturnPart::new("code_execution", content, "call_456");
590 let request_part = ModelRequestPart::BuiltinToolReturn(part);
591
592 assert!(request_part.is_builtin_tool_return());
593 assert_eq!(request_part.part_kind(), "builtin-tool-return");
594 }
595
596 #[test]
597 fn test_serde_roundtrip_with_builtin_tool_return() {
598 use crate::messages::parts::{
599 BuiltinToolReturnContent, FileSearchResult, FileSearchResults,
600 };
601
602 let results = FileSearchResults::new(
603 "main function",
604 vec![FileSearchResult::new("main.rs", "fn main() {}")],
605 );
606 let content = BuiltinToolReturnContent::file_search(results);
607 let part = BuiltinToolReturnPart::new("file_search", content, "call_789");
608
609 let req = ModelRequest::with_parts(vec![
610 ModelRequestPart::UserPrompt(UserPromptPart::new("Search files")),
611 ModelRequestPart::BuiltinToolReturn(part),
612 ]);
613
614 let json = serde_json::to_string(&req).unwrap();
615 let parsed: ModelRequest = serde_json::from_str(&json).unwrap();
616
617 assert_eq!(req.len(), parsed.len());
618 assert_eq!(parsed.builtin_tool_returns().count(), 1);
619 }
620}