Skip to main content

zagens_core/engine/
approval.rs

1//! Tool-approval and user-input handshake for the agent loop (P2 PR4 → `zagens-core`).
2//!
3//! TUI/Desktop supply policy type `P` (e.g. `SandboxPolicy`) and user response type `R`
4//! (e.g. `UserInputResponse`). Event emission for `request_user_input` stays in the L2 shell.
5
6use tokio::sync::mpsc;
7use tokio_util::sync::CancellationToken;
8use zagens_tools::ToolError;
9
10#[derive(Debug, Clone)]
11pub enum ApprovalDecision<P> {
12    Approved {
13        id: String,
14        /// Fingerprint for session-scoped approval cache (runtime-server).
15        cache_key: Option<String>,
16        /// When true, identical tool calls skip future prompts for this engine session.
17        remember_for_session: bool,
18    },
19    Denied {
20        id: String,
21    },
22    /// Retry a tool with an elevated sandbox policy.
23    RetryWithPolicy {
24        id: String,
25        policy: P,
26    },
27}
28
29#[derive(Debug, Clone)]
30pub enum UserInputDecision<R> {
31    Submitted { id: String, response: R },
32    Cancelled { id: String },
33}
34
35/// Result of awaiting tool approval from the user.
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum ApprovalResult<P> {
38    Approved {
39        cache_key: Option<String>,
40        remember_for_session: bool,
41    },
42    Denied,
43    RetryWithPolicy(P),
44}
45
46/// Block until the user approves, denies, or retries with a new policy for `tool_id`.
47pub async fn await_tool_approval<P>(
48    tool_id: &str,
49    cancel_token: &CancellationToken,
50    rx_approval: &mut mpsc::Receiver<ApprovalDecision<P>>,
51) -> Result<ApprovalResult<P>, ToolError>
52where
53    P: Clone,
54{
55    loop {
56        tokio::select! {
57            _ = cancel_token.cancelled() => {
58                return Err(ToolError::execution_failed(
59                    "Request cancelled while awaiting approval".to_string(),
60                ));
61            }
62            decision = rx_approval.recv() => {
63                let Some(decision) = decision else {
64                    return Err(ToolError::execution_failed(
65                        "Approval channel closed".to_string(),
66                    ));
67                };
68                match decision {
69                    ApprovalDecision::Approved {
70                        id,
71                        cache_key,
72                        remember_for_session,
73                    } if id == tool_id => {
74                        return Ok(ApprovalResult::Approved {
75                            cache_key,
76                            remember_for_session,
77                        });
78                    }
79                    ApprovalDecision::Denied { id } if id == tool_id => {
80                        return Ok(ApprovalResult::Denied);
81                    }
82                    ApprovalDecision::RetryWithPolicy { id, policy } if id == tool_id => {
83                        return Ok(ApprovalResult::RetryWithPolicy(policy));
84                    }
85                    _ => continue,
86                }
87            }
88        }
89    }
90}
91
92/// Block until the user submits or cancels input for `tool_id` (after the shell emits the prompt).
93pub async fn recv_user_input_for_tool<R>(
94    tool_id: &str,
95    cancel_token: &CancellationToken,
96    rx_user_input: &mut mpsc::Receiver<UserInputDecision<R>>,
97) -> Result<R, ToolError> {
98    loop {
99        tokio::select! {
100            _ = cancel_token.cancelled() => {
101                return Err(ToolError::execution_failed(
102                    "Request cancelled while awaiting user input".to_string(),
103                ));
104            }
105            decision = rx_user_input.recv() => {
106                let Some(decision) = decision else {
107                    return Err(ToolError::execution_failed(
108                        "User input channel closed".to_string(),
109                    ));
110                };
111                match decision {
112                    UserInputDecision::Submitted { id, response } if id == tool_id => {
113                        return Ok(response);
114                    }
115                    UserInputDecision::Cancelled { id } if id == tool_id => {
116                        return Err(ToolError::execution_failed(
117                            "User input cancelled".to_string(),
118                        ));
119                    }
120                    _ => continue,
121                }
122            }
123        }
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use tokio::sync::mpsc;
131
132    #[derive(Debug, Clone, PartialEq, Eq)]
133    struct TestPolicy(u8);
134
135    #[tokio::test]
136    async fn await_tool_approval_matches_id() {
137        let cancel = CancellationToken::new();
138        let (tx, mut rx) = mpsc::channel(4);
139        let tool_id = "tool-1";
140        let task = tokio::spawn({
141            let cancel = cancel.clone();
142            async move { await_tool_approval::<TestPolicy>(tool_id, &cancel, &mut rx).await }
143        });
144        tx.send(ApprovalDecision::Denied { id: "other".into() })
145            .await
146            .unwrap();
147        tx.send(ApprovalDecision::Approved {
148            id: tool_id.into(),
149            cache_key: None,
150            remember_for_session: false,
151        })
152        .await
153        .unwrap();
154        assert!(matches!(
155            task.await.unwrap().unwrap(),
156            ApprovalResult::Approved {
157                cache_key: None,
158                remember_for_session: false,
159            }
160        ));
161    }
162
163    #[tokio::test]
164    async fn recv_user_input_for_tool_returns_response() {
165        let cancel = CancellationToken::new();
166        let (tx, mut rx) = mpsc::channel(4);
167        let tool_id = "inp-1";
168        let task = tokio::spawn({
169            let cancel = cancel.clone();
170            async move { recv_user_input_for_tool(tool_id, &cancel, &mut rx).await }
171        });
172        tx.send(UserInputDecision::Submitted {
173            id: tool_id.into(),
174            response: 42u32,
175        })
176        .await
177        .unwrap();
178        assert_eq!(task.await.unwrap().unwrap(), 42);
179    }
180}