Skip to main content

upstream_rs/providers/github/
github_client.rs

1use anyhow::{Context, Result, bail};
2use reqwest::{Client, header};
3use serde::Deserialize;
4use std::path::Path;
5use tokio::fs::File;
6use tokio::io::AsyncWriteExt;
7
8use super::github_dtos::GithubReleaseDto;
9
10#[derive(Debug, Clone)]
11pub struct GithubClient {
12    client: Client,
13}
14
15#[cfg(test)]
16mod tests {
17    use crate::providers::github::github_dtos::GithubReleaseDto;
18
19    #[test]
20    fn github_release_dto_accepts_nullable_string_fields() {
21        let json = r#"
22        {
23          "id": 1,
24          "tag_name": "v1.0.0",
25          "name": null,
26          "body": null,
27          "prerelease": false,
28          "draft": false,
29          "published_at": null,
30          "assets": [
31            {
32              "id": 42,
33              "name": "tree-sitter-linux.tar.gz",
34              "browser_download_url": "https://example.com/asset.tar.gz",
35              "size": 1234,
36              "content_type": null,
37              "created_at": null
38            }
39          ]
40        }
41        "#;
42
43        let parsed = serde_json::from_str::<GithubReleaseDto>(json).expect("valid release JSON");
44        assert_eq!(parsed.name, "");
45        assert_eq!(parsed.body, "");
46        assert_eq!(parsed.published_at, "");
47        assert_eq!(parsed.assets[0].content_type, "");
48        assert_eq!(parsed.assets[0].created_at, "");
49    }
50}
51
52impl GithubClient {
53    pub fn new(token: Option<&str>) -> Result<Self> {
54        let mut headers = header::HeaderMap::new();
55
56        let user_agent = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
57        headers.insert(
58            header::USER_AGENT,
59            header::HeaderValue::from_str(&user_agent)
60                .context("Failed to create user agent header")?,
61        );
62
63        if let Some(token) = token {
64            let auth_value = format!("Bearer {}", token);
65            headers.insert(
66                header::AUTHORIZATION,
67                header::HeaderValue::from_str(&auth_value)
68                    .context("Failed to create authorization header")?,
69            );
70        }
71
72        let client = Client::builder()
73            .default_headers(headers)
74            .build()
75            .context("Failed to build HTTP client")?;
76
77        Ok(Self { client })
78    }
79
80    async fn get_json<T: for<'de> Deserialize<'de>>(&self, url: &str) -> Result<T> {
81        let response = self
82            .client
83            .get(url)
84            .send()
85            .await
86            .context(format!("Failed to send request to {}", url))?;
87
88        response
89            .error_for_status_ref()
90            .context(format!("GitHub API returned error for {}", url))?;
91
92        let data = response
93            .json::<T>()
94            .await
95            .context("Failed to parse JSON response")?;
96
97        Ok(data)
98    }
99
100    pub async fn download_file<F>(
101        &self,
102        url: &str,
103        destination: &Path,
104        progress: &mut Option<F>,
105    ) -> Result<()>
106    where
107        F: FnMut(u64, u64),
108    {
109        let response = self
110            .client
111            .get(url)
112            .send()
113            .await
114            .context(format!("Failed to download from {}", url))?;
115
116        response
117            .error_for_status_ref()
118            .context("Download request failed")?;
119
120        let total_bytes = response.content_length().unwrap_or(0);
121
122        let mut file = File::create(destination)
123            .await
124            .context(format!("Failed to create file at {:?}", destination))?;
125
126        let mut stream = response.bytes_stream();
127        let mut total_read: u64 = 0;
128
129        use futures_util::StreamExt;
130        while let Some(chunk) = stream.next().await {
131            let chunk = chunk.context("Failed to read download chunk")?;
132
133            file.write_all(&chunk)
134                .await
135                .context("Failed to write to file")?;
136
137            total_read += chunk.len() as u64;
138
139            if let Some(cb) = progress.as_mut() {
140                cb(total_read, total_bytes);
141            }
142        }
143
144        file.flush().await.context("Failed to flush file")?;
145
146        if total_bytes > 0 && total_read != total_bytes {
147            bail!(
148                "Download size mismatch: expected {} bytes, got {} bytes",
149                total_bytes,
150                total_read
151            );
152        }
153
154        Ok(())
155    }
156
157    pub async fn get_release_by_tag(
158        &self,
159        owner_repo: &str,
160        tag: &str,
161    ) -> Result<GithubReleaseDto> {
162        let url = format!(
163            "https://api.github.com/repos/{}/releases/tags/{}",
164            owner_repo, tag
165        );
166        self.get_json(&url)
167            .await
168            .context(format!("Failed to get release for tag {}", tag))
169    }
170
171    pub async fn get_latest_release(&self, owner_repo: &str) -> Result<GithubReleaseDto> {
172        let url = format!(
173            "https://api.github.com/repos/{}/releases/latest",
174            owner_repo
175        );
176        self.get_json(&url)
177            .await
178            .context(format!("Failed to get latest release for {}", owner_repo))
179    }
180
181    pub async fn get_releases(
182        &self,
183        owner_repo: &str,
184        per_page: Option<u32>,
185        max_total: Option<u32>,
186    ) -> Result<Vec<GithubReleaseDto>> {
187        let per_page = per_page.unwrap_or(30);
188        let mut page = 1;
189        let mut releases = Vec::new();
190
191        loop {
192            let url = format!(
193                "https://api.github.com/repos/{}/releases?per_page={}&page={}",
194                owner_repo, per_page, page
195            );
196            let batch: Vec<GithubReleaseDto> = self
197                .get_json(&url)
198                .await
199                .context(format!("Failed to get releases page {}", page))?;
200
201            if batch.is_empty() {
202                break;
203            }
204
205            releases.extend(batch);
206
207            // Check if we've hit the total limit
208            if let Some(max) = max_total
209                && releases.len() >= max as usize
210            {
211                releases.truncate(max as usize);
212                break;
213            }
214
215            // Check if this was a partial page (last page)
216            if releases.len() % per_page as usize != 0 {
217                break;
218            }
219
220            page += 1;
221        }
222
223        Ok(releases)
224    }
225}