Skip to main content

sr_core/
github.rs

1use crate::error::ReleaseError;
2use crate::release::VcsProvider;
3
4/// GitHub implementation of the VcsProvider trait using the GitHub REST API.
5pub struct GitHubProvider {
6    owner: String,
7    repo: String,
8    hostname: String,
9    token: String,
10}
11
12#[derive(serde::Deserialize)]
13struct ReleaseResponse {
14    id: u64,
15    html_url: String,
16    upload_url: String,
17    #[serde(default)]
18    assets: Vec<ReleaseAsset>,
19}
20
21#[derive(serde::Deserialize, Debug, Clone)]
22struct ReleaseAsset {
23    #[allow(dead_code)]
24    id: u64,
25    name: String,
26    url: String,
27}
28
29impl GitHubProvider {
30    pub fn new(owner: String, repo: String, hostname: String, token: String) -> Self {
31        Self {
32            owner,
33            repo,
34            hostname,
35            token,
36        }
37    }
38
39    fn base_url(&self) -> String {
40        format!("https://{}/{}/{}", self.hostname, self.owner, self.repo)
41    }
42
43    fn api_url(&self) -> String {
44        if self.hostname == "github.com" {
45            "https://api.github.com".to_string()
46        } else {
47            format!("https://{}/api/v3", self.hostname)
48        }
49    }
50
51    fn agent(&self) -> ureq::Agent {
52        ureq::Agent::new_with_config(ureq::config::Config::builder().https_only(true).build())
53    }
54
55    fn get_release_by_tag(&self, tag: &str) -> Result<ReleaseResponse, ReleaseError> {
56        let url = format!(
57            "{}/repos/{}/{}/releases/tags/{tag}",
58            self.api_url(),
59            self.owner,
60            self.repo
61        );
62        let resp = self
63            .agent()
64            .get(&url)
65            .header("Authorization", &format!("Bearer {}", self.token))
66            .header("Accept", "application/vnd.github+json")
67            .header("X-GitHub-Api-Version", "2022-11-28")
68            .header("User-Agent", "sr-github")
69            .call()
70            .map_err(|e| ReleaseError::Vcs(format!("GitHub API GET {url}: {e}")))?;
71        let release: ReleaseResponse = resp
72            .into_body()
73            .read_json()
74            .map_err(|e| ReleaseError::Vcs(format!("failed to parse release response: {e}")))?;
75        Ok(release)
76    }
77}
78
79// ---------------------------------------------------------------------------
80// PR support
81// ---------------------------------------------------------------------------
82
83#[derive(Debug, serde::Deserialize)]
84pub struct PrMetadata {
85    pub number: u64,
86    pub title: String,
87    pub body: Option<String>,
88    pub user: PrUser,
89    pub base: PrRef,
90    pub head: PrRef,
91}
92
93#[derive(Debug, serde::Deserialize)]
94pub struct PrUser {
95    pub login: String,
96}
97
98#[derive(Debug, serde::Deserialize)]
99pub struct PrRef {
100    #[serde(rename = "ref")]
101    pub ref_name: String,
102}
103
104impl GitHubProvider {
105    /// Find the open PR number for a given branch.
106    pub fn get_pr_for_branch(&self, branch: &str) -> Result<PrMetadata, ReleaseError> {
107        let url = format!(
108            "{}/repos/{}/{}/pulls?head={}:{}&state=open&per_page=1",
109            self.api_url(),
110            self.owner,
111            self.repo,
112            self.owner,
113            branch
114        );
115        let resp = self
116            .agent()
117            .get(&url)
118            .header("Authorization", &format!("Bearer {}", self.token))
119            .header("Accept", "application/vnd.github+json")
120            .header("X-GitHub-Api-Version", "2022-11-28")
121            .header("User-Agent", "sr-github")
122            .call()
123            .map_err(|e| ReleaseError::Vcs(format!("GitHub API GET {url}: {e}")))?;
124        let prs: Vec<PrMetadata> = resp
125            .into_body()
126            .read_json()
127            .map_err(|e| ReleaseError::Vcs(format!("failed to parse PR list: {e}")))?;
128        prs.into_iter()
129            .next()
130            .ok_or_else(|| ReleaseError::Vcs(format!("no open PR found for branch '{branch}'")))
131    }
132
133    /// Fetch the diff for a PR.
134    pub fn get_pr_diff(&self, pr_number: u64) -> Result<String, ReleaseError> {
135        let url = format!(
136            "{}/repos/{}/{}/pulls/{pr_number}",
137            self.api_url(),
138            self.owner,
139            self.repo
140        );
141        let resp = self
142            .agent()
143            .get(&url)
144            .header("Authorization", &format!("Bearer {}", self.token))
145            .header("Accept", "application/vnd.github.v3.diff")
146            .header("X-GitHub-Api-Version", "2022-11-28")
147            .header("User-Agent", "sr-github")
148            .call()
149            .map_err(|e| ReleaseError::Vcs(format!("GitHub API GET {url}: {e}")))?;
150        resp.into_body()
151            .read_to_string()
152            .map_err(|e| ReleaseError::Vcs(format!("failed to read PR diff: {e}")))
153    }
154
155    /// Count open PRs.
156    pub fn count_open_prs(&self) -> Result<(usize, usize), ReleaseError> {
157        // Returns (ready_count, draft_count)
158        let url = format!(
159            "{}/repos/{}/{}/pulls?state=open&per_page=100",
160            self.api_url(),
161            self.owner,
162            self.repo
163        );
164        let resp = self
165            .agent()
166            .get(&url)
167            .header("Authorization", &format!("Bearer {}", self.token))
168            .header("Accept", "application/vnd.github+json")
169            .header("X-GitHub-Api-Version", "2022-11-28")
170            .header("User-Agent", "sr-github")
171            .call()
172            .map_err(|e| ReleaseError::Vcs(format!("GitHub API GET {url}: {e}")))?;
173
174        #[derive(serde::Deserialize)]
175        struct MinimalPr {
176            draft: Option<bool>,
177        }
178
179        let prs: Vec<MinimalPr> = resp
180            .into_body()
181            .read_json()
182            .map_err(|e| ReleaseError::Vcs(format!("failed to parse PR list: {e}")))?;
183
184        let draft_count = prs.iter().filter(|p| p.draft == Some(true)).count();
185        let ready_count = prs.len() - draft_count;
186        Ok((ready_count, draft_count))
187    }
188
189    /// Post a review summary on a PR.
190    pub fn post_pr_review(&self, pr_number: u64, body: &str) -> Result<(), ReleaseError> {
191        let url = format!(
192            "{}/repos/{}/{}/pulls/{pr_number}/reviews",
193            self.api_url(),
194            self.owner,
195            self.repo
196        );
197        let payload = serde_json::json!({
198            "body": body,
199            "event": "COMMENT",
200        });
201        self.agent()
202            .post(&url)
203            .header("Authorization", &format!("Bearer {}", self.token))
204            .header("Accept", "application/vnd.github+json")
205            .header("X-GitHub-Api-Version", "2022-11-28")
206            .header("User-Agent", "sr-github")
207            .send_json(&payload)
208            .map_err(|e| ReleaseError::Vcs(format!("GitHub API POST {url}: {e}")))?;
209        Ok(())
210    }
211}
212
213impl VcsProvider for GitHubProvider {
214    fn create_release(
215        &self,
216        tag: &str,
217        name: &str,
218        body: &str,
219        prerelease: bool,
220        draft: bool,
221    ) -> Result<String, ReleaseError> {
222        let url = format!(
223            "{}/repos/{}/{}/releases",
224            self.api_url(),
225            self.owner,
226            self.repo
227        );
228        let payload = serde_json::json!({
229            "tag_name": tag,
230            "name": name,
231            "body": body,
232            "prerelease": prerelease,
233            "draft": draft,
234        });
235
236        let resp = self
237            .agent()
238            .post(&url)
239            .header("Authorization", &format!("Bearer {}", self.token))
240            .header("Accept", "application/vnd.github+json")
241            .header("X-GitHub-Api-Version", "2022-11-28")
242            .header("User-Agent", "sr-github")
243            .send_json(&payload)
244            .map_err(|e| ReleaseError::Vcs(format!("GitHub API POST {url}: {e}")))?;
245
246        let release: ReleaseResponse = resp
247            .into_body()
248            .read_json()
249            .map_err(|e| ReleaseError::Vcs(format!("failed to parse release response: {e}")))?;
250
251        Ok(release.html_url)
252    }
253
254    fn compare_url(&self, base: &str, head: &str) -> Result<String, ReleaseError> {
255        Ok(format!("{}/compare/{base}...{head}", self.base_url()))
256    }
257
258    fn release_exists(&self, tag: &str) -> Result<bool, ReleaseError> {
259        let url = format!(
260            "{}/repos/{}/{}/releases/tags/{tag}",
261            self.api_url(),
262            self.owner,
263            self.repo
264        );
265        match self
266            .agent()
267            .get(&url)
268            .header("Authorization", &format!("Bearer {}", self.token))
269            .header("Accept", "application/vnd.github+json")
270            .header("X-GitHub-Api-Version", "2022-11-28")
271            .header("User-Agent", "sr-github")
272            .call()
273        {
274            Ok(_) => Ok(true),
275            Err(ureq::Error::StatusCode(404)) => Ok(false),
276            Err(e) => Err(ReleaseError::Vcs(format!("GitHub API GET {url}: {e}"))),
277        }
278    }
279
280    fn repo_url(&self) -> Option<String> {
281        Some(self.base_url())
282    }
283
284    fn delete_release(&self, tag: &str) -> Result<(), ReleaseError> {
285        let release = self.get_release_by_tag(tag)?;
286        let url = format!(
287            "{}/repos/{}/{}/releases/{}",
288            self.api_url(),
289            self.owner,
290            self.repo,
291            release.id
292        );
293        self.agent()
294            .delete(&url)
295            .header("Authorization", &format!("Bearer {}", self.token))
296            .header("Accept", "application/vnd.github+json")
297            .header("X-GitHub-Api-Version", "2022-11-28")
298            .header("User-Agent", "sr-github")
299            .call()
300            .map_err(|e| ReleaseError::Vcs(format!("GitHub API DELETE {url}: {e}")))?;
301        Ok(())
302    }
303
304    fn update_release(
305        &self,
306        tag: &str,
307        name: &str,
308        body: &str,
309        prerelease: bool,
310        draft: bool,
311    ) -> Result<String, ReleaseError> {
312        let release = self.get_release_by_tag(tag)?;
313        let url = format!(
314            "{}/repos/{}/{}/releases/{}",
315            self.api_url(),
316            self.owner,
317            self.repo,
318            release.id
319        );
320        let payload = serde_json::json!({
321            "name": name,
322            "body": body,
323            "prerelease": prerelease,
324            "draft": draft,
325        });
326        let resp = self
327            .agent()
328            .patch(&url)
329            .header("Authorization", &format!("Bearer {}", self.token))
330            .header("Accept", "application/vnd.github+json")
331            .header("X-GitHub-Api-Version", "2022-11-28")
332            .header("User-Agent", "sr-github")
333            .send_json(&payload)
334            .map_err(|e| ReleaseError::Vcs(format!("GitHub API PATCH {url}: {e}")))?;
335        let updated: ReleaseResponse = resp
336            .into_body()
337            .read_json()
338            .map_err(|e| ReleaseError::Vcs(format!("failed to parse release response: {e}")))?;
339        Ok(updated.html_url)
340    }
341
342    fn upload_assets(&self, tag: &str, files: &[&str]) -> Result<(), ReleaseError> {
343        let release = self.get_release_by_tag(tag)?;
344        // The upload_url from the API looks like:
345        //   https://uploads.github.com/repos/owner/repo/releases/123/assets{?name,label}
346        // Strip the {?name,label} template suffix.
347        let upload_base = release
348            .upload_url
349            .split('{')
350            .next()
351            .unwrap_or(&release.upload_url);
352
353        for file_path in files {
354            let path = std::path::Path::new(file_path);
355            let file_name = path
356                .file_name()
357                .and_then(|n| n.to_str())
358                .ok_or_else(|| ReleaseError::Vcs(format!("invalid file path: {file_path}")))?;
359
360            let data = std::fs::read(path)
361                .map_err(|e| ReleaseError::Vcs(format!("failed to read asset {file_path}: {e}")))?;
362
363            let content_type = mime_from_extension(file_name);
364            let url = format!("{upload_base}?name={file_name}");
365
366            // Retry up to 3 times for transient upload failures
367            let mut last_err = None;
368            for attempt in 0..3 {
369                if attempt > 0 {
370                    std::thread::sleep(std::time::Duration::from_secs(1 << attempt));
371                    eprintln!(
372                        "Retrying upload of {file_name} (attempt {}/3)...",
373                        attempt + 1
374                    );
375                }
376                match self
377                    .agent()
378                    .post(&url)
379                    .header("Authorization", &format!("Bearer {}", self.token))
380                    .header("Accept", "application/vnd.github+json")
381                    .header("X-GitHub-Api-Version", "2022-11-28")
382                    .header("User-Agent", "sr-github")
383                    .header("Content-Type", content_type)
384                    .send(&data[..])
385                {
386                    Ok(_) => {
387                        last_err = None;
388                        break;
389                    }
390                    Err(e) => {
391                        last_err = Some(format!("GitHub API upload asset {file_name}: {e}"));
392                    }
393                }
394            }
395            if let Some(err_msg) = last_err {
396                return Err(ReleaseError::Vcs(err_msg));
397            }
398        }
399
400        Ok(())
401    }
402
403    fn verify_release(&self, tag: &str) -> Result<(), ReleaseError> {
404        // GET the release by tag to confirm it exists and is accessible
405        self.get_release_by_tag(tag)?;
406        Ok(())
407    }
408
409    fn list_assets(&self, tag: &str) -> Result<Vec<String>, ReleaseError> {
410        match self.get_release_by_tag(tag) {
411            Ok(release) => Ok(release.assets.into_iter().map(|a| a.name).collect()),
412            // No release yet — no assets.
413            Err(ReleaseError::Vcs(msg)) if msg.contains("404") => Ok(Vec::new()),
414            Err(e) => Err(e),
415        }
416    }
417
418    fn fetch_asset(&self, tag: &str, name: &str) -> Result<Option<Vec<u8>>, ReleaseError> {
419        let release = match self.get_release_by_tag(tag) {
420            Ok(r) => r,
421            Err(ReleaseError::Vcs(msg)) if msg.contains("404") => return Ok(None),
422            Err(e) => return Err(e),
423        };
424        let asset = match release.assets.into_iter().find(|a| a.name == name) {
425            Some(a) => a,
426            None => return Ok(None),
427        };
428        // GET the asset URL with Accept: application/octet-stream returns bytes.
429        let resp = self
430            .agent()
431            .get(&asset.url)
432            .header("Authorization", &format!("Bearer {}", self.token))
433            .header("Accept", "application/octet-stream")
434            .header("X-GitHub-Api-Version", "2022-11-28")
435            .header("User-Agent", "sr-github")
436            .call()
437            .map_err(|e| ReleaseError::Vcs(format!("GitHub API GET {}: {e}", asset.url)))?;
438        let bytes = resp
439            .into_body()
440            .read_to_vec()
441            .map_err(|e| ReleaseError::Vcs(format!("failed to read asset body: {e}")))?;
442        Ok(Some(bytes))
443    }
444}
445
446/// Map file extension to MIME type for GitHub asset uploads.
447fn mime_from_extension(filename: &str) -> &'static str {
448    match filename.rsplit('.').next().unwrap_or("") {
449        "gz" | "tgz" => "application/gzip",
450        "zip" => "application/zip",
451        "tar" => "application/x-tar",
452        "xz" => "application/x-xz",
453        "bz2" => "application/x-bzip2",
454        "zst" | "zstd" => "application/zstd",
455        "deb" => "application/vnd.debian.binary-package",
456        "rpm" => "application/x-rpm",
457        "dmg" => "application/x-apple-diskimage",
458        "msi" => "application/x-msi",
459        "exe" => "application/vnd.microsoft.portable-executable",
460        "sig" | "asc" => "application/pgp-signature",
461        "sha512" => "text/plain",
462        "json" => "application/json",
463        "txt" | "md" => "text/plain",
464        _ => "application/octet-stream",
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471
472    fn github_com_provider() -> GitHubProvider {
473        GitHubProvider::new(
474            "urmzd".into(),
475            "sr".into(),
476            "github.com".into(),
477            "test-token".into(),
478        )
479    }
480
481    fn ghes_provider() -> GitHubProvider {
482        GitHubProvider::new(
483            "org".into(),
484            "repo".into(),
485            "ghes.example.com".into(),
486            "test-token".into(),
487        )
488    }
489
490    #[test]
491    fn test_api_url_github_com() {
492        assert_eq!(github_com_provider().api_url(), "https://api.github.com");
493    }
494
495    #[test]
496    fn test_api_url_ghes() {
497        assert_eq!(ghes_provider().api_url(), "https://ghes.example.com/api/v3");
498    }
499
500    #[test]
501    fn test_base_url() {
502        assert_eq!(
503            github_com_provider().base_url(),
504            "https://github.com/urmzd/sr"
505        );
506        assert_eq!(
507            ghes_provider().base_url(),
508            "https://ghes.example.com/org/repo"
509        );
510    }
511
512    #[test]
513    fn test_compare_url() {
514        let p = github_com_provider();
515        assert_eq!(
516            p.compare_url("v0.9.0", "v1.0.0").unwrap(),
517            "https://github.com/urmzd/sr/compare/v0.9.0...v1.0.0"
518        );
519    }
520
521    #[test]
522    fn test_repo_url() {
523        assert_eq!(
524            github_com_provider().repo_url().unwrap(),
525            "https://github.com/urmzd/sr"
526        );
527    }
528}