Skip to main content

rustauth_core/api/
additional_fields.rs

1use std::collections::BTreeMap;
2
3use serde_json::Value;
4
5use crate::api::http_json::{snake_to_camel, user_to_http_value};
6use crate::context::AuthContext;
7use crate::db::{DbAdapter, DbFieldType, DbRecord, DbValue, FindOne, User};
8use crate::error::RustAuthError;
9use crate::options::{SessionAdditionalField, UserAdditionalField};
10
11pub trait AdditionalField {
12    fn field_type(&self) -> &DbFieldType;
13    fn required(&self) -> bool;
14    fn input(&self) -> bool;
15    fn returned(&self) -> bool;
16    fn default_value(&self) -> Option<&DbValue>;
17    fn db_name(&self) -> Option<&str>;
18}
19
20impl AdditionalField for UserAdditionalField {
21    fn field_type(&self) -> &DbFieldType {
22        &self.field_type
23    }
24
25    fn required(&self) -> bool {
26        self.required
27    }
28
29    fn input(&self) -> bool {
30        self.input
31    }
32
33    fn returned(&self) -> bool {
34        self.returned
35    }
36
37    fn default_value(&self) -> Option<&DbValue> {
38        self.default_value.as_ref()
39    }
40
41    fn db_name(&self) -> Option<&str> {
42        self.db_name.as_deref()
43    }
44}
45
46impl AdditionalField for SessionAdditionalField {
47    fn field_type(&self) -> &DbFieldType {
48        &self.field_type
49    }
50
51    fn required(&self) -> bool {
52        self.required
53    }
54
55    fn input(&self) -> bool {
56        self.input
57    }
58
59    fn returned(&self) -> bool {
60        self.returned
61    }
62
63    fn default_value(&self) -> Option<&DbValue> {
64        self.default_value.as_ref()
65    }
66
67    fn db_name(&self) -> Option<&str> {
68        self.db_name.as_deref()
69    }
70}
71
72pub fn create_values<F>(
73    fields: &BTreeMap<String, F>,
74    body: &serde_json::Map<String, Value>,
75) -> Result<DbRecord, AdditionalFieldError>
76where
77    F: AdditionalField,
78{
79    let mut values = DbRecord::new();
80    for (name, field) in fields {
81        match body.get(name) {
82            Some(value) => {
83                if !field.input() {
84                    return Err(AdditionalFieldError::NotInput(name.clone()));
85                }
86                values.insert(
87                    storage_name(name, field),
88                    json_to_db_value(name, field.field_type(), value)
89                        .map_err(AdditionalFieldError::InvalidType)?,
90                );
91            }
92            None => {
93                if let Some(value) = field.default_value() {
94                    values.insert(storage_name(name, field), value.clone());
95                } else if field.required() {
96                    return Err(AdditionalFieldError::MissingRequired(name.clone()));
97                } else {
98                    values.insert(storage_name(name, field), DbValue::Null);
99                }
100            }
101        }
102    }
103    Ok(values)
104}
105
106pub fn update_values<F>(
107    fields: &BTreeMap<String, F>,
108    body: &serde_json::Map<String, Value>,
109) -> Result<DbRecord, AdditionalFieldError>
110where
111    F: AdditionalField,
112{
113    let mut values = DbRecord::new();
114    for (name, value) in body {
115        let Some(field) = fields.get(name) else {
116            continue;
117        };
118        if !field.input() {
119            return Err(AdditionalFieldError::NotInput(name.clone()));
120        }
121        values.insert(
122            storage_name(name, field),
123            json_to_db_value(name, field.field_type(), value)
124                .map_err(AdditionalFieldError::InvalidType)?,
125        );
126    }
127    Ok(values)
128}
129
130pub fn insert_returned_fields<F>(
131    object: &mut serde_json::Map<String, Value>,
132    fields: &BTreeMap<String, F>,
133    record: &DbRecord,
134) -> Result<(), RustAuthError>
135where
136    F: AdditionalField,
137{
138    for (name, field) in fields {
139        if !field.returned() {
140            continue;
141        }
142        let value = record
143            .get(name)
144            .or_else(|| field.db_name().and_then(|db_name| record.get(db_name)))
145            .or_else(|| field.default_value())
146            .unwrap_or(&DbValue::Null);
147        object.insert(name.clone(), db_value_to_json(value)?);
148    }
149    Ok(())
150}
151
152/// Like [`insert_returned_fields`] but emits camelCase HTTP keys.
153pub(crate) fn insert_returned_fields_http<F>(
154    object: &mut serde_json::Map<String, Value>,
155    fields: &BTreeMap<String, F>,
156    record: &DbRecord,
157) -> Result<(), RustAuthError>
158where
159    F: AdditionalField,
160{
161    for (name, field) in fields {
162        if !field.returned() {
163            continue;
164        }
165        let value = record
166            .get(name)
167            .or_else(|| field.db_name().and_then(|db_name| record.get(db_name)))
168            .or_else(|| field.default_value())
169            .unwrap_or(&DbValue::Null);
170        object.insert(snake_to_camel(name), db_value_to_json(value)?);
171    }
172    Ok(())
173}
174
175fn storage_name<F>(logical_name: &str, field: &F) -> String
176where
177    F: AdditionalField,
178{
179    field
180        .db_name()
181        .map(str::to_owned)
182        .unwrap_or_else(|| logical_name.to_owned())
183}
184
185pub fn db_value_to_json(value: &DbValue) -> Result<Value, RustAuthError> {
186    match value {
187        DbValue::String(value) => Ok(Value::String(value.clone())),
188        DbValue::Number(value) => Ok(Value::Number((*value).into())),
189        DbValue::Boolean(value) => Ok(Value::Bool(*value)),
190        DbValue::Timestamp(value) => {
191            serde_json::to_value(value).map_err(|error| RustAuthError::Serialization {
192                context: "serializing additional field timestamp",
193                message: error.to_string(),
194            })
195        }
196        DbValue::Json(value) => Ok(value.clone()),
197        DbValue::StringArray(values) => Ok(Value::Array(
198            values.iter().cloned().map(Value::String).collect(),
199        )),
200        DbValue::NumberArray(values) => Ok(Value::Array(
201            values
202                .iter()
203                .map(|value| Value::Number((*value).into()))
204                .collect(),
205        )),
206        DbValue::Record(record) => db_record_to_json(record),
207        DbValue::RecordArray(records) => records
208            .iter()
209            .map(db_record_to_json)
210            .collect::<Result<Vec<_>, _>>()
211            .map(Value::Array),
212        DbValue::Null => Ok(Value::Null),
213    }
214}
215
216pub fn json_to_db_value(
217    name: &str,
218    field_type: &DbFieldType,
219    value: &Value,
220) -> Result<DbValue, String> {
221    if value.is_null() {
222        return Ok(DbValue::Null);
223    }
224    match field_type {
225        DbFieldType::String => value
226            .as_str()
227            .map(|value| DbValue::String(value.to_owned())),
228        DbFieldType::Number => value.as_i64().map(DbValue::Number),
229        DbFieldType::Boolean => value.as_bool().map(DbValue::Boolean),
230        DbFieldType::Json => Some(DbValue::Json(value.clone())),
231        DbFieldType::StringArray => value.as_array().and_then(|values| {
232            values
233                .iter()
234                .map(|value| value.as_str().map(str::to_owned))
235                .collect::<Option<Vec<_>>>()
236                .map(DbValue::StringArray)
237        }),
238        DbFieldType::NumberArray => value.as_array().and_then(|values| {
239            values
240                .iter()
241                .map(Value::as_i64)
242                .collect::<Option<Vec<_>>>()
243                .map(DbValue::NumberArray)
244        }),
245        DbFieldType::Timestamp => None,
246    }
247    .ok_or_else(|| format!("invalid value for additional field `{name}`"))
248}
249
250#[derive(Debug, Clone, PartialEq, Eq)]
251pub enum AdditionalFieldError {
252    MissingRequired(String),
253    NotInput(String),
254    InvalidType(String),
255}
256
257impl AdditionalFieldError {
258    pub fn message(&self) -> String {
259        match self {
260            Self::MissingRequired(name) => format!("missing required additional field `{name}`"),
261            Self::NotInput(name) => format!("additional field `{name}` is not accepted as input"),
262            Self::InvalidType(message) => message.clone(),
263        }
264    }
265}
266
267pub async fn user_response_value(
268    adapter: &dyn DbAdapter,
269    context: &AuthContext,
270    fields: &BTreeMap<String, UserAdditionalField>,
271    user: &User,
272) -> Result<Value, RustAuthError> {
273    if fields.is_empty() {
274        return user_to_http_value(user);
275    }
276    let users = context.schema().table("user")?;
277    let record = adapter
278        .find_one(
279            FindOne::new(users.model())
280                .where_clause(users.where_eq("id", DbValue::String(user.id.clone()))?),
281        )
282        .await?
283        .map(|record| users.map_record(record))
284        .transpose()?;
285    let mut value = user_to_http_value(user)?;
286    let Some(object) = value.as_object_mut() else {
287        return Err(RustAuthError::Serialization {
288            context: "serializing user output",
289            message: "expected JSON object".to_owned(),
290        });
291    };
292    if let Some(record) = record {
293        insert_returned_fields_http(object, fields, &record)?;
294    }
295    Ok(value)
296}
297
298fn db_record_to_json(record: &DbRecord) -> Result<Value, RustAuthError> {
299    record
300        .iter()
301        .map(|(field, value)| db_value_to_json(value).map(|value| (field.clone(), value)))
302        .collect::<Result<serde_json::Map<_, _>, _>>()
303        .map(Value::Object)
304}