tuitbot_core/source/google_drive/
mod.rs1mod jwt;
14
15use std::path::Path;
16use std::sync::Mutex;
17use std::time::{Duration, Instant};
18
19use async_trait::async_trait;
20
21use super::{ContentSourceProvider, SourceError, SourceFile};
22use crate::automation::watchtower::matches_patterns;
23use crate::source::connector::google_drive::GoogleDriveConnector;
24use crate::source::connector::{ConnectorError, RemoteConnector};
25use crate::storage::DbPool;
26
27pub enum DriveAuthStrategy {
33 ServiceAccount { key_path: String },
35 LinkedAccount {
37 connection_id: i64,
38 pool: DbPool,
39 connector_key: Vec<u8>,
40 connector: GoogleDriveConnector,
41 },
42}
43
44pub struct GoogleDriveProvider {
54 folder_id: String,
55 auth_strategy: DriveAuthStrategy,
56 http_client: reqwest::Client,
57 token_cache: Mutex<Option<CachedToken>>,
58}
59
60struct CachedToken {
61 access_token: String,
62 expires_at: Instant,
63}
64
65impl GoogleDriveProvider {
66 pub fn new(folder_id: String, service_account_key_path: String) -> Self {
68 Self {
69 folder_id,
70 auth_strategy: DriveAuthStrategy::ServiceAccount {
71 key_path: service_account_key_path,
72 },
73 http_client: reqwest::Client::new(),
74 token_cache: Mutex::new(None),
75 }
76 }
77
78 pub fn from_connection(
80 folder_id: String,
81 connection_id: i64,
82 pool: DbPool,
83 connector_key: Vec<u8>,
84 connector: GoogleDriveConnector,
85 ) -> Self {
86 Self {
87 folder_id,
88 auth_strategy: DriveAuthStrategy::LinkedAccount {
89 connection_id,
90 pool,
91 connector_key,
92 connector,
93 },
94 http_client: reqwest::Client::new(),
95 token_cache: Mutex::new(None),
96 }
97 }
98
99 #[cfg(test)]
101 pub fn with_client(
102 folder_id: String,
103 service_account_key_path: String,
104 client: reqwest::Client,
105 ) -> Self {
106 Self {
107 folder_id,
108 auth_strategy: DriveAuthStrategy::ServiceAccount {
109 key_path: service_account_key_path,
110 },
111 http_client: client,
112 token_cache: Mutex::new(None),
113 }
114 }
115
116 #[cfg(test)]
118 pub fn with_client_and_connection(
119 folder_id: String,
120 connection_id: i64,
121 pool: DbPool,
122 connector_key: Vec<u8>,
123 connector: GoogleDriveConnector,
124 client: reqwest::Client,
125 ) -> Self {
126 Self {
127 folder_id,
128 auth_strategy: DriveAuthStrategy::LinkedAccount {
129 connection_id,
130 pool,
131 connector_key,
132 connector,
133 },
134 http_client: client,
135 token_cache: Mutex::new(None),
136 }
137 }
138
139 async fn get_access_token(&self) -> Result<String, SourceError> {
141 if let Ok(cache) = self.token_cache.lock() {
143 if let Some(ref tok) = *cache {
144 if tok.expires_at > Instant::now() + Duration::from_secs(60) {
145 return Ok(tok.access_token.clone());
146 }
147 }
148 }
149
150 let token = match &self.auth_strategy {
151 DriveAuthStrategy::ServiceAccount { key_path } => {
152 self.fetch_service_account_token(key_path).await?
153 }
154 DriveAuthStrategy::LinkedAccount {
155 connection_id,
156 pool,
157 connector_key,
158 connector,
159 } => {
160 self.refresh_from_connection(*connection_id, pool, connector_key, connector)
161 .await?
162 }
163 };
164
165 let access_token = token.access_token.clone();
166
167 if let Ok(mut cache) = self.token_cache.lock() {
168 *cache = Some(token);
169 }
170
171 Ok(access_token)
172 }
173
174 async fn fetch_service_account_token(
177 &self,
178 key_path: &str,
179 ) -> Result<CachedToken, SourceError> {
180 let key_bytes = tokio::fs::read_to_string(key_path).await.map_err(|e| {
181 SourceError::Auth(format!("cannot read service account key {key_path}: {e}"))
182 })?;
183
184 let key_json: serde_json::Value = serde_json::from_str(&key_bytes)
185 .map_err(|e| SourceError::Auth(format!("invalid service account JSON: {e}")))?;
186
187 let client_email = key_json["client_email"]
188 .as_str()
189 .ok_or_else(|| SourceError::Auth("missing client_email in key".into()))?;
190
191 let private_key_pem = key_json["private_key"]
192 .as_str()
193 .ok_or_else(|| SourceError::Auth("missing private_key in key".into()))?;
194
195 let token_uri = key_json["token_uri"]
196 .as_str()
197 .unwrap_or("https://oauth2.googleapis.com/token");
198
199 let now = chrono::Utc::now().timestamp();
200 let claims = serde_json::json!({
201 "iss": client_email,
202 "scope": "https://www.googleapis.com/auth/drive.readonly",
203 "aud": token_uri,
204 "iat": now,
205 "exp": now + 3600,
206 });
207
208 let jwt_token = jwt::build_jwt(&claims, private_key_pem)?;
209
210 let resp = self
211 .http_client
212 .post(token_uri)
213 .form(&[
214 ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
215 ("assertion", &jwt_token),
216 ])
217 .send()
218 .await
219 .map_err(|e| SourceError::Auth(format!("token exchange failed: {e}")))?;
220
221 if !resp.status().is_success() {
222 let body = resp.text().await.unwrap_or_default();
223 return Err(SourceError::Auth(format!(
224 "token endpoint returned error: {body}"
225 )));
226 }
227
228 let body: serde_json::Value = resp
229 .json()
230 .await
231 .map_err(|e| SourceError::Auth(format!("invalid token response: {e}")))?;
232
233 let access_token = body["access_token"]
234 .as_str()
235 .ok_or_else(|| SourceError::Auth("no access_token in response".into()))?
236 .to_string();
237
238 let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
239
240 Ok(CachedToken {
241 access_token,
242 expires_at: Instant::now() + Duration::from_secs(expires_in),
243 })
244 }
245
246 async fn refresh_from_connection(
255 &self,
256 connection_id: i64,
257 pool: &DbPool,
258 connector_key: &[u8],
259 connector: &GoogleDriveConnector,
260 ) -> Result<CachedToken, SourceError> {
261 let encrypted = crate::storage::watchtower::read_encrypted_credentials(pool, connection_id)
263 .await
264 .map_err(|e| SourceError::ConnectionBroken {
265 connection_id,
266 reason: format!("failed to read credentials: {e}"),
267 })?;
268
269 let encrypted = match encrypted {
270 Some(enc) => enc,
271 None => {
272 return Err(SourceError::ConnectionBroken {
273 connection_id,
274 reason: "no credentials found for connection".into(),
275 });
276 }
277 };
278
279 match connector
281 .refresh_access_token(&encrypted, connector_key)
282 .await
283 {
284 Ok(refreshed) => {
285 let expires_in = refreshed.expires_in_secs.max(0) as u64;
286 Ok(CachedToken {
287 access_token: refreshed.access_token,
288 expires_at: Instant::now() + Duration::from_secs(expires_in),
289 })
290 }
291 Err(ConnectorError::TokenRefresh(msg)) if is_revocation_error(&msg) => {
292 Err(SourceError::ConnectionBroken {
293 connection_id,
294 reason: format!("token revoked: {msg}"),
295 })
296 }
297 Err(ConnectorError::Encryption(msg)) => Err(SourceError::ConnectionBroken {
298 connection_id,
299 reason: format!("credential decryption failed: {msg}"),
300 }),
301 Err(e) => Err(SourceError::Auth(format!(
302 "token refresh failed for connection {connection_id}: {e}"
303 ))),
304 }
305 }
306}
307
308fn is_revocation_error(msg: &str) -> bool {
310 let lower = msg.to_lowercase();
311 lower.contains("revoked")
312 || lower.contains("invalid_grant")
313 || lower.contains("token has been expired or revoked")
314}
315
316#[async_trait]
317impl ContentSourceProvider for GoogleDriveProvider {
318 fn source_type(&self) -> &str {
319 "google_drive"
320 }
321
322 async fn scan_for_changes(
323 &self,
324 since_cursor: Option<&str>,
325 patterns: &[String],
326 ) -> Result<Vec<SourceFile>, SourceError> {
327 let token = self.get_access_token().await?;
328
329 let mut q = format!("'{}' in parents and trashed = false", self.folder_id);
330
331 if let Some(cursor) = since_cursor {
332 q.push_str(&format!(" and modifiedTime > '{cursor}'"));
333 }
334
335 let resp = self
336 .http_client
337 .get("https://www.googleapis.com/drive/v3/files")
338 .bearer_auth(&token)
339 .query(&[
340 ("q", q.as_str()),
341 ("fields", "files(id,name,md5Checksum,modifiedTime,mimeType)"),
342 ("pageSize", "1000"),
343 ])
344 .send()
345 .await
346 .map_err(|e| SourceError::Network(format!("Drive list failed: {e}")))?;
347
348 if !resp.status().is_success() {
349 let body = resp.text().await.unwrap_or_default();
350 return Err(SourceError::Network(format!("Drive API error: {body}")));
351 }
352
353 let body: serde_json::Value = resp
354 .json()
355 .await
356 .map_err(|e| SourceError::Network(format!("invalid Drive response: {e}")))?;
357
358 let files = body["files"].as_array().cloned().unwrap_or_default();
359
360 let mut result = Vec::new();
361 for file in &files {
362 let id = match file["id"].as_str() {
363 Some(id) => id,
364 None => continue,
365 };
366 let name = file["name"].as_str().unwrap_or("unknown");
367
368 if !patterns.is_empty() && !matches_patterns(Path::new(name), patterns) {
369 continue;
370 }
371
372 let hash = file["md5Checksum"].as_str().unwrap_or("").to_string();
373 let modified = file["modifiedTime"].as_str().unwrap_or("").to_string();
374
375 result.push(SourceFile {
376 provider_id: format!("gdrive://{id}/{name}"),
377 display_name: name.to_string(),
378 content_hash: hash,
379 modified_at: modified,
380 });
381 }
382
383 Ok(result)
384 }
385
386 async fn read_content(&self, file_id: &str) -> Result<String, SourceError> {
387 let drive_id = extract_drive_id(file_id)?;
388 let token = self.get_access_token().await?;
389
390 let url = format!("https://www.googleapis.com/drive/v3/files/{drive_id}?alt=media");
391
392 let resp = self
393 .http_client
394 .get(&url)
395 .bearer_auth(&token)
396 .send()
397 .await
398 .map_err(|e| SourceError::Network(format!("Drive get failed: {e}")))?;
399
400 if resp.status() == reqwest::StatusCode::NOT_FOUND {
401 return Err(SourceError::NotFound(format!("file {drive_id} not found")));
402 }
403
404 if !resp.status().is_success() {
405 let body = resp.text().await.unwrap_or_default();
406 return Err(SourceError::Network(format!(
407 "Drive download error: {body}"
408 )));
409 }
410
411 resp.text()
412 .await
413 .map_err(|e| SourceError::Network(format!("read body failed: {e}")))
414 }
415}
416
417#[cfg(test)]
423impl GoogleDriveProvider {
424 pub fn extract_drive_id_for_test(provider_id: &str) -> String {
425 extract_drive_id(provider_id).unwrap()
426 }
427}
428
429fn extract_drive_id(provider_id: &str) -> Result<String, SourceError> {
432 if let Some(rest) = provider_id.strip_prefix("gdrive://") {
433 if let Some(slash) = rest.find('/') {
434 Ok(rest[..slash].to_string())
435 } else {
436 Ok(rest.to_string())
437 }
438 } else {
439 Ok(provider_id.to_string())
440 }
441}