systemprompt_logging/layer/
mod.rs1#![allow(clippy::print_stderr)]
2
3mod visitor;
4
5use std::time::Duration;
6
7use chrono::Utc;
8use tokio::sync::mpsc;
9use tracing::{Event, Subscriber};
10use tracing_subscriber::layer::Context;
11use tracing_subscriber::registry::LookupSpan;
12use tracing_subscriber::Layer;
13
14use crate::models::{LogEntry, LogLevel};
15use systemprompt_database::DbPool;
16use systemprompt_identifiers::{ClientId, ContextId, LogId, SessionId, TaskId, TraceId, UserId};
17use visitor::{extract_span_context, FieldVisitor, SpanContext, SpanFields, SpanVisitor};
18
19const BUFFER_FLUSH_SIZE: usize = 100;
20const BUFFER_FLUSH_INTERVAL_SECS: u64 = 10;
21
22enum LogCommand {
23 Entry(Box<LogEntry>),
24 FlushNow,
25}
26
27pub struct DatabaseLayer {
28 sender: mpsc::UnboundedSender<LogCommand>,
29}
30
31impl std::fmt::Debug for DatabaseLayer {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 f.debug_struct("DatabaseLayer").finish_non_exhaustive()
34 }
35}
36
37impl DatabaseLayer {
38 pub fn new(db_pool: DbPool) -> Self {
39 let (sender, receiver) = mpsc::unbounded_channel();
40
41 tokio::spawn(Self::batch_writer(db_pool, receiver));
42
43 Self { sender }
44 }
45
46 async fn batch_writer(db_pool: DbPool, mut receiver: mpsc::UnboundedReceiver<LogCommand>) {
47 let mut buffer = Vec::with_capacity(BUFFER_FLUSH_SIZE);
48 let mut interval = tokio::time::interval(Duration::from_secs(BUFFER_FLUSH_INTERVAL_SECS));
49
50 loop {
51 tokio::select! {
52 Some(command) = receiver.recv() => {
53 match command {
54 LogCommand::Entry(entry) => {
55 buffer.push(*entry);
56 if buffer.len() >= BUFFER_FLUSH_SIZE {
57 Self::flush(&db_pool, &mut buffer).await;
58 }
59 }
60 LogCommand::FlushNow => {
61 if !buffer.is_empty() {
62 Self::flush(&db_pool, &mut buffer).await;
63 }
64 }
65 }
66 }
67 _ = interval.tick() => {
68 if !buffer.is_empty() {
69 Self::flush(&db_pool, &mut buffer).await;
70 }
71 }
72 }
73 }
74 }
75
76 async fn flush(db_pool: &DbPool, buffer: &mut Vec<LogEntry>) {
77 if let Err(e) = Self::batch_insert(db_pool, buffer).await {
78 eprintln!("Failed to flush logs: {e}");
79 }
80 buffer.clear();
81 }
82
83 async fn batch_insert(db_pool: &DbPool, entries: &[LogEntry]) -> anyhow::Result<()> {
84 let pool = db_pool.pool_arc()?;
85 for entry in entries {
86 let metadata_json: Option<String> = entry
87 .metadata
88 .as_ref()
89 .map(serde_json::to_string)
90 .transpose()?;
91
92 let entry_id = entry.id.as_str();
93 let level_str = entry.level.to_string();
94 let user_id = entry.user_id.as_str();
95 let session_id = entry.session_id.as_str();
96 let task_id = entry.task_id.as_ref().map(TaskId::as_str);
97 let trace_id = entry.trace_id.as_str();
98 let context_id = entry.context_id.as_ref().map(ContextId::as_str);
99 let client_id = entry.client_id.as_ref().map(ClientId::as_str);
100
101 sqlx::query!(
102 r"
103 INSERT INTO logs (id, level, module, message, metadata, user_id, session_id, task_id, trace_id, context_id, client_id)
104 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
105 ",
106 entry_id,
107 level_str,
108 entry.module,
109 entry.message,
110 metadata_json,
111 user_id,
112 session_id,
113 task_id,
114 trace_id,
115 context_id,
116 client_id
117 )
118 .execute(pool.as_ref())
119 .await?;
120 }
121
122 Ok(())
123 }
124}
125
126impl<S> Layer<S> for DatabaseLayer
127where
128 S: Subscriber + for<'a> LookupSpan<'a>,
129{
130 fn on_new_span(
131 &self,
132 attrs: &tracing::span::Attributes<'_>,
133 id: &tracing::span::Id,
134 ctx: Context<'_, S>,
135 ) {
136 let Some(span) = ctx.span(id) else {
137 return;
138 };
139 let mut fields = SpanFields::default();
140 let mut context = SpanContext::default();
141 let mut visitor = SpanVisitor {
142 context: &mut context,
143 };
144 attrs.record(&mut visitor);
145
146 fields.user = context.user;
147 fields.session = context.session;
148 fields.task = context.task;
149 fields.trace = context.trace;
150 fields.context = context.context;
151 fields.client = context.client;
152
153 let mut extensions = span.extensions_mut();
154 extensions.insert(fields);
155 }
156
157 fn on_record(
158 &self,
159 id: &tracing::span::Id,
160 values: &tracing::span::Record<'_>,
161 ctx: Context<'_, S>,
162 ) {
163 if let Some(span) = ctx.span(id) {
164 let mut extensions = span.extensions_mut();
165 if let Some(fields) = extensions.get_mut::<SpanFields>() {
166 let mut context = SpanContext {
167 user: fields.user.clone(),
168 session: fields.session.clone(),
169 task: fields.task.clone(),
170 trace: fields.trace.clone(),
171 context: fields.context.clone(),
172 client: fields.client.clone(),
173 };
174 let mut visitor = SpanVisitor {
175 context: &mut context,
176 };
177 values.record(&mut visitor);
178
179 fields.user = context.user;
180 fields.session = context.session;
181 fields.task = context.task;
182 fields.trace = context.trace;
183 fields.context = context.context;
184 fields.client = context.client;
185 }
186 }
187 }
188
189 fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
190 let level = *event.metadata().level();
191 let module = event.metadata().target().to_string();
192
193 let mut visitor = FieldVisitor::default();
194 event.record(&mut visitor);
195
196 let span_context = ctx
197 .current_span()
198 .id()
199 .and_then(|id| ctx.span(id))
200 .map(extract_span_context);
201
202 let log_level = match level {
203 tracing::Level::ERROR => LogLevel::Error,
204 tracing::Level::WARN => LogLevel::Warn,
205 tracing::Level::INFO => LogLevel::Info,
206 tracing::Level::DEBUG => LogLevel::Debug,
207 tracing::Level::TRACE => LogLevel::Trace,
208 };
209
210 let is_error = log_level == LogLevel::Error;
211
212 let entry = LogEntry {
213 id: LogId::generate(),
214 timestamp: Utc::now(),
215 level: log_level,
216 module,
217 message: visitor.message,
218 metadata: visitor.fields,
219 user_id: span_context
220 .as_ref()
221 .and_then(|c| c.user.as_ref())
222 .map_or_else(UserId::system, |s| UserId::new(s.clone())),
223 session_id: span_context
224 .as_ref()
225 .and_then(|c| c.session.as_ref())
226 .map_or_else(SessionId::system, |s| SessionId::new(s.clone())),
227 task_id: span_context
228 .as_ref()
229 .and_then(|c| c.task.as_ref())
230 .map(|s| TaskId::new(s.clone())),
231 trace_id: span_context
232 .as_ref()
233 .and_then(|c| c.trace.as_ref())
234 .map_or_else(TraceId::system, |s| TraceId::new(s.clone())),
235 context_id: span_context
236 .as_ref()
237 .and_then(|c| c.context.as_ref())
238 .map(|s| ContextId::new(s.clone())),
239 client_id: span_context
240 .as_ref()
241 .and_then(|c| c.client.as_ref())
242 .map(|s| ClientId::new(s.clone())),
243 };
244
245 let _ = self.sender.send(LogCommand::Entry(Box::new(entry)));
246
247 if is_error {
248 let _ = self.sender.send(LogCommand::FlushNow);
249 }
250 }
251}