1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
//! Error types for matrix and vector operations.
//!
//! This contains two kind of errors.
//! [`InvalidPositions`](InvalidPositions) represents errors
//! when building a vector or matrix with invalid positions.
//! [`IncompatibleDimensions`](IncompatibleDimensions) represents errors
//! when two objects have incompatible dimensions for a given operations
//! such as addition or multiplication.

use is_sorted::IsSorted;
use itertools::Itertools;
use std::fmt;

/// An error to represent invalid positions in a vector or matrix.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum InvalidPositions {
    Unsorted,
    OutOfBound,
    Duplicated,
}

impl fmt::Display for InvalidPositions {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            InvalidPositions::Unsorted => "some positions are not sorted".fmt(f),
            InvalidPositions::OutOfBound => "some positions are out of bound".fmt(f),
            InvalidPositions::Duplicated => "some positions are duplicated".fmt(f),
        }
    }
}

impl std::error::Error for InvalidPositions {}

pub(crate) fn validate_positions(
    length: usize,
    positions: &[usize],
) -> Result<(), InvalidPositions> {
    for position in positions.iter() {
        if *position >= length {
            return Result::Err(InvalidPositions::OutOfBound);
        }
    }
    if !IsSorted::is_sorted(&mut positions.iter()) {
        return Result::Err(InvalidPositions::Unsorted);
    }
    if positions.iter().unique().count() != positions.len() {
        return Result::Err(InvalidPositions::Duplicated);
    }
    Ok(())
}

/// An error to represent incompatible dimensions
/// in matrix and vector operations.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub struct IncompatibleDimensions<DL, DR> {
    left_dimensions: DL,
    right_dimensions: DR,
}

impl<DL, DR> IncompatibleDimensions<DL, DR> {
    pub fn new(left_dimensions: DL, right_dimensions: DR) -> Self {
        Self {
            left_dimensions,
            right_dimensions,
        }
    }
}

pub type VecVecIncompatibleDimensions = IncompatibleDimensions<usize, usize>;
pub type MatVecIncompatibleDimensions = IncompatibleDimensions<(usize, usize), usize>;
pub type VecMatIncompatibleDimensions = IncompatibleDimensions<usize, (usize, usize)>;
pub type MatMatIncompatibleDimensions = IncompatibleDimensions<(usize, usize), (usize, usize)>;

macro_rules! impl_dim_error {
    ($dl:ty, $dr:ty) => {
        impl fmt::Display for IncompatibleDimensions<$dl, $dr> {
            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
                format!(
                    "incompatible dimensions {:?} and {:?}",
                    self.left_dimensions, self.right_dimensions
                )
                .fmt(f)
            }
        }

        impl std::error::Error for IncompatibleDimensions<$dl, $dr> {}
    };
}

impl_dim_error!(usize, usize); // Vec - Vec
impl_dim_error!((usize, usize), usize); // Mat - Vec
impl_dim_error!(usize, (usize, usize)); // Vec - Mat
impl_dim_error!((usize, usize), (usize, usize)); // Mat - Mat