Skip to main content

upstream_rs/providers/github/
github_client.rs

1use anyhow::{Context, Result};
2use reqwest::{Client, header};
3use serde::Deserialize;
4use std::path::Path;
5
6use crate::{
7    models::{provider::RepositorySearchFilters, upstream::DownloadConfig},
8    providers::download_handler,
9};
10
11use super::github_dtos::{GithubReleaseDto, GithubRepositorySearchResponseDto};
12#[derive(Debug, Deserialize)]
13struct GithubCommitDto {
14    sha: String,
15}
16
17#[derive(Debug, Clone)]
18pub struct GithubClient {
19    client: Client,
20    download_config: DownloadConfig,
21}
22
23impl GithubClient {
24    pub fn new(token: Option<&str>) -> Result<Self> {
25        Self::new_with_download_config(token, DownloadConfig::default())
26    }
27
28    pub fn new_with_download_config(
29        token: Option<&str>,
30        download_config: DownloadConfig,
31    ) -> Result<Self> {
32        let mut headers = header::HeaderMap::new();
33
34        let user_agent = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
35        headers.insert(
36            header::USER_AGENT,
37            header::HeaderValue::from_str(&user_agent)
38                .context("Failed to create user agent header")?,
39        );
40
41        if let Some(token) = token {
42            let auth_value = format!("Bearer {}", token);
43            headers.insert(
44                header::AUTHORIZATION,
45                header::HeaderValue::from_str(&auth_value)
46                    .context("Failed to create authorization header")?,
47            );
48        }
49
50        let client = Client::builder()
51            .default_headers(headers)
52            .build()
53            .context("Failed to build HTTP client")?;
54
55        Ok(Self {
56            client,
57            download_config,
58        })
59    }
60
61    async fn get_json<T: for<'de> Deserialize<'de>>(&self, url: &str) -> Result<T> {
62        let response = self
63            .client
64            .get(url)
65            .send()
66            .await
67            .context(format!("Failed to send request to {}", url))?;
68
69        response
70            .error_for_status_ref()
71            .context(format!("GitHub API returned error for {}", url))?;
72
73        let data = response
74            .json::<T>()
75            .await
76            .context("Failed to parse JSON response")?;
77
78        Ok(data)
79    }
80
81    pub async fn download_file<F>(
82        &self,
83        url: &str,
84        destination: &Path,
85        progress: &mut Option<F>,
86    ) -> Result<()>
87    where
88        F: FnMut(u64, u64),
89    {
90        download_handler::download_file_with_config(
91            &self.client,
92            url,
93            destination,
94            progress,
95            self.download_config,
96        )
97        .await
98    }
99
100    pub async fn get_release_by_tag(
101        &self,
102        owner_repo: &str,
103        tag: &str,
104    ) -> Result<GithubReleaseDto> {
105        let url = format!(
106            "https://api.github.com/repos/{}/releases/tags/{}",
107            owner_repo, tag
108        );
109        self.get_json(&url)
110            .await
111            .context(format!("Failed to get release for tag {}", tag))
112    }
113
114    pub async fn get_latest_release(&self, owner_repo: &str) -> Result<GithubReleaseDto> {
115        let url = format!(
116            "https://api.github.com/repos/{}/releases/latest",
117            owner_repo
118        );
119        self.get_json(&url)
120            .await
121            .context(format!("Failed to get latest release for {}", owner_repo))
122    }
123
124    pub async fn get_releases(
125        &self,
126        owner_repo: &str,
127        per_page: Option<u32>,
128        max_total: Option<u32>,
129    ) -> Result<Vec<GithubReleaseDto>> {
130        let per_page = per_page.unwrap_or(30);
131        let mut page = 1;
132        let mut releases = Vec::new();
133
134        loop {
135            let batch = self
136                .get_releases_page(owner_repo, per_page, page)
137                .await
138                .context(format!("Failed to get releases page {}", page))?;
139            let partial_page = batch.len() < per_page as usize;
140
141            if batch.is_empty() {
142                break;
143            }
144
145            releases.extend(batch);
146
147            // Check if we've hit the total limit
148            if let Some(max) = max_total
149                && releases.len() >= max as usize
150            {
151                releases.truncate(max as usize);
152                break;
153            }
154
155            if partial_page {
156                break;
157            }
158
159            page += 1;
160        }
161
162        Ok(releases)
163    }
164
165    pub async fn get_releases_page(
166        &self,
167        owner_repo: &str,
168        per_page: u32,
169        page: u32,
170    ) -> Result<Vec<GithubReleaseDto>> {
171        let url = format!(
172            "https://api.github.com/repos/{}/releases?per_page={}&page={}",
173            owner_repo, per_page, page
174        );
175        self.get_json(&url)
176            .await
177            .context(format!("Failed to get releases page {}", page))
178    }
179
180    pub async fn get_branch_head_sha(&self, owner_repo: &str, branch: &str) -> Result<String> {
181        let encoded_branch = branch.replace('/', "%2F");
182        let url = format!(
183            "https://api.github.com/repos/{}/commits/{}",
184            owner_repo, encoded_branch
185        );
186        let dto: GithubCommitDto = self.get_json(&url).await.context(format!(
187            "Failed to get branch head for {}/{}",
188            owner_repo, branch
189        ))?;
190        Ok(dto.sha)
191    }
192
193    pub async fn search_repositories(
194        &self,
195        query: &str,
196        limit: Option<u32>,
197        filters: &RepositorySearchFilters,
198    ) -> Result<GithubRepositorySearchResponseDto> {
199        let per_page = limit.unwrap_or(10).clamp(1, 100);
200        let search_query = Self::build_repository_search_query(query, filters);
201        let mut url = reqwest::Url::parse("https://api.github.com/search/repositories")
202            .context("Failed to build GitHub search URL")?;
203        url.query_pairs_mut()
204            .append_pair("q", &search_query)
205            .append_pair("per_page", &per_page.to_string());
206
207        self.get_json(url.as_str()).await.context(format!(
208            "Failed to search repositories for '{}'",
209            search_query
210        ))
211    }
212
213    fn build_repository_search_query(query: &str, filters: &RepositorySearchFilters) -> String {
214        let mut parts = vec![query.trim().to_string()];
215
216        if let Some(language) = &filters.language {
217            parts.push(format!(
218                "language:{}",
219                Self::format_search_qualifier_value(language)
220            ));
221        }
222        if let Some(topic) = &filters.topic {
223            parts.push(format!(
224                "topic:{}",
225                Self::format_search_qualifier_value(topic)
226            ));
227        }
228        if let Some(min_stars) = filters.min_stars {
229            parts.push(format!("stars:>={min_stars}"));
230        }
231        if let Some(max_stars) = filters.max_stars {
232            parts.push(format!("stars:<={max_stars}"));
233        }
234        if let Some(pushed_after) = filters.pushed_after {
235            parts.push(format!("pushed:>={pushed_after}"));
236        }
237        if filters.include_forks {
238            parts.push("fork:true".to_string());
239        }
240        if !filters.include_archived {
241            parts.push("archived:false".to_string());
242        }
243
244        let parts = parts
245            .into_iter()
246            .filter(|part| !part.is_empty())
247            .collect::<Vec<_>>();
248        if parts.is_empty() {
249            return "stars:>=0".to_string();
250        }
251
252        parts.join(" ")
253    }
254
255    fn format_search_qualifier_value(value: &str) -> String {
256        if value.chars().any(char::is_whitespace) {
257            format!("\"{}\"", value.replace('\\', "\\\\").replace('"', "\\\""))
258        } else {
259            value.to_string()
260        }
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use chrono::NaiveDate;
267
268    use crate::models::provider::RepositorySearchFilters;
269    use crate::providers::github::GithubClient;
270    use crate::providers::github::github_dtos::{
271        GithubReleaseDto, GithubRepositorySearchResponseDto,
272    };
273
274    #[test]
275    fn github_release_dto_accepts_nullable_string_fields() {
276        let json = r#"
277        {
278          "id": 1,
279          "tag_name": "v1.0.0",
280          "name": null,
281          "body": null,
282          "prerelease": false,
283          "draft": false,
284          "published_at": null,
285          "assets": [
286            {
287              "id": 42,
288              "name": "tree-sitter-linux.tar.gz",
289              "browser_download_url": "https://example.com/asset.tar.gz",
290              "size": 1234,
291              "content_type": null,
292              "created_at": null
293            }
294          ]
295        }
296        "#;
297
298        let parsed = serde_json::from_str::<GithubReleaseDto>(json).expect("valid release JSON");
299        assert_eq!(parsed.name, "");
300        assert_eq!(parsed.body, "");
301        assert_eq!(parsed.published_at, "");
302        assert_eq!(parsed.assets[0].content_type, "");
303        assert_eq!(parsed.assets[0].created_at, "");
304    }
305
306    #[test]
307    fn github_search_dto_accepts_nullable_string_fields() {
308        let json = r#"
309        {
310          "items": [
311            {
312              "full_name": "BurntSushi/ripgrep",
313              "name": "ripgrep",
314              "description": null,
315              "stargazers_count": 10,
316              "language": null,
317              "updated_at": null,
318              "archived": false,
319              "fork": false
320            }
321          ]
322        }
323        "#;
324
325        let parsed = serde_json::from_str::<GithubRepositorySearchResponseDto>(json)
326            .expect("valid search JSON");
327        assert_eq!(parsed.items.len(), 1);
328        assert_eq!(parsed.items[0].description, "");
329        assert_eq!(parsed.items[0].language, "");
330        assert_eq!(parsed.items[0].updated_at, "");
331    }
332
333    #[test]
334    fn build_repository_search_query_adds_discovery_filters() {
335        let filters = RepositorySearchFilters::new(
336            Some("Rust".to_string()),
337            Some("cli".to_string()),
338            Some(100),
339            Some(50_000),
340            Some(NaiveDate::from_ymd_opt(2026, 1, 2).unwrap()),
341            true,
342            false,
343        );
344
345        assert_eq!(
346            GithubClient::build_repository_search_query("fast search", &filters),
347            "fast search language:Rust topic:cli stars:>=100 stars:<=50000 pushed:>=2026-01-02 fork:true archived:false"
348        );
349    }
350
351    #[test]
352    fn build_repository_search_query_quotes_multi_word_qualifier_values() {
353        let filters = RepositorySearchFilters::new(
354            Some("Common Lisp".to_string()),
355            None,
356            None,
357            None,
358            None,
359            false,
360            false,
361        );
362
363        assert_eq!(
364            GithubClient::build_repository_search_query("editor", &filters),
365            "editor language:\"Common Lisp\" archived:false"
366        );
367    }
368
369    #[test]
370    fn build_repository_search_query_falls_back_when_query_and_filters_are_empty() {
371        let filters = RepositorySearchFilters::new(None, None, None, None, None, false, true);
372
373        assert_eq!(
374            GithubClient::build_repository_search_query("", &filters),
375            "stars:>=0"
376        );
377    }
378}