Skip to main content

systemprompt_logging/layer/
proxy.rs

1use std::io::Write;
2use std::sync::{Arc, OnceLock};
3
4use chrono::Utc;
5use tracing::{Event, Subscriber};
6use tracing_subscriber::Layer;
7use tracing_subscriber::layer::Context;
8use tracing_subscriber::registry::LookupSpan;
9
10use super::DatabaseLayer;
11use super::visitor::{FieldVisitor, SpanContext, SpanFields, SpanVisitor, extract_span_context};
12use crate::models::{LogEntry, LogLevel};
13use systemprompt_database::DbPool;
14use systemprompt_identifiers::{ClientId, ContextId, LogId, SessionId, TaskId, TraceId, UserId};
15
16#[derive(Clone)]
17pub struct ProxyDatabaseLayer {
18    inner: Arc<OnceLock<DatabaseLayer>>,
19}
20
21impl std::fmt::Debug for ProxyDatabaseLayer {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("ProxyDatabaseLayer")
24            .field("attached", &self.inner.get().is_some())
25            .finish()
26    }
27}
28
29impl Default for ProxyDatabaseLayer {
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl ProxyDatabaseLayer {
36    pub fn new() -> Self {
37        Self {
38            inner: Arc::new(OnceLock::new()),
39        }
40    }
41
42    pub fn attach(&self, db_pool: DbPool) {
43        if self.inner.set(DatabaseLayer::new(db_pool)).is_err() {
44            writeln!(
45                std::io::stderr(),
46                "ProxyDatabaseLayer: database layer already attached, ignoring duplicate"
47            )
48            .ok();
49        }
50    }
51}
52
53impl<S> Layer<S> for ProxyDatabaseLayer
54where
55    S: Subscriber + for<'a> LookupSpan<'a>,
56{
57    fn on_new_span(
58        &self,
59        attrs: &tracing::span::Attributes<'_>,
60        id: &tracing::span::Id,
61        ctx: Context<'_, S>,
62    ) {
63        if let Some(db) = self.inner.get() {
64            db.on_new_span(attrs, id, ctx);
65        } else {
66            record_span_fields(attrs, id, &ctx);
67        }
68    }
69
70    fn on_record(
71        &self,
72        id: &tracing::span::Id,
73        values: &tracing::span::Record<'_>,
74        ctx: Context<'_, S>,
75    ) {
76        if let Some(db) = self.inner.get() {
77            db.on_record(id, values, ctx);
78        } else {
79            update_span_fields(id, values, &ctx);
80        }
81    }
82
83    fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
84        if let Some(db) = self.inner.get() {
85            db.on_event(event, ctx);
86        }
87    }
88}
89
90pub fn record_span_fields<S>(
91    attrs: &tracing::span::Attributes<'_>,
92    id: &tracing::span::Id,
93    ctx: &Context<'_, S>,
94) where
95    S: Subscriber + for<'a> LookupSpan<'a>,
96{
97    let Some(span) = ctx.span(id) else {
98        return;
99    };
100    let mut fields = SpanFields::default();
101    let mut context = SpanContext::default();
102    let mut visitor = SpanVisitor {
103        context: &mut context,
104    };
105    attrs.record(&mut visitor);
106
107    fields.user = context.user;
108    fields.session = context.session;
109    fields.task = context.task;
110    fields.trace = context.trace;
111    fields.context = context.context;
112    fields.client = context.client;
113
114    let mut extensions = span.extensions_mut();
115    extensions.insert(fields);
116}
117
118pub fn update_span_fields<S>(
119    id: &tracing::span::Id,
120    values: &tracing::span::Record<'_>,
121    ctx: &Context<'_, S>,
122) where
123    S: Subscriber + for<'a> LookupSpan<'a>,
124{
125    if let Some(span) = ctx.span(id) {
126        let mut extensions = span.extensions_mut();
127        if let Some(fields) = extensions.get_mut::<SpanFields>() {
128            let mut context = SpanContext {
129                user: fields.user.clone(),
130                session: fields.session.clone(),
131                task: fields.task.clone(),
132                trace: fields.trace.clone(),
133                context: fields.context.clone(),
134                client: fields.client.clone(),
135            };
136            let mut visitor = SpanVisitor {
137                context: &mut context,
138            };
139            values.record(&mut visitor);
140
141            fields.user = context.user;
142            fields.session = context.session;
143            fields.task = context.task;
144            fields.trace = context.trace;
145            fields.context = context.context;
146            fields.client = context.client;
147        }
148    }
149}
150
151pub fn build_log_entry<S>(event: &Event<'_>, ctx: &Context<'_, S>) -> LogEntry
152where
153    S: Subscriber + for<'a> LookupSpan<'a>,
154{
155    let level = *event.metadata().level();
156    let module = event.metadata().target().to_string();
157
158    let mut visitor = FieldVisitor::default();
159    event.record(&mut visitor);
160
161    let span_context = ctx
162        .current_span()
163        .id()
164        .and_then(|id| ctx.span(id))
165        .map(extract_span_context);
166
167    let log_level = match level {
168        tracing::Level::ERROR => LogLevel::Error,
169        tracing::Level::WARN => LogLevel::Warn,
170        tracing::Level::INFO => LogLevel::Info,
171        tracing::Level::DEBUG => LogLevel::Debug,
172        tracing::Level::TRACE => LogLevel::Trace,
173    };
174
175    LogEntry {
176        id: LogId::generate(),
177        timestamp: Utc::now(),
178        level: log_level,
179        module,
180        message: visitor.message,
181        metadata: visitor.fields,
182        user_id: span_context
183            .as_ref()
184            .and_then(|c| c.user.as_ref())
185            .map_or_else(UserId::system, |s| UserId::new(s.clone())),
186        session_id: span_context
187            .as_ref()
188            .and_then(|c| c.session.as_ref())
189            .map_or_else(SessionId::system, |s| SessionId::new(s.clone())),
190        task_id: span_context
191            .as_ref()
192            .and_then(|c| c.task.as_ref())
193            .map(|s| TaskId::new(s.clone())),
194        trace_id: span_context
195            .as_ref()
196            .and_then(|c| c.trace.as_ref())
197            .map_or_else(TraceId::system, |s| TraceId::new(s.clone())),
198        context_id: span_context
199            .as_ref()
200            .and_then(|c| c.context.as_ref())
201            .and_then(|s| ContextId::try_new(s.clone()).ok()),
202        client_id: span_context
203            .as_ref()
204            .and_then(|c| c.client.as_ref())
205            .map(|s| ClientId::new(s.clone())),
206    }
207}