1use rustauth_core::context::AuthContext;
2use rustauth_core::crypto::random::generate_random_string;
3use rustauth_core::db::{
4 DbAdapter, DbRecord, DbSchema, DbValue, Delete, FindMany, FindOne, SchemaTable, Update,
5};
6use rustauth_core::error::RustAuthError;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use time::OffsetDateTime;
10
11const PASSKEY_MODEL: &str = "passkey";
12
13#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
14#[serde(rename_all = "camelCase")]
15pub struct Passkey {
16 pub id: String,
17 #[serde(skip_serializing_if = "Option::is_none")]
18 pub name: Option<String>,
19 pub public_key: String,
20 pub user_id: String,
21 #[serde(rename = "credentialID")]
22 pub credential_id: String,
23 pub counter: i64,
24 pub device_type: String,
25 pub backed_up: bool,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub transports: Option<String>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub created_at: Option<OffsetDateTime>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub aaguid: Option<String>,
32 #[serde(skip)]
33 pub webauthn_credential: Value,
34}
35
36#[derive(Clone)]
37pub struct PasskeyStore<'a> {
38 adapter: &'a dyn DbAdapter,
39 schema: DbSchema,
40}
41
42impl Passkey {
43 pub(crate) fn registration_exclude_value(&self) -> Value {
45 if !self.webauthn_credential.is_null() {
46 self.webauthn_credential.clone()
47 } else {
48 Value::String(self.credential_id.clone())
49 }
50 }
51
52 pub(crate) fn authentication_credential_value(&self) -> Result<Option<Value>, RustAuthError> {
57 if !self.webauthn_credential.is_null() {
58 return Ok(Some(self.webauthn_credential.clone()));
59 }
60 crate::webauthn::legacy_passkey_credential_value(
61 &self.credential_id,
62 &self.public_key,
63 self.counter,
64 &self.device_type,
65 self.backed_up,
66 self.transports.as_deref(),
67 )
68 .map(Some)
69 }
70}
71
72impl<'a> PasskeyStore<'a> {
73 pub fn with_schema(adapter: &'a dyn DbAdapter, schema: DbSchema) -> Self {
74 Self { adapter, schema }
75 }
76
77 pub fn from_context(context: &'a AuthContext) -> Result<Self, RustAuthError> {
78 Ok(Self::with_schema(
79 context.adapter_ref()?,
80 context.db_schema.clone(),
81 ))
82 }
83
84 pub fn new(context: &'a AuthContext) -> Result<Self, RustAuthError> {
86 Self::from_context(context)
87 }
88
89 fn passkeys(&self) -> Result<SchemaTable<'_>, RustAuthError> {
90 SchemaTable::new(&self.schema, PASSKEY_MODEL)
91 }
92
93 fn parse_passkey(&self, record: DbRecord) -> Result<Passkey, RustAuthError> {
94 passkey_from_record(self.passkeys()?.map_record(record)?)
95 }
96
97 pub async fn list_by_user(&self, user_id: &str) -> Result<Vec<Passkey>, RustAuthError> {
98 let passkeys = self.passkeys()?;
99 self.adapter
100 .find_many(
101 FindMany::new(passkeys.model()).where_clause(
102 passkeys.where_eq("user_id", DbValue::String(user_id.to_owned()))?,
103 ),
104 )
105 .await?
106 .into_iter()
107 .map(|record| self.parse_passkey(record))
108 .collect()
109 }
110
111 pub async fn find_by_id(&self, id: &str) -> Result<Option<Passkey>, RustAuthError> {
112 let passkeys = self.passkeys()?;
113 self.adapter
114 .find_one(
115 FindOne::new(passkeys.model())
116 .where_clause(passkeys.where_eq("id", DbValue::String(id.to_owned()))?),
117 )
118 .await?
119 .map(|record| self.parse_passkey(record))
120 .transpose()
121 }
122
123 pub async fn find_by_credential_id(
124 &self,
125 credential_id: &str,
126 ) -> Result<Option<Passkey>, RustAuthError> {
127 let passkeys = self.passkeys()?;
128 self.adapter
129 .find_one(FindOne::new(passkeys.model()).where_clause(
130 passkeys.where_eq("credential_id", DbValue::String(credential_id.to_owned()))?,
131 ))
132 .await?
133 .map(|record| self.parse_passkey(record))
134 .transpose()
135 }
136
137 pub async fn create(
138 &self,
139 user_id: &str,
140 name: Option<String>,
141 credential: crate::webauthn::VerifiedPasskeyCredential,
142 ) -> Result<Passkey, RustAuthError> {
143 let passkeys = self.passkeys()?;
144 let now = OffsetDateTime::now_utc();
145 let record = self
146 .adapter
147 .create(
148 passkeys
149 .create()
150 .data("id", DbValue::String(generate_random_string(32)))
151 .data("name", optional_string(name))
152 .data("public_key", DbValue::String(credential.public_key))
153 .data("user_id", DbValue::String(user_id.to_owned()))
154 .data("credential_id", DbValue::String(credential.credential_id))
155 .data("counter", DbValue::Number(i64::from(credential.counter)))
156 .data("device_type", DbValue::String(credential.device_type))
157 .data("backed_up", DbValue::Boolean(credential.backed_up))
158 .data("transports", optional_string(credential.transports))
159 .data("created_at", DbValue::Timestamp(now))
160 .data("aaguid", optional_string(credential.aaguid))
161 .data("webauthn_credential", DbValue::Json(credential.credential))
162 .force_allow_id(),
163 )
164 .await?;
165 self.parse_passkey(record)
166 }
167
168 pub async fn update_name_for_user(
169 &self,
170 id: &str,
171 user_id: &str,
172 name: String,
173 ) -> Result<Option<Passkey>, RustAuthError> {
174 let passkeys = self.passkeys()?;
175 self.adapter
176 .update(
177 Update::new(passkeys.model())
178 .where_clause(passkeys.where_eq("id", DbValue::String(id.to_owned()))?)
179 .where_clause(
180 passkeys.where_eq("user_id", DbValue::String(user_id.to_owned()))?,
181 )
182 .data("name", DbValue::String(name)),
183 )
184 .await?
185 .map(|record| self.parse_passkey(record))
186 .transpose()
187 }
188
189 pub async fn update_after_authentication(
190 &self,
191 id: &str,
192 expected_counter: i64,
193 verification: crate::webauthn::VerifiedAuthentication,
194 ) -> Result<Option<Passkey>, RustAuthError> {
195 let passkeys = self.passkeys()?;
196 let mut update = Update::new(passkeys.model())
197 .where_clause(passkeys.where_eq("id", DbValue::String(id.to_owned()))?)
198 .where_clause(passkeys.where_eq("counter", DbValue::Number(expected_counter))?)
199 .data(
200 "counter",
201 DbValue::Number(i64::from(verification.new_counter)),
202 );
203 if let Some(credential) = verification.credential {
204 update = update.data("webauthn_credential", DbValue::Json(credential));
205 }
206 self.adapter
207 .update(update)
208 .await?
209 .map(|record| self.parse_passkey(record))
210 .transpose()
211 }
212
213 pub async fn delete_for_user(&self, id: &str, user_id: &str) -> Result<bool, RustAuthError> {
214 let passkeys = self.passkeys()?;
215 let Some(passkey) = self.find_by_id(id).await? else {
216 return Ok(false);
217 };
218 if passkey.user_id != user_id {
219 return Ok(false);
220 }
221 self.adapter
222 .delete(
223 Delete::new(passkeys.model())
224 .where_clause(passkeys.where_eq("id", DbValue::String(id.to_owned()))?),
225 )
226 .await?;
227 Ok(true)
228 }
229}
230
231fn optional_string(value: Option<String>) -> DbValue {
232 value.map(DbValue::String).unwrap_or(DbValue::Null)
233}
234
235fn passkey_from_record(record: DbRecord) -> Result<Passkey, RustAuthError> {
236 Ok(Passkey {
237 id: required_string(&record, "id")?.to_owned(),
238 name: optional_string_field(&record, "name")?,
239 public_key: required_string(&record, "public_key")?.to_owned(),
240 user_id: required_string(&record, "user_id")?.to_owned(),
241 credential_id: required_string(&record, "credential_id")?.to_owned(),
242 counter: required_number(&record, "counter")?,
243 device_type: required_string(&record, "device_type")?.to_owned(),
244 backed_up: required_bool(&record, "backed_up")?,
245 transports: optional_string_field(&record, "transports")?,
246 created_at: optional_timestamp(&record, "created_at")?,
247 aaguid: optional_string_field(&record, "aaguid")?,
248 webauthn_credential: match record.get("webauthn_credential") {
249 Some(DbValue::Json(value)) => value.clone(),
250 Some(DbValue::Null) | None => Value::Null,
251 Some(_) => return Err(invalid_field("webauthn_credential", "json")),
252 },
253 })
254}
255
256fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, RustAuthError> {
257 match record.get(field) {
258 Some(DbValue::String(value)) => Ok(value),
259 Some(_) => Err(invalid_field(field, "string")),
260 None => Err(missing_field(field)),
261 }
262}
263
264fn optional_string_field(record: &DbRecord, field: &str) -> Result<Option<String>, RustAuthError> {
265 match record.get(field) {
266 Some(DbValue::String(value)) => Ok(Some(value.to_owned())),
267 Some(DbValue::Null) | None => Ok(None),
268 Some(_) => Err(invalid_field(field, "string or null")),
269 }
270}
271
272fn required_number(record: &DbRecord, field: &str) -> Result<i64, RustAuthError> {
273 match record.get(field) {
274 Some(DbValue::Number(value)) => Ok(*value),
275 Some(_) => Err(invalid_field(field, "number")),
276 None => Err(missing_field(field)),
277 }
278}
279
280fn required_bool(record: &DbRecord, field: &str) -> Result<bool, RustAuthError> {
281 match record.get(field) {
282 Some(DbValue::Boolean(value)) => Ok(*value),
283 Some(_) => Err(invalid_field(field, "boolean")),
284 None => Err(missing_field(field)),
285 }
286}
287
288fn optional_timestamp(
289 record: &DbRecord,
290 field: &str,
291) -> Result<Option<OffsetDateTime>, RustAuthError> {
292 match record.get(field) {
293 Some(DbValue::Timestamp(value)) => Ok(Some(*value)),
294 Some(DbValue::Null) | None => Ok(None),
295 Some(_) => Err(invalid_field(field, "timestamp or null")),
296 }
297}
298
299fn missing_field(field: &str) -> RustAuthError {
300 RustAuthError::Adapter(format!("passkey record is missing `{field}`"))
301}
302
303fn invalid_field(field: &str, expected: &str) -> RustAuthError {
304 RustAuthError::Adapter(format!("passkey record field `{field}` must be {expected}"))
305}