Skip to main content

svod_tensor/rand/
like.rs

1//! `Tensor::rand_like` / `Tensor::randn_like` — convenience wrappers that
2//! inherit shape, dtype, and device from an existing tensor.
3//!
4//! Used by nn modules that want to sample noise matching some reference
5//! tensor (dropout masks, noise injection, gaussian-init layers, etc.).
6
7use snafu::ResultExt;
8use svod_dtype::DType;
9use svod_ir::shape::to_vec_usize;
10
11use crate::{Result, Tensor, UOpSnafu};
12
13impl Tensor {
14    /// `rand_like` with a dtype override (device and shape still inherited).
15    pub fn rand_like_with_dtype(&self, dtype: DType) -> Result<Tensor> {
16        let shape = to_vec_usize(&self.shape()?).context(UOpSnafu)?;
17        Self::rand_with(&shape, dtype, self.device())
18    }
19
20    /// Uniform `[0, 1)` random tensor with the same shape/dtype/device as `self`.
21    pub fn rand_like(&self) -> Result<Tensor> {
22        self.rand_like_with_dtype(self.uop().dtype())
23    }
24
25    /// `randn_like` with a dtype override.
26    ///
27    /// Internally generates f32 samples via Box-Muller, then casts to the
28    /// target dtype. Using f32 inside Box-Muller keeps cos/log/sqrt accurate
29    /// even when the caller wants low-precision output.
30    pub fn randn_like_with_dtype(&self, dtype: DType) -> Result<Tensor> {
31        let shape = to_vec_usize(&self.shape()?).context(UOpSnafu)?;
32        Tensor::randn(&shape)?.cast(dtype)
33    }
34
35    /// Standard normal `N(0, 1)` random tensor with the same shape/dtype/device as `self`.
36    pub fn randn_like(&self) -> Result<Tensor> {
37        self.randn_like_with_dtype(self.uop().dtype())
38    }
39
40    /// Uniform integer `[low, high)` random tensor with the same shape/dtype/device as `self`.
41    ///
42    /// The underlying `Tensor::randint` returns `Int32`; if `self`'s dtype
43    /// differs the result is cast to match (e.g. `Int64` template → `Int64`
44    /// result). Requires `low < high`.
45    pub fn randint_like(&self, low: i64, high: i64) -> Result<Tensor> {
46        let shape = to_vec_usize(&self.shape()?).context(UOpSnafu)?;
47        let r = Tensor::randint(&shape, low, high)?;
48        if r.uop().dtype() == self.uop().dtype() { Ok(r) } else { r.cast(self.uop().dtype()) }
49    }
50}