Skip to main content

walrus_model/local/
mod.rs

1//! Local LLM provider via mistralrs.
2//!
3//! Wraps `mistralrs::Model` for native on-device inference.
4//! No HTTP transport — inference runs in-process.
5//! Provides per-builder constructors: `from_text()`, `from_gguf()`,
6//! `from_vision()`. All use the walrus model cache directory.
7
8use std::path::PathBuf;
9use std::sync::Arc;
10
11pub mod download;
12mod provider;
13
14/// Local LLM provider wrapping a mistralrs `Model`.
15#[derive(Clone)]
16pub struct Local {
17    model: Arc<mistralrs::Model>,
18}
19
20impl Local {
21    /// Construct from a pre-built mistralrs `Model`.
22    pub fn from_model(model: mistralrs::Model) -> Self {
23        Self {
24            model: Arc::new(model),
25        }
26    }
27
28    /// Build using `TextModelBuilder`.
29    ///
30    /// Standard text models from HuggingFace.
31    pub async fn from_text(
32        model_id: &str,
33        isq: Option<mistralrs::IsqType>,
34        chat_template: Option<&str>,
35    ) -> anyhow::Result<Self> {
36        let mut builder = mistralrs::TextModelBuilder::new(model_id)
37            .with_logging()
38            .from_hf_cache_pathf(cache_dir());
39        if let Some(isq) = isq {
40            builder = builder.with_isq(isq);
41        }
42        if let Some(template) = chat_template {
43            builder = builder.with_chat_template(template);
44        }
45        let model = builder.build().await?;
46        Ok(Self::from_model(model))
47    }
48
49    /// Build using `GgufModelBuilder`.
50    ///
51    /// GGUF quantized models from HuggingFace. The `model_id` is the HF repo
52    /// ID; mistralrs auto-discovers GGUF files in the repo.
53    pub async fn from_gguf(model_id: &str, chat_template: Option<&str>) -> anyhow::Result<Self> {
54        // Pass empty files vec — mistralrs will auto-detect GGUF files.
55        let mut builder =
56            mistralrs::GgufModelBuilder::new(model_id, Vec::<String>::new()).with_logging();
57        if let Some(template) = chat_template {
58            builder = builder.with_chat_template(template);
59        }
60        let model = builder.build().await?;
61        Ok(Self::from_model(model))
62    }
63
64    /// Access the inner mistralrs `Model`.
65    pub fn model(&self) -> &mistralrs::Model {
66        &self.model
67    }
68
69    /// Query the context length for a given model ID.
70    ///
71    /// Returns None if the model doesn't report a sequence length.
72    pub fn context_length(&self, model: &str) -> Option<usize> {
73        self.model
74            .max_sequence_length_with_model(Some(model))
75            .ok()
76            .flatten()
77    }
78
79    /// Build using `VisionModelBuilder`.
80    ///
81    /// Vision-language models from HuggingFace.
82    pub async fn from_vision(
83        model_id: &str,
84        isq: Option<mistralrs::IsqType>,
85        chat_template: Option<&str>,
86    ) -> anyhow::Result<Self> {
87        let mut builder = mistralrs::VisionModelBuilder::new(model_id)
88            .with_logging()
89            .from_hf_cache_pathf(cache_dir());
90        if let Some(isq) = isq {
91            builder = builder.with_isq(isq);
92        }
93        if let Some(template) = chat_template {
94            builder = builder.with_chat_template(template);
95        }
96        let model = builder.build().await?;
97        Ok(Self::from_model(model))
98    }
99}
100
101/// Walrus HF cache directory: `~/.walrus/hf`.
102pub(crate) fn cache_dir() -> PathBuf {
103    dirs::home_dir()
104        .expect("no home directory")
105        .join(".walrus")
106        .join("hf")
107}