Skip to main content

tf_idf_vectorizer/utils/datastruct/vector/
tf.rs

1use std::{alloc::Layout, iter::FusedIterator, mem, ptr::NonNull};
2
3use num_traits::Num;
4
5#[allow(dead_code)]
6const TF_VECTOR_SIZE: usize = core::mem::size_of::<TFVector<u8>>();
7static_assertions::const_assert!(TF_VECTOR_SIZE == 32);
8
9pub trait TFVectorTrait<N>
10where N: Num + Copy
11{
12    fn len(&self) -> u32;
13    fn nnz(&self) -> u32;
14    fn cap(&self) -> u32;
15    fn term_sum(&self) -> u32;
16    fn new() -> Self;
17    fn new_with_capacity(capacity: u32) -> Self;
18    fn shrink_to_fit(&mut self);
19    fn raw_iter(&self) -> RawTFVectorIter<'_, N>;
20    unsafe fn from_vec(ind_vec: Vec<u32>, val_vec: Vec<N>, len: u32, term_sum: u32) -> Self;
21    unsafe fn ind_ptr(&self) -> *mut u32;
22    unsafe fn val_ptr(&self) -> *mut N;
23    /// Power Jump Search
24    /// Returns Some((value, sp_vec_raw_ind)) if found, None otherwise
25    #[inline(always)]
26    unsafe fn power_jump_search(&self, target: u32, start: usize) -> Option<(N, usize)>
27    where
28        N: Copy,
29    {
30        let nnz = self.nnz() as usize;
31        if start >= nnz {
32            return None;
33        }
34
35        let ind = unsafe { core::slice::from_raw_parts(self.ind_ptr(), nnz) };
36        let val = unsafe { core::slice::from_raw_parts(self.val_ptr(), nnz) };
37
38        // fast path
39        let mut lo = start;
40        let mut hi = start;
41
42        let s = ind[hi];
43        if s == target {
44            return Some((val[hi], hi));
45        }
46        if s > target {
47            return None; // forward-only
48        }
49
50        // galloping
51        let mut step = 1usize;
52        loop {
53            let next_hi = hi + step;
54            if next_hi >= nnz {
55                hi = nnz - 1;
56                break;
57            }
58            hi = next_hi;
59
60            if ind[hi] >= target {
61                break;
62            }
63
64            lo = hi;
65            step <<= 1;
66        }
67
68        // lower_bound in (lo, hi] => [lo+1, hi+1)
69        let mut l = lo + 1;
70        let mut r = hi + 1; // exclusive
71        while l < r {
72            let m = (l + r) >> 1;
73            if ind[m] < target {
74                l = m + 1;
75            } else {
76                r = m;
77            }
78        }
79
80        if l < nnz && ind[l] == target {
81            Some((val[l], l))
82        } else {
83            None
84        }
85    }
86    #[inline(always)]
87    fn get_power_jump(&self, target: u32, cut_down: &mut usize) -> Option<N>
88    where
89        N: Copy,
90    {
91        unsafe {
92            if let Some((v, idx)) = self.power_jump_search(target, *cut_down) {
93                *cut_down = idx;
94                Some(v)
95            } else {
96                None
97            }
98        }
99    }
100    #[inline(always)]
101    fn as_val_slice(&self) -> &[N] {
102        unsafe { core::slice::from_raw_parts(self.val_ptr(), self.nnz() as usize) }
103    }
104    #[inline(always)]
105    fn as_ind_slice(&self) -> &[u32] {
106        unsafe { core::slice::from_raw_parts(self.ind_ptr(), self.nnz() as usize) }
107    }
108    #[inline(always)]
109    fn perm(&mut self, perm_idxs: &[u32]) {
110        unsafe {
111            let mut_ind_slice = core::slice::from_raw_parts_mut(self.ind_ptr(), self.nnz() as usize);
112            let mut_val_slice = core::slice::from_raw_parts_mut(self.val_ptr(), self.nnz() as usize);
113            mut_ind_slice.iter_mut().for_each(|x| {
114                *x = perm_idxs[*x as usize];
115            });
116            crate::utils::sort::radix_sort_u32_soa(mut_ind_slice, mut_val_slice);
117        }
118    }
119}
120
121impl<N> TFVectorTrait<N> for TFVector<N> 
122where N: Num + Copy
123{
124    fn new() -> Self {
125        Self::low_new()
126    }
127
128    #[inline]
129    fn new_with_capacity(capacity: u32) -> Self {
130        let mut vec = Self::low_new();
131        if capacity != 0 {
132            vec.set_cap(capacity);
133        }
134        vec
135    }
136
137    #[inline]
138    fn shrink_to_fit(&mut self) {
139        if self.nnz < self.cap {
140            self.set_cap(self.nnz);
141        }
142    }
143
144    #[inline(always)]
145    fn raw_iter(&self) -> RawTFVectorIter<'_, N> {
146        RawTFVectorIter {
147            vec: self,
148            pos: 0,
149            end: self.nnz,
150        }
151    }
152
153    #[inline(always)]
154    fn nnz(&self) -> u32 {
155        self.nnz
156    }
157
158    #[inline(always)]
159    fn len(&self) -> u32 {
160        self.len
161    }
162
163    #[inline(always)]
164    fn cap(&self) -> u32 {
165        self.cap
166    }
167
168    #[inline(always)]
169    fn term_sum(&self) -> u32 {
170        self.term_sum
171    }
172
173    #[inline(always)]
174    unsafe fn from_vec(mut ind_vec: Vec<u32>, mut val_vec: Vec<N>, len: u32, term_sum: u32) -> Self {
175        debug_assert_eq!(
176            ind_vec.len(),
177            val_vec.len(),
178            "ind_vec and val_vec must have the same length"
179        );
180
181        // sort
182        crate::utils::sort::radix_sort_u32_soa(&mut ind_vec, &mut val_vec);
183
184        let nnz = ind_vec.len() as u32;
185
186        if nnz == 0 {
187            let mut v = TFVector::low_new();
188            v.len = len;
189            v.term_sum = term_sum;
190            return v;
191        }
192
193        // Consume the Vecs and avoid an extra copy:
194        // Vec -> Box<[T]> guarantees allocation sized to exactly `len`,
195        // which matches `Layout::array::<T>(nnz)` used by `free_alloc()`.
196        let inds_box: Box<[u32]> = ind_vec.into_boxed_slice();
197        let vals_box: Box<[N]> = val_vec.into_boxed_slice();
198
199        let inds_ptr = Box::into_raw(inds_box) as *mut u32;
200        let vals_ptr = Box::into_raw(vals_box) as *mut N;
201
202        TFVector {
203            inds: unsafe { NonNull::new_unchecked(inds_ptr) },
204            vals: unsafe { NonNull::new_unchecked(vals_ptr) },
205            cap: nnz,
206            nnz,
207            len,
208            term_sum,
209        }
210    }
211
212    #[inline(always)]
213    unsafe fn ind_ptr(&self) -> *mut u32 {
214        self.inds.as_ptr()
215    }
216
217    #[inline(always)]
218    unsafe fn val_ptr(&self) -> *mut N {
219        self.vals.as_ptr()
220    }
221}
222
223
224pub struct RawTFVectorIter<'a, N>
225where
226    N: Num + 'a,
227{
228    vec: &'a TFVector<N>,
229    pos: u32, // front
230    end: u32, // back (exclusive)
231}
232
233impl<'a, N> RawTFVectorIter<'a, N>
234where
235    N: Num + 'a,
236{
237    #[inline]
238    pub fn new(vec: &'a TFVector<N>) -> Self {
239        Self { vec, pos: 0, end: vec.nnz }
240    }
241}
242
243impl<'a, N> Iterator for RawTFVectorIter<'a, N>
244where
245    N: Num + 'a + Copy,
246{
247    type Item = (u32, N);
248
249    #[inline]
250    fn next(&mut self) -> Option<Self::Item> {
251        if self.pos >= self.end {
252            return None;
253        }
254        unsafe {
255            let i = self.pos as usize;
256            self.pos += 1;
257            let ind = *self.vec.inds.as_ptr().add(i);
258            let val = *self.vec.vals.as_ptr().add(i);
259            Some((ind, val))
260        }
261    }
262
263    #[inline]
264    fn size_hint(&self) -> (usize, Option<usize>) {
265        let remaining = (self.end - self.pos) as usize;
266        (remaining, Some(remaining))
267    }
268}
269
270impl<'a, N> DoubleEndedIterator for RawTFVectorIter<'a, N>
271where
272    N: Num + 'a + Copy,
273{
274    #[inline]
275    fn next_back(&mut self) -> Option<Self::Item> {
276        if self.pos >= self.end {
277            return None;
278        }
279        self.end -= 1;
280        unsafe {
281            let i = self.end as usize;
282            let ind = *self.vec.inds.as_ptr().add(i);
283            let val = *self.vec.vals.as_ptr().add(i);
284            Some((ind, val))
285        }
286    }
287}
288
289impl<'a, N> ExactSizeIterator for RawTFVectorIter<'a, N>
290where
291    N: Num + 'a + Copy,
292{
293    #[inline]
294    fn len(&self) -> usize {
295        (self.end - self.pos) as usize
296    }
297}
298
299impl<'a, N> FusedIterator for RawTFVectorIter<'a, N>
300where
301    N: Num + 'a + Copy,
302{}
303
304/// ZeroSpVecの生実装
305#[derive(Debug)]
306#[repr(align(32))] // どうなんだろうか
307pub struct TFVector<N> 
308where N: Num
309{
310    inds: NonNull<u32>,
311    vals: NonNull<N>,
312    cap: u32,
313    nnz: u32,
314    len: u32,
315    /// sum of terms of this document
316    /// denormalize number for this document
317    /// for reverse calculation to get term counts from tf values
318    term_sum: u32, // for future use
319}
320
321/// Low Level Implementation
322impl<N> TFVector<N> 
323where N: Num
324{
325    const VAL_SIZE: usize = mem::size_of::<N>();
326
327    #[inline]
328    fn low_new() -> Self {
329        // ZST は許さん
330        debug_assert!(Self::VAL_SIZE != 0, "Zero-sized type is not supported for TFVector");
331
332        TFVector {
333            // ダングリングポインタで初期化
334            inds: NonNull::dangling(),
335            vals: NonNull::dangling(),
336            cap: 0,
337            nnz: 0,
338            len: 0,
339            term_sum: 0,
340        }
341    }
342
343
344    #[inline]
345    #[allow(dead_code)]
346    fn grow(&mut self) {
347        let new_cap = if self.cap == 0 {
348            1
349        } else {
350            self.cap.checked_mul(2).expect("TFVector capacity overflowed")
351        };
352
353        self.set_cap(new_cap);
354    }
355
356    #[inline]
357    fn set_cap(&mut self, new_cap: u32) {
358        if new_cap == 0 {
359            // キャパシティを0にする場合はメモリを解放する
360            self.free_alloc();
361            return;
362        }
363        let new_inds_layout = Layout::array::<u32>(new_cap as usize).expect("Failed to create inds memory layout");
364        let new_vals_layout = Layout::array::<N>(new_cap as usize).expect("Failed to create vals memory layout");
365
366        if self.cap == 0 {
367            let new_inds_ptr = unsafe { std::alloc::alloc(new_inds_layout) };
368            let new_vals_ptr = unsafe { std::alloc::alloc(new_vals_layout) };
369            if new_inds_ptr.is_null() || new_vals_ptr.is_null() {
370                if new_inds_ptr.is_null() {
371                    oom(new_inds_layout);
372                } else {
373                    oom(new_vals_layout);
374                }
375            }
376
377            self.inds = unsafe { NonNull::new_unchecked(new_inds_ptr as *mut u32) };
378            self.vals = unsafe { NonNull::new_unchecked(new_vals_ptr as *mut N) };
379            self.cap = new_cap;
380        } else {
381            let old_inds_layout = Layout::array::<u32>(self.cap as usize).expect("Failed to create old inds memory layout");
382            let old_vals_layout = Layout::array::<N>(self.cap as usize).expect("Failed to create old vals memory layout");
383
384            let new_inds_ptr = unsafe { std::alloc::realloc(
385                self.inds.as_ptr().cast::<u8>(),
386                old_inds_layout,
387                new_inds_layout.size(),
388            ) };
389            let new_vals_ptr = unsafe { std::alloc::realloc(
390                self.vals.as_ptr().cast::<u8>(),
391                old_vals_layout,
392                new_vals_layout.size(),
393            ) };
394            if new_inds_ptr.is_null() || new_vals_ptr.is_null() {
395                if new_inds_ptr.is_null() {
396                    oom(new_inds_layout);
397                } else {
398                    oom(new_vals_layout);
399                }
400            }
401
402            self.inds = unsafe { NonNull::new_unchecked(new_inds_ptr as *mut u32) };
403            self.vals = unsafe { NonNull::new_unchecked(new_vals_ptr as *mut N) };
404            self.cap = new_cap;
405        }
406    }
407
408    #[inline]
409    fn free_alloc(&mut self) {
410        if self.cap != 0 {
411            unsafe {
412                let inds_layout = Layout::array::<u32>(self.cap as usize).unwrap();
413                let vals_layout = Layout::array::<N>(self.cap as usize).unwrap();
414                std::alloc::dealloc(self.inds.as_ptr().cast::<u8>(), inds_layout);
415                std::alloc::dealloc(self.vals.as_ptr().cast::<u8>(), vals_layout);
416            }
417        }
418        self.inds = NonNull::dangling();
419        self.vals = NonNull::dangling();
420        self.cap = 0;
421    }
422}
423
424unsafe impl<N: Num + Send + Sync> Send for TFVector<N> {}
425unsafe impl<N: Num + Sync> Sync for TFVector<N> {}
426
427impl<N> Drop for TFVector<N> 
428where N: Num
429{
430    #[inline]
431    fn drop(&mut self) {
432        self.free_alloc();
433    }
434}
435
436impl<N> Clone for TFVector<N>
437where
438    N: Num + Copy,
439{
440    #[inline]
441    fn clone(&self) -> Self {
442        let mut new_vec = TFVector::low_new();
443        if self.nnz > 0 {
444            new_vec.set_cap(self.nnz);
445            new_vec.len = self.len;
446            new_vec.nnz = self.nnz;
447            new_vec.term_sum = self.term_sum;
448
449            unsafe {
450                std::ptr::copy_nonoverlapping(
451                    self.inds.as_ptr(),
452                    new_vec.inds.as_ptr(),
453                    self.nnz as usize,
454                );
455                std::ptr::copy_nonoverlapping(
456                    self.vals.as_ptr(),
457                    new_vec.vals.as_ptr(),
458                    self.nnz as usize,
459                );
460            }
461        }
462        new_vec
463    }
464}
465
466
467
468/// OutOfMemoryへの対処用
469/// プロセスを終了させる
470/// 本来はpanic!を使用するべきだが、
471/// OOMの場合panic!を発生させるとTraceBackによるメモリ仕様が起きてしまうため
472/// 仕方なく強制終了させる
473/// 本来OOMはOSにより管理され発生前にKillされるはずなのであんまり意味はない。
474#[cold]
475#[inline(never)]
476fn oom(layout: Layout) -> ! {
477    std::alloc::handle_alloc_error(layout)
478}