tuitbot_core/source/google_drive/
mod.rs1mod jwt;
14
15use std::path::Path;
16use std::sync::Mutex;
17use std::time::{Duration, Instant};
18
19use async_trait::async_trait;
20
21use super::{ContentSourceProvider, SourceError, SourceFile};
22use crate::automation::watchtower::matches_patterns;
23use crate::source::connector::google_drive::GoogleDriveConnector;
24use crate::source::connector::{ConnectorError, RemoteConnector};
25use crate::storage::DbPool;
26
27pub enum DriveAuthStrategy {
33 ServiceAccount { key_path: String },
35 LinkedAccount {
37 connection_id: i64,
38 pool: DbPool,
39 connector_key: Vec<u8>,
40 connector: GoogleDriveConnector,
41 },
42}
43
44pub struct GoogleDriveProvider {
54 folder_id: String,
55 auth_strategy: DriveAuthStrategy,
56 http_client: reqwest::Client,
57 token_cache: Mutex<Option<CachedToken>>,
58}
59
60struct CachedToken {
61 access_token: String,
62 expires_at: Instant,
63}
64
65impl GoogleDriveProvider {
66 pub fn new(folder_id: String, service_account_key_path: String) -> Self {
68 Self {
69 folder_id,
70 auth_strategy: DriveAuthStrategy::ServiceAccount {
71 key_path: service_account_key_path,
72 },
73 http_client: reqwest::Client::new(),
74 token_cache: Mutex::new(None),
75 }
76 }
77
78 pub fn from_connection(
80 folder_id: String,
81 connection_id: i64,
82 pool: DbPool,
83 connector_key: Vec<u8>,
84 connector: GoogleDriveConnector,
85 ) -> Self {
86 Self {
87 folder_id,
88 auth_strategy: DriveAuthStrategy::LinkedAccount {
89 connection_id,
90 pool,
91 connector_key,
92 connector,
93 },
94 http_client: reqwest::Client::new(),
95 token_cache: Mutex::new(None),
96 }
97 }
98
99 #[cfg(test)]
101 pub fn with_client(
102 folder_id: String,
103 service_account_key_path: String,
104 client: reqwest::Client,
105 ) -> Self {
106 Self {
107 folder_id,
108 auth_strategy: DriveAuthStrategy::ServiceAccount {
109 key_path: service_account_key_path,
110 },
111 http_client: client,
112 token_cache: Mutex::new(None),
113 }
114 }
115
116 #[cfg(test)]
118 pub fn with_client_and_connection(
119 folder_id: String,
120 connection_id: i64,
121 pool: DbPool,
122 connector_key: Vec<u8>,
123 connector: GoogleDriveConnector,
124 client: reqwest::Client,
125 ) -> Self {
126 Self {
127 folder_id,
128 auth_strategy: DriveAuthStrategy::LinkedAccount {
129 connection_id,
130 pool,
131 connector_key,
132 connector,
133 },
134 http_client: client,
135 token_cache: Mutex::new(None),
136 }
137 }
138
139 async fn get_access_token(&self) -> Result<String, SourceError> {
141 if let Ok(cache) = self.token_cache.lock() {
143 if let Some(ref tok) = *cache {
144 if tok.expires_at > Instant::now() + Duration::from_secs(60) {
145 return Ok(tok.access_token.clone());
146 }
147 }
148 }
149
150 let token = match &self.auth_strategy {
151 DriveAuthStrategy::ServiceAccount { key_path } => {
152 self.fetch_service_account_token(key_path).await?
153 }
154 DriveAuthStrategy::LinkedAccount {
155 connection_id,
156 pool,
157 connector_key,
158 connector,
159 } => {
160 self.refresh_from_connection(*connection_id, pool, connector_key, connector)
161 .await?
162 }
163 };
164
165 let access_token = token.access_token.clone();
166
167 if let Ok(mut cache) = self.token_cache.lock() {
168 *cache = Some(token);
169 }
170
171 Ok(access_token)
172 }
173
174 async fn fetch_service_account_token(
177 &self,
178 key_path: &str,
179 ) -> Result<CachedToken, SourceError> {
180 let key_bytes = tokio::fs::read_to_string(key_path).await.map_err(|e| {
181 SourceError::Auth(format!("cannot read service account key {key_path}: {e}"))
182 })?;
183
184 let key_json: serde_json::Value = serde_json::from_str(&key_bytes)
185 .map_err(|e| SourceError::Auth(format!("invalid service account JSON: {e}")))?;
186
187 let client_email = key_json["client_email"]
188 .as_str()
189 .ok_or_else(|| SourceError::Auth("missing client_email in key".into()))?;
190
191 let private_key_pem = key_json["private_key"]
192 .as_str()
193 .ok_or_else(|| SourceError::Auth("missing private_key in key".into()))?;
194
195 let token_uri = key_json["token_uri"]
196 .as_str()
197 .unwrap_or("https://oauth2.googleapis.com/token");
198
199 let now = chrono::Utc::now().timestamp();
200 let claims = serde_json::json!({
201 "iss": client_email,
202 "scope": "https://www.googleapis.com/auth/drive.readonly",
203 "aud": token_uri,
204 "iat": now,
205 "exp": now + 3600,
206 });
207
208 let jwt_token = jwt::build_jwt(&claims, private_key_pem)?;
209
210 let resp = self
211 .http_client
212 .post(token_uri)
213 .form(&[
214 ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
215 ("assertion", &jwt_token),
216 ])
217 .send()
218 .await
219 .map_err(|e| SourceError::Auth(format!("token exchange failed: {e}")))?;
220
221 if !resp.status().is_success() {
222 let body = resp.text().await.unwrap_or_default();
223 return Err(SourceError::Auth(format!(
224 "token endpoint returned error: {body}"
225 )));
226 }
227
228 let body: serde_json::Value = resp
229 .json()
230 .await
231 .map_err(|e| SourceError::Auth(format!("invalid token response: {e}")))?;
232
233 let access_token = body["access_token"]
234 .as_str()
235 .ok_or_else(|| SourceError::Auth("no access_token in response".into()))?
236 .to_string();
237
238 let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
239
240 Ok(CachedToken {
241 access_token,
242 expires_at: Instant::now() + Duration::from_secs(expires_in),
243 })
244 }
245
246 async fn refresh_from_connection(
255 &self,
256 connection_id: i64,
257 pool: &DbPool,
258 connector_key: &[u8],
259 connector: &GoogleDriveConnector,
260 ) -> Result<CachedToken, SourceError> {
261 let encrypted = crate::storage::watchtower::read_encrypted_credentials(pool, connection_id)
263 .await
264 .map_err(|e| SourceError::ConnectionBroken {
265 connection_id,
266 reason: format!("failed to read credentials: {e}"),
267 })?;
268
269 let encrypted = match encrypted {
270 Some(enc) => enc,
271 None => {
272 return Err(SourceError::ConnectionBroken {
273 connection_id,
274 reason: "no credentials found for connection".into(),
275 });
276 }
277 };
278
279 match connector
281 .refresh_access_token(&encrypted, connector_key)
282 .await
283 {
284 Ok(refreshed) => {
285 let expires_in = refreshed.expires_in_secs.max(0) as u64;
286 Ok(CachedToken {
287 access_token: refreshed.access_token,
288 expires_at: Instant::now() + Duration::from_secs(expires_in),
289 })
290 }
291 Err(ConnectorError::TokenRefresh(msg)) if is_revocation_error(&msg) => {
292 Err(SourceError::ConnectionBroken {
293 connection_id,
294 reason: format!("token revoked: {msg}"),
295 })
296 }
297 Err(ConnectorError::Encryption(msg)) => Err(SourceError::ConnectionBroken {
298 connection_id,
299 reason: format!("credential decryption failed: {msg}"),
300 }),
301 Err(e) => Err(SourceError::Auth(format!(
302 "token refresh failed for connection {connection_id}: {e}"
303 ))),
304 }
305 }
306}
307
308fn is_revocation_error(msg: &str) -> bool {
310 let lower = msg.to_lowercase();
311 lower.contains("revoked")
312 || lower.contains("invalid_grant")
313 || lower.contains("token has been expired or revoked")
314}
315
316#[async_trait]
317impl ContentSourceProvider for GoogleDriveProvider {
318 fn source_type(&self) -> &str {
319 "google_drive"
320 }
321
322 async fn scan_for_changes(
323 &self,
324 since_cursor: Option<&str>,
325 patterns: &[String],
326 ) -> Result<Vec<SourceFile>, SourceError> {
327 let token = self.get_access_token().await?;
328
329 let mut q = format!("'{}' in parents and trashed = false", self.folder_id);
330
331 if let Some(cursor) = since_cursor {
332 q.push_str(&format!(" and modifiedTime > '{cursor}'"));
333 }
334
335 let resp = self
336 .http_client
337 .get("https://www.googleapis.com/drive/v3/files")
338 .bearer_auth(&token)
339 .query(&[
340 ("q", q.as_str()),
341 ("fields", "files(id,name,md5Checksum,modifiedTime,mimeType)"),
342 ("pageSize", "1000"),
343 ])
344 .send()
345 .await
346 .map_err(|e| SourceError::Network(format!("Drive list failed: {e}")))?;
347
348 if !resp.status().is_success() {
349 let body = resp.text().await.unwrap_or_default();
350 return Err(SourceError::Network(format!("Drive API error: {body}")));
351 }
352
353 let body: serde_json::Value = resp
354 .json()
355 .await
356 .map_err(|e| SourceError::Network(format!("invalid Drive response: {e}")))?;
357
358 let files = body["files"].as_array().cloned().unwrap_or_default();
359
360 let mut result = Vec::new();
361 for file in &files {
362 let id = match file["id"].as_str() {
363 Some(id) => id,
364 None => continue,
365 };
366 let name = file["name"].as_str().unwrap_or("unknown");
367
368 if !patterns.is_empty() && !matches_patterns(Path::new(name), patterns) {
369 continue;
370 }
371
372 let hash = file["md5Checksum"].as_str().unwrap_or("").to_string();
373 let modified = file["modifiedTime"].as_str().unwrap_or("").to_string();
374
375 result.push(SourceFile {
376 provider_id: format!("gdrive://{id}/{name}"),
377 display_name: name.to_string(),
378 content_hash: hash,
379 modified_at: modified,
380 });
381 }
382
383 Ok(result)
384 }
385
386 async fn read_content(&self, file_id: &str) -> Result<String, SourceError> {
387 let drive_id = extract_drive_id(file_id)?;
388 let token = self.get_access_token().await?;
389
390 let url = format!("https://www.googleapis.com/drive/v3/files/{drive_id}?alt=media");
391
392 let resp = self
393 .http_client
394 .get(&url)
395 .bearer_auth(&token)
396 .send()
397 .await
398 .map_err(|e| SourceError::Network(format!("Drive get failed: {e}")))?;
399
400 if resp.status() == reqwest::StatusCode::NOT_FOUND {
401 return Err(SourceError::NotFound(format!("file {drive_id} not found")));
402 }
403
404 if !resp.status().is_success() {
405 let body = resp.text().await.unwrap_or_default();
406 return Err(SourceError::Network(format!(
407 "Drive download error: {body}"
408 )));
409 }
410
411 resp.text()
412 .await
413 .map_err(|e| SourceError::Network(format!("read body failed: {e}")))
414 }
415}
416
417#[cfg(test)]
423impl GoogleDriveProvider {
424 pub fn extract_drive_id_for_test(provider_id: &str) -> String {
425 extract_drive_id(provider_id).unwrap()
426 }
427}
428
429fn extract_drive_id(provider_id: &str) -> Result<String, SourceError> {
432 if let Some(rest) = provider_id.strip_prefix("gdrive://") {
433 if let Some(slash) = rest.find('/') {
434 Ok(rest[..slash].to_string())
435 } else {
436 Ok(rest.to_string())
437 }
438 } else {
439 Ok(provider_id.to_string())
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446
447 #[test]
450 fn extract_drive_id_with_prefix_and_name() {
451 let id = extract_drive_id("gdrive://abc123/doc.md").unwrap();
452 assert_eq!(id, "abc123");
453 }
454
455 #[test]
456 fn extract_drive_id_with_prefix_no_name() {
457 let id = extract_drive_id("gdrive://abc123").unwrap();
458 assert_eq!(id, "abc123");
459 }
460
461 #[test]
462 fn extract_drive_id_raw_id() {
463 let id = extract_drive_id("abc123").unwrap();
464 assert_eq!(id, "abc123");
465 }
466
467 #[test]
468 fn extract_drive_id_with_nested_path() {
469 let id = extract_drive_id("gdrive://xyz/folder/file.txt").unwrap();
470 assert_eq!(id, "xyz");
471 }
472
473 #[test]
474 fn extract_drive_id_empty_after_prefix() {
475 let id = extract_drive_id("gdrive://").unwrap();
476 assert_eq!(id, "");
477 }
478
479 #[test]
482 fn revocation_error_detects_revoked() {
483 assert!(is_revocation_error("Token has been revoked by user"));
484 }
485
486 #[test]
487 fn revocation_error_detects_invalid_grant() {
488 assert!(is_revocation_error("invalid_grant: token expired"));
489 }
490
491 #[test]
492 fn revocation_error_detects_expired_or_revoked() {
493 assert!(is_revocation_error("Token has been expired or revoked"));
494 }
495
496 #[test]
497 fn revocation_error_returns_false_for_network_error() {
498 assert!(!is_revocation_error("connection timeout"));
499 }
500
501 #[test]
502 fn revocation_error_case_insensitive() {
503 assert!(is_revocation_error("REVOKED"));
504 assert!(is_revocation_error("Invalid_Grant"));
505 }
506
507 #[test]
510 fn new_creates_service_account_provider() {
511 let provider =
512 GoogleDriveProvider::new("folder123".to_string(), "/path/to/key.json".to_string());
513 assert_eq!(provider.folder_id, "folder123");
514 assert!(matches!(
515 provider.auth_strategy,
516 DriveAuthStrategy::ServiceAccount { .. }
517 ));
518 }
519
520 #[test]
521 fn source_type_is_google_drive() {
522 let provider =
523 GoogleDriveProvider::new("folder123".to_string(), "/path/to/key.json".to_string());
524 assert_eq!(provider.source_type(), "google_drive");
525 }
526
527 #[test]
528 fn with_client_creates_provider_with_custom_http_client() {
529 let client = reqwest::Client::new();
530 let provider = GoogleDriveProvider::with_client(
531 "folder123".to_string(),
532 "/path/to/key.json".to_string(),
533 client,
534 );
535 assert_eq!(provider.folder_id, "folder123");
536 }
537
538 #[test]
539 fn token_cache_starts_empty() {
540 let provider =
541 GoogleDriveProvider::new("folder123".to_string(), "/path/to/key.json".to_string());
542 let cache = provider.token_cache.lock().unwrap();
543 assert!(cache.is_none());
544 }
545
546 #[test]
547 fn extract_drive_id_for_test_helper() {
548 let id = GoogleDriveProvider::extract_drive_id_for_test("gdrive://myid/name.md");
549 assert_eq!(id, "myid");
550 }
551
552 #[test]
555 fn extract_drive_id_long_id() {
556 let id = extract_drive_id("gdrive://1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgVE2upms/doc.md")
557 .unwrap();
558 assert_eq!(id, "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgVE2upms");
559 }
560
561 #[test]
562 fn extract_drive_id_special_chars_in_name() {
563 let id = extract_drive_id("gdrive://abc123/my file (1).md").unwrap();
564 assert_eq!(id, "abc123");
565 }
566
567 #[test]
570 fn revocation_error_empty_string() {
571 assert!(!is_revocation_error(""));
572 }
573
574 #[test]
575 fn revocation_error_partial_match() {
576 assert!(!is_revocation_error("revo")); }
578
579 #[test]
580 fn revocation_error_with_surrounding_text() {
581 assert!(is_revocation_error(
582 "Error: token has been expired or revoked by user"
583 ));
584 }
585
586 #[test]
589 fn new_preserves_folder_id() {
590 let provider = GoogleDriveProvider::new("folder_abc".to_string(), "/key.json".to_string());
591 assert_eq!(provider.folder_id, "folder_abc");
592 }
593
594 #[test]
595 fn new_provider_source_type() {
596 let provider = GoogleDriveProvider::new("f".to_string(), "/k.json".to_string());
597 assert_eq!(provider.source_type(), "google_drive");
598 }
599
600 #[test]
603 fn service_account_strategy_matches() {
604 let provider = GoogleDriveProvider::new("folder".to_string(), "/path/key.json".to_string());
605 assert!(matches!(
606 provider.auth_strategy,
607 DriveAuthStrategy::ServiceAccount { .. }
608 ));
609 }
610
611 #[test]
614 fn cached_token_expires_at_is_in_future() {
615 let token = CachedToken {
616 access_token: "tok".to_string(),
617 expires_at: Instant::now() + Duration::from_secs(3600),
618 };
619 assert!(token.expires_at > Instant::now());
620 }
621
622 #[test]
623 fn cached_token_expired() {
624 let token = CachedToken {
625 access_token: "old_tok".to_string(),
626 expires_at: Instant::now() - Duration::from_secs(1),
627 };
628 assert!(token.expires_at < Instant::now());
629 }
630}