Skip to main content

sqry_nl/classifier/
model.rs

1//! ONNX model loading and inference.
2
3use crate::error::ClassifierError;
4use crate::types::{ClassificationResult, Intent};
5use ort::session::Session;
6use ort::value::Tensor;
7use sha2::{Digest, Sha256};
8use std::collections::HashMap;
9use std::io::Read;
10use std::path::Path;
11
12use super::BAKED_MANIFEST;
13use super::calibration::CalibrationParams;
14use super::manifest::Manifest;
15use super::resolve::TrustMode;
16
17// ---------------------------------------------------------------------------
18// NL08 — ONNX Runtime "missing dylib" detection
19// ---------------------------------------------------------------------------
20//
21// The `ort` crate (with the `load-dynamic` feature, which sqry-nl uses)
22// resolves `libonnxruntime` at first API call via `libloading`. If the
23// shared library is absent, ort's `setup_api()` calls `.expect("Failed
24// to load ONNX Runtime dylib")` — meaning the failure surfaces as a
25// **panic**, not a typed `Result::Err`. Some downstream surfaces (e.g.
26// symbol lookup after a successful library open) do return a typed
27// `ort::Error` that carries the substring `"libonnxruntime"` /
28// `"failed to load"` / `"OrtGetApiBase"` / `"dlopen"` / `"DyLib"` in its
29// `Display` form.
30//
31// We therefore detect the missing-dylib condition through TWO channels:
32//
33//   1. `std::panic::catch_unwind` around the `Session::builder()` chain
34//      to convert panics into a typed `OnnxRuntimeMissing` error.
35//   2. String-pattern matching on the `Display` of any returned
36//      `ort::Error` for the substrings above, so symbol-lookup failures
37//      after a partial library load also surface as
38//      `OnnxRuntimeMissing` instead of the opaque `OnnxError(_)`.
39//
40// A deterministic test seam — the `SQRY_NL_FORCE_ORT_MISSING` env var
41// — short-circuits this path before any ORT call. The seam is gated on
42// `debug_assertions` so it cannot be exploited in release binaries
43// shipped to operators. Cargo test runs under `debug_assertions` by
44// default, so the CLI / MCP / LSP integration tests can drive this path
45// without needing an actual missing libonnxruntime on the host.
46
47/// Return the platform-specific install hint for missing ONNX Runtime.
48///
49/// Used to populate
50/// [`crate::error::ClassifierError::OnnxRuntimeMissing`] /
51/// [`crate::error::NlError::OnnxRuntimeMissing`].
52#[must_use]
53pub fn onnx_runtime_install_hint() -> String {
54    #[cfg(target_os = "linux")]
55    {
56        "Install via apt: 'sudo apt-get install libonnxruntime-dev' OR \
57         download from https://github.com/microsoft/onnxruntime/releases"
58            .to_string()
59    }
60    #[cfg(target_os = "macos")]
61    {
62        "Install via brew: 'brew install onnxruntime'".to_string()
63    }
64    #[cfg(target_os = "windows")]
65    {
66        "Download libonnxruntime.dll from \
67         https://github.com/microsoft/onnxruntime/releases and place in PATH"
68            .to_string()
69    }
70    #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
71    {
72        // Other Unix-likes (FreeBSD, etc.) — mirror Linux guidance.
73        "Install libonnxruntime via your platform package manager OR \
74         download from https://github.com/microsoft/onnxruntime/releases"
75            .to_string()
76    }
77}
78
79/// Return `true` when the env-var test seam is active.
80///
81/// Gated on `debug_assertions` so release binaries do not honour the
82/// override. `cargo test` runs under `debug_assertions` regardless of
83/// the harness binary's profile, so subprocess tests of the release
84/// `sqry` binary need to spawn the debug-built binary (which `cargo
85/// test` always does — `cargo build --release` is a separate command).
86#[cfg(debug_assertions)]
87fn ort_missing_forced() -> bool {
88    match std::env::var("SQRY_NL_FORCE_ORT_MISSING") {
89        Ok(v) => {
90            let v = v.trim();
91            v.eq_ignore_ascii_case("1")
92                || v.eq_ignore_ascii_case("true")
93                || v.eq_ignore_ascii_case("yes")
94                || v.eq_ignore_ascii_case("on")
95        }
96        Err(_) => false,
97    }
98}
99
100#[cfg(not(debug_assertions))]
101fn ort_missing_forced() -> bool {
102    false
103}
104
105/// Returns `true` if the given error string looks like a dylib-load
106/// failure for `libonnxruntime`. Matches the substrings ort emits in
107/// the load-dynamic path. Case-insensitive on the substring tokens.
108///
109/// NL08 review iter-1: the broad tokens `"dylib"`, `"dlopen"`, and
110/// `"failed to load"` were intentionally excluded from this OR set.
111/// They false-positive on operator-supplied paths (e.g.
112/// `SQRY_NL_MODEL_DIR=/some/dylib-models/...`) and on unrelated model
113/// load errors carrying such paths in their message — a bad-ONNX-bytes
114/// failure for a model under such a path would otherwise be
115/// misclassified as `OnnxRuntimeMissing`. The three remaining tokens
116/// (`libonnxruntime`, `onnxruntime.dll`, `ortgetapibase`) uniquely
117/// identify the ort dylib-load surface and will never appear in a
118/// legitimate file path or model parse error.
119fn looks_like_dylib_load_failure(msg: &str) -> bool {
120    let lower = msg.to_ascii_lowercase();
121    lower.contains("libonnxruntime")
122        || lower.contains("onnxruntime.dll")
123        || lower.contains("ortgetapibase")
124}
125
126/// Construct `ClassifierError::OnnxRuntimeMissing` with the platform hint.
127fn onnx_runtime_missing_error() -> ClassifierError {
128    ClassifierError::OnnxRuntimeMissing {
129        hint: onnx_runtime_install_hint(),
130    }
131}
132
133/// Intent classifier using an ONNX model (`all-MiniLM-L6-v2` or `DistilBERT`).
134pub struct IntentClassifier {
135    /// ONNX Runtime session
136    session: Session,
137    /// `HuggingFace` tokenizer
138    tokenizer: tokenizers::Tokenizer,
139    /// Calibration parameters for confidence scaling
140    calibration: CalibrationParams,
141    /// Model version string
142    model_version: String,
143    /// Whether the ONNX model declares `token_type_ids` as an input.
144    /// BERT-architecture models (`MiniLM`) require it; `DistilBERT` does not.
145    /// Passing an undeclared input to ort causes a runtime error.
146    has_token_type_ids: bool,
147}
148
149/// Compute SHA256 hash of a file.
150fn compute_file_hash(path: &Path) -> Result<String, ClassifierError> {
151    let mut file = std::fs::File::open(path).map_err(|e| {
152        ClassifierError::OnnxError(format!("Failed to open {}: {e}", path.display()))
153    })?;
154
155    let mut hasher = Sha256::new();
156    let mut buffer = [0u8; 8192];
157
158    loop {
159        let bytes_read = file.read(&mut buffer).map_err(|e| {
160            ClassifierError::OnnxError(format!("Failed to read {}: {e}", path.display()))
161        })?;
162        if bytes_read == 0 {
163            break;
164        }
165        hasher.update(&buffer[..bytes_read]);
166    }
167
168    Ok(format!("{:x}", hasher.finalize()))
169}
170
171// ---------------------------------------------------------------------------
172// NL04 Integrity Contract — AUTHORITATIVE
173// ---------------------------------------------------------------------------
174//
175// `verify_integrity` is the single point at which on-disk model artifacts
176// are validated against an expected-hash table. Two distinct failure modes
177// must NEVER be conflated:
178//
179//   1. TAMPERING — a file is present on disk, its sha256 was checked, and
180//      the computed hash does NOT match the expected hash. This ALWAYS
181//      yields `Err(ChecksumMismatch { file, expected, actual })`,
182//      regardless of `allow_unverified`. The escape hatch covers
183//      missingness only; it never silences hash mismatch on a present
184//      file. This matches spec FR-7 + FR-13.
185//
186//   2. MISSINGNESS — `checksums.json` itself is absent, or a file listed
187//      in `checksums.json` is absent on disk. In strict mode
188//      (`allow_unverified == false`, the default per FR-7), missingness
189//      is a fatal error (`ChecksumsMissing` / `ChecksummedFileMissing`).
190//      With `allow_unverified == true`, missingness downgrades to a
191//      `tracing::warn!` and the loader continues — but ALL still-present
192//      files are still hashed.
193//
194// Trust mode (FR-14):
195//   - `TrustMode::Trusted` (resolver levels 4-5): the on-disk
196//     `checksums.json` is hashed and cross-checked against
197//     `BAKED_MANIFEST.files["checksums.json"]`. A mismatch ALWAYS errors,
198//     even when `allow_unverified == true`. This anchors integrity in
199//     the binary itself rather than the operator-supplied directory.
200//   - `TrustMode::Custom` (resolver levels 1-3): the local
201//     `manifest.json` (parsed from disk in the same directory) is the
202//     trust root. `Translator::new` is responsible for emitting the
203//     loud `tracing::warn!` that integrity is rooted in user-supplied
204//     data; this function focuses on the actual verification.
205// ---------------------------------------------------------------------------
206
207/// Load checksums from `checksums.json` if present.
208///
209/// Returns `Ok(None)` when the file is absent (caller decides whether
210/// that is fatal based on `allow_unverified`). Returns `Ok(Some(map))`
211/// when present and parseable. Returns `Err` only on parse / I/O
212/// failure — those are always fatal.
213fn try_load_checksums(
214    checksums_path: &Path,
215) -> Result<Option<HashMap<String, String>>, ClassifierError> {
216    if !checksums_path.exists() {
217        return Ok(None);
218    }
219    let content = std::fs::read_to_string(checksums_path)
220        .map_err(|e| ClassifierError::OnnxError(format!("Failed to read checksums.json: {e}")))?;
221    let map = serde_json::from_str(&content)
222        .map_err(|e| ClassifierError::OnnxError(format!("Failed to parse checksums.json: {e}")))?;
223    Ok(Some(map))
224}
225
226/// Verify model directory integrity per the NL04 contract documented above.
227///
228/// See the module-level "NL04 Integrity Contract — AUTHORITATIVE" comment
229/// block for the full tampering-vs-missingness rules. A short summary:
230///
231/// - Tampering on a present file ALWAYS errors.
232/// - Missingness errors only when `allow_unverified == false`.
233/// - In `TrustMode::Trusted`, `checksums.json`'s own bytes are
234///   cross-checked against `BAKED_MANIFEST.files["checksums.json"]` —
235///   a mismatch is ALWAYS fatal.
236fn verify_integrity(
237    model_dir: &Path,
238    allow_unverified: bool,
239    trust_mode: TrustMode,
240) -> Result<(), ClassifierError> {
241    verify_integrity_with_trusted_manifest(model_dir, allow_unverified, trust_mode, &BAKED_MANIFEST)
242}
243
244fn verify_integrity_with_trusted_manifest(
245    model_dir: &Path,
246    allow_unverified: bool,
247    trust_mode: TrustMode,
248    trusted_manifest: &Manifest,
249) -> Result<(), ClassifierError> {
250    let checksums_path = model_dir.join("checksums.json");
251
252    match trust_mode {
253        TrustMode::Trusted => {
254            verify_trusted_checksums_anchor(&checksums_path, allow_unverified, trusted_manifest)?;
255        }
256        TrustMode::Custom => verify_custom_checksums_anchor(model_dir, &checksums_path)?,
257    }
258
259    // Per-file pass over `checksums.json`. Same tampering-vs-missingness
260    // rules apply file-by-file.
261    let Some(checksums) = try_load_checksums(&checksums_path)? else {
262        if allow_unverified {
263            tracing::warn!(
264                "No checksums.json found in {} — allow_unverified=true; \
265                 skipping integrity verification (development workflow)",
266                model_dir.display()
267            );
268            return Ok(());
269        }
270        return Err(ClassifierError::ChecksumsMissing);
271    };
272
273    let mut verified_count = 0usize;
274    for (filename, expected_hash) in &checksums {
275        let file_path = model_dir.join(filename);
276        if !file_path.exists() {
277            // MISSINGNESS — strict by default, warn-and-skip with hatch.
278            if allow_unverified {
279                tracing::warn!(
280                    "Checksummed file missing: {filename} — allow_unverified=true; \
281                     continuing (other listed files will still be hashed)"
282                );
283                continue;
284            }
285            return Err(ClassifierError::ChecksummedFileMissing(filename.clone()));
286        }
287
288        let actual_hash = compute_file_hash(&file_path)?;
289        if &actual_hash != expected_hash {
290            // TAMPERING — ALWAYS fatal, regardless of allow_unverified.
291            return Err(ClassifierError::ChecksumMismatch {
292                file: filename.clone(),
293                expected: expected_hash.clone(),
294                actual: actual_hash,
295            });
296        }
297        verified_count += 1;
298        tracing::debug!("Verified checksum for {filename}");
299    }
300    tracing::info!(
301        "Model integrity verified: {} of {} listed files checked",
302        verified_count,
303        checksums.len()
304    );
305    Ok(())
306}
307
308fn verify_trusted_checksums_anchor(
309    checksums_path: &Path,
310    allow_unverified: bool,
311    trusted_manifest: &Manifest,
312) -> Result<(), ClassifierError> {
313    let Some(expected_checksums_hash) = trusted_manifest.files.get("checksums.json") else {
314        return Ok(());
315    };
316
317    if checksums_path.exists() {
318        verify_checksums_json_hash(
319            checksums_path,
320            expected_checksums_hash,
321            "Trusted-mode anchor OK: checksums.json matches BAKED_MANIFEST",
322        )
323    } else if allow_unverified {
324        tracing::warn!(
325            "checksums.json missing under Trusted resolver level — \
326             allow_unverified=true downgrades to warn; baked-in trust \
327             anchor cannot be cross-checked"
328        );
329        Ok(())
330    } else {
331        Err(ClassifierError::ChecksumsMissing)
332    }
333}
334
335fn verify_custom_checksums_anchor(
336    model_dir: &Path,
337    checksums_path: &Path,
338) -> Result<(), ClassifierError> {
339    let local_manifest_path = model_dir.join("manifest.json");
340    if !local_manifest_path.exists() {
341        return Err(ClassifierError::ManifestAnchorInvalid(format!(
342            "manifest.json missing at {}",
343            local_manifest_path.display()
344        )));
345    }
346
347    let local_manifest = Manifest::parse_path(&local_manifest_path).map_err(|err| {
348        ClassifierError::ManifestAnchorInvalid(format!(
349            "failed to parse manifest.json at {}: {err}",
350            local_manifest_path.display()
351        ))
352    })?;
353    let expected_checksums_hash = local_manifest.files.get("checksums.json").ok_or_else(|| {
354        ClassifierError::ManifestAnchorInvalid(format!(
355            "manifest.files[\"checksums.json\"] missing in {}",
356            local_manifest_path.display()
357        ))
358    })?;
359
360    if checksums_path.exists() {
361        verify_checksums_json_hash(
362            checksums_path,
363            expected_checksums_hash,
364            "Custom-mode anchor OK: checksums.json matches local manifest.json",
365        )
366    } else {
367        tracing::warn!(
368            target: "sqry_nl::classifier",
369            "Custom-mode integrity anchor skipped: checksums.json missing at {} \
370             (operator-supplied dir without a complete manifest)",
371            checksums_path.display()
372        );
373        Ok(())
374    }
375}
376
377fn verify_checksums_json_hash(
378    checksums_path: &Path,
379    expected_checksums_hash: &str,
380    success_message: &str,
381) -> Result<(), ClassifierError> {
382    let actual = compute_file_hash(checksums_path)?;
383    if actual != expected_checksums_hash {
384        // TAMPERING — always fatal, no opt-out.
385        return Err(ClassifierError::ChecksumMismatch {
386            file: "checksums.json".to_string(),
387            expected: expected_checksums_hash.to_string(),
388            actual,
389        });
390    }
391    tracing::debug!("{success_message}");
392    Ok(())
393}
394
395/// Parse model version from version.txt content.
396fn parse_model_version(content: &str) -> String {
397    for line in content.lines() {
398        let line = line.trim();
399        if line.starts_with("model_version=") {
400            return line
401                .strip_prefix("model_version=")
402                .unwrap_or("unknown")
403                .to_string();
404        }
405    }
406    "unknown".to_string()
407}
408
409impl IntentClassifier {
410    /// Load classifier from model directory.
411    ///
412    /// Expected directory structure:
413    /// ```text
414    /// model_dir/
415    /// ├── intent_classifier.onnx
416    /// ├── tokenizer.json
417    /// ├── config.json
418    /// ├── calibration.json or temperature.json (optional)
419    /// ├── checksums.json
420    /// └── version.txt
421    /// ```
422    ///
423    /// # Arguments
424    ///
425    /// * `model_dir` — Resolved model directory (output of NL02
426    ///   resolver chain).
427    /// * `allow_unverified` — Operator escape hatch. When `false`
428    ///   (NL04 default per FR-7), missingness is fatal. When `true`,
429    ///   missingness downgrades to `tracing::warn!`. **Tampering on a
430    ///   present file ALWAYS errors regardless of this flag** — see
431    ///   the inline contract documented at [`verify_integrity`].
432    /// * `trust_mode` — Output of [`TrustMode::from(ResolverLevel)`].
433    ///   Trusted mode anchors integrity in the binary's baked-in
434    ///   manifest; Custom mode trusts the user-supplied
435    ///   `manifest.json` shipped alongside the model directory.
436    ///
437    /// # Errors
438    ///
439    /// Returns [`ClassifierError`] if:
440    /// - Model files not found
441    /// - Checksum verification fails (AC-11.8 / NL04 integrity contract)
442    /// - ONNX Runtime initialization fails
443    pub fn load(
444        model_dir: &Path,
445        allow_unverified: bool,
446        trust_mode: TrustMode,
447    ) -> Result<Self, ClassifierError> {
448        Self::load_inner(model_dir, allow_unverified, trust_mode)
449    }
450
451    /// Run only the NL04 integrity contract for a model directory,
452    /// without invoking ONNX Runtime.
453    ///
454    /// Same contract as [`Self::load`]'s integrity pass — exists so
455    /// integration tests can exercise the contract on synthetic
456    /// fixtures (stub ONNX bytes) without the dylib dependency.
457    ///
458    /// # Errors
459    ///
460    /// Returns [`ClassifierError::ChecksumMismatch`] /
461    /// [`ClassifierError::ChecksumsMissing`] /
462    /// [`ClassifierError::ChecksummedFileMissing`] per the contract.
463    #[doc(hidden)]
464    pub fn verify_integrity_for_tests(
465        model_dir: &Path,
466        allow_unverified: bool,
467        trust_mode: TrustMode,
468    ) -> Result<(), ClassifierError> {
469        verify_integrity(model_dir, allow_unverified, trust_mode)
470    }
471
472    /// Run the NL04 integrity contract with a test-supplied trusted
473    /// manifest instead of the binary's baked model manifest.
474    ///
475    /// This keeps active integration tests hermetic: they can exercise
476    /// the Trusted-mode anchor and strict per-file pass against
477    /// synthetic model fixtures without committing the large external
478    /// ONNX model tree.
479    ///
480    /// # Errors
481    ///
482    /// Returns the same [`ClassifierError`] variants as
483    /// [`Self::verify_integrity_for_tests`].
484    #[doc(hidden)]
485    pub fn verify_integrity_with_manifest_for_tests(
486        model_dir: &Path,
487        allow_unverified: bool,
488        trust_mode: TrustMode,
489        trusted_manifest: &Manifest,
490    ) -> Result<(), ClassifierError> {
491        verify_integrity_with_trusted_manifest(
492            model_dir,
493            allow_unverified,
494            trust_mode,
495            trusted_manifest,
496        )
497    }
498
499    fn load_inner(
500        model_dir: &Path,
501        allow_unverified: bool,
502        trust_mode: TrustMode,
503    ) -> Result<Self, ClassifierError> {
504        // NL08: deterministic test seam — when
505        // `SQRY_NL_FORCE_ORT_MISSING` is truthy AND we are running a
506        // debug build (cargo test / cargo run), short-circuit straight
507        // to `OnnxRuntimeMissing`. This lets the CLI / MCP / LSP
508        // integration tests drive the missing-runtime path without
509        // needing an actual missing libonnxruntime on the host. The
510        // helper is a no-op in release builds.
511        if ort_missing_forced() {
512            return Err(onnx_runtime_missing_error());
513        }
514
515        // Check model directory exists
516        if !model_dir.exists() {
517            return Err(ClassifierError::ModelNotFound(
518                model_dir.display().to_string(),
519            ));
520        }
521
522        // Verify integrity BEFORE any artifact load — this is the
523        // first-fail gate per the NL04 integrity contract. Tampering
524        // detection happens here, prior to ONNX session creation, so
525        // synthetic test fixtures (stub ONNX bytes) can exercise the
526        // contract without invoking the inference engine.
527        verify_integrity(model_dir, allow_unverified, trust_mode)?;
528
529        let model_path = model_dir.join("intent_classifier.onnx");
530        let tokenizer_path = model_dir.join("tokenizer.json");
531
532        if !model_path.exists() {
533            return Err(ClassifierError::ModelNotFound(
534                model_path.display().to_string(),
535            ));
536        }
537
538        if !tokenizer_path.exists() {
539            return Err(ClassifierError::ModelNotFound(
540                tokenizer_path.display().to_string(),
541            ));
542        }
543
544        // Load ONNX session.
545        //
546        // NL08: the `ort` crate panics in `setup_api()` (with the
547        // `load-dynamic` feature) if `libonnxruntime` cannot be loaded,
548        // so we wrap the whole builder chain in `catch_unwind` and
549        // reinterpret either a panic or any error string that looks
550        // like a dylib-load failure as
551        // `ClassifierError::OnnxRuntimeMissing` so callers can surface
552        // an actionable platform-specific install hint.
553        let model_path_for_load = model_path.clone();
554        let session_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
555            Session::builder()?
556                .with_intra_threads(1)?
557                .commit_from_file(&model_path_for_load)
558        }));
559        let session = match session_result {
560            Ok(Ok(session)) => session,
561            Ok(Err(e)) => {
562                let msg = e.to_string();
563                if looks_like_dylib_load_failure(&msg) {
564                    return Err(onnx_runtime_missing_error());
565                }
566                return Err(ClassifierError::OnnxError(msg));
567            }
568            Err(panic_payload) => {
569                let panic_msg = panic_payload
570                    .downcast_ref::<&'static str>()
571                    .map(|s| (*s).to_string())
572                    .or_else(|| panic_payload.downcast_ref::<String>().cloned())
573                    .unwrap_or_else(|| "ort panic with unknown payload".to_string());
574                if looks_like_dylib_load_failure(&panic_msg) {
575                    return Err(onnx_runtime_missing_error());
576                }
577                // Any other panic from ort is escalated as a generic
578                // ONNX error rather than re-thrown — translator
579                // construction must always return a typed error.
580                return Err(ClassifierError::OnnxError(format!(
581                    "ort panic during session init: {panic_msg}"
582                )));
583            }
584        };
585
586        // Detect whether model expects token_type_ids (BERT vs DistilBERT)
587        let model_inputs = session.inputs();
588        let has_token_type_ids = model_inputs
589            .iter()
590            .any(|input| input.name() == "token_type_ids");
591        tracing::debug!(
592            "Model inputs: {:?}, has_token_type_ids: {has_token_type_ids}",
593            model_inputs
594                .iter()
595                .map(ort::value::Outlet::name)
596                .collect::<Vec<_>>()
597        );
598
599        // Load tokenizer
600        let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
601            .map_err(|e| ClassifierError::TokenizationFailed(e.to_string()))?;
602
603        // Load calibration (optional) — try calibration.json first, then temperature.json
604        let calibration_path = model_dir.join("calibration.json");
605        let temperature_path = model_dir.join("temperature.json");
606        let calibration = if calibration_path.exists() {
607            let content = std::fs::read_to_string(&calibration_path)
608                .map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
609            serde_json::from_str(&content).unwrap_or_default()
610        } else if temperature_path.exists() {
611            let content = std::fs::read_to_string(&temperature_path)
612                .map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
613            let params: CalibrationParams = serde_json::from_str(&content).unwrap_or_default();
614            tracing::debug!(
615                "Loaded calibration temperature={} from temperature.json",
616                params.temperature
617            );
618            params
619        } else {
620            CalibrationParams::default()
621        };
622
623        // Load and parse version
624        let version_path = model_dir.join("version.txt");
625        let model_version = if version_path.exists() {
626            std::fs::read_to_string(&version_path)
627                .map_or_else(|_| "unknown".to_string(), |s| parse_model_version(&s))
628        } else {
629            "unknown".to_string()
630        };
631
632        Ok(Self {
633            session,
634            tokenizer,
635            calibration,
636            model_version,
637            has_token_type_ids,
638        })
639    }
640
641    /// Classify intent from natural language input.
642    ///
643    /// # Critical: `batch_size=1` enforcement (C1 mitigation)
644    ///
645    /// ONNX Runtime may crash with `batch_size` > 1. This method
646    /// always processes exactly one input.
647    ///
648    /// # Errors
649    ///
650    /// Returns [`ClassifierError`] if tokenization or inference fails.
651    ///
652    /// # Note
653    ///
654    /// This method requires `&mut self` due to ort 2.0 API requirements.
655    /// Use a Mutex wrapper if concurrent access is needed.
656    pub fn classify(&mut self, input: &str) -> Result<ClassificationResult, ClassifierError> {
657        // Tokenize input
658        let encoding = self
659            .tokenizer
660            .encode(input, true)
661            .map_err(|e| ClassifierError::TokenizationFailed(e.to_string()))?;
662
663        let input_ids = encoding.get_ids();
664        let attention_mask = encoding.get_attention_mask();
665
666        // Truncate to max 512 tokens
667        let seq_len = input_ids.len().min(512);
668        if input_ids.len() > 512 {
669            tracing::warn!("Input truncated from {} to 512 tokens", input_ids.len());
670        }
671
672        // Prepare input tensors (batch_size=1)
673        let input_ids_i64: Vec<i64> = input_ids[..seq_len].iter().map(|&x| i64::from(x)).collect();
674        let attention_mask_i64: Vec<i64> = attention_mask[..seq_len]
675            .iter()
676            .map(|&x| i64::from(x))
677            .collect();
678
679        // Create input tensors with shape [1, seq_len] - ort 2.0 requires Vec not slice
680        let input_ids_tensor = Tensor::from_array(([1, seq_len], input_ids_i64))
681            .map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
682        let attention_mask_tensor = Tensor::from_array(([1, seq_len], attention_mask_i64))
683            .map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
684
685        // Build inputs conditionally: BERT-family models (MiniLM) require token_type_ids,
686        // while DistilBERT does not declare it. ort rejects undeclared input names.
687        let inputs = if self.has_token_type_ids {
688            let type_ids = encoding.get_type_ids();
689            let token_type_ids_i64: Vec<i64> =
690                type_ids[..seq_len].iter().map(|&x| i64::from(x)).collect();
691            let token_type_ids_tensor = Tensor::from_array(([1, seq_len], token_type_ids_i64))
692                .map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
693            ort::inputs![
694                "input_ids" => input_ids_tensor,
695                "attention_mask" => attention_mask_tensor,
696                "token_type_ids" => token_type_ids_tensor,
697            ]
698        } else {
699            ort::inputs![
700                "input_ids" => input_ids_tensor,
701                "attention_mask" => attention_mask_tensor,
702            ]
703        };
704
705        let outputs = self
706            .session
707            .run(inputs)
708            .map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
709
710        // Extract logits from output
711        let logits_tensor = outputs
712            .get("logits")
713            .ok_or_else(|| ClassifierError::OnnxError("No 'logits' output".to_string()))?;
714
715        // try_extract_tensor returns (&Shape, &[T]) tuple in ort 2.0
716        let (_, logits_data) = logits_tensor
717            .try_extract_tensor::<f32>()
718            .map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
719
720        let logits: Vec<f32> = logits_data.to_vec();
721
722        // Apply calibration and softmax
723        let probabilities = self.calibration.apply_temperature_scaling(&logits);
724
725        // Find argmax
726        let (intent_idx, confidence) = probabilities
727            .iter()
728            .enumerate()
729            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
730            .map_or((Intent::NUM_CLASSES - 1, 0.0), |(idx, &conf)| (idx, conf)); // Default to Ambiguous
731
732        let intent = Intent::from_index(intent_idx);
733
734        Ok(ClassificationResult {
735            intent,
736            confidence,
737            all_probabilities: probabilities,
738            model_version: self.model_version.clone(),
739        })
740    }
741
742    /// Get the model version.
743    #[must_use]
744    pub fn model_version(&self) -> &str {
745        &self.model_version
746    }
747}
748
749#[cfg(test)]
750mod tests {
751    use super::*;
752
753    #[test]
754    fn test_parse_model_version() {
755        let content = r"
756# sqry-nl Intent Classifier Model
757model_version=1.0.0
758model_date=2025-12-09T07:34:00Z
759accuracy=0.9998
760";
761        assert_eq!(parse_model_version(content), "1.0.0");
762    }
763
764    #[test]
765    fn test_parse_model_version_missing() {
766        let content = "# No version here\naccuracy=0.99";
767        assert_eq!(parse_model_version(content), "unknown");
768    }
769
770    #[test]
771    fn test_parse_model_version_empty() {
772        assert_eq!(parse_model_version(""), "unknown");
773    }
774
775    // Tests requiring actual model files are marked as ignored
776    // and run during integration testing.
777
778    #[test]
779    #[ignore = "Requires trained model files"]
780    fn test_classifier_load() {
781        // Would test model loading
782    }
783
784    #[test]
785    #[ignore = "Requires trained model files"]
786    fn test_classifier_inference() {
787        // Would test inference
788    }
789
790    #[test]
791    #[ignore = "Requires trained model files"]
792    fn test_checksum_verification() {
793        // Would test checksum verification against deployed model
794    }
795}