1use std::path::Path;
14use std::sync::Mutex;
15use std::time::{Duration, Instant};
16
17use async_trait::async_trait;
18use sha2::{Digest, Sha256};
19
20use super::{ContentSourceProvider, SourceError, SourceFile};
21use crate::automation::watchtower::matches_patterns;
22
23pub struct GoogleDriveProvider {
32 folder_id: String,
33 service_account_key_path: String,
34 http_client: reqwest::Client,
35 token_cache: Mutex<Option<CachedToken>>,
36}
37
38struct CachedToken {
39 access_token: String,
40 expires_at: Instant,
41}
42
43impl GoogleDriveProvider {
44 pub fn new(folder_id: String, service_account_key_path: String) -> Self {
45 Self {
46 folder_id,
47 service_account_key_path,
48 http_client: reqwest::Client::new(),
49 token_cache: Mutex::new(None),
50 }
51 }
52
53 #[cfg(test)]
55 pub fn with_client(
56 folder_id: String,
57 service_account_key_path: String,
58 client: reqwest::Client,
59 ) -> Self {
60 Self {
61 folder_id,
62 service_account_key_path,
63 http_client: client,
64 token_cache: Mutex::new(None),
65 }
66 }
67
68 async fn get_access_token(&self) -> Result<String, SourceError> {
70 if let Ok(cache) = self.token_cache.lock() {
72 if let Some(ref tok) = *cache {
73 if tok.expires_at > Instant::now() + Duration::from_secs(60) {
74 return Ok(tok.access_token.clone());
75 }
76 }
77 }
78
79 let token = self.fetch_new_token().await?;
80 let access_token = token.access_token.clone();
81
82 if let Ok(mut cache) = self.token_cache.lock() {
83 *cache = Some(token);
84 }
85
86 Ok(access_token)
87 }
88
89 async fn fetch_new_token(&self) -> Result<CachedToken, SourceError> {
92 let key_bytes = tokio::fs::read_to_string(&self.service_account_key_path)
93 .await
94 .map_err(|e| {
95 SourceError::Auth(format!(
96 "cannot read service account key {}: {e}",
97 self.service_account_key_path
98 ))
99 })?;
100
101 let key_json: serde_json::Value = serde_json::from_str(&key_bytes)
102 .map_err(|e| SourceError::Auth(format!("invalid service account JSON: {e}")))?;
103
104 let client_email = key_json["client_email"]
105 .as_str()
106 .ok_or_else(|| SourceError::Auth("missing client_email in key".into()))?;
107
108 let private_key_pem = key_json["private_key"]
109 .as_str()
110 .ok_or_else(|| SourceError::Auth("missing private_key in key".into()))?;
111
112 let token_uri = key_json["token_uri"]
113 .as_str()
114 .unwrap_or("https://oauth2.googleapis.com/token");
115
116 let now = chrono::Utc::now().timestamp();
118 let claims = serde_json::json!({
119 "iss": client_email,
120 "scope": "https://www.googleapis.com/auth/drive.readonly",
121 "aud": token_uri,
122 "iat": now,
123 "exp": now + 3600,
124 });
125
126 let jwt = build_jwt(&claims, private_key_pem)?;
127
128 let resp = self
130 .http_client
131 .post(token_uri)
132 .form(&[
133 ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
134 ("assertion", &jwt),
135 ])
136 .send()
137 .await
138 .map_err(|e| SourceError::Auth(format!("token exchange failed: {e}")))?;
139
140 if !resp.status().is_success() {
141 let body = resp.text().await.unwrap_or_default();
142 return Err(SourceError::Auth(format!(
143 "token endpoint returned error: {body}"
144 )));
145 }
146
147 let body: serde_json::Value = resp
148 .json()
149 .await
150 .map_err(|e| SourceError::Auth(format!("invalid token response: {e}")))?;
151
152 let access_token = body["access_token"]
153 .as_str()
154 .ok_or_else(|| SourceError::Auth("no access_token in response".into()))?
155 .to_string();
156
157 let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
158
159 Ok(CachedToken {
160 access_token,
161 expires_at: Instant::now() + Duration::from_secs(expires_in),
162 })
163 }
164}
165
166#[async_trait]
167impl ContentSourceProvider for GoogleDriveProvider {
168 fn source_type(&self) -> &str {
169 "google_drive"
170 }
171
172 async fn scan_for_changes(
173 &self,
174 since_cursor: Option<&str>,
175 patterns: &[String],
176 ) -> Result<Vec<SourceFile>, SourceError> {
177 let token = self.get_access_token().await?;
178
179 let mut q = format!("'{}' in parents and trashed = false", self.folder_id);
181
182 if let Some(cursor) = since_cursor {
184 q.push_str(&format!(" and modifiedTime > '{cursor}'"));
185 }
186
187 let resp = self
188 .http_client
189 .get("https://www.googleapis.com/drive/v3/files")
190 .bearer_auth(&token)
191 .query(&[
192 ("q", q.as_str()),
193 ("fields", "files(id,name,md5Checksum,modifiedTime,mimeType)"),
194 ("pageSize", "1000"),
195 ])
196 .send()
197 .await
198 .map_err(|e| SourceError::Network(format!("Drive list failed: {e}")))?;
199
200 if !resp.status().is_success() {
201 let body = resp.text().await.unwrap_or_default();
202 return Err(SourceError::Network(format!("Drive API error: {body}")));
203 }
204
205 let body: serde_json::Value = resp
206 .json()
207 .await
208 .map_err(|e| SourceError::Network(format!("invalid Drive response: {e}")))?;
209
210 let files = body["files"].as_array().cloned().unwrap_or_default();
211
212 let mut result = Vec::new();
213 for file in &files {
214 let id = match file["id"].as_str() {
215 Some(id) => id,
216 None => continue,
217 };
218 let name = file["name"].as_str().unwrap_or("unknown");
219
220 if !patterns.is_empty() && !matches_patterns(Path::new(name), patterns) {
222 continue;
223 }
224
225 let hash = file["md5Checksum"].as_str().unwrap_or("").to_string();
226 let modified = file["modifiedTime"].as_str().unwrap_or("").to_string();
227
228 result.push(SourceFile {
229 provider_id: format!("gdrive://{id}/{name}"),
230 display_name: name.to_string(),
231 content_hash: hash,
232 modified_at: modified,
233 });
234 }
235
236 Ok(result)
237 }
238
239 async fn read_content(&self, file_id: &str) -> Result<String, SourceError> {
240 let drive_id = extract_drive_id(file_id)?;
242
243 let token = self.get_access_token().await?;
244
245 let url = format!("https://www.googleapis.com/drive/v3/files/{drive_id}?alt=media");
246
247 let resp = self
248 .http_client
249 .get(&url)
250 .bearer_auth(&token)
251 .send()
252 .await
253 .map_err(|e| SourceError::Network(format!("Drive get failed: {e}")))?;
254
255 if resp.status() == reqwest::StatusCode::NOT_FOUND {
256 return Err(SourceError::NotFound(format!("file {drive_id} not found")));
257 }
258
259 if !resp.status().is_success() {
260 let body = resp.text().await.unwrap_or_default();
261 return Err(SourceError::Network(format!(
262 "Drive download error: {body}"
263 )));
264 }
265
266 resp.text()
267 .await
268 .map_err(|e| SourceError::Network(format!("read body failed: {e}")))
269 }
270}
271
272#[cfg(test)]
278impl GoogleDriveProvider {
279 pub fn extract_drive_id_for_test(provider_id: &str) -> String {
280 extract_drive_id(provider_id).unwrap()
281 }
282}
283
284fn extract_drive_id(provider_id: &str) -> Result<String, SourceError> {
287 if let Some(rest) = provider_id.strip_prefix("gdrive://") {
288 if let Some(slash) = rest.find('/') {
289 Ok(rest[..slash].to_string())
290 } else {
291 Ok(rest.to_string())
292 }
293 } else {
294 Ok(provider_id.to_string())
295 }
296}
297
298fn build_jwt(claims: &serde_json::Value, private_key_pem: &str) -> Result<String, SourceError> {
302 let header = base64_url_encode(
303 &serde_json::to_vec(&serde_json::json!({"alg": "RS256", "typ": "JWT"}))
304 .map_err(|e| SourceError::Auth(format!("JWT header: {e}")))?,
305 );
306 let payload = base64_url_encode(
307 &serde_json::to_vec(claims).map_err(|e| SourceError::Auth(format!("JWT payload: {e}")))?,
308 );
309
310 let signing_input = format!("{header}.{payload}");
311
312 let signature = rsa_sign_sha256(signing_input.as_bytes(), private_key_pem)?;
313 let sig_b64 = base64_url_encode(&signature);
314
315 Ok(format!("{signing_input}.{sig_b64}"))
316}
317
318fn rsa_sign_sha256(data: &[u8], pem: &str) -> Result<Vec<u8>, SourceError> {
320 let der = pem_to_der(pem)?;
322
323 let hash = Sha256::digest(data);
325
326 let digest_info = build_pkcs1_digest_info(&hash);
328
329 rsa_pkcs1_sign(&der, &digest_info)
331}
332
333fn pem_to_der(pem: &str) -> Result<Vec<u8>, SourceError> {
335 let pem = pem.trim();
336 let body: String = pem
337 .lines()
338 .filter(|line| !line.starts_with("-----"))
339 .collect::<Vec<_>>()
340 .join("");
341
342 use base64::Engine;
343 base64::engine::general_purpose::STANDARD
344 .decode(&body)
345 .map_err(|e| SourceError::Auth(format!("PEM decode failed: {e}")))
346}
347
348fn build_pkcs1_digest_info(hash: &[u8]) -> Vec<u8> {
350 let prefix: &[u8] = &[
353 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01,
354 0x05, 0x00, 0x04, 0x20,
355 ];
356 let mut info = prefix.to_vec();
357 info.extend_from_slice(hash);
358 info
359}
360
361fn rsa_pkcs1_sign(der: &[u8], digest_info: &[u8]) -> Result<Vec<u8>, SourceError> {
366 let rsa_key = parse_pkcs8_rsa(der)?;
368
369 let k = rsa_key.n_bytes.len(); if digest_info.len() + 11 > k {
373 return Err(SourceError::Auth("RSA key too small for signature".into()));
374 }
375
376 let mut em = vec![0x00, 0x01];
377 let ps_len = k - digest_info.len() - 3;
378 em.extend(std::iter::repeat(0xFF).take(ps_len));
379 em.push(0x00);
380 em.extend_from_slice(digest_info);
381
382 let m = BigUint::from_bytes_be(&em);
384 let n = BigUint::from_bytes_be(&rsa_key.n_bytes);
385 let d = BigUint::from_bytes_be(&rsa_key.d_bytes);
386
387 let sig = mod_pow(&m, &d, &n);
388 let mut sig_bytes = sig.to_bytes_be();
389
390 while sig_bytes.len() < k {
392 sig_bytes.insert(0, 0);
393 }
394
395 Ok(sig_bytes)
396}
397
398#[derive(Clone)]
404struct BigUint {
405 bytes: Vec<u8>,
407}
408
409impl BigUint {
410 fn from_bytes_be(b: &[u8]) -> Self {
411 let start = b
412 .iter()
413 .position(|&x| x != 0)
414 .unwrap_or(b.len().saturating_sub(1));
415 Self {
416 bytes: b[start..].to_vec(),
417 }
418 }
419
420 fn to_bytes_be(&self) -> Vec<u8> {
421 self.bytes.clone()
422 }
423
424 fn is_zero(&self) -> bool {
425 self.bytes.iter().all(|&b| b == 0)
426 }
427
428 fn bit_len(&self) -> usize {
429 if self.is_zero() {
430 return 0;
431 }
432 let top = self.bytes[0];
433 (self.bytes.len() - 1) * 8 + (8 - top.leading_zeros() as usize)
434 }
435
436 fn bit(&self, i: usize) -> bool {
437 let byte_idx = self.bytes.len() - 1 - i / 8;
438 if byte_idx >= self.bytes.len() {
439 return false;
440 }
441 (self.bytes[byte_idx] >> (i % 8)) & 1 == 1
442 }
443
444 fn mul_mod(a: &BigUint, b: &BigUint, m: &BigUint) -> BigUint {
445 let a_len = a.bytes.len();
447 let b_len = b.bytes.len();
448 let mut result = vec![0u32; a_len + b_len];
449
450 for i in (0..a_len).rev() {
451 let mut carry: u32 = 0;
452 for j in (0..b_len).rev() {
453 let prod = (a.bytes[i] as u32) * (b.bytes[j] as u32) + result[i + j + 1] + carry;
454 result[i + j + 1] = prod & 0xFF;
455 carry = prod >> 8;
456 }
457 result[i] += carry;
458 }
459
460 let bytes: Vec<u8> = result.iter().map(|&x| x as u8).collect();
461 let val = BigUint::from_bytes_be(&bytes);
462 BigUint::modulo(&val, m)
463 }
464
465 fn modulo(a: &BigUint, m: &BigUint) -> BigUint {
466 if a.bytes.len() < m.bytes.len() {
467 return a.clone();
468 }
469 let mut remainder = BigUint::from_bytes_be(&[0]);
471 let total_bits = a.bytes.len() * 8;
472
473 for i in (0..total_bits).rev() {
474 remainder = BigUint::shift_left_one(&remainder);
476 if a.bit(i) {
477 let last = remainder.bytes.len() - 1;
478 remainder.bytes[last] |= 1;
479 }
480 if BigUint::gte(&remainder, m) {
481 remainder = BigUint::sub(&remainder, m);
482 }
483 }
484 remainder
485 }
486
487 fn shift_left_one(a: &BigUint) -> BigUint {
488 let mut result = vec![0u8; a.bytes.len() + 1];
489 let mut carry = 0u8;
490 for i in (0..a.bytes.len()).rev() {
491 let val = (a.bytes[i] as u16) * 2 + carry as u16;
492 result[i + 1] = val as u8;
493 carry = (val >> 8) as u8;
494 }
495 result[0] = carry;
496 BigUint::from_bytes_be(&result)
497 }
498
499 fn gte(a: &BigUint, b: &BigUint) -> bool {
500 if a.bytes.len() != b.bytes.len() {
501 return a.bytes.len() > b.bytes.len();
502 }
503 a.bytes >= b.bytes
504 }
505
506 fn sub(a: &BigUint, b: &BigUint) -> BigUint {
507 let len = a.bytes.len().max(b.bytes.len());
508 let mut result = vec![0i16; len];
509 let a_off = len - a.bytes.len();
510 let b_off = len - b.bytes.len();
511
512 for i in 0..a.bytes.len() {
513 result[a_off + i] += a.bytes[i] as i16;
514 }
515 for i in 0..b.bytes.len() {
516 result[b_off + i] -= b.bytes[i] as i16;
517 }
518
519 for i in (1..len).rev() {
521 if result[i] < 0 {
522 result[i] += 256;
523 result[i - 1] -= 1;
524 }
525 }
526
527 let bytes: Vec<u8> = result.iter().map(|&x| x as u8).collect();
528 BigUint::from_bytes_be(&bytes)
529 }
530}
531
532fn mod_pow(base: &BigUint, exp: &BigUint, modulus: &BigUint) -> BigUint {
534 let bits = exp.bit_len();
535 if bits == 0 {
536 return BigUint::from_bytes_be(&[1]);
537 }
538
539 let mut result = BigUint::from_bytes_be(&[1]);
540 let mut b = BigUint::modulo(base, modulus);
541
542 for i in 0..bits {
543 if exp.bit(i) {
544 result = BigUint::mul_mod(&result, &b, modulus);
545 }
546 b = BigUint::mul_mod(&b, &b, modulus);
547 }
548 result
549}
550
551struct RsaKeyParts {
556 n_bytes: Vec<u8>,
557 d_bytes: Vec<u8>,
558}
559
560fn parse_pkcs8_rsa(der: &[u8]) -> Result<RsaKeyParts, SourceError> {
562 let (_, inner) = parse_der_sequence(der)
567 .map_err(|_| SourceError::Auth("invalid PKCS#8 outer SEQUENCE".into()))?;
568
569 let rest =
571 skip_der_element(inner).map_err(|_| SourceError::Auth("invalid PKCS#8 version".into()))?;
572
573 let rest =
575 skip_der_element(rest).map_err(|_| SourceError::Auth("invalid PKCS#8 algorithm".into()))?;
576
577 let (_, pkcs1_der) = parse_der_octet_string(rest)
579 .map_err(|_| SourceError::Auth("invalid PKCS#8 octet string".into()))?;
580
581 parse_pkcs1_rsa(pkcs1_der)
582}
583
584fn parse_pkcs1_rsa(der: &[u8]) -> Result<RsaKeyParts, SourceError> {
586 let (_, inner) =
588 parse_der_sequence(der).map_err(|_| SourceError::Auth("invalid PKCS#1 SEQUENCE".into()))?;
589
590 let rest =
592 skip_der_element(inner).map_err(|_| SourceError::Auth("invalid PKCS#1 version".into()))?;
593
594 let (rest, n_bytes) =
596 parse_der_integer(rest).map_err(|_| SourceError::Auth("invalid PKCS#1 modulus".into()))?;
597
598 let rest =
600 skip_der_element(rest).map_err(|_| SourceError::Auth("invalid PKCS#1 exponent".into()))?;
601
602 let (_rest, d_bytes) = parse_der_integer(rest)
604 .map_err(|_| SourceError::Auth("invalid PKCS#1 private exponent".into()))?;
605
606 Ok(RsaKeyParts { n_bytes, d_bytes })
607}
608
609fn parse_der_length(data: &[u8]) -> Result<(usize, &[u8]), ()> {
612 if data.is_empty() {
613 return Err(());
614 }
615 if data[0] < 0x80 {
616 Ok((data[0] as usize, &data[1..]))
617 } else {
618 let num_bytes = (data[0] & 0x7F) as usize;
619 if num_bytes == 0 || num_bytes > 4 || data.len() < 1 + num_bytes {
620 return Err(());
621 }
622 let mut len: usize = 0;
623 for i in 0..num_bytes {
624 len = (len << 8) | data[1 + i] as usize;
625 }
626 Ok((len, &data[1 + num_bytes..]))
627 }
628}
629
630fn parse_der_sequence(data: &[u8]) -> Result<(&[u8], &[u8]), ()> {
631 if data.is_empty() || data[0] != 0x30 {
632 return Err(());
633 }
634 let (len, rest) = parse_der_length(&data[1..])?;
635 if rest.len() < len {
636 return Err(());
637 }
638 Ok((&rest[len..], &rest[..len]))
639}
640
641fn parse_der_octet_string(data: &[u8]) -> Result<(&[u8], &[u8]), ()> {
642 if data.is_empty() || data[0] != 0x04 {
643 return Err(());
644 }
645 let (len, rest) = parse_der_length(&data[1..])?;
646 if rest.len() < len {
647 return Err(());
648 }
649 Ok((&rest[len..], &rest[..len]))
650}
651
652fn parse_der_integer(data: &[u8]) -> Result<(&[u8], Vec<u8>), ()> {
653 if data.is_empty() || data[0] != 0x02 {
654 return Err(());
655 }
656 let (len, rest) = parse_der_length(&data[1..])?;
657 if rest.len() < len {
658 return Err(());
659 }
660 let mut bytes = rest[..len].to_vec();
661 if bytes.len() > 1 && bytes[0] == 0 {
663 bytes.remove(0);
664 }
665 Ok((&rest[len..], bytes))
666}
667
668fn skip_der_element(data: &[u8]) -> Result<&[u8], ()> {
669 if data.is_empty() {
670 return Err(());
671 }
672 let (len, rest) = parse_der_length(&data[1..])?;
673 if rest.len() < len {
674 return Err(());
675 }
676 Ok(&rest[len..])
677}
678
679fn base64_url_encode(data: &[u8]) -> String {
681 use base64::Engine;
682 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
683}