Skip to main content

upstream_rs/providers/github/
github_client.rs

1use anyhow::{Context, Result};
2use reqwest::{Client, header};
3use serde::Deserialize;
4use std::path::Path;
5
6use crate::{
7    models::{provider::RepositorySearchFilters, upstream::DownloadConfig},
8    providers::{download_handler, http::http_status},
9};
10
11use super::github_dtos::{GithubReleaseDto, GithubRepositorySearchResponseDto};
12#[derive(Debug, Deserialize)]
13struct GithubCommitDto {
14    sha: String,
15}
16
17#[derive(Debug, Clone)]
18pub struct GithubClient {
19    client: Client,
20    download_config: DownloadConfig,
21}
22
23impl GithubClient {
24    pub fn new(token: Option<&str>, download_config: DownloadConfig) -> Result<Self> {
25        let mut headers = header::HeaderMap::new();
26
27        let user_agent = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
28        headers.insert(
29            header::USER_AGENT,
30            header::HeaderValue::from_str(&user_agent)
31                .context("Failed to create user agent header")?,
32        );
33
34        if let Some(token) = token {
35            let auth_value = format!("Bearer {}", token);
36            headers.insert(
37                header::AUTHORIZATION,
38                header::HeaderValue::from_str(&auth_value)
39                    .context("Failed to create authorization header")?,
40            );
41        }
42
43        let client = Client::builder()
44            .default_headers(headers)
45            .build()
46            .context("Failed to build HTTP client")?;
47
48        Ok(Self {
49            client,
50            download_config,
51        })
52    }
53
54    async fn get_json<T: for<'de> Deserialize<'de>>(&self, url: &str) -> Result<T> {
55        let response = self
56            .client
57            .get(url)
58            .send()
59            .await
60            .context(format!("Failed to send request to {}", url))?;
61
62        http_status::error_for_status(&response, "GitHub API", url)?;
63
64        let data = response
65            .json::<T>()
66            .await
67            .context("Failed to parse JSON response")?;
68
69        Ok(data)
70    }
71
72    async fn get_text_with_accept(&self, url: &str, accept: &'static str) -> Result<String> {
73        let response = self
74            .client
75            .get(url)
76            .header(header::ACCEPT, accept)
77            .send()
78            .await
79            .context(format!("Failed to send request to {}", url))?;
80
81        http_status::error_for_status(&response, "GitHub API", url)?;
82
83        response
84            .text()
85            .await
86            .context("Failed to read text response")
87    }
88
89    pub async fn download_file<F>(
90        &self,
91        url: &str,
92        destination: &Path,
93        progress: &mut Option<F>,
94    ) -> Result<()>
95    where
96        F: FnMut(u64, u64),
97    {
98        download_handler::download_file(
99            &self.client,
100            url,
101            destination,
102            progress,
103            self.download_config,
104        )
105        .await
106    }
107
108    pub async fn check_token(&self) -> Result<reqwest::Response> {
109        let url = "https://api.github.com/user";
110        self.client
111            .get(url)
112            .send()
113            .await
114            .context(format!("Failed to send request to {}", url))
115    }
116
117    pub async fn get_release_by_tag(
118        &self,
119        owner_repo: &str,
120        tag: &str,
121    ) -> Result<GithubReleaseDto> {
122        let url = format!(
123            "https://api.github.com/repos/{}/releases/tags/{}",
124            owner_repo, tag
125        );
126        self.get_json(&url)
127            .await
128            .context(format!("Failed to get release for tag {}", tag))
129    }
130
131    pub async fn get_latest_release(&self, owner_repo: &str) -> Result<GithubReleaseDto> {
132        let url = format!(
133            "https://api.github.com/repos/{}/releases/latest",
134            owner_repo
135        );
136        self.get_json(&url)
137            .await
138            .context(format!("Failed to get latest release for {}", owner_repo))
139    }
140
141    pub async fn get_releases(
142        &self,
143        owner_repo: &str,
144        per_page: Option<u32>,
145        max_total: Option<u32>,
146    ) -> Result<Vec<GithubReleaseDto>> {
147        let per_page = per_page.unwrap_or(30);
148        let mut page = 1;
149        let mut releases = Vec::new();
150
151        loop {
152            let batch = self
153                .get_releases_page(owner_repo, per_page, page)
154                .await
155                .context(format!("Failed to get releases page {}", page))?;
156            let partial_page = batch.len() < per_page as usize;
157
158            if batch.is_empty() {
159                break;
160            }
161
162            releases.extend(batch);
163
164            // Check if we've hit the total limit
165            if let Some(max) = max_total
166                && releases.len() >= max as usize
167            {
168                releases.truncate(max as usize);
169                break;
170            }
171
172            if partial_page {
173                break;
174            }
175
176            page += 1;
177        }
178
179        Ok(releases)
180    }
181
182    pub async fn get_releases_page(
183        &self,
184        owner_repo: &str,
185        per_page: u32,
186        page: u32,
187    ) -> Result<Vec<GithubReleaseDto>> {
188        let url = format!(
189            "https://api.github.com/repos/{}/releases?per_page={}&page={}",
190            owner_repo, per_page, page
191        );
192        self.get_json(&url)
193            .await
194            .context(format!("Failed to get releases page {}", page))
195    }
196
197    pub async fn get_branch_head_sha(&self, owner_repo: &str, branch: &str) -> Result<String> {
198        let encoded_branch = branch.replace('/', "%2F");
199        let url = format!(
200            "https://api.github.com/repos/{}/commits/{}",
201            owner_repo, encoded_branch
202        );
203        let dto: GithubCommitDto = self.get_json(&url).await.context(format!(
204            "Failed to get branch head for {}/{}",
205            owner_repo, branch
206        ))?;
207        Ok(dto.sha)
208    }
209
210    pub async fn get_project_readme(&self, owner_repo: &str) -> Result<String> {
211        let url = format!("https://api.github.com/repos/{}/readme", owner_repo);
212        self.get_text_with_accept(&url, "application/vnd.github.raw")
213            .await
214            .context(format!("Failed to get README for {}", owner_repo))
215    }
216
217    pub async fn search_repositories(
218        &self,
219        query: &str,
220        limit: Option<u32>,
221        filters: &RepositorySearchFilters,
222    ) -> Result<GithubRepositorySearchResponseDto> {
223        let per_page = limit.unwrap_or(10).clamp(1, 100);
224        let search_query = Self::build_repository_search_query(query, filters);
225        let mut url = reqwest::Url::parse("https://api.github.com/search/repositories")
226            .context("Failed to build GitHub search URL")?;
227        url.query_pairs_mut()
228            .append_pair("q", &search_query)
229            .append_pair("per_page", &per_page.to_string());
230
231        self.get_json(url.as_str()).await.context(format!(
232            "Failed to search repositories for '{}'",
233            search_query
234        ))
235    }
236
237    fn build_repository_search_query(query: &str, filters: &RepositorySearchFilters) -> String {
238        let mut parts = vec![query.trim().to_string()];
239
240        if let Some(language) = &filters.language {
241            parts.push(format!(
242                "language:{}",
243                Self::format_search_qualifier_value(language)
244            ));
245        }
246        if let Some(topic) = &filters.topic {
247            parts.push(format!(
248                "topic:{}",
249                Self::format_search_qualifier_value(topic)
250            ));
251        }
252        if let Some(min_stars) = filters.min_stars {
253            parts.push(format!("stars:>={min_stars}"));
254        }
255        if let Some(max_stars) = filters.max_stars {
256            parts.push(format!("stars:<={max_stars}"));
257        }
258        if let Some(pushed_after) = filters.pushed_after {
259            parts.push(format!("pushed:>={pushed_after}"));
260        }
261        if filters.include_forks {
262            parts.push("fork:true".to_string());
263        }
264        if !filters.include_archived {
265            parts.push("archived:false".to_string());
266        }
267
268        let parts = parts
269            .into_iter()
270            .filter(|part| !part.is_empty())
271            .collect::<Vec<_>>();
272        if parts.is_empty() {
273            return "stars:>=0".to_string();
274        }
275
276        parts.join(" ")
277    }
278
279    fn format_search_qualifier_value(value: &str) -> String {
280        if value.chars().any(char::is_whitespace) {
281            format!("\"{}\"", value.replace('\\', "\\\\").replace('"', "\\\""))
282        } else {
283            value.to_string()
284        }
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use chrono::NaiveDate;
291
292    use crate::models::provider::RepositorySearchFilters;
293    use crate::providers::github::GithubClient;
294    use crate::providers::github::github_dtos::{
295        GithubReleaseDto, GithubRepositorySearchResponseDto,
296    };
297
298    #[test]
299    fn github_release_dto_accepts_nullable_string_fields() {
300        let json = r#"
301        {
302          "id": 1,
303          "tag_name": "v1.0.0",
304          "name": null,
305          "body": null,
306          "prerelease": false,
307          "draft": false,
308          "published_at": null,
309          "assets": [
310            {
311              "id": 42,
312              "name": "tree-sitter-linux.tar.gz",
313              "browser_download_url": "https://example.com/asset.tar.gz",
314              "size": 1234,
315              "content_type": null,
316              "created_at": null
317            }
318          ]
319        }
320        "#;
321
322        let parsed = serde_json::from_str::<GithubReleaseDto>(json).expect("valid release JSON");
323        assert_eq!(parsed.name, "");
324        assert_eq!(parsed.body, "");
325        assert_eq!(parsed.published_at, "");
326        assert_eq!(parsed.assets[0].content_type, "");
327        assert_eq!(parsed.assets[0].created_at, "");
328    }
329
330    #[test]
331    fn github_search_dto_accepts_nullable_string_fields() {
332        let json = r#"
333        {
334          "items": [
335            {
336              "full_name": "BurntSushi/ripgrep",
337              "name": "ripgrep",
338              "description": null,
339              "stargazers_count": 10,
340              "language": null,
341              "updated_at": null,
342              "archived": false,
343              "fork": false
344            }
345          ]
346        }
347        "#;
348
349        let parsed = serde_json::from_str::<GithubRepositorySearchResponseDto>(json)
350            .expect("valid search JSON");
351        assert_eq!(parsed.items.len(), 1);
352        assert_eq!(parsed.items[0].description, "");
353        assert_eq!(parsed.items[0].language, "");
354        assert_eq!(parsed.items[0].updated_at, "");
355    }
356
357    #[test]
358    fn build_repository_search_query_adds_discovery_filters() {
359        let filters = RepositorySearchFilters::new(
360            Some("Rust".to_string()),
361            Some("cli".to_string()),
362            Some(100),
363            Some(50_000),
364            Some(NaiveDate::from_ymd_opt(2026, 1, 2).unwrap()),
365            true,
366            false,
367        );
368
369        assert_eq!(
370            GithubClient::build_repository_search_query("fast search", &filters),
371            "fast search language:Rust topic:cli stars:>=100 stars:<=50000 pushed:>=2026-01-02 fork:true archived:false"
372        );
373    }
374
375    #[test]
376    fn build_repository_search_query_quotes_multi_word_qualifier_values() {
377        let filters = RepositorySearchFilters::new(
378            Some("Common Lisp".to_string()),
379            None,
380            None,
381            None,
382            None,
383            false,
384            false,
385        );
386
387        assert_eq!(
388            GithubClient::build_repository_search_query("editor", &filters),
389            "editor language:\"Common Lisp\" archived:false"
390        );
391    }
392
393    #[test]
394    fn build_repository_search_query_falls_back_when_query_and_filters_are_empty() {
395        let filters = RepositorySearchFilters::new(None, None, None, None, None, false, true);
396
397        assert_eq!(
398            GithubClient::build_repository_search_query("", &filters),
399            "stars:>=0"
400        );
401    }
402}