tower_llm/approvals/
mod.rs

1//! Guardrails and approvals
2//!
3//! What this module provides (spec)
4//! - Pluggable approval strategies applied before model/tool execution
5//! - Ability to deny, allow, or rewrite requests
6//!
7//! Exports
8//! - Models
9//!   - `ApprovalRequest { stage: Stage, request: RawChatRequest | ToolInvocation }`
10//!   - `Decision::{Allow, Deny{reason}, Modify{request}}`
11//!   - `Stage::{Model, Tool}`
12//! - Services
13//!   - `Approver: Service<ApprovalRequest, Response=Decision>`
14//! - Layers
15//!   - `ApprovalLayer<S, A>` where `A: Approver`
16//!     - On Model stage: evaluate before calling provider; on Modify, replace request; on Deny, short-circuit
17//!     - On Tool stage: evaluate before invoking router
18//! - Utils
19//!   - Prebuilt approvers: `AllowListTools`, `MaxArgsSize`, `RequireReasoning`
20//!
21//! Implementation strategy
22//! - Keep `Approver` pure and side-effect free (unless intentionally stateful)
23//! - The layer inspects the stage and constructs `ApprovalRequest` appropriately
24//! - Decisions flow control the inner service call
25//!
26//! Composition
27//! - `ServiceBuilder::new().layer(ApprovalLayer::new(my_approver)).service(step)`
28//! - For tools, wrap the router separately if needed
29//!
30//! Testing strategy
31//! - Fake approver returning scripted decisions
32//! - Unit tests per stage: Model denial prevents provider call; Tool denial prevents router call; Modify rewrites inputs
33
34use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37
38use async_openai::types::CreateChatCompletionRequest;
39use tower::{BoxError, Layer, Service, ServiceExt};
40
41use crate::core::{StepOutcome, ToolInvocation, ToolOutput};
42
43/// Stage at which an approval is evaluated.
44#[derive(Debug, Clone)]
45pub enum Stage {
46    Model,
47    Tool,
48}
49
50/// Approval request payload.
51#[derive(Debug, Clone)]
52pub enum ApprovalRequest {
53    Model {
54        request: Box<CreateChatCompletionRequest>,
55    },
56    Tool {
57        invocation: ToolInvocation,
58    },
59}
60
61/// Approval decision.
62#[derive(Debug, Clone)]
63pub enum Decision {
64    Allow,
65    Deny {
66        reason: String,
67    },
68    ModifyModel {
69        request: Box<CreateChatCompletionRequest>,
70    },
71    ModifyTool {
72        invocation: ToolInvocation,
73    },
74}
75
76/// Approver service trait alias: Service<ApprovalRequest, Decision>.
77pub trait Approver: Service<ApprovalRequest, Response = Decision, Error = BoxError> {}
78impl<T> Approver for T where T: Service<ApprovalRequest, Response = Decision, Error = BoxError> {}
79
80/// Layer that evaluates approvals at the model stage before invoking the inner step service.
81pub struct ModelApprovalLayer<A> {
82    approver: A,
83}
84
85impl<A> ModelApprovalLayer<A> {
86    pub fn new(approver: A) -> Self {
87        Self { approver }
88    }
89}
90
91pub struct ModelApproval<S, A> {
92    inner: Arc<tokio::sync::Mutex<S>>,
93    approver: A,
94}
95
96impl<S, A> Layer<S> for ModelApprovalLayer<A>
97where
98    A: Clone,
99{
100    type Service = ModelApproval<S, A>;
101    fn layer(&self, inner: S) -> Self::Service {
102        ModelApproval {
103            inner: Arc::new(tokio::sync::Mutex::new(inner)),
104            approver: self.approver.clone(),
105        }
106    }
107}
108
109impl<S, A> Service<CreateChatCompletionRequest> for ModelApproval<S, A>
110where
111    S: Service<CreateChatCompletionRequest, Response = StepOutcome, Error = BoxError>
112        + Send
113        + 'static,
114    S::Future: Send + 'static,
115    A: Approver + Clone + Send + 'static,
116    A::Future: Send + 'static,
117{
118    type Response = StepOutcome;
119    type Error = BoxError;
120    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
121
122    fn poll_ready(
123        &mut self,
124        cx: &mut std::task::Context<'_>,
125    ) -> std::task::Poll<Result<(), Self::Error>> {
126        let _ = cx;
127        std::task::Poll::Ready(Ok(()))
128    }
129
130    fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
131        let mut approver = self.approver.clone();
132        let inner = self.inner.clone();
133        Box::pin(async move {
134            let decision = ServiceExt::ready(&mut approver)
135                .await?
136                .call(ApprovalRequest::Model {
137                    request: Box::new(req.clone()),
138                })
139                .await?;
140            match decision {
141                Decision::Allow | Decision::ModifyTool { .. } => {
142                    let mut guard = inner.lock().await;
143                    Service::call(&mut *guard, req).await
144                }
145                Decision::Deny { reason: _ } => Ok(StepOutcome::Done {
146                    messages: vec![],
147                    aux: Default::default(),
148                }),
149                Decision::ModifyModel { request } => {
150                    let mut guard = inner.lock().await;
151                    Service::call(&mut *guard, *request).await
152                }
153            }
154        })
155    }
156}
157
158/// Wrapper for tool router that evaluates approvals per invocation.
159pub struct ToolApprovalLayer<A> {
160    approver: A,
161}
162
163impl<A> ToolApprovalLayer<A> {
164    pub fn new(approver: A) -> Self {
165        Self { approver }
166    }
167}
168
169pub struct ToolApproval<R, A> {
170    inner: Arc<tokio::sync::Mutex<R>>,
171    approver: A,
172}
173
174impl<R, A> Layer<R> for ToolApprovalLayer<A>
175where
176    A: Clone,
177{
178    type Service = ToolApproval<R, A>;
179    fn layer(&self, inner: R) -> Self::Service {
180        ToolApproval {
181            inner: Arc::new(tokio::sync::Mutex::new(inner)),
182            approver: self.approver.clone(),
183        }
184    }
185}
186
187impl<R, A> Service<ToolInvocation> for ToolApproval<R, A>
188where
189    R: Service<ToolInvocation, Response = ToolOutput, Error = BoxError> + Send + 'static,
190    R::Future: Send + 'static,
191    A: Approver + Clone + Send + 'static,
192    A::Future: Send + 'static,
193{
194    type Response = ToolOutput;
195    type Error = BoxError;
196    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
197
198    fn poll_ready(
199        &mut self,
200        cx: &mut std::task::Context<'_>,
201    ) -> std::task::Poll<Result<(), Self::Error>> {
202        let _ = cx;
203        std::task::Poll::Ready(Ok(()))
204    }
205
206    fn call(&mut self, inv: ToolInvocation) -> Self::Future {
207        let mut approver = self.approver.clone();
208        let inner = self.inner.clone();
209        Box::pin(async move {
210            let decision = ServiceExt::ready(&mut approver)
211                .await?
212                .call(ApprovalRequest::Tool {
213                    invocation: inv.clone(),
214                })
215                .await?;
216            match decision {
217                Decision::Allow | Decision::ModifyModel { .. } => {
218                    let mut guard = inner.lock().await;
219                    Service::call(&mut *guard, inv).await
220                }
221                Decision::Deny { reason } => {
222                    Err::<ToolOutput, BoxError>(format!("denied: {}", reason).into())
223                }
224                Decision::ModifyTool { invocation } => {
225                    let mut guard = inner.lock().await;
226                    Service::call(&mut *guard, invocation).await
227                }
228            }
229        })
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use serde_json::Value;
237    use std::sync::atomic::{AtomicUsize, Ordering};
238    use tower::{service_fn, Service};
239
240    #[tokio::test]
241    async fn model_denial_short_circuits() {
242        static CALLED: AtomicUsize = AtomicUsize::new(0);
243        let inner = service_fn(|_req: CreateChatCompletionRequest| async move {
244            CALLED.fetch_add(1, Ordering::SeqCst);
245            Ok::<_, BoxError>(StepOutcome::Done {
246                messages: vec![],
247                aux: Default::default(),
248            })
249        });
250        let approver = service_fn(|_req: ApprovalRequest| async move {
251            Ok::<_, BoxError>(Decision::Deny {
252                reason: "no".into(),
253            })
254        });
255        let mut svc = ModelApprovalLayer::new(approver).layer(inner);
256        let req = CreateChatCompletionRequest {
257            model: "gpt-4o".into(),
258            messages: vec![],
259            ..Default::default()
260        };
261        let _ = Service::call(&mut svc, req).await.unwrap();
262        assert_eq!(CALLED.load(Ordering::SeqCst), 0);
263    }
264
265    #[tokio::test]
266    async fn tool_denial_prevents_call() {
267        static CALLED: AtomicUsize = AtomicUsize::new(0);
268        let inner = service_fn(|_inv: ToolInvocation| async move {
269            CALLED.fetch_add(1, Ordering::SeqCst);
270            Ok::<_, BoxError>(ToolOutput {
271                id: "1".into(),
272                result: Value::Null,
273            })
274        });
275        let approver = service_fn(|_req: ApprovalRequest| async move {
276            Ok::<_, BoxError>(Decision::Deny {
277                reason: "no".into(),
278            })
279        });
280        let mut svc = ToolApprovalLayer::new(approver).layer(inner);
281        let err = Service::call(
282            &mut svc,
283            ToolInvocation {
284                id: "1".into(),
285                name: "x".into(),
286                arguments: Value::Null,
287            },
288        )
289        .await
290        .unwrap_err();
291        assert!(format!("{}", err).contains("denied"));
292        assert_eq!(CALLED.load(Ordering::SeqCst), 0);
293    }
294
295    #[tokio::test]
296    async fn model_modify_rewrites_request() {
297        static CALLED: AtomicUsize = AtomicUsize::new(0);
298        // Inner asserts it receives the modified request (system message content == "APPROVED")
299        let inner = service_fn(|req: CreateChatCompletionRequest| async move {
300            CALLED.fetch_add(1, Ordering::SeqCst);
301            // First message should be system with content "APPROVED"
302            if let Some(first) = req.messages.first() {
303                match first {
304                    async_openai::types::ChatCompletionRequestMessage::System(s) => {
305                        if let async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) = &s.content {
306                            assert_eq!(t, "APPROVED");
307                        } else {
308                            panic!("expected text content");
309                        }
310                    }
311                    _ => panic!("expected system message first"),
312                }
313            } else {
314                panic!("no messages");
315            }
316            Ok::<_, BoxError>(StepOutcome::Done {
317                messages: req.messages,
318                aux: Default::default(),
319            })
320        });
321        // Approver transforms incoming request by injecting system message "APPROVED"
322        let approver = service_fn(|req: ApprovalRequest| async move {
323            match req {
324                ApprovalRequest::Model { request: _ } => {
325                    let mut b = async_openai::types::CreateChatCompletionRequestArgs::default();
326                    b.model("gpt-4o");
327                    let sys =
328                        async_openai::types::ChatCompletionRequestSystemMessageArgs::default()
329                            .content("APPROVED")
330                            .build()
331                            .unwrap();
332                    b.messages(vec![sys.into()]);
333                    let modified = b.build().unwrap();
334                    Ok::<_, BoxError>(Decision::ModifyModel {
335                        request: Box::new(modified),
336                    })
337                }
338                _ => Ok::<_, BoxError>(Decision::Allow),
339            }
340        });
341        let mut svc = ModelApprovalLayer::new(approver).layer(inner);
342        let orig = async_openai::types::CreateChatCompletionRequestArgs::default()
343            .model("gpt-4o")
344            .messages(vec![])
345            .build()
346            .unwrap();
347        let _ = Service::call(&mut svc, orig).await.unwrap();
348        assert_eq!(CALLED.load(Ordering::SeqCst), 1);
349    }
350
351    #[tokio::test]
352    async fn tool_modify_rewrites_invocation() {
353        static CALLED: AtomicUsize = AtomicUsize::new(0);
354        // Inner asserts invocation name was modified to "modified_tool"
355        let inner = service_fn(|inv: ToolInvocation| async move {
356            CALLED.fetch_add(1, Ordering::SeqCst);
357            assert_eq!(inv.name, "modified_tool");
358            Ok::<_, BoxError>(ToolOutput {
359                id: inv.id,
360                result: Value::Null,
361            })
362        });
363        // Approver rewrites tool invocation name
364        let approver = service_fn(|req: ApprovalRequest| async move {
365            match req {
366                ApprovalRequest::Tool { mut invocation } => {
367                    invocation.name = "modified_tool".to_string();
368                    Ok::<_, BoxError>(Decision::ModifyTool { invocation })
369                }
370                _ => Ok::<_, BoxError>(Decision::Allow),
371            }
372        });
373        let mut svc = ToolApprovalLayer::new(approver).layer(inner);
374        let _ = Service::call(
375            &mut svc,
376            ToolInvocation {
377                id: "1".into(),
378                name: "orig".into(),
379                arguments: Value::Null,
380            },
381        )
382        .await
383        .unwrap();
384        assert_eq!(CALLED.load(Ordering::SeqCst), 1);
385    }
386}