Skip to main content

zagens_core/engine/
dispatch.rs

1//! Tool-input parsing and batch policy helpers (P2 PR4).
2//!
3//! Types that tie to the live engine batch driver (`ToolExecutionPlan`, lock
4//! guards) remain in `deepseek-runtime` engine dispatch glue.
5
6use serde_json::{Value, json};
7
8use zagens_tools::{ToolError, ToolResult};
9
10use super::streaming::ToolUseState;
11use crate::chat::{Tool, ToolCaller};
12
13/// Parallel-batch eligibility for one planned tool invocation.
14#[derive(Debug, Clone, Copy)]
15pub struct ToolParallelPlanFlags {
16    pub read_only: bool,
17    pub supports_parallel: bool,
18    pub approval_required: bool,
19    pub interactive: bool,
20}
21
22/// Promote a streaming `ToolUseState` to a finalized JSON input.
23#[must_use]
24pub fn final_tool_input(state: &ToolUseState) -> Value {
25    if !state.input_buffer.trim().is_empty()
26        && let Some(parsed) = parse_tool_input_json(&state.input_buffer)
27    {
28        return parsed;
29    }
30    state.input.clone()
31}
32
33/// Parse streamed tool arguments (fences, segments, double-encoded JSON).
34///
35/// Callers that need TUI `arg_repair` should try that first, then fall back to
36/// this function.
37#[must_use]
38pub fn parse_tool_input_json(buffer: &str) -> Option<Value> {
39    let trimmed = buffer.trim();
40    if trimmed.is_empty() {
41        return None;
42    }
43    if let Some(stripped) = strip_code_fences(trimmed)
44        && let Ok(value) = serde_json::from_str::<Value>(&stripped)
45    {
46        return Some(value);
47    }
48    if let Ok(Value::String(inner)) = serde_json::from_str::<Value>(trimmed)
49        && let Ok(value) = serde_json::from_str::<Value>(&inner)
50    {
51        return Some(value);
52    }
53    extract_json_segment(trimmed).and_then(|segment| serde_json::from_str::<Value>(&segment).ok())
54}
55
56#[must_use]
57pub fn caller_type_for_tool_use(caller: Option<&ToolCaller>) -> &str {
58    caller.map_or("direct", |c| c.caller_type.as_str())
59}
60
61#[must_use]
62pub fn caller_allowed_for_tool(caller: Option<&ToolCaller>, tool_def: Option<&Tool>) -> bool {
63    let requested = caller_type_for_tool_use(caller);
64    if let Some(def) = tool_def
65        && let Some(allowed) = &def.allowed_callers
66    {
67        if allowed.is_empty() {
68            return requested == "direct";
69        }
70        return allowed.iter().any(|item| item == requested);
71    }
72    requested == "direct"
73}
74
75#[must_use]
76pub fn format_tool_error(err: &ToolError, tool_name: &str) -> String {
77    match err {
78        ToolError::InvalidInput { message } => {
79            format!("Invalid input for tool '{tool_name}': {message}")
80        }
81        ToolError::MissingField { field } => {
82            format!("Tool '{tool_name}' is missing required field '{field}'")
83        }
84        ToolError::PathEscape { path } => format!(
85            "Path escapes workspace: {}. Use a workspace-relative path or enable trust mode.",
86            path.display()
87        ),
88        ToolError::ExecutionFailed { message } => message.clone(),
89        ToolError::Timeout { seconds } => format!(
90            "Tool '{tool_name}' timed out after {seconds}s. Try a narrower scope or a longer timeout."
91        ),
92        ToolError::NotAvailable { message } => {
93            let lower = message.to_ascii_lowercase();
94            if lower.contains("current tool catalog") || lower.contains("did you mean:") {
95                message.clone()
96            } else {
97                format!(
98                    "Tool '{tool_name}' is not available: {message}. Check mode, feature flags, or tool name."
99                )
100            }
101        }
102        ToolError::PermissionDenied { message } => format!(
103            "Tool '{tool_name}' was denied: {message}. Adjust approval mode or request permission."
104        ),
105    }
106}
107
108pub fn parse_parallel_tool_calls(input: &Value) -> Result<Vec<(String, Value)>, ToolError> {
109    let tool_uses = input
110        .get("tool_uses")
111        .and_then(|v| v.as_array())
112        .ok_or_else(|| ToolError::missing_field("tool_uses"))?;
113    if tool_uses.is_empty() {
114        return Err(ToolError::invalid_input(
115            "multi_tool_use.parallel requires at least one tool call",
116        ));
117    }
118
119    let mut calls = Vec::with_capacity(tool_uses.len());
120    for item in tool_uses {
121        let name = item
122            .get("recipient_name")
123            .or_else(|| item.get("tool_name"))
124            .or_else(|| item.get("name"))
125            .or_else(|| item.get("tool"))
126            .and_then(|v| v.as_str())
127            .ok_or_else(|| ToolError::missing_field("recipient_name"))?;
128        let params = item
129            .get("parameters")
130            .or_else(|| item.get("input"))
131            .or_else(|| item.get("args"))
132            .or_else(|| item.get("arguments"))
133            .cloned()
134            .unwrap_or_else(|| json!({}));
135        calls.push((normalize_parallel_tool_name(name), params));
136    }
137
138    Ok(calls)
139}
140
141#[must_use]
142pub fn should_parallelize_tool_batch(plans: &[ToolParallelPlanFlags]) -> bool {
143    !plans.is_empty()
144        && plans.iter().all(|plan| {
145            plan.read_only && plan.supports_parallel && !plan.approval_required && !plan.interactive
146        })
147}
148
149#[must_use]
150pub fn should_stop_after_plan_tool(
151    is_plan_mode: bool,
152    tool_name: &str,
153    result: &Result<ToolResult, ToolError>,
154) -> bool {
155    is_plan_mode && tool_name == "update_plan" && result.is_ok()
156}
157
158#[must_use]
159pub fn should_force_update_plan_first(is_plan_mode: bool, content: &str) -> bool {
160    if !is_plan_mode {
161        return false;
162    }
163
164    let lower = content.to_ascii_lowercase();
165    let asks_for_direct_plan = [
166        "quick plan",
167        "short plan",
168        "simple plan",
169        "3-step plan",
170        "3 step plan",
171        "three-step plan",
172        "three step plan",
173        "high-level plan",
174        "high level plan",
175        "give me a plan",
176        "make a plan",
177        "outline a plan",
178        "draft a plan",
179    ]
180    .iter()
181    .any(|needle| lower.contains(needle));
182
183    if !asks_for_direct_plan {
184        return false;
185    }
186
187    let asks_for_repo_exploration = [
188        "inspect the repo",
189        "inspect the code",
190        "explore the repo",
191        "search the repo",
192        "read the code",
193        "review the code",
194        "analyze the code",
195        "investigate",
196        "look through",
197        "understand the current",
198        "ground it in the codebase",
199        "based on the codebase",
200    ]
201    .iter()
202    .any(|needle| lower.contains(needle));
203
204    !asks_for_repo_exploration
205}
206
207/// Whether `name` is dispatched through the MCP pool (vs. the native
208/// `ToolRegistry`). Mirrors the body of
209/// `tui::mcp::McpPool::is_mcp_tool` so the core turn loop and
210/// [`McpHost`](crate::engine::hosts::McpHost) implementations can
211/// answer the same question without depending on the tui crate.
212///
213/// **Drift guard**: the
214/// `is_mcp_tool_name_matches_tui_mcp_pool` test in `tui::mcp` asserts
215/// this function and `McpPool::is_mcp_tool` produce identical output
216/// on a curated name set.
217#[must_use]
218pub fn is_mcp_tool_name(name: &str) -> bool {
219    name.starts_with("mcp_")
220        || matches!(
221            name,
222            "list_mcp_resources" | "list_mcp_resource_templates" | "read_mcp_resource"
223        )
224}
225
226#[must_use]
227pub fn mcp_tool_is_parallel_safe(name: &str) -> bool {
228    matches!(
229        name,
230        "list_mcp_resources"
231            | "list_mcp_resource_templates"
232            | "mcp_read_resource"
233            | "read_mcp_resource"
234            | "mcp_get_prompt"
235    )
236}
237
238#[must_use]
239pub fn mcp_tool_is_read_only(name: &str) -> bool {
240    matches!(
241        name,
242        "list_mcp_resources"
243            | "list_mcp_resource_templates"
244            | "mcp_read_resource"
245            | "read_mcp_resource"
246            | "mcp_get_prompt"
247    )
248}
249
250#[must_use]
251pub fn mcp_tool_approval_description(name: &str) -> String {
252    if mcp_tool_is_read_only(name) {
253        format!("Read-only MCP tool '{name}'")
254    } else {
255        format!("MCP tool '{name}' may have side effects")
256    }
257}
258
259fn strip_code_fences(text: &str) -> Option<String> {
260    if !text.contains("```") {
261        return None;
262    }
263    let mut lines = Vec::new();
264    for line in text.lines() {
265        if line.trim_start().starts_with("```") {
266            continue;
267        }
268        lines.push(line);
269    }
270    let stripped = lines.join("\n");
271    let stripped = stripped.trim();
272    if stripped.is_empty() {
273        None
274    } else {
275        Some(stripped.to_string())
276    }
277}
278
279fn extract_json_segment(text: &str) -> Option<String> {
280    extract_balanced_segment(text, '{', '}').or_else(|| extract_balanced_segment(text, '[', ']'))
281}
282
283fn extract_balanced_segment(text: &str, open: char, close: char) -> Option<String> {
284    let start = text.find(open)?;
285    let mut depth = 0i32;
286    let mut end = None;
287    for (offset, ch) in text[start..].char_indices() {
288        if ch == open {
289            depth += 1;
290        } else if ch == close {
291            depth -= 1;
292            if depth == 0 {
293                end = Some(start + offset + ch.len_utf8());
294                break;
295            }
296        }
297    }
298    end.map(|end_idx| text[start..end_idx].to_string())
299}
300
301fn normalize_parallel_tool_name(raw: &str) -> String {
302    let mut name = raw.trim();
303    for prefix in ["functions.", "tools.", "tool."] {
304        if let Some(stripped) = name.strip_prefix(prefix) {
305            name = stripped;
306            break;
307        }
308    }
309    name.to_string()
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn parallel_batch_requires_read_only_parallel_tools() {
318        let ok = ToolParallelPlanFlags {
319            read_only: true,
320            supports_parallel: true,
321            approval_required: false,
322            interactive: false,
323        };
324        assert!(should_parallelize_tool_batch(&[ok, ok]));
325        assert!(!should_parallelize_tool_batch(&[ToolParallelPlanFlags {
326            read_only: false,
327            ..ok
328        }]));
329    }
330
331    #[test]
332    fn plan_mode_stops_after_update_plan() {
333        assert!(should_stop_after_plan_tool(
334            true,
335            "update_plan",
336            &Ok(ToolResult::success("ok"))
337        ));
338        assert!(!should_stop_after_plan_tool(
339            false,
340            "update_plan",
341            &Ok(ToolResult::success("ok"))
342        ));
343    }
344
345    #[test]
346    fn is_mcp_tool_name_covers_prefix_and_resource_helpers() {
347        assert!(is_mcp_tool_name("mcp_filesystem_read"));
348        assert!(is_mcp_tool_name("mcp_git_status"));
349        assert!(is_mcp_tool_name("list_mcp_resources"));
350        assert!(is_mcp_tool_name("list_mcp_resource_templates"));
351        assert!(is_mcp_tool_name("read_mcp_resource"));
352        assert!(!is_mcp_tool_name("read_file"));
353        assert!(!is_mcp_tool_name("exec_shell"));
354        assert!(!is_mcp_tool_name(""));
355    }
356
357    #[test]
358    fn quick_plan_forces_update_plan_first() {
359        assert!(should_force_update_plan_first(
360            true,
361            "Give me a quick 3-step plan."
362        ));
363        assert!(!should_force_update_plan_first(
364            true,
365            "Inspect the repo and give me a quick plan."
366        ));
367    }
368}