Skip to main content

upstream_rs/providers/github/
github_client.rs

1use anyhow::{Context, Result};
2use reqwest::{Client, header};
3use serde::Deserialize;
4use std::path::Path;
5
6use crate::{models::provider::RepositorySearchFilters, providers::download_handler};
7
8use super::github_dtos::{GithubReleaseDto, GithubRepositorySearchResponseDto};
9#[derive(Debug, Deserialize)]
10struct GithubCommitDto {
11    sha: String,
12}
13
14#[derive(Debug, Clone)]
15pub struct GithubClient {
16    client: Client,
17}
18
19impl GithubClient {
20    pub fn new(token: Option<&str>) -> Result<Self> {
21        let mut headers = header::HeaderMap::new();
22
23        let user_agent = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
24        headers.insert(
25            header::USER_AGENT,
26            header::HeaderValue::from_str(&user_agent)
27                .context("Failed to create user agent header")?,
28        );
29
30        if let Some(token) = token {
31            let auth_value = format!("Bearer {}", token);
32            headers.insert(
33                header::AUTHORIZATION,
34                header::HeaderValue::from_str(&auth_value)
35                    .context("Failed to create authorization header")?,
36            );
37        }
38
39        let client = Client::builder()
40            .default_headers(headers)
41            .build()
42            .context("Failed to build HTTP client")?;
43
44        Ok(Self { client })
45    }
46
47    async fn get_json<T: for<'de> Deserialize<'de>>(&self, url: &str) -> Result<T> {
48        let response = self
49            .client
50            .get(url)
51            .send()
52            .await
53            .context(format!("Failed to send request to {}", url))?;
54
55        response
56            .error_for_status_ref()
57            .context(format!("GitHub API returned error for {}", url))?;
58
59        let data = response
60            .json::<T>()
61            .await
62            .context("Failed to parse JSON response")?;
63
64        Ok(data)
65    }
66
67    pub async fn download_file<F>(
68        &self,
69        url: &str,
70        destination: &Path,
71        progress: &mut Option<F>,
72    ) -> Result<()>
73    where
74        F: FnMut(u64, u64),
75    {
76        download_handler::download_file(&self.client, url, destination, progress).await
77    }
78
79    pub async fn get_release_by_tag(
80        &self,
81        owner_repo: &str,
82        tag: &str,
83    ) -> Result<GithubReleaseDto> {
84        let url = format!(
85            "https://api.github.com/repos/{}/releases/tags/{}",
86            owner_repo, tag
87        );
88        self.get_json(&url)
89            .await
90            .context(format!("Failed to get release for tag {}", tag))
91    }
92
93    pub async fn get_latest_release(&self, owner_repo: &str) -> Result<GithubReleaseDto> {
94        let url = format!(
95            "https://api.github.com/repos/{}/releases/latest",
96            owner_repo
97        );
98        self.get_json(&url)
99            .await
100            .context(format!("Failed to get latest release for {}", owner_repo))
101    }
102
103    pub async fn get_releases(
104        &self,
105        owner_repo: &str,
106        per_page: Option<u32>,
107        max_total: Option<u32>,
108    ) -> Result<Vec<GithubReleaseDto>> {
109        let per_page = per_page.unwrap_or(30);
110        let mut page = 1;
111        let mut releases = Vec::new();
112
113        loop {
114            let url = format!(
115                "https://api.github.com/repos/{}/releases?per_page={}&page={}",
116                owner_repo, per_page, page
117            );
118            let batch: Vec<GithubReleaseDto> = self
119                .get_json(&url)
120                .await
121                .context(format!("Failed to get releases page {}", page))?;
122
123            if batch.is_empty() {
124                break;
125            }
126
127            releases.extend(batch);
128
129            // Check if we've hit the total limit
130            if let Some(max) = max_total
131                && releases.len() >= max as usize
132            {
133                releases.truncate(max as usize);
134                break;
135            }
136
137            // Check if this was a partial page (last page)
138            if releases.len() % per_page as usize != 0 {
139                break;
140            }
141
142            page += 1;
143        }
144
145        Ok(releases)
146    }
147
148    pub async fn get_branch_head_sha(&self, owner_repo: &str, branch: &str) -> Result<String> {
149        let encoded_branch = branch.replace('/', "%2F");
150        let url = format!(
151            "https://api.github.com/repos/{}/commits/{}",
152            owner_repo, encoded_branch
153        );
154        let dto: GithubCommitDto = self.get_json(&url).await.context(format!(
155            "Failed to get branch head for {}/{}",
156            owner_repo, branch
157        ))?;
158        Ok(dto.sha)
159    }
160
161    pub async fn search_repositories(
162        &self,
163        query: &str,
164        limit: Option<u32>,
165        filters: &RepositorySearchFilters,
166    ) -> Result<GithubRepositorySearchResponseDto> {
167        let per_page = limit.unwrap_or(10).clamp(1, 100);
168        let search_query = Self::build_repository_search_query(query, filters);
169        let mut url = reqwest::Url::parse("https://api.github.com/search/repositories")
170            .context("Failed to build GitHub search URL")?;
171        url.query_pairs_mut()
172            .append_pair("q", &search_query)
173            .append_pair("per_page", &per_page.to_string());
174
175        self.get_json(url.as_str()).await.context(format!(
176            "Failed to search repositories for '{}'",
177            search_query
178        ))
179    }
180
181    fn build_repository_search_query(query: &str, filters: &RepositorySearchFilters) -> String {
182        let mut parts = vec![query.trim().to_string()];
183
184        if let Some(language) = &filters.language {
185            parts.push(format!(
186                "language:{}",
187                Self::format_search_qualifier_value(language)
188            ));
189        }
190        if let Some(topic) = &filters.topic {
191            parts.push(format!(
192                "topic:{}",
193                Self::format_search_qualifier_value(topic)
194            ));
195        }
196        if let Some(min_stars) = filters.min_stars {
197            parts.push(format!("stars:>={min_stars}"));
198        }
199        if let Some(max_stars) = filters.max_stars {
200            parts.push(format!("stars:<={max_stars}"));
201        }
202        if let Some(pushed_after) = filters.pushed_after {
203            parts.push(format!("pushed:>={pushed_after}"));
204        }
205        if filters.include_forks {
206            parts.push("fork:true".to_string());
207        }
208        if !filters.include_archived {
209            parts.push("archived:false".to_string());
210        }
211
212        let parts = parts
213            .into_iter()
214            .filter(|part| !part.is_empty())
215            .collect::<Vec<_>>();
216        if parts.is_empty() {
217            return "stars:>=0".to_string();
218        }
219
220        parts.join(" ")
221    }
222
223    fn format_search_qualifier_value(value: &str) -> String {
224        if value.chars().any(char::is_whitespace) {
225            format!("\"{}\"", value.replace('\\', "\\\\").replace('"', "\\\""))
226        } else {
227            value.to_string()
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use chrono::NaiveDate;
235
236    use crate::models::provider::RepositorySearchFilters;
237    use crate::providers::github::GithubClient;
238    use crate::providers::github::github_dtos::{
239        GithubReleaseDto, GithubRepositorySearchResponseDto,
240    };
241
242    #[test]
243    fn github_release_dto_accepts_nullable_string_fields() {
244        let json = r#"
245        {
246          "id": 1,
247          "tag_name": "v1.0.0",
248          "name": null,
249          "body": null,
250          "prerelease": false,
251          "draft": false,
252          "published_at": null,
253          "assets": [
254            {
255              "id": 42,
256              "name": "tree-sitter-linux.tar.gz",
257              "browser_download_url": "https://example.com/asset.tar.gz",
258              "size": 1234,
259              "content_type": null,
260              "created_at": null
261            }
262          ]
263        }
264        "#;
265
266        let parsed = serde_json::from_str::<GithubReleaseDto>(json).expect("valid release JSON");
267        assert_eq!(parsed.name, "");
268        assert_eq!(parsed.body, "");
269        assert_eq!(parsed.published_at, "");
270        assert_eq!(parsed.assets[0].content_type, "");
271        assert_eq!(parsed.assets[0].created_at, "");
272    }
273
274    #[test]
275    fn github_search_dto_accepts_nullable_string_fields() {
276        let json = r#"
277        {
278          "items": [
279            {
280              "full_name": "BurntSushi/ripgrep",
281              "name": "ripgrep",
282              "description": null,
283              "stargazers_count": 10,
284              "language": null,
285              "updated_at": null,
286              "archived": false,
287              "fork": false
288            }
289          ]
290        }
291        "#;
292
293        let parsed = serde_json::from_str::<GithubRepositorySearchResponseDto>(json)
294            .expect("valid search JSON");
295        assert_eq!(parsed.items.len(), 1);
296        assert_eq!(parsed.items[0].description, "");
297        assert_eq!(parsed.items[0].language, "");
298        assert_eq!(parsed.items[0].updated_at, "");
299    }
300
301    #[test]
302    fn build_repository_search_query_adds_discovery_filters() {
303        let filters = RepositorySearchFilters::new(
304            Some("Rust".to_string()),
305            Some("cli".to_string()),
306            Some(100),
307            Some(50_000),
308            Some(NaiveDate::from_ymd_opt(2026, 1, 2).unwrap()),
309            true,
310            false,
311        );
312
313        assert_eq!(
314            GithubClient::build_repository_search_query("fast search", &filters),
315            "fast search language:Rust topic:cli stars:>=100 stars:<=50000 pushed:>=2026-01-02 fork:true archived:false"
316        );
317    }
318
319    #[test]
320    fn build_repository_search_query_quotes_multi_word_qualifier_values() {
321        let filters = RepositorySearchFilters::new(
322            Some("Common Lisp".to_string()),
323            None,
324            None,
325            None,
326            None,
327            false,
328            false,
329        );
330
331        assert_eq!(
332            GithubClient::build_repository_search_query("editor", &filters),
333            "editor language:\"Common Lisp\" archived:false"
334        );
335    }
336
337    #[test]
338    fn build_repository_search_query_falls_back_when_query_and_filters_are_empty() {
339        let filters = RepositorySearchFilters::new(None, None, None, None, None, false, true);
340
341        assert_eq!(
342            GithubClient::build_repository_search_query("", &filters),
343            "stars:>=0"
344        );
345    }
346}