train_station/tensor/indexing/masked_fill.rs
1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5 /// Fill masked elements with a specified value
6 ///
7 /// This operation returns a copy of the input tensor where elements are replaced
8 /// by the specified value wherever the corresponding boolean mask is true.
9 /// Elements where the mask is false retain their original values from the input tensor.
10 ///
11 /// The masked_fill operation is commonly used in machine learning for operations
12 /// like masking attention weights, zeroing out specific elements, and implementing
13 /// dropout-like functionality.
14 ///
15 /// # Arguments
16 ///
17 /// * `mask` - Boolean array with the same length as the number of tensor elements
18 /// * `value` - The value to fill masked positions with
19 ///
20 /// # Returns
21 ///
22 /// A new tensor with the same shape as the input, where masked elements are
23 /// replaced by `value` and unmasked elements retain their original values
24 ///
25 /// # Examples
26 ///
27 /// ## Basic Masked Fill
28 ///
29 /// ```
30 /// use train_station::Tensor;
31 ///
32 /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
33 /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
34 ///
35 /// // Create a mask: [false, true, false, true, false, true]
36 /// let mask = [false, true, false, true, false, true];
37 /// let result = tensor.masked_fill(&mask, -1.0);
38 ///
39 /// // Result: [[0.0, -1.0, 2.0], [-1.0, 4.0, -1.0]]
40 /// assert_eq!(result.shape().dims, vec![2, 3]);
41 /// assert_eq!(result.get(&[0, 0]), 0.0); // Unmasked
42 /// assert_eq!(result.get(&[0, 1]), -1.0); // Masked
43 /// assert_eq!(result.get(&[0, 2]), 2.0); // Unmasked
44 /// assert_eq!(result.get(&[1, 0]), -1.0); // Masked
45 /// assert_eq!(result.get(&[1, 1]), 4.0); // Unmasked
46 /// assert_eq!(result.get(&[1, 2]), -1.0); // Masked
47 /// ```
48 ///
49 /// ## Masked Fill with Gradient Tracking
50 ///
51 /// ```
52 /// use train_station::Tensor;
53 ///
54 /// let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap()
55 /// .with_requires_grad();
56 ///
57 /// // Create a mask with some true values
58 /// let mask = [false, true, false, true, false, false];
59 /// let mut result = tensor.masked_fill(&mask, 5.0);
60 ///
61 /// // Compute gradients
62 /// result.backward(None);
63 /// let grad = tensor.grad_by_value().expect("gradient missing");
64 ///
65 /// // Gradients should be zero where mask is true, 1 elsewhere
66 /// assert_eq!(grad.shape().dims, vec![2, 3]);
67 /// assert!((grad.get(&[0, 0]) - 1.0).abs() < 1e-6); // Unmasked: gradient flows
68 /// assert!((grad.get(&[0, 1]) - 0.0).abs() < 1e-6); // Masked: no gradient
69 /// assert!((grad.get(&[0, 2]) - 1.0).abs() < 1e-6); // Unmasked: gradient flows
70 /// ```
71 ///
72 /// ## Zeroing Out Specific Elements
73 ///
74 /// ```
75 /// use train_station::Tensor;
76 ///
77 /// // Create a tensor with some values
78 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
79 ///
80 /// // Create a mask to zero out every other element
81 /// let mask = [true, false, true, false, true, false];
82 /// let result = tensor.masked_fill(&mask, 0.0);
83 ///
84 /// // Result: [[0.0, 2.0, 0.0], [4.0, 0.0, 6.0]]
85 /// assert_eq!(result.get(&[0, 0]), 0.0); // Zeroed
86 /// assert_eq!(result.get(&[0, 1]), 2.0); // Kept
87 /// assert_eq!(result.get(&[0, 2]), 0.0); // Zeroed
88 /// assert_eq!(result.get(&[1, 0]), 4.0); // Kept
89 /// assert_eq!(result.get(&[1, 1]), 0.0); // Zeroed
90 /// assert_eq!(result.get(&[1, 2]), 6.0); // Kept
91 /// ```
92 ///
93 /// # Performance Characteristics
94 ///
95 /// - **Time Complexity**: O(n) where n is the number of elements in the tensor
96 /// - **Memory Usage**: Creates a new tensor with the same size as the input
97 /// - **Optimization**: Uses efficient stride-based iteration for non-contiguous tensors
98 /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
99 /// - **Memory Layout**: Output tensor is always contiguous for optimal performance
100 ///
101 /// # Implementation Details
102 ///
103 /// The masked_fill operation works by:
104 /// 1. Validating that the mask length equals the number of tensor elements
105 /// 2. Creating a new contiguous output tensor with the same shape
106 /// 3. Iterating through all elements in logical order
107 /// 4. For each element, checking the corresponding mask value:
108 /// - If mask is true: use the fill value
109 /// - If mask is false: copy the original value from input tensor
110 /// 5. Computing source offsets using the input tensor's shape for non-contiguous tensors
111 /// 6. Registering the operation for gradient computation if needed
112 ///
113 /// # Safety
114 ///
115 /// This function performs bounds checking to ensure:
116 /// - The mask length equals the number of tensor elements
117 /// - Memory access is safe through proper offset calculations
118 /// - The operation handles both contiguous and non-contiguous tensors correctly
119 ///
120 /// # Panics
121 ///
122 /// This function will panic if:
123 /// - The mask length does not equal the number of tensor elements
124 ///
125 /// # Thread Safety
126 ///
127 /// This function is thread-safe and can be called concurrently on different tensors.
128 /// The operation does not modify the input tensor and creates a new output tensor.
129 ///
130 /// # GradTrack Behavior
131 ///
132 /// When gradient tracking is enabled:
133 /// - Gradients do not flow through masked positions (they are zeroed)
134 /// - Gradients flow normally through unmasked positions
135 /// - This behavior is useful for implementing operations like dropout
136 #[track_caller]
137 pub fn masked_fill(&self, mask: &[bool], value: f32) -> Tensor {
138 let numel = self.size();
139 assert_eq!(
140 mask.len(),
141 numel,
142 "mask length {} must equal tensor elements {}",
143 mask.len(),
144 numel
145 );
146
147 // Output is a contiguous copy with applied mask
148 let mut output = Tensor::new(self.shape().dims.clone());
149
150 // Iterate in logical order using strides if needed
151 let rank = self.shape().rank();
152 let mut coords = vec![0usize; rank];
153 for (lin, &m) in mask.iter().enumerate().take(numel) {
154 // Decode logical coords for mask mapping
155 let mut tmp = lin;
156 for i in (0..rank).rev() {
157 let s = self.shape().dims[i];
158 coords[i] = if s == 0 { 0 } else { tmp % s };
159 tmp /= s;
160 }
161 let src_off = self.shape().offset(&coords);
162 unsafe {
163 *output.as_mut_ptr().add(lin) = if m {
164 value
165 } else {
166 *self.as_ptr().add(src_off)
167 };
168 }
169 }
170
171 if self.requires_grad() {
172 output.set_requires_grad(true);
173 let grad_fn = GradFn::MaskedFill {
174 mask: mask.to_vec(),
175 input_shape: self.shape().dims.clone(),
176 };
177 output.set_grad_fn(grad_fn.clone());
178 GradEngine::register_operation(output.id(), vec![self.id()], grad_fn);
179 }
180
181 output
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 #[test]
190 fn test_masked_fill_basic() {
191 let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
192 let mask = vec![false, true, false, true, false, true];
193 let y = x.masked_fill(&mask, -1.0);
194 assert_eq!(y.shape().dims, vec![2, 3]);
195 assert_eq!(y.get(&[0, 0]), 0.0);
196 assert_eq!(y.get(&[0, 1]), -1.0);
197 assert_eq!(y.get(&[1, 0]), -1.0);
198 }
199
200 #[test]
201 fn test_masked_fill_gradients() {
202 let x = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3])
203 .unwrap()
204 .with_requires_grad();
205 let mask = vec![false, true, false, true, false, false];
206 let mut y = x.masked_fill(&mask, 5.0);
207 y.backward(None);
208 let gx = x.grad_by_value().expect("grad missing");
209 // Grad should be zero where mask is true, 1 elsewhere (from upstream ones)
210 for (i, &m) in mask.iter().enumerate().take(6) {
211 let expected = if m { 0.0 } else { 1.0 };
212 assert!((gx.get(&[i / 3, i % 3]) - expected).abs() < 1e-6);
213 }
214 }
215
216 #[test]
217 fn test_masked_fill_coordinate_decoding_bug_fix() {
218 // This test specifically targets the coordinate decoding bug that was fixed
219 // Create a 3D tensor to test multi-dimensional coordinate mapping
220 let data = vec![
221 1.0, 2.0, 3.0, 4.0, // [0, 0, :]
222 5.0, 6.0, 7.0, 8.0, // [0, 1, :]
223 9.0, 10.0, 11.0, 12.0, // [1, 0, :]
224 13.0, 14.0, 15.0, 16.0, // [1, 1, :]
225 ];
226 let tensor = Tensor::from_slice(&data, vec![2, 2, 4]).unwrap();
227
228 // Create a mask that selects specific positions
229 // Position 0: [0,0,0] -> 1.0, Position 5: [0,1,1] -> 6.0, Position 10: [1,0,2] -> 11.0
230 let mut mask = vec![false; 16];
231 mask[0] = true; // [0,0,0]
232 mask[5] = true; // [0,1,1]
233 mask[10] = true; // [1,0,2]
234
235 let result = tensor.masked_fill(&mask, 99.0);
236
237 // Verify the correct positions were masked
238 assert_eq!(result.get(&[0, 0, 0]), 99.0); // position 0, was 1.0
239 assert_eq!(result.get(&[0, 1, 1]), 99.0); // position 5, was 6.0
240 assert_eq!(result.get(&[1, 0, 2]), 99.0); // position 10, was 11.0
241
242 // Verify unmasked positions remain unchanged
243 assert_eq!(result.get(&[0, 0, 1]), 2.0); // position 1, unchanged
244 assert_eq!(result.get(&[0, 1, 0]), 5.0); // position 4, unchanged
245 assert_eq!(result.get(&[1, 1, 3]), 16.0); // position 15, unchanged
246 }
247}