trojan_rules/provider/
http.rs1use std::path::{Path, PathBuf};
7use std::time::Duration;
8
9use crate::error::RulesError;
10use crate::rule::ParsedRule;
11
12const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
14
15#[derive(Debug)]
17pub struct HttpProvider {
18 url: String,
19 cache_path: Option<PathBuf>,
20 format: String,
21 behavior: Option<String>,
22 timeout: Duration,
23}
24
25impl HttpProvider {
26 pub fn new(
33 url: impl Into<String>,
34 cache_path: Option<PathBuf>,
35 format: impl Into<String>,
36 behavior: Option<String>,
37 ) -> Self {
38 Self {
39 url: url.into(),
40 cache_path,
41 format: format.into(),
42 behavior,
43 timeout: DEFAULT_TIMEOUT,
44 }
45 }
46
47 pub fn with_timeout(mut self, timeout: Duration) -> Self {
49 self.timeout = timeout;
50 self
51 }
52
53 pub async fn fetch(&self) -> Result<String, RulesError> {
55 tracing::debug!(url = %self.url, "fetching remote rule-set");
56
57 let client = reqwest::Client::builder()
58 .timeout(self.timeout)
59 .build()
60 .map_err(|e| RulesError::Http(format!("failed to build HTTP client: {e}")))?;
61
62 let response =
63 client.get(&self.url).send().await.map_err(|e| {
64 RulesError::Http(format!("HTTP request failed for {}: {e}", self.url))
65 })?;
66
67 let status = response.status();
68 if !status.is_success() {
69 return Err(RulesError::Http(format!(
70 "HTTP {} for {}",
71 status, self.url
72 )));
73 }
74
75 let content = response
76 .text()
77 .await
78 .map_err(|e| RulesError::Http(format!("failed to read response body: {e}")))?;
79
80 tracing::debug!(url = %self.url, bytes = content.len(), "fetched remote rule-set");
81
82 if let Some(ref cache_path) = self.cache_path
84 && let Err(e) = write_cache(cache_path, &content).await
85 {
86 tracing::warn!(path = %cache_path.display(), error = %e, "failed to write cache");
87 }
88
89 Ok(content)
90 }
91
92 pub async fn load(&self) -> Result<Vec<ParsedRule>, RulesError> {
94 match self.fetch().await {
95 Ok(content) => self.parse(&content),
96 Err(fetch_err) => {
97 if let Some(ref cache_path) = self.cache_path
99 && cache_path.exists()
100 {
101 tracing::warn!(
102 url = %self.url,
103 error = %fetch_err,
104 cache = %cache_path.display(),
105 "fetch failed, using cached rules"
106 );
107 let content = tokio::fs::read_to_string(cache_path)
108 .await
109 .map_err(RulesError::Io)?;
110 return self.parse(&content);
111 }
112 Err(fetch_err)
113 }
114 }
115 }
116
117 pub fn load_cached(&self) -> Result<Option<Vec<ParsedRule>>, RulesError> {
119 match &self.cache_path {
120 Some(path) if path.exists() => {
121 let content = std::fs::read_to_string(path)?;
122 Ok(Some(self.parse(&content)?))
123 }
124 _ => Ok(None),
125 }
126 }
127
128 fn parse(&self, content: &str) -> Result<Vec<ParsedRule>, RulesError> {
130 crate::provider::FileProvider::parse(content, &self.format, self.behavior.as_deref())
131 }
132
133 pub fn url(&self) -> &str {
135 &self.url
136 }
137
138 pub fn cache_path(&self) -> Option<&Path> {
140 self.cache_path.as_deref()
141 }
142}
143
144async fn write_cache(path: &Path, content: &str) -> Result<(), std::io::Error> {
150 if let Some(parent) = path.parent() {
151 tokio::fs::create_dir_all(parent).await?;
152 }
153 let tmp_path = path.with_extension("tmp");
154 tokio::fs::write(&tmp_path, content).await?;
155 #[cfg(target_os = "windows")]
157 {
158 let _ = tokio::fs::remove_file(path).await;
159 }
160 tokio::fs::rename(&tmp_path, path).await
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn http_provider_new() {
169 let p = HttpProvider::new(
170 "https://example.com/rules.txt",
171 Some(PathBuf::from("/tmp/cache.txt")),
172 "surge",
173 Some("domain-set".to_string()),
174 );
175 assert_eq!(p.url(), "https://example.com/rules.txt");
176 assert_eq!(p.cache_path(), Some(Path::new("/tmp/cache.txt")));
177 }
178
179 #[test]
180 fn http_provider_parse_surge() {
181 let p = HttpProvider::new("http://example.com", None, "surge", None);
182 let rules = p
183 .parse("DOMAIN,example.com\nDOMAIN-SUFFIX,test.com")
184 .unwrap();
185 assert_eq!(rules.len(), 2);
186 }
187
188 #[test]
189 fn http_provider_parse_clash() {
190 let p = HttpProvider::new(
191 "http://example.com",
192 None,
193 "clash",
194 Some("domain".to_string()),
195 );
196 let content = "payload:\n - 'example.com'\n - '+.test.com'";
197 let rules = p.parse(content).unwrap();
198 assert_eq!(rules.len(), 2);
199 }
200
201 #[test]
202 fn load_cached_no_path() {
203 let p = HttpProvider::new("http://example.com", None, "surge", None);
204 assert!(p.load_cached().unwrap().is_none());
205 }
206
207 #[test]
208 fn load_cached_nonexistent_path() {
209 let p = HttpProvider::new(
210 "http://example.com",
211 Some(PathBuf::from("/nonexistent/path/rules.txt")),
212 "surge",
213 None,
214 );
215 assert!(p.load_cached().unwrap().is_none());
216 }
217}