shimmy_llama_cpp_2/model/params.rs
1//! A safe wrapper around `llama_model_params`.
2
3use crate::model::params::kv_overrides::KvOverrides;
4use std::ffi::{c_char, CStr, CString};
5use std::fmt::{Debug, Formatter};
6use std::pin::Pin;
7use std::ptr::null;
8
9pub mod kv_overrides;
10
11/// A safe wrapper around `llama_model_params`.
12#[allow(clippy::module_name_repetitions)]
13pub struct LlamaModelParams {
14 pub(crate) params: shimmy_llama_cpp_sys_2::llama_model_params,
15 kv_overrides: Vec<shimmy_llama_cpp_sys_2::llama_model_kv_override>,
16 /// Storage for tensor buffer override patterns (keeps CStrings alive)
17 tensor_override_patterns: Vec<CString>,
18 /// Tensor buffer overrides (NULL-terminated array)
19 tensor_overrides: Vec<shimmy_llama_cpp_sys_2::llama_model_tensor_buft_override>,
20}
21
22impl Debug for LlamaModelParams {
23 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("LlamaModelParams")
25 .field("n_gpu_layers", &self.params.n_gpu_layers)
26 .field("main_gpu", &self.params.main_gpu)
27 .field("vocab_only", &self.params.vocab_only)
28 .field("use_mmap", &self.params.use_mmap)
29 .field("use_mlock", &self.params.use_mlock)
30 .field("kv_overrides", &"vec of kv_overrides")
31 .finish()
32 }
33}
34
35impl LlamaModelParams {
36 /// See [`KvOverrides`]
37 ///
38 /// # Examples
39 ///
40 /// ```rust
41 /// # use llama_cpp_2::model::params::LlamaModelParams;
42 /// let params = Box::pin(LlamaModelParams::default());
43 /// let kv_overrides = params.kv_overrides();
44 /// let count = kv_overrides.into_iter().count();
45 /// assert_eq!(count, 0);
46 /// ```
47 #[must_use]
48 pub fn kv_overrides(&self) -> KvOverrides {
49 KvOverrides::new(self)
50 }
51
52 /// Appends a key-value override to the model parameters. It must be pinned as this creates a self-referential struct.
53 ///
54 /// # Examples
55 ///
56 /// ```rust
57 /// # use std::ffi::{CStr, CString};
58 /// use std::pin::pin;
59 /// # use llama_cpp_2::model::params::LlamaModelParams;
60 /// # use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
61 /// let mut params = pin!(LlamaModelParams::default());
62 /// let key = CString::new("key").expect("CString::new failed");
63 /// params.as_mut().append_kv_override(&key, ParamOverrideValue::Int(50));
64 ///
65 /// let kv_overrides = params.kv_overrides().into_iter().collect::<Vec<_>>();
66 /// assert_eq!(kv_overrides.len(), 1);
67 ///
68 /// let (k, v) = &kv_overrides[0];
69 /// assert_eq!(v, &ParamOverrideValue::Int(50));
70 ///
71 /// assert_eq!(k.to_bytes(), b"key", "expected key to be 'key', was {:?}", k);
72 /// ```
73 #[allow(clippy::missing_panics_doc)] // panics are just to enforce internal invariants, not user errors
74 pub fn append_kv_override(
75 mut self: Pin<&mut Self>,
76 key: &CStr,
77 value: kv_overrides::ParamOverrideValue,
78 ) {
79 let kv_override = self
80 .kv_overrides
81 .get_mut(0)
82 .expect("kv_overrides did not have a next allocated");
83
84 assert_eq!(kv_override.key[0], 0, "last kv_override was not empty");
85
86 // There should be some way to do this without iterating over everything.
87 for (i, &c) in key.to_bytes_with_nul().iter().enumerate() {
88 kv_override.key[i] = c_char::try_from(c).expect("invalid character in key");
89 }
90
91 kv_override.tag = value.tag();
92 kv_override.__bindgen_anon_1 = value.value();
93
94 // set to null pointer for panic safety (as push may move the vector, invalidating the pointer)
95 self.params.kv_overrides = null();
96
97 // push the next one to ensure we maintain the iterator invariant of ending with a 0
98 self.kv_overrides
99 .push(shimmy_llama_cpp_sys_2::llama_model_kv_override {
100 key: [0; 128],
101 tag: 0,
102 __bindgen_anon_1: shimmy_llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
103 val_i64: 0,
104 },
105 });
106
107 // set the pointer to the (potentially) new vector
108 self.params.kv_overrides = self.kv_overrides.as_ptr();
109
110 eprintln!("saved ptr: {:?}", self.params.kv_overrides);
111 }
112}
113
114impl LlamaModelParams {
115 /// Get the number of layers to offload to the GPU.
116 #[must_use]
117 pub fn n_gpu_layers(&self) -> i32 {
118 self.params.n_gpu_layers
119 }
120
121 /// The GPU that is used for scratch and small tensors
122 #[must_use]
123 pub fn main_gpu(&self) -> i32 {
124 self.params.main_gpu
125 }
126
127 /// only load the vocabulary, no weights
128 #[must_use]
129 pub fn vocab_only(&self) -> bool {
130 self.params.vocab_only
131 }
132
133 /// use mmap if possible
134 #[must_use]
135 pub fn use_mmap(&self) -> bool {
136 self.params.use_mmap
137 }
138
139 /// force system to keep model in RAM
140 #[must_use]
141 pub fn use_mlock(&self) -> bool {
142 self.params.use_mlock
143 }
144
145 /// sets the number of gpu layers to offload to the GPU.
146 /// ```
147 /// # use llama_cpp_2::model::params::LlamaModelParams;
148 /// let params = LlamaModelParams::default();
149 /// let params = params.with_n_gpu_layers(1);
150 /// assert_eq!(params.n_gpu_layers(), 1);
151 /// ```
152 #[must_use]
153 pub fn with_n_gpu_layers(mut self, n_gpu_layers: u32) -> Self {
154 // The only way this conversion can fail is if u32 overflows the i32 - in which case we set
155 // to MAX
156 let n_gpu_layers = i32::try_from(n_gpu_layers).unwrap_or(i32::MAX);
157 self.params.n_gpu_layers = n_gpu_layers;
158 self
159 }
160
161 /// sets the main GPU
162 #[must_use]
163 pub fn with_main_gpu(mut self, main_gpu: i32) -> Self {
164 self.params.main_gpu = main_gpu;
165 self
166 }
167
168 /// sets `vocab_only`
169 #[must_use]
170 pub fn with_vocab_only(mut self, vocab_only: bool) -> Self {
171 self.params.vocab_only = vocab_only;
172 self
173 }
174
175 /// sets `use_mlock`
176 #[must_use]
177 pub fn with_use_mlock(mut self, use_mlock: bool) -> Self {
178 self.params.use_mlock = use_mlock;
179 self
180 }
181
182 /// Offload ALL MoE (Mixture of Experts) expert tensors to CPU.
183 ///
184 /// This reduces VRAM usage for large MoE models (e.g., GPT-OSS, Qwen) by keeping
185 /// expert tensors in system RAM while attention layers remain on GPU.
186 ///
187 /// Matches llama.cpp `--cpu-moe` flag behavior.
188 ///
189 /// # Examples
190 ///
191 /// ```rust
192 /// # use llama_cpp_2::model::params::LlamaModelParams;
193 /// let params = LlamaModelParams::default()
194 /// .with_n_gpu_layers(999) // Offload all non-expert layers to GPU
195 /// .with_cpu_moe_all(); // But keep expert tensors in CPU
196 /// ```
197 #[must_use]
198 pub fn with_cpu_moe_all(mut self) -> Self {
199 self.push_tensor_override(r"\.ffn_(up|down|gate)_exps");
200 self
201 }
202
203 /// Offload the first N MoE layers' expert tensors to CPU.
204 ///
205 /// This allows fine-grained control over VRAM usage by offloading only some expert layers.
206 /// Typically used when you have limited VRAM and want to balance GPU/CPU usage.
207 ///
208 /// Matches llama.cpp `--n-cpu-moe N` flag behavior.
209 ///
210 /// # Examples
211 ///
212 /// ```rust
213 /// # use llama_cpp_2::model::params::LlamaModelParams;
214 /// let params = LlamaModelParams::default()
215 /// .with_n_gpu_layers(999) // Offload layers to GPU
216 /// .with_n_cpu_moe(10); // Except first 10 MoE expert layers -> CPU
217 /// ```
218 #[must_use]
219 pub fn with_n_cpu_moe(mut self, n: usize) -> Self {
220 for i in 0..n {
221 let pattern = format!(r"blk\.{}\.ffn_(up|down|gate)_exps", i);
222 self.push_tensor_override(&pattern);
223 }
224 self
225 }
226
227 /// Internal: Push a tensor buffer override pattern.
228 ///
229 /// Patterns are regex expressions matched against tensor names in the model.
230 /// Matched tensors are allocated using CPU buffer type instead of GPU.
231 fn push_tensor_override(&mut self, pattern: &str) {
232 // Create and store CString to keep it alive
233 let c_pattern = CString::new(pattern)
234 .expect("pattern must not contain NUL bytes");
235
236 self.tensor_override_patterns.push(c_pattern);
237
238 // Get CPU buffer type
239 let cpu_buft = unsafe { shimmy_llama_cpp_sys_2::ggml_backend_cpu_buffer_type() };
240
241 // Create override entry pointing to the stored CString
242 let override_entry = shimmy_llama_cpp_sys_2::llama_model_tensor_buft_override {
243 pattern: self.tensor_override_patterns.last().unwrap().as_ptr(),
244 buft: cpu_buft,
245 };
246
247 // Remove old NULL terminator if it exists
248 if let Some(last) = self.tensor_overrides.last() {
249 if last.pattern.is_null() {
250 self.tensor_overrides.pop();
251 }
252 }
253
254 // Add new entry
255 self.tensor_overrides.push(override_entry);
256
257 // Re-add NULL terminator (pattern=NULL signals end of array to C)
258 self.tensor_overrides.push(shimmy_llama_cpp_sys_2::llama_model_tensor_buft_override {
259 pattern: std::ptr::null(),
260 buft: std::ptr::null_mut(),
261 });
262
263 // Update C params pointer
264 self.params.tensor_buft_overrides = self.tensor_overrides.as_ptr();
265 }
266}
267
268/// Default parameters for `LlamaModel`. (as defined in llama.cpp by `llama_model_default_params`)
269/// ```
270/// # use llama_cpp_2::model::params::LlamaModelParams;
271/// let params = LlamaModelParams::default();
272/// assert_eq!(params.n_gpu_layers(), 999, "n_gpu_layers should be 999");
273/// assert_eq!(params.main_gpu(), 0, "main_gpu should be 0");
274/// assert_eq!(params.vocab_only(), false, "vocab_only should be false");
275/// assert_eq!(params.use_mmap(), true, "use_mmap should be true");
276/// assert_eq!(params.use_mlock(), false, "use_mlock should be false");
277/// ```
278impl Default for LlamaModelParams {
279 fn default() -> Self {
280 let default_params = unsafe { shimmy_llama_cpp_sys_2::llama_model_default_params() };
281 LlamaModelParams {
282 params: default_params,
283 // push the next one to ensure we maintain the iterator invariant of ending with a 0
284 kv_overrides: vec![shimmy_llama_cpp_sys_2::llama_model_kv_override {
285 key: [0; 128],
286 tag: 0,
287 __bindgen_anon_1: shimmy_llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
288 val_i64: 0,
289 },
290 }],
291 tensor_override_patterns: Vec::new(),
292 tensor_overrides: Vec::new(),
293 }
294 }
295}