train_station/tensor/indexing/
masked_fill.rs

1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5    /// Fill masked elements with a specified value
6    ///
7    /// This operation returns a copy of the input tensor where elements are replaced
8    /// by the specified value wherever the corresponding boolean mask is true.
9    /// Elements where the mask is false retain their original values from the input tensor.
10    ///
11    /// The masked_fill operation is commonly used in machine learning for operations
12    /// like masking attention weights, zeroing out specific elements, and implementing
13    /// dropout-like functionality.
14    ///
15    /// # Arguments
16    ///
17    /// * `mask` - Boolean array with the same length as the number of tensor elements
18    /// * `value` - The value to fill masked positions with
19    ///
20    /// # Returns
21    ///
22    /// A new tensor with the same shape as the input, where masked elements are
23    /// replaced by `value` and unmasked elements retain their original values
24    ///
25    /// # Examples
26    ///
27    /// ## Basic Masked Fill
28    ///
29    /// ```
30    /// use train_station::Tensor;
31    ///
32    /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
33    /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
34    ///
35    /// // Create a mask: [false, true, false, true, false, true]
36    /// let mask = [false, true, false, true, false, true];
37    /// let result = tensor.masked_fill(&mask, -1.0);
38    ///
39    /// // Result: [[0.0, -1.0, 2.0], [-1.0, 4.0, -1.0]]
40    /// assert_eq!(result.shape().dims, vec![2, 3]);
41    /// assert_eq!(result.get(&[0, 0]), 0.0);   // Unmasked
42    /// assert_eq!(result.get(&[0, 1]), -1.0);  // Masked
43    /// assert_eq!(result.get(&[0, 2]), 2.0);   // Unmasked
44    /// assert_eq!(result.get(&[1, 0]), -1.0);  // Masked
45    /// assert_eq!(result.get(&[1, 1]), 4.0);   // Unmasked
46    /// assert_eq!(result.get(&[1, 2]), -1.0);  // Masked
47    /// ```
48    ///
49    /// ## Masked Fill with Gradient Tracking
50    ///
51    /// ```
52    /// use train_station::Tensor;
53    ///
54    /// let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap()
55    ///     .with_requires_grad();
56    ///
57    /// // Create a mask with some true values
58    /// let mask = [false, true, false, true, false, false];
59    /// let mut result = tensor.masked_fill(&mask, 5.0);
60    ///
61    /// // Compute gradients
62    /// result.backward(None);
63    /// let grad = tensor.grad_by_value().expect("gradient missing");
64    ///
65    /// // Gradients should be zero where mask is true, 1 elsewhere
66    /// assert_eq!(grad.shape().dims, vec![2, 3]);
67    /// assert!((grad.get(&[0, 0]) - 1.0).abs() < 1e-6);   // Unmasked: gradient flows
68    /// assert!((grad.get(&[0, 1]) - 0.0).abs() < 1e-6);   // Masked: no gradient
69    /// assert!((grad.get(&[0, 2]) - 1.0).abs() < 1e-6);   // Unmasked: gradient flows
70    /// ```
71    ///
72    /// ## Zeroing Out Specific Elements
73    ///
74    /// ```
75    /// use train_station::Tensor;
76    ///
77    /// // Create a tensor with some values
78    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
79    ///
80    /// // Create a mask to zero out every other element
81    /// let mask = [true, false, true, false, true, false];
82    /// let result = tensor.masked_fill(&mask, 0.0);
83    ///
84    /// // Result: [[0.0, 2.0, 0.0], [4.0, 0.0, 6.0]]
85    /// assert_eq!(result.get(&[0, 0]), 0.0);  // Zeroed
86    /// assert_eq!(result.get(&[0, 1]), 2.0);  // Kept
87    /// assert_eq!(result.get(&[0, 2]), 0.0);  // Zeroed
88    /// assert_eq!(result.get(&[1, 0]), 4.0);  // Kept
89    /// assert_eq!(result.get(&[1, 1]), 0.0);  // Zeroed
90    /// assert_eq!(result.get(&[1, 2]), 6.0);  // Kept
91    /// ```
92    ///
93    /// # Performance Characteristics
94    ///
95    /// - **Time Complexity**: O(n) where n is the number of elements in the tensor
96    /// - **Memory Usage**: Creates a new tensor with the same size as the input
97    /// - **Optimization**: Uses efficient stride-based iteration for non-contiguous tensors
98    /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
99    /// - **Memory Layout**: Output tensor is always contiguous for optimal performance
100    ///
101    /// # Implementation Details
102    ///
103    /// The masked_fill operation works by:
104    /// 1. Validating that the mask length equals the number of tensor elements
105    /// 2. Creating a new contiguous output tensor with the same shape
106    /// 3. Iterating through all elements in logical order
107    /// 4. For each element, checking the corresponding mask value:
108    ///    - If mask is true: use the fill value
109    ///    - If mask is false: copy the original value from input tensor
110    /// 5. Computing source offsets using the input tensor's shape for non-contiguous tensors
111    /// 6. Registering the operation for gradient computation if needed
112    ///
113    /// # Safety
114    ///
115    /// This function performs bounds checking to ensure:
116    /// - The mask length equals the number of tensor elements
117    /// - Memory access is safe through proper offset calculations
118    /// - The operation handles both contiguous and non-contiguous tensors correctly
119    ///
120    /// # Panics
121    ///
122    /// This function will panic if:
123    /// - The mask length does not equal the number of tensor elements
124    ///
125    /// # Thread Safety
126    ///
127    /// This function is thread-safe and can be called concurrently on different tensors.
128    /// The operation does not modify the input tensor and creates a new output tensor.
129    ///
130    /// # GradTrack Behavior
131    ///
132    /// When gradient tracking is enabled:
133    /// - Gradients do not flow through masked positions (they are zeroed)
134    /// - Gradients flow normally through unmasked positions
135    /// - This behavior is useful for implementing operations like dropout
136    #[track_caller]
137    pub fn masked_fill(&self, mask: &[bool], value: f32) -> Tensor {
138        let numel = self.size();
139        assert_eq!(
140            mask.len(),
141            numel,
142            "mask length {} must equal tensor elements {}",
143            mask.len(),
144            numel
145        );
146
147        // Output is a contiguous copy with applied mask
148        let mut output = Tensor::new(self.shape().dims.clone());
149
150        // Iterate in logical order using strides if needed
151        let rank = self.shape().rank();
152        let mut coords = vec![0usize; rank];
153        for (lin, &m) in mask.iter().enumerate().take(numel) {
154            // Decode logical coords for mask mapping
155            let mut tmp = lin;
156            for i in (0..rank).rev() {
157                let s = self.shape().dims[i];
158                coords[i] = if s == 0 { 0 } else { tmp % s };
159                tmp /= s;
160            }
161            let src_off = self.shape().offset(&coords);
162            unsafe {
163                *output.as_mut_ptr().add(lin) = if m {
164                    value
165                } else {
166                    *self.as_ptr().add(src_off)
167                };
168            }
169        }
170
171        if self.requires_grad() {
172            output.set_requires_grad(true);
173            let grad_fn = GradFn::MaskedFill {
174                mask: mask.to_vec(),
175                input_shape: self.shape().dims.clone(),
176            };
177            output.set_grad_fn(grad_fn.clone());
178            GradEngine::register_operation(output.id(), vec![self.id()], grad_fn);
179        }
180
181        output
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    #[test]
190    fn test_masked_fill_basic() {
191        let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
192        let mask = vec![false, true, false, true, false, true];
193        let y = x.masked_fill(&mask, -1.0);
194        assert_eq!(y.shape().dims, vec![2, 3]);
195        assert_eq!(y.get(&[0, 0]), 0.0);
196        assert_eq!(y.get(&[0, 1]), -1.0);
197        assert_eq!(y.get(&[1, 0]), -1.0);
198    }
199
200    #[test]
201    fn test_masked_fill_gradients() {
202        let x = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3])
203            .unwrap()
204            .with_requires_grad();
205        let mask = vec![false, true, false, true, false, false];
206        let mut y = x.masked_fill(&mask, 5.0);
207        y.backward(None);
208        let gx = x.grad_by_value().expect("grad missing");
209        // Grad should be zero where mask is true, 1 elsewhere (from upstream ones)
210        for (i, &m) in mask.iter().enumerate().take(6) {
211            let expected = if m { 0.0 } else { 1.0 };
212            assert!((gx.get(&[i / 3, i % 3]) - expected).abs() < 1e-6);
213        }
214    }
215
216    #[test]
217    fn test_masked_fill_coordinate_decoding_bug_fix() {
218        // This test specifically targets the coordinate decoding bug that was fixed
219        // Create a 3D tensor to test multi-dimensional coordinate mapping
220        let data = vec![
221            1.0, 2.0, 3.0, 4.0, // [0, 0, :]
222            5.0, 6.0, 7.0, 8.0, // [0, 1, :]
223            9.0, 10.0, 11.0, 12.0, // [1, 0, :]
224            13.0, 14.0, 15.0, 16.0, // [1, 1, :]
225        ];
226        let tensor = Tensor::from_slice(&data, vec![2, 2, 4]).unwrap();
227
228        // Create a mask that selects specific positions
229        // Position 0: [0,0,0] -> 1.0, Position 5: [0,1,1] -> 6.0, Position 10: [1,0,2] -> 11.0
230        let mut mask = vec![false; 16];
231        mask[0] = true; // [0,0,0]
232        mask[5] = true; // [0,1,1]
233        mask[10] = true; // [1,0,2]
234
235        let result = tensor.masked_fill(&mask, 99.0);
236
237        // Verify the correct positions were masked
238        assert_eq!(result.get(&[0, 0, 0]), 99.0); // position 0, was 1.0
239        assert_eq!(result.get(&[0, 1, 1]), 99.0); // position 5, was 6.0
240        assert_eq!(result.get(&[1, 0, 2]), 99.0); // position 10, was 11.0
241
242        // Verify unmasked positions remain unchanged
243        assert_eq!(result.get(&[0, 0, 1]), 2.0); // position 1, unchanged
244        assert_eq!(result.get(&[0, 1, 0]), 5.0); // position 4, unchanged
245        assert_eq!(result.get(&[1, 1, 3]), 16.0); // position 15, unchanged
246    }
247}