Skip to main content

systemprompt_api/routes/oauth/endpoints/token/
validation.rs

1use super::{TokenError, TokenResult};
2use anyhow::Result;
3use systemprompt_identifiers::{AuthorizationCode, ClientId};
4use systemprompt_oauth::repository::{AuthCodeValidationResult, OAuthRepository};
5use systemprompt_oauth::services::validation::validate_client_credentials as validate_client_credentials_shared;
6
7pub fn extract_required_field<'a>(
8    field: Option<&'a str>,
9    field_name: &str,
10) -> TokenResult<&'a str> {
11    field.ok_or_else(|| TokenError::InvalidRequest {
12        field: field_name.to_string(),
13        message: "is required".to_string(),
14    })
15}
16
17pub async fn validate_client_credentials(
18    repo: &OAuthRepository,
19    client_id: &ClientId,
20    client_secret: Option<&str>,
21) -> Result<()> {
22    validate_client_credentials_shared(repo, client_id, client_secret)
23        .await
24        .map_err(Into::into)
25}
26
27#[derive(Debug)]
28pub struct AuthCodeValidationParams<'a> {
29    pub repo: &'a OAuthRepository,
30    pub code: &'a AuthorizationCode,
31    pub client_id: &'a ClientId,
32    pub redirect_uri: Option<&'a str>,
33    pub code_verifier: Option<&'a str>,
34    pub request_resource: Option<&'a str>,
35}
36
37pub async fn validate_authorization_code(
38    params: AuthCodeValidationParams<'_>,
39) -> Result<AuthCodeValidationResult> {
40    let result = params
41        .repo
42        .validate_authorization_code(
43            params.code,
44            params.client_id,
45            params.redirect_uri,
46            params.code_verifier,
47        )
48        .await?;
49
50    if let Some(req_resource) = params.request_resource {
51        if let Some(ref stored_resource) = result.resource {
52            if req_resource != stored_resource {
53                return Err(anyhow::anyhow!(
54                    "Resource parameter mismatch: expected '{}', got '{}'",
55                    stored_resource,
56                    req_resource
57                ));
58            }
59        }
60    }
61
62    Ok(result)
63}