Skip to main content

task_graph_mcp/tools/
agents.rs

1//! Worker connection and management tools.
2
3use super::{get_bool, get_i32, get_string, get_string_array, make_tool_with_prompts};
4use crate::config::workflows::WorkflowsConfig;
5use crate::config::{AppConfig, Prompts, ServerPaths, StatesConfig};
6use crate::db::Database;
7use crate::error::ToolError;
8use crate::format::{OutputFormat, ToolResult, format_workers_markdown};
9use anyhow::Result;
10use rmcp::model::Tool;
11use serde_json::{Value, json};
12
13/// Options for connecting a worker to the task graph.
14pub struct ConnectOptions<'a> {
15    pub db: &'a Database,
16    pub server_paths: &'a ServerPaths,
17    pub config: &'a AppConfig,
18    /// Per-connect workflow (may differ from config.workflows for named workflows).
19    pub workflows: &'a WorkflowsConfig,
20}
21
22pub fn get_tools(prompts: &Prompts) -> Vec<Tool> {
23    vec![
24        make_tool_with_prompts(
25            "connect",
26            "Connect as a worker. Call this FIRST before using other tools. Returns worker_id (save it for all subsequent calls). Tags enable task affinity matching.",
27            json!({
28                "worker_id": {
29                    "type": "string",
30                    "description": "Only use if assigned a unique name (e.g., 'worker-17', 'coordinator'). Avoid generic names like 'claude'. Leave empty for an auto-generated petname."
31                },
32                "tags": {
33                    "type": "array",
34                    "items": { "type": "string" },
35                    "description": "Freeform tags for capabilities, roles, etc."
36                },
37                "force": {
38                    "type": "boolean",
39                    "description": "Force reconnection if worker ID already exists (default: false). Use for stuck worker recovery."
40                },
41                "db_path": {
42                    "type": "string",
43                    "description": "Override database file path (same as TASK_GRAPH_DB_PATH env var). Note: Can only be set before server starts."
44                },
45                "media_dir": {
46                    "type": "string",
47                    "description": "Override media directory path (same as TASK_GRAPH_MEDIA_DIR env var). Note: Can only be set before server starts."
48                },
49                "log_dir": {
50                    "type": "string",
51                    "description": "Override log directory path (same as TASK_GRAPH_LOG_DIR env var). Note: Can only be set before server starts."
52                },
53                "config_path": {
54                    "type": "string",
55                    "description": "Override config file path (same as TASK_GRAPH_CONFIG_PATH env var). Note: Can only be set before server starts."
56                },
57                "workflow": {
58                    "type": "string",
59                    "description": "Named workflow to use (e.g., 'swarm' for workflow-swarm.yaml). If not specified, uses default workflows.yaml."
60                }
61            }),
62            vec![],
63            prompts,
64        ),
65        make_tool_with_prompts(
66            "disconnect",
67            "Disconnect a worker, releasing all claims and locks.",
68            json!({
69                "worker_id": {
70                    "type": "string",
71                    "description": "The worker's ID"
72                },
73                "final_status": {
74                    "type": "string",
75                    "enum": ["pending", "completed", "cancelled", "failed"],
76                    "description": "Status to set released tasks to (default: config disconnect_status, typically 'pending'). Must be an untimed status."
77                }
78            }),
79            vec!["worker_id"],
80            prompts,
81        ),
82        make_tool_with_prompts(
83            "list_agents",
84            "List all connected workers with their current status, claim counts, and what they're working on. Automatically evicts stale workers (no heartbeat within timeout).",
85            json!({
86                "tags": {
87                    "type": "array",
88                    "items": { "type": "string" },
89                    "description": "Filter workers that have ALL of these tags"
90                },
91                "file": {
92                    "type": "string",
93                    "description": "Filter workers that have claimed this file"
94                },
95                "task": {
96                    "type": "string",
97                    "description": "Filter workers related to this task ID"
98                },
99                "depth": {
100                    "type": "integer",
101                    "description": "Task relationship depth (-3 to 3). Negative: ancestors, positive: descendants. Used with 'task' filter."
102                },
103                "stale_timeout": {
104                    "type": "integer",
105                    "description": "Seconds without heartbeat before a worker is considered stale and evicted. Set to 0 to disable auto-cleanup. Default: 300 (5 minutes)."
106                }
107            }),
108            vec![],
109            prompts,
110        ),
111        make_tool_with_prompts(
112            "cleanup_stale",
113            "Evict stale workers that haven't sent a heartbeat within the timeout period. Releases their task claims and file locks.",
114            json!({
115                "timeout": {
116                    "type": "integer",
117                    "description": "Seconds without heartbeat before a worker is considered stale. Default: 300 (5 minutes)."
118                },
119                "final_status": {
120                    "type": "string",
121                    "enum": ["pending", "completed", "cancelled", "failed"],
122                    "description": "Status to set released tasks to (default: config disconnect_status, typically 'pending'). Must be an untimed status."
123                }
124            }),
125            vec![],
126            prompts,
127        ),
128    ]
129}
130
131pub fn connect(opts: ConnectOptions<'_>, args: Value) -> Result<Value> {
132    let ConnectOptions {
133        db,
134        server_paths,
135        config,
136        workflows,
137    } = opts;
138
139    let states_config = &config.states;
140    let phases_config = &config.phases;
141    let deps_config = &config.deps;
142    let tags_config = &config.tags;
143    let ids_config = &config.ids;
144
145    let worker_id = get_string(&args, "worker_id");
146    let tags = get_string_array(&args, "tags").unwrap_or_default();
147    let force = get_bool(&args, "force").unwrap_or(false);
148    let workflow = get_string(&args, "workflow");
149
150    // Validate tags if provided
151    let tag_warnings = tags_config.validate_tags(&tags)?;
152
153    // Check for path override requests (informational - paths are set at server startup)
154    let mut path_notes: Vec<String> = Vec::new();
155
156    if let Some(requested_db) = get_string(&args, "db_path")
157        && server_paths.db_path.to_string_lossy() != requested_db
158    {
159        path_notes.push(format!(
160                "db_path: requested '{}' but server is using '{}' (set TASK_GRAPH_DB_PATH before starting server)",
161                requested_db,
162                server_paths.db_path.display()
163            ));
164    }
165
166    if let Some(requested_media) = get_string(&args, "media_dir")
167        && server_paths.media_dir.to_string_lossy() != requested_media
168    {
169        path_notes.push(format!(
170                "media_dir: requested '{}' but server is using '{}' (set TASK_GRAPH_MEDIA_DIR before starting server)",
171                requested_media,
172                server_paths.media_dir.display()
173            ));
174    }
175
176    if let Some(requested_log) = get_string(&args, "log_dir")
177        && server_paths.log_dir.to_string_lossy() != requested_log
178    {
179        path_notes.push(format!(
180                "log_dir: requested '{}' but server is using '{}' (set TASK_GRAPH_LOG_DIR before starting server)",
181                requested_log,
182                server_paths.log_dir.display()
183            ));
184    }
185
186    if let Some(requested_config) = get_string(&args, "config_path") {
187        let current_config = server_paths
188            .config_path
189            .as_ref()
190            .map(|p| p.to_string_lossy().to_string());
191        if current_config.as_deref() != Some(&requested_config) {
192            path_notes.push(format!(
193                "config_path: requested '{}' but server is using '{}' (set TASK_GRAPH_CONFIG_PATH before starting server)",
194                requested_config,
195                current_config.unwrap_or_else(|| "default locations".to_string())
196            ));
197        }
198    }
199
200    let worker = db.register_worker(worker_id, tags, force, ids_config, workflow)?;
201
202    // Build config summary for the response
203    let timed_states: Vec<&str> = states_config
204        .definitions
205        .iter()
206        .filter(|(_, def)| def.timed)
207        .map(|(name, _)| name.as_str())
208        .collect();
209
210    let terminal_states: Vec<&str> = states_config
211        .definitions
212        .iter()
213        .filter(|(_, def)| def.exits.is_empty())
214        .map(|(name, _)| name.as_str())
215        .collect();
216
217    let mut response = json!({
218        "version": env!("CARGO_PKG_VERSION"),
219        "worker_id": &worker.id,
220        "tags": worker.tags,
221        "max_claims": worker.max_claims,
222        "registered_at": worker.registered_at,
223        "workflow": worker.workflow,
224        "paths": {
225            "db_path": server_paths.db_path.to_string_lossy(),
226            "media_dir": server_paths.media_dir.to_string_lossy(),
227            "log_dir": server_paths.log_dir.to_string_lossy(),
228            "config_path": server_paths.config_path.as_ref().map(|p| p.to_string_lossy().to_string())
229        },
230        "config": {
231            "states": states_config.state_names(),
232            "initial_state": &states_config.initial,
233            "timed_states": timed_states,
234            "terminal_states": terminal_states,
235            "blocking_states": &states_config.blocking_states,
236            "phases": phases_config.phase_names(),
237            "dependency_types": deps_config.dep_type_names(),
238            "known_tags": tags_config.tag_names()
239        }
240    });
241
242    if !path_notes.is_empty() {
243        response["path_warnings"] = json!(path_notes);
244    }
245
246    if !tag_warnings.is_empty() {
247        response["tag_warnings"] = json!(tag_warnings);
248    }
249
250    // Deliver workflow-specific role information and prompts
251    if let Some(role_name) = workflows.match_role(&worker.tags) {
252        let mut role_info = json!({
253            "role": &role_name,
254        });
255
256        // Include role definition details
257        if let Some(role_def) = workflows.get_role(&role_name) {
258            if let Some(ref desc) = role_def.description {
259                role_info["description"] = json!(desc);
260            }
261            if let Some(max) = role_def.max_claims {
262                role_info["max_claims"] = json!(max);
263            }
264            if let Some(can_assign) = role_def.can_assign {
265                role_info["can_assign"] = json!(can_assign);
266            }
267        }
268
269        response["role"] = role_info;
270
271        // Include role-specific prompts
272        let prompts = workflows.get_role_prompts(&role_name);
273        if !prompts.is_empty() {
274            response["role_prompts"] = json!(prompts);
275        }
276    }
277
278    // Include workflow description if available
279    if let Some(ref desc) = workflows.description {
280        response["workflow_description"] = json!(desc);
281    }
282
283    Ok(response)
284}
285
286pub fn disconnect(db: &Database, states_config: &StatesConfig, args: Value) -> Result<Value> {
287    let worker_id =
288        get_string(&args, "worker_id").ok_or_else(|| ToolError::missing_field("worker_id"))?;
289
290    // Get final_status from args or fall back to config
291    let final_status =
292        get_string(&args, "final_status").unwrap_or_else(|| states_config.disconnect_state.clone());
293
294    // Validate final_status is untimed
295    if states_config.is_timed_state(&final_status) {
296        return Err(ToolError::invalid_value(
297            "final_status",
298            &format!(
299                "must be an untimed status, got '{}'. Valid statuses: {:?}",
300                final_status,
301                states_config.untimed_state_names()
302            ),
303        )
304        .into());
305    }
306
307    // Release worker locks before unregistering (close claim_sequence records)
308    let _ = db.release_worker_locks(&worker_id);
309
310    // Unregister and get summary
311    let summary = db.unregister_worker(&worker_id, &final_status)?;
312
313    Ok(json!({
314        "success": true,
315        "tasks_released": summary.tasks_released,
316        "files_released": summary.files_released,
317        "final_status": summary.final_status
318    }))
319}
320
321pub fn list_agents(
322    db: &Database,
323    states_config: &StatesConfig,
324    format: OutputFormat,
325    args: Value,
326) -> Result<ToolResult> {
327    // Extract filter parameters
328    let tags = get_string_array(&args, "tags");
329    let file = get_string(&args, "file");
330    let task = get_string(&args, "task");
331    let depth = get_i32(&args, "depth").unwrap_or(0).clamp(-3, 3);
332
333    // Auto-cleanup stale workers (default 5 minutes, 0 to disable)
334    let stale_timeout = get_i32(&args, "stale_timeout").unwrap_or(300);
335    let cleanup_summary = if stale_timeout > 0 {
336        let final_status = states_config.disconnect_state.clone();
337        db.cleanup_stale_workers(stale_timeout as i64, &final_status)
338            .ok()
339    } else {
340        None
341    };
342
343    // Get workers with filters
344    let workers =
345        db.list_workers_filtered(tags.as_ref(), file.as_deref(), task.as_deref(), depth)?;
346
347    // Get current time for heartbeat age calculation
348    let now = std::time::SystemTime::now()
349        .duration_since(std::time::UNIX_EPOCH)
350        .map(|d| d.as_millis() as i64)
351        .unwrap_or(0);
352
353    match format {
354        OutputFormat::Markdown => {
355            let mut output = String::new();
356            if let Some(ref summary) = cleanup_summary
357                && summary.workers_evicted > 0
358            {
359                output.push_str(&format!(
360                    "**Evicted {} stale worker(s)**: {} (released {} task(s), {} file(s))\n\n",
361                    summary.workers_evicted,
362                    summary.evicted_worker_ids.join(", "),
363                    summary.tasks_released,
364                    summary.files_released
365                ));
366            }
367            output.push_str(&format_workers_markdown(&workers));
368            Ok(ToolResult::Raw(output))
369        }
370        OutputFormat::Json => {
371            let mut result = json!({
372                "workers": workers.iter().map(|w| json!({
373                    "id": w.id,
374                    "tags": w.tags,
375                    "max_claims": w.max_claims,
376                    "claim_count": w.claim_count,
377                    "current_thought": w.current_thought,
378                    "registered_at": w.registered_at,
379                    "last_heartbeat": w.last_heartbeat,
380                    "heartbeat_age_ms": now - w.last_heartbeat,
381                    "workflow": w.workflow
382                })).collect::<Vec<_>>()
383            });
384
385            if let Some(summary) = cleanup_summary
386                && summary.workers_evicted > 0
387            {
388                result["cleanup"] = json!({
389                    "workers_evicted": summary.workers_evicted,
390                    "evicted_worker_ids": summary.evicted_worker_ids,
391                    "tasks_released": summary.tasks_released,
392                    "files_released": summary.files_released
393                });
394            }
395
396            Ok(ToolResult::Json(result))
397        }
398    }
399}
400
401pub fn cleanup_stale(db: &Database, states_config: &StatesConfig, args: Value) -> Result<Value> {
402    // Default timeout: 5 minutes
403    let timeout = get_i32(&args, "timeout").unwrap_or(300) as i64;
404
405    // Get final_status from args or fall back to config
406    let final_status =
407        get_string(&args, "final_status").unwrap_or_else(|| states_config.disconnect_state.clone());
408
409    // Validate final_status is untimed
410    if states_config.is_timed_state(&final_status) {
411        return Err(ToolError::invalid_value(
412            "final_status",
413            &format!(
414                "must be an untimed status, got '{}'. Valid statuses: {:?}",
415                final_status,
416                states_config.untimed_state_names()
417            ),
418        )
419        .into());
420    }
421
422    let summary = db.cleanup_stale_workers(timeout, &final_status)?;
423
424    Ok(json!({
425        "workers_evicted": summary.workers_evicted,
426        "evicted_worker_ids": summary.evicted_worker_ids,
427        "tasks_released": summary.tasks_released,
428        "files_released": summary.files_released,
429        "final_status": summary.final_status
430    }))
431}