Skip to main content

rustic_ai/
messages.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4use crate::usage::RequestUsage;
5
6#[derive(Clone, Debug, Serialize, Deserialize)]
7pub struct ImageUrl {
8    pub url: String,
9    pub media_type: Option<String>,
10}
11
12#[derive(Clone, Debug, Serialize, Deserialize)]
13pub struct VideoUrl {
14    pub url: String,
15    pub media_type: Option<String>,
16}
17
18#[derive(Clone, Debug, Serialize, Deserialize)]
19pub struct AudioUrl {
20    pub url: String,
21    pub media_type: Option<String>,
22}
23
24#[derive(Clone, Debug, Serialize, Deserialize)]
25pub struct DocumentUrl {
26    pub url: String,
27    pub media_type: Option<String>,
28}
29
30#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct BinaryContent {
32    pub data: Vec<u8>,
33    pub media_type: String,
34}
35
36#[derive(Clone, Debug, Serialize, Deserialize)]
37pub enum UserContent {
38    Text(String),
39    Image(ImageUrl),
40    Video(VideoUrl),
41    Audio(AudioUrl),
42    Document(DocumentUrl),
43    Binary(BinaryContent),
44}
45
46#[derive(Clone, Debug, Serialize, Deserialize)]
47pub struct SystemPromptPart {
48    pub content: String,
49}
50
51#[derive(Clone, Debug, Serialize, Deserialize)]
52pub struct UserPromptPart {
53    pub content: Vec<UserContent>,
54}
55
56#[derive(Clone, Debug, Serialize, Deserialize)]
57pub struct ToolReturnPart {
58    pub tool_name: String,
59    pub tool_call_id: String,
60    pub content: Value,
61}
62
63#[derive(Clone, Debug, Serialize, Deserialize)]
64pub struct RetryPromptPart {
65    pub content: String,
66    pub tool_name: Option<String>,
67    pub tool_call_id: Option<String>,
68}
69
70#[derive(Clone, Debug, Serialize, Deserialize)]
71pub enum ModelRequestPart {
72    SystemPrompt(SystemPromptPart),
73    UserPrompt(UserPromptPart),
74    ToolReturn(ToolReturnPart),
75    RetryPrompt(RetryPromptPart),
76}
77
78#[derive(Clone, Debug, Serialize, Deserialize)]
79pub struct ModelRequest {
80    pub parts: Vec<ModelRequestPart>,
81    pub instructions: Option<String>,
82}
83
84impl ModelRequest {
85    pub fn user_text_prompt(prompt: impl Into<String>) -> Self {
86        Self {
87            parts: vec![ModelRequestPart::UserPrompt(UserPromptPart {
88                content: vec![UserContent::Text(prompt.into())],
89            })],
90            instructions: None,
91        }
92    }
93}
94
95#[derive(Clone, Debug, Serialize, Deserialize)]
96pub struct TextPart {
97    pub content: String,
98}
99
100#[derive(Clone, Debug, Serialize, Deserialize)]
101pub struct ToolCallPart {
102    pub id: String,
103    pub name: String,
104    pub arguments: Value,
105}
106
107#[derive(Clone, Debug, Serialize, Deserialize)]
108pub struct ProviderItemPart {
109    pub provider: String,
110    pub payload: Value,
111}
112
113#[derive(Clone, Debug, Serialize, Deserialize)]
114pub enum ModelResponsePart {
115    Text(TextPart),
116    ToolCall(ToolCallPart),
117    ProviderItem(ProviderItemPart),
118}
119
120#[derive(Clone, Debug, Serialize, Deserialize)]
121pub struct ModelResponse {
122    pub parts: Vec<ModelResponsePart>,
123    pub usage: Option<RequestUsage>,
124    pub model_name: Option<String>,
125    pub finish_reason: Option<String>,
126}
127
128impl ModelResponse {
129    pub fn text(&self) -> Option<String> {
130        let mut texts = Vec::new();
131        for part in &self.parts {
132            if let ModelResponsePart::Text(text) = part {
133                texts.push(text.content.clone());
134            }
135        }
136        if texts.is_empty() {
137            None
138        } else {
139            Some(texts.join("\n\n"))
140        }
141    }
142
143    pub fn tool_calls(&self) -> Vec<ToolCallPart> {
144        self.parts
145            .iter()
146            .filter_map(|part| match part {
147                ModelResponsePart::ToolCall(call) => Some(call.clone()),
148                _ => None,
149            })
150            .collect()
151    }
152}
153
154#[derive(Clone, Debug, Serialize, Deserialize)]
155pub enum ModelMessage {
156    Request(ModelRequest),
157    Response(ModelResponse),
158}