sea_orm_tracing/
connection.rs

1//! Traced database connection wrapper.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::Instant;
7
8use async_trait::async_trait;
9use sea_orm::{
10    AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
11    ExecResult, IsolationLevel, QueryResult, Statement, StreamTrait, TransactionError,
12    TransactionTrait,
13};
14use tracing::{field, Instrument, Span};
15
16use crate::config::TracingConfig;
17use crate::parser::ParsedSql;
18
19/// A traced wrapper around SeaORM's `DatabaseConnection`.
20///
21/// This wrapper implements `ConnectionTrait`, `StreamTrait`, and `TransactionTrait`,
22/// making it a drop-in replacement for `DatabaseConnection`. All database operations
23/// are automatically instrumented with tracing spans.
24///
25/// # Span Nesting
26///
27/// Spans created by `TracedConnection` automatically become children of the current
28/// tracing span context. This means if you're using tracing middleware in your web
29/// framework (e.g., `tower-http`'s `TraceLayer`), database spans will appear nested
30/// under HTTP request spans in your traces.
31///
32/// # Example
33///
34/// ```rust,ignore
35/// use sea_orm::Database;
36/// use sea_orm_tracing::TracedConnection;
37///
38/// let db = Database::connect("postgres://localhost/mydb").await?;
39/// let traced = TracedConnection::from(db);
40///
41/// // All queries are now traced
42/// let users = Users::find().all(&traced).await?;
43/// ```
44#[derive(Debug, Clone)]
45pub struct TracedConnection {
46    inner: DatabaseConnection,
47    config: Arc<TracingConfig>,
48}
49
50impl TracedConnection {
51    /// Create a new traced connection with the given configuration.
52    pub fn new(connection: DatabaseConnection, config: TracingConfig) -> Self {
53        Self {
54            inner: connection,
55            config: Arc::new(config),
56        }
57    }
58
59    /// Create a new traced connection with default configuration.
60    pub fn wrap(connection: DatabaseConnection) -> Self {
61        Self::new(connection, TracingConfig::default())
62    }
63
64    /// Get a reference to the underlying `DatabaseConnection`.
65    pub fn inner(&self) -> &DatabaseConnection {
66        &self.inner
67    }
68
69    /// Get the tracing configuration.
70    pub fn config(&self) -> &TracingConfig {
71        &self.config
72    }
73
74    /// Consume the wrapper and return the inner `DatabaseConnection`.
75    pub fn into_inner(self) -> DatabaseConnection {
76        self.inner
77    }
78
79    /// Get the database backend name for span attributes.
80    fn db_system(&self) -> &'static str {
81        match self.inner.get_database_backend() {
82            DbBackend::Postgres => "postgresql",
83            DbBackend::MySql => "mysql",
84            DbBackend::Sqlite => "sqlite",
85        }
86    }
87
88    /// Create a tracing span for a database operation.
89    fn create_span(&self, stmt: &Statement) -> Span {
90        let parsed = ParsedSql::parse(&stmt.sql);
91        let span_name = parsed.span_name();
92        let db_system = self.db_system();
93
94        let span = tracing::info_span!(
95            "db.query",
96            otel.name = %span_name,
97            db.system = %db_system,
98            db.operation = %parsed.operation.as_str(),
99            db.sql.table = field::Empty,
100            db.statement = field::Empty,
101            db.rows_affected = field::Empty,
102            db.duration_ms = field::Empty,
103            db.name = field::Empty,
104            server.address = field::Empty,
105            server.port = field::Empty,
106            peer.service = field::Empty,
107            otel.status_code = field::Empty,
108            error.message = field::Empty,
109            slow_query = field::Empty,
110        );
111
112        // Record table if available
113        if let Some(table) = &parsed.table {
114            span.record("db.sql.table", table.as_str());
115        }
116
117        // Record database name if configured
118        if let Some(db_name) = &self.config.database_name {
119            span.record("db.name", db_name.as_str());
120        }
121
122        // Record server address and port for X-Ray service map
123        if let Some(addr) = &self.config.server_address {
124            span.record("server.address", addr.as_str());
125        }
126        if let Some(port) = self.config.server_port {
127            span.record("server.port", port as i64);
128        }
129
130        // Record peer service for X-Ray trace map node naming
131        if let Some(peer) = &self.config.peer_service {
132            span.record("peer.service", peer.as_str());
133        }
134
135        // Record SQL statement if configured
136        if self.config.log_statements {
137            span.record("db.statement", stmt.sql.as_str());
138        }
139
140        span
141    }
142
143    /// Record the result of a database operation in the span.
144    fn record_result<T, E: std::fmt::Display>(
145        &self,
146        span: &Span,
147        result: &Result<T, E>,
148        start: Instant,
149        row_count: Option<u64>,
150    ) {
151        let duration_ms = start.elapsed().as_millis() as i64;
152        span.record("db.duration_ms", duration_ms);
153
154        // Record row count if available and configured
155        if self.config.record_row_counts {
156            if let Some(count) = row_count {
157                span.record("db.rows_affected", count);
158            }
159        }
160
161        // Check for slow query
162        if start.elapsed() > self.config.slow_query_threshold {
163            span.record("slow_query", true);
164            let threshold_ms = self.config.slow_query_threshold.as_millis() as i64;
165            tracing::warn!(
166                parent: span,
167                duration_ms = duration_ms,
168                threshold_ms = threshold_ms,
169                "Slow query detected"
170            );
171        }
172
173        match result {
174            Ok(_) => {
175                span.record("otel.status_code", "OK");
176            }
177            Err(e) => {
178                span.record("otel.status_code", "ERROR");
179                span.record("error.message", e.to_string().as_str());
180                tracing::error!(
181                    parent: span,
182                    error = %e,
183                    "Database query failed"
184                );
185            }
186        }
187    }
188}
189
190impl From<DatabaseConnection> for TracedConnection {
191    fn from(connection: DatabaseConnection) -> Self {
192        Self::wrap(connection)
193    }
194}
195
196impl AsRef<DatabaseConnection> for TracedConnection {
197    fn as_ref(&self) -> &DatabaseConnection {
198        &self.inner
199    }
200}
201
202#[async_trait]
203impl ConnectionTrait for TracedConnection {
204    fn get_database_backend(&self) -> DbBackend {
205        self.inner.get_database_backend()
206    }
207
208    async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
209        let span = self.create_span(&stmt);
210        let start = Instant::now();
211
212        let result = self
213            .inner
214            .execute(stmt)
215            .instrument(span.clone())
216            .await;
217
218        let row_count = result.as_ref().ok().map(|r| r.rows_affected());
219        self.record_result(&span, &result, start, row_count);
220
221        result
222    }
223
224    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
225        let stmt = Statement::from_string(self.get_database_backend(), sql);
226        let span = self.create_span(&stmt);
227        let start = Instant::now();
228
229        let result = self
230            .inner
231            .execute_unprepared(sql)
232            .instrument(span.clone())
233            .await;
234
235        let row_count = result.as_ref().ok().map(|r| r.rows_affected());
236        self.record_result(&span, &result, start, row_count);
237
238        result
239    }
240
241    async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
242        let span = self.create_span(&stmt);
243        let start = Instant::now();
244
245        let result = self
246            .inner
247            .query_one(stmt)
248            .instrument(span.clone())
249            .await;
250
251        let row_count = result.as_ref().ok().map(|opt| if opt.is_some() { 1 } else { 0 });
252        self.record_result(&span, &result, start, row_count);
253
254        result
255    }
256
257    async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
258        let span = self.create_span(&stmt);
259        let start = Instant::now();
260
261        let result = self
262            .inner
263            .query_all(stmt)
264            .instrument(span.clone())
265            .await;
266
267        let row_count = result.as_ref().ok().map(|rows| rows.len() as u64);
268        self.record_result(&span, &result, start, row_count);
269
270        result
271    }
272
273    fn support_returning(&self) -> bool {
274        self.inner.support_returning()
275    }
276
277    fn is_mock_connection(&self) -> bool {
278        self.inner.is_mock_connection()
279    }
280}
281
282#[async_trait]
283impl StreamTrait for TracedConnection {
284    type Stream<'a> = <DatabaseConnection as StreamTrait>::Stream<'a>;
285
286    fn stream<'a>(
287        &'a self,
288        stmt: Statement,
289    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
290        let span = self.create_span(&stmt);
291        let start = Instant::now();
292        let config = self.config.clone();
293
294        Box::pin(async move {
295            let result = self.inner.stream(stmt).instrument(span.clone()).await;
296
297            // Record basic result info (we can't know row count for streams)
298            let duration_ms = start.elapsed().as_millis() as i64;
299            span.record("db.duration_ms", duration_ms);
300
301            if start.elapsed() > config.slow_query_threshold {
302                span.record("slow_query", true);
303            }
304
305            match &result {
306                Ok(_) => {
307                    span.record("otel.status_code", "OK");
308                }
309                Err(e) => {
310                    span.record("otel.status_code", "ERROR");
311                    span.record("error.message", e.to_string().as_str());
312                }
313            }
314
315            result
316        })
317    }
318}
319
320#[async_trait]
321impl TransactionTrait for TracedConnection {
322    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
323        let span = tracing::info_span!(
324            "db.transaction",
325            otel.name = "BEGIN",
326            db.system = %self.db_system(),
327            db.operation = "BEGIN",
328            otel.status_code = field::Empty,
329            error.message = field::Empty,
330        );
331
332        let result = self.inner.begin().instrument(span.clone()).await;
333
334        match &result {
335            Ok(_) => {
336                span.record("otel.status_code", "OK");
337            }
338            Err(e) => {
339                span.record("otel.status_code", "ERROR");
340                span.record("error.message", e.to_string().as_str());
341            }
342        }
343
344        result
345    }
346
347    async fn begin_with_config(
348        &self,
349        isolation_level: Option<IsolationLevel>,
350        access_mode: Option<AccessMode>,
351    ) -> Result<DatabaseTransaction, DbErr> {
352        let span = tracing::info_span!(
353            "db.transaction",
354            otel.name = "BEGIN",
355            db.system = %self.db_system(),
356            db.operation = "BEGIN",
357            db.transaction.isolation_level = ?isolation_level,
358            db.transaction.access_mode = ?access_mode,
359            otel.status_code = field::Empty,
360            error.message = field::Empty,
361        );
362
363        let result = self
364            .inner
365            .begin_with_config(isolation_level, access_mode)
366            .instrument(span.clone())
367            .await;
368
369        match &result {
370            Ok(_) => {
371                span.record("otel.status_code", "OK");
372            }
373            Err(e) => {
374                span.record("otel.status_code", "ERROR");
375                span.record("error.message", e.to_string().as_str());
376            }
377        }
378
379        result
380    }
381
382    async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
383    where
384        F: for<'c> FnOnce(
385                &'c DatabaseTransaction,
386            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
387            + Send,
388        T: Send,
389        E: std::fmt::Display + std::fmt::Debug + Send,
390    {
391        let span = tracing::info_span!(
392            "db.transaction",
393            otel.name = "TRANSACTION",
394            db.system = %self.db_system(),
395            db.operation = "TRANSACTION",
396            otel.status_code = field::Empty,
397            error.message = field::Empty,
398        );
399
400        let result = self
401            .inner
402            .transaction(callback)
403            .instrument(span.clone())
404            .await;
405
406        match &result {
407            Ok(_) => {
408                span.record("otel.status_code", "OK");
409            }
410            Err(e) => {
411                span.record("otel.status_code", "ERROR");
412                span.record("error.message", format!("{:?}", e).as_str());
413            }
414        }
415
416        result
417    }
418
419    async fn transaction_with_config<F, T, E>(
420        &self,
421        callback: F,
422        isolation_level: Option<IsolationLevel>,
423        access_mode: Option<AccessMode>,
424    ) -> Result<T, TransactionError<E>>
425    where
426        F: for<'c> FnOnce(
427                &'c DatabaseTransaction,
428            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
429            + Send,
430        T: Send,
431        E: std::fmt::Display + std::fmt::Debug + Send,
432    {
433        let span = tracing::info_span!(
434            "db.transaction",
435            otel.name = "TRANSACTION",
436            db.system = %self.db_system(),
437            db.operation = "TRANSACTION",
438            db.transaction.isolation_level = ?isolation_level,
439            db.transaction.access_mode = ?access_mode,
440            otel.status_code = field::Empty,
441            error.message = field::Empty,
442        );
443
444        let result = self
445            .inner
446            .transaction_with_config(callback, isolation_level, access_mode)
447            .instrument(span.clone())
448            .await;
449
450        match &result {
451            Ok(_) => {
452                span.record("otel.status_code", "OK");
453            }
454            Err(e) => {
455                span.record("otel.status_code", "ERROR");
456                span.record("error.message", format!("{:?}", e).as_str());
457            }
458        }
459
460        result
461    }
462}
463
464/// Extension trait for easy wrapping of database connections.
465pub trait TracingExt {
466    /// Wrap this connection with tracing instrumentation.
467    fn with_tracing(self) -> TracedConnection;
468
469    /// Wrap this connection with custom tracing configuration.
470    fn with_tracing_config(self, config: TracingConfig) -> TracedConnection;
471}
472
473impl TracingExt for DatabaseConnection {
474    fn with_tracing(self) -> TracedConnection {
475        TracedConnection::wrap(self)
476    }
477
478    fn with_tracing_config(self, config: TracingConfig) -> TracedConnection {
479        TracedConnection::new(self, config)
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn test_config_builder() {
489        let config = TracingConfig::default()
490            .with_statement_logging(true)
491            .with_database_name("test_db");
492
493        assert!(config.log_statements);
494        assert_eq!(config.database_name, Some("test_db".to_string()));
495    }
496
497    #[test]
498    fn test_development_config() {
499        let config = TracingConfig::development();
500        assert!(config.log_statements);
501        assert!(config.log_parameters);
502    }
503
504    #[test]
505    fn test_production_config() {
506        let config = TracingConfig::production();
507        assert!(!config.log_statements);
508        assert!(!config.log_parameters);
509    }
510}