tf_idf_vectorizer/utils/math/vector/
mod.rs

1pub mod math;
2pub mod math_normalized;
3pub mod serde;
4
5use std::{alloc::{alloc, dealloc, realloc, Layout}, fmt, marker::PhantomData, mem, ptr::{self, NonNull}};
6use std::ops::Index;
7use std::fmt::Debug;
8
9use num::Num;
10/// ZeroSpVecは0要素を疎とした過疎ベクトルを実装です
11/// indices と valuesを持ち
12/// indicesは要素のインデックスを保持し、
13/// valuesは要素の値を保持します
14/// 
15/// 要素はindicesの昇順でソートされていることを保証します
16pub struct ZeroSpVec<N>  
17where N: Num
18{
19    buf: RawZeroSpVec<N>,
20    len: usize,
21    nnz: usize,
22    zero: N,
23}
24
25impl<N> ZeroSpVec<N> 
26where N: Num
27{
28    #[inline]
29    fn ind_ptr(&self) -> *mut usize {
30        self.buf.ind_ptr.as_ptr()
31    }
32
33    #[inline]
34    fn val_ptr(&self) -> *mut N {
35        self.buf.val_ptr.as_ptr()
36    }
37
38    /// raw_pushは、要素を追加するためのメソッドです。
39    /// ただし全体の長さを考慮しませんのであとでlenを更新する必要があります。
40    /// 
41    /// # Arguments
42    /// - `index` - 追加する要素のインデックス
43    /// - `value` - 追加する要素の値
44    #[inline]
45    fn raw_push(&mut self, index: usize, value: N) {
46        if self.nnz == self.buf.cap {
47            self.buf.grow();
48        }
49        unsafe {
50            let val_ptr = self.val_ptr().add(self.nnz);
51            let ind_ptr = self.ind_ptr().add(self.nnz);
52            ptr::write(val_ptr, value);
53            ptr::write(ind_ptr, index);
54        }
55        self.nnz += 1;
56    }
57
58    #[inline]
59    fn ind_binary_search(&self, index: &usize) -> Result<usize, usize> {
60        // 要素が無い場合は「まだどこにも挿入されていない」ので Err(0)
61        if self.nnz == 0 {
62            return Err(0);
63        }
64
65        let mut left = 0;
66        let mut right = self.nnz - 1;
67        while left < right {
68            let mid = left + (right - left) / 2;
69            let mid_index = unsafe { ptr::read(self.ind_ptr().add(mid)) };
70            if mid_index == *index {
71                return Ok(mid);
72            } else if mid_index < *index {
73                left = mid + 1;
74            } else {
75                right = mid;
76            }
77        }
78
79        // ループ終了後 left == right の位置になっている
80        let final_index = unsafe { ptr::read(self.ind_ptr().add(left)) };
81        if final_index == *index {
82            Ok(left)
83        } else if final_index < *index {
84            Err(left + 1)
85        } else {
86            Err(left)
87        }
88    }
89
90    #[inline]
91    pub fn new() -> Self {
92        ZeroSpVec {
93            buf: RawZeroSpVec::new(),
94            len: 0,
95            nnz: 0,
96            zero: N::zero(),
97        }
98    }
99
100    #[inline]
101    pub fn with_capacity(cap: usize) -> Self {
102        let mut buf = RawZeroSpVec::new();
103        buf.cap = cap;
104        buf.cap_set();
105        ZeroSpVec {
106            buf: buf,
107            len: 0,
108            nnz: 0,
109            zero: N::zero(),
110        }
111    }
112
113    #[inline]
114    pub fn reserve(&mut self, additional: usize) {
115        let new_cap = self.nnz + additional;
116        if new_cap > self.buf.cap {
117            self.buf.cap = new_cap;
118            self.buf.re_cap_set();
119        }
120    }
121
122    #[inline]
123    pub fn shrink_to_fit(&mut self) {
124        if self.len < self.buf.cap {
125            let new_cap = self.nnz;
126            self.buf.cap = new_cap;
127            self.buf.re_cap_set();
128        }
129    }
130
131    #[inline]
132    pub fn is_empty(&self) -> bool {
133        self.len == 0
134    }
135
136    #[inline]
137    pub fn len(&self) -> usize {
138        self.len
139    }
140
141    #[inline]
142    pub fn capacity(&self) -> usize {
143        self.buf.cap
144    }
145
146    #[inline]
147    pub fn nnz(&self) -> usize {
148        self.nnz
149    }
150
151    #[inline]
152    pub fn add_dim(&mut self, dim: usize) {
153        self.len += dim;
154    }
155
156    #[inline]
157    pub fn clear(&mut self) {
158        while let Some(_) = self.pop() {
159            // do nothing
160        }
161    }
162
163    #[inline]
164    pub fn push(&mut self, elem: N) {
165        if self.nnz == self.buf.cap {
166            self.buf.grow();
167        }
168        if elem != N::zero() {
169            unsafe {
170                let val_ptr = self.val_ptr().add(self.nnz);
171                let ind_ptr = self.ind_ptr().add(self.nnz);
172                ptr::write(val_ptr, elem);
173                ptr::write(ind_ptr, self.len);
174            }
175            self.nnz += 1;
176        }
177        self.len += 1;
178    }
179
180    #[inline]
181    pub fn pop(&mut self) -> Option<N> {
182        if self.nnz == 0 {
183            return None;
184        }
185        let pop_element = if self.nnz == self.len {
186            self.nnz -= 1;
187            unsafe {
188                Some(ptr::read(self.val_ptr().add(self.nnz)))
189            }
190        } else {
191            Some(N::zero())
192        };
193        self.len -= 1;
194        pop_element
195    }
196
197    #[inline]
198    pub fn get(&self, index: usize) -> Option<&N> {
199        if index >= self.len {
200            return None;
201        }
202        match self.ind_binary_search(&index) {
203            Ok(idx) => {
204                unsafe {
205                    Some(&*self.val_ptr().add(idx))
206                }
207            },
208            Err(_) => {
209                Some(&self.zero)
210            }
211        }
212    }
213
214    #[inline]
215    pub fn get_ind(&self, index: usize) -> Option<usize> {
216        if index >= self.nnz {
217            return None;
218        }
219        unsafe {
220            Some(ptr::read(self.ind_ptr().add(index)))
221        }
222    }
223
224
225
226    /// removeメソッド
227    /// 
228    /// `index` 番目の要素を削除し、削除した要素を返します。
229    /// - 論理インデックス `index` が物理的に存在すれば、その値を返す
230    /// - 物理的になければ(= デフォルト扱いだった)デフォルト値を返す
231    /// 
232    /// # Arguments
233    /// - `index` - 削除する要素の論理インデックス
234    /// 
235    /// # Returns
236    /// - `N` - 削除した要素の値
237    #[inline]
238    pub fn remove(&mut self, index: usize) -> N {
239        debug_assert!(index < self.len, "index out of bounds");
240        
241        // 論理的な要素数は常に1つ減る
242        self.len -= 1;
243
244        match self.ind_binary_search(&index) {
245            Ok(i) => {
246                // 今回削除する要素を読みだす
247                let removed_val = unsafe {
248                    ptr::read(self.val_ptr().add(i))
249                };
250
251                // `i` 番目を削除するので、後ろを前にシフト
252                let count = self.nnz - i - 1;
253                if count > 0 {
254                    unsafe {
255                        // 値をコピーして前につめる
256                        ptr::copy(
257                            self.val_ptr().add(i + 1),
258                            self.val_ptr().add(i),
259                            count
260                        );
261                        // インデックスもコピーして前につめる
262                        ptr::copy(
263                            self.ind_ptr().add(i + 1),
264                            self.ind_ptr().add(i),
265                            count
266                        );
267                        // シフトした後のインデックスは全て -1 (1つ前に詰める)
268                        for offset in i..(self.nnz - 1) {
269                            *self.ind_ptr().add(offset) -= 1;
270                        }
271                    }
272                }
273                // nnzは 1 減
274                self.nnz -= 1;
275
276                // 取り除いた要素を返す
277                removed_val
278            }
279            Err(i) => {
280                // index は詰める必要があるので、i 以降の要素のインデックスを -1
281                // (たとえば “要素自体は無い” けど、後ろにある要素は
282                //  論理インデックスが 1 つ前になる)
283                if i < self.nnz {
284                    unsafe {
285                        for offset in i..self.nnz {
286                            *self.ind_ptr().add(offset) -= 1;
287                        }
288                    }
289                }
290
291                // 0返す
292                N::zero()
293            }
294        }
295    }
296
297    #[inline]
298    pub fn from_vec(vec: Vec<N>) -> Self {
299        let mut zero_sp_vec = ZeroSpVec::with_capacity(vec.len());
300        for entry in vec {
301            zero_sp_vec.push(entry);
302        }
303        zero_sp_vec
304    }
305
306    #[inline]
307    pub fn iter(&self) -> ZeroSpVecIter<N> {
308        ZeroSpVecIter {
309            vec: self,
310            pos: 0,
311        }
312    }
313
314    #[inline]
315    pub fn raw_iter(&self) -> ZeroSpVecRawIter<N> {
316        ZeroSpVecRawIter {
317            vec: self,
318            pos: 0,
319        }
320    }
321}
322
323unsafe impl <N: Num + Send> Send for ZeroSpVec<N> {}
324unsafe impl <N: Num + Sync> Sync for ZeroSpVec<N> {}
325
326impl<N> Clone for ZeroSpVec<N> 
327where N: Num
328{
329    #[inline]
330    fn clone(&self) -> Self {
331        ZeroSpVec {
332            buf: self.buf.clone(),
333            len: self.len,
334            nnz: self.nnz,
335            zero: N::zero(),
336        }
337    }
338}
339
340impl<N> Drop for ZeroSpVec<N> 
341where N: Num
342{
343    #[inline]
344    fn drop(&mut self) {
345        // RawZeroSpVecで実装済み
346    }
347}
348
349impl<N> Default for ZeroSpVec<N> 
350where N: Num
351{
352    #[inline]
353    fn default() -> Self {
354        ZeroSpVec::new()
355    }
356}
357
358impl<N> Index<usize> for ZeroSpVec<N> 
359where N: Num
360{
361    type Output = N;
362
363    #[inline]
364    fn index(&self, index: usize) -> &Self::Output {
365        self.get(index).expect("index out of bounds")
366    }
367}
368
369impl<N: Num + Debug> Debug for ZeroSpVec<N> {
370    #[inline]
371    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
372        if f.sign_plus() {
373            f.debug_struct("DefaultSparseVec")
374                .field("buf", &self.buf)
375                .field("nnz", &self.nnz)
376                .field("len", &self.len)
377                .field("zero", &self.zero)
378                .finish()
379        } else if f.alternate() {
380            write!(f, "ZeroSpVec({:?})", self.iter().collect::<Vec<&N>>())
381        } else {
382            f.debug_list().entries((0..self.len).map(|i| self.get(i).unwrap())).finish()
383        }
384    }
385}
386
387pub struct ZeroSpVecIter<'a, N> 
388where N: Num
389{
390    vec: &'a ZeroSpVec<N>,
391    pos: usize,
392}
393
394impl<'a, N> Iterator for ZeroSpVecIter<'a, N> 
395where N: Num
396{
397    type Item = &'a N;
398
399    #[inline]
400    fn next(&mut self) -> Option<Self::Item> {
401        self.vec.get(self.pos).map(|val| {
402            self.pos += 1;
403            val
404        })
405    }
406}
407
408pub struct ZeroSpVecRawIter<'a, N> 
409where N: Num
410{
411    vec: &'a ZeroSpVec<N>,
412    pos: usize,
413}
414
415impl<'a, N> Iterator for ZeroSpVecRawIter<'a, N> 
416where N: Num
417{
418    type Item = (usize, &'a N);
419
420    #[inline]
421    fn next(&mut self) -> Option<Self::Item> {
422        if self.pos < self.vec.nnz() {
423            let index = unsafe { *self.vec.ind_ptr().add(self.pos) };
424            let value = unsafe { &*self.vec.val_ptr().add(self.pos) };
425            self.pos += 1;
426            Some((index, value))
427        } else {
428            None
429        }
430    }
431    
432}
433
434
435
436
437
438
439
440
441
442
443
444/// ZeroSpVecの生実装
445#[derive(Debug)]
446struct RawZeroSpVec<N> 
447where N: Num 
448{
449    val_ptr: NonNull<N>,
450    ind_ptr: NonNull<usize>,
451    /// cap 定義
452    /// 0 => メモリ未確保 (flag)
453    /// usize::MAX =>  zero size struct (ZST) として定義 処理の簡略化を実施 (flag)
454    /// _ => 実際のcapN
455    cap: usize,
456    _marker: PhantomData<N>, // 所有権管理用にPhantomDataを追加
457}
458
459impl<N> RawZeroSpVec<N> 
460where N: Num
461{
462    #[inline]
463    fn new() -> Self {
464        // zero size struct (ZST)をusize::MAXと定義 ある種のフラグとして使用
465        let cap = if mem::size_of::<N>() == 0 { std::usize::MAX } else { 0 }; 
466
467        RawZeroSpVec {
468            // 空のポインタを代入しておく メモリ確保を遅延させる
469            val_ptr: NonNull::dangling(),
470            // 空のポインタを代入しておく メモリ確保を遅延させる
471            ind_ptr: NonNull::dangling(),
472            cap: cap,
473            _marker: PhantomData,
474        }
475    }
476
477    #[inline]
478    fn grow(&mut self) {
479        unsafe {
480            let val_elem_size = mem::size_of::<N>();
481            let ind_elem_size = mem::size_of::<usize>();
482
483            // 安全性: ZSTの場合growはcapを超えた場合にしか呼ばれない
484            // これは必然的にオーバーフローしていることをしめしている
485            debug_assert!(val_elem_size != 0, "capacity overflow");
486
487            // アライメントの取得 適切なメモリ確保を行うため
488            let t_align = mem::align_of::<N>();
489            let usize_align = mem::align_of::<usize>();
490
491            // アロケーション
492            let (new_cap, val_ptr, ind_ptr): (usize, *mut N, *mut usize) = 
493                if self.cap == 0 {
494                    let new_val_layout = Layout::from_size_align(val_elem_size, t_align).expect("Failed to create memory layout");
495                    let new_ind_layout = Layout::from_size_align(ind_elem_size, usize_align).expect("Failed to create memory layout");
496                    (
497                        1,
498                        alloc(new_val_layout) as *mut N,
499                        alloc(new_ind_layout) as *mut usize,
500                    )
501                } else {
502                    // 効率化: cap * 2 でメモリを確保する 見た目上はO(log n)の増加を実現
503                    let new_cap = self.cap * 2;
504                    let new_val_layout = Layout::from_size_align(val_elem_size * self.cap, t_align).expect("Failed to create memory layout for reallocation");
505                    let new_ind_layout = Layout::from_size_align(ind_elem_size * self.cap, usize_align).expect("Failed to create memory layout for reallocation");
506                    (
507                        new_cap,
508                        realloc(self.val_ptr.as_ptr() as *mut u8, new_val_layout, val_elem_size * new_cap) as *mut N,
509                        realloc(self.ind_ptr.as_ptr() as *mut u8, new_ind_layout, ind_elem_size * new_cap) as *mut usize,
510                    )
511                };
512
513            // アロケーション失敗時の処理
514            if val_ptr.is_null() || ind_ptr.is_null() {
515                oom();
516            }
517
518            // selfに返却
519            self.val_ptr = NonNull::new_unchecked(val_ptr);
520            self.ind_ptr = NonNull::new_unchecked(ind_ptr);
521            self.cap = new_cap;
522        }
523    }
524    
525    #[inline]
526    fn cap_set(&mut self) {
527        unsafe {
528            let val_elem_size = mem::size_of::<N>();
529            let ind_elem_size = mem::size_of::<usize>();
530
531            let t_align = mem::align_of::<N>();
532            let usize_align = mem::align_of::<usize>();
533
534            let new_val_layout = Layout::from_size_align(val_elem_size * self.cap, t_align).expect("Failed to create memory layout");
535            let new_ind_layout = Layout::from_size_align(ind_elem_size * self.cap, usize_align).expect("Failed to create memory layout");
536            let new_val_ptr = alloc(new_val_layout) as *mut N;
537            let new_ind_ptr = alloc(new_ind_layout) as *mut usize;
538            if new_val_ptr.is_null() || new_ind_ptr.is_null() {
539                oom();
540            }
541            self.val_ptr = NonNull::new_unchecked(new_val_ptr);
542            self.ind_ptr = NonNull::new_unchecked(new_ind_ptr);
543        }
544    }
545
546    #[inline]
547    fn re_cap_set(&mut self) {
548        unsafe {
549            let val_elem_size = mem::size_of::<N>();
550            let ind_elem_size = mem::size_of::<usize>();
551
552            let t_align = mem::align_of::<N>();
553            let usize_align = mem::align_of::<usize>();
554
555            let new_val_layout = Layout::from_size_align(val_elem_size * self.cap, t_align).expect("Failed to create memory layout");
556            let new_ind_layout = Layout::from_size_align(ind_elem_size * self.cap, usize_align).expect("Failed to create memory layout");
557            let new_val_ptr = realloc(self.val_ptr.as_ptr() as *mut u8, new_val_layout, val_elem_size * self.cap) as *mut N;
558            let new_ind_ptr = realloc(self.ind_ptr.as_ptr() as *mut u8, new_ind_layout, ind_elem_size * self.cap) as *mut usize;
559            if new_val_ptr.is_null() || new_ind_ptr.is_null() {
560                oom();
561            }
562            self.val_ptr = NonNull::new_unchecked(new_val_ptr);
563            self.ind_ptr = NonNull::new_unchecked(new_ind_ptr);
564        }
565    }
566}
567
568impl<N> Clone for RawZeroSpVec<N> 
569where N: Num
570{
571    #[inline]
572    fn clone(&self) -> Self {
573        unsafe {
574            let val_elem_size = mem::size_of::<N>();
575            let ind_elem_size = mem::size_of::<usize>();
576
577            let t_align = mem::align_of::<N>();
578            let usize_align = mem::align_of::<usize>();
579
580            let new_val_layout = Layout::from_size_align(val_elem_size * self.cap, t_align).expect("Failed to create memory layout");
581            let new_ind_layout = Layout::from_size_align(ind_elem_size * self.cap, usize_align).expect("Failed to create memory layout");
582            let new_val_ptr = alloc(new_val_layout) as *mut N;
583            let new_ind_ptr = alloc(new_ind_layout) as *mut usize;
584            if new_val_ptr.is_null() || new_ind_ptr.is_null() {
585                oom();
586            }
587            ptr::copy_nonoverlapping(self.val_ptr.as_ptr(), new_val_ptr, self.cap);
588            ptr::copy_nonoverlapping(self.ind_ptr.as_ptr(), new_ind_ptr, self.cap);
589            
590            RawZeroSpVec {
591                val_ptr: NonNull::new_unchecked(new_val_ptr),
592                ind_ptr: NonNull::new_unchecked(new_ind_ptr),
593                cap: self.cap,
594                _marker: PhantomData,
595            }
596        }
597    }
598}
599
600unsafe impl<N: Num + Send> Send for RawZeroSpVec<N> {}
601unsafe impl<N: Num + Sync> Sync for RawZeroSpVec<N> {}
602
603impl<N> Drop for RawZeroSpVec<N> 
604where N: Num
605{
606    #[inline]
607    fn drop(&mut self) {
608        unsafe {
609            let val_elem_size = mem::size_of::<N>();
610            let ind_elem_size = mem::size_of::<usize>();
611
612            let t_align = mem::align_of::<N>();
613            let usize_align = mem::align_of::<usize>();
614
615            let new_val_layout = Layout::from_size_align(val_elem_size * self.cap, t_align).expect("Failed to create memory layout");
616            let new_ind_layout = Layout::from_size_align(ind_elem_size * self.cap, usize_align).expect("Failed to create memory layout");
617            dealloc(self.val_ptr.as_ptr() as *mut u8, new_val_layout);
618            dealloc(self.ind_ptr.as_ptr() as *mut u8, new_ind_layout);
619        }
620    }
621}
622
623/// OutOfMemoryへの対処用
624/// プロセスを終了させる
625/// 本来はpanic!を使用するべきだが、
626/// OOMの場合panic!を発生させるとTraceBackによるメモリ仕様が起きてしまうため
627/// 仕方なく強制終了させる
628/// 本来OOMはOSにより管理され発生前にKillされるはずなのであんまり意味はない。
629#[cold]
630fn oom() {
631    ::std::process::exit(-9999);
632}