Skip to main content

tiny_loop/
types.rs

1use schemars::{JsonSchema, schema_for};
2use serde::{Deserialize, Serialize};
3use serde_json::{Map, Value};
4
5/// LLM message with role-specific fields
6#[derive(Serialize, Deserialize, Clone, Debug)]
7#[serde(tag = "role", rename_all = "lowercase")]
8pub enum Message {
9    /// System message with instructions
10    System {
11        /// Message content
12        content: String,
13    },
14    /// User message with input
15    User {
16        /// Message content
17        content: String,
18    },
19    /// Assistant message with response and optional tool calls
20    Assistant {
21        /// Message content
22        content: String,
23        /// Tool calls requested by the assistant
24        #[serde(skip_serializing_if = "Option::is_none")]
25        tool_calls: Option<Vec<ToolCall>>,
26    },
27    /// Tool execution result
28    Tool {
29        /// Tool execution result content
30        content: String,
31        /// ID of the tool call this responds to
32        tool_call_id: String,
33    },
34    /// Custom role with arbitrary fields
35    #[serde(untagged)]
36    Custom {
37        /// Custom role name
38        role: String,
39        /// Custom message body
40        #[serde(flatten)]
41        body: Value,
42    },
43}
44
45/// Tool call from LLM
46#[derive(Serialize, Deserialize, Clone, Debug)]
47pub struct ToolCall {
48    /// Unique identifier for this tool call
49    pub id: String,
50    /// Type of the call (typically "function")
51    #[serde(rename = "type")]
52    pub call_type: String,
53    /// Function call details
54    pub function: FunctionCall,
55}
56
57/// Function call details
58#[derive(Serialize, Deserialize, Clone, Debug)]
59pub struct FunctionCall {
60    /// Function name to call
61    pub name: String,
62    /// JSON-encoded function arguments
63    pub arguments: String,
64}
65
66/// Tool definition for LLM
67#[derive(Serialize, Clone)]
68pub struct ToolDefinition {
69    /// Type of the tool (typically "function")
70    #[serde(rename = "type")]
71    pub tool_type: String,
72    /// Function definition
73    pub function: ToolFunction,
74}
75
76/// Tool function definition
77#[derive(Serialize, Clone)]
78pub struct ToolFunction {
79    /// Function name
80    pub name: String,
81    /// Function description
82    pub description: String,
83    /// JSON schema for function parameters
84    pub parameters: Parameters,
85}
86
87/// JSON schema parameters with metadata stripped
88#[derive(Serialize, Clone, Debug)]
89pub struct Parameters(Map<String, Value>);
90
91impl Parameters {
92    /// Create Parameters from a Json object (map)
93    pub fn from_object(mut obj: Map<String, Value>) -> Self {
94        // Remove `$schema` and `title` fields from JSON schema
95        obj.remove("$schema");
96        obj.remove("title");
97
98        Self(obj)
99    }
100
101    /// Create Parameters from a JsonSchema
102    pub fn from_schema(schema: schemars::Schema) -> Self {
103        let obj = schema.to_value().as_object().unwrap().clone();
104        Self::from_object(obj)
105    }
106
107    /// Create Parameters from a type implementing JsonSchema
108    pub fn from_type<T: JsonSchema>() -> Self {
109        Self::from_schema(schema_for!(T))
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn test_system_roundtrip() {
119        let msg = Message::System {
120            content: "test".into(),
121        };
122        let json = serde_json::to_string(&msg).unwrap();
123        let parsed: Message = serde_json::from_str(&json).unwrap();
124        assert!(matches!(parsed, Message::System { content } if content == "test"));
125    }
126
127    #[test]
128    fn test_user_roundtrip() {
129        let msg = Message::User {
130            content: "test".into(),
131        };
132        let json = serde_json::to_string(&msg).unwrap();
133        let parsed: Message = serde_json::from_str(&json).unwrap();
134        assert!(matches!(parsed, Message::User { content } if content == "test"));
135    }
136
137    #[test]
138    fn test_assistant_no_tools_roundtrip() {
139        let msg = Message::Assistant {
140            content: "test".into(),
141            tool_calls: None,
142        };
143        let json = serde_json::to_string(&msg).unwrap();
144        assert!(!json.contains("tool_calls"));
145        let parsed: Message = serde_json::from_str(&json).unwrap();
146        assert!(
147            matches!(parsed, Message::Assistant { content, tool_calls: None } if content == "test")
148        );
149    }
150
151    #[test]
152    fn test_assistant_with_tools_roundtrip() {
153        let msg = Message::Assistant {
154            content: "test".into(),
155            tool_calls: Some(vec![ToolCall {
156                id: "call_1".into(),
157                call_type: "function".into(),
158                function: FunctionCall {
159                    name: "fn".into(),
160                    arguments: "{}".into(),
161                },
162            }]),
163        };
164        let json = serde_json::to_string(&msg).unwrap();
165        let parsed: Message = serde_json::from_str(&json).unwrap();
166        assert!(
167            matches!(parsed, Message::Assistant { tool_calls: Some(calls), .. } if calls.len() == 1)
168        );
169    }
170
171    #[test]
172    fn test_tool_roundtrip() {
173        let msg = Message::Tool {
174            content: "result".into(),
175            tool_call_id: "call_123".into(),
176        };
177        let json = serde_json::to_string(&msg).unwrap();
178        let parsed: Message = serde_json::from_str(&json).unwrap();
179        assert!(matches!(parsed, Message::Tool { content, tool_call_id } 
180            if content == "result" && tool_call_id == "call_123"));
181    }
182
183    #[test]
184    fn test_custom_roundtrip() {
185        let msg = Message::Custom {
186            role: "custom".into(),
187            body: serde_json::json!({"content": "test", "extra": "field"}),
188        };
189        let json = serde_json::to_string(&msg).unwrap();
190        let parsed: Message = serde_json::from_str(&json).unwrap();
191        assert!(matches!(parsed, Message::Custom { role, .. } if role == "custom"));
192    }
193
194    #[test]
195    fn test_tool_call_roundtrip() {
196        let tc = ToolCall {
197            id: "call_1".into(),
198            call_type: "function".into(),
199            function: FunctionCall {
200                name: "test".into(),
201                arguments: r#"{"key":"value"}"#.into(),
202            },
203        };
204        let json = serde_json::to_string(&tc).unwrap();
205        let parsed: ToolCall = serde_json::from_str(&json).unwrap();
206        assert_eq!(parsed.id, "call_1");
207        assert_eq!(parsed.function.name, "test");
208    }
209
210    #[test]
211    fn test_tool_definition_serialization() {
212        let td = ToolDefinition {
213            tool_type: "function".into(),
214            function: ToolFunction {
215                name: "test".into(),
216                description: "desc".into(),
217                parameters: Parameters::from_schema(schema_for!(String)),
218            },
219        };
220        let json = serde_json::to_string(&td).unwrap();
221        assert!(json.contains(r#""type":"function"#));
222        assert!(json.contains(r#""name":"test"#));
223    }
224}