Skip to main content

sr_core/
github.rs

1use crate::error::ReleaseError;
2use crate::release::VcsProvider;
3
4/// GitHub implementation of the VcsProvider trait using the GitHub REST API.
5pub struct GitHubProvider {
6    owner: String,
7    repo: String,
8    hostname: String,
9    token: String,
10}
11
12#[derive(serde::Deserialize)]
13struct ReleaseResponse {
14    id: u64,
15    html_url: String,
16    upload_url: String,
17}
18
19impl GitHubProvider {
20    pub fn new(owner: String, repo: String, hostname: String, token: String) -> Self {
21        Self {
22            owner,
23            repo,
24            hostname,
25            token,
26        }
27    }
28
29    fn base_url(&self) -> String {
30        format!("https://{}/{}/{}", self.hostname, self.owner, self.repo)
31    }
32
33    fn api_url(&self) -> String {
34        if self.hostname == "github.com" {
35            "https://api.github.com".to_string()
36        } else {
37            format!("https://{}/api/v3", self.hostname)
38        }
39    }
40
41    fn agent(&self) -> ureq::Agent {
42        ureq::Agent::new_with_config(ureq::config::Config::builder().https_only(true).build())
43    }
44
45    fn get_release_by_tag(&self, tag: &str) -> Result<ReleaseResponse, ReleaseError> {
46        let url = format!(
47            "{}/repos/{}/{}/releases/tags/{tag}",
48            self.api_url(),
49            self.owner,
50            self.repo
51        );
52        let resp = self
53            .agent()
54            .get(&url)
55            .header("Authorization", &format!("Bearer {}", self.token))
56            .header("Accept", "application/vnd.github+json")
57            .header("X-GitHub-Api-Version", "2022-11-28")
58            .header("User-Agent", "sr-github")
59            .call()
60            .map_err(|e| ReleaseError::Vcs(format!("GitHub API GET {url}: {e}")))?;
61        let release: ReleaseResponse = resp
62            .into_body()
63            .read_json()
64            .map_err(|e| ReleaseError::Vcs(format!("failed to parse release response: {e}")))?;
65        Ok(release)
66    }
67}
68
69// ---------------------------------------------------------------------------
70// PR support
71// ---------------------------------------------------------------------------
72
73#[derive(Debug, serde::Deserialize)]
74pub struct PrMetadata {
75    pub number: u64,
76    pub title: String,
77    pub body: Option<String>,
78    pub user: PrUser,
79    pub base: PrRef,
80    pub head: PrRef,
81}
82
83#[derive(Debug, serde::Deserialize)]
84pub struct PrUser {
85    pub login: String,
86}
87
88#[derive(Debug, serde::Deserialize)]
89pub struct PrRef {
90    #[serde(rename = "ref")]
91    pub ref_name: String,
92}
93
94impl GitHubProvider {
95    /// Find the open PR number for a given branch.
96    pub fn get_pr_for_branch(&self, branch: &str) -> Result<PrMetadata, ReleaseError> {
97        let url = format!(
98            "{}/repos/{}/{}/pulls?head={}:{}&state=open&per_page=1",
99            self.api_url(),
100            self.owner,
101            self.repo,
102            self.owner,
103            branch
104        );
105        let resp = self
106            .agent()
107            .get(&url)
108            .header("Authorization", &format!("Bearer {}", self.token))
109            .header("Accept", "application/vnd.github+json")
110            .header("X-GitHub-Api-Version", "2022-11-28")
111            .header("User-Agent", "sr-github")
112            .call()
113            .map_err(|e| ReleaseError::Vcs(format!("GitHub API GET {url}: {e}")))?;
114        let prs: Vec<PrMetadata> = resp
115            .into_body()
116            .read_json()
117            .map_err(|e| ReleaseError::Vcs(format!("failed to parse PR list: {e}")))?;
118        prs.into_iter()
119            .next()
120            .ok_or_else(|| ReleaseError::Vcs(format!("no open PR found for branch '{branch}'")))
121    }
122
123    /// Fetch the diff for a PR.
124    pub fn get_pr_diff(&self, pr_number: u64) -> Result<String, ReleaseError> {
125        let url = format!(
126            "{}/repos/{}/{}/pulls/{pr_number}",
127            self.api_url(),
128            self.owner,
129            self.repo
130        );
131        let resp = self
132            .agent()
133            .get(&url)
134            .header("Authorization", &format!("Bearer {}", self.token))
135            .header("Accept", "application/vnd.github.v3.diff")
136            .header("X-GitHub-Api-Version", "2022-11-28")
137            .header("User-Agent", "sr-github")
138            .call()
139            .map_err(|e| ReleaseError::Vcs(format!("GitHub API GET {url}: {e}")))?;
140        resp.into_body()
141            .read_to_string()
142            .map_err(|e| ReleaseError::Vcs(format!("failed to read PR diff: {e}")))
143    }
144
145    /// Count open PRs.
146    pub fn count_open_prs(&self) -> Result<(usize, usize), ReleaseError> {
147        // Returns (ready_count, draft_count)
148        let url = format!(
149            "{}/repos/{}/{}/pulls?state=open&per_page=100",
150            self.api_url(),
151            self.owner,
152            self.repo
153        );
154        let resp = self
155            .agent()
156            .get(&url)
157            .header("Authorization", &format!("Bearer {}", self.token))
158            .header("Accept", "application/vnd.github+json")
159            .header("X-GitHub-Api-Version", "2022-11-28")
160            .header("User-Agent", "sr-github")
161            .call()
162            .map_err(|e| ReleaseError::Vcs(format!("GitHub API GET {url}: {e}")))?;
163
164        #[derive(serde::Deserialize)]
165        struct MinimalPr {
166            draft: Option<bool>,
167        }
168
169        let prs: Vec<MinimalPr> = resp
170            .into_body()
171            .read_json()
172            .map_err(|e| ReleaseError::Vcs(format!("failed to parse PR list: {e}")))?;
173
174        let draft_count = prs.iter().filter(|p| p.draft == Some(true)).count();
175        let ready_count = prs.len() - draft_count;
176        Ok((ready_count, draft_count))
177    }
178
179    /// Post a review summary on a PR.
180    pub fn post_pr_review(&self, pr_number: u64, body: &str) -> Result<(), ReleaseError> {
181        let url = format!(
182            "{}/repos/{}/{}/pulls/{pr_number}/reviews",
183            self.api_url(),
184            self.owner,
185            self.repo
186        );
187        let payload = serde_json::json!({
188            "body": body,
189            "event": "COMMENT",
190        });
191        self.agent()
192            .post(&url)
193            .header("Authorization", &format!("Bearer {}", self.token))
194            .header("Accept", "application/vnd.github+json")
195            .header("X-GitHub-Api-Version", "2022-11-28")
196            .header("User-Agent", "sr-github")
197            .send_json(&payload)
198            .map_err(|e| ReleaseError::Vcs(format!("GitHub API POST {url}: {e}")))?;
199        Ok(())
200    }
201}
202
203impl VcsProvider for GitHubProvider {
204    fn create_release(
205        &self,
206        tag: &str,
207        name: &str,
208        body: &str,
209        prerelease: bool,
210        draft: bool,
211    ) -> Result<String, ReleaseError> {
212        let url = format!(
213            "{}/repos/{}/{}/releases",
214            self.api_url(),
215            self.owner,
216            self.repo
217        );
218        let payload = serde_json::json!({
219            "tag_name": tag,
220            "name": name,
221            "body": body,
222            "prerelease": prerelease,
223            "draft": draft,
224        });
225
226        let resp = self
227            .agent()
228            .post(&url)
229            .header("Authorization", &format!("Bearer {}", self.token))
230            .header("Accept", "application/vnd.github+json")
231            .header("X-GitHub-Api-Version", "2022-11-28")
232            .header("User-Agent", "sr-github")
233            .send_json(&payload)
234            .map_err(|e| ReleaseError::Vcs(format!("GitHub API POST {url}: {e}")))?;
235
236        let release: ReleaseResponse = resp
237            .into_body()
238            .read_json()
239            .map_err(|e| ReleaseError::Vcs(format!("failed to parse release response: {e}")))?;
240
241        Ok(release.html_url)
242    }
243
244    fn compare_url(&self, base: &str, head: &str) -> Result<String, ReleaseError> {
245        Ok(format!("{}/compare/{base}...{head}", self.base_url()))
246    }
247
248    fn release_exists(&self, tag: &str) -> Result<bool, ReleaseError> {
249        let url = format!(
250            "{}/repos/{}/{}/releases/tags/{tag}",
251            self.api_url(),
252            self.owner,
253            self.repo
254        );
255        match self
256            .agent()
257            .get(&url)
258            .header("Authorization", &format!("Bearer {}", self.token))
259            .header("Accept", "application/vnd.github+json")
260            .header("X-GitHub-Api-Version", "2022-11-28")
261            .header("User-Agent", "sr-github")
262            .call()
263        {
264            Ok(_) => Ok(true),
265            Err(ureq::Error::StatusCode(404)) => Ok(false),
266            Err(e) => Err(ReleaseError::Vcs(format!("GitHub API GET {url}: {e}"))),
267        }
268    }
269
270    fn repo_url(&self) -> Option<String> {
271        Some(self.base_url())
272    }
273
274    fn delete_release(&self, tag: &str) -> Result<(), ReleaseError> {
275        let release = self.get_release_by_tag(tag)?;
276        let url = format!(
277            "{}/repos/{}/{}/releases/{}",
278            self.api_url(),
279            self.owner,
280            self.repo,
281            release.id
282        );
283        self.agent()
284            .delete(&url)
285            .header("Authorization", &format!("Bearer {}", self.token))
286            .header("Accept", "application/vnd.github+json")
287            .header("X-GitHub-Api-Version", "2022-11-28")
288            .header("User-Agent", "sr-github")
289            .call()
290            .map_err(|e| ReleaseError::Vcs(format!("GitHub API DELETE {url}: {e}")))?;
291        Ok(())
292    }
293
294    fn update_release(
295        &self,
296        tag: &str,
297        name: &str,
298        body: &str,
299        prerelease: bool,
300        draft: bool,
301    ) -> Result<String, ReleaseError> {
302        let release = self.get_release_by_tag(tag)?;
303        let url = format!(
304            "{}/repos/{}/{}/releases/{}",
305            self.api_url(),
306            self.owner,
307            self.repo,
308            release.id
309        );
310        let payload = serde_json::json!({
311            "name": name,
312            "body": body,
313            "prerelease": prerelease,
314            "draft": draft,
315        });
316        let resp = self
317            .agent()
318            .patch(&url)
319            .header("Authorization", &format!("Bearer {}", self.token))
320            .header("Accept", "application/vnd.github+json")
321            .header("X-GitHub-Api-Version", "2022-11-28")
322            .header("User-Agent", "sr-github")
323            .send_json(&payload)
324            .map_err(|e| ReleaseError::Vcs(format!("GitHub API PATCH {url}: {e}")))?;
325        let updated: ReleaseResponse = resp
326            .into_body()
327            .read_json()
328            .map_err(|e| ReleaseError::Vcs(format!("failed to parse release response: {e}")))?;
329        Ok(updated.html_url)
330    }
331
332    fn upload_assets(&self, tag: &str, files: &[&str]) -> Result<(), ReleaseError> {
333        let release = self.get_release_by_tag(tag)?;
334        // The upload_url from the API looks like:
335        //   https://uploads.github.com/repos/owner/repo/releases/123/assets{?name,label}
336        // Strip the {?name,label} template suffix.
337        let upload_base = release
338            .upload_url
339            .split('{')
340            .next()
341            .unwrap_or(&release.upload_url);
342
343        for file_path in files {
344            let path = std::path::Path::new(file_path);
345            let file_name = path
346                .file_name()
347                .and_then(|n| n.to_str())
348                .ok_or_else(|| ReleaseError::Vcs(format!("invalid file path: {file_path}")))?;
349
350            let data = std::fs::read(path)
351                .map_err(|e| ReleaseError::Vcs(format!("failed to read asset {file_path}: {e}")))?;
352
353            let content_type = mime_from_extension(file_name);
354            let url = format!("{upload_base}?name={file_name}");
355
356            // Retry up to 3 times for transient upload failures
357            let mut last_err = None;
358            for attempt in 0..3 {
359                if attempt > 0 {
360                    std::thread::sleep(std::time::Duration::from_secs(1 << attempt));
361                    eprintln!(
362                        "Retrying upload of {file_name} (attempt {}/3)...",
363                        attempt + 1
364                    );
365                }
366                match self
367                    .agent()
368                    .post(&url)
369                    .header("Authorization", &format!("Bearer {}", self.token))
370                    .header("Accept", "application/vnd.github+json")
371                    .header("X-GitHub-Api-Version", "2022-11-28")
372                    .header("User-Agent", "sr-github")
373                    .header("Content-Type", content_type)
374                    .send(&data[..])
375                {
376                    Ok(_) => {
377                        last_err = None;
378                        break;
379                    }
380                    Err(e) => {
381                        last_err = Some(format!("GitHub API upload asset {file_name}: {e}"));
382                    }
383                }
384            }
385            if let Some(err_msg) = last_err {
386                return Err(ReleaseError::Vcs(err_msg));
387            }
388        }
389
390        Ok(())
391    }
392
393    fn verify_release(&self, tag: &str) -> Result<(), ReleaseError> {
394        // GET the release by tag to confirm it exists and is accessible
395        self.get_release_by_tag(tag)?;
396        Ok(())
397    }
398}
399
400/// Map file extension to MIME type for GitHub asset uploads.
401fn mime_from_extension(filename: &str) -> &'static str {
402    match filename.rsplit('.').next().unwrap_or("") {
403        "gz" | "tgz" => "application/gzip",
404        "zip" => "application/zip",
405        "tar" => "application/x-tar",
406        "xz" => "application/x-xz",
407        "bz2" => "application/x-bzip2",
408        "zst" | "zstd" => "application/zstd",
409        "deb" => "application/vnd.debian.binary-package",
410        "rpm" => "application/x-rpm",
411        "dmg" => "application/x-apple-diskimage",
412        "msi" => "application/x-msi",
413        "exe" => "application/vnd.microsoft.portable-executable",
414        "sig" | "asc" => "application/pgp-signature",
415        "sha512" => "text/plain",
416        "json" => "application/json",
417        "txt" | "md" => "text/plain",
418        _ => "application/octet-stream",
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    fn github_com_provider() -> GitHubProvider {
427        GitHubProvider::new(
428            "urmzd".into(),
429            "sr".into(),
430            "github.com".into(),
431            "test-token".into(),
432        )
433    }
434
435    fn ghes_provider() -> GitHubProvider {
436        GitHubProvider::new(
437            "org".into(),
438            "repo".into(),
439            "ghes.example.com".into(),
440            "test-token".into(),
441        )
442    }
443
444    #[test]
445    fn test_api_url_github_com() {
446        assert_eq!(github_com_provider().api_url(), "https://api.github.com");
447    }
448
449    #[test]
450    fn test_api_url_ghes() {
451        assert_eq!(ghes_provider().api_url(), "https://ghes.example.com/api/v3");
452    }
453
454    #[test]
455    fn test_base_url() {
456        assert_eq!(
457            github_com_provider().base_url(),
458            "https://github.com/urmzd/sr"
459        );
460        assert_eq!(
461            ghes_provider().base_url(),
462            "https://ghes.example.com/org/repo"
463        );
464    }
465
466    #[test]
467    fn test_compare_url() {
468        let p = github_com_provider();
469        assert_eq!(
470            p.compare_url("v0.9.0", "v1.0.0").unwrap(),
471            "https://github.com/urmzd/sr/compare/v0.9.0...v1.0.0"
472        );
473    }
474
475    #[test]
476    fn test_repo_url() {
477        assert_eq!(
478            github_com_provider().repo_url().unwrap(),
479            "https://github.com/urmzd/sr"
480        );
481    }
482}