Skip to main content

tiny_loop/
types.rs

1use schemars::{JsonSchema, generate::SchemaSettings};
2use serde::{Deserialize, Serialize};
3use serde_json::{Map, Value};
4
5/// LLM message with role-specific fields
6#[derive(Serialize, Deserialize, Clone, Debug)]
7#[serde(tag = "role", rename_all = "lowercase")]
8pub enum Message {
9    /// System message with instructions
10    System {
11        /// Message content
12        content: String,
13    },
14    /// User message with input
15    User {
16        /// Message content
17        content: String,
18    },
19    /// Assistant message with response and optional tool calls
20    Assistant {
21        /// Message content
22        content: String,
23        /// Tool calls requested by the assistant
24        #[serde(skip_serializing_if = "Option::is_none")]
25        tool_calls: Option<Vec<ToolCall>>,
26    },
27    /// Tool execution result
28    Tool {
29        /// Tool execution result content
30        content: String,
31        /// ID of the tool call this responds to
32        tool_call_id: String,
33    },
34    /// Custom role with arbitrary fields
35    #[serde(untagged)]
36    Custom {
37        /// Custom role name
38        role: String,
39        /// Custom message body
40        #[serde(flatten)]
41        body: Value,
42    },
43}
44
45/// Tool call from LLM
46#[derive(Serialize, Deserialize, Clone, Debug)]
47pub struct ToolCall {
48    /// Unique identifier for this tool call
49    pub id: String,
50    /// Type of the call (typically "function")
51    #[serde(rename = "type")]
52    pub call_type: String,
53    /// Function call details
54    pub function: FunctionCall,
55}
56
57/// Function call details
58#[derive(Serialize, Deserialize, Clone, Debug)]
59pub struct FunctionCall {
60    /// Function name to call
61    pub name: String,
62    /// JSON-encoded function arguments
63    pub arguments: String,
64}
65
66/// Tool definition for LLM
67#[derive(Serialize, Clone, Debug)]
68pub struct ToolDefinition {
69    /// Type of the tool (typically "function")
70    #[serde(rename = "type")]
71    pub tool_type: String,
72    /// Function definition
73    pub function: ToolFunction,
74}
75
76/// Tool function definition
77#[derive(Serialize, Clone, Debug)]
78pub struct ToolFunction {
79    /// Function name
80    pub name: String,
81    /// Function description
82    pub description: String,
83    /// JSON schema for function parameters
84    pub parameters: Parameters,
85}
86
87/// JSON schema parameters with metadata stripped
88#[derive(Serialize, Clone, Debug)]
89pub struct Parameters(Map<String, Value>);
90
91impl Parameters {
92    /// Create Parameters from a Json object (map)
93    pub fn from_object(mut obj: Map<String, Value>) -> Self {
94        // Remove `$schema` and `title` fields from JSON schema
95        obj.remove("$schema");
96        obj.remove("title");
97
98        Self(obj)
99    }
100
101    /// Create Parameters from a JsonSchema
102    pub fn from_schema(schema: schemars::Schema) -> Self {
103        let obj = schema.to_value().as_object().unwrap().clone();
104        Self::from_object(obj)
105    }
106
107    /// Create Parameters from a type implementing JsonSchema
108    pub fn from_type<T: JsonSchema>() -> Self {
109        let settings = SchemaSettings::default().with(|s| {
110            s.inline_subschemas = true;
111        });
112        let generator = settings.into_generator();
113        let schema = generator.into_root_schema_for::<T>();
114        Self::from_schema(schema)
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn test_system_roundtrip() {
124        let msg = Message::System {
125            content: "test".into(),
126        };
127        let json = serde_json::to_string(&msg).unwrap();
128        let parsed: Message = serde_json::from_str(&json).unwrap();
129        assert!(matches!(parsed, Message::System { content } if content == "test"));
130    }
131
132    #[test]
133    fn test_user_roundtrip() {
134        let msg = Message::User {
135            content: "test".into(),
136        };
137        let json = serde_json::to_string(&msg).unwrap();
138        let parsed: Message = serde_json::from_str(&json).unwrap();
139        assert!(matches!(parsed, Message::User { content } if content == "test"));
140    }
141
142    #[test]
143    fn test_assistant_no_tools_roundtrip() {
144        let msg = Message::Assistant {
145            content: "test".into(),
146            tool_calls: None,
147        };
148        let json = serde_json::to_string(&msg).unwrap();
149        assert!(!json.contains("tool_calls"));
150        let parsed: Message = serde_json::from_str(&json).unwrap();
151        assert!(
152            matches!(parsed, Message::Assistant { content, tool_calls: None } if content == "test")
153        );
154    }
155
156    #[test]
157    fn test_assistant_with_tools_roundtrip() {
158        let msg = Message::Assistant {
159            content: "test".into(),
160            tool_calls: Some(vec![ToolCall {
161                id: "call_1".into(),
162                call_type: "function".into(),
163                function: FunctionCall {
164                    name: "fn".into(),
165                    arguments: "{}".into(),
166                },
167            }]),
168        };
169        let json = serde_json::to_string(&msg).unwrap();
170        let parsed: Message = serde_json::from_str(&json).unwrap();
171        assert!(
172            matches!(parsed, Message::Assistant { tool_calls: Some(calls), .. } if calls.len() == 1)
173        );
174    }
175
176    #[test]
177    fn test_tool_roundtrip() {
178        let msg = Message::Tool {
179            content: "result".into(),
180            tool_call_id: "call_123".into(),
181        };
182        let json = serde_json::to_string(&msg).unwrap();
183        let parsed: Message = serde_json::from_str(&json).unwrap();
184        assert!(matches!(parsed, Message::Tool { content, tool_call_id } 
185            if content == "result" && tool_call_id == "call_123"));
186    }
187
188    #[test]
189    fn test_custom_roundtrip() {
190        let msg = Message::Custom {
191            role: "custom".into(),
192            body: serde_json::json!({"content": "test", "extra": "field"}),
193        };
194        let json = serde_json::to_string(&msg).unwrap();
195        let parsed: Message = serde_json::from_str(&json).unwrap();
196        assert!(matches!(parsed, Message::Custom { role, .. } if role == "custom"));
197    }
198
199    #[test]
200    fn test_tool_call_roundtrip() {
201        let tc = ToolCall {
202            id: "call_1".into(),
203            call_type: "function".into(),
204            function: FunctionCall {
205                name: "test".into(),
206                arguments: r#"{"key":"value"}"#.into(),
207            },
208        };
209        let json = serde_json::to_string(&tc).unwrap();
210        let parsed: ToolCall = serde_json::from_str(&json).unwrap();
211        assert_eq!(parsed.id, "call_1");
212        assert_eq!(parsed.function.name, "test");
213    }
214
215    #[test]
216    fn test_tool_definition_serialization() {
217        let td = ToolDefinition {
218            tool_type: "function".into(),
219            function: ToolFunction {
220                name: "test".into(),
221                description: "desc".into(),
222                parameters: Parameters::from_type::<String>(),
223            },
224        };
225        let json = serde_json::to_string(&td).unwrap();
226        assert!(json.contains(r#""type":"function"#));
227        assert!(json.contains(r#""name":"test"#));
228    }
229}