1use std::fs;
47use std::path::PathBuf;
48use std::time::{Duration, SystemTime};
49
50#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct UpdateInfo {
53 pub current: String,
55 pub latest: String,
57}
58
59#[derive(Debug)]
61pub enum Error {
62 HttpError(String),
64 ParseError(String),
66 VersionError(String),
68 CacheError(String),
70}
71
72impl std::fmt::Display for Error {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 match self {
75 Self::HttpError(msg) => write!(f, "HTTP error: {msg}"),
76 Self::ParseError(msg) => write!(f, "Parse error: {msg}"),
77 Self::VersionError(msg) => write!(f, "Version error: {msg}"),
78 Self::CacheError(msg) => write!(f, "Cache error: {msg}"),
79 }
80 }
81}
82
83impl std::error::Error for Error {}
84
85#[derive(Debug, Clone)]
100pub struct UpdateChecker {
101 crate_name: String,
102 current_version: String,
103 cache_duration: Duration,
104 timeout: Duration,
105 cache_dir: Option<PathBuf>,
106}
107
108impl UpdateChecker {
109 #[must_use]
116 pub fn new(crate_name: impl Into<String>, current_version: impl Into<String>) -> Self {
117 Self {
118 crate_name: crate_name.into(),
119 current_version: current_version.into(),
120 cache_duration: Duration::from_secs(24 * 60 * 60), timeout: Duration::from_secs(5),
122 cache_dir: dirs::cache_dir(),
123 }
124 }
125
126 #[must_use]
130 pub const fn cache_duration(mut self, duration: Duration) -> Self {
131 self.cache_duration = duration;
132 self
133 }
134
135 #[must_use]
137 pub const fn timeout(mut self, timeout: Duration) -> Self {
138 self.timeout = timeout;
139 self
140 }
141
142 #[must_use]
146 pub fn cache_dir(mut self, dir: Option<PathBuf>) -> Self {
147 self.cache_dir = dir;
148 self
149 }
150
151 pub fn check(&self) -> Result<Option<UpdateInfo>, Error> {
162 let latest = self.get_latest_version()?;
163
164 let current = semver::Version::parse(&self.current_version)
165 .map_err(|e| Error::VersionError(format!("Invalid current version: {e}")))?;
166 let latest_ver = semver::Version::parse(&latest)
167 .map_err(|e| Error::VersionError(format!("Invalid latest version: {e}")))?;
168
169 if latest_ver > current {
170 Ok(Some(UpdateInfo {
171 current: self.current_version.clone(),
172 latest,
173 }))
174 } else {
175 Ok(None)
176 }
177 }
178
179 fn get_latest_version(&self) -> Result<String, Error> {
181 let cache_path = self.cache_path();
182
183 if self.cache_duration > Duration::ZERO {
185 if let Some(ref path) = cache_path {
186 if let Some(cached) = self.read_cache(path) {
187 return Ok(cached);
188 }
189 }
190 }
191
192 let latest = self.fetch_latest_version()?;
194
195 if let Some(ref path) = cache_path {
197 let _ = fs::write(path, &latest);
198 }
199
200 Ok(latest)
201 }
202
203 fn cache_path(&self) -> Option<PathBuf> {
205 self.cache_dir
206 .as_ref()
207 .map(|d| d.join(format!("{}-update-check", self.crate_name)))
208 }
209
210 fn read_cache(&self, path: &PathBuf) -> Option<String> {
212 let metadata = fs::metadata(path).ok()?;
213 let modified = metadata.modified().ok()?;
214 let age = SystemTime::now().duration_since(modified).ok()?;
215
216 if age < self.cache_duration {
217 fs::read_to_string(path).ok().map(|s| s.trim().to_string())
218 } else {
219 None
220 }
221 }
222
223 fn fetch_latest_version(&self) -> Result<String, Error> {
225 let url = format!("https://crates.io/api/v1/crates/{}", self.crate_name);
226
227 let tls_config = build_tls_config();
228 let config = ureq::Agent::config_builder()
229 .timeout_global(Some(self.timeout))
230 .user_agent(concat!(
231 env!("CARGO_PKG_NAME"),
232 "/",
233 env!("CARGO_PKG_VERSION")
234 ))
235 .tls_config(tls_config)
236 .build();
237 let agent: ureq::Agent = config.into();
238
239 let body = agent
240 .get(&url)
241 .call()
242 .map_err(|e| Error::HttpError(e.to_string()))?
243 .into_body()
244 .read_to_string()
245 .map_err(|e| Error::HttpError(e.to_string()))?;
246
247 let marker = r#""newest_version":""#;
249 let start = body
250 .find(marker)
251 .ok_or_else(|| Error::ParseError("newest_version not found".to_string()))?
252 + marker.len();
253 let end = body[start..]
254 .find('"')
255 .ok_or_else(|| Error::ParseError("version end quote not found".to_string()))?
256 + start;
257
258 Ok(body[start..end].to_string())
259 }
260}
261
262fn build_tls_config() -> ureq::tls::TlsConfig {
264 #[cfg(feature = "native-tls")]
265 {
266 ureq::tls::TlsConfig::builder()
267 .provider(ureq::tls::TlsProvider::NativeTls)
268 .build()
269 }
270
271 #[cfg(all(feature = "rustls", not(feature = "native-tls")))]
272 {
273 ureq::tls::TlsConfig::builder()
274 .provider(ureq::tls::TlsProvider::Rustls)
275 .build()
276 }
277
278 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
279 {
280 compile_error!("Either 'native-tls' or 'rustls' feature must be enabled");
281 }
282}
283
284pub fn check(
298 crate_name: impl Into<String>,
299 current_version: impl Into<String>,
300) -> Result<Option<UpdateInfo>, Error> {
301 UpdateChecker::new(crate_name, current_version).check()
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_update_info_display() {
310 let info = UpdateInfo {
311 current: "1.0.0".to_string(),
312 latest: "2.0.0".to_string(),
313 };
314 assert_eq!(info.current, "1.0.0");
315 assert_eq!(info.latest, "2.0.0");
316 }
317
318 #[test]
319 fn test_checker_builder() {
320 let checker = UpdateChecker::new("test-crate", "1.0.0")
321 .cache_duration(Duration::from_secs(3600))
322 .timeout(Duration::from_secs(10));
323
324 assert_eq!(checker.crate_name, "test-crate");
325 assert_eq!(checker.current_version, "1.0.0");
326 assert_eq!(checker.cache_duration, Duration::from_secs(3600));
327 assert_eq!(checker.timeout, Duration::from_secs(10));
328 }
329
330 #[test]
331 fn test_cache_disabled() {
332 let checker = UpdateChecker::new("test-crate", "1.0.0")
333 .cache_duration(Duration::ZERO)
334 .cache_dir(None);
335
336 assert_eq!(checker.cache_duration, Duration::ZERO);
337 assert!(checker.cache_dir.is_none());
338 }
339
340 #[test]
341 fn test_error_display() {
342 let err = Error::HttpError("connection failed".to_string());
343 assert_eq!(err.to_string(), "HTTP error: connection failed");
344
345 let err = Error::ParseError("invalid json".to_string());
346 assert_eq!(err.to_string(), "Parse error: invalid json");
347 }
348}