1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4use std::time::Duration;
5
6use anyhow::Result;
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use tokio::sync::RwLock;
11use zagens_protocol::{ToolKind, ToolOutput, ToolPayload};
12
13mod dag_scheduler;
14mod policy_engine;
15mod resource_locks;
16mod tool_manifest;
17
18pub use dag_scheduler::{
19 DagPlanView, ScheduleResource, SchedulerShadowStats, build_execution_waves,
20 record_scheduler_shadow_diff, scheduler_shadow_stats, wave_parallel_eligible,
21};
22pub use policy_engine::{
23 ApprovalNeed, ParallelResourceKey, PolicyDecision, PolicyEngine, PolicyInput, PolicyPlanMeta,
24 PolicySessionMode, PolicyShadowStats, SandboxClass, policy_shadow_stats,
25 record_policy_shadow_diff,
26};
27pub use resource_locks::{ResourceLockMode, resource_lock_order, resource_lock_targets};
28pub use tool_manifest::{
29 Footprint, FootprintProvenance, ResourceSet, SpawnClass, ToolManifest,
30 derive_conservative_footprint,
31};
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35pub enum ToolCapability {
36 ReadOnly,
38 WritesFiles,
40 ExecutesCode,
42 Network,
44 Sandboxable,
46 RequiresApproval,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
52pub enum ApprovalRequirement {
53 #[default]
55 Auto,
56 Suggest,
58 Required,
60}
61
62#[derive(Debug, Clone)]
64pub enum ToolError {
65 InvalidInput { message: String },
66 MissingField { field: String },
67 PathEscape { path: PathBuf },
68 ExecutionFailed { message: String },
69 Timeout { seconds: u64 },
70 NotAvailable { message: String },
71 PermissionDenied { message: String },
72}
73
74impl std::fmt::Display for ToolError {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 match self {
77 Self::InvalidInput { message } => {
78 write!(f, "Failed to validate input: {message}")
79 }
80 Self::MissingField { field } => {
81 write!(
82 f,
83 "Failed to validate input: missing required field '{field}'"
84 )
85 }
86 Self::PathEscape { path } => {
87 write!(
88 f,
89 "Failed to resolve path '{}': path escapes workspace",
90 path.display()
91 )
92 }
93 Self::ExecutionFailed { message } => {
94 write!(f, "Failed to execute tool: {message}")
95 }
96 Self::Timeout { seconds } => {
97 write!(
98 f,
99 "Failed to execute tool: operation timed out after {seconds}s"
100 )
101 }
102 Self::NotAvailable { message } => {
103 write!(f, "Failed to locate tool: {message}")
104 }
105 Self::PermissionDenied { message } => {
106 write!(f, "Failed to authorize tool execution: {message}")
107 }
108 }
109 }
110}
111
112impl std::error::Error for ToolError {}
113
114impl ToolError {
115 #[must_use]
116 pub fn invalid_input(msg: impl Into<String>) -> Self {
117 Self::InvalidInput {
118 message: msg.into(),
119 }
120 }
121
122 #[must_use]
123 pub fn missing_field(field: impl Into<String>) -> Self {
124 Self::MissingField {
125 field: field.into(),
126 }
127 }
128
129 #[must_use]
130 pub fn execution_failed(msg: impl Into<String>) -> Self {
131 Self::ExecutionFailed {
132 message: msg.into(),
133 }
134 }
135
136 #[must_use]
137 pub fn path_escape(path: impl Into<PathBuf>) -> Self {
138 Self::PathEscape { path: path.into() }
139 }
140
141 #[must_use]
142 pub fn not_available(msg: impl Into<String>) -> Self {
143 Self::NotAvailable {
144 message: msg.into(),
145 }
146 }
147
148 #[must_use]
149 pub fn permission_denied(msg: impl Into<String>) -> Self {
150 Self::PermissionDenied {
151 message: msg.into(),
152 }
153 }
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct ToolResult {
159 pub content: String,
161 pub success: bool,
163 #[serde(skip_serializing_if = "Option::is_none")]
165 pub metadata: Option<Value>,
166}
167
168impl ToolResult {
169 #[must_use]
171 pub fn success(content: impl Into<String>) -> Self {
172 Self {
173 content: content.into(),
174 success: true,
175 metadata: None,
176 }
177 }
178
179 #[must_use]
181 pub fn error(message: impl Into<String>) -> Self {
182 Self {
183 content: message.into(),
184 success: false,
185 metadata: None,
186 }
187 }
188
189 pub fn json<T: Serialize>(value: &T) -> std::result::Result<Self, serde_json::Error> {
191 Ok(Self {
192 content: serde_json::to_string_pretty(value)?,
193 success: true,
194 metadata: None,
195 })
196 }
197
198 #[must_use]
200 pub fn with_metadata(mut self, metadata: Value) -> Self {
201 self.metadata = Some(metadata);
202 self
203 }
204}
205
206pub fn required_str<'a>(input: &'a Value, field: &str) -> std::result::Result<&'a str, ToolError> {
208 input.get(field).and_then(Value::as_str).ok_or_else(|| {
209 let provided: Vec<&str> = input
212 .as_object()
213 .map(|obj| obj.keys().map(|k| k.as_str()).collect())
214 .unwrap_or_default();
215 if provided.is_empty() {
216 ToolError::missing_field(field)
217 } else {
218 let hint = format!(
219 "missing required field '{field}'. Input provided: {}",
220 provided.join(", ")
221 );
222 ToolError::invalid_input(hint)
223 }
224 })
225}
226
227#[must_use]
229pub fn optional_str<'a>(input: &'a Value, field: &str) -> Option<&'a str> {
230 input.get(field).and_then(Value::as_str)
231}
232
233pub fn required_u64(input: &Value, field: &str) -> std::result::Result<u64, ToolError> {
235 input
236 .get(field)
237 .and_then(Value::as_u64)
238 .ok_or_else(|| ToolError::missing_field(field))
239}
240
241#[must_use]
243pub fn optional_u64(input: &Value, field: &str, default: u64) -> u64 {
244 input.get(field).and_then(Value::as_u64).unwrap_or(default)
245}
246
247#[must_use]
249pub fn optional_bool(input: &Value, field: &str, default: bool) -> bool {
250 input.get(field).and_then(Value::as_bool).unwrap_or(default)
251}
252
253#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct ToolSpec {
255 pub name: String,
256 pub input_schema: Value,
257 pub output_schema: Value,
258 pub supports_parallel_tool_calls: bool,
259 pub timeout_ms: Option<u64>,
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct ConfiguredToolSpec {
264 pub spec: ToolSpec,
265 pub supports_parallel_tool_calls: bool,
266}
267
268#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
269#[serde(rename_all = "snake_case")]
270pub enum ToolCallSource {
271 Direct,
272 JsRepl,
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct ToolCall {
277 pub name: String,
278 pub payload: ToolPayload,
279 pub source: ToolCallSource,
280 pub raw_tool_call_id: Option<String>,
281}
282
283impl ToolCall {
284 pub fn execution_subject(&self, fallback_cwd: &str) -> (String, String, &'static str) {
285 match &self.payload {
286 ToolPayload::LocalShell { params } => (
287 params.command.clone(),
288 params
289 .cwd
290 .clone()
291 .unwrap_or_else(|| fallback_cwd.to_string()),
292 "shell",
293 ),
294 _ => (self.name.clone(), fallback_cwd.to_string(), "tool"),
295 }
296 }
297}
298
299#[derive(Debug, Clone)]
300pub struct ToolInvocation {
301 pub call_id: String,
302 pub tool_name: String,
303 pub payload: ToolPayload,
304 pub source: ToolCallSource,
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub enum FunctionCallError {
309 ToolNotFound { name: String },
310 KindMismatch { expected: ToolKind, got: ToolKind },
311 MutatingToolRejected { name: String },
312 TimedOut { name: String, timeout_ms: u64 },
313 Cancelled { name: String },
314 ExecutionFailed { name: String, error: String },
315}
316
317#[async_trait]
318pub trait ToolHandler: Send + Sync {
319 fn kind(&self) -> ToolKind;
320 fn matches_kind(&self, kind: ToolKind) -> bool {
321 self.kind() == kind
322 }
323 fn is_mutating(&self) -> bool {
324 false
325 }
326 async fn handle(
327 &self,
328 invocation: ToolInvocation,
329 ) -> std::result::Result<ToolOutput, FunctionCallError>;
330}
331
332#[derive(Debug, Default)]
333pub struct ToolCallRuntime {
334 pub parallel_execution: Arc<RwLock<()>>,
335}
336
337#[derive(Default)]
338pub struct ToolRegistry {
339 handlers: HashMap<String, Arc<dyn ToolHandler>>,
340 specs: HashMap<String, ConfiguredToolSpec>,
341 runtime: ToolCallRuntime,
342}
343
344impl ToolRegistry {
345 pub fn register(&mut self, spec: ToolSpec, handler: Arc<dyn ToolHandler>) -> Result<()> {
346 let name = spec.name.clone();
347 self.specs.insert(
348 name.clone(),
349 ConfiguredToolSpec {
350 supports_parallel_tool_calls: spec.supports_parallel_tool_calls,
351 spec,
352 },
353 );
354 self.handlers.insert(name, handler);
355 Ok(())
356 }
357
358 pub fn list_specs(&self) -> Vec<ConfiguredToolSpec> {
359 self.specs.values().cloned().collect()
360 }
361
362 pub async fn dispatch(
363 &self,
364 call: ToolCall,
365 allow_mutating: bool,
366 ) -> std::result::Result<ToolOutput, FunctionCallError> {
367 let handler = self.handlers.get(&call.name).cloned().ok_or_else(|| {
368 FunctionCallError::ToolNotFound {
369 name: call.name.clone(),
370 }
371 })?;
372 let configured =
373 self.specs
374 .get(&call.name)
375 .cloned()
376 .ok_or_else(|| FunctionCallError::ToolNotFound {
377 name: call.name.clone(),
378 })?;
379
380 let payload_kind = tool_payload_kind(&call.payload);
381 let expected = handler.kind();
382 if !handler.matches_kind(payload_kind) {
383 return Err(FunctionCallError::KindMismatch {
384 expected,
385 got: payload_kind,
386 });
387 }
388 if handler.is_mutating() && !allow_mutating {
389 return Err(FunctionCallError::MutatingToolRejected { name: call.name });
390 }
391
392 let invocation = ToolInvocation {
393 call_id: call
394 .raw_tool_call_id
395 .clone()
396 .unwrap_or_else(|| format!("tool-call-{}", uuid::Uuid::new_v4())),
397 tool_name: call.name.clone(),
398 payload: call.payload,
399 source: call.source,
400 };
401
402 if configured.supports_parallel_tool_calls {
403 let _guard = self.runtime.parallel_execution.read().await;
404 self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
405 .await
406 } else {
407 let _guard = self.runtime.parallel_execution.write().await;
408 self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
409 .await
410 }
411 }
412
413 async fn execute_with_timeout(
414 &self,
415 handler: Arc<dyn ToolHandler>,
416 timeout_ms: Option<u64>,
417 invocation: ToolInvocation,
418 ) -> std::result::Result<ToolOutput, FunctionCallError> {
419 if let Some(timeout_ms) = timeout_ms {
420 let name = invocation.tool_name.clone();
421 match tokio::time::timeout(
422 Duration::from_millis(timeout_ms),
423 handler.handle(invocation),
424 )
425 .await
426 {
427 Ok(result) => result,
428 Err(_) => Err(FunctionCallError::TimedOut { name, timeout_ms }),
429 }
430 } else {
431 handler.handle(invocation).await
432 }
433 }
434}
435
436fn tool_payload_kind(payload: &ToolPayload) -> ToolKind {
437 match payload {
438 ToolPayload::Mcp { .. } => ToolKind::Mcp,
439 ToolPayload::Function { .. }
440 | ToolPayload::Custom { .. }
441 | ToolPayload::LocalShell { .. } => ToolKind::Function,
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use serde_json::json;
448
449 use super::*;
450
451 #[test]
452 fn tool_result_json_round_trips_content() {
453 let result = ToolResult::json(&json!({"ok": true})).expect("json");
454 assert!(result.success);
455 assert!(result.content.contains("\"ok\": true"));
456 }
457
458 #[test]
459 fn helper_extractors_validate_shape() {
460 let input = json!({"name": "demo", "count": 7, "enabled": true});
461 assert_eq!(required_str(&input, "name").expect("name"), "demo");
462 assert_eq!(optional_u64(&input, "count", 0), 7);
463 assert!(optional_bool(&input, "enabled", false));
464 assert!(matches!(
465 required_u64(&input, "name"),
466 Err(ToolError::MissingField { .. })
467 ));
468 }
469
470 #[test]
471 fn required_str_reports_provided_fields_on_missing_required_field() {
472 let input = json!({"path": "src/lib.rs", "content": "new body"});
473 let err = required_str(&input, "replace").expect_err("replace is missing");
474 let message = err.to_string();
475 assert!(message.contains("missing required field 'replace'"));
476 assert!(message.contains("Input provided:"));
477 assert!(message.contains("path"));
478 assert!(message.contains("content"));
479 }
480
481 #[test]
482 fn tool_error_display_matches_legacy_text() {
483 let err = ToolError::missing_field("path");
484 assert_eq!(
485 err.to_string(),
486 "Failed to validate input: missing required field 'path'"
487 );
488 }
489}