1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::Instant;
7
8use async_trait::async_trait;
9use sea_orm::{
10 AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
11 ExecResult, IsolationLevel, QueryResult, Statement, StreamTrait, TransactionError,
12 TransactionTrait,
13};
14use tracing::{field, Instrument, Span};
15
16use crate::config::TracingConfig;
17use crate::parser::ParsedSql;
18
19#[derive(Debug, Clone)]
45pub struct TracedConnection {
46 inner: DatabaseConnection,
47 config: Arc<TracingConfig>,
48}
49
50impl TracedConnection {
51 pub fn new(connection: DatabaseConnection, config: TracingConfig) -> Self {
53 Self {
54 inner: connection,
55 config: Arc::new(config),
56 }
57 }
58
59 pub fn wrap(connection: DatabaseConnection) -> Self {
61 Self::new(connection, TracingConfig::default())
62 }
63
64 pub fn inner(&self) -> &DatabaseConnection {
66 &self.inner
67 }
68
69 pub fn config(&self) -> &TracingConfig {
71 &self.config
72 }
73
74 pub fn into_inner(self) -> DatabaseConnection {
76 self.inner
77 }
78
79 fn db_system(&self) -> &'static str {
81 match self.inner.get_database_backend() {
82 DbBackend::Postgres => "postgresql",
83 DbBackend::MySql => "mysql",
84 DbBackend::Sqlite => "sqlite",
85 }
86 }
87
88 fn create_span(&self, stmt: &Statement) -> Span {
90 let parsed = ParsedSql::parse(&stmt.sql);
91 let span_name = parsed.span_name();
92 let db_system = self.db_system();
93
94 let span = tracing::info_span!(
95 "db.query",
96 otel.name = %span_name,
97 db.system = %db_system,
98 db.operation = %parsed.operation.as_str(),
99 db.sql.table = field::Empty,
100 db.statement = field::Empty,
101 db.rows_affected = field::Empty,
102 db.duration_ms = field::Empty,
103 db.name = field::Empty,
104 server.address = field::Empty,
105 server.port = field::Empty,
106 peer.service = field::Empty,
107 otel.status_code = field::Empty,
108 error.message = field::Empty,
109 slow_query = field::Empty,
110 );
111
112 if let Some(table) = &parsed.table {
114 span.record("db.sql.table", table.as_str());
115 }
116
117 if let Some(db_name) = &self.config.database_name {
119 span.record("db.name", db_name.as_str());
120 }
121
122 if let Some(addr) = &self.config.server_address {
124 span.record("server.address", addr.as_str());
125 }
126 if let Some(port) = self.config.server_port {
127 span.record("server.port", port as i64);
128 }
129
130 if let Some(peer) = &self.config.peer_service {
132 span.record("peer.service", peer.as_str());
133 }
134
135 if self.config.log_statements {
137 span.record("db.statement", stmt.sql.as_str());
138 }
139
140 span
141 }
142
143 fn record_result<T, E: std::fmt::Display>(
145 &self,
146 span: &Span,
147 result: &Result<T, E>,
148 start: Instant,
149 row_count: Option<u64>,
150 ) {
151 let duration_ms = start.elapsed().as_millis() as i64;
152 span.record("db.duration_ms", duration_ms);
153
154 if self.config.record_row_counts {
156 if let Some(count) = row_count {
157 span.record("db.rows_affected", count);
158 }
159 }
160
161 if start.elapsed() > self.config.slow_query_threshold {
163 span.record("slow_query", true);
164 let threshold_ms = self.config.slow_query_threshold.as_millis() as i64;
165 tracing::warn!(
166 parent: span,
167 duration_ms = duration_ms,
168 threshold_ms = threshold_ms,
169 "Slow query detected"
170 );
171 }
172
173 match result {
174 Ok(_) => {
175 span.record("otel.status_code", "OK");
176 }
177 Err(e) => {
178 span.record("otel.status_code", "ERROR");
179 span.record("error.message", e.to_string().as_str());
180 tracing::error!(
181 parent: span,
182 error = %e,
183 "Database query failed"
184 );
185 }
186 }
187 }
188}
189
190impl From<DatabaseConnection> for TracedConnection {
191 fn from(connection: DatabaseConnection) -> Self {
192 Self::wrap(connection)
193 }
194}
195
196impl AsRef<DatabaseConnection> for TracedConnection {
197 fn as_ref(&self) -> &DatabaseConnection {
198 &self.inner
199 }
200}
201
202#[async_trait]
203impl ConnectionTrait for TracedConnection {
204 fn get_database_backend(&self) -> DbBackend {
205 self.inner.get_database_backend()
206 }
207
208 async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
209 let span = self.create_span(&stmt);
210 let start = Instant::now();
211
212 let result = self
213 .inner
214 .execute(stmt)
215 .instrument(span.clone())
216 .await;
217
218 let row_count = result.as_ref().ok().map(|r| r.rows_affected());
219 self.record_result(&span, &result, start, row_count);
220
221 result
222 }
223
224 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
225 let stmt = Statement::from_string(self.get_database_backend(), sql);
226 let span = self.create_span(&stmt);
227 let start = Instant::now();
228
229 let result = self
230 .inner
231 .execute_unprepared(sql)
232 .instrument(span.clone())
233 .await;
234
235 let row_count = result.as_ref().ok().map(|r| r.rows_affected());
236 self.record_result(&span, &result, start, row_count);
237
238 result
239 }
240
241 async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
242 let span = self.create_span(&stmt);
243 let start = Instant::now();
244
245 let result = self
246 .inner
247 .query_one(stmt)
248 .instrument(span.clone())
249 .await;
250
251 let row_count = result.as_ref().ok().map(|opt| if opt.is_some() { 1 } else { 0 });
252 self.record_result(&span, &result, start, row_count);
253
254 result
255 }
256
257 async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
258 let span = self.create_span(&stmt);
259 let start = Instant::now();
260
261 let result = self
262 .inner
263 .query_all(stmt)
264 .instrument(span.clone())
265 .await;
266
267 let row_count = result.as_ref().ok().map(|rows| rows.len() as u64);
268 self.record_result(&span, &result, start, row_count);
269
270 result
271 }
272
273 fn support_returning(&self) -> bool {
274 self.inner.support_returning()
275 }
276
277 fn is_mock_connection(&self) -> bool {
278 self.inner.is_mock_connection()
279 }
280}
281
282#[async_trait]
283impl StreamTrait for TracedConnection {
284 type Stream<'a> = <DatabaseConnection as StreamTrait>::Stream<'a>;
285
286 fn stream<'a>(
287 &'a self,
288 stmt: Statement,
289 ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
290 let span = self.create_span(&stmt);
291 let start = Instant::now();
292 let config = self.config.clone();
293
294 Box::pin(async move {
295 let result = self.inner.stream(stmt).instrument(span.clone()).await;
296
297 let duration_ms = start.elapsed().as_millis() as i64;
299 span.record("db.duration_ms", duration_ms);
300
301 if start.elapsed() > config.slow_query_threshold {
302 span.record("slow_query", true);
303 }
304
305 match &result {
306 Ok(_) => {
307 span.record("otel.status_code", "OK");
308 }
309 Err(e) => {
310 span.record("otel.status_code", "ERROR");
311 span.record("error.message", e.to_string().as_str());
312 }
313 }
314
315 result
316 })
317 }
318}
319
320#[async_trait]
321impl TransactionTrait for TracedConnection {
322 async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
323 let span = tracing::info_span!(
324 "db.transaction",
325 otel.name = "BEGIN",
326 db.system = %self.db_system(),
327 db.operation = "BEGIN",
328 otel.status_code = field::Empty,
329 error.message = field::Empty,
330 );
331
332 let result = self.inner.begin().instrument(span.clone()).await;
333
334 match &result {
335 Ok(_) => {
336 span.record("otel.status_code", "OK");
337 }
338 Err(e) => {
339 span.record("otel.status_code", "ERROR");
340 span.record("error.message", e.to_string().as_str());
341 }
342 }
343
344 result
345 }
346
347 async fn begin_with_config(
348 &self,
349 isolation_level: Option<IsolationLevel>,
350 access_mode: Option<AccessMode>,
351 ) -> Result<DatabaseTransaction, DbErr> {
352 let span = tracing::info_span!(
353 "db.transaction",
354 otel.name = "BEGIN",
355 db.system = %self.db_system(),
356 db.operation = "BEGIN",
357 db.transaction.isolation_level = ?isolation_level,
358 db.transaction.access_mode = ?access_mode,
359 otel.status_code = field::Empty,
360 error.message = field::Empty,
361 );
362
363 let result = self
364 .inner
365 .begin_with_config(isolation_level, access_mode)
366 .instrument(span.clone())
367 .await;
368
369 match &result {
370 Ok(_) => {
371 span.record("otel.status_code", "OK");
372 }
373 Err(e) => {
374 span.record("otel.status_code", "ERROR");
375 span.record("error.message", e.to_string().as_str());
376 }
377 }
378
379 result
380 }
381
382 async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
383 where
384 F: for<'c> FnOnce(
385 &'c DatabaseTransaction,
386 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
387 + Send,
388 T: Send,
389 E: std::fmt::Display + std::fmt::Debug + Send,
390 {
391 let span = tracing::info_span!(
392 "db.transaction",
393 otel.name = "TRANSACTION",
394 db.system = %self.db_system(),
395 db.operation = "TRANSACTION",
396 otel.status_code = field::Empty,
397 error.message = field::Empty,
398 );
399
400 let result = self
401 .inner
402 .transaction(callback)
403 .instrument(span.clone())
404 .await;
405
406 match &result {
407 Ok(_) => {
408 span.record("otel.status_code", "OK");
409 }
410 Err(e) => {
411 span.record("otel.status_code", "ERROR");
412 span.record("error.message", format!("{:?}", e).as_str());
413 }
414 }
415
416 result
417 }
418
419 async fn transaction_with_config<F, T, E>(
420 &self,
421 callback: F,
422 isolation_level: Option<IsolationLevel>,
423 access_mode: Option<AccessMode>,
424 ) -> Result<T, TransactionError<E>>
425 where
426 F: for<'c> FnOnce(
427 &'c DatabaseTransaction,
428 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
429 + Send,
430 T: Send,
431 E: std::fmt::Display + std::fmt::Debug + Send,
432 {
433 let span = tracing::info_span!(
434 "db.transaction",
435 otel.name = "TRANSACTION",
436 db.system = %self.db_system(),
437 db.operation = "TRANSACTION",
438 db.transaction.isolation_level = ?isolation_level,
439 db.transaction.access_mode = ?access_mode,
440 otel.status_code = field::Empty,
441 error.message = field::Empty,
442 );
443
444 let result = self
445 .inner
446 .transaction_with_config(callback, isolation_level, access_mode)
447 .instrument(span.clone())
448 .await;
449
450 match &result {
451 Ok(_) => {
452 span.record("otel.status_code", "OK");
453 }
454 Err(e) => {
455 span.record("otel.status_code", "ERROR");
456 span.record("error.message", format!("{:?}", e).as_str());
457 }
458 }
459
460 result
461 }
462}
463
464pub trait TracingExt {
466 fn with_tracing(self) -> TracedConnection;
468
469 fn with_tracing_config(self, config: TracingConfig) -> TracedConnection;
471}
472
473impl TracingExt for DatabaseConnection {
474 fn with_tracing(self) -> TracedConnection {
475 TracedConnection::wrap(self)
476 }
477
478 fn with_tracing_config(self, config: TracingConfig) -> TracedConnection {
479 TracedConnection::new(self, config)
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn test_config_builder() {
489 let config = TracingConfig::default()
490 .with_statement_logging(true)
491 .with_database_name("test_db");
492
493 assert!(config.log_statements);
494 assert_eq!(config.database_name, Some("test_db".to_string()));
495 }
496
497 #[test]
498 fn test_development_config() {
499 let config = TracingConfig::development();
500 assert!(config.log_statements);
501 assert!(config.log_parameters);
502 }
503
504 #[test]
505 fn test_production_config() {
506 let config = TracingConfig::production();
507 assert!(!config.log_statements);
508 assert!(!config.log_parameters);
509 }
510}