rustorch_core/ops/
embedding.rs1use crate::autograd::BackwardOp;
2use crate::storage::Storage;
3use crate::Tensor;
4use rayon::prelude::*;
5use std::sync::Arc;
6
7#[derive(Debug)]
8pub struct EmbeddingBackward {
9 pub input: Tensor, pub weight: Tensor,
11 pub num_embeddings: usize,
12 pub embedding_dim: usize,
13 pub padding_idx: Option<usize>,
14}
15
16impl BackwardOp for EmbeddingBackward {
17 fn backward(&self, grad: &Tensor) {
18 if self.weight.requires_grad() {
19 let mut weight_grad_data = vec![0.0; self.num_embeddings * self.embedding_dim];
34
35 #[cfg(feature = "wgpu_backend")]
36 let (grad, input) = {
37 let g = if grad.storage().device().is_wgpu() {
38 grad.to_cpu()
39 } else {
40 grad.clone()
41 };
42 let i = if self.input.storage().device().is_wgpu() {
43 self.input.to_cpu()
44 } else {
45 self.input.clone()
46 };
47 (g, i)
48 };
49 #[cfg(not(feature = "wgpu_backend"))]
50 let input = self.input.clone();
51
52 let grad_guard = grad.data();
53 let grad_data = &*grad_guard;
54
55 let input_guard = input.data(); let input_data = &*input_guard;
57
58 let num_indices = input_data.len();
63 let dim = self.embedding_dim;
64
65 if grad_data.len() != num_indices * dim {
66 panic!("Embedding backward shape mismatch");
67 }
68
69 for (i, &idx_f) in input_data.iter().enumerate() {
73 let idx = idx_f as usize;
74 if let Some(pad) = self.padding_idx {
75 if idx == pad {
76 continue;
77 }
78 }
79 if idx >= self.num_embeddings {
80 continue;
82 }
83
84 let grad_offset = i * dim;
85 let weight_offset = idx * dim;
86
87 for j in 0..dim {
88 weight_grad_data[weight_offset + j] += grad_data[grad_offset + j];
89 }
90 }
91
92 let weight_grad = Tensor::new(&weight_grad_data, self.weight.shape());
93 self.weight.accumulate_grad(&weight_grad);
94 self.weight.backward_step();
95 }
96 }
97}
98
99pub fn embedding(
100 input: &Tensor,
101 weight: &Tensor,
102 padding_idx: Option<usize>,
103 _max_norm: Option<f32>,
104 _norm_type: f32,
105 _scale_grad_by_freq: bool,
106 _sparse: bool,
107) -> Tensor {
108 let weight_shape = weight.shape();
113 if weight_shape.len() != 2 {
114 panic!("Embedding weight must be 2D");
115 }
116 let num_embeddings = weight_shape[0];
117 let embedding_dim = weight_shape[1];
118
119 #[cfg(feature = "wgpu_backend")]
120 let (input, weight) = {
121 let i = if input.storage().device().is_wgpu() {
122 input.to_cpu()
123 } else {
124 input.clone()
125 };
126 let w = if weight.storage().device().is_wgpu() {
127 weight.to_cpu()
128 } else {
129 weight.clone()
130 };
131 (i, w)
132 };
133 #[cfg(not(feature = "wgpu_backend"))]
134 let (input, weight) = (input.clone(), weight.clone());
135
136 let input_guard = input.data();
137 let input_data = &*input_guard;
138
139 let weight_guard = weight.data();
140 let weight_data = &*weight_guard;
141
142 let num_indices = input_data.len();
143 let mut output_data = vec![0.0; num_indices * embedding_dim];
144
145 output_data
147 .par_chunks_mut(embedding_dim)
148 .enumerate()
149 .for_each(|(i, out_row)| {
150 let idx_f = input_data[i];
151 let idx = idx_f as usize;
152
153 if idx >= num_embeddings {
154 return;
160 }
161
162 if let Some(pad) = padding_idx {
163 if idx == pad {
164 out_row.fill(0.0);
166 return;
167 }
168 }
169
170 let weight_offset = idx * embedding_dim;
171 let w_row = &weight_data[weight_offset..weight_offset + embedding_dim];
172 out_row.copy_from_slice(w_row);
173 });
174
175 let mut output_shape = input.shape().to_vec();
176 output_shape.push(embedding_dim);
177
178 let storage = Storage::new(output_data);
179 let mut tensor = Tensor::new_with_storage(storage, &output_shape);
180
181 if weight.requires_grad() {
182 tensor.set_requires_grad_mut(true);
183 tensor.set_op(Arc::new(EmbeddingBackward {
184 input: input.clone(),
185 weight: weight.clone(),
186 num_embeddings,
187 embedding_dim,
188 padding_idx,
189 }));
190 }
191
192 tensor
193}