use crate::{
dim::{Dim1, Dim2, Dim3, Dim4, DimTrait},
index::Index0D,
matrix::{AsMutPtr, AsPtr, IndexAxisDyn, IndexAxisMutDyn, MatrixBase, ViewMutMatix},
matrix_impl::{matrix_into_dim, Matrix},
memory_impl::{ViewMem, ViewMutMem},
num::Num,
operation::copy_from::CopyFrom,
};
fn add_assign_1d_1d_cpu<T: Num, D: DimTrait>(
a: &mut Matrix<ViewMutMem<T>, D>,
b: &Matrix<ViewMem<T>, D>,
) {
let num_elm = a.shape().num_elm();
let inner_slice_a = a.stride()[a.shape().len() - 1];
let inner_slice_b = b.stride()[b.shape().len() - 1];
let a_slice = a.as_mut_slice();
let b_slice = b.as_slice();
if inner_slice_a == 1 && inner_slice_b == 1 {
for i in 0..num_elm {
a_slice[i] += b_slice[i];
}
} else {
for i in 0..num_elm {
a_slice[i * inner_slice_a] += b_slice[i * inner_slice_b];
}
}
}
fn add_1d_scalar_cpu<T: Num, D: DimTrait>(a: &mut Matrix<ViewMutMem<T>, D>, b: T) {
let num_elm = a.shape().num_elm();
let inner_slice_a = a.stride()[a.shape().len() - 1];
let a_slice =
unsafe { std::slice::from_raw_parts_mut(a.as_mut_ptr(), num_elm * inner_slice_a) };
if inner_slice_a == 1 {
for item in a_slice.iter_mut() {
*item += b;
}
} else {
for i in 0..num_elm {
a_slice[i * inner_slice_a] += b;
}
}
}
fn add_assign_matrix_scalar<T, D>(to: Matrix<ViewMutMem<T>, D>, other: T)
where
T: Num,
D: DimTrait,
{
match to.shape().slice() {
[] => {
let mut to = to;
unsafe { to.as_mut_ptr().write(*to.as_ptr() + other) }
}
[_] => {
let mut to: Matrix<ViewMutMem<T>, Dim1> = matrix_into_dim(to);
add_1d_scalar_cpu(&mut to, other);
}
[a, _] => {
let mut to: Matrix<ViewMutMem<T>, Dim2> = matrix_into_dim(to);
for i in 0..*a {
let to = to.index_axis_mut_dyn(Index0D::new(i));
add_assign_matrix_scalar(to, other);
}
}
[a, _, _] => {
let mut to: Matrix<ViewMutMem<T>, Dim3> = matrix_into_dim(to);
for i in 0..*a {
let to = to.index_axis_mut_dyn(Index0D::new(i));
add_assign_matrix_scalar(to, other);
}
}
[a, _, _, _] => {
let mut to: Matrix<ViewMutMem<T>, Dim4> = matrix_into_dim(to);
for i in 0..*a {
let to = to.index_axis_mut_dyn(Index0D::new(i));
add_assign_matrix_scalar(to, other);
}
}
_ => panic!("not implemented: this is bug. please report this bug."),
}
}
fn add_matrix_scalar<T, D>(to: Matrix<ViewMutMem<T>, D>, lhs: Matrix<ViewMem<T>, D>, rhs: T)
where
T: Num,
D: DimTrait,
{
assert_eq!(to.shape(), lhs.shape());
let mut to = to.into_dyn_dim();
let lhs = lhs.into_dyn_dim();
to.copy_from(&lhs);
add_assign_matrix_scalar(to, rhs);
}
fn add_assign_matrix_matrix<T, D1, D2>(
source: Matrix<ViewMutMem<T>, D1>,
other: Matrix<ViewMem<T>, D2>,
) where
T: Num,
D1: DimTrait,
D2: DimTrait,
{
let mut source = source.into_dyn_dim();
let other = other.into_dyn_dim();
assert!(source.shape().is_include(&other.shape()));
if source.shape().is_empty() {
unsafe {
source
.as_mut_ptr()
.write(*source.as_ptr() + *other.as_ptr());
}
return;
}
if other.shape().is_empty() {
let s = unsafe { *other.as_ptr() };
add_assign_matrix_scalar(source, s);
return;
}
if source.shape().len() == 1 {
add_assign_1d_1d_cpu(&mut source, &other);
} else if source.shape() == other.shape() {
macro_rules! same_dim {
($dim:ty) => {{
let mut source: Matrix<ViewMutMem<T>, $dim> = matrix_into_dim(source);
let other: Matrix<ViewMem<T>, $dim> = matrix_into_dim(other);
for i in 0..source.shape()[0] {
let source = source.index_axis_mut_dyn(Index0D::new(i));
let other = other.index_axis_dyn(Index0D::new(i));
add_assign_matrix_matrix(source, other);
}
}};
}
match source.shape().len() {
2 => same_dim!(Dim2),
3 => same_dim!(Dim3),
4 => same_dim!(Dim4),
_ => panic!("not implemented: this is bug. please report this bug."),
}
} else {
macro_rules! diff_dim {
($dim1:ty, $dim2:ty) => {{
let mut source: Matrix<ViewMutMem<T>, $dim1> = matrix_into_dim(source);
let other: Matrix<ViewMem<T>, $dim2> = matrix_into_dim(other);
for i in 0..source.shape()[0] {
let source = source.index_axis_mut_dyn(Index0D::new(i));
let other = other.clone();
add_assign_matrix_matrix(source, other);
}
}};
}
match (source.shape().len(), other.shape().len()) {
(2, 1) => diff_dim!(Dim2, Dim1),
(3, 1) => diff_dim!(Dim3, Dim1),
(4, 1) => diff_dim!(Dim4, Dim1),
(3, 2) => diff_dim!(Dim3, Dim2),
(4, 2) => diff_dim!(Dim4, Dim2),
(4, 3) => diff_dim!(Dim4, Dim3),
_ => panic!("not implemented: this is bug. please report this bug."),
}
}
}
fn add_matrix_matrix<T, D1, D2, D3>(
to: Matrix<ViewMutMem<T>, D1>,
lhs: Matrix<ViewMem<T>, D2>,
rhs: Matrix<ViewMem<T>, D3>,
) where
T: Num,
D1: DimTrait,
D2: DimTrait,
D3: DimTrait,
{
if lhs.shape().len() < rhs.shape().len() {
add_matrix_matrix(to, rhs, lhs);
return;
}
assert_eq!(to.shape().slice(), lhs.shape().slice());
let mut to = to.into_dyn_dim();
let lhs = lhs.into_dyn_dim();
to.copy_from(&lhs);
add_assign_matrix_matrix(to, rhs);
}
pub trait MatrixAdd<Rhs, Lhs>: ViewMutMatix + MatrixBase {
fn add(self, lhs: Rhs, rhs: Lhs);
}
pub trait MatrixAddAssign<Rhs>: ViewMutMatix + MatrixBase {
fn add_assign(self, rhs: Rhs);
}
impl<'a, 'b, T, D> MatrixAdd<Matrix<ViewMem<'a, T>, D>, T> for Matrix<ViewMutMem<'b, T>, D>
where
T: Num,
D: DimTrait,
{
fn add(self, lhs: Matrix<ViewMem<T>, D>, rhs: T) {
add_matrix_scalar(self, lhs, rhs);
}
}
impl<'a, 'b, 'c, T, D1, D2> MatrixAdd<Matrix<ViewMem<'a, T>, D1>, Matrix<ViewMem<'b, T>, D2>>
for Matrix<ViewMutMem<'c, T>, D1>
where
T: Num,
D1: DimTrait,
D2: DimTrait,
{
fn add(self, lhs: Matrix<ViewMem<T>, D1>, rhs: Matrix<ViewMem<T>, D2>) {
add_matrix_matrix(self, lhs, rhs);
}
}
impl<'a, T, D> MatrixAddAssign<T> for Matrix<ViewMutMem<'a, T>, D>
where
T: Num,
D: DimTrait,
{
fn add_assign(self, rhs: T) {
add_assign_matrix_scalar(self, rhs);
}
}
impl<'a, 'b, T, D1, D2> MatrixAddAssign<Matrix<ViewMem<'a, T>, D1>>
for Matrix<ViewMutMem<'b, T>, D2>
where
T: Num,
D1: DimTrait,
D2: DimTrait,
{
fn add_assign(self, rhs: Matrix<ViewMem<T>, D1>) {
add_assign_matrix_matrix(self, rhs);
}
}
#[cfg(test)]
mod add {
use crate::{
matrix::{IndexItem, MatrixSlice, OwnedMatrix, ToViewMatrix, ToViewMutMatrix},
matrix_impl::{OwnedMatrix0D, OwnedMatrix1D, OwnedMatrix2D, OwnedMatrix3D, OwnedMatrixDyn},
operation::zeros::Zeros,
slice,
};
use super::*;
#[test]
fn add_dyn_dyn() {
let a = OwnedMatrix1D::from_vec(vec![1.0, 2.0, 3.0], [3]);
let b = OwnedMatrix1D::from_vec(vec![1.0, 2.0, 3.0], [3]);
let ans = OwnedMatrix1D::<f32>::zeros([3]);
let a = a.into_dyn_dim();
let b = b.into_dyn_dim();
let mut ans = ans.into_dyn_dim();
ans.to_view_mut().add(a.to_view(), b.to_view());
}
#[test]
fn add_1d_scalar() {
let a = OwnedMatrix1D::from_vec(vec![1.0, 2.0, 3.0], [3]);
let mut ans = OwnedMatrix1D::<f32>::zeros([3]);
let b = OwnedMatrix0D::from_vec(vec![2.0], []);
ans.to_view_mut().add(a.to_view(), b.to_view());
assert_eq!(ans.index_item([0]), 3.0);
assert_eq!(ans.index_item([1]), 4.0);
assert_eq!(ans.index_item([2]), 5.0);
}
#[test]
fn add_1d_scalar_default_stride() {
let a = OwnedMatrix1D::from_vec(vec![1.0, 2.0, 3.0], [3]);
let mut ans = OwnedMatrix1D::<f32>::zeros([3]);
ans.to_view_mut().add(a.to_view(), 1.0);
assert_eq!(ans.index_item([0]), 2.0);
assert_eq!(ans.index_item([1]), 3.0);
assert_eq!(ans.index_item([2]), 4.0);
}
#[test]
fn add_1d_scalar_sliced() {
let a = OwnedMatrix1D::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [6]);
let mut ans = OwnedMatrix1D::<f32>::zeros([3]);
let sliced = a.slice(slice!(..;2));
ans.to_view_mut().add(sliced.to_view(), 1.0);
assert_eq!(ans.index_item([0]), 2.0);
assert_eq!(ans.index_item([1]), 4.0);
assert_eq!(ans.index_item([2]), 6.0);
}
#[test]
fn add_3d_scalar_sliced() {
let a = OwnedMatrix3D::from_vec(
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0,
],
[3, 3, 4],
);
let mut ans = OwnedMatrix3D::<f32>::zeros([3, 3, 2]);
let sliced = a.slice(slice!(.., .., ..;2));
ans.to_view_mut().add(sliced.to_view(), 1.0);
assert_eq!(ans.index_item([0, 0, 0]), 2.0);
assert_eq!(ans.index_item([0, 0, 1]), 4.0);
assert_eq!(ans.index_item([0, 1, 0]), 6.0);
assert_eq!(ans.index_item([0, 1, 1]), 8.0);
assert_eq!(ans.index_item([0, 2, 0]), 10.0);
assert_eq!(ans.index_item([0, 2, 1]), 12.0);
assert_eq!(ans.index_item([1, 0, 0]), 14.0);
assert_eq!(ans.index_item([1, 0, 1]), 16.0);
assert_eq!(ans.index_item([1, 1, 0]), 18.0);
assert_eq!(ans.index_item([1, 1, 1]), 20.0);
assert_eq!(ans.index_item([1, 2, 0]), 22.0);
assert_eq!(ans.index_item([1, 2, 1]), 24.0);
assert_eq!(ans.index_item([2, 0, 0]), 26.0);
assert_eq!(ans.index_item([2, 0, 1]), 28.0);
assert_eq!(ans.index_item([2, 1, 0]), 30.0);
assert_eq!(ans.index_item([2, 1, 1]), 32.0);
assert_eq!(ans.index_item([2, 2, 0]), 34.0);
assert_eq!(ans.index_item([2, 2, 1]), 36.0);
}
#[test]
fn add_1d_1d_default_stride() {
let a = OwnedMatrix1D::from_vec(vec![1.0, 2.0, 3.0], [3]);
let b = OwnedMatrix1D::from_vec(vec![1.0, 2.0, 3.0], [3]);
let mut ans = OwnedMatrix1D::<f32>::zeros([3]);
ans.to_view_mut().add(a.to_view(), b.to_view());
assert_eq!(ans.index_item([0]), 2.0);
assert_eq!(ans.index_item([1]), 4.0);
assert_eq!(ans.index_item([2]), 6.0);
}
#[test]
fn add_1d_1d_sliced() {
let a = OwnedMatrix1D::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [6]);
let b = OwnedMatrix1D::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [6]);
let mut ans = OwnedMatrix1D::<f32>::zeros([3]);
let sliced_a = a.slice(slice!(..;2));
let sliced_b = b.slice(slice!(1..;2));
ans.to_view_mut()
.add(sliced_a.to_view(), sliced_b.to_view());
assert_eq!(ans.index_item([0]), 3.0);
assert_eq!(ans.index_item([1]), 7.0);
assert_eq!(ans.index_item([2]), 11.0);
}
#[test]
fn add_2d_1d_default() {
let a = OwnedMatrix2D::from_vec(
vec![
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16.,
],
[4, 4],
);
let b = OwnedMatrix1D::from_vec(vec![1., 2., 3., 4., 5., 6., 7., 8.], [8]);
let mut ans = OwnedMatrix2D::<f32>::zeros([2, 2]);
let sliced_a = a.slice(slice!(..2, ..2));
let sliced_b = b.slice(slice!(..2));
ans.to_view_mut()
.add(sliced_a.to_view(), sliced_b.to_view());
assert_eq!(ans.index_item([0, 0]), 2.0);
assert_eq!(ans.index_item([0, 1]), 4.0);
assert_eq!(ans.index_item([1, 0]), 6.0);
assert_eq!(ans.index_item([1, 1]), 8.0);
}
#[test]
fn add_3d_1d_sliced() {
let mut v = Vec::new();
let num_elm = 4 * 4 * 4;
for i in 0..num_elm {
v.push(i as f32);
}
let a = OwnedMatrix3D::from_vec(v, [4, 4, 4]);
let b = OwnedMatrix1D::from_vec(vec![1., 2., 3., 4.], [4]);
let mut ans = OwnedMatrix3D::<f32>::zeros([2, 2, 2]);
let sliced_a = a.slice(slice!(..2, 1..;2, ..2));
let sliced_b = b.slice(slice!(..2));
ans.to_view_mut()
.add(sliced_a.to_view(), sliced_b.to_view());
assert_eq!(ans.index_item([0, 0, 0]), 5.);
assert_eq!(ans.index_item([0, 0, 1]), 7.);
assert_eq!(ans.index_item([0, 1, 0]), 13.);
assert_eq!(ans.index_item([0, 1, 1]), 15.);
assert_eq!(ans.index_item([1, 0, 0]), 21.);
assert_eq!(ans.index_item([1, 0, 1]), 23.);
assert_eq!(ans.index_item([1, 1, 0]), 29.);
assert_eq!(ans.index_item([1, 1, 1]), 31.);
}
#[test]
fn add_2d_2d_default() {
let a = OwnedMatrix2D::from_vec(
vec![
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16.,
],
[4, 4],
);
let b = OwnedMatrix2D::from_vec(
vec![
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16.,
],
[4, 4],
);
let mut ans = OwnedMatrix2D::<f32>::zeros([4, 4]);
ans.to_view_mut().add(a.to_view(), b.to_view());
assert_eq!(ans.index_item([0, 0]), 2.0);
assert_eq!(ans.index_item([0, 1]), 4.0);
assert_eq!(ans.index_item([0, 2]), 6.0);
assert_eq!(ans.index_item([0, 3]), 8.0);
assert_eq!(ans.index_item([1, 0]), 10.0);
assert_eq!(ans.index_item([1, 1]), 12.0);
assert_eq!(ans.index_item([1, 2]), 14.0);
assert_eq!(ans.index_item([1, 3]), 16.0);
assert_eq!(ans.index_item([2, 0]), 18.0);
assert_eq!(ans.index_item([2, 1]), 20.0);
assert_eq!(ans.index_item([2, 2]), 22.0);
assert_eq!(ans.index_item([2, 3]), 24.0);
assert_eq!(ans.index_item([3, 0]), 26.0);
assert_eq!(ans.index_item([3, 1]), 28.0);
assert_eq!(ans.index_item([3, 2]), 30.0);
assert_eq!(ans.index_item([3, 3]), 32.0);
}
#[test]
fn add_2d_0d() {
let a = OwnedMatrix2D::from_vec(
vec![
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16.,
],
[4, 4],
);
let b = OwnedMatrix0D::from_vec(vec![1.], []);
let mut ans = OwnedMatrix2D::<f32>::zeros([4, 4]);
ans.to_view_mut().add(a.to_view(), b.to_view());
assert_eq!(ans.index_item([0, 0]), 2.0);
assert_eq!(ans.index_item([0, 1]), 3.0);
assert_eq!(ans.index_item([0, 2]), 4.0);
assert_eq!(ans.index_item([0, 3]), 5.0);
assert_eq!(ans.index_item([1, 0]), 6.0);
assert_eq!(ans.index_item([1, 1]), 7.0);
assert_eq!(ans.index_item([1, 2]), 8.0);
assert_eq!(ans.index_item([1, 3]), 9.0);
assert_eq!(ans.index_item([2, 0]), 10.0);
assert_eq!(ans.index_item([2, 1]), 11.0);
assert_eq!(ans.index_item([2, 2]), 12.0);
assert_eq!(ans.index_item([2, 3]), 13.0);
assert_eq!(ans.index_item([3, 0]), 14.0);
assert_eq!(ans.index_item([3, 1]), 15.0);
assert_eq!(ans.index_item([3, 2]), 16.0);
assert_eq!(ans.index_item([3, 3]), 17.0);
}
#[test]
fn add_2d_0d_dyn() {
let a = OwnedMatrixDyn::from_vec(
vec![
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16.,
],
[4, 4],
);
let b = OwnedMatrixDyn::from_vec(vec![1.], []);
let mut ans = OwnedMatrixDyn::<f32>::zeros([4, 4]);
ans.to_view_mut().add(a.to_view(), b.to_view());
assert_eq!(ans.index_item([0, 0]), 2.0);
assert_eq!(ans.index_item([0, 1]), 3.0);
assert_eq!(ans.index_item([0, 2]), 4.0);
assert_eq!(ans.index_item([0, 3]), 5.0);
assert_eq!(ans.index_item([1, 0]), 6.0);
assert_eq!(ans.index_item([1, 1]), 7.0);
assert_eq!(ans.index_item([1, 2]), 8.0);
assert_eq!(ans.index_item([1, 3]), 9.0);
assert_eq!(ans.index_item([2, 0]), 10.0);
assert_eq!(ans.index_item([2, 1]), 11.0);
assert_eq!(ans.index_item([2, 2]), 12.0);
assert_eq!(ans.index_item([2, 3]), 13.0);
assert_eq!(ans.index_item([3, 0]), 14.0);
assert_eq!(ans.index_item([3, 1]), 15.0);
assert_eq!(ans.index_item([3, 2]), 16.0);
assert_eq!(ans.index_item([3, 3]), 17.0);
}
#[test]
fn add_4d_2d_dyn() {
let zeros_4d = OwnedMatrixDyn::<f32>::zeros([2, 2, 2, 2]);
let ones_2d = OwnedMatrixDyn::from_vec(vec![1., 1., 1., 1.], [2, 2]);
let mut ans = OwnedMatrixDyn::<f32>::zeros([2, 2, 2, 2]);
ans.to_view_mut().add(zeros_4d.to_view(), ones_2d.to_view());
}
}