1use crate::{Device, LazyBuffer as LB, Result};
2
3pub mod lb {
4 use super::*;
5
6 pub fn add<D: Device>(dev: &D) -> Result<()> {
7 let lhs = LB::cst(40., (5, 2), dev)?;
8 let rhs = LB::cst(2., (5, 2), dev)?;
9 let lb = lhs.binary(crate::lang::BinaryOp::Add, rhs)?;
10 let data = lb.data_vec::<f32>()?;
11 assert_eq!(data, [42., 42., 42., 42., 42., 42., 42., 42., 42., 42.]);
12 Ok(())
13 }
14
15 pub fn mm<D: Device>(dev: &D) -> Result<()> {
16 let lhs = LB::cst(1., (4, 2), dev)?;
17 let rhs = LB::cst(2., (2, 3), dev)?;
18 let lb = lhs.matmul(rhs)?;
19 let data = lb.data_vec::<f32>()?;
20 assert_eq!(data, [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0]);
21 Ok(())
22 }
23
24 pub fn cat<D: Device>(dev: &D) -> Result<()> {
25 let arg0 = LB::copy([1f32, 2.0, 3.0, 4.0].as_slice(), (2, 2), dev)?;
26 let arg1 = LB::copy([1.1f32, 2.1, 3.1, 4.1].as_slice(), (2, 2), dev)?;
27 let arg2 = LB::copy([1.2f32, 2.2, 3.2, 4.2].as_slice(), (2, 2), dev)?;
28 let lb = LB::<D>::cat(&[&arg0, &arg1, &arg2], 0)?;
29 let data = lb.data_vec::<f32>()?;
30 assert_eq!(data, [1.0, 2.0, 3.0, 4.0, 1.1, 2.1, 3.1, 4.1, 1.2, 2.2, 3.2, 4.2]);
31 let lb = LB::<D>::cat(&[&arg0, &arg1, &arg2], 1)?;
32 let data = lb.data_vec::<f32>()?;
33 assert_eq!(data, [1.0, 2.0, 1.1, 2.1, 1.2, 2.2, 3.2, 4.2, 0.0, 0.0, 0.0, 0.0]);
34 Ok(())
35 }
36}