Skip to main content

upstream_rs/providers/
discovery.rs

1use anyhow::{Result, anyhow};
2use reqwest::Url;
3
4use crate::{
5    models::{
6        common::enums::{Channel, Filetype, Provider},
7        provider::Release,
8        upstream::Package,
9    },
10    providers::{asset_selector::AssetCandidate, provider_manager::ProviderManager},
11    utils::filename_parser::parse_filetype,
12};
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum SourceKind {
16    Repository,
17    ForgeUrl,
18    DirectAsset,
19    DownloadPage,
20}
21
22#[derive(Debug, Clone)]
23pub struct DiscoveredSource {
24    pub original: String,
25    pub repo_slug: String,
26    pub provider: Provider,
27    pub base_url: Option<String>,
28    pub kind: SourceKind,
29}
30
31#[derive(Debug, Clone)]
32pub struct DiscoveryResult {
33    pub source: DiscoveredSource,
34    pub releases: Vec<Release>,
35    pub candidates: Vec<AssetCandidate>,
36}
37
38#[derive(Debug, Clone)]
39pub struct DiscoveryRequest {
40    pub source: String,
41    pub channel: Channel,
42    pub package_name: String,
43    pub filetype: Filetype,
44    pub match_pattern: Option<String>,
45    pub exclude_pattern: Option<String>,
46    pub base_url_override: Option<String>,
47    pub limit: u32,
48}
49
50impl DiscoveryResult {
51    pub fn recommended_candidate(&self) -> Option<&AssetCandidate> {
52        self.candidates.first()
53    }
54
55    pub fn is_ambiguous(&self) -> bool {
56        let Some(top) = self.candidates.first() else {
57            return false;
58        };
59        let Some(next) = self.candidates.get(1) else {
60            return false;
61        };
62
63        next.score >= top.score.saturating_sub(20)
64    }
65}
66
67impl ProviderManager {
68    pub async fn discover_source(&self, request: DiscoveryRequest) -> Result<DiscoveryResult> {
69        let mut discovered = infer_source(&request.source)?;
70        if let Some(base_url) = request.base_url_override.as_deref() {
71            discovered.base_url = Some(base_url.to_string());
72        }
73
74        let mut releases = self
75            .get_releases(
76                &discovered.repo_slug,
77                &discovered.provider,
78                Some(request.limit),
79                Some(request.limit),
80                discovered.base_url.as_deref(),
81            )
82            .await?;
83
84        releases = filter_releases_by_channel(releases, &request.channel);
85        releases.sort_by(|a, b| b.version.cmp(&a.version));
86
87        let probe_package = Package::with_defaults(
88            request.package_name,
89            discovered.repo_slug.clone(),
90            request.filetype,
91            request.match_pattern,
92            request.exclude_pattern,
93            request.channel,
94            discovered.provider.clone(),
95            discovered.base_url.clone(),
96        );
97
98        let candidates = releases
99            .first()
100            .map(|release| self.get_candidate_assets(release, &probe_package))
101            .transpose()?
102            .unwrap_or_default();
103
104        Ok(DiscoveryResult {
105            source: discovered,
106            releases,
107            candidates,
108        })
109    }
110}
111
112pub fn infer_source(source: &str) -> Result<DiscoveredSource> {
113    let original = source.trim().to_string();
114    if original.is_empty() {
115        return Err(anyhow!("Source cannot be empty"));
116    }
117
118    if let Ok(url) = Url::parse(&original) {
119        return infer_url_source(&original, &url);
120    }
121
122    if looks_like_repo_slug(&original) {
123        return Ok(DiscoveredSource {
124            original,
125            repo_slug: source.trim_matches('/').to_string(),
126            provider: Provider::Github,
127            base_url: None,
128            kind: SourceKind::Repository,
129        });
130    }
131
132    Ok(DiscoveredSource {
133        original: original.clone(),
134        repo_slug: original,
135        provider: Provider::WebScraper,
136        base_url: None,
137        kind: SourceKind::DownloadPage,
138    })
139}
140
141pub fn normalize_source_for_provider(
142    source: &str,
143    provider: &Provider,
144    base_url: Option<&str>,
145) -> String {
146    let trimmed = source.trim();
147    if trimmed.is_empty() {
148        return String::new();
149    }
150
151    let Ok(url) = Url::parse(trimmed) else {
152        return trimmed.trim_matches('/').to_string();
153    };
154
155    let host = url.host_str().unwrap_or("").to_lowercase();
156    let segments: Vec<&str> = url
157        .path_segments()
158        .map(|parts| parts.filter(|part| !part.is_empty()).collect())
159        .unwrap_or_default();
160
161    match provider {
162        Provider::Github => {
163            if (host == "github.com" || host == "www.github.com")
164                && let Some(slug) = owner_repo_slug(&segments)
165            {
166                return slug;
167            }
168        }
169        Provider::Gitlab => {
170            if is_gitlab_host(&host)
171                && let Some(slug) = gitlab_slug(&segments)
172            {
173                return slug;
174            }
175            if let Some(base) = base_url
176                && let Ok(base_url_parsed) = Url::parse(base)
177                && same_host(&url, &base_url_parsed)
178                && let Some(slug) = gitlab_slug(&segments)
179            {
180                return slug;
181            }
182        }
183        Provider::Gitea => {
184            if (host == "gitea.com"
185                || host == "www.gitea.com"
186                || host == "codeberg.org"
187                || host == "www.codeberg.org")
188                && let Some(slug) = owner_repo_slug(&segments)
189            {
190                return slug;
191            }
192            if let Some(base) = base_url
193                && let Ok(base_url_parsed) = Url::parse(base)
194                && same_host(&url, &base_url_parsed)
195                && let Some(slug) = owner_repo_slug(&segments)
196            {
197                return slug;
198            }
199        }
200        Provider::Direct | Provider::WebScraper => {}
201    }
202
203    trimmed.to_string()
204}
205
206pub fn infer_package_name(
207    source: &str,
208    provider: Option<&Provider>,
209    base_url: Option<&str>,
210) -> Result<Option<String>> {
211    let source_info = if let Some(provider) = provider {
212        match provider {
213            Provider::Github | Provider::Gitlab | Provider::Gitea => {
214                let repo_slug = normalize_source_for_provider(source, provider, base_url);
215                return Ok(repo_name_from_slug(&repo_slug).map(str::to_string));
216            }
217            Provider::Direct | Provider::WebScraper => return Ok(None),
218        }
219    } else {
220        infer_source(source)?
221    };
222
223    if matches!(
224        source_info.kind,
225        SourceKind::Repository | SourceKind::ForgeUrl
226    ) && matches!(
227        source_info.provider,
228        Provider::Github | Provider::Gitlab | Provider::Gitea
229    ) {
230        return Ok(repo_name_from_slug(&source_info.repo_slug).map(str::to_string));
231    }
232
233    Ok(None)
234}
235
236fn repo_name_from_slug(repo_slug: &str) -> Option<&str> {
237    repo_slug
238        .trim_matches('/')
239        .rsplit('/')
240        .next()
241        .map(|name| name.trim_end_matches(".git"))
242        .filter(|name| !name.is_empty())
243}
244
245fn infer_url_source(original: &str, url: &Url) -> Result<DiscoveredSource> {
246    let host = url.host_str().unwrap_or("").to_lowercase();
247    let segments: Vec<&str> = url
248        .path_segments()
249        .map(|parts| parts.filter(|part| !part.is_empty()).collect())
250        .unwrap_or_default();
251
252    if (host == "github.com" || host == "www.github.com")
253        && let Some(slug) = owner_repo_slug(&segments)
254    {
255        return Ok(DiscoveredSource {
256            original: original.to_string(),
257            repo_slug: slug,
258            provider: Provider::Github,
259            base_url: None,
260            kind: SourceKind::ForgeUrl,
261        });
262    }
263
264    if is_gitlab_host(&host)
265        && let Some(slug) = gitlab_slug(&segments)
266    {
267        return Ok(DiscoveredSource {
268            original: original.to_string(),
269            repo_slug: slug,
270            provider: Provider::Gitlab,
271            base_url: gitlab_base_url(url, &host),
272            kind: SourceKind::ForgeUrl,
273        });
274    }
275
276    if (host == "gitea.com"
277        || host == "www.gitea.com"
278        || host == "codeberg.org"
279        || host == "www.codeberg.org")
280        && let Some(slug) = owner_repo_slug(&segments)
281    {
282        return Ok(DiscoveredSource {
283            original: original.to_string(),
284            repo_slug: slug,
285            provider: Provider::Gitea,
286            base_url: Some(format!("{}://{}", url.scheme(), host)),
287            kind: SourceKind::ForgeUrl,
288        });
289    }
290
291    if is_direct_asset_url(url) {
292        return Ok(DiscoveredSource {
293            original: original.to_string(),
294            repo_slug: original.to_string(),
295            provider: Provider::Direct,
296            base_url: None,
297            kind: SourceKind::DirectAsset,
298        });
299    }
300
301    Ok(DiscoveredSource {
302        original: original.to_string(),
303        repo_slug: original.to_string(),
304        provider: Provider::WebScraper,
305        base_url: None,
306        kind: SourceKind::DownloadPage,
307    })
308}
309
310fn looks_like_repo_slug(value: &str) -> bool {
311    let parts: Vec<&str> = value.split('/').collect();
312    parts.len() == 2
313        && parts
314            .iter()
315            .all(|part| !part.is_empty() && !part.contains(char::is_whitespace))
316}
317
318fn owner_repo_slug(segments: &[&str]) -> Option<String> {
319    if segments.len() < 2 {
320        return None;
321    }
322
323    let owner = segments[0];
324    let repo = segments[1].trim_end_matches(".git");
325    if owner.is_empty() || repo.is_empty() {
326        return None;
327    }
328
329    Some(format!("{owner}/{repo}"))
330}
331
332fn gitlab_slug(segments: &[&str]) -> Option<String> {
333    let stop_markers = ["-", "releases", "downloads", "packages"];
334    let parts: Vec<&str> = segments
335        .iter()
336        .copied()
337        .take_while(|segment| !stop_markers.contains(segment))
338        .collect();
339
340    if parts.len() < 2 {
341        return None;
342    }
343
344    Some(parts.join("/"))
345}
346
347fn is_gitlab_host(host: &str) -> bool {
348    let normalized = host.strip_prefix("www.").unwrap_or(host);
349    normalized == "gitlab.com" || normalized.starts_with("gitlab.")
350}
351
352fn gitlab_base_url(url: &Url, host: &str) -> Option<String> {
353    if matches!(host, "gitlab.com" | "www.gitlab.com") {
354        None
355    } else {
356        Some(format!("{}://{}", url.scheme(), host))
357    }
358}
359
360fn same_host(a: &Url, b: &Url) -> bool {
361    match (a.host_str(), b.host_str()) {
362        (Some(ha), Some(hb)) => ha.eq_ignore_ascii_case(hb),
363        _ => false,
364    }
365}
366
367fn is_direct_asset_url(url: &Url) -> bool {
368    let filename = url
369        .path_segments()
370        .and_then(|mut parts| parts.next_back())
371        .unwrap_or("");
372
373    !matches!(
374        parse_filetype(filename),
375        Filetype::Binary | Filetype::Checksum | Filetype::Auto
376    )
377}
378
379fn filter_releases_by_channel(mut releases: Vec<Release>, channel: &Channel) -> Vec<Release> {
380    match channel {
381        Channel::Stable => {
382            releases.retain(|r| !r.is_prerelease && !ProviderManager::is_nightly_release(&r.tag))
383        }
384        Channel::Preview => releases.retain(ProviderManager::is_preview_release),
385        Channel::Nightly => releases.retain(|r| ProviderManager::is_nightly_release(&r.tag)),
386    }
387    releases
388}
389
390#[cfg(test)]
391mod tests {
392    use super::{
393        DiscoveredSource, DiscoveryResult, SourceKind, infer_source, normalize_source_for_provider,
394    };
395    use crate::models::{
396        common::{Version, enums::Provider},
397        provider::{Asset, Release},
398    };
399    use crate::providers::asset_selector::AssetCandidate;
400    use chrono::Utc;
401
402    #[test]
403    fn infer_source_keeps_owner_repo_as_github() {
404        let source = infer_source("BurntSushi/ripgrep").expect("infer source");
405
406        assert_eq!(source.provider, Provider::Github);
407        assert_eq!(source.repo_slug, "BurntSushi/ripgrep");
408        assert_eq!(source.kind, SourceKind::Repository);
409    }
410
411    #[test]
412    fn infer_source_normalizes_github_release_urls() {
413        let source =
414            infer_source("https://github.com/sharkdp/fd/releases/latest").expect("infer source");
415
416        assert_eq!(source.provider, Provider::Github);
417        assert_eq!(source.repo_slug, "sharkdp/fd");
418        assert_eq!(source.kind, SourceKind::ForgeUrl);
419    }
420
421    #[test]
422    fn infer_source_normalizes_plain_github_repo_urls() {
423        let source = infer_source("https://github.com/sharkdp/bat").expect("infer source");
424        assert_eq!(source.provider, Provider::Github);
425        assert_eq!(source.repo_slug, "sharkdp/bat");
426        assert_eq!(source.kind, SourceKind::ForgeUrl);
427    }
428
429    #[test]
430    fn infer_source_normalizes_www_github_repo_urls() {
431        let source = infer_source("https://www.github.com/sharkdp/bat/").expect("infer source");
432        assert_eq!(source.provider, Provider::Github);
433        assert_eq!(source.repo_slug, "sharkdp/bat");
434        assert_eq!(source.kind, SourceKind::ForgeUrl);
435    }
436
437    #[test]
438    fn infer_source_strips_git_suffix_for_repo_urls() {
439        let source = infer_source("https://github.com/sharkdp/bat.git").expect("infer source");
440        assert_eq!(source.provider, Provider::Github);
441        assert_eq!(source.repo_slug, "sharkdp/bat");
442        assert_eq!(source.kind, SourceKind::ForgeUrl);
443    }
444
445    #[test]
446    fn infer_source_normalizes_self_hosted_gitlab_urls() {
447        let source = infer_source("https://gitlab.futo.org/videostreaming/Grayjay.Desktop")
448            .expect("infer source");
449
450        assert_eq!(source.provider, Provider::Gitlab);
451        assert_eq!(source.repo_slug, "videostreaming/Grayjay.Desktop");
452        assert_eq!(source.base_url.as_deref(), Some("https://gitlab.futo.org"));
453        assert_eq!(source.kind, SourceKind::ForgeUrl);
454    }
455
456    #[test]
457    fn normalize_source_for_provider_extracts_slug_for_github_urls() {
458        let normalized = normalize_source_for_provider(
459            "https://github.com/sharkdp/bat",
460            &Provider::Github,
461            None,
462        );
463        assert_eq!(normalized, "sharkdp/bat");
464    }
465
466    #[test]
467    fn normalize_source_for_provider_extracts_slug_for_self_hosted_gitlab_urls() {
468        let normalized = normalize_source_for_provider(
469            "https://gitlab.futo.org/videostreaming/Grayjay.Desktop",
470            &Provider::Gitlab,
471            None,
472        );
473
474        assert_eq!(normalized, "videostreaming/Grayjay.Desktop");
475    }
476
477    #[test]
478    fn infer_source_normalizes_codeberg_urls_as_gitea() {
479        let source =
480            infer_source("https://codeberg.org/forgejo/forgejo/releases").expect("infer source");
481
482        assert_eq!(source.provider, Provider::Gitea);
483        assert_eq!(source.repo_slug, "forgejo/forgejo");
484        assert_eq!(source.base_url.as_deref(), Some("https://codeberg.org"));
485    }
486
487    #[test]
488    fn infer_package_name_uses_git_repo_basename() {
489        assert_eq!(
490            super::infer_package_name("BurntSushi/ripgrep", None, None).expect("infer name"),
491            Some("ripgrep".to_string())
492        );
493        assert_eq!(
494            super::infer_package_name("https://github.com/sharkdp/bat.git", None, None)
495                .expect("infer name"),
496            Some("bat".to_string())
497        );
498        assert_eq!(
499            super::infer_package_name(
500                "https://gitlab.futo.org/videostreaming/Grayjay.Desktop",
501                None,
502                None,
503            )
504            .expect("infer name"),
505            Some("Grayjay.Desktop".to_string())
506        );
507    }
508
509    #[test]
510    fn infer_package_name_returns_none_for_http_sources() {
511        assert_eq!(
512            super::infer_package_name("https://example.invalid/downloads", None, None)
513                .expect("infer name"),
514            None
515        );
516        assert_eq!(
517            super::infer_package_name("https://example.invalid/tool.tar.gz", None, None)
518                .expect("infer name"),
519            None
520        );
521    }
522
523    #[test]
524    fn infer_source_detects_direct_assets() {
525        let source =
526            infer_source("https://example.invalid/download/tool-linux-x64.tar.gz").expect("infer");
527
528        assert_eq!(source.provider, Provider::Direct);
529        assert_eq!(source.kind, SourceKind::DirectAsset);
530    }
531
532    #[test]
533    fn infer_source_uses_scraper_for_generic_pages() {
534        let source = infer_source("https://example.invalid/downloads").expect("infer");
535
536        assert_eq!(source.provider, Provider::WebScraper);
537        assert_eq!(source.kind, SourceKind::DownloadPage);
538    }
539
540    #[test]
541    fn discovery_result_marks_close_scores_as_ambiguous() {
542        let release = Release {
543            id: 1,
544            tag: "v1.0.0".to_string(),
545            name: "v1.0.0".to_string(),
546            body: String::new(),
547            is_draft: false,
548            is_prerelease: false,
549            assets: Vec::new(),
550            version: Version::new(1, 0, 0, false),
551            published_at: Utc::now(),
552        };
553        let source = DiscoveredSource {
554            original: "https://example.invalid/downloads".to_string(),
555            repo_slug: "https://example.invalid/downloads".to_string(),
556            provider: Provider::WebScraper,
557            base_url: None,
558            kind: SourceKind::DownloadPage,
559        };
560        let asset = Asset::new(
561            "https://example.invalid/tool.tar.gz".to_string(),
562            1,
563            "tool.tar.gz".to_string(),
564            1000,
565            Utc::now(),
566        );
567        let result = DiscoveryResult {
568            source,
569            releases: vec![release],
570            candidates: vec![
571                AssetCandidate {
572                    asset: asset.clone(),
573                    score: 100,
574                },
575                AssetCandidate { asset, score: 80 },
576            ],
577        };
578
579        assert!(result.is_ambiguous());
580    }
581}