1use std::future::Future;
49use std::pin::Pin;
50use tower::{Layer, Service};
51
52pub struct TapLayer<Req, Res, Err, OnReq = fn(&Req), OnRes = fn(&Res), OnErr = fn(&Err)> {
54 on_request: Option<OnReq>,
55 on_response: Option<OnRes>,
56 on_error: Option<OnErr>,
57 _phantom: std::marker::PhantomData<fn(Req, Res, Err)>,
58}
59
60impl<Req, Res, Err, OnReq, OnRes, OnErr> Default for TapLayer<Req, Res, Err, OnReq, OnRes, OnErr> {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl<Req, Res, Err, OnReq, OnRes, OnErr> TapLayer<Req, Res, Err, OnReq, OnRes, OnErr> {
67 pub fn new() -> Self {
69 Self {
70 on_request: None,
71 on_response: None,
72 on_error: None,
73 _phantom: std::marker::PhantomData,
74 }
75 }
76
77 pub fn on_request<F>(self, f: F) -> TapLayer<Req, Res, Err, F, OnRes, OnErr>
79 where
80 F: Fn(&Req) + Send + Sync + 'static,
81 {
82 TapLayer {
83 on_request: Some(f),
84 on_response: self.on_response,
85 on_error: self.on_error,
86 _phantom: std::marker::PhantomData,
87 }
88 }
89
90 pub fn on_response<F>(self, f: F) -> TapLayer<Req, Res, Err, OnReq, F, OnErr>
92 where
93 F: Fn(&Res) + Send + Sync + 'static,
94 {
95 TapLayer {
96 on_request: self.on_request,
97 on_response: Some(f),
98 on_error: self.on_error,
99 _phantom: std::marker::PhantomData,
100 }
101 }
102
103 pub fn on_error<F>(self, f: F) -> TapLayer<Req, Res, Err, OnReq, OnRes, F>
105 where
106 F: Fn(&Err) + Send + Sync + 'static,
107 {
108 TapLayer {
109 on_request: self.on_request,
110 on_response: self.on_response,
111 on_error: Some(f),
112 _phantom: std::marker::PhantomData,
113 }
114 }
115}
116
117pub struct Tap<S, Req, Res, Err, OnReq, OnRes, OnErr> {
119 inner: S,
120 on_request: Option<OnReq>,
121 on_response: Option<OnRes>,
122 on_error: Option<OnErr>,
123 _phantom: std::marker::PhantomData<fn(Req, Res, Err)>,
124}
125
126impl<S, Req, Res, Err, OnReq, OnRes, OnErr> Layer<S>
127 for TapLayer<Req, Res, Err, OnReq, OnRes, OnErr>
128where
129 OnReq: Clone,
130 OnRes: Clone,
131 OnErr: Clone,
132{
133 type Service = Tap<S, Req, Res, Err, OnReq, OnRes, OnErr>;
134
135 fn layer(&self, inner: S) -> Self::Service {
136 Tap {
137 inner,
138 on_request: self.on_request.clone(),
139 on_response: self.on_response.clone(),
140 on_error: self.on_error.clone(),
141 _phantom: std::marker::PhantomData,
142 }
143 }
144}
145
146impl<S, Req, Res, Err, OnReq, OnRes, OnErr> Service<Req>
147 for Tap<S, Req, Res, Err, OnReq, OnRes, OnErr>
148where
149 S: Service<Req, Response = Res, Error = Err> + Send + 'static,
150 S::Future: Send + 'static,
151 OnReq: Fn(&Req) + Send + Sync + Clone + 'static,
152 OnRes: Fn(&Res) + Send + Sync + Clone + 'static,
153 OnErr: Fn(&Err) + Send + Sync + Clone + 'static,
154{
155 type Response = Res;
156 type Error = Err;
157 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
158
159 fn poll_ready(
160 &mut self,
161 cx: &mut std::task::Context<'_>,
162 ) -> std::task::Poll<Result<(), Self::Error>> {
163 self.inner.poll_ready(cx)
164 }
165
166 fn call(&mut self, req: Req) -> Self::Future {
167 if let Some(f) = &self.on_request {
168 f(&req);
169 }
170
171 let on_response = self.on_response.clone();
172 let on_error = self.on_error.clone();
173 let fut = self.inner.call(req);
174
175 Box::pin(async move {
176 match fut.await {
177 Ok(res) => {
178 if let Some(f) = &on_response {
179 f(&res);
180 }
181 Ok(res)
182 }
183 Err(err) => {
184 if let Some(f) = &on_error {
185 f(&err);
186 }
187 Err(err)
188 }
189 }
190 })
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use async_openai::types::{CreateChatCompletionRequest, CreateChatCompletionRequestArgs};
198 use std::sync::atomic::{AtomicUsize, Ordering};
199 use std::sync::Arc;
200 use tower::{service_fn, BoxError, ServiceExt};
201
202 #[tokio::test]
203 async fn tap_invokes_request_and_response_hooks() {
204 let req_count = Arc::new(AtomicUsize::new(0));
205 let res_count = Arc::new(AtomicUsize::new(0));
206
207 let rc1 = req_count.clone();
208 let rc2 = res_count.clone();
209 let layer =
210 TapLayer::<CreateChatCompletionRequest, crate::core::StepOutcome, BoxError>::new()
211 .on_request(move |_r: &CreateChatCompletionRequest| {
212 rc1.fetch_add(1, Ordering::Relaxed);
213 })
214 .on_response(move |_o: &crate::core::StepOutcome| {
215 rc2.fetch_add(1, Ordering::Relaxed);
216 });
217
218 let inner = service_fn(|_req: CreateChatCompletionRequest| async move {
219 Ok::<_, BoxError>(crate::core::StepOutcome::Done {
220 messages: vec![],
221 aux: Default::default(),
222 })
223 });
224
225 let mut svc = layer.layer(inner);
226 let req = CreateChatCompletionRequestArgs::default()
227 .model("gpt-4o")
228 .messages(vec![])
229 .build()
230 .unwrap();
231 let _ = ServiceExt::ready(&mut svc)
232 .await
233 .unwrap()
234 .call(req)
235 .await
236 .unwrap();
237
238 assert_eq!(req_count.load(Ordering::Relaxed), 1);
239 assert_eq!(res_count.load(Ordering::Relaxed), 1);
240 }
241
242 #[tokio::test]
243 async fn tap_invokes_error_hook() {
244 let err_count = Arc::new(AtomicUsize::new(0));
245 let ec = err_count.clone();
246
247 let layer =
248 TapLayer::<CreateChatCompletionRequest, crate::core::StepOutcome, BoxError>::new()
249 .on_error(move |_e: &BoxError| {
250 ec.fetch_add(1, Ordering::Relaxed);
251 });
252
253 let inner = service_fn(|_req: CreateChatCompletionRequest| async move {
254 Err::<crate::core::StepOutcome, BoxError>("boom".into())
255 });
256
257 let mut svc = layer.layer(inner);
258 let req = CreateChatCompletionRequestArgs::default()
259 .model("gpt-4o")
260 .messages(vec![])
261 .build()
262 .unwrap();
263 let _ = ServiceExt::ready(&mut svc).await.unwrap().call(req).await;
264
265 assert_eq!(err_count.load(Ordering::Relaxed), 1);
266 }
267
268 #[tokio::test]
269 async fn tap_with_no_hooks_is_transparent() {
270 let layer =
271 TapLayer::<CreateChatCompletionRequest, crate::core::StepOutcome, BoxError>::new();
272 let inner = service_fn(|_req: CreateChatCompletionRequest| async move {
273 Ok::<_, BoxError>(crate::core::StepOutcome::Done {
274 messages: vec![],
275 aux: Default::default(),
276 })
277 });
278 let mut svc = layer.layer(inner);
279 let req = CreateChatCompletionRequestArgs::default()
280 .model("gpt-4o")
281 .messages(vec![])
282 .build()
283 .unwrap();
284 let out = ServiceExt::ready(&mut svc)
285 .await
286 .unwrap()
287 .call(req)
288 .await
289 .unwrap();
290 match out {
291 crate::core::StepOutcome::Done { .. } => {}
292 _ => panic!("expected Done"),
293 }
294 }
295
296 #[tokio::test]
297 async fn tap_response_hook_fires_on_next() {
298 let resp_count = Arc::new(AtomicUsize::new(0));
299 let rc = resp_count.clone();
300 let layer =
301 TapLayer::<CreateChatCompletionRequest, crate::core::StepOutcome, BoxError>::new()
302 .on_response(move |_o: &crate::core::StepOutcome| {
303 rc.fetch_add(1, Ordering::Relaxed);
304 });
305
306 let inner = service_fn(|_req: CreateChatCompletionRequest| async move {
307 Ok::<_, BoxError>(crate::core::StepOutcome::Next {
308 messages: vec![],
309 aux: Default::default(),
310 invoked_tools: vec![],
311 })
312 });
313 let mut svc = layer.layer(inner);
314 let req = CreateChatCompletionRequestArgs::default()
315 .model("gpt-4o")
316 .messages(vec![])
317 .build()
318 .unwrap();
319 let _ = ServiceExt::ready(&mut svc)
320 .await
321 .unwrap()
322 .call(req)
323 .await
324 .unwrap();
325 assert_eq!(resp_count.load(Ordering::Relaxed), 1);
326 }
327
328 #[tokio::test]
329 async fn tap_layers_can_be_chained_and_both_fire() {
330 let req_a = Arc::new(AtomicUsize::new(0));
331 let req_b = Arc::new(AtomicUsize::new(0));
332 let ra = req_a.clone();
333 let rb = req_b.clone();
334
335 let l1 = TapLayer::<CreateChatCompletionRequest, crate::core::StepOutcome, BoxError>::new()
336 .on_request(move |_r: &CreateChatCompletionRequest| {
337 ra.fetch_add(1, Ordering::Relaxed);
338 });
339 let l2 = TapLayer::<CreateChatCompletionRequest, crate::core::StepOutcome, BoxError>::new()
340 .on_request(move |_r: &CreateChatCompletionRequest| {
341 rb.fetch_add(1, Ordering::Relaxed);
342 });
343
344 let inner = service_fn(|_req: CreateChatCompletionRequest| async move {
345 Ok::<_, BoxError>(crate::core::StepOutcome::Done {
346 messages: vec![],
347 aux: Default::default(),
348 })
349 });
350
351 let mut svc = l2.layer(l1.layer(inner));
353 let req = CreateChatCompletionRequestArgs::default()
354 .model("gpt-4o")
355 .messages(vec![])
356 .build()
357 .unwrap();
358 let _ = ServiceExt::ready(&mut svc)
359 .await
360 .unwrap()
361 .call(req)
362 .await
363 .unwrap();
364 assert_eq!(req_a.load(Ordering::Relaxed), 1);
365 assert_eq!(req_b.load(Ordering::Relaxed), 1);
366 }
367
368 #[derive(Clone, Default)]
369 struct CountingReady {
370 calls: Arc<AtomicUsize>,
371 }
372
373 impl tower::Service<CreateChatCompletionRequest> for CountingReady {
374 type Response = crate::core::StepOutcome;
375 type Error = BoxError;
376 type Future = std::pin::Pin<
377 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
378 >;
379
380 fn poll_ready(
381 &mut self,
382 _cx: &mut std::task::Context<'_>,
383 ) -> std::task::Poll<Result<(), Self::Error>> {
384 self.calls.fetch_add(1, Ordering::Relaxed);
385 std::task::Poll::Ready(Ok(()))
386 }
387
388 fn call(&mut self, _req: CreateChatCompletionRequest) -> Self::Future {
389 Box::pin(async move {
390 Ok::<_, BoxError>(crate::core::StepOutcome::Done {
391 messages: vec![],
392 aux: Default::default(),
393 })
394 })
395 }
396 }
397
398 #[tokio::test]
399 async fn tap_poll_ready_is_delegated() {
400 let inner = CountingReady::default();
401 let layer =
402 TapLayer::<CreateChatCompletionRequest, crate::core::StepOutcome, BoxError>::new();
403 let calls = inner.calls.clone();
404
405 let mut svc = layer.layer(inner);
406 let req = CreateChatCompletionRequestArgs::default()
407 .model("gpt-4o")
408 .messages(vec![])
409 .build()
410 .unwrap();
411 let _ = ServiceExt::ready(&mut svc)
412 .await
413 .unwrap()
414 .call(req)
415 .await
416 .unwrap();
417 assert!(calls.load(Ordering::Relaxed) >= 1);
418 }
419}