web_rwkv/runtime/
loader.rs

1use std::borrow::Cow;
2
3use half::f16;
4use itertools::Itertools;
5use regex::Regex;
6use safetensors::{Dtype, SafeTensorError, SafeTensors};
7use thiserror::Error;
8use web_rwkv_derive::{Deref, DerefMut};
9
10use super::model::{ModelCustomInfo, ModelInfo, ModelVersion, Quant};
11use crate::{
12    context::Context,
13    num::Scalar,
14    tensor::{
15        kind::ReadWrite,
16        matrix::Matrix,
17        ops::{Activation, TensorOp},
18        shape::{Shape, TensorDimension},
19        TensorCpu, TensorError, TensorErrorKind, TensorGpu, TensorInit, TensorInto, TensorReshape,
20        TensorShape,
21    },
22};
23
24pub const PAD_VEC: [usize; 4] = [8, 1, 1, 1];
25pub const PAD_MAT: [usize; 4] = [8, 8, 1, 1];
26
27#[derive(Debug, Error)]
28pub enum LoaderError {
29    #[error("invalid model version")]
30    InvalidVersion,
31    #[error("tensor error")]
32    TensorError(#[from] TensorError),
33    #[error("failed to load safe tensor")]
34    SafeTensor(#[from] safetensors::SafeTensorError),
35    #[error("failed to parse int")]
36    ParseIntError(#[from] std::num::ParseIntError),
37    #[error("failed to parse regex")]
38    RegexError(#[from] regex::Error),
39}
40
41pub type ReaderTensor<'a> = (Dtype, Vec<usize>, Cow<'a, [u8]>);
42
43/// Interface accessing a safetensors data blob.
44pub trait Reader {
45    fn names(&self) -> Vec<&str>;
46    fn contains(&self, name: &str) -> bool;
47    fn shape(&self, name: &str) -> Result<Vec<usize>, SafeTensorError>;
48    fn tensor(&self, name: &str) -> Result<ReaderTensor<'_>, SafeTensorError>;
49}
50
51impl Reader for SafeTensors<'_> {
52    #[inline]
53    fn names(&self) -> Vec<&str> {
54        self.names().into_iter().map(AsRef::as_ref).collect()
55    }
56
57    #[inline]
58    fn contains(&self, name: &str) -> bool {
59        self.names().contains(&name)
60    }
61
62    #[inline]
63    fn shape(&self, name: &str) -> Result<Vec<usize>, SafeTensorError> {
64        Ok(self.tensor(name)?.shape().to_vec())
65    }
66
67    #[inline]
68    fn tensor(&self, name: &str) -> Result<ReaderTensor<'_>, SafeTensorError> {
69        let tensor = SafeTensors::tensor(self, name)?;
70        let shape = tensor.shape().to_vec();
71        let data = tensor.data().into();
72        Ok((tensor.dtype(), shape, data))
73    }
74}
75
76pub trait TensorFromReader<T: Scalar> {
77    /// Create a tensor from safetensors reader.
78    fn from_reader(reader: ReaderTensor) -> Result<TensorCpu<T>, TensorError>;
79}
80
81impl<T: Scalar> TensorFromReader<T> for TensorCpu<T> {
82    fn from_reader((dt, shape, data): ReaderTensor) -> Result<Self, TensorError> {
83        if T::DATA_TYPE != dt {
84            Err(TensorErrorKind::Type)?;
85        }
86        let shape = Shape::from_slice_rev(&shape)?;
87        match data {
88            Cow::Borrowed(data) => Self::from_data(shape, bytemuck::cast_slice(data)),
89            Cow::Owned(data) => {
90                let data = bytemuck::cast_slice(&data);
91                let data = Cow::Owned(data.to_vec());
92                Self::from_data(shape, data)
93            }
94        }
95    }
96}
97
98/// A LoRA that adds to the model when loading.
99#[derive(Clone)]
100pub struct Lora<R> {
101    /// Binary safetensors LoRA content.
102    pub data: R,
103    /// A list of LoRA blend patterns.
104    /// A blend pattern is a regex that matches the name of multiple tensors, and a blend factor.
105    /// When applying the patterns, they are applied in order.
106    pub blend: LoraBlend,
107}
108
109/// A list of LoRA blend patterns.
110#[derive(Debug, Default, Clone, Deref, DerefMut)]
111pub struct LoraBlend(pub Vec<LoraBlendPattern>);
112
113impl LoraBlend {
114    /// Build a blend pattern that replaces all vectors, and adds to all matrices with `alpha`.
115    #[inline]
116    pub fn full(alpha: f32) -> Self {
117        Self::default().add_nominal(1.0).add_matrices(alpha)
118    }
119
120    /// Add a blend pattern that interpolates tensors with factor `alpha` from 0 to 1.
121    #[inline]
122    pub fn add_nominal(mut self, alpha: f32) -> Self {
123        let pattern = LoraBlendPattern::new(r".+", alpha).unwrap();
124        self.push(pattern);
125        self
126    }
127
128    /// Add a blend pattern that adds to all matrices with `alpha`.
129    #[inline]
130    pub fn add_matrices(mut self, alpha: f32) -> Self {
131        let pattern = LoraBlendPattern::new(
132            r"blocks\.([0-9]+)\.(att|ffn)\.(key|value|receptance|gate|output)\.weight",
133            alpha,
134        )
135        .unwrap();
136        self.push(pattern);
137        self
138    }
139
140    /// Add a blend pattern that interpolates tensors in a layer with factor `alpha` from 0 to 1.
141    pub fn add_layer_nominal(mut self, layer: usize, alpha: f32) -> Self {
142        let pattern = format!(r"blocks\.{layer}");
143        let pattern = LoraBlendPattern::new(&pattern, alpha).unwrap();
144        self.push(pattern);
145        self
146    }
147
148    /// Add a blend pattern that adds to all matrices in a layer with `alpha`.
149    pub fn add_layer_matrices(mut self, layer: usize, alpha: f32) -> Self {
150        let pattern =
151            format!(r"blocks\.{layer}\.(att|ffn)\.(key|value|receptance|gate|output)\.weight");
152        let pattern = LoraBlendPattern::new(&pattern, alpha).unwrap();
153        self.push(pattern);
154        self
155    }
156}
157
158/// A blend pattern is a regex that matches the name of multiple tensors, and a blend factor.
159#[derive(Debug, Clone)]
160pub struct LoraBlendPattern {
161    /// A regex pattern that matches tensors in the model.
162    pattern: Regex,
163    /// The blend factor.
164    alpha: f32,
165}
166
167impl LoraBlendPattern {
168    #[inline]
169    pub fn new(pattern: &str, alpha: f32) -> Result<Self, LoaderError> {
170        Ok(Self {
171            pattern: Regex::new(pattern)?,
172            alpha,
173        })
174    }
175
176    #[inline]
177    pub fn alpha(&self) -> f32 {
178        self.alpha
179    }
180}
181
182struct LoraVector {
183    tensor: TensorGpu<f16, ReadWrite>,
184    alpha: f32,
185}
186
187struct LoraMatrix {
188    x: TensorGpu<f16, ReadWrite>,
189    y: TensorGpu<f16, ReadWrite>,
190    rank: usize,
191    alpha: f32,
192}
193
194#[derive(Clone)]
195pub struct Loader<R> {
196    pub context: Context,
197    pub model: R,
198    pub lora: Vec<Lora<R>>,
199}
200
201impl<R: Reader> Loader<R> {
202    pub fn info(model: &R) -> Result<ModelInfo, LoaderError> {
203        let num_layer = {
204            let mut r: usize = 0;
205            for i in model.names() {
206                const PREFIX: &str = "blocks.";
207                if let Some(i) = i.strip_prefix(PREFIX) {
208                    let i = &i[..i.find('.').unwrap_or(0)];
209                    r = r.max(i.parse::<usize>()?)
210                }
211            }
212            r + 1
213        };
214
215        let embed = model.shape("emb.weight")?;
216        let ffn = model.shape("blocks.0.ffn.key.weight")?;
217
218        let v4 = [
219            "blocks.0.att.time_decay",
220            "blocks.0.att.time_first",
221            "blocks.0.att.time_mix_k",
222            "blocks.0.att.time_mix_v",
223            "blocks.0.att.time_mix_r",
224        ]
225        .into_iter()
226        .all(|name| model.contains(name));
227        let v5 = [
228            "blocks.0.att.gate.weight",
229            "blocks.0.att.ln_x.weight",
230            "blocks.0.att.ln_x.bias",
231        ]
232        .into_iter()
233        .all(|name| model.contains(name));
234        let v6 = [
235            "blocks.0.att.time_mix_x",
236            "blocks.0.att.time_mix_w",
237            "blocks.0.att.time_mix_k",
238            "blocks.0.att.time_mix_v",
239            "blocks.0.att.time_mix_r",
240            "blocks.0.att.time_mix_g",
241            "blocks.0.att.time_mix_w1",
242            "blocks.0.att.time_mix_w2",
243            "blocks.0.att.time_decay_w1",
244            "blocks.0.att.time_decay_w2",
245            "blocks.0.ffn.time_mix_k",
246            "blocks.0.ffn.time_mix_r",
247        ]
248        .into_iter()
249        .all(|name| model.contains(name));
250        let v7 = [
251            "blocks.0.att.x_r",
252            "blocks.0.att.x_w",
253            "blocks.0.att.x_k",
254            "blocks.0.att.x_v",
255            "blocks.0.att.x_a",
256            "blocks.0.att.x_g",
257            "blocks.0.att.w0",
258            "blocks.0.att.w1",
259            "blocks.0.att.w2",
260            "blocks.0.att.a0",
261            "blocks.0.att.a1",
262            "blocks.0.att.a2",
263            "blocks.0.att.g1",
264            "blocks.0.att.g2",
265            "blocks.0.att.r_k",
266            "blocks.0.att.k_k",
267            "blocks.0.att.k_a",
268        ]
269        .into_iter()
270        .all(|name| model.contains(name));
271
272        let version = match (v4, v5, v6, v7) {
273            (true, false, false, false) => ModelVersion::V4,
274            (_, true, false, false) => ModelVersion::V5,
275            (_, _, true, false) => ModelVersion::V6,
276            (_, _, _, true) => ModelVersion::V7,
277            _ => return Err(LoaderError::InvalidVersion),
278        };
279
280        let num_emb = embed[1];
281        let num_hidden = ffn[0];
282        let num_vocab = embed[0];
283
284        let num_head = match version {
285            ModelVersion::V4 => 1,
286            ModelVersion::V5 | ModelVersion::V6 => model.shape("blocks.0.att.time_first")?[0],
287            ModelVersion::V7 => model.shape("blocks.0.att.r_k")?[0],
288        };
289
290        let custom = match version {
291            ModelVersion::V6 => {
292                let time_mix = model.shape("blocks.0.att.time_mix_w1")?[0] / 5;
293                let time_decay = model.shape("blocks.0.att.time_decay_w1")?[0];
294                ModelCustomInfo::V6(super::v6::CustomInfo {
295                    time_mix,
296                    time_decay,
297                })
298            }
299            ModelVersion::V7 => {
300                let w = model.shape("blocks.0.att.w1")?[0];
301                let a = model.shape("blocks.0.att.a1")?[0];
302                let g = model.shape("blocks.0.att.g1")?[0];
303                let v = model.shape("blocks.1.att.v1")?[0];
304                ModelCustomInfo::V7(super::v7::CustomInfo { w, a, g, v })
305            }
306            _ => ModelCustomInfo::None,
307        };
308
309        Ok(ModelInfo {
310            version,
311            num_layer,
312            num_emb,
313            num_hidden,
314            num_vocab,
315            num_head,
316            custom,
317        })
318    }
319
320    /// Load all lora and blend factors about the vector with a given name.
321    /// In each LoRA, only the last matched pattern is loaded.
322    fn lora_vectors(&self, name: impl AsRef<str>) -> Result<Vec<LoraVector>, LoaderError> {
323        let context = &self.context;
324        let name = name.as_ref();
325
326        let mut vectors = vec![];
327        for lora in self.lora.iter() {
328            let Some(blend) = lora
329                .blend
330                .iter()
331                .filter(|blend| blend.pattern.is_match(name))
332                .next_back()
333            else {
334                continue;
335            };
336
337            let Ok(tensor) = lora.data.tensor(name) else {
338                continue;
339            };
340            let tensor = TensorCpu::from_reader(tensor)?.to(context);
341            let alpha = blend.alpha;
342            vectors.push(LoraVector { tensor, alpha });
343
344            log::info!("vector (LoRA) {name}, alpha: {alpha}");
345        }
346        Ok(vectors)
347    }
348
349    /// Load all lora and blend factors about the matrix with a given name.
350    /// In each LoRA, only the last matched pattern is loaded.
351    fn lora_matrices(&self, name: impl AsRef<str>) -> Result<Vec<LoraMatrix>, LoaderError> {
352        let context = &self.context;
353        let name = name.as_ref();
354
355        let mut matrices = vec![];
356        for lora in self.lora.iter() {
357            let Some(blend) = lora
358                .blend
359                .iter()
360                .filter(|blend| blend.pattern.is_match(name))
361                .next_back()
362            else {
363                continue;
364            };
365
366            let name = name.split('.').filter(|x| !x.contains("weight")).join(".");
367            let Ok(x) = lora.data.tensor(&format!("{name}.lora.0")) else {
368                continue;
369            };
370            let Ok(y) = lora.data.tensor(&format!("{name}.lora.1")) else {
371                continue;
372            };
373
374            let rank = x.1[1];
375            let alpha = blend.alpha;
376            let x = TensorCpu::from_reader(x)?.to(context);
377            let y = TensorCpu::from_reader(y)?.to(context);
378            matrices.push(LoraMatrix { x, y, rank, alpha });
379
380            log::info!("matrix (LoRA) {name}, alpha: {alpha}, rank: {rank}");
381        }
382        Ok(matrices)
383    }
384
385    pub fn tensor_shape(&self, name: impl AsRef<str>) -> Result<Shape, LoaderError> {
386        let shape = self.model.shape(name.as_ref())?;
387        Ok(Shape::from_slice_rev(&shape)?)
388    }
389
390    pub fn load_vector_f32(
391        &self,
392        name: impl AsRef<str>,
393    ) -> Result<TensorGpu<f32, ReadWrite>, LoaderError> {
394        let context = &self.context;
395        let tensor = self.model.tensor(name.as_ref())?;
396        let tensor: TensorGpu<_, _> = TensorCpu::<f16>::from_reader(tensor)?
397            .map(|x| x.to_f32())
398            .reshape(
399                TensorDimension::Auto,
400                TensorDimension::Size(1),
401                TensorDimension::Size(1),
402                TensorDimension::Size(1),
403            )?
404            .to(context);
405
406        let mut ops = vec![];
407        for lora in self.lora_vectors(name)? {
408            let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0];
409            let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
410
411            let shape = lora.tensor.shape();
412            let tensor = tensor.reshape(
413                TensorDimension::Size(shape[0]),
414                TensorDimension::Size(shape[1]),
415                TensorDimension::Size(shape[2]),
416                TensorDimension::Size(shape[3]),
417            )?;
418
419            let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
420            ops.push(op);
421        }
422
423        context.queue.submit(context.encode(&TensorOp::List(ops)));
424        Ok(tensor)
425    }
426
427    pub fn load_vector_exp_f32(
428        &self,
429        name: impl AsRef<str>,
430    ) -> Result<TensorGpu<f32, ReadWrite>, LoaderError> {
431        let context = &self.context;
432        let tensor = self.model.tensor(name.as_ref())?;
433        let tensor: TensorGpu<_, _> = TensorCpu::<f16>::from_reader(tensor)?
434            // .map(|x| -x.to_f32().exp())
435            .map(|x| x.to_f32())
436            .reshape(
437                TensorDimension::Auto,
438                TensorDimension::Size(1),
439                TensorDimension::Size(1),
440                TensorDimension::Size(1),
441            )?
442            .to(context);
443
444        let mut ops = vec![];
445        for lora in self.lora_vectors(name)? {
446            let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0];
447            let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
448
449            let shape = lora.tensor.shape();
450            let tensor = tensor.reshape(
451                TensorDimension::Size(shape[0]),
452                TensorDimension::Size(shape[1]),
453                TensorDimension::Size(shape[2]),
454                TensorDimension::Size(shape[3]),
455            )?;
456
457            let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
458            ops.push(op);
459        }
460
461        let op = TensorOp::activate(&tensor, Activation::OppositeExp)?;
462        ops.push(op);
463
464        context.queue.submit(context.encode(&TensorOp::List(ops)));
465        Ok(tensor)
466    }
467
468    pub fn load_vector_exp_exp_f32(
469        &self,
470        name: impl AsRef<str>,
471    ) -> Result<TensorGpu<f32, ReadWrite>, LoaderError> {
472        let context = &self.context;
473        let tensor = self.model.tensor(name.as_ref())?;
474        let tensor: TensorGpu<_, _> = TensorCpu::<f16>::from_reader(tensor)?
475            // .map(|x| -x.to_f32().exp())
476            // .map(|x| x.exp())
477            .map(|x| x.to_f32())
478            .reshape(
479                TensorDimension::Auto,
480                TensorDimension::Size(1),
481                TensorDimension::Size(1),
482                TensorDimension::Size(1),
483            )?
484            .to(context);
485
486        let mut ops = vec![];
487        for lora in self.lora_vectors(name)? {
488            let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0];
489            let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
490
491            let shape = lora.tensor.shape();
492            let tensor = tensor.reshape(
493                TensorDimension::Size(shape[0]),
494                TensorDimension::Size(shape[1]),
495                TensorDimension::Size(shape[2]),
496                TensorDimension::Size(shape[3]),
497            )?;
498
499            let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
500            ops.push(op);
501        }
502
503        let op = TensorOp::activate(&tensor, Activation::StableExp)?;
504        ops.push(op);
505
506        context.queue.submit(context.encode(&TensorOp::List(ops)));
507        Ok(tensor)
508    }
509
510    pub fn load_vector_f16(
511        &self,
512        name: impl AsRef<str>,
513    ) -> Result<TensorGpu<f16, ReadWrite>, LoaderError> {
514        let context = &self.context;
515        let lora = self.lora_vectors(name.as_ref())?;
516        let tensor = self.model.tensor(name.as_ref())?;
517        let tensor = if lora.is_empty() {
518            TensorCpu::from_reader(tensor)?
519                .reshape(
520                    TensorDimension::Auto,
521                    TensorDimension::Size(1),
522                    TensorDimension::Size(1),
523                    TensorDimension::Size(1),
524                )?
525                .to(context)
526        } else {
527            let tensor_f32: TensorGpu<f32, _> = TensorCpu::<f16>::from_reader(tensor)?
528                .map(|x| x.to_f32())
529                .reshape(
530                    TensorDimension::Auto,
531                    TensorDimension::Size(1),
532                    TensorDimension::Size(1),
533                    TensorDimension::Size(1),
534                )?
535                .to(context);
536            let tensor_f16: TensorGpu<f16, _> = context.tensor_init(tensor_f32.shape());
537
538            let mut ops = vec![];
539            for lora in lora {
540                let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0];
541                let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
542
543                let shape = lora.tensor.shape();
544                let tensor = tensor_f32.reshape(
545                    TensorDimension::Size(shape[0]),
546                    TensorDimension::Size(shape[1]),
547                    TensorDimension::Size(shape[2]),
548                    TensorDimension::Size(shape[3]),
549                )?;
550
551                let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
552                ops.push(op);
553            }
554
555            let op = TensorOp::blit(&tensor_f32, &tensor_f16)?;
556            ops.push(op);
557
558            context.queue.submit(context.encode(&TensorOp::List(ops)));
559            tensor_f16
560        };
561        Ok(tensor)
562    }
563
564    pub fn load_matrix_f16(
565        &self,
566        name: impl AsRef<str>,
567    ) -> Result<TensorGpu<f16, ReadWrite>, LoaderError> {
568        let context = &self.context;
569        let tensor = self.model.tensor(name.as_ref())?;
570        let tensor: TensorGpu<_, _> = TensorCpu::from_reader(tensor)?.to(context);
571
572        let mut ops = vec![];
573        for lora in self.lora_matrices(name.as_ref())? {
574            let factor = vec![lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0];
575            let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
576            let op = TensorOp::blend_lora(&factor, &lora.x, &lora.y, &tensor)?;
577            ops.push(op);
578        }
579        for lora in self.lora_vectors(name.as_ref())? {
580            let factor = vec![lora.alpha, 1.0, 0.0, 0.0];
581            let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
582            let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
583            ops.push(op);
584        }
585
586        context.queue.submit(context.encode(&TensorOp::List(ops)));
587        Ok(tensor)
588    }
589
590    pub fn load_matrix_f16_discount(
591        &self,
592        name: impl AsRef<str>,
593        discount: f32,
594    ) -> Result<TensorGpu<f16, ReadWrite>, LoaderError> {
595        let context = &self.context;
596        let tensor = self.model.tensor(name.as_ref())?;
597        let tensor: TensorGpu<_, _> = TensorCpu::<f16>::from_reader(tensor)?
598            .map(|x| f16::from_f32(discount * x.to_f32()))
599            .to(context);
600
601        let mut ops = vec![];
602        for lora in self.lora_matrices(name.as_ref())? {
603            let factor = vec![discount * lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0];
604            let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
605            let op = TensorOp::blend_lora(&factor, &lora.x, &lora.y, &tensor)?;
606            ops.push(op);
607        }
608        for lora in self.lora_vectors(name.as_ref())? {
609            let factor = vec![discount * lora.alpha, 1.0, 0.0, 0.0];
610            let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
611            let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
612            ops.push(op);
613        }
614
615        context.queue.submit(context.encode(&TensorOp::List(ops)));
616        Ok(tensor)
617    }
618
619    pub fn load_in_place_matrix_f16(
620        &self,
621        matrix: &TensorGpu<f16, ReadWrite>,
622        name: impl AsRef<str>,
623    ) -> Result<(), LoaderError> {
624        let context = &self.context;
625        let tensor = self.model.tensor(name.as_ref())?;
626        let tensor = TensorCpu::from_reader(tensor)?;
627        matrix.load(&tensor)?;
628
629        let mut ops = vec![];
630        for lora in self.lora_matrices(name.as_ref())? {
631            let factor = vec![lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0];
632            let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
633            let op = TensorOp::blend_lora(&factor, &lora.x, &lora.y, matrix)?;
634            ops.push(op);
635        }
636        for lora in self.lora_vectors(name.as_ref())? {
637            let factor = vec![lora.alpha, 1.0, 0.0, 0.0];
638            let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
639            let op = TensorOp::blend(&factor, &lora.tensor, matrix)?;
640            ops.push(op);
641        }
642
643        context.queue.submit(context.encode(&TensorOp::List(ops)));
644        Ok(())
645    }
646
647    pub fn load_in_place_matrix_f16_discount(
648        &self,
649        matrix: &TensorGpu<f16, ReadWrite>,
650        name: impl AsRef<str>,
651        discount: f32,
652    ) -> Result<(), LoaderError> {
653        let context = &self.context;
654
655        let tensor = self.model.tensor(name.as_ref())?;
656        let tensor = TensorCpu::<f16>::from_reader(tensor)?
657            .map(|x| f16::from_f32(discount * x.to_f32()))
658            .reshape(
659                TensorDimension::Full,
660                TensorDimension::Full,
661                TensorDimension::Size(1),
662                TensorDimension::Size(1),
663            )?;
664        matrix.load(&tensor)?;
665
666        let mut ops = vec![];
667        for lora in self.lora_matrices(name.as_ref())? {
668            let factor = vec![discount * lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0];
669            let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
670            let op = TensorOp::blend_lora(&factor, &lora.x, &lora.y, matrix)?;
671            ops.push(op);
672        }
673        for lora in self.lora_vectors(name.as_ref())? {
674            let factor = vec![discount * lora.alpha, 1.0, 0.0, 0.0];
675            let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
676            let op = TensorOp::blend(&factor, &lora.tensor, matrix)?;
677            ops.push(op);
678        }
679
680        context.queue.submit(context.encode(&TensorOp::List(ops)));
681        Ok(())
682    }
683
684    pub fn load_matrix_f16_padded_cpu(
685        &self,
686        name: impl AsRef<str>,
687    ) -> Result<TensorCpu<f16>, LoaderError> {
688        let (dt, shape, tensor) = self.model.tensor(name.as_ref())?;
689        let tensor = TensorCpu::from_reader((dt, shape, tensor))?.pad(PAD_MAT);
690        Ok(tensor)
691    }
692
693    pub fn load_matrix_f16_padded(
694        &self,
695        name: impl AsRef<str>,
696    ) -> Result<TensorGpu<f16, ReadWrite>, LoaderError> {
697        let context = &self.context;
698        let (dt, shape, tensor) = self.model.tensor(name.as_ref())?;
699        let tensor = TensorCpu::from_reader((dt, shape, tensor))?
700            .pad(PAD_MAT)
701            .to(context);
702        Ok(tensor)
703    }
704
705    pub fn load_matrix(&self, name: String, quant: Quant) -> Result<Matrix, LoaderError> {
706        let context = &self.context;
707        match quant {
708            Quant::None => Ok(Matrix::Fp16(self.load_matrix_f16(name)?)),
709            Quant::Int8 => {
710                let shape = self.tensor_shape(&name)?;
711                let buffer = context.tensor_init(shape);
712                self.load_in_place_matrix_f16(&buffer, &name)?;
713                Ok(Matrix::quant_u8(&buffer)?)
714            }
715            Quant::NF4 => {
716                let shape = self.tensor_shape(&name)?;
717                let buffer = context.tensor_init(shape);
718                self.load_in_place_matrix_f16(&buffer, &name)?;
719                Ok(Matrix::quant_nf4(&buffer)?)
720            }
721            Quant::SF4 => {
722                let shape = self.tensor_shape(&name)?;
723                let buffer = context.tensor_init(shape);
724                self.load_in_place_matrix_f16(&buffer, &name)?;
725                Ok(Matrix::quant_sf4(&buffer, 5.0)?)
726            }
727        }
728    }
729
730    pub fn load_matrix_discount(
731        &self,
732        name: String,
733        quant: Quant,
734        discount: f32,
735    ) -> Result<Matrix, LoaderError> {
736        let context = &self.context;
737        match quant {
738            Quant::None => Ok(Matrix::Fp16(self.load_matrix_f16_discount(name, discount)?)),
739            Quant::Int8 => {
740                let shape = self.tensor_shape(&name)?;
741                let buffer = context.tensor_init(shape);
742                self.load_in_place_matrix_f16_discount(&buffer, &name, discount)?;
743                Ok(Matrix::quant_u8(&buffer)?)
744            }
745            Quant::NF4 => {
746                let shape = self.tensor_shape(&name)?;
747                let buffer = context.tensor_init(shape);
748                self.load_in_place_matrix_f16_discount(&buffer, &name, discount)?;
749                Ok(Matrix::quant_nf4(&buffer)?)
750            }
751            Quant::SF4 => {
752                let shape = self.tensor_shape(&name)?;
753                let buffer = context.tensor_init(shape);
754                self.load_in_place_matrix_f16_discount(&buffer, &name, discount)?;
755                Ok(Matrix::quant_sf4(&buffer, 5.0)?)
756            }
757        }
758    }
759}