tf_idf_vectorizer/utils/datastruct/vector/
mod.rs

1use std::{alloc::Layout, iter::FusedIterator, mem, ptr::NonNull};
2
3use num_traits::Num;
4
5pub mod serde;
6
7const TF_VECTOR_SIZE: usize = core::mem::size_of::<TFVector<u8>>();
8static_assertions::const_assert!(TF_VECTOR_SIZE == 32);
9
10pub trait TFVectorTrait<N>
11where N: Num + Copy
12{
13    fn len(&self) -> u32;
14    fn nnz(&self) -> u32;
15    fn term_sum(&self) -> u32;
16    fn new() -> Self;
17    fn new_with_capacity(capacity: u32) -> Self;
18    fn shrink_to_fit(&mut self);
19    fn raw_iter(&self) -> RawTFVectorIter<'_, N>;
20    unsafe fn from_vec(ind_vec: Vec<u32>, val_vec: Vec<N>, len: u32, term_sum: u32) -> Self;
21    unsafe fn ind_ptr(&self) -> *mut u32;
22    unsafe fn val_ptr(&self) -> *mut N;
23    /// Power Jump Search
24    /// Returns Some((value, sp_vec_raw_ind)) if found, None otherwise
25    unsafe fn power_jump_search(&self, target: u32, start: usize) -> Option<(N, usize)>
26    where
27        N: Copy,
28    {
29        let nnz = self.nnz() as usize;
30        if start >= nnz {
31            return None;
32        }
33
34        let ind = unsafe { core::slice::from_raw_parts(self.ind_ptr(), nnz) };
35        let val = unsafe { core::slice::from_raw_parts(self.val_ptr(), nnz) };
36
37        // fast path
38        let mut lo = start;
39        let mut hi = start;
40
41        let s = ind[hi];
42        if s == target {
43            return Some((val[hi], hi));
44        }
45        if s > target {
46            return None; // forward-only
47        }
48
49        // galloping
50        let mut step = 1usize;
51        loop {
52            let next_hi = hi + step;
53            if next_hi >= nnz {
54                hi = nnz - 1;
55                break;
56            }
57            hi = next_hi;
58
59            if ind[hi] >= target {
60                break;
61            }
62
63            lo = hi;
64            step <<= 1;
65        }
66
67        // lower_bound in (lo, hi] => [lo+1, hi+1)
68        let mut l = lo + 1;
69        let mut r = hi + 1; // exclusive
70        while l < r {
71            let m = (l + r) >> 1;
72            if ind[m] < target {
73                l = m + 1;
74            } else {
75                r = m;
76            }
77        }
78
79        if l < nnz && ind[l] == target {
80            Some((val[l], l))
81        } else {
82            None
83        }
84    }
85    fn get_power_jump(&self, target: u32, cut_down: &mut usize) -> Option<N>
86    where
87        N: Copy,
88    {
89        unsafe {
90            if let Some((v, idx)) = self.power_jump_search(target, *cut_down) {
91                *cut_down = idx;
92                Some(v)
93            } else {
94                None
95            }
96        }
97    }
98    fn as_val_slice(&self) -> &[N] {
99        unsafe { core::slice::from_raw_parts(self.val_ptr(), self.nnz() as usize) }
100    }
101    fn as_ind_slice(&self) -> &[u32] {
102        unsafe { core::slice::from_raw_parts(self.ind_ptr(), self.nnz() as usize) }
103    }
104}
105
106impl<N> TFVectorTrait<N> for TFVector<N> 
107where N: Num + Copy
108{
109    fn new() -> Self {
110        Self::low_new()
111    }
112
113    fn new_with_capacity(capacity: u32) -> Self {
114        let mut vec = Self::low_new();
115        if capacity != 0 {
116            vec.set_cap(capacity);
117        }
118        vec
119    }
120
121    #[inline]
122    fn shrink_to_fit(&mut self) {
123        if self.nnz < self.cap {
124            self.set_cap(self.nnz);
125        }
126    }
127
128    #[inline]
129    fn raw_iter(&self) -> RawTFVectorIter<'_, N> {
130        RawTFVectorIter {
131            vec: self,
132            pos: 0,
133            end: self.nnz,
134        }
135    }
136
137    #[inline]
138    fn nnz(&self) -> u32 {
139        self.nnz
140    }
141
142    #[inline]
143    fn len(&self) -> u32 {
144        self.len
145    }
146
147    fn term_sum(&self) -> u32 {
148        self.term_sum
149    }
150
151    unsafe fn from_vec(mut ind_vec: Vec<u32>, mut val_vec: Vec<N>, len: u32, term_sum: u32) -> Self {
152        debug_assert_eq!(
153            ind_vec.len(),
154            val_vec.len(),
155            "ind_vec and val_vec must have the same length"
156        );
157
158        // sort
159        crate::utils::sort::radix_sort_u32_soa(&mut ind_vec, &mut val_vec);
160
161        let nnz = ind_vec.len() as u32;
162
163        if nnz == 0 {
164            let mut v = TFVector::low_new();
165            v.len = len;
166            v.term_sum = term_sum;
167            return v;
168        }
169
170        // Consume the Vecs and avoid an extra copy:
171        // Vec -> Box<[T]> guarantees allocation sized to exactly `len`,
172        // which matches `Layout::array::<T>(nnz)` used by `free_alloc()`.
173        let inds_box: Box<[u32]> = ind_vec.into_boxed_slice();
174        let vals_box: Box<[N]> = val_vec.into_boxed_slice();
175
176        let inds_ptr = Box::into_raw(inds_box) as *mut u32;
177        let vals_ptr = Box::into_raw(vals_box) as *mut N;
178
179        TFVector {
180            inds: unsafe { NonNull::new_unchecked(inds_ptr) },
181            vals: unsafe { NonNull::new_unchecked(vals_ptr) },
182            cap: nnz,
183            nnz,
184            len,
185            term_sum,
186        }
187    }
188
189    unsafe fn ind_ptr(&self) -> *mut u32 {
190        self.inds.as_ptr()
191    }
192
193    unsafe fn val_ptr(&self) -> *mut N {
194        self.vals.as_ptr()
195    }
196}
197
198
199pub struct RawTFVectorIter<'a, N>
200where
201    N: Num + 'a,
202{
203    vec: &'a TFVector<N>,
204    pos: u32, // front
205    end: u32, // back (exclusive)
206}
207
208impl<'a, N> RawTFVectorIter<'a, N>
209where
210    N: Num + 'a,
211{
212    #[inline]
213    pub fn new(vec: &'a TFVector<N>) -> Self {
214        Self { vec, pos: 0, end: vec.nnz }
215    }
216}
217
218impl<'a, N> Iterator for RawTFVectorIter<'a, N>
219where
220    N: Num + 'a + Copy,
221{
222    type Item = (u32, N);
223
224    #[inline]
225    fn next(&mut self) -> Option<Self::Item> {
226        if self.pos >= self.end {
227            return None;
228        }
229        unsafe {
230            let i = self.pos as usize;
231            self.pos += 1;
232            let ind = *self.vec.inds.as_ptr().add(i);
233            let val = *self.vec.vals.as_ptr().add(i);
234            Some((ind, val))
235        }
236    }
237
238    #[inline]
239    fn size_hint(&self) -> (usize, Option<usize>) {
240        let remaining = (self.end - self.pos) as usize;
241        (remaining, Some(remaining))
242    }
243}
244
245impl<'a, N> DoubleEndedIterator for RawTFVectorIter<'a, N>
246where
247    N: Num + 'a + Copy,
248{
249    #[inline]
250    fn next_back(&mut self) -> Option<Self::Item> {
251        if self.pos >= self.end {
252            return None;
253        }
254        self.end -= 1;
255        unsafe {
256            let i = self.end as usize;
257            let ind = *self.vec.inds.as_ptr().add(i);
258            let val = *self.vec.vals.as_ptr().add(i);
259            Some((ind, val))
260        }
261    }
262}
263
264impl<'a, N> ExactSizeIterator for RawTFVectorIter<'a, N>
265where
266    N: Num + 'a + Copy,
267{
268    #[inline]
269    fn len(&self) -> usize {
270        (self.end - self.pos) as usize
271    }
272}
273
274impl<'a, N> FusedIterator for RawTFVectorIter<'a, N>
275where
276    N: Num + 'a + Copy,
277{}
278
279/// ZeroSpVecの生実装
280#[derive(Debug)]
281pub struct TFVector<N> 
282where N: Num
283{
284    inds: NonNull<u32>,
285    vals: NonNull<N>,
286    cap: u32,
287    nnz: u32,
288    len: u32,
289    /// sum of terms of this document
290    /// denormalize number for this document
291    /// for reverse calculation to get term counts from tf values
292    term_sum: u32, // for future use
293}
294
295/// Low Level Implementation
296impl<N> TFVector<N> 
297where N: Num
298{
299    const VAL_SIZE: usize = mem::size_of::<N>();
300
301    #[inline]
302    fn low_new() -> Self {
303        // ZST は許さん
304        debug_assert!(Self::VAL_SIZE != 0, "Zero-sized type is not supported for TFVector");
305
306        TFVector {
307            // ダングリングポインタで初期化
308            inds: NonNull::dangling(),
309            vals: NonNull::dangling(),
310            cap: 0,
311            nnz: 0,
312            len: 0,
313            term_sum: 0,
314        }
315    }
316
317
318    #[inline]
319    #[allow(dead_code)]
320    fn grow(&mut self) {
321        let new_cap = if self.cap == 0 {
322            1
323        } else {
324            self.cap.checked_mul(2).expect("TFVector capacity overflowed")
325        };
326
327        self.set_cap(new_cap);
328    }
329
330    #[inline]
331    fn set_cap(&mut self, new_cap: u32) {
332        if new_cap == 0 {
333            // キャパシティを0にする場合はメモリを解放する
334            self.free_alloc();
335            return;
336        }
337        let new_inds_layout = Layout::array::<u32>(new_cap as usize).expect("Failed to create inds memory layout");
338        let new_vals_layout = Layout::array::<N>(new_cap as usize).expect("Failed to create vals memory layout");
339
340        if self.cap == 0 {
341            let new_inds_ptr = unsafe { std::alloc::alloc(new_inds_layout) };
342            let new_vals_ptr = unsafe { std::alloc::alloc(new_vals_layout) };
343            if new_inds_ptr.is_null() || new_vals_ptr.is_null() {
344                if new_inds_ptr.is_null() {
345                    oom(new_inds_layout);
346                } else {
347                    oom(new_vals_layout);
348                }
349            }
350
351            self.inds = unsafe { NonNull::new_unchecked(new_inds_ptr as *mut u32) };
352            self.vals = unsafe { NonNull::new_unchecked(new_vals_ptr as *mut N) };
353            self.cap = new_cap;
354        } else {
355            let old_inds_layout = Layout::array::<u32>(self.cap as usize).expect("Failed to create old inds memory layout");
356            let old_vals_layout = Layout::array::<N>(self.cap as usize).expect("Failed to create old vals memory layout");
357
358            let new_inds_ptr = unsafe { std::alloc::realloc(
359                self.inds.as_ptr().cast::<u8>(),
360                old_inds_layout,
361                new_inds_layout.size(),
362            ) };
363            let new_vals_ptr = unsafe { std::alloc::realloc(
364                self.vals.as_ptr().cast::<u8>(),
365                old_vals_layout,
366                new_vals_layout.size(),
367            ) };
368            if new_inds_ptr.is_null() || new_vals_ptr.is_null() {
369                if new_inds_ptr.is_null() {
370                    oom(new_inds_layout);
371                } else {
372                    oom(new_vals_layout);
373                }
374            }
375
376            self.inds = unsafe { NonNull::new_unchecked(new_inds_ptr as *mut u32) };
377            self.vals = unsafe { NonNull::new_unchecked(new_vals_ptr as *mut N) };
378            self.cap = new_cap;
379        }
380    }
381
382    #[inline]
383    fn free_alloc(&mut self) {
384        if self.cap != 0 {
385            unsafe {
386                let inds_layout = Layout::array::<u32>(self.cap as usize).unwrap();
387                let vals_layout = Layout::array::<N>(self.cap as usize).unwrap();
388                std::alloc::dealloc(self.inds.as_ptr().cast::<u8>(), inds_layout);
389                std::alloc::dealloc(self.vals.as_ptr().cast::<u8>(), vals_layout);
390            }
391        }
392        self.inds = NonNull::dangling();
393        self.vals = NonNull::dangling();
394        self.cap = 0;
395    }
396}
397
398unsafe impl<N: Num + Send + Sync> Send for TFVector<N> {}
399unsafe impl<N: Num + Sync> Sync for TFVector<N> {}
400
401impl<N> Drop for TFVector<N> 
402where N: Num
403{
404    #[inline]
405    fn drop(&mut self) {
406        self.free_alloc();
407    }
408}
409
410impl<N> Clone for TFVector<N>
411where
412    N: Num + Copy,
413{
414    #[inline]
415    fn clone(&self) -> Self {
416        let mut new_vec = TFVector::low_new();
417        if self.nnz > 0 {
418            new_vec.set_cap(self.nnz);
419            new_vec.len = self.len;
420            new_vec.nnz = self.nnz;
421            new_vec.term_sum = self.term_sum;
422
423            unsafe {
424                std::ptr::copy_nonoverlapping(
425                    self.inds.as_ptr(),
426                    new_vec.inds.as_ptr(),
427                    self.nnz as usize,
428                );
429                std::ptr::copy_nonoverlapping(
430                    self.vals.as_ptr(),
431                    new_vec.vals.as_ptr(),
432                    self.nnz as usize,
433                );
434            }
435        }
436        new_vec
437    }
438}
439
440
441
442/// OutOfMemoryへの対処用
443/// プロセスを終了させる
444/// 本来はpanic!を使用するべきだが、
445/// OOMの場合panic!を発生させるとTraceBackによるメモリ仕様が起きてしまうため
446/// 仕方なく強制終了させる
447/// 本来OOMはOSにより管理され発生前にKillされるはずなのであんまり意味はない。
448#[cold]
449#[inline(never)]
450fn oom(layout: Layout) -> ! {
451    std::alloc::handle_alloc_error(layout)
452}