spec_ai_core/tools/builtin/
calculator.rs1use crate::tools::{Tool, ToolResult};
2use anyhow::{Context, Result};
3use async_trait::async_trait;
4use serde::Deserialize;
5use serde_json::Value;
6
7pub struct MathTool;
9
10#[derive(Debug, Deserialize)]
11struct MathArgs {
12 operation: String,
13 a: f64,
14 b: f64,
15}
16
17impl MathTool {
18 pub fn new() -> Self {
19 Self
20 }
21
22 fn evaluate(&self, operation: &str, a: f64, b: f64) -> Result<f64> {
23 match operation {
24 "add" | "+" => Ok(a + b),
25 "subtract" | "-" => Ok(a - b),
26 "multiply" | "*" => Ok(a * b),
27 "divide" | "/" => {
28 if b == 0.0 {
29 anyhow::bail!("Division by zero");
30 }
31 Ok(a / b)
32 }
33 "power" | "**" => Ok(a.powf(b)),
34 "modulo" | "%" => {
35 if b == 0.0 {
36 anyhow::bail!("Modulo by zero");
37 }
38 Ok(a % b)
39 }
40 "sqrt" => {
42 if a < 0.0 {
43 anyhow::bail!("Cannot compute square root of a negative number");
44 }
45 Ok(a.sqrt())
46 }
47 "abs" => Ok(a.abs()),
48 "exp" => Ok(a.exp()),
49 "ln" => {
50 if a <= 0.0 {
51 anyhow::bail!("Natural logarithm is only defined for positive values");
52 }
53 Ok(a.ln())
54 }
55 "log10" => {
56 if a <= 0.0 {
57 anyhow::bail!("Base-10 logarithm is only defined for positive values");
58 }
59 Ok(a.log10())
60 }
61 "log2" => {
62 if a <= 0.0 {
63 anyhow::bail!("Base-2 logarithm is only defined for positive values");
64 }
65 Ok(a.log2())
66 }
67 "sin" => Ok(a.sin()),
68 "cos" => Ok(a.cos()),
69 "tan" => Ok(a.tan()),
70 "asin" => {
71 if !(-1.0..=1.0).contains(&a) {
72 anyhow::bail!("asin is only defined for inputs between -1 and 1");
73 }
74 Ok(a.asin())
75 }
76 "acos" => {
77 if !(-1.0..=1.0).contains(&a) {
78 anyhow::bail!("acos is only defined for inputs between -1 and 1");
79 }
80 Ok(a.acos())
81 }
82 "atan" => Ok(a.atan()),
83 "sinh" => Ok(a.sinh()),
84 "cosh" => Ok(a.cosh()),
85 "tanh" => Ok(a.tanh()),
86 "min" => Ok(a.min(b)),
88 "max" => Ok(a.max(b)),
89 "hypot" => Ok(a.hypot(b)),
90 "atan2" => Ok(a.atan2(b)),
91 _ => anyhow::bail!("Unsupported operation: {}", operation),
92 }
93 }
94}
95
96impl Default for MathTool {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102#[async_trait]
103impl Tool for MathTool {
104 fn name(&self) -> &str {
105 "calculator"
106 }
107
108 fn description(&self) -> &str {
109 "Calculator tool: performs mathematical operations using a small standard library (arithmetic, powers, modulo, roots, logs, trigonometric and hyperbolic functions, and simple two-argument operations like min/max)"
110 }
111
112 fn parameters(&self) -> Value {
113 serde_json::json!({
114 "type": "object",
115 "properties": {
116 "operation": {
117 "type": "string",
118 "description": "The operation to perform. Supports arithmetic (add, subtract, multiply, divide, power, modulo or +, -, *, /, **, %), common unary functions (sqrt, abs, exp, ln, log10, log2, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh), and simple binary functions (min, max, hypot, atan2). For unary functions, only `a` is used.",
119 "enum": [
120 "add", "subtract", "multiply", "divide", "power", "modulo",
121 "+", "-", "*", "/", "**", "%",
122 "sqrt", "abs", "exp", "ln", "log10", "log2",
123 "sin", "cos", "tan", "asin", "acos", "atan",
124 "sinh", "cosh", "tanh",
125 "min", "max", "hypot", "atan2"
126 ]
127 },
128 "a": {
129 "type": "number",
130 "description": "The first operand (or the sole operand for unary functions)"
131 },
132 "b": {
133 "type": "number",
134 "description": "The second operand (ignored for unary functions such as sqrt, ln, sin, etc.)"
135 }
136 },
137 "required": ["operation", "a", "b"]
138 })
139 }
140
141 async fn execute(&self, args: Value) -> Result<ToolResult> {
142 let math_args: MathArgs =
143 serde_json::from_value(args).context("Failed to parse math arguments")?;
144
145 match self.evaluate(&math_args.operation, math_args.a, math_args.b) {
146 Ok(result) => Ok(ToolResult::success(result.to_string())),
147 Err(e) => Ok(ToolResult::failure(e.to_string())),
148 }
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[tokio::test]
157 async fn test_math_tool_basic() {
158 let tool = MathTool::new();
159
160 assert_eq!(tool.name(), "calculator");
161 assert!(!tool.description().is_empty());
162 }
163
164 #[tokio::test]
165 async fn test_math_tool_parameters() {
166 let tool = MathTool::new();
167 let params = tool.parameters();
168
169 assert!(params.is_object());
170 assert!(params["properties"]["operation"].is_object());
171 assert!(params["properties"]["a"].is_object());
172 assert!(params["properties"]["b"].is_object());
173 }
174
175 #[tokio::test]
176 async fn test_math_tool_add() {
177 let tool = MathTool::new();
178 let args = serde_json::json!({
179 "operation": "add",
180 "a": 5.0,
181 "b": 3.0
182 });
183
184 let result = tool.execute(args).await.unwrap();
185
186 assert!(result.success);
187 assert_eq!(result.output, "8");
188 }
189
190 #[tokio::test]
191 async fn test_math_tool_add_symbol() {
192 let tool = MathTool::new();
193 let args = serde_json::json!({
194 "operation": "+",
195 "a": 10.5,
196 "b": 2.5
197 });
198
199 let result = tool.execute(args).await.unwrap();
200
201 assert!(result.success);
202 assert_eq!(result.output, "13");
203 }
204
205 #[tokio::test]
206 async fn test_math_tool_subtract() {
207 let tool = MathTool::new();
208 let args = serde_json::json!({
209 "operation": "subtract",
210 "a": 10.0,
211 "b": 3.0
212 });
213
214 let result = tool.execute(args).await.unwrap();
215
216 assert!(result.success);
217 assert_eq!(result.output, "7");
218 }
219
220 #[tokio::test]
221 async fn test_math_tool_multiply() {
222 let tool = MathTool::new();
223 let args = serde_json::json!({
224 "operation": "multiply",
225 "a": 4.0,
226 "b": 5.0
227 });
228
229 let result = tool.execute(args).await.unwrap();
230
231 assert!(result.success);
232 assert_eq!(result.output, "20");
233 }
234
235 #[tokio::test]
236 async fn test_math_tool_divide() {
237 let tool = MathTool::new();
238 let args = serde_json::json!({
239 "operation": "divide",
240 "a": 15.0,
241 "b": 3.0
242 });
243
244 let result = tool.execute(args).await.unwrap();
245
246 assert!(result.success);
247 assert_eq!(result.output, "5");
248 }
249
250 #[tokio::test]
251 async fn test_math_tool_divide_by_zero() {
252 let tool = MathTool::new();
253 let args = serde_json::json!({
254 "operation": "divide",
255 "a": 10.0,
256 "b": 0.0
257 });
258
259 let result = tool.execute(args).await.unwrap();
260
261 assert!(!result.success);
262 assert!(result.error.is_some());
263 assert!(result.error.unwrap().contains("Division by zero"));
264 }
265
266 #[tokio::test]
267 async fn test_math_tool_power() {
268 let tool = MathTool::new();
269 let args = serde_json::json!({
270 "operation": "power",
271 "a": 2.0,
272 "b": 3.0
273 });
274
275 let result = tool.execute(args).await.unwrap();
276
277 assert!(result.success);
278 assert_eq!(result.output, "8");
279 }
280
281 #[tokio::test]
282 async fn test_math_tool_modulo() {
283 let tool = MathTool::new();
284 let args = serde_json::json!({
285 "operation": "modulo",
286 "a": 10.0,
287 "b": 3.0
288 });
289
290 let result = tool.execute(args).await.unwrap();
291
292 assert!(result.success);
293 assert_eq!(result.output, "1");
294 }
295
296 #[tokio::test]
297 async fn test_math_tool_modulo_by_zero() {
298 let tool = MathTool::new();
299 let args = serde_json::json!({
300 "operation": "modulo",
301 "a": 10.0,
302 "b": 0.0
303 });
304
305 let result = tool.execute(args).await.unwrap();
306
307 assert!(!result.success);
308 assert!(result.error.is_some());
309 }
310
311 #[tokio::test]
312 async fn test_math_tool_invalid_operation() {
313 let tool = MathTool::new();
314 let args = serde_json::json!({
315 "operation": "invalid",
316 "a": 10.0,
317 "b": 3.0
318 });
319
320 let result = tool.execute(args).await.unwrap();
321
322 assert!(!result.success);
323 assert!(result.error.is_some());
324 }
325
326 #[tokio::test]
327 async fn test_math_tool_missing_arguments() {
328 let tool = MathTool::new();
329 let args = serde_json::json!({
330 "operation": "add"
331 });
332
333 let result = tool.execute(args).await;
334
335 assert!(result.is_err());
336 }
337
338 #[tokio::test]
339 async fn test_math_tool_negative_numbers() {
340 let tool = MathTool::new();
341 let args = serde_json::json!({
342 "operation": "add",
343 "a": -5.0,
344 "b": 3.0
345 });
346
347 let result = tool.execute(args).await.unwrap();
348
349 assert!(result.success);
350 assert_eq!(result.output, "-2");
351 }
352
353 #[tokio::test]
354 async fn test_math_tool_decimal_numbers() {
355 let tool = MathTool::new();
356 let args = serde_json::json!({
357 "operation": "multiply",
358 "a": 2.5,
359 "b": 4.2
360 });
361
362 let result = tool.execute(args).await.unwrap();
363
364 assert!(result.success);
365 let output: f64 = result.output.parse().unwrap();
366 assert!((output - 10.5).abs() < 0.0001);
367 }
368
369 #[tokio::test]
370 async fn test_math_tool_sqrt() {
371 let tool = MathTool::new();
372 let args = serde_json::json!({
373 "operation": "sqrt",
374 "a": 16.0,
375 "b": 0.0
376 });
377
378 let result = tool.execute(args).await.unwrap();
379
380 assert!(result.success);
381 let output: f64 = result.output.parse().unwrap();
382 assert!((output - 4.0).abs() < 0.0001);
383 }
384
385 #[tokio::test]
386 async fn test_math_tool_sin() {
387 let tool = MathTool::new();
388 let args = serde_json::json!({
390 "operation": "sin",
391 "a": std::f64::consts::FRAC_PI_2,
392 "b": 0.0
393 });
394
395 let result = tool.execute(args).await.unwrap();
396
397 assert!(result.success);
398 let output: f64 = result.output.parse().unwrap();
399 assert!((output - 1.0).abs() < 0.0001);
400 }
401}