Skip to main content

rustauth_core/db/sql/
executor.rs

1use super::*;
2
3/// Minimal async execution boundary required by the shared SQL runner.
4///
5/// Adapter crates implement this trait for their driver/pool/transaction
6/// context. The shared layer owns SQL planning, while this trait owns only
7/// driver execution and returning raw driver rows.
8pub trait SqlExecutor {
9    /// Driver-specific row type returned by fetch operations.
10    type Row;
11
12    /// Executes a statement that does not need decoded rows and returns affected rows.
13    fn execute<'a>(&'a mut self, statement: SqlStatement) -> AdapterFuture<'a, u64>;
14
15    /// Fetches all rows produced by a read statement.
16    fn fetch_all<'a>(&'a mut self, statement: SqlStatement) -> AdapterFuture<'a, Vec<Self::Row>>;
17
18    /// Fetches at most one row produced by a read statement.
19    fn fetch_optional<'a>(
20        &'a mut self,
21        statement: SqlStatement,
22    ) -> AdapterFuture<'a, Option<Self::Row>>;
23
24    /// Fetches a single signed integer scalar, used by count queries.
25    fn fetch_scalar_i64<'a>(&'a mut self, statement: SqlStatement) -> AdapterFuture<'a, i64>;
26}
27
28/// Driver row decoding boundary for converting raw driver rows into RustAuth values.
29pub trait SqlRowReader<Row> {
30    /// Reads a single projected field by SQL alias.
31    fn value_at(&self, row: &Row, field: &DbField, alias: &str) -> Result<DbValue, RustAuthError>;
32
33    /// Reads a complete record from the selected fields tracked by a read statement.
34    fn record(&self, row: &Row, selection: &[SqlSelectedField]) -> Result<DbRecord, RustAuthError> {
35        selection
36            .iter()
37            .map(|selected| {
38                self.value_at(row, &selected.field, &selected.alias)
39                    .map(|value| (selected.logical_name.clone(), value))
40            })
41            .collect()
42    }
43}
44
45/// Shared CRUD runner for SQL adapters that can execute raw SQL.
46pub struct SqlAdapterRunner<'a, E, R> {
47    dialect: SqlDialect,
48    schema: &'a DbSchema,
49    executor: E,
50    row_reader: R,
51}
52
53impl<'a, E, R> SqlAdapterRunner<'a, E, R> {
54    /// Creates a runner for one adapter operation.
55    pub fn new(dialect: SqlDialect, schema: &'a DbSchema, executor: E, row_reader: R) -> Self {
56        Self {
57            dialect,
58            schema,
59            executor,
60            row_reader,
61        }
62    }
63}
64
65impl<E, R> SqlAdapterRunner<'_, E, R>
66where
67    E: SqlExecutor,
68    R: SqlRowReader<E::Row>,
69{
70    pub async fn create(mut self, query: Create) -> Result<DbRecord, RustAuthError> {
71        let table = resolve_table(self.schema, &query.model)?;
72        let statement = create_statement(self.dialect, self.schema, &query)?;
73        if self.dialect.supports_insert_returning() && table_has_database_generated_id(table) {
74            let selection = create_returning_selection(self.schema, &query)?;
75            let row = self.executor.fetch_optional(statement).await?;
76            return row
77                .as_ref()
78                .map(|row| self.row_reader.record(row, &selection))
79                .transpose()?
80                .ok_or_else(|| {
81                    RustAuthError::Adapter(
82                        "sql adapter did not return inserted database-generated id".to_owned(),
83                    )
84                });
85        }
86        self.executor.execute(statement).await?;
87        if self.dialect == SqlDialect::MySql
88            && table
89                .field("id")
90                .is_some_and(|field| field.generated_id == Some(IdGeneration::Serial))
91        {
92            let id = self
93                .executor
94                .fetch_scalar_i64(SqlStatement::new("SELECT CAST(LAST_INSERT_ID() AS SIGNED)"))
95                .await?;
96            let mut record = query.data;
97            record.insert("id".to_owned(), DbValue::Number(id));
98            return Ok(select_record(record, &query.select));
99        }
100        Ok(select_record(query.data, &query.select))
101    }
102
103    pub async fn find_one(mut self, query: FindOne) -> Result<Option<DbRecord>, RustAuthError> {
104        if !query.joins.is_empty() {
105            let mut find_many = FindMany::new(query.model);
106            find_many.where_clauses = query.where_clauses;
107            find_many.limit = Some(1);
108            find_many.select = query.select;
109            find_many.joins = query.joins;
110            return self
111                .find_many(find_many)
112                .await
113                .map(|records| records.into_iter().next());
114        }
115        let read = find_one_statement(self.dialect, self.schema, &query)?;
116        let row = self.executor.fetch_optional(read.statement).await?;
117        row.as_ref()
118            .map(|row| self.row_reader.record(row, &read.selection))
119            .transpose()
120    }
121
122    pub async fn find_many(mut self, query: FindMany) -> Result<Vec<DbRecord>, RustAuthError> {
123        if !query.joins.is_empty() {
124            return self.find_many_with_joins(query).await;
125        }
126        let read = find_many_statement(self.dialect, self.schema, &query)?;
127        let rows = self.executor.fetch_all(read.statement).await?;
128        rows.iter()
129            .map(|row| self.row_reader.record(row, &read.selection))
130            .collect()
131    }
132
133    async fn find_many_with_joins(
134        mut self,
135        query: FindMany,
136    ) -> Result<Vec<DbRecord>, RustAuthError> {
137        let read = find_many_with_joins_statement(self.dialect, self.schema, &query)?;
138        let rows = self.executor.fetch_all(read.statement).await?;
139        joined_rows(
140            &rows,
141            &read.base_selection,
142            &query.select,
143            &read.joins,
144            |row, field, alias| self.row_reader.value_at(row, field, alias),
145        )
146    }
147
148    pub async fn count(mut self, query: Count) -> Result<u64, RustAuthError> {
149        let statement = count_statement(self.dialect, self.schema, &query)?;
150        let count = self.executor.fetch_scalar_i64(statement).await?;
151        u64::try_from(count).map_err(|_| RustAuthError::NumericOutOfRange {
152            context: "SQL count result",
153        })
154    }
155
156    pub async fn update(mut self, query: Update) -> Result<Option<DbRecord>, RustAuthError> {
157        if query.data.is_empty() {
158            return Ok(None);
159        }
160        match update_one_plan(self.dialect, self.schema, &query)? {
161            SqlUpdateOnePlan::Returning(read) => {
162                let row = self.executor.fetch_optional(read.statement).await?;
163                row.as_ref()
164                    .map(|row| self.row_reader.record(row, &read.selection))
165                    .transpose()
166            }
167            SqlUpdateOnePlan::PreselectThenUpdate {
168                select,
169                update,
170                data,
171            } => {
172                let Some(row) = self.executor.fetch_optional(select.statement).await? else {
173                    return Ok(None);
174                };
175                let mut record = self.row_reader.record(&row, &select.selection)?;
176                self.executor.execute(update).await?;
177                record.extend(data);
178                Ok(Some(record))
179            }
180        }
181    }
182
183    pub async fn update_many(mut self, query: UpdateMany) -> Result<u64, RustAuthError> {
184        if query.data.is_empty() {
185            return Ok(0);
186        }
187        let statement = update_many_statement(self.dialect, self.schema, &query)?;
188        self.executor.execute(statement).await
189    }
190
191    pub async fn delete(mut self, query: Delete) -> Result<(), RustAuthError> {
192        let plan = delete_one_statement(self.dialect, self.schema, &query)?;
193        self.executor.execute(plan.statement).await?;
194        Ok(())
195    }
196
197    pub async fn delete_many(mut self, query: DeleteMany) -> Result<u64, RustAuthError> {
198        let statement = delete_many_statement(self.dialect, self.schema, &query)?;
199        self.executor.execute(statement).await
200    }
201}