Skip to main content

task_graph_mcp/tools/
mod.rs

1//! MCP tool implementations.
2
3pub mod agents;
4pub mod attachments;
5pub mod claiming;
6pub mod context;
7pub mod deps;
8pub mod feedback;
9pub mod files;
10pub mod gates;
11pub mod query;
12pub mod schema;
13pub mod search;
14pub mod skills;
15pub mod tasks;
16pub mod tracking;
17pub mod workflows;
18
19pub use context::ToolContext;
20
21use crate::config::{AppConfig, Prompts, ServerPaths, workflows::WorkflowsConfig};
22use crate::db::Database;
23use crate::error::ToolError;
24use crate::format::{OutputFormat, ToolResult};
25use anyhow::Result;
26use rmcp::model::Tool;
27use serde_json::Value;
28use std::path::PathBuf;
29use std::sync::Arc;
30
31/// Tool handler that processes MCP tool calls.
32pub struct ToolHandler {
33    pub db: Arc<Database>,
34    pub media_dir: PathBuf,
35    pub skills_dir: PathBuf,
36    pub server_paths: Arc<ServerPaths>,
37    pub prompts: Arc<Prompts>,
38    /// Consolidated application configuration.
39    pub config: AppConfig,
40    pub default_format: OutputFormat,
41    pub default_page_size: i32,
42    pub path_mapper: Arc<crate::paths::PathMapper>,
43}
44
45impl ToolHandler {
46    #[allow(clippy::too_many_arguments)]
47    pub fn new(
48        db: Arc<Database>,
49        media_dir: PathBuf,
50        skills_dir: PathBuf,
51        server_paths: Arc<ServerPaths>,
52        prompts: Arc<Prompts>,
53        config: AppConfig,
54        default_format: OutputFormat,
55        default_page_size: i32,
56        path_mapper: Arc<crate::paths::PathMapper>,
57    ) -> Self {
58        Self {
59            db,
60            media_dir,
61            skills_dir,
62            server_paths,
63            prompts,
64            config,
65            default_format,
66            default_page_size,
67            path_mapper,
68        }
69    }
70
71    /// Get the workflow config for a worker.
72    /// Looks up the worker's workflow name and returns the corresponding config,
73    /// or falls back to the configured default workflow, or the base config.
74    /// If the worker has overlays, applies them on top of the base workflow
75    /// and caches the merged result for reuse.
76    pub fn get_workflow_for_worker(&self, worker_id: &str) -> Arc<WorkflowsConfig> {
77        if let Ok(Some(worker)) = self.db.get_worker(worker_id) {
78            // Resolve the base workflow
79            let base = if let Some(ref workflow_name) = worker.workflow {
80                self.config
81                    .workflows
82                    .get_named_workflow(workflow_name)
83                    .map(Arc::clone)
84            } else {
85                None
86            }
87            .or_else(|| self.config.workflows.get_default_workflow().map(Arc::clone))
88            .unwrap_or_else(|| Arc::clone(&self.config.workflows));
89
90            // If the worker has overlays, build a merged config
91            if !worker.overlays.is_empty() {
92                // Check cache first (composite key: "workflow+overlay1+overlay2")
93                let cache_key = format!(
94                    "{}+{}",
95                    worker.workflow.as_deref().unwrap_or("default"),
96                    worker.overlays.join("+")
97                );
98                if let Some(cached) = self.config.workflows.get_named_workflow(&cache_key) {
99                    return Arc::clone(cached);
100                }
101
102                // Build merged workflow by applying overlays in order
103                let mut merged = (*base).clone();
104                for name in &worker.overlays {
105                    if let Some(overlay) = self.config.workflows.named_overlays.get(name) {
106                        merged.apply_overlay(overlay);
107                    }
108                }
109                merged.active_overlays = worker.overlays.clone();
110                return Arc::new(merged);
111            }
112
113            return base;
114        }
115        // Fall back to configured default workflow, or base config
116        if let Some(default_workflow) = self.config.workflows.get_default_workflow() {
117            Arc::clone(default_workflow)
118        } else {
119            Arc::clone(&self.config.workflows)
120        }
121    }
122
123    /// Get all available tools.
124    pub fn get_tools(&self) -> Vec<Tool> {
125        let mut tools = Vec::new();
126
127        // Worker tools
128        tools.extend(agents::get_tools(&self.prompts));
129
130        // Task tools (with dynamic state schema)
131        tools.extend(tasks::get_tools(&self.prompts, &self.config.states));
132
133        // Tracking tools
134        tools.extend(tracking::get_tools(&self.prompts, &self.config.states));
135
136        // Dependency tools
137        tools.extend(deps::get_tools(&self.prompts, &self.config.deps));
138
139        // Claiming tools (with dynamic state schema)
140        tools.extend(claiming::get_tools(&self.prompts, &self.config.states));
141
142        // File coordination tools
143        tools.extend(files::get_tools(&self.prompts));
144
145        // Attachment tools
146        tools.extend(attachments::get_tools(&self.prompts));
147
148        // Skill tools (no prompts needed, always available)
149        tools.extend(skills::get_tools());
150
151        // Schema introspection tools
152        tools.extend(schema::get_tools());
153
154        // Search tools
155        tools.extend(search::get_tools(&self.prompts));
156
157        // Query tools (read-only SQL)
158        tools.extend(query::get_tools());
159
160        // Gate checking tools
161        tools.extend(gates::get_tools(&self.prompts));
162
163        // Workflow discovery tools (no auth needed, callable before connect)
164        tools.extend(workflows::get_tools());
165
166        // Feedback tools (conditionally enabled)
167        if self.config.feedback.enabled {
168            tools.extend(feedback::get_tools());
169        }
170
171        tools
172    }
173
174    /// Call a tool by name.
175    #[allow(unused_variables)]
176    pub async fn call_tool(
177        &self,
178        name: &str,
179        arguments: Value,
180        ctx: &ToolContext,
181    ) -> Result<ToolResult> {
182        // Helper to wrap JSON results
183        let json = |r: Result<Value>| r.map(ToolResult::Json);
184
185        match name {
186            // Worker tools
187            "connect" => {
188                // Resolve base workflow from args (worker isn't registered yet during connect)
189                let base_workflow = arguments
190                    .get("workflow")
191                    .and_then(|v| v.as_str())
192                    .and_then(|name| self.config.workflows.get_named_workflow(name))
193                    .map(Arc::clone)
194                    .or_else(|| self.config.workflows.get_default_workflow().map(Arc::clone))
195                    .unwrap_or_else(|| Arc::clone(&self.config.workflows));
196
197                // Apply overlays if specified
198                let overlay_names: Vec<String> = arguments
199                    .get("overlays")
200                    .and_then(|v| v.as_array())
201                    .map(|arr| {
202                        arr.iter()
203                            .filter_map(|v| v.as_str().map(String::from))
204                            .collect()
205                    })
206                    .unwrap_or_default();
207
208                let workflow = if overlay_names.is_empty() {
209                    base_workflow
210                } else {
211                    let mut merged = (*base_workflow).clone();
212                    for name in &overlay_names {
213                        if let Some(overlay) = self.config.workflows.named_overlays.get(name) {
214                            merged.apply_overlay(overlay);
215                        }
216                    }
217                    merged.active_overlays = overlay_names;
218                    Arc::new(merged)
219                };
220
221                json(agents::connect(
222                    agents::ConnectOptions {
223                        db: &self.db,
224                        server_paths: &self.server_paths,
225                        config: &self.config,
226                        workflows: &workflow,
227                    },
228                    arguments,
229                ))
230            }
231            "disconnect" => json(agents::disconnect(&self.db, &self.config.states, arguments)),
232            "list_agents" => agents::list_agents(
233                &self.db,
234                &self.config.states,
235                self.default_format,
236                arguments,
237            ),
238            "cleanup_stale" => json(agents::cleanup_stale(
239                &self.db,
240                &self.config.states,
241                arguments,
242            )),
243            "add_overlay" => json(agents::add_overlay(&self.db, &self.config, arguments)),
244            "remove_overlay" => json(agents::remove_overlay(&self.db, &self.config, arguments)),
245
246            // Task tools
247            "create" => json(tasks::create(&self.db, &self.config, arguments)),
248            "create_tree" => json(tasks::create_tree(&self.db, &self.config, arguments)),
249            "get" => json(tasks::get(&self.db, self.default_format, arguments)),
250            "list_tasks" => json(tasks::list_tasks(
251                &self.db,
252                &self.config.states,
253                &self.config.deps,
254                self.default_format,
255                arguments,
256            )),
257            "update" => {
258                // Look up worker's workflow for prompts
259                let worker_id = arguments
260                    .get("worker_id")
261                    .and_then(|v| v.as_str())
262                    .unwrap_or("");
263                let workflow = self.get_workflow_for_worker(worker_id);
264                json(tasks::update(
265                    tasks::UpdateOptions {
266                        db: &self.db,
267                        config: &self.config,
268                        workflows: &workflow,
269                    },
270                    arguments,
271                ))
272            }
273            "delete" => json(tasks::delete(&self.db, arguments)),
274            "rename" => json(tasks::rename(&self.db, arguments)),
275            "scan" => json(tasks::scan(&self.db, self.default_format, arguments)),
276
277            // Tracking tools
278            "thinking" => json(tracking::thinking(&self.db, arguments)),
279            "task_history" => json(tracking::task_history(
280                &self.db,
281                &self.config.states,
282                self.default_format,
283                arguments,
284            )),
285            "log_metrics" => json(tracking::log_metrics(&self.db, arguments)),
286            "get_metrics" => json(tracking::get_metrics(&self.db, arguments)),
287            "project_history" => json(tracking::project_history(
288                &self.db,
289                self.default_format,
290                arguments,
291            )),
292
293            // Dependency tools
294            "link" => json(deps::link(&self.db, &self.config.deps, arguments)),
295            "unlink" => json(deps::unlink(&self.db, arguments)),
296            "relink" => json(deps::relink(&self.db, &self.config.deps, arguments)),
297
298            // Claiming tools
299            "claim" => {
300                // Look up worker's workflow for prompts
301                let worker_id = arguments
302                    .get("worker_id")
303                    .and_then(|v| v.as_str())
304                    .unwrap_or("");
305                let workflow = self.get_workflow_for_worker(worker_id);
306                json(claiming::claim(
307                    &self.db,
308                    &self.config,
309                    &workflow,
310                    arguments,
311                ))
312            }
313
314            // File coordination tools
315            "mark_file" => json(files::mark_file(&self.db, arguments)),
316            "unmark_file" => json(files::unmark_file(&self.db, arguments)),
317            "list_marks" => json(files::list_marks(&self.db, self.default_format, arguments)),
318            "mark_updates" => {
319                json(files::mark_updates_async(std::sync::Arc::clone(&self.db), arguments).await)
320            }
321
322            // Attachment tools
323            "attach" => json(attachments::attach(
324                &self.db,
325                &self.media_dir,
326                &self.config.attachments,
327                arguments,
328            )),
329            "attachments" => json(attachments::attachments(
330                &self.db,
331                &self.media_dir,
332                self.default_format,
333                arguments,
334            )),
335            "detach" => json(attachments::detach(&self.db, &self.media_dir, arguments)),
336
337            // Skill tools
338            name if skills::is_skill_tool(name) => {
339                json(skills::call_tool(&self.skills_dir, name, &arguments))
340            }
341
342            // Schema introspection tools
343            "get_schema" => json(schema::get_schema(&self.db, arguments)),
344
345            // Search tools
346            "search" => json(search::search(&self.db, self.default_page_size, arguments)),
347
348            // Query tools (read-only SQL)
349            "query" => query::query(&self.db, self.default_format, arguments),
350
351            // Gate checking tools
352            "check_gates" => {
353                // Look up worker's workflow for gate definitions
354                // Since check_gates doesn't require worker_id, use base workflow
355                json(gates::check_gates(
356                    &self.db,
357                    &self.config.workflows,
358                    arguments,
359                ))
360            }
361
362            // Workflow discovery tools (no connection required)
363            "list_workflows" => json(workflows::list_workflows(&self.config.workflows)),
364
365            // Feedback tools (gated by config)
366            "give_feedback" | "list_feedback" if !self.config.feedback.enabled => {
367                Err(ToolError::unknown_tool(name).into())
368            }
369            "give_feedback" => {
370                let db_dir = self
371                    .server_paths
372                    .db_path
373                    .parent()
374                    .unwrap_or(std::path::Path::new("."));
375                json(feedback::give_feedback(db_dir, arguments))
376            }
377            "list_feedback" => {
378                let db_dir = self
379                    .server_paths
380                    .db_path
381                    .parent()
382                    .unwrap_or(std::path::Path::new("."));
383                json(feedback::list_feedback(db_dir))
384            }
385
386            _ => Err(ToolError::unknown_tool(name).into()),
387        }
388    }
389}
390
391/// Helper to create a tool definition.
392pub fn make_tool(name: &str, description: &str, properties: Value, required: Vec<&str>) -> Tool {
393    let input_schema = rmcp::model::JsonObject::from_iter([
394        ("type".to_string(), serde_json::json!("object")),
395        ("properties".to_string(), properties),
396        ("required".to_string(), serde_json::json!(required)),
397    ]);
398
399    Tool::new(name.to_string(), description.to_string(), input_schema)
400}
401
402/// Helper to create a tool definition with prompt overrides.
403/// Looks up the tool description in prompts, falls back to default_description.
404pub fn make_tool_with_prompts(
405    name: &str,
406    default_description: &str,
407    properties: Value,
408    required: Vec<&str>,
409    prompts: &Prompts,
410) -> Tool {
411    let description = prompts
412        .get_tool_description(name)
413        .unwrap_or(default_description);
414    make_tool(name, description, properties, required)
415}
416
417/// Helper to get a string from arguments.
418pub fn get_string(args: &Value, key: &str) -> Option<String> {
419    args.get(key).and_then(|v| v.as_str().map(String::from))
420}
421
422/// Helper to get an i32 from arguments.
423pub fn get_i32(args: &Value, key: &str) -> Option<i32> {
424    args.get(key).and_then(|v| v.as_i64().map(|n| n as i32))
425}
426
427/// Helper to get an i64 from arguments.
428pub fn get_i64(args: &Value, key: &str) -> Option<i64> {
429    args.get(key).and_then(|v| v.as_i64())
430}
431
432/// Helper to get an f64 from arguments.
433pub fn get_f64(args: &Value, key: &str) -> Option<f64> {
434    args.get(key).and_then(|v| v.as_f64())
435}
436
437/// Helper to get a bool from arguments.
438pub fn get_bool(args: &Value, key: &str) -> Option<bool> {
439    args.get(key).and_then(|v| v.as_bool())
440}
441
442/// Helper to get a string array from arguments.
443pub fn get_string_array(args: &Value, key: &str) -> Option<Vec<String>> {
444    args.get(key).and_then(|v| {
445        v.as_array().map(|arr| {
446            arr.iter()
447                .filter_map(|v| v.as_str().map(String::from))
448                .collect()
449        })
450    })
451}
452
453/// Helper to get either a single string or array of strings from arguments.
454/// Normalizes to a Vec<String>.
455pub fn get_string_or_array(args: &Value, key: &str) -> Option<Vec<String>> {
456    args.get(key).and_then(|v| {
457        if let Some(s) = v.as_str() {
458            // Single string - wrap in vec
459            Some(vec![s.to_string()])
460        } else {
461            v.as_array().map(|arr| {
462                arr.iter()
463                    .filter_map(|item| item.as_str().map(String::from))
464                    .collect()
465            })
466        }
467    })
468}
469
470/// Parsed result that may be a list of IDs or a wildcard "*".
471pub enum IdList {
472    Ids(Vec<String>),
473    Wildcard,
474}
475
476/// Like get_string_or_array, but recognizes "*" as a wildcard sentinel.
477pub fn get_string_or_array_or_wildcard(args: &Value, key: &str) -> Option<IdList> {
478    let vals = get_string_or_array(args, key)?;
479    if vals.len() == 1 && vals[0] == "*" {
480        Some(IdList::Wildcard)
481    } else {
482        Some(IdList::Ids(vals))
483    }
484}