1use std::time::Instant;
8
9use crate::bridge::{Point3D, PointCloud, SceneObject, SpatialIndex};
10use crate::mcp::{ToolRequest, ToolResponse};
11use crate::perception::PerceptionPipeline;
12
13pub struct ToolExecutor {
16 pipeline: PerceptionPipeline,
17 index: SpatialIndex,
18}
19
20impl ToolExecutor {
21 pub fn new() -> Self {
23 Self {
24 pipeline: PerceptionPipeline::with_thresholds(0.5, 2.0),
25 index: SpatialIndex::new(3),
26 }
27 }
28
29 pub fn execute(&mut self, request: &ToolRequest) -> ToolResponse {
31 let start = Instant::now();
32 let result = match request.tool_name.as_str() {
33 "detect_obstacles" => self.handle_detect_obstacles(request),
34 "build_scene_graph" => self.handle_build_scene_graph(request),
35 "predict_trajectory" => self.handle_predict_trajectory(request),
36 "focus_attention" => self.handle_focus_attention(request),
37 "detect_anomalies" => self.handle_detect_anomalies(request),
38 "spatial_search" => self.handle_spatial_search(request),
39 "insert_points" => self.handle_insert_points(request),
40 other => Err(format!("unknown tool: {other}")),
41 };
42 let latency_us = start.elapsed().as_micros() as u64;
43 match result {
44 Ok(value) => ToolResponse::ok(value, latency_us),
45 Err(msg) => ToolResponse::err(msg, latency_us),
46 }
47 }
48
49 pub fn index(&self) -> &SpatialIndex {
51 &self.index
52 }
53
54 fn handle_detect_obstacles(
57 &self,
58 req: &ToolRequest,
59 ) -> std::result::Result<serde_json::Value, String> {
60 let cloud = parse_point_cloud(req, "point_cloud_json")?;
61 let pos = parse_position(req, "robot_position")?;
62 let max_dist = req
63 .arguments
64 .get("max_distance")
65 .and_then(|v| v.as_f64())
66 .unwrap_or(20.0);
67
68 let obstacles = self
69 .pipeline
70 .detect_obstacles(&cloud, pos, max_dist)
71 .map_err(|e| e.to_string())?;
72
73 serde_json::to_value(&obstacles).map_err(|e| e.to_string())
74 }
75
76 fn handle_build_scene_graph(
77 &self,
78 req: &ToolRequest,
79 ) -> std::result::Result<serde_json::Value, String> {
80 let objects: Vec<SceneObject> = parse_json_arg(req, "objects_json")?;
81 let max_edge = req
82 .arguments
83 .get("max_edge_distance")
84 .and_then(|v| v.as_f64())
85 .unwrap_or(5.0);
86
87 let graph = self
88 .pipeline
89 .build_scene_graph(&objects, max_edge)
90 .map_err(|e| e.to_string())?;
91
92 serde_json::to_value(&graph).map_err(|e| e.to_string())
93 }
94
95 fn handle_predict_trajectory(
96 &self,
97 req: &ToolRequest,
98 ) -> std::result::Result<serde_json::Value, String> {
99 let pos = parse_position(req, "position")?;
100 let vel = parse_position(req, "velocity")?;
101 let steps = req
102 .arguments
103 .get("steps")
104 .and_then(|v| v.as_u64())
105 .unwrap_or(10) as usize;
106 let dt = req
107 .arguments
108 .get("dt")
109 .and_then(|v| v.as_f64())
110 .unwrap_or(0.1);
111
112 let traj = self
113 .pipeline
114 .predict_trajectory(pos, vel, steps, dt)
115 .map_err(|e| e.to_string())?;
116
117 serde_json::to_value(&traj).map_err(|e| e.to_string())
118 }
119
120 fn handle_focus_attention(
121 &self,
122 req: &ToolRequest,
123 ) -> std::result::Result<serde_json::Value, String> {
124 let cloud = parse_point_cloud(req, "point_cloud_json")?;
125 let center = parse_position(req, "center")?;
126 let radius = req
127 .arguments
128 .get("radius")
129 .and_then(|v| v.as_f64())
130 .ok_or("missing 'radius'")?;
131
132 let focused = self
133 .pipeline
134 .focus_attention(&cloud, center, radius)
135 .map_err(|e| e.to_string())?;
136
137 serde_json::to_value(&focused).map_err(|e| e.to_string())
138 }
139
140 fn handle_detect_anomalies(
141 &self,
142 req: &ToolRequest,
143 ) -> std::result::Result<serde_json::Value, String> {
144 let cloud = parse_point_cloud(req, "point_cloud_json")?;
145 let anomalies = self
146 .pipeline
147 .detect_anomalies(&cloud)
148 .map_err(|e| e.to_string())?;
149 serde_json::to_value(&anomalies).map_err(|e| e.to_string())
150 }
151
152 fn handle_spatial_search(
153 &self,
154 req: &ToolRequest,
155 ) -> std::result::Result<serde_json::Value, String> {
156 let query: Vec<f32> = req
157 .arguments
158 .get("query")
159 .and_then(|v| v.as_array())
160 .map(|a| a.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect())
161 .ok_or("missing 'query'")?;
162 let k = req
163 .arguments
164 .get("k")
165 .and_then(|v| v.as_u64())
166 .unwrap_or(5) as usize;
167
168 let results = self
169 .index
170 .search_nearest(&query, k)
171 .map_err(|e| e.to_string())?;
172
173 let pairs: Vec<serde_json::Value> = results
174 .iter()
175 .map(|(idx, dist)| serde_json::json!({"index": idx, "distance": dist}))
176 .collect();
177 Ok(serde_json::json!(pairs))
178 }
179
180 fn handle_insert_points(
181 &mut self,
182 req: &ToolRequest,
183 ) -> std::result::Result<serde_json::Value, String> {
184 let points: Vec<Point3D> = parse_json_arg(req, "points_json")?;
185 let cloud = PointCloud::new(points, 0);
186 self.index.insert_point_cloud(&cloud);
187 Ok(serde_json::json!({"inserted": cloud.len(), "total": self.index.len()}))
188 }
189}
190
191impl Default for ToolExecutor {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197fn parse_point_cloud(
200 req: &ToolRequest,
201 key: &str,
202) -> std::result::Result<PointCloud, String> {
203 let raw = req
204 .arguments
205 .get(key)
206 .ok_or_else(|| format!("missing '{key}'"))?;
207
208 if let Some(s) = raw.as_str() {
209 serde_json::from_str(s).map_err(|e| format!("invalid point cloud JSON: {e}"))
210 } else {
211 serde_json::from_value(raw.clone()).map_err(|e| format!("invalid point cloud: {e}"))
212 }
213}
214
215fn parse_position(
216 req: &ToolRequest,
217 key: &str,
218) -> std::result::Result<[f64; 3], String> {
219 let arr = req
220 .arguments
221 .get(key)
222 .and_then(|v| v.as_array())
223 .ok_or_else(|| format!("missing '{key}'"))?;
224
225 if arr.len() < 3 {
226 return Err(format!("'{key}' must have at least 3 elements"));
227 }
228 let x = arr[0].as_f64().ok_or("non-numeric")?;
229 let y = arr[1].as_f64().ok_or("non-numeric")?;
230 let z = arr[2].as_f64().ok_or("non-numeric")?;
231 Ok([x, y, z])
232}
233
234fn parse_json_arg<T: serde::de::DeserializeOwned>(
235 req: &ToolRequest,
236 key: &str,
237) -> std::result::Result<T, String> {
238 let raw = req
239 .arguments
240 .get(key)
241 .ok_or_else(|| format!("missing '{key}'"))?;
242
243 if let Some(s) = raw.as_str() {
244 serde_json::from_str(s).map_err(|e| format!("invalid JSON for '{key}': {e}"))
245 } else {
246 serde_json::from_value(raw.clone()).map_err(|e| format!("invalid '{key}': {e}"))
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use std::collections::HashMap;
254
255 fn make_request(tool: &str, args: serde_json::Value) -> ToolRequest {
256 let arguments: HashMap<String, serde_json::Value> =
257 serde_json::from_value(args).unwrap();
258 ToolRequest { tool_name: tool.to_string(), arguments }
259 }
260
261 #[test]
262 fn test_detect_obstacles() {
263 let mut exec = ToolExecutor::new();
264 let cloud = PointCloud::new(
265 vec![
266 Point3D::new(2.0, 0.0, 0.0),
267 Point3D::new(2.1, 0.0, 0.0),
268 Point3D::new(2.0, 0.1, 0.0),
269 ],
270 1000,
271 );
272 let cloud_json = serde_json::to_string(&cloud).unwrap();
273 let req = make_request("detect_obstacles", serde_json::json!({
274 "point_cloud_json": cloud_json,
275 "robot_position": [0.0, 0.0, 0.0],
276 }));
277 let resp = exec.execute(&req);
278 assert!(resp.success);
279 }
280
281 #[test]
282 fn test_predict_trajectory() {
283 let mut exec = ToolExecutor::new();
284 let req = make_request("predict_trajectory", serde_json::json!({
285 "position": [0.0, 0.0, 0.0],
286 "velocity": [1.0, 0.0, 0.0],
287 "steps": 5,
288 "dt": 0.5,
289 }));
290 let resp = exec.execute(&req);
291 assert!(resp.success);
292 let traj = resp.result;
293 assert_eq!(traj["waypoints"].as_array().unwrap().len(), 5);
294 }
295
296 #[test]
297 fn test_insert_and_search() {
298 let mut exec = ToolExecutor::new();
299
300 let points = vec![
302 Point3D::new(1.0, 0.0, 0.0),
303 Point3D::new(2.0, 0.0, 0.0),
304 Point3D::new(10.0, 0.0, 0.0),
305 ];
306 let points_json = serde_json::to_string(&points).unwrap();
307 let req = make_request("insert_points", serde_json::json!({
308 "points_json": points_json,
309 }));
310 let resp = exec.execute(&req);
311 assert!(resp.success);
312 assert_eq!(resp.result["total"], 3);
313
314 let req = make_request("spatial_search", serde_json::json!({
316 "query": [1.0, 0.0, 0.0],
317 "k": 2,
318 }));
319 let resp = exec.execute(&req);
320 assert!(resp.success);
321 let results = resp.result.as_array().unwrap();
322 assert_eq!(results.len(), 2);
323 }
324
325 #[test]
326 fn test_unknown_tool() {
327 let mut exec = ToolExecutor::new();
328 let req = make_request("nonexistent", serde_json::json!({}));
329 let resp = exec.execute(&req);
330 assert!(!resp.success);
331 assert!(resp.error.unwrap().contains("unknown tool"));
332 }
333
334 #[test]
335 fn test_build_scene_graph() {
336 let mut exec = ToolExecutor::new();
337 let objects = vec![
338 SceneObject::new(0, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
339 SceneObject::new(1, [2.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
340 ];
341 let objects_json = serde_json::to_string(&objects).unwrap();
342 let req = make_request("build_scene_graph", serde_json::json!({
343 "objects_json": objects_json,
344 "max_edge_distance": 5.0,
345 }));
346 let resp = exec.execute(&req);
347 assert!(resp.success);
348 assert_eq!(resp.result["edges"].as_array().unwrap().len(), 1);
349 }
350}