Skip to main content

systemprompt_logging/layer/
mod.rs

1#![allow(clippy::print_stderr)]
2
3mod visitor;
4
5use std::time::Duration;
6
7use chrono::Utc;
8use tokio::sync::mpsc;
9use tracing::{Event, Subscriber};
10use tracing_subscriber::layer::Context;
11use tracing_subscriber::registry::LookupSpan;
12use tracing_subscriber::Layer;
13
14use crate::models::{LogEntry, LogLevel};
15use systemprompt_database::DbPool;
16use systemprompt_identifiers::{ClientId, ContextId, LogId, SessionId, TaskId, TraceId, UserId};
17use visitor::{extract_span_context, FieldVisitor, SpanContext, SpanFields, SpanVisitor};
18
19const BUFFER_FLUSH_SIZE: usize = 100;
20const BUFFER_FLUSH_INTERVAL_SECS: u64 = 10;
21
22enum LogCommand {
23    Entry(Box<LogEntry>),
24    FlushNow,
25}
26
27pub struct DatabaseLayer {
28    sender: mpsc::UnboundedSender<LogCommand>,
29}
30
31impl std::fmt::Debug for DatabaseLayer {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("DatabaseLayer").finish_non_exhaustive()
34    }
35}
36
37impl DatabaseLayer {
38    pub fn new(db_pool: DbPool) -> Self {
39        let (sender, receiver) = mpsc::unbounded_channel();
40
41        tokio::spawn(Self::batch_writer(db_pool, receiver));
42
43        Self { sender }
44    }
45
46    async fn batch_writer(db_pool: DbPool, mut receiver: mpsc::UnboundedReceiver<LogCommand>) {
47        let mut buffer = Vec::with_capacity(BUFFER_FLUSH_SIZE);
48        let mut interval = tokio::time::interval(Duration::from_secs(BUFFER_FLUSH_INTERVAL_SECS));
49
50        loop {
51            tokio::select! {
52                Some(command) = receiver.recv() => {
53                    match command {
54                        LogCommand::Entry(entry) => {
55                            buffer.push(*entry);
56                            if buffer.len() >= BUFFER_FLUSH_SIZE {
57                                Self::flush(&db_pool, &mut buffer).await;
58                            }
59                        }
60                        LogCommand::FlushNow => {
61                            if !buffer.is_empty() {
62                                Self::flush(&db_pool, &mut buffer).await;
63                            }
64                        }
65                    }
66                }
67                _ = interval.tick() => {
68                    if !buffer.is_empty() {
69                        Self::flush(&db_pool, &mut buffer).await;
70                    }
71                }
72            }
73        }
74    }
75
76    async fn flush(db_pool: &DbPool, buffer: &mut Vec<LogEntry>) {
77        if let Err(e) = Self::batch_insert(db_pool, buffer).await {
78            let msg = e.to_string();
79            if !msg.contains("does not exist") {
80                eprintln!("Failed to flush logs: {e}");
81            }
82        }
83        buffer.clear();
84    }
85
86    async fn batch_insert(db_pool: &DbPool, entries: &[LogEntry]) -> anyhow::Result<()> {
87        let pool = db_pool.write_pool_arc()?;
88        for entry in entries {
89            let metadata_json: Option<String> = entry
90                .metadata
91                .as_ref()
92                .map(serde_json::to_string)
93                .transpose()?;
94
95            let entry_id = entry.id.as_str();
96            let level_str = entry.level.to_string();
97            let user_id = entry.user_id.as_str();
98            let session_id = entry.session_id.as_str();
99            let task_id = entry.task_id.as_ref().map(TaskId::as_str);
100            let trace_id = entry.trace_id.as_str();
101            let context_id = entry.context_id.as_ref().map(ContextId::as_str);
102            let client_id = entry.client_id.as_ref().map(ClientId::as_str);
103
104            sqlx::query!(
105                r"
106                INSERT INTO logs (id, level, module, message, metadata, user_id, session_id, task_id, trace_id, context_id, client_id)
107                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
108                ",
109                entry_id,
110                level_str,
111                entry.module,
112                entry.message,
113                metadata_json,
114                user_id,
115                session_id,
116                task_id,
117                trace_id,
118                context_id,
119                client_id
120            )
121            .execute(pool.as_ref())
122            .await?;
123        }
124
125        Ok(())
126    }
127}
128
129impl<S> Layer<S> for DatabaseLayer
130where
131    S: Subscriber + for<'a> LookupSpan<'a>,
132{
133    fn on_new_span(
134        &self,
135        attrs: &tracing::span::Attributes<'_>,
136        id: &tracing::span::Id,
137        ctx: Context<'_, S>,
138    ) {
139        let Some(span) = ctx.span(id) else {
140            return;
141        };
142        let mut fields = SpanFields::default();
143        let mut context = SpanContext::default();
144        let mut visitor = SpanVisitor {
145            context: &mut context,
146        };
147        attrs.record(&mut visitor);
148
149        fields.user = context.user;
150        fields.session = context.session;
151        fields.task = context.task;
152        fields.trace = context.trace;
153        fields.context = context.context;
154        fields.client = context.client;
155
156        let mut extensions = span.extensions_mut();
157        extensions.insert(fields);
158    }
159
160    fn on_record(
161        &self,
162        id: &tracing::span::Id,
163        values: &tracing::span::Record<'_>,
164        ctx: Context<'_, S>,
165    ) {
166        if let Some(span) = ctx.span(id) {
167            let mut extensions = span.extensions_mut();
168            if let Some(fields) = extensions.get_mut::<SpanFields>() {
169                let mut context = SpanContext {
170                    user: fields.user.clone(),
171                    session: fields.session.clone(),
172                    task: fields.task.clone(),
173                    trace: fields.trace.clone(),
174                    context: fields.context.clone(),
175                    client: fields.client.clone(),
176                };
177                let mut visitor = SpanVisitor {
178                    context: &mut context,
179                };
180                values.record(&mut visitor);
181
182                fields.user = context.user;
183                fields.session = context.session;
184                fields.task = context.task;
185                fields.trace = context.trace;
186                fields.context = context.context;
187                fields.client = context.client;
188            }
189        }
190    }
191
192    fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
193        let level = *event.metadata().level();
194        let module = event.metadata().target().to_string();
195
196        let mut visitor = FieldVisitor::default();
197        event.record(&mut visitor);
198
199        let span_context = ctx
200            .current_span()
201            .id()
202            .and_then(|id| ctx.span(id))
203            .map(extract_span_context);
204
205        let log_level = match level {
206            tracing::Level::ERROR => LogLevel::Error,
207            tracing::Level::WARN => LogLevel::Warn,
208            tracing::Level::INFO => LogLevel::Info,
209            tracing::Level::DEBUG => LogLevel::Debug,
210            tracing::Level::TRACE => LogLevel::Trace,
211        };
212
213        let is_error = log_level == LogLevel::Error;
214
215        let entry = LogEntry {
216            id: LogId::generate(),
217            timestamp: Utc::now(),
218            level: log_level,
219            module,
220            message: visitor.message,
221            metadata: visitor.fields,
222            user_id: span_context
223                .as_ref()
224                .and_then(|c| c.user.as_ref())
225                .map_or_else(UserId::system, |s| UserId::new(s.clone())),
226            session_id: span_context
227                .as_ref()
228                .and_then(|c| c.session.as_ref())
229                .map_or_else(SessionId::system, |s| SessionId::new(s.clone())),
230            task_id: span_context
231                .as_ref()
232                .and_then(|c| c.task.as_ref())
233                .map(|s| TaskId::new(s.clone())),
234            trace_id: span_context
235                .as_ref()
236                .and_then(|c| c.trace.as_ref())
237                .map_or_else(TraceId::system, |s| TraceId::new(s.clone())),
238            context_id: span_context
239                .as_ref()
240                .and_then(|c| c.context.as_ref())
241                .map(|s| ContextId::new(s.clone())),
242            client_id: span_context
243                .as_ref()
244                .and_then(|c| c.client.as_ref())
245                .map(|s| ClientId::new(s.clone())),
246        };
247
248        let _ = self.sender.send(LogCommand::Entry(Box::new(entry)));
249
250        if is_error {
251            let _ = self.sender.send(LogCommand::FlushNow);
252        }
253    }
254}