1use std::{any::Any, collections::HashMap};
2
3#[cfg(not(target_arch = "wasm32"))]
4use futures::future::BoxFuture;
5#[cfg(target_arch = "wasm32")]
6use futures::future::LocalBoxFuture;
7use half::f16;
8use serde::{Deserialize, Serialize};
9use wasm_bindgen::prelude::wasm_bindgen;
10
11use super::loader::{Lora, Reader, PAD_MAT};
12use crate::{
13 context::{Context, ContextBuilder},
14 impl_deserialize_seed,
15 num::Scalar,
16 tensor::{kind::ReadWrite, shape::Shape, TensorCpu, TensorError, TensorGpu, TensorGpuView},
17};
18
19#[wasm_bindgen]
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
21pub enum ModelVersion {
22 V4,
23 V5,
24 V6,
25 V7,
26}
27
28#[wasm_bindgen]
29#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
30pub struct ModelInfo {
31 pub version: ModelVersion,
32 pub num_layer: usize,
33 pub num_emb: usize,
34 pub num_hidden: usize,
35 pub num_vocab: usize,
36 pub num_head: usize,
37 #[wasm_bindgen(skip)]
38 pub custom: ModelCustomInfo,
39}
40
41impl ModelInfo {
42 pub const BUFFER_SIZE: usize = 256 << 20;
43 pub const STORAGE_BUFFER_BINDING_SIZE: usize = 128 << 20;
44}
45
46impl_deserialize_seed!(ModelInfo);
47
48#[wasm_bindgen]
49impl ModelInfo {
50 pub fn max_non_head_buffer_size(&self) -> usize {
52 self.num_emb * self.num_hidden * f16::size()
53 }
54
55 pub fn head_buffer_size(&self) -> usize {
57 self.num_emb * self.num_vocab_padded() * f16::size()
58 }
59
60 pub fn num_vocab_padded(&self) -> usize {
61 self.num_vocab.next_multiple_of(PAD_MAT[1])
62 }
63}
64
65#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
67pub enum ModelCustomInfo {
68 #[default]
69 None,
70 V6(super::v6::CustomInfo),
71 V7(super::v7::CustomInfo),
72}
73
74pub trait AsAny {
75 fn as_any(&self) -> &dyn Any;
76}
77
78pub trait State {
79 fn num_batch(&self) -> usize;
81 fn init_shape(&self) -> Shape;
83 fn init(&self) -> TensorCpu<f32>;
85 fn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>;
87 fn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>;
89 fn load(&self, tensor: TensorCpu<f32>, batch: usize) -> Result<(), TensorError>;
91 #[cfg(not(target_arch = "wasm32"))]
93 fn back(&self, batch: usize) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>>;
94 #[cfg(target_arch = "wasm32")]
96 fn back(&self, batch: usize) -> LocalBoxFuture<'_, Result<TensorCpu<f32>, TensorError>>;
97 fn write(&self, tensor: TensorGpu<f32, ReadWrite>, batch: usize) -> Result<(), TensorError>;
99 fn read(&self, batch: usize) -> Result<TensorGpu<f32, ReadWrite>, TensorError>;
101 fn embed(&self, layer: usize, backed: TensorCpu<f32>) -> Result<TensorCpu<f32>, TensorError>;
103}
104
105pub trait Bundle {
106 fn info(&self) -> ModelInfo;
108 #[cfg(not(target_arch = "wasm32"))]
109 fn state(&self) -> impl State + AsAny + Send + Sync + 'static;
111 #[cfg(target_arch = "wasm32")]
112 fn state(&self) -> impl State + AsAny + 'static;
114 #[cfg(not(target_arch = "wasm32"))]
115 fn model(&self) -> impl Serialize + Send + Sync + 'static;
117 #[cfg(target_arch = "wasm32")]
118 fn model(&self) -> impl Serialize + 'static;
120}
121
122#[wasm_bindgen]
124#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
125pub enum Quant {
126 #[default]
128 None,
129 Int8,
131 NF4,
133 SF4,
135}
136
137pub struct ModelBuilder<R: Reader> {
138 pub context: Context,
139 pub model: R,
140 pub rescale: Option<usize>,
141 pub sep: Option<usize>,
142 pub lora: Vec<Lora<R>>,
143 pub quant: HashMap<usize, Quant>,
144}
145
146impl<R: Reader> ModelBuilder<R> {
147 pub fn new(context: &Context, model: R) -> Self {
148 Self {
149 context: context.clone(),
150 model,
151 rescale: None,
152 sep: None,
153 lora: vec![],
154 quant: Default::default(),
155 }
156 }
157
158 pub fn rescale(mut self, value: usize) -> Self {
160 self.rescale = match value {
161 0 => Some(usize::MAX),
162 x => Some(x),
163 };
164 self
165 }
166
167 pub fn sep(mut self, value: usize) -> Self {
169 self.sep = match value {
170 0 => Some(usize::MAX),
171 x => Some(x),
172 };
173 self
174 }
175
176 pub fn lora(mut self, value: Lora<R>) -> Self {
177 self.lora.push(value);
178 self
179 }
180
181 pub fn quant(mut self, value: HashMap<usize, Quant>) -> Self {
182 self.quant = value;
183 self
184 }
185}
186
187pub trait ContextAutoLimits {
188 fn auto_limits(self, info: &ModelInfo) -> Self;
190}
191
192impl ContextAutoLimits for ContextBuilder {
193 fn auto_limits(mut self, info: &ModelInfo) -> Self {
194 self.limits.max_buffer_size = ModelInfo::BUFFER_SIZE
195 .max(info.max_non_head_buffer_size())
196 .max(info.head_buffer_size()) as u64;
197 self.limits.max_storage_buffer_binding_size = ModelInfo::STORAGE_BUFFER_BINDING_SIZE
198 .max(info.max_non_head_buffer_size())
199 .max(info.head_buffer_size())
200 as u32;
201 self
202 }
203}