smelte_rs/cpu/f32/
ops.rs

1use crate::cpu::f32::tensor::Tensor;
2use crate::SmeltError;
3
4#[cfg(feature = "matrixmultiply")]
5use matrixmultiply::sgemm;
6
7#[cfg(feature = "cblas")]
8use cblas_sys::{
9    cblas_sgemm as sgemm, CblasColMajor as ColMajor, CblasNoTrans as NoTr,
10    CblasRowMajor as RowMajor, CblasTrans as Tr,
11};
12
13/// Operation for selecting entire rows within tensor `weights`. Each `id` is the index
14/// of the row.
15pub fn select(ids: &[usize], weights: &Tensor, out: &mut Tensor) -> Result<(), SmeltError> {
16    let sequence_length = ids.len();
17    let vocab_size = weights.shape()[0];
18    let hidden_dim = weights.shape()[1];
19    if out.shape() != [sequence_length, hidden_dim] {
20        return Err(SmeltError::DimensionMismatch {
21            expected: vec![sequence_length, hidden_dim],
22            got: out.shape().to_vec(),
23        });
24    }
25    for (i, id) in ids.iter().enumerate() {
26        let id = *id;
27        if id >= vocab_size {
28            return Err(SmeltError::OutOfVocabulary { vocab_size, id });
29        }
30        let weight_offset = id * hidden_dim;
31        let data_offset = i * hidden_dim;
32        out.data_mut()[data_offset..data_offset + hidden_dim]
33            .copy_from_slice(&weights.data()[weight_offset..weight_offset + hidden_dim]);
34    }
35    Ok(())
36}
37
38/// Regular matrix multiplication
39pub fn matmul<'a>(a: &Tensor<'a>, b: &Tensor<'a>, out: &mut Tensor<'a>) -> Result<(), SmeltError> {
40    g_matmul::<false>(a, b, out)
41}
42
43/// Matrix multiplication matmul(A, B.transposed())
44pub fn matmul_t<'a>(
45    a: &Tensor<'a>,
46    b: &Tensor<'a>,
47    out: &mut Tensor<'a>,
48) -> Result<(), SmeltError> {
49    g_matmul::<true>(a, b, out)
50}
51
52#[inline]
53fn g_matmul<'a, const TRANSPOSE: bool>(
54    a: &Tensor<'a>,
55    b: &Tensor<'a>,
56    c: &mut Tensor<'a>,
57) -> Result<(), SmeltError> {
58    let dim = a.shape().len();
59
60    if dim < 2 {
61        return Err(SmeltError::InsufficientRank { minimum_rank: 2 });
62    }
63    if b.shape().len() != dim {
64        return Err(SmeltError::InvalidRank { expected_rank: dim });
65    }
66    if c.shape().len() != dim {
67        return Err(SmeltError::InvalidRank { expected_rank: dim });
68    }
69
70    let m = a.shape()[dim - 2];
71    let k = a.shape()[dim - 1];
72
73    let mut expected_c = a.shape().to_vec();
74    let mut expected_b = a.shape().to_vec();
75
76    let (expected_b, n) = if TRANSPOSE {
77        let n = b.shape()[dim - 2];
78        expected_b[dim - 2] = n;
79        expected_b[dim - 1] = k;
80        (expected_b, n)
81    } else {
82        let n = b.shape()[dim - 1];
83        expected_b[dim - 2] = k;
84        expected_b[dim - 1] = n;
85        (expected_b, n)
86    };
87
88    expected_c[dim - 2] = m;
89    expected_c[dim - 1] = n;
90
91    if expected_b != b.shape() {
92        return Err(SmeltError::DimensionMismatch {
93            expected: expected_b,
94            got: b.shape().to_vec(),
95        });
96    }
97
98    if expected_c != c.shape() {
99        return Err(SmeltError::DimensionMismatch {
100            expected: expected_c,
101            got: c.shape().to_vec(),
102        });
103    }
104
105    // Zero out c
106    c.data_mut().iter_mut().for_each(|v| *v = 0.0);
107
108    let batching: usize = a.shape()[..dim - 2].iter().product();
109    let a_skip: usize = m * k;
110    let b_skip: usize = n * k;
111    let c_skip: usize = m * n;
112
113    let ar = k as isize;
114    let ac = 1;
115    let (br, bc) = if TRANSPOSE {
116        (1, b.shape()[dim - 1] as isize)
117    } else {
118        (b.shape()[dim - 1] as isize, 1)
119    };
120    let cr = n as isize;
121    let cc = 1;
122
123    (0..batching).for_each(|step| {
124        let ap = &a.data()[step * a_skip..];
125        let bp = &b.data()[step * b_skip..];
126        let cp = &mut c.data_mut()[step * c_skip..];
127
128        #[cfg(feature = "matrixmultiply")]
129        unsafe {
130            sgemm(
131                m,
132                k,
133                n,
134                1.0,
135                ap.as_ptr(),
136                ar,
137                ac,
138                bp.as_ptr(),
139                br,
140                bc,
141                1.0,
142                cp.as_mut_ptr(),
143                cr,
144                cc,
145            );
146        }
147
148        #[cfg(feature = "cblas")]
149        unsafe {
150            let (m, n, k) = (m as libc::c_int, n as libc::c_int, k as libc::c_int);
151            let (layout, a_tr, b_tr, lda, ldb, ldc) = if cr < cc {
152                let (lda, a_tr) = if ar < ac { (m, NoTr) } else { (k, Tr) };
153                let (ldb, b_tr) = if br < bc { (k, NoTr) } else { (n, Tr) };
154                (ColMajor, a_tr, b_tr, lda, ldb, m)
155            } else {
156                let (lda, a_tr) = if ar < ac { (m, Tr) } else { (k, NoTr) };
157                let (ldb, b_tr) = if br < bc { (k, Tr) } else { (n, NoTr) };
158                (RowMajor, a_tr, b_tr, lda, ldb, n)
159            };
160            sgemm(
161                layout,
162                a_tr,
163                b_tr,
164                m,
165                n,
166                k,
167                1.0,
168                ap.as_ptr(),
169                lda,
170                // a_skip as i32,
171                bp.as_ptr(),
172                ldb,
173                // b_skip as i32,
174                1.0,
175                cp.as_mut_ptr(),
176                ldc,
177                // c_skip as i32,
178                // batching as i32,
179            )
180        }
181    });
182    Ok(())
183}
184
185/// tensor elementwise addition. b += a.
186/// a is automatically broadcasted.
187pub fn add(a: &Tensor, b: &mut Tensor) -> Result<(), SmeltError> {
188    if a.shape() == b.shape() {
189        a.data()
190            .iter()
191            .zip(b.data_mut().iter_mut())
192            .for_each(|(left, right)| *right += left);
193        Ok(())
194    } else if &b.shape()[1..] == a.shape() {
195        let n = b.shape()[0];
196        (0..n).for_each(|i| {
197            a.data()
198                .iter()
199                .zip(b.data_mut().iter_mut().skip(i * a.shape()[0]))
200                .for_each(|(left, right)| *right += left);
201        });
202        Ok(())
203    } else {
204        Err(SmeltError::DimensionMismatch {
205            expected: b.shape().to_vec(),
206            got: a.shape().to_vec(),
207        })
208    }
209}
210
211/// tensor elementwise multiplication. b *= a.
212/// a is automatically broadcasted.
213pub fn mul(a: &Tensor, b: &mut Tensor) -> Result<(), SmeltError> {
214    if a.shape() == b.shape() {
215        a.data()
216            .iter()
217            .zip(b.data_mut().iter_mut())
218            .for_each(|(left, right)| *right *= left);
219        Ok(())
220    } else if &b.shape()[1..] == a.shape() {
221        let n = b.shape()[0];
222        (0..n).for_each(|i| {
223            a.data()
224                .iter()
225                .zip(b.data_mut().iter_mut().skip(i * a.shape()[0]))
226                .for_each(|(left, right)| *right *= left);
227        });
228        Ok(())
229    } else {
230        Err(SmeltError::DimensionMismatch {
231            expected: b.shape().to_vec(),
232            got: a.shape().to_vec(),
233        })
234    }
235}
236
237/// Basic operation for the layernorm.
238/// x = (x - x.mean()) / (x.var() + epsilon)
239/// `mean` and `var` do not have to be initialized, they are simply passed to
240/// avoid allocation.
241pub fn normalize(x: &mut Tensor, epsilon: f32) -> Result<(), SmeltError> {
242    let dim = x.shape().len();
243    let size = x.shape()[dim - 1];
244    x.data_mut().chunks_mut(size).for_each(|chunk| {
245        let sum: f32 = chunk.iter().sum();
246        let mean = sum / size as f32;
247        chunk.iter_mut().for_each(|v| *v -= mean);
248        let var: f32 = chunk.iter().map(|v| v * v).sum();
249        let var = var / size as f32;
250        let stddev: f32 = (var + epsilon).sqrt();
251        chunk.iter_mut().for_each(|v| *v /= stddev);
252    });
253    Ok(())
254}
255
256#[inline]
257fn g_softmax<const CAUSAL: bool>(
258    x: &mut Tensor,
259    past_sequence_length: usize,
260) -> Result<(), SmeltError> {
261    let dim = x.shape().len();
262
263    let m = x.shape()[dim - 2];
264    let n = x.shape()[dim - 1];
265
266    x.data_mut()
267        .chunks_mut(n)
268        .enumerate()
269        .for_each(|(i, chunk)| {
270            let i = i % m;
271            let mut current_max = f32::NEG_INFINITY;
272            for (j, &v) in chunk.iter().enumerate() {
273                if (!CAUSAL || i + past_sequence_length >= j) && v > current_max {
274                    current_max = v;
275                }
276            }
277            for v in chunk.iter_mut() {
278                *v -= current_max;
279                *v = (*v).exp();
280            }
281            let mut sum = 0.0;
282            for (j, &v) in chunk.iter().enumerate() {
283                if !CAUSAL || i + past_sequence_length >= j {
284                    sum += v;
285                }
286            }
287            for (j, v) in chunk.iter_mut().enumerate() {
288                if !CAUSAL || i + past_sequence_length >= j {
289                    *v /= sum;
290                } else {
291                    *v = 0.0;
292                }
293            }
294        });
295    Ok(())
296}
297
298/// Softmax on the last dimension for tensor `x`
299pub fn softmax(x: &mut Tensor) -> Result<(), SmeltError> {
300    g_softmax::<false>(x, 0)
301}
302
303/// Causal softmax on the last dimension for tensor `x`. The causality is determined by the
304/// shape of `x` and `past_sequence_length` which defines how big is the missing part of the
305/// square.
306pub fn causal_softmax(x: &mut Tensor, past_sequence_length: usize) -> Result<(), SmeltError> {
307    g_softmax::<true>(x, past_sequence_length)
308}
309
310/// Argmax of the last dimension of tensor `x `.
311pub fn special_argmax(x: &Tensor) -> Result<usize, SmeltError> {
312    if x.shape().len() != 2 {
313        return Err(SmeltError::InvalidRank { expected_rank: 2 });
314    }
315    let n = x.shape()[0];
316    let m = x.shape()[1];
317
318    let mut max = f32::NEG_INFINITY;
319    let mut max_id = usize::MAX;
320    for (i, &v) in x.data().iter().skip((n - 1) * m).enumerate() {
321        if v > max {
322            max = v;
323            max_id = i;
324        }
325    }
326    Ok(max_id)
327}
328
329/// utility function to use a faster but less precise tanh
330pub fn faster_tanh(x: f32) -> f32 {
331    let x2 = x * x;
332    let x3 = x2 * x;
333    let x5 = x3 * x2;
334
335    let a = x + (0.16489087 * x3) + (0.00985468 * x5);
336
337    a / (1.0 + (a * a)).sqrt()
338}
339
340/// utility function to use a faster but less precise tanh
341#[inline]
342pub fn inline_tanh(x: f32) -> f32 {
343    1.0 - (2.0 / (1.0 + (2.0 * x).exp()))
344}
345
346/// `gelu` operation
347/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
348/// but using [faster_tanh]
349#[inline]
350pub fn faster_gelu(v: f32) -> f32 {
351    0.5 * (v)
352        * (1.0 + faster_tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
353}
354
355/// `gelu` operation
356/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
357#[inline]
358pub fn gelu(v: f32) -> f32 {
359    0.5 * (v)
360        * (1.0 + inline_tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
361}
362
363/// Applies `func` to every item of the tensor
364pub fn apply<F: Fn(f32) -> f32 + Sync>(x: &mut Tensor, func: F) {
365    x.data_mut().iter_mut().for_each(|v| *v = func(*v));
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use crate::tests::simplify;
372
373    #[test]
374    fn simple_matmul() {
375        let data = vec![1.0, 2.0, 3.0, 4.0];
376        let a = Tensor::new(data, vec![2, 2]).unwrap();
377        let data = [1.0, 2.0, 3.0, 4.0];
378        let b = Tensor::borrowed(&data, vec![2, 2]).unwrap();
379        let data = vec![0.0; 4];
380        let mut c = Tensor::new(data, vec![2, 2]).unwrap();
381
382        matmul(&a, &b, &mut c).unwrap();
383        assert_eq!(c.data(), &[7.0, 10.0, 15.0, 22.0]);
384        matmul(&a, &b, &mut c).unwrap();
385        assert_eq!(c.data(), &[7.0, 10.0, 15.0, 22.0]);
386
387        let data = vec![1.0, 2.0];
388        let a = Tensor::new(data, vec![2, 1]).unwrap();
389        let data = [3.0, 4.0];
390        let b = Tensor::borrowed(&data, vec![1, 2]).unwrap();
391        let data = vec![0.0; 4];
392        let mut c = Tensor::new(data, vec![2, 2]).unwrap();
393        matmul(&a, &b, &mut c).unwrap();
394        assert_eq!(c.data(), &[3.0, 4.0, 6.0, 8.0]);
395
396        let data: Vec<_> = (0..6).map(|i| i as f32).collect();
397        let a = Tensor::new(data, vec![2, 3]).unwrap();
398        let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
399        let b = Tensor::new(data, vec![3, 2]).unwrap();
400        let mut c = Tensor::zeros(vec![2, 2]);
401        matmul(&a, &b, &mut c).unwrap();
402        assert_eq!(c.data(), &[16., 19., 52., 64.]);
403
404        let data: Vec<_> = (0..12).map(|i| i as f32).collect();
405        let a = Tensor::new(data, vec![2, 2, 3]).unwrap();
406        let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
407        let b = Tensor::new(data, vec![2, 3, 2]).unwrap();
408        let mut c: Tensor = Tensor::zeros(vec![2, 2, 2]);
409        matmul(&a, &b, &mut c).unwrap();
410        assert_eq!(c.data(), &[16., 19., 52., 64., 214., 235., 304., 334.]);
411    }
412
413    #[test]
414    fn simple_matmul_t() {
415        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
416        // A.T
417        let b = Tensor::borrowed(&[1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
418        let mut c = Tensor::zeros(vec![2, 2]);
419
420        matmul_t(&a, &b, &mut c).unwrap();
421        assert_eq!(c.data(), &[7.0, 10.0, 15.0, 22.0]);
422
423        let a = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
424        let b = Tensor::borrowed(&[3.0, 4.0], vec![2, 1]).unwrap();
425        let mut c = Tensor::zeros(vec![2, 2]);
426        matmul_t(&a, &b, &mut c).unwrap();
427        assert_eq!(c.data(), &[3.0, 4.0, 6.0, 8.0]);
428
429        let data: Vec<_> = (0..6).map(|i| i as f32).collect();
430        let a = Tensor::new(data, vec![2, 3]).unwrap();
431        let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
432        let b = Tensor::new(data, vec![2, 3]).unwrap();
433        let mut c = Tensor::zeros(vec![2, 2]);
434        matmul_t(&a, &b, &mut c).unwrap();
435        assert_eq!(c.data(), &[11., 20., 38., 74.]);
436
437        let data: Vec<_> = (0..12).map(|i| i as f32).collect();
438        let a = Tensor::new(data, vec![2, 2, 3]).unwrap();
439        let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
440        let b = Tensor::new(data, vec![2, 2, 3]).unwrap();
441        let mut c = Tensor::zeros(vec![2, 2, 2]);
442        matmul_t(&a, &b, &mut c).unwrap();
443        assert_eq!(c.data(), &[11., 20., 38., 74., 191., 254., 272., 362.]);
444    }
445
446    #[test]
447    fn simple_softmax() {
448        let mut a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
449        softmax(&mut a).unwrap();
450        assert_eq!(
451            simplify(a.data()),
452            // Values obtained through python
453            [0.2689, 0.7311, 0.2689, 0.7311]
454        );
455    }
456
457    #[test]
458    fn simple_causal_softmax() {
459        let mut a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
460        // Large enough for the second test
461        causal_softmax(&mut a, 0).unwrap();
462        assert_eq!(
463            simplify(a.data()),
464            // Values obtained through python
465            [1.0000, 0.0000, 0.2689, 0.7311]
466        );
467
468        let mut a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
469        causal_softmax(&mut a, 1).unwrap();
470        assert_eq!(
471            simplify(a.data()),
472            // Values obtained through python
473            [0.2689, 0.7311, 0.2689, 0.7311]
474        );
475
476        let data: Vec<_> = (0..12).map(|i| (i + 1) as f32).collect();
477        let mut a = Tensor::new(data, vec![3, 2, 2]).unwrap();
478        causal_softmax(&mut a, 0).unwrap();
479        assert_eq!(
480            simplify(a.data()),
481            // Values obtained through python
482            [
483                1.0000, 0.0000, 0.2689, 0.7311, 1.0000, 0.0000, 0.2689, 0.7311, 1.0000, 0.0000,
484                0.2689, 0.7311
485            ]
486        );
487
488        let data: Vec<_> = (0..12).map(|i| (i + 1) as f32).collect();
489        let mut a = Tensor::new(data, vec![2, 2, 3]).unwrap();
490        causal_softmax(&mut a, 1).unwrap();
491        assert_eq!(
492            simplify(a.data()),
493            // Values obtained through python
494            [
495                0.2689, 0.7311, 0.0, 0.09, 0.2447, 0.6652, 0.2689, 0.7311, 0.0, 0.09, 0.2447,
496                0.6652
497            ]
498        );
499    }
500
501    #[test]
502    fn simple_select() {
503        let a = Tensor::borrowed(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
504        let mut tensor = Tensor::zeros(vec![3, 2]);
505        select(&[1, 0, 0], &a, &mut tensor).unwrap();
506        assert_eq!(
507            simplify(tensor.data()),
508            // Values obtained through python
509            [3.0, 4.0, 1.0, 2.0, 1.0, 2.0]
510        );
511    }
512
513    #[test]
514    fn simple_normalize() {
515        let mut a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
516        let epsilon = 1e-5;
517        normalize(&mut a, epsilon).unwrap();
518        assert_eq!(
519            simplify(a.data()),
520            // Values obtained through python
521            [-1.0, 1.0, -1.0, 1.0]
522        );
523    }
524}