1use crate::cache::constants::*;
7use crate::cache::source::{GitReference, SourceDetector, SourceType};
8use crate::cache::storage::CacheStorage;
9use crate::cache::tools::{
10 CacheCrateFromCratesIOParams, CacheCrateFromGitHubParams, CacheCrateFromLocalParams,
11};
12use crate::cache::utils::copy_directory_contents;
13use anyhow::{Context, Result, bail};
14use flate2::read::GzDecoder;
15use futures::StreamExt;
16use git2::{Cred, FetchOptions, RemoteCallbacks};
17use std::env;
18use std::fs::{self, File};
19use std::io::Write;
20use std::path::{Path, PathBuf};
21use tar::Archive;
22use zeroize::Zeroizing;
23
24const LOCK_TIMEOUT_SECS: u64 = 60;
26const LOCK_POLL_INTERVAL_MS: u64 = 100;
27
28struct LockGuard {
30 path: PathBuf,
31}
32
33impl Drop for LockGuard {
34 fn drop(&mut self) {
35 let _ = std::fs::remove_file(&self.path);
36 }
37}
38
39#[derive(Debug, Clone)]
41pub enum CrateSource {
42 CratesIO(CacheCrateFromCratesIOParams),
43 GitHub(CacheCrateFromGitHubParams),
44 LocalPath(CacheCrateFromLocalParams),
45}
46
47#[derive(Debug, Clone)]
49pub struct CrateDownloader {
50 storage: CacheStorage,
51 client: reqwest::Client,
52}
53
54impl CrateDownloader {
55 pub fn new(storage: CacheStorage) -> Self {
57 let client = Self::build_http_client();
58 Self { storage, client }
59 }
60
61 fn build_http_client() -> reqwest::Client {
63 let user_agent = Self::format_user_agent();
64
65 tracing::info!("Creating HTTP client with User-Agent: {}", user_agent);
66
67 reqwest::Client::builder()
68 .user_agent(user_agent)
69 .redirect(reqwest::redirect::Policy::limited(10))
70 .build()
71 .expect("Failed to create HTTP client") }
73
74 fn format_user_agent() -> String {
76 format!(
77 "{}/{} ({})",
78 env!("CARGO_PKG_NAME"),
79 env!("CARGO_PKG_VERSION"),
80 env!("CARGO_PKG_REPOSITORY")
81 )
82 }
83
84 pub async fn download_or_copy_crate(
86 &self,
87 name: &str,
88 version: &str,
89 source: Option<&str>,
90 ) -> Result<PathBuf> {
91 let source_type = SourceDetector::detect(source);
92
93 match source_type {
94 SourceType::CratesIo => self.download_crate(name, version).await,
95 SourceType::GitHub {
96 url,
97 reference,
98 repo_path,
99 } => {
100 let version_str = match reference {
101 GitReference::Branch(branch) => branch,
102 GitReference::Tag(tag) => tag,
103 GitReference::Default => "main".to_string(),
104 };
105 self.download_from_github(name, &version_str, &url, repo_path.as_deref())
106 .await
107 }
108 SourceType::Local { path } => self.copy_from_local(name, version, &path).await,
109 }
110 }
111
112 async fn download_crate(&self, name: &str, version: &str) -> Result<PathBuf> {
114 if self.storage.is_cached(name, version) {
116 tracing::info!("Crate {}-{} already cached", name, version);
117 return self.storage.source_path(name, version);
118 }
119
120 let crate_path = self.storage.crate_path(name, version)?;
122 let lock_path = crate_path.with_extension("lock");
123
124 if lock_path.exists() {
126 tracing::info!(
127 "Another process is downloading {}-{}, waiting...",
128 name,
129 version
130 );
131 let start = std::time::Instant::now();
133 while lock_path.exists()
134 && start.elapsed() < std::time::Duration::from_secs(LOCK_TIMEOUT_SECS)
135 {
136 tokio::time::sleep(std::time::Duration::from_millis(LOCK_POLL_INTERVAL_MS)).await;
137 }
138
139 if self.storage.is_cached(name, version) {
141 tracing::info!("Crate {}-{} was cached by another process", name, version);
142 return self.storage.source_path(name, version);
143 }
144 }
145
146 if let Some(parent) = lock_path.parent() {
148 self.storage.ensure_dir(parent)?;
149 }
150 std::fs::write(&lock_path, "downloading").context("Failed to create lock file")?;
151
152 let _lock_guard = LockGuard {
154 path: lock_path.clone(),
155 };
156
157 tracing::info!(
158 "Starting fresh download of {}-{} from crates.io",
159 name,
160 version
161 );
162
163 let url = format!("https://crates.io/api/v1/crates/{name}/{version}/download");
164 tracing::debug!("Download URL: {}", url);
165
166 let response = self
167 .client
168 .get(&url)
169 .send()
170 .await
171 .with_context(|| format!("Failed to download {name}-{version}"))?;
172
173 if !response.status().is_success() {
174 bail!(
175 "Failed to download {}-{}: HTTP {}",
176 name,
177 version,
178 response.status()
179 );
180 }
181
182 let temp_file_path = std::env::temp_dir().join(format!(
184 "{name}-{version}-{}-{}.tar.gz",
185 std::process::id(),
186 uuid::Uuid::new_v4().simple()
187 ));
188 let mut temp_file = File::create(&temp_file_path)
189 .with_context(|| format!("Failed to create temporary file for {name}-{version}"))?;
190
191 let mut stream = response.bytes_stream();
192 while let Some(chunk) = stream.next().await {
193 let chunk = chunk.context("Failed to read chunk from download stream")?;
194 temp_file
195 .write_all(&chunk)
196 .context("Failed to write to temporary file")?;
197 }
198
199 let source_path = self.storage.source_path(name, version)?;
201 self.storage.ensure_dir(&source_path)?;
202
203 let tar_gz = File::open(&temp_file_path).context("Failed to open downloaded file")?;
204 let tar = GzDecoder::new(tar_gz);
205 let mut archive = Archive::new(tar);
206
207 for entry in archive.entries()? {
209 let mut entry = entry?;
210 let path = entry.path()?;
211
212 let components: Vec<_> = path.components().collect();
214 if components.len() > 1 {
215 let relative_path: PathBuf = components[1..].iter().collect();
216
217 let has_parent_refs = relative_path
220 .components()
221 .any(|c| matches!(c, std::path::Component::ParentDir));
222
223 if has_parent_refs {
224 tracing::warn!(
225 "Skipping entry with parent directory reference: {}",
226 path.display()
227 );
228 continue;
229 }
230
231 let dest_path = source_path.join(&relative_path);
232
233 let canonical_source = source_path
235 .canonicalize()
236 .unwrap_or_else(|_| source_path.clone());
237
238 if let Ok(canonical_dest) = dest_path.canonicalize() {
239 if !canonical_dest.starts_with(&canonical_source) {
240 tracing::warn!(
241 "Skipping entry that would escape destination: {}",
242 path.display()
243 );
244 continue;
245 }
246 } else if let Some(parent) = dest_path.parent() {
247 if matches!(parent.canonicalize(), Ok(canonical_parent) if !canonical_parent.starts_with(&canonical_source))
249 {
250 tracing::warn!(
251 "Skipping entry with parent outside destination: {}",
252 path.display()
253 );
254 continue;
255 }
256 }
257
258 if let Some(parent) = dest_path.parent() {
259 std::fs::create_dir_all(parent)?;
260 }
261
262 entry.unpack(&dest_path)?;
263 }
264 }
265
266 std::fs::remove_file(&temp_file_path).ok();
268
269 self.storage.save_metadata(name, version)?;
271
272 tracing::info!("Successfully downloaded and extracted {}-{}", name, version);
273 Ok(source_path)
274 }
275
276 async fn download_from_github(
278 &self,
279 name: &str,
280 version: &str,
281 repo_url: &str,
282 repo_path: Option<&str>,
283 ) -> Result<PathBuf> {
284 if self.storage.is_cached(name, version) {
286 tracing::info!("Crate {}-{} already cached", name, version);
287 return self.storage.source_path(name, version);
288 }
289
290 let crate_path = self.storage.crate_path(name, version)?;
292 let lock_path = crate_path.with_extension("lock");
293
294 if lock_path.exists() {
296 tracing::info!(
297 "Another process is downloading {}-{}, waiting...",
298 name,
299 version
300 );
301 let start = std::time::Instant::now();
303 while lock_path.exists()
304 && start.elapsed() < std::time::Duration::from_secs(LOCK_TIMEOUT_SECS)
305 {
306 tokio::time::sleep(std::time::Duration::from_millis(LOCK_POLL_INTERVAL_MS)).await;
307 }
308
309 if self.storage.is_cached(name, version) {
311 tracing::info!("Crate {}-{} was cached by another process", name, version);
312 return self.storage.source_path(name, version);
313 }
314 }
315
316 if let Some(parent) = lock_path.parent() {
318 self.storage.ensure_dir(parent)?;
319 }
320 std::fs::write(&lock_path, "downloading").context("Failed to create lock file")?;
321
322 let _lock_guard = LockGuard {
324 path: lock_path.clone(),
325 };
326
327 tracing::info!(
328 "Downloading crate {}-{} from GitHub: {}",
329 name,
330 version,
331 repo_url
332 );
333
334 let temp_dir = std::env::temp_dir().join(format!("rust-docs-mcp-git-{name}-{version}"));
335
336 if temp_dir.exists() {
338 fs::remove_dir_all(&temp_dir).context("Failed to clean temp directory")?;
339 }
340
341 let github_token = env::var("GITHUB_TOKEN").ok().map(Zeroizing::new);
343 let has_token = github_token.is_some();
344
345 let mut fetch_options = FetchOptions::new();
347 let mut callbacks = RemoteCallbacks::new();
348
349 if let Some(token) = github_token {
350 tracing::debug!("Using GITHUB_TOKEN for authentication");
351 callbacks.credentials(move |_url, username_from_url, _allowed_types| {
352 Cred::userpass_plaintext(username_from_url.unwrap_or("git"), &token)
353 });
354 } else {
355 tracing::debug!("No GITHUB_TOKEN found, using unauthenticated access");
356 }
357
358 fetch_options.remote_callbacks(callbacks);
359
360 let mut builder = git2::build::RepoBuilder::new();
362 builder.fetch_options(fetch_options);
363
364 let repo = builder
365 .clone(repo_url, &temp_dir)
366 .with_context(|| {
367 let mut msg = format!("Failed to clone repository: {repo_url}");
368 if !has_token && repo_url.contains("github.com") {
369 msg.push_str("\nNote: Set GITHUB_TOKEN environment variable for private repositories and higher rate limits");
370 }
371 msg
372 })?;
373
374 if version != "main" && version != "master" {
377 if !Self::is_valid_git_ref(version) {
379 bail!("Invalid git reference name: {}", version);
380 }
381
382 let refname = format!("refs/remotes/origin/{version}");
384 if let Ok(reference) = repo.find_reference(&refname) {
385 let oid = reference
386 .target()
387 .ok_or_else(|| anyhow::anyhow!("Reference has no target"))?;
388 repo.set_head_detached(oid)
389 .with_context(|| format!("Failed to checkout branch: {version}"))?;
390 repo.checkout_head(Some(git2::build::CheckoutBuilder::default().force()))
391 .with_context(|| format!("Failed to checkout branch: {version}"))?;
392 } else {
393 let tag_ref = format!("refs/tags/{version}");
395 if let Ok(reference) = repo.find_reference(&tag_ref) {
396 let oid = reference
397 .target()
398 .ok_or_else(|| anyhow::anyhow!("Reference has no target"))?;
399 repo.set_head_detached(oid)
400 .with_context(|| format!("Failed to checkout tag: {version}"))?;
401 repo.checkout_head(Some(git2::build::CheckoutBuilder::default().force()))
402 .with_context(|| format!("Failed to checkout tag: {version}"))?;
403 } else {
404 bail!("Could not find branch or tag: {}", version);
405 }
406 }
407 }
408
409 let repo_source_path = if let Some(path) = repo_path {
411 temp_dir.join(path)
412 } else {
413 temp_dir.clone()
414 };
415
416 let cargo_toml = repo_source_path.join(CARGO_TOML);
418 if !cargo_toml.exists() {
419 bail!(
420 "No Cargo.toml found at path: {}",
421 repo_source_path.display()
422 );
423 }
424
425 let source_path = self.storage.source_path(name, version)?;
427 self.storage.ensure_dir(&source_path)?;
428
429 copy_directory_contents(&repo_source_path, &source_path)
430 .context("Failed to copy repository contents")?;
431
432 fs::remove_dir_all(&temp_dir).ok();
434
435 let source_info = match repo_path {
437 Some(path) => format!("{repo_url}#{path}"),
438 None => repo_url.to_string(),
439 };
440 self.storage.save_metadata_with_source(
441 name,
442 version,
443 "github",
444 Some(&source_info),
445 None,
446 )?;
447
448 tracing::info!(
449 "Successfully downloaded and extracted {}-{} from GitHub",
450 name,
451 version
452 );
453 Ok(source_path)
454 }
455
456 async fn copy_from_local(
458 &self,
459 name: &str,
460 version: &str,
461 local_path: &str,
462 ) -> Result<PathBuf> {
463 tracing::info!(
464 "Copying crate {}-{} from local path: {}",
465 name,
466 version,
467 local_path
468 );
469
470 let expanded_path = shellexpand::full(local_path)
472 .with_context(|| format!("Failed to expand path: {local_path}"))?;
473 let source_path_input = Path::new(expanded_path.as_ref());
474
475 if !source_path_input.exists() {
477 bail!("Local path does not exist: {}", source_path_input.display());
478 }
479
480 let cargo_toml = source_path_input.join(CARGO_TOML);
481 if !cargo_toml.exists() {
482 bail!(
483 "No Cargo.toml found at path: {}",
484 source_path_input.display()
485 );
486 }
487
488 let source_path = self.storage.source_path(name, version)?;
490 self.storage.ensure_dir(&source_path)?;
491
492 copy_directory_contents(source_path_input, &source_path)
493 .context("Failed to copy local directory contents")?;
494
495 self.storage
497 .save_metadata_with_source(name, version, "local", Some(local_path), None)?;
498
499 tracing::info!("Successfully copied {}-{} from local path", name, version);
500 Ok(source_path)
501 }
502
503 fn is_valid_git_ref(ref_name: &str) -> bool {
505 if ref_name.is_empty() || ref_name.contains("..") {
513 return false;
514 }
515
516 if ref_name.starts_with('.')
517 || ref_name.ends_with('.')
518 || ref_name.starts_with('/')
519 || ref_name.ends_with('/')
520 {
521 return false;
522 }
523
524 ref_name.chars().all(|c| {
527 c.is_alphanumeric() || c == '-' || c == '_' || c == '.' || c == '/' || c == '+' })
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535 use tempfile::TempDir;
536
537 #[test]
538 fn test_downloader_creation() {
539 let temp_dir = TempDir::new().unwrap();
540 let storage = CacheStorage::new(Some(temp_dir.path().to_path_buf())).unwrap();
541 let downloader = CrateDownloader::new(storage);
542
543 assert!(format!("{downloader:?}").contains("CrateDownloader"));
545 }
546
547 #[tokio::test]
548 async fn test_user_agent_set() {
549 let _ = tracing_subscriber::fmt()
551 .with_env_filter("rust_docs_mcp=debug")
552 .try_init();
553
554 let temp_dir = TempDir::new().unwrap();
556 let storage = CacheStorage::new(Some(temp_dir.path().to_path_buf())).unwrap();
557
558 let downloader = CrateDownloader::new(storage);
560
561 match downloader.download_crate("serde", "1.0.0").await {
564 Ok(path) => {
565 assert!(path.exists());
566 println!("Successfully downloaded crate to: {path:?}");
567 }
568 Err(e) => {
569 let error_msg = format!("{e}");
571 assert!(!error_msg.contains("403"), "Got 403 error: {error_msg}");
572 }
573 }
574 }
575
576 #[tokio::test]
577 async fn test_problematic_crate_download() {
578 let _ = tracing_subscriber::fmt()
580 .with_env_filter("rust_docs_mcp=debug")
581 .try_init();
582
583 let temp_dir = TempDir::new().unwrap();
585 let storage = CacheStorage::new(Some(temp_dir.path().to_path_buf())).unwrap();
586 let downloader = CrateDownloader::new(storage);
587
588 match downloader
589 .download_crate("google-sheets4", "6.0.0+20240621")
590 .await
591 {
592 Ok(path) => {
593 assert!(path.exists());
594 println!("Successfully downloaded google-sheets4-6.0.0+20240621 to: {path:?}");
595 }
596 Err(e) => {
597 panic!("Failed to download google-sheets4: {e}");
598 }
599 }
600 }
601}