shimmy_llama_cpp_2/model/
params.rs

1//! A safe wrapper around `llama_model_params`.
2
3use crate::model::params::kv_overrides::KvOverrides;
4use std::ffi::{c_char, CStr, CString};
5use std::fmt::{Debug, Formatter};
6use std::pin::Pin;
7use std::ptr::null;
8
9pub mod kv_overrides;
10
11/// A safe wrapper around `llama_model_params`.
12#[allow(clippy::module_name_repetitions)]
13pub struct LlamaModelParams {
14    pub(crate) params: shimmy_llama_cpp_sys_2::llama_model_params,
15    kv_overrides: Vec<shimmy_llama_cpp_sys_2::llama_model_kv_override>,
16    /// Storage for tensor buffer override patterns (keeps CStrings alive)
17    tensor_override_patterns: Vec<CString>,
18    /// Tensor buffer overrides (NULL-terminated array)
19    tensor_overrides: Vec<shimmy_llama_cpp_sys_2::llama_model_tensor_buft_override>,
20}
21
22impl Debug for LlamaModelParams {
23    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("LlamaModelParams")
25            .field("n_gpu_layers", &self.params.n_gpu_layers)
26            .field("main_gpu", &self.params.main_gpu)
27            .field("vocab_only", &self.params.vocab_only)
28            .field("use_mmap", &self.params.use_mmap)
29            .field("use_mlock", &self.params.use_mlock)
30            .field("kv_overrides", &"vec of kv_overrides")
31            .finish()
32    }
33}
34
35impl LlamaModelParams {
36    /// See [`KvOverrides`]
37    ///
38    /// # Examples
39    ///
40    /// ```rust
41    /// # use llama_cpp_2::model::params::LlamaModelParams;
42    /// let params = Box::pin(LlamaModelParams::default());
43    /// let kv_overrides = params.kv_overrides();
44    /// let count = kv_overrides.into_iter().count();
45    /// assert_eq!(count, 0);
46    /// ```
47    #[must_use]
48    pub fn kv_overrides(&self) -> KvOverrides {
49        KvOverrides::new(self)
50    }
51
52    /// Appends a key-value override to the model parameters. It must be pinned as this creates a self-referential struct.
53    ///
54    /// # Examples
55    ///
56    /// ```rust
57    /// # use std::ffi::{CStr, CString};
58    /// use std::pin::pin;
59    /// # use llama_cpp_2::model::params::LlamaModelParams;
60    /// # use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
61    /// let mut params = pin!(LlamaModelParams::default());
62    /// let key = CString::new("key").expect("CString::new failed");
63    /// params.as_mut().append_kv_override(&key, ParamOverrideValue::Int(50));
64    ///
65    /// let kv_overrides = params.kv_overrides().into_iter().collect::<Vec<_>>();
66    /// assert_eq!(kv_overrides.len(), 1);
67    ///
68    /// let (k, v) = &kv_overrides[0];
69    /// assert_eq!(v, &ParamOverrideValue::Int(50));
70    ///
71    /// assert_eq!(k.to_bytes(), b"key", "expected key to be 'key', was {:?}", k);
72    /// ```
73    #[allow(clippy::missing_panics_doc)] // panics are just to enforce internal invariants, not user errors
74    pub fn append_kv_override(
75        mut self: Pin<&mut Self>,
76        key: &CStr,
77        value: kv_overrides::ParamOverrideValue,
78    ) {
79        let kv_override = self
80            .kv_overrides
81            .get_mut(0)
82            .expect("kv_overrides did not have a next allocated");
83
84        assert_eq!(kv_override.key[0], 0, "last kv_override was not empty");
85
86        // There should be some way to do this without iterating over everything.
87        for (i, &c) in key.to_bytes_with_nul().iter().enumerate() {
88            kv_override.key[i] = c_char::try_from(c).expect("invalid character in key");
89        }
90
91        kv_override.tag = value.tag();
92        kv_override.__bindgen_anon_1 = value.value();
93
94        // set to null pointer for panic safety (as push may move the vector, invalidating the pointer)
95        self.params.kv_overrides = null();
96
97        // push the next one to ensure we maintain the iterator invariant of ending with a 0
98        self.kv_overrides
99            .push(shimmy_llama_cpp_sys_2::llama_model_kv_override {
100                key: [0; 128],
101                tag: 0,
102                __bindgen_anon_1: shimmy_llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
103                    val_i64: 0,
104                },
105            });
106
107        // set the pointer to the (potentially) new vector
108        self.params.kv_overrides = self.kv_overrides.as_ptr();
109
110        eprintln!("saved ptr: {:?}", self.params.kv_overrides);
111    }
112}
113
114impl LlamaModelParams {
115    /// Get the number of layers to offload to the GPU.
116    #[must_use]
117    pub fn n_gpu_layers(&self) -> i32 {
118        self.params.n_gpu_layers
119    }
120
121    /// The GPU that is used for scratch and small tensors
122    #[must_use]
123    pub fn main_gpu(&self) -> i32 {
124        self.params.main_gpu
125    }
126
127    /// only load the vocabulary, no weights
128    #[must_use]
129    pub fn vocab_only(&self) -> bool {
130        self.params.vocab_only
131    }
132
133    /// use mmap if possible
134    #[must_use]
135    pub fn use_mmap(&self) -> bool {
136        self.params.use_mmap
137    }
138
139    /// force system to keep model in RAM
140    #[must_use]
141    pub fn use_mlock(&self) -> bool {
142        self.params.use_mlock
143    }
144
145    /// sets the number of gpu layers to offload to the GPU.
146    /// ```
147    /// # use llama_cpp_2::model::params::LlamaModelParams;
148    /// let params = LlamaModelParams::default();
149    /// let params = params.with_n_gpu_layers(1);
150    /// assert_eq!(params.n_gpu_layers(), 1);
151    /// ```
152    #[must_use]
153    pub fn with_n_gpu_layers(mut self, n_gpu_layers: u32) -> Self {
154        // The only way this conversion can fail is if u32 overflows the i32 - in which case we set
155        // to MAX
156        let n_gpu_layers = i32::try_from(n_gpu_layers).unwrap_or(i32::MAX);
157        self.params.n_gpu_layers = n_gpu_layers;
158        self
159    }
160
161    /// sets the main GPU
162    #[must_use]
163    pub fn with_main_gpu(mut self, main_gpu: i32) -> Self {
164        self.params.main_gpu = main_gpu;
165        self
166    }
167
168    /// sets `vocab_only`
169    #[must_use]
170    pub fn with_vocab_only(mut self, vocab_only: bool) -> Self {
171        self.params.vocab_only = vocab_only;
172        self
173    }
174
175    /// sets `use_mlock`
176    #[must_use]
177    pub fn with_use_mlock(mut self, use_mlock: bool) -> Self {
178        self.params.use_mlock = use_mlock;
179        self
180    }
181
182    /// Offload ALL MoE (Mixture of Experts) expert tensors to CPU.
183    ///
184    /// This reduces VRAM usage for large MoE models (e.g., GPT-OSS, Qwen) by keeping
185    /// expert tensors in system RAM while attention layers remain on GPU.
186    ///
187    /// Matches llama.cpp `--cpu-moe` flag behavior.
188    ///
189    /// # Examples
190    ///
191    /// ```rust
192    /// # use llama_cpp_2::model::params::LlamaModelParams;
193    /// let params = LlamaModelParams::default()
194    ///     .with_n_gpu_layers(999)  // Offload all non-expert layers to GPU
195    ///     .with_cpu_moe_all();     // But keep expert tensors in CPU
196    /// ```
197    #[must_use]
198    pub fn with_cpu_moe_all(mut self) -> Self {
199        self.push_tensor_override(r"\.ffn_(up|down|gate)_exps");
200        self
201    }
202
203    /// Offload the first N MoE layers' expert tensors to CPU.
204    ///
205    /// This allows fine-grained control over VRAM usage by offloading only some expert layers.
206    /// Typically used when you have limited VRAM and want to balance GPU/CPU usage.
207    ///
208    /// Matches llama.cpp `--n-cpu-moe N` flag behavior.
209    ///
210    /// # Examples
211    ///
212    /// ```rust
213    /// # use llama_cpp_2::model::params::LlamaModelParams;
214    /// let params = LlamaModelParams::default()
215    ///     .with_n_gpu_layers(999)  // Offload layers to GPU
216    ///     .with_n_cpu_moe(10);     // Except first 10 MoE expert layers -> CPU
217    /// ```
218    #[must_use]
219    pub fn with_n_cpu_moe(mut self, n: usize) -> Self {
220        for i in 0..n {
221            let pattern = format!(r"blk\.{}\.ffn_(up|down|gate)_exps", i);
222            self.push_tensor_override(&pattern);
223        }
224        self
225    }
226
227    /// Internal: Push a tensor buffer override pattern.
228    ///
229    /// Patterns are regex expressions matched against tensor names in the model.
230    /// Matched tensors are allocated using CPU buffer type instead of GPU.
231    fn push_tensor_override(&mut self, pattern: &str) {
232        // Create and store CString to keep it alive
233        let c_pattern = CString::new(pattern)
234            .expect("pattern must not contain NUL bytes");
235        
236        self.tensor_override_patterns.push(c_pattern);
237        
238        // Get CPU buffer type
239        let cpu_buft = unsafe { shimmy_llama_cpp_sys_2::ggml_backend_cpu_buffer_type() };
240        
241        // Create override entry pointing to the stored CString
242        let override_entry = shimmy_llama_cpp_sys_2::llama_model_tensor_buft_override {
243            pattern: self.tensor_override_patterns.last().unwrap().as_ptr(),
244            buft: cpu_buft,
245        };
246        
247        // Remove old NULL terminator if it exists
248        if let Some(last) = self.tensor_overrides.last() {
249            if last.pattern.is_null() {
250                self.tensor_overrides.pop();
251            }
252        }
253        
254        // Add new entry
255        self.tensor_overrides.push(override_entry);
256        
257        // Re-add NULL terminator (pattern=NULL signals end of array to C)
258        self.tensor_overrides.push(shimmy_llama_cpp_sys_2::llama_model_tensor_buft_override {
259            pattern: std::ptr::null(),
260            buft: std::ptr::null_mut(),
261        });
262        
263        // Update C params pointer
264        self.params.tensor_buft_overrides = self.tensor_overrides.as_ptr();
265    }
266}
267
268/// Default parameters for `LlamaModel`. (as defined in llama.cpp by `llama_model_default_params`)
269/// ```
270/// # use llama_cpp_2::model::params::LlamaModelParams;
271/// let params = LlamaModelParams::default();
272/// assert_eq!(params.n_gpu_layers(), 999, "n_gpu_layers should be 999");
273/// assert_eq!(params.main_gpu(), 0, "main_gpu should be 0");
274/// assert_eq!(params.vocab_only(), false, "vocab_only should be false");
275/// assert_eq!(params.use_mmap(), true, "use_mmap should be true");
276/// assert_eq!(params.use_mlock(), false, "use_mlock should be false");
277/// ```
278impl Default for LlamaModelParams {
279    fn default() -> Self {
280        let default_params = unsafe { shimmy_llama_cpp_sys_2::llama_model_default_params() };
281        LlamaModelParams {
282            params: default_params,
283            // push the next one to ensure we maintain the iterator invariant of ending with a 0      
284            kv_overrides: vec![shimmy_llama_cpp_sys_2::llama_model_kv_override {
285                key: [0; 128],
286                tag: 0,
287                __bindgen_anon_1: shimmy_llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
288                    val_i64: 0,
289                },
290            }],
291            tensor_override_patterns: Vec::new(),
292            tensor_overrides: Vec::new(),
293        }
294    }
295}