tf_idf_vectorizer/utils/datastruct/vector/
mod.rs

1use std::{alloc::Layout, iter::FusedIterator, mem, ptr::NonNull};
2
3use num_traits::Num;
4
5pub mod serde;
6
7const TF_VECTOR_SIZE: usize = core::mem::size_of::<TFVector<u8>>();
8static_assertions::const_assert!(TF_VECTOR_SIZE == 32);
9
10pub trait TFVectorTrait<N>
11where N: Num + Copy
12{
13    fn len(&self) -> u32;
14    fn nnz(&self) -> u32;
15    fn term_sum(&self) -> u32;
16    fn new() -> Self;
17    fn new_with_capacity(capacity: u32) -> Self;
18    fn shrink_to_fit(&mut self);
19    fn raw_iter(&self) -> RawTFVectorIter<'_, N>;
20    unsafe fn from_vec(ind_vec: Vec<u32>, val_vec: Vec<N>, len: u32, term_sum: u32) -> Self;
21    unsafe fn ind_ptr(&self) -> *mut u32;
22    unsafe fn val_ptr(&self) -> *mut N;
23    /// Power Jump Search
24    /// Returns Some((value, sp_vec_raw_ind)) if found, None otherwise
25    #[inline(always)]
26    unsafe fn power_jump_search(&self, target: u32, start: usize) -> Option<(N, usize)>
27    where
28        N: Copy,
29    {
30        let nnz = self.nnz() as usize;
31        if start >= nnz {
32            return None;
33        }
34
35        let ind = unsafe { core::slice::from_raw_parts(self.ind_ptr(), nnz) };
36        let val = unsafe { core::slice::from_raw_parts(self.val_ptr(), nnz) };
37
38        // fast path
39        let mut lo = start;
40        let mut hi = start;
41
42        let s = ind[hi];
43        if s == target {
44            return Some((val[hi], hi));
45        }
46        if s > target {
47            return None; // forward-only
48        }
49
50        // galloping
51        let mut step = 1usize;
52        loop {
53            let next_hi = hi + step;
54            if next_hi >= nnz {
55                hi = nnz - 1;
56                break;
57            }
58            hi = next_hi;
59
60            if ind[hi] >= target {
61                break;
62            }
63
64            lo = hi;
65            step <<= 1;
66        }
67
68        // lower_bound in (lo, hi] => [lo+1, hi+1)
69        let mut l = lo + 1;
70        let mut r = hi + 1; // exclusive
71        while l < r {
72            let m = (l + r) >> 1;
73            if ind[m] < target {
74                l = m + 1;
75            } else {
76                r = m;
77            }
78        }
79
80        if l < nnz && ind[l] == target {
81            Some((val[l], l))
82        } else {
83            None
84        }
85    }
86    #[inline(always)]
87    fn get_power_jump(&self, target: u32, cut_down: &mut usize) -> Option<N>
88    where
89        N: Copy,
90    {
91        unsafe {
92            if let Some((v, idx)) = self.power_jump_search(target, *cut_down) {
93                *cut_down = idx;
94                Some(v)
95            } else {
96                None
97            }
98        }
99    }
100    #[inline(always)]
101    fn as_val_slice(&self) -> &[N] {
102        unsafe { core::slice::from_raw_parts(self.val_ptr(), self.nnz() as usize) }
103    }
104    #[inline(always)]
105    fn as_ind_slice(&self) -> &[u32] {
106        unsafe { core::slice::from_raw_parts(self.ind_ptr(), self.nnz() as usize) }
107    }
108}
109
110impl<N> TFVectorTrait<N> for TFVector<N> 
111where N: Num + Copy
112{
113    fn new() -> Self {
114        Self::low_new()
115    }
116
117    #[inline]
118    fn new_with_capacity(capacity: u32) -> Self {
119        let mut vec = Self::low_new();
120        if capacity != 0 {
121            vec.set_cap(capacity);
122        }
123        vec
124    }
125
126    #[inline]
127    fn shrink_to_fit(&mut self) {
128        if self.nnz < self.cap {
129            self.set_cap(self.nnz);
130        }
131    }
132
133    #[inline(always)]
134    fn raw_iter(&self) -> RawTFVectorIter<'_, N> {
135        RawTFVectorIter {
136            vec: self,
137            pos: 0,
138            end: self.nnz,
139        }
140    }
141
142    #[inline(always)]
143    fn nnz(&self) -> u32 {
144        self.nnz
145    }
146
147    #[inline(always)]
148    fn len(&self) -> u32 {
149        self.len
150    }
151
152    #[inline(always)]
153    fn term_sum(&self) -> u32 {
154        self.term_sum
155    }
156
157    #[inline(always)]
158    unsafe fn from_vec(mut ind_vec: Vec<u32>, mut val_vec: Vec<N>, len: u32, term_sum: u32) -> Self {
159        debug_assert_eq!(
160            ind_vec.len(),
161            val_vec.len(),
162            "ind_vec and val_vec must have the same length"
163        );
164
165        // sort
166        crate::utils::sort::radix_sort_u32_soa(&mut ind_vec, &mut val_vec);
167
168        let nnz = ind_vec.len() as u32;
169
170        if nnz == 0 {
171            let mut v = TFVector::low_new();
172            v.len = len;
173            v.term_sum = term_sum;
174            return v;
175        }
176
177        // Consume the Vecs and avoid an extra copy:
178        // Vec -> Box<[T]> guarantees allocation sized to exactly `len`,
179        // which matches `Layout::array::<T>(nnz)` used by `free_alloc()`.
180        let inds_box: Box<[u32]> = ind_vec.into_boxed_slice();
181        let vals_box: Box<[N]> = val_vec.into_boxed_slice();
182
183        let inds_ptr = Box::into_raw(inds_box) as *mut u32;
184        let vals_ptr = Box::into_raw(vals_box) as *mut N;
185
186        TFVector {
187            inds: unsafe { NonNull::new_unchecked(inds_ptr) },
188            vals: unsafe { NonNull::new_unchecked(vals_ptr) },
189            cap: nnz,
190            nnz,
191            len,
192            term_sum,
193        }
194    }
195
196    #[inline(always)]
197    unsafe fn ind_ptr(&self) -> *mut u32 {
198        self.inds.as_ptr()
199    }
200
201    #[inline(always)]
202    unsafe fn val_ptr(&self) -> *mut N {
203        self.vals.as_ptr()
204    }
205}
206
207
208pub struct RawTFVectorIter<'a, N>
209where
210    N: Num + 'a,
211{
212    vec: &'a TFVector<N>,
213    pos: u32, // front
214    end: u32, // back (exclusive)
215}
216
217impl<'a, N> RawTFVectorIter<'a, N>
218where
219    N: Num + 'a,
220{
221    #[inline]
222    pub fn new(vec: &'a TFVector<N>) -> Self {
223        Self { vec, pos: 0, end: vec.nnz }
224    }
225}
226
227impl<'a, N> Iterator for RawTFVectorIter<'a, N>
228where
229    N: Num + 'a + Copy,
230{
231    type Item = (u32, N);
232
233    #[inline]
234    fn next(&mut self) -> Option<Self::Item> {
235        if self.pos >= self.end {
236            return None;
237        }
238        unsafe {
239            let i = self.pos as usize;
240            self.pos += 1;
241            let ind = *self.vec.inds.as_ptr().add(i);
242            let val = *self.vec.vals.as_ptr().add(i);
243            Some((ind, val))
244        }
245    }
246
247    #[inline]
248    fn size_hint(&self) -> (usize, Option<usize>) {
249        let remaining = (self.end - self.pos) as usize;
250        (remaining, Some(remaining))
251    }
252}
253
254impl<'a, N> DoubleEndedIterator for RawTFVectorIter<'a, N>
255where
256    N: Num + 'a + Copy,
257{
258    #[inline]
259    fn next_back(&mut self) -> Option<Self::Item> {
260        if self.pos >= self.end {
261            return None;
262        }
263        self.end -= 1;
264        unsafe {
265            let i = self.end as usize;
266            let ind = *self.vec.inds.as_ptr().add(i);
267            let val = *self.vec.vals.as_ptr().add(i);
268            Some((ind, val))
269        }
270    }
271}
272
273impl<'a, N> ExactSizeIterator for RawTFVectorIter<'a, N>
274where
275    N: Num + 'a + Copy,
276{
277    #[inline]
278    fn len(&self) -> usize {
279        (self.end - self.pos) as usize
280    }
281}
282
283impl<'a, N> FusedIterator for RawTFVectorIter<'a, N>
284where
285    N: Num + 'a + Copy,
286{}
287
288/// ZeroSpVecの生実装
289#[derive(Debug)]
290pub struct TFVector<N> 
291where N: Num
292{
293    inds: NonNull<u32>,
294    vals: NonNull<N>,
295    cap: u32,
296    nnz: u32,
297    len: u32,
298    /// sum of terms of this document
299    /// denormalize number for this document
300    /// for reverse calculation to get term counts from tf values
301    term_sum: u32, // for future use
302}
303
304/// Low Level Implementation
305impl<N> TFVector<N> 
306where N: Num
307{
308    const VAL_SIZE: usize = mem::size_of::<N>();
309
310    #[inline]
311    fn low_new() -> Self {
312        // ZST は許さん
313        debug_assert!(Self::VAL_SIZE != 0, "Zero-sized type is not supported for TFVector");
314
315        TFVector {
316            // ダングリングポインタで初期化
317            inds: NonNull::dangling(),
318            vals: NonNull::dangling(),
319            cap: 0,
320            nnz: 0,
321            len: 0,
322            term_sum: 0,
323        }
324    }
325
326
327    #[inline]
328    #[allow(dead_code)]
329    fn grow(&mut self) {
330        let new_cap = if self.cap == 0 {
331            1
332        } else {
333            self.cap.checked_mul(2).expect("TFVector capacity overflowed")
334        };
335
336        self.set_cap(new_cap);
337    }
338
339    #[inline]
340    fn set_cap(&mut self, new_cap: u32) {
341        if new_cap == 0 {
342            // キャパシティを0にする場合はメモリを解放する
343            self.free_alloc();
344            return;
345        }
346        let new_inds_layout = Layout::array::<u32>(new_cap as usize).expect("Failed to create inds memory layout");
347        let new_vals_layout = Layout::array::<N>(new_cap as usize).expect("Failed to create vals memory layout");
348
349        if self.cap == 0 {
350            let new_inds_ptr = unsafe { std::alloc::alloc(new_inds_layout) };
351            let new_vals_ptr = unsafe { std::alloc::alloc(new_vals_layout) };
352            if new_inds_ptr.is_null() || new_vals_ptr.is_null() {
353                if new_inds_ptr.is_null() {
354                    oom(new_inds_layout);
355                } else {
356                    oom(new_vals_layout);
357                }
358            }
359
360            self.inds = unsafe { NonNull::new_unchecked(new_inds_ptr as *mut u32) };
361            self.vals = unsafe { NonNull::new_unchecked(new_vals_ptr as *mut N) };
362            self.cap = new_cap;
363        } else {
364            let old_inds_layout = Layout::array::<u32>(self.cap as usize).expect("Failed to create old inds memory layout");
365            let old_vals_layout = Layout::array::<N>(self.cap as usize).expect("Failed to create old vals memory layout");
366
367            let new_inds_ptr = unsafe { std::alloc::realloc(
368                self.inds.as_ptr().cast::<u8>(),
369                old_inds_layout,
370                new_inds_layout.size(),
371            ) };
372            let new_vals_ptr = unsafe { std::alloc::realloc(
373                self.vals.as_ptr().cast::<u8>(),
374                old_vals_layout,
375                new_vals_layout.size(),
376            ) };
377            if new_inds_ptr.is_null() || new_vals_ptr.is_null() {
378                if new_inds_ptr.is_null() {
379                    oom(new_inds_layout);
380                } else {
381                    oom(new_vals_layout);
382                }
383            }
384
385            self.inds = unsafe { NonNull::new_unchecked(new_inds_ptr as *mut u32) };
386            self.vals = unsafe { NonNull::new_unchecked(new_vals_ptr as *mut N) };
387            self.cap = new_cap;
388        }
389    }
390
391    #[inline]
392    fn free_alloc(&mut self) {
393        if self.cap != 0 {
394            unsafe {
395                let inds_layout = Layout::array::<u32>(self.cap as usize).unwrap();
396                let vals_layout = Layout::array::<N>(self.cap as usize).unwrap();
397                std::alloc::dealloc(self.inds.as_ptr().cast::<u8>(), inds_layout);
398                std::alloc::dealloc(self.vals.as_ptr().cast::<u8>(), vals_layout);
399            }
400        }
401        self.inds = NonNull::dangling();
402        self.vals = NonNull::dangling();
403        self.cap = 0;
404    }
405}
406
407unsafe impl<N: Num + Send + Sync> Send for TFVector<N> {}
408unsafe impl<N: Num + Sync> Sync for TFVector<N> {}
409
410impl<N> Drop for TFVector<N> 
411where N: Num
412{
413    #[inline]
414    fn drop(&mut self) {
415        self.free_alloc();
416    }
417}
418
419impl<N> Clone for TFVector<N>
420where
421    N: Num + Copy,
422{
423    #[inline]
424    fn clone(&self) -> Self {
425        let mut new_vec = TFVector::low_new();
426        if self.nnz > 0 {
427            new_vec.set_cap(self.nnz);
428            new_vec.len = self.len;
429            new_vec.nnz = self.nnz;
430            new_vec.term_sum = self.term_sum;
431
432            unsafe {
433                std::ptr::copy_nonoverlapping(
434                    self.inds.as_ptr(),
435                    new_vec.inds.as_ptr(),
436                    self.nnz as usize,
437                );
438                std::ptr::copy_nonoverlapping(
439                    self.vals.as_ptr(),
440                    new_vec.vals.as_ptr(),
441                    self.nnz as usize,
442                );
443            }
444        }
445        new_vec
446    }
447}
448
449
450
451/// OutOfMemoryへの対処用
452/// プロセスを終了させる
453/// 本来はpanic!を使用するべきだが、
454/// OOMの場合panic!を発生させるとTraceBackによるメモリ仕様が起きてしまうため
455/// 仕方なく強制終了させる
456/// 本来OOMはOSにより管理され発生前にKillされるはずなのであんまり意味はない。
457#[cold]
458#[inline(never)]
459fn oom(layout: Layout) -> ! {
460    std::alloc::handle_alloc_error(layout)
461}