Skip to main content

ruvector_robotics/mcp/
executor.rs

1//! MCP tool execution engine.
2//!
3//! [`ToolExecutor`] wires up the perception pipeline, spatial index, and
4//! memory system to actually *execute* tool requests, turning the schema-only
5//! registry into a working tool backend.
6
7use std::time::Instant;
8
9use crate::bridge::{Point3D, PointCloud, SceneObject, SpatialIndex};
10use crate::mcp::{ToolRequest, ToolResponse};
11use crate::perception::PerceptionPipeline;
12
13/// Stateful executor that handles incoming [`ToolRequest`]s by dispatching to
14/// the appropriate subsystem.
15pub struct ToolExecutor {
16    pipeline: PerceptionPipeline,
17    index: SpatialIndex,
18}
19
20impl ToolExecutor {
21    /// Create a new executor with default subsystem configurations.
22    pub fn new() -> Self {
23        Self {
24            pipeline: PerceptionPipeline::with_thresholds(0.5, 2.0),
25            index: SpatialIndex::new(3),
26        }
27    }
28
29    /// Execute a tool request and return a response with timing.
30    pub fn execute(&mut self, request: &ToolRequest) -> ToolResponse {
31        let start = Instant::now();
32        let result = match request.tool_name.as_str() {
33            "detect_obstacles" => self.handle_detect_obstacles(request),
34            "build_scene_graph" => self.handle_build_scene_graph(request),
35            "predict_trajectory" => self.handle_predict_trajectory(request),
36            "focus_attention" => self.handle_focus_attention(request),
37            "detect_anomalies" => self.handle_detect_anomalies(request),
38            "spatial_search" => self.handle_spatial_search(request),
39            "insert_points" => self.handle_insert_points(request),
40            other => Err(format!("unknown tool: {other}")),
41        };
42        let latency_us = start.elapsed().as_micros() as u64;
43        match result {
44            Ok(value) => ToolResponse::ok(value, latency_us),
45            Err(msg) => ToolResponse::err(msg, latency_us),
46        }
47    }
48
49    /// Access the internal spatial index (e.g. for testing).
50    pub fn index(&self) -> &SpatialIndex {
51        &self.index
52    }
53
54    // -- handlers -----------------------------------------------------------
55
56    fn handle_detect_obstacles(
57        &self,
58        req: &ToolRequest,
59    ) -> std::result::Result<serde_json::Value, String> {
60        let cloud = parse_point_cloud(req, "point_cloud_json")?;
61        let pos = parse_position(req, "robot_position")?;
62        let max_dist = req
63            .arguments
64            .get("max_distance")
65            .and_then(|v| v.as_f64())
66            .unwrap_or(20.0);
67
68        let obstacles = self
69            .pipeline
70            .detect_obstacles(&cloud, pos, max_dist)
71            .map_err(|e| e.to_string())?;
72
73        serde_json::to_value(&obstacles).map_err(|e| e.to_string())
74    }
75
76    fn handle_build_scene_graph(
77        &self,
78        req: &ToolRequest,
79    ) -> std::result::Result<serde_json::Value, String> {
80        let objects: Vec<SceneObject> = parse_json_arg(req, "objects_json")?;
81        let max_edge = req
82            .arguments
83            .get("max_edge_distance")
84            .and_then(|v| v.as_f64())
85            .unwrap_or(5.0);
86
87        let graph = self
88            .pipeline
89            .build_scene_graph(&objects, max_edge)
90            .map_err(|e| e.to_string())?;
91
92        serde_json::to_value(&graph).map_err(|e| e.to_string())
93    }
94
95    fn handle_predict_trajectory(
96        &self,
97        req: &ToolRequest,
98    ) -> std::result::Result<serde_json::Value, String> {
99        let pos = parse_position(req, "position")?;
100        let vel = parse_position(req, "velocity")?;
101        let steps = req
102            .arguments
103            .get("steps")
104            .and_then(|v| v.as_u64())
105            .unwrap_or(10) as usize;
106        let dt = req
107            .arguments
108            .get("dt")
109            .and_then(|v| v.as_f64())
110            .unwrap_or(0.1);
111
112        let traj = self
113            .pipeline
114            .predict_trajectory(pos, vel, steps, dt)
115            .map_err(|e| e.to_string())?;
116
117        serde_json::to_value(&traj).map_err(|e| e.to_string())
118    }
119
120    fn handle_focus_attention(
121        &self,
122        req: &ToolRequest,
123    ) -> std::result::Result<serde_json::Value, String> {
124        let cloud = parse_point_cloud(req, "point_cloud_json")?;
125        let center = parse_position(req, "center")?;
126        let radius = req
127            .arguments
128            .get("radius")
129            .and_then(|v| v.as_f64())
130            .ok_or("missing 'radius'")?;
131
132        let focused = self
133            .pipeline
134            .focus_attention(&cloud, center, radius)
135            .map_err(|e| e.to_string())?;
136
137        serde_json::to_value(&focused).map_err(|e| e.to_string())
138    }
139
140    fn handle_detect_anomalies(
141        &self,
142        req: &ToolRequest,
143    ) -> std::result::Result<serde_json::Value, String> {
144        let cloud = parse_point_cloud(req, "point_cloud_json")?;
145        let anomalies = self
146            .pipeline
147            .detect_anomalies(&cloud)
148            .map_err(|e| e.to_string())?;
149        serde_json::to_value(&anomalies).map_err(|e| e.to_string())
150    }
151
152    fn handle_spatial_search(
153        &self,
154        req: &ToolRequest,
155    ) -> std::result::Result<serde_json::Value, String> {
156        let query: Vec<f32> = req
157            .arguments
158            .get("query")
159            .and_then(|v| v.as_array())
160            .map(|a| a.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect())
161            .ok_or("missing 'query'")?;
162        let k = req
163            .arguments
164            .get("k")
165            .and_then(|v| v.as_u64())
166            .unwrap_or(5) as usize;
167
168        let results = self
169            .index
170            .search_nearest(&query, k)
171            .map_err(|e| e.to_string())?;
172
173        let pairs: Vec<serde_json::Value> = results
174            .iter()
175            .map(|(idx, dist)| serde_json::json!({"index": idx, "distance": dist}))
176            .collect();
177        Ok(serde_json::json!(pairs))
178    }
179
180    fn handle_insert_points(
181        &mut self,
182        req: &ToolRequest,
183    ) -> std::result::Result<serde_json::Value, String> {
184        let points: Vec<Point3D> = parse_json_arg(req, "points_json")?;
185        let cloud = PointCloud::new(points, 0);
186        self.index.insert_point_cloud(&cloud);
187        Ok(serde_json::json!({"inserted": cloud.len(), "total": self.index.len()}))
188    }
189}
190
191impl Default for ToolExecutor {
192    fn default() -> Self {
193        Self::new()
194    }
195}
196
197// -- argument parsers -------------------------------------------------------
198
199fn parse_point_cloud(
200    req: &ToolRequest,
201    key: &str,
202) -> std::result::Result<PointCloud, String> {
203    let raw = req
204        .arguments
205        .get(key)
206        .ok_or_else(|| format!("missing '{key}'"))?;
207
208    if let Some(s) = raw.as_str() {
209        serde_json::from_str(s).map_err(|e| format!("invalid point cloud JSON: {e}"))
210    } else {
211        serde_json::from_value(raw.clone()).map_err(|e| format!("invalid point cloud: {e}"))
212    }
213}
214
215fn parse_position(
216    req: &ToolRequest,
217    key: &str,
218) -> std::result::Result<[f64; 3], String> {
219    let arr = req
220        .arguments
221        .get(key)
222        .and_then(|v| v.as_array())
223        .ok_or_else(|| format!("missing '{key}'"))?;
224
225    if arr.len() < 3 {
226        return Err(format!("'{key}' must have at least 3 elements"));
227    }
228    let x = arr[0].as_f64().ok_or("non-numeric")?;
229    let y = arr[1].as_f64().ok_or("non-numeric")?;
230    let z = arr[2].as_f64().ok_or("non-numeric")?;
231    Ok([x, y, z])
232}
233
234fn parse_json_arg<T: serde::de::DeserializeOwned>(
235    req: &ToolRequest,
236    key: &str,
237) -> std::result::Result<T, String> {
238    let raw = req
239        .arguments
240        .get(key)
241        .ok_or_else(|| format!("missing '{key}'"))?;
242
243    if let Some(s) = raw.as_str() {
244        serde_json::from_str(s).map_err(|e| format!("invalid JSON for '{key}': {e}"))
245    } else {
246        serde_json::from_value(raw.clone()).map_err(|e| format!("invalid '{key}': {e}"))
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use std::collections::HashMap;
254
255    fn make_request(tool: &str, args: serde_json::Value) -> ToolRequest {
256        let arguments: HashMap<String, serde_json::Value> =
257            serde_json::from_value(args).unwrap();
258        ToolRequest { tool_name: tool.to_string(), arguments }
259    }
260
261    #[test]
262    fn test_detect_obstacles() {
263        let mut exec = ToolExecutor::new();
264        let cloud = PointCloud::new(
265            vec![
266                Point3D::new(2.0, 0.0, 0.0),
267                Point3D::new(2.1, 0.0, 0.0),
268                Point3D::new(2.0, 0.1, 0.0),
269            ],
270            1000,
271        );
272        let cloud_json = serde_json::to_string(&cloud).unwrap();
273        let req = make_request("detect_obstacles", serde_json::json!({
274            "point_cloud_json": cloud_json,
275            "robot_position": [0.0, 0.0, 0.0],
276        }));
277        let resp = exec.execute(&req);
278        assert!(resp.success);
279    }
280
281    #[test]
282    fn test_predict_trajectory() {
283        let mut exec = ToolExecutor::new();
284        let req = make_request("predict_trajectory", serde_json::json!({
285            "position": [0.0, 0.0, 0.0],
286            "velocity": [1.0, 0.0, 0.0],
287            "steps": 5,
288            "dt": 0.5,
289        }));
290        let resp = exec.execute(&req);
291        assert!(resp.success);
292        let traj = resp.result;
293        assert_eq!(traj["waypoints"].as_array().unwrap().len(), 5);
294    }
295
296    #[test]
297    fn test_insert_and_search() {
298        let mut exec = ToolExecutor::new();
299
300        // Insert points
301        let points = vec![
302            Point3D::new(1.0, 0.0, 0.0),
303            Point3D::new(2.0, 0.0, 0.0),
304            Point3D::new(10.0, 0.0, 0.0),
305        ];
306        let points_json = serde_json::to_string(&points).unwrap();
307        let req = make_request("insert_points", serde_json::json!({
308            "points_json": points_json,
309        }));
310        let resp = exec.execute(&req);
311        assert!(resp.success);
312        assert_eq!(resp.result["total"], 3);
313
314        // Search
315        let req = make_request("spatial_search", serde_json::json!({
316            "query": [1.0, 0.0, 0.0],
317            "k": 2,
318        }));
319        let resp = exec.execute(&req);
320        assert!(resp.success);
321        let results = resp.result.as_array().unwrap();
322        assert_eq!(results.len(), 2);
323    }
324
325    #[test]
326    fn test_unknown_tool() {
327        let mut exec = ToolExecutor::new();
328        let req = make_request("nonexistent", serde_json::json!({}));
329        let resp = exec.execute(&req);
330        assert!(!resp.success);
331        assert!(resp.error.unwrap().contains("unknown tool"));
332    }
333
334    #[test]
335    fn test_build_scene_graph() {
336        let mut exec = ToolExecutor::new();
337        let objects = vec![
338            SceneObject::new(0, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
339            SceneObject::new(1, [2.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
340        ];
341        let objects_json = serde_json::to_string(&objects).unwrap();
342        let req = make_request("build_scene_graph", serde_json::json!({
343            "objects_json": objects_json,
344            "max_edge_distance": 5.0,
345        }));
346        let resp = exec.execute(&req);
347        assert!(resp.success);
348        assert_eq!(resp.result["edges"].as_array().unwrap().len(), 1);
349    }
350}