Skip to main content

rust_genai/
afc.rs

1//! Automatic Function Calling (AFC) helpers.
2
3use std::collections::HashMap;
4use std::future::Future;
5use std::hash::BuildHasher;
6
7use futures_util::future::BoxFuture;
8use rust_genai_types::content::{FunctionCall, FunctionResponse, Part};
9use rust_genai_types::models::GenerateContentConfig;
10use rust_genai_types::tool::{FunctionDeclaration, Tool};
11use serde_json::Value;
12
13use crate::error::{Error, Result};
14
15/// 默认最大远程调用次数。
16pub const DEFAULT_MAX_REMOTE_CALLS: usize = 10;
17
18/// 可调用工具接口。
19pub trait CallableTool: Send {
20    fn tool(&mut self) -> BoxFuture<'_, Result<Tool>>;
21    fn call_tool(&mut self, function_calls: Vec<FunctionCall>) -> BoxFuture<'_, Result<Vec<Part>>>;
22}
23
24/// Inline callable tool handler 类型。
25pub type ToolHandler =
26    Box<dyn Fn(Value) -> BoxFuture<'static, Result<Value>> + Send + Sync + 'static>;
27
28/// 以函数声明 + handler 组合的可调用工具。
29#[derive(Default)]
30pub struct InlineCallableTool {
31    tool: Tool,
32    handlers: HashMap<String, ToolHandler>,
33}
34
35impl InlineCallableTool {
36    /// 通过 `FunctionDeclaration` 列表创建工具。
37    #[must_use]
38    pub fn from_declarations(declarations: Vec<FunctionDeclaration>) -> Self {
39        Self {
40            tool: Tool {
41                function_declarations: Some(declarations),
42                ..Tool::default()
43            },
44            handlers: HashMap::new(),
45        }
46    }
47
48    /// 注册 handler。
49    pub fn register_handler<F, Fut>(&mut self, name: impl Into<String>, handler: F)
50    where
51        F: Fn(Value) -> Fut + Send + Sync + 'static,
52        Fut: Future<Output = Result<Value>> + Send + 'static,
53    {
54        let key = name.into();
55        self.handlers.insert(
56            key,
57            Box::new(move |value| {
58                let fut = handler(value);
59                Box::pin(fut)
60            }),
61        );
62    }
63
64    /// 使用 builder 风格注册 handler。
65    #[must_use]
66    pub fn with_handler<F, Fut>(mut self, name: impl Into<String>, handler: F) -> Self
67    where
68        F: Fn(Value) -> Fut + Send + Sync + 'static,
69        Fut: Future<Output = Result<Value>> + Send + 'static,
70    {
71        self.register_handler(name, handler);
72        self
73    }
74}
75
76impl CallableTool for InlineCallableTool {
77    fn tool(&mut self) -> BoxFuture<'_, Result<Tool>> {
78        Box::pin(async move { Ok(self.tool.clone()) })
79    }
80
81    fn call_tool(&mut self, function_calls: Vec<FunctionCall>) -> BoxFuture<'_, Result<Vec<Part>>> {
82        Box::pin(async move {
83            let mut parts = Vec::new();
84            for call in function_calls {
85                let Some(name) = call.name.as_ref() else {
86                    continue;
87                };
88                let Some(handler) = self.handlers.get(name) else {
89                    continue;
90                };
91                let args = call.args.clone().unwrap_or(Value::Null);
92                let response_value = handler(args).await?;
93                let function_response = FunctionResponse {
94                    will_continue: None,
95                    scheduling: None,
96                    parts: None,
97                    id: call.id.clone(),
98                    name: Some(name.clone()),
99                    response: Some(response_value),
100                };
101                parts.push(Part::function_response(function_response));
102            }
103            Ok(parts)
104        })
105    }
106}
107
108/// 解析 callable tools,返回声明列表与函数映射。
109///
110/// # Errors
111/// 当工具声明重复或工具返回错误时返回错误。
112pub async fn resolve_callable_tools(
113    callable_tools: &mut [Box<dyn CallableTool>],
114) -> Result<CallableToolInfo> {
115    let mut tools = Vec::new();
116    let mut function_map: HashMap<String, usize> = HashMap::new();
117
118    for (index, tool) in callable_tools.iter_mut().enumerate() {
119        let declaration_tool = tool.tool().await?;
120        if let Some(declarations) = &declaration_tool.function_declarations {
121            for declaration in declarations {
122                if function_map.contains_key(&declaration.name) {
123                    return Err(Error::InvalidConfig {
124                        message: format!("Duplicate tool declaration name: {}", declaration.name),
125                    });
126                }
127                function_map.insert(declaration.name.clone(), index);
128            }
129        }
130        tools.push(declaration_tool);
131    }
132
133    Ok(CallableToolInfo {
134        tools,
135        function_map,
136    })
137}
138
139/// 调用 callable tools。
140///
141/// # Errors
142/// 当函数调用缺少工具或工具调用失败时返回错误。
143pub async fn call_callable_tools<S: BuildHasher + Sync>(
144    callable_tools: &mut [Box<dyn CallableTool>],
145    function_map: &HashMap<String, usize, S>,
146    function_calls: &[FunctionCall],
147) -> Result<Vec<Part>> {
148    let mut grouped: HashMap<usize, Vec<FunctionCall>> = HashMap::new();
149    for call in function_calls {
150        let name = call.name.as_ref().ok_or_else(|| Error::InvalidConfig {
151            message: "Function call name was not returned by the model.".into(),
152        })?;
153        let index = function_map.get(name).ok_or_else(|| Error::InvalidConfig {
154            message: format!(
155                "Automatic function calling was requested, but not all the tools the model used implement the CallableTool interface. Missing tool: {name}."
156            ),
157        })?;
158        grouped.entry(*index).or_default().push(call.clone());
159    }
160
161    let mut parts = Vec::new();
162    for (index, calls) in grouped {
163        let response_parts = callable_tools[index].call_tool(calls).await?;
164        parts.extend(response_parts);
165    }
166    Ok(parts)
167}
168
169/// callable tools 解析结果。
170pub struct CallableToolInfo<S = std::collections::hash_map::RandomState> {
171    pub tools: Vec<Tool>,
172    pub function_map: HashMap<String, usize, S>,
173}
174
175/// 判断是否应禁用 AFC。
176#[must_use]
177pub fn should_disable_afc(config: &GenerateContentConfig, has_callable_tools: bool) -> bool {
178    if !has_callable_tools {
179        return true;
180    }
181    if config
182        .automatic_function_calling
183        .as_ref()
184        .and_then(|cfg| cfg.disable)
185        .unwrap_or(false)
186    {
187        return true;
188    }
189    if let Some(max_calls) = config
190        .automatic_function_calling
191        .as_ref()
192        .and_then(|cfg| cfg.maximum_remote_calls)
193    {
194        if max_calls <= 0 {
195            return true;
196        }
197    }
198    false
199}
200
201/// 获取最大远程调用次数。
202#[must_use]
203pub fn max_remote_calls(config: &GenerateContentConfig) -> usize {
204    config
205        .automatic_function_calling
206        .as_ref()
207        .and_then(|cfg| cfg.maximum_remote_calls)
208        .and_then(|value| usize::try_from(value).ok())
209        .unwrap_or(DEFAULT_MAX_REMOTE_CALLS)
210}
211
212/// 是否应附加 AFC 历史。
213#[must_use]
214pub fn should_append_history(config: &GenerateContentConfig) -> bool {
215    !config
216        .automatic_function_calling
217        .as_ref()
218        .and_then(|cfg| cfg.ignore_call_history)
219        .unwrap_or(false)
220}
221
222/// 检查 AFC 兼容性(禁止未实现 `CallableTool` 的 function declarations)。
223///
224/// # Errors
225/// 当发现不兼容工具时返回错误。
226pub fn validate_afc_tools<S: BuildHasher>(
227    _callable_function_map: &HashMap<String, usize, S>,
228    tools: Option<&[Tool]>,
229) -> Result<()> {
230    let Some(tools) = tools else {
231        return Ok(());
232    };
233
234    for tool in tools {
235        if let Some(declarations) = &tool.function_declarations {
236            if !declarations.is_empty() {
237                return Err(Error::InvalidConfig {
238                    message: "Incompatible tools found. Automatic function calling does not support mixing CallableTools with basic function declarations.".into(),
239                });
240            }
241        }
242    }
243    Ok(())
244}
245
246/// 校验 AFC 与其他配置的冲突。
247///
248/// # Errors
249/// 当检测到不兼容配置时返回错误。
250pub fn validate_afc_config(config: &GenerateContentConfig) -> Result<()> {
251    if config
252        .tool_config
253        .as_ref()
254        .and_then(|cfg| cfg.function_calling_config.as_ref())
255        .and_then(|cfg| cfg.stream_function_call_arguments)
256        .unwrap_or(false)
257        && !config
258            .automatic_function_calling
259            .as_ref()
260            .and_then(|cfg| cfg.disable)
261            .unwrap_or(false)
262    {
263        return Err(Error::InvalidConfig {
264            message: "stream_function_call_arguments is not compatible with automatic function calling. Disable AFC or disable stream_function_call_arguments.".into(),
265        });
266    }
267    Ok(())
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use rust_genai_types::models::AutomaticFunctionCallingConfig;
274    use rust_genai_types::tool::{FunctionDeclaration, Tool};
275    use serde_json::json;
276
277    #[test]
278    fn test_should_disable_afc_when_max_calls_zero() {
279        let config = GenerateContentConfig {
280            automatic_function_calling: Some(AutomaticFunctionCallingConfig {
281                maximum_remote_calls: Some(0),
282                ..Default::default()
283            }),
284            ..Default::default()
285        };
286        assert!(should_disable_afc(&config, true));
287    }
288
289    #[test]
290    fn test_should_append_history_respects_ignore_flag() {
291        let config = GenerateContentConfig {
292            automatic_function_calling: Some(AutomaticFunctionCallingConfig {
293                ignore_call_history: Some(true),
294                ..Default::default()
295            }),
296            ..Default::default()
297        };
298        assert!(!should_append_history(&config));
299    }
300
301    #[test]
302    fn test_validate_afc_tools_rejects_plain_declarations() {
303        let tool = Tool {
304            function_declarations: Some(vec![FunctionDeclaration {
305                name: "test_fn".to_string(),
306                description: None,
307                parameters: None,
308                parameters_json_schema: None,
309                response: None,
310                response_json_schema: None,
311                behavior: None,
312            }]),
313            ..Default::default()
314        };
315        let err = validate_afc_tools(&HashMap::new(), Some(&[tool])).unwrap_err();
316        assert!(matches!(err, Error::InvalidConfig { .. }));
317    }
318
319    #[tokio::test]
320    async fn test_inline_callable_tool_roundtrip() {
321        let mut tool = InlineCallableTool::from_declarations(vec![FunctionDeclaration {
322            name: "sum".to_string(),
323            description: None,
324            parameters: None,
325            parameters_json_schema: None,
326            response: None,
327            response_json_schema: None,
328            behavior: None,
329        }]);
330        tool.register_handler("sum", |value| async move {
331            let a = value["a"].as_i64().unwrap_or(0);
332            let b = value["b"].as_i64().unwrap_or(0);
333            Ok(json!({ "result": a + b }))
334        });
335
336        let mut tools: Vec<Box<dyn CallableTool>> = vec![Box::new(tool)];
337        let info = resolve_callable_tools(&mut tools).await.unwrap();
338        assert!(info.function_map.contains_key("sum"));
339
340        let calls = vec![FunctionCall {
341            id: Some("call-1".into()),
342            name: Some("sum".into()),
343            args: Some(json!({"a": 1, "b": 2})),
344            partial_args: None,
345            will_continue: None,
346        }];
347        let parts = call_callable_tools(&mut tools, &info.function_map, &calls)
348            .await
349            .unwrap();
350        assert_eq!(parts.len(), 1);
351    }
352
353    #[tokio::test]
354    async fn test_call_callable_tools_rejects_missing_name() {
355        let mut tools: Vec<Box<dyn CallableTool>> = Vec::new();
356        let calls = vec![FunctionCall {
357            id: None,
358            name: None,
359            args: None,
360            partial_args: None,
361            will_continue: None,
362        }];
363        let err = call_callable_tools(&mut tools, &HashMap::new(), &calls)
364            .await
365            .unwrap_err();
366        assert!(matches!(err, Error::InvalidConfig { .. }));
367    }
368
369    #[tokio::test]
370    async fn test_call_callable_tools_rejects_unknown_tool() {
371        let calls = vec![FunctionCall {
372            id: Some("call-1".into()),
373            name: Some("missing".into()),
374            args: None,
375            partial_args: None,
376            will_continue: None,
377        }];
378        let err = call_callable_tools(&mut [], &HashMap::new(), &calls)
379            .await
380            .unwrap_err();
381        assert!(matches!(err, Error::InvalidConfig { .. }));
382    }
383}