Skip to main content

tiny_loop/
types.rs

1use schemars::{JsonSchema, generate::SchemaSettings};
2use serde::{Deserialize, Serialize};
3use serde_json::{Map, Value};
4
5/// LLM message with role-specific fields
6#[derive(Serialize, Deserialize, Clone, Debug)]
7#[serde(tag = "role", rename_all = "lowercase")]
8pub enum Message {
9    /// System message with instructions
10    System {
11        /// Message content
12        content: String,
13    },
14    /// User message with input
15    User {
16        /// Message content
17        content: String,
18    },
19    /// Assistant message with response and optional tool calls
20    Assistant {
21        /// Message content
22        content: String,
23        /// Tool calls requested by the assistant
24        #[serde(skip_serializing_if = "Option::is_none")]
25        tool_calls: Option<Vec<ToolCall>>,
26    },
27    /// Tool execution result
28    Tool {
29        /// Tool execution result content
30        content: String,
31        /// ID of the tool call this responds to
32        tool_call_id: String,
33    },
34    /// Custom role with arbitrary fields
35    #[serde(untagged)]
36    Custom {
37        /// Custom role name
38        role: String,
39        /// Custom message body
40        #[serde(flatten)]
41        body: Value,
42    },
43}
44
45/// Tool call from LLM
46#[derive(Serialize, Deserialize, Clone, Debug)]
47pub struct ToolCall {
48    /// Unique identifier for this tool call
49    pub id: String,
50    /// Type of the call (typically "function")
51    #[serde(rename = "type")]
52    pub call_type: String,
53    /// Function call details
54    pub function: FunctionCall,
55}
56
57/// Function call details
58#[derive(Serialize, Deserialize, Clone, Debug)]
59pub struct FunctionCall {
60    /// Function name to call
61    pub name: String,
62    /// JSON-encoded function arguments
63    pub arguments: String,
64}
65
66/// Tool definition for LLM
67#[derive(Serialize, Clone, Debug)]
68pub struct ToolDefinition {
69    /// Type of the tool (typically "function")
70    #[serde(rename = "type")]
71    pub tool_type: String,
72    /// Function definition
73    pub function: ToolFunction,
74}
75
76/// Tool function definition
77#[derive(Serialize, Clone, Debug)]
78pub struct ToolFunction {
79    /// Function name
80    pub name: String,
81    /// Function description
82    pub description: String,
83    /// JSON schema for function parameters
84    pub parameters: Parameters,
85}
86
87/// JSON schema parameters with metadata stripped
88#[derive(Serialize, Clone, Debug)]
89pub struct Parameters(Map<String, Value>);
90
91impl Parameters {
92    /// Create Parameters from a Json object (map)
93    pub fn from_object(mut obj: Map<String, Value>) -> Self {
94        // Remove `$schema`, `title`, and `description` fields from JSON schema
95        obj.remove("$schema");
96        obj.remove("title");
97        obj.remove("description");
98
99        Self(obj)
100    }
101
102    /// Create Parameters from a JsonSchema
103    pub fn from_schema(schema: schemars::Schema) -> Self {
104        let obj = schema.to_value().as_object().unwrap().clone();
105        Self::from_object(obj)
106    }
107
108    /// Create Parameters from a type implementing JsonSchema
109    pub fn from_type<T: JsonSchema>() -> Self {
110        let settings = SchemaSettings::default().with(|s| {
111            s.inline_subschemas = true;
112        });
113        let generator = settings.into_generator();
114        let schema = generator.into_root_schema_for::<T>();
115        Self::from_schema(schema)
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[test]
124    fn test_system_roundtrip() {
125        let msg = Message::System {
126            content: "test".into(),
127        };
128        let json = serde_json::to_string(&msg).unwrap();
129        let parsed: Message = serde_json::from_str(&json).unwrap();
130        assert!(matches!(parsed, Message::System { content } if content == "test"));
131    }
132
133    #[test]
134    fn test_user_roundtrip() {
135        let msg = Message::User {
136            content: "test".into(),
137        };
138        let json = serde_json::to_string(&msg).unwrap();
139        let parsed: Message = serde_json::from_str(&json).unwrap();
140        assert!(matches!(parsed, Message::User { content } if content == "test"));
141    }
142
143    #[test]
144    fn test_assistant_no_tools_roundtrip() {
145        let msg = Message::Assistant {
146            content: "test".into(),
147            tool_calls: None,
148        };
149        let json = serde_json::to_string(&msg).unwrap();
150        assert!(!json.contains("tool_calls"));
151        let parsed: Message = serde_json::from_str(&json).unwrap();
152        assert!(
153            matches!(parsed, Message::Assistant { content, tool_calls: None } if content == "test")
154        );
155    }
156
157    #[test]
158    fn test_assistant_with_tools_roundtrip() {
159        let msg = Message::Assistant {
160            content: "test".into(),
161            tool_calls: Some(vec![ToolCall {
162                id: "call_1".into(),
163                call_type: "function".into(),
164                function: FunctionCall {
165                    name: "fn".into(),
166                    arguments: "{}".into(),
167                },
168            }]),
169        };
170        let json = serde_json::to_string(&msg).unwrap();
171        let parsed: Message = serde_json::from_str(&json).unwrap();
172        assert!(
173            matches!(parsed, Message::Assistant { tool_calls: Some(calls), .. } if calls.len() == 1)
174        );
175    }
176
177    #[test]
178    fn test_tool_roundtrip() {
179        let msg = Message::Tool {
180            content: "result".into(),
181            tool_call_id: "call_123".into(),
182        };
183        let json = serde_json::to_string(&msg).unwrap();
184        let parsed: Message = serde_json::from_str(&json).unwrap();
185        assert!(matches!(parsed, Message::Tool { content, tool_call_id } 
186            if content == "result" && tool_call_id == "call_123"));
187    }
188
189    #[test]
190    fn test_custom_roundtrip() {
191        let msg = Message::Custom {
192            role: "custom".into(),
193            body: serde_json::json!({"content": "test", "extra": "field"}),
194        };
195        let json = serde_json::to_string(&msg).unwrap();
196        let parsed: Message = serde_json::from_str(&json).unwrap();
197        assert!(matches!(parsed, Message::Custom { role, .. } if role == "custom"));
198    }
199
200    #[test]
201    fn test_tool_call_roundtrip() {
202        let tc = ToolCall {
203            id: "call_1".into(),
204            call_type: "function".into(),
205            function: FunctionCall {
206                name: "test".into(),
207                arguments: r#"{"key":"value"}"#.into(),
208            },
209        };
210        let json = serde_json::to_string(&tc).unwrap();
211        let parsed: ToolCall = serde_json::from_str(&json).unwrap();
212        assert_eq!(parsed.id, "call_1");
213        assert_eq!(parsed.function.name, "test");
214    }
215
216    #[test]
217    fn test_tool_definition_serialization() {
218        let td = ToolDefinition {
219            tool_type: "function".into(),
220            function: ToolFunction {
221                name: "test".into(),
222                description: "desc".into(),
223                parameters: Parameters::from_type::<String>(),
224            },
225        };
226        let json = serde_json::to_string(&td).unwrap();
227        assert!(json.contains(r#""type":"function"#));
228        assert!(json.contains(r#""name":"test"#));
229    }
230}