1use core::iter::Sum;
4use core::ops::{Deref, DerefMut, Mul};
5
6use stride::Stride;
7
8#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
14#[repr(transparent)]
15pub struct Row<const M: usize, const N: usize, T> {
16 data: Stride<T, M>,
17}
18
19impl<T, const M: usize, const N: usize> Row<M, N, T> {
20 pub(crate) fn new(data: &[T]) -> &Self {
21 unsafe { &*(data as *const [T] as *const Self) }
23 }
24
25 pub(crate) fn new_mut(data: &mut [T]) -> &mut Self {
26 unsafe { &mut *(data as *mut [T] as *mut Self) }
28 }
29}
30
31impl<T, const M: usize, const N: usize> Deref for Row<M, N, T> {
32 type Target = Stride<T, M>;
33
34 fn deref(&self) -> &Self::Target {
35 &self.data
36 }
37}
38
39impl<T, const M: usize, const N: usize> DerefMut for Row<M, N, T> {
40 fn deref_mut(&mut self) -> &mut Self::Target {
41 &mut self.data
42 }
43}
44
45impl<T, U, const M: usize, const N: usize, const S: usize> PartialEq<Stride<U, S>> for Row<M, N, T>
46where
47 T: PartialEq<U>,
48{
49 fn eq(&self, other: &Stride<U, S>) -> bool {
50 self.data.eq(other)
51 }
52}
53
54impl<T, U, const M: usize, const N: usize> PartialEq<[U]> for Row<M, N, T>
55where
56 T: PartialEq<U>,
57{
58 fn eq(&self, other: &[U]) -> bool {
59 self.data.eq(other)
60 }
61}
62
63impl<T, U, const M: usize, const N: usize, const P: usize> PartialEq<[U; P]> for Row<M, N, T>
64where
65 T: PartialEq<U>,
66{
67 fn eq(&self, other: &[U; P]) -> bool {
68 self.data.eq(other)
69 }
70}
71
72#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
78#[repr(transparent)]
79pub struct Column<const M: usize, const N: usize, T> {
80 data: Stride<T, 1>,
81}
82
83impl<T, const M: usize, const N: usize> Column<M, N, T> {
84 pub(crate) fn new(data: &[T]) -> &Self {
85 unsafe { &*(data as *const [T] as *const Self) }
87 }
88
89 pub(crate) fn new_mut(data: &mut [T]) -> &mut Self {
90 unsafe { &mut *(data as *mut [T] as *mut Self) }
92 }
93}
94
95impl<T, const M: usize, const N: usize> Deref for Column<M, N, T> {
96 type Target = Stride<T, 1>;
97
98 fn deref(&self) -> &Self::Target {
99 &self.data
100 }
101}
102
103impl<T, const M: usize, const N: usize> DerefMut for Column<M, N, T> {
104 fn deref_mut(&mut self) -> &mut Self::Target {
105 &mut self.data
106 }
107}
108
109impl<T, U, const M: usize, const N: usize, const S: usize> PartialEq<Stride<U, S>>
110 for Column<M, N, T>
111where
112 T: PartialEq<U>,
113{
114 fn eq(&self, other: &Stride<U, S>) -> bool {
115 self.data.eq(other)
116 }
117}
118
119impl<T, U, const M: usize, const N: usize> PartialEq<[U]> for Column<M, N, T>
120where
121 T: PartialEq<U>,
122{
123 fn eq(&self, other: &[U]) -> bool {
124 self.data.eq(other)
125 }
126}
127
128impl<T, U, const M: usize, const N: usize, const P: usize> PartialEq<[U; P]> for Column<M, N, T>
129where
130 T: PartialEq<U>,
131{
132 fn eq(&self, other: &[U; P]) -> bool {
133 self.data.eq(other)
134 }
135}
136
137impl<T, const M: usize, const N: usize> Row<M, N, T> {
142 #[inline]
143 pub fn dot<const P: usize>(&self, other: &Column<N, P, T>) -> T
144 where
145 T: Copy + Mul<Output = T> + Sum,
146 {
147 (0..N).map(|i| self[i] * other[i]).sum()
148 }
149
150 #[inline]
152 pub fn dot_partial<const P: usize>(
153 &self,
154 other: &Column<N, P, T>,
155 range: core::ops::Range<usize>,
156 ) -> T
157 where
158 T: Copy + Mul<Output = T> + Sum,
159 {
160 (0..N)
161 .skip(range.start)
162 .take(range.count())
163 .map(|i| self[i] * other[i])
164 .sum()
165 }
166}
167
168#[test]
169fn iter() {
170 use super::*;
171 let m = matrix![
172 1.0, 2.0, 3.0, 4.0;
173 5.0, 6.0, 7.0, 8.0;
174 ];
175 let mut r = m.row(1).get(1..3).unwrap().iter();
176 assert_eq!(r.next(), Some(&6.0));
177 assert_eq!(r.next(), Some(&7.0));
178 assert_eq!(r.next(), None);
179
180 let mut c = m.column(2).get(0..2).unwrap().iter();
181 assert_eq!(c.next(), Some(&3.0));
182 assert_eq!(c.next(), Some(&7.0));
183 assert_eq!(c.next(), None);
184}
185
186#[test]
187fn dot_partial() {
188 use super::*;
189 let m = matrix![
190 1.0, 2.0, 3.0, 4.0;
191 5.0, 6.0, 7.0, 8.0;
192 9.0, 10.0, 12.0, 13.0;
193 14.0, 15.0, 16.0, 17.0;
194 ];
195 let d = m.row(1).dot_partial(m.column(2), 1..3);
196 assert_eq!(d, 126.0);
197}