1use crate::{ZoeyError, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use tracing::{debug, info, warn};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct FunctionDefinition {
12 pub name: String,
14
15 pub description: String,
17
18 pub parameters: serde_json::Value,
20
21 #[serde(skip_serializing_if = "Option::is_none")]
23 pub required: Option<bool>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct FunctionCall {
29 pub name: String,
31
32 pub arguments: serde_json::Value,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct FunctionResult {
39 pub name: String,
41
42 pub result: serde_json::Value,
44
45 pub success: bool,
47
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub error: Option<String>,
51}
52
53pub type FunctionHandler = Arc<
55 dyn Fn(
56 serde_json::Value,
57 )
58 -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send>>
59 + Send
60 + Sync,
61>;
62
63pub struct FunctionRegistry {
65 functions: HashMap<String, (FunctionDefinition, FunctionHandler)>,
66}
67
68impl FunctionRegistry {
69 pub fn new() -> Self {
71 Self {
72 functions: HashMap::new(),
73 }
74 }
75
76 pub fn register(&mut self, definition: FunctionDefinition, handler: FunctionHandler) {
78 info!("Registering function: {}", definition.name);
79 debug!("Function description: {}", definition.description);
80 self.functions
81 .insert(definition.name.clone(), (definition, handler));
82 }
83
84 pub fn validate_definition(definition: &FunctionDefinition) -> Result<()> {
86 if definition.name.is_empty() {
87 return Err(ZoeyError::validation("Function name cannot be empty"));
88 }
89
90 if definition.name.contains(char::is_whitespace) {
91 return Err(ZoeyError::validation(
92 "Function name cannot contain whitespace",
93 ));
94 }
95
96 if definition.description.is_empty() {
97 return Err(ZoeyError::validation(
98 "Function description cannot be empty",
99 ));
100 }
101
102 if !definition.parameters.is_object() {
104 return Err(ZoeyError::validation(
105 "Function parameters must be a JSON object",
106 ));
107 }
108
109 Ok(())
110 }
111
112 pub fn get_definition(&self, name: &str) -> Option<&FunctionDefinition> {
114 self.functions.get(name).map(|(def, _)| def)
115 }
116
117 pub fn get_all_definitions(&self) -> Vec<FunctionDefinition> {
119 self.functions
120 .values()
121 .map(|(def, _)| def.clone())
122 .collect()
123 }
124
125 pub async fn execute(&self, call: FunctionCall) -> FunctionResult {
127 info!("Executing function: {}", call.name);
128 debug!("Function arguments: {}", call.arguments);
129
130 match self.functions.get(&call.name) {
131 Some((_def, handler)) => match handler(call.arguments.clone()).await {
132 Ok(result) => {
133 info!("Function {} executed successfully", call.name);
134 debug!("Result: {}", result);
135 FunctionResult {
136 name: call.name,
137 result,
138 success: true,
139 error: None,
140 }
141 }
142 Err(e) => {
143 warn!("Function {} failed: {}", call.name, e);
144 FunctionResult {
145 name: call.name,
146 result: serde_json::Value::Null,
147 success: false,
148 error: Some(e.to_string()),
149 }
150 }
151 },
152 None => {
153 warn!("Function '{}' not found in registry", call.name);
154 FunctionResult {
155 name: call.name.clone(),
156 result: serde_json::Value::Null,
157 success: false,
158 error: Some(format!("Function '{}' not found", call.name)),
159 }
160 }
161 }
162 }
163
164 pub fn has_function(&self, name: &str) -> bool {
166 self.functions.contains_key(name)
167 }
168
169 pub fn len(&self) -> usize {
171 self.functions.len()
172 }
173
174 pub fn is_empty(&self) -> bool {
176 self.functions.is_empty()
177 }
178}
179
180impl Default for FunctionRegistry {
181 fn default() -> Self {
182 Self::new()
183 }
184}
185
186pub fn create_function_definition(
188 name: impl Into<String>,
189 description: impl Into<String>,
190 parameters: serde_json::Value,
191) -> FunctionDefinition {
192 FunctionDefinition {
193 name: name.into(),
194 description: description.into(),
195 parameters,
196 required: None,
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[tokio::test]
205 async fn test_function_registry() {
206 let mut registry = FunctionRegistry::new();
207
208 let def = FunctionDefinition {
209 name: "get_weather".to_string(),
210 description: "Get current weather".to_string(),
211 parameters: serde_json::json!({
212 "type": "object",
213 "properties": {
214 "location": {"type": "string"}
215 }
216 }),
217 required: Some(true),
218 };
219
220 let handler: FunctionHandler = Arc::new(|_args| {
221 Box::pin(async move {
222 Ok(serde_json::json!({
223 "temperature": 72,
224 "condition": "sunny"
225 }))
226 })
227 });
228
229 registry.register(def, handler);
230
231 assert_eq!(registry.len(), 1);
232 assert!(registry.has_function("get_weather"));
233 }
234
235 #[tokio::test]
236 async fn test_function_execution() {
237 let mut registry = FunctionRegistry::new();
238
239 let def = create_function_definition(
240 "add_numbers",
241 "Add two numbers",
242 serde_json::json!({"type": "object"}),
243 );
244
245 let handler: FunctionHandler = Arc::new(|args| {
246 Box::pin(async move {
247 let a = args["a"].as_i64().unwrap_or(0);
248 let b = args["b"].as_i64().unwrap_or(0);
249 Ok(serde_json::json!(a + b))
250 })
251 });
252
253 registry.register(def, handler);
254
255 let call = FunctionCall {
256 name: "add_numbers".to_string(),
257 arguments: serde_json::json!({"a": 5, "b": 3}),
258 };
259
260 let result = registry.execute(call).await;
261
262 assert!(result.success);
263 assert_eq!(result.result, serde_json::json!(8));
264 }
265
266 #[tokio::test]
267 async fn test_function_not_found() {
268 let registry = FunctionRegistry::new();
269
270 let call = FunctionCall {
271 name: "nonexistent".to_string(),
272 arguments: serde_json::json!({}),
273 };
274
275 let result = registry.execute(call).await;
276
277 assert!(!result.success);
278 assert!(result.error.is_some());
279 }
280}