spacetimedb_client_api/routes/
identity.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
use std::time::Duration;

use axum::extract::{Path, Query, State};
use axum::response::IntoResponse;
use axum::Extension;
use chrono::Utc;
use http::header::CONTENT_TYPE;
use http::StatusCode;
use rand::Rng;
use serde::{Deserialize, Serialize};

use spacetimedb::auth::identity::{encode_token, encode_token_with_expiry};
use spacetimedb::messages::control_db::IdentityEmail;
use spacetimedb_client_api_messages::recovery::{RecoveryCode, RecoveryCodeResponse};
use spacetimedb_lib::de::serde::DeserializeWrapper;
use spacetimedb_lib::{Address, Identity};

use crate::auth::{anon_auth_middleware, SpacetimeAuth, SpacetimeAuthRequired};
use crate::{log_and_500, ControlStateDelegate, ControlStateReadAccess, ControlStateWriteAccess, NodeDelegate};

#[derive(Deserialize)]
pub struct CreateIdentityQueryParams {
    email: Option<email_address::EmailAddress>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateIdentityResponse {
    identity: Identity,
    token: String,
}

pub async fn create_identity<S: ControlStateDelegate + NodeDelegate>(
    State(ctx): State<S>,
    Query(CreateIdentityQueryParams { email }): Query<CreateIdentityQueryParams>,
) -> axum::response::Result<impl IntoResponse> {
    let auth = SpacetimeAuth::alloc(&ctx).await?;
    if let Some(email) = email {
        ctx.add_email(&auth.identity, email.as_str())
            .await
            .map_err(log_and_500)?;
    }

    let identity_response = CreateIdentityResponse {
        identity: auth.identity,
        token: auth.creds.token().to_owned(),
    };
    Ok(axum::Json(identity_response))
}

#[derive(Debug, Clone, Serialize)]
pub struct GetIdentityResponse {
    identities: Vec<GetIdentityResponseEntry>,
}

#[derive(Debug, Clone, Serialize)]
pub struct GetIdentityResponseEntry {
    identity: Identity,
    email: String,
}

#[derive(Deserialize)]
pub struct GetIdentityQueryParams {
    email: Option<String>,
}
pub async fn get_identity<S: ControlStateDelegate>(
    State(ctx): State<S>,
    Query(GetIdentityQueryParams { email }): Query<GetIdentityQueryParams>,
) -> axum::response::Result<impl IntoResponse> {
    match email {
        None => Err(StatusCode::BAD_REQUEST.into()),
        Some(email) => {
            let identities = ctx.get_identities_for_email(email.as_str()).map_err(log_and_500)?;
            let identities = identities
                .into_iter()
                .map(|identity_email| GetIdentityResponseEntry {
                    identity: identity_email.identity,
                    email: identity_email.email,
                })
                .collect::<Vec<_>>();
            Ok(axum::Json(GetIdentityResponse { identities }))
        }
    }
}

/// A version of `Identity` appropriate for URL de/encoding.
///
/// Because `Identity` is represented in SATS as a `ProductValue`,
/// its serialized format is somewhat gnarly.
/// When URL-encoding identities, we want to use only the hex string,
/// without wrapping it in a `ProductValue`.
/// This keeps our routes pretty, like `/identity/<64 hex chars>/set-email`.
///
/// This newtype around `Identity` implements `Deserialize`
/// directly from the inner identity bytes,
/// without the enclosing `ProductValue` wrapper.
#[derive(derive_more::Into)]
pub struct IdentityForUrl(Identity);

impl<'de> serde::Deserialize<'de> for IdentityForUrl {
    fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
        <_>::deserialize(de).map(|DeserializeWrapper(b)| IdentityForUrl(Identity::from_byte_array(b)))
    }
}

#[derive(Deserialize)]
pub struct SetEmailParams {
    identity: IdentityForUrl,
}

#[derive(Deserialize)]
pub struct SetEmailQueryParams {
    email: email_address::EmailAddress,
}

pub async fn set_email<S: ControlStateWriteAccess>(
    State(ctx): State<S>,
    Path(SetEmailParams { identity }): Path<SetEmailParams>,
    Query(SetEmailQueryParams { email }): Query<SetEmailQueryParams>,
    Extension(auth): Extension<SpacetimeAuth>,
) -> axum::response::Result<impl IntoResponse> {
    let identity = identity.into();

    if auth.identity != identity {
        return Err(StatusCode::UNAUTHORIZED.into());
    }
    ctx.add_email(&identity, email.as_str()).await.map_err(log_and_500)?;

    Ok(())
}

pub async fn check_email<S: ControlStateReadAccess>(
    State(ctx): State<S>,
    Path(SetEmailParams { identity }): Path<SetEmailParams>,
    Extension(auth): Extension<SpacetimeAuth>,
) -> axum::response::Result<impl IntoResponse> {
    let identity = identity.into();

    if auth.identity != identity {
        return Err(StatusCode::UNAUTHORIZED.into());
    }

    let emails = ctx
        .get_emails_for_identity(&identity)
        .map_err(log_and_500)?
        .into_iter()
        .map(|IdentityEmail { email, .. }| email)
        .collect::<Vec<_>>();

    Ok(axum::Json(emails))
}

#[derive(Deserialize)]
pub struct GetDatabasesParams {
    identity: IdentityForUrl,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GetDatabasesResponse {
    addresses: Vec<Address>,
}

pub async fn get_databases<S: ControlStateDelegate>(
    State(ctx): State<S>,
    Path(GetDatabasesParams { identity }): Path<GetDatabasesParams>,
) -> axum::response::Result<impl IntoResponse> {
    let identity = identity.into();
    // Linear scan for all databases that have this identity, and return their addresses
    let all_dbs = ctx.get_databases().map_err(|e| {
        log::error!("Failure when retrieving databases for search: {}", e);
        StatusCode::INTERNAL_SERVER_ERROR
    })?;
    let addresses = all_dbs
        .iter()
        .filter(|db| db.owner_identity == identity)
        .map(|db| db.address)
        .collect();
    Ok(axum::Json(GetDatabasesResponse { addresses }))
}

#[derive(Debug, Serialize)]
pub struct WebsocketTokenResponse {
    pub token: String,
}

pub async fn create_websocket_token<S: NodeDelegate>(
    State(ctx): State<S>,
    SpacetimeAuthRequired(auth): SpacetimeAuthRequired,
) -> axum::response::Result<impl IntoResponse> {
    let expiry = Duration::from_secs(60);
    let token = encode_token_with_expiry(ctx.private_key(), auth.identity, Some(expiry)).map_err(log_and_500)?;
    Ok(axum::Json(WebsocketTokenResponse { token }))
}

#[derive(Deserialize)]
pub struct ValidateTokenParams {
    identity: IdentityForUrl,
}

pub async fn validate_token(
    Path(ValidateTokenParams { identity }): Path<ValidateTokenParams>,
    SpacetimeAuthRequired(auth): SpacetimeAuthRequired,
) -> axum::response::Result<impl IntoResponse> {
    let identity = Identity::from(identity);

    if auth.identity != identity {
        return Err(StatusCode::BAD_REQUEST.into());
    }

    Ok(StatusCode::NO_CONTENT)
}

pub async fn get_public_key<S: NodeDelegate>(State(ctx): State<S>) -> axum::response::Result<impl IntoResponse> {
    Ok((
        [(CONTENT_TYPE, "application/pem-certificate-chain")],
        ctx.public_key_bytes().to_owned(),
    ))
}

#[derive(Deserialize)]
pub struct RequestRecoveryCodeParams {
    /// Whether or not the client is requesting a login link for a web-login. This is false for CLI logins.
    #[serde(default)]
    link: bool,
    email: String,
    identity: IdentityForUrl,
}

pub async fn request_recovery_code<S: NodeDelegate + ControlStateDelegate>(
    State(ctx): State<S>,
    Query(RequestRecoveryCodeParams { link, email, identity }): Query<RequestRecoveryCodeParams>,
) -> axum::response::Result<impl IntoResponse> {
    let identity = Identity::from(identity);
    let Some(sendgrid) = ctx.sendgrid_controller() else {
        log::error!("A recovery code was requested, but SendGrid is disabled.");
        return Err((StatusCode::INTERNAL_SERVER_ERROR, "SendGrid is disabled.").into());
    };

    if !ctx
        .get_identities_for_email(email.as_str())
        .map_err(log_and_500)?
        .iter()
        .any(|a| a.identity == identity)
    {
        return Err((
            StatusCode::BAD_REQUEST,
            "Email is not associated with the provided identity.",
        )
            .into());
    }

    let code = rand::thread_rng().gen_range(0..=999999);
    let code = format!("{code:06}");
    let recovery_code = RecoveryCode {
        code: code.clone(),
        generation_time: Utc::now(),
        identity,
    };
    ctx.insert_recovery_code(&identity, email.as_str(), recovery_code)
        .await
        .map_err(log_and_500)?;

    sendgrid
        .send_recovery_email(email.as_str(), code.as_str(), &identity.to_hex(), link)
        .await
        .map_err(log_and_500)?;
    Ok(())
}

#[derive(Deserialize)]
pub struct ConfirmRecoveryCodeParams {
    pub email: String,
    pub identity: IdentityForUrl,
    pub code: String,
}

/// Note: We should be slightly more security conscious about this function because
///  we are providing a login token to the user initiating the request. We want to make
///  sure there aren't any logical issues in here that would allow a user to request a token
///  for an identity that they don't have authority over.
pub async fn confirm_recovery_code<S: ControlStateDelegate + NodeDelegate>(
    State(ctx): State<S>,
    Query(ConfirmRecoveryCodeParams { email, identity, code }): Query<ConfirmRecoveryCodeParams>,
) -> axum::response::Result<impl IntoResponse> {
    let identity = Identity::from(identity);
    let recovery_codes = ctx.get_recovery_codes(email.as_str()).map_err(log_and_500)?;

    let recovery_code = recovery_codes
        .into_iter()
        .find(|rc| rc.code == code.as_str())
        .ok_or((StatusCode::NOT_FOUND, "Recovery code not found."))?;

    let duration = Utc::now() - recovery_code.generation_time;
    if duration.num_seconds() > 60 * 10 {
        return Err((StatusCode::BAD_REQUEST, "Recovery code expired.").into());
    }

    // Make sure the identity provided by the request matches the recovery code registration
    if recovery_code.identity != identity {
        return Err((
            StatusCode::BAD_REQUEST,
            "Recovery code doesn't match the provided identity.",
        )
            .into());
    }

    if !ctx
        .get_identities_for_email(email.as_str())
        .map_err(log_and_500)?
        .iter()
        .any(|a| a.identity == identity)
    {
        // This can happen if someone changes their associated email during a recovery request.
        return Err((StatusCode::NOT_FOUND, "No identity associated with that email.").into());
    }

    // Recovery code is verified, return the identity and token to the user
    let token = encode_token(ctx.private_key(), identity).map_err(log_and_500)?;
    let result = RecoveryCodeResponse { identity, token };

    Ok(axum::Json(result))
}

pub fn router<S>(ctx: S) -> axum::Router<S>
where
    S: NodeDelegate + ControlStateDelegate + Clone + 'static,
{
    use axum::routing::{get, post};
    let auth_middleware = axum::middleware::from_fn_with_state(ctx, anon_auth_middleware::<S>);
    axum::Router::new()
        .route("/", get(get_identity::<S>).post(create_identity::<S>))
        .route("/public-key", get(get_public_key::<S>))
        .route("/request_recovery_code", post(request_recovery_code::<S>))
        .route("/confirm_recovery_code", post(confirm_recovery_code::<S>))
        .route("/websocket_token", post(create_websocket_token::<S>))
        .route("/:identity/verify", get(validate_token))
        .route(
            "/:identity/set-email",
            post(set_email::<S>).route_layer(auth_middleware.clone()),
        )
        .route("/:identity/emails", get(check_email::<S>).route_layer(auth_middleware))
        .route("/:identity/databases", get(get_databases::<S>))
}