zagens_core/engine/
approval.rs1use tokio::sync::mpsc;
7use tokio_util::sync::CancellationToken;
8use zagens_tools::ToolError;
9
10#[derive(Debug, Clone)]
11pub enum ApprovalDecision<P> {
12 Approved {
13 id: String,
14 cache_key: Option<String>,
16 remember_for_session: bool,
18 },
19 Denied {
20 id: String,
21 },
22 RetryWithPolicy {
24 id: String,
25 policy: P,
26 },
27}
28
29#[derive(Debug, Clone)]
30pub enum UserInputDecision<R> {
31 Submitted { id: String, response: R },
32 Cancelled { id: String },
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum ApprovalResult<P> {
38 Approved {
39 cache_key: Option<String>,
40 remember_for_session: bool,
41 },
42 Denied,
43 RetryWithPolicy(P),
44}
45
46pub async fn await_tool_approval<P>(
48 tool_id: &str,
49 cancel_token: &CancellationToken,
50 rx_approval: &mut mpsc::Receiver<ApprovalDecision<P>>,
51) -> Result<ApprovalResult<P>, ToolError>
52where
53 P: Clone,
54{
55 loop {
56 tokio::select! {
57 _ = cancel_token.cancelled() => {
58 return Err(ToolError::execution_failed(
59 "Request cancelled while awaiting approval".to_string(),
60 ));
61 }
62 decision = rx_approval.recv() => {
63 let Some(decision) = decision else {
64 return Err(ToolError::execution_failed(
65 "Approval channel closed".to_string(),
66 ));
67 };
68 match decision {
69 ApprovalDecision::Approved {
70 id,
71 cache_key,
72 remember_for_session,
73 } if id == tool_id => {
74 return Ok(ApprovalResult::Approved {
75 cache_key,
76 remember_for_session,
77 });
78 }
79 ApprovalDecision::Denied { id } if id == tool_id => {
80 return Ok(ApprovalResult::Denied);
81 }
82 ApprovalDecision::RetryWithPolicy { id, policy } if id == tool_id => {
83 return Ok(ApprovalResult::RetryWithPolicy(policy));
84 }
85 _ => continue,
86 }
87 }
88 }
89 }
90}
91
92pub async fn recv_user_input_for_tool<R>(
94 tool_id: &str,
95 cancel_token: &CancellationToken,
96 rx_user_input: &mut mpsc::Receiver<UserInputDecision<R>>,
97) -> Result<R, ToolError> {
98 loop {
99 tokio::select! {
100 _ = cancel_token.cancelled() => {
101 return Err(ToolError::execution_failed(
102 "Request cancelled while awaiting user input".to_string(),
103 ));
104 }
105 decision = rx_user_input.recv() => {
106 let Some(decision) = decision else {
107 return Err(ToolError::execution_failed(
108 "User input channel closed".to_string(),
109 ));
110 };
111 match decision {
112 UserInputDecision::Submitted { id, response } if id == tool_id => {
113 return Ok(response);
114 }
115 UserInputDecision::Cancelled { id } if id == tool_id => {
116 return Err(ToolError::execution_failed(
117 "User input cancelled".to_string(),
118 ));
119 }
120 _ => continue,
121 }
122 }
123 }
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use tokio::sync::mpsc;
131
132 #[derive(Debug, Clone, PartialEq, Eq)]
133 struct TestPolicy(u8);
134
135 #[tokio::test]
136 async fn await_tool_approval_matches_id() {
137 let cancel = CancellationToken::new();
138 let (tx, mut rx) = mpsc::channel(4);
139 let tool_id = "tool-1";
140 let task = tokio::spawn({
141 let cancel = cancel.clone();
142 async move { await_tool_approval::<TestPolicy>(tool_id, &cancel, &mut rx).await }
143 });
144 tx.send(ApprovalDecision::Denied { id: "other".into() })
145 .await
146 .unwrap();
147 tx.send(ApprovalDecision::Approved {
148 id: tool_id.into(),
149 cache_key: None,
150 remember_for_session: false,
151 })
152 .await
153 .unwrap();
154 assert!(matches!(
155 task.await.unwrap().unwrap(),
156 ApprovalResult::Approved {
157 cache_key: None,
158 remember_for_session: false,
159 }
160 ));
161 }
162
163 #[tokio::test]
164 async fn recv_user_input_for_tool_returns_response() {
165 let cancel = CancellationToken::new();
166 let (tx, mut rx) = mpsc::channel(4);
167 let tool_id = "inp-1";
168 let task = tokio::spawn({
169 let cancel = cancel.clone();
170 async move { recv_user_input_for_tool(tool_id, &cancel, &mut rx).await }
171 });
172 tx.send(UserInputDecision::Submitted {
173 id: tool_id.into(),
174 response: 42u32,
175 })
176 .await
177 .unwrap();
178 assert_eq!(task.await.unwrap().unwrap(), 42);
179 }
180}