1use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37
38use async_openai::types::CreateChatCompletionRequest;
39use tower::{BoxError, Layer, Service, ServiceExt};
40
41use crate::core::{StepOutcome, ToolInvocation, ToolOutput};
42
43#[derive(Debug, Clone)]
45pub enum Stage {
46 Model,
47 Tool,
48}
49
50#[derive(Debug, Clone)]
52pub enum ApprovalRequest {
53 Model {
54 request: Box<CreateChatCompletionRequest>,
55 },
56 Tool {
57 invocation: ToolInvocation,
58 },
59}
60
61#[derive(Debug, Clone)]
63pub enum Decision {
64 Allow,
65 Deny {
66 reason: String,
67 },
68 ModifyModel {
69 request: Box<CreateChatCompletionRequest>,
70 },
71 ModifyTool {
72 invocation: ToolInvocation,
73 },
74}
75
76pub trait Approver: Service<ApprovalRequest, Response = Decision, Error = BoxError> {}
78impl<T> Approver for T where T: Service<ApprovalRequest, Response = Decision, Error = BoxError> {}
79
80pub struct ModelApprovalLayer<A> {
82 approver: A,
83}
84
85impl<A> ModelApprovalLayer<A> {
86 pub fn new(approver: A) -> Self {
87 Self { approver }
88 }
89}
90
91pub struct ModelApproval<S, A> {
92 inner: Arc<tokio::sync::Mutex<S>>,
93 approver: A,
94}
95
96impl<S, A> Layer<S> for ModelApprovalLayer<A>
97where
98 A: Clone,
99{
100 type Service = ModelApproval<S, A>;
101 fn layer(&self, inner: S) -> Self::Service {
102 ModelApproval {
103 inner: Arc::new(tokio::sync::Mutex::new(inner)),
104 approver: self.approver.clone(),
105 }
106 }
107}
108
109impl<S, A> Service<CreateChatCompletionRequest> for ModelApproval<S, A>
110where
111 S: Service<CreateChatCompletionRequest, Response = StepOutcome, Error = BoxError>
112 + Send
113 + 'static,
114 S::Future: Send + 'static,
115 A: Approver + Clone + Send + 'static,
116 A::Future: Send + 'static,
117{
118 type Response = StepOutcome;
119 type Error = BoxError;
120 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
121
122 fn poll_ready(
123 &mut self,
124 cx: &mut std::task::Context<'_>,
125 ) -> std::task::Poll<Result<(), Self::Error>> {
126 let _ = cx;
127 std::task::Poll::Ready(Ok(()))
128 }
129
130 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
131 let mut approver = self.approver.clone();
132 let inner = self.inner.clone();
133 Box::pin(async move {
134 let decision = ServiceExt::ready(&mut approver)
135 .await?
136 .call(ApprovalRequest::Model {
137 request: Box::new(req.clone()),
138 })
139 .await?;
140 match decision {
141 Decision::Allow | Decision::ModifyTool { .. } => {
142 let mut guard = inner.lock().await;
143 Service::call(&mut *guard, req).await
144 }
145 Decision::Deny { reason: _ } => Ok(StepOutcome::Done {
146 messages: vec![],
147 aux: Default::default(),
148 }),
149 Decision::ModifyModel { request } => {
150 let mut guard = inner.lock().await;
151 Service::call(&mut *guard, *request).await
152 }
153 }
154 })
155 }
156}
157
158pub struct ToolApprovalLayer<A> {
160 approver: A,
161}
162
163impl<A> ToolApprovalLayer<A> {
164 pub fn new(approver: A) -> Self {
165 Self { approver }
166 }
167}
168
169pub struct ToolApproval<R, A> {
170 inner: Arc<tokio::sync::Mutex<R>>,
171 approver: A,
172}
173
174impl<R, A> Layer<R> for ToolApprovalLayer<A>
175where
176 A: Clone,
177{
178 type Service = ToolApproval<R, A>;
179 fn layer(&self, inner: R) -> Self::Service {
180 ToolApproval {
181 inner: Arc::new(tokio::sync::Mutex::new(inner)),
182 approver: self.approver.clone(),
183 }
184 }
185}
186
187impl<R, A> Service<ToolInvocation> for ToolApproval<R, A>
188where
189 R: Service<ToolInvocation, Response = ToolOutput, Error = BoxError> + Send + 'static,
190 R::Future: Send + 'static,
191 A: Approver + Clone + Send + 'static,
192 A::Future: Send + 'static,
193{
194 type Response = ToolOutput;
195 type Error = BoxError;
196 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
197
198 fn poll_ready(
199 &mut self,
200 cx: &mut std::task::Context<'_>,
201 ) -> std::task::Poll<Result<(), Self::Error>> {
202 let _ = cx;
203 std::task::Poll::Ready(Ok(()))
204 }
205
206 fn call(&mut self, inv: ToolInvocation) -> Self::Future {
207 let mut approver = self.approver.clone();
208 let inner = self.inner.clone();
209 Box::pin(async move {
210 let decision = ServiceExt::ready(&mut approver)
211 .await?
212 .call(ApprovalRequest::Tool {
213 invocation: inv.clone(),
214 })
215 .await?;
216 match decision {
217 Decision::Allow | Decision::ModifyModel { .. } => {
218 let mut guard = inner.lock().await;
219 Service::call(&mut *guard, inv).await
220 }
221 Decision::Deny { reason } => {
222 Err::<ToolOutput, BoxError>(format!("denied: {}", reason).into())
223 }
224 Decision::ModifyTool { invocation } => {
225 let mut guard = inner.lock().await;
226 Service::call(&mut *guard, invocation).await
227 }
228 }
229 })
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use serde_json::Value;
237 use std::sync::atomic::{AtomicUsize, Ordering};
238 use tower::{service_fn, Service};
239
240 #[tokio::test]
241 async fn model_denial_short_circuits() {
242 static CALLED: AtomicUsize = AtomicUsize::new(0);
243 let inner = service_fn(|_req: CreateChatCompletionRequest| async move {
244 CALLED.fetch_add(1, Ordering::SeqCst);
245 Ok::<_, BoxError>(StepOutcome::Done {
246 messages: vec![],
247 aux: Default::default(),
248 })
249 });
250 let approver = service_fn(|_req: ApprovalRequest| async move {
251 Ok::<_, BoxError>(Decision::Deny {
252 reason: "no".into(),
253 })
254 });
255 let mut svc = ModelApprovalLayer::new(approver).layer(inner);
256 let req = CreateChatCompletionRequest {
257 model: "gpt-4o".into(),
258 messages: vec![],
259 ..Default::default()
260 };
261 let _ = Service::call(&mut svc, req).await.unwrap();
262 assert_eq!(CALLED.load(Ordering::SeqCst), 0);
263 }
264
265 #[tokio::test]
266 async fn tool_denial_prevents_call() {
267 static CALLED: AtomicUsize = AtomicUsize::new(0);
268 let inner = service_fn(|_inv: ToolInvocation| async move {
269 CALLED.fetch_add(1, Ordering::SeqCst);
270 Ok::<_, BoxError>(ToolOutput {
271 id: "1".into(),
272 result: Value::Null,
273 })
274 });
275 let approver = service_fn(|_req: ApprovalRequest| async move {
276 Ok::<_, BoxError>(Decision::Deny {
277 reason: "no".into(),
278 })
279 });
280 let mut svc = ToolApprovalLayer::new(approver).layer(inner);
281 let err = Service::call(
282 &mut svc,
283 ToolInvocation {
284 id: "1".into(),
285 name: "x".into(),
286 arguments: Value::Null,
287 },
288 )
289 .await
290 .unwrap_err();
291 assert!(format!("{}", err).contains("denied"));
292 assert_eq!(CALLED.load(Ordering::SeqCst), 0);
293 }
294
295 #[tokio::test]
296 async fn model_modify_rewrites_request() {
297 static CALLED: AtomicUsize = AtomicUsize::new(0);
298 let inner = service_fn(|req: CreateChatCompletionRequest| async move {
300 CALLED.fetch_add(1, Ordering::SeqCst);
301 if let Some(first) = req.messages.first() {
303 match first {
304 async_openai::types::ChatCompletionRequestMessage::System(s) => {
305 if let async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) = &s.content {
306 assert_eq!(t, "APPROVED");
307 } else {
308 panic!("expected text content");
309 }
310 }
311 _ => panic!("expected system message first"),
312 }
313 } else {
314 panic!("no messages");
315 }
316 Ok::<_, BoxError>(StepOutcome::Done {
317 messages: req.messages,
318 aux: Default::default(),
319 })
320 });
321 let approver = service_fn(|req: ApprovalRequest| async move {
323 match req {
324 ApprovalRequest::Model { request: _ } => {
325 let mut b = async_openai::types::CreateChatCompletionRequestArgs::default();
326 b.model("gpt-4o");
327 let sys =
328 async_openai::types::ChatCompletionRequestSystemMessageArgs::default()
329 .content("APPROVED")
330 .build()
331 .unwrap();
332 b.messages(vec![sys.into()]);
333 let modified = b.build().unwrap();
334 Ok::<_, BoxError>(Decision::ModifyModel {
335 request: Box::new(modified),
336 })
337 }
338 _ => Ok::<_, BoxError>(Decision::Allow),
339 }
340 });
341 let mut svc = ModelApprovalLayer::new(approver).layer(inner);
342 let orig = async_openai::types::CreateChatCompletionRequestArgs::default()
343 .model("gpt-4o")
344 .messages(vec![])
345 .build()
346 .unwrap();
347 let _ = Service::call(&mut svc, orig).await.unwrap();
348 assert_eq!(CALLED.load(Ordering::SeqCst), 1);
349 }
350
351 #[tokio::test]
352 async fn tool_modify_rewrites_invocation() {
353 static CALLED: AtomicUsize = AtomicUsize::new(0);
354 let inner = service_fn(|inv: ToolInvocation| async move {
356 CALLED.fetch_add(1, Ordering::SeqCst);
357 assert_eq!(inv.name, "modified_tool");
358 Ok::<_, BoxError>(ToolOutput {
359 id: inv.id,
360 result: Value::Null,
361 })
362 });
363 let approver = service_fn(|req: ApprovalRequest| async move {
365 match req {
366 ApprovalRequest::Tool { mut invocation } => {
367 invocation.name = "modified_tool".to_string();
368 Ok::<_, BoxError>(Decision::ModifyTool { invocation })
369 }
370 _ => Ok::<_, BoxError>(Decision::Allow),
371 }
372 });
373 let mut svc = ToolApprovalLayer::new(approver).layer(inner);
374 let _ = Service::call(
375 &mut svc,
376 ToolInvocation {
377 id: "1".into(),
378 name: "orig".into(),
379 arguments: Value::Null,
380 },
381 )
382 .await
383 .unwrap();
384 assert_eq!(CALLED.load(Ordering::SeqCst), 1);
385 }
386}