1use serde::{Deserialize, Serialize};
7
8pub const CURRENT_VERSION: &str = env!("CARGO_PKG_VERSION");
10
11const GITHUB_OWNER: &str = "DevJadhav";
13const GITHUB_REPO: &str = "Rustant";
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct UpdateConfig {
18 #[serde(default = "default_true")]
20 pub auto_check: bool,
21 #[serde(default = "default_check_interval")]
23 pub check_interval_hours: u64,
24 #[serde(default = "default_channel")]
26 pub channel: String,
27}
28
29fn default_true() -> bool {
30 true
31}
32
33fn default_check_interval() -> u64 {
34 24
35}
36
37fn default_channel() -> String {
38 "stable".into()
39}
40
41impl Default for UpdateConfig {
42 fn default() -> Self {
43 Self {
44 auto_check: true,
45 check_interval_hours: 24,
46 channel: "stable".into(),
47 }
48 }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct UpdateCheckResult {
54 pub current_version: String,
56 pub latest_version: Option<String>,
58 pub update_available: bool,
60 pub release_url: Option<String>,
62 pub release_notes: Option<String>,
64}
65
66pub struct UpdateChecker {
68 config: UpdateConfig,
69}
70
71impl UpdateChecker {
72 pub fn new(config: UpdateConfig) -> Self {
74 Self { config }
75 }
76
77 pub async fn check(&self) -> Result<UpdateCheckResult, UpdateError> {
79 let client = reqwest::Client::builder()
80 .user_agent(format!("rustant/{}", CURRENT_VERSION))
81 .build()
82 .map_err(|e| UpdateError::NetworkError(e.to_string()))?;
83
84 let url = format!(
85 "https://api.github.com/repos/{}/{}/releases/latest",
86 GITHUB_OWNER, GITHUB_REPO
87 );
88
89 let response = client
90 .get(&url)
91 .send()
92 .await
93 .map_err(|e| UpdateError::NetworkError(e.to_string()))?;
94
95 if !response.status().is_success() {
96 return Err(UpdateError::NetworkError(format!(
97 "GitHub API returned status {}",
98 response.status()
99 )));
100 }
101
102 let release: GitHubRelease = response
103 .json()
104 .await
105 .map_err(|e| UpdateError::ParseError(e.to_string()))?;
106
107 let latest_version = release.tag_name.trim_start_matches('v').to_string();
108 let update_available = is_newer_version(&latest_version, CURRENT_VERSION);
109
110 if self.config.channel == "stable" && release.prerelease {
112 return Ok(UpdateCheckResult {
113 current_version: CURRENT_VERSION.into(),
114 latest_version: Some(latest_version),
115 update_available: false,
116 release_url: Some(release.html_url),
117 release_notes: Some(release.body.unwrap_or_default()),
118 });
119 }
120
121 Ok(UpdateCheckResult {
122 current_version: CURRENT_VERSION.into(),
123 latest_version: Some(latest_version),
124 update_available,
125 release_url: Some(release.html_url),
126 release_notes: Some(release.body.unwrap_or_default()),
127 })
128 }
129
130 pub fn config(&self) -> &UpdateConfig {
132 &self.config
133 }
134}
135
136pub struct Updater;
138
139impl Updater {
140 pub fn update() -> Result<(), UpdateError> {
142 let status = self_update::backends::github::Update::configure()
143 .repo_owner(GITHUB_OWNER)
144 .repo_name(GITHUB_REPO)
145 .bin_name("rustant")
146 .current_version(CURRENT_VERSION)
147 .show_output(true)
148 .show_download_progress(true)
149 .build()
150 .map_err(|e| UpdateError::UpdateFailed(e.to_string()))?
151 .update()
152 .map_err(|e| UpdateError::UpdateFailed(e.to_string()))?;
153
154 tracing::info!(
155 old_version = CURRENT_VERSION,
156 new_version = %status.version(),
157 "Updated successfully"
158 );
159
160 Ok(())
161 }
162}
163
164#[derive(Debug, thiserror::Error)]
166pub enum UpdateError {
167 #[error("Network error: {0}")]
168 NetworkError(String),
169 #[error("Parse error: {0}")]
170 ParseError(String),
171 #[error("Update failed: {0}")]
172 UpdateFailed(String),
173}
174
175#[derive(Debug, Deserialize)]
177struct GitHubRelease {
178 tag_name: String,
179 html_url: String,
180 body: Option<String>,
181 prerelease: bool,
182}
183
184pub fn is_newer_version(latest: &str, current: &str) -> bool {
186 let latest_parts: Vec<u32> = latest.split('.').filter_map(|p| p.parse().ok()).collect();
187 let current_parts: Vec<u32> = current.split('.').filter_map(|p| p.parse().ok()).collect();
188
189 for i in 0..3 {
190 let l = latest_parts.get(i).copied().unwrap_or(0);
191 let c = current_parts.get(i).copied().unwrap_or(0);
192 if l > c {
193 return true;
194 }
195 if l < c {
196 return false;
197 }
198 }
199 false }
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_is_newer_version() {
208 assert!(is_newer_version("1.1.0", "1.0.0"));
209 assert!(is_newer_version("2.0.0", "1.9.9"));
210 assert!(is_newer_version("0.2.0", "0.1.0"));
211 assert!(is_newer_version("0.1.1", "0.1.0"));
212 }
213
214 #[test]
215 fn test_is_not_newer_version() {
216 assert!(!is_newer_version("1.0.0", "1.0.0"));
217 assert!(!is_newer_version("0.9.0", "1.0.0"));
218 assert!(!is_newer_version("0.1.0", "0.2.0"));
219 }
220
221 #[test]
222 fn test_update_config_defaults() {
223 let config = UpdateConfig::default();
224 assert!(config.auto_check);
225 assert_eq!(config.check_interval_hours, 24);
226 assert_eq!(config.channel, "stable");
227 }
228
229 #[test]
230 fn test_update_config_serialization() {
231 let config = UpdateConfig {
232 auto_check: false,
233 check_interval_hours: 12,
234 channel: "beta".into(),
235 };
236 let json = serde_json::to_string(&config).unwrap();
237 let restored: UpdateConfig = serde_json::from_str(&json).unwrap();
238 assert!(!restored.auto_check);
239 assert_eq!(restored.check_interval_hours, 12);
240 assert_eq!(restored.channel, "beta");
241 }
242
243 #[test]
244 fn test_update_check_result_serialization() {
245 let result = UpdateCheckResult {
246 current_version: "0.1.0".into(),
247 latest_version: Some("0.2.0".into()),
248 update_available: true,
249 release_url: Some("https://github.com/DevJadhav/Rustant/releases/v0.2.0".into()),
250 release_notes: Some("New features".into()),
251 };
252 let json = serde_json::to_string(&result).unwrap();
253 let restored: UpdateCheckResult = serde_json::from_str(&json).unwrap();
254 assert!(restored.update_available);
255 assert_eq!(restored.latest_version, Some("0.2.0".into()));
256 }
257
258 #[test]
259 fn test_current_version_defined() {
260 assert!(!CURRENT_VERSION.is_empty());
261 }
262
263 #[test]
264 fn test_update_checker_creation() {
265 let config = UpdateConfig::default();
266 let checker = UpdateChecker::new(config);
267 assert!(checker.config().auto_check);
268 }
269
270 #[test]
271 fn test_version_comparison_edge_cases() {
272 assert!(!is_newer_version("0.0.0", "0.0.0"));
273 assert!(is_newer_version("0.0.1", "0.0.0"));
274 assert!(is_newer_version("10.0.0", "9.9.9"));
275 }
276}