unmtx_gpu/lib.rs
1//
2// Copyright (c) 2025 Łukasz Szpakowski
3//
4// This Source Code Form is subject to the terms of the Mozilla Public
5// License, v. 2.0. If a copy of the MPL was not distributed with this
6// file, You can obtain one at https://mozilla.org/MPL/2.0/.
7//
8//! Micro neural matrix library for GPU is small library that operates on matrices.
9//!
10//! This library uses GPU by the following computing platforms:
11//!
12//! - OpenCL
13//! - CUDA
14//!
15//! If this library uses CUDA, this library can use the cuBLAS library to multiplication of
16//! matrices.
17//!
18//! A frontend-backend architecture is used by this library. The frontend of this library can use
19//! one of two backends (OpenCL or CUDA). These backend allows to use GPUs by the computing
20//! platforms. The frontend and the backend can have many instances. This library provides a
21//! high-level interfece to operations of matrices by the frontend and methods of a [`Matrix`]
22//! structure.
23//!
24//! # Examples
25//!
26//! ```
27//! # use unmtx_gpu::*;
28//! let a = matrix![
29//! [1.0, 2.0],
30//! [3.0, 4.0]
31//! ];
32//! let x = matrix![
33//! [5.0],
34//! [6.0]
35//! ];
36//! let b = matrix![
37//! [7.0],
38//! [8.0]
39//! ];
40//! let c = a * x + b;
41//! assert_eq!(vec![1.0 * 5.0 + 2.0 * 6.0 + 7.0, 3.0 * 5.0 + 4.0 * 6.0 + 8.0], c.elems());
42//! ```
43use std::ops::Add;
44use std::ops::AddAssign;
45use std::ops::Sub;
46use std::ops::SubAssign;
47use std::ops::Mul;
48use std::ops::MulAssign;
49use std::ops::Div;
50use std::ops::DivAssign;
51use std::error;
52use std::fmt;
53use std::result;
54use std::sync::Arc;
55use std::sync::Mutex;
56use std::sync::MutexGuard;
57
58#[cfg(feature = "opencl")]
59pub mod opencl;
60#[cfg(feature = "cuda")]
61pub mod cuda;
62
63/// A backend trait.
64///
65/// The backend provides a low-level interface to computing platform (OpenCL or CUDA) for basic
66/// operations and functions on matrices. The backend methods operate on backend arrays which
67/// refers to areas of the device memory. The backend is low-level layer between a frontend and
68/// computing platform.
69pub trait Backend
70{
71 /// Returns the backend name.
72 fn name(&self) -> &'static str;
73
74 /// Returns `true` if the backend uses cuBLAS, otherwise `false`.
75 fn has_cublas(&self) -> bool;
76
77 /// Allocates a backend array.
78 unsafe fn alloc(&self, n: usize) -> Result<BackendArray>;
79
80 /// Allocates a backend array and stores zeros in the backend array.
81 fn alloc_and_store_zeros(&self, n: usize) -> Result<BackendArray>;
82
83 /// Allocates a backend array and stores the elements in the backend array.
84 fn alloc_and_store(&self, elems: &[f32]) -> Result<BackendArray>;
85
86 /// Loads elements from the backenc array.
87 fn load(&self, a: &BackendArray, elems: &mut [f32]) -> Result<()>;
88
89 /// Stores elements in the backend array.
90 fn store(&self, a: &BackendArray, elems: &[f32]) -> Result<()>;
91
92 /// Copies the `a` backend array to the `b` backend array.
93 fn copy(&self, a: &BackendArray, b: &BackendArray) -> Result<()>;
94
95 /// Transposes the `a` matrix and then the result is in the `b` matrix
96 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
97 fn transpose_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
98
99 /// Adds the `a` matrix onto the `b` matrix and then the result is in the `c` matrix
100 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>+</mo><mi mathvariant="bold">B</mi></mrow></math>).
101 fn add_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
102
103 /// Adds the transposed `a` matrix onto the `b` matrix and then the result is in the `c` matrix
104 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>+</mo><mi mathvariant="bold">B</mi></mrow></math>).
105 fn add_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
106
107 /// Adds the `a` matrix onto the transposed `b` matrix and then the result is in the `c` matrix
108 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>+</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
109 fn add_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
110
111 /// Adds the transposed `a` matrix onto the transposed `b` matrix and then the result is in the
112 /// `c` matrix
113 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>+</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
114 fn add_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
115
116 /// Subtracts the `b` matrix from the `a` matrix and then the result is in the `c` matrix
117 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>-</mo><mi mathvariant="bold">B</mi></mrow></math>).
118 fn sub_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
119
120 /// Subtracts the `b` matrix from the transposed `a` matrix and then the result is in the `c`
121 /// matrix
122 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>-</mo><mi mathvariant="bold">B</mi></mrow></math>).
123 fn sub_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
124
125 /// Subtracts the transposed `b` matrix from the `a` matrix and then the result is in the `c`
126 /// matrix
127 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>-</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
128 fn sub_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
129
130 /// Subtracts the transposed `b` matrix from the transposed `a` matrix and then the result is
131 /// in the `c` matrix
132 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>-</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
133 fn sub_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
134
135 /// Multiplies the `a` matrix by the `b` matrix and then the result is in the `c` matrix
136 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>·</mo><mi mathvariant="bold">B</mi></mrow></math>).
137 fn mul_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>;
138
139 /// Multiplies the transposed `a` matrix by the `b` matrix and then the result is in the `c`
140 /// matrix
141 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>·</mo><mi mathvariant="bold">B</mi></mrow></math>).
142 fn mul_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>;
143
144 /// Multiplies the `a` matrix by the transposed `b` matrix and then the result is in the `c`
145 /// matrix
146 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>·</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
147 fn mul_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>;
148
149 /// Multiplies the transposed `a` matrix by the transposed `b` matrix and then the result is in
150 /// the `c` matrix
151 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>·</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
152 fn mul_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>;
153
154 /// Multiplies the `a` matrix elements by the `b` matrix elements and then the result is in the
155 /// `c` matrix
156 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><mo>·</mo><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mrow></math>).
157 fn mul_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
158
159 /// Multiplies the transposed `a` matrix elements by the `b` matrix elements and saves the
160 /// result to the `c` matrix
161 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi mathvariant="italic">ji</mi></msub><mo>·</mo><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mrow></math>).
162 fn mul_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
163
164 /// Multiplies the `a` matrix elements by the transposed `b` matrix elements and then the
165 /// result is in the `c` matrix
166 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><mo>·</mo><msub><mi>b</mi><mi mathvariant="italic">ji</mi></msub></mrow></math>).
167 fn mul_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
168
169 /// Multiplies the transposed `a` matrix elements by the transposed `b` matrix elements and
170 /// then the result is in the `c` matrix.
171 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi mathvariant="italic">ji</mi></msub><mo>·</mo><msub><mi>b</mi><mi mathvariant="italic">ji</mi></msub></mrow></math>).
172 fn mul_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
173
174 /// Divides the `a` matrix elements by the `b` matrix elements and then the result is in the
175 /// `c` matrix
176 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mfrac></mrow></math>).
177 fn div_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
178
179 /// Divides the transposed `a` matrix elements by the `b` matrix elements and then the result
180 /// is in the `c` matrix
181 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><msub><mi>a</mi><mi mathvariant="italic">ji</mi></msub><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mfrac></mrow></math>).
182 fn div_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
183
184 /// Divides the transposed `a` matrix elements by the `b` matrix elements and then the result
185 /// is in the `c` matrix
186 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><msub><mi>b</mi><mi mathvariant="italic">ji</mi></msub></mfrac></mrow></math>).
187 fn div_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
188
189 /// Divides the transposed `a` matrix elements by the transposed `b` matrix elements and then
190 /// the result is in the `c` matrix
191 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><msub><mi>a</mi><mi mathvariant="italic">ji</mi></msub><msub><mi>b</mi><mi mathvariant="italic">ji</mi></msub></mfrac></mrow></math>).
192 fn div_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
193
194 /// Adds the `a` matrix onto the `b` scalar and then the result is in the `c` matrix
195 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>+</mo><mi>b</mi></mrow></math>).
196 fn add_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
197
198 /// Adds the transposed `a` matrix onto the `b` scalar and then the result is in the `c` matrix
199 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>+</mo><mi>b</mi></mrow></math>).
200 fn add_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
201
202 /// Subtracts the `b` scalar from the `a` matrix and then the result is in the `c` matrix.
203 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>-</mo><mi>b</mi></mrow></math>).
204 fn sub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
205
206 /// Subtracts the `b` scalar from the transposed `a` matrix and then the result is in the `c`
207 /// matrix
208 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>-</mo><mi>b</mi></mrow></math>).
209 fn sub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
210
211 /// Subtracts the `a` matrix from the `b` scalar and then the result is in the `c` matrix
212 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi>b</mi><mo>-</mo><mi mathvariant="bold">A</mi></mrow></math>).
213 fn rsub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
214
215 /// Subtracts the transposed `a` matrix from the `b` scalar and then the result is in the `c`
216 /// matrix
217 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi>b</mi><mo>-</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
218 fn rsub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
219
220 /// Multiplies the `a` matrix by the `b` scalar and then the result is in the `c` matrix
221 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>·</mo><mi>b</mi></mrow></math>).
222 fn mul_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
223
224 /// Multiplies the transposed `a` matrix by the `b` scalar and then the result is in the `c`
225 /// matrix
226 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>·</mo><mi>b</mi></mrow></math>).
227 fn mul_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
228
229 /// Divides the `a` matrix by the `b` scalar and then the result is in the `c` matrix
230 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mfrac><mi mathvariant="bold">A</mi><mi>b</mi></mfrac></mrow></math>).
231 fn div_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
232
233 /// Divides the transposed `a` matrix by the `b` scalar and then the result is in the `c`
234 /// matrix
235 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mfrac><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mi>b</mi></mfrac></mrow></math>).
236 fn div_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
237
238 /// Divides the `b` scalar by the `a` matrix elements and then the result is in the `c` matrix
239 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><mi>b</mi><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub></mfrac></mrow></math>).
240 fn rdiv_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
241
242 /// Divides the `b` scalar by the transposed `a` matrix elements and then the result is in the
243 /// `c` matrix
244 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><mi>b</mi><msub><mi>a</mi><mi mathvariant="italic">ji</mi></msub></mfrac></mrow></math>).
245 fn rdiv_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
246
247 /// Calculates sigmoid function for the `a` matrix adn the result is the `b` matrix
248 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>sigmoid</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
249 fn sigmoid_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
250
251 /// Calculates sigmoid function for the transposed `a` matrix and then the result is in the `b`
252 /// matrix
253 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>sigmoid</mi><mo fence="true">(</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo fence="true">)</mo></mrow></math>).
254 fn sigmoid_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
255
256 /// Calculates hyperbolic tangent function for the `a` matrix and then the result is in `b`
257 /// matrix
258 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>tanh</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
259 fn tanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
260
261 /// Calculates hyperbolic tangent function for the transposed `a` matrix and then the result is
262 /// in the `b` matrix
263 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>tanh</mi><mo fence="true">(</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo fence="true">)</mo></mrow></math>).
264 fn tanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
265
266 /// Calculates softmax function for the `a` matrix and then the result is in the `b` matrix
267 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>softmax</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
268 fn softmax_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
269
270 /// Calculates softmax function for the transposed `a` matrix and then the result is in the `b`
271 /// matrix
272 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>softmax</mi><mo fence="true">(</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo fence="true">)</mo></mrow></math>).
273 fn softmax_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
274
275 /// Repeats the `a` vector as column
276 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi>i</mi></msub></mrow></math>).
277 fn repeat_col_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
278
279 /// Repeats the `a` vector as row
280 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi>j</mi></msub></mrow></math>).
281 fn repeat_row_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
282}
283
284/// An error enumeration.
285#[derive(Debug)]
286pub enum Error
287{
288 /// Can't initialize a default backend.
289 DefaultBackendInitialization,
290 /// Mismatched sizes of matrices for a matrix operation.
291 OpSize(usize, usize, usize, usize),
292 /// Mismatched sizes of matrices for a matrix multiplication.
293 MulSize(usize, usize, usize, usize, usize, usize),
294 /// Mismatched sizes of matrices for a matrix transposition.
295 TransposeSize(usize, usize, usize, usize),
296 /// An argument matrix is transposed.
297 ArgTransposition,
298 /// A result matrix is transposed.
299 ResTransposition,
300 /// A number of matrix elements isn't equal to a number of elements.
301 MatrixElemCount(usize, usize),
302 /// A matrix isn't a vector.
303 IsNotVector,
304 /// A mutex can't be locked.
305 Mutex,
306 /// An OpenCL error.
307 #[cfg(feature = "opencl")]
308 OpenCl(opencl::ClError),
309 /// A CUDA error.
310 #[cfg(feature = "cuda")]
311 Cuda(cuda::DriverError),
312 /// A cuBLAS error.
313 #[cfg(feature = "cuda")]
314 Cublas(cuda::CublasError),
315 /// No a cuBLAS.
316 #[cfg(feature = "cuda")]
317 NoCublas,
318 /// A compilation error.
319 Compilation(String),
320 /// No a platform.
321 NoPlatform,
322 /// No a device.
323 NoDevice,
324 /// No a kernel.
325 NoKernel(String),
326 /// A type of device information is invalid.
327 InvalidDeviceInfoType,
328 /// A number of backend array elements isn't equal to a number of elements.
329 BackendArrayElemCount(usize, usize),
330 /// Two numbers of elements of backend arrays aren't equal.
331 TwoBackendArrayElemCounts(usize, usize),
332 /// A backend array is invalid.
333 InvalidBackendArray,
334}
335
336impl error::Error for Error
337{}
338
339impl fmt::Display for Error
340{
341 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
342 {
343 match self {
344 Error::DefaultBackendInitialization => write!(f, "can't initialize default backend"),
345 Error::OpSize(n1, m1, n2, m2) => write!(f, "mismatched sizes of matrices ({}x{}, {}x{})", n1, m1, n2, m2),
346 Error::MulSize(n1, m1, n2, m2, n3, m3) => write!(f, "mismatched sizes of matrices for multiplication ({}x{}, {}x{}, {}x{})", n1, m1, n2, m2, n3, m3),
347 Error::TransposeSize(n1, m1, n2, m2) => write!(f, "mismatched sizes of matrices for transposition ({}x{}, {}x{})", n1, m1, n2, m2),
348 Error::ArgTransposition => write!(f, "argument matrix is transposed"),
349 Error::ResTransposition => write!(f, "result matrix is transposed"),
350 Error::MatrixElemCount(n1, n2) => write!(f, "number of matrix elements isn't equal to number of elements ({}, {})", n1, n2),
351 Error::IsNotVector => write!(f, "matrix isn't vector"),
352 Error::Mutex => write!(f, "can't lock mutex"),
353 #[cfg(feature = "opencl")]
354 Error::OpenCl(err) => write!(f, "OpenCL error: {}", err),
355 #[cfg(feature = "cuda")]
356 Error::Cuda(err) => write!(f, "CUDA error: {}", err),
357 #[cfg(feature = "cuda")]
358 Error::Cublas(err) => write!(f, "cuBLAS error: {}", err),
359 #[cfg(feature = "cuda")]
360 Error::NoCublas => write!(f, "no cuBLAS"),
361 Error::Compilation(msg) => write!(f, "{}", msg),
362 Error::NoPlatform => write!(f, "no platform"),
363 Error::NoDevice => write!(f, "no device"),
364 Error::NoKernel(name) => write!(f, "no kernel {}", name),
365 Error::InvalidDeviceInfoType => write!(f, "invalid device info type"),
366 Error::BackendArrayElemCount(n1, n2) => write!(f, "number of backend array elements isn't equal to number of elements ({}, {})", n1, n2),
367 Error::TwoBackendArrayElemCounts(n1, n2) => write!(f, "two numbers of elements of backend arrays aren't equal ({}, {})", n1, n2),
368 Error::InvalidBackendArray => write!(f, "invalid backend array"),
369 }
370 }
371}
372
373/// A result type.
374pub type Result<T> = result::Result<T, Error>;
375
376/// An enumeration of backend array.
377///
378/// This enumeration contains the reference to the area of the device memory for computing
379/// platform (OpenCL or CUDA).
380#[derive(Debug)]
381pub enum BackendArray
382{
383 /// A backend array for OpenCL.
384 #[cfg(feature = "opencl")]
385 OpenCl(opencl::ClBackendArray),
386 /// A backend array for CUDA.
387 #[cfg(feature = "cuda")]
388 Cuda(cuda::CudaBackendArray),
389}
390
391static mut DEFAULT_BACKEND: Mutex<Option<Arc<dyn Backend>>> = Mutex::new(None);
392
393fn mutex_lock<T>(mutex: &Mutex<T>) -> Result<MutexGuard<'_, T>>
394{
395 match mutex.lock() {
396 Ok(guard) => Ok(guard),
397 Err(_) => return Err(Error::Mutex),
398 }
399}
400
401/// Returns a default backend.
402pub fn get_default_backend() -> Result<Option<Arc<dyn Backend>>>
403{
404 unsafe {
405 let default_backend_g = mutex_lock(&DEFAULT_BACKEND)?;
406 Ok(default_backend_g.clone())
407 }
408}
409
410/// Sets a default backend.
411pub fn set_default_backend(backend: Arc<dyn Backend>) -> Result<()>
412{
413 unsafe {
414 let mut default_backend_g = mutex_lock(&DEFAULT_BACKEND)?;
415 *default_backend_g = Some(backend);
416 }
417 Ok(())
418}
419
420/// Unsets a default backend.
421pub fn unset_default_backend() -> Result<()>
422{
423 unsafe {
424 let mut default_backend_g = mutex_lock(&DEFAULT_BACKEND)?;
425 *default_backend_g = None;
426 }
427 Ok(())
428}
429
430/// Sets a default backend if the default backend is uninitialized and returns the default backend.
431///
432/// This method takes a closure that returns the backend and then the backend is set as the default
433/// backend if the default backend is uninitialized. The closure is only called if the backend is
434/// to be set.
435pub fn set_default_backend_for_uninitialized<F>(f: F) -> Result<Arc<dyn Backend>>
436 where F: FnOnce() -> Result<Arc<dyn Backend>>
437{
438 unsafe {
439 let mut default_backend_g = mutex_lock(&DEFAULT_BACKEND)?;
440 match &*default_backend_g {
441 Some(default_backend) => Ok(default_backend.clone()),
442 None => {
443 let backend = f()?;
444 *default_backend_g = Some(backend.clone());
445 Ok(backend)
446 },
447 }
448 }
449}
450
451/// Initializes a default backend if the backend is uninitialized and returns the default backend.
452pub fn initialize_default_backend_for_uninitialized() -> Result<Arc<dyn Backend>>
453{
454 #[cfg(feature = "opencl")]
455 let res = set_default_backend_for_uninitialized(|| Ok(Arc::new(opencl::ClBackend::new()?)));
456 #[cfg(all(not(feature = "opencl"), feature = "cuda"))]
457 let res = set_default_backend_for_uninitialized(|| Ok(Arc::new(cuda::CudaBackend::new()?)));
458 #[cfg(all(not(feature = "opencl"), not(feature = "cuda")))]
459 let res: Result<Arc<dyn Backend>> = Err(Error::DefaultBackendInitialization);
460 res
461}
462
463/// Finalizes a default backend.
464pub fn finalize_default_backend() -> Result<()>
465{ unset_default_backend() }
466
467/// Creates a matrix from the arguments.
468///
469/// # Examples
470///
471/// ```
472/// # use unmtx_gpu::*;
473/// let a = matrix![
474/// [1.0, 2.0, 3.0],
475/// [4.0, 5.0, 6.0]
476/// ];
477/// assert_eq!(2, a.row_count());
478/// assert_eq!(3, a.col_count());
479/// assert_eq!(false, a.is_transposed());
480/// assert_eq!(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], a.elems());
481/// ```
482#[macro_export]
483macro_rules! matrix {
484 ($([$($elem:expr),* $(,)*]),* $(,)*) => {
485 $crate::Matrix::new_with_elem_vecs(vec![$(vec![$($elem),*]),*].as_slice())
486 };
487}
488
489/// A matrix structure.
490#[derive(Clone, Debug)]
491pub struct Matrix
492{
493 row_count: usize,
494 col_count: usize,
495 is_transposed: bool,
496 array: Arc<BackendArray>,
497}
498
499impl Matrix
500{
501 /// Creates a matrix with the number of rows and the number of columns.
502 pub fn new(row_count: usize, col_count: usize) -> Self
503 {
504 let frontend = Frontend::new().unwrap();
505 frontend.create_matrix_and_set_zeros(row_count, col_count).unwrap()
506 }
507
508 /// Creates a matrix with the number of rows, the number of columns, and the elements.
509 pub fn new_with_elems(row_count: usize, col_count: usize, elems: &[f32]) -> Self
510 {
511 let frontend = Frontend::new().unwrap();
512 frontend.create_matrix_and_set_elems(row_count, col_count, elems).unwrap()
513 }
514
515 /// Creates a matrix with the vector of rows.
516 pub fn new_with_elem_vecs(elem_vecs: &[Vec<f32>]) -> Self
517 {
518 let frontend = Frontend::new().unwrap();
519 let col_count = match elem_vecs.first() {
520 Some(elems) => elems.len(),
521 None => 0,
522 };
523 for row in elem_vecs {
524 assert_eq!(col_count, row.len());
525 }
526 let row_count = elem_vecs.len();
527 let elems: Vec<f32> = elem_vecs.iter().flatten().map(|e| *e).collect();
528 frontend.create_matrix_and_set_elems(row_count, col_count, elems.as_slice()).unwrap()
529 }
530
531 /// Returns the number of matrix rows.
532 pub fn row_count(&self) -> usize
533 { self.row_count }
534
535 /// Returns the number of matrix columns.
536 pub fn col_count(&self) -> usize
537 { self.col_count }
538
539 /// Returns `true` if the matrix is transposed, otherwise `false`.
540 ///
541 /// This method indeed returns the transpose flag of matrix that is changed by
542 /// [`transpose`](Self::transpose).
543 pub fn is_transposed(&self) -> bool
544 { self.is_transposed }
545
546 /// Returns the matrix elements.
547 pub fn elems(&self) -> Vec<f32>
548 {
549 let frontend = Frontend::new().unwrap();
550 frontend.elems_and_transpose_flag(self).unwrap().0
551 }
552
553 /// Creates a matrix copy.
554 ///
555 /// This method indeed copies the matrix array to a new matrix array.
556 pub fn copy(&self) -> Self
557 {
558 let frontend = Frontend::new().unwrap();
559 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
560 frontend.copy(self, &res).unwrap();
561 res
562 }
563
564 /// Transposes the matrix
565 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
566 ///
567 /// This method doesn't indeed transpose the matrix but changes the transpose flag and
568 /// exchanges the number of matrix rows with the number of matrix columns.
569 ///
570 /// # Examples
571 ///
572 /// ```
573 /// # use unmtx_gpu::*;
574 /// let a = matrix![
575 /// [1.0, 2.0, 3.0],
576 /// [4.0, 5.0, 6.0]
577 /// ];
578 /// let b = a.transpose();
579 /// assert_eq!(3, b.row_count());
580 /// assert_eq!(2, b.col_count());
581 /// assert_eq!(true, b.is_transposed());
582 /// assert_eq!(a.elems(), b.elems());
583 /// let c = b.transpose();
584 /// assert_eq!(2, c.row_count());
585 /// assert_eq!(3, c.col_count());
586 /// assert_eq!(false, c.is_transposed());
587 /// assert_eq!(a.elems(), c.elems());
588 /// ```
589 pub fn transpose(&self) -> Self
590 {
591 Matrix {
592 row_count: self.col_count,
593 col_count: self.row_count,
594 is_transposed: !self.is_transposed,
595 array: self.array.clone(),
596 }
597 }
598
599 /// Indeed transposes the matrix
600 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
601 ///
602 /// This method indeed transposes the matrix without changing the transpose flag.
603 ///
604 /// # Examples
605 ///
606 /// ```
607 /// # use unmtx_gpu::*;
608 /// let a = matrix![
609 /// [1.0, 2.0, 3.0],
610 /// [4.0, 5.0, 6.0]
611 /// ];
612 /// let b = a.really_transpose();
613 /// assert_eq!(3, b.row_count());
614 /// assert_eq!(2, b.col_count());
615 /// assert_eq!(false, b.is_transposed());
616 /// assert_eq!(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], b.elems());
617 /// ```
618 pub fn really_transpose(&self) -> Self
619 {
620 let frontend = Frontend::new().unwrap();
621 let res = unsafe { frontend.create_matrix(self.col_count, self.row_count) }.unwrap();
622 frontend.really_transpose(self, &res).unwrap();
623 res
624 }
625
626 /// Multiplies the matrix elements by the `b` matrix elements
627 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><mo>·</mo><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mrow></math>).
628 ///
629 /// # Examples
630 ///
631 /// ```
632 /// # use unmtx_gpu::*;
633 /// let a = matrix![
634 /// [1.0, 2.0],
635 /// [3.0, 4.0]
636 /// ];
637 /// let b = matrix![
638 /// [5.0, 6.0],
639 /// [7.0, 8.0]
640 /// ];
641 /// let c = a.mul_elems(&b);
642 /// assert_eq!(vec![1.0 * 5.0, 2.0 * 6.0, 3.0 * 7.0, 4.0 * 8.0], c.elems());
643 /// ```
644 pub fn mul_elems(&self, b: &Self) -> Self
645 {
646 let frontend = Frontend::new().unwrap();
647 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
648 frontend.mul_elems(self, b, &res).unwrap();
649 res
650 }
651
652 /// Divides the matrix elements by the `b` matrix elements
653 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mfrac><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mfrac></mrow></math>).
654 ///
655 /// # Examples
656 ///
657 /// ```
658 /// # use unmtx_gpu::*;
659 /// let a = matrix![
660 /// [1.0, 2.0],
661 /// [3.0, 4.0]
662 /// ];
663 /// let b = matrix![
664 /// [5.0, 6.0],
665 /// [7.0, 8.0]
666 /// ];
667 /// let c = a.div_elems(&b);
668 /// let elems = c.elems();
669 /// assert!((1.0 / 5.0 - elems[0]).abs() < 0.001);
670 /// assert!((2.0 / 6.0 - elems[1]).abs() < 0.001);
671 /// assert!((3.0 / 7.0 - elems[2]).abs() < 0.001);
672 /// assert!((4.0 / 8.0 - elems[3]).abs() < 0.001);
673 /// ```
674 pub fn div_elems(&self, b: &Self) -> Self
675 {
676 let frontend = Frontend::new().unwrap();
677 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
678 frontend.div_elems(self, b, &res).unwrap();
679 res
680 }
681
682 /// Subtracts the matrix from the scalar
683 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi>b</mi><mo>-</mo><mi mathvariant="bold">A</mi></mrow></math>).
684 ///
685 /// # Examples
686 ///
687 /// ```
688 /// # use unmtx_gpu::*;
689 /// let a = matrix![
690 /// [1.0, 2.0],
691 /// [3.0, 4.0]
692 /// ];
693 /// let b = a.rsub(10.5);
694 /// assert_eq!(vec![10.5 - 1.0, 10.5 - 2.0, 10.5 - 3.0, 10.5 - 4.0], b.elems());
695 /// ```
696 pub fn rsub(&self, b: f32) -> Self
697 {
698 let frontend = Frontend::new().unwrap();
699 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
700 frontend.rsub_for_scalar(self, b, &res).unwrap();
701 res
702 }
703
704 /// Divides the scalar by the matrix elements
705 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mfrac><mi>b</mi><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub></mfrac></mrow></math>).
706 ///
707 /// # Examples
708 ///
709 /// ```
710 /// # use unmtx_gpu::*;
711 /// let a = matrix![
712 /// [1.0, 2.0],
713 /// [3.0, 4.0]
714 /// ];
715 /// let b = a.rdiv(10.5);
716 /// let elems = b.elems();
717 /// assert!((10.5 / 1.0 - elems[0]).abs() < 0.001);
718 /// assert!((10.5 / 2.0 - elems[1]).abs() < 0.001);
719 /// assert!((10.5 / 3.0 - elems[2]).abs() < 0.001);
720 /// assert!((10.5 / 4.0 - elems[3]).abs() < 0.001);
721 /// ```
722 pub fn rdiv(&self, b: f32) -> Self
723 {
724 let frontend = Frontend::new().unwrap();
725 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
726 frontend.rdiv_for_scalar(self, b, &res).unwrap();
727 res
728 }
729
730 /// Calculates sigmoid function for the matrix
731 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi>sigmoid</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
732 ///
733 /// # Examples
734 ///
735 /// ```
736 /// # use unmtx_gpu::*;
737 /// let a = matrix![
738 /// [1.0, 2.0],
739 /// [3.0, 4.0]
740 /// ];
741 /// let b = a.sigmoid();
742 /// let elems = b.elems();
743 /// assert!((1.0 / (1.0 + (-1.0f32).exp()) - elems[0]).abs() < 0.001);
744 /// assert!((1.0 / (1.0 + (-2.0f32).exp()) - elems[1]).abs() < 0.001);
745 /// assert!((1.0 / (1.0 + (-3.0f32).exp()) - elems[2]).abs() < 0.001);
746 /// assert!((1.0 / (1.0 + (-4.0f32).exp()) - elems[3]).abs() < 0.001);
747 /// ```
748 pub fn sigmoid(&self) -> Self
749 {
750 let frontend = Frontend::new().unwrap();
751 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
752 frontend.sigmoid(self, &res).unwrap();
753 res
754 }
755
756 /// Calculates hiperbolic tangent function for the matrix
757 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi>tanh</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
758 ///
759 /// # Examples
760 ///
761 /// ```
762 /// # use unmtx_gpu::*;
763 /// let a = matrix![
764 /// [1.0, 2.0],
765 /// [3.0, 4.0]
766 /// ];
767 /// let b = a.tanh();
768 /// let elems = b.elems();
769 /// assert!((1.0f32.tanh() - elems[0]).abs() < 0.001);
770 /// assert!((2.0f32.tanh() - elems[1]).abs() < 0.001);
771 /// assert!((3.0f32.tanh() - elems[2]).abs() < 0.001);
772 /// assert!((4.0f32.tanh() - elems[3]).abs() < 0.001);
773 /// ```
774 pub fn tanh(&self) -> Self
775 {
776 let frontend = Frontend::new().unwrap();
777 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
778 frontend.tanh(self, &res).unwrap();
779 res
780 }
781
782 /// Calculates softmax function for the matrix
783 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi>softmax</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
784 ///
785 /// # Examples
786 ///
787 /// ```
788 /// # use unmtx_gpu::*;
789 /// let a = matrix![
790 /// [1.0, 2.0],
791 /// [3.0, 4.0]
792 /// ];
793 /// let b = a.softmax();
794 /// let elems = b.elems();
795 /// let sum1 = 1.0f32.exp() + 3.0f32.exp();
796 /// let sum2 = 2.0f32.exp() + 4.0f32.exp();
797 /// assert!((1.0f32.exp() / sum1 - elems[0]).abs() < 0.001);
798 /// assert!((2.0f32.exp() / sum2 - elems[1]).abs() < 0.001);
799 /// assert!((3.0f32.exp() / sum1 - elems[2]).abs() < 0.001);
800 /// assert!((4.0f32.exp() / sum2 - elems[3]).abs() < 0.001);
801 /// ```
802 pub fn softmax(&self) -> Self
803 {
804 let frontend = Frontend::new().unwrap();
805 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
806 frontend.softmax(self, &res).unwrap();
807 res
808 }
809
810 /// Repeats the vector as column or a row.
811 ///
812 /// # Examples
813 ///
814 /// ```
815 /// # use unmtx_gpu::*;
816 /// let a = matrix![
817 /// [1.0],
818 /// [2.0]
819 /// ];
820 /// let b = a.repeat(3);
821 /// assert_eq!(vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0], b.elems());
822 /// let c = matrix![[1.0, 2.0, 3.0]];
823 /// let d = c.repeat(2);
824 /// assert_eq!(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], d.elems());
825 /// ```
826 pub fn repeat(&self, n: usize) -> Self
827 {
828 assert!(self.col_count == 1 || self.row_count == 1);
829 let frontend = Frontend::new().unwrap();
830 let res = if self.col_count == 1 {
831 unsafe { frontend.create_matrix(self.row_count, n) }.unwrap()
832 } else {
833 unsafe { frontend.create_matrix(n, self.col_count) }.unwrap()
834 };
835 frontend.repeat(self, &res).unwrap();
836 res
837 }
838}
839
840impl Add for Matrix
841{
842 type Output = Self;
843
844 fn add(self, rhs: Self) -> Self::Output
845 {
846 let frontend = Frontend::new().unwrap();
847 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
848 frontend.add(&self, &rhs, &res).unwrap();
849 res
850 }
851}
852
853impl Add<&Matrix> for Matrix
854{
855 type Output = Self;
856
857 fn add(self, rhs: &Matrix) -> Self::Output
858 {
859 let frontend = Frontend::new().unwrap();
860 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
861 frontend.add(&self, rhs, &res).unwrap();
862 res
863 }
864}
865
866impl Add<f32> for Matrix
867{
868 type Output = Self;
869
870 fn add(self, rhs: f32) -> Self::Output
871 {
872 let frontend = Frontend::new().unwrap();
873 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
874 frontend.add_for_scalar(&self, rhs, &res).unwrap();
875 res
876 }
877}
878
879impl Add<&f32> for Matrix
880{
881 type Output = Self;
882
883 fn add(self, rhs: &f32) -> Self::Output
884 {
885 let frontend = Frontend::new().unwrap();
886 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
887 frontend.add_for_scalar(&self, *rhs, &res).unwrap();
888 res
889 }
890}
891
892impl Add<Matrix> for &Matrix
893{
894 type Output = Matrix;
895
896 fn add(self, rhs: Matrix) -> Self::Output
897 {
898 let frontend = Frontend::new().unwrap();
899 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
900 frontend.add(self, &rhs, &res).unwrap();
901 res
902 }
903}
904
905impl Add<&Matrix> for &Matrix
906{
907 type Output = Matrix;
908
909 fn add(self, rhs: &Matrix) -> Self::Output
910 {
911 let frontend = Frontend::new().unwrap();
912 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
913 frontend.add(self, rhs, &res).unwrap();
914 res
915 }
916}
917
918impl Add<f32> for &Matrix
919{
920 type Output = Matrix;
921
922 fn add(self, rhs: f32) -> Self::Output
923 {
924 let frontend = Frontend::new().unwrap();
925 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
926 frontend.add_for_scalar(self, rhs, &res).unwrap();
927 res
928 }
929}
930
931impl Add<&f32> for &Matrix
932{
933 type Output = Matrix;
934
935 fn add(self, rhs: &f32) -> Self::Output
936 {
937 let frontend = Frontend::new().unwrap();
938 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
939 frontend.add_for_scalar(self, *rhs, &res).unwrap();
940 res
941 }
942}
943
944impl AddAssign for Matrix
945{
946 fn add_assign(&mut self, rhs: Self)
947 {
948 let frontend = Frontend::new().unwrap();
949 frontend.add(self, &rhs, &self).unwrap();
950 }
951}
952
953impl AddAssign<&Matrix> for Matrix
954{
955 fn add_assign(&mut self, rhs: &Self)
956 {
957 let frontend = Frontend::new().unwrap();
958 frontend.add(&self, rhs, &self).unwrap();
959 }
960}
961
962impl AddAssign<f32> for Matrix
963{
964 fn add_assign(&mut self, rhs: f32)
965 {
966 let frontend = Frontend::new().unwrap();
967 frontend.add_for_scalar(&self, rhs, &self).unwrap();
968 }
969}
970
971impl AddAssign<&f32> for Matrix
972{
973 fn add_assign(&mut self, rhs: &f32)
974 {
975 let frontend = Frontend::new().unwrap();
976 frontend.add_for_scalar(&self, *rhs, &self).unwrap();
977 }
978}
979
980impl Sub for Matrix
981{
982 type Output = Self;
983
984 fn sub(self, rhs: Self) -> Self::Output
985 {
986 let frontend = Frontend::new().unwrap();
987 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
988 frontend.sub(&self, &rhs, &res).unwrap();
989 res
990 }
991}
992
993impl Sub<&Matrix> for Matrix
994{
995 type Output = Self;
996
997 fn sub(self, rhs: &Matrix) -> Self::Output
998 {
999 let frontend = Frontend::new().unwrap();
1000 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1001 frontend.sub(&self, rhs, &res).unwrap();
1002 res
1003 }
1004}
1005
1006impl Sub<f32> for Matrix
1007{
1008 type Output = Self;
1009
1010 fn sub(self, rhs: f32) -> Self::Output
1011 {
1012 let frontend = Frontend::new().unwrap();
1013 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1014 frontend.sub_for_scalar(&self, rhs, &res).unwrap();
1015 res
1016 }
1017}
1018
1019impl Sub<&f32> for Matrix
1020{
1021 type Output = Self;
1022
1023 fn sub(self, rhs: &f32) -> Self::Output
1024 {
1025 let frontend = Frontend::new().unwrap();
1026 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1027 frontend.sub_for_scalar(&self, *rhs, &res).unwrap();
1028 res
1029 }
1030}
1031
1032impl Sub<Matrix> for &Matrix
1033{
1034 type Output = Matrix;
1035
1036 fn sub(self, rhs: Matrix) -> Self::Output
1037 {
1038 let frontend = Frontend::new().unwrap();
1039 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1040 frontend.sub(self, &rhs, &res).unwrap();
1041 res
1042 }
1043}
1044
1045impl Sub<&Matrix> for &Matrix
1046{
1047 type Output = Matrix;
1048
1049 fn sub(self, rhs: &Matrix) -> Self::Output
1050 {
1051 let frontend = Frontend::new().unwrap();
1052 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1053 frontend.sub(self, rhs, &res).unwrap();
1054 res
1055 }
1056}
1057
1058impl Sub<f32> for &Matrix
1059{
1060 type Output = Matrix;
1061
1062 fn sub(self, rhs: f32) -> Self::Output
1063 {
1064 let frontend = Frontend::new().unwrap();
1065 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1066 frontend.sub_for_scalar(self, rhs, &res).unwrap();
1067 res
1068 }
1069}
1070
1071impl Sub<&f32> for &Matrix
1072{
1073 type Output = Matrix;
1074
1075 fn sub(self, rhs: &f32) -> Self::Output
1076 {
1077 let frontend = Frontend::new().unwrap();
1078 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1079 frontend.sub_for_scalar(self, *rhs, &res).unwrap();
1080 res
1081 }
1082}
1083
1084impl SubAssign for Matrix
1085{
1086 fn sub_assign(&mut self, rhs: Self)
1087 {
1088 let frontend = Frontend::new().unwrap();
1089 frontend.sub(&self, &rhs, &self).unwrap();
1090 }
1091}
1092
1093impl SubAssign<&Matrix> for Matrix
1094{
1095 fn sub_assign(&mut self, rhs: &Self)
1096 {
1097 let frontend = Frontend::new().unwrap();
1098 frontend.sub(&self, rhs, &self).unwrap();
1099 }
1100}
1101
1102impl SubAssign<f32> for Matrix
1103{
1104 fn sub_assign(&mut self, rhs: f32)
1105 {
1106 let frontend = Frontend::new().unwrap();
1107 frontend.sub_for_scalar(&self, rhs, &self).unwrap();
1108 }
1109}
1110
1111impl SubAssign<&f32> for Matrix
1112{
1113 fn sub_assign(&mut self, rhs: &f32)
1114 {
1115 let frontend = Frontend::new().unwrap();
1116 frontend.sub_for_scalar(&self, *rhs, &self).unwrap();
1117 }
1118}
1119
1120impl Mul for Matrix
1121{
1122 type Output = Self;
1123
1124 fn mul(self, rhs: Self) -> Self::Output
1125 {
1126 let frontend = Frontend::new().unwrap();
1127 let res = if frontend.backend().has_cublas() {
1128 frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1129 } else {
1130 unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1131 };
1132 frontend.mul(&self, &rhs, &res).unwrap();
1133 res
1134 }
1135}
1136
1137impl Mul<&Matrix> for Matrix
1138{
1139 type Output = Self;
1140
1141 fn mul(self, rhs: &Matrix) -> Self::Output
1142 {
1143 let frontend = Frontend::new().unwrap();
1144 let res = if frontend.backend().has_cublas() {
1145 frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1146 } else {
1147 unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1148 };
1149 frontend.mul(&self, rhs, &res).unwrap();
1150 res
1151 }
1152}
1153
1154impl Mul<f32> for Matrix
1155{
1156 type Output = Self;
1157
1158 fn mul(self, rhs: f32) -> Self::Output
1159 {
1160 let frontend = Frontend::new().unwrap();
1161 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1162 frontend.mul_for_scalar(&self, rhs, &res).unwrap();
1163 res
1164 }
1165}
1166
1167impl Mul<&f32> for Matrix
1168{
1169 type Output = Self;
1170
1171 fn mul(self, rhs: &f32) -> Self::Output
1172 {
1173 let frontend = Frontend::new().unwrap();
1174 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1175 frontend.mul_for_scalar(&self, *rhs, &res).unwrap();
1176 res
1177 }
1178}
1179
1180impl Mul<Matrix> for &Matrix
1181{
1182 type Output = Matrix;
1183
1184 fn mul(self, rhs: Matrix) -> Self::Output
1185 {
1186 let frontend = Frontend::new().unwrap();
1187 let res = if frontend.backend().has_cublas() {
1188 frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1189 } else {
1190 unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1191 };
1192 frontend.mul(self, &rhs, &res).unwrap();
1193 res
1194 }
1195}
1196
1197impl Mul<&Matrix> for &Matrix
1198{
1199 type Output = Matrix;
1200
1201 fn mul(self, rhs: &Matrix) -> Self::Output
1202 {
1203 let frontend = Frontend::new().unwrap();
1204 let res = if frontend.backend().has_cublas() {
1205 frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1206 } else {
1207 unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1208 };
1209 frontend.mul(self, rhs, &res).unwrap();
1210 res
1211 }
1212}
1213
1214impl Mul<f32> for &Matrix
1215{
1216 type Output = Matrix;
1217
1218 fn mul(self, rhs: f32) -> Self::Output
1219 {
1220 let frontend = Frontend::new().unwrap();
1221 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1222 frontend.mul_for_scalar(self, rhs, &res).unwrap();
1223 res
1224 }
1225}
1226
1227impl Mul<&f32> for &Matrix
1228{
1229 type Output = Matrix;
1230
1231 fn mul(self, rhs: &f32) -> Self::Output
1232 {
1233 let frontend = Frontend::new().unwrap();
1234 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1235 frontend.mul_for_scalar(self, *rhs, &res).unwrap();
1236 res
1237 }
1238}
1239
1240impl MulAssign for Matrix
1241{
1242 fn mul_assign(&mut self, rhs: Self)
1243 {
1244 let frontend = Frontend::new().unwrap();
1245 let res = if frontend.backend().has_cublas() {
1246 frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1247 } else {
1248 unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1249 };
1250 frontend.mul(&self, &rhs, &res).unwrap();
1251 *self = res;
1252 }
1253}
1254
1255impl MulAssign<&Matrix> for Matrix
1256{
1257 fn mul_assign(&mut self, rhs: &Self)
1258 {
1259 let frontend = Frontend::new().unwrap();
1260 let res = if frontend.backend().has_cublas() {
1261 frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1262 } else {
1263 unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1264 };
1265 frontend.mul(&self, rhs, &res).unwrap();
1266 *self = res;
1267 }
1268}
1269
1270impl MulAssign<f32> for Matrix
1271{
1272 fn mul_assign(&mut self, rhs: f32)
1273 {
1274 let frontend = Frontend::new().unwrap();
1275 frontend.mul_for_scalar(&self, rhs, &self).unwrap();
1276 }
1277}
1278
1279impl MulAssign<&f32> for Matrix
1280{
1281 fn mul_assign(&mut self, rhs: &f32)
1282 {
1283 let frontend = Frontend::new().unwrap();
1284 frontend.mul_for_scalar(&self, *rhs, &self).unwrap();
1285 }
1286}
1287
1288impl Div<f32> for Matrix
1289{
1290 type Output = Self;
1291
1292 fn div(self, rhs: f32) -> Self::Output
1293 {
1294 let frontend = Frontend::new().unwrap();
1295 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1296 frontend.div_for_scalar(&self, rhs, &res).unwrap();
1297 res
1298 }
1299}
1300
1301impl Div<&f32> for Matrix
1302{
1303 type Output = Self;
1304
1305 fn div(self, rhs: &f32) -> Self::Output
1306 {
1307 let frontend = Frontend::new().unwrap();
1308 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1309 frontend.div_for_scalar(&self, *rhs, &res).unwrap();
1310 res
1311 }
1312}
1313
1314impl Div<f32> for &Matrix
1315{
1316 type Output = Matrix;
1317
1318 fn div(self, rhs: f32) -> Self::Output
1319 {
1320 let frontend = Frontend::new().unwrap();
1321 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1322 frontend.div_for_scalar(self, rhs, &res).unwrap();
1323 res
1324 }
1325}
1326
1327impl Div<&f32> for &Matrix
1328{
1329 type Output = Matrix;
1330
1331 fn div(self, rhs: &f32) -> Self::Output
1332 {
1333 let frontend = Frontend::new().unwrap();
1334 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1335 frontend.div_for_scalar(self, *rhs, &res).unwrap();
1336 res
1337 }
1338}
1339
1340impl DivAssign<f32> for Matrix
1341{
1342 fn div_assign(&mut self, rhs: f32)
1343 {
1344 initialize_default_backend_for_uninitialized().unwrap();
1345 let frontend = Frontend::new().unwrap();
1346 frontend.div_for_scalar(&self, rhs, &self).unwrap();
1347 }
1348}
1349
1350impl DivAssign<&f32> for Matrix
1351{
1352 fn div_assign(&mut self, rhs: &f32)
1353 {
1354 initialize_default_backend_for_uninitialized().unwrap();
1355 let frontend = Frontend::new().unwrap();
1356 frontend.div_for_scalar(&self, *rhs, &self).unwrap();
1357 }
1358}
1359
1360/// A frontend structure.
1361///
1362/// The frontend contains methods which operate on matrices or calculate functions for the
1363/// matrices. Backend methods are called by the frontend to operate the matrices. The frontend is
1364/// high-level layer that can be directly used by programmer or a [`Matrix`] structure.
1365pub struct Frontend
1366{
1367 backend: Arc<dyn Backend>,
1368}
1369
1370impl Frontend
1371{
1372 /// Creates a frontend with a default backend.
1373 ///
1374 /// This method also automatically initializes a default backend if the default backend is
1375 /// uninitialized.
1376 pub fn new() -> Result<Frontend>
1377 { Ok(Frontend { backend: initialize_default_backend_for_uninitialized()?, }) }
1378
1379 /// Creates a frotend with the backend.
1380 pub fn new_with_backend(backend: Arc<dyn Backend>) -> Frontend
1381 { Frontend { backend, } }
1382
1383 /// Returns the backend.
1384 pub fn backend(&self) -> Arc<dyn Backend>
1385 { self.backend.clone() }
1386
1387 /// Creates a matrix with unset elements.
1388 pub unsafe fn create_matrix(&self, row_count: usize, col_count: usize) -> Result<Matrix>
1389 {
1390 Ok(Matrix {
1391 row_count,
1392 col_count,
1393 is_transposed: false,
1394 array: Arc::new(self.backend.alloc(row_count * col_count)?),
1395 })
1396 }
1397
1398 /// Creates a matrix and sets the matrix elements on zeros.
1399 pub fn create_matrix_and_set_zeros(&self, row_count: usize, col_count: usize) -> Result<Matrix>
1400 {
1401 Ok(Matrix {
1402 row_count,
1403 col_count,
1404 is_transposed: false,
1405 array: Arc::new(self.backend.alloc_and_store_zeros(row_count * col_count)?),
1406 })
1407 }
1408
1409 /// Creates a matrix and sets the matrix elements.
1410 pub fn create_matrix_and_set_elems(&self, row_count: usize, col_count: usize, elems: &[f32]) -> Result<Matrix>
1411 {
1412 if row_count * col_count != elems.len() {
1413 return Err(Error::MatrixElemCount(row_count * col_count, elems.len()));
1414 }
1415 Ok(Matrix {
1416 row_count,
1417 col_count,
1418 is_transposed: false,
1419 array: Arc::new(self.backend.alloc_and_store(elems)?),
1420 })
1421 }
1422
1423 /// Sets the matrix elements.
1424 pub fn set_elems(&self, a: &Matrix, elems: &[f32]) -> Result<()>
1425 {
1426 if a.row_count() * a.col_count() != elems.len() {
1427 return Err(Error::MatrixElemCount(a.row_count() * a.col_count(), elems.len()));
1428 }
1429 self.backend.store(&*a.array, elems)
1430 }
1431
1432 /// Copies the `a` matrix to the `b` matrix.
1433 ///
1434 /// This method indeed copies the `a` matrix array to the `b` matrix array.
1435 pub fn copy(&self, a: &Matrix, b: &Matrix) -> Result<()>
1436 {
1437 if a.row_count != b.row_count || a.col_count != b.col_count {
1438 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
1439 }
1440 self.backend.copy(&*a.array, &*b.array)
1441 }
1442
1443 /// Copies the matrix elements to the mutable slice and the transpose flag to the object that
1444 /// is referred by the reference.
1445 pub fn get_elems_and_transpose_flag(&self, a: &Matrix, elems: &mut [f32], is_transposed: &mut bool) -> Result<()>
1446 {
1447 if a.row_count * a.col_count != elems.len() {
1448 return Err(Error::MatrixElemCount(a.row_count * a.col_count, elems.len()));
1449 }
1450 if !a.is_transposed {
1451 self.backend.load(&*a.array, elems)?;
1452 } else {
1453 self.backend.load(&*a.array, elems)?;
1454 }
1455 *is_transposed = true;
1456 Ok(())
1457 }
1458
1459 /// Returns the elements and the transpose flag of matrix.
1460 pub fn elems_and_transpose_flag(&self, a: &Matrix) -> Result<(Vec<f32>, bool)>
1461 {
1462 let mut elems: Vec<f32> = vec![0.0; a.row_count * a.col_count];
1463 let mut is_transposed = false;
1464 self.get_elems_and_transpose_flag(a, elems.as_mut_slice(), &mut is_transposed)?;
1465 Ok((elems, is_transposed))
1466 }
1467
1468 /// Adds the `a` matrix onto the `b` matrix and then the result is in the `c` matrix
1469 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>+</mo><mi mathvariant="bold">B</mi></mrow></math>).
1470 ///
1471 /// # Examples
1472 ///
1473 /// ```
1474 /// # use unmtx_gpu::*;
1475 /// let a = matrix![
1476 /// [1.0, 2.0],
1477 /// [3.0, 4.0]
1478 /// ];
1479 /// let b = matrix![
1480 /// [5.0, 6.0],
1481 /// [7.0, 8.0]
1482 /// ];
1483 /// let c = Matrix::new(2, 2);
1484 /// let frontend = Frontend::new().unwrap();
1485 /// frontend.add(&a, &b, &c).unwrap();
1486 /// assert_eq!(vec![1.0 + 5.0, 2.0 + 6.0, 3.0 + 7.0, 4.0 + 8.0], c.elems());
1487 /// ```
1488 pub fn add(&self, a: &Matrix, b: &Matrix, c: &Matrix) -> Result<()>
1489 {
1490 if a.row_count != b.row_count || a.col_count != b.col_count {
1491 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
1492 }
1493 if a.row_count != c.row_count || a.col_count != c.col_count {
1494 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1495 }
1496 if c.is_transposed {
1497 return Err(Error::ResTransposition);
1498 }
1499 match (a.is_transposed, b.is_transposed) {
1500 (false, false) => self.backend.add_a_b(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1501 (true, false) => self.backend.add_at_b(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1502 (false, true) => self.backend.add_a_bt(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1503 (true, true) => self.backend.add_at_bt(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1504 }
1505 }
1506
1507 /// Subtracts the `b` matrix from the `a` matrix and then the result is in the `c` matrix
1508 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>-</mo><mi mathvariant="bold">B</mi></mrow></math>).
1509 ///
1510 /// # Examples
1511 ///
1512 /// ```
1513 /// # use unmtx_gpu::*;
1514 /// let a = matrix![
1515 /// [1.0, 2.0],
1516 /// [3.0, 4.0]
1517 /// ];
1518 /// let b = matrix![
1519 /// [5.0, 6.0],
1520 /// [7.0, 8.0]
1521 /// ];
1522 /// let c = Matrix::new(2, 2);
1523 /// let frontend = Frontend::new().unwrap();
1524 /// frontend.sub(&a, &b, &c).unwrap();
1525 /// assert_eq!(vec![1.0 - 5.0, 2.0 - 6.0, 3.0 - 7.0, 4.0 - 8.0], c.elems());
1526 /// ```
1527 pub fn sub(&self, a: &Matrix, b: &Matrix, c: &Matrix) -> Result<()>
1528 {
1529 if a.row_count != b.row_count || a.col_count != b.col_count {
1530 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
1531 }
1532 if a.row_count != c.row_count || a.col_count != c.col_count {
1533 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1534 }
1535 if c.is_transposed {
1536 return Err(Error::ResTransposition);
1537 }
1538 match (a.is_transposed, b.is_transposed) {
1539 (false, false) => self.backend.sub_a_b(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1540 (true, false) => self.backend.sub_at_b(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1541 (false, true) => self.backend.sub_a_bt(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1542 (true, true) => self.backend.sub_at_bt(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1543 }
1544 }
1545
1546 /// Multiplies the `a` matrix by the `b` matrix and then the result is in the `c` matrix
1547 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>·</mo><mi mathvariant="bold">B</mi></mrow></math>).
1548 ///
1549 /// # Examples
1550 ///
1551 /// ```
1552 /// # use unmtx_gpu::*;
1553 /// let a = matrix![
1554 /// [1.0, 2.0, 3.0],
1555 /// [4.0, 5.0, 6.0]
1556 /// ];
1557 /// let b = matrix![
1558 /// [7.0, 8.0],
1559 /// [9.0, 10.0],
1560 /// [11.0, 12.0]
1561 /// ];
1562 /// let c = Matrix::new(2, 2);
1563 /// let frontend = Frontend::new().unwrap();
1564 /// frontend.mul(&a, &b, &c).unwrap();
1565 /// let c11: f32 = 1.0 * 7.0 + 2.0 * 9.0 + 3.0 * 11.0;
1566 /// let c12: f32 = 1.0 * 8.0 + 2.0 * 10.0 + 3.0 * 12.0;
1567 /// let c21: f32 = 4.0 * 7.0 + 5.0 * 9.0 + 6.0 * 11.0;
1568 /// let c22: f32 = 4.0 * 8.0 + 5.0 * 10.0 + 6.0 * 12.0;
1569 /// assert_eq!(vec![c11, c12, c21, c22], c.elems());
1570 /// ```
1571 pub fn mul(&self, a: &Matrix, b: &Matrix, c: &Matrix) -> Result<()>
1572 {
1573 if a.row_count != c.row_count {
1574 return Err(Error::MulSize(a.row_count, a.col_count, b.row_count, b.col_count, c.row_count, c.col_count));
1575 }
1576 if b.col_count != c.col_count {
1577 return Err(Error::MulSize(a.row_count, a.col_count, b.row_count, b.col_count, c.row_count, c.col_count));
1578 }
1579 if a.col_count != b.row_count {
1580 return Err(Error::MulSize(a.row_count, a.col_count, b.row_count, b.col_count, c.row_count, c.col_count));
1581 }
1582 if c.is_transposed {
1583 return Err(Error::ResTransposition);
1584 }
1585 match (a.is_transposed, b.is_transposed) {
1586 (false, false) => self.backend.mul_a_b(&*a.array, &*b.array, &*c.array, a.row_count, b.col_count, a.col_count),
1587 (true, false) => self.backend.mul_at_b(&*a.array, &*b.array, &*c.array, a.row_count, b.col_count, a.col_count),
1588 (false, true) => self.backend.mul_a_bt(&*a.array, &*b.array, &*c.array, a.row_count, b.col_count, a.col_count),
1589 (true, true) => self.backend.mul_at_bt(&*a.array, &*b.array, &*c.array, a.row_count, b.col_count, a.col_count),
1590 }
1591 }
1592
1593 /// Multiplies the `a` matrix elements by the `b` matrix elements and then the result is in the
1594 /// `c` matrix
1595 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><mo>·</mo><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mrow></math>).
1596 ///
1597 /// # Examples
1598 ///
1599 /// ```
1600 /// # use unmtx_gpu::*;
1601 /// let a = matrix![
1602 /// [1.0, 2.0],
1603 /// [3.0, 4.0]
1604 /// ];
1605 /// let b = matrix![
1606 /// [5.0, 6.0],
1607 /// [7.0, 8.0]
1608 /// ];
1609 /// let c = Matrix::new(2, 2);
1610 /// let frontend = Frontend::new().unwrap();
1611 /// frontend.mul_elems(&a, &b, &c).unwrap();
1612 /// assert_eq!(vec![1.0 * 5.0, 2.0 * 6.0, 3.0 * 7.0, 4.0 * 8.0], c.elems());
1613 /// ```
1614 pub fn mul_elems(&self, a: &Matrix, b: &Matrix, c: &Matrix) -> Result<()>
1615 {
1616 if a.row_count != b.row_count || a.col_count != b.col_count {
1617 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
1618 }
1619 if a.row_count != c.row_count || a.col_count != c.col_count {
1620 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1621 }
1622 if c.is_transposed {
1623 return Err(Error::ResTransposition);
1624 }
1625 match (a.is_transposed, b.is_transposed) {
1626 (false, false) => self.backend.mul_a_b_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1627 (true, false) => self.backend.mul_at_b_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1628 (false, true) => self.backend.mul_a_bt_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1629 (true, true) => self.backend.mul_at_bt_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1630 }
1631 }
1632
1633 /// Divides the `a` matrix elements by the `b` matrix elements and then the result is in the `c`
1634 /// matrix
1635 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mfrac></mrow></math>).
1636 ///
1637 /// # Examples
1638 ///
1639 /// ```
1640 /// # use unmtx_gpu::*;
1641 /// let a = matrix![
1642 /// [1.0, 2.0],
1643 /// [3.0, 4.0]
1644 /// ];
1645 /// let b = matrix![
1646 /// [5.0, 6.0],
1647 /// [7.0, 8.0]
1648 /// ];
1649 /// let c = Matrix::new(2, 2);
1650 /// let frontend = Frontend::new().unwrap();
1651 /// frontend.div_elems(&a, &b, &c).unwrap();
1652 /// let elems = c.elems();
1653 /// assert!((1.0 / 5.0 - elems[0]).abs() < 0.001);
1654 /// assert!((2.0 / 6.0 - elems[1]).abs() < 0.001);
1655 /// assert!((3.0 / 7.0 - elems[2]).abs() < 0.001);
1656 /// assert!((4.0 / 8.0 - elems[3]).abs() < 0.001);
1657 /// ```
1658 pub fn div_elems(&self, a: &Matrix, b: &Matrix, c: &Matrix) -> Result<()>
1659 {
1660 if a.row_count != b.row_count || a.col_count != b.col_count {
1661 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
1662 }
1663 if a.row_count != c.row_count || a.col_count != c.col_count {
1664 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1665 }
1666 if c.is_transposed {
1667 return Err(Error::ResTransposition);
1668 }
1669 match (a.is_transposed, b.is_transposed) {
1670 (false, false) => self.backend.div_a_b_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1671 (true, false) => self.backend.div_at_b_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1672 (false, true) => self.backend.div_a_bt_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1673 (true, true) => self.backend.div_at_bt_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1674 }
1675 }
1676
1677 /// Adds the `a` matrix onto the `b` scalar and then the result is in the `c` matrix
1678 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>+</mo><mi>b</mi></mrow></math>).
1679 ///
1680 /// # Examples
1681 ///
1682 /// ```
1683 /// # use unmtx_gpu::*;
1684 /// let a = matrix![
1685 /// [1.0, 2.0],
1686 /// [3.0, 4.0]
1687 /// ];
1688 /// let c = Matrix::new(2, 2);
1689 /// let frontend = Frontend::new().unwrap();
1690 /// frontend.add_for_scalar(&a, 10.5, &c).unwrap();
1691 /// assert_eq!(vec![1.0 + 10.5, 2.0 + 10.5, 3.0 + 10.5, 4.0 + 10.5], c.elems());
1692 /// ```
1693 pub fn add_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1694 {
1695 if a.row_count != c.row_count || a.col_count != c.col_count {
1696 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1697 }
1698 if c.is_transposed {
1699 return Err(Error::ResTransposition);
1700 }
1701 if !a.is_transposed {
1702 self.backend.add_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1703 } else {
1704 self.backend.add_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1705 }
1706 }
1707
1708 /// Subtracts the `b` scalar from the `a` matrix and then the result is in the `c` matrix
1709 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>-</mo><mi>b</mi></mrow></math>).
1710 ///
1711 /// # Examples
1712 ///
1713 /// ```
1714 /// # use unmtx_gpu::*;
1715 /// let a = matrix![
1716 /// [1.0, 2.0],
1717 /// [3.0, 4.0]
1718 /// ];
1719 /// let c = Matrix::new(2, 2);
1720 /// let frontend = Frontend::new().unwrap();
1721 /// frontend.sub_for_scalar(&a, 10.5, &c).unwrap();
1722 /// assert_eq!(vec![1.0 - 10.5, 2.0 - 10.5, 3.0 - 10.5, 4.0 - 10.5], c.elems());
1723 /// ```
1724 pub fn sub_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1725 {
1726 if a.row_count != c.row_count || a.col_count != c.col_count {
1727 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1728 }
1729 if c.is_transposed {
1730 return Err(Error::ResTransposition);
1731 }
1732 if !a.is_transposed {
1733 self.backend.sub_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1734 } else {
1735 self.backend.sub_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1736 }
1737 }
1738
1739 /// Subtracts the `a` matrix from the `b` scalar and then the result is in the `c` matrix
1740 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi>b</mi><mo>-</mo><mi mathvariant="bold">A</mi></mrow></math>).
1741 ///
1742 /// # Examples
1743 ///
1744 /// ```
1745 /// # use unmtx_gpu::*;
1746 /// let a = matrix![
1747 /// [1.0, 2.0],
1748 /// [3.0, 4.0]
1749 /// ];
1750 /// let c = Matrix::new(2, 2);
1751 /// let frontend = Frontend::new().unwrap();
1752 /// frontend.rsub_for_scalar(&a, 10.5, &c).unwrap();
1753 /// assert_eq!(vec![10.5 - 1.0, 10.5 - 2.0, 10.5 - 3.0, 10.5 - 4.0], c.elems());
1754 /// ```
1755 pub fn rsub_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1756 {
1757 if a.row_count != c.row_count || a.col_count != c.col_count {
1758 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1759 }
1760 if c.is_transposed {
1761 return Err(Error::ResTransposition);
1762 }
1763 if !a.is_transposed {
1764 self.backend.rsub_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1765 } else {
1766 self.backend.rsub_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1767 }
1768 }
1769
1770 /// Multiplies the `a` matrix by the `b` scalar and then the result is in the `c` matrix
1771 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>·</mo><mi>b</mi></mrow></math>).
1772 ///
1773 /// # Examples
1774 ///
1775 /// ```
1776 /// # use unmtx_gpu::*;
1777 /// let a = matrix![
1778 /// [1.0, 2.0],
1779 /// [3.0, 4.0]
1780 /// ];
1781 /// let c = Matrix::new(2, 2);
1782 /// let frontend = Frontend::new().unwrap();
1783 /// frontend.mul_for_scalar(&a, 10.5, &c).unwrap();
1784 /// assert_eq!(vec![1.0 * 10.5, 2.0 * 10.5, 3.0 * 10.5, 4.0 * 10.5], c.elems());
1785 /// ```
1786 pub fn mul_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1787 {
1788 if a.row_count != c.row_count || a.col_count != c.col_count {
1789 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1790 }
1791 if c.is_transposed {
1792 return Err(Error::ResTransposition);
1793 }
1794 if !a.is_transposed {
1795 self.backend.mul_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1796 } else {
1797 self.backend.mul_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1798 }
1799 }
1800
1801 /// Divides the `a` matrix by the `b` scalar and then the result is in the `c` matrix
1802 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mfrac><mi mathvariant="bold">A</mi><mi>b</mi></mfrac></mrow></math>).
1803 ///
1804 /// # Examples
1805 ///
1806 /// ```
1807 /// # use unmtx_gpu::*;
1808 /// let a = matrix![
1809 /// [1.0, 2.0],
1810 /// [3.0, 4.0]
1811 /// ];
1812 /// let c = Matrix::new(2, 2);
1813 /// let frontend = Frontend::new().unwrap();
1814 /// frontend.div_for_scalar(&a, 10.5, &c).unwrap();
1815 /// let elems = c.elems();
1816 /// assert!((1.0 / 10.5 - elems[0]).abs() < 0.001);
1817 /// assert!((2.0 / 10.5 - elems[1]).abs() < 0.001);
1818 /// assert!((3.0 / 10.5 - elems[2]).abs() < 0.001);
1819 /// assert!((4.0 / 10.5 - elems[3]).abs() < 0.001);
1820 /// ```
1821 pub fn div_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1822 {
1823 if a.row_count != c.row_count || a.col_count != c.col_count {
1824 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1825 }
1826 if c.is_transposed {
1827 return Err(Error::ResTransposition);
1828 }
1829 if !a.is_transposed {
1830 self.backend.div_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1831 } else {
1832 self.backend.div_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1833 }
1834 }
1835
1836 /// Divides the `b` scalar by the `a` matrix elements and then the result is in the `c` matrix
1837 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><mi>b</mi><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub></mfrac></mrow></math>).
1838 ///
1839 /// # Examples
1840 ///
1841 /// ```
1842 /// # use unmtx_gpu::*;
1843 /// let a = matrix![
1844 /// [1.0, 2.0],
1845 /// [3.0, 4.0]
1846 /// ];
1847 /// let c = Matrix::new(2, 2);
1848 /// let frontend = Frontend::new().unwrap();
1849 /// frontend.rdiv_for_scalar(&a, 10.5, &c).unwrap();
1850 /// let elems = c.elems();
1851 /// assert!((10.5 / 1.0- elems[0]).abs() < 0.001);
1852 /// assert!((10.5 / 2.0 - elems[1]).abs() < 0.001);
1853 /// assert!((10.5 / 3.0 - elems[2]).abs() < 0.001);
1854 /// assert!((10.5 / 4.0 - elems[3]).abs() < 0.001);
1855 /// ```
1856 pub fn rdiv_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1857 {
1858 if a.row_count != c.row_count || a.col_count != c.col_count {
1859 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1860 }
1861 if c.is_transposed {
1862 return Err(Error::ResTransposition);
1863 }
1864 if !a.is_transposed {
1865 self.backend.rdiv_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1866 } else {
1867 self.backend.rdiv_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1868 }
1869 }
1870
1871 /// Calculates sigmoid function for the `a` matrix and then the result is in the `b` matrix
1872 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>sigmoid</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
1873 ///
1874 /// # Examples
1875 ///
1876 /// ```
1877 /// # use unmtx_gpu::*;
1878 /// let a = matrix![
1879 /// [1.0, 2.0],
1880 /// [3.0, 4.0]
1881 /// ];
1882 /// let b = Matrix::new(2, 2);
1883 /// let frontend = Frontend::new().unwrap();
1884 /// frontend.sigmoid(&a, &b).unwrap();
1885 /// let elems = b.elems();
1886 /// assert!((1.0 / (1.0 + (-1.0f32).exp()) - elems[0]).abs() < 0.001);
1887 /// assert!((1.0 / (1.0 + (-2.0f32).exp()) - elems[1]).abs() < 0.001);
1888 /// assert!((1.0 / (1.0 + (-3.0f32).exp()) - elems[2]).abs() < 0.001);
1889 /// assert!((1.0 / (1.0 + (-4.0f32).exp()) - elems[3]).abs() < 0.001);
1890 /// ```
1891 pub fn sigmoid(&self, a: &Matrix, b: &Matrix) -> Result<()>
1892 {
1893 if a.row_count != b.row_count || a.col_count != b.col_count {
1894 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
1895 }
1896 if b.is_transposed {
1897 return Err(Error::ResTransposition);
1898 }
1899 if !a.is_transposed {
1900 self.backend.sigmoid_a(&*a.array, &*b.array, a.row_count, a.col_count)
1901 } else {
1902 self.backend.sigmoid_at(&*a.array, &*b.array, a.row_count, a.col_count)
1903 }
1904 }
1905
1906 /// Calculates hyperbolic tangent function for the `a` matrix and then the result is in the `b`
1907 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>tanh</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
1908 ///
1909 /// # Examples
1910 ///
1911 /// ```
1912 /// # use unmtx_gpu::*;
1913 /// let a = matrix![
1914 /// [1.0, 2.0],
1915 /// [3.0, 4.0]
1916 /// ];
1917 /// let b = Matrix::new(2, 2);
1918 /// let frontend = Frontend::new().unwrap();
1919 /// frontend.tanh(&a, &b).unwrap();
1920 /// let elems = b.elems();
1921 /// assert!((1.0f32.tanh() - elems[0]).abs() < 0.001);
1922 /// assert!((2.0f32.tanh() - elems[1]).abs() < 0.001);
1923 /// assert!((3.0f32.tanh() - elems[2]).abs() < 0.001);
1924 /// assert!((4.0f32.tanh() - elems[3]).abs() < 0.001);
1925 /// ```
1926 pub fn tanh(&self, a: &Matrix, b: &Matrix) -> Result<()>
1927 {
1928 if a.row_count != b.row_count || a.col_count != b.col_count {
1929 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
1930 }
1931 if b.is_transposed {
1932 return Err(Error::ResTransposition);
1933 }
1934 if !a.is_transposed {
1935 self.backend.tanh_a(&*a.array, &*b.array, a.row_count, a.col_count)
1936 } else {
1937 self.backend.tanh_at(&*a.array, &*b.array, a.row_count, a.col_count)
1938 }
1939 }
1940
1941 /// Calculates softmax function for the `a` matrix and then the result is in the `b` matrix
1942 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>softmax</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
1943 ///
1944 /// # Examples
1945 ///
1946 /// ```
1947 /// # use unmtx_gpu::*;
1948 /// let a = matrix![
1949 /// [1.0, 2.0],
1950 /// [3.0, 4.0]
1951 /// ];
1952 /// let b = Matrix::new(2, 2);
1953 /// let frontend = Frontend::new().unwrap();
1954 /// frontend.softmax(&a, &b).unwrap();
1955 /// let elems = b.elems();
1956 /// let sum1 = 1.0f32.exp() + 3.0f32.exp();
1957 /// let sum2 = 2.0f32.exp() + 4.0f32.exp();
1958 /// assert!((1.0f32.exp() / sum1 - elems[0]).abs() < 0.001);
1959 /// assert!((2.0f32.exp() / sum2 - elems[1]).abs() < 0.001);
1960 /// assert!((3.0f32.exp() / sum1 - elems[2]).abs() < 0.001);
1961 /// assert!((4.0f32.exp() / sum2 - elems[3]).abs() < 0.001);
1962 /// ```
1963 pub fn softmax(&self, a: &Matrix, b: &Matrix) -> Result<()>
1964 {
1965 if a.row_count != b.row_count || a.col_count != b.col_count {
1966 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
1967 }
1968 if b.is_transposed {
1969 return Err(Error::ResTransposition);
1970 }
1971 if !a.is_transposed {
1972 self.backend.softmax_a(&*a.array, &*b.array, a.row_count, a.col_count)
1973 } else {
1974 self.backend.softmax_at(&*a.array, &*b.array, a.row_count, a.col_count)
1975 }
1976 }
1977
1978 /// Indeed transposes the `a` matrix and then the result is in the `b` matrix
1979 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
1980 ///
1981 /// This method indeed transposes the `a` matrix without changing the transpose flag.
1982 ///
1983 /// # Examples
1984 ///
1985 /// ```
1986 /// # use unmtx_gpu::*;
1987 /// let a = matrix![
1988 /// [1.0, 2.0, 3.0],
1989 /// [4.0, 5.0, 6.0]
1990 /// ];
1991 /// let b = Matrix::new(3, 2);
1992 /// let frontend = Frontend::new().unwrap();
1993 /// frontend.really_transpose(&a, &b).unwrap();
1994 /// assert_eq!(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], b.elems());
1995 /// ```
1996 pub fn really_transpose(&self, a: &Matrix, b: &Matrix) -> Result<()>
1997 {
1998 if a.row_count != b.col_count || a.col_count != b.row_count {
1999 return Err(Error::TransposeSize(a.row_count, a.col_count, b.row_count, b.col_count));
2000 }
2001 if a.is_transposed {
2002 return Err(Error::ArgTransposition);
2003 }
2004 if b.is_transposed {
2005 return Err(Error::ResTransposition);
2006 }
2007 self.backend.transpose_a(&*a.array, &*b.array, a.col_count, a.row_count)
2008 }
2009
2010 /// Repeats the `a` vector as column or a row
2011 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi>i</mi></msub></mrow></math> or
2012 /// <math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi>j</mi></msub></mrow></math>).
2013 ///
2014 /// # Examples
2015 ///
2016 /// ```
2017 /// # use unmtx_gpu::*;
2018 /// let a = matrix![
2019 /// [1.0],
2020 /// [2.0]
2021 /// ];
2022 /// let b = Matrix::new(2, 3);
2023 /// let frontend = Frontend::new().unwrap();
2024 /// frontend.repeat(&a, &b).unwrap();
2025 /// assert_eq!(vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0], b.elems());
2026 /// let c = matrix![[1.0, 2.0, 3.0]];
2027 /// let d = Matrix::new(2, 3);
2028 /// let frontend = Frontend::new().unwrap();
2029 /// frontend.repeat(&c, &d).unwrap();
2030 /// assert_eq!(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], d.elems());
2031 /// ```
2032 pub fn repeat(&self, a: &Matrix, b: &Matrix) -> Result<()>
2033 {
2034 if b.is_transposed {
2035 return Err(Error::ResTransposition);
2036 }
2037 if a.col_count == 1 {
2038 if a.row_count != b.row_count {
2039 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
2040 }
2041 self.backend.repeat_col_a(&*a.array, &*b.array, a.row_count, b.col_count)
2042 } else if a.row_count == 1 {
2043 if a.col_count != b.col_count {
2044 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
2045 }
2046 self.backend.repeat_row_a(&*a.array, &*b.array, b.row_count, a.col_count)
2047 } else {
2048 Err(Error::IsNotVector)
2049 }
2050 }
2051}
2052
2053#[cfg(test)]
2054mod test_helpers;
2055#[cfg(all(test, not(feature = "test_only_backend")))]
2056mod tests;