Skip to main content

task_mcp/just/
mod.rs

1pub mod model;
2
3use std::collections::{HashMap, VecDeque};
4use std::path::{Path, PathBuf};
5use std::sync::Mutex;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7
8use regex::Regex;
9use tokio::process::Command;
10use uuid::Uuid;
11
12use crate::config::TaskMode;
13use crate::just::model::{
14    JustDump, JustRecipe, Recipe, TaskError, TaskExecution, TaskExecutionSummary,
15};
16
17// =============================================================================
18// Error
19// =============================================================================
20
21#[derive(Debug, thiserror::Error)]
22pub enum JustError {
23    #[error("just command not found: {0}")]
24    NotFound(String),
25    #[error("just command failed (exit {code}): {stderr}")]
26    CommandFailed { code: i32, stderr: String },
27    #[error("failed to parse just dump json: {0}")]
28    ParseError(#[from] serde_json::Error),
29    #[error("I/O error while reading justfile: {0}")]
30    Io(#[from] std::io::Error),
31}
32
33// =============================================================================
34// Public API
35// =============================================================================
36
37/// Discover recipes from the justfile at `justfile_path`.
38///
39/// Filtering behaviour depends on `mode`:
40/// - `TaskMode::AgentOnly`: only recipes marked agent-safe are returned.
41/// - `TaskMode::All`: all non-private recipes are returned.
42///
43/// Agent-safe detection: patten A (group attribute) first, pattern B (comment regex) as fallback.
44pub async fn list_recipes(
45    justfile_path: &Path,
46    mode: &TaskMode,
47    workdir: Option<&Path>,
48) -> Result<Vec<Recipe>, JustError> {
49    let dump = dump_json(justfile_path, workdir).await?;
50
51    // Read justfile text for pattern B fallback
52    let justfile_text = tokio::fs::read_to_string(justfile_path).await.ok();
53
54    // Build comment-tag set from justfile text (pattern B)
55    let comment_tagged = justfile_text
56        .as_deref()
57        .map(extract_comment_tagged_recipes)
58        .unwrap_or_default();
59
60    let mut recipes: Vec<Recipe> = dump
61        .recipes
62        .into_values()
63        .filter(|r| !r.private)
64        .map(|raw| {
65            let allow_agent = is_allow_agent(&raw, &comment_tagged);
66            Recipe::from_just_recipe(raw, allow_agent)
67        })
68        .collect();
69
70    // Sort by name for deterministic output
71    recipes.sort_by(|a, b| a.name.cmp(&b.name));
72
73    match mode {
74        TaskMode::AgentOnly => Ok(recipes.into_iter().filter(|r| r.allow_agent).collect()),
75        TaskMode::All => Ok(recipes),
76    }
77}
78
79// =============================================================================
80// Internal helpers
81// =============================================================================
82
83/// Run `just --dump --dump-format json --unstable` and return parsed output.
84async fn dump_json(justfile_path: &Path, workdir: Option<&Path>) -> Result<JustDump, JustError> {
85    let mut cmd = Command::new("just");
86    cmd.arg("--justfile")
87        .arg(justfile_path)
88        .arg("--dump")
89        .arg("--dump-format")
90        .arg("json")
91        .arg("--unstable");
92    if let Some(dir) = workdir {
93        cmd.current_dir(dir);
94    }
95    let output = cmd
96        .output()
97        .await
98        .map_err(|e| JustError::NotFound(e.to_string()))?;
99
100    if !output.status.success() {
101        let code = output.status.code().unwrap_or(-1);
102        let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
103        return Err(JustError::CommandFailed { code, stderr });
104    }
105
106    let json_str = String::from_utf8_lossy(&output.stdout);
107    let dump: JustDump = serde_json::from_str(&json_str)?;
108    Ok(dump)
109}
110
111/// Pattern A: check if the recipe has a `[group: 'agent']` attribute.
112fn has_group_agent_attribute(recipe: &JustRecipe) -> bool {
113    recipe.attributes.iter().any(|a| {
114        a.group
115            .as_ref()
116            .map(|g| g.as_str() == "agent")
117            .unwrap_or(false)
118    })
119}
120
121/// Pattern B: extract recipe names that are preceded by a `# [allow-agent]` comment.
122///
123/// Regex matches lines like (justfile syntax, not Rust):
124/// ```text
125/// # [allow-agent]
126/// recipe_name:
127/// ```
128fn extract_comment_tagged_recipes(justfile_text: &str) -> std::collections::HashSet<String> {
129    // Allow-agent comment followed (with optional blank/doc lines) by a recipe line
130    // Simple approach: line-by-line scan with state
131    let allow_agent_re = Regex::new(r"^\s*#\s*\[allow-agent\]").expect("valid regex");
132    let recipe_name_re = Regex::new(r"^([a-zA-Z0-9_-]+)\s*(?:\S.*)?:").expect("valid regex");
133
134    let mut tagged = std::collections::HashSet::new();
135    let mut saw_allow_agent = false;
136
137    for line in justfile_text.lines() {
138        if allow_agent_re.is_match(line) {
139            saw_allow_agent = true;
140            continue;
141        }
142
143        if saw_allow_agent {
144            // Skip doc comment lines and blank lines between tag and recipe
145            let trimmed = line.trim();
146            if trimmed.is_empty() || trimmed.starts_with('#') {
147                continue;
148            }
149
150            // Try to extract recipe name
151            if let Some(cap) = recipe_name_re.captures(line) {
152                let name = cap[1].to_string();
153                // Exclude attribute lines like [group: 'agent']
154                if !name.starts_with('[') {
155                    tagged.insert(name);
156                }
157            }
158            // Once we see a non-comment, non-blank line, reset state
159            saw_allow_agent = false;
160        }
161    }
162
163    tagged
164}
165
166/// Determine if a recipe is agent-safe using pattern A first, then pattern B.
167fn is_allow_agent(recipe: &JustRecipe, comment_tagged: &std::collections::HashSet<String>) -> bool {
168    has_group_agent_attribute(recipe) || comment_tagged.contains(&recipe.name)
169}
170
171/// Resolve justfile path from an optional override, workdir, or the current directory.
172pub fn resolve_justfile_path(override_path: Option<&str>, workdir: Option<&Path>) -> PathBuf {
173    match override_path {
174        Some(p) => PathBuf::from(p),
175        None => match workdir {
176            Some(dir) => dir.join("justfile"),
177            None => PathBuf::from("justfile"),
178        },
179    }
180}
181
182// =============================================================================
183// Output truncation
184// =============================================================================
185
186const MAX_OUTPUT_BYTES: usize = 100 * 1024; // 100 KB
187const HEAD_BYTES: usize = 50 * 1024; // 50 KB
188const TAIL_BYTES: usize = 50 * 1024; // 50 KB
189
190/// Truncate output to at most `MAX_OUTPUT_BYTES`.
191///
192/// If truncation is necessary the result contains:
193/// `{head}\n...[truncated {n} bytes]...\n{tail}`
194///
195/// UTF-8 multi-byte boundaries are respected — the slice points are adjusted
196/// so that we never split a multi-byte character.
197pub fn truncate_output(output: &str) -> (String, bool) {
198    if output.len() <= MAX_OUTPUT_BYTES {
199        return (output.to_string(), false);
200    }
201
202    // Find safe byte boundary for the head (≤ HEAD_BYTES)
203    let head_end = safe_byte_boundary(output, HEAD_BYTES);
204    // Find safe byte boundary for the tail (last TAIL_BYTES)
205    let tail_start_raw = output.len().saturating_sub(TAIL_BYTES);
206    let tail_start = safe_tail_start(output, tail_start_raw);
207
208    let head = &output[..head_end];
209    let tail = &output[tail_start..];
210    let truncated_bytes = output.len() - head_end - (output.len() - tail_start);
211
212    (
213        format!("{head}\n...[truncated {truncated_bytes} bytes]...\n{tail}"),
214        true,
215    )
216}
217
218/// Find the largest byte index `<= limit` that lies on a UTF-8 character boundary.
219fn safe_byte_boundary(s: &str, limit: usize) -> usize {
220    if limit >= s.len() {
221        return s.len();
222    }
223    // Walk backwards from `limit` until we hit a valid char boundary
224    let mut idx = limit;
225    while idx > 0 && !s.is_char_boundary(idx) {
226        idx -= 1;
227    }
228    idx
229}
230
231/// Find the smallest byte index `>= hint` that lies on a UTF-8 character boundary.
232fn safe_tail_start(s: &str, hint: usize) -> usize {
233    if hint >= s.len() {
234        return s.len();
235    }
236    let mut idx = hint;
237    while idx < s.len() && !s.is_char_boundary(idx) {
238        idx += 1;
239    }
240    idx
241}
242
243// =============================================================================
244// Argument validation
245// =============================================================================
246
247/// Reject argument values that contain shell meta-characters.
248///
249/// `tokio::process::Command` bypasses the shell, but `just` itself invokes a
250/// shell interpreter for recipe bodies.  Validating inputs here prevents
251/// injection attacks in case a recipe passes an argument through to the shell.
252pub fn validate_arg_value(value: &str) -> Result<(), TaskError> {
253    const DANGEROUS: &[&str] = &[";", "|", "&&", "||", "`", "$(", "${", "\n", "\r"];
254    for pattern in DANGEROUS {
255        if value.contains(pattern) {
256            return Err(TaskError::DangerousArgument(value.to_string()));
257        }
258    }
259    Ok(())
260}
261
262// =============================================================================
263// Recipe execution
264// =============================================================================
265
266/// Execute a recipe by name, passing `args` as positional parameters.
267///
268/// Steps:
269/// 1. Confirm the recipe exists in `list_recipes(justfile_path, mode)`.
270/// 2. Validate each argument value for dangerous characters.
271/// 3. Run `just --justfile {path} {recipe_name} {arg_values...}` with a
272///    timeout.
273/// 4. Capture stdout/stderr and apply truncation.
274/// 5. Return a `TaskExecution` record.
275pub async fn execute_recipe(
276    recipe_name: &str,
277    args: &HashMap<String, String>,
278    justfile_path: &Path,
279    timeout: Duration,
280    mode: &TaskMode,
281    workdir: Option<&Path>,
282) -> Result<TaskExecution, TaskError> {
283    // 1. Whitelist check
284    let recipes = list_recipes(justfile_path, mode, workdir).await?;
285    let recipe = recipes
286        .iter()
287        .find(|r| r.name == recipe_name)
288        .ok_or_else(|| TaskError::RecipeNotFound(recipe_name.to_string()))?;
289
290    // 2. Argument validation
291    for value in args.values() {
292        validate_arg_value(value)?;
293    }
294
295    // Build positional argument list in parameter definition order
296    let positional: Vec<&str> = recipe
297        .parameters
298        .iter()
299        .filter_map(|p| args.get(&p.name).map(|v| v.as_str()))
300        .collect();
301
302    // 3. Construct and run the command
303    let started_at = SystemTime::now()
304        .duration_since(UNIX_EPOCH)
305        .unwrap_or_default()
306        .as_secs();
307    let start_instant = std::time::Instant::now();
308
309    let mut cmd = Command::new("just");
310    cmd.arg("--justfile").arg(justfile_path).arg(recipe_name);
311    for arg in &positional {
312        cmd.arg(arg);
313    }
314    if let Some(dir) = workdir {
315        cmd.current_dir(dir);
316    }
317
318    let run_result = tokio::time::timeout(timeout, cmd.output()).await;
319
320    let duration_ms = start_instant.elapsed().as_millis() as u64;
321
322    let output = match run_result {
323        Err(_) => return Err(TaskError::Timeout),
324        Ok(Err(io_err)) => return Err(TaskError::Io(io_err)),
325        Ok(Ok(out)) => out,
326    };
327
328    let exit_code = output.status.code();
329
330    // 4. Capture + truncate
331    let raw_stdout = String::from_utf8_lossy(&output.stdout).into_owned();
332    let raw_stderr = String::from_utf8_lossy(&output.stderr).into_owned();
333
334    let (stdout, stdout_truncated) = truncate_output(&raw_stdout);
335    let (stderr, stderr_truncated) = truncate_output(&raw_stderr);
336    let truncated = stdout_truncated || stderr_truncated;
337
338    // 5. Return execution record
339    Ok(TaskExecution {
340        id: Uuid::new_v4().to_string(),
341        task_name: recipe_name.to_string(),
342        args: args.clone(),
343        exit_code,
344        stdout,
345        stderr,
346        started_at,
347        duration_ms,
348        truncated,
349    })
350}
351
352// =============================================================================
353// Task log store
354// =============================================================================
355
356/// In-memory ring buffer of recent task executions.
357///
358/// `Arc<TaskLogStore>` is `Clone` because `Arc<T>` implements `Clone` for any
359/// `T: ?Sized`.  `TaskLogStore` itself does not need to implement `Clone`.
360pub struct TaskLogStore {
361    logs: Mutex<VecDeque<TaskExecution>>,
362    max_entries: usize,
363}
364
365impl TaskLogStore {
366    pub fn new(max_entries: usize) -> Self {
367        Self {
368            logs: Mutex::new(VecDeque::new()),
369            max_entries,
370        }
371    }
372
373    /// Append an execution record, evicting the oldest entry when full.
374    pub fn push(&self, execution: TaskExecution) {
375        let mut guard = self.logs.lock().expect("log store lock poisoned");
376        if guard.len() >= self.max_entries {
377            guard.pop_front();
378        }
379        guard.push_back(execution);
380    }
381
382    /// Look up a specific execution by ID.  Returns a clone.
383    pub fn get(&self, id: &str) -> Option<TaskExecution> {
384        let guard = self.logs.lock().expect("log store lock poisoned");
385        guard.iter().find(|e| e.id == id).cloned()
386    }
387
388    /// Return summaries of the most recent `n` executions (newest first).
389    pub fn recent(&self, n: usize) -> Vec<TaskExecutionSummary> {
390        let guard = self.logs.lock().expect("log store lock poisoned");
391        guard
392            .iter()
393            .rev()
394            .take(n)
395            .map(TaskExecutionSummary::from_execution)
396            .collect()
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use crate::just::model::RecipeAttribute;
404
405    fn make_recipe(name: &str, attributes: Vec<RecipeAttribute>) -> JustRecipe {
406        crate::just::model::JustRecipe {
407            name: name.to_string(),
408            namepath: name.to_string(),
409            doc: None,
410            attributes,
411            parameters: vec![],
412            private: false,
413            quiet: false,
414        }
415    }
416
417    #[test]
418    fn has_group_agent_attribute_true() {
419        let recipe = make_recipe(
420            "build",
421            vec![RecipeAttribute {
422                group: Some("agent".to_string()),
423            }],
424        );
425        assert!(has_group_agent_attribute(&recipe));
426    }
427
428    #[test]
429    fn has_group_agent_attribute_false_no_attrs() {
430        let recipe = make_recipe("deploy", vec![]);
431        assert!(!has_group_agent_attribute(&recipe));
432    }
433
434    #[test]
435    fn has_group_agent_attribute_false_other_group() {
436        let recipe = make_recipe(
437            "build",
438            vec![RecipeAttribute {
439                group: Some("ci".to_string()),
440            }],
441        );
442        assert!(!has_group_agent_attribute(&recipe));
443    }
444
445    #[test]
446    fn extract_comment_tagged_recipes_basic() {
447        let text = "# [allow-agent]\nbuild:\n    cargo build\n\ndeploy:\n    ./deploy.sh\n";
448        let tagged = extract_comment_tagged_recipes(text);
449        assert!(tagged.contains("build"), "build should be tagged");
450        assert!(!tagged.contains("deploy"), "deploy should not be tagged");
451    }
452
453    #[test]
454    fn extract_comment_tagged_recipes_with_doc_comment() {
455        let text = "# [allow-agent]\n# Run tests\ntest filter=\"\":\n    cargo test {{filter}}\n\ndeploy:\n    ./deploy.sh\n";
456        let tagged = extract_comment_tagged_recipes(text);
457        assert!(tagged.contains("test"), "test should be tagged");
458        assert!(!tagged.contains("deploy"));
459    }
460
461    #[test]
462    fn extract_comment_tagged_recipes_multiple() {
463        let text = "# [allow-agent]\nbuild:\n    cargo build\n\n# [allow-agent]\ninfo:\n    echo info\n\ndeploy:\n    ./deploy.sh\n";
464        let tagged = extract_comment_tagged_recipes(text);
465        assert!(tagged.contains("build"));
466        assert!(tagged.contains("info"));
467        assert!(!tagged.contains("deploy"));
468    }
469
470    #[test]
471    fn is_allow_agent_pattern_a() {
472        let tagged = std::collections::HashSet::new();
473        // not in comment set
474        let recipe = make_recipe(
475            "build",
476            vec![RecipeAttribute {
477                group: Some("agent".to_string()),
478            }],
479        );
480        assert!(is_allow_agent(&recipe, &tagged));
481    }
482
483    #[test]
484    fn is_allow_agent_pattern_b() {
485        let mut tagged = std::collections::HashSet::new();
486        tagged.insert("build".to_string());
487        let recipe = make_recipe("build", vec![]);
488        assert!(is_allow_agent(&recipe, &tagged));
489    }
490
491    #[test]
492    fn is_allow_agent_neither() {
493        let tagged = std::collections::HashSet::new();
494        let recipe = make_recipe("deploy", vec![]);
495        assert!(!is_allow_agent(&recipe, &tagged));
496    }
497
498    #[test]
499    fn resolve_justfile_path_override() {
500        let p = resolve_justfile_path(Some("/custom/justfile"), None);
501        assert_eq!(p, PathBuf::from("/custom/justfile"));
502    }
503
504    #[test]
505    fn resolve_justfile_path_default() {
506        let p = resolve_justfile_path(None, None);
507        assert_eq!(p, PathBuf::from("justfile"));
508    }
509
510    #[test]
511    fn resolve_justfile_path_with_workdir() {
512        let workdir = Path::new("/some/project");
513        let p = resolve_justfile_path(None, Some(workdir));
514        assert_eq!(p, PathBuf::from("/some/project/justfile"));
515    }
516
517    #[test]
518    fn resolve_justfile_path_override_ignores_workdir() {
519        // override_path takes precedence over workdir
520        let workdir = Path::new("/some/project");
521        let p = resolve_justfile_path(Some("/custom/justfile"), Some(workdir));
522        assert_eq!(p, PathBuf::from("/custom/justfile"));
523    }
524
525    // -------------------------------------------------------------------------
526    // truncate_output tests
527    // -------------------------------------------------------------------------
528
529    #[test]
530    fn truncate_output_short_input_unchanged() {
531        let input = "hello";
532        let (result, truncated) = truncate_output(input);
533        assert!(!truncated);
534        assert_eq!(result, input);
535    }
536
537    #[test]
538    fn truncate_output_long_input_truncated() {
539        // Create a string longer than MAX_OUTPUT_BYTES (100 KB)
540        let input = "x".repeat(200 * 1024);
541        let (result, truncated) = truncate_output(&input);
542        assert!(truncated);
543        assert!(result.contains("...[truncated"));
544        // Result should be smaller than the input
545        assert!(result.len() < input.len());
546    }
547
548    #[test]
549    fn truncate_output_utf8_boundary() {
550        // Build a string that is just over HEAD_BYTES using multi-byte chars
551        // Each '日' is 3 bytes; we need HEAD_BYTES+1 bytes to trigger truncation
552        let char_3bytes = '日';
553        // Fill slightly above MAX_OUTPUT_BYTES boundary
554        let count = (MAX_OUTPUT_BYTES / 3) + 10;
555        let input: String = std::iter::repeat(char_3bytes).take(count).collect();
556        let (result, truncated) = truncate_output(&input);
557        // Verify the result is valid UTF-8 (no panic = success)
558        assert!(std::str::from_utf8(result.as_bytes()).is_ok());
559        if truncated {
560            assert!(result.contains("...[truncated"));
561        }
562    }
563
564    // -------------------------------------------------------------------------
565    // validate_arg_value tests
566    // -------------------------------------------------------------------------
567
568    #[test]
569    fn validate_arg_value_safe_values() {
570        assert!(validate_arg_value("hello world").is_ok());
571        assert!(validate_arg_value("value_123-abc").is_ok());
572        assert!(validate_arg_value("path/to/file.txt").is_ok());
573    }
574
575    #[test]
576    fn validate_arg_value_semicolon_rejected() {
577        assert!(validate_arg_value("foo; rm -rf /").is_err());
578    }
579
580    #[test]
581    fn validate_arg_value_pipe_rejected() {
582        assert!(validate_arg_value("foo | cat /etc/passwd").is_err());
583    }
584
585    #[test]
586    fn validate_arg_value_and_and_rejected() {
587        assert!(validate_arg_value("foo && evil").is_err());
588    }
589
590    #[test]
591    fn validate_arg_value_backtick_rejected() {
592        assert!(validate_arg_value("foo`id`").is_err());
593    }
594
595    #[test]
596    fn validate_arg_value_dollar_paren_rejected() {
597        assert!(validate_arg_value("$(id)").is_err());
598    }
599
600    #[test]
601    fn validate_arg_value_newline_rejected() {
602        assert!(validate_arg_value("foo\nbar").is_err());
603    }
604
605    // -------------------------------------------------------------------------
606    // TaskLogStore tests
607    // -------------------------------------------------------------------------
608
609    fn make_execution(id: &str, task_name: &str) -> TaskExecution {
610        TaskExecution {
611            id: id.to_string(),
612            task_name: task_name.to_string(),
613            args: HashMap::new(),
614            exit_code: Some(0),
615            stdout: "".to_string(),
616            stderr: "".to_string(),
617            started_at: 0,
618            duration_ms: 0,
619            truncated: false,
620        }
621    }
622
623    #[test]
624    fn task_log_store_push_and_get() {
625        let store = TaskLogStore::new(10);
626        let exec = make_execution("id-1", "build");
627        store.push(exec);
628        let retrieved = store.get("id-1").expect("should find id-1");
629        assert_eq!(retrieved.task_name, "build");
630    }
631
632    #[test]
633    fn task_log_store_get_missing() {
634        let store = TaskLogStore::new(10);
635        assert!(store.get("nonexistent").is_none());
636    }
637
638    #[test]
639    fn task_log_store_evicts_oldest_when_full() {
640        let store = TaskLogStore::new(3);
641        store.push(make_execution("id-1", "a"));
642        store.push(make_execution("id-2", "b"));
643        store.push(make_execution("id-3", "c"));
644        store.push(make_execution("id-4", "d")); // evicts id-1
645        assert!(store.get("id-1").is_none(), "id-1 should be evicted");
646        assert!(store.get("id-4").is_some(), "id-4 should exist");
647    }
648
649    #[test]
650    fn task_log_store_recent_newest_first() {
651        let store = TaskLogStore::new(10);
652        store.push(make_execution("id-1", "a"));
653        store.push(make_execution("id-2", "b"));
654        store.push(make_execution("id-3", "c"));
655        let recent = store.recent(2);
656        assert_eq!(recent.len(), 2);
657        assert_eq!(recent[0].id, "id-3", "newest should be first");
658        assert_eq!(recent[1].id, "id-2");
659    }
660
661    #[test]
662    fn task_log_store_recent_n_larger_than_store() {
663        let store = TaskLogStore::new(10);
664        store.push(make_execution("id-1", "a"));
665        let recent = store.recent(5);
666        assert_eq!(recent.len(), 1);
667    }
668}