1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
//! Provides an alternative implementation for `Vec::drain_filter`.
//!
//! Import `VecDrainWhereExt` to extend `Vec` with an
//! `e_drain_where` method which drains all elements where
//! a predicate indicates it. The `e_` prefix is to prevent
//! name collision/confusion as `drain_filter` might be
//! stabilized as `drain_where`. Also in difference to
//! `drain_filter` this implementation doesn't run to
//! completion when dropped, allowing stopping the draining
//! from the outside (through combinators/for loop break)
//! and is not prone to double panics/panics on drop.
#[cfg(test)]
extern crate quickcheck;

use std::{isize, ptr, mem};

/// Ext. trait adding `e_drain_where` to `Vec`
pub trait VecDrainWhereExt<Item> {
    /// Drains all elements from the vector where the predicate is true.
    ///
    /// Note that dropping the iterator early will stop the process
    /// of draining. So for example if you add an combinator to the
    /// drain iterator which short circuits (e.g. `any`/`all`) this
    /// will stop draining once short circuiting is hit. So use it
    /// with care.
    ///
    /// you can use fold e.g. `any(pred)` => `fold(false, |s| )
    ///
    /// # Leak Behavior
    ///
    /// For safety reasons the length of the original vector
    /// is set to 0 while the drain iterator lives.
    ///
    /// # Panic/Drop Behavior
    ///
    /// When the iterator is dropped due to an panic in
    /// the predicate the element it panicked on is leaked
    /// but all elements which have already been decided
    /// to not be drained and such which have not yet been
    /// decided about will still be in the vector safely.
    /// I.e. if the panic also causes the vector to drop
    /// they are normally dropped if not the vector still
    /// can be normally used.
    ///
    /// # Tip: non iterator short circuiting `all`/`any`
    ///
    /// Instead of `iter.any(pred)` use
    /// `iter.fold(false, |s,i| s|pred(i))`.
    ///
    /// Instead of `iter.all(pred)` use
    /// `iter.fold(true, |s,i| s&pred(i))`.
    ///
    /// And if it is fine to not call `pred` once
    /// it's found/has show to not hold but it's
    /// still required to run the iterator to end
    /// in the normal case replace the `|` with `||`
    /// and the `&` with `&&`.
    fn e_drain_where<F>(&mut self, predicate: F)
        -> VecDrainWhere<Item, F>
        where F: FnMut(&mut Item) -> bool;
}

impl<Item> VecDrainWhereExt<Item> for Vec<Item> {
    fn e_drain_where<F>(&mut self, predicate: F)
        -> VecDrainWhere<Item, F>
        where F: FnMut(&mut Item) -> bool
    {
        let ptr = self.as_mut_ptr();
        let len = self.len();
        if len == 0 {
            let nptr = 0 as *mut _;
            return VecDrainWhere {
                pos: nptr,
                gap_pos: nptr,
                end: nptr,
                self_ref: self,
                predicate
            };
        }

        if len > isize::MAX as usize {
            panic!("can not handle more then isize::MAX elements");
        }

        // leak amplification for safety
        unsafe { self.set_len(0) }

        let end = unsafe { ptr.offset(len as isize) };

        VecDrainWhere {
            pos: ptr,
            gap_pos: ptr,
            end,
            self_ref: self,
            predicate
        }
    }
}

/// Iterator for draining a vector conditionally.
#[must_use]
#[derive(Debug)]
pub struct VecDrainWhere<'a, Item: 'a, Pred> {
    pos: *mut Item,
    gap_pos: *mut Item,
    end: *mut Item,
    predicate: Pred,
    self_ref: &'a mut Vec<Item>
}

impl<'a, I: 'a, P> Iterator for VecDrainWhere<'a, I, P>
    where P: FnMut(&mut I) -> bool
{
    type Item = I;

    fn next(&mut self) -> Option<Self::Item> {
        loop {
            if self.pos.is_null() || self.pos >= self.end {
                return None;
            } else {
                unsafe {
                    let ref_to_current = &mut *self.pos;
                    self.pos = self.pos.offset(1);
                    let should_be_drained = (self.predicate)(ref_to_current);
                    if should_be_drained {
                        let item = ptr::read(ref_to_current);
                        return Some(item);
                    } else {
                        if self.gap_pos < ref_to_current {
                            ptr::copy_nonoverlapping(ref_to_current, self.gap_pos, 1);
                        }
                        self.gap_pos = self.gap_pos.offset(1);
                    }
                }
            }
        }
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        (0, Some(self.self_ref.len()))
    }
}

impl<'a, I: 'a, P> Drop for VecDrainWhere<'a, I, P> {
    /// If the iterator was run to completion this will
    /// set the len to the new len after drop. I.e. it
    /// will undo the leak amplification.
    ///
    /// If the iterator is dropped before completion this
    /// will move the remaining elements to the (single)
    /// gap (still) left from draining elements and then
    /// sets the new length.
    ///
    /// If the iterator is dropped because the called
    /// predicate panicked the element it panicked on
    /// is _leaked_. This is because its simply to easy
    /// to leaf the `&mut T` value in a illegal state
    /// likely to panic drop or even behave unsafely
    /// (through it surly shouldn't behave this way).
    fn drop(&mut self) {
        let pos = self.pos as usize;
        if self.pos.is_null() {
            return
        }
        let start  = self.self_ref.as_mut_ptr() as usize;
        let end = self.end as usize;
        let gap = self.gap_pos as usize;
        let item_size: usize = mem::size_of::<I>();
        unsafe {
            let cur_len = (gap - start)/item_size;
            let rem_len = (end - pos)/item_size;
            ptr::copy(self.pos, self.gap_pos, rem_len);
            self.self_ref.set_len(cur_len + rem_len);
        }
    }
}


#[cfg(test)]
mod tests {
    use quickcheck::TestResult;
    //Uhm, this is not unused at all, so it being displayed
    // as such is a rustc bug (is in the bug tracker).
    #[allow(unused_imports)]
    use super::VecDrainWhereExt;

    mod check_with_mask {
        use super::*;

        fn cmp_with_mask(mask: Vec<bool>) -> TestResult {
            let mut data = (0..mask.len()).collect::<Vec<_>>();
            let data2 = data.clone();
            let new_len = mask.len() - mask.iter().fold(0, |s,i| if *i { s + 1 } else { s });
            let mut mask_iter = mask.clone().into_iter();
            let mut last_el: Option<usize> = None;

            let mut failed = None;
            data.e_drain_where(|el| {
                if let Some(lel) = last_el {
                    if lel + 1 != *el {
                        failed = Some(TestResult::error(
                            format!("unexpected element (exp {}, got {})", lel + 1, el)));
                    }
                }
                last_el = Some(*el);

                if let Some(mask) = mask_iter.next() {
                    mask
                } else {
                    failed = Some(TestResult::error("called predicate to often"));
                    false
                }
            }).for_each(drop);

            if let Some(f) = failed {
                return f;
            }

            if new_len != data.len() {
                return TestResult::error(format!(
                    "rem count: {}, found count: {} - {:?} | {:?}",
                    new_len, data.len(), data, mask
                ))
            }

            let expected = data2.iter().zip(mask.iter())
                    .filter(|&(_d, p)| *p)
                    .map(|(d, _p)| *d)
                    .collect::<Vec<_>>();

            if expected != data {
                TestResult::error("unexpected data");
            }
            TestResult::passed()

        }

        #[test]
        fn qc_cmp_with_mask() {
            ::quickcheck::quickcheck(cmp_with_mask as fn(Vec<bool>) -> TestResult);
        }


        #[test]
        fn fix_divide_byte_len_by_size_of() {
            let res = cmp_with_mask(vec![false]);
            assert!(!res.is_error(), "{:?}", res)
        }

        #[test]
        fn fix_update_last_el_in_test() {
            let res = cmp_with_mask(vec![false, false, false]);
            assert!(!res.is_error(), "{:?}", res)
        }
    }

    mod check_with_panic {
        use super::*;

        fn panic_situations(mask: Vec<(bool, bool)>) -> TestResult {
            let mut data = (0..mask.len()).collect::<Vec<_>>();
            let mut mask_iter = mask.clone().into_iter();
            let mut fail = None;
            let mut expect_panic = false;
            let expected_len = mask.iter()
                .fold(0, |sum, &(msk, pnk)| {
                    if expect_panic { sum + 1 }
                    else if pnk { expect_panic=true; sum }
                    else if msk { sum }
                    else { sum + 1}
                });

            let res = ::std::panic::catch_unwind(::std::panic::AssertUnwindSafe(|| {
                data.e_drain_where(|_item| {
                    let (mask, do_panic) = mask_iter.next()
                        .unwrap_or_else(|| {
                            fail = Some(TestResult::error("unexpected no more masks"));
                            (false, false)
                        });

                    if do_panic {
                        panic!("-- yes panic --");
                    }
                    mask
                }).for_each(drop);
            }));

            if let Some(failure) = fail {
                return failure;
            }

            if expect_panic {
                if res.is_ok() {
                    return TestResult::error(format!(
                        "unexpectedly no panic? exp {}, len {}, ({:?})",
                        expected_len, mask.len(), mask
                    ))
                }
            } else {
                if res.is_err() {
                    return TestResult::error(format!(
                        "unexpectedly error? exp {}, len {}, ({:?})",
                        expected_len, mask.len(), mask
                    ))
                }
            }

            if data.len() != expected_len {
                return TestResult::error(format!(
                    "unexpected resulting len {}, exp {} ({:?} - {:?})",
                    data.len(), expected_len, data, mask
                ));
            }

            TestResult::passed()
        }


        #[test]
        fn qc_panic_test() {
            ::quickcheck::quickcheck(panic_situations as fn(Vec<(bool,bool)>) -> TestResult)
        }

        #[test]
        fn fix_messed_up_test() {
            let res = panic_situations(vec![(true, false)]);
            assert!(!res.is_error(), "{:?}", res);
        }
    }

}