1use super::*;
2
3pub trait SqlExecutor {
9 type Row;
11
12 fn execute<'a>(&'a mut self, statement: SqlStatement) -> AdapterFuture<'a, u64>;
14
15 fn fetch_all<'a>(&'a mut self, statement: SqlStatement) -> AdapterFuture<'a, Vec<Self::Row>>;
17
18 fn fetch_optional<'a>(
20 &'a mut self,
21 statement: SqlStatement,
22 ) -> AdapterFuture<'a, Option<Self::Row>>;
23
24 fn fetch_scalar_i64<'a>(&'a mut self, statement: SqlStatement) -> AdapterFuture<'a, i64>;
26}
27
28pub trait SqlRowReader<Row> {
30 fn value_at(&self, row: &Row, field: &DbField, alias: &str) -> Result<DbValue, RustAuthError>;
32
33 fn record(&self, row: &Row, selection: &[SqlSelectedField]) -> Result<DbRecord, RustAuthError> {
35 selection
36 .iter()
37 .map(|selected| {
38 self.value_at(row, &selected.field, &selected.alias)
39 .map(|value| (selected.logical_name.clone(), value))
40 })
41 .collect()
42 }
43}
44
45pub struct SqlAdapterRunner<'a, E, R> {
47 dialect: SqlDialect,
48 schema: &'a DbSchema,
49 executor: E,
50 row_reader: R,
51}
52
53impl<'a, E, R> SqlAdapterRunner<'a, E, R> {
54 pub fn new(dialect: SqlDialect, schema: &'a DbSchema, executor: E, row_reader: R) -> Self {
56 Self {
57 dialect,
58 schema,
59 executor,
60 row_reader,
61 }
62 }
63}
64
65impl<E, R> SqlAdapterRunner<'_, E, R>
66where
67 E: SqlExecutor,
68 R: SqlRowReader<E::Row>,
69{
70 pub async fn create(mut self, query: Create) -> Result<DbRecord, RustAuthError> {
71 let table = resolve_table(self.schema, &query.model)?;
72 let statement = create_statement(self.dialect, self.schema, &query)?;
73 if self.dialect.supports_insert_returning() && table_has_database_generated_id(table) {
74 let selection = create_returning_selection(self.schema, &query)?;
75 let row = self.executor.fetch_optional(statement).await?;
76 return row
77 .as_ref()
78 .map(|row| self.row_reader.record(row, &selection))
79 .transpose()?
80 .ok_or_else(|| {
81 RustAuthError::Adapter(
82 "sql adapter did not return inserted database-generated id".to_owned(),
83 )
84 });
85 }
86 self.executor.execute(statement).await?;
87 if self.dialect == SqlDialect::MySql
88 && table
89 .field("id")
90 .is_some_and(|field| field.generated_id == Some(IdGeneration::Serial))
91 {
92 let id = self
93 .executor
94 .fetch_scalar_i64(SqlStatement::new("SELECT CAST(LAST_INSERT_ID() AS SIGNED)"))
95 .await?;
96 let mut record = query.data;
97 record.insert("id".to_owned(), DbValue::Number(id));
98 return Ok(select_record(record, &query.select));
99 }
100 Ok(select_record(query.data, &query.select))
101 }
102
103 pub async fn find_one(mut self, query: FindOne) -> Result<Option<DbRecord>, RustAuthError> {
104 if !query.joins.is_empty() {
105 let mut find_many = FindMany::new(query.model);
106 find_many.where_clauses = query.where_clauses;
107 find_many.limit = Some(1);
108 find_many.select = query.select;
109 find_many.joins = query.joins;
110 return self
111 .find_many(find_many)
112 .await
113 .map(|records| records.into_iter().next());
114 }
115 let read = find_one_statement(self.dialect, self.schema, &query)?;
116 let row = self.executor.fetch_optional(read.statement).await?;
117 row.as_ref()
118 .map(|row| self.row_reader.record(row, &read.selection))
119 .transpose()
120 }
121
122 pub async fn find_many(mut self, query: FindMany) -> Result<Vec<DbRecord>, RustAuthError> {
123 if !query.joins.is_empty() {
124 return self.find_many_with_joins(query).await;
125 }
126 let read = find_many_statement(self.dialect, self.schema, &query)?;
127 let rows = self.executor.fetch_all(read.statement).await?;
128 rows.iter()
129 .map(|row| self.row_reader.record(row, &read.selection))
130 .collect()
131 }
132
133 async fn find_many_with_joins(
134 mut self,
135 query: FindMany,
136 ) -> Result<Vec<DbRecord>, RustAuthError> {
137 let read = find_many_with_joins_statement(self.dialect, self.schema, &query)?;
138 let rows = self.executor.fetch_all(read.statement).await?;
139 joined_rows(
140 &rows,
141 &read.base_selection,
142 &query.select,
143 &read.joins,
144 |row, field, alias| self.row_reader.value_at(row, field, alias),
145 )
146 }
147
148 pub async fn count(mut self, query: Count) -> Result<u64, RustAuthError> {
149 let statement = count_statement(self.dialect, self.schema, &query)?;
150 let count = self.executor.fetch_scalar_i64(statement).await?;
151 u64::try_from(count).map_err(|_| RustAuthError::NumericOutOfRange {
152 context: "SQL count result",
153 })
154 }
155
156 pub async fn update(mut self, query: Update) -> Result<Option<DbRecord>, RustAuthError> {
157 if query.data.is_empty() {
158 return Ok(None);
159 }
160 match update_one_plan(self.dialect, self.schema, &query)? {
161 SqlUpdateOnePlan::Returning(read) => {
162 let row = self.executor.fetch_optional(read.statement).await?;
163 row.as_ref()
164 .map(|row| self.row_reader.record(row, &read.selection))
165 .transpose()
166 }
167 SqlUpdateOnePlan::PreselectThenUpdate {
168 select,
169 update,
170 data,
171 } => {
172 let Some(row) = self.executor.fetch_optional(select.statement).await? else {
173 return Ok(None);
174 };
175 let mut record = self.row_reader.record(&row, &select.selection)?;
176 self.executor.execute(update).await?;
177 record.extend(data);
178 Ok(Some(record))
179 }
180 }
181 }
182
183 pub async fn update_many(mut self, query: UpdateMany) -> Result<u64, RustAuthError> {
184 if query.data.is_empty() {
185 return Ok(0);
186 }
187 let statement = update_many_statement(self.dialect, self.schema, &query)?;
188 self.executor.execute(statement).await
189 }
190
191 pub async fn delete(mut self, query: Delete) -> Result<(), RustAuthError> {
192 let plan = delete_one_statement(self.dialect, self.schema, &query)?;
193 self.executor.execute(plan.statement).await?;
194 Ok(())
195 }
196
197 pub async fn delete_many(mut self, query: DeleteMany) -> Result<u64, RustAuthError> {
198 let statement = delete_many_statement(self.dialect, self.schema, &query)?;
199 self.executor.execute(statement).await
200 }
201}