sos_database/
audit_provider.rs

1//! Database audit log provider.
2use crate::{
3    entity::{AuditEntity, AuditRecord, AuditRow},
4    Error,
5};
6use async_sqlite::Client;
7use async_trait::async_trait;
8use futures::stream::BoxStream;
9use sos_audit::{AuditEvent, AuditStreamSink};
10use tokio_stream::wrappers::ReceiverStream;
11
12/// Audit provider that appends to a database table.
13pub struct AuditDatabaseProvider<E>
14where
15    E: std::error::Error
16        + std::fmt::Debug
17        + From<crate::Error>
18        + Send
19        + Sync
20        + 'static,
21{
22    client: Client,
23    marker: std::marker::PhantomData<E>,
24}
25
26impl<E> AuditDatabaseProvider<E>
27where
28    E: std::error::Error
29        + std::fmt::Debug
30        + From<crate::Error>
31        + Send
32        + Sync
33        + 'static,
34{
35    /// Create a new audit file provider.
36    pub fn new(client: Client) -> Self {
37        Self {
38            client,
39            marker: std::marker::PhantomData,
40        }
41    }
42}
43
44#[async_trait]
45impl<E> AuditStreamSink for AuditDatabaseProvider<E>
46where
47    E: std::error::Error
48        + std::fmt::Debug
49        + From<crate::Error>
50        + From<std::io::Error>
51        + Send
52        + Sync
53        + 'static,
54{
55    type Error = E;
56
57    async fn append_audit_events(
58        &self,
59        events: &[AuditEvent],
60    ) -> std::result::Result<(), Self::Error> {
61        let mut audit_events = Vec::new();
62        for event in events {
63            audit_events.push(event.try_into()?);
64        }
65        self.client
66            .conn(move |conn| {
67                let audit = AuditEntity::new(&conn);
68                audit.insert_audit_logs(audit_events.as_slice())?;
69                Ok(())
70            })
71            .await
72            .map_err(Error::from)?;
73        Ok(())
74    }
75
76    async fn audit_stream(
77        &self,
78        reverse: bool,
79    ) -> std::result::Result<
80        BoxStream<'static, std::result::Result<AuditEvent, Self::Error>>,
81        Self::Error,
82    > {
83        let (tx, rx) = tokio::sync::mpsc::channel::<
84            std::result::Result<AuditEvent, Self::Error>,
85        >(16);
86
87        let client = self.client.clone();
88        tokio::task::spawn(async move {
89            client
90                .conn_and_then(move |conn| {
91                    let mut stmt = if reverse {
92                        conn.prepare(
93                            "SELECT * FROM audit_logs ORDER BY log_id DESC",
94                        )?
95                    } else {
96                        conn.prepare(
97                            "SELECT * FROM audit_logs ORDER BY log_id ASC",
98                        )?
99                    };
100                    let mut rows = stmt.query([])?;
101
102                    while let Some(row) = rows.next()? {
103                        if tx.is_closed() {
104                            break;
105                        }
106                        let row: AuditRow = row.try_into()?;
107                        let record: AuditRecord = row.try_into()?;
108                        let inner_tx = tx.clone();
109                        let res = futures::executor::block_on(async move {
110                            inner_tx.send(Ok(record.event)).await
111                        });
112                        if let Err(e) = res {
113                            tracing::error!(error = %e);
114                            break;
115                        }
116                    }
117
118                    Ok::<_, Error>(())
119                })
120                .await?;
121            Ok::<_, Self::Error>(())
122        });
123
124        Ok(Box::pin(ReceiverStream::new(rx)))
125    }
126}