systemprompt_logging/layer/
proxy.rs1use std::io::Write;
2use std::sync::{Arc, OnceLock};
3
4use chrono::Utc;
5use tracing::{Event, Subscriber};
6use tracing_subscriber::Layer;
7use tracing_subscriber::layer::Context;
8use tracing_subscriber::registry::LookupSpan;
9
10use super::DatabaseLayer;
11use super::visitor::{FieldVisitor, SpanContext, SpanFields, SpanVisitor, extract_span_context};
12use crate::models::{LogEntry, LogLevel};
13use systemprompt_database::DbPool;
14use systemprompt_identifiers::{ClientId, ContextId, LogId, SessionId, TaskId, TraceId, UserId};
15
16#[derive(Clone)]
17pub struct ProxyDatabaseLayer {
18 inner: Arc<OnceLock<DatabaseLayer>>,
19}
20
21impl std::fmt::Debug for ProxyDatabaseLayer {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("ProxyDatabaseLayer")
24 .field("attached", &self.inner.get().is_some())
25 .finish()
26 }
27}
28
29impl Default for ProxyDatabaseLayer {
30 fn default() -> Self {
31 Self::new()
32 }
33}
34
35impl ProxyDatabaseLayer {
36 pub fn new() -> Self {
37 Self {
38 inner: Arc::new(OnceLock::new()),
39 }
40 }
41
42 pub fn attach(&self, db_pool: DbPool) {
43 if self.inner.set(DatabaseLayer::new(db_pool)).is_err() {
44 writeln!(
45 std::io::stderr(),
46 "ProxyDatabaseLayer: database layer already attached, ignoring duplicate"
47 )
48 .ok();
49 }
50 }
51}
52
53impl<S> Layer<S> for ProxyDatabaseLayer
54where
55 S: Subscriber + for<'a> LookupSpan<'a>,
56{
57 fn on_new_span(
58 &self,
59 attrs: &tracing::span::Attributes<'_>,
60 id: &tracing::span::Id,
61 ctx: Context<'_, S>,
62 ) {
63 if let Some(db) = self.inner.get() {
64 db.on_new_span(attrs, id, ctx);
65 } else {
66 record_span_fields(attrs, id, &ctx);
67 }
68 }
69
70 fn on_record(
71 &self,
72 id: &tracing::span::Id,
73 values: &tracing::span::Record<'_>,
74 ctx: Context<'_, S>,
75 ) {
76 if let Some(db) = self.inner.get() {
77 db.on_record(id, values, ctx);
78 } else {
79 update_span_fields(id, values, &ctx);
80 }
81 }
82
83 fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
84 if let Some(db) = self.inner.get() {
85 db.on_event(event, ctx);
86 }
87 }
88}
89
90pub fn record_span_fields<S>(
91 attrs: &tracing::span::Attributes<'_>,
92 id: &tracing::span::Id,
93 ctx: &Context<'_, S>,
94) where
95 S: Subscriber + for<'a> LookupSpan<'a>,
96{
97 let Some(span) = ctx.span(id) else {
98 return;
99 };
100 let mut fields = SpanFields::default();
101 let mut context = SpanContext::default();
102 let mut visitor = SpanVisitor {
103 context: &mut context,
104 };
105 attrs.record(&mut visitor);
106
107 fields.user = context.user;
108 fields.session = context.session;
109 fields.task = context.task;
110 fields.trace = context.trace;
111 fields.context = context.context;
112 fields.client = context.client;
113
114 let mut extensions = span.extensions_mut();
115 extensions.insert(fields);
116}
117
118pub fn update_span_fields<S>(
119 id: &tracing::span::Id,
120 values: &tracing::span::Record<'_>,
121 ctx: &Context<'_, S>,
122) where
123 S: Subscriber + for<'a> LookupSpan<'a>,
124{
125 if let Some(span) = ctx.span(id) {
126 let mut extensions = span.extensions_mut();
127 if let Some(fields) = extensions.get_mut::<SpanFields>() {
128 let mut context = SpanContext {
129 user: fields.user.clone(),
130 session: fields.session.clone(),
131 task: fields.task.clone(),
132 trace: fields.trace.clone(),
133 context: fields.context.clone(),
134 client: fields.client.clone(),
135 };
136 let mut visitor = SpanVisitor {
137 context: &mut context,
138 };
139 values.record(&mut visitor);
140
141 fields.user = context.user;
142 fields.session = context.session;
143 fields.task = context.task;
144 fields.trace = context.trace;
145 fields.context = context.context;
146 fields.client = context.client;
147 }
148 }
149}
150
151pub fn build_log_entry<S>(event: &Event<'_>, ctx: &Context<'_, S>) -> LogEntry
152where
153 S: Subscriber + for<'a> LookupSpan<'a>,
154{
155 let level = *event.metadata().level();
156 let module = event.metadata().target().to_string();
157
158 let mut visitor = FieldVisitor::default();
159 event.record(&mut visitor);
160
161 let span_context = ctx
162 .current_span()
163 .id()
164 .and_then(|id| ctx.span(id))
165 .map(extract_span_context);
166
167 let log_level = match level {
168 tracing::Level::ERROR => LogLevel::Error,
169 tracing::Level::WARN => LogLevel::Warn,
170 tracing::Level::INFO => LogLevel::Info,
171 tracing::Level::DEBUG => LogLevel::Debug,
172 tracing::Level::TRACE => LogLevel::Trace,
173 };
174
175 LogEntry {
176 id: LogId::generate(),
177 timestamp: Utc::now(),
178 level: log_level,
179 module,
180 message: visitor.message,
181 metadata: visitor.fields,
182 user_id: span_context
183 .as_ref()
184 .and_then(|c| c.user.as_ref())
185 .map_or_else(UserId::system, |s| UserId::new(s.clone())),
186 session_id: span_context
187 .as_ref()
188 .and_then(|c| c.session.as_ref())
189 .map_or_else(SessionId::system, |s| SessionId::new(s.clone())),
190 task_id: span_context
191 .as_ref()
192 .and_then(|c| c.task.as_ref())
193 .map(|s| TaskId::new(s.clone())),
194 trace_id: span_context
195 .as_ref()
196 .and_then(|c| c.trace.as_ref())
197 .map_or_else(TraceId::system, |s| TraceId::new(s.clone())),
198 context_id: span_context
199 .as_ref()
200 .and_then(|c| c.context.as_ref())
201 .and_then(|s| ContextId::try_new(s.clone()).ok()),
202 client_id: span_context
203 .as_ref()
204 .and_then(|c| c.client.as_ref())
205 .map(|s| ClientId::new(s.clone())),
206 }
207}