walrus_model/local/
mod.rs1use std::path::PathBuf;
9use std::sync::Arc;
10
11pub mod download;
12mod provider;
13
14#[derive(Clone)]
16pub struct Local {
17 model: Arc<mistralrs::Model>,
18}
19
20impl Local {
21 pub fn from_model(model: mistralrs::Model) -> Self {
23 Self {
24 model: Arc::new(model),
25 }
26 }
27
28 pub async fn from_text(
32 model_id: &str,
33 isq: Option<mistralrs::IsqType>,
34 chat_template: Option<&str>,
35 ) -> anyhow::Result<Self> {
36 let mut builder = mistralrs::TextModelBuilder::new(model_id)
37 .with_logging()
38 .from_hf_cache_pathf(cache_dir());
39 if let Some(isq) = isq {
40 builder = builder.with_isq(isq);
41 }
42 if let Some(template) = chat_template {
43 builder = builder.with_chat_template(template);
44 }
45 let model = builder.build().await?;
46 Ok(Self::from_model(model))
47 }
48
49 pub async fn from_gguf(model_id: &str, chat_template: Option<&str>) -> anyhow::Result<Self> {
54 let mut builder =
56 mistralrs::GgufModelBuilder::new(model_id, Vec::<String>::new()).with_logging();
57 if let Some(template) = chat_template {
58 builder = builder.with_chat_template(template);
59 }
60 let model = builder.build().await?;
61 Ok(Self::from_model(model))
62 }
63
64 pub fn model(&self) -> &mistralrs::Model {
66 &self.model
67 }
68
69 pub fn context_length(&self, model: &str) -> Option<usize> {
73 self.model
74 .max_sequence_length_with_model(Some(model))
75 .ok()
76 .flatten()
77 }
78
79 pub async fn from_vision(
83 model_id: &str,
84 isq: Option<mistralrs::IsqType>,
85 chat_template: Option<&str>,
86 ) -> anyhow::Result<Self> {
87 let mut builder = mistralrs::VisionModelBuilder::new(model_id)
88 .with_logging()
89 .from_hf_cache_pathf(cache_dir());
90 if let Some(isq) = isq {
91 builder = builder.with_isq(isq);
92 }
93 if let Some(template) = chat_template {
94 builder = builder.with_chat_template(template);
95 }
96 let model = builder.build().await?;
97 Ok(Self::from_model(model))
98 }
99}
100
101pub(crate) fn cache_dir() -> PathBuf {
103 dirs::home_dir()
104 .expect("no home directory")
105 .join(".walrus")
106 .join("hf")
107}