Skip to main content

upstream_rs/providers/
pattern_matcher.rs

1use std::collections::HashSet;
2use std::fmt;
3
4use serde::de::{self, SeqAccess, Visitor};
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6
7use crate::models::common::Version;
8use crate::models::provider::Asset;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct GeneratedAssetPatterns {
12    pub match_pattern: PatternTable,
13    pub exclude_pattern: PatternTable,
14}
15
16#[derive(Debug, Clone, Default, PartialEq, Eq)]
17pub struct PatternTable {
18    patterns: Vec<String>,
19}
20
21impl PatternTable {
22    pub fn empty() -> Self {
23        Self::default()
24    }
25
26    pub fn from_cli_arg(value: Option<String>) -> Self {
27        value
28            .map(|value| Self::from_comma_separated(&value))
29            .unwrap_or_default()
30    }
31
32    pub fn from_patterns<I, S>(patterns: I) -> Self
33    where
34        I: IntoIterator<Item = S>,
35        S: AsRef<str>,
36    {
37        let mut seen = HashSet::new();
38        let mut out = Vec::new();
39        for pattern in patterns {
40            let normalized = normalize_pattern(pattern.as_ref());
41            if !normalized.is_empty() && seen.insert(normalized.clone()) {
42                out.push(normalized);
43            }
44        }
45        Self { patterns: out }
46    }
47
48    pub fn from_comma_separated(value: &str) -> Self {
49        Self::from_patterns(value.split(','))
50    }
51
52    fn from_legacy_string(value: &str) -> Self {
53        let mut seen = HashSet::new();
54        let mut out = Vec::new();
55        for chunk in value.split(',') {
56            for pattern in chunk.split_whitespace() {
57                let normalized = normalize_pattern(pattern);
58                if !normalized.is_empty() && seen.insert(normalized.clone()) {
59                    out.push(normalized);
60                }
61            }
62        }
63        Self { patterns: out }
64    }
65
66    pub fn is_empty(&self) -> bool {
67        self.patterns.is_empty()
68    }
69
70    pub fn as_slice(&self) -> &[String] {
71        &self.patterns
72    }
73
74    pub fn match_ratio(&self, value: &str) -> f64 {
75        if self.patterns.is_empty() {
76            return 0.0;
77        }
78
79        let matched = self
80            .patterns
81            .iter()
82            .filter(|pattern| pattern_matches_value(value, pattern))
83            .count();
84        matched as f64 / self.patterns.len() as f64
85    }
86}
87
88impl fmt::Display for PatternTable {
89    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90        write!(f, "{}", self.patterns.join(","))
91    }
92}
93
94impl Serialize for PatternTable {
95    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
96    where
97        S: Serializer,
98    {
99        self.patterns.serialize(serializer)
100    }
101}
102
103impl<'de> Deserialize<'de> for PatternTable {
104    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
105    where
106        D: Deserializer<'de>,
107    {
108        deserializer.deserialize_any(PatternTableVisitor)
109    }
110}
111
112struct PatternTableVisitor;
113
114impl<'de> Visitor<'de> for PatternTableVisitor {
115    type Value = PatternTable;
116
117    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
118        formatter.write_str("null, a string, or an array of pattern strings")
119    }
120
121    fn visit_unit<E>(self) -> Result<Self::Value, E>
122    where
123        E: de::Error,
124    {
125        Ok(PatternTable::empty())
126    }
127
128    fn visit_none<E>(self) -> Result<Self::Value, E>
129    where
130        E: de::Error,
131    {
132        Ok(PatternTable::empty())
133    }
134
135    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
136    where
137        E: de::Error,
138    {
139        Ok(PatternTable::from_legacy_string(value))
140    }
141
142    fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
143    where
144        E: de::Error,
145    {
146        Ok(PatternTable::from_legacy_string(&value))
147    }
148
149    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
150    where
151        A: SeqAccess<'de>,
152    {
153        let mut patterns = Vec::new();
154        while let Some(value) = seq.next_element::<String>()? {
155            patterns.push(value);
156        }
157        Ok(PatternTable::from_patterns(patterns))
158    }
159}
160
161pub fn pattern_match_ratio(value: &str, patterns: &PatternTable) -> f64 {
162    patterns.match_ratio(value)
163}
164
165fn pattern_matches_value(value: &str, pattern: &str) -> bool {
166    let value = value.to_ascii_lowercase();
167    let pattern = normalize_pattern(pattern);
168    if pattern.is_empty() {
169        return false;
170    }
171
172    if value.contains(&pattern) {
173        return true;
174    }
175
176    let value_tokens: HashSet<String> = asset_pattern_tokens(&value).into_iter().collect();
177    let pattern_tokens = asset_pattern_tokens(&pattern);
178    !pattern_tokens.is_empty()
179        && pattern_tokens
180            .iter()
181            .all(|token| value_tokens.contains(token))
182}
183
184fn normalize_pattern(value: &str) -> String {
185    value.trim().to_ascii_lowercase()
186}
187
188fn asset_pattern_tokens(value: &str) -> Vec<String> {
189    let mut tokens = Vec::new();
190    for segment in value
191        .split(|ch: char| !ch.is_ascii_alphanumeric() && ch != '.')
192        .filter(|segment| !segment.is_empty())
193    {
194        if is_semver_like_token(segment) {
195            continue;
196        }
197
198        if segment.contains('.') {
199            for part in segment.split('.') {
200                push_asset_pattern_token(&mut tokens, part);
201            }
202        } else {
203            push_asset_pattern_token(&mut tokens, segment);
204        }
205    }
206
207    dedupe_preserving_order(tokens)
208}
209
210fn push_asset_pattern_token(tokens: &mut Vec<String>, value: &str) {
211    let normalized = normalize_pattern(value);
212    if normalized.is_empty() || is_semver_like_token(&normalized) {
213        return;
214    }
215    tokens.push(normalized);
216}
217
218fn dedupe_preserving_order(tokens: Vec<String>) -> Vec<String> {
219    let mut seen = HashSet::new();
220    tokens
221        .into_iter()
222        .filter(|token| seen.insert(token.clone()))
223        .collect()
224}
225
226fn is_semver_like_token(token: &str) -> bool {
227    let trimmed = token.strip_prefix('v').unwrap_or(token);
228    trimmed.contains('.') && Version::parse(trimmed).is_ok()
229}
230
231fn pattern_tokens_for_asset(asset_name: &str) -> Vec<String> {
232    asset_pattern_tokens(asset_name)
233}
234
235fn pattern_set_for_asset(asset_name: &str) -> HashSet<String> {
236    pattern_tokens_for_asset(asset_name).into_iter().collect()
237}
238
239fn pattern_table_from_set(tokens: HashSet<String>) -> PatternTable {
240    let mut tokens: Vec<String> = tokens.into_iter().collect();
241    tokens.sort();
242    PatternTable::from_patterns(tokens)
243}
244
245pub fn generate_patterns_for_asset(
246    selected: &Asset,
247    release_assets: &[Asset],
248    package_name: &str,
249) -> GeneratedAssetPatterns {
250    let package_tokens: HashSet<String> =
251        pattern_tokens_for_asset(package_name).into_iter().collect();
252    let mut selected_set = pattern_set_for_asset(&selected.name);
253    selected_set.retain(|token| !package_tokens.contains(token));
254
255    if selected_set.is_empty() {
256        selected_set.extend(asset_platform_tokens(selected));
257    }
258
259    let mut exclude_tokens = HashSet::new();
260    for asset in release_assets {
261        if asset.id == selected.id {
262            continue;
263        }
264
265        if asset.filetype != selected.filetype {
266            continue;
267        }
268
269        let mut other_tokens = pattern_set_for_asset(&asset.name);
270        other_tokens.retain(|token| !package_tokens.contains(token));
271        for token in other_tokens.difference(&selected_set) {
272            exclude_tokens.insert(token.clone());
273        }
274    }
275
276    GeneratedAssetPatterns {
277        match_pattern: pattern_table_from_set(selected_set),
278        exclude_pattern: pattern_table_from_set(exclude_tokens),
279    }
280}
281
282fn asset_platform_tokens(asset: &Asset) -> Vec<String> {
283    let mut tokens = Vec::new();
284    if let Some(os) = &asset.target_os {
285        tokens.push(format!("{os:?}").to_ascii_lowercase());
286    }
287    if let Some(arch) = &asset.target_arch {
288        tokens.push(format!("{arch:?}").to_ascii_lowercase());
289    }
290    tokens
291}
292
293#[cfg(test)]
294mod tests {
295    use super::{PatternTable, generate_patterns_for_asset, pattern_match_ratio};
296    #[cfg(target_os = "linux")]
297    use crate::models::common::Version;
298    #[cfg(target_os = "linux")]
299    use crate::models::common::enums::{Channel, Filetype, Provider};
300    use crate::models::provider::Asset;
301    #[cfg(target_os = "linux")]
302    use crate::models::provider::Release;
303    #[cfg(target_os = "linux")]
304    use crate::models::upstream::Package;
305    #[cfg(target_os = "linux")]
306    use crate::providers::asset_selector::AssetSelector;
307    use chrono::Utc;
308
309    #[cfg(target_os = "linux")]
310    fn make_release(assets: Vec<Asset>, prerelease: bool, tag: &str) -> Release {
311        Release {
312            id: 1,
313            tag: tag.to_string(),
314            name: tag.to_string(),
315            body: String::new(),
316            is_draft: false,
317            is_prerelease: prerelease,
318            assets,
319            version: Version::new(1, 0, 0, prerelease),
320            published_at: Utc::now(),
321        }
322    }
323
324    #[test]
325    fn pattern_match_ratio_scores_matched_tokens_as_percentage() {
326        assert_eq!(
327            pattern_match_ratio(
328                "tool-x86_64-linux-musl.tar.gz",
329                &PatternTable::from_patterns(["x86", "64", "linux", "musl"])
330            ),
331            1.0
332        );
333        assert_eq!(
334            pattern_match_ratio(
335                "tool-x86_64-linux-gnu.tar.gz",
336                &PatternTable::from_patterns(["x86", "64", "linux", "musl"])
337            ),
338            3.0 / 4.0
339        );
340        assert_eq!(
341            pattern_match_ratio(
342                "tool-aarch64-darwin.tar.gz",
343                &PatternTable::from_patterns(["x86", "64", "linux", "musl"])
344            ),
345            1.0 / 4.0
346        );
347    }
348
349    #[test]
350    fn cli_patterns_split_on_commas_only() {
351        let table = PatternTable::from_cli_arg(Some("linux-x86_64,musl".to_string()));
352        assert_eq!(table.as_slice(), ["linux-x86_64", "musl"]);
353    }
354
355    #[test]
356    fn legacy_strings_split_on_whitespace_and_commas() {
357        let json = r#""x86_64 linux,musl""#;
358        let table: PatternTable = serde_json::from_str(json).expect("legacy table");
359        assert_eq!(table.as_slice(), ["x86_64", "linux", "musl"]);
360    }
361
362    #[test]
363    fn generate_patterns_for_selected_asset_keeps_stable_platform_tokens() {
364        let selected = Asset::new(
365            "https://example.invalid/tool-v1.2.3-x86_64-unknown-linux-musl.tar.gz".to_string(),
366            1,
367            "tool-v1.2.3-x86_64-unknown-linux-musl.tar.gz".to_string(),
368            200_000,
369            Utc::now(),
370        );
371        let release_assets = vec![
372            selected.clone(),
373            Asset::new(
374                "https://example.invalid/tool-v1.2.3-aarch64-unknown-linux-musl.tar.gz".to_string(),
375                2,
376                "tool-v1.2.3-aarch64-unknown-linux-musl.tar.gz".to_string(),
377                200_000,
378                Utc::now(),
379            ),
380        ];
381
382        let generated = generate_patterns_for_asset(&selected, &release_assets, "tool");
383        assert!(
384            generated
385                .match_pattern
386                .as_slice()
387                .contains(&"x86".to_string())
388        );
389        assert!(
390            generated
391                .match_pattern
392                .as_slice()
393                .contains(&"64".to_string())
394        );
395        assert!(
396            generated
397                .match_pattern
398                .as_slice()
399                .contains(&"linux".to_string())
400        );
401        assert!(
402            generated
403                .match_pattern
404                .as_slice()
405                .contains(&"musl".to_string())
406        );
407        assert!(!generated.match_pattern.to_string().contains("1.2.3"));
408        assert!(
409            generated
410                .exclude_pattern
411                .as_slice()
412                .contains(&"aarch64".to_string())
413        );
414    }
415
416    #[test]
417    fn generate_patterns_for_selected_asset_keeps_flavor_tokens() {
418        let selected = Asset::new(
419            "https://example.invalid/ffmpeg-release-essentials.7z".to_string(),
420            1,
421            "ffmpeg-release-essentials.7z".to_string(),
422            200_000,
423            Utc::now(),
424        );
425        let release_assets = vec![
426            selected.clone(),
427            Asset::new(
428                "https://example.invalid/ffmpeg-release-full.7z".to_string(),
429                2,
430                "ffmpeg-release-full.7z".to_string(),
431                200_000,
432                Utc::now(),
433            ),
434            Asset::new(
435                "https://example.invalid/ffmpeg-release-full-shared.7z".to_string(),
436                3,
437                "ffmpeg-release-full-shared.7z".to_string(),
438                200_000,
439                Utc::now(),
440            ),
441        ];
442
443        let generated = generate_patterns_for_asset(&selected, &release_assets, "ffmpeg");
444        assert!(
445            generated
446                .match_pattern
447                .as_slice()
448                .contains(&"essentials".to_string())
449        );
450        assert!(
451            generated
452                .exclude_pattern
453                .as_slice()
454                .contains(&"full".to_string())
455        );
456        assert!(
457            generated
458                .exclude_pattern
459                .as_slice()
460                .contains(&"shared".to_string())
461        );
462    }
463
464    #[cfg(target_os = "linux")]
465    #[test]
466    fn generated_patterns_select_similar_asset_on_future_release() {
467        let selector = AssetSelector::new();
468        let selected = Asset::new(
469            "https://example.invalid/tool-v1.2.3-x86_64-unknown-linux-musl.tar.gz".to_string(),
470            1,
471            "tool-v1.2.3-x86_64-unknown-linux-musl.tar.gz".to_string(),
472            200_000,
473            Utc::now(),
474        );
475        let generated = generate_patterns_for_asset(
476            &selected,
477            &[
478                selected.clone(),
479                Asset::new(
480                    "https://example.invalid/tool-v1.2.3-x86_64-unknown-linux-gnu.tar.gz"
481                        .to_string(),
482                    2,
483                    "tool-v1.2.3-x86_64-unknown-linux-gnu.tar.gz".to_string(),
484                    200_000,
485                    Utc::now(),
486                ),
487            ],
488            "tool",
489        );
490        let package = Package::with_defaults(
491            "tool".to_string(),
492            "owner/tool".to_string(),
493            Filetype::Archive,
494            Some(generated.match_pattern.to_string()),
495            Some(generated.exclude_pattern.to_string()),
496            Channel::Stable,
497            Provider::Github,
498            None,
499        );
500        let future_release = make_release(
501            vec![
502                Asset::new(
503                    "https://example.invalid/tool-v1.3.0-x86_64-unknown-linux-gnu.tar.gz"
504                        .to_string(),
505                    3,
506                    "tool-v1.3.0-x86_64-unknown-linux-gnu.tar.gz".to_string(),
507                    200_000,
508                    Utc::now(),
509                ),
510                Asset::new(
511                    "https://example.invalid/tool-v1.3.0-x86_64-unknown-linux-musl.tar.gz"
512                        .to_string(),
513                    4,
514                    "tool-v1.3.0-x86_64-unknown-linux-musl.tar.gz".to_string(),
515                    200_000,
516                    Utc::now(),
517                ),
518            ],
519            false,
520            "v1.3.0",
521        );
522
523        let best = selector
524            .find_recommended_asset(&future_release, &package)
525            .expect("best asset");
526        assert!(best.name.contains("musl"));
527    }
528}