Skip to main content

rustauth_plugins/device_authorization/
options.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use rustauth_core::error::RustAuthError;
6use time::Duration;
7
8pub type DeviceCodeGeneratorFuture = Pin<Box<dyn Future<Output = String> + Send>>;
9pub type AsyncDeviceCodeGenerator = Arc<dyn Fn() -> DeviceCodeGeneratorFuture + Send + Sync>;
10pub type ClientValidationFuture = Pin<Box<dyn Future<Output = Result<bool, RustAuthError>> + Send>>;
11pub type ClientValidator = Arc<dyn Fn(String) -> ClientValidationFuture + Send + Sync>;
12pub type DeviceAuthRequestFuture = Pin<Box<dyn Future<Output = Result<(), RustAuthError>> + Send>>;
13pub type DeviceAuthRequestHook =
14    Arc<dyn Fn(String, Option<String>) -> DeviceAuthRequestFuture + Send + Sync>;
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum DeviceAuthorizationOptionsError {
18    EmptyDeviceCodeLength,
19    EmptyUserCodeLength,
20    NonPositiveExpiresIn,
21    NonPositiveInterval,
22}
23
24impl std::fmt::Display for DeviceAuthorizationOptionsError {
25    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        let message = match self {
27            Self::EmptyDeviceCodeLength => "device code length must be greater than zero",
28            Self::EmptyUserCodeLength => "user code length must be greater than zero",
29            Self::NonPositiveExpiresIn => "expires_in must be positive",
30            Self::NonPositiveInterval => "interval must be positive",
31        };
32        formatter.write_str(message)
33    }
34}
35
36impl std::error::Error for DeviceAuthorizationOptionsError {}
37
38#[derive(Clone)]
39pub struct DeviceAuthorizationOptions {
40    pub expires_in: Duration,
41    pub interval: Duration,
42    pub device_code_length: usize,
43    pub user_code_length: usize,
44    pub generate_device_code: Option<AsyncDeviceCodeGenerator>,
45    pub generate_user_code: Option<AsyncDeviceCodeGenerator>,
46    pub validate_client: Option<ClientValidator>,
47    pub on_device_auth_request: Option<DeviceAuthRequestHook>,
48    pub verification_uri: String,
49    pub schema: DeviceAuthorizationSchemaOptions,
50}
51
52#[derive(Debug, Clone, Default, PartialEq, Eq)]
53pub struct DeviceAuthorizationSchemaOptions {
54    pub table_name: Option<String>,
55    pub fields: DeviceAuthorizationSchemaFields,
56}
57
58impl DeviceAuthorizationSchemaOptions {
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    #[must_use]
64    pub fn table_name(mut self, table_name: impl Into<String>) -> Self {
65        self.table_name = Some(table_name.into());
66        self
67    }
68
69    #[must_use]
70    pub fn field_name(
71        mut self,
72        logical_name: impl Into<String>,
73        physical_name: impl Into<String>,
74    ) -> Self {
75        self.fields.set(logical_name.into(), physical_name.into());
76        self
77    }
78}
79
80#[derive(Debug, Clone, Default, PartialEq, Eq)]
81pub struct DeviceAuthorizationSchemaFields {
82    pub id: Option<String>,
83    pub device_code: Option<String>,
84    pub user_code: Option<String>,
85    pub user_id: Option<String>,
86    pub expires_at: Option<String>,
87    pub status: Option<String>,
88    pub last_polled_at: Option<String>,
89    pub polling_interval: Option<String>,
90    pub client_id: Option<String>,
91    pub scope: Option<String>,
92    pub created_at: Option<String>,
93    pub updated_at: Option<String>,
94}
95
96impl DeviceAuthorizationSchemaFields {
97    fn set(&mut self, logical_name: String, physical_name: String) {
98        match logical_name.as_str() {
99            "id" => self.id = Some(physical_name),
100            "device_code" | "deviceCode" => self.device_code = Some(physical_name),
101            "user_code" | "userCode" => self.user_code = Some(physical_name),
102            "user_id" | "userId" => self.user_id = Some(physical_name),
103            "expires_at" | "expiresAt" => self.expires_at = Some(physical_name),
104            "status" => self.status = Some(physical_name),
105            "last_polled_at" | "lastPolledAt" => self.last_polled_at = Some(physical_name),
106            "polling_interval" | "pollingInterval" => {
107                self.polling_interval = Some(physical_name);
108            }
109            "client_id" | "clientId" => self.client_id = Some(physical_name),
110            "scope" => self.scope = Some(physical_name),
111            "created_at" | "createdAt" => self.created_at = Some(physical_name),
112            "updated_at" | "updatedAt" => self.updated_at = Some(physical_name),
113            _ => {}
114        }
115    }
116}
117
118impl Default for DeviceAuthorizationOptions {
119    fn default() -> Self {
120        Self {
121            expires_in: Duration::minutes(30),
122            interval: Duration::seconds(5),
123            device_code_length: 40,
124            user_code_length: 8,
125            generate_device_code: None,
126            generate_user_code: None,
127            validate_client: None,
128            on_device_auth_request: None,
129            verification_uri: "/device".to_owned(),
130            schema: DeviceAuthorizationSchemaOptions::default(),
131        }
132    }
133}
134
135impl DeviceAuthorizationOptions {
136    pub fn new() -> Self {
137        Self::default()
138    }
139
140    #[must_use]
141    pub fn builder() -> DeviceAuthorizationOptionsBuilder {
142        DeviceAuthorizationOptionsBuilder::default()
143    }
144
145    pub fn validate(&self) -> Result<(), DeviceAuthorizationOptionsError> {
146        if self.device_code_length == 0 {
147            return Err(DeviceAuthorizationOptionsError::EmptyDeviceCodeLength);
148        }
149        if self.user_code_length == 0 {
150            return Err(DeviceAuthorizationOptionsError::EmptyUserCodeLength);
151        }
152        if self.expires_in <= Duration::ZERO {
153            return Err(DeviceAuthorizationOptionsError::NonPositiveExpiresIn);
154        }
155        if self.interval <= Duration::ZERO {
156            return Err(DeviceAuthorizationOptionsError::NonPositiveInterval);
157        }
158        Ok(())
159    }
160
161    #[must_use]
162    pub fn expires_in(mut self, expires_in: Duration) -> Self {
163        self.expires_in = expires_in;
164        self
165    }
166
167    #[must_use]
168    pub fn interval(mut self, interval: Duration) -> Self {
169        self.interval = interval;
170        self
171    }
172
173    #[must_use]
174    pub fn device_code_length(mut self, length: usize) -> Self {
175        self.device_code_length = length;
176        self
177    }
178
179    #[must_use]
180    pub fn user_code_length(mut self, length: usize) -> Self {
181        self.user_code_length = length;
182        self
183    }
184
185    #[must_use]
186    pub fn generate_device_code<F>(mut self, generator: F) -> Self
187    where
188        F: Fn() -> String + Send + Sync + 'static,
189    {
190        self.generate_device_code =
191            Some(Arc::new(move || Box::pin(std::future::ready(generator()))));
192        self
193    }
194
195    #[must_use]
196    pub fn generate_user_code<F>(mut self, generator: F) -> Self
197    where
198        F: Fn() -> String + Send + Sync + 'static,
199    {
200        self.generate_user_code = Some(Arc::new(move || Box::pin(std::future::ready(generator()))));
201        self
202    }
203
204    #[must_use]
205    pub fn generate_device_code_async<F, Fut>(mut self, generator: F) -> Self
206    where
207        F: Fn() -> Fut + Send + Sync + 'static,
208        Fut: Future<Output = String> + Send + 'static,
209    {
210        self.generate_device_code = Some(Arc::new(move || Box::pin(generator())));
211        self
212    }
213
214    #[must_use]
215    pub fn generate_user_code_async<F, Fut>(mut self, generator: F) -> Self
216    where
217        F: Fn() -> Fut + Send + Sync + 'static,
218        Fut: Future<Output = String> + Send + 'static,
219    {
220        self.generate_user_code = Some(Arc::new(move || Box::pin(generator())));
221        self
222    }
223
224    #[must_use]
225    pub fn validate_client<F, Fut>(mut self, validator: F) -> Self
226    where
227        F: Fn(String) -> Fut + Send + Sync + 'static,
228        Fut: Future<Output = Result<bool, RustAuthError>> + Send + 'static,
229    {
230        self.validate_client = Some(Arc::new(move |client_id| Box::pin(validator(client_id))));
231        self
232    }
233
234    #[must_use]
235    pub fn on_device_auth_request<F, Fut>(mut self, hook: F) -> Self
236    where
237        F: Fn(String, Option<String>) -> Fut + Send + Sync + 'static,
238        Fut: Future<Output = Result<(), RustAuthError>> + Send + 'static,
239    {
240        self.on_device_auth_request = Some(Arc::new(move |client_id, scope| {
241            Box::pin(hook(client_id, scope))
242        }));
243        self
244    }
245
246    #[must_use]
247    pub fn verification_uri(mut self, uri: impl Into<String>) -> Self {
248        self.verification_uri = uri.into();
249        self
250    }
251
252    #[must_use]
253    pub fn schema(mut self, schema: DeviceAuthorizationSchemaOptions) -> Self {
254        self.schema = schema;
255        self
256    }
257}
258
259#[derive(Clone, Default)]
260pub struct DeviceAuthorizationOptionsBuilder {
261    expires_in: Option<Duration>,
262    interval: Option<Duration>,
263    device_code_length: Option<usize>,
264    user_code_length: Option<usize>,
265    generate_device_code: Option<AsyncDeviceCodeGenerator>,
266    generate_user_code: Option<AsyncDeviceCodeGenerator>,
267    validate_client: Option<ClientValidator>,
268    on_device_auth_request: Option<DeviceAuthRequestHook>,
269    verification_uri: Option<String>,
270    schema: Option<DeviceAuthorizationSchemaOptions>,
271}
272
273impl DeviceAuthorizationOptionsBuilder {
274    #[must_use]
275    pub fn expires_in(mut self, expires_in: Duration) -> Self {
276        self.expires_in = Some(expires_in);
277        self
278    }
279
280    #[must_use]
281    pub fn interval(mut self, interval: Duration) -> Self {
282        self.interval = Some(interval);
283        self
284    }
285
286    #[must_use]
287    pub fn device_code_length(mut self, length: usize) -> Self {
288        self.device_code_length = Some(length);
289        self
290    }
291
292    #[must_use]
293    pub fn user_code_length(mut self, length: usize) -> Self {
294        self.user_code_length = Some(length);
295        self
296    }
297
298    #[must_use]
299    pub fn generate_device_code(mut self, generator: AsyncDeviceCodeGenerator) -> Self {
300        self.generate_device_code = Some(generator);
301        self
302    }
303
304    #[must_use]
305    pub fn generate_user_code(mut self, generator: AsyncDeviceCodeGenerator) -> Self {
306        self.generate_user_code = Some(generator);
307        self
308    }
309
310    #[must_use]
311    pub fn validate_client(mut self, validator: ClientValidator) -> Self {
312        self.validate_client = Some(validator);
313        self
314    }
315
316    #[must_use]
317    pub fn on_device_auth_request(mut self, hook: DeviceAuthRequestHook) -> Self {
318        self.on_device_auth_request = Some(hook);
319        self
320    }
321
322    #[must_use]
323    pub fn verification_uri(mut self, uri: impl Into<String>) -> Self {
324        self.verification_uri = Some(uri.into());
325        self
326    }
327
328    #[must_use]
329    pub fn schema(mut self, schema: DeviceAuthorizationSchemaOptions) -> Self {
330        self.schema = Some(schema);
331        self
332    }
333
334    pub fn build(self) -> Result<DeviceAuthorizationOptions, RustAuthError> {
335        let defaults = DeviceAuthorizationOptions::default();
336        let options = DeviceAuthorizationOptions {
337            expires_in: self.expires_in.unwrap_or(defaults.expires_in),
338            interval: self.interval.unwrap_or(defaults.interval),
339            device_code_length: self
340                .device_code_length
341                .unwrap_or(defaults.device_code_length),
342            user_code_length: self.user_code_length.unwrap_or(defaults.user_code_length),
343            generate_device_code: self.generate_device_code.or(defaults.generate_device_code),
344            generate_user_code: self.generate_user_code.or(defaults.generate_user_code),
345            validate_client: self.validate_client.or(defaults.validate_client),
346            on_device_auth_request: self
347                .on_device_auth_request
348                .or(defaults.on_device_auth_request),
349            verification_uri: self.verification_uri.unwrap_or(defaults.verification_uri),
350            schema: self.schema.unwrap_or(defaults.schema),
351        };
352        options
353            .validate()
354            .map_err(|error| RustAuthError::InvalidConfig(error.to_string()))?;
355        Ok(options)
356    }
357}