Skip to main content

rustauth_plugins/additional_fields/
mod.rs

1//! Additional fields plugin.
2
3use std::collections::BTreeMap;
4
5use rustauth_core::db::{DbField, DbFieldType, DbValue};
6use rustauth_core::options::{SessionAdditionalField, UserAdditionalField};
7use rustauth_core::plugin::{AuthPlugin, PluginInitOutput, PluginSchemaContribution};
8
9pub const UPSTREAM_PLUGIN_ID: &str = "additional-fields";
10
11#[derive(Debug, Clone, Default, PartialEq)]
12pub struct AdditionalFieldsOptions {
13    pub user: BTreeMap<String, AdditionalField>,
14    pub session: BTreeMap<String, AdditionalField>,
15}
16
17impl AdditionalFieldsOptions {
18    pub fn new() -> Self {
19        Self::default()
20    }
21
22    #[must_use]
23    pub fn builder() -> AdditionalFieldsOptionsBuilder {
24        AdditionalFieldsOptionsBuilder::default()
25    }
26
27    #[must_use]
28    pub fn user_field(mut self, name: impl Into<String>, field: AdditionalField) -> Self {
29        self.user.insert(name.into(), field);
30        self
31    }
32
33    #[must_use]
34    pub fn session_field(mut self, name: impl Into<String>, field: AdditionalField) -> Self {
35        self.session.insert(name.into(), field);
36        self
37    }
38}
39
40#[derive(Debug, Clone, PartialEq)]
41pub struct AdditionalField {
42    pub field_type: DbFieldType,
43    pub required: bool,
44    pub input: bool,
45    pub returned: bool,
46    pub unique: bool,
47    pub index: bool,
48    pub default_value: Option<DbValue>,
49    pub db_name: Option<String>,
50}
51
52impl AdditionalField {
53    pub fn new(field_type: DbFieldType) -> Self {
54        Self {
55            field_type,
56            required: true,
57            input: true,
58            returned: true,
59            unique: false,
60            index: false,
61            default_value: None,
62            db_name: None,
63        }
64    }
65
66    #[must_use]
67    pub fn optional(mut self) -> Self {
68        self.required = false;
69        self
70    }
71
72    #[must_use]
73    pub fn generated(mut self) -> Self {
74        self.input = false;
75        self
76    }
77
78    #[must_use]
79    pub fn hidden(mut self) -> Self {
80        self.returned = false;
81        self
82    }
83
84    #[must_use]
85    pub fn unique(mut self) -> Self {
86        self.unique = true;
87        self
88    }
89
90    #[must_use]
91    pub fn indexed(mut self) -> Self {
92        self.index = true;
93        self
94    }
95
96    #[must_use]
97    pub fn default_value(mut self, value: DbValue) -> Self {
98        self.default_value = Some(value);
99        self
100    }
101
102    #[must_use]
103    pub fn db_name(mut self, db_name: impl Into<String>) -> Self {
104        self.db_name = Some(db_name.into());
105        self
106    }
107}
108
109#[derive(Debug, Clone, Default)]
110pub struct AdditionalFieldsOptionsBuilder {
111    user: Option<BTreeMap<String, AdditionalField>>,
112    session: Option<BTreeMap<String, AdditionalField>>,
113}
114
115impl AdditionalFieldsOptionsBuilder {
116    #[must_use]
117    pub fn user(mut self, user: BTreeMap<String, AdditionalField>) -> Self {
118        self.user = Some(user);
119        self
120    }
121
122    #[must_use]
123    pub fn user_field(mut self, name: impl Into<String>, field: AdditionalField) -> Self {
124        self.user
125            .get_or_insert_with(BTreeMap::new)
126            .insert(name.into(), field);
127        self
128    }
129
130    #[must_use]
131    pub fn session(mut self, session: BTreeMap<String, AdditionalField>) -> Self {
132        self.session = Some(session);
133        self
134    }
135
136    #[must_use]
137    pub fn session_field(mut self, name: impl Into<String>, field: AdditionalField) -> Self {
138        self.session
139            .get_or_insert_with(BTreeMap::new)
140            .insert(name.into(), field);
141        self
142    }
143
144    #[must_use]
145    pub fn build(self) -> AdditionalFieldsOptions {
146        let defaults = AdditionalFieldsOptions::default();
147        AdditionalFieldsOptions {
148            user: self.user.unwrap_or(defaults.user),
149            session: self.session.unwrap_or(defaults.session),
150        }
151    }
152}
153
154/// Create the additional-fields plugin.
155#[must_use]
156pub fn additional_fields(options: AdditionalFieldsOptions) -> AuthPlugin {
157    AuthPlugin::new(UPSTREAM_PLUGIN_ID).with_init(move |_context| {
158        let mut output = PluginInitOutput::new();
159        for (name, field) in &options.user {
160            output = output
161                .schema(PluginSchemaContribution::field(
162                    "user",
163                    name.clone(),
164                    field.schema_field(name),
165                ))
166                .user_additional_field(name.clone(), field.user_runtime_field());
167        }
168        for (name, field) in &options.session {
169            output = output
170                .schema(PluginSchemaContribution::field(
171                    "session",
172                    name.clone(),
173                    field.schema_field(name),
174                ))
175                .session_additional_field(name.clone(), field.session_runtime_field());
176        }
177        Ok(output)
178    })
179}
180
181impl AdditionalField {
182    fn schema_field(&self, logical_name: &str) -> DbField {
183        let mut field = DbField::new(
184            self.db_name
185                .clone()
186                .unwrap_or_else(|| logical_name.to_owned()),
187            self.field_type.clone(),
188        );
189        if !self.required {
190            field = field.optional();
191        }
192        if self.unique {
193            field = field.unique();
194        }
195        if self.index {
196            field = field.indexed();
197        }
198        if !self.returned {
199            field = field.hidden();
200        }
201        if !self.input {
202            field = field.generated();
203        }
204        field
205    }
206
207    fn user_runtime_field(&self) -> UserAdditionalField {
208        let mut field = UserAdditionalField::new(self.field_type.clone());
209        field.required = self.required;
210        field.input = self.input;
211        field.returned = self.returned;
212        field.default_value = self.default_value.clone();
213        field.db_name = self.db_name.clone();
214        field
215    }
216
217    fn session_runtime_field(&self) -> SessionAdditionalField {
218        let mut field = SessionAdditionalField::new(self.field_type.clone());
219        field.required = self.required;
220        field.input = self.input;
221        field.returned = self.returned;
222        field.default_value = self.default_value.clone();
223        field.db_name = self.db_name.clone();
224        field
225    }
226}