sayr_engine/
telemetry.rs

1use std::sync::{Arc, Mutex};
2use std::time::{Duration, SystemTime};
3
4use opentelemetry::global;
5use opentelemetry::trace::{Span, SpanKind, Tracer};
6use opentelemetry::KeyValue;
7use opentelemetry_otlp::WithExportConfig;
8use opentelemetry_sdk;
9use serde::{Deserialize, Serialize};
10use serde_json;
11use tokio::time::sleep;
12use tracing::{span, Level};
13use tracing_subscriber::layer::SubscriberExt;
14use tracing_subscriber::util::SubscriberInitExt;
15use tracing_subscriber::{EnvFilter, Registry};
16
17use crate::error::{AgnoError, Result};
18
19#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
20pub struct TelemetryLabels {
21    pub tenant: Option<String>,
22    pub tool: Option<String>,
23    pub workflow: Option<String>,
24}
25
26impl TelemetryLabels {
27    pub fn with_tenant(mut self, tenant: impl Into<String>) -> Self {
28        self.tenant = Some(tenant.into());
29        self
30    }
31
32    pub fn with_tool(mut self, tool: impl Into<String>) -> Self {
33        self.tool = Some(tool.into());
34        self
35    }
36
37    pub fn with_workflow(mut self, workflow: impl Into<String>) -> Self {
38        self.workflow = Some(workflow.into());
39        self
40    }
41
42    pub fn as_attributes(&self) -> Vec<KeyValue> {
43        let mut attrs = Vec::new();
44        if let Some(tenant) = &self.tenant {
45            attrs.push(KeyValue::new("tenant", tenant.clone()));
46        }
47        if let Some(tool) = &self.tool {
48            attrs.push(KeyValue::new("tool", tool.clone()));
49        }
50        if let Some(workflow) = &self.workflow {
51            attrs.push(KeyValue::new("workflow", workflow.clone()));
52        }
53        attrs
54    }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct TelemetryEvent {
59    pub kind: String,
60    pub timestamp: SystemTime,
61    pub detail: serde_json::Value,
62    pub labels: TelemetryLabels,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct FailureRecord {
67    pub context: String,
68    pub error: String,
69    pub attempt: u32,
70    pub labels: TelemetryLabels,
71}
72
73#[derive(Default, Clone)]
74pub struct TelemetryCollector {
75    events: Arc<Mutex<Vec<TelemetryEvent>>>,
76    failures: Arc<Mutex<Vec<FailureRecord>>>,
77}
78
79impl TelemetryCollector {
80    pub fn record(
81        &self,
82        kind: impl Into<String>,
83        detail: serde_json::Value,
84        labels: TelemetryLabels,
85    ) {
86        self.events.lock().unwrap().push(TelemetryEvent {
87            kind: kind.into(),
88            timestamp: SystemTime::now(),
89            detail,
90            labels,
91        });
92    }
93
94    pub fn record_failure(
95        &self,
96        context: impl Into<String>,
97        error: impl Into<String>,
98        attempt: u32,
99        labels: TelemetryLabels,
100    ) {
101        self.failures.lock().unwrap().push(FailureRecord {
102            context: context.into(),
103            error: error.into(),
104            attempt,
105            labels,
106        });
107    }
108
109    pub fn drain(&self) -> (Vec<TelemetryEvent>, Vec<FailureRecord>) {
110        let mut events = self.events.lock().unwrap();
111        let mut failures = self.failures.lock().unwrap();
112        (std::mem::take(&mut *events), std::mem::take(&mut *failures))
113    }
114}
115
116#[derive(Default, Clone)]
117pub struct TelemetrySink {
118    buffer: Arc<Mutex<Vec<TelemetryEvent>>>,
119}
120
121impl TelemetrySink {
122    pub fn push(&self, event: TelemetryEvent) {
123        self.buffer.lock().unwrap().push(event);
124    }
125
126    pub fn flush(&self) -> Vec<TelemetryEvent> {
127        let mut guard = self.buffer.lock().unwrap();
128        std::mem::take(&mut *guard)
129    }
130}
131
132#[derive(Debug, Clone)]
133pub struct RetryPolicy {
134    pub max_retries: u32,
135    pub backoff: Duration,
136}
137
138impl RetryPolicy {
139    pub fn default_external_call() -> Self {
140        Self {
141            max_retries: 3,
142            backoff: Duration::from_millis(200),
143        }
144    }
145
146    pub async fn retry<F, Fut, T>(
147        &self,
148        mut f: F,
149        telemetry: Option<&TelemetryCollector>,
150        labels: TelemetryLabels,
151    ) -> Result<T>
152    where
153        F: FnMut(u32) -> Fut,
154        Fut: std::future::Future<Output = Result<T>>,
155    {
156        for attempt in 0..=self.max_retries {
157            match f(attempt).await {
158                Ok(value) => return Ok(value),
159                Err(err) => {
160                    if let Some(t) = telemetry {
161                        t.record_failure("retry", format!("{err}"), attempt, labels.clone());
162                    }
163                    let span = span!(
164                        Level::INFO,
165                        "retry_failure",
166                        attempt,
167                        tenant = labels.tenant.as_deref().unwrap_or(""),
168                        tool = labels.tool.as_deref().unwrap_or(""),
169                        workflow = labels.workflow.as_deref().unwrap_or("")
170                    );
171                    let _enter = span.enter();
172                    tracing::warn!("retry attempt {} failed: {}", attempt, err);
173                    if attempt == self.max_retries {
174                        return Err(err);
175                    }
176                    sleep(self.backoff * (attempt + 1)).await;
177                }
178            }
179        }
180        Err(AgnoError::Protocol("retry exhausted".into()))
181    }
182}
183
184#[derive(Clone)]
185pub struct FallbackChain<T> {
186    steps: Vec<(String, Arc<dyn Fn() -> Result<T> + Send + Sync>)>,
187}
188
189impl<T> std::fmt::Debug for FallbackChain<T> {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        let labels: Vec<&str> = self.steps.iter().map(|(label, _)| label.as_str()).collect();
192        f.debug_struct("FallbackChain")
193            .field("steps", &labels)
194            .finish()
195    }
196}
197
198impl<T> FallbackChain<T> {
199    pub fn new() -> Self {
200        Self { steps: Vec::new() }
201    }
202
203    pub fn with_step(
204        mut self,
205        label: impl Into<String>,
206        handler: impl Fn() -> Result<T> + Send + Sync + 'static,
207    ) -> Self {
208        self.steps.push((label.into(), Arc::new(handler)));
209        self
210    }
211
212    pub fn execute(
213        &self,
214        telemetry: Option<&TelemetryCollector>,
215        labels: TelemetryLabels,
216    ) -> Result<T> {
217        let mut last_error: Option<AgnoError> = None;
218        for (label, handler) in self.steps.iter() {
219            let span = span!(
220                Level::DEBUG,
221                "fallback_step",
222                step = label.as_str(),
223                tenant = labels.tenant.as_deref().unwrap_or(""),
224                tool = labels.tool.as_deref().unwrap_or(""),
225                workflow = labels.workflow.as_deref().unwrap_or("")
226            );
227            let _guard = span.enter();
228            match handler() {
229                Ok(value) => {
230                    if let Some(t) = telemetry {
231                        t.record(
232                            "fallback_success",
233                            serde_json::json!({ "step": label }),
234                            labels.clone(),
235                        );
236                    }
237                    tracing::info!("fallback step succeeded");
238                    return Ok(value);
239                }
240                Err(err) => {
241                    if let Some(t) = telemetry {
242                        t.record_failure(label.clone(), format!("{err}"), 0, labels.clone());
243                    }
244                    tracing::warn!("fallback step failed: {}", err);
245                    last_error = Some(err);
246                }
247            }
248        }
249        Err(last_error.unwrap_or_else(|| AgnoError::Protocol("fallback exhausted".into())))
250    }
251}
252
253pub fn span_with_labels(_name: &str, labels: &TelemetryLabels) -> tracing::Span {
254    span!(
255        Level::INFO,
256        "labeled_span",
257        tenant = labels.tenant.as_deref().unwrap_or(""),
258        tool = labels.tool.as_deref().unwrap_or(""),
259        workflow = labels.workflow.as_deref().unwrap_or("")
260    )
261}
262
263pub fn init_tracing(service_name: &str, otlp_endpoint: Option<&str>) -> Result<()> {
264    let trace_config = opentelemetry_sdk::trace::config().with_resource(
265        opentelemetry_sdk::Resource::new(vec![KeyValue::new(
266            "service.name",
267            service_name.to_owned(),
268        )]),
269    );
270
271    let tracer = if let Some(endpoint) = otlp_endpoint {
272        opentelemetry_otlp::new_pipeline()
273            .tracing()
274            .with_trace_config(trace_config)
275            .with_exporter(
276                opentelemetry_otlp::new_exporter()
277                    .tonic()
278                    .with_endpoint(endpoint),
279            )
280            .install_batch(opentelemetry_sdk::runtime::Tokio)
281            .map_err(|e| AgnoError::Telemetry(e.to_string()))?
282    } else {
283        opentelemetry_otlp::new_pipeline()
284            .tracing()
285            .with_trace_config(trace_config)
286            .with_exporter(opentelemetry_otlp::new_exporter().tonic())
287            .install_batch(opentelemetry_sdk::runtime::Tokio)
288            .map_err(|e| AgnoError::Telemetry(e.to_string()))?
289    };
290
291    let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
292    let fmt_layer = tracing_subscriber::fmt::layer().json().with_target(true);
293    let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
294    Registry::default()
295        .with(env_filter)
296        .with(fmt_layer)
297        .with(telemetry)
298        .try_init()
299        .map_err(|e| AgnoError::Telemetry(format!("failed to init tracing: {e}")))?;
300    Ok(())
301}
302
303pub fn current_span_attributes(labels: &TelemetryLabels) {
304    let tracer = global::tracer("agno-tracer");
305    let mut span = tracer
306        .span_builder("context")
307        .with_kind(SpanKind::Internal)
308        .with_attributes(labels.as_attributes())
309        .start(&tracer);
310    span.add_event("context attached".to_string(), labels.as_attributes());
311    span.end();
312}
313
314pub fn flush_tracer() {
315    // OpenTelemetry 0.22 doesn't expose a flush method; shutdown flushes internally
316    global::shutdown_tracer_provider();
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[tokio::test]
324    async fn retries_until_success() {
325        let policy = RetryPolicy {
326            max_retries: 2,
327            backoff: Duration::from_millis(1),
328        };
329        use std::sync::Arc;
330        use tokio::sync::Mutex;
331
332        let calls = Arc::new(Mutex::new(0u32));
333        let telemetry = TelemetryCollector::default();
334        let labels = TelemetryLabels {
335            tenant: Some("tenant-a".into()),
336            tool: Some("retry".into()),
337            workflow: Some("test".into()),
338        };
339        let res = policy
340            .retry(
341                |_: u32| {
342                    let calls = calls.clone();
343                    async move {
344                        let mut guard = calls.lock().await;
345                        *guard += 1;
346                        if *guard < 2 {
347                            Err(AgnoError::Protocol("fail".into()))
348                        } else {
349                            Ok(42)
350                        }
351                    }
352                },
353                Some(&telemetry),
354                labels.clone(),
355            )
356            .await;
357        assert_eq!(res.unwrap(), 42);
358        let drained = telemetry.drain();
359        assert_eq!(drained.1.len(), 1);
360        assert_eq!(drained.1[0].labels, labels);
361    }
362
363    #[test]
364    fn runs_fallbacks() {
365        let telemetry = TelemetryCollector::default();
366        let labels = TelemetryLabels {
367            tenant: Some("tenant-a".into()),
368            tool: Some("fallback".into()),
369            workflow: Some("test".into()),
370        };
371        let chain = FallbackChain::new()
372            .with_step("primary", || Err(AgnoError::Protocol("nope".into())))
373            .with_step("secondary", || Ok("ok"));
374        let res = chain.execute(Some(&telemetry), labels.clone()).unwrap();
375        assert_eq!(res, "ok");
376        let drained = telemetry.drain();
377        assert_eq!(drained.1.len(), 1);
378        assert_eq!(drained.1[0].labels, labels);
379    }
380}