Skip to main content

saorsa_ai/
message.rs

1//! Message and content types for LLM conversations.
2
3use serde::{Deserialize, Serialize};
4
5/// The role of a message participant.
6#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum Role {
9    /// User message.
10    User,
11    /// Assistant (model) message.
12    Assistant,
13}
14
15/// A conversation message.
16#[derive(Clone, Debug, Serialize, Deserialize)]
17pub struct Message {
18    /// The role of the message sender.
19    pub role: Role,
20    /// The content blocks.
21    pub content: Vec<ContentBlock>,
22}
23
24impl Message {
25    /// Create a user message with text content.
26    pub fn user(text: impl Into<String>) -> Self {
27        Self {
28            role: Role::User,
29            content: vec![ContentBlock::Text { text: text.into() }],
30        }
31    }
32
33    /// Create an assistant message with text content.
34    pub fn assistant(text: impl Into<String>) -> Self {
35        Self {
36            role: Role::Assistant,
37            content: vec![ContentBlock::Text { text: text.into() }],
38        }
39    }
40
41    /// Create a message with tool result content.
42    pub fn tool_result(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
43        Self {
44            role: Role::User,
45            content: vec![ContentBlock::ToolResult {
46                tool_use_id: tool_use_id.into(),
47                content: content.into(),
48            }],
49        }
50    }
51}
52
53/// A block of content within a message.
54#[derive(Clone, Debug, Serialize, Deserialize)]
55#[serde(tag = "type", rename_all = "snake_case")]
56pub enum ContentBlock {
57    /// Plain text content.
58    Text {
59        /// The text.
60        text: String,
61    },
62    /// A tool use request from the assistant.
63    ToolUse {
64        /// Unique ID for this tool use.
65        id: String,
66        /// Tool name.
67        name: String,
68        /// Tool input (JSON object).
69        input: serde_json::Value,
70    },
71    /// A tool result from the user.
72    ToolResult {
73        /// The ID of the tool_use this is responding to.
74        tool_use_id: String,
75        /// The result content.
76        content: String,
77    },
78}
79
80/// Definition of a tool the model can use.
81#[derive(Clone, Debug, Serialize, Deserialize)]
82pub struct ToolDefinition {
83    /// The tool name.
84    pub name: String,
85    /// Description of what the tool does.
86    pub description: String,
87    /// JSON Schema for the tool's input parameters.
88    pub input_schema: serde_json::Value,
89}
90
91impl ToolDefinition {
92    /// Create a new tool definition.
93    pub fn new(
94        name: impl Into<String>,
95        description: impl Into<String>,
96        input_schema: serde_json::Value,
97    ) -> Self {
98        Self {
99            name: name.into(),
100            description: description.into(),
101            input_schema,
102        }
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[test]
111    fn user_message_construction() {
112        let msg = Message::user("Hello");
113        assert_eq!(msg.role, Role::User);
114        assert_eq!(msg.content.len(), 1);
115        match &msg.content[0] {
116            ContentBlock::Text { text } => {
117                assert_eq!(text, "Hello");
118            }
119            _ => panic!("Expected Text content block"),
120        }
121    }
122
123    #[test]
124    fn assistant_message_construction() {
125        let msg = Message::assistant("Hi there");
126        assert_eq!(msg.role, Role::Assistant);
127    }
128
129    #[test]
130    fn message_serialization_roundtrip() {
131        let msg = Message::user("test");
132        let json = serde_json::to_string(&msg);
133        assert!(json.is_ok());
134        let json = json.as_deref().unwrap_or("");
135        let parsed: Result<Message, _> = serde_json::from_str(json);
136        assert!(parsed.is_ok());
137    }
138
139    #[test]
140    fn tool_use_serialization() {
141        let block = ContentBlock::ToolUse {
142            id: "tool_1".into(),
143            name: "bash".into(),
144            input: serde_json::json!({"command": "ls"}),
145        };
146        let json = serde_json::to_string(&block);
147        assert!(json.is_ok());
148        let json_str = json.as_deref().unwrap_or("");
149        assert!(json_str.contains("tool_use"));
150        assert!(json_str.contains("bash"));
151    }
152
153    #[test]
154    fn tool_result_message() {
155        let msg = Message::tool_result("tool_1", "file.txt");
156        assert_eq!(msg.role, Role::User);
157        match &msg.content[0] {
158            ContentBlock::ToolResult {
159                tool_use_id,
160                content,
161            } => {
162                assert_eq!(tool_use_id, "tool_1");
163                assert_eq!(content, "file.txt");
164            }
165            _ => panic!("Expected ToolResult content block"),
166        }
167    }
168
169    #[test]
170    fn tool_definition_creation() {
171        let tool = ToolDefinition::new(
172            "read_file",
173            "Read a file from disk",
174            serde_json::json!({
175                "type": "object",
176                "properties": {
177                    "path": {"type": "string"}
178                },
179                "required": ["path"]
180            }),
181        );
182        assert_eq!(tool.name, "read_file");
183    }
184}