1mod activations;
42mod complex;
43pub mod constructors;
44mod conversions;
45mod expression;
46mod math_ops;
47mod sparse;
48pub mod transformations;
49mod utils;
50
51#[cfg(test)]
52mod property_tests;
53
54use crate::errors::Result;
55use scirs2_core::ndarray::{ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
56use scirs2_core::Complex;
57use scirs2_core::{Complex32, Complex64};
58use serde::{Deserialize, Serialize};
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
62pub enum DType {
63 F32,
65 F16,
67 BF16,
69 F64,
71 C32,
73 C64,
75 CF16,
77 CBF16,
79 U8,
81 U16,
83 U32,
85 U64,
87 I8,
89 I16,
91 I32,
93 I64,
95 Bool,
97}
98
99impl DType {
100 pub fn size_in_bytes(&self) -> usize {
102 match self {
103 DType::F32 => 4,
104 DType::F16 => 2,
105 DType::BF16 => 2,
106 DType::F64 => 8,
107 DType::C32 => 8, DType::C64 => 16, DType::CF16 => 4, DType::CBF16 => 4, DType::U8 => 1,
112 DType::U16 => 2,
113 DType::U32 => 4,
114 DType::U64 => 8,
115 DType::I8 => 1,
116 DType::I16 => 2,
117 DType::I32 => 4,
118 DType::I64 => 8,
119 DType::Bool => 1,
120 }
121 }
122}
123
124#[cfg(all(target_os = "macos", feature = "metal"))]
158#[derive(Debug)]
159pub struct MetalTensorData {
160 pub buffer_id: crate::gpu_ops::metal::BufferId,
161 pub shape: Vec<usize>,
162 pub dtype: DType,
163}
164
165#[cfg(all(target_os = "macos", feature = "metal"))]
166impl Clone for MetalTensorData {
167 fn clone(&self) -> Self {
168 Self {
171 buffer_id: self.buffer_id,
172 shape: self.shape.clone(),
173 dtype: self.dtype,
174 }
175 }
176}
177
178#[cfg(feature = "cuda")]
180#[derive(Debug)]
181pub struct CudaTensorData {
182 pub buffer_id: crate::gpu_ops::cuda::BufferId,
183 pub shape: Vec<usize>,
184 pub dtype: DType,
185}
186
187#[cfg(feature = "cuda")]
188impl Clone for CudaTensorData {
189 fn clone(&self) -> Self {
190 Self {
193 buffer_id: self.buffer_id,
194 shape: self.shape.clone(),
195 dtype: self.dtype,
196 }
197 }
198}
199
200pub enum Tensor {
201 F32(ArrayD<f32>),
203 F64(ArrayD<f64>),
204 F16(ArrayD<half::f16>),
205 BF16(ArrayD<half::bf16>),
206 I64(ArrayD<i64>),
207 C32(ArrayD<Complex32>),
209 C64(ArrayD<Complex64>),
210 CF16(ArrayD<Complex<half::f16>>),
211 CBF16(ArrayD<Complex<half::bf16>>),
212 Sparse(crate::sparse_tensor::SparseTensor),
214 #[cfg(feature = "torch")]
217 Torch(tch::Tensor),
218 #[cfg(feature = "candle")]
219 Candle(candle_core::Tensor),
220 #[cfg(all(target_os = "macos", feature = "metal"))]
222 Metal(MetalTensorData),
223 #[cfg(feature = "cuda")]
225 CUDA(CudaTensorData),
226}
227
228impl Clone for Tensor {
230 fn clone(&self) -> Self {
231 match self {
232 Tensor::F32(arr) => Tensor::F32(arr.clone()),
233 Tensor::F64(arr) => Tensor::F64(arr.clone()),
234 Tensor::F16(arr) => Tensor::F16(arr.clone()),
235 Tensor::BF16(arr) => Tensor::BF16(arr.clone()),
236 Tensor::I64(arr) => Tensor::I64(arr.clone()),
237 Tensor::C32(arr) => Tensor::C32(arr.clone()),
238 Tensor::C64(arr) => Tensor::C64(arr.clone()),
239 Tensor::CF16(arr) => Tensor::CF16(arr.clone()),
240 Tensor::CBF16(arr) => Tensor::CBF16(arr.clone()),
241 Tensor::Sparse(s) => Tensor::Sparse(s.clone()),
242 #[cfg(feature = "torch")]
243 Tensor::Torch(t) => Tensor::Torch(t.shallow_clone()),
244 #[cfg(feature = "candle")]
245 Tensor::Candle(t) => Tensor::Candle(t.clone()),
246 #[cfg(all(target_os = "macos", feature = "metal"))]
247 Tensor::Metal(data) => Tensor::Metal(data.clone()),
248 #[cfg(feature = "cuda")]
249 Tensor::CUDA(data) => Tensor::CUDA(data.clone()),
250 }
251 }
252}
253
254impl std::fmt::Debug for Tensor {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 match self {
258 Tensor::F32(_) => write!(f, "Tensor::F32(shape: {:?}, dtype: F32)", self.shape()),
259 Tensor::F64(_) => write!(f, "Tensor::F64(shape: {:?}, dtype: F64)", self.shape()),
260 Tensor::F16(_) => write!(f, "Tensor::F16(shape: {:?}, dtype: F16)", self.shape()),
261 Tensor::BF16(_) => write!(f, "Tensor::BF16(shape: {:?}, dtype: BF16)", self.shape()),
262 Tensor::I64(_) => write!(f, "Tensor::I64(shape: {:?}, dtype: I64)", self.shape()),
263 Tensor::C32(_) => write!(f, "Tensor::C32(shape: {:?}, dtype: C32)", self.shape()),
264 Tensor::C64(_) => write!(f, "Tensor::C64(shape: {:?}, dtype: C64)", self.shape()),
265 Tensor::CF16(_) => write!(f, "Tensor::CF16(shape: {:?}, dtype: CF16)", self.shape()),
266 Tensor::CBF16(_) => write!(f, "Tensor::CBF16(shape: {:?}, dtype: CBF16)", self.shape()),
267 Tensor::Sparse(s) => write!(f, "Tensor::Sparse({:?})", s),
268 #[cfg(feature = "torch")]
269 Tensor::Torch(_) => write!(f, "Tensor::Torch(shape: {:?})", self.shape()),
270 #[cfg(feature = "candle")]
271 Tensor::Candle(_) => write!(f, "Tensor::Candle(shape: {:?})", self.shape()),
272 #[cfg(all(target_os = "macos", feature = "metal"))]
273 Tensor::Metal(data) => write!(
274 f,
275 "Tensor::Metal(shape: {:?}, dtype: {:?}, buffer_id: {:?})",
276 data.shape, data.dtype, data.buffer_id
277 ),
278 #[cfg(feature = "cuda")]
279 Tensor::CUDA(data) => write!(
280 f,
281 "Tensor::CUDA(shape: {:?}, dtype: {:?}, buffer_id: {:?})",
282 data.shape, data.dtype, data.buffer_id
283 ),
284 }
285 }
286}
287
288#[cfg(any(feature = "torch", feature = "candle"))]
294unsafe impl Sync for Tensor {}
295
296impl From<ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>> for Tensor {
299 fn from(arr: ArrayD<f32>) -> Self {
300 Tensor::F32(arr)
301 }
302}
303
304impl From<ArrayBase<OwnedRepr<f64>, Dim<IxDynImpl>>> for Tensor {
305 fn from(arr: ArrayD<f64>) -> Self {
306 Tensor::F64(arr)
307 }
308}
309
310impl std::ops::Add for Tensor {
312 type Output = Result<Tensor>;
313
314 fn add(self, other: Tensor) -> Self::Output {
315 Tensor::add(&self, &other)
316 }
317}
318
319impl std::ops::Add for &Tensor {
320 type Output = Result<Tensor>;
321
322 fn add(self, other: &Tensor) -> Self::Output {
323 Tensor::add(self, other)
324 }
325}
326
327impl std::ops::Add<&&Tensor> for &Tensor {
328 type Output = Result<Tensor>;
329
330 fn add(self, other: &&Tensor) -> Self::Output {
331 Tensor::add(self, other)
332 }
333}
334
335impl std::ops::Add<&Tensor> for &&Tensor {
336 type Output = Result<Tensor>;
337
338 fn add(self, other: &Tensor) -> Self::Output {
339 Tensor::add(self, other)
340 }
341}
342
343impl std::ops::Sub for Tensor {
344 type Output = Result<Tensor>;
345
346 fn sub(self, other: Tensor) -> Self::Output {
347 Tensor::sub(&self, &other)
348 }
349}
350
351impl std::ops::Mul<f32> for Tensor {
353 type Output = Result<Tensor>;
354
355 fn mul(self, scalar: f32) -> Self::Output {
356 self.scalar_mul(scalar)
357 }
358}
359
360impl std::ops::Mul<f32> for &Tensor {
361 type Output = Result<Tensor>;
362
363 fn mul(self, scalar: f32) -> Self::Output {
364 self.scalar_mul(scalar)
365 }
366}
367
368impl std::ops::Mul<f64> for Tensor {
369 type Output = Result<Tensor>;
370
371 fn mul(self, scalar: f64) -> Self::Output {
372 self.scalar_mul(scalar as f32)
373 }
374}
375
376impl std::ops::Mul<f64> for &Tensor {
377 type Output = Result<Tensor>;
378
379 fn mul(self, scalar: f64) -> Self::Output {
380 self.scalar_mul(scalar as f32)
381 }
382}
383
384impl std::ops::Mul<&Tensor> for &Tensor {
386 type Output = Result<Tensor>;
387
388 fn mul(self, other: &Tensor) -> Self::Output {
389 Tensor::mul(self, other)
390 }
391}
392
393impl std::ops::Mul<Tensor> for &Tensor {
394 type Output = Result<Tensor>;
395
396 fn mul(self, other: Tensor) -> Self::Output {
397 Tensor::mul(self, &other)
398 }
399}
400
401impl std::ops::Mul<&Tensor> for Tensor {
402 type Output = Result<Tensor>;
403
404 fn mul(self, other: &Tensor) -> Self::Output {
405 Tensor::mul(&self, other)
406 }
407}
408
409impl std::ops::Div<f32> for Tensor {
411 type Output = Result<Tensor>;
412
413 fn div(self, scalar: f32) -> Self::Output {
414 self.scalar_div(scalar)
415 }
416}
417
418impl std::ops::Div<f32> for &Tensor {
419 type Output = Result<Tensor>;
420
421 fn div(self, scalar: f32) -> Self::Output {
422 self.scalar_div(scalar)
423 }
424}
425
426impl std::ops::Div<f64> for Tensor {
427 type Output = Result<Tensor>;
428
429 fn div(self, scalar: f64) -> Self::Output {
430 self.scalar_div(scalar as f32)
431 }
432}
433
434impl std::ops::Div<f64> for &Tensor {
435 type Output = Result<Tensor>;
436
437 fn div(self, scalar: f64) -> Self::Output {
438 self.scalar_div(scalar as f32)
439 }
440}
441
442impl std::ops::Div<f64> for &&Tensor {
443 type Output = Result<Tensor>;
444
445 fn div(self, scalar: f64) -> Self::Output {
446 (*self).scalar_div(scalar as f32)
447 }
448}
449
450impl std::ops::Sub for &Tensor {
452 type Output = Result<Tensor>;
453
454 fn sub(self, other: &Tensor) -> Self::Output {
455 Tensor::sub(self, other)
456 }
457}
458
459pub type TensorType = DType;
461
462pub use expression::{EvalContext, ExprNode, OpType, OptimizationHints, TensorExpr};
464
465pub use utils::{clear_gradients, disable_grad, enable_grad, is_grad_enabled};