1use std::borrow::Cow;
2
3use half::f16;
4use itertools::Itertools;
5use regex::Regex;
6use safetensors::{Dtype, SafeTensorError, SafeTensors};
7use thiserror::Error;
8use web_rwkv_derive::{Deref, DerefMut};
9
10use super::model::{ModelCustomInfo, ModelInfo, ModelVersion, Quant};
11use crate::{
12 context::Context,
13 num::Scalar,
14 tensor::{
15 kind::ReadWrite,
16 matrix::Matrix,
17 ops::{Activation, TensorOp},
18 shape::{Shape, TensorDimension},
19 TensorCpu, TensorError, TensorErrorKind, TensorGpu, TensorInit, TensorInto, TensorReshape,
20 TensorShape,
21 },
22};
23
24pub const PAD_VEC: [usize; 4] = [8, 1, 1, 1];
25pub const PAD_MAT: [usize; 4] = [8, 8, 1, 1];
26
27#[derive(Debug, Error)]
28pub enum LoaderError {
29 #[error("invalid model version")]
30 InvalidVersion,
31 #[error("tensor error")]
32 TensorError(#[from] TensorError),
33 #[error("failed to load safe tensor")]
34 SafeTensor(#[from] safetensors::SafeTensorError),
35 #[error("failed to parse int")]
36 ParseIntError(#[from] std::num::ParseIntError),
37 #[error("failed to parse regex")]
38 RegexError(#[from] regex::Error),
39}
40
41pub type ReaderTensor<'a> = (Dtype, Vec<usize>, Cow<'a, [u8]>);
42
43pub trait Reader {
45 fn names(&self) -> Vec<&str>;
46 fn contains(&self, name: &str) -> bool;
47 fn shape(&self, name: &str) -> Result<Vec<usize>, SafeTensorError>;
48 fn tensor(&self, name: &str) -> Result<ReaderTensor<'_>, SafeTensorError>;
49}
50
51impl Reader for SafeTensors<'_> {
52 #[inline]
53 fn names(&self) -> Vec<&str> {
54 self.names().into_iter().map(AsRef::as_ref).collect()
55 }
56
57 #[inline]
58 fn contains(&self, name: &str) -> bool {
59 self.names().contains(&name)
60 }
61
62 #[inline]
63 fn shape(&self, name: &str) -> Result<Vec<usize>, SafeTensorError> {
64 Ok(self.tensor(name)?.shape().to_vec())
65 }
66
67 #[inline]
68 fn tensor(&self, name: &str) -> Result<ReaderTensor<'_>, SafeTensorError> {
69 let tensor = SafeTensors::tensor(self, name)?;
70 let shape = tensor.shape().to_vec();
71 let data = tensor.data().into();
72 Ok((tensor.dtype(), shape, data))
73 }
74}
75
76pub trait TensorFromReader<T: Scalar> {
77 fn from_reader(reader: ReaderTensor) -> Result<TensorCpu<T>, TensorError>;
79}
80
81impl<T: Scalar> TensorFromReader<T> for TensorCpu<T> {
82 fn from_reader((dt, shape, data): ReaderTensor) -> Result<Self, TensorError> {
83 if T::DATA_TYPE != dt {
84 Err(TensorErrorKind::Type)?;
85 }
86 let shape = Shape::from_slice_rev(&shape)?;
87 match data {
88 Cow::Borrowed(data) => Self::from_data(shape, bytemuck::cast_slice(data)),
89 Cow::Owned(data) => {
90 let data = bytemuck::cast_slice(&data);
91 let data = Cow::Owned(data.to_vec());
92 Self::from_data(shape, data)
93 }
94 }
95 }
96}
97
98#[derive(Clone)]
100pub struct Lora<R> {
101 pub data: R,
103 pub blend: LoraBlend,
107}
108
109#[derive(Debug, Default, Clone, Deref, DerefMut)]
111pub struct LoraBlend(pub Vec<LoraBlendPattern>);
112
113impl LoraBlend {
114 #[inline]
116 pub fn full(alpha: f32) -> Self {
117 Self::default().add_nominal(1.0).add_matrices(alpha)
118 }
119
120 #[inline]
122 pub fn add_nominal(mut self, alpha: f32) -> Self {
123 let pattern = LoraBlendPattern::new(r".+", alpha).unwrap();
124 self.push(pattern);
125 self
126 }
127
128 #[inline]
130 pub fn add_matrices(mut self, alpha: f32) -> Self {
131 let pattern = LoraBlendPattern::new(
132 r"blocks\.([0-9]+)\.(att|ffn)\.(key|value|receptance|gate|output)\.weight",
133 alpha,
134 )
135 .unwrap();
136 self.push(pattern);
137 self
138 }
139
140 pub fn add_layer_nominal(mut self, layer: usize, alpha: f32) -> Self {
142 let pattern = format!(r"blocks\.{layer}");
143 let pattern = LoraBlendPattern::new(&pattern, alpha).unwrap();
144 self.push(pattern);
145 self
146 }
147
148 pub fn add_layer_matrices(mut self, layer: usize, alpha: f32) -> Self {
150 let pattern =
151 format!(r"blocks\.{layer}\.(att|ffn)\.(key|value|receptance|gate|output)\.weight");
152 let pattern = LoraBlendPattern::new(&pattern, alpha).unwrap();
153 self.push(pattern);
154 self
155 }
156}
157
158#[derive(Debug, Clone)]
160pub struct LoraBlendPattern {
161 pattern: Regex,
163 alpha: f32,
165}
166
167impl LoraBlendPattern {
168 #[inline]
169 pub fn new(pattern: &str, alpha: f32) -> Result<Self, LoaderError> {
170 Ok(Self {
171 pattern: Regex::new(pattern)?,
172 alpha,
173 })
174 }
175
176 #[inline]
177 pub fn alpha(&self) -> f32 {
178 self.alpha
179 }
180}
181
182struct LoraVector {
183 tensor: TensorGpu<f16, ReadWrite>,
184 alpha: f32,
185}
186
187struct LoraMatrix {
188 x: TensorGpu<f16, ReadWrite>,
189 y: TensorGpu<f16, ReadWrite>,
190 rank: usize,
191 alpha: f32,
192}
193
194#[derive(Clone)]
195pub struct Loader<R> {
196 pub context: Context,
197 pub model: R,
198 pub lora: Vec<Lora<R>>,
199}
200
201impl<R: Reader> Loader<R> {
202 pub fn info(model: &R) -> Result<ModelInfo, LoaderError> {
203 let num_layer = {
204 let mut r: usize = 0;
205 for i in model.names() {
206 const PREFIX: &str = "blocks.";
207 if let Some(i) = i.strip_prefix(PREFIX) {
208 let i = &i[..i.find('.').unwrap_or(0)];
209 r = r.max(i.parse::<usize>()?)
210 }
211 }
212 r + 1
213 };
214
215 let embed = model.shape("emb.weight")?;
216 let ffn = model.shape("blocks.0.ffn.key.weight")?;
217
218 let v4 = [
219 "blocks.0.att.time_decay",
220 "blocks.0.att.time_first",
221 "blocks.0.att.time_mix_k",
222 "blocks.0.att.time_mix_v",
223 "blocks.0.att.time_mix_r",
224 ]
225 .into_iter()
226 .all(|name| model.contains(name));
227 let v5 = [
228 "blocks.0.att.gate.weight",
229 "blocks.0.att.ln_x.weight",
230 "blocks.0.att.ln_x.bias",
231 ]
232 .into_iter()
233 .all(|name| model.contains(name));
234 let v6 = [
235 "blocks.0.att.time_mix_x",
236 "blocks.0.att.time_mix_w",
237 "blocks.0.att.time_mix_k",
238 "blocks.0.att.time_mix_v",
239 "blocks.0.att.time_mix_r",
240 "blocks.0.att.time_mix_g",
241 "blocks.0.att.time_mix_w1",
242 "blocks.0.att.time_mix_w2",
243 "blocks.0.att.time_decay_w1",
244 "blocks.0.att.time_decay_w2",
245 "blocks.0.ffn.time_mix_k",
246 "blocks.0.ffn.time_mix_r",
247 ]
248 .into_iter()
249 .all(|name| model.contains(name));
250 let v7 = [
251 "blocks.0.att.x_r",
252 "blocks.0.att.x_w",
253 "blocks.0.att.x_k",
254 "blocks.0.att.x_v",
255 "blocks.0.att.x_a",
256 "blocks.0.att.x_g",
257 "blocks.0.att.w0",
258 "blocks.0.att.w1",
259 "blocks.0.att.w2",
260 "blocks.0.att.a0",
261 "blocks.0.att.a1",
262 "blocks.0.att.a2",
263 "blocks.0.att.g1",
264 "blocks.0.att.g2",
265 "blocks.0.att.r_k",
266 "blocks.0.att.k_k",
267 "blocks.0.att.k_a",
268 ]
269 .into_iter()
270 .all(|name| model.contains(name));
271
272 let version = match (v4, v5, v6, v7) {
273 (true, false, false, false) => ModelVersion::V4,
274 (_, true, false, false) => ModelVersion::V5,
275 (_, _, true, false) => ModelVersion::V6,
276 (_, _, _, true) => ModelVersion::V7,
277 _ => return Err(LoaderError::InvalidVersion),
278 };
279
280 let num_emb = embed[1];
281 let num_hidden = ffn[0];
282 let num_vocab = embed[0];
283
284 let num_head = match version {
285 ModelVersion::V4 => 1,
286 ModelVersion::V5 | ModelVersion::V6 => model.shape("blocks.0.att.time_first")?[0],
287 ModelVersion::V7 => model.shape("blocks.0.att.r_k")?[0],
288 };
289
290 let custom = match version {
291 ModelVersion::V6 => {
292 let time_mix = model.shape("blocks.0.att.time_mix_w1")?[0] / 5;
293 let time_decay = model.shape("blocks.0.att.time_decay_w1")?[0];
294 ModelCustomInfo::V6(super::v6::CustomInfo {
295 time_mix,
296 time_decay,
297 })
298 }
299 ModelVersion::V7 => {
300 let w = model.shape("blocks.0.att.w1")?[0];
301 let a = model.shape("blocks.0.att.a1")?[0];
302 let g = model.shape("blocks.0.att.g1")?[0];
303 let v = model.shape("blocks.1.att.v1")?[0];
304 ModelCustomInfo::V7(super::v7::CustomInfo { w, a, g, v })
305 }
306 _ => ModelCustomInfo::None,
307 };
308
309 Ok(ModelInfo {
310 version,
311 num_layer,
312 num_emb,
313 num_hidden,
314 num_vocab,
315 num_head,
316 custom,
317 })
318 }
319
320 fn lora_vectors(&self, name: impl AsRef<str>) -> Result<Vec<LoraVector>, LoaderError> {
323 let context = &self.context;
324 let name = name.as_ref();
325
326 let mut vectors = vec![];
327 for lora in self.lora.iter() {
328 let Some(blend) = lora
329 .blend
330 .iter()
331 .filter(|blend| blend.pattern.is_match(name))
332 .next_back()
333 else {
334 continue;
335 };
336
337 let Ok(tensor) = lora.data.tensor(name) else {
338 continue;
339 };
340 let tensor = TensorCpu::from_reader(tensor)?.to(context);
341 let alpha = blend.alpha;
342 vectors.push(LoraVector { tensor, alpha });
343
344 log::info!("vector (LoRA) {name}, alpha: {alpha}");
345 }
346 Ok(vectors)
347 }
348
349 fn lora_matrices(&self, name: impl AsRef<str>) -> Result<Vec<LoraMatrix>, LoaderError> {
352 let context = &self.context;
353 let name = name.as_ref();
354
355 let mut matrices = vec![];
356 for lora in self.lora.iter() {
357 let Some(blend) = lora
358 .blend
359 .iter()
360 .filter(|blend| blend.pattern.is_match(name))
361 .next_back()
362 else {
363 continue;
364 };
365
366 let name = name.split('.').filter(|x| !x.contains("weight")).join(".");
367 let Ok(x) = lora.data.tensor(&format!("{name}.lora.0")) else {
368 continue;
369 };
370 let Ok(y) = lora.data.tensor(&format!("{name}.lora.1")) else {
371 continue;
372 };
373
374 let rank = x.1[1];
375 let alpha = blend.alpha;
376 let x = TensorCpu::from_reader(x)?.to(context);
377 let y = TensorCpu::from_reader(y)?.to(context);
378 matrices.push(LoraMatrix { x, y, rank, alpha });
379
380 log::info!("matrix (LoRA) {name}, alpha: {alpha}, rank: {rank}");
381 }
382 Ok(matrices)
383 }
384
385 pub fn tensor_shape(&self, name: impl AsRef<str>) -> Result<Shape, LoaderError> {
386 let shape = self.model.shape(name.as_ref())?;
387 Ok(Shape::from_slice_rev(&shape)?)
388 }
389
390 pub fn load_vector_f32(
391 &self,
392 name: impl AsRef<str>,
393 ) -> Result<TensorGpu<f32, ReadWrite>, LoaderError> {
394 let context = &self.context;
395 let tensor = self.model.tensor(name.as_ref())?;
396 let tensor: TensorGpu<_, _> = TensorCpu::<f16>::from_reader(tensor)?
397 .map(|x| x.to_f32())
398 .reshape(
399 TensorDimension::Auto,
400 TensorDimension::Size(1),
401 TensorDimension::Size(1),
402 TensorDimension::Size(1),
403 )?
404 .to(context);
405
406 let mut ops = vec![];
407 for lora in self.lora_vectors(name)? {
408 let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0];
409 let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
410
411 let shape = lora.tensor.shape();
412 let tensor = tensor.reshape(
413 TensorDimension::Size(shape[0]),
414 TensorDimension::Size(shape[1]),
415 TensorDimension::Size(shape[2]),
416 TensorDimension::Size(shape[3]),
417 )?;
418
419 let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
420 ops.push(op);
421 }
422
423 context.queue.submit(context.encode(&TensorOp::List(ops)));
424 Ok(tensor)
425 }
426
427 pub fn load_vector_exp_f32(
428 &self,
429 name: impl AsRef<str>,
430 ) -> Result<TensorGpu<f32, ReadWrite>, LoaderError> {
431 let context = &self.context;
432 let tensor = self.model.tensor(name.as_ref())?;
433 let tensor: TensorGpu<_, _> = TensorCpu::<f16>::from_reader(tensor)?
434 .map(|x| x.to_f32())
436 .reshape(
437 TensorDimension::Auto,
438 TensorDimension::Size(1),
439 TensorDimension::Size(1),
440 TensorDimension::Size(1),
441 )?
442 .to(context);
443
444 let mut ops = vec![];
445 for lora in self.lora_vectors(name)? {
446 let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0];
447 let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
448
449 let shape = lora.tensor.shape();
450 let tensor = tensor.reshape(
451 TensorDimension::Size(shape[0]),
452 TensorDimension::Size(shape[1]),
453 TensorDimension::Size(shape[2]),
454 TensorDimension::Size(shape[3]),
455 )?;
456
457 let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
458 ops.push(op);
459 }
460
461 let op = TensorOp::activate(&tensor, Activation::OppositeExp)?;
462 ops.push(op);
463
464 context.queue.submit(context.encode(&TensorOp::List(ops)));
465 Ok(tensor)
466 }
467
468 pub fn load_vector_exp_exp_f32(
469 &self,
470 name: impl AsRef<str>,
471 ) -> Result<TensorGpu<f32, ReadWrite>, LoaderError> {
472 let context = &self.context;
473 let tensor = self.model.tensor(name.as_ref())?;
474 let tensor: TensorGpu<_, _> = TensorCpu::<f16>::from_reader(tensor)?
475 .map(|x| x.to_f32())
478 .reshape(
479 TensorDimension::Auto,
480 TensorDimension::Size(1),
481 TensorDimension::Size(1),
482 TensorDimension::Size(1),
483 )?
484 .to(context);
485
486 let mut ops = vec![];
487 for lora in self.lora_vectors(name)? {
488 let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0];
489 let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
490
491 let shape = lora.tensor.shape();
492 let tensor = tensor.reshape(
493 TensorDimension::Size(shape[0]),
494 TensorDimension::Size(shape[1]),
495 TensorDimension::Size(shape[2]),
496 TensorDimension::Size(shape[3]),
497 )?;
498
499 let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
500 ops.push(op);
501 }
502
503 let op = TensorOp::activate(&tensor, Activation::StableExp)?;
504 ops.push(op);
505
506 context.queue.submit(context.encode(&TensorOp::List(ops)));
507 Ok(tensor)
508 }
509
510 pub fn load_vector_f16(
511 &self,
512 name: impl AsRef<str>,
513 ) -> Result<TensorGpu<f16, ReadWrite>, LoaderError> {
514 let context = &self.context;
515 let lora = self.lora_vectors(name.as_ref())?;
516 let tensor = self.model.tensor(name.as_ref())?;
517 let tensor = if lora.is_empty() {
518 TensorCpu::from_reader(tensor)?
519 .reshape(
520 TensorDimension::Auto,
521 TensorDimension::Size(1),
522 TensorDimension::Size(1),
523 TensorDimension::Size(1),
524 )?
525 .to(context)
526 } else {
527 let tensor_f32: TensorGpu<f32, _> = TensorCpu::<f16>::from_reader(tensor)?
528 .map(|x| x.to_f32())
529 .reshape(
530 TensorDimension::Auto,
531 TensorDimension::Size(1),
532 TensorDimension::Size(1),
533 TensorDimension::Size(1),
534 )?
535 .to(context);
536 let tensor_f16: TensorGpu<f16, _> = context.tensor_init(tensor_f32.shape());
537
538 let mut ops = vec![];
539 for lora in lora {
540 let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0];
541 let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
542
543 let shape = lora.tensor.shape();
544 let tensor = tensor_f32.reshape(
545 TensorDimension::Size(shape[0]),
546 TensorDimension::Size(shape[1]),
547 TensorDimension::Size(shape[2]),
548 TensorDimension::Size(shape[3]),
549 )?;
550
551 let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
552 ops.push(op);
553 }
554
555 let op = TensorOp::blit(&tensor_f32, &tensor_f16)?;
556 ops.push(op);
557
558 context.queue.submit(context.encode(&TensorOp::List(ops)));
559 tensor_f16
560 };
561 Ok(tensor)
562 }
563
564 pub fn load_matrix_f16(
565 &self,
566 name: impl AsRef<str>,
567 ) -> Result<TensorGpu<f16, ReadWrite>, LoaderError> {
568 let context = &self.context;
569 let tensor = self.model.tensor(name.as_ref())?;
570 let tensor: TensorGpu<_, _> = TensorCpu::from_reader(tensor)?.to(context);
571
572 let mut ops = vec![];
573 for lora in self.lora_matrices(name.as_ref())? {
574 let factor = vec![lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0];
575 let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
576 let op = TensorOp::blend_lora(&factor, &lora.x, &lora.y, &tensor)?;
577 ops.push(op);
578 }
579 for lora in self.lora_vectors(name.as_ref())? {
580 let factor = vec![lora.alpha, 1.0, 0.0, 0.0];
581 let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
582 let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
583 ops.push(op);
584 }
585
586 context.queue.submit(context.encode(&TensorOp::List(ops)));
587 Ok(tensor)
588 }
589
590 pub fn load_matrix_f16_discount(
591 &self,
592 name: impl AsRef<str>,
593 discount: f32,
594 ) -> Result<TensorGpu<f16, ReadWrite>, LoaderError> {
595 let context = &self.context;
596 let tensor = self.model.tensor(name.as_ref())?;
597 let tensor: TensorGpu<_, _> = TensorCpu::<f16>::from_reader(tensor)?
598 .map(|x| f16::from_f32(discount * x.to_f32()))
599 .to(context);
600
601 let mut ops = vec![];
602 for lora in self.lora_matrices(name.as_ref())? {
603 let factor = vec![discount * lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0];
604 let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
605 let op = TensorOp::blend_lora(&factor, &lora.x, &lora.y, &tensor)?;
606 ops.push(op);
607 }
608 for lora in self.lora_vectors(name.as_ref())? {
609 let factor = vec![discount * lora.alpha, 1.0, 0.0, 0.0];
610 let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
611 let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
612 ops.push(op);
613 }
614
615 context.queue.submit(context.encode(&TensorOp::List(ops)));
616 Ok(tensor)
617 }
618
619 pub fn load_in_place_matrix_f16(
620 &self,
621 matrix: &TensorGpu<f16, ReadWrite>,
622 name: impl AsRef<str>,
623 ) -> Result<(), LoaderError> {
624 let context = &self.context;
625 let tensor = self.model.tensor(name.as_ref())?;
626 let tensor = TensorCpu::from_reader(tensor)?;
627 matrix.load(&tensor)?;
628
629 let mut ops = vec![];
630 for lora in self.lora_matrices(name.as_ref())? {
631 let factor = vec![lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0];
632 let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
633 let op = TensorOp::blend_lora(&factor, &lora.x, &lora.y, matrix)?;
634 ops.push(op);
635 }
636 for lora in self.lora_vectors(name.as_ref())? {
637 let factor = vec![lora.alpha, 1.0, 0.0, 0.0];
638 let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
639 let op = TensorOp::blend(&factor, &lora.tensor, matrix)?;
640 ops.push(op);
641 }
642
643 context.queue.submit(context.encode(&TensorOp::List(ops)));
644 Ok(())
645 }
646
647 pub fn load_in_place_matrix_f16_discount(
648 &self,
649 matrix: &TensorGpu<f16, ReadWrite>,
650 name: impl AsRef<str>,
651 discount: f32,
652 ) -> Result<(), LoaderError> {
653 let context = &self.context;
654
655 let tensor = self.model.tensor(name.as_ref())?;
656 let tensor = TensorCpu::<f16>::from_reader(tensor)?
657 .map(|x| f16::from_f32(discount * x.to_f32()))
658 .reshape(
659 TensorDimension::Full,
660 TensorDimension::Full,
661 TensorDimension::Size(1),
662 TensorDimension::Size(1),
663 )?;
664 matrix.load(&tensor)?;
665
666 let mut ops = vec![];
667 for lora in self.lora_matrices(name.as_ref())? {
668 let factor = vec![discount * lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0];
669 let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
670 let op = TensorOp::blend_lora(&factor, &lora.x, &lora.y, matrix)?;
671 ops.push(op);
672 }
673 for lora in self.lora_vectors(name.as_ref())? {
674 let factor = vec![discount * lora.alpha, 1.0, 0.0, 0.0];
675 let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
676 let op = TensorOp::blend(&factor, &lora.tensor, matrix)?;
677 ops.push(op);
678 }
679
680 context.queue.submit(context.encode(&TensorOp::List(ops)));
681 Ok(())
682 }
683
684 pub fn load_matrix_f16_padded_cpu(
685 &self,
686 name: impl AsRef<str>,
687 ) -> Result<TensorCpu<f16>, LoaderError> {
688 let (dt, shape, tensor) = self.model.tensor(name.as_ref())?;
689 let tensor = TensorCpu::from_reader((dt, shape, tensor))?.pad(PAD_MAT);
690 Ok(tensor)
691 }
692
693 pub fn load_matrix_f16_padded(
694 &self,
695 name: impl AsRef<str>,
696 ) -> Result<TensorGpu<f16, ReadWrite>, LoaderError> {
697 let context = &self.context;
698 let (dt, shape, tensor) = self.model.tensor(name.as_ref())?;
699 let tensor = TensorCpu::from_reader((dt, shape, tensor))?
700 .pad(PAD_MAT)
701 .to(context);
702 Ok(tensor)
703 }
704
705 pub fn load_matrix(&self, name: String, quant: Quant) -> Result<Matrix, LoaderError> {
706 let context = &self.context;
707 match quant {
708 Quant::None => Ok(Matrix::Fp16(self.load_matrix_f16(name)?)),
709 Quant::Int8 => {
710 let shape = self.tensor_shape(&name)?;
711 let buffer = context.tensor_init(shape);
712 self.load_in_place_matrix_f16(&buffer, &name)?;
713 Ok(Matrix::quant_u8(&buffer)?)
714 }
715 Quant::NF4 => {
716 let shape = self.tensor_shape(&name)?;
717 let buffer = context.tensor_init(shape);
718 self.load_in_place_matrix_f16(&buffer, &name)?;
719 Ok(Matrix::quant_nf4(&buffer)?)
720 }
721 Quant::SF4 => {
722 let shape = self.tensor_shape(&name)?;
723 let buffer = context.tensor_init(shape);
724 self.load_in_place_matrix_f16(&buffer, &name)?;
725 Ok(Matrix::quant_sf4(&buffer, 5.0)?)
726 }
727 }
728 }
729
730 pub fn load_matrix_discount(
731 &self,
732 name: String,
733 quant: Quant,
734 discount: f32,
735 ) -> Result<Matrix, LoaderError> {
736 let context = &self.context;
737 match quant {
738 Quant::None => Ok(Matrix::Fp16(self.load_matrix_f16_discount(name, discount)?)),
739 Quant::Int8 => {
740 let shape = self.tensor_shape(&name)?;
741 let buffer = context.tensor_init(shape);
742 self.load_in_place_matrix_f16_discount(&buffer, &name, discount)?;
743 Ok(Matrix::quant_u8(&buffer)?)
744 }
745 Quant::NF4 => {
746 let shape = self.tensor_shape(&name)?;
747 let buffer = context.tensor_init(shape);
748 self.load_in_place_matrix_f16_discount(&buffer, &name, discount)?;
749 Ok(Matrix::quant_nf4(&buffer)?)
750 }
751 Quant::SF4 => {
752 let shape = self.tensor_shape(&name)?;
753 let buffer = context.tensor_init(shape);
754 self.load_in_place_matrix_f16_discount(&buffer, &name, discount)?;
755 Ok(Matrix::quant_sf4(&buffer, 5.0)?)
756 }
757 }
758 }
759}