1use std::collections::BTreeMap;
2
3use serde_json::Value;
4
5use crate::api::http_json::{snake_to_camel, user_to_http_value};
6use crate::context::AuthContext;
7use crate::db::{DbAdapter, DbFieldType, DbRecord, DbValue, FindOne, User};
8use crate::error::RustAuthError;
9use crate::options::{SessionAdditionalField, UserAdditionalField};
10
11pub trait AdditionalField {
12 fn field_type(&self) -> &DbFieldType;
13 fn required(&self) -> bool;
14 fn input(&self) -> bool;
15 fn returned(&self) -> bool;
16 fn default_value(&self) -> Option<&DbValue>;
17 fn db_name(&self) -> Option<&str>;
18}
19
20impl AdditionalField for UserAdditionalField {
21 fn field_type(&self) -> &DbFieldType {
22 &self.field_type
23 }
24
25 fn required(&self) -> bool {
26 self.required
27 }
28
29 fn input(&self) -> bool {
30 self.input
31 }
32
33 fn returned(&self) -> bool {
34 self.returned
35 }
36
37 fn default_value(&self) -> Option<&DbValue> {
38 self.default_value.as_ref()
39 }
40
41 fn db_name(&self) -> Option<&str> {
42 self.db_name.as_deref()
43 }
44}
45
46impl AdditionalField for SessionAdditionalField {
47 fn field_type(&self) -> &DbFieldType {
48 &self.field_type
49 }
50
51 fn required(&self) -> bool {
52 self.required
53 }
54
55 fn input(&self) -> bool {
56 self.input
57 }
58
59 fn returned(&self) -> bool {
60 self.returned
61 }
62
63 fn default_value(&self) -> Option<&DbValue> {
64 self.default_value.as_ref()
65 }
66
67 fn db_name(&self) -> Option<&str> {
68 self.db_name.as_deref()
69 }
70}
71
72pub fn create_values<F>(
73 fields: &BTreeMap<String, F>,
74 body: &serde_json::Map<String, Value>,
75) -> Result<DbRecord, AdditionalFieldError>
76where
77 F: AdditionalField,
78{
79 let mut values = DbRecord::new();
80 for (name, field) in fields {
81 match body.get(name) {
82 Some(value) => {
83 if !field.input() {
84 return Err(AdditionalFieldError::NotInput(name.clone()));
85 }
86 values.insert(
87 storage_name(name, field),
88 json_to_db_value(name, field.field_type(), value)
89 .map_err(AdditionalFieldError::InvalidType)?,
90 );
91 }
92 None => {
93 if let Some(value) = field.default_value() {
94 values.insert(storage_name(name, field), value.clone());
95 } else if field.required() {
96 return Err(AdditionalFieldError::MissingRequired(name.clone()));
97 } else {
98 values.insert(storage_name(name, field), DbValue::Null);
99 }
100 }
101 }
102 }
103 Ok(values)
104}
105
106pub fn update_values<F>(
107 fields: &BTreeMap<String, F>,
108 body: &serde_json::Map<String, Value>,
109) -> Result<DbRecord, AdditionalFieldError>
110where
111 F: AdditionalField,
112{
113 let mut values = DbRecord::new();
114 for (name, value) in body {
115 let Some(field) = fields.get(name) else {
116 continue;
117 };
118 if !field.input() {
119 return Err(AdditionalFieldError::NotInput(name.clone()));
120 }
121 values.insert(
122 storage_name(name, field),
123 json_to_db_value(name, field.field_type(), value)
124 .map_err(AdditionalFieldError::InvalidType)?,
125 );
126 }
127 Ok(values)
128}
129
130pub fn insert_returned_fields<F>(
131 object: &mut serde_json::Map<String, Value>,
132 fields: &BTreeMap<String, F>,
133 record: &DbRecord,
134) -> Result<(), RustAuthError>
135where
136 F: AdditionalField,
137{
138 for (name, field) in fields {
139 if !field.returned() {
140 continue;
141 }
142 let value = record
143 .get(name)
144 .or_else(|| field.db_name().and_then(|db_name| record.get(db_name)))
145 .or_else(|| field.default_value())
146 .unwrap_or(&DbValue::Null);
147 object.insert(name.clone(), db_value_to_json(value)?);
148 }
149 Ok(())
150}
151
152pub(crate) fn insert_returned_fields_http<F>(
154 object: &mut serde_json::Map<String, Value>,
155 fields: &BTreeMap<String, F>,
156 record: &DbRecord,
157) -> Result<(), RustAuthError>
158where
159 F: AdditionalField,
160{
161 for (name, field) in fields {
162 if !field.returned() {
163 continue;
164 }
165 let value = record
166 .get(name)
167 .or_else(|| field.db_name().and_then(|db_name| record.get(db_name)))
168 .or_else(|| field.default_value())
169 .unwrap_or(&DbValue::Null);
170 object.insert(snake_to_camel(name), db_value_to_json(value)?);
171 }
172 Ok(())
173}
174
175fn storage_name<F>(logical_name: &str, field: &F) -> String
176where
177 F: AdditionalField,
178{
179 field
180 .db_name()
181 .map(str::to_owned)
182 .unwrap_or_else(|| logical_name.to_owned())
183}
184
185pub fn db_value_to_json(value: &DbValue) -> Result<Value, RustAuthError> {
186 match value {
187 DbValue::String(value) => Ok(Value::String(value.clone())),
188 DbValue::Number(value) => Ok(Value::Number((*value).into())),
189 DbValue::Boolean(value) => Ok(Value::Bool(*value)),
190 DbValue::Timestamp(value) => {
191 serde_json::to_value(value).map_err(|error| RustAuthError::Serialization {
192 context: "serializing additional field timestamp",
193 message: error.to_string(),
194 })
195 }
196 DbValue::Json(value) => Ok(value.clone()),
197 DbValue::StringArray(values) => Ok(Value::Array(
198 values.iter().cloned().map(Value::String).collect(),
199 )),
200 DbValue::NumberArray(values) => Ok(Value::Array(
201 values
202 .iter()
203 .map(|value| Value::Number((*value).into()))
204 .collect(),
205 )),
206 DbValue::Record(record) => db_record_to_json(record),
207 DbValue::RecordArray(records) => records
208 .iter()
209 .map(db_record_to_json)
210 .collect::<Result<Vec<_>, _>>()
211 .map(Value::Array),
212 DbValue::Null => Ok(Value::Null),
213 }
214}
215
216pub fn json_to_db_value(
217 name: &str,
218 field_type: &DbFieldType,
219 value: &Value,
220) -> Result<DbValue, String> {
221 if value.is_null() {
222 return Ok(DbValue::Null);
223 }
224 match field_type {
225 DbFieldType::String => value
226 .as_str()
227 .map(|value| DbValue::String(value.to_owned())),
228 DbFieldType::Number => value.as_i64().map(DbValue::Number),
229 DbFieldType::Boolean => value.as_bool().map(DbValue::Boolean),
230 DbFieldType::Json => Some(DbValue::Json(value.clone())),
231 DbFieldType::StringArray => value.as_array().and_then(|values| {
232 values
233 .iter()
234 .map(|value| value.as_str().map(str::to_owned))
235 .collect::<Option<Vec<_>>>()
236 .map(DbValue::StringArray)
237 }),
238 DbFieldType::NumberArray => value.as_array().and_then(|values| {
239 values
240 .iter()
241 .map(Value::as_i64)
242 .collect::<Option<Vec<_>>>()
243 .map(DbValue::NumberArray)
244 }),
245 DbFieldType::Timestamp => None,
246 }
247 .ok_or_else(|| format!("invalid value for additional field `{name}`"))
248}
249
250#[derive(Debug, Clone, PartialEq, Eq)]
251pub enum AdditionalFieldError {
252 MissingRequired(String),
253 NotInput(String),
254 InvalidType(String),
255}
256
257impl AdditionalFieldError {
258 pub fn message(&self) -> String {
259 match self {
260 Self::MissingRequired(name) => format!("missing required additional field `{name}`"),
261 Self::NotInput(name) => format!("additional field `{name}` is not accepted as input"),
262 Self::InvalidType(message) => message.clone(),
263 }
264 }
265}
266
267pub async fn user_response_value(
268 adapter: &dyn DbAdapter,
269 context: &AuthContext,
270 fields: &BTreeMap<String, UserAdditionalField>,
271 user: &User,
272) -> Result<Value, RustAuthError> {
273 if fields.is_empty() {
274 return user_to_http_value(user);
275 }
276 let users = context.schema().table("user")?;
277 let record = adapter
278 .find_one(
279 FindOne::new(users.model())
280 .where_clause(users.where_eq("id", DbValue::String(user.id.clone()))?),
281 )
282 .await?
283 .map(|record| users.map_record(record))
284 .transpose()?;
285 let mut value = user_to_http_value(user)?;
286 let Some(object) = value.as_object_mut() else {
287 return Err(RustAuthError::Serialization {
288 context: "serializing user output",
289 message: "expected JSON object".to_owned(),
290 });
291 };
292 if let Some(record) = record {
293 insert_returned_fields_http(object, fields, &record)?;
294 }
295 Ok(value)
296}
297
298fn db_record_to_json(record: &DbRecord) -> Result<Value, RustAuthError> {
299 record
300 .iter()
301 .map(|(field, value)| db_value_to_json(value).map(|value| (field.clone(), value)))
302 .collect::<Result<serde_json::Map<_, _>, _>>()
303 .map(Value::Object)
304}