Skip to main content

systemprompt_logging/layer/
proxy.rs

1use std::io::Write;
2use std::sync::{Arc, OnceLock};
3
4use chrono::Utc;
5use tracing::{Event, Subscriber};
6use tracing_subscriber::Layer;
7use tracing_subscriber::layer::Context;
8use tracing_subscriber::registry::LookupSpan;
9
10use super::DatabaseLayer;
11use super::visitor::{FieldVisitor, SpanContext, SpanFields, SpanVisitor, extract_span_context};
12use crate::models::{LogEntry, LogLevel};
13use systemprompt_database::DbPool;
14use systemprompt_identifiers::{ClientId, ContextId, LogId, SessionId, TaskId, TraceId, UserId};
15
16#[derive(Clone)]
17pub struct ProxyDatabaseLayer {
18    inner: Arc<OnceLock<DatabaseLayer>>,
19}
20
21impl std::fmt::Debug for ProxyDatabaseLayer {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("ProxyDatabaseLayer")
24            .field("attached", &self.inner.get().is_some())
25            .finish()
26    }
27}
28
29impl Default for ProxyDatabaseLayer {
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl ProxyDatabaseLayer {
36    pub fn new() -> Self {
37        Self {
38            inner: Arc::new(OnceLock::new()),
39        }
40    }
41
42    pub fn attach(&self, db_pool: DbPool) {
43        if self.inner.set(DatabaseLayer::new(db_pool)).is_err() {
44            let _ = writeln!(
45                std::io::stderr(),
46                "ProxyDatabaseLayer: database layer already attached, ignoring duplicate"
47            );
48        }
49    }
50}
51
52impl<S> Layer<S> for ProxyDatabaseLayer
53where
54    S: Subscriber + for<'a> LookupSpan<'a>,
55{
56    fn on_new_span(
57        &self,
58        attrs: &tracing::span::Attributes<'_>,
59        id: &tracing::span::Id,
60        ctx: Context<'_, S>,
61    ) {
62        if let Some(db) = self.inner.get() {
63            db.on_new_span(attrs, id, ctx);
64        } else {
65            record_span_fields(attrs, id, &ctx);
66        }
67    }
68
69    fn on_record(
70        &self,
71        id: &tracing::span::Id,
72        values: &tracing::span::Record<'_>,
73        ctx: Context<'_, S>,
74    ) {
75        if let Some(db) = self.inner.get() {
76            db.on_record(id, values, ctx);
77        } else {
78            update_span_fields(id, values, &ctx);
79        }
80    }
81
82    fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
83        if let Some(db) = self.inner.get() {
84            db.on_event(event, ctx);
85        }
86    }
87}
88
89pub fn record_span_fields<S>(
90    attrs: &tracing::span::Attributes<'_>,
91    id: &tracing::span::Id,
92    ctx: &Context<'_, S>,
93) where
94    S: Subscriber + for<'a> LookupSpan<'a>,
95{
96    let Some(span) = ctx.span(id) else {
97        return;
98    };
99    let mut fields = SpanFields::default();
100    let mut context = SpanContext::default();
101    let mut visitor = SpanVisitor {
102        context: &mut context,
103    };
104    attrs.record(&mut visitor);
105
106    fields.user = context.user;
107    fields.session = context.session;
108    fields.task = context.task;
109    fields.trace = context.trace;
110    fields.context = context.context;
111    fields.client = context.client;
112
113    let mut extensions = span.extensions_mut();
114    extensions.insert(fields);
115}
116
117pub fn update_span_fields<S>(
118    id: &tracing::span::Id,
119    values: &tracing::span::Record<'_>,
120    ctx: &Context<'_, S>,
121) where
122    S: Subscriber + for<'a> LookupSpan<'a>,
123{
124    if let Some(span) = ctx.span(id) {
125        let mut extensions = span.extensions_mut();
126        if let Some(fields) = extensions.get_mut::<SpanFields>() {
127            let mut context = SpanContext {
128                user: fields.user.clone(),
129                session: fields.session.clone(),
130                task: fields.task.clone(),
131                trace: fields.trace.clone(),
132                context: fields.context.clone(),
133                client: fields.client.clone(),
134            };
135            let mut visitor = SpanVisitor {
136                context: &mut context,
137            };
138            values.record(&mut visitor);
139
140            fields.user = context.user;
141            fields.session = context.session;
142            fields.task = context.task;
143            fields.trace = context.trace;
144            fields.context = context.context;
145            fields.client = context.client;
146        }
147    }
148}
149
150pub fn build_log_entry<S>(event: &Event<'_>, ctx: &Context<'_, S>) -> LogEntry
151where
152    S: Subscriber + for<'a> LookupSpan<'a>,
153{
154    let level = *event.metadata().level();
155    let module = event.metadata().target().to_string();
156
157    let mut visitor = FieldVisitor::default();
158    event.record(&mut visitor);
159
160    let span_context = ctx
161        .current_span()
162        .id()
163        .and_then(|id| ctx.span(id))
164        .map(extract_span_context);
165
166    let log_level = match level {
167        tracing::Level::ERROR => LogLevel::Error,
168        tracing::Level::WARN => LogLevel::Warn,
169        tracing::Level::INFO => LogLevel::Info,
170        tracing::Level::DEBUG => LogLevel::Debug,
171        tracing::Level::TRACE => LogLevel::Trace,
172    };
173
174    LogEntry {
175        id: LogId::generate(),
176        timestamp: Utc::now(),
177        level: log_level,
178        module,
179        message: visitor.message,
180        metadata: visitor.fields,
181        user_id: span_context
182            .as_ref()
183            .and_then(|c| c.user.as_ref())
184            .map_or_else(UserId::system, |s| UserId::new(s.clone())),
185        session_id: span_context
186            .as_ref()
187            .and_then(|c| c.session.as_ref())
188            .map_or_else(SessionId::system, |s| SessionId::new(s.clone())),
189        task_id: span_context
190            .as_ref()
191            .and_then(|c| c.task.as_ref())
192            .map(|s| TaskId::new(s.clone())),
193        trace_id: span_context
194            .as_ref()
195            .and_then(|c| c.trace.as_ref())
196            .map_or_else(TraceId::system, |s| TraceId::new(s.clone())),
197        context_id: span_context
198            .as_ref()
199            .and_then(|c| c.context.as_ref())
200            .map(|s| ContextId::new(s.clone())),
201        client_id: span_context
202            .as_ref()
203            .and_then(|c| c.client.as_ref())
204            .map(|s| ClientId::new(s.clone())),
205    }
206}