Skip to main content

upstream_rs/providers/gitlab/
gitlab_client.rs

1use anyhow::{Context, Result, bail};
2use reqwest::{Client, header};
3use serde::Deserialize;
4use std::path::Path;
5use tokio::fs::File;
6use tokio::io::AsyncWriteExt;
7
8use super::gitlab_dtos::GitlabReleaseDto;
9
10#[derive(Debug, Clone)]
11pub struct GitlabClient {
12    client: Client,
13    base_url: String,
14}
15
16impl GitlabClient {
17    pub fn new(token: Option<&str>, base_url: Option<&str>) -> Result<Self> {
18        let mut base = base_url.unwrap_or("https://gitlab.com").to_string();
19
20        if !base.starts_with("http://") && !base.starts_with("https://") {
21            base = format!("https://{}", base);
22        }
23
24        let mut headers = header::HeaderMap::new();
25        let user_agent = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
26        headers.insert(
27            header::USER_AGENT,
28            header::HeaderValue::from_str(&user_agent)
29                .context("Failed to create user agent header")?,
30        );
31
32        if let Some(token) = token {
33            headers.insert(
34                "PRIVATE-TOKEN",
35                header::HeaderValue::from_str(token)
36                    .context("Failed to create private token header")?,
37            );
38        }
39
40        let client = Client::builder()
41            .default_headers(headers)
42            .build()
43            .context("Failed to build HTTP client")?;
44
45        Ok(Self {
46            client,
47            base_url: base,
48        })
49    }
50
51    async fn get_json<T: for<'de> Deserialize<'de>>(&self, url: &str) -> Result<T> {
52        let response = self
53            .client
54            .get(url)
55            .send()
56            .await
57            .context(format!("Failed to send request to {}", url))?;
58
59        response
60            .error_for_status_ref()
61            .context(format!("GitLab API returned error for {}", url))?;
62
63        let data = response
64            .json::<T>()
65            .await
66            .context("Failed to parse JSON response")?;
67
68        Ok(data)
69    }
70
71    pub async fn download_file<F>(
72        &self,
73        url: &str,
74        destination: &Path,
75        progress: &mut Option<F>,
76    ) -> Result<()>
77    where
78        F: FnMut(u64, u64),
79    {
80        let response = self
81            .client
82            .get(url)
83            .send()
84            .await
85            .context(format!("Failed to download from {}", url))?;
86
87        response
88            .error_for_status_ref()
89            .context("Download request failed")?;
90
91        let total_bytes = response.content_length().unwrap_or(0);
92        let mut file = File::create(destination)
93            .await
94            .context(format!("Failed to create file at {:?}", destination))?;
95
96        let mut stream = response.bytes_stream();
97        let mut total_read: u64 = 0;
98
99        use futures_util::StreamExt;
100        while let Some(chunk) = stream.next().await {
101            let chunk = chunk.context("Failed to read download chunk")?;
102            file.write_all(&chunk)
103                .await
104                .context("Failed to write to file")?;
105
106            total_read += chunk.len() as u64;
107            if let Some(cb) = progress.as_mut() {
108                cb(total_read, total_bytes);
109            }
110        }
111
112        file.flush().await.context("Failed to flush file")?;
113
114        if total_bytes > 0 && total_read != total_bytes {
115            bail!(
116                "Download size mismatch: expected {} bytes, got {} bytes",
117                total_bytes,
118                total_read
119            );
120        }
121
122        Ok(())
123    }
124
125    fn encode_project_path(project_path: &str) -> String {
126        project_path.replace('/', "%2F")
127    }
128
129    pub async fn get_release_by_tag(
130        &self,
131        project_path: &str,
132        tag: &str,
133    ) -> Result<GitlabReleaseDto> {
134        let encoded_path = Self::encode_project_path(project_path);
135        let url = format!(
136            "{}/api/v4/projects/{}/releases/{}",
137            self.base_url, encoded_path, tag
138        );
139        self.get_json(&url)
140            .await
141            .context(format!("Failed to get release for tag {}", tag))
142    }
143
144    pub async fn get_releases(
145        &self,
146        project_path: &str,
147        per_page: Option<u32>,
148        max_total: Option<u32>,
149    ) -> Result<Vec<GitlabReleaseDto>> {
150        let per_page = per_page.unwrap_or(20).min(100);
151        let encoded_path = Self::encode_project_path(project_path);
152        let mut page = 1;
153        let mut releases = Vec::new();
154
155        loop {
156            let url = format!(
157                "{}/api/v4/projects/{}/releases?per_page={}&page={}",
158                self.base_url, encoded_path, per_page, page
159            );
160            let batch: Vec<GitlabReleaseDto> = self
161                .get_json(&url)
162                .await
163                .context(format!("Failed to get releases page {}", page))?;
164
165            if batch.is_empty() {
166                break;
167            }
168
169            releases.extend(batch);
170
171            if let Some(max) = max_total
172                && releases.len() >= max as usize
173            {
174                releases.truncate(max as usize);
175                break;
176            }
177
178            if releases.len() % per_page as usize != 0 {
179                break;
180            }
181
182            page += 1;
183        }
184
185        Ok(releases)
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::GitlabClient;
192    use crate::providers::gitlab::gitlab_dtos::GitlabReleaseDto;
193
194    #[test]
195    fn new_normalizes_base_url_without_scheme() {
196        let client = GitlabClient::new(None, Some("gitlab.example.com")).expect("client");
197        assert_eq!(client.base_url, "https://gitlab.example.com");
198    }
199
200    #[test]
201    fn encode_project_path_percent_encodes_slashes() {
202        assert_eq!(
203            GitlabClient::encode_project_path("group/subgroup/project"),
204            "group%2Fsubgroup%2Fproject"
205        );
206    }
207
208    #[test]
209    fn gitlab_release_dto_deserializes_minimal_valid_payload() {
210        let json = r#"
211            {
212              "tag_name": "v1.0.0",
213              "name": "v1.0.0",
214              "description": "notes",
215              "created_at": "2026-02-21T00:00:00Z",
216              "released_at": null,
217              "upcoming_release": false,
218              "assets": { "count": 0, "sources": [], "links": [] }
219            }
220            "#;
221
222        let parsed = serde_json::from_str::<GitlabReleaseDto>(json).expect("parse release");
223        assert_eq!(parsed.tag_name, "v1.0.0");
224        assert_eq!(parsed.assets.count, 0);
225        assert!(parsed.assets.links.is_empty());
226    }
227}