tiny_recursive_rs/utils.rs
1/// Utility functions for TRM
2use candle_core::{Result, Tensor, Device};
3use num_traits::Float;
4
5/// Truncated normal initialization
6pub fn trunc_normal_init<F: Float>(std: F, a: F, b: F) -> impl Fn() -> F {
7 move || {
8 // TODO: Implement truncated normal
9 // For now, return 0
10 F::zero()
11 }
12}
13
14/// Calculate the number of parameters in a tensor
15pub fn count_parameters(tensor: &Tensor) -> usize {
16 tensor.dims().iter().product()
17}
18
19/// Create a causal mask for attention
20pub fn create_causal_mask(seq_len: usize, device: &Device) -> Result<Tensor> {
21 // TODO: Implement causal mask creation
22 todo!("Implement causal mask")
23}