Skip to main content

rustorch_core/ops/
embedding.rs

1use crate::autograd::BackwardOp;
2use crate::storage::Storage;
3use crate::Tensor;
4use rayon::prelude::*;
5use std::sync::Arc;
6
7#[derive(Debug)]
8pub struct EmbeddingBackward {
9    pub input: Tensor, // indices
10    pub weight: Tensor,
11    pub num_embeddings: usize,
12    pub embedding_dim: usize,
13    pub padding_idx: Option<usize>,
14}
15
16impl BackwardOp for EmbeddingBackward {
17    fn backward(&self, grad: &Tensor) {
18        if self.weight.requires_grad() {
19            // Grad is (N, *, EmbeddingDim)
20            // Input is (N, *) indices
21            // We need to scatter add grad to weight.grad
22            // weight.grad shape: (NumEmbeddings, EmbeddingDim)
23
24            // This is sparse update.
25            // For simplicity, we can iterate over indices and accumulate.
26            // But we need to lock weight.grad.
27
28            // Let's create a dense grad tensor for weight first (inefficient but simple)
29            // Or better: accumulate directly if possible.
30            // Tensor::accumulate_grad expects a Tensor.
31
32            // We need to implement a "SparseAccumulate" or just create a dense Zero tensor and fill it.
33            let mut weight_grad_data = vec![0.0; self.num_embeddings * self.embedding_dim];
34
35            #[cfg(feature = "wgpu_backend")]
36            let (grad, input) = {
37                let g = if grad.storage().device().is_wgpu() {
38                    grad.to_cpu()
39                } else {
40                    grad.clone()
41                };
42                let i = if self.input.storage().device().is_wgpu() {
43                    self.input.to_cpu()
44                } else {
45                    self.input.clone()
46                };
47                (g, i)
48            };
49            #[cfg(not(feature = "wgpu_backend"))]
50            let input = self.input.clone();
51
52            let grad_guard = grad.data();
53            let grad_data = &*grad_guard;
54
55            let input_guard = input.data(); // These are f32, need to cast to usize
56            let input_data = &*input_guard;
57
58            // Check shapes
59            // Input: (B...)
60            // Grad: (B..., Dim)
61            // Input len * Dim == Grad len
62            let num_indices = input_data.len();
63            let dim = self.embedding_dim;
64
65            if grad_data.len() != num_indices * dim {
66                panic!("Embedding backward shape mismatch");
67            }
68
69            // Iterate and accumulate
70            // This part is hard to parallelize without atomic adds on weight_grad_data.
71            // So run serial or use localized buffers. Serial for now.
72            for (i, &idx_f) in input_data.iter().enumerate() {
73                let idx = idx_f as usize;
74                if let Some(pad) = self.padding_idx {
75                    if idx == pad {
76                        continue;
77                    }
78                }
79                if idx >= self.num_embeddings {
80                    // Index out of bounds, ignore or panic? PyTorch panics or errors.
81                    continue;
82                }
83
84                let grad_offset = i * dim;
85                let weight_offset = idx * dim;
86
87                for j in 0..dim {
88                    weight_grad_data[weight_offset + j] += grad_data[grad_offset + j];
89                }
90            }
91
92            let weight_grad = Tensor::new(&weight_grad_data, self.weight.shape());
93            self.weight.accumulate_grad(&weight_grad);
94            self.weight.backward_step();
95        }
96    }
97}
98
99pub fn embedding(
100    input: &Tensor,
101    weight: &Tensor,
102    padding_idx: Option<usize>,
103    _max_norm: Option<f32>,
104    _norm_type: f32,
105    _scale_grad_by_freq: bool,
106    _sparse: bool,
107) -> Tensor {
108    // Input: Indices (Arbitrary Shape) -> but stored as f32 in Tensor
109    // Weight: (NumEmbeddings, EmbeddingDim)
110    // Output: (InputShape..., EmbeddingDim)
111
112    let weight_shape = weight.shape();
113    if weight_shape.len() != 2 {
114        panic!("Embedding weight must be 2D");
115    }
116    let num_embeddings = weight_shape[0];
117    let embedding_dim = weight_shape[1];
118
119    #[cfg(feature = "wgpu_backend")]
120    let (input, weight) = {
121        let i = if input.storage().device().is_wgpu() {
122            input.to_cpu()
123        } else {
124            input.clone()
125        };
126        let w = if weight.storage().device().is_wgpu() {
127            weight.to_cpu()
128        } else {
129            weight.clone()
130        };
131        (i, w)
132    };
133    #[cfg(not(feature = "wgpu_backend"))]
134    let (input, weight) = (input.clone(), weight.clone());
135
136    let input_guard = input.data();
137    let input_data = &*input_guard;
138
139    let weight_guard = weight.data();
140    let weight_data = &*weight_guard;
141
142    let num_indices = input_data.len();
143    let mut output_data = vec![0.0; num_indices * embedding_dim];
144
145    // Parallel lookup
146    output_data
147        .par_chunks_mut(embedding_dim)
148        .enumerate()
149        .for_each(|(i, out_row)| {
150            let idx_f = input_data[i];
151            let idx = idx_f as usize;
152
153            if idx >= num_embeddings {
154                // Panic in real scenario
155                // panic!("Index {} out of bounds for embedding size {}", idx, num_embeddings);
156                // But inside parallel iterator panic is messy.
157                // Let's just fill 0 or clamp?
158                // PyTorch: runtime error.
159                return;
160            }
161
162            if let Some(pad) = padding_idx {
163                if idx == pad {
164                    // Zero vector
165                    out_row.fill(0.0);
166                    return;
167                }
168            }
169
170            let weight_offset = idx * embedding_dim;
171            let w_row = &weight_data[weight_offset..weight_offset + embedding_dim];
172            out_row.copy_from_slice(w_row);
173        });
174
175    let mut output_shape = input.shape().to_vec();
176    output_shape.push(embedding_dim);
177
178    let storage = Storage::new(output_data);
179    let mut tensor = Tensor::new_with_storage(storage, &output_shape);
180
181    if weight.requires_grad() {
182        tensor.set_requires_grad_mut(true);
183        tensor.set_op(Arc::new(EmbeddingBackward {
184            input: input.clone(),
185            weight: weight.clone(),
186            num_embeddings,
187            embedding_dim,
188            padding_idx,
189        }));
190    }
191
192    tensor
193}