Skip to main content

tuitbot_core/source/
google_drive.rs

1//! Google Drive content source provider.
2//!
3//! Polls a Google Drive folder for `.md` and `.txt` files using the
4//! Drive API v3 with service-account authentication.  The provider
5//! records stable Google Drive file IDs as `provider_id` values in
6//! `gdrive://<file_id>/<filename>` format for deduplication.
7//!
8//! Authentication uses a service-account JSON key: the provider reads
9//! the key file, builds a JWT, and exchanges it for an access token
10//! via Google's OAuth2 token endpoint.  Tokens are cached in memory
11//! with expiry tracking.
12
13use std::path::Path;
14use std::sync::Mutex;
15use std::time::{Duration, Instant};
16
17use async_trait::async_trait;
18use sha2::{Digest, Sha256};
19
20use super::{ContentSourceProvider, SourceError, SourceFile};
21use crate::automation::watchtower::matches_patterns;
22
23// ---------------------------------------------------------------------------
24// Provider
25// ---------------------------------------------------------------------------
26
27/// Google Drive content source provider.
28///
29/// Instantiated only when a `google_drive` source is configured with a
30/// valid `folder_id` and `service_account_key` path.
31pub struct GoogleDriveProvider {
32    folder_id: String,
33    service_account_key_path: String,
34    http_client: reqwest::Client,
35    token_cache: Mutex<Option<CachedToken>>,
36}
37
38struct CachedToken {
39    access_token: String,
40    expires_at: Instant,
41}
42
43impl GoogleDriveProvider {
44    pub fn new(folder_id: String, service_account_key_path: String) -> Self {
45        Self {
46            folder_id,
47            service_account_key_path,
48            http_client: reqwest::Client::new(),
49            token_cache: Mutex::new(None),
50        }
51    }
52
53    /// Build with an explicit HTTP client (for testing with wiremock).
54    #[cfg(test)]
55    pub fn with_client(
56        folder_id: String,
57        service_account_key_path: String,
58        client: reqwest::Client,
59    ) -> Self {
60        Self {
61            folder_id,
62            service_account_key_path,
63            http_client: client,
64            token_cache: Mutex::new(None),
65        }
66    }
67
68    /// Obtain a valid access token, refreshing if expired.
69    async fn get_access_token(&self) -> Result<String, SourceError> {
70        // Check cache.
71        if let Ok(cache) = self.token_cache.lock() {
72            if let Some(ref tok) = *cache {
73                if tok.expires_at > Instant::now() + Duration::from_secs(60) {
74                    return Ok(tok.access_token.clone());
75                }
76            }
77        }
78
79        let token = self.fetch_new_token().await?;
80        let access_token = token.access_token.clone();
81
82        if let Ok(mut cache) = self.token_cache.lock() {
83            *cache = Some(token);
84        }
85
86        Ok(access_token)
87    }
88
89    /// Read the service-account key, build a JWT, and exchange for an
90    /// access token via Google's token endpoint.
91    async fn fetch_new_token(&self) -> Result<CachedToken, SourceError> {
92        let key_bytes = tokio::fs::read_to_string(&self.service_account_key_path)
93            .await
94            .map_err(|e| {
95                SourceError::Auth(format!(
96                    "cannot read service account key {}: {e}",
97                    self.service_account_key_path
98                ))
99            })?;
100
101        let key_json: serde_json::Value = serde_json::from_str(&key_bytes)
102            .map_err(|e| SourceError::Auth(format!("invalid service account JSON: {e}")))?;
103
104        let client_email = key_json["client_email"]
105            .as_str()
106            .ok_or_else(|| SourceError::Auth("missing client_email in key".into()))?;
107
108        let private_key_pem = key_json["private_key"]
109            .as_str()
110            .ok_or_else(|| SourceError::Auth("missing private_key in key".into()))?;
111
112        let token_uri = key_json["token_uri"]
113            .as_str()
114            .unwrap_or("https://oauth2.googleapis.com/token");
115
116        // Build JWT claims.
117        let now = chrono::Utc::now().timestamp();
118        let claims = serde_json::json!({
119            "iss": client_email,
120            "scope": "https://www.googleapis.com/auth/drive.readonly",
121            "aud": token_uri,
122            "iat": now,
123            "exp": now + 3600,
124        });
125
126        let jwt = build_jwt(&claims, private_key_pem)?;
127
128        // Exchange JWT for access token.
129        let resp = self
130            .http_client
131            .post(token_uri)
132            .form(&[
133                ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
134                ("assertion", &jwt),
135            ])
136            .send()
137            .await
138            .map_err(|e| SourceError::Auth(format!("token exchange failed: {e}")))?;
139
140        if !resp.status().is_success() {
141            let body = resp.text().await.unwrap_or_default();
142            return Err(SourceError::Auth(format!(
143                "token endpoint returned error: {body}"
144            )));
145        }
146
147        let body: serde_json::Value = resp
148            .json()
149            .await
150            .map_err(|e| SourceError::Auth(format!("invalid token response: {e}")))?;
151
152        let access_token = body["access_token"]
153            .as_str()
154            .ok_or_else(|| SourceError::Auth("no access_token in response".into()))?
155            .to_string();
156
157        let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
158
159        Ok(CachedToken {
160            access_token,
161            expires_at: Instant::now() + Duration::from_secs(expires_in),
162        })
163    }
164}
165
166#[async_trait]
167impl ContentSourceProvider for GoogleDriveProvider {
168    fn source_type(&self) -> &str {
169        "google_drive"
170    }
171
172    async fn scan_for_changes(
173        &self,
174        since_cursor: Option<&str>,
175        patterns: &[String],
176    ) -> Result<Vec<SourceFile>, SourceError> {
177        let token = self.get_access_token().await?;
178
179        // Build query: files in this folder, not trashed.
180        let mut q = format!("'{}' in parents and trashed = false", self.folder_id);
181
182        // Filter by modified time if we have a cursor.
183        if let Some(cursor) = since_cursor {
184            q.push_str(&format!(" and modifiedTime > '{cursor}'"));
185        }
186
187        let resp = self
188            .http_client
189            .get("https://www.googleapis.com/drive/v3/files")
190            .bearer_auth(&token)
191            .query(&[
192                ("q", q.as_str()),
193                ("fields", "files(id,name,md5Checksum,modifiedTime,mimeType)"),
194                ("pageSize", "1000"),
195            ])
196            .send()
197            .await
198            .map_err(|e| SourceError::Network(format!("Drive list failed: {e}")))?;
199
200        if !resp.status().is_success() {
201            let body = resp.text().await.unwrap_or_default();
202            return Err(SourceError::Network(format!("Drive API error: {body}")));
203        }
204
205        let body: serde_json::Value = resp
206            .json()
207            .await
208            .map_err(|e| SourceError::Network(format!("invalid Drive response: {e}")))?;
209
210        let files = body["files"].as_array().cloned().unwrap_or_default();
211
212        let mut result = Vec::new();
213        for file in &files {
214            let id = match file["id"].as_str() {
215                Some(id) => id,
216                None => continue,
217            };
218            let name = file["name"].as_str().unwrap_or("unknown");
219
220            // Filter by patterns (match against filename).
221            if !patterns.is_empty() && !matches_patterns(Path::new(name), patterns) {
222                continue;
223            }
224
225            let hash = file["md5Checksum"].as_str().unwrap_or("").to_string();
226            let modified = file["modifiedTime"].as_str().unwrap_or("").to_string();
227
228            result.push(SourceFile {
229                provider_id: format!("gdrive://{id}/{name}"),
230                display_name: name.to_string(),
231                content_hash: hash,
232                modified_at: modified,
233            });
234        }
235
236        Ok(result)
237    }
238
239    async fn read_content(&self, file_id: &str) -> Result<String, SourceError> {
240        // Extract the Drive file ID from our provider_id format.
241        let drive_id = extract_drive_id(file_id)?;
242
243        let token = self.get_access_token().await?;
244
245        let url = format!("https://www.googleapis.com/drive/v3/files/{drive_id}?alt=media");
246
247        let resp = self
248            .http_client
249            .get(&url)
250            .bearer_auth(&token)
251            .send()
252            .await
253            .map_err(|e| SourceError::Network(format!("Drive get failed: {e}")))?;
254
255        if resp.status() == reqwest::StatusCode::NOT_FOUND {
256            return Err(SourceError::NotFound(format!("file {drive_id} not found")));
257        }
258
259        if !resp.status().is_success() {
260            let body = resp.text().await.unwrap_or_default();
261            return Err(SourceError::Network(format!(
262                "Drive download error: {body}"
263            )));
264        }
265
266        resp.text()
267            .await
268            .map_err(|e| SourceError::Network(format!("read body failed: {e}")))
269    }
270}
271
272// ---------------------------------------------------------------------------
273// Helpers
274// ---------------------------------------------------------------------------
275
276/// Test-only accessor for `extract_drive_id`.
277#[cfg(test)]
278impl GoogleDriveProvider {
279    pub fn extract_drive_id_for_test(provider_id: &str) -> String {
280        extract_drive_id(provider_id).unwrap()
281    }
282}
283
284/// Extract Drive file ID from `gdrive://<id>/<name>` format.
285/// Also accepts a raw ID without the prefix.
286fn extract_drive_id(provider_id: &str) -> Result<String, SourceError> {
287    if let Some(rest) = provider_id.strip_prefix("gdrive://") {
288        if let Some(slash) = rest.find('/') {
289            Ok(rest[..slash].to_string())
290        } else {
291            Ok(rest.to_string())
292        }
293    } else {
294        Ok(provider_id.to_string())
295    }
296}
297
298/// Build a signed JWT for Google service-account auth.
299///
300/// Uses RS256 (RSA + SHA-256). The private key is parsed from PEM format.
301fn build_jwt(claims: &serde_json::Value, private_key_pem: &str) -> Result<String, SourceError> {
302    let header = base64_url_encode(
303        &serde_json::to_vec(&serde_json::json!({"alg": "RS256", "typ": "JWT"}))
304            .map_err(|e| SourceError::Auth(format!("JWT header: {e}")))?,
305    );
306    let payload = base64_url_encode(
307        &serde_json::to_vec(claims).map_err(|e| SourceError::Auth(format!("JWT payload: {e}")))?,
308    );
309
310    let signing_input = format!("{header}.{payload}");
311
312    let signature = rsa_sign_sha256(signing_input.as_bytes(), private_key_pem)?;
313    let sig_b64 = base64_url_encode(&signature);
314
315    Ok(format!("{signing_input}.{sig_b64}"))
316}
317
318/// RSA-SHA256 signing using the `rsa` crate (already an indirect dep via oauth2).
319fn rsa_sign_sha256(data: &[u8], pem: &str) -> Result<Vec<u8>, SourceError> {
320    // Parse PEM to DER.
321    let der = pem_to_der(pem)?;
322
323    // Hash with SHA-256.
324    let hash = Sha256::digest(data);
325
326    // Build PKCS#1 v1.5 DigestInfo for SHA-256.
327    let digest_info = build_pkcs1_digest_info(&hash);
328
329    // Parse RSA private key and sign.
330    rsa_pkcs1_sign(&der, &digest_info)
331}
332
333/// Decode a PEM-encoded RSA private key to DER bytes.
334fn pem_to_der(pem: &str) -> Result<Vec<u8>, SourceError> {
335    let pem = pem.trim();
336    let body: String = pem
337        .lines()
338        .filter(|line| !line.starts_with("-----"))
339        .collect::<Vec<_>>()
340        .join("");
341
342    use base64::Engine;
343    base64::engine::general_purpose::STANDARD
344        .decode(&body)
345        .map_err(|e| SourceError::Auth(format!("PEM decode failed: {e}")))
346}
347
348/// Build PKCS#1 v1.5 DigestInfo prefix for SHA-256.
349fn build_pkcs1_digest_info(hash: &[u8]) -> Vec<u8> {
350    // DER encoding of DigestInfo for SHA-256:
351    // SEQUENCE { SEQUENCE { OID sha256, NULL }, OCTET STRING hash }
352    let prefix: &[u8] = &[
353        0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01,
354        0x05, 0x00, 0x04, 0x20,
355    ];
356    let mut info = prefix.to_vec();
357    info.extend_from_slice(hash);
358    info
359}
360
361/// Minimal RSA PKCS#1 v1.5 signing from DER-encoded private key.
362///
363/// Parses PKCS#8 (the format Google uses) and extracts modulus + private
364/// exponent, then performs raw RSA: `signature = message^d mod n`.
365fn rsa_pkcs1_sign(der: &[u8], digest_info: &[u8]) -> Result<Vec<u8>, SourceError> {
366    // Parse PKCS#8 wrapper to get the inner RSA key.
367    let rsa_key = parse_pkcs8_rsa(der)?;
368
369    let k = rsa_key.n_bytes.len(); // key length in bytes
370
371    // PKCS#1 v1.5 padding: 0x00 0x01 [0xFF...] 0x00 [DigestInfo]
372    if digest_info.len() + 11 > k {
373        return Err(SourceError::Auth("RSA key too small for signature".into()));
374    }
375
376    let mut em = vec![0x00, 0x01];
377    let ps_len = k - digest_info.len() - 3;
378    em.extend(std::iter::repeat(0xFF).take(ps_len));
379    em.push(0x00);
380    em.extend_from_slice(digest_info);
381
382    // Convert to big integer and compute m^d mod n.
383    let m = BigUint::from_bytes_be(&em);
384    let n = BigUint::from_bytes_be(&rsa_key.n_bytes);
385    let d = BigUint::from_bytes_be(&rsa_key.d_bytes);
386
387    let sig = mod_pow(&m, &d, &n);
388    let mut sig_bytes = sig.to_bytes_be();
389
390    // Left-pad to key length.
391    while sig_bytes.len() < k {
392        sig_bytes.insert(0, 0);
393    }
394
395    Ok(sig_bytes)
396}
397
398// ---------------------------------------------------------------------------
399// Minimal big-integer arithmetic for RSA
400// ---------------------------------------------------------------------------
401
402/// Simple big unsigned integer backed by a byte vector (big-endian).
403#[derive(Clone)]
404struct BigUint {
405    /// Big-endian bytes with no leading zeros (except for zero itself).
406    bytes: Vec<u8>,
407}
408
409impl BigUint {
410    fn from_bytes_be(b: &[u8]) -> Self {
411        let start = b
412            .iter()
413            .position(|&x| x != 0)
414            .unwrap_or(b.len().saturating_sub(1));
415        Self {
416            bytes: b[start..].to_vec(),
417        }
418    }
419
420    fn to_bytes_be(&self) -> Vec<u8> {
421        self.bytes.clone()
422    }
423
424    fn is_zero(&self) -> bool {
425        self.bytes.iter().all(|&b| b == 0)
426    }
427
428    fn bit_len(&self) -> usize {
429        if self.is_zero() {
430            return 0;
431        }
432        let top = self.bytes[0];
433        (self.bytes.len() - 1) * 8 + (8 - top.leading_zeros() as usize)
434    }
435
436    fn bit(&self, i: usize) -> bool {
437        let byte_idx = self.bytes.len() - 1 - i / 8;
438        if byte_idx >= self.bytes.len() {
439            return false;
440        }
441        (self.bytes[byte_idx] >> (i % 8)) & 1 == 1
442    }
443
444    fn mul_mod(a: &BigUint, b: &BigUint, m: &BigUint) -> BigUint {
445        // Simple schoolbook multiplication then mod.
446        let a_len = a.bytes.len();
447        let b_len = b.bytes.len();
448        let mut result = vec![0u32; a_len + b_len];
449
450        for i in (0..a_len).rev() {
451            let mut carry: u32 = 0;
452            for j in (0..b_len).rev() {
453                let prod = (a.bytes[i] as u32) * (b.bytes[j] as u32) + result[i + j + 1] + carry;
454                result[i + j + 1] = prod & 0xFF;
455                carry = prod >> 8;
456            }
457            result[i] += carry;
458        }
459
460        let bytes: Vec<u8> = result.iter().map(|&x| x as u8).collect();
461        let val = BigUint::from_bytes_be(&bytes);
462        BigUint::modulo(&val, m)
463    }
464
465    fn modulo(a: &BigUint, m: &BigUint) -> BigUint {
466        if a.bytes.len() < m.bytes.len() {
467            return a.clone();
468        }
469        // Binary long division.
470        let mut remainder = BigUint::from_bytes_be(&[0]);
471        let total_bits = a.bytes.len() * 8;
472
473        for i in (0..total_bits).rev() {
474            // Shift remainder left by 1 and add next bit of a.
475            remainder = BigUint::shift_left_one(&remainder);
476            if a.bit(i) {
477                let last = remainder.bytes.len() - 1;
478                remainder.bytes[last] |= 1;
479            }
480            if BigUint::gte(&remainder, m) {
481                remainder = BigUint::sub(&remainder, m);
482            }
483        }
484        remainder
485    }
486
487    fn shift_left_one(a: &BigUint) -> BigUint {
488        let mut result = vec![0u8; a.bytes.len() + 1];
489        let mut carry = 0u8;
490        for i in (0..a.bytes.len()).rev() {
491            let val = (a.bytes[i] as u16) * 2 + carry as u16;
492            result[i + 1] = val as u8;
493            carry = (val >> 8) as u8;
494        }
495        result[0] = carry;
496        BigUint::from_bytes_be(&result)
497    }
498
499    fn gte(a: &BigUint, b: &BigUint) -> bool {
500        if a.bytes.len() != b.bytes.len() {
501            return a.bytes.len() > b.bytes.len();
502        }
503        a.bytes >= b.bytes
504    }
505
506    fn sub(a: &BigUint, b: &BigUint) -> BigUint {
507        let len = a.bytes.len().max(b.bytes.len());
508        let mut result = vec![0i16; len];
509        let a_off = len - a.bytes.len();
510        let b_off = len - b.bytes.len();
511
512        for i in 0..a.bytes.len() {
513            result[a_off + i] += a.bytes[i] as i16;
514        }
515        for i in 0..b.bytes.len() {
516            result[b_off + i] -= b.bytes[i] as i16;
517        }
518
519        // Propagate borrows.
520        for i in (1..len).rev() {
521            if result[i] < 0 {
522                result[i] += 256;
523                result[i - 1] -= 1;
524            }
525        }
526
527        let bytes: Vec<u8> = result.iter().map(|&x| x as u8).collect();
528        BigUint::from_bytes_be(&bytes)
529    }
530}
531
532/// Modular exponentiation: base^exp mod modulus.
533fn mod_pow(base: &BigUint, exp: &BigUint, modulus: &BigUint) -> BigUint {
534    let bits = exp.bit_len();
535    if bits == 0 {
536        return BigUint::from_bytes_be(&[1]);
537    }
538
539    let mut result = BigUint::from_bytes_be(&[1]);
540    let mut b = BigUint::modulo(base, modulus);
541
542    for i in 0..bits {
543        if exp.bit(i) {
544            result = BigUint::mul_mod(&result, &b, modulus);
545        }
546        b = BigUint::mul_mod(&b, &b, modulus);
547    }
548    result
549}
550
551// ---------------------------------------------------------------------------
552// ASN.1/DER parsing for PKCS#8 RSA keys
553// ---------------------------------------------------------------------------
554
555struct RsaKeyParts {
556    n_bytes: Vec<u8>,
557    d_bytes: Vec<u8>,
558}
559
560/// Parse a PKCS#8 DER-encoded RSA private key and extract (n, d).
561fn parse_pkcs8_rsa(der: &[u8]) -> Result<RsaKeyParts, SourceError> {
562    // PKCS#8 is a SEQUENCE containing:
563    //   INTEGER version
564    //   SEQUENCE { OID rsaEncryption, NULL }
565    //   OCTET STRING (containing PKCS#1 RSA private key)
566    let (_, inner) = parse_der_sequence(der)
567        .map_err(|_| SourceError::Auth("invalid PKCS#8 outer SEQUENCE".into()))?;
568
569    // Skip version INTEGER.
570    let rest =
571        skip_der_element(inner).map_err(|_| SourceError::Auth("invalid PKCS#8 version".into()))?;
572
573    // Skip algorithm SEQUENCE.
574    let rest =
575        skip_der_element(rest).map_err(|_| SourceError::Auth("invalid PKCS#8 algorithm".into()))?;
576
577    // Parse OCTET STRING containing the PKCS#1 key.
578    let (_, pkcs1_der) = parse_der_octet_string(rest)
579        .map_err(|_| SourceError::Auth("invalid PKCS#8 octet string".into()))?;
580
581    parse_pkcs1_rsa(pkcs1_der)
582}
583
584/// Parse a PKCS#1 DER-encoded RSA private key and extract (n, d).
585fn parse_pkcs1_rsa(der: &[u8]) -> Result<RsaKeyParts, SourceError> {
586    // SEQUENCE { version, n, e, d, p, q, dp, dq, qinv }
587    let (_, inner) =
588        parse_der_sequence(der).map_err(|_| SourceError::Auth("invalid PKCS#1 SEQUENCE".into()))?;
589
590    // Skip version.
591    let rest =
592        skip_der_element(inner).map_err(|_| SourceError::Auth("invalid PKCS#1 version".into()))?;
593
594    // Read n.
595    let (rest, n_bytes) =
596        parse_der_integer(rest).map_err(|_| SourceError::Auth("invalid PKCS#1 modulus".into()))?;
597
598    // Skip e.
599    let rest =
600        skip_der_element(rest).map_err(|_| SourceError::Auth("invalid PKCS#1 exponent".into()))?;
601
602    // Read d.
603    let (_rest, d_bytes) = parse_der_integer(rest)
604        .map_err(|_| SourceError::Auth("invalid PKCS#1 private exponent".into()))?;
605
606    Ok(RsaKeyParts { n_bytes, d_bytes })
607}
608
609// Minimal DER parsing helpers.
610
611fn parse_der_length(data: &[u8]) -> Result<(usize, &[u8]), ()> {
612    if data.is_empty() {
613        return Err(());
614    }
615    if data[0] < 0x80 {
616        Ok((data[0] as usize, &data[1..]))
617    } else {
618        let num_bytes = (data[0] & 0x7F) as usize;
619        if num_bytes == 0 || num_bytes > 4 || data.len() < 1 + num_bytes {
620            return Err(());
621        }
622        let mut len: usize = 0;
623        for i in 0..num_bytes {
624            len = (len << 8) | data[1 + i] as usize;
625        }
626        Ok((len, &data[1 + num_bytes..]))
627    }
628}
629
630fn parse_der_sequence(data: &[u8]) -> Result<(&[u8], &[u8]), ()> {
631    if data.is_empty() || data[0] != 0x30 {
632        return Err(());
633    }
634    let (len, rest) = parse_der_length(&data[1..])?;
635    if rest.len() < len {
636        return Err(());
637    }
638    Ok((&rest[len..], &rest[..len]))
639}
640
641fn parse_der_octet_string(data: &[u8]) -> Result<(&[u8], &[u8]), ()> {
642    if data.is_empty() || data[0] != 0x04 {
643        return Err(());
644    }
645    let (len, rest) = parse_der_length(&data[1..])?;
646    if rest.len() < len {
647        return Err(());
648    }
649    Ok((&rest[len..], &rest[..len]))
650}
651
652fn parse_der_integer(data: &[u8]) -> Result<(&[u8], Vec<u8>), ()> {
653    if data.is_empty() || data[0] != 0x02 {
654        return Err(());
655    }
656    let (len, rest) = parse_der_length(&data[1..])?;
657    if rest.len() < len {
658        return Err(());
659    }
660    let mut bytes = rest[..len].to_vec();
661    // Strip leading zero used for positive sign.
662    if bytes.len() > 1 && bytes[0] == 0 {
663        bytes.remove(0);
664    }
665    Ok((&rest[len..], bytes))
666}
667
668fn skip_der_element(data: &[u8]) -> Result<&[u8], ()> {
669    if data.is_empty() {
670        return Err(());
671    }
672    let (len, rest) = parse_der_length(&data[1..])?;
673    if rest.len() < len {
674        return Err(());
675    }
676    Ok(&rest[len..])
677}
678
679/// URL-safe Base64 encoding without padding.
680fn base64_url_encode(data: &[u8]) -> String {
681    use base64::Engine;
682    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
683}