spec_ai_core/tools/builtin/
calculator.rs

1use crate::tools::{Tool, ToolResult};
2use anyhow::{Context, Result};
3use async_trait::async_trait;
4use serde::Deserialize;
5use serde_json::Value;
6
7/// A calculator tool backed by a small standard math library
8pub struct MathTool;
9
10#[derive(Debug, Deserialize)]
11struct MathArgs {
12    operation: String,
13    a: f64,
14    b: f64,
15}
16
17impl MathTool {
18    pub fn new() -> Self {
19        Self
20    }
21
22    fn evaluate(&self, operation: &str, a: f64, b: f64) -> Result<f64> {
23        match operation {
24            "add" | "+" => Ok(a + b),
25            "subtract" | "-" => Ok(a - b),
26            "multiply" | "*" => Ok(a * b),
27            "divide" | "/" => {
28                if b == 0.0 {
29                    anyhow::bail!("Division by zero");
30                }
31                Ok(a / b)
32            }
33            "power" | "**" => Ok(a.powf(b)),
34            "modulo" | "%" => {
35                if b == 0.0 {
36                    anyhow::bail!("Modulo by zero");
37                }
38                Ok(a % b)
39            }
40            // Unary functions – operate on `a`, ignore `b`
41            "sqrt" => {
42                if a < 0.0 {
43                    anyhow::bail!("Cannot compute square root of a negative number");
44                }
45                Ok(a.sqrt())
46            }
47            "abs" => Ok(a.abs()),
48            "exp" => Ok(a.exp()),
49            "ln" => {
50                if a <= 0.0 {
51                    anyhow::bail!("Natural logarithm is only defined for positive values");
52                }
53                Ok(a.ln())
54            }
55            "log10" => {
56                if a <= 0.0 {
57                    anyhow::bail!("Base-10 logarithm is only defined for positive values");
58                }
59                Ok(a.log10())
60            }
61            "log2" => {
62                if a <= 0.0 {
63                    anyhow::bail!("Base-2 logarithm is only defined for positive values");
64                }
65                Ok(a.log2())
66            }
67            "sin" => Ok(a.sin()),
68            "cos" => Ok(a.cos()),
69            "tan" => Ok(a.tan()),
70            "asin" => {
71                if !(-1.0..=1.0).contains(&a) {
72                    anyhow::bail!("asin is only defined for inputs between -1 and 1");
73                }
74                Ok(a.asin())
75            }
76            "acos" => {
77                if !(-1.0..=1.0).contains(&a) {
78                    anyhow::bail!("acos is only defined for inputs between -1 and 1");
79                }
80                Ok(a.acos())
81            }
82            "atan" => Ok(a.atan()),
83            "sinh" => Ok(a.sinh()),
84            "cosh" => Ok(a.cosh()),
85            "tanh" => Ok(a.tanh()),
86            // Binary functions – use both `a` and `b`
87            "min" => Ok(a.min(b)),
88            "max" => Ok(a.max(b)),
89            "hypot" => Ok(a.hypot(b)),
90            "atan2" => Ok(a.atan2(b)),
91            _ => anyhow::bail!("Unsupported operation: {}", operation),
92        }
93    }
94}
95
96impl Default for MathTool {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102#[async_trait]
103impl Tool for MathTool {
104    fn name(&self) -> &str {
105        "calculator"
106    }
107
108    fn description(&self) -> &str {
109        "Calculator tool: performs mathematical operations using a small standard library (arithmetic, powers, modulo, roots, logs, trigonometric and hyperbolic functions, and simple two-argument operations like min/max)"
110    }
111
112    fn parameters(&self) -> Value {
113        serde_json::json!({
114            "type": "object",
115            "properties": {
116                "operation": {
117                    "type": "string",
118                    "description": "The operation to perform. Supports arithmetic (add, subtract, multiply, divide, power, modulo or +, -, *, /, **, %), common unary functions (sqrt, abs, exp, ln, log10, log2, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh), and simple binary functions (min, max, hypot, atan2). For unary functions, only `a` is used.",
119                    "enum": [
120                        "add", "subtract", "multiply", "divide", "power", "modulo",
121                        "+", "-", "*", "/", "**", "%",
122                        "sqrt", "abs", "exp", "ln", "log10", "log2",
123                        "sin", "cos", "tan", "asin", "acos", "atan",
124                        "sinh", "cosh", "tanh",
125                        "min", "max", "hypot", "atan2"
126                    ]
127                },
128                "a": {
129                    "type": "number",
130                    "description": "The first operand (or the sole operand for unary functions)"
131                },
132                "b": {
133                    "type": "number",
134                    "description": "The second operand (ignored for unary functions such as sqrt, ln, sin, etc.)"
135                }
136            },
137            "required": ["operation", "a", "b"]
138        })
139    }
140
141    async fn execute(&self, args: Value) -> Result<ToolResult> {
142        let math_args: MathArgs =
143            serde_json::from_value(args).context("Failed to parse math arguments")?;
144
145        match self.evaluate(&math_args.operation, math_args.a, math_args.b) {
146            Ok(result) => Ok(ToolResult::success(result.to_string())),
147            Err(e) => Ok(ToolResult::failure(e.to_string())),
148        }
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[tokio::test]
157    async fn test_math_tool_basic() {
158        let tool = MathTool::new();
159
160        assert_eq!(tool.name(), "calculator");
161        assert!(!tool.description().is_empty());
162    }
163
164    #[tokio::test]
165    async fn test_math_tool_parameters() {
166        let tool = MathTool::new();
167        let params = tool.parameters();
168
169        assert!(params.is_object());
170        assert!(params["properties"]["operation"].is_object());
171        assert!(params["properties"]["a"].is_object());
172        assert!(params["properties"]["b"].is_object());
173    }
174
175    #[tokio::test]
176    async fn test_math_tool_add() {
177        let tool = MathTool::new();
178        let args = serde_json::json!({
179            "operation": "add",
180            "a": 5.0,
181            "b": 3.0
182        });
183
184        let result = tool.execute(args).await.unwrap();
185
186        assert!(result.success);
187        assert_eq!(result.output, "8");
188    }
189
190    #[tokio::test]
191    async fn test_math_tool_add_symbol() {
192        let tool = MathTool::new();
193        let args = serde_json::json!({
194            "operation": "+",
195            "a": 10.5,
196            "b": 2.5
197        });
198
199        let result = tool.execute(args).await.unwrap();
200
201        assert!(result.success);
202        assert_eq!(result.output, "13");
203    }
204
205    #[tokio::test]
206    async fn test_math_tool_subtract() {
207        let tool = MathTool::new();
208        let args = serde_json::json!({
209            "operation": "subtract",
210            "a": 10.0,
211            "b": 3.0
212        });
213
214        let result = tool.execute(args).await.unwrap();
215
216        assert!(result.success);
217        assert_eq!(result.output, "7");
218    }
219
220    #[tokio::test]
221    async fn test_math_tool_multiply() {
222        let tool = MathTool::new();
223        let args = serde_json::json!({
224            "operation": "multiply",
225            "a": 4.0,
226            "b": 5.0
227        });
228
229        let result = tool.execute(args).await.unwrap();
230
231        assert!(result.success);
232        assert_eq!(result.output, "20");
233    }
234
235    #[tokio::test]
236    async fn test_math_tool_divide() {
237        let tool = MathTool::new();
238        let args = serde_json::json!({
239            "operation": "divide",
240            "a": 15.0,
241            "b": 3.0
242        });
243
244        let result = tool.execute(args).await.unwrap();
245
246        assert!(result.success);
247        assert_eq!(result.output, "5");
248    }
249
250    #[tokio::test]
251    async fn test_math_tool_divide_by_zero() {
252        let tool = MathTool::new();
253        let args = serde_json::json!({
254            "operation": "divide",
255            "a": 10.0,
256            "b": 0.0
257        });
258
259        let result = tool.execute(args).await.unwrap();
260
261        assert!(!result.success);
262        assert!(result.error.is_some());
263        assert!(result.error.unwrap().contains("Division by zero"));
264    }
265
266    #[tokio::test]
267    async fn test_math_tool_power() {
268        let tool = MathTool::new();
269        let args = serde_json::json!({
270            "operation": "power",
271            "a": 2.0,
272            "b": 3.0
273        });
274
275        let result = tool.execute(args).await.unwrap();
276
277        assert!(result.success);
278        assert_eq!(result.output, "8");
279    }
280
281    #[tokio::test]
282    async fn test_math_tool_modulo() {
283        let tool = MathTool::new();
284        let args = serde_json::json!({
285            "operation": "modulo",
286            "a": 10.0,
287            "b": 3.0
288        });
289
290        let result = tool.execute(args).await.unwrap();
291
292        assert!(result.success);
293        assert_eq!(result.output, "1");
294    }
295
296    #[tokio::test]
297    async fn test_math_tool_modulo_by_zero() {
298        let tool = MathTool::new();
299        let args = serde_json::json!({
300            "operation": "modulo",
301            "a": 10.0,
302            "b": 0.0
303        });
304
305        let result = tool.execute(args).await.unwrap();
306
307        assert!(!result.success);
308        assert!(result.error.is_some());
309    }
310
311    #[tokio::test]
312    async fn test_math_tool_invalid_operation() {
313        let tool = MathTool::new();
314        let args = serde_json::json!({
315            "operation": "invalid",
316            "a": 10.0,
317            "b": 3.0
318        });
319
320        let result = tool.execute(args).await.unwrap();
321
322        assert!(!result.success);
323        assert!(result.error.is_some());
324    }
325
326    #[tokio::test]
327    async fn test_math_tool_missing_arguments() {
328        let tool = MathTool::new();
329        let args = serde_json::json!({
330            "operation": "add"
331        });
332
333        let result = tool.execute(args).await;
334
335        assert!(result.is_err());
336    }
337
338    #[tokio::test]
339    async fn test_math_tool_negative_numbers() {
340        let tool = MathTool::new();
341        let args = serde_json::json!({
342            "operation": "add",
343            "a": -5.0,
344            "b": 3.0
345        });
346
347        let result = tool.execute(args).await.unwrap();
348
349        assert!(result.success);
350        assert_eq!(result.output, "-2");
351    }
352
353    #[tokio::test]
354    async fn test_math_tool_decimal_numbers() {
355        let tool = MathTool::new();
356        let args = serde_json::json!({
357            "operation": "multiply",
358            "a": 2.5,
359            "b": 4.2
360        });
361
362        let result = tool.execute(args).await.unwrap();
363
364        assert!(result.success);
365        let output: f64 = result.output.parse().unwrap();
366        assert!((output - 10.5).abs() < 0.0001);
367    }
368
369    #[tokio::test]
370    async fn test_math_tool_sqrt() {
371        let tool = MathTool::new();
372        let args = serde_json::json!({
373            "operation": "sqrt",
374            "a": 16.0,
375            "b": 0.0
376        });
377
378        let result = tool.execute(args).await.unwrap();
379
380        assert!(result.success);
381        let output: f64 = result.output.parse().unwrap();
382        assert!((output - 4.0).abs() < 0.0001);
383    }
384
385    #[tokio::test]
386    async fn test_math_tool_sin() {
387        let tool = MathTool::new();
388        // sin(pi/2) ≈ 1.0
389        let args = serde_json::json!({
390            "operation": "sin",
391            "a": std::f64::consts::FRAC_PI_2,
392            "b": 0.0
393        });
394
395        let result = tool.execute(args).await.unwrap();
396
397        assert!(result.success);
398        let output: f64 = result.output.parse().unwrap();
399        assert!((output - 1.0).abs() < 0.0001);
400    }
401}