1use std::path::{Path, PathBuf};
23
24pub fn default_models_dir() -> PathBuf {
33 if let Ok(dir) = std::env::var("ZER_MODEL_DIR") {
34 return PathBuf::from(dir);
35 }
36 if let Some(home) = std::env::var_os("HOME") {
37 let cache = PathBuf::from(home)
38 .join(".cache")
39 .join("zer")
40 .join("models");
41 if cache.exists() {
42 return cache;
43 }
44 }
45 PathBuf::from("models")
46}
47
48#[derive(Debug, Clone)]
52pub enum TokenizerSource {
53 File(PathBuf),
55 HuggingFace(String),
58}
59
60impl TokenizerSource {
61 pub fn file(path: impl AsRef<Path>) -> Self {
63 Self::File(path.as_ref().to_owned())
64 }
65
66 pub fn hub(model_id: impl Into<String>) -> Self {
68 Self::HuggingFace(model_id.into())
69 }
70}
71
72pub trait JudgeModelSpec: Send + Sync {
79 fn name(&self) -> &str;
81
82 fn model_path(&self) -> &Path;
84
85 fn tokenizer_source(&self) -> &TokenizerSource;
87
88 fn max_length(&self) -> usize;
90
91 fn entailment_idx(&self) -> usize;
95
96 fn vram_bytes(&self) -> u64;
98}
99
100#[derive(Debug, Clone, Copy, Default)]
112pub enum ModelPrecision {
113 Base,
114 Fp16,
115 #[default]
116 Fp16Fused,
117}
118
119impl ModelPrecision {
120 pub fn subfolder(self) -> &'static str {
121 match self {
122 Self::Base => "base",
123 Self::Fp16 => "fp16",
124 Self::Fp16Fused => "fp16_fused",
125 }
126 }
127}
128
129pub struct MiniLmSpec {
133 model_path: PathBuf,
134 tokenizer_source: TokenizerSource,
135}
136
137impl MiniLmSpec {
138 pub fn new(model_path: impl AsRef<Path>, tokenizer_source: TokenizerSource) -> Self {
139 Self {
140 model_path: model_path.as_ref().to_owned(),
141 tokenizer_source,
142 }
143 }
144
145 pub fn from_dir(dir: impl AsRef<Path>) -> Self {
148 let dir = dir.as_ref();
149 Self {
150 model_path: dir.join("model.onnx"),
151 tokenizer_source: TokenizerSource::file(dir.join("tokenizer.json")),
152 }
153 }
154
155 pub fn from_env(precision: ModelPrecision) -> Self {
161 let base = default_models_dir()
162 .join("nli-base")
163 .join(precision.subfolder())
164 .join("nli-minilm-onnx");
165 Self::from_dir(base)
166 }
167}
168
169impl JudgeModelSpec for MiniLmSpec {
170 fn name(&self) -> &str {
171 "cross-encoder/nli-MiniLM2-L6-H768"
172 }
173 fn model_path(&self) -> &Path {
174 &self.model_path
175 }
176 fn tokenizer_source(&self) -> &TokenizerSource {
177 &self.tokenizer_source
178 }
179 fn max_length(&self) -> usize {
180 512
181 }
182 fn entailment_idx(&self) -> usize {
183 1
184 }
185 fn vram_bytes(&self) -> u64 {
186 256 * 1024 * 1024
187 } }
189
190pub struct DebertaBaseSpec {
194 model_path: PathBuf,
195 tokenizer_source: TokenizerSource,
196}
197
198impl DebertaBaseSpec {
199 pub fn new(model_path: impl AsRef<Path>, tokenizer_source: TokenizerSource) -> Self {
200 Self {
201 model_path: model_path.as_ref().to_owned(),
202 tokenizer_source,
203 }
204 }
205
206 pub fn from_dir(dir: impl AsRef<Path>) -> Self {
207 let dir = dir.as_ref();
208 Self {
209 model_path: dir.join("model.onnx"),
210 tokenizer_source: TokenizerSource::file(dir.join("tokenizer.json")),
211 }
212 }
213
214 pub fn from_env(precision: ModelPrecision) -> Self {
219 let base = default_models_dir()
220 .join("nli-base")
221 .join(precision.subfolder())
222 .join("nli-deberta-v3-base-onnx");
223 Self::from_dir(base)
224 }
225}
226
227impl JudgeModelSpec for DebertaBaseSpec {
228 fn name(&self) -> &str {
229 "cross-encoder/nli-deberta-v3-base"
230 }
231 fn model_path(&self) -> &Path {
232 &self.model_path
233 }
234 fn tokenizer_source(&self) -> &TokenizerSource {
235 &self.tokenizer_source
236 }
237 fn max_length(&self) -> usize {
238 512
239 }
240 fn entailment_idx(&self) -> usize {
241 1
242 }
243 fn vram_bytes(&self) -> u64 {
244 2 * 1024 * 1024 * 1024
245 } }
247
248pub fn spec_from_env(
257 precision: ModelPrecision,
258 available_vram_bytes: u64,
259) -> Box<dyn JudgeModelSpec> {
260 let models_dir = default_models_dir()
261 .join("nli-base")
262 .join(precision.subfolder());
263 spec_from_vram(&models_dir, available_vram_bytes)
264}
265
266pub fn spec_from_vram(models_dir: &Path, available_vram_bytes: u64) -> Box<dyn JudgeModelSpec> {
283 let base = models_dir.join("nli-deberta-v3-base-onnx");
284 let mini = models_dir.join("nli-minilm-onnx");
285
286 if available_vram_bytes >= 2 * 1024 * 1024 * 1024 && base.exists() {
287 tracing::info!(
288 "judge: selecting DeBERTa-v3-base ({:.1} GB VRAM available)",
289 available_vram_bytes as f64 / 1e9
290 );
291 return Box::new(DebertaBaseSpec::from_dir(&base));
292 }
293
294 tracing::info!("judge: selecting MiniLM-L6 (CPU or low VRAM)");
295 Box::new(MiniLmSpec::from_dir(&mini))
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use std::path::PathBuf;
302
303 fn dummy_path(name: &str) -> PathBuf {
304 PathBuf::from(format!("/nonexistent/{name}"))
305 }
306
307 #[test]
310 fn minilm_from_dir_sets_expected_paths() {
311 let spec = MiniLmSpec::from_dir("/some/dir");
312 assert_eq!(spec.model_path(), Path::new("/some/dir/model.onnx"));
313 assert!(
314 matches!(spec.tokenizer_source(), TokenizerSource::File(p) if p == Path::new("/some/dir/tokenizer.json"))
315 );
316 }
317
318 #[test]
319 fn minilm_metadata() {
320 let spec = MiniLmSpec::from_dir("/d");
321 assert_eq!(spec.name(), "cross-encoder/nli-MiniLM2-L6-H768");
322 assert_eq!(spec.max_length(), 512);
323 assert_eq!(spec.entailment_idx(), 1);
324 assert_eq!(spec.vram_bytes(), 256 * 1024 * 1024);
325 }
326
327 #[test]
330 fn deberta_base_from_dir_sets_expected_paths() {
331 let spec = DebertaBaseSpec::from_dir("/fp16_fused/dir");
332 assert_eq!(spec.model_path(), Path::new("/fp16_fused/dir/model.onnx"));
333 assert!(
334 matches!(spec.tokenizer_source(), TokenizerSource::File(p) if p == Path::new("/fp16_fused/dir/tokenizer.json"))
335 );
336 }
337
338 #[test]
339 fn deberta_base_metadata() {
340 let spec = DebertaBaseSpec::from_dir("/d");
341 assert_eq!(spec.name(), "cross-encoder/nli-deberta-v3-base");
342 assert_eq!(spec.max_length(), 512);
343 assert_eq!(spec.entailment_idx(), 1);
344 assert_eq!(spec.vram_bytes(), 2 * 1024 * 1024 * 1024);
345 }
346
347 #[test]
350 fn spec_from_vram_no_dirs_returns_minilm() {
351 let spec = spec_from_vram(Path::new("/nonexistent"), 16 * 1024 * 1024 * 1024);
353 assert_eq!(spec.name(), "cross-encoder/nli-MiniLM2-L6-H768");
354 }
355
356 #[test]
357 fn spec_from_vram_selects_minilm_when_low_vram() {
358 let spec = spec_from_vram(Path::new("/nonexistent"), 512 * 1024 * 1024);
360 assert_eq!(spec.name(), "cross-encoder/nli-MiniLM2-L6-H768");
361 }
362
363 #[test]
364 fn spec_from_vram_with_real_models_dir_selects_best_available() {
365 let models_dir = Path::new("../../models/nli-base/fp16_fused");
367 if !models_dir.exists() {
368 return; }
370 let spec = spec_from_vram(models_dir, 0);
372 assert_eq!(spec.name(), "cross-encoder/nli-MiniLM2-L6-H768");
373 }
374
375 #[test]
376 fn token_source_file_convenience() {
377 let ts = TokenizerSource::file("/tmp/tok.json");
378 assert!(matches!(ts, TokenizerSource::File(p) if p == Path::new("/tmp/tok.json")));
379 }
380
381 #[test]
382 fn token_source_hub_convenience() {
383 let ts = TokenizerSource::hub("cross-encoder/nli-deberta-v3-base");
384 assert!(
385 matches!(ts, TokenizerSource::HuggingFace(s) if s == "cross-encoder/nli-deberta-v3-base")
386 );
387 }
388
389 #[test]
390 fn minilm_new_constructor() {
391 let spec = MiniLmSpec::new(
392 dummy_path("model.onnx"),
393 TokenizerSource::file(dummy_path("tok.json")),
394 );
395 assert_eq!(spec.model_path(), Path::new("/nonexistent/model.onnx"));
396 }
397}