zoey_core/
function_calling.rs

1//! Function calling support for LLMs
2
3use crate::{ZoeyError, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use tracing::{debug, info, warn};
8
9/// Function definition for LLM function calling
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct FunctionDefinition {
12    /// Function name
13    pub name: String,
14
15    /// Function description
16    pub description: String,
17
18    /// Parameters schema (JSON Schema)
19    pub parameters: serde_json::Value,
20
21    /// Whether this function is required
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub required: Option<bool>,
24}
25
26/// Function call from LLM
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct FunctionCall {
29    /// Function name to call
30    pub name: String,
31
32    /// Arguments (JSON object)
33    pub arguments: serde_json::Value,
34}
35
36/// Function execution result
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct FunctionResult {
39    /// Function name that was called
40    pub name: String,
41
42    /// Result value
43    pub result: serde_json::Value,
44
45    /// Whether the call succeeded
46    pub success: bool,
47
48    /// Error message if failed
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub error: Option<String>,
51}
52
53/// Function handler type
54pub type FunctionHandler = Arc<
55    dyn Fn(
56            serde_json::Value,
57        )
58            -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send>>
59        + Send
60        + Sync,
61>;
62
63/// Function registry for managing callable functions
64pub struct FunctionRegistry {
65    functions: HashMap<String, (FunctionDefinition, FunctionHandler)>,
66}
67
68impl FunctionRegistry {
69    /// Create a new function registry
70    pub fn new() -> Self {
71        Self {
72            functions: HashMap::new(),
73        }
74    }
75
76    /// Register a function
77    pub fn register(&mut self, definition: FunctionDefinition, handler: FunctionHandler) {
78        info!("Registering function: {}", definition.name);
79        debug!("Function description: {}", definition.description);
80        self.functions
81            .insert(definition.name.clone(), (definition, handler));
82    }
83
84    /// Validate a function definition
85    pub fn validate_definition(definition: &FunctionDefinition) -> Result<()> {
86        if definition.name.is_empty() {
87            return Err(ZoeyError::validation("Function name cannot be empty"));
88        }
89
90        if definition.name.contains(char::is_whitespace) {
91            return Err(ZoeyError::validation(
92                "Function name cannot contain whitespace",
93            ));
94        }
95
96        if definition.description.is_empty() {
97            return Err(ZoeyError::validation(
98                "Function description cannot be empty",
99            ));
100        }
101
102        // Validate parameters is valid JSON
103        if !definition.parameters.is_object() {
104            return Err(ZoeyError::validation(
105                "Function parameters must be a JSON object",
106            ));
107        }
108
109        Ok(())
110    }
111
112    /// Get function definition
113    pub fn get_definition(&self, name: &str) -> Option<&FunctionDefinition> {
114        self.functions.get(name).map(|(def, _)| def)
115    }
116
117    /// Get all function definitions
118    pub fn get_all_definitions(&self) -> Vec<FunctionDefinition> {
119        self.functions
120            .values()
121            .map(|(def, _)| def.clone())
122            .collect()
123    }
124
125    /// Execute a function call
126    pub async fn execute(&self, call: FunctionCall) -> FunctionResult {
127        info!("Executing function: {}", call.name);
128        debug!("Function arguments: {}", call.arguments);
129
130        match self.functions.get(&call.name) {
131            Some((_def, handler)) => match handler(call.arguments.clone()).await {
132                Ok(result) => {
133                    info!("Function {} executed successfully", call.name);
134                    debug!("Result: {}", result);
135                    FunctionResult {
136                        name: call.name,
137                        result,
138                        success: true,
139                        error: None,
140                    }
141                }
142                Err(e) => {
143                    warn!("Function {} failed: {}", call.name, e);
144                    FunctionResult {
145                        name: call.name,
146                        result: serde_json::Value::Null,
147                        success: false,
148                        error: Some(e.to_string()),
149                    }
150                }
151            },
152            None => {
153                warn!("Function '{}' not found in registry", call.name);
154                FunctionResult {
155                    name: call.name.clone(),
156                    result: serde_json::Value::Null,
157                    success: false,
158                    error: Some(format!("Function '{}' not found", call.name)),
159                }
160            }
161        }
162    }
163
164    /// Check if function exists
165    pub fn has_function(&self, name: &str) -> bool {
166        self.functions.contains_key(name)
167    }
168
169    /// Get number of registered functions
170    pub fn len(&self) -> usize {
171        self.functions.len()
172    }
173
174    /// Check if empty
175    pub fn is_empty(&self) -> bool {
176        self.functions.is_empty()
177    }
178}
179
180impl Default for FunctionRegistry {
181    fn default() -> Self {
182        Self::new()
183    }
184}
185
186/// Helper to create a function definition
187pub fn create_function_definition(
188    name: impl Into<String>,
189    description: impl Into<String>,
190    parameters: serde_json::Value,
191) -> FunctionDefinition {
192    FunctionDefinition {
193        name: name.into(),
194        description: description.into(),
195        parameters,
196        required: None,
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[tokio::test]
205    async fn test_function_registry() {
206        let mut registry = FunctionRegistry::new();
207
208        let def = FunctionDefinition {
209            name: "get_weather".to_string(),
210            description: "Get current weather".to_string(),
211            parameters: serde_json::json!({
212                "type": "object",
213                "properties": {
214                    "location": {"type": "string"}
215                }
216            }),
217            required: Some(true),
218        };
219
220        let handler: FunctionHandler = Arc::new(|_args| {
221            Box::pin(async move {
222                Ok(serde_json::json!({
223                    "temperature": 72,
224                    "condition": "sunny"
225                }))
226            })
227        });
228
229        registry.register(def, handler);
230
231        assert_eq!(registry.len(), 1);
232        assert!(registry.has_function("get_weather"));
233    }
234
235    #[tokio::test]
236    async fn test_function_execution() {
237        let mut registry = FunctionRegistry::new();
238
239        let def = create_function_definition(
240            "add_numbers",
241            "Add two numbers",
242            serde_json::json!({"type": "object"}),
243        );
244
245        let handler: FunctionHandler = Arc::new(|args| {
246            Box::pin(async move {
247                let a = args["a"].as_i64().unwrap_or(0);
248                let b = args["b"].as_i64().unwrap_or(0);
249                Ok(serde_json::json!(a + b))
250            })
251        });
252
253        registry.register(def, handler);
254
255        let call = FunctionCall {
256            name: "add_numbers".to_string(),
257            arguments: serde_json::json!({"a": 5, "b": 3}),
258        };
259
260        let result = registry.execute(call).await;
261
262        assert!(result.success);
263        assert_eq!(result.result, serde_json::json!(8));
264    }
265
266    #[tokio::test]
267    async fn test_function_not_found() {
268        let registry = FunctionRegistry::new();
269
270        let call = FunctionCall {
271            name: "nonexistent".to_string(),
272            arguments: serde_json::json!({}),
273        };
274
275        let result = registry.execute(call).await;
276
277        assert!(!result.success);
278        assert!(result.error.is_some());
279    }
280}