svod_tensor/rand/like.rs
1//! `Tensor::rand_like` / `Tensor::randn_like` — convenience wrappers that
2//! inherit shape, dtype, and device from an existing tensor.
3//!
4//! Used by nn modules that want to sample noise matching some reference
5//! tensor (dropout masks, noise injection, gaussian-init layers, etc.).
6
7use snafu::ResultExt;
8use svod_dtype::DType;
9use svod_ir::shape::to_vec_usize;
10
11use crate::{Result, Tensor, UOpSnafu};
12
13impl Tensor {
14 /// `rand_like` with a dtype override (device and shape still inherited).
15 pub fn rand_like_with_dtype(&self, dtype: DType) -> Result<Tensor> {
16 let shape = to_vec_usize(&self.shape()?).context(UOpSnafu)?;
17 Self::rand_with(&shape, dtype, self.device())
18 }
19
20 /// Uniform `[0, 1)` random tensor with the same shape/dtype/device as `self`.
21 pub fn rand_like(&self) -> Result<Tensor> {
22 self.rand_like_with_dtype(self.uop().dtype())
23 }
24
25 /// `randn_like` with a dtype override.
26 ///
27 /// Internally generates f32 samples via Box-Muller, then casts to the
28 /// target dtype. Using f32 inside Box-Muller keeps cos/log/sqrt accurate
29 /// even when the caller wants low-precision output.
30 pub fn randn_like_with_dtype(&self, dtype: DType) -> Result<Tensor> {
31 let shape = to_vec_usize(&self.shape()?).context(UOpSnafu)?;
32 Tensor::randn(&shape)?.cast(dtype)
33 }
34
35 /// Standard normal `N(0, 1)` random tensor with the same shape/dtype/device as `self`.
36 pub fn randn_like(&self) -> Result<Tensor> {
37 self.randn_like_with_dtype(self.uop().dtype())
38 }
39
40 /// Uniform integer `[low, high)` random tensor with the same shape/dtype/device as `self`.
41 ///
42 /// The underlying `Tensor::randint` returns `Int32`; if `self`'s dtype
43 /// differs the result is cast to match (e.g. `Int64` template → `Int64`
44 /// result). Requires `low < high`.
45 pub fn randint_like(&self, low: i64, high: i64) -> Result<Tensor> {
46 let shape = to_vec_usize(&self.shape()?).context(UOpSnafu)?;
47 let r = Tensor::randint(&shape, low, high)?;
48 if r.uop().dtype() == self.uop().dtype() { Ok(r) } else { r.cast(self.uop().dtype()) }
49 }
50}