Skip to main content

upstream_rs/providers/github/
github_client.rs

1use anyhow::{Context, Result};
2use reqwest::{Client, header};
3use serde::Deserialize;
4use std::path::Path;
5
6use crate::{
7    models::{provider::RepositorySearchFilters, upstream::DownloadConfig},
8    providers::download_handler,
9};
10
11use super::github_dtos::{GithubReleaseDto, GithubRepositorySearchResponseDto};
12#[derive(Debug, Deserialize)]
13struct GithubCommitDto {
14    sha: String,
15}
16
17#[derive(Debug, Clone)]
18pub struct GithubClient {
19    client: Client,
20    download_config: DownloadConfig,
21}
22
23impl GithubClient {
24    pub fn new(token: Option<&str>, download_config: DownloadConfig) -> Result<Self> {
25        let mut headers = header::HeaderMap::new();
26
27        let user_agent = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
28        headers.insert(
29            header::USER_AGENT,
30            header::HeaderValue::from_str(&user_agent)
31                .context("Failed to create user agent header")?,
32        );
33
34        if let Some(token) = token {
35            let auth_value = format!("Bearer {}", token);
36            headers.insert(
37                header::AUTHORIZATION,
38                header::HeaderValue::from_str(&auth_value)
39                    .context("Failed to create authorization header")?,
40            );
41        }
42
43        let client = Client::builder()
44            .default_headers(headers)
45            .build()
46            .context("Failed to build HTTP client")?;
47
48        Ok(Self {
49            client,
50            download_config,
51        })
52    }
53
54    async fn get_json<T: for<'de> Deserialize<'de>>(&self, url: &str) -> Result<T> {
55        let response = self
56            .client
57            .get(url)
58            .send()
59            .await
60            .context(format!("Failed to send request to {}", url))?;
61
62        response
63            .error_for_status_ref()
64            .context(format!("GitHub API returned error for {}", url))?;
65
66        let data = response
67            .json::<T>()
68            .await
69            .context("Failed to parse JSON response")?;
70
71        Ok(data)
72    }
73
74    pub async fn download_file<F>(
75        &self,
76        url: &str,
77        destination: &Path,
78        progress: &mut Option<F>,
79    ) -> Result<()>
80    where
81        F: FnMut(u64, u64),
82    {
83        download_handler::download_file(
84            &self.client,
85            url,
86            destination,
87            progress,
88            self.download_config,
89        )
90        .await
91    }
92
93    pub async fn get_release_by_tag(
94        &self,
95        owner_repo: &str,
96        tag: &str,
97    ) -> Result<GithubReleaseDto> {
98        let url = format!(
99            "https://api.github.com/repos/{}/releases/tags/{}",
100            owner_repo, tag
101        );
102        self.get_json(&url)
103            .await
104            .context(format!("Failed to get release for tag {}", tag))
105    }
106
107    pub async fn get_latest_release(&self, owner_repo: &str) -> Result<GithubReleaseDto> {
108        let url = format!(
109            "https://api.github.com/repos/{}/releases/latest",
110            owner_repo
111        );
112        self.get_json(&url)
113            .await
114            .context(format!("Failed to get latest release for {}", owner_repo))
115    }
116
117    pub async fn get_releases(
118        &self,
119        owner_repo: &str,
120        per_page: Option<u32>,
121        max_total: Option<u32>,
122    ) -> Result<Vec<GithubReleaseDto>> {
123        let per_page = per_page.unwrap_or(30);
124        let mut page = 1;
125        let mut releases = Vec::new();
126
127        loop {
128            let batch = self
129                .get_releases_page(owner_repo, per_page, page)
130                .await
131                .context(format!("Failed to get releases page {}", page))?;
132            let partial_page = batch.len() < per_page as usize;
133
134            if batch.is_empty() {
135                break;
136            }
137
138            releases.extend(batch);
139
140            // Check if we've hit the total limit
141            if let Some(max) = max_total
142                && releases.len() >= max as usize
143            {
144                releases.truncate(max as usize);
145                break;
146            }
147
148            if partial_page {
149                break;
150            }
151
152            page += 1;
153        }
154
155        Ok(releases)
156    }
157
158    pub async fn get_releases_page(
159        &self,
160        owner_repo: &str,
161        per_page: u32,
162        page: u32,
163    ) -> Result<Vec<GithubReleaseDto>> {
164        let url = format!(
165            "https://api.github.com/repos/{}/releases?per_page={}&page={}",
166            owner_repo, per_page, page
167        );
168        self.get_json(&url)
169            .await
170            .context(format!("Failed to get releases page {}", page))
171    }
172
173    pub async fn get_branch_head_sha(&self, owner_repo: &str, branch: &str) -> Result<String> {
174        let encoded_branch = branch.replace('/', "%2F");
175        let url = format!(
176            "https://api.github.com/repos/{}/commits/{}",
177            owner_repo, encoded_branch
178        );
179        let dto: GithubCommitDto = self.get_json(&url).await.context(format!(
180            "Failed to get branch head for {}/{}",
181            owner_repo, branch
182        ))?;
183        Ok(dto.sha)
184    }
185
186    pub async fn search_repositories(
187        &self,
188        query: &str,
189        limit: Option<u32>,
190        filters: &RepositorySearchFilters,
191    ) -> Result<GithubRepositorySearchResponseDto> {
192        let per_page = limit.unwrap_or(10).clamp(1, 100);
193        let search_query = Self::build_repository_search_query(query, filters);
194        let mut url = reqwest::Url::parse("https://api.github.com/search/repositories")
195            .context("Failed to build GitHub search URL")?;
196        url.query_pairs_mut()
197            .append_pair("q", &search_query)
198            .append_pair("per_page", &per_page.to_string());
199
200        self.get_json(url.as_str()).await.context(format!(
201            "Failed to search repositories for '{}'",
202            search_query
203        ))
204    }
205
206    fn build_repository_search_query(query: &str, filters: &RepositorySearchFilters) -> String {
207        let mut parts = vec![query.trim().to_string()];
208
209        if let Some(language) = &filters.language {
210            parts.push(format!(
211                "language:{}",
212                Self::format_search_qualifier_value(language)
213            ));
214        }
215        if let Some(topic) = &filters.topic {
216            parts.push(format!(
217                "topic:{}",
218                Self::format_search_qualifier_value(topic)
219            ));
220        }
221        if let Some(min_stars) = filters.min_stars {
222            parts.push(format!("stars:>={min_stars}"));
223        }
224        if let Some(max_stars) = filters.max_stars {
225            parts.push(format!("stars:<={max_stars}"));
226        }
227        if let Some(pushed_after) = filters.pushed_after {
228            parts.push(format!("pushed:>={pushed_after}"));
229        }
230        if filters.include_forks {
231            parts.push("fork:true".to_string());
232        }
233        if !filters.include_archived {
234            parts.push("archived:false".to_string());
235        }
236
237        let parts = parts
238            .into_iter()
239            .filter(|part| !part.is_empty())
240            .collect::<Vec<_>>();
241        if parts.is_empty() {
242            return "stars:>=0".to_string();
243        }
244
245        parts.join(" ")
246    }
247
248    fn format_search_qualifier_value(value: &str) -> String {
249        if value.chars().any(char::is_whitespace) {
250            format!("\"{}\"", value.replace('\\', "\\\\").replace('"', "\\\""))
251        } else {
252            value.to_string()
253        }
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use chrono::NaiveDate;
260
261    use crate::models::provider::RepositorySearchFilters;
262    use crate::providers::github::GithubClient;
263    use crate::providers::github::github_dtos::{
264        GithubReleaseDto, GithubRepositorySearchResponseDto,
265    };
266
267    #[test]
268    fn github_release_dto_accepts_nullable_string_fields() {
269        let json = r#"
270        {
271          "id": 1,
272          "tag_name": "v1.0.0",
273          "name": null,
274          "body": null,
275          "prerelease": false,
276          "draft": false,
277          "published_at": null,
278          "assets": [
279            {
280              "id": 42,
281              "name": "tree-sitter-linux.tar.gz",
282              "browser_download_url": "https://example.com/asset.tar.gz",
283              "size": 1234,
284              "content_type": null,
285              "created_at": null
286            }
287          ]
288        }
289        "#;
290
291        let parsed = serde_json::from_str::<GithubReleaseDto>(json).expect("valid release JSON");
292        assert_eq!(parsed.name, "");
293        assert_eq!(parsed.body, "");
294        assert_eq!(parsed.published_at, "");
295        assert_eq!(parsed.assets[0].content_type, "");
296        assert_eq!(parsed.assets[0].created_at, "");
297    }
298
299    #[test]
300    fn github_search_dto_accepts_nullable_string_fields() {
301        let json = r#"
302        {
303          "items": [
304            {
305              "full_name": "BurntSushi/ripgrep",
306              "name": "ripgrep",
307              "description": null,
308              "stargazers_count": 10,
309              "language": null,
310              "updated_at": null,
311              "archived": false,
312              "fork": false
313            }
314          ]
315        }
316        "#;
317
318        let parsed = serde_json::from_str::<GithubRepositorySearchResponseDto>(json)
319            .expect("valid search JSON");
320        assert_eq!(parsed.items.len(), 1);
321        assert_eq!(parsed.items[0].description, "");
322        assert_eq!(parsed.items[0].language, "");
323        assert_eq!(parsed.items[0].updated_at, "");
324    }
325
326    #[test]
327    fn build_repository_search_query_adds_discovery_filters() {
328        let filters = RepositorySearchFilters::new(
329            Some("Rust".to_string()),
330            Some("cli".to_string()),
331            Some(100),
332            Some(50_000),
333            Some(NaiveDate::from_ymd_opt(2026, 1, 2).unwrap()),
334            true,
335            false,
336        );
337
338        assert_eq!(
339            GithubClient::build_repository_search_query("fast search", &filters),
340            "fast search language:Rust topic:cli stars:>=100 stars:<=50000 pushed:>=2026-01-02 fork:true archived:false"
341        );
342    }
343
344    #[test]
345    fn build_repository_search_query_quotes_multi_word_qualifier_values() {
346        let filters = RepositorySearchFilters::new(
347            Some("Common Lisp".to_string()),
348            None,
349            None,
350            None,
351            None,
352            false,
353            false,
354        );
355
356        assert_eq!(
357            GithubClient::build_repository_search_query("editor", &filters),
358            "editor language:\"Common Lisp\" archived:false"
359        );
360    }
361
362    #[test]
363    fn build_repository_search_query_falls_back_when_query_and_filters_are_empty() {
364        let filters = RepositorySearchFilters::new(None, None, None, None, None, false, true);
365
366        assert_eq!(
367            GithubClient::build_repository_search_query("", &filters),
368            "stars:>=0"
369        );
370    }
371}