train_station/tensor/indexing/masked_fill.rs
1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5 /// Fill masked elements with a specified value
6 ///
7 /// This operation returns a copy of the input tensor where elements are replaced
8 /// by the specified value wherever the corresponding boolean mask is true.
9 /// Elements where the mask is false retain their original values from the input tensor.
10 ///
11 /// The masked_fill operation is commonly used in machine learning for operations
12 /// like masking attention weights, zeroing out specific elements, and implementing
13 /// dropout-like functionality.
14 ///
15 /// # Arguments
16 ///
17 /// * `mask` - Boolean array with the same length as the number of tensor elements
18 /// * `value` - The value to fill masked positions with
19 ///
20 /// # Returns
21 ///
22 /// A new tensor with the same shape as the input, where masked elements are
23 /// replaced by `value` and unmasked elements retain their original values
24 ///
25 /// # Examples
26 ///
27 /// ## Basic Masked Fill
28 ///
29 /// ```
30 /// use train_station::Tensor;
31 ///
32 /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
33 /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
34 ///
35 /// // Create a mask: [false, true, false, true, false, true]
36 /// let mask = [false, true, false, true, false, true];
37 /// let result = tensor.masked_fill(&mask, -1.0);
38 ///
39 /// // Result: [[0.0, -1.0, 2.0], [-1.0, 4.0, -1.0]]
40 /// assert_eq!(result.shape().dims, vec![2, 3]);
41 /// assert_eq!(result.get(&[0, 0]), 0.0); // Unmasked
42 /// assert_eq!(result.get(&[0, 1]), -1.0); // Masked
43 /// assert_eq!(result.get(&[0, 2]), 2.0); // Unmasked
44 /// assert_eq!(result.get(&[1, 0]), -1.0); // Masked
45 /// assert_eq!(result.get(&[1, 1]), 4.0); // Unmasked
46 /// assert_eq!(result.get(&[1, 2]), -1.0); // Masked
47 /// ```
48 ///
49 /// ## Masked Fill with Gradient Tracking
50 ///
51 /// ```
52 /// use train_station::Tensor;
53 ///
54 /// let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap()
55 /// .with_requires_grad();
56 ///
57 /// // Create a mask with some true values
58 /// let mask = [false, true, false, true, false, false];
59 /// let mut result = tensor.masked_fill(&mask, 5.0);
60 ///
61 /// // Compute gradients
62 /// result.backward(None);
63 /// let grad = tensor.grad_by_value().expect("gradient missing");
64 ///
65 /// // Gradients should be zero where mask is true, 1 elsewhere
66 /// assert_eq!(grad.shape().dims, vec![2, 3]);
67 /// assert!((grad.get(&[0, 0]) - 1.0).abs() < 1e-6); // Unmasked: gradient flows
68 /// assert!((grad.get(&[0, 1]) - 0.0).abs() < 1e-6); // Masked: no gradient
69 /// assert!((grad.get(&[0, 2]) - 1.0).abs() < 1e-6); // Unmasked: gradient flows
70 /// ```
71 ///
72 /// ## Zeroing Out Specific Elements
73 ///
74 /// ```
75 /// use train_station::Tensor;
76 ///
77 /// // Create a tensor with some values
78 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
79 ///
80 /// // Create a mask to zero out every other element
81 /// let mask = [true, false, true, false, true, false];
82 /// let result = tensor.masked_fill(&mask, 0.0);
83 ///
84 /// // Result: [[0.0, 2.0, 0.0], [4.0, 0.0, 6.0]]
85 /// assert_eq!(result.get(&[0, 0]), 0.0); // Zeroed
86 /// assert_eq!(result.get(&[0, 1]), 2.0); // Kept
87 /// assert_eq!(result.get(&[0, 2]), 0.0); // Zeroed
88 /// assert_eq!(result.get(&[1, 0]), 4.0); // Kept
89 /// assert_eq!(result.get(&[1, 1]), 0.0); // Zeroed
90 /// assert_eq!(result.get(&[1, 2]), 6.0); // Kept
91 /// ```
92 ///
93 /// # Performance Characteristics
94 ///
95 /// - **Time Complexity**: O(n) where n is the number of elements in the tensor
96 /// - **Memory Usage**: Creates a new tensor with the same size as the input
97 /// - **Optimization**: Uses efficient stride-based iteration for non-contiguous tensors
98 /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
99 /// - **Memory Layout**: Output tensor is always contiguous for optimal performance
100 ///
101 /// # Implementation Details
102 ///
103 /// The masked_fill operation works by:
104 /// 1. Validating that the mask length equals the number of tensor elements
105 /// 2. Creating a new contiguous output tensor with the same shape
106 /// 3. Iterating through all elements in logical order
107 /// 4. For each element, checking the corresponding mask value:
108 /// - If mask is true: use the fill value
109 /// - If mask is false: copy the original value from input tensor
110 /// 5. Computing source offsets using the input tensor's shape for non-contiguous tensors
111 /// 6. Registering the operation for gradient computation if needed
112 ///
113 /// # Safety
114 ///
115 /// This function performs bounds checking to ensure:
116 /// - The mask length equals the number of tensor elements
117 /// - Memory access is safe through proper offset calculations
118 /// - The operation handles both contiguous and non-contiguous tensors correctly
119 ///
120 /// # Panics
121 ///
122 /// This function will panic if:
123 /// - The mask length does not equal the number of tensor elements
124 ///
125 /// # Thread Safety
126 ///
127 /// This function is thread-safe and can be called concurrently on different tensors.
128 /// The operation does not modify the input tensor and creates a new output tensor.
129 ///
130 /// # GradTrack Behavior
131 ///
132 /// When gradient tracking is enabled:
133 /// - Gradients do not flow through masked positions (they are zeroed)
134 /// - Gradients flow normally through unmasked positions
135 /// - This behavior is useful for implementing operations like dropout
136 pub fn masked_fill(&self, mask: &[bool], value: f32) -> Tensor {
137 let numel = self.size();
138 assert_eq!(
139 mask.len(),
140 numel,
141 "mask length {} must equal tensor elements {}",
142 mask.len(),
143 numel
144 );
145
146 // Output is a contiguous copy with applied mask
147 let mut output = Tensor::new(self.shape().dims.clone());
148
149 // Iterate in logical order using strides if needed
150 let rank = self.shape().rank();
151 let mut coords = vec![0usize; rank];
152 for (lin, &m) in mask.iter().enumerate().take(numel) {
153 // Decode logical coords for mask mapping
154 let mut tmp = lin;
155 for i in (0..rank).rev() {
156 let s = self.shape().dims[i];
157 coords[i] = if s == 0 { 0 } else { tmp % s };
158 tmp /= s;
159 }
160 let src_off = self.shape().offset(&coords);
161 unsafe {
162 *output.as_mut_ptr().add(lin) = if m {
163 value
164 } else {
165 *self.as_ptr().add(src_off)
166 };
167 }
168 }
169
170 if self.requires_grad() {
171 output.set_requires_grad(true);
172 let grad_fn = GradFn::MaskedFill {
173 mask: mask.to_vec(),
174 input_shape: self.shape().dims.clone(),
175 };
176 output.set_grad_fn(grad_fn.clone());
177 GradEngine::register_operation(output.id(), vec![self.id()], grad_fn);
178 }
179
180 output
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn test_masked_fill_basic() {
190 let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
191 let mask = vec![false, true, false, true, false, true];
192 let y = x.masked_fill(&mask, -1.0);
193 assert_eq!(y.shape().dims, vec![2, 3]);
194 assert_eq!(y.get(&[0, 0]), 0.0);
195 assert_eq!(y.get(&[0, 1]), -1.0);
196 assert_eq!(y.get(&[1, 0]), -1.0);
197 }
198
199 #[test]
200 fn test_masked_fill_gradients() {
201 let x = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3])
202 .unwrap()
203 .with_requires_grad();
204 let mask = vec![false, true, false, true, false, false];
205 let mut y = x.masked_fill(&mask, 5.0);
206 y.backward(None);
207 let gx = x.grad_by_value().expect("grad missing");
208 // Grad should be zero where mask is true, 1 elsewhere (from upstream ones)
209 for (i, &m) in mask.iter().enumerate().take(6) {
210 let expected = if m { 0.0 } else { 1.0 };
211 assert!((gx.get(&[i / 3, i % 3]) - expected).abs() < 1e-6);
212 }
213 }
214
215 #[test]
216 fn test_masked_fill_coordinate_decoding_bug_fix() {
217 // This test specifically targets the coordinate decoding bug that was fixed
218 // Create a 3D tensor to test multi-dimensional coordinate mapping
219 let data = vec![
220 1.0, 2.0, 3.0, 4.0, // [0, 0, :]
221 5.0, 6.0, 7.0, 8.0, // [0, 1, :]
222 9.0, 10.0, 11.0, 12.0, // [1, 0, :]
223 13.0, 14.0, 15.0, 16.0, // [1, 1, :]
224 ];
225 let tensor = Tensor::from_slice(&data, vec![2, 2, 4]).unwrap();
226
227 // Create a mask that selects specific positions
228 // Position 0: [0,0,0] -> 1.0, Position 5: [0,1,1] -> 6.0, Position 10: [1,0,2] -> 11.0
229 let mut mask = vec![false; 16];
230 mask[0] = true; // [0,0,0]
231 mask[5] = true; // [0,1,1]
232 mask[10] = true; // [1,0,2]
233
234 let result = tensor.masked_fill(&mask, 99.0);
235
236 // Verify the correct positions were masked
237 assert_eq!(result.get(&[0, 0, 0]), 99.0); // position 0, was 1.0
238 assert_eq!(result.get(&[0, 1, 1]), 99.0); // position 5, was 6.0
239 assert_eq!(result.get(&[1, 0, 2]), 99.0); // position 10, was 11.0
240
241 // Verify unmasked positions remain unchanged
242 assert_eq!(result.get(&[0, 0, 1]), 2.0); // position 1, unchanged
243 assert_eq!(result.get(&[0, 1, 0]), 5.0); // position 4, unchanged
244 assert_eq!(result.get(&[1, 1, 3]), 16.0); // position 15, unchanged
245 }
246}