rust_docs_mcp/cache/
downloader.rs

1//! Crate downloading and source management
2//!
3//! This module handles downloading crates from various sources including
4//! crates.io, GitHub repositories, and local filesystem paths.
5
6use crate::cache::constants::*;
7use crate::cache::source::{GitReference, SourceDetector, SourceType};
8use crate::cache::storage::CacheStorage;
9use crate::cache::tools::{
10    CacheCrateFromCratesIOParams, CacheCrateFromGitHubParams, CacheCrateFromLocalParams,
11};
12use crate::cache::utils::copy_directory_contents;
13use anyhow::{Context, Result, bail};
14use flate2::read::GzDecoder;
15use futures::StreamExt;
16use git2::{Cred, FetchOptions, RemoteCallbacks};
17use std::env;
18use std::fs::{self, File};
19use std::io::Write;
20use std::path::{Path, PathBuf};
21use tar::Archive;
22use zeroize::Zeroizing;
23
24/// Constants for download operations
25const LOCK_TIMEOUT_SECS: u64 = 60;
26const LOCK_POLL_INTERVAL_MS: u64 = 100;
27
28/// RAII guard for cleaning up lock files
29struct LockGuard {
30    path: PathBuf,
31}
32
33impl Drop for LockGuard {
34    fn drop(&mut self) {
35        let _ = std::fs::remove_file(&self.path);
36    }
37}
38
39/// Unified crate source enum that reuses the parameter structs from tools
40#[derive(Debug, Clone)]
41pub enum CrateSource {
42    CratesIO(CacheCrateFromCratesIOParams),
43    GitHub(CacheCrateFromGitHubParams),
44    LocalPath(CacheCrateFromLocalParams),
45}
46
47/// Service for downloading crates from various sources
48#[derive(Debug, Clone)]
49pub struct CrateDownloader {
50    storage: CacheStorage,
51    client: reqwest::Client,
52}
53
54impl CrateDownloader {
55    /// Create a new crate downloader
56    pub fn new(storage: CacheStorage) -> Self {
57        let client = Self::build_http_client();
58        Self { storage, client }
59    }
60
61    /// Build the HTTP client with proper configuration
62    fn build_http_client() -> reqwest::Client {
63        let user_agent = Self::format_user_agent();
64
65        tracing::info!("Creating HTTP client with User-Agent: {}", user_agent);
66
67        reqwest::Client::builder()
68            .user_agent(user_agent)
69            .redirect(reqwest::redirect::Policy::limited(10))
70            .build()
71            .expect("Failed to create HTTP client") // HTTP client creation should not fail with proper configuration
72    }
73
74    /// Format the user-agent string for API compliance
75    fn format_user_agent() -> String {
76        format!(
77            "{}/{} ({})",
78            env!("CARGO_PKG_NAME"),
79            env!("CARGO_PKG_VERSION"),
80            env!("CARGO_PKG_REPOSITORY")
81        )
82    }
83
84    /// Download or copy a crate from the specified source
85    pub async fn download_or_copy_crate(
86        &self,
87        name: &str,
88        version: &str,
89        source: Option<&str>,
90    ) -> Result<PathBuf> {
91        let source_type = SourceDetector::detect(source);
92
93        match source_type {
94            SourceType::CratesIo => self.download_crate(name, version).await,
95            SourceType::GitHub {
96                url,
97                reference,
98                repo_path,
99            } => {
100                let version_str = match reference {
101                    GitReference::Branch(branch) => branch,
102                    GitReference::Tag(tag) => tag,
103                    GitReference::Default => "main".to_string(),
104                };
105                self.download_from_github(name, &version_str, &url, repo_path.as_deref())
106                    .await
107            }
108            SourceType::Local { path } => self.copy_from_local(name, version, &path).await,
109        }
110    }
111
112    /// Download a crate from crates.io
113    async fn download_crate(&self, name: &str, version: &str) -> Result<PathBuf> {
114        // Check if already cached
115        if self.storage.is_cached(name, version) {
116            tracing::info!("Crate {}-{} already cached", name, version);
117            return self.storage.source_path(name, version);
118        }
119
120        // Create a lock file to prevent concurrent downloads
121        let crate_path = self.storage.crate_path(name, version)?;
122        let lock_path = crate_path.with_extension("lock");
123
124        // Check if another process is already downloading
125        if lock_path.exists() {
126            tracing::info!(
127                "Another process is downloading {}-{}, waiting...",
128                name,
129                version
130            );
131            // Wait for the other process to finish (simple polling)
132            let start = std::time::Instant::now();
133            while lock_path.exists()
134                && start.elapsed() < std::time::Duration::from_secs(LOCK_TIMEOUT_SECS)
135            {
136                tokio::time::sleep(std::time::Duration::from_millis(LOCK_POLL_INTERVAL_MS)).await;
137            }
138
139            // Check if it was successfully cached by the other process
140            if self.storage.is_cached(name, version) {
141                tracing::info!("Crate {}-{} was cached by another process", name, version);
142                return self.storage.source_path(name, version);
143            }
144        }
145
146        // Create lock file
147        if let Some(parent) = lock_path.parent() {
148            self.storage.ensure_dir(parent)?;
149        }
150        std::fs::write(&lock_path, "downloading").context("Failed to create lock file")?;
151
152        // Ensure lock file is removed on exit
153        let _lock_guard = LockGuard {
154            path: lock_path.clone(),
155        };
156
157        tracing::info!(
158            "Starting fresh download of {}-{} from crates.io",
159            name,
160            version
161        );
162
163        let url = format!("https://crates.io/api/v1/crates/{name}/{version}/download");
164        tracing::debug!("Download URL: {}", url);
165
166        let response = self
167            .client
168            .get(&url)
169            .send()
170            .await
171            .with_context(|| format!("Failed to download {name}-{version}"))?;
172
173        if !response.status().is_success() {
174            bail!(
175                "Failed to download {}-{}: HTTP {}",
176                name,
177                version,
178                response.status()
179            );
180        }
181
182        // Save to a temporary file first - make path unique to avoid concurrent conflicts
183        let temp_file_path = std::env::temp_dir().join(format!(
184            "{name}-{version}-{}-{}.tar.gz",
185            std::process::id(),
186            uuid::Uuid::new_v4().simple()
187        ));
188        let mut temp_file = File::create(&temp_file_path)
189            .with_context(|| format!("Failed to create temporary file for {name}-{version}"))?;
190
191        let mut stream = response.bytes_stream();
192        while let Some(chunk) = stream.next().await {
193            let chunk = chunk.context("Failed to read chunk from download stream")?;
194            temp_file
195                .write_all(&chunk)
196                .context("Failed to write to temporary file")?;
197        }
198
199        // Extract the crate
200        let source_path = self.storage.source_path(name, version)?;
201        self.storage.ensure_dir(&source_path)?;
202
203        let tar_gz = File::open(&temp_file_path).context("Failed to open downloaded file")?;
204        let tar = GzDecoder::new(tar_gz);
205        let mut archive = Archive::new(tar);
206
207        // Extract with proper path handling
208        for entry in archive.entries()? {
209            let mut entry = entry?;
210            let path = entry.path()?;
211
212            // Skip the top-level directory (crate-version/)
213            let components: Vec<_> = path.components().collect();
214            if components.len() > 1 {
215                let relative_path: PathBuf = components[1..].iter().collect();
216
217                // Validate that the path doesn't escape the destination directory
218                // Check for path traversal attempts
219                let has_parent_refs = relative_path
220                    .components()
221                    .any(|c| matches!(c, std::path::Component::ParentDir));
222
223                if has_parent_refs {
224                    tracing::warn!(
225                        "Skipping entry with parent directory reference: {}",
226                        path.display()
227                    );
228                    continue;
229                }
230
231                let dest_path = source_path.join(&relative_path);
232
233                // Additional validation: ensure the destination is within source_path
234                let canonical_source = source_path
235                    .canonicalize()
236                    .unwrap_or_else(|_| source_path.clone());
237
238                if let Ok(canonical_dest) = dest_path.canonicalize() {
239                    if !canonical_dest.starts_with(&canonical_source) {
240                        tracing::warn!(
241                            "Skipping entry that would escape destination: {}",
242                            path.display()
243                        );
244                        continue;
245                    }
246                } else if let Some(parent) = dest_path.parent() {
247                    // For files that don't exist yet, check the parent
248                    if matches!(parent.canonicalize(), Ok(canonical_parent) if !canonical_parent.starts_with(&canonical_source))
249                    {
250                        tracing::warn!(
251                            "Skipping entry with parent outside destination: {}",
252                            path.display()
253                        );
254                        continue;
255                    }
256                }
257
258                if let Some(parent) = dest_path.parent() {
259                    std::fs::create_dir_all(parent)?;
260                }
261
262                entry.unpack(&dest_path)?;
263            }
264        }
265
266        // Clean up temp file
267        std::fs::remove_file(&temp_file_path).ok();
268
269        // Save metadata for the cached crate
270        self.storage.save_metadata(name, version)?;
271
272        tracing::info!("Successfully downloaded and extracted {}-{}", name, version);
273        Ok(source_path)
274    }
275
276    /// Download a crate from GitHub repository
277    async fn download_from_github(
278        &self,
279        name: &str,
280        version: &str,
281        repo_url: &str,
282        repo_path: Option<&str>,
283    ) -> Result<PathBuf> {
284        // Check if already cached
285        if self.storage.is_cached(name, version) {
286            tracing::info!("Crate {}-{} already cached", name, version);
287            return self.storage.source_path(name, version);
288        }
289
290        // Create a lock file to prevent concurrent downloads
291        let crate_path = self.storage.crate_path(name, version)?;
292        let lock_path = crate_path.with_extension("lock");
293
294        // Check if another process is already downloading
295        if lock_path.exists() {
296            tracing::info!(
297                "Another process is downloading {}-{}, waiting...",
298                name,
299                version
300            );
301            // Wait for the other process to finish (simple polling)
302            let start = std::time::Instant::now();
303            while lock_path.exists()
304                && start.elapsed() < std::time::Duration::from_secs(LOCK_TIMEOUT_SECS)
305            {
306                tokio::time::sleep(std::time::Duration::from_millis(LOCK_POLL_INTERVAL_MS)).await;
307            }
308
309            // Check if it was successfully cached by the other process
310            if self.storage.is_cached(name, version) {
311                tracing::info!("Crate {}-{} was cached by another process", name, version);
312                return self.storage.source_path(name, version);
313            }
314        }
315
316        // Create lock file
317        if let Some(parent) = lock_path.parent() {
318            self.storage.ensure_dir(parent)?;
319        }
320        std::fs::write(&lock_path, "downloading").context("Failed to create lock file")?;
321
322        // Ensure lock file is removed on exit
323        let _lock_guard = LockGuard {
324            path: lock_path.clone(),
325        };
326
327        tracing::info!(
328            "Downloading crate {}-{} from GitHub: {}",
329            name,
330            version,
331            repo_url
332        );
333
334        let temp_dir = std::env::temp_dir().join(format!("rust-docs-mcp-git-{name}-{version}"));
335
336        // Clean up any existing temp directory
337        if temp_dir.exists() {
338            fs::remove_dir_all(&temp_dir).context("Failed to clean temp directory")?;
339        }
340
341        // Set up GitHub authentication if token is available
342        let github_token = env::var("GITHUB_TOKEN").ok().map(Zeroizing::new);
343        let has_token = github_token.is_some();
344
345        // Configure git authentication callbacks
346        let mut fetch_options = FetchOptions::new();
347        let mut callbacks = RemoteCallbacks::new();
348
349        if let Some(token) = github_token {
350            tracing::debug!("Using GITHUB_TOKEN for authentication");
351            callbacks.credentials(move |_url, username_from_url, _allowed_types| {
352                Cred::userpass_plaintext(username_from_url.unwrap_or("git"), &token)
353            });
354        } else {
355            tracing::debug!("No GITHUB_TOKEN found, using unauthenticated access");
356        }
357
358        fetch_options.remote_callbacks(callbacks);
359
360        // Clone the repository with authentication
361        let mut builder = git2::build::RepoBuilder::new();
362        builder.fetch_options(fetch_options);
363
364        let repo = builder
365            .clone(repo_url, &temp_dir)
366            .with_context(|| {
367                let mut msg = format!("Failed to clone repository: {repo_url}");
368                if !has_token && repo_url.contains("github.com") {
369                    msg.push_str("\nNote: Set GITHUB_TOKEN environment variable for private repositories and higher rate limits");
370                }
371                msg
372            })?;
373
374        // Checkout the specific branch or tag (version contains the branch/tag name)
375        // The version parameter here is actually the branch or tag name
376        if version != "main" && version != "master" {
377            // Validate git reference name to prevent potential issues
378            if !Self::is_valid_git_ref(version) {
379                bail!("Invalid git reference name: {}", version);
380            }
381
382            // Try to checkout as a branch first
383            let refname = format!("refs/remotes/origin/{version}");
384            if let Ok(reference) = repo.find_reference(&refname) {
385                let oid = reference
386                    .target()
387                    .ok_or_else(|| anyhow::anyhow!("Reference has no target"))?;
388                repo.set_head_detached(oid)
389                    .with_context(|| format!("Failed to checkout branch: {version}"))?;
390                repo.checkout_head(Some(git2::build::CheckoutBuilder::default().force()))
391                    .with_context(|| format!("Failed to checkout branch: {version}"))?;
392            } else {
393                // Try as a tag
394                let tag_ref = format!("refs/tags/{version}");
395                if let Ok(reference) = repo.find_reference(&tag_ref) {
396                    let oid = reference
397                        .target()
398                        .ok_or_else(|| anyhow::anyhow!("Reference has no target"))?;
399                    repo.set_head_detached(oid)
400                        .with_context(|| format!("Failed to checkout tag: {version}"))?;
401                    repo.checkout_head(Some(git2::build::CheckoutBuilder::default().force()))
402                        .with_context(|| format!("Failed to checkout tag: {version}"))?;
403                } else {
404                    bail!("Could not find branch or tag: {}", version);
405                }
406            }
407        }
408
409        // Determine source path within the repository
410        let repo_source_path = if let Some(path) = repo_path {
411            temp_dir.join(path)
412        } else {
413            temp_dir.clone()
414        };
415
416        // Verify Cargo.toml exists
417        let cargo_toml = repo_source_path.join(CARGO_TOML);
418        if !cargo_toml.exists() {
419            bail!(
420                "No Cargo.toml found at path: {}",
421                repo_source_path.display()
422            );
423        }
424
425        // Copy to cache location
426        let source_path = self.storage.source_path(name, version)?;
427        self.storage.ensure_dir(&source_path)?;
428
429        copy_directory_contents(&repo_source_path, &source_path)
430            .context("Failed to copy repository contents")?;
431
432        // Clean up temp directory
433        fs::remove_dir_all(&temp_dir).ok();
434
435        // Save metadata with source information
436        let source_info = match repo_path {
437            Some(path) => format!("{repo_url}#{path}"),
438            None => repo_url.to_string(),
439        };
440        self.storage.save_metadata_with_source(
441            name,
442            version,
443            "github",
444            Some(&source_info),
445            None,
446        )?;
447
448        tracing::info!(
449            "Successfully downloaded and extracted {}-{} from GitHub",
450            name,
451            version
452        );
453        Ok(source_path)
454    }
455
456    /// Copy a crate from local file system
457    async fn copy_from_local(
458        &self,
459        name: &str,
460        version: &str,
461        local_path: &str,
462    ) -> Result<PathBuf> {
463        tracing::info!(
464            "Copying crate {}-{} from local path: {}",
465            name,
466            version,
467            local_path
468        );
469
470        // Expand tilde and other shell expansions
471        let expanded_path = shellexpand::full(local_path)
472            .with_context(|| format!("Failed to expand path: {local_path}"))?;
473        let source_path_input = Path::new(expanded_path.as_ref());
474
475        // Verify the path exists and contains Cargo.toml
476        if !source_path_input.exists() {
477            bail!("Local path does not exist: {}", source_path_input.display());
478        }
479
480        let cargo_toml = source_path_input.join(CARGO_TOML);
481        if !cargo_toml.exists() {
482            bail!(
483                "No Cargo.toml found at path: {}",
484                source_path_input.display()
485            );
486        }
487
488        // Copy to cache location
489        let source_path = self.storage.source_path(name, version)?;
490        self.storage.ensure_dir(&source_path)?;
491
492        copy_directory_contents(source_path_input, &source_path)
493            .context("Failed to copy local directory contents")?;
494
495        // Save metadata with source information
496        self.storage
497            .save_metadata_with_source(name, version, "local", Some(local_path), None)?;
498
499        tracing::info!("Successfully copied {}-{} from local path", name, version);
500        Ok(source_path)
501    }
502
503    /// Validate git reference name to prevent potential issues
504    fn is_valid_git_ref(ref_name: &str) -> bool {
505        // Git references must not:
506        // - Be empty
507        // - Contain ".." (directory traversal)
508        // - Start or end with dots or slashes
509        // - Contain control characters or spaces
510        // - Contain characters that could be problematic in shell contexts
511
512        if ref_name.is_empty() || ref_name.contains("..") {
513            return false;
514        }
515
516        if ref_name.starts_with('.')
517            || ref_name.ends_with('.')
518            || ref_name.starts_with('/')
519            || ref_name.ends_with('/')
520        {
521            return false;
522        }
523
524        // Allow alphanumeric, dots, slashes, hyphens, underscores
525        // Common for tags like "v1.0.0" or branches like "feature/new-thing"
526        ref_name.chars().all(|c| {
527            c.is_alphanumeric() || c == '-' || c == '_' || c == '.' || c == '/' || c == '+' // Allow for version tags like "1.0.0+20240621"
528        })
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535    use tempfile::TempDir;
536
537    #[test]
538    fn test_downloader_creation() {
539        let temp_dir = TempDir::new().unwrap();
540        let storage = CacheStorage::new(Some(temp_dir.path().to_path_buf())).unwrap();
541        let downloader = CrateDownloader::new(storage);
542
543        // Just verify it was created successfully
544        assert!(format!("{downloader:?}").contains("CrateDownloader"));
545    }
546
547    #[tokio::test]
548    async fn test_user_agent_set() {
549        // Initialize logging for the test
550        let _ = tracing_subscriber::fmt()
551            .with_env_filter("rust_docs_mcp=debug")
552            .try_init();
553
554        // Create a temporary directory for testing
555        let temp_dir = TempDir::new().unwrap();
556        let storage = CacheStorage::new(Some(temp_dir.path().to_path_buf())).unwrap();
557
558        // Create downloader
559        let downloader = CrateDownloader::new(storage);
560
561        // Test that download doesn't fail with 403
562        // Note: This is an integration test that requires internet access
563        match downloader.download_crate("serde", "1.0.0").await {
564            Ok(path) => {
565                assert!(path.exists());
566                println!("Successfully downloaded crate to: {path:?}");
567            }
568            Err(e) => {
569                // If it fails, it should not be a 403 error
570                let error_msg = format!("{e}");
571                assert!(!error_msg.contains("403"), "Got 403 error: {error_msg}");
572            }
573        }
574    }
575
576    #[tokio::test]
577    async fn test_problematic_crate_download() {
578        // Initialize logging for the test
579        let _ = tracing_subscriber::fmt()
580            .with_env_filter("rust_docs_mcp=debug")
581            .try_init();
582
583        // Test downloading the specific crate that was failing
584        let temp_dir = TempDir::new().unwrap();
585        let storage = CacheStorage::new(Some(temp_dir.path().to_path_buf())).unwrap();
586        let downloader = CrateDownloader::new(storage);
587
588        match downloader
589            .download_crate("google-sheets4", "6.0.0+20240621")
590            .await
591        {
592            Ok(path) => {
593                assert!(path.exists());
594                println!("Successfully downloaded google-sheets4-6.0.0+20240621 to: {path:?}");
595            }
596            Err(e) => {
597                panic!("Failed to download google-sheets4: {e}");
598            }
599        }
600    }
601}