Skip to main content

twapi_oauth2/
lib.rs

1use std::time::Duration;
2
3use base64::prelude::*;
4use query_string_builder::QueryString;
5use reqwest::{RequestBuilder, StatusCode, header::HeaderMap};
6use serde::{Deserialize, Serialize};
7use sha2::Digest;
8
9pub mod error;
10pub mod x;
11
12pub use reqwest;
13
14use crate::error::OAuth2Error;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct TokenResult {
18    pub access_token: String,
19    pub refresh_token: String,
20    pub expires_in: u64,
21    pub scope: String,
22    pub token_type: String,
23}
24
25pub(crate) enum ResponseType {
26    Code,
27    #[allow(unused)]
28    Token,
29}
30
31impl std::fmt::Display for ResponseType {
32    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
33        match self {
34            Self::Code => write!(f, "code"),
35            Self::Token => write!(f, "token"),
36        }
37    }
38}
39
40pub(crate) enum CodeChallengeMethod {
41    S256,
42    #[allow(unused)]
43    Plain,
44}
45
46impl std::fmt::Display for CodeChallengeMethod {
47    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
48        match self {
49            Self::S256 => write!(f, "S256"),
50            Self::Plain => write!(f, "plain"),
51        }
52    }
53}
54
55pub(crate) struct PkceS256 {
56    pub code_challenge: String,
57    pub code_verifier: String,
58}
59
60impl PkceS256 {
61    pub fn new() -> Self {
62        let size = 32;
63        let random_bytes: Vec<u8> = (0..size).map(|_| rand::random::<u8>()).collect();
64        let code_verifier = BASE64_URL_SAFE_NO_PAD.encode(&random_bytes);
65        let code_challenge = {
66            let hash = sha2::Sha256::digest(code_verifier.as_bytes());
67            BASE64_URL_SAFE_NO_PAD.encode(hash)
68        };
69        Self {
70            code_challenge,
71            code_verifier,
72        }
73    }
74}
75
76impl Default for PkceS256 {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82#[allow(clippy::too_many_arguments)]
83pub(crate) fn authorize_url(
84    url: &str,
85    response_type: ResponseType,
86    client_id: &str,
87    redirect_uri: &str,
88    scopes: &str,
89    state: &str,
90    code_challenge: &str,
91    code_challenge_method: CodeChallengeMethod,
92) -> String {
93    let qs = QueryString::dynamic()
94        .with_value("response_type", response_type.to_string())
95        .with_value("client_id", client_id)
96        .with_value("redirect_uri", redirect_uri)
97        .with_value("scope", scopes)
98        .with_value("state", state)
99        .with_value("code_challenge", code_challenge)
100        .with_value("code_challenge_method", code_challenge_method.to_string());
101    format!("{}{}", url, qs)
102}
103
104#[allow(clippy::too_many_arguments)]
105pub(crate) async fn token(
106    url: &str,
107    client_id: &str,
108    client_secret: &str,
109    redirect_uri: &str,
110    code: &str,
111    code_verifier: &str,
112    grant_type: &str,
113    timeout: Duration,
114    try_count: usize,
115    retry_millis: u64,
116) -> Result<(TokenResult, StatusCode, HeaderMap), OAuth2Error> {
117    let params = [
118        ("grant_type", grant_type),
119        ("code", code),
120        ("redirect_uri", redirect_uri),
121        ("client_id", client_id),
122        ("code_verifier", code_verifier),
123    ];
124
125    let client = reqwest::Client::new();
126
127    execute_retry(
128        || {
129            client
130                .post(url)
131                .form(&params)
132                .basic_auth(client_id, Some(client_secret))
133                .timeout(timeout)
134        },
135        try_count,
136        retry_millis,
137    )
138    .await
139}
140
141pub(crate) async fn execute_retry<T>(
142    f: impl Fn() -> RequestBuilder,
143    try_count: usize,
144    retry_millis: u64,
145) -> Result<(T, StatusCode, HeaderMap), OAuth2Error>
146where
147    T: serde::de::DeserializeOwned,
148{
149    for i in 0..try_count {
150        let req = f();
151        let res = req.send().await?;
152        let status = res.status();
153        let headers = res.headers().clone();
154        if status.is_success() {
155            let json: T = res.json().await?;
156            return Ok((json, status, headers));
157        } else if status.is_client_error() {
158            let body = res.text().await.unwrap_or_default();
159            return Err(OAuth2Error::ClientError(body, status, headers));
160        }
161        if i + 1 < try_count {
162            // ジッターとエクスポーネンシャルバックオフを組み合わせる
163            let jitter: u64 = rand::random::<u64>() % retry_millis;
164            let exp_backoff = 2u64.pow(i as u32) * retry_millis;
165            let retry_duration = Duration::from_millis(exp_backoff + jitter);
166            tokio::time::sleep(retry_duration).await;
167        } else {
168            let body = res.text().await.unwrap_or_default();
169            return Err(OAuth2Error::RetryOver(body, status, headers));
170        }
171    }
172    unreachable!()
173}
174
175#[cfg(test)]
176mod tests {
177    use crate::x::{X_AUTHORIZE_URL, XScope};
178
179    use super::*;
180
181    // CLIENT_ID=xxx cargo test -- --nocapture
182    #[tokio::test]
183    async fn test_authorize() {
184        let client_id = std::env::var("CLIENT_ID").unwrap();
185        let redirect_url = std::env::var("REDIRECT_URL").unwrap();
186        let state = "test_state";
187        let scopes = XScope::scopes_to_string(&XScope::all());
188        let code_challenge = "test_code_challenge";
189        let res = authorize_url(
190            X_AUTHORIZE_URL,
191            ResponseType::Code,
192            &client_id,
193            &redirect_url,
194            &scopes,
195            &state,
196            &code_challenge,
197            CodeChallengeMethod::Plain,
198        );
199        println!("res: {}", res);
200    }
201}