smelte_rs/nn/models/
bert.rs

1use crate::cpu::f32::{matmul, matmul_t, softmax, Tensor as F32Tensor};
2use crate::nn::layers::{Embedding, LayerNorm, Linear};
3use crate::traits::{Tensor, TensorOps};
4use crate::SmeltError;
5
6macro_rules! debug {
7    // `()` indicates that the macro takes no argument.
8    ($str: expr, $tensor: expr) => {
9        // The macro will expand into the contents of this block.
10        // println!(
11        //     "{} {:?}..{:?}",
12        //     $str,
13        //     &$tensor.data()[..3],
14        //     &$tensor.data()[$tensor.data().len() - 3..]
15        // );
16        // let n = $tensor.data().len();
17        // println!(
18        //     "{} {:?}..{:?}",
19        //     $str,
20        //     &$tensor.data()[..3],
21        //     &$tensor.data()[n - 768 * 3..n - 768 * 3 + 3]
22        // );
23    };
24}
25
26/// TODO
27pub struct BertContext<T: Tensor> {
28    input_ids: Vec<usize>,
29    type_ids: Vec<usize>,
30    position_ids: Vec<usize>,
31    hidden_states: T,
32    // Required to compute position_ids before adding into hidden_states
33    // - Used in the MLP to prevent cloning the skip connection
34    // - Used in the attention for the output Linear layer
35    hidden_states_copy: T,
36    // Store the hidden_states after the attention (prevents a clone in the skip connection)
37    hidden_states_attn_output: T,
38    q_cache: T,
39    // Store the k splitted_heads
40    k_cache: T,
41    // Store the k splitted_heads
42    v_cache: T,
43    // Store the qk result
44    qk: T,
45    qkv: T,
46    // Intermediate states (H, 4H)
47    intermediate_states: T,
48    pool: T,
49    pool_output: T,
50    probs: T,
51}
52
53impl<T: Tensor> BertContext<T> {
54    /// TODO
55    pub fn probs(&self) -> &T {
56        &self.probs
57    }
58}
59
60fn split_heads(q: &F32Tensor, out_q: &mut F32Tensor) -> Result<(), SmeltError> {
61    let num_heads = out_q.shape()[0];
62    let sequence_length = out_q.shape()[1];
63    let head_dim = out_q.shape()[2];
64    let hidden_dim = head_dim * num_heads;
65
66    (0..num_heads).for_each(|i| {
67        (0..sequence_length).for_each(|j| {
68            (0..head_dim).for_each(|k| {
69                let index = j * hidden_dim + i * head_dim + k;
70                let out_index = i * sequence_length * head_dim + j * head_dim + k;
71                out_q.data_mut()[out_index] = q.data()[index];
72            });
73        });
74    });
75    Ok(())
76}
77
78fn attention<'data, 'ctx>(
79    q_weights: &Linear<F32Tensor<'data>>,
80    k_weights: &Linear<F32Tensor<'data>>,
81    v_weights: &Linear<F32Tensor<'data>>,
82    ctx: &mut BertContext<F32Tensor<'ctx>>,
83) -> Result<(), SmeltError>
84where
85    'data: 'ctx,
86{
87    q_weights.forward(&ctx.hidden_states, &mut ctx.hidden_states_copy)?;
88    split_heads(&ctx.hidden_states_copy, &mut ctx.q_cache)?;
89
90    debug!("Q head splitted", ctx.q_cache);
91
92    k_weights.forward(&ctx.hidden_states, &mut ctx.hidden_states_copy)?;
93    split_heads(&ctx.hidden_states_copy, &mut ctx.k_cache)?;
94
95    debug!("K head splitted", ctx.k_cache);
96
97    v_weights.forward(&ctx.hidden_states, &mut ctx.hidden_states_copy)?;
98    split_heads(&ctx.hidden_states_copy, &mut ctx.v_cache)?;
99
100    debug!("V head splitted", ctx.v_cache);
101
102    matmul_t(&ctx.q_cache, &ctx.k_cache, &mut ctx.qk).unwrap();
103
104    let num_heads = ctx.q_cache.shape()[0];
105    let sequence_length = ctx.q_cache.shape()[1];
106    let head_dim = ctx.q_cache.shape()[2];
107    let hidden_dim = head_dim * num_heads;
108    let scale = (head_dim as f32).sqrt();
109    ctx.qk.data_mut().iter_mut().for_each(|v| *v /= scale);
110
111    softmax(&mut ctx.qk).unwrap();
112    debug!("attention_probs", ctx.qk);
113    matmul(&ctx.qk, &ctx.v_cache, &mut ctx.qkv).unwrap();
114    debug!("qkv", ctx.qkv);
115
116    let new_out = &mut ctx.hidden_states_attn_output.data_mut();
117    (0..num_heads).for_each(|i| {
118        (0..sequence_length).for_each(|j| {
119            (0..head_dim).for_each(|k| {
120                let in_index = i * sequence_length * head_dim + j * head_dim + k;
121                let out_index = j * hidden_dim + i * head_dim + k;
122                new_out[out_index] = (ctx.qkv).data()[in_index];
123            });
124        });
125    });
126    debug!("qkv (reshaed)", ctx.hidden_states_attn_output);
127
128    Ok(())
129}
130
131/// TODO
132pub trait TensorAttention<T: Tensor> {
133    /// TODO
134    fn attention(
135        query: &Linear<T>,
136        key: &Linear<T>,
137        value: &Linear<T>,
138        ctx: &mut BertContext<T>,
139    ) -> Result<(), SmeltError>;
140}
141
142impl<'a> TensorAttention<F32Tensor<'a>> for F32Tensor<'a> {
143    fn attention(
144        query: &Linear<F32Tensor<'a>>,
145        key: &Linear<F32Tensor<'a>>,
146        value: &Linear<F32Tensor<'a>>,
147        ctx: &mut BertContext<F32Tensor<'a>>,
148    ) -> Result<(), SmeltError> {
149        attention(query, key, value, ctx)?;
150        Ok(())
151    }
152}
153
154/// TODO
155pub trait Debug<T: Tensor> {
156    /// TODO
157    fn data(&self) -> &[f32];
158}
159
160impl<'a> Debug<F32Tensor<'a>> for F32Tensor<'a> {
161    fn data(&self) -> &[f32] {
162        self.data()
163    }
164}
165
166/// TODO
167pub trait BertOps<T: Tensor>: TensorOps<T> + TensorAttention<T> + Debug<T> {}
168
169impl<'a> BertOps<F32Tensor<'a>> for F32Tensor<'a> {}
170
171/// TODO
172#[derive(Clone)]
173pub struct BertAttention<T: Tensor> {
174    query: Linear<T>,
175    key: Linear<T>,
176    value: Linear<T>,
177    output: Linear<T>,
178    output_ln: LayerNorm<T>,
179}
180
181impl<T: Tensor + BertOps<T>> BertAttention<T> {
182    /// TODO
183    pub fn new(
184        query: Linear<T>,
185        key: Linear<T>,
186        value: Linear<T>,
187        output: Linear<T>,
188        output_ln: LayerNorm<T>,
189    ) -> Self {
190        Self {
191            query,
192            key,
193            value,
194            output,
195            output_ln,
196        }
197    }
198
199    /// TODO
200    pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
201        T::attention(&self.query, &self.key, &self.value, ctx)?;
202
203        self.output
204            .forward(&ctx.hidden_states_attn_output, &mut ctx.hidden_states_copy)?;
205        T::add(&ctx.hidden_states_copy, &mut ctx.hidden_states)?;
206        self.output_ln.forward(&mut ctx.hidden_states)?;
207        Ok(())
208    }
209}
210
211/// TODO
212#[derive(Clone)]
213pub struct Mlp<T: Tensor> {
214    intermediate: Linear<T>,
215    output: Linear<T>,
216    output_ln: LayerNorm<T>,
217}
218
219impl<T: Tensor + BertOps<T>> Mlp<T> {
220    /// TODO
221    pub fn new(intermediate: Linear<T>, output: Linear<T>, output_ln: LayerNorm<T>) -> Self {
222        Self {
223            intermediate,
224            output,
225            output_ln,
226        }
227    }
228
229    /// TODO
230    pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
231        // println!("=====");
232        debug!("Before MLP", ctx.hidden_states);
233        self.intermediate
234            .forward(&ctx.hidden_states, &mut ctx.intermediate_states)?;
235        debug!("Intermediate ", ctx.intermediate_states);
236        T::gelu(&mut ctx.intermediate_states)?;
237        debug!("Intermediate (gelu)", ctx.intermediate_states);
238        self.output
239            .forward(&ctx.intermediate_states, &mut ctx.hidden_states_copy)?;
240        debug!("output", ctx.hidden_states_copy);
241        T::add(&ctx.hidden_states_copy, &mut ctx.hidden_states)?;
242        debug!("output (skip)", ctx.hidden_states);
243        self.output_ln.forward(&mut ctx.hidden_states)?;
244        debug!("output ln", ctx.hidden_states);
245        Ok(())
246    }
247}
248
249/// TODO
250#[derive(Clone)]
251pub struct BertLayer<T: Tensor> {
252    attention: BertAttention<T>,
253    mlp: Mlp<T>,
254}
255
256impl<T: Tensor + BertOps<T>> BertLayer<T> {
257    /// TODO
258    pub fn new(attention: BertAttention<T>, mlp: Mlp<T>) -> Self {
259        Self { attention, mlp }
260    }
261
262    /// TODO
263    pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
264        debug!("Before attention", ctx.hidden_states);
265        self.attention.forward(ctx)?;
266        debug!("After attention", ctx.hidden_states);
267        self.mlp.forward(ctx)?;
268        debug!("After mlp", ctx.hidden_states);
269        // println!("---------");
270        Ok(())
271    }
272}
273
274/// TODO
275#[derive(Clone)]
276pub struct BertEncoder<T: Tensor> {
277    layers: Vec<BertLayer<T>>,
278}
279
280impl<T: Tensor + BertOps<T>> BertEncoder<T> {
281    /// TODO
282    pub fn new(layers: Vec<BertLayer<T>>) -> Self {
283        Self { layers }
284    }
285
286    /// TODO
287    pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
288        for layer in &self.layers {
289            layer.forward(ctx)?;
290        }
291        Ok(())
292    }
293}
294
295/// TODO
296#[derive(Clone)]
297pub struct BertEmbeddings<T: Tensor> {
298    input_embeddings: Embedding<T>,
299    position_embeddings: Embedding<T>,
300    type_embeddings: Embedding<T>,
301    layer_norm: LayerNorm<T>,
302}
303
304impl<T: Tensor + BertOps<T>> BertEmbeddings<T> {
305    /// TODO
306    pub fn new(
307        input_embeddings: Embedding<T>,
308        position_embeddings: Embedding<T>,
309        type_embeddings: Embedding<T>,
310        layer_norm: LayerNorm<T>,
311    ) -> Self {
312        Self {
313            input_embeddings,
314            position_embeddings,
315            type_embeddings,
316            layer_norm,
317        }
318    }
319
320    /// TODO
321    pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
322        let input_ids = &ctx.input_ids;
323        let position_ids = &ctx.position_ids;
324        let type_ids = &ctx.type_ids;
325
326        if input_ids.len() != position_ids.len() {
327            return Err(SmeltError::InvalidLength {
328                expected: input_ids.len(),
329                got: position_ids.len(),
330            });
331        }
332        if input_ids.len() != type_ids.len() {
333            return Err(SmeltError::InvalidLength {
334                expected: input_ids.len(),
335                got: type_ids.len(),
336            });
337        }
338
339        self.input_embeddings
340            .forward(input_ids, &mut ctx.hidden_states)?;
341
342        debug!("input embeddings", ctx.hidden_states);
343
344        self.type_embeddings
345            .forward(type_ids, &mut ctx.hidden_states_copy)?;
346        debug!("type embeddings", ctx.hidden_states_copy);
347        T::add(&ctx.hidden_states_copy, &mut ctx.hidden_states)?;
348        debug!("After add type embeddings", ctx.hidden_states);
349
350        self.position_embeddings
351            .forward(position_ids, &mut ctx.hidden_states_copy)?;
352        debug!("position embeddings", ctx.hidden_states_copy);
353        T::add(&ctx.hidden_states_copy, &mut ctx.hidden_states)?;
354        debug!("After add position embeddings", ctx.hidden_states);
355
356        self.layer_norm.forward(&mut ctx.hidden_states)?;
357
358        debug!("After embeddings", ctx.hidden_states);
359        Ok(())
360    }
361}
362
363/// TODO
364pub struct Bert<T: Tensor + BertOps<T>> {
365    embeddings: BertEmbeddings<T>,
366    encoder: BertEncoder<T>,
367}
368
369impl<T: Tensor + BertOps<T>> Bert<T> {
370    /// TODO
371    pub fn new(embeddings: BertEmbeddings<T>, encoder: BertEncoder<T>) -> Self {
372        Self {
373            embeddings,
374            encoder,
375        }
376    }
377    /// TODO
378    pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
379        self.embeddings.forward(ctx)?;
380        self.encoder.forward(ctx)
381    }
382}
383
384/// TODO
385#[derive(Clone)]
386pub struct BertPooler<T: Tensor> {
387    pooler: Linear<T>,
388}
389
390impl<T: Tensor + BertOps<T>> BertPooler<T> {
391    /// TODO
392    pub fn new(pooler: Linear<T>) -> Self {
393        Self { pooler }
394    }
395
396    /// TODO
397    pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
398        T::select(&[0], &ctx.hidden_states, &mut ctx.pool)?;
399        self.pooler.forward(&ctx.pool, &mut ctx.pool_output)?;
400        T::tanh(&mut ctx.pool_output)?;
401        Ok(())
402    }
403}
404
405/// TODO
406pub struct BertClassifier<T: Tensor + BertOps<T>> {
407    bert: Bert<T>,
408    pooler: BertPooler<T>,
409    classifier: Linear<T>,
410}
411
412impl<T: Tensor + BertOps<T> + TensorAttention<T>> BertClassifier<T> {
413    /// TODO
414    pub fn new(bert: Bert<T>, pooler: BertPooler<T>, classifier: Linear<T>) -> Self {
415        Self {
416            bert,
417            pooler,
418            classifier,
419        }
420    }
421
422    /// TODO
423    pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
424        self.bert.forward(ctx)?;
425        self.pooler.forward(ctx)?;
426        self.classifier.forward(&ctx.pool_output, &mut ctx.probs)?;
427        T::softmax(&mut ctx.probs)?;
428        Ok(())
429    }
430
431    /// TODO
432    pub fn new_context(
433        &self,
434        input_ids: Vec<usize>,
435        position_ids: Vec<usize>,
436        type_ids: Vec<usize>,
437        num_heads: usize,
438    ) -> BertContext<T> {
439        let hidden_dim = self.bert.embeddings.input_embeddings.weight().shape()[1];
440        let intermediate_dim = self.bert.encoder.layers[0]
441            .mlp
442            .intermediate
443            .weight()
444            .shape()[0];
445        let num_classes = self.classifier.weight().shape()[0];
446        let head_dim = hidden_dim / num_heads;
447        let sequence_length = input_ids.len();
448
449        let hidden_states = T::zeros(vec![sequence_length, hidden_dim]);
450        let hidden_states_copy = T::zeros(vec![sequence_length, hidden_dim]);
451        let hidden_states_attn_output = T::zeros(vec![sequence_length, hidden_dim]);
452        let intermediate_states = T::zeros(vec![sequence_length, intermediate_dim]);
453        let q_cache = T::zeros(vec![num_heads, sequence_length, head_dim]);
454        let k_cache = T::zeros(vec![num_heads, sequence_length, head_dim]);
455        let v_cache = T::zeros(vec![num_heads, sequence_length, head_dim]);
456        let qk = T::zeros(vec![num_heads, sequence_length, sequence_length]);
457        let qkv = T::zeros(vec![num_heads, sequence_length, head_dim]);
458        let pool = T::zeros(vec![1, hidden_dim]);
459        let pool_output = T::zeros(vec![1, hidden_dim]);
460        let probs = T::zeros(vec![1, num_classes]);
461        BertContext {
462            input_ids,
463            position_ids,
464            type_ids,
465            hidden_states,
466            hidden_states_copy,
467            hidden_states_attn_output,
468            intermediate_states,
469            q_cache,
470            k_cache,
471            v_cache,
472            qk,
473            qkv,
474            pool,
475            pool_output,
476            probs,
477        }
478    }
479}