1use crate::cpu::f32::tensor::Tensor;
2use crate::SmeltError;
3
4#[cfg(feature = "matrixmultiply")]
5use matrixmultiply::sgemm;
6
7#[cfg(feature = "cblas")]
8use cblas_sys::{
9 cblas_sgemm as sgemm, CblasColMajor as ColMajor, CblasNoTrans as NoTr,
10 CblasRowMajor as RowMajor, CblasTrans as Tr,
11};
12
13pub fn select(ids: &[usize], weights: &Tensor, out: &mut Tensor) -> Result<(), SmeltError> {
16 let sequence_length = ids.len();
17 let vocab_size = weights.shape()[0];
18 let hidden_dim = weights.shape()[1];
19 if out.shape() != [sequence_length, hidden_dim] {
20 return Err(SmeltError::DimensionMismatch {
21 expected: vec![sequence_length, hidden_dim],
22 got: out.shape().to_vec(),
23 });
24 }
25 for (i, id) in ids.iter().enumerate() {
26 let id = *id;
27 if id >= vocab_size {
28 return Err(SmeltError::OutOfVocabulary { vocab_size, id });
29 }
30 let weight_offset = id * hidden_dim;
31 let data_offset = i * hidden_dim;
32 out.data_mut()[data_offset..data_offset + hidden_dim]
33 .copy_from_slice(&weights.data()[weight_offset..weight_offset + hidden_dim]);
34 }
35 Ok(())
36}
37
38pub fn matmul<'a>(a: &Tensor<'a>, b: &Tensor<'a>, out: &mut Tensor<'a>) -> Result<(), SmeltError> {
40 g_matmul::<false>(a, b, out)
41}
42
43pub fn matmul_t<'a>(
45 a: &Tensor<'a>,
46 b: &Tensor<'a>,
47 out: &mut Tensor<'a>,
48) -> Result<(), SmeltError> {
49 g_matmul::<true>(a, b, out)
50}
51
52#[inline]
53fn g_matmul<'a, const TRANSPOSE: bool>(
54 a: &Tensor<'a>,
55 b: &Tensor<'a>,
56 c: &mut Tensor<'a>,
57) -> Result<(), SmeltError> {
58 let dim = a.shape().len();
59
60 if dim < 2 {
61 return Err(SmeltError::InsufficientRank { minimum_rank: 2 });
62 }
63 if b.shape().len() != dim {
64 return Err(SmeltError::InvalidRank { expected_rank: dim });
65 }
66 if c.shape().len() != dim {
67 return Err(SmeltError::InvalidRank { expected_rank: dim });
68 }
69
70 let m = a.shape()[dim - 2];
71 let k = a.shape()[dim - 1];
72
73 let mut expected_c = a.shape().to_vec();
74 let mut expected_b = a.shape().to_vec();
75
76 let (expected_b, n) = if TRANSPOSE {
77 let n = b.shape()[dim - 2];
78 expected_b[dim - 2] = n;
79 expected_b[dim - 1] = k;
80 (expected_b, n)
81 } else {
82 let n = b.shape()[dim - 1];
83 expected_b[dim - 2] = k;
84 expected_b[dim - 1] = n;
85 (expected_b, n)
86 };
87
88 expected_c[dim - 2] = m;
89 expected_c[dim - 1] = n;
90
91 if expected_b != b.shape() {
92 return Err(SmeltError::DimensionMismatch {
93 expected: expected_b,
94 got: b.shape().to_vec(),
95 });
96 }
97
98 if expected_c != c.shape() {
99 return Err(SmeltError::DimensionMismatch {
100 expected: expected_c,
101 got: c.shape().to_vec(),
102 });
103 }
104
105 c.data_mut().iter_mut().for_each(|v| *v = 0.0);
107
108 let batching: usize = a.shape()[..dim - 2].iter().product();
109 let a_skip: usize = m * k;
110 let b_skip: usize = n * k;
111 let c_skip: usize = m * n;
112
113 let ar = k as isize;
114 let ac = 1;
115 let (br, bc) = if TRANSPOSE {
116 (1, b.shape()[dim - 1] as isize)
117 } else {
118 (b.shape()[dim - 1] as isize, 1)
119 };
120 let cr = n as isize;
121 let cc = 1;
122
123 (0..batching).for_each(|step| {
124 let ap = &a.data()[step * a_skip..];
125 let bp = &b.data()[step * b_skip..];
126 let cp = &mut c.data_mut()[step * c_skip..];
127
128 #[cfg(feature = "matrixmultiply")]
129 unsafe {
130 sgemm(
131 m,
132 k,
133 n,
134 1.0,
135 ap.as_ptr(),
136 ar,
137 ac,
138 bp.as_ptr(),
139 br,
140 bc,
141 1.0,
142 cp.as_mut_ptr(),
143 cr,
144 cc,
145 );
146 }
147
148 #[cfg(feature = "cblas")]
149 unsafe {
150 let (m, n, k) = (m as libc::c_int, n as libc::c_int, k as libc::c_int);
151 let (layout, a_tr, b_tr, lda, ldb, ldc) = if cr < cc {
152 let (lda, a_tr) = if ar < ac { (m, NoTr) } else { (k, Tr) };
153 let (ldb, b_tr) = if br < bc { (k, NoTr) } else { (n, Tr) };
154 (ColMajor, a_tr, b_tr, lda, ldb, m)
155 } else {
156 let (lda, a_tr) = if ar < ac { (m, Tr) } else { (k, NoTr) };
157 let (ldb, b_tr) = if br < bc { (k, Tr) } else { (n, NoTr) };
158 (RowMajor, a_tr, b_tr, lda, ldb, n)
159 };
160 sgemm(
161 layout,
162 a_tr,
163 b_tr,
164 m,
165 n,
166 k,
167 1.0,
168 ap.as_ptr(),
169 lda,
170 bp.as_ptr(),
172 ldb,
173 1.0,
175 cp.as_mut_ptr(),
176 ldc,
177 )
180 }
181 });
182 Ok(())
183}
184
185pub fn add(a: &Tensor, b: &mut Tensor) -> Result<(), SmeltError> {
188 if a.shape() == b.shape() {
189 a.data()
190 .iter()
191 .zip(b.data_mut().iter_mut())
192 .for_each(|(left, right)| *right += left);
193 Ok(())
194 } else if &b.shape()[1..] == a.shape() {
195 let n = b.shape()[0];
196 (0..n).for_each(|i| {
197 a.data()
198 .iter()
199 .zip(b.data_mut().iter_mut().skip(i * a.shape()[0]))
200 .for_each(|(left, right)| *right += left);
201 });
202 Ok(())
203 } else {
204 Err(SmeltError::DimensionMismatch {
205 expected: b.shape().to_vec(),
206 got: a.shape().to_vec(),
207 })
208 }
209}
210
211pub fn mul(a: &Tensor, b: &mut Tensor) -> Result<(), SmeltError> {
214 if a.shape() == b.shape() {
215 a.data()
216 .iter()
217 .zip(b.data_mut().iter_mut())
218 .for_each(|(left, right)| *right *= left);
219 Ok(())
220 } else if &b.shape()[1..] == a.shape() {
221 let n = b.shape()[0];
222 (0..n).for_each(|i| {
223 a.data()
224 .iter()
225 .zip(b.data_mut().iter_mut().skip(i * a.shape()[0]))
226 .for_each(|(left, right)| *right *= left);
227 });
228 Ok(())
229 } else {
230 Err(SmeltError::DimensionMismatch {
231 expected: b.shape().to_vec(),
232 got: a.shape().to_vec(),
233 })
234 }
235}
236
237pub fn normalize(x: &mut Tensor, epsilon: f32) -> Result<(), SmeltError> {
242 let dim = x.shape().len();
243 let size = x.shape()[dim - 1];
244 x.data_mut().chunks_mut(size).for_each(|chunk| {
245 let sum: f32 = chunk.iter().sum();
246 let mean = sum / size as f32;
247 chunk.iter_mut().for_each(|v| *v -= mean);
248 let var: f32 = chunk.iter().map(|v| v * v).sum();
249 let var = var / size as f32;
250 let stddev: f32 = (var + epsilon).sqrt();
251 chunk.iter_mut().for_each(|v| *v /= stddev);
252 });
253 Ok(())
254}
255
256#[inline]
257fn g_softmax<const CAUSAL: bool>(
258 x: &mut Tensor,
259 past_sequence_length: usize,
260) -> Result<(), SmeltError> {
261 let dim = x.shape().len();
262
263 let m = x.shape()[dim - 2];
264 let n = x.shape()[dim - 1];
265
266 x.data_mut()
267 .chunks_mut(n)
268 .enumerate()
269 .for_each(|(i, chunk)| {
270 let i = i % m;
271 let mut current_max = f32::NEG_INFINITY;
272 for (j, &v) in chunk.iter().enumerate() {
273 if (!CAUSAL || i + past_sequence_length >= j) && v > current_max {
274 current_max = v;
275 }
276 }
277 for v in chunk.iter_mut() {
278 *v -= current_max;
279 *v = (*v).exp();
280 }
281 let mut sum = 0.0;
282 for (j, &v) in chunk.iter().enumerate() {
283 if !CAUSAL || i + past_sequence_length >= j {
284 sum += v;
285 }
286 }
287 for (j, v) in chunk.iter_mut().enumerate() {
288 if !CAUSAL || i + past_sequence_length >= j {
289 *v /= sum;
290 } else {
291 *v = 0.0;
292 }
293 }
294 });
295 Ok(())
296}
297
298pub fn softmax(x: &mut Tensor) -> Result<(), SmeltError> {
300 g_softmax::<false>(x, 0)
301}
302
303pub fn causal_softmax(x: &mut Tensor, past_sequence_length: usize) -> Result<(), SmeltError> {
307 g_softmax::<true>(x, past_sequence_length)
308}
309
310pub fn special_argmax(x: &Tensor) -> Result<usize, SmeltError> {
312 if x.shape().len() != 2 {
313 return Err(SmeltError::InvalidRank { expected_rank: 2 });
314 }
315 let n = x.shape()[0];
316 let m = x.shape()[1];
317
318 let mut max = f32::NEG_INFINITY;
319 let mut max_id = usize::MAX;
320 for (i, &v) in x.data().iter().skip((n - 1) * m).enumerate() {
321 if v > max {
322 max = v;
323 max_id = i;
324 }
325 }
326 Ok(max_id)
327}
328
329pub fn faster_tanh(x: f32) -> f32 {
331 let x2 = x * x;
332 let x3 = x2 * x;
333 let x5 = x3 * x2;
334
335 let a = x + (0.16489087 * x3) + (0.00985468 * x5);
336
337 a / (1.0 + (a * a)).sqrt()
338}
339
340#[inline]
342pub fn inline_tanh(x: f32) -> f32 {
343 1.0 - (2.0 / (1.0 + (2.0 * x).exp()))
344}
345
346#[inline]
350pub fn faster_gelu(v: f32) -> f32 {
351 0.5 * (v)
352 * (1.0 + faster_tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
353}
354
355#[inline]
358pub fn gelu(v: f32) -> f32 {
359 0.5 * (v)
360 * (1.0 + inline_tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
361}
362
363pub fn apply<F: Fn(f32) -> f32 + Sync>(x: &mut Tensor, func: F) {
365 x.data_mut().iter_mut().for_each(|v| *v = func(*v));
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use crate::tests::simplify;
372
373 #[test]
374 fn simple_matmul() {
375 let data = vec![1.0, 2.0, 3.0, 4.0];
376 let a = Tensor::new(data, vec![2, 2]).unwrap();
377 let data = [1.0, 2.0, 3.0, 4.0];
378 let b = Tensor::borrowed(&data, vec![2, 2]).unwrap();
379 let data = vec![0.0; 4];
380 let mut c = Tensor::new(data, vec![2, 2]).unwrap();
381
382 matmul(&a, &b, &mut c).unwrap();
383 assert_eq!(c.data(), &[7.0, 10.0, 15.0, 22.0]);
384 matmul(&a, &b, &mut c).unwrap();
385 assert_eq!(c.data(), &[7.0, 10.0, 15.0, 22.0]);
386
387 let data = vec![1.0, 2.0];
388 let a = Tensor::new(data, vec![2, 1]).unwrap();
389 let data = [3.0, 4.0];
390 let b = Tensor::borrowed(&data, vec![1, 2]).unwrap();
391 let data = vec![0.0; 4];
392 let mut c = Tensor::new(data, vec![2, 2]).unwrap();
393 matmul(&a, &b, &mut c).unwrap();
394 assert_eq!(c.data(), &[3.0, 4.0, 6.0, 8.0]);
395
396 let data: Vec<_> = (0..6).map(|i| i as f32).collect();
397 let a = Tensor::new(data, vec![2, 3]).unwrap();
398 let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
399 let b = Tensor::new(data, vec![3, 2]).unwrap();
400 let mut c = Tensor::zeros(vec![2, 2]);
401 matmul(&a, &b, &mut c).unwrap();
402 assert_eq!(c.data(), &[16., 19., 52., 64.]);
403
404 let data: Vec<_> = (0..12).map(|i| i as f32).collect();
405 let a = Tensor::new(data, vec![2, 2, 3]).unwrap();
406 let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
407 let b = Tensor::new(data, vec![2, 3, 2]).unwrap();
408 let mut c: Tensor = Tensor::zeros(vec![2, 2, 2]);
409 matmul(&a, &b, &mut c).unwrap();
410 assert_eq!(c.data(), &[16., 19., 52., 64., 214., 235., 304., 334.]);
411 }
412
413 #[test]
414 fn simple_matmul_t() {
415 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
416 let b = Tensor::borrowed(&[1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
418 let mut c = Tensor::zeros(vec![2, 2]);
419
420 matmul_t(&a, &b, &mut c).unwrap();
421 assert_eq!(c.data(), &[7.0, 10.0, 15.0, 22.0]);
422
423 let a = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
424 let b = Tensor::borrowed(&[3.0, 4.0], vec![2, 1]).unwrap();
425 let mut c = Tensor::zeros(vec![2, 2]);
426 matmul_t(&a, &b, &mut c).unwrap();
427 assert_eq!(c.data(), &[3.0, 4.0, 6.0, 8.0]);
428
429 let data: Vec<_> = (0..6).map(|i| i as f32).collect();
430 let a = Tensor::new(data, vec![2, 3]).unwrap();
431 let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
432 let b = Tensor::new(data, vec![2, 3]).unwrap();
433 let mut c = Tensor::zeros(vec![2, 2]);
434 matmul_t(&a, &b, &mut c).unwrap();
435 assert_eq!(c.data(), &[11., 20., 38., 74.]);
436
437 let data: Vec<_> = (0..12).map(|i| i as f32).collect();
438 let a = Tensor::new(data, vec![2, 2, 3]).unwrap();
439 let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
440 let b = Tensor::new(data, vec![2, 2, 3]).unwrap();
441 let mut c = Tensor::zeros(vec![2, 2, 2]);
442 matmul_t(&a, &b, &mut c).unwrap();
443 assert_eq!(c.data(), &[11., 20., 38., 74., 191., 254., 272., 362.]);
444 }
445
446 #[test]
447 fn simple_softmax() {
448 let mut a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
449 softmax(&mut a).unwrap();
450 assert_eq!(
451 simplify(a.data()),
452 [0.2689, 0.7311, 0.2689, 0.7311]
454 );
455 }
456
457 #[test]
458 fn simple_causal_softmax() {
459 let mut a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
460 causal_softmax(&mut a, 0).unwrap();
462 assert_eq!(
463 simplify(a.data()),
464 [1.0000, 0.0000, 0.2689, 0.7311]
466 );
467
468 let mut a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
469 causal_softmax(&mut a, 1).unwrap();
470 assert_eq!(
471 simplify(a.data()),
472 [0.2689, 0.7311, 0.2689, 0.7311]
474 );
475
476 let data: Vec<_> = (0..12).map(|i| (i + 1) as f32).collect();
477 let mut a = Tensor::new(data, vec![3, 2, 2]).unwrap();
478 causal_softmax(&mut a, 0).unwrap();
479 assert_eq!(
480 simplify(a.data()),
481 [
483 1.0000, 0.0000, 0.2689, 0.7311, 1.0000, 0.0000, 0.2689, 0.7311, 1.0000, 0.0000,
484 0.2689, 0.7311
485 ]
486 );
487
488 let data: Vec<_> = (0..12).map(|i| (i + 1) as f32).collect();
489 let mut a = Tensor::new(data, vec![2, 2, 3]).unwrap();
490 causal_softmax(&mut a, 1).unwrap();
491 assert_eq!(
492 simplify(a.data()),
493 [
495 0.2689, 0.7311, 0.0, 0.09, 0.2447, 0.6652, 0.2689, 0.7311, 0.0, 0.09, 0.2447,
496 0.6652
497 ]
498 );
499 }
500
501 #[test]
502 fn simple_select() {
503 let a = Tensor::borrowed(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
504 let mut tensor = Tensor::zeros(vec![3, 2]);
505 select(&[1, 0, 0], &a, &mut tensor).unwrap();
506 assert_eq!(
507 simplify(tensor.data()),
508 [3.0, 4.0, 1.0, 2.0, 1.0, 2.0]
510 );
511 }
512
513 #[test]
514 fn simple_normalize() {
515 let mut a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
516 let epsilon = 1e-5;
517 normalize(&mut a, epsilon).unwrap();
518 assert_eq!(
519 simplify(a.data()),
520 [-1.0, 1.0, -1.0, 1.0]
522 );
523 }
524}