sos_database/
audit_provider.rs

1//! Database audit log provider.
2use crate::{
3    entity::{AuditEntity, AuditRecord, AuditRow},
4    Error,
5};
6use async_sqlite::Client;
7use async_trait::async_trait;
8use futures::stream::BoxStream;
9use sos_audit::{AuditEvent, AuditStreamSink};
10use tokio_stream::wrappers::ReceiverStream;
11
12/// Audit provider that appends to a database table.
13pub struct AuditDatabaseProvider<E>
14where
15    E: std::error::Error
16        + std::fmt::Debug
17        + From<crate::Error>
18        + Send
19        + Sync
20        + 'static,
21{
22    client: Client,
23    marker: std::marker::PhantomData<E>,
24}
25
26impl<E> AuditDatabaseProvider<E>
27where
28    E: std::error::Error
29        + std::fmt::Debug
30        + From<crate::Error>
31        + Send
32        + Sync
33        + 'static,
34{
35    /// Create a new audit file provider.
36    pub fn new(client: Client) -> Self {
37        Self {
38            client,
39            marker: std::marker::PhantomData,
40        }
41    }
42}
43
44#[async_trait]
45impl<E> AuditStreamSink for AuditDatabaseProvider<E>
46where
47    E: std::error::Error
48        + std::fmt::Debug
49        + From<crate::Error>
50        + From<std::io::Error>
51        + Send
52        + Sync
53        + 'static,
54{
55    type Error = E;
56
57    async fn append_audit_events(
58        &self,
59        events: &[AuditEvent],
60    ) -> std::result::Result<(), Self::Error> {
61        let mut audit_events = Vec::new();
62        for event in events {
63            audit_events.push(event.try_into()?);
64        }
65        self.client
66            .conn(move |conn| {
67                let audit = AuditEntity::new(&conn);
68                audit.insert_audit_logs(audit_events.as_slice())?;
69                Ok(())
70            })
71            .await
72            .map_err(Error::from)?;
73        Ok(())
74    }
75
76    async fn audit_stream(
77        &self,
78        reverse: bool,
79    ) -> std::result::Result<
80        BoxStream<'static, std::result::Result<AuditEvent, Self::Error>>,
81        Self::Error,
82    > {
83        let (tx, rx) = tokio::sync::mpsc::channel::<
84            std::result::Result<AuditEvent, Self::Error>,
85        >(16);
86
87        self.client
88            .conn_and_then(move |conn| {
89                let mut stmt = if reverse {
90                    conn.prepare(
91                        "SELECT * FROM audit_logs ORDER BY log_id DESC",
92                    )?
93                } else {
94                    conn.prepare(
95                        "SELECT * FROM audit_logs ORDER BY log_id ASC",
96                    )?
97                };
98                let mut rows = stmt.query([])?;
99
100                while let Some(row) = rows.next()? {
101                    let row: AuditRow = row.try_into()?;
102                    let record: AuditRecord = row.try_into()?;
103                    let inner_tx = tx.clone();
104                    futures::executor::block_on(async move {
105                        if let Err(e) = inner_tx.send(Ok(record.event)).await
106                        {
107                            tracing::error!(error = %e);
108                        }
109                    });
110                }
111
112                Ok::<_, Error>(())
113            })
114            .await
115            .map_err(Error::from)?;
116
117        Ok(Box::pin(ReceiverStream::new(rx)))
118    }
119}