1use crate::{Error, Parameter, ParameterType, Result};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
9pub struct ToolDefinition {
10 pub name: String,
12
13 #[serde(default)]
15 pub description: String,
16
17 #[serde(default)]
19 pub parameters: Vec<Parameter>,
20}
21
22impl ToolDefinition {
23 pub fn new(name: impl Into<String>) -> Self {
25 Self {
26 name: name.into(),
27 description: String::new(),
28 parameters: Vec::new(),
29 }
30 }
31
32 pub fn builder(name: impl Into<String>) -> ToolDefinitionBuilder {
34 ToolDefinitionBuilder::new(name)
35 }
36
37 pub fn get_parameter(&self, name: &str) -> Option<&Parameter> {
39 self.parameters.iter().find(|p| p.name == name)
40 }
41
42 pub fn required_parameters(&self) -> impl Iterator<Item = &Parameter> {
44 self.parameters.iter().filter(|p| p.required)
45 }
46
47 pub fn validate_args(&self, args: &Value) -> Result<()> {
49 let empty_map = serde_json::Map::new();
50 let args_obj = args.as_object().unwrap_or(&empty_map);
51
52 for param in self.required_parameters() {
54 if !args_obj.contains_key(¶m.name) {
55 if param.default.is_none() {
57 return Err(Error::MissingParameter(param.name.clone()));
58 }
59 }
60 }
61
62 for (key, value) in args_obj {
64 if let Some(param) = self.get_parameter(key) {
65 if !param.param_type.matches(value) {
66 return Err(Error::InvalidParameterType {
67 name: key.clone(),
68 expected: param.param_type.as_str().to_string(),
69 actual: json_type_name(value).to_string(),
70 });
71 }
72
73 if !param.enum_values.is_empty() && !param.enum_values.contains(value) {
75 return Err(Error::InvalidConfig(format!(
76 "parameter '{}' must be one of: {:?}",
77 key, param.enum_values
78 )));
79 }
80 }
81 }
82
83 Ok(())
84 }
85
86 pub fn parse_mcp_input_schema(schema: &serde_json::Value) -> Result<Vec<Parameter>> {
88 let mut params = Vec::new();
89
90 if let Some(properties) = schema.get("properties") {
91 if let Some(props_obj) = properties.as_object() {
92 for (name, prop) in props_obj {
93 let param_type = if let Some(type_val) = prop.get("type") {
94 match type_val.as_str() {
95 Some("string") => ParameterType::String,
96 Some("number") | Some("integer") => ParameterType::Integer,
97 Some("boolean") => ParameterType::Boolean,
98 Some("array") => ParameterType::Array,
99 Some("object") => ParameterType::Object,
100 _ => ParameterType::String,
101 }
102 } else {
103 ParameterType::String
104 };
105
106 let description = prop
107 .get("description")
108 .and_then(|v| v.as_str())
109 .unwrap_or("")
110 .to_string();
111
112 let required = if let Some(req_array) = schema.get("required") {
113 if let Some(arr) = req_array.as_array() {
114 arr.iter().any(|v| v.as_str() == Some(name))
115 } else {
116 false
117 }
118 } else {
119 false
120 };
121
122 params.push(Parameter {
123 name: name.to_string(),
124 param_type,
125 description,
126 required,
127 default: None,
128 enum_values: vec![],
129 });
130 }
131 }
132 }
133
134 Ok(params)
135 }
136}
137
138fn json_type_name(value: &Value) -> &'static str {
140 match value {
141 Value::Null => "null",
142 Value::Bool(_) => "boolean",
143 Value::Number(_) => "number",
144 Value::String(_) => "string",
145 Value::Array(_) => "array",
146 Value::Object(_) => "object",
147 }
148}
149
150#[derive(Debug, Default)]
152pub struct ToolDefinitionBuilder {
153 name: String,
154 description: String,
155 parameters: Vec<Parameter>,
156}
157
158impl ToolDefinitionBuilder {
159 pub fn new(name: impl Into<String>) -> Self {
161 Self {
162 name: name.into(),
163 ..Default::default()
164 }
165 }
166
167 pub fn description(mut self, description: impl Into<String>) -> Self {
169 self.description = description.into();
170 self
171 }
172
173 pub fn parameter(mut self, parameter: Parameter) -> Self {
175 self.parameters.push(parameter);
176 self
177 }
178
179 pub fn parameters(mut self, parameters: impl IntoIterator<Item = Parameter>) -> Self {
181 self.parameters.extend(parameters);
182 self
183 }
184
185 pub fn build(self) -> ToolDefinition {
187 ToolDefinition {
188 name: self.name,
189 description: self.description,
190 parameters: self.parameters,
191 }
192 }
193}
194
195#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
197pub struct ToolCall {
198 pub tool: String,
200
201 #[serde(default)]
203 pub arguments: Value,
204}
205
206impl ToolCall {
207 pub fn new(tool: impl Into<String>) -> Self {
209 Self {
210 tool: tool.into(),
211 arguments: Value::Object(serde_json::Map::new()),
212 }
213 }
214
215 pub fn with_args(tool: impl Into<String>, arguments: Value) -> Self {
217 Self {
218 tool: tool.into(),
219 arguments,
220 }
221 }
222
223 pub fn builder(tool: impl Into<String>) -> ToolCallBuilder {
225 ToolCallBuilder::new(tool)
226 }
227}
228
229#[derive(Debug, Default)]
231pub struct ToolCallBuilder {
232 tool: String,
233 arguments: serde_json::Map<String, Value>,
234}
235
236impl ToolCallBuilder {
237 pub fn new(tool: impl Into<String>) -> Self {
239 Self {
240 tool: tool.into(),
241 arguments: serde_json::Map::new(),
242 }
243 }
244
245 pub fn arg(mut self, name: impl Into<String>, value: Value) -> Self {
247 self.arguments.insert(name.into(), value);
248 self
249 }
250
251 pub fn arg_str(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
253 self.arguments
254 .insert(name.into(), Value::String(value.into()));
255 self
256 }
257
258 pub fn arg_int(mut self, name: impl Into<String>, value: i64) -> Self {
260 self.arguments
261 .insert(name.into(), Value::Number(value.into()));
262 self
263 }
264
265 pub fn arg_bool(mut self, name: impl Into<String>, value: bool) -> Self {
267 self.arguments.insert(name.into(), Value::Bool(value));
268 self
269 }
270
271 pub fn build(self) -> ToolCall {
273 ToolCall {
274 tool: self.tool,
275 arguments: Value::Object(self.arguments),
276 }
277 }
278}
279
280#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
282pub struct ToolResult {
283 pub success: bool,
285
286 #[serde(default, skip_serializing_if = "Option::is_none")]
288 pub data: Option<Value>,
289
290 #[serde(default, skip_serializing_if = "Option::is_none")]
292 pub error: Option<String>,
293
294 #[serde(default, skip_serializing_if = "Option::is_none")]
296 pub duration_ms: Option<u64>,
297}
298
299impl ToolResult {
300 pub fn success(data: Value) -> Self {
302 Self {
303 success: true,
304 data: Some(data),
305 error: None,
306 duration_ms: None,
307 }
308 }
309
310 pub fn failure(error: impl Into<String>) -> Self {
312 Self {
313 success: false,
314 data: None,
315 error: Some(error.into()),
316 duration_ms: None,
317 }
318 }
319
320 pub fn with_duration(mut self, duration_ms: u64) -> Self {
322 self.duration_ms = Some(duration_ms);
323 self
324 }
325
326 pub fn is_success(&self) -> bool {
328 self.success
329 }
330
331 pub fn into_data(self) -> Result<Value> {
333 if self.success {
334 self.data
335 .ok_or_else(|| Error::ExecutionFailed("no data returned".to_string()))
336 } else {
337 Err(Error::ExecutionFailed(
338 self.error.unwrap_or_else(|| "unknown error".to_string()),
339 ))
340 }
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use crate::ParameterType;
348 use serde_json::json;
349
350 #[test]
351 fn tool_definition_new() {
352 let tool = ToolDefinition::new("test_tool");
353 assert_eq!(tool.name, "test_tool");
354 assert!(tool.description.is_empty());
355 assert!(tool.parameters.is_empty());
356 }
357
358 #[test]
359 fn tool_definition_builder() {
360 let tool = ToolDefinition::builder("read_file")
361 .description("Read a file")
362 .parameter(Parameter::required_string("path"))
363 .build();
364
365 assert_eq!(tool.name, "read_file");
366 assert_eq!(tool.description, "Read a file");
367 assert_eq!(tool.parameters.len(), 1);
368 }
369
370 #[test]
371 fn tool_definition_get_parameter() {
372 let tool = ToolDefinition::builder("test")
373 .parameter(Parameter::required_string("path"))
374 .parameter(Parameter::optional_string("encoding"))
375 .build();
376
377 assert!(tool.get_parameter("path").is_some());
378 assert!(tool.get_parameter("encoding").is_some());
379 assert!(tool.get_parameter("nonexistent").is_none());
380 }
381
382 #[test]
383 fn tool_definition_required_parameters() {
384 let tool = ToolDefinition::builder("test")
385 .parameter(Parameter::required_string("required1"))
386 .parameter(Parameter::optional_string("optional1"))
387 .parameter(Parameter::required_string("required2"))
388 .build();
389
390 let required: Vec<_> = tool.required_parameters().collect();
391 assert_eq!(required.len(), 2);
392 assert!(required.iter().any(|p| p.name == "required1"));
393 assert!(required.iter().any(|p| p.name == "required2"));
394 }
395
396 #[test]
397 fn tool_definition_validate_args_success() {
398 let tool = ToolDefinition::builder("test")
399 .parameter(Parameter::required_string("name"))
400 .parameter(
401 Parameter::builder("count")
402 .param_type(ParameterType::Integer)
403 .build(),
404 )
405 .build();
406
407 let args = json!({"name": "test", "count": 5});
408 assert!(tool.validate_args(&args).is_ok());
409 }
410
411 #[test]
412 fn tool_definition_validate_args_missing_required() {
413 let tool = ToolDefinition::builder("test")
414 .parameter(Parameter::required_string("name"))
415 .build();
416
417 let args = json!({});
418 let result = tool.validate_args(&args);
419 assert!(matches!(result, Err(Error::MissingParameter(_))));
420 }
421
422 #[test]
423 fn tool_definition_validate_args_with_default() {
424 let tool = ToolDefinition::builder("test")
425 .parameter(
426 Parameter::builder("count")
427 .required(true)
428 .default(json!(10))
429 .build(),
430 )
431 .build();
432
433 let args = json!({});
434 assert!(tool.validate_args(&args).is_ok());
435 }
436
437 #[test]
438 fn tool_definition_validate_args_wrong_type() {
439 let tool = ToolDefinition::builder("test")
440 .parameter(
441 Parameter::builder("count")
442 .param_type(ParameterType::Integer)
443 .build(),
444 )
445 .build();
446
447 let args = json!({"count": "not a number"});
448 let result = tool.validate_args(&args);
449 assert!(matches!(result, Err(Error::InvalidParameterType { .. })));
450 }
451
452 #[test]
453 fn tool_definition_validate_args_enum() {
454 let tool = ToolDefinition::builder("test")
455 .parameter(
456 Parameter::builder("format")
457 .enum_value(json!("json"))
458 .enum_value(json!("yaml"))
459 .build(),
460 )
461 .build();
462
463 assert!(tool.validate_args(&json!({"format": "json"})).is_ok());
464 assert!(tool.validate_args(&json!({"format": "yaml"})).is_ok());
465 assert!(tool.validate_args(&json!({"format": "toml"})).is_err());
466 }
467
468 #[test]
469 fn tool_call_new() {
470 let call = ToolCall::new("test_tool");
471 assert_eq!(call.tool, "test_tool");
472 assert!(call.arguments.is_object());
473 }
474
475 #[test]
476 fn tool_call_with_args() {
477 let args = json!({"path": "/tmp/test.txt"});
478 let call = ToolCall::with_args("read_file", args.clone());
479 assert_eq!(call.tool, "read_file");
480 assert_eq!(call.arguments, args);
481 }
482
483 #[test]
484 fn tool_call_builder() {
485 let call = ToolCall::builder("github.list_repos")
486 .arg_str("owner", "octocat")
487 .arg_int("per_page", 10)
488 .arg_bool("include_forks", false)
489 .build();
490
491 assert_eq!(call.tool, "github.list_repos");
492 assert_eq!(call.arguments["owner"], "octocat");
493 assert_eq!(call.arguments["per_page"], 10);
494 assert_eq!(call.arguments["include_forks"], false);
495 }
496
497 #[test]
498 fn tool_result_success() {
499 let result = ToolResult::success(json!({"status": "ok"}));
500 assert!(result.is_success());
501 assert!(result.data.is_some());
502 assert!(result.error.is_none());
503 }
504
505 #[test]
506 fn tool_result_failure() {
507 let result = ToolResult::failure("Something went wrong");
508 assert!(!result.is_success());
509 assert!(result.data.is_none());
510 assert_eq!(result.error, Some("Something went wrong".to_string()));
511 }
512
513 #[test]
514 fn tool_result_with_duration() {
515 let result = ToolResult::success(json!(null)).with_duration(250);
516 assert_eq!(result.duration_ms, Some(250));
517 }
518
519 #[test]
520 fn tool_result_into_data_success() {
521 let result = ToolResult::success(json!({"value": 42}));
522 let data = result.into_data().unwrap();
523 assert_eq!(data["value"], 42);
524 }
525
526 #[test]
527 fn tool_result_into_data_failure() {
528 let result = ToolResult::failure("error");
529 let err = result.into_data().unwrap_err();
530 assert!(matches!(err, Error::ExecutionFailed(_)));
531 }
532
533 #[test]
534 fn tool_definition_serialization() {
535 let tool = ToolDefinition::builder("test")
536 .description("A test tool")
537 .parameter(Parameter::required_string("name"))
538 .build();
539
540 let json = serde_json::to_string(&tool).unwrap();
541 let parsed: ToolDefinition = serde_json::from_str(&json).unwrap();
542
543 assert_eq!(tool, parsed);
544 }
545
546 #[test]
547 fn tool_definition_validate_args_edge_cases() {
548 let tool = ToolDefinition::builder("test")
549 .parameter(
550 Parameter::builder("array_param")
551 .param_type(ParameterType::Array)
552 .required(true)
553 .build(),
554 )
555 .parameter(
556 Parameter::builder("object_param")
557 .param_type(ParameterType::Object)
558 .required(false)
559 .default(json!({}))
560 .build(),
561 )
562 .build();
563
564 assert!(tool
566 .validate_args(&json!({"array_param": [1, 2, 3], "object_param": {"key": "value"}}))
567 .is_ok());
568
569 assert!(tool.validate_args(&json!({"array_param": []})).is_ok());
571
572 assert!(tool.validate_args(&json!({"array_param": {}})).is_err());
574
575 assert!(tool
577 .validate_args(&json!({"array_param": [], "object_param": "not an object"}))
578 .is_err());
579 }
580
581 #[test]
582 fn tool_definition_validate_args_with_all_types() {
583 let tool = ToolDefinition::builder("test")
584 .parameter(Parameter::required_string("str_param"))
585 .parameter(
586 Parameter::builder("int_param")
587 .param_type(ParameterType::Integer)
588 .required(true)
589 .build(),
590 )
591 .parameter(
592 Parameter::builder("num_param")
593 .param_type(ParameterType::Number)
594 .required(true)
595 .build(),
596 )
597 .parameter(
598 Parameter::builder("bool_param")
599 .param_type(ParameterType::Boolean)
600 .required(true)
601 .build(),
602 )
603 .parameter(
604 Parameter::builder("arr_param")
605 .param_type(ParameterType::Array)
606 .required(true)
607 .build(),
608 )
609 .parameter(
610 Parameter::builder("obj_param")
611 .param_type(ParameterType::Object)
612 .required(true)
613 .build(),
614 )
615 .build();
616
617 let args = json!({
618 "str_param": "test",
619 "int_param": 42,
620 "num_param": 3.14,
621 "bool_param": true,
622 "arr_param": [1, 2, 3],
623 "obj_param": {"key": "value"}
624 });
625
626 assert!(tool.validate_args(&args).is_ok());
627
628 assert!(tool.validate_args(&json!({"str_param": 42, "int_param": 42, "num_param": 3.14, "bool_param": true, "arr_param": [], "obj_param": {}})).is_err());
630 assert!(tool.validate_args(&json!({"str_param": "test", "int_param": "not int", "num_param": 3.14, "bool_param": true, "arr_param": [], "obj_param": {}})).is_err());
631 }
632
633 #[test]
634 fn tool_definition_validate_args_empty_required() {
635 let tool = ToolDefinition::builder("test")
636 .parameter(Parameter::required_string("param1"))
637 .parameter(Parameter::required_string("param2"))
638 .parameter(Parameter::required_string("param3"))
639 .build();
640
641 assert!(tool.validate_args(&json!({})).is_err());
643
644 assert!(tool.validate_args(&json!({"param1": "value"})).is_err());
646
647 assert!(tool
649 .validate_args(&json!({"param1": "v1", "param2": "v2", "param3": "v3"}))
650 .is_ok());
651 }
652
653 #[test]
654 fn tool_call_serialization() {
655 let call = ToolCall::builder("test").arg_str("name", "value").build();
656
657 let json = serde_json::to_string(&call).unwrap();
658 let parsed: ToolCall = serde_json::from_str(&json).unwrap();
659
660 assert_eq!(call, parsed);
661 }
662
663 #[test]
664 fn tool_result_serialization() {
665 let result = ToolResult::success(json!({"data": [1, 2, 3]})).with_duration(100);
666
667 let json = serde_json::to_string(&result).unwrap();
668 let parsed: ToolResult = serde_json::from_str(&json).unwrap();
669
670 assert_eq!(result, parsed);
671 }
672
673 #[test]
674 fn parse_mcp_input_schema_basic() {
675 let schema = json!({
676 "type": "object",
677 "properties": {
678 "name": {
679 "type": "string",
680 "description": "The name"
681 },
682 "age": {
683 "type": "integer",
684 "description": "The age"
685 }
686 },
687 "required": ["name"]
688 });
689
690 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
691 assert_eq!(params.len(), 2);
692
693 let name_param = params.iter().find(|p| p.name == "name").unwrap();
694 assert_eq!(name_param.param_type, ParameterType::String);
695 assert_eq!(name_param.description, "The name");
696 assert!(name_param.required);
697
698 let age_param = params.iter().find(|p| p.name == "age").unwrap();
699 assert_eq!(age_param.param_type, ParameterType::Integer);
700 assert_eq!(age_param.description, "The age");
701 assert!(!age_param.required);
702 }
703
704 #[test]
705 fn parse_mcp_input_schema_all_types() {
706 let schema = json!({
707 "type": "object",
708 "properties": {
709 "str": {"type": "string"},
710 "num": {"type": "number"},
711 "int": {"type": "integer"},
712 "bool": {"type": "boolean"},
713 "arr": {"type": "array"},
714 "obj": {"type": "object"}
715 }
716 });
717
718 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
719 assert_eq!(params.len(), 6);
720
721 assert_eq!(
722 params.iter().find(|p| p.name == "str").unwrap().param_type,
723 ParameterType::String
724 );
725 assert_eq!(
726 params.iter().find(|p| p.name == "num").unwrap().param_type,
727 ParameterType::Integer
728 );
729 assert_eq!(
730 params.iter().find(|p| p.name == "int").unwrap().param_type,
731 ParameterType::Integer
732 );
733 assert_eq!(
734 params.iter().find(|p| p.name == "bool").unwrap().param_type,
735 ParameterType::Boolean
736 );
737 assert_eq!(
738 params.iter().find(|p| p.name == "arr").unwrap().param_type,
739 ParameterType::Array
740 );
741 assert_eq!(
742 params.iter().find(|p| p.name == "obj").unwrap().param_type,
743 ParameterType::Object
744 );
745 }
746
747 #[test]
748 fn parse_mcp_input_schema_empty() {
749 let schema = json!({});
750 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
751 assert_eq!(params.len(), 0);
752 }
753
754 #[test]
755 fn parse_mcp_input_schema_no_properties() {
756 let schema = json!({"type": "object"});
757 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
758 assert_eq!(params.len(), 0);
759 }
760
761 #[test]
762 fn parse_mcp_input_schema_unknown_type() {
763 let schema = json!({
764 "type": "object",
765 "properties": {
766 "unknown": {"type": "unknown_type"}
767 }
768 });
769
770 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
771 assert_eq!(params.len(), 1);
772 assert_eq!(params[0].param_type, ParameterType::String);
774 }
775
776 #[test]
777 fn parse_mcp_input_schema_missing_type() {
778 let schema = json!({
779 "type": "object",
780 "properties": {
781 "field": {"description": "A field without type"}
782 }
783 });
784
785 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
786 assert_eq!(params.len(), 1);
787 assert_eq!(params[0].param_type, ParameterType::String);
789 }
790
791 #[test]
792 fn parse_mcp_input_schema_all_required() {
793 let schema = json!({
794 "type": "object",
795 "properties": {
796 "field1": {"type": "string"},
797 "field2": {"type": "string"},
798 "field3": {"type": "string"}
799 },
800 "required": ["field1", "field2", "field3"]
801 });
802
803 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
804 assert_eq!(params.len(), 3);
805 assert!(params.iter().all(|p| p.required));
806 }
807
808 #[test]
809 fn parse_mcp_input_schema_no_required() {
810 let schema = json!({
811 "type": "object",
812 "properties": {
813 "field1": {"type": "string"},
814 "field2": {"type": "string"}
815 }
816 });
817
818 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
819 assert_eq!(params.len(), 2);
820 assert!(params.iter().all(|p| !p.required));
821 }
822
823 #[test]
824 fn parse_mcp_input_schema_no_description() {
825 let schema = json!({
826 "type": "object",
827 "properties": {
828 "field": {"type": "string"}
829 }
830 });
831
832 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
833 assert_eq!(params.len(), 1);
834 assert_eq!(params[0].description, "");
835 }
836}