1use std::collections::HashMap;
4use std::future::Future;
5use std::hash::BuildHasher;
6
7use futures_util::future::BoxFuture;
8use rust_genai_types::content::{FunctionCall, FunctionResponse, Part};
9use rust_genai_types::models::GenerateContentConfig;
10use rust_genai_types::tool::{FunctionDeclaration, Tool};
11use serde_json::Value;
12
13use crate::error::{Error, Result};
14
15pub const DEFAULT_MAX_REMOTE_CALLS: usize = 10;
17
18pub trait CallableTool: Send {
20 fn tool(&mut self) -> BoxFuture<'_, Result<Tool>>;
21 fn call_tool(&mut self, function_calls: Vec<FunctionCall>) -> BoxFuture<'_, Result<Vec<Part>>>;
22}
23
24pub type ToolHandler =
26 Box<dyn Fn(Value) -> BoxFuture<'static, Result<Value>> + Send + Sync + 'static>;
27
28#[derive(Default)]
30pub struct InlineCallableTool {
31 tool: Tool,
32 handlers: HashMap<String, ToolHandler>,
33}
34
35impl InlineCallableTool {
36 #[must_use]
38 pub fn from_declarations(declarations: Vec<FunctionDeclaration>) -> Self {
39 Self {
40 tool: Tool {
41 function_declarations: Some(declarations),
42 ..Tool::default()
43 },
44 handlers: HashMap::new(),
45 }
46 }
47
48 pub fn register_handler<F, Fut>(&mut self, name: impl Into<String>, handler: F)
50 where
51 F: Fn(Value) -> Fut + Send + Sync + 'static,
52 Fut: Future<Output = Result<Value>> + Send + 'static,
53 {
54 let key = name.into();
55 self.handlers.insert(
56 key,
57 Box::new(move |value| {
58 let fut = handler(value);
59 Box::pin(fut)
60 }),
61 );
62 }
63
64 #[must_use]
66 pub fn with_handler<F, Fut>(mut self, name: impl Into<String>, handler: F) -> Self
67 where
68 F: Fn(Value) -> Fut + Send + Sync + 'static,
69 Fut: Future<Output = Result<Value>> + Send + 'static,
70 {
71 self.register_handler(name, handler);
72 self
73 }
74}
75
76impl CallableTool for InlineCallableTool {
77 fn tool(&mut self) -> BoxFuture<'_, Result<Tool>> {
78 Box::pin(async move { Ok(self.tool.clone()) })
79 }
80
81 fn call_tool(&mut self, function_calls: Vec<FunctionCall>) -> BoxFuture<'_, Result<Vec<Part>>> {
82 Box::pin(async move {
83 let mut parts = Vec::new();
84 for call in function_calls {
85 let Some(name) = call.name.as_ref() else {
86 continue;
87 };
88 let Some(handler) = self.handlers.get(name) else {
89 continue;
90 };
91 let args = call.args.clone().unwrap_or(Value::Null);
92 let response_value = handler(args).await?;
93 let function_response = FunctionResponse {
94 will_continue: None,
95 scheduling: None,
96 parts: None,
97 id: call.id.clone(),
98 name: Some(name.clone()),
99 response: Some(response_value),
100 };
101 parts.push(Part::function_response(function_response));
102 }
103 Ok(parts)
104 })
105 }
106}
107
108pub async fn resolve_callable_tools(
113 callable_tools: &mut [Box<dyn CallableTool>],
114) -> Result<CallableToolInfo> {
115 let mut tools = Vec::new();
116 let mut function_map: HashMap<String, usize> = HashMap::new();
117
118 for (index, tool) in callable_tools.iter_mut().enumerate() {
119 let declaration_tool = tool.tool().await?;
120 if let Some(declarations) = &declaration_tool.function_declarations {
121 for declaration in declarations {
122 if function_map.contains_key(&declaration.name) {
123 return Err(Error::InvalidConfig {
124 message: format!("Duplicate tool declaration name: {}", declaration.name),
125 });
126 }
127 function_map.insert(declaration.name.clone(), index);
128 }
129 }
130 tools.push(declaration_tool);
131 }
132
133 Ok(CallableToolInfo {
134 tools,
135 function_map,
136 })
137}
138
139pub async fn call_callable_tools<S: BuildHasher + Sync>(
144 callable_tools: &mut [Box<dyn CallableTool>],
145 function_map: &HashMap<String, usize, S>,
146 function_calls: &[FunctionCall],
147) -> Result<Vec<Part>> {
148 let mut grouped: HashMap<usize, Vec<FunctionCall>> = HashMap::new();
149 for call in function_calls {
150 let name = call.name.as_ref().ok_or_else(|| Error::InvalidConfig {
151 message: "Function call name was not returned by the model.".into(),
152 })?;
153 let index = function_map.get(name).ok_or_else(|| Error::InvalidConfig {
154 message: format!(
155 "Automatic function calling was requested, but not all the tools the model used implement the CallableTool interface. Missing tool: {name}."
156 ),
157 })?;
158 grouped.entry(*index).or_default().push(call.clone());
159 }
160
161 let mut parts = Vec::new();
162 for (index, calls) in grouped {
163 let response_parts = callable_tools[index].call_tool(calls).await?;
164 parts.extend(response_parts);
165 }
166 Ok(parts)
167}
168
169pub struct CallableToolInfo<S = std::collections::hash_map::RandomState> {
171 pub tools: Vec<Tool>,
172 pub function_map: HashMap<String, usize, S>,
173}
174
175#[must_use]
177pub fn should_disable_afc(config: &GenerateContentConfig, has_callable_tools: bool) -> bool {
178 if !has_callable_tools {
179 return true;
180 }
181 if config
182 .automatic_function_calling
183 .as_ref()
184 .and_then(|cfg| cfg.disable)
185 .unwrap_or(false)
186 {
187 return true;
188 }
189 if let Some(max_calls) = config
190 .automatic_function_calling
191 .as_ref()
192 .and_then(|cfg| cfg.maximum_remote_calls)
193 {
194 if max_calls <= 0 {
195 return true;
196 }
197 }
198 false
199}
200
201#[must_use]
203pub fn max_remote_calls(config: &GenerateContentConfig) -> usize {
204 config
205 .automatic_function_calling
206 .as_ref()
207 .and_then(|cfg| cfg.maximum_remote_calls)
208 .and_then(|value| usize::try_from(value).ok())
209 .unwrap_or(DEFAULT_MAX_REMOTE_CALLS)
210}
211
212#[must_use]
214pub fn should_append_history(config: &GenerateContentConfig) -> bool {
215 !config
216 .automatic_function_calling
217 .as_ref()
218 .and_then(|cfg| cfg.ignore_call_history)
219 .unwrap_or(false)
220}
221
222pub fn validate_afc_tools<S: BuildHasher>(
227 _callable_function_map: &HashMap<String, usize, S>,
228 tools: Option<&[Tool]>,
229) -> Result<()> {
230 let Some(tools) = tools else {
231 return Ok(());
232 };
233
234 for tool in tools {
235 if let Some(declarations) = &tool.function_declarations {
236 if !declarations.is_empty() {
237 return Err(Error::InvalidConfig {
238 message: "Incompatible tools found. Automatic function calling does not support mixing CallableTools with basic function declarations.".into(),
239 });
240 }
241 }
242 }
243 Ok(())
244}
245
246pub fn validate_afc_config(config: &GenerateContentConfig) -> Result<()> {
251 if config
252 .tool_config
253 .as_ref()
254 .and_then(|cfg| cfg.function_calling_config.as_ref())
255 .and_then(|cfg| cfg.stream_function_call_arguments)
256 .unwrap_or(false)
257 && !config
258 .automatic_function_calling
259 .as_ref()
260 .and_then(|cfg| cfg.disable)
261 .unwrap_or(false)
262 {
263 return Err(Error::InvalidConfig {
264 message: "stream_function_call_arguments is not compatible with automatic function calling. Disable AFC or disable stream_function_call_arguments.".into(),
265 });
266 }
267 Ok(())
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use rust_genai_types::models::AutomaticFunctionCallingConfig;
274 use rust_genai_types::tool::{FunctionDeclaration, Tool};
275 use serde_json::json;
276
277 #[test]
278 fn test_should_disable_afc_when_max_calls_zero() {
279 let config = GenerateContentConfig {
280 automatic_function_calling: Some(AutomaticFunctionCallingConfig {
281 maximum_remote_calls: Some(0),
282 ..Default::default()
283 }),
284 ..Default::default()
285 };
286 assert!(should_disable_afc(&config, true));
287 }
288
289 #[test]
290 fn test_should_append_history_respects_ignore_flag() {
291 let config = GenerateContentConfig {
292 automatic_function_calling: Some(AutomaticFunctionCallingConfig {
293 ignore_call_history: Some(true),
294 ..Default::default()
295 }),
296 ..Default::default()
297 };
298 assert!(!should_append_history(&config));
299 }
300
301 #[test]
302 fn test_validate_afc_tools_rejects_plain_declarations() {
303 let tool = Tool {
304 function_declarations: Some(vec![FunctionDeclaration {
305 name: "test_fn".to_string(),
306 description: None,
307 parameters: None,
308 parameters_json_schema: None,
309 response: None,
310 response_json_schema: None,
311 behavior: None,
312 }]),
313 ..Default::default()
314 };
315 let err = validate_afc_tools(&HashMap::new(), Some(&[tool])).unwrap_err();
316 assert!(matches!(err, Error::InvalidConfig { .. }));
317 }
318
319 #[tokio::test]
320 async fn test_inline_callable_tool_roundtrip() {
321 let mut tool = InlineCallableTool::from_declarations(vec![FunctionDeclaration {
322 name: "sum".to_string(),
323 description: None,
324 parameters: None,
325 parameters_json_schema: None,
326 response: None,
327 response_json_schema: None,
328 behavior: None,
329 }]);
330 tool.register_handler("sum", |value| async move {
331 let a = value["a"].as_i64().unwrap_or(0);
332 let b = value["b"].as_i64().unwrap_or(0);
333 Ok(json!({ "result": a + b }))
334 });
335
336 let mut tools: Vec<Box<dyn CallableTool>> = vec![Box::new(tool)];
337 let info = resolve_callable_tools(&mut tools).await.unwrap();
338 assert!(info.function_map.contains_key("sum"));
339
340 let calls = vec![FunctionCall {
341 id: Some("call-1".into()),
342 name: Some("sum".into()),
343 args: Some(json!({"a": 1, "b": 2})),
344 partial_args: None,
345 will_continue: None,
346 }];
347 let parts = call_callable_tools(&mut tools, &info.function_map, &calls)
348 .await
349 .unwrap();
350 assert_eq!(parts.len(), 1);
351 }
352
353 #[tokio::test]
354 async fn test_call_callable_tools_rejects_missing_name() {
355 let mut tools: Vec<Box<dyn CallableTool>> = Vec::new();
356 let calls = vec![FunctionCall {
357 id: None,
358 name: None,
359 args: None,
360 partial_args: None,
361 will_continue: None,
362 }];
363 let err = call_callable_tools(&mut tools, &HashMap::new(), &calls)
364 .await
365 .unwrap_err();
366 assert!(matches!(err, Error::InvalidConfig { .. }));
367 }
368
369 #[tokio::test]
370 async fn test_call_callable_tools_rejects_unknown_tool() {
371 let calls = vec![FunctionCall {
372 id: Some("call-1".into()),
373 name: Some("missing".into()),
374 args: None,
375 partial_args: None,
376 will_continue: None,
377 }];
378 let err = call_callable_tools(&mut [], &HashMap::new(), &calls)
379 .await
380 .unwrap_err();
381 assert!(matches!(err, Error::InvalidConfig { .. }));
382 }
383}