1pub mod executor;
9
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
19#[serde(rename_all = "lowercase")]
20pub enum ParamType {
21 String,
22 Number,
23 Integer,
24 Boolean,
25 Array,
26 Object,
27}
28
29impl ParamType {
30 fn as_schema_str(self) -> &'static str {
31 match self {
32 Self::String => "string",
33 Self::Number => "number",
34 Self::Integer => "integer",
35 Self::Boolean => "boolean",
36 Self::Array => "array",
37 Self::Object => "object",
38 }
39 }
40}
41
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
44pub struct ToolParameter {
45 pub name: String,
46 pub description: String,
47 pub param_type: ParamType,
48 pub required: bool,
49}
50
51impl ToolParameter {
52 pub fn new(name: &str, description: &str, param_type: ParamType, required: bool) -> Self {
53 Self {
54 name: name.to_string(),
55 description: description.to_string(),
56 param_type,
57 required,
58 }
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
68pub enum ToolCategory {
69 Perception,
70 Navigation,
71 Cognition,
72 Swarm,
73 Memory,
74 Planning,
75}
76
77#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
83pub struct ToolDefinition {
84 pub name: String,
85 pub description: String,
86 pub parameters: Vec<ToolParameter>,
87 pub category: ToolCategory,
88}
89
90impl ToolDefinition {
91 pub fn new(
92 name: &str,
93 description: &str,
94 parameters: Vec<ToolParameter>,
95 category: ToolCategory,
96 ) -> Self {
97 Self {
98 name: name.to_string(),
99 description: description.to_string(),
100 parameters,
101 category,
102 }
103 }
104
105 fn to_schema(&self) -> serde_json::Value {
107 let mut properties = serde_json::Map::new();
108 let mut required: Vec<serde_json::Value> = Vec::new();
109
110 for param in &self.parameters {
111 let mut prop = serde_json::Map::new();
112 prop.insert(
113 "type".to_string(),
114 serde_json::Value::String(param.param_type.as_schema_str().to_string()),
115 );
116 prop.insert(
117 "description".to_string(),
118 serde_json::Value::String(param.description.clone()),
119 );
120 properties.insert(param.name.clone(), serde_json::Value::Object(prop));
121
122 if param.required {
123 required.push(serde_json::Value::String(param.name.clone()));
124 }
125 }
126
127 serde_json::json!({
128 "name": self.name,
129 "description": self.description,
130 "inputSchema": {
131 "type": "object",
132 "properties": properties,
133 "required": required,
134 }
135 })
136 }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct ToolRequest {
146 pub tool_name: String,
147 pub arguments: HashMap<String, serde_json::Value>,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct ToolResponse {
153 pub success: bool,
154 pub result: serde_json::Value,
155 pub error: Option<String>,
156 pub latency_us: u64,
157}
158
159impl ToolResponse {
160 pub fn ok(result: serde_json::Value, latency_us: u64) -> Self {
162 Self { success: true, result, error: None, latency_us }
163 }
164
165 pub fn err(message: impl Into<String>, latency_us: u64) -> Self {
167 Self {
168 success: false,
169 result: serde_json::Value::Null,
170 error: Some(message.into()),
171 latency_us,
172 }
173 }
174}
175
176#[derive(Debug, Clone)]
186pub struct RoboticsToolRegistry {
187 tools: HashMap<String, ToolDefinition>,
188}
189
190impl Default for RoboticsToolRegistry {
191 fn default() -> Self {
192 Self::new()
193 }
194}
195
196impl RoboticsToolRegistry {
197 pub fn new() -> Self {
199 let mut registry = Self { tools: HashMap::new() };
200 registry.register_defaults();
201 registry
202 }
203
204 pub fn empty() -> Self {
206 Self { tools: HashMap::new() }
207 }
208
209 pub fn register_tool(&mut self, tool: ToolDefinition) {
211 self.tools.insert(tool.name.clone(), tool);
212 }
213
214 pub fn list_tools(&self) -> Vec<&ToolDefinition> {
216 self.tools.values().collect()
217 }
218
219 pub fn get_tool(&self, name: &str) -> Option<&ToolDefinition> {
221 self.tools.get(name)
222 }
223
224 pub fn list_by_category(&self, category: ToolCategory) -> Vec<&ToolDefinition> {
226 self.tools.values().filter(|t| t.category == category).collect()
227 }
228
229 pub fn to_mcp_schema(&self) -> serde_json::Value {
231 let mut tools: Vec<serde_json::Value> =
232 self.tools.values().map(|t| t.to_schema()).collect();
233 tools.sort_by(|a, b| {
235 let na = a.get("name").and_then(|v| v.as_str()).unwrap_or("");
236 let nb = b.get("name").and_then(|v| v.as_str()).unwrap_or("");
237 na.cmp(nb)
238 });
239 serde_json::json!({ "tools": tools })
240 }
241
242 fn register_defaults(&mut self) {
245 self.register_tool(ToolDefinition::new(
246 "detect_obstacles",
247 "Detect obstacles in a point cloud relative to the robot position",
248 vec![
249 ToolParameter::new(
250 "point_cloud_json", "JSON-encoded point cloud", ParamType::String, true,
251 ),
252 ToolParameter::new(
253 "robot_position", "Robot [x,y,z] position", ParamType::Array, true,
254 ),
255 ToolParameter::new(
256 "max_distance", "Maximum detection distance in meters", ParamType::Number, false,
257 ),
258 ],
259 ToolCategory::Perception,
260 ));
261
262 self.register_tool(ToolDefinition::new(
263 "build_scene_graph",
264 "Build a scene graph from detected objects with spatial edges",
265 vec![
266 ToolParameter::new(
267 "objects_json", "JSON array of scene objects", ParamType::String, true,
268 ),
269 ToolParameter::new(
270 "max_edge_distance", "Maximum edge distance in meters", ParamType::Number, false,
271 ),
272 ],
273 ToolCategory::Perception,
274 ));
275
276 self.register_tool(ToolDefinition::new(
277 "predict_trajectory",
278 "Predict future trajectory from current position and velocity",
279 vec![
280 ToolParameter::new("position", "Current [x,y,z] position", ParamType::Array, true),
281 ToolParameter::new("velocity", "Current [vx,vy,vz] velocity", ParamType::Array, true),
282 ToolParameter::new("steps", "Number of prediction steps", ParamType::Integer, true),
283 ToolParameter::new("dt", "Time step in seconds", ParamType::Number, false),
284 ],
285 ToolCategory::Navigation,
286 ));
287
288 self.register_tool(ToolDefinition::new(
289 "focus_attention",
290 "Extract a region of interest from a point cloud by center and radius",
291 vec![
292 ToolParameter::new(
293 "point_cloud_json", "JSON-encoded point cloud", ParamType::String, true,
294 ),
295 ToolParameter::new("center", "Focus center [x,y,z]", ParamType::Array, true),
296 ToolParameter::new("radius", "Attention radius in meters", ParamType::Number, true),
297 ],
298 ToolCategory::Perception,
299 ));
300
301 self.register_tool(ToolDefinition::new(
302 "detect_anomalies",
303 "Detect anomalous points in a point cloud using statistical analysis",
304 vec![
305 ToolParameter::new(
306 "point_cloud_json", "JSON-encoded point cloud", ParamType::String, true,
307 ),
308 ],
309 ToolCategory::Perception,
310 ));
311
312 self.register_tool(ToolDefinition::new(
313 "spatial_search",
314 "Search for nearest neighbours in the spatial index",
315 vec![
316 ToolParameter::new("query", "Query vector [x,y,z]", ParamType::Array, true),
317 ToolParameter::new("k", "Number of neighbours to return", ParamType::Integer, true),
318 ],
319 ToolCategory::Perception,
320 ));
321
322 self.register_tool(ToolDefinition::new(
323 "insert_points",
324 "Insert points into the spatial index for later retrieval",
325 vec![
326 ToolParameter::new(
327 "points_json", "JSON array of [x,y,z] points", ParamType::String, true,
328 ),
329 ],
330 ToolCategory::Perception,
331 ));
332
333 self.register_tool(ToolDefinition::new(
334 "store_memory",
335 "Store a vector in episodic memory with an importance score",
336 vec![
337 ToolParameter::new("key", "Unique memory key", ParamType::String, true),
338 ToolParameter::new("data", "Data vector to store", ParamType::Array, true),
339 ToolParameter::new(
340 "importance", "Importance weight 0.0-1.0", ParamType::Number, false,
341 ),
342 ],
343 ToolCategory::Memory,
344 ));
345
346 self.register_tool(ToolDefinition::new(
347 "recall_memory",
348 "Recall the k most similar memories to a query vector",
349 vec![
350 ToolParameter::new(
351 "query", "Query vector for similarity search", ParamType::Array, true,
352 ),
353 ToolParameter::new("k", "Number of memories to recall", ParamType::Integer, true),
354 ],
355 ToolCategory::Memory,
356 ));
357
358 self.register_tool(ToolDefinition::new(
359 "learn_skill",
360 "Learn a new skill from demonstration trajectories",
361 vec![
362 ToolParameter::new("name", "Skill name identifier", ParamType::String, true),
363 ToolParameter::new(
364 "demonstrations_json",
365 "JSON array of demonstration trajectories",
366 ParamType::String,
367 true,
368 ),
369 ],
370 ToolCategory::Cognition,
371 ));
372
373 self.register_tool(ToolDefinition::new(
374 "execute_skill",
375 "Execute a previously learned skill by name",
376 vec![
377 ToolParameter::new("name", "Name of the skill to execute", ParamType::String, true),
378 ],
379 ToolCategory::Cognition,
380 ));
381
382 self.register_tool(ToolDefinition::new(
383 "plan_behavior",
384 "Generate a behavior tree plan for a given goal and preconditions",
385 vec![
386 ToolParameter::new("goal", "Goal description", ParamType::String, true),
387 ToolParameter::new(
388 "conditions_json",
389 "JSON object of current conditions",
390 ParamType::String,
391 false,
392 ),
393 ],
394 ToolCategory::Planning,
395 ));
396
397 self.register_tool(ToolDefinition::new(
398 "coordinate_swarm",
399 "Coordinate a multi-robot swarm for a given task",
400 vec![
401 ToolParameter::new(
402 "task_json", "JSON-encoded task specification", ParamType::String, true,
403 ),
404 ],
405 ToolCategory::Swarm,
406 ));
407
408 self.register_tool(ToolDefinition::new(
409 "update_world_model",
410 "Update the internal world model with a new or changed object",
411 vec![
412 ToolParameter::new(
413 "object_json", "JSON-encoded object to upsert", ParamType::String, true,
414 ),
415 ],
416 ToolCategory::Cognition,
417 ));
418
419 self.register_tool(ToolDefinition::new(
420 "get_world_state",
421 "Retrieve the current world model state, optionally filtered by object id",
422 vec![
423 ToolParameter::new(
424 "object_id", "Optional object id to filter", ParamType::Integer, false,
425 ),
426 ],
427 ToolCategory::Cognition,
428 ));
429 }
430}
431
432#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[test]
441 fn test_registry_has_15_default_tools() {
442 let registry = RoboticsToolRegistry::new();
443 assert_eq!(registry.list_tools().len(), 15);
444 }
445
446 #[test]
447 fn test_list_tools_returns_all() {
448 let registry = RoboticsToolRegistry::new();
449 let tools = registry.list_tools();
450 let mut names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
451 names.sort();
452
453 let expected = vec![
454 "build_scene_graph",
455 "coordinate_swarm",
456 "detect_anomalies",
457 "detect_obstacles",
458 "execute_skill",
459 "focus_attention",
460 "get_world_state",
461 "insert_points",
462 "learn_skill",
463 "plan_behavior",
464 "predict_trajectory",
465 "recall_memory",
466 "spatial_search",
467 "store_memory",
468 "update_world_model",
469 ];
470 assert_eq!(names, expected);
471 }
472
473 #[test]
474 fn test_get_tool_by_name() {
475 let registry = RoboticsToolRegistry::new();
476
477 let tool = registry.get_tool("detect_obstacles").unwrap();
478 assert_eq!(tool.category, ToolCategory::Perception);
479 assert_eq!(tool.parameters.len(), 3);
480 assert!(tool.parameters.iter().any(|p| p.name == "point_cloud_json" && p.required));
481
482 let tool = registry.get_tool("predict_trajectory").unwrap();
483 assert_eq!(tool.category, ToolCategory::Navigation);
484 assert_eq!(tool.parameters.len(), 4);
485
486 assert!(registry.get_tool("nonexistent").is_none());
487 }
488
489 #[test]
490 fn test_list_by_category_perception() {
491 let registry = RoboticsToolRegistry::new();
492 let perception = registry.list_by_category(ToolCategory::Perception);
493 assert_eq!(perception.len(), 6);
494 for tool in &perception {
495 assert_eq!(tool.category, ToolCategory::Perception);
496 }
497 }
498
499 #[test]
500 fn test_list_by_category_counts() {
501 let registry = RoboticsToolRegistry::new();
502 assert_eq!(registry.list_by_category(ToolCategory::Perception).len(), 6);
503 assert_eq!(registry.list_by_category(ToolCategory::Navigation).len(), 1);
504 assert_eq!(registry.list_by_category(ToolCategory::Cognition).len(), 4);
505 assert_eq!(registry.list_by_category(ToolCategory::Memory).len(), 2);
506 assert_eq!(registry.list_by_category(ToolCategory::Planning).len(), 1);
507 assert_eq!(registry.list_by_category(ToolCategory::Swarm).len(), 1);
508 }
509
510 #[test]
511 fn test_to_mcp_schema_valid_json() {
512 let registry = RoboticsToolRegistry::new();
513 let schema = registry.to_mcp_schema();
514
515 let tools = schema.get("tools").unwrap().as_array().unwrap();
516 assert_eq!(tools.len(), 15);
517
518 let names: Vec<&str> = tools
520 .iter()
521 .map(|t| t.get("name").unwrap().as_str().unwrap())
522 .collect();
523 let mut sorted = names.clone();
524 sorted.sort();
525 assert_eq!(names, sorted);
526
527 for tool in tools {
529 assert!(tool.get("name").unwrap().is_string());
530 assert!(tool.get("description").unwrap().is_string());
531 let input = tool.get("inputSchema").unwrap();
532 assert_eq!(input.get("type").unwrap().as_str().unwrap(), "object");
533 assert!(input.get("properties").unwrap().is_object());
534 assert!(input.get("required").unwrap().is_array());
535 }
536 }
537
538 #[test]
539 fn test_schema_required_fields() {
540 let registry = RoboticsToolRegistry::new();
541 let schema = registry.to_mcp_schema();
542 let tools = schema["tools"].as_array().unwrap();
543
544 let obs = tools.iter().find(|t| t["name"] == "detect_obstacles").unwrap();
545 let required = obs["inputSchema"]["required"].as_array().unwrap();
546 let req_names: Vec<&str> = required.iter().map(|v| v.as_str().unwrap()).collect();
547 assert!(req_names.contains(&"point_cloud_json"));
548 assert!(req_names.contains(&"robot_position"));
549 assert!(!req_names.contains(&"max_distance"));
550 }
551
552 #[test]
553 fn test_tool_request_serialization() {
554 let mut args = HashMap::new();
555 args.insert("k".to_string(), serde_json::json!(5));
556 args.insert("query".to_string(), serde_json::json!([1.0, 2.0, 3.0]));
557
558 let req = ToolRequest { tool_name: "spatial_search".to_string(), arguments: args };
559 let json = serde_json::to_string(&req).unwrap();
560 let deserialized: ToolRequest = serde_json::from_str(&json).unwrap();
561 assert_eq!(deserialized.tool_name, "spatial_search");
562 assert_eq!(deserialized.arguments["k"], serde_json::json!(5));
563 }
564
565 #[test]
566 fn test_tool_response_ok() {
567 let resp = ToolResponse::ok(serde_json::json!({"obstacles": 3}), 420);
568 assert!(resp.success);
569 assert!(resp.error.is_none());
570 assert_eq!(resp.latency_us, 420);
571 assert_eq!(resp.result["obstacles"], 3);
572
573 let json = serde_json::to_string(&resp).unwrap();
574 let deserialized: ToolResponse = serde_json::from_str(&json).unwrap();
575 assert!(deserialized.success);
576 }
577
578 #[test]
579 fn test_tool_response_err() {
580 let resp = ToolResponse::err("something went wrong", 100);
581 assert!(!resp.success);
582 assert_eq!(resp.error.as_deref(), Some("something went wrong"));
583 assert!(resp.result.is_null());
584 }
585
586 #[test]
587 fn test_register_custom_tool() {
588 let mut registry = RoboticsToolRegistry::new();
589 assert_eq!(registry.list_tools().len(), 15);
590
591 let custom = ToolDefinition::new(
592 "my_custom_tool",
593 "A custom tool for testing",
594 vec![ToolParameter::new("input", "The input data", ParamType::String, true)],
595 ToolCategory::Cognition,
596 );
597 registry.register_tool(custom);
598 assert_eq!(registry.list_tools().len(), 16);
599
600 let tool = registry.get_tool("my_custom_tool").unwrap();
601 assert_eq!(tool.description, "A custom tool for testing");
602 assert_eq!(tool.parameters.len(), 1);
603 }
604
605 #[test]
606 fn test_register_overwrites_existing() {
607 let mut registry = RoboticsToolRegistry::new();
608 let replacement = ToolDefinition::new(
609 "detect_obstacles",
610 "Replaced description",
611 vec![],
612 ToolCategory::Perception,
613 );
614 registry.register_tool(replacement);
615 assert_eq!(registry.list_tools().len(), 15);
616 let tool = registry.get_tool("detect_obstacles").unwrap();
617 assert_eq!(tool.description, "Replaced description");
618 assert!(tool.parameters.is_empty());
619 }
620
621 #[test]
622 fn test_empty_registry() {
623 let registry = RoboticsToolRegistry::empty();
624 assert_eq!(registry.list_tools().len(), 0);
625 assert!(registry.get_tool("detect_obstacles").is_none());
626 }
627
628 #[test]
629 fn test_param_type_serde_roundtrip() {
630 let types = vec![
631 ParamType::String,
632 ParamType::Number,
633 ParamType::Integer,
634 ParamType::Boolean,
635 ParamType::Array,
636 ParamType::Object,
637 ];
638 for pt in types {
639 let json = serde_json::to_string(&pt).unwrap();
640 let deserialized: ParamType = serde_json::from_str(&json).unwrap();
641 assert_eq!(pt, deserialized);
642 }
643 }
644
645 #[test]
646 fn test_tool_category_serde_roundtrip() {
647 let categories = vec![
648 ToolCategory::Perception,
649 ToolCategory::Navigation,
650 ToolCategory::Cognition,
651 ToolCategory::Swarm,
652 ToolCategory::Memory,
653 ToolCategory::Planning,
654 ];
655 for cat in categories {
656 let json = serde_json::to_string(&cat).unwrap();
657 let deserialized: ToolCategory = serde_json::from_str(&json).unwrap();
658 assert_eq!(cat, deserialized);
659 }
660 }
661
662 #[test]
663 fn test_tool_definition_serde_roundtrip() {
664 let tool = ToolDefinition::new(
665 "test_tool",
666 "A tool for testing",
667 vec![
668 ToolParameter::new("a", "param a", ParamType::String, true),
669 ToolParameter::new("b", "param b", ParamType::Number, false),
670 ],
671 ToolCategory::Navigation,
672 );
673 let json = serde_json::to_string(&tool).unwrap();
674 let deserialized: ToolDefinition = serde_json::from_str(&json).unwrap();
675 assert_eq!(tool, deserialized);
676 }
677
678 #[test]
679 fn test_default_trait() {
680 let registry = RoboticsToolRegistry::default();
681 assert_eq!(registry.list_tools().len(), 15);
682 }
683}