train_station/tensor/indexing/
masked_fill.rs

1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5    /// Fill masked elements with a specified value
6    ///
7    /// This operation returns a copy of the input tensor where elements are replaced
8    /// by the specified value wherever the corresponding boolean mask is true.
9    /// Elements where the mask is false retain their original values from the input tensor.
10    ///
11    /// The masked_fill operation is commonly used in machine learning for operations
12    /// like masking attention weights, zeroing out specific elements, and implementing
13    /// dropout-like functionality.
14    ///
15    /// # Arguments
16    ///
17    /// * `mask` - Boolean array with the same length as the number of tensor elements
18    /// * `value` - The value to fill masked positions with
19    ///
20    /// # Returns
21    ///
22    /// A new tensor with the same shape as the input, where masked elements are
23    /// replaced by `value` and unmasked elements retain their original values
24    ///
25    /// # Examples
26    ///
27    /// ## Basic Masked Fill
28    ///
29    /// ```
30    /// use train_station::Tensor;
31    ///
32    /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
33    /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
34    ///
35    /// // Create a mask: [false, true, false, true, false, true]
36    /// let mask = [false, true, false, true, false, true];
37    /// let result = tensor.masked_fill(&mask, -1.0);
38    ///
39    /// // Result: [[0.0, -1.0, 2.0], [-1.0, 4.0, -1.0]]
40    /// assert_eq!(result.shape().dims, vec![2, 3]);
41    /// assert_eq!(result.get(&[0, 0]), 0.0);   // Unmasked
42    /// assert_eq!(result.get(&[0, 1]), -1.0);  // Masked
43    /// assert_eq!(result.get(&[0, 2]), 2.0);   // Unmasked
44    /// assert_eq!(result.get(&[1, 0]), -1.0);  // Masked
45    /// assert_eq!(result.get(&[1, 1]), 4.0);   // Unmasked
46    /// assert_eq!(result.get(&[1, 2]), -1.0);  // Masked
47    /// ```
48    ///
49    /// ## Masked Fill with Gradient Tracking
50    ///
51    /// ```
52    /// use train_station::Tensor;
53    ///
54    /// let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap()
55    ///     .with_requires_grad();
56    ///
57    /// // Create a mask with some true values
58    /// let mask = [false, true, false, true, false, false];
59    /// let mut result = tensor.masked_fill(&mask, 5.0);
60    ///
61    /// // Compute gradients
62    /// result.backward(None);
63    /// let grad = tensor.grad_by_value().expect("gradient missing");
64    ///
65    /// // Gradients should be zero where mask is true, 1 elsewhere
66    /// assert_eq!(grad.shape().dims, vec![2, 3]);
67    /// assert!((grad.get(&[0, 0]) - 1.0).abs() < 1e-6);   // Unmasked: gradient flows
68    /// assert!((grad.get(&[0, 1]) - 0.0).abs() < 1e-6);   // Masked: no gradient
69    /// assert!((grad.get(&[0, 2]) - 1.0).abs() < 1e-6);   // Unmasked: gradient flows
70    /// ```
71    ///
72    /// ## Zeroing Out Specific Elements
73    ///
74    /// ```
75    /// use train_station::Tensor;
76    ///
77    /// // Create a tensor with some values
78    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
79    ///
80    /// // Create a mask to zero out every other element
81    /// let mask = [true, false, true, false, true, false];
82    /// let result = tensor.masked_fill(&mask, 0.0);
83    ///
84    /// // Result: [[0.0, 2.0, 0.0], [4.0, 0.0, 6.0]]
85    /// assert_eq!(result.get(&[0, 0]), 0.0);  // Zeroed
86    /// assert_eq!(result.get(&[0, 1]), 2.0);  // Kept
87    /// assert_eq!(result.get(&[0, 2]), 0.0);  // Zeroed
88    /// assert_eq!(result.get(&[1, 0]), 4.0);  // Kept
89    /// assert_eq!(result.get(&[1, 1]), 0.0);  // Zeroed
90    /// assert_eq!(result.get(&[1, 2]), 6.0);  // Kept
91    /// ```
92    ///
93    /// # Performance Characteristics
94    ///
95    /// - **Time Complexity**: O(n) where n is the number of elements in the tensor
96    /// - **Memory Usage**: Creates a new tensor with the same size as the input
97    /// - **Optimization**: Uses efficient stride-based iteration for non-contiguous tensors
98    /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
99    /// - **Memory Layout**: Output tensor is always contiguous for optimal performance
100    ///
101    /// # Implementation Details
102    ///
103    /// The masked_fill operation works by:
104    /// 1. Validating that the mask length equals the number of tensor elements
105    /// 2. Creating a new contiguous output tensor with the same shape
106    /// 3. Iterating through all elements in logical order
107    /// 4. For each element, checking the corresponding mask value:
108    ///    - If mask is true: use the fill value
109    ///    - If mask is false: copy the original value from input tensor
110    /// 5. Computing source offsets using the input tensor's shape for non-contiguous tensors
111    /// 6. Registering the operation for gradient computation if needed
112    ///
113    /// # Safety
114    ///
115    /// This function performs bounds checking to ensure:
116    /// - The mask length equals the number of tensor elements
117    /// - Memory access is safe through proper offset calculations
118    /// - The operation handles both contiguous and non-contiguous tensors correctly
119    ///
120    /// # Panics
121    ///
122    /// This function will panic if:
123    /// - The mask length does not equal the number of tensor elements
124    ///
125    /// # Thread Safety
126    ///
127    /// This function is thread-safe and can be called concurrently on different tensors.
128    /// The operation does not modify the input tensor and creates a new output tensor.
129    ///
130    /// # GradTrack Behavior
131    ///
132    /// When gradient tracking is enabled:
133    /// - Gradients do not flow through masked positions (they are zeroed)
134    /// - Gradients flow normally through unmasked positions
135    /// - This behavior is useful for implementing operations like dropout
136    pub fn masked_fill(&self, mask: &[bool], value: f32) -> Tensor {
137        let numel = self.size();
138        assert_eq!(
139            mask.len(),
140            numel,
141            "mask length {} must equal tensor elements {}",
142            mask.len(),
143            numel
144        );
145
146        // Output is a contiguous copy with applied mask
147        let mut output = Tensor::new(self.shape().dims.clone());
148
149        // Iterate in logical order using strides if needed
150        let rank = self.shape().rank();
151        let mut coords = vec![0usize; rank];
152        for (lin, &m) in mask.iter().enumerate().take(numel) {
153            // Decode logical coords for mask mapping
154            let mut tmp = lin;
155            for i in (0..rank).rev() {
156                let s = self.shape().dims[i];
157                coords[i] = if s == 0 { 0 } else { tmp % s };
158                tmp /= s;
159            }
160            let src_off = self.shape().offset(&coords);
161            unsafe {
162                *output.as_mut_ptr().add(lin) = if m {
163                    value
164                } else {
165                    *self.as_ptr().add(src_off)
166                };
167            }
168        }
169
170        if self.requires_grad() {
171            output.set_requires_grad(true);
172            let grad_fn = GradFn::MaskedFill {
173                mask: mask.to_vec(),
174                input_shape: self.shape().dims.clone(),
175            };
176            output.set_grad_fn(grad_fn.clone());
177            GradEngine::register_operation(output.id(), vec![self.id()], grad_fn);
178        }
179
180        output
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn test_masked_fill_basic() {
190        let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
191        let mask = vec![false, true, false, true, false, true];
192        let y = x.masked_fill(&mask, -1.0);
193        assert_eq!(y.shape().dims, vec![2, 3]);
194        assert_eq!(y.get(&[0, 0]), 0.0);
195        assert_eq!(y.get(&[0, 1]), -1.0);
196        assert_eq!(y.get(&[1, 0]), -1.0);
197    }
198
199    #[test]
200    fn test_masked_fill_gradients() {
201        let x = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3])
202            .unwrap()
203            .with_requires_grad();
204        let mask = vec![false, true, false, true, false, false];
205        let mut y = x.masked_fill(&mask, 5.0);
206        y.backward(None);
207        let gx = x.grad_by_value().expect("grad missing");
208        // Grad should be zero where mask is true, 1 elsewhere (from upstream ones)
209        for (i, &m) in mask.iter().enumerate().take(6) {
210            let expected = if m { 0.0 } else { 1.0 };
211            assert!((gx.get(&[i / 3, i % 3]) - expected).abs() < 1e-6);
212        }
213    }
214
215    #[test]
216    fn test_masked_fill_coordinate_decoding_bug_fix() {
217        // This test specifically targets the coordinate decoding bug that was fixed
218        // Create a 3D tensor to test multi-dimensional coordinate mapping
219        let data = vec![
220            1.0, 2.0, 3.0, 4.0, // [0, 0, :]
221            5.0, 6.0, 7.0, 8.0, // [0, 1, :]
222            9.0, 10.0, 11.0, 12.0, // [1, 0, :]
223            13.0, 14.0, 15.0, 16.0, // [1, 1, :]
224        ];
225        let tensor = Tensor::from_slice(&data, vec![2, 2, 4]).unwrap();
226
227        // Create a mask that selects specific positions
228        // Position 0: [0,0,0] -> 1.0, Position 5: [0,1,1] -> 6.0, Position 10: [1,0,2] -> 11.0
229        let mut mask = vec![false; 16];
230        mask[0] = true; // [0,0,0]
231        mask[5] = true; // [0,1,1]
232        mask[10] = true; // [1,0,2]
233
234        let result = tensor.masked_fill(&mask, 99.0);
235
236        // Verify the correct positions were masked
237        assert_eq!(result.get(&[0, 0, 0]), 99.0); // position 0, was 1.0
238        assert_eq!(result.get(&[0, 1, 1]), 99.0); // position 5, was 6.0
239        assert_eq!(result.get(&[1, 0, 2]), 99.0); // position 10, was 11.0
240
241        // Verify unmasked positions remain unchanged
242        assert_eq!(result.get(&[0, 0, 1]), 2.0); // position 1, unchanged
243        assert_eq!(result.get(&[0, 1, 0]), 5.0); // position 4, unchanged
244        assert_eq!(result.get(&[1, 1, 3]), 16.0); // position 15, unchanged
245    }
246}