Skip to main content

trojan_rules/provider/
http.rs

1//! HTTP-based rule-set provider with local file caching.
2//!
3//! Fetches rule-sets from remote URLs and caches them to the local filesystem.
4//! On fetch failure, falls back to the cached version if available.
5
6use std::path::{Path, PathBuf};
7use std::time::Duration;
8
9use crate::error::RulesError;
10use crate::rule::ParsedRule;
11
12/// Default HTTP request timeout.
13const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
14
15/// Provider that fetches rule-sets from HTTP/HTTPS URLs with local caching.
16#[derive(Debug)]
17pub struct HttpProvider {
18    url: String,
19    cache_path: Option<PathBuf>,
20    format: String,
21    behavior: Option<String>,
22    timeout: Duration,
23}
24
25impl HttpProvider {
26    /// Create a new HTTP provider.
27    ///
28    /// - `url`: Remote URL to fetch the rule-set from.
29    /// - `cache_path`: Optional local path to cache the fetched content.
30    /// - `format`: Rule-set format ("surge" or "clash").
31    /// - `behavior`: Optional behavior hint ("domain", "ipcidr", "classical", "domain-set").
32    pub fn new(
33        url: impl Into<String>,
34        cache_path: Option<PathBuf>,
35        format: impl Into<String>,
36        behavior: Option<String>,
37    ) -> Self {
38        Self {
39            url: url.into(),
40            cache_path,
41            format: format.into(),
42            behavior,
43            timeout: DEFAULT_TIMEOUT,
44        }
45    }
46
47    /// Set the HTTP request timeout.
48    pub fn with_timeout(mut self, timeout: Duration) -> Self {
49        self.timeout = timeout;
50        self
51    }
52
53    /// Fetch the rule-set content from the remote URL.
54    pub async fn fetch(&self) -> Result<String, RulesError> {
55        tracing::debug!(url = %self.url, "fetching remote rule-set");
56
57        let client = reqwest::Client::builder()
58            .timeout(self.timeout)
59            .build()
60            .map_err(|e| RulesError::Http(format!("failed to build HTTP client: {e}")))?;
61
62        let response =
63            client.get(&self.url).send().await.map_err(|e| {
64                RulesError::Http(format!("HTTP request failed for {}: {e}", self.url))
65            })?;
66
67        let status = response.status();
68        if !status.is_success() {
69            return Err(RulesError::Http(format!(
70                "HTTP {} for {}",
71                status, self.url
72            )));
73        }
74
75        let content = response
76            .text()
77            .await
78            .map_err(|e| RulesError::Http(format!("failed to read response body: {e}")))?;
79
80        tracing::debug!(url = %self.url, bytes = content.len(), "fetched remote rule-set");
81
82        // Update cache if path is configured
83        if let Some(ref cache_path) = self.cache_path
84            && let Err(e) = write_cache(cache_path, &content).await
85        {
86            tracing::warn!(path = %cache_path.display(), error = %e, "failed to write cache");
87        }
88
89        Ok(content)
90    }
91
92    /// Load the rule-set: try fetching from URL, fall back to cache on failure.
93    pub async fn load(&self) -> Result<Vec<ParsedRule>, RulesError> {
94        match self.fetch().await {
95            Ok(content) => self.parse(&content),
96            Err(fetch_err) => {
97                // Try cache fallback
98                if let Some(ref cache_path) = self.cache_path
99                    && cache_path.exists()
100                {
101                    tracing::warn!(
102                        url = %self.url,
103                        error = %fetch_err,
104                        cache = %cache_path.display(),
105                        "fetch failed, using cached rules"
106                    );
107                    let content = tokio::fs::read_to_string(cache_path)
108                        .await
109                        .map_err(RulesError::Io)?;
110                    return self.parse(&content);
111                }
112                Err(fetch_err)
113            }
114        }
115    }
116
117    /// Load from cache only (for startup before first fetch).
118    pub fn load_cached(&self) -> Result<Option<Vec<ParsedRule>>, RulesError> {
119        match &self.cache_path {
120            Some(path) if path.exists() => {
121                let content = std::fs::read_to_string(path)?;
122                Ok(Some(self.parse(&content)?))
123            }
124            _ => Ok(None),
125        }
126    }
127
128    /// Parse content using the configured format and behavior.
129    fn parse(&self, content: &str) -> Result<Vec<ParsedRule>, RulesError> {
130        crate::provider::FileProvider::parse(content, &self.format, self.behavior.as_deref())
131    }
132
133    /// Get the URL of this provider.
134    pub fn url(&self) -> &str {
135        &self.url
136    }
137
138    /// Get the cache path of this provider.
139    pub fn cache_path(&self) -> Option<&Path> {
140        self.cache_path.as_deref()
141    }
142}
143
144/// Write content to a cache file atomically (write-to-temp + rename).
145///
146/// This prevents truncated cache files if the process is killed mid-write.
147/// On Windows, the destination is removed first since `rename` fails when
148/// the target already exists.
149async fn write_cache(path: &Path, content: &str) -> Result<(), std::io::Error> {
150    if let Some(parent) = path.parent() {
151        tokio::fs::create_dir_all(parent).await?;
152    }
153    let tmp_path = path.with_extension("tmp");
154    tokio::fs::write(&tmp_path, content).await?;
155    // On Windows, rename fails if the destination exists; remove it first.
156    #[cfg(target_os = "windows")]
157    {
158        let _ = tokio::fs::remove_file(path).await;
159    }
160    tokio::fs::rename(&tmp_path, path).await
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn http_provider_new() {
169        let p = HttpProvider::new(
170            "https://example.com/rules.txt",
171            Some(PathBuf::from("/tmp/cache.txt")),
172            "surge",
173            Some("domain-set".to_string()),
174        );
175        assert_eq!(p.url(), "https://example.com/rules.txt");
176        assert_eq!(p.cache_path(), Some(Path::new("/tmp/cache.txt")));
177    }
178
179    #[test]
180    fn http_provider_parse_surge() {
181        let p = HttpProvider::new("http://example.com", None, "surge", None);
182        let rules = p
183            .parse("DOMAIN,example.com\nDOMAIN-SUFFIX,test.com")
184            .unwrap();
185        assert_eq!(rules.len(), 2);
186    }
187
188    #[test]
189    fn http_provider_parse_clash() {
190        let p = HttpProvider::new(
191            "http://example.com",
192            None,
193            "clash",
194            Some("domain".to_string()),
195        );
196        let content = "payload:\n  - 'example.com'\n  - '+.test.com'";
197        let rules = p.parse(content).unwrap();
198        assert_eq!(rules.len(), 2);
199    }
200
201    #[test]
202    fn load_cached_no_path() {
203        let p = HttpProvider::new("http://example.com", None, "surge", None);
204        assert!(p.load_cached().unwrap().is_none());
205    }
206
207    #[test]
208    fn load_cached_nonexistent_path() {
209        let p = HttpProvider::new(
210            "http://example.com",
211            Some(PathBuf::from("/nonexistent/path/rules.txt")),
212            "surge",
213            None,
214        );
215        assert!(p.load_cached().unwrap().is_none());
216    }
217}