systemprompt_logging/layer/
mod.rs1mod proxy;
8mod visitor;
9
10use std::io::Write;
11use std::sync::Arc;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::time::Duration;
14
15use tokio::sync::mpsc;
16use tracing::{Event, Subscriber};
17use tracing_subscriber::Layer;
18use tracing_subscriber::layer::Context;
19use tracing_subscriber::registry::LookupSpan;
20
21pub use proxy::ProxyDatabaseLayer;
22use proxy::{build_log_entry, record_span_fields, update_span_fields};
23
24use crate::models::{LogEntry, LogLevel};
25use systemprompt_database::DbPool;
26use systemprompt_identifiers::{ClientId, ContextId, TaskId};
27
28const BUFFER_FLUSH_SIZE: usize = 100;
29const BUFFER_FLUSH_INTERVAL_SECS: u64 = 10;
30
31const CHANNEL_CAPACITY: usize = 8192;
35
36enum LogCommand {
37 Entry(Box<LogEntry>),
38 FlushNow,
39}
40
41struct LogChannel {
45 sender: mpsc::Sender<LogCommand>,
46 dropped: Arc<AtomicU64>,
47}
48
49impl LogChannel {
50 fn new(capacity: usize) -> (Self, mpsc::Receiver<LogCommand>) {
51 let (sender, receiver) = mpsc::channel(capacity);
52 let channel = Self {
53 sender,
54 dropped: Arc::new(AtomicU64::new(0)),
55 };
56 (channel, receiver)
57 }
58
59 fn send(&self, command: LogCommand) {
60 if let Err(mpsc::error::TrySendError::Full(_)) = self.sender.try_send(command) {
61 self.dropped.fetch_add(1, Ordering::Relaxed);
62 }
63 }
64
65 fn dropped(&self) -> u64 {
66 self.dropped.load(Ordering::Relaxed)
67 }
68}
69
70pub struct DatabaseLayer {
71 channel: LogChannel,
72}
73
74impl std::fmt::Debug for DatabaseLayer {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 f.debug_struct("DatabaseLayer")
77 .field("dropped", &self.channel.dropped())
78 .finish_non_exhaustive()
79 }
80}
81
82impl DatabaseLayer {
83 pub fn new(db_pool: DbPool) -> Self {
84 let (channel, receiver) = LogChannel::new(CHANNEL_CAPACITY);
85
86 tokio::spawn(Self::batch_writer(db_pool, receiver));
87
88 Self { channel }
89 }
90
91 async fn batch_writer(db_pool: DbPool, mut receiver: mpsc::Receiver<LogCommand>) {
92 let mut buffer = Vec::with_capacity(BUFFER_FLUSH_SIZE);
93 let mut interval = tokio::time::interval(Duration::from_secs(BUFFER_FLUSH_INTERVAL_SECS));
94 let mut failed_total: u64 = 0;
95
96 loop {
97 tokio::select! {
98 Some(command) = receiver.recv() => {
99 match command {
100 LogCommand::Entry(entry) => {
101 buffer.push(*entry);
102 if buffer.len() >= BUFFER_FLUSH_SIZE {
103 Self::flush(&db_pool, &mut buffer, &mut failed_total).await;
104 }
105 }
106 LogCommand::FlushNow => {
107 if !buffer.is_empty() {
108 Self::flush(&db_pool, &mut buffer, &mut failed_total).await;
109 }
110 }
111 }
112 }
113 _ = interval.tick() => {
114 if !buffer.is_empty() {
115 Self::flush(&db_pool, &mut buffer, &mut failed_total).await;
116 }
117 }
118 }
119 }
120 }
121
122 async fn flush(db_pool: &DbPool, buffer: &mut Vec<LogEntry>, failed_total: &mut u64) {
123 if let Err(e) = Self::batch_insert(db_pool, buffer).await {
124 let lost = u64::try_from(buffer.len()).unwrap_or(u64::MAX);
125 *failed_total = failed_total.saturating_add(lost);
126 writeln!(
127 std::io::stderr(),
128 "DATABASE LOG FLUSH FAILED ({lost} entries lost this flush, {failed_total} total lost since start): {e}"
129 )
130 .ok();
131 }
132 buffer.clear();
133 }
134
135 async fn batch_insert(
136 db_pool: &DbPool,
137 entries: &[LogEntry],
138 ) -> Result<(), crate::models::LoggingError> {
139 let pool = db_pool.write_pool_arc()?;
140 for entry in entries {
141 let metadata_json: Option<String> = entry
142 .metadata
143 .as_ref()
144 .map(serde_json::to_string)
145 .transpose()?;
146
147 let entry_id = entry.id.as_str();
148 let level_str = entry.level.to_string();
149 let user_id = entry.user_id.as_str();
150 let session_id = entry.session_id.as_str();
151 let task_id = entry.task_id.as_ref().map(TaskId::as_str);
152 let trace_id = entry.trace_id.as_str();
153 let context_id = entry.context_id.as_ref().map(ContextId::as_str);
154 let client_id = entry.client_id.as_ref().map(ClientId::as_str);
155
156 sqlx::query!(
157 r"
158 INSERT INTO logs (id, level, module, message, metadata, user_id, session_id, task_id, trace_id, context_id, client_id)
159 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
160 ",
161 entry_id,
162 level_str,
163 entry.module,
164 entry.message,
165 metadata_json,
166 user_id,
167 session_id,
168 task_id,
169 trace_id,
170 context_id,
171 client_id
172 )
173 .execute(pool.as_ref())
174 .await?;
175 }
176
177 Ok(())
178 }
179}
180
181impl DatabaseLayer {
182 fn send_entry(&self, entry: LogEntry) {
183 let is_error = entry.level == LogLevel::Error;
184 self.channel.send(LogCommand::Entry(Box::new(entry)));
185 if is_error {
186 self.channel.send(LogCommand::FlushNow);
187 }
188 }
189}
190
191impl<S> Layer<S> for DatabaseLayer
192where
193 S: Subscriber + for<'a> LookupSpan<'a>,
194{
195 fn on_new_span(
196 &self,
197 attrs: &tracing::span::Attributes<'_>,
198 id: &tracing::span::Id,
199 ctx: Context<'_, S>,
200 ) {
201 record_span_fields(attrs, id, &ctx);
202 }
203
204 fn on_record(
205 &self,
206 id: &tracing::span::Id,
207 values: &tracing::span::Record<'_>,
208 ctx: Context<'_, S>,
209 ) {
210 update_span_fields(id, values, &ctx);
211 }
212
213 fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
214 if let Some(entry) = build_log_entry(event, &ctx) {
215 self.send_entry(entry);
216 }
217 }
218}