Skip to main content

ytmapi_rs/auth/
oauth.rs

1use super::{AuthToken, RawResult, fallback_client_version};
2use crate::client::Client;
3use crate::error::{Error, Result};
4use crate::parse::ProcessedResult;
5use crate::utils::constants::{
6    OAUTH_CODE_URL, OAUTH_GRANT_URL, OAUTH_SCOPE, OAUTH_TOKEN_URL, OAUTH_USER_AGENT, USER_AGENT,
7    YTM_URL,
8};
9use serde::{Deserialize, Serialize};
10use serde_json::json;
11use std::borrow::Cow;
12use std::time::{SystemTime, UNIX_EPOCH};
13
14/// Since we detect oauth expiry on the client side, to reduce risk of race
15/// conditions we refresh `REFRESH_S_BEFORE_EXPIRING` seconds before the token
16/// is due to expire.
17const REFRESH_S_BEFORE_EXPIRING: u64 = 60;
18
19// The original reason for the two different structs was that we did not save
20// the refresh token. But now we do, so consider simply making this only one
21// struct. Otherwise the only difference is not including Scope which is not
22// super relevant.
23#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
24pub struct OAuthToken {
25    token_type: String,
26    access_token: String,
27    refresh_token: String,
28    expires_in: usize,
29    request_time: SystemTime,
30    client_id: String,
31    client_secret: String,
32}
33// TODO: Lock down construction of this type.
34#[derive(Clone, Deserialize)]
35pub struct OAuthDeviceCode(String);
36
37#[derive(Clone, Deserialize)]
38struct GoogleOAuthToken {
39    pub access_token: String,
40    /// Currently it seems Google gives out these for around 6 min.
41    pub expires_in: usize,
42    pub refresh_token: String,
43    // Unused currently - for future use
44    #[allow(dead_code)]
45    pub scope: String,
46    pub token_type: String,
47}
48#[derive(Clone, Deserialize)]
49struct GoogleOAuthRefreshToken {
50    pub access_token: String,
51    pub expires_in: usize,
52    // Unused currently - for future use
53    #[allow(dead_code)]
54    pub scope: String,
55    pub token_type: String,
56}
57#[derive(Clone, Deserialize)]
58pub struct OAuthTokenGenerator {
59    pub device_code: OAuthDeviceCode,
60    pub expires_in: usize,
61    pub interval: usize,
62    pub user_code: String,
63    pub verification_url: String,
64}
65
66impl OAuthToken {
67    fn from_google_refresh_token(
68        google_token: GoogleOAuthRefreshToken,
69        request_time: SystemTime,
70        refresh_token: String,
71        client_id: String,
72        client_secret: String,
73    ) -> Self {
74        // See comment above on OAuthToken
75        let GoogleOAuthRefreshToken {
76            access_token,
77            expires_in,
78            token_type,
79            ..
80        } = google_token;
81        Self {
82            token_type,
83            refresh_token,
84            access_token,
85            request_time,
86            expires_in,
87            client_id,
88            client_secret,
89        }
90    }
91    fn from_google_token(
92        google_token: GoogleOAuthToken,
93        request_time: SystemTime,
94        client_id: String,
95        client_secret: String,
96    ) -> Self {
97        // See comment above on OAuthToken
98        let GoogleOAuthToken {
99            access_token,
100            expires_in,
101            token_type,
102            refresh_token,
103            ..
104        } = google_token;
105        Self {
106            token_type,
107            refresh_token,
108            access_token,
109            request_time,
110            expires_in,
111            client_id,
112            client_secret,
113        }
114    }
115}
116
117impl OAuthDeviceCode {
118    pub fn new(code: String) -> Self {
119        Self(code)
120    }
121    pub fn get_code(&self) -> &str {
122        &self.0
123    }
124}
125
126impl AuthToken for OAuthToken {
127    fn deserialize_response<Q>(
128        raw: RawResult<Q, Self>,
129    ) -> Result<crate::parse::ProcessedResult<Q>> {
130        let processed = ProcessedResult::try_from(raw)?;
131        // Guard against error codes in json response.
132        // TODO: Add a test for this
133        if let Some(error) = processed.get_json().pointer("/error") {
134            let Some(code) = error.pointer("/code").and_then(|v| v.as_u64()) else {
135                return Err(Error::response("API reported an error but no code"));
136            };
137            let message = error
138                .pointer("/message")
139                .and_then(|s| s.as_str())
140                .map(|s| s.to_string())
141                .unwrap_or_default();
142            return Err(Error::other_code(code, message));
143        }
144        Ok(processed)
145    }
146    fn headers(&self) -> Result<impl IntoIterator<Item = (&str, Cow<'_, str>)>> {
147        let request_time_unix = self.request_time.duration_since(UNIX_EPOCH)?.as_secs();
148        let now_unix = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
149        // TODO: Better handling for expiration case.
150        if now_unix + REFRESH_S_BEFORE_EXPIRING > request_time_unix + self.expires_in as u64 {
151            return Err(Error::oauth_token_expired(self));
152        }
153        Ok([
154            // TODO: Confirm if parsing for expired user agent also relevant here.
155            ("User-Agent", USER_AGENT.into()),
156            ("X-Origin", YTM_URL.into()),
157            ("Content-Type", "application/json".into()),
158            (
159                "Authorization",
160                format!("{} {}", self.token_type, self.access_token).into(),
161            ),
162            ("X-Goog-Request-Time", request_time_unix.to_string().into()),
163        ])
164    }
165    fn client_version(&self) -> Cow<'_, str> {
166        let now_datetime: chrono::DateTime<chrono::Utc> = SystemTime::now().into();
167        fallback_client_version(&now_datetime).into()
168    }
169}
170
171impl OAuthToken {
172    pub async fn from_code(
173        client: &Client,
174        code: OAuthDeviceCode,
175        client_id: impl Into<String>,
176        client_secret: impl Into<String>,
177    ) -> Result<OAuthToken> {
178        let client_id = client_id.into();
179        let client_secret = client_secret.into();
180        let body = json!({
181            "client_secret" : &client_secret,
182            "grant_type" : OAUTH_GRANT_URL,
183            "code" : code.get_code(),
184            "client_id" : &client_id
185        });
186        let headers = [("User-Agent", OAUTH_USER_AGENT.into())];
187        let result = client
188            .post_json_query(OAUTH_TOKEN_URL, headers, &body, &())
189            .await?;
190        let google_token: GoogleOAuthToken =
191            serde_json::from_str(&result.text).map_err(|_| Error::response(&result.text))?;
192        Ok(OAuthToken::from_google_token(
193            google_token,
194            SystemTime::now(),
195            client_id,
196            client_secret,
197        ))
198    }
199    pub async fn refresh(&self, client: &Client) -> Result<OAuthToken> {
200        let body = json!({
201            "grant_type" : "refresh_token",
202            "refresh_token" : self.refresh_token,
203            "client_secret" : self.client_secret,
204            "client_id" : self.client_id
205        });
206        let headers = [("User-Agent", OAUTH_USER_AGENT.into())];
207        let result = client
208            .post_json_query(OAUTH_TOKEN_URL, headers, &body, &())
209            .await?;
210        let google_token: GoogleOAuthRefreshToken = serde_json::from_str(&result.text)
211            .map_err(|e| Error::unable_to_serialize_oauth(&result.text, e))?;
212        Ok(OAuthToken::from_google_refresh_token(
213            google_token,
214            SystemTime::now(),
215            // TODO: Remove clone.
216            self.refresh_token.clone(),
217            self.client_id.clone(),
218            self.client_secret.clone(),
219        ))
220    }
221}
222
223impl OAuthTokenGenerator {
224    pub async fn new(client: &Client, client_id: impl Into<String>) -> Result<OAuthTokenGenerator> {
225        let body = json!({
226            "scope" : OAUTH_SCOPE,
227            "client_id" : client_id.into()
228        });
229        let headers = [("User-Agent", OAUTH_USER_AGENT.into())];
230        let result_text = client
231            .post_json_query(OAUTH_CODE_URL, headers, &body, &())
232            .await?
233            .text;
234        serde_json::from_str(&result_text).map_err(|_| Error::response(&result_text))
235    }
236}
237// Don't use default Debug implementation for OAuthToken - contents are
238// private
239// TODO: Display some fields, such as time.
240impl std::fmt::Debug for OAuthToken {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        write!(f, "Private OAuthToken")
243    }
244}