Skip to main content

tuitbot_core/source/google_drive/
mod.rs

1//! Google Drive content source provider.
2//!
3//! Polls a Google Drive folder for `.md` and `.txt` files using the
4//! Drive API v3. Supports two authentication strategies:
5//!
6//! - **ServiceAccount** (legacy): reads a service-account JSON key file,
7//!   builds a JWT, and exchanges it for an access token.
8//! - **LinkedAccount** (new): uses encrypted OAuth refresh tokens stored
9//!   in the `connections` table, refreshing via the connector module.
10//!
11//! `connection_id` takes precedence when both are configured.
12
13mod jwt;
14
15use std::path::Path;
16use std::sync::Mutex;
17use std::time::{Duration, Instant};
18
19use async_trait::async_trait;
20
21use super::{ContentSourceProvider, SourceError, SourceFile};
22use crate::automation::watchtower::matches_patterns;
23use crate::source::connector::google_drive::GoogleDriveConnector;
24use crate::source::connector::{ConnectorError, RemoteConnector};
25use crate::storage::DbPool;
26
27// ---------------------------------------------------------------------------
28// Auth strategy
29// ---------------------------------------------------------------------------
30
31/// How the provider obtains access tokens for Drive API calls.
32pub enum DriveAuthStrategy {
33    /// Legacy: service-account JSON key file.
34    ServiceAccount { key_path: String },
35    /// New: linked-account credentials from connections table.
36    LinkedAccount {
37        connection_id: i64,
38        pool: DbPool,
39        connector_key: Vec<u8>,
40        connector: GoogleDriveConnector,
41    },
42}
43
44// ---------------------------------------------------------------------------
45// Provider
46// ---------------------------------------------------------------------------
47
48/// Google Drive content source provider.
49///
50/// Instantiated when a `google_drive` source is configured with a
51/// valid `folder_id` and either a `service_account_key` path or a
52/// `connection_id` referencing a linked account.
53pub struct GoogleDriveProvider {
54    folder_id: String,
55    auth_strategy: DriveAuthStrategy,
56    http_client: reqwest::Client,
57    token_cache: Mutex<Option<CachedToken>>,
58}
59
60struct CachedToken {
61    access_token: String,
62    expires_at: Instant,
63}
64
65impl GoogleDriveProvider {
66    /// Create a provider using a service-account key file (legacy path).
67    pub fn new(folder_id: String, service_account_key_path: String) -> Self {
68        Self {
69            folder_id,
70            auth_strategy: DriveAuthStrategy::ServiceAccount {
71                key_path: service_account_key_path,
72            },
73            http_client: reqwest::Client::new(),
74            token_cache: Mutex::new(None),
75        }
76    }
77
78    /// Create a provider using linked-account OAuth credentials.
79    pub fn from_connection(
80        folder_id: String,
81        connection_id: i64,
82        pool: DbPool,
83        connector_key: Vec<u8>,
84        connector: GoogleDriveConnector,
85    ) -> Self {
86        Self {
87            folder_id,
88            auth_strategy: DriveAuthStrategy::LinkedAccount {
89                connection_id,
90                pool,
91                connector_key,
92                connector,
93            },
94            http_client: reqwest::Client::new(),
95            token_cache: Mutex::new(None),
96        }
97    }
98
99    /// Build with an explicit HTTP client (for testing with wiremock).
100    #[cfg(test)]
101    pub fn with_client(
102        folder_id: String,
103        service_account_key_path: String,
104        client: reqwest::Client,
105    ) -> Self {
106        Self {
107            folder_id,
108            auth_strategy: DriveAuthStrategy::ServiceAccount {
109                key_path: service_account_key_path,
110            },
111            http_client: client,
112            token_cache: Mutex::new(None),
113        }
114    }
115
116    /// Build with an explicit HTTP client and linked-account strategy (for testing).
117    #[cfg(test)]
118    pub fn with_client_and_connection(
119        folder_id: String,
120        connection_id: i64,
121        pool: DbPool,
122        connector_key: Vec<u8>,
123        connector: GoogleDriveConnector,
124        client: reqwest::Client,
125    ) -> Self {
126        Self {
127            folder_id,
128            auth_strategy: DriveAuthStrategy::LinkedAccount {
129                connection_id,
130                pool,
131                connector_key,
132                connector,
133            },
134            http_client: client,
135            token_cache: Mutex::new(None),
136        }
137    }
138
139    /// Obtain a valid access token, refreshing if expired.
140    async fn get_access_token(&self) -> Result<String, SourceError> {
141        // Check cache.
142        if let Ok(cache) = self.token_cache.lock() {
143            if let Some(ref tok) = *cache {
144                if tok.expires_at > Instant::now() + Duration::from_secs(60) {
145                    return Ok(tok.access_token.clone());
146                }
147            }
148        }
149
150        let token = match &self.auth_strategy {
151            DriveAuthStrategy::ServiceAccount { key_path } => {
152                self.fetch_service_account_token(key_path).await?
153            }
154            DriveAuthStrategy::LinkedAccount {
155                connection_id,
156                pool,
157                connector_key,
158                connector,
159            } => {
160                self.refresh_from_connection(*connection_id, pool, connector_key, connector)
161                    .await?
162            }
163        };
164
165        let access_token = token.access_token.clone();
166
167        if let Ok(mut cache) = self.token_cache.lock() {
168            *cache = Some(token);
169        }
170
171        Ok(access_token)
172    }
173
174    /// Read the service-account key, build a JWT, and exchange for an
175    /// access token via Google's token endpoint.
176    async fn fetch_service_account_token(
177        &self,
178        key_path: &str,
179    ) -> Result<CachedToken, SourceError> {
180        let key_bytes = tokio::fs::read_to_string(key_path).await.map_err(|e| {
181            SourceError::Auth(format!("cannot read service account key {key_path}: {e}"))
182        })?;
183
184        let key_json: serde_json::Value = serde_json::from_str(&key_bytes)
185            .map_err(|e| SourceError::Auth(format!("invalid service account JSON: {e}")))?;
186
187        let client_email = key_json["client_email"]
188            .as_str()
189            .ok_or_else(|| SourceError::Auth("missing client_email in key".into()))?;
190
191        let private_key_pem = key_json["private_key"]
192            .as_str()
193            .ok_or_else(|| SourceError::Auth("missing private_key in key".into()))?;
194
195        let token_uri = key_json["token_uri"]
196            .as_str()
197            .unwrap_or("https://oauth2.googleapis.com/token");
198
199        let now = chrono::Utc::now().timestamp();
200        let claims = serde_json::json!({
201            "iss": client_email,
202            "scope": "https://www.googleapis.com/auth/drive.readonly",
203            "aud": token_uri,
204            "iat": now,
205            "exp": now + 3600,
206        });
207
208        let jwt_token = jwt::build_jwt(&claims, private_key_pem)?;
209
210        let resp = self
211            .http_client
212            .post(token_uri)
213            .form(&[
214                ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
215                ("assertion", &jwt_token),
216            ])
217            .send()
218            .await
219            .map_err(|e| SourceError::Auth(format!("token exchange failed: {e}")))?;
220
221        if !resp.status().is_success() {
222            let body = resp.text().await.unwrap_or_default();
223            return Err(SourceError::Auth(format!(
224                "token endpoint returned error: {body}"
225            )));
226        }
227
228        let body: serde_json::Value = resp
229            .json()
230            .await
231            .map_err(|e| SourceError::Auth(format!("invalid token response: {e}")))?;
232
233        let access_token = body["access_token"]
234            .as_str()
235            .ok_or_else(|| SourceError::Auth("no access_token in response".into()))?
236            .to_string();
237
238        let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
239
240        Ok(CachedToken {
241            access_token,
242            expires_at: Instant::now() + Duration::from_secs(expires_in),
243        })
244    }
245
246    /// Refresh an access token from linked-account credentials.
247    ///
248    /// 1. Check in-memory cache -- return if valid (>60s remaining).
249    /// 2. Read encrypted credentials from DB -- if None, return ConnectionBroken.
250    /// 3. Call connector's refresh_access_token.
251    /// 4. On success: cache the new token, return it.
252    /// 5. On revocation/irrecoverable error: return ConnectionBroken.
253    /// 6. On transient error: return Auth error.
254    async fn refresh_from_connection(
255        &self,
256        connection_id: i64,
257        pool: &DbPool,
258        connector_key: &[u8],
259        connector: &GoogleDriveConnector,
260    ) -> Result<CachedToken, SourceError> {
261        // Read encrypted credentials from DB.
262        let encrypted = crate::storage::watchtower::read_encrypted_credentials(pool, connection_id)
263            .await
264            .map_err(|e| SourceError::ConnectionBroken {
265                connection_id,
266                reason: format!("failed to read credentials: {e}"),
267            })?;
268
269        let encrypted = match encrypted {
270            Some(enc) => enc,
271            None => {
272                return Err(SourceError::ConnectionBroken {
273                    connection_id,
274                    reason: "no credentials found for connection".into(),
275                });
276            }
277        };
278
279        // Call connector refresh.
280        match connector
281            .refresh_access_token(&encrypted, connector_key)
282            .await
283        {
284            Ok(refreshed) => {
285                let expires_in = refreshed.expires_in_secs.max(0) as u64;
286                Ok(CachedToken {
287                    access_token: refreshed.access_token,
288                    expires_at: Instant::now() + Duration::from_secs(expires_in),
289                })
290            }
291            Err(ConnectorError::TokenRefresh(msg)) if is_revocation_error(&msg) => {
292                Err(SourceError::ConnectionBroken {
293                    connection_id,
294                    reason: format!("token revoked: {msg}"),
295                })
296            }
297            Err(ConnectorError::Encryption(msg)) => Err(SourceError::ConnectionBroken {
298                connection_id,
299                reason: format!("credential decryption failed: {msg}"),
300            }),
301            Err(e) => Err(SourceError::Auth(format!(
302                "token refresh failed for connection {connection_id}: {e}"
303            ))),
304        }
305    }
306}
307
308/// Check if a refresh failure message indicates token revocation.
309fn is_revocation_error(msg: &str) -> bool {
310    let lower = msg.to_lowercase();
311    lower.contains("revoked")
312        || lower.contains("invalid_grant")
313        || lower.contains("token has been expired or revoked")
314}
315
316#[async_trait]
317impl ContentSourceProvider for GoogleDriveProvider {
318    fn source_type(&self) -> &str {
319        "google_drive"
320    }
321
322    async fn scan_for_changes(
323        &self,
324        since_cursor: Option<&str>,
325        patterns: &[String],
326    ) -> Result<Vec<SourceFile>, SourceError> {
327        let token = self.get_access_token().await?;
328
329        let mut q = format!("'{}' in parents and trashed = false", self.folder_id);
330
331        if let Some(cursor) = since_cursor {
332            q.push_str(&format!(" and modifiedTime > '{cursor}'"));
333        }
334
335        let resp = self
336            .http_client
337            .get("https://www.googleapis.com/drive/v3/files")
338            .bearer_auth(&token)
339            .query(&[
340                ("q", q.as_str()),
341                ("fields", "files(id,name,md5Checksum,modifiedTime,mimeType)"),
342                ("pageSize", "1000"),
343            ])
344            .send()
345            .await
346            .map_err(|e| SourceError::Network(format!("Drive list failed: {e}")))?;
347
348        if !resp.status().is_success() {
349            let body = resp.text().await.unwrap_or_default();
350            return Err(SourceError::Network(format!("Drive API error: {body}")));
351        }
352
353        let body: serde_json::Value = resp
354            .json()
355            .await
356            .map_err(|e| SourceError::Network(format!("invalid Drive response: {e}")))?;
357
358        let files = body["files"].as_array().cloned().unwrap_or_default();
359
360        let mut result = Vec::new();
361        for file in &files {
362            let id = match file["id"].as_str() {
363                Some(id) => id,
364                None => continue,
365            };
366            let name = file["name"].as_str().unwrap_or("unknown");
367
368            if !patterns.is_empty() && !matches_patterns(Path::new(name), patterns) {
369                continue;
370            }
371
372            let hash = file["md5Checksum"].as_str().unwrap_or("").to_string();
373            let modified = file["modifiedTime"].as_str().unwrap_or("").to_string();
374
375            result.push(SourceFile {
376                provider_id: format!("gdrive://{id}/{name}"),
377                display_name: name.to_string(),
378                content_hash: hash,
379                modified_at: modified,
380            });
381        }
382
383        Ok(result)
384    }
385
386    async fn read_content(&self, file_id: &str) -> Result<String, SourceError> {
387        let drive_id = extract_drive_id(file_id)?;
388        let token = self.get_access_token().await?;
389
390        let url = format!("https://www.googleapis.com/drive/v3/files/{drive_id}?alt=media");
391
392        let resp = self
393            .http_client
394            .get(&url)
395            .bearer_auth(&token)
396            .send()
397            .await
398            .map_err(|e| SourceError::Network(format!("Drive get failed: {e}")))?;
399
400        if resp.status() == reqwest::StatusCode::NOT_FOUND {
401            return Err(SourceError::NotFound(format!("file {drive_id} not found")));
402        }
403
404        if !resp.status().is_success() {
405            let body = resp.text().await.unwrap_or_default();
406            return Err(SourceError::Network(format!(
407                "Drive download error: {body}"
408            )));
409        }
410
411        resp.text()
412            .await
413            .map_err(|e| SourceError::Network(format!("read body failed: {e}")))
414    }
415}
416
417// ---------------------------------------------------------------------------
418// Helpers
419// ---------------------------------------------------------------------------
420
421/// Test-only accessor for `extract_drive_id`.
422#[cfg(test)]
423impl GoogleDriveProvider {
424    pub fn extract_drive_id_for_test(provider_id: &str) -> String {
425        extract_drive_id(provider_id).unwrap()
426    }
427}
428
429/// Extract Drive file ID from `gdrive://<id>/<name>` format.
430/// Also accepts a raw ID without the prefix.
431fn extract_drive_id(provider_id: &str) -> Result<String, SourceError> {
432    if let Some(rest) = provider_id.strip_prefix("gdrive://") {
433        if let Some(slash) = rest.find('/') {
434            Ok(rest[..slash].to_string())
435        } else {
436            Ok(rest.to_string())
437        }
438    } else {
439        Ok(provider_id.to_string())
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    // ── extract_drive_id ─────────────────────────────────────────────
448
449    #[test]
450    fn extract_drive_id_with_prefix_and_name() {
451        let id = extract_drive_id("gdrive://abc123/doc.md").unwrap();
452        assert_eq!(id, "abc123");
453    }
454
455    #[test]
456    fn extract_drive_id_with_prefix_no_name() {
457        let id = extract_drive_id("gdrive://abc123").unwrap();
458        assert_eq!(id, "abc123");
459    }
460
461    #[test]
462    fn extract_drive_id_raw_id() {
463        let id = extract_drive_id("abc123").unwrap();
464        assert_eq!(id, "abc123");
465    }
466
467    #[test]
468    fn extract_drive_id_with_nested_path() {
469        let id = extract_drive_id("gdrive://xyz/folder/file.txt").unwrap();
470        assert_eq!(id, "xyz");
471    }
472
473    #[test]
474    fn extract_drive_id_empty_after_prefix() {
475        let id = extract_drive_id("gdrive://").unwrap();
476        assert_eq!(id, "");
477    }
478
479    // ── is_revocation_error ──────────────────────────────────────────
480
481    #[test]
482    fn revocation_error_detects_revoked() {
483        assert!(is_revocation_error("Token has been revoked by user"));
484    }
485
486    #[test]
487    fn revocation_error_detects_invalid_grant() {
488        assert!(is_revocation_error("invalid_grant: token expired"));
489    }
490
491    #[test]
492    fn revocation_error_detects_expired_or_revoked() {
493        assert!(is_revocation_error("Token has been expired or revoked"));
494    }
495
496    #[test]
497    fn revocation_error_returns_false_for_network_error() {
498        assert!(!is_revocation_error("connection timeout"));
499    }
500
501    #[test]
502    fn revocation_error_case_insensitive() {
503        assert!(is_revocation_error("REVOKED"));
504        assert!(is_revocation_error("Invalid_Grant"));
505    }
506
507    // ── GoogleDriveProvider construction ─────────────────────────────
508
509    #[test]
510    fn new_creates_service_account_provider() {
511        let provider =
512            GoogleDriveProvider::new("folder123".to_string(), "/path/to/key.json".to_string());
513        assert_eq!(provider.folder_id, "folder123");
514        assert!(matches!(
515            provider.auth_strategy,
516            DriveAuthStrategy::ServiceAccount { .. }
517        ));
518    }
519
520    #[test]
521    fn source_type_is_google_drive() {
522        let provider =
523            GoogleDriveProvider::new("folder123".to_string(), "/path/to/key.json".to_string());
524        assert_eq!(provider.source_type(), "google_drive");
525    }
526
527    #[test]
528    fn with_client_creates_provider_with_custom_http_client() {
529        let client = reqwest::Client::new();
530        let provider = GoogleDriveProvider::with_client(
531            "folder123".to_string(),
532            "/path/to/key.json".to_string(),
533            client,
534        );
535        assert_eq!(provider.folder_id, "folder123");
536    }
537
538    #[test]
539    fn token_cache_starts_empty() {
540        let provider =
541            GoogleDriveProvider::new("folder123".to_string(), "/path/to/key.json".to_string());
542        let cache = provider.token_cache.lock().unwrap();
543        assert!(cache.is_none());
544    }
545
546    #[test]
547    fn extract_drive_id_for_test_helper() {
548        let id = GoogleDriveProvider::extract_drive_id_for_test("gdrive://myid/name.md");
549        assert_eq!(id, "myid");
550    }
551
552    // ── extract_drive_id edge cases ──────────────────────────────────
553
554    #[test]
555    fn extract_drive_id_long_id() {
556        let id = extract_drive_id("gdrive://1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgVE2upms/doc.md")
557            .unwrap();
558        assert_eq!(id, "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgVE2upms");
559    }
560
561    #[test]
562    fn extract_drive_id_special_chars_in_name() {
563        let id = extract_drive_id("gdrive://abc123/my file (1).md").unwrap();
564        assert_eq!(id, "abc123");
565    }
566
567    // ── is_revocation_error edge cases ───────────────────────────────
568
569    #[test]
570    fn revocation_error_empty_string() {
571        assert!(!is_revocation_error(""));
572    }
573
574    #[test]
575    fn revocation_error_partial_match() {
576        assert!(!is_revocation_error("revo")); // partial "revoked"
577    }
578
579    #[test]
580    fn revocation_error_with_surrounding_text() {
581        assert!(is_revocation_error(
582            "Error: token has been expired or revoked by user"
583        ));
584    }
585
586    // ── GoogleDriveProvider constructors ──────────────────────────────
587
588    #[test]
589    fn new_preserves_folder_id() {
590        let provider = GoogleDriveProvider::new("folder_abc".to_string(), "/key.json".to_string());
591        assert_eq!(provider.folder_id, "folder_abc");
592    }
593
594    #[test]
595    fn new_provider_source_type() {
596        let provider = GoogleDriveProvider::new("f".to_string(), "/k.json".to_string());
597        assert_eq!(provider.source_type(), "google_drive");
598    }
599
600    // ── DriveAuthStrategy matching ───────────────────────────────────
601
602    #[test]
603    fn service_account_strategy_matches() {
604        let provider = GoogleDriveProvider::new("folder".to_string(), "/path/key.json".to_string());
605        assert!(matches!(
606            provider.auth_strategy,
607            DriveAuthStrategy::ServiceAccount { .. }
608        ));
609    }
610
611    // ── CachedToken usage ────────────────────────────────────────────
612
613    #[test]
614    fn cached_token_expires_at_is_in_future() {
615        let token = CachedToken {
616            access_token: "tok".to_string(),
617            expires_at: Instant::now() + Duration::from_secs(3600),
618        };
619        assert!(token.expires_at > Instant::now());
620    }
621
622    #[test]
623    fn cached_token_expired() {
624        let token = CachedToken {
625            access_token: "old_tok".to_string(),
626            expires_at: Instant::now() - Duration::from_secs(1),
627        };
628        assert!(token.expires_at < Instant::now());
629    }
630}