Skip to main content

tf_idf_vectorizer/utils/datastruct/vector/
tf.rs

1use std::{alloc::Layout, iter::FusedIterator, mem, ptr::NonNull};
2
3use num_traits::Num;
4
5#[allow(dead_code)]
6const TF_VECTOR_SIZE: usize = core::mem::size_of::<TFVector<u8>>();
7static_assertions::const_assert!(TF_VECTOR_SIZE == 32);
8
9pub trait TFVectorTrait<N>
10where N: Num + Copy
11{
12    fn len(&self) -> u32;
13    fn nnz(&self) -> u32;
14    fn term_sum(&self) -> u32;
15    fn new() -> Self;
16    fn new_with_capacity(capacity: u32) -> Self;
17    fn shrink_to_fit(&mut self);
18    fn raw_iter(&self) -> RawTFVectorIter<'_, N>;
19    unsafe fn from_vec(ind_vec: Vec<u32>, val_vec: Vec<N>, len: u32, term_sum: u32) -> Self;
20    unsafe fn ind_ptr(&self) -> *mut u32;
21    unsafe fn val_ptr(&self) -> *mut N;
22    /// Power Jump Search
23    /// Returns Some((value, sp_vec_raw_ind)) if found, None otherwise
24    #[inline(always)]
25    unsafe fn power_jump_search(&self, target: u32, start: usize) -> Option<(N, usize)>
26    where
27        N: Copy,
28    {
29        let nnz = self.nnz() as usize;
30        if start >= nnz {
31            return None;
32        }
33
34        let ind = unsafe { core::slice::from_raw_parts(self.ind_ptr(), nnz) };
35        let val = unsafe { core::slice::from_raw_parts(self.val_ptr(), nnz) };
36
37        // fast path
38        let mut lo = start;
39        let mut hi = start;
40
41        let s = ind[hi];
42        if s == target {
43            return Some((val[hi], hi));
44        }
45        if s > target {
46            return None; // forward-only
47        }
48
49        // galloping
50        let mut step = 1usize;
51        loop {
52            let next_hi = hi + step;
53            if next_hi >= nnz {
54                hi = nnz - 1;
55                break;
56            }
57            hi = next_hi;
58
59            if ind[hi] >= target {
60                break;
61            }
62
63            lo = hi;
64            step <<= 1;
65        }
66
67        // lower_bound in (lo, hi] => [lo+1, hi+1)
68        let mut l = lo + 1;
69        let mut r = hi + 1; // exclusive
70        while l < r {
71            let m = (l + r) >> 1;
72            if ind[m] < target {
73                l = m + 1;
74            } else {
75                r = m;
76            }
77        }
78
79        if l < nnz && ind[l] == target {
80            Some((val[l], l))
81        } else {
82            None
83        }
84    }
85    #[inline(always)]
86    fn get_power_jump(&self, target: u32, cut_down: &mut usize) -> Option<N>
87    where
88        N: Copy,
89    {
90        unsafe {
91            if let Some((v, idx)) = self.power_jump_search(target, *cut_down) {
92                *cut_down = idx;
93                Some(v)
94            } else {
95                None
96            }
97        }
98    }
99    #[inline(always)]
100    fn as_val_slice(&self) -> &[N] {
101        unsafe { core::slice::from_raw_parts(self.val_ptr(), self.nnz() as usize) }
102    }
103    #[inline(always)]
104    fn as_ind_slice(&self) -> &[u32] {
105        unsafe { core::slice::from_raw_parts(self.ind_ptr(), self.nnz() as usize) }
106    }
107}
108
109impl<N> TFVectorTrait<N> for TFVector<N> 
110where N: Num + Copy
111{
112    fn new() -> Self {
113        Self::low_new()
114    }
115
116    #[inline]
117    fn new_with_capacity(capacity: u32) -> Self {
118        let mut vec = Self::low_new();
119        if capacity != 0 {
120            vec.set_cap(capacity);
121        }
122        vec
123    }
124
125    #[inline]
126    fn shrink_to_fit(&mut self) {
127        if self.nnz < self.cap {
128            self.set_cap(self.nnz);
129        }
130    }
131
132    #[inline(always)]
133    fn raw_iter(&self) -> RawTFVectorIter<'_, N> {
134        RawTFVectorIter {
135            vec: self,
136            pos: 0,
137            end: self.nnz,
138        }
139    }
140
141    #[inline(always)]
142    fn nnz(&self) -> u32 {
143        self.nnz
144    }
145
146    #[inline(always)]
147    fn len(&self) -> u32 {
148        self.len
149    }
150
151    #[inline(always)]
152    fn term_sum(&self) -> u32 {
153        self.term_sum
154    }
155
156    #[inline(always)]
157    unsafe fn from_vec(mut ind_vec: Vec<u32>, mut val_vec: Vec<N>, len: u32, term_sum: u32) -> Self {
158        debug_assert_eq!(
159            ind_vec.len(),
160            val_vec.len(),
161            "ind_vec and val_vec must have the same length"
162        );
163
164        // sort
165        crate::utils::sort::radix_sort_u32_soa(&mut ind_vec, &mut val_vec);
166
167        let nnz = ind_vec.len() as u32;
168
169        if nnz == 0 {
170            let mut v = TFVector::low_new();
171            v.len = len;
172            v.term_sum = term_sum;
173            return v;
174        }
175
176        // Consume the Vecs and avoid an extra copy:
177        // Vec -> Box<[T]> guarantees allocation sized to exactly `len`,
178        // which matches `Layout::array::<T>(nnz)` used by `free_alloc()`.
179        let inds_box: Box<[u32]> = ind_vec.into_boxed_slice();
180        let vals_box: Box<[N]> = val_vec.into_boxed_slice();
181
182        let inds_ptr = Box::into_raw(inds_box) as *mut u32;
183        let vals_ptr = Box::into_raw(vals_box) as *mut N;
184
185        TFVector {
186            inds: unsafe { NonNull::new_unchecked(inds_ptr) },
187            vals: unsafe { NonNull::new_unchecked(vals_ptr) },
188            cap: nnz,
189            nnz,
190            len,
191            term_sum,
192        }
193    }
194
195    #[inline(always)]
196    unsafe fn ind_ptr(&self) -> *mut u32 {
197        self.inds.as_ptr()
198    }
199
200    #[inline(always)]
201    unsafe fn val_ptr(&self) -> *mut N {
202        self.vals.as_ptr()
203    }
204}
205
206
207pub struct RawTFVectorIter<'a, N>
208where
209    N: Num + 'a,
210{
211    vec: &'a TFVector<N>,
212    pos: u32, // front
213    end: u32, // back (exclusive)
214}
215
216impl<'a, N> RawTFVectorIter<'a, N>
217where
218    N: Num + 'a,
219{
220    #[inline]
221    pub fn new(vec: &'a TFVector<N>) -> Self {
222        Self { vec, pos: 0, end: vec.nnz }
223    }
224}
225
226impl<'a, N> Iterator for RawTFVectorIter<'a, N>
227where
228    N: Num + 'a + Copy,
229{
230    type Item = (u32, N);
231
232    #[inline]
233    fn next(&mut self) -> Option<Self::Item> {
234        if self.pos >= self.end {
235            return None;
236        }
237        unsafe {
238            let i = self.pos as usize;
239            self.pos += 1;
240            let ind = *self.vec.inds.as_ptr().add(i);
241            let val = *self.vec.vals.as_ptr().add(i);
242            Some((ind, val))
243        }
244    }
245
246    #[inline]
247    fn size_hint(&self) -> (usize, Option<usize>) {
248        let remaining = (self.end - self.pos) as usize;
249        (remaining, Some(remaining))
250    }
251}
252
253impl<'a, N> DoubleEndedIterator for RawTFVectorIter<'a, N>
254where
255    N: Num + 'a + Copy,
256{
257    #[inline]
258    fn next_back(&mut self) -> Option<Self::Item> {
259        if self.pos >= self.end {
260            return None;
261        }
262        self.end -= 1;
263        unsafe {
264            let i = self.end as usize;
265            let ind = *self.vec.inds.as_ptr().add(i);
266            let val = *self.vec.vals.as_ptr().add(i);
267            Some((ind, val))
268        }
269    }
270}
271
272impl<'a, N> ExactSizeIterator for RawTFVectorIter<'a, N>
273where
274    N: Num + 'a + Copy,
275{
276    #[inline]
277    fn len(&self) -> usize {
278        (self.end - self.pos) as usize
279    }
280}
281
282impl<'a, N> FusedIterator for RawTFVectorIter<'a, N>
283where
284    N: Num + 'a + Copy,
285{}
286
287/// ZeroSpVecの生実装
288#[derive(Debug)]
289#[repr(align(32))] // どうなんだろうか
290pub struct TFVector<N> 
291where N: Num
292{
293    inds: NonNull<u32>,
294    vals: NonNull<N>,
295    cap: u32,
296    nnz: u32,
297    len: u32,
298    /// sum of terms of this document
299    /// denormalize number for this document
300    /// for reverse calculation to get term counts from tf values
301    term_sum: u32, // for future use
302}
303
304/// Low Level Implementation
305impl<N> TFVector<N> 
306where N: Num
307{
308    const VAL_SIZE: usize = mem::size_of::<N>();
309
310    #[inline]
311    fn low_new() -> Self {
312        // ZST は許さん
313        debug_assert!(Self::VAL_SIZE != 0, "Zero-sized type is not supported for TFVector");
314
315        TFVector {
316            // ダングリングポインタで初期化
317            inds: NonNull::dangling(),
318            vals: NonNull::dangling(),
319            cap: 0,
320            nnz: 0,
321            len: 0,
322            term_sum: 0,
323        }
324    }
325
326
327    #[inline]
328    #[allow(dead_code)]
329    fn grow(&mut self) {
330        let new_cap = if self.cap == 0 {
331            1
332        } else {
333            self.cap.checked_mul(2).expect("TFVector capacity overflowed")
334        };
335
336        self.set_cap(new_cap);
337    }
338
339    #[inline]
340    fn set_cap(&mut self, new_cap: u32) {
341        if new_cap == 0 {
342            // キャパシティを0にする場合はメモリを解放する
343            self.free_alloc();
344            return;
345        }
346        let new_inds_layout = Layout::array::<u32>(new_cap as usize).expect("Failed to create inds memory layout");
347        let new_vals_layout = Layout::array::<N>(new_cap as usize).expect("Failed to create vals memory layout");
348
349        if self.cap == 0 {
350            let new_inds_ptr = unsafe { std::alloc::alloc(new_inds_layout) };
351            let new_vals_ptr = unsafe { std::alloc::alloc(new_vals_layout) };
352            if new_inds_ptr.is_null() || new_vals_ptr.is_null() {
353                if new_inds_ptr.is_null() {
354                    oom(new_inds_layout);
355                } else {
356                    oom(new_vals_layout);
357                }
358            }
359
360            self.inds = unsafe { NonNull::new_unchecked(new_inds_ptr as *mut u32) };
361            self.vals = unsafe { NonNull::new_unchecked(new_vals_ptr as *mut N) };
362            self.cap = new_cap;
363        } else {
364            let old_inds_layout = Layout::array::<u32>(self.cap as usize).expect("Failed to create old inds memory layout");
365            let old_vals_layout = Layout::array::<N>(self.cap as usize).expect("Failed to create old vals memory layout");
366
367            let new_inds_ptr = unsafe { std::alloc::realloc(
368                self.inds.as_ptr().cast::<u8>(),
369                old_inds_layout,
370                new_inds_layout.size(),
371            ) };
372            let new_vals_ptr = unsafe { std::alloc::realloc(
373                self.vals.as_ptr().cast::<u8>(),
374                old_vals_layout,
375                new_vals_layout.size(),
376            ) };
377            if new_inds_ptr.is_null() || new_vals_ptr.is_null() {
378                if new_inds_ptr.is_null() {
379                    oom(new_inds_layout);
380                } else {
381                    oom(new_vals_layout);
382                }
383            }
384
385            self.inds = unsafe { NonNull::new_unchecked(new_inds_ptr as *mut u32) };
386            self.vals = unsafe { NonNull::new_unchecked(new_vals_ptr as *mut N) };
387            self.cap = new_cap;
388        }
389    }
390
391    #[inline]
392    fn free_alloc(&mut self) {
393        if self.cap != 0 {
394            unsafe {
395                let inds_layout = Layout::array::<u32>(self.cap as usize).unwrap();
396                let vals_layout = Layout::array::<N>(self.cap as usize).unwrap();
397                std::alloc::dealloc(self.inds.as_ptr().cast::<u8>(), inds_layout);
398                std::alloc::dealloc(self.vals.as_ptr().cast::<u8>(), vals_layout);
399            }
400        }
401        self.inds = NonNull::dangling();
402        self.vals = NonNull::dangling();
403        self.cap = 0;
404    }
405}
406
407unsafe impl<N: Num + Send + Sync> Send for TFVector<N> {}
408unsafe impl<N: Num + Sync> Sync for TFVector<N> {}
409
410impl<N> Drop for TFVector<N> 
411where N: Num
412{
413    #[inline]
414    fn drop(&mut self) {
415        self.free_alloc();
416    }
417}
418
419impl<N> Clone for TFVector<N>
420where
421    N: Num + Copy,
422{
423    #[inline]
424    fn clone(&self) -> Self {
425        let mut new_vec = TFVector::low_new();
426        if self.nnz > 0 {
427            new_vec.set_cap(self.nnz);
428            new_vec.len = self.len;
429            new_vec.nnz = self.nnz;
430            new_vec.term_sum = self.term_sum;
431
432            unsafe {
433                std::ptr::copy_nonoverlapping(
434                    self.inds.as_ptr(),
435                    new_vec.inds.as_ptr(),
436                    self.nnz as usize,
437                );
438                std::ptr::copy_nonoverlapping(
439                    self.vals.as_ptr(),
440                    new_vec.vals.as_ptr(),
441                    self.nnz as usize,
442                );
443            }
444        }
445        new_vec
446    }
447}
448
449
450
451/// OutOfMemoryへの対処用
452/// プロセスを終了させる
453/// 本来はpanic!を使用するべきだが、
454/// OOMの場合panic!を発生させるとTraceBackによるメモリ仕様が起きてしまうため
455/// 仕方なく強制終了させる
456/// 本来OOMはOSにより管理され発生前にKillされるはずなのであんまり意味はない。
457#[cold]
458#[inline(never)]
459fn oom(layout: Layout) -> ! {
460    std::alloc::handle_alloc_error(layout)
461}