Skip to main content

wax/
score.rs

1use std::collections::HashMap;
2
3use crate::cli::SortMode;
4use crate::model::{CandidateRecord, OwnedAlbum, SeedAlbum};
5
6#[derive(Debug, Clone)]
7pub struct ScoreOptions {
8    pub min_overlap: usize,
9    pub exclude_artist: bool,
10    pub exclude_label: bool,
11    pub required_tags: Vec<String>,
12    pub source_label_plural: &'static str,
13    pub sort: SortMode,
14    pub limit: usize,
15}
16
17#[derive(Debug, Default)]
18struct Aggregate {
19    album: Option<OwnedAlbum>,
20    collectors: Vec<String>,
21}
22
23pub fn rank_candidates(
24    seed: &SeedAlbum,
25    collector_albums: Vec<(String, Vec<OwnedAlbum>)>,
26    options: &ScoreOptions,
27) -> Vec<CandidateRecord> {
28    let mut aggregates: HashMap<String, Aggregate> = HashMap::new();
29    let scanned = collector_albums.len().max(1);
30
31    for (collector, albums) in collector_albums {
32        for album in albums {
33            if album.url == seed.url {
34                continue;
35            }
36            if options.exclude_artist && album.artist.eq_ignore_ascii_case(&seed.artist) {
37                continue;
38            }
39            if options.exclude_label
40                && seed.label.is_some()
41                && album.label.is_some()
42                && seed.label == album.label
43            {
44                continue;
45            }
46            if !options.required_tags.is_empty()
47                && !options.required_tags.iter().all(|tag| {
48                    album
49                        .tags
50                        .iter()
51                        .any(|value| value.eq_ignore_ascii_case(tag))
52                })
53            {
54                continue;
55            }
56
57            let entry = aggregates.entry(album.url.clone()).or_default();
58            if entry.album.is_none() {
59                entry.album = Some(album);
60            }
61            entry.collectors.push(collector.clone());
62        }
63    }
64
65    let mut ranked = Vec::new();
66    for aggregate in aggregates.into_values() {
67        let Some(album) = aggregate.album else {
68            continue;
69        };
70
71        let overlap_count = aggregate.collectors.len();
72        if overlap_count < options.min_overlap {
73            continue;
74        }
75
76        let overlap_ratio = overlap_count as f64 / scanned as f64;
77        let same_artist_penalty = if album.artist.eq_ignore_ascii_case(&seed.artist) {
78            4.0
79        } else {
80            0.0
81        };
82        let score = (overlap_count as f64 * 1.5) + (overlap_ratio * 10.0) - same_artist_penalty;
83
84        ranked.push(CandidateRecord {
85            rank: 0,
86            title: album.title.clone(),
87            artist: album.artist.clone(),
88            url: album.url,
89            overlap_count,
90            overlap_ratio,
91            score,
92            reason: format!(
93                "Seen in {overlap_count} of {scanned} sampled {}",
94                options.source_label_plural
95            ),
96            collectors: aggregate.collectors,
97        });
98    }
99
100    match options.sort {
101        SortMode::Score => ranked.sort_by(|a, b| b.score.total_cmp(&a.score)),
102        SortMode::Overlap => ranked.sort_by(|a, b| b.overlap_count.cmp(&a.overlap_count)),
103    }
104
105    for (index, item) in ranked.iter_mut().take(options.limit).enumerate() {
106        item.rank = index + 1;
107    }
108
109    ranked.truncate(options.limit);
110    ranked
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use crate::model::{ItemKind, Platform};
117
118    #[test]
119    fn ranks_candidates_by_overlap() {
120        let seed = SeedAlbum {
121            platform: Platform::Bandcamp,
122            kind: ItemKind::Album,
123            title: "Seed".to_string(),
124            artist: "Seed Artist".to_string(),
125            url: "https://seed.bandcamp.com/album/seed".to_string(),
126            artist_url: None,
127            tags: vec![],
128            label: None,
129            release_id: None,
130        };
131
132        let albums = vec![
133            (
134                "fan_a".to_string(),
135                vec![OwnedAlbum {
136                    platform: Platform::Bandcamp,
137                    kind: ItemKind::Album,
138                    title: "A".to_string(),
139                    artist: "Other".to_string(),
140                    url: "https://x.bandcamp.com/album/a".to_string(),
141                    tags: vec![],
142                    label: None,
143                }],
144            ),
145            (
146                "fan_b".to_string(),
147                vec![OwnedAlbum {
148                    platform: Platform::Bandcamp,
149                    kind: ItemKind::Album,
150                    title: "A".to_string(),
151                    artist: "Other".to_string(),
152                    url: "https://x.bandcamp.com/album/a".to_string(),
153                    tags: vec![],
154                    label: None,
155                }],
156            ),
157        ];
158
159        let ranked = rank_candidates(
160            &seed,
161            albums,
162            &ScoreOptions {
163                min_overlap: 1,
164                exclude_artist: false,
165                exclude_label: false,
166                required_tags: vec![],
167                source_label_plural: "collectors",
168                sort: SortMode::Score,
169                limit: 10,
170            },
171        );
172
173        assert_eq!(ranked.len(), 1);
174        assert_eq!(ranked[0].overlap_count, 2);
175    }
176}