1use crate::{Error, Parameter, ParameterType, Result};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
9pub struct ToolDefinition {
10 pub name: String,
12
13 #[serde(default)]
15 pub description: String,
16
17 #[serde(default)]
19 pub parameters: Vec<Parameter>,
20}
21
22impl ToolDefinition {
23 pub fn new(name: impl Into<String>) -> Self {
25 Self {
26 name: name.into(),
27 description: String::new(),
28 parameters: Vec::new(),
29 }
30 }
31
32 pub fn builder(name: impl Into<String>) -> ToolDefinitionBuilder {
34 ToolDefinitionBuilder::new(name)
35 }
36
37 pub fn get_parameter(&self, name: &str) -> Option<&Parameter> {
39 self.parameters.iter().find(|p| p.name == name)
40 }
41
42 pub fn required_parameters(&self) -> impl Iterator<Item = &Parameter> {
44 self.parameters.iter().filter(|p| p.required)
45 }
46
47 pub fn validate_args(&self, args: &Value) -> Result<()> {
49 let empty_map = serde_json::Map::new();
50 let args_obj = args.as_object().unwrap_or(&empty_map);
51
52 for param in self.required_parameters() {
54 if !args_obj.contains_key(¶m.name) {
55 if param.default.is_none() {
57 return Err(Error::MissingParameter(param.name.clone()));
58 }
59 }
60 }
61
62 for (key, value) in args_obj {
64 if let Some(param) = self.get_parameter(key) {
65 if !param.param_type.matches(value) {
66 return Err(Error::InvalidParameterType {
67 name: key.clone(),
68 expected: param.param_type.as_str().to_string(),
69 actual: json_type_name(value).to_string(),
70 });
71 }
72
73 if !param.enum_values.is_empty() && !param.enum_values.contains(value) {
75 return Err(Error::InvalidConfig(format!(
76 "parameter '{}' must be one of: {:?}",
77 key, param.enum_values
78 )));
79 }
80 }
81 }
82
83 Ok(())
84 }
85
86 pub fn parse_mcp_input_schema(schema: &serde_json::Value) -> Result<Vec<Parameter>> {
88 let mut params = Vec::new();
89
90 if let Some(properties) = schema.get("properties") {
91 if let Some(props_obj) = properties.as_object() {
92 for (name, prop) in props_obj {
93 let param_type = if let Some(type_val) = prop.get("type") {
94 match type_val.as_str() {
95 Some("string") => ParameterType::String,
96 Some("integer") => ParameterType::Integer,
97 Some("number") => ParameterType::Number,
98 Some("boolean") => ParameterType::Boolean,
99 Some("array") => ParameterType::Array,
100 Some("object") => ParameterType::Object,
101 _ => ParameterType::String,
102 }
103 } else {
104 ParameterType::String
105 };
106
107 let description = prop
108 .get("description")
109 .and_then(|v| v.as_str())
110 .unwrap_or("")
111 .to_string();
112
113 let required = if let Some(req_array) = schema.get("required") {
114 if let Some(arr) = req_array.as_array() {
115 arr.iter().any(|v| v.as_str() == Some(name))
116 } else {
117 false
118 }
119 } else {
120 false
121 };
122
123 params.push(Parameter {
124 name: name.to_string(),
125 param_type,
126 description,
127 required,
128 default: None,
129 enum_values: vec![],
130 });
131 }
132 }
133 }
134
135 Ok(params)
136 }
137}
138
139fn json_type_name(value: &Value) -> &'static str {
141 match value {
142 Value::Null => "null",
143 Value::Bool(_) => "boolean",
144 Value::Number(_) => "number",
145 Value::String(_) => "string",
146 Value::Array(_) => "array",
147 Value::Object(_) => "object",
148 }
149}
150
151#[derive(Debug, Default)]
153pub struct ToolDefinitionBuilder {
154 name: String,
155 description: String,
156 parameters: Vec<Parameter>,
157}
158
159impl ToolDefinitionBuilder {
160 pub fn new(name: impl Into<String>) -> Self {
162 Self {
163 name: name.into(),
164 ..Default::default()
165 }
166 }
167
168 pub fn description(mut self, description: impl Into<String>) -> Self {
170 self.description = description.into();
171 self
172 }
173
174 pub fn parameter(mut self, parameter: Parameter) -> Self {
176 self.parameters.push(parameter);
177 self
178 }
179
180 pub fn parameters(mut self, parameters: impl IntoIterator<Item = Parameter>) -> Self {
182 self.parameters.extend(parameters);
183 self
184 }
185
186 pub fn build(self) -> ToolDefinition {
188 ToolDefinition {
189 name: self.name,
190 description: self.description,
191 parameters: self.parameters,
192 }
193 }
194}
195
196#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
198pub struct ToolCall {
199 pub tool: String,
201
202 #[serde(default)]
204 pub arguments: Value,
205}
206
207impl ToolCall {
208 pub fn new(tool: impl Into<String>) -> Self {
210 Self {
211 tool: tool.into(),
212 arguments: Value::Object(serde_json::Map::new()),
213 }
214 }
215
216 pub fn with_args(tool: impl Into<String>, arguments: Value) -> Self {
218 Self {
219 tool: tool.into(),
220 arguments,
221 }
222 }
223
224 pub fn builder(tool: impl Into<String>) -> ToolCallBuilder {
226 ToolCallBuilder::new(tool)
227 }
228}
229
230#[derive(Debug, Default)]
232pub struct ToolCallBuilder {
233 tool: String,
234 arguments: serde_json::Map<String, Value>,
235}
236
237impl ToolCallBuilder {
238 pub fn new(tool: impl Into<String>) -> Self {
240 Self {
241 tool: tool.into(),
242 arguments: serde_json::Map::new(),
243 }
244 }
245
246 pub fn arg(mut self, name: impl Into<String>, value: Value) -> Self {
248 self.arguments.insert(name.into(), value);
249 self
250 }
251
252 pub fn arg_str(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
254 self.arguments
255 .insert(name.into(), Value::String(value.into()));
256 self
257 }
258
259 pub fn arg_int(mut self, name: impl Into<String>, value: i64) -> Self {
261 self.arguments
262 .insert(name.into(), Value::Number(value.into()));
263 self
264 }
265
266 pub fn arg_bool(mut self, name: impl Into<String>, value: bool) -> Self {
268 self.arguments.insert(name.into(), Value::Bool(value));
269 self
270 }
271
272 pub fn build(self) -> ToolCall {
274 ToolCall {
275 tool: self.tool,
276 arguments: Value::Object(self.arguments),
277 }
278 }
279}
280
281#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
283pub struct ToolResult {
284 pub success: bool,
286
287 #[serde(default, skip_serializing_if = "Option::is_none")]
289 pub data: Option<Value>,
290
291 #[serde(default, skip_serializing_if = "Option::is_none")]
293 pub error: Option<String>,
294
295 #[serde(default, skip_serializing_if = "Option::is_none")]
297 pub duration_ms: Option<u64>,
298}
299
300impl ToolResult {
301 pub fn success(data: Value) -> Self {
303 Self {
304 success: true,
305 data: Some(data),
306 error: None,
307 duration_ms: None,
308 }
309 }
310
311 pub fn failure(error: impl Into<String>) -> Self {
313 Self {
314 success: false,
315 data: None,
316 error: Some(error.into()),
317 duration_ms: None,
318 }
319 }
320
321 pub fn with_duration(mut self, duration_ms: u64) -> Self {
323 self.duration_ms = Some(duration_ms);
324 self
325 }
326
327 pub fn is_success(&self) -> bool {
329 self.success
330 }
331
332 pub fn into_data(self) -> Result<Value> {
334 if self.success {
335 self.data
336 .ok_or_else(|| Error::ExecutionFailed("no data returned".to_string()))
337 } else {
338 Err(Error::ExecutionFailed(
339 self.error.unwrap_or_else(|| "unknown error".to_string()),
340 ))
341 }
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::ParameterType;
349 use serde_json::json;
350
351 #[test]
352 fn tool_definition_new() {
353 let tool = ToolDefinition::new("test_tool");
354 assert_eq!(tool.name, "test_tool");
355 assert!(tool.description.is_empty());
356 assert!(tool.parameters.is_empty());
357 }
358
359 #[test]
360 fn tool_definition_builder() {
361 let tool = ToolDefinition::builder("read_file")
362 .description("Read a file")
363 .parameter(Parameter::required_string("path"))
364 .build();
365
366 assert_eq!(tool.name, "read_file");
367 assert_eq!(tool.description, "Read a file");
368 assert_eq!(tool.parameters.len(), 1);
369 }
370
371 #[test]
372 fn tool_definition_get_parameter() {
373 let tool = ToolDefinition::builder("test")
374 .parameter(Parameter::required_string("path"))
375 .parameter(Parameter::optional_string("encoding"))
376 .build();
377
378 assert!(tool.get_parameter("path").is_some());
379 assert!(tool.get_parameter("encoding").is_some());
380 assert!(tool.get_parameter("nonexistent").is_none());
381 }
382
383 #[test]
384 fn tool_definition_required_parameters() {
385 let tool = ToolDefinition::builder("test")
386 .parameter(Parameter::required_string("required1"))
387 .parameter(Parameter::optional_string("optional1"))
388 .parameter(Parameter::required_string("required2"))
389 .build();
390
391 let required: Vec<_> = tool.required_parameters().collect();
392 assert_eq!(required.len(), 2);
393 assert!(required.iter().any(|p| p.name == "required1"));
394 assert!(required.iter().any(|p| p.name == "required2"));
395 }
396
397 #[test]
398 fn tool_definition_validate_args_success() {
399 let tool = ToolDefinition::builder("test")
400 .parameter(Parameter::required_string("name"))
401 .parameter(
402 Parameter::builder("count")
403 .param_type(ParameterType::Integer)
404 .build(),
405 )
406 .build();
407
408 let args = json!({"name": "test", "count": 5});
409 assert!(tool.validate_args(&args).is_ok());
410 }
411
412 #[test]
413 fn tool_definition_validate_args_missing_required() {
414 let tool = ToolDefinition::builder("test")
415 .parameter(Parameter::required_string("name"))
416 .build();
417
418 let args = json!({});
419 let result = tool.validate_args(&args);
420 assert!(matches!(result, Err(Error::MissingParameter(_))));
421 }
422
423 #[test]
424 fn tool_definition_validate_args_with_default() {
425 let tool = ToolDefinition::builder("test")
426 .parameter(
427 Parameter::builder("count")
428 .required(true)
429 .default(json!(10))
430 .build(),
431 )
432 .build();
433
434 let args = json!({});
435 assert!(tool.validate_args(&args).is_ok());
436 }
437
438 #[test]
439 fn tool_definition_validate_args_wrong_type() {
440 let tool = ToolDefinition::builder("test")
441 .parameter(
442 Parameter::builder("count")
443 .param_type(ParameterType::Integer)
444 .build(),
445 )
446 .build();
447
448 let args = json!({"count": "not a number"});
449 let result = tool.validate_args(&args);
450 assert!(matches!(result, Err(Error::InvalidParameterType { .. })));
451 }
452
453 #[test]
454 fn tool_definition_validate_args_enum() {
455 let tool = ToolDefinition::builder("test")
456 .parameter(
457 Parameter::builder("format")
458 .enum_value(json!("json"))
459 .enum_value(json!("yaml"))
460 .build(),
461 )
462 .build();
463
464 assert!(tool.validate_args(&json!({"format": "json"})).is_ok());
465 assert!(tool.validate_args(&json!({"format": "yaml"})).is_ok());
466 assert!(tool.validate_args(&json!({"format": "toml"})).is_err());
467 }
468
469 #[test]
470 fn tool_call_new() {
471 let call = ToolCall::new("test_tool");
472 assert_eq!(call.tool, "test_tool");
473 assert!(call.arguments.is_object());
474 }
475
476 #[test]
477 fn tool_call_with_args() {
478 let args = json!({"path": "/tmp/test.txt"});
479 let call = ToolCall::with_args("read_file", args.clone());
480 assert_eq!(call.tool, "read_file");
481 assert_eq!(call.arguments, args);
482 }
483
484 #[test]
485 fn tool_call_builder() {
486 let call = ToolCall::builder("github.list_repos")
487 .arg_str("owner", "octocat")
488 .arg_int("per_page", 10)
489 .arg_bool("include_forks", false)
490 .build();
491
492 assert_eq!(call.tool, "github.list_repos");
493 assert_eq!(call.arguments["owner"], "octocat");
494 assert_eq!(call.arguments["per_page"], 10);
495 assert_eq!(call.arguments["include_forks"], false);
496 }
497
498 #[test]
499 fn tool_result_success() {
500 let result = ToolResult::success(json!({"status": "ok"}));
501 assert!(result.is_success());
502 assert!(result.data.is_some());
503 assert!(result.error.is_none());
504 }
505
506 #[test]
507 fn tool_result_failure() {
508 let result = ToolResult::failure("Something went wrong");
509 assert!(!result.is_success());
510 assert!(result.data.is_none());
511 assert_eq!(result.error, Some("Something went wrong".to_string()));
512 }
513
514 #[test]
515 fn tool_result_with_duration() {
516 let result = ToolResult::success(json!(null)).with_duration(250);
517 assert_eq!(result.duration_ms, Some(250));
518 }
519
520 #[test]
521 fn tool_result_into_data_success() {
522 let result = ToolResult::success(json!({"value": 42}));
523 let data = result.into_data().unwrap();
524 assert_eq!(data["value"], 42);
525 }
526
527 #[test]
528 fn tool_result_into_data_failure() {
529 let result = ToolResult::failure("error");
530 let err = result.into_data().unwrap_err();
531 assert!(matches!(err, Error::ExecutionFailed(_)));
532 }
533
534 #[test]
535 fn tool_definition_serialization() {
536 let tool = ToolDefinition::builder("test")
537 .description("A test tool")
538 .parameter(Parameter::required_string("name"))
539 .build();
540
541 let json = serde_json::to_string(&tool).unwrap();
542 let parsed: ToolDefinition = serde_json::from_str(&json).unwrap();
543
544 assert_eq!(tool, parsed);
545 }
546
547 #[test]
548 fn tool_definition_validate_args_edge_cases() {
549 let tool = ToolDefinition::builder("test")
550 .parameter(
551 Parameter::builder("array_param")
552 .param_type(ParameterType::Array)
553 .required(true)
554 .build(),
555 )
556 .parameter(
557 Parameter::builder("object_param")
558 .param_type(ParameterType::Object)
559 .required(false)
560 .default(json!({}))
561 .build(),
562 )
563 .build();
564
565 assert!(tool
567 .validate_args(&json!({"array_param": [1, 2, 3], "object_param": {"key": "value"}}))
568 .is_ok());
569
570 assert!(tool.validate_args(&json!({"array_param": []})).is_ok());
572
573 assert!(tool.validate_args(&json!({"array_param": {}})).is_err());
575
576 assert!(tool
578 .validate_args(&json!({"array_param": [], "object_param": "not an object"}))
579 .is_err());
580 }
581
582 #[test]
583 fn tool_definition_validate_args_with_all_types() {
584 let tool = ToolDefinition::builder("test")
585 .parameter(Parameter::required_string("str_param"))
586 .parameter(
587 Parameter::builder("int_param")
588 .param_type(ParameterType::Integer)
589 .required(true)
590 .build(),
591 )
592 .parameter(
593 Parameter::builder("num_param")
594 .param_type(ParameterType::Number)
595 .required(true)
596 .build(),
597 )
598 .parameter(
599 Parameter::builder("bool_param")
600 .param_type(ParameterType::Boolean)
601 .required(true)
602 .build(),
603 )
604 .parameter(
605 Parameter::builder("arr_param")
606 .param_type(ParameterType::Array)
607 .required(true)
608 .build(),
609 )
610 .parameter(
611 Parameter::builder("obj_param")
612 .param_type(ParameterType::Object)
613 .required(true)
614 .build(),
615 )
616 .build();
617
618 let args = json!({
619 "str_param": "test",
620 "int_param": 42,
621 "num_param": 3.14,
622 "bool_param": true,
623 "arr_param": [1, 2, 3],
624 "obj_param": {"key": "value"}
625 });
626
627 assert!(tool.validate_args(&args).is_ok());
628
629 assert!(tool.validate_args(&json!({"str_param": 42, "int_param": 42, "num_param": 3.14, "bool_param": true, "arr_param": [], "obj_param": {}})).is_err());
631 assert!(tool.validate_args(&json!({"str_param": "test", "int_param": "not int", "num_param": 3.14, "bool_param": true, "arr_param": [], "obj_param": {}})).is_err());
632 }
633
634 #[test]
635 fn tool_definition_validate_args_empty_required() {
636 let tool = ToolDefinition::builder("test")
637 .parameter(Parameter::required_string("param1"))
638 .parameter(Parameter::required_string("param2"))
639 .parameter(Parameter::required_string("param3"))
640 .build();
641
642 assert!(tool.validate_args(&json!({})).is_err());
644
645 assert!(tool.validate_args(&json!({"param1": "value"})).is_err());
647
648 assert!(tool
650 .validate_args(&json!({"param1": "v1", "param2": "v2", "param3": "v3"}))
651 .is_ok());
652 }
653
654 #[test]
655 fn tool_call_serialization() {
656 let call = ToolCall::builder("test").arg_str("name", "value").build();
657
658 let json = serde_json::to_string(&call).unwrap();
659 let parsed: ToolCall = serde_json::from_str(&json).unwrap();
660
661 assert_eq!(call, parsed);
662 }
663
664 #[test]
665 fn tool_result_serialization() {
666 let result = ToolResult::success(json!({"data": [1, 2, 3]})).with_duration(100);
667
668 let json = serde_json::to_string(&result).unwrap();
669 let parsed: ToolResult = serde_json::from_str(&json).unwrap();
670
671 assert_eq!(result, parsed);
672 }
673
674 #[test]
675 fn parse_mcp_input_schema_basic() {
676 let schema = json!({
677 "type": "object",
678 "properties": {
679 "name": {
680 "type": "string",
681 "description": "The name"
682 },
683 "age": {
684 "type": "integer",
685 "description": "The age"
686 }
687 },
688 "required": ["name"]
689 });
690
691 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
692 assert_eq!(params.len(), 2);
693
694 let name_param = params.iter().find(|p| p.name == "name").unwrap();
695 assert_eq!(name_param.param_type, ParameterType::String);
696 assert_eq!(name_param.description, "The name");
697 assert!(name_param.required);
698
699 let age_param = params.iter().find(|p| p.name == "age").unwrap();
700 assert_eq!(age_param.param_type, ParameterType::Integer);
701 assert_eq!(age_param.description, "The age");
702 assert!(!age_param.required);
703 }
704
705 #[test]
706 fn parse_mcp_input_schema_all_types() {
707 let schema = json!({
708 "type": "object",
709 "properties": {
710 "str": {"type": "string"},
711 "num": {"type": "number"},
712 "int": {"type": "integer"},
713 "bool": {"type": "boolean"},
714 "arr": {"type": "array"},
715 "obj": {"type": "object"}
716 }
717 });
718
719 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
720 assert_eq!(params.len(), 6);
721
722 assert_eq!(
723 params.iter().find(|p| p.name == "str").unwrap().param_type,
724 ParameterType::String
725 );
726 assert_eq!(
727 params.iter().find(|p| p.name == "num").unwrap().param_type,
728 ParameterType::Number
729 );
730 assert_eq!(
731 params.iter().find(|p| p.name == "int").unwrap().param_type,
732 ParameterType::Integer
733 );
734 assert_eq!(
735 params.iter().find(|p| p.name == "bool").unwrap().param_type,
736 ParameterType::Boolean
737 );
738 assert_eq!(
739 params.iter().find(|p| p.name == "arr").unwrap().param_type,
740 ParameterType::Array
741 );
742 assert_eq!(
743 params.iter().find(|p| p.name == "obj").unwrap().param_type,
744 ParameterType::Object
745 );
746 }
747
748 #[test]
749 fn parse_mcp_input_schema_empty() {
750 let schema = json!({});
751 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
752 assert_eq!(params.len(), 0);
753 }
754
755 #[test]
756 fn parse_mcp_input_schema_no_properties() {
757 let schema = json!({"type": "object"});
758 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
759 assert_eq!(params.len(), 0);
760 }
761
762 #[test]
763 fn parse_mcp_input_schema_unknown_type() {
764 let schema = json!({
765 "type": "object",
766 "properties": {
767 "unknown": {"type": "unknown_type"}
768 }
769 });
770
771 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
772 assert_eq!(params.len(), 1);
773 assert_eq!(params[0].param_type, ParameterType::String);
775 }
776
777 #[test]
778 fn parse_mcp_input_schema_missing_type() {
779 let schema = json!({
780 "type": "object",
781 "properties": {
782 "field": {"description": "A field without type"}
783 }
784 });
785
786 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
787 assert_eq!(params.len(), 1);
788 assert_eq!(params[0].param_type, ParameterType::String);
790 }
791
792 #[test]
793 fn parse_mcp_input_schema_all_required() {
794 let schema = json!({
795 "type": "object",
796 "properties": {
797 "field1": {"type": "string"},
798 "field2": {"type": "string"},
799 "field3": {"type": "string"}
800 },
801 "required": ["field1", "field2", "field3"]
802 });
803
804 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
805 assert_eq!(params.len(), 3);
806 assert!(params.iter().all(|p| p.required));
807 }
808
809 #[test]
810 fn parse_mcp_input_schema_no_required() {
811 let schema = json!({
812 "type": "object",
813 "properties": {
814 "field1": {"type": "string"},
815 "field2": {"type": "string"}
816 }
817 });
818
819 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
820 assert_eq!(params.len(), 2);
821 assert!(params.iter().all(|p| !p.required));
822 }
823
824 #[test]
825 fn parse_mcp_input_schema_no_description() {
826 let schema = json!({
827 "type": "object",
828 "properties": {
829 "field": {"type": "string"}
830 }
831 });
832
833 let params = ToolDefinition::parse_mcp_input_schema(&schema).unwrap();
834 assert_eq!(params.len(), 1);
835 assert_eq!(params[0].description, "");
836 }
837}