simple_whisper/
model.rs

1use std::path::PathBuf;
2
3use hf_hub::{Cache, Repo};
4use strum::{Display, EnumIter, EnumString};
5use tokio::sync::mpsc::UnboundedSender;
6
7use crate::{
8    Error, Event,
9    download::{ProgressType, download_file},
10};
11
12struct HFCoordinates {
13    repo: Repo,
14    model: String,
15}
16
17/// OpenAI supported models
18#[derive(Default, Clone, Debug, EnumIter, EnumString, Display)]
19#[strum(serialize_all = "snake_case")]
20pub enum Model {
21    /// The tiny model.
22    #[strum(serialize = "tiny", to_string = "Tiny - tiny")]
23    Tiny,
24    /// The tiny-q5_1 model.
25    #[strum(serialize = "tiny-q5_1", to_string = "Tiny - tiny-q5_1")]
26    TinyQ5_1,
27    /// The tiny-q8_0 model.
28    #[strum(serialize = "tiny-q8_0", to_string = "Tiny - tiny-q8_0")]
29    TinyQ8_0,
30    /// The tiny model with only English support.
31    #[strum(serialize = "tiny_en", to_string = "TinyEn - tiny_en")]
32    TinyEn,
33    /// The tiny-q5_1 model with only English support.
34    #[strum(serialize = "tiny_en-q5_1", to_string = "TinyEn - tiny_en-q5_1")]
35    TinyEnQ5_1,
36    /// The tiny-q8_0 model with only English support.
37    #[strum(serialize = "tiny_en-q8_0", to_string = "Tiny - tiny_en-q8_0")]
38    TinyEnQ8_0,
39    /// The base model.
40    #[default]
41    #[strum(serialize = "base", to_string = "Base - base")]
42    Base,
43    /// The base-q5_1 model.
44    #[strum(serialize = "base-q5_1", to_string = "Base - base-q5_1")]
45    BaseQ5_1,
46    /// The base-q8_0 model.
47    #[strum(serialize = "base-q8_0", to_string = "Base - base-q8_0")]
48    BaseQ8_0,
49    /// The base model with only English support.
50    #[strum(serialize = "base_en", to_string = "BaseEn - base_en")]
51    BaseEn,
52    /// The base-q5_1 model with only English support.
53    #[strum(serialize = "base_en-q5_1", to_string = "BaseEn -base_en-q5_1")]
54    BaseEnQ5_1,
55    /// The base-q8_0 model with only English support.
56    #[strum(serialize = "base_en-q8_0", to_string = "BaseEn - base_en-q8_0")]
57    BaseEnQ8_0,
58    /// The small model.
59    #[strum(serialize = "small", to_string = "Small - small")]
60    Small,
61    /// The small-q5_1 model.
62    #[strum(serialize = "small-q5_1", to_string = "Small - small-q5_1")]
63    SmallQ5_1,
64    /// The small-q8_0 model.
65    #[strum(serialize = "small-q8_0", to_string = "Small - small-q8_0")]
66    SmallQ8_0,
67    /// The small model with only English support.
68    #[strum(serialize = "small_en", to_string = "SmallEn - small_en")]
69    SmallEn,
70    /// The small-q5_1 model with only English support.
71    #[strum(serialize = "small_en-q5_1", to_string = "SmallEn - small_en-q5_1")]
72    SmallEnQ5_1,
73    /// The small-q8_0 model with only English support.
74    #[strum(serialize = "small_en-q8_0", to_string = "SmallEn - small_en-q8_0")]
75    SmallEnQ8_0,
76    /// The medium model.
77    #[strum(serialize = "medium", to_string = "Medium - medium")]
78    Medium,
79    /// The medium-q5_0 model.
80    #[strum(serialize = "medium-q5_0", to_string = "Medium - medium-q5_0")]
81    MediumQ5_0,
82    /// The medium-q8_0 model.
83    #[strum(serialize = "medium-q8_0", to_string = "Medium - medium-q8_0")]
84    MediumQ8_0,
85    /// The medium model with only English support.
86    #[strum(serialize = "medium_en", to_string = "MediumEn - medium_en")]
87    MediumEn,
88    /// The medium-q5_0 model with only English support.
89    #[strum(serialize = "medium_en-q5_0	", to_string = "MediumEn - medium_en-q5_0")]
90    MediumEnQ5_0,
91    /// The medium-q8_0 model with only English support.
92    #[strum(serialize = "medium_en-q8_0", to_string = "MediumEn - medium_en-q8_0")]
93    MediumEnQ8_0,
94    /// The large model.
95    #[strum(serialize = "large", to_string = "Large V1 - large")]
96    Large,
97    /// The large model v2.
98    #[strum(serialize = "large_v2", to_string = "Large V2 - large_v2")]
99    LargeV2,
100    #[strum(serialize = "large_v2-q5_0", to_string = "Large V2 - large_v2-q5_0")]
101    LargeV2Q5_0,
102    #[strum(serialize = "large_v2-q8_0", to_string = "Large V2 - large_v2-q8_0")]
103    LargeV2Q8_0,
104    /// The large model v3.
105    #[strum(serialize = "large_v3", to_string = "Large V3 - large_v3")]
106    LargeV3,
107    /// The large_v3-q5_0 model v3.
108    #[strum(serialize = "large_v3-q5_0", to_string = "Large V3 - large_v3-q5_0")]
109    LargeV3Q5_0,
110    /// The large model v3 turbo.
111    #[strum(
112        serialize = "large_v3_turbo",
113        to_string = "Large V3 Turbo - large_v3_turbo"
114    )]
115    LargeV3Turbo,
116    /// The large_v3_turbo-q5_0 model v3 turbo.
117    #[strum(
118        serialize = "large_v3_turbo-q5_0",
119        to_string = "Large V3 Turbo - large_v3_turbo-q5_0"
120    )]
121    LargeV3TurboQ5_0,
122    /// The large_v3_turbo-q8_0 model v3 turbo.
123    #[strum(
124        serialize = "large_v3_turbo-q8_0",
125        to_string = "Large V3 Turbo - large_v3_turbo-q8_0"
126    )]
127    LargeV3TurboQ8_0,
128}
129
130impl Model {
131    fn hf_coordinates(&self) -> HFCoordinates {
132        let repo = Repo::with_revision(
133            "ggerganov/whisper.cpp".to_owned(),
134            hf_hub::RepoType::Model,
135            "main".to_owned(),
136        );
137        match self {
138            Model::Tiny => HFCoordinates {
139                repo,
140                model: "ggml-tiny.bin".to_owned(),
141            },
142            Model::TinyEn => HFCoordinates {
143                repo,
144                model: "ggml-tiny.en.bin".to_owned(),
145            },
146            Model::Base => HFCoordinates {
147                repo,
148                model: "ggml-base.bin".to_owned(),
149            },
150            Model::BaseEn => HFCoordinates {
151                repo,
152                model: "ggml-base.en.bin".to_owned(),
153            },
154            Model::Small => HFCoordinates {
155                repo,
156                model: "ggml-small.bin".to_owned(),
157            },
158            Model::SmallEn => HFCoordinates {
159                repo,
160                model: "ggml-small.en.bin".to_owned(),
161            },
162            Model::Medium => HFCoordinates {
163                repo,
164                model: "ggml-medium.bin".to_owned(),
165            },
166            Model::MediumEn => HFCoordinates {
167                repo,
168                model: "ggml-medium.en.bin".to_owned(),
169            },
170            Model::Large => HFCoordinates {
171                repo,
172                model: "ggml-large-v1.bin".to_owned(),
173            },
174            Model::LargeV2 => HFCoordinates {
175                repo,
176                model: "ggml-large-v2.bin".to_owned(),
177            },
178            Model::LargeV3 => HFCoordinates {
179                repo,
180                model: "ggml-large-v3.bin".to_owned(),
181            },
182            Model::TinyQ5_1 => HFCoordinates {
183                repo,
184                model: "ggml-tiny-q5_1.bin".to_owned(),
185            },
186            Model::TinyQ8_0 => HFCoordinates {
187                repo,
188                model: "ggml-tiny-q8_0.bin".to_owned(),
189            },
190            Model::TinyEnQ5_1 => HFCoordinates {
191                repo,
192                model: "ggml-tiny.en-q5_1.bin".to_owned(),
193            },
194            Model::TinyEnQ8_0 => HFCoordinates {
195                repo,
196                model: "ggml-tiny.en-q8_0.bin".to_owned(),
197            },
198            Model::BaseQ5_1 => HFCoordinates {
199                repo,
200                model: "ggml-base-q5_1.bin".to_owned(),
201            },
202            Model::BaseQ8_0 => HFCoordinates {
203                repo,
204                model: "ggml-base-q8_0.bin".to_owned(),
205            },
206            Model::BaseEnQ5_1 => HFCoordinates {
207                repo,
208                model: "ggml-base.en-q5_1.bin".to_owned(),
209            },
210            Model::BaseEnQ8_0 => HFCoordinates {
211                repo,
212                model: "ggml-base.en-q8_0.bin".to_owned(),
213            },
214            Model::SmallQ5_1 => HFCoordinates {
215                repo,
216                model: "ggml-small-q5_1.bin".to_owned(),
217            },
218            Model::SmallQ8_0 => HFCoordinates {
219                repo,
220                model: "ggml-small-q8_0.bin".to_owned(),
221            },
222            Model::SmallEnQ5_1 => HFCoordinates {
223                repo,
224                model: "ggml-small.en-q5_1.bin".to_owned(),
225            },
226            Model::SmallEnQ8_0 => HFCoordinates {
227                repo,
228                model: "ggml-small.en-q8_0.bin".to_owned(),
229            },
230            Model::MediumQ5_0 => HFCoordinates {
231                repo,
232                model: "ggml-medium-q5_0.bin".to_owned(),
233            },
234            Model::MediumQ8_0 => HFCoordinates {
235                repo,
236                model: "ggml-medium-q8_0.bin".to_owned(),
237            },
238            Model::MediumEnQ5_0 => HFCoordinates {
239                repo,
240                model: "ggml-medium.en-q5_0.bin".to_owned(),
241            },
242            Model::MediumEnQ8_0 => HFCoordinates {
243                repo,
244                model: "ggml-medium.en-q8_0.bin".to_owned(),
245            },
246            Model::LargeV2Q5_0 => HFCoordinates {
247                repo,
248                model: "ggml-large-v2-q5_0.bin".to_owned(),
249            },
250            Model::LargeV2Q8_0 => HFCoordinates {
251                repo,
252                model: "ggml-large-v2-q8_0.bin".to_owned(),
253            },
254            Model::LargeV3Q5_0 => HFCoordinates {
255                repo,
256                model: "ggml-large-v3-q5_0.bin".to_owned(),
257            },
258            Model::LargeV3Turbo => HFCoordinates {
259                repo,
260                model: "ggml-large-v3-turbo.bin".to_owned(),
261            },
262            Model::LargeV3TurboQ5_0 => HFCoordinates {
263                repo,
264                model: "ggml-large-v3-turbo-q5_0.bin".to_owned(),
265            },
266            Model::LargeV3TurboQ8_0 => HFCoordinates {
267                repo,
268                model: "ggml-large-v3-turbo-q8_0.bin".to_owned(),
269            },
270        }
271    }
272
273    /// True if the model supports multiple languages, false otherwise.
274    pub fn is_multilingual(&self) -> bool {
275        !self.to_string().contains("en")
276    }
277
278    /// Check if the file model has been cached before
279    pub fn cached(&self) -> bool {
280        let coordinates = self.hf_coordinates();
281        let cache = Cache::from_env().repo(coordinates.repo);
282        cache.get(&coordinates.model).is_some()
283    }
284
285    pub(crate) async fn internal_download_model(
286        &self,
287        force_download: bool,
288        progress: ProgressType,
289    ) -> Result<PathBuf, Error> {
290        let coordinates = self.hf_coordinates();
291
292        download_file(
293            &coordinates.model,
294            force_download,
295            progress,
296            coordinates.repo,
297        )
298        .await
299    }
300
301    pub async fn download_model(&self, force_download: bool) -> Result<PathBuf, Error> {
302        self.internal_download_model(force_download, ProgressType::ProgressBar)
303            .await
304    }
305
306    pub async fn download_model_listener(
307        &self,
308        force_download: bool,
309        tx: UnboundedSender<Event>,
310    ) -> Result<PathBuf, Error> {
311        self.internal_download_model(force_download, ProgressType::Callback(tx))
312            .await
313    }
314}