1use std::future::Future;
29use std::pin::Pin;
30
31use async_openai::types::CreateChatCompletionRequest;
32use tower::{BoxError, Layer, Service, ServiceExt};
33use tracing::{info, info_span, Instrument};
34
35use crate::core::StepOutcome;
36
37#[derive(Debug, Clone, Copy, Default)]
38pub struct Usage {
39 pub prompt_tokens: usize,
40 pub completion_tokens: usize,
41}
42
43#[derive(Debug, Clone)]
44pub enum MetricRecord {
45 Counter { name: &'static str, value: u64 },
46 Histogram { name: &'static str, value: u64 },
47}
48
49pub trait MetricsCollector: Service<MetricRecord, Response = (), Error = BoxError> {}
50impl<T> MetricsCollector for T where T: Service<MetricRecord, Response = (), Error = BoxError> {}
51
52pub struct TracingLayer;
54impl TracingLayer {
55 pub fn new() -> Self {
56 Self
57 }
58}
59impl Default for TracingLayer {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65pub struct Tracing<S> {
66 inner: S,
67}
68
69impl<S> Layer<S> for TracingLayer {
70 type Service = Tracing<S>;
71 fn layer(&self, inner: S) -> Self::Service {
72 Tracing { inner }
73 }
74}
75
76impl<S> Service<CreateChatCompletionRequest> for Tracing<S>
77where
78 S: Service<CreateChatCompletionRequest, Response = StepOutcome, Error = BoxError>
79 + Send
80 + 'static,
81 S::Future: Send + 'static,
82{
83 type Response = StepOutcome;
84 type Error = BoxError;
85 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
86
87 fn poll_ready(
88 &mut self,
89 cx: &mut std::task::Context<'_>,
90 ) -> std::task::Poll<Result<(), Self::Error>> {
91 self.inner.poll_ready(cx)
92 }
93
94 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
95 let model = req.model.clone();
96 let span = info_span!("step", model = %model);
97 let fut = self.inner.call(req).instrument(span);
98 Box::pin(async move {
99 let out = fut.await?;
100 match &out {
101 StepOutcome::Next {
102 aux, invoked_tools, ..
103 } => {
104 info!(prompt = aux.prompt_tokens, completion = aux.completion_tokens, tools = aux.tool_invocations, invoked = ?invoked_tools, "step next")
105 }
106 StepOutcome::Done { aux, .. } => info!(
107 prompt = aux.prompt_tokens,
108 completion = aux.completion_tokens,
109 tools = aux.tool_invocations,
110 "step done"
111 ),
112 }
113 Ok(out)
114 })
115 }
116}
117
118pub struct MetricsLayer<C> {
120 collector: C,
121}
122impl<C> MetricsLayer<C> {
123 pub fn new(collector: C) -> Self {
124 Self { collector }
125 }
126}
127
128pub struct Metrics<S, C> {
129 inner: S,
130 collector: C,
131}
132
133impl<S, C> Layer<S> for MetricsLayer<C>
134where
135 C: Clone,
136{
137 type Service = Metrics<S, C>;
138 fn layer(&self, inner: S) -> Self::Service {
139 Metrics {
140 inner,
141 collector: self.collector.clone(),
142 }
143 }
144}
145
146impl<S, C> Service<CreateChatCompletionRequest> for Metrics<S, C>
147where
148 S: Service<CreateChatCompletionRequest, Response = StepOutcome, Error = BoxError>
149 + Send
150 + 'static,
151 S::Future: Send + 'static,
152 C: MetricsCollector + Clone + Send + 'static,
153 C::Future: Send + 'static,
154{
155 type Response = StepOutcome;
156 type Error = BoxError;
157 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
158
159 fn poll_ready(
160 &mut self,
161 cx: &mut std::task::Context<'_>,
162 ) -> std::task::Poll<Result<(), Self::Error>> {
163 self.inner.poll_ready(cx)
164 }
165
166 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
167 let mut collector = self.collector.clone();
168 let fut = self.inner.call(req);
169 Box::pin(async move {
170 let out = fut.await?;
171 let (prompt, completion, tools) = match &out {
172 StepOutcome::Next { aux, .. } | StepOutcome::Done { aux, .. } => (
173 aux.prompt_tokens,
174 aux.completion_tokens,
175 aux.tool_invocations,
176 ),
177 };
178 let _ = ServiceExt::ready(&mut collector)
179 .await?
180 .call(MetricRecord::Counter {
181 name: "prompt_tokens",
182 value: prompt as u64,
183 })
184 .await;
185 let _ = ServiceExt::ready(&mut collector)
186 .await?
187 .call(MetricRecord::Counter {
188 name: "completion_tokens",
189 value: completion as u64,
190 })
191 .await;
192 let _ = ServiceExt::ready(&mut collector)
193 .await?
194 .call(MetricRecord::Counter {
195 name: "tool_invocations",
196 value: tools as u64,
197 })
198 .await;
199 Ok(out)
200 })
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use tower::service_fn;
208
209 #[tokio::test]
210 async fn metrics_layer_updates_collector() {
211 let inner = service_fn(|_req: CreateChatCompletionRequest| async move {
212 Ok::<_, BoxError>(StepOutcome::Done {
213 messages: vec![],
214 aux: crate::core::StepAux {
215 prompt_tokens: 3,
216 completion_tokens: 7,
217 tool_invocations: 1,
218 },
219 })
220 });
221 let sink = std::sync::Arc::new(tokio::sync::Mutex::new(Vec::<(&'static str, u64)>::new()));
222 let sink_cl = sink.clone();
223 let collector = service_fn(move |rec: MetricRecord| {
224 let sink = sink_cl.clone();
225 async move {
226 if let MetricRecord::Counter { name, value } = rec {
227 sink.lock().await.push((name, value))
228 }
229 Ok::<(), BoxError>(())
230 }
231 });
232 let mut svc = MetricsLayer::new(collector).layer(inner);
233 let req = CreateChatCompletionRequest {
234 model: "gpt-4o".into(),
235 messages: vec![],
236 ..Default::default()
237 };
238 let _ = ServiceExt::ready(&mut svc)
239 .await
240 .unwrap()
241 .call(req)
242 .await
243 .unwrap();
244 let data = sink.lock().await.clone();
245 assert!(data.iter().any(|(n, v)| *n == "prompt_tokens" && *v == 3));
246 assert!(data
247 .iter()
248 .any(|(n, v)| *n == "completion_tokens" && *v == 7));
249 assert!(data
250 .iter()
251 .any(|(n, v)| *n == "tool_invocations" && *v == 1));
252 }
253}