tiny_recursive_rs/
utils.rs

1/// Utility functions for TRM
2use candle_core::{Result, Tensor, Device};
3use num_traits::Float;
4
5/// Truncated normal initialization
6pub fn trunc_normal_init<F: Float>(std: F, a: F, b: F) -> impl Fn() -> F {
7    move || {
8        // TODO: Implement truncated normal
9        // For now, return 0
10        F::zero()
11    }
12}
13
14/// Calculate the number of parameters in a tensor
15pub fn count_parameters(tensor: &Tensor) -> usize {
16    tensor.dims().iter().product()
17}
18
19/// Create a causal mask for attention
20pub fn create_causal_mask(seq_len: usize, device: &Device) -> Result<Tensor> {
21    // TODO: Implement causal mask creation
22    todo!("Implement causal mask")
23}