Skip to main content

stormchaser_api/routes/
auth.rs

1use super::{AuthExchangeRequest, AuthExchangeResponse, AuthRefreshRequest};
2use crate::auth;
3use crate::{AppState, Claims, JWT_SECRET};
4use axum::{
5    extract::{Query, State},
6    http::StatusCode,
7    response::{IntoResponse, Redirect},
8    Json,
9};
10use jsonwebtoken::decode;
11use jsonwebtoken::decode_header;
12use jsonwebtoken::DecodingKey;
13use jsonwebtoken::Validation;
14use jsonwebtoken::{encode, EncodingKey, Header};
15use reqwest::Client;
16use serde::Deserialize;
17use std::time::{SystemTime, UNIX_EPOCH};
18
19#[derive(Debug, Deserialize)]
20/// Loginquery.
21pub struct LoginQuery {
22    /// The callback url.
23    pub callback_url: String,
24}
25
26#[utoipa::path(
27    get,
28    path = "/api/v1/auth/login",
29    params(
30        ("callback_url" = String, Query, description = "The local URL to redirect back to after login")
31    ),
32    responses(
33        (status = 303, description = "Redirect to OIDC provider")
34    ),
35    tag = "stormchaser"
36)]
37/// Login.
38pub async fn login(
39    State(state): State<AppState>,
40    Query(query): Query<LoginQuery>,
41) -> Result<impl IntoResponse, StatusCode> {
42    let oidc_config = state
43        .oidc_config
44        .as_ref()
45        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
46
47    let auth_url = format!(
48        "{}/auth?client_id={}&redirect_uri={}&response_type=code&scope=openid+profile+email+offline_access",
49        oidc_config.external_issuer,
50        oidc_config.client_id,
51        urlencoding::encode(&query.callback_url)
52    );
53
54    Ok(Redirect::to(&auth_url))
55}
56
57#[derive(Debug, Deserialize)]
58struct TokenResponse {
59    id_token: String,
60    refresh_token: Option<String>,
61}
62
63#[utoipa::path(
64    post,
65    path = "/api/v1/auth/exchange",
66    request_body = AuthExchangeRequest,
67    responses(
68        (status = 200, description = "Token exchanged successfully", body = AuthExchangeResponse),
69        (status = 401, description = "Unauthorized")
70    ),
71    tag = "stormchaser"
72)]
73/// Exchange token.
74pub async fn exchange_token(
75    State(state): State<AppState>,
76    Json(payload): Json<AuthExchangeRequest>,
77) -> Result<impl IntoResponse, StatusCode> {
78    let oidc_config = state
79        .oidc_config
80        .as_ref()
81        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
82
83    // 1. Exchange code for tokens
84    let client = Client::new();
85    let token_url = format!("{}/token", oidc_config.issuer.trim_end_matches('/'));
86
87    let params = [
88        ("grant_type", "authorization_code"),
89        ("code", &payload.sso_token),
90        ("client_id", &oidc_config.client_id),
91        ("client_secret", &oidc_config.client_secret),
92        ("redirect_uri", &payload.callback_url),
93    ];
94
95    let res = client
96        .post(&token_url)
97        .form(&params)
98        .send()
99        .await
100        .map_err(|e| {
101            tracing::error!("Failed to send token exchange request: {:?}", e);
102            StatusCode::UNAUTHORIZED
103        })?;
104
105    if !res.status().is_success() {
106        let status = res.status();
107        let err_body = res.text().await.unwrap_or_default();
108        tracing::error!("Token exchange failed with status {}: {}", status, err_body);
109        return Err(StatusCode::UNAUTHORIZED);
110    }
111
112    let token_res: TokenResponse = res.json().await.map_err(|e| {
113        tracing::error!("Failed to parse token response: {:?}", e);
114        StatusCode::UNAUTHORIZED
115    })?;
116
117    // 2. Validate the ID Token
118    let header = match decode_header(&token_res.id_token) {
119        Ok(h) => h,
120        Err(e) => {
121            tracing::error!("Failed to decode id_token header: {:?}", e);
122            return Err(StatusCode::UNAUTHORIZED);
123        }
124    };
125
126    let kid = match header.kid {
127        Some(k) => k,
128        None => {
129            tracing::error!("No kid in id_token header");
130            return Err(StatusCode::UNAUTHORIZED);
131        }
132    };
133
134    let jwk_opt = state.jwks.read().await.get(&kid).cloned();
135    let jwk = match jwk_opt {
136        Some(j) => j,
137        None => {
138            tracing::warn!("kid {} not found in JWKS cache, attempting refresh", kid);
139            let new_jwks = auth::jwks::fetch_jwks(&oidc_config.jwks_url).await;
140            let mut jwks_write = state.jwks.write().await;
141            *jwks_write = new_jwks;
142
143            match jwks_write.get(&kid) {
144                Some(j) => j.clone(),
145                None => {
146                    tracing::error!("kid {} not found in JWKS cache even after refresh", kid);
147                    return Err(StatusCode::UNAUTHORIZED);
148                }
149            }
150        }
151    };
152
153    let mut validation = Validation::new(header.alg);
154    validation.set_audience(std::slice::from_ref(&oidc_config.client_id));
155    validation.set_issuer(&[
156        oidc_config.issuer.as_str(),
157        oidc_config.external_issuer.as_str(),
158    ]);
159
160    let decoding_key = match DecodingKey::from_jwk(&jwk) {
161        Ok(k) => k,
162        Err(e) => {
163            tracing::error!("Failed to create decoding key from JWK: {:?}", e);
164            return Err(StatusCode::UNAUTHORIZED);
165        }
166    };
167
168    let token_data = match decode::<Claims>(&token_res.id_token, &decoding_key, &validation) {
169        Ok(d) => d,
170        Err(e) => {
171            tracing::error!("Failed to validate id_token: {:?}", e);
172            return Err(StatusCode::UNAUTHORIZED);
173        }
174    };
175
176    // 3. Generate Stormchaser Access Token
177    let user_id = token_data.claims.sub;
178    let email = token_data.claims.email;
179
180    let expires_in = 3600;
181    let expiration = SystemTime::now()
182        .duration_since(UNIX_EPOCH)
183        .unwrap()
184        .as_secs() as usize
185        + expires_in;
186
187    let claims = Claims {
188        sub: user_id,
189        email,
190        exp: expiration,
191    };
192
193    let token = encode(
194        &Header::default(),
195        &claims,
196        &EncodingKey::from_secret(JWT_SECRET),
197    )
198    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
199
200    Ok(Json(AuthExchangeResponse {
201        access_token: token,
202        refresh_token: token_res.refresh_token,
203        token_type: "Bearer".to_string(),
204        expires_in,
205    }))
206}
207
208/// Refreshes the auth token.
209#[utoipa::path(
210    post,
211    path = "/api/v1/auth/refresh",
212    request_body = AuthRefreshRequest,
213    responses(
214        (status = 200, description = "Token refreshed successfully", body = AuthExchangeResponse),
215        (status = 401, description = "Unauthorized")
216    ),
217    tag = "stormchaser"
218)]
219/// Refresh token.
220pub async fn refresh_token(
221    State(state): State<AppState>,
222    Json(payload): Json<AuthRefreshRequest>,
223) -> Result<impl IntoResponse, StatusCode> {
224    let oidc_config = state
225        .oidc_config
226        .as_ref()
227        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
228
229    // Exchange refresh token for new tokens
230    let client = Client::new();
231    let token_url = format!("{}/token", oidc_config.issuer.trim_end_matches('/'));
232
233    let params = [
234        ("grant_type", "refresh_token"),
235        ("refresh_token", &payload.refresh_token),
236        ("client_id", &oidc_config.client_id),
237        ("client_secret", &oidc_config.client_secret),
238    ];
239
240    let res = client
241        .post(&token_url)
242        .form(&params)
243        .send()
244        .await
245        .map_err(|e| {
246            tracing::error!("Failed to send refresh request: {:?}", e);
247            StatusCode::UNAUTHORIZED
248        })?;
249
250    if !res.status().is_success() {
251        let status = res.status();
252        let err_body = res.text().await.unwrap_or_default();
253        tracing::error!("Token refresh failed with status {}: {}", status, err_body);
254        return Err(StatusCode::UNAUTHORIZED);
255    }
256
257    let token_res: TokenResponse = res.json().await.map_err(|e| {
258        tracing::error!("Failed to parse refresh response: {:?}", e);
259        StatusCode::UNAUTHORIZED
260    })?;
261
262    // Validate the new ID Token
263    let header = match decode_header(&token_res.id_token) {
264        Ok(h) => h,
265        Err(e) => {
266            tracing::error!("Failed to decode id_token header: {:?}", e);
267            return Err(StatusCode::UNAUTHORIZED);
268        }
269    };
270
271    let kid = header.kid.ok_or_else(|| {
272        tracing::error!("No kid in id_token header");
273        StatusCode::UNAUTHORIZED
274    })?;
275
276    let jwk_opt = state.jwks.read().await.get(&kid).cloned();
277    let jwk = match jwk_opt {
278        Some(j) => j,
279        None => {
280            tracing::warn!("kid {} not found in JWKS cache, attempting refresh", kid);
281            let new_jwks = auth::jwks::fetch_jwks(&oidc_config.jwks_url).await;
282            let mut jwks_write = state.jwks.write().await;
283            *jwks_write = new_jwks;
284
285            match jwks_write.get(&kid) {
286                Some(j) => j.clone(),
287                None => {
288                    tracing::error!("kid {} not found in JWKS cache even after refresh", kid);
289                    return Err(StatusCode::UNAUTHORIZED);
290                }
291            }
292        }
293    };
294
295    let mut validation = Validation::new(header.alg);
296    validation.set_audience(std::slice::from_ref(&oidc_config.client_id));
297    validation.set_issuer(&[
298        oidc_config.issuer.as_str(),
299        oidc_config.external_issuer.as_str(),
300    ]);
301
302    let decoding_key = DecodingKey::from_jwk(&jwk).map_err(|e| {
303        tracing::error!("Failed to create decoding key: {:?}", e);
304        StatusCode::UNAUTHORIZED
305    })?;
306
307    let token_data =
308        decode::<Claims>(&token_res.id_token, &decoding_key, &validation).map_err(|e| {
309            tracing::error!("Failed to validate id_token: {:?}", e);
310            StatusCode::UNAUTHORIZED
311        })?;
312
313    let user_id = token_data.claims.sub;
314    let email = token_data.claims.email;
315    let expires_in = 3600;
316    let expiration = SystemTime::now()
317        .duration_since(UNIX_EPOCH)
318        .unwrap()
319        .as_secs() as usize
320        + expires_in;
321
322    let claims = Claims {
323        sub: user_id,
324        email,
325        exp: expiration,
326    };
327
328    let token = encode(
329        &Header::default(),
330        &claims,
331        &EncodingKey::from_secret(JWT_SECRET),
332    )
333    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
334
335    Ok(Json(AuthExchangeResponse {
336        access_token: token,
337        refresh_token: token_res.refresh_token,
338        token_type: "Bearer".to_string(),
339        expires_in,
340    }))
341}