web_rwkv/runtime/
model.rs

1use std::{any::Any, collections::HashMap};
2
3#[cfg(not(target_arch = "wasm32"))]
4use futures::future::BoxFuture;
5#[cfg(target_arch = "wasm32")]
6use futures::future::LocalBoxFuture;
7use half::f16;
8use serde::{Deserialize, Serialize};
9use wasm_bindgen::prelude::wasm_bindgen;
10
11use super::loader::{Lora, Reader, PAD_MAT};
12use crate::{
13    context::{Context, ContextBuilder},
14    impl_deserialize_seed,
15    num::Scalar,
16    tensor::{kind::ReadWrite, shape::Shape, TensorCpu, TensorError, TensorGpu, TensorGpuView},
17};
18
19#[wasm_bindgen]
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
21pub enum ModelVersion {
22    V4,
23    V5,
24    V6,
25    V7,
26}
27
28#[wasm_bindgen]
29#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
30pub struct ModelInfo {
31    pub version: ModelVersion,
32    pub num_layer: usize,
33    pub num_emb: usize,
34    pub num_hidden: usize,
35    pub num_vocab: usize,
36    pub num_head: usize,
37    #[wasm_bindgen(skip)]
38    pub custom: ModelCustomInfo,
39}
40
41impl ModelInfo {
42    pub const BUFFER_SIZE: usize = 256 << 20;
43    pub const STORAGE_BUFFER_BINDING_SIZE: usize = 128 << 20;
44}
45
46impl_deserialize_seed!(ModelInfo);
47
48#[wasm_bindgen]
49impl ModelInfo {
50    /// The required storage buffer size, not including head.
51    pub fn max_non_head_buffer_size(&self) -> usize {
52        self.num_emb * self.num_hidden * f16::size()
53    }
54
55    /// The head and embed's size.
56    pub fn head_buffer_size(&self) -> usize {
57        self.num_emb * self.num_vocab_padded() * f16::size()
58    }
59
60    pub fn num_vocab_padded(&self) -> usize {
61        self.num_vocab.next_multiple_of(PAD_MAT[1])
62    }
63}
64
65/// Info about the model's inner LoRA dimensions.
66#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
67pub enum ModelCustomInfo {
68    #[default]
69    None,
70    V6(super::v6::CustomInfo),
71    V7(super::v7::CustomInfo),
72}
73
74pub trait AsAny {
75    fn as_any(&self) -> &dyn Any;
76}
77
78pub trait State {
79    /// Batch number of this state.
80    fn num_batch(&self) -> usize;
81    /// Shape of the initialized one-batch CPU state.
82    fn init_shape(&self) -> Shape;
83    /// Initialize a one-batch state on CPU.
84    fn init(&self) -> TensorCpu<f32>;
85    /// The part of the state that is used in an `att` layer.
86    fn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>;
87    /// The part of the state that is used in an `ffn` layer.
88    fn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>;
89    /// Load a batch of the state from CPU to GPU.
90    fn load(&self, tensor: TensorCpu<f32>, batch: usize) -> Result<(), TensorError>;
91    /// Read back a batch of the state from GPU to CPU.
92    #[cfg(not(target_arch = "wasm32"))]
93    fn back(&self, batch: usize) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>>;
94    /// Read back a batch of the state from GPU to CPU.
95    #[cfg(target_arch = "wasm32")]
96    fn back(&self, batch: usize) -> LocalBoxFuture<'_, Result<TensorCpu<f32>, TensorError>>;
97    /// Write into the state from a GPU tensor.
98    fn write(&self, tensor: TensorGpu<f32, ReadWrite>, batch: usize) -> Result<(), TensorError>;
99    /// Read the state out into a GPU tensor.
100    fn read(&self, batch: usize) -> Result<TensorGpu<f32, ReadWrite>, TensorError>;
101    /// Get an embed vector from a backed state.
102    fn embed(&self, layer: usize, backed: TensorCpu<f32>) -> Result<TensorCpu<f32>, TensorError>;
103}
104
105pub trait Bundle {
106    /// The model info.
107    fn info(&self) -> ModelInfo;
108    #[cfg(not(target_arch = "wasm32"))]
109    /// Get the state from the bundle.
110    fn state(&self) -> impl State + AsAny + Send + Sync + 'static;
111    #[cfg(target_arch = "wasm32")]
112    /// Get the state from the bundle.
113    fn state(&self) -> impl State + AsAny + 'static;
114    #[cfg(not(target_arch = "wasm32"))]
115    /// Get the model from the bundle.
116    fn model(&self) -> impl Serialize + Send + Sync + 'static;
117    #[cfg(target_arch = "wasm32")]
118    /// Get the model from the bundle.
119    fn model(&self) -> impl Serialize + 'static;
120}
121
122/// Quantization of a layer.
123#[wasm_bindgen]
124#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
125pub enum Quant {
126    /// No quantization.
127    #[default]
128    None,
129    /// Use `Int8` quantization.
130    Int8,
131    /// Use `NF4` quantization.
132    NF4,
133    /// Use `SF4` quantization with `nu` set to 5.
134    SF4,
135}
136
137pub struct ModelBuilder<R: Reader> {
138    pub context: Context,
139    pub model: R,
140    pub rescale: Option<usize>,
141    pub sep: Option<usize>,
142    pub lora: Vec<Lora<R>>,
143    pub quant: HashMap<usize, Quant>,
144}
145
146impl<R: Reader> ModelBuilder<R> {
147    pub fn new(context: &Context, model: R) -> Self {
148        Self {
149            context: context.clone(),
150            model,
151            rescale: None,
152            sep: None,
153            lora: vec![],
154            quant: Default::default(),
155        }
156    }
157
158    /// Half the layer and activation every `value` layers.
159    pub fn rescale(mut self, value: usize) -> Self {
160        self.rescale = match value {
161            0 => Some(usize::MAX),
162            x => Some(x),
163        };
164        self
165    }
166
167    /// Separately encoding commands every `value` layers.
168    pub fn sep(mut self, value: usize) -> Self {
169        self.sep = match value {
170            0 => Some(usize::MAX),
171            x => Some(x),
172        };
173        self
174    }
175
176    pub fn lora(mut self, value: Lora<R>) -> Self {
177        self.lora.push(value);
178        self
179    }
180
181    pub fn quant(mut self, value: HashMap<usize, Quant>) -> Self {
182        self.quant = value;
183        self
184    }
185}
186
187pub trait ContextAutoLimits {
188    /// Compute the limits automatically based on given model build info.
189    fn auto_limits(self, info: &ModelInfo) -> Self;
190}
191
192impl ContextAutoLimits for ContextBuilder {
193    fn auto_limits(mut self, info: &ModelInfo) -> Self {
194        self.limits.max_buffer_size = ModelInfo::BUFFER_SIZE
195            .max(info.max_non_head_buffer_size())
196            .max(info.head_buffer_size()) as u64;
197        self.limits.max_storage_buffer_binding_size = ModelInfo::STORAGE_BUFFER_BINDING_SIZE
198            .max(info.max_non_head_buffer_size())
199            .max(info.head_buffer_size())
200            as u32;
201        self
202    }
203}