rust2vec/
storage.rs

1//! Embedding matrix representations.
2
3use std::fs::File;
4use std::io::{BufReader, Read, Seek, SeekFrom, Write};
5use std::mem::size_of;
6
7use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
8use failure::{ensure, format_err, Error};
9use memmap::{Mmap, MmapOptions};
10use ndarray::{Array, Array1, Array2, ArrayView, ArrayView2, Dimension, Ix1, Ix2};
11use rand::{FromEntropy, Rng};
12use rand_xorshift::XorShiftRng;
13use reductive::pq::{QuantizeVector, ReconstructVector, TrainPQ, PQ};
14
15use crate::io::private::{ChunkIdentifier, MmapChunk, ReadChunk, TypeId, WriteChunk};
16
17/// Copy-on-write wrapper for `Array`/`ArrayView`.
18///
19/// The `CowArray` type stores an owned array or an array view. In
20/// both cases a view (`as_view`) or an owned array (`into_owned`) can
21/// be obtained. If the wrapped array is a view, retrieving an owned
22/// array will copy the underlying data.
23pub enum CowArray<'a, A, D> {
24    Borrowed(ArrayView<'a, A, D>),
25    Owned(Array<A, D>),
26}
27
28impl<'a, A, D> CowArray<'a, A, D>
29where
30    D: Dimension,
31{
32    pub fn as_view(&self) -> ArrayView<A, D> {
33        match self {
34            CowArray::Borrowed(borrow) => borrow.view(),
35            CowArray::Owned(owned) => owned.view(),
36        }
37    }
38}
39
40impl<'a, A, D> CowArray<'a, A, D>
41where
42    A: Clone,
43    D: Dimension,
44{
45    pub fn into_owned(self) -> Array<A, D> {
46        match self {
47            CowArray::Borrowed(borrow) => borrow.to_owned(),
48            CowArray::Owned(owned) => owned,
49        }
50    }
51}
52
53/// 1D copy-on-write array.
54pub type CowArray1<'a, A> = CowArray<'a, A, Ix1>;
55
56/// Memory-mapped matrix.
57pub struct MmapArray {
58    map: Mmap,
59    shape: Ix2,
60}
61
62impl MmapChunk for MmapArray {
63    fn mmap_chunk(read: &mut BufReader<File>) -> Result<Self, Error> {
64        ensure!(
65            read.read_u32::<LittleEndian>()? == ChunkIdentifier::NdArray as u32,
66            "invalid chunk identifier for NdArray"
67        );
68
69        // Read and discard chunk length.
70        read.read_u64::<LittleEndian>()?;
71
72        let rows = read.read_u64::<LittleEndian>()? as usize;
73        let cols = read.read_u32::<LittleEndian>()? as usize;
74        let shape = Ix2(rows, cols);
75
76        ensure!(
77            read.read_u32::<LittleEndian>()? == f32::type_id(),
78            "Expected single precision floating point matrix for NdArray."
79        );
80
81        let n_padding = padding::<f32>(read.seek(SeekFrom::Current(0))?);
82        read.seek(SeekFrom::Current(n_padding as i64))?;
83
84        // Set up memory mapping.
85        let matrix_len = shape.size() * size_of::<f32>();
86        let offset = read.seek(SeekFrom::Current(0))?;
87        let mut mmap_opts = MmapOptions::new();
88        let map = unsafe {
89            mmap_opts
90                .offset(offset)
91                .len(matrix_len)
92                .map(&read.get_ref())?
93        };
94
95        // Position the reader after the matrix.
96        read.seek(SeekFrom::Current(matrix_len as i64))?;
97
98        Ok(MmapArray { map, shape })
99    }
100}
101
102impl WriteChunk for MmapArray {
103    fn chunk_identifier(&self) -> ChunkIdentifier {
104        ChunkIdentifier::NdArray
105    }
106
107    fn write_chunk<W>(&self, write: &mut W) -> Result<(), Error>
108    where
109        W: Write + Seek,
110    {
111        NdArray::write_ndarray_chunk(self.view(), write)
112    }
113}
114
115/// In-memory `ndarray` matrix.
116#[derive(Debug)]
117pub struct NdArray(pub Array2<f32>);
118
119impl NdArray {
120    fn write_ndarray_chunk<W>(data: ArrayView2<f32>, write: &mut W) -> Result<(), Error>
121    where
122        W: Write + Seek,
123    {
124        write.write_u32::<LittleEndian>(ChunkIdentifier::NdArray as u32)?;
125        let n_padding = padding::<f32>(write.seek(SeekFrom::Current(0))?);
126        // Chunk size: rows (u64), columns (u32), type id (u32),
127        //             padding ([0,4) bytes), matrix.
128        let chunk_len = size_of::<u64>()
129            + size_of::<u32>()
130            + size_of::<u32>()
131            + n_padding as usize
132            + (data.rows() * data.cols() * size_of::<f32>());
133        write.write_u64::<LittleEndian>(chunk_len as u64)?;
134        write.write_u64::<LittleEndian>(data.rows() as u64)?;
135        write.write_u32::<LittleEndian>(data.cols() as u32)?;
136        write.write_u32::<LittleEndian>(f32::type_id())?;
137
138        // Write padding, such that the embedding matrix starts on at
139        // a multiple of the size of f32 (4 bytes). This is necessary
140        // for memory mapping a matrix. Interpreting the raw u8 data
141        // as a proper f32 array requires that the data is aligned in
142        // memory. However, we cannot always memory map the starting
143        // offset of the matrix directly, since mmap(2) requires a
144        // file offset that is page-aligned. Since the page size is
145        // always a larger power of 2 (e.g. 2^12), which is divisible
146        // by 4, the offset of the matrix with regards to the page
147        // boundary is also a multiple of 4.
148
149        let padding = vec![0; n_padding as usize];
150        write.write_all(&padding)?;
151
152        for row in data.outer_iter() {
153            for col in row.iter() {
154                write.write_f32::<LittleEndian>(*col)?;
155            }
156        }
157
158        Ok(())
159    }
160}
161
162impl ReadChunk for NdArray {
163    fn read_chunk<R>(read: &mut R) -> Result<Self, Error>
164    where
165        R: Read + Seek,
166    {
167        let chunk_id = read.read_u32::<LittleEndian>()?;
168        let chunk_id = ChunkIdentifier::try_from(chunk_id)
169            .ok_or_else(|| format_err!("Unknown chunk identifier: {}", chunk_id))?;
170        ensure!(
171            chunk_id == ChunkIdentifier::NdArray,
172            "Cannot read chunk {:?} as NdArray",
173            chunk_id
174        );
175
176        // Read and discard chunk length.
177        read.read_u64::<LittleEndian>()?;
178
179        let rows = read.read_u64::<LittleEndian>()? as usize;
180        let cols = read.read_u32::<LittleEndian>()? as usize;
181
182        ensure!(
183            read.read_u32::<LittleEndian>()? == f32::type_id(),
184            "Expected single precision floating point matrix for NdArray."
185        );
186
187        let n_padding = padding::<f32>(read.seek(SeekFrom::Current(0))?);
188        read.seek(SeekFrom::Current(n_padding as i64))?;
189
190        let mut data = vec![0f32; rows * cols];
191        read.read_f32_into::<LittleEndian>(&mut data)?;
192
193        Ok(NdArray(Array2::from_shape_vec((rows, cols), data)?))
194    }
195}
196
197impl WriteChunk for NdArray {
198    fn chunk_identifier(&self) -> ChunkIdentifier {
199        ChunkIdentifier::NdArray
200    }
201
202    fn write_chunk<W>(&self, write: &mut W) -> Result<(), Error>
203    where
204        W: Write + Seek,
205    {
206        Self::write_ndarray_chunk(self.0.view(), write)
207    }
208}
209
210/// Quantized embedding matrix.
211pub struct QuantizedArray {
212    quantizer: PQ<f32>,
213    quantized: Array2<u8>,
214    norms: Option<Array1<f32>>,
215}
216
217impl ReadChunk for QuantizedArray {
218    fn read_chunk<R>(read: &mut R) -> Result<Self, Error>
219    where
220        R: Read + Seek,
221    {
222        let chunk_id = read.read_u32::<LittleEndian>()?;
223        let chunk_id = ChunkIdentifier::try_from(chunk_id)
224            .ok_or_else(|| format_err!("Unknown chunk identifier: {}", chunk_id))?;
225        ensure!(
226            chunk_id == ChunkIdentifier::QuantizedArray,
227            "Cannot read chunk {:?} as QuantizedArray",
228            chunk_id
229        );
230
231        // Read and discard chunk length.
232        read.read_u64::<LittleEndian>()?;
233
234        let projection = read.read_u32::<LittleEndian>()? != 0;
235        let read_norms = read.read_u32::<LittleEndian>()? != 0;
236        let quantized_len = read.read_u32::<LittleEndian>()? as usize;
237        let reconstructed_len = read.read_u32::<LittleEndian>()? as usize;
238        let n_centroids = read.read_u32::<LittleEndian>()? as usize;
239        let n_embeddings = read.read_u64::<LittleEndian>()? as usize;
240
241        ensure!(
242            read.read_u32::<LittleEndian>()? == u8::type_id(),
243            "Expected unsigned byte quantized embedding matrices."
244        );
245
246        ensure!(
247            read.read_u32::<LittleEndian>()? == f32::type_id(),
248            "Expected single precision floating point matrix quantizer matrices."
249        );
250
251        let n_padding = padding::<f32>(read.seek(SeekFrom::Current(0))?);
252        read.seek(SeekFrom::Current(n_padding as i64))?;
253
254        let projection = if projection {
255            let mut projection_vec = vec![0f32; reconstructed_len * reconstructed_len];
256            read.read_f32_into::<LittleEndian>(&mut projection_vec)?;
257            Some(Array2::from_shape_vec(
258                (reconstructed_len, reconstructed_len),
259                projection_vec,
260            )?)
261        } else {
262            None
263        };
264
265        let mut quantizers = Vec::with_capacity(quantized_len);
266        for _ in 0..quantized_len {
267            let mut subquantizer_vec =
268                vec![0f32; n_centroids * (reconstructed_len / quantized_len)];
269            read.read_f32_into::<LittleEndian>(&mut subquantizer_vec)?;
270            let subquantizer = Array2::from_shape_vec(
271                (n_centroids, reconstructed_len / quantized_len),
272                subquantizer_vec,
273            )?;
274            quantizers.push(subquantizer);
275        }
276
277        let norms = if read_norms {
278            let mut norms_vec = vec![0f32; n_embeddings];
279            read.read_f32_into::<LittleEndian>(&mut norms_vec)?;
280            Some(Array1::from_vec(norms_vec))
281        } else {
282            None
283        };
284
285        let mut quantized_embeddings_vec = vec![0u8; n_embeddings * quantized_len];
286        read.read_exact(&mut quantized_embeddings_vec)?;
287        let quantized =
288            Array2::from_shape_vec((n_embeddings, quantized_len), quantized_embeddings_vec)?;
289
290        Ok(QuantizedArray {
291            quantizer: PQ::new(projection, quantizers),
292            quantized,
293            norms,
294        })
295    }
296}
297
298impl WriteChunk for QuantizedArray {
299    fn chunk_identifier(&self) -> ChunkIdentifier {
300        ChunkIdentifier::QuantizedArray
301    }
302
303    fn write_chunk<W>(&self, write: &mut W) -> Result<(), Error>
304    where
305        W: Write + Seek,
306    {
307        write.write_u32::<LittleEndian>(ChunkIdentifier::QuantizedArray as u32)?;
308
309        // projection (u32), use_norms (u32), quantized_len (u32),
310        // reconstructed_len (u32), n_centroids (u32), rows (u64),
311        // types (2 x u32 bytes), padding, projection matrix,
312        // centroids, norms, quantized data.
313        let n_padding = padding::<f32>(write.seek(SeekFrom::Current(0))?);
314        let chunk_size = size_of::<u32>()
315            + size_of::<u32>()
316            + size_of::<u32>()
317            + size_of::<u32>()
318            + size_of::<u32>()
319            + size_of::<u64>()
320            + 2 * size_of::<u32>()
321            + n_padding as usize
322            + self.quantizer.projection().is_some() as usize
323                * self.quantizer.reconstructed_len()
324                * self.quantizer.reconstructed_len()
325                * size_of::<f32>()
326            + self.quantizer.quantized_len()
327                * self.quantizer.n_quantizer_centroids()
328                * (self.quantizer.reconstructed_len() / self.quantizer.quantized_len())
329                * size_of::<f32>()
330            + self.norms.is_some() as usize * self.quantized.rows() * size_of::<f32>()
331            + self.quantized.rows() * self.quantizer.quantized_len();
332
333        write.write_u64::<LittleEndian>(chunk_size as u64)?;
334
335        write.write_u32::<LittleEndian>(self.quantizer.projection().is_some() as u32)?;
336        write.write_u32::<LittleEndian>(self.norms.is_some() as u32)?;
337        write.write_u32::<LittleEndian>(self.quantizer.quantized_len() as u32)?;
338        write.write_u32::<LittleEndian>(self.quantizer.reconstructed_len() as u32)?;
339        write.write_u32::<LittleEndian>(self.quantizer.n_quantizer_centroids() as u32)?;
340        write.write_u64::<LittleEndian>(self.quantized.rows() as u64)?;
341
342        // Quantized and reconstruction types.
343        write.write_u32::<LittleEndian>(u8::type_id())?;
344        write.write_u32::<LittleEndian>(f32::type_id())?;
345
346        let padding = vec![0u8; n_padding as usize];
347        write.write_all(&padding)?;
348
349        // Write projection matrix.
350        if let Some(projection) = self.quantizer.projection() {
351            for row in projection.outer_iter() {
352                for &col in row {
353                    write.write_f32::<LittleEndian>(col)?;
354                }
355            }
356        }
357
358        // Write subquantizers.
359        for subquantizer in self.quantizer.subquantizers() {
360            for row in subquantizer.outer_iter() {
361                for &col in row {
362                    write.write_f32::<LittleEndian>(col)?;
363                }
364            }
365        }
366
367        // Write norms.
368        if let Some(ref norms) = self.norms {
369            for row in norms.outer_iter() {
370                for &col in row {
371                    write.write_f32::<LittleEndian>(col)?;
372                }
373            }
374        }
375
376        // Write quantized embedding matrix.
377        for row in self.quantized.outer_iter() {
378            for &col in row {
379                write.write_u8(col)?;
380            }
381        }
382
383        Ok(())
384    }
385}
386
387/// Storage types wrapper.
388///
389/// This crate makes it possible to create fine-grained embedding
390/// types, such as `Embeddings<SimpleVocab, NdArray>` or
391/// `Embeddings<SubwordVocab, QuantizedArray>`. However, in some cases
392/// it is more pleasant to have a single type that covers all
393/// vocabulary and storage types. `VocabWrap` and `StorageWrap` wrap
394/// all the vocabularies and storage types known to this crate such
395/// that the type `Embeddings<VocabWrap, StorageWrap>` covers all
396/// variations.
397pub enum StorageWrap {
398    NdArray(NdArray),
399    QuantizedArray(QuantizedArray),
400    MmapArray(MmapArray),
401}
402
403impl From<MmapArray> for StorageWrap {
404    fn from(s: MmapArray) -> Self {
405        StorageWrap::MmapArray(s)
406    }
407}
408
409impl From<NdArray> for StorageWrap {
410    fn from(s: NdArray) -> Self {
411        StorageWrap::NdArray(s)
412    }
413}
414
415impl From<QuantizedArray> for StorageWrap {
416    fn from(s: QuantizedArray) -> Self {
417        StorageWrap::QuantizedArray(s)
418    }
419}
420
421impl ReadChunk for StorageWrap {
422    fn read_chunk<R>(read: &mut R) -> Result<Self, Error>
423    where
424        R: Read + Seek,
425    {
426        let chunk_start_pos = read.seek(SeekFrom::Current(0))?;
427
428        let chunk_id = read.read_u32::<LittleEndian>()?;
429        let chunk_id = ChunkIdentifier::try_from(chunk_id)
430            .ok_or_else(|| format_err!("Unknown chunk identifier: {}", chunk_id))?;
431
432        read.seek(SeekFrom::Start(chunk_start_pos))?;
433
434        match chunk_id {
435            ChunkIdentifier::NdArray => NdArray::read_chunk(read).map(StorageWrap::NdArray),
436            ChunkIdentifier::QuantizedArray => {
437                QuantizedArray::read_chunk(read).map(StorageWrap::QuantizedArray)
438            }
439            _ => Err(format_err!(
440                "Chunk type {:?} cannot be read as storage",
441                chunk_id
442            )),
443        }
444    }
445}
446
447impl MmapChunk for StorageWrap {
448    fn mmap_chunk(read: &mut BufReader<File>) -> Result<Self, Error> {
449        let chunk_start_pos = read.seek(SeekFrom::Current(0))?;
450
451        let chunk_id = read.read_u32::<LittleEndian>()?;
452        let chunk_id = ChunkIdentifier::try_from(chunk_id)
453            .ok_or_else(|| format_err!("Unknown chunk identifier: {}", chunk_id))?;
454
455        read.seek(SeekFrom::Start(chunk_start_pos))?;
456
457        match chunk_id {
458            ChunkIdentifier::NdArray => MmapArray::mmap_chunk(read).map(StorageWrap::MmapArray),
459            _ => Err(format_err!(
460                "Chunk type {:?} cannot be memory mapped as viewable storage",
461                chunk_id
462            )),
463        }
464    }
465}
466
467impl WriteChunk for StorageWrap {
468    fn chunk_identifier(&self) -> ChunkIdentifier {
469        match self {
470            StorageWrap::MmapArray(inner) => inner.chunk_identifier(),
471            StorageWrap::NdArray(inner) => inner.chunk_identifier(),
472            StorageWrap::QuantizedArray(inner) => inner.chunk_identifier(),
473        }
474    }
475
476    fn write_chunk<W>(&self, write: &mut W) -> Result<(), Error>
477    where
478        W: Write + Seek,
479    {
480        match self {
481            StorageWrap::MmapArray(inner) => inner.write_chunk(write),
482            StorageWrap::NdArray(inner) => inner.write_chunk(write),
483            StorageWrap::QuantizedArray(inner) => inner.write_chunk(write),
484        }
485    }
486}
487
488/// Wrapper for storage types that implement views.
489///
490/// This type covers the subset of storage types that implement
491/// `StorageView`. See the `StorageWrap` type for more information.
492pub enum StorageViewWrap {
493    MmapArray(MmapArray),
494    NdArray(NdArray),
495}
496
497impl From<MmapArray> for StorageViewWrap {
498    fn from(s: MmapArray) -> Self {
499        StorageViewWrap::MmapArray(s)
500    }
501}
502
503impl From<NdArray> for StorageViewWrap {
504    fn from(s: NdArray) -> Self {
505        StorageViewWrap::NdArray(s)
506    }
507}
508
509impl ReadChunk for StorageViewWrap {
510    fn read_chunk<R>(read: &mut R) -> Result<Self, Error>
511    where
512        R: Read + Seek,
513    {
514        let chunk_start_pos = read.seek(SeekFrom::Current(0))?;
515
516        let chunk_id = read.read_u32::<LittleEndian>()?;
517        let chunk_id = ChunkIdentifier::try_from(chunk_id)
518            .ok_or_else(|| format_err!("Unknown chunk identifier: {}", chunk_id))?;
519
520        read.seek(SeekFrom::Start(chunk_start_pos))?;
521
522        match chunk_id {
523            ChunkIdentifier::NdArray => NdArray::read_chunk(read).map(StorageViewWrap::NdArray),
524            _ => Err(format_err!(
525                "Chunk type {:?} cannot be read as viewable storage",
526                chunk_id
527            )),
528        }
529    }
530}
531
532impl WriteChunk for StorageViewWrap {
533    fn chunk_identifier(&self) -> ChunkIdentifier {
534        match self {
535            StorageViewWrap::MmapArray(inner) => inner.chunk_identifier(),
536            StorageViewWrap::NdArray(inner) => inner.chunk_identifier(),
537        }
538    }
539
540    fn write_chunk<W>(&self, write: &mut W) -> Result<(), Error>
541    where
542        W: Write + Seek,
543    {
544        match self {
545            StorageViewWrap::MmapArray(inner) => inner.write_chunk(write),
546            StorageViewWrap::NdArray(inner) => inner.write_chunk(write),
547        }
548    }
549}
550
551impl MmapChunk for StorageViewWrap {
552    fn mmap_chunk(read: &mut BufReader<File>) -> Result<Self, Error> {
553        let chunk_start_pos = read.seek(SeekFrom::Current(0))?;
554
555        let chunk_id = read.read_u32::<LittleEndian>()?;
556        let chunk_id = ChunkIdentifier::try_from(chunk_id)
557            .ok_or_else(|| format_err!("Unknown chunk identifier: {}", chunk_id))?;
558
559        read.seek(SeekFrom::Start(chunk_start_pos))?;
560
561        match chunk_id {
562            ChunkIdentifier::NdArray => MmapArray::mmap_chunk(read).map(StorageViewWrap::MmapArray),
563            _ => Err(format_err!(
564                "Chunk type {:?} cannot be memory mapped as viewable storage",
565                chunk_id
566            )),
567        }
568    }
569}
570
571/// Embedding matrix storage.
572///
573/// To allow for embeddings to be stored in different manners (e.g.
574/// regular *n x d* matrix or as quantized vectors), this trait
575/// abstracts over concrete storage types.
576pub trait Storage {
577    fn embedding(&self, idx: usize) -> CowArray1<f32>;
578
579    fn shape(&self) -> (usize, usize);
580}
581
582impl Storage for MmapArray {
583    fn embedding(&self, idx: usize) -> CowArray1<f32> {
584        CowArray::Owned(
585            // Alignment is ok, padding guarantees that the pointer is at
586            // a multiple of 4.
587            #[allow(clippy::cast_ptr_alignment)]
588            unsafe { ArrayView2::from_shape_ptr(self.shape, self.map.as_ptr() as *const f32) }
589                .row(idx)
590                .to_owned(),
591        )
592    }
593
594    fn shape(&self) -> (usize, usize) {
595        self.shape.into_pattern()
596    }
597}
598
599impl Storage for NdArray {
600    fn embedding(&self, idx: usize) -> CowArray1<f32> {
601        CowArray::Borrowed(self.0.row(idx))
602    }
603
604    fn shape(&self) -> (usize, usize) {
605        self.0.dim()
606    }
607}
608
609impl Storage for QuantizedArray {
610    fn embedding(&self, idx: usize) -> CowArray1<f32> {
611        let mut reconstructed = self.quantizer.reconstruct_vector(self.quantized.row(idx));
612        if let Some(ref norms) = self.norms {
613            reconstructed *= norms[idx];
614        }
615
616        CowArray::Owned(reconstructed)
617    }
618
619    fn shape(&self) -> (usize, usize) {
620        (self.quantized.rows(), self.quantizer.reconstructed_len())
621    }
622}
623
624impl Storage for StorageWrap {
625    fn embedding(&self, idx: usize) -> CowArray1<f32> {
626        match self {
627            StorageWrap::MmapArray(inner) => inner.embedding(idx),
628            StorageWrap::NdArray(inner) => inner.embedding(idx),
629            StorageWrap::QuantizedArray(inner) => inner.embedding(idx),
630        }
631    }
632
633    fn shape(&self) -> (usize, usize) {
634        match self {
635            StorageWrap::MmapArray(inner) => inner.shape(),
636            StorageWrap::NdArray(inner) => inner.shape(),
637            StorageWrap::QuantizedArray(inner) => inner.shape(),
638        }
639    }
640}
641
642impl Storage for StorageViewWrap {
643    fn embedding(&self, idx: usize) -> CowArray1<f32> {
644        match self {
645            StorageViewWrap::MmapArray(inner) => inner.embedding(idx),
646            StorageViewWrap::NdArray(inner) => inner.embedding(idx),
647        }
648    }
649
650    fn shape(&self) -> (usize, usize) {
651        match self {
652            StorageViewWrap::MmapArray(inner) => inner.shape(),
653            StorageViewWrap::NdArray(inner) => inner.shape(),
654        }
655    }
656}
657
658/// Storage that provide a view of the embedding matrix.
659pub trait StorageView: Storage {
660    /// Get a view of the embedding matrix.
661    fn view(&self) -> ArrayView2<f32>;
662}
663
664impl StorageView for NdArray {
665    fn view(&self) -> ArrayView2<f32> {
666        self.0.view()
667    }
668}
669
670impl StorageView for MmapArray {
671    fn view(&self) -> ArrayView2<f32> {
672        // Alignment is ok, padding guarantees that the pointer is at
673        // a multiple of 4.
674        #[allow(clippy::cast_ptr_alignment)]
675        unsafe {
676            ArrayView2::from_shape_ptr(self.shape, self.map.as_ptr() as *const f32)
677        }
678    }
679}
680
681impl StorageView for StorageViewWrap {
682    fn view(&self) -> ArrayView2<f32> {
683        match self {
684            StorageViewWrap::MmapArray(inner) => inner.view(),
685            StorageViewWrap::NdArray(inner) => inner.view(),
686        }
687    }
688}
689
690/// Quantizable embedding matrix.
691pub trait Quantize {
692    /// Quantize the embedding matrix.
693    ///
694    /// This method trains a quantizer for the embedding matrix and
695    /// then quantizes the matrix using this quantizer.
696    ///
697    /// The xorshift PRNG is used for picking the initial quantizer
698    /// centroids.
699    fn quantize<T>(
700        &self,
701        n_subquantizers: usize,
702        n_subquantizer_bits: u32,
703        n_iterations: usize,
704        n_attempts: usize,
705        normalize: bool,
706    ) -> QuantizedArray
707    where
708        T: TrainPQ<f32>,
709    {
710        self.quantize_using::<T, _>(
711            n_subquantizers,
712            n_subquantizer_bits,
713            n_iterations,
714            n_attempts,
715            normalize,
716            &mut XorShiftRng::from_entropy(),
717        )
718    }
719
720    /// Quantize the embedding matrix using the provided RNG.
721    ///
722    /// This method trains a quantizer for the embedding matrix and
723    /// then quantizes the matrix using this quantizer.
724    fn quantize_using<T, R>(
725        &self,
726        n_subquantizers: usize,
727        n_subquantizer_bits: u32,
728        n_iterations: usize,
729        n_attempts: usize,
730        normalize: bool,
731        rng: &mut R,
732    ) -> QuantizedArray
733    where
734        T: TrainPQ<f32>,
735        R: Rng;
736}
737
738impl<S> Quantize for S
739where
740    S: StorageView,
741{
742    /// Quantize the embedding matrix.
743    ///
744    /// This method trains a quantizer for the embedding matrix and
745    /// then quantizes the matrix using this quantizer.
746    fn quantize_using<T, R>(
747        &self,
748        n_subquantizers: usize,
749        n_subquantizer_bits: u32,
750        n_iterations: usize,
751        n_attempts: usize,
752        normalize: bool,
753        rng: &mut R,
754    ) -> QuantizedArray
755    where
756        T: TrainPQ<f32>,
757        R: Rng,
758    {
759        let (embeds, norms) = if normalize {
760            let norms = self.view().outer_iter().map(|e| e.dot(&e).sqrt()).collect();
761            let mut normalized = self.view().to_owned();
762            for (mut embedding, &norm) in normalized.outer_iter_mut().zip(&norms) {
763                embedding /= norm;
764            }
765            (CowArray::Owned(normalized), Some(norms))
766        } else {
767            (CowArray::Borrowed(self.view()), None)
768        };
769
770        let quantizer = T::train_pq_using(
771            n_subquantizers,
772            n_subquantizer_bits,
773            n_iterations,
774            n_attempts,
775            embeds.as_view(),
776            rng,
777        );
778
779        let quantized = quantizer.quantize_batch(embeds.as_view());
780
781        QuantizedArray {
782            quantizer,
783            quantized,
784            norms,
785        }
786    }
787}
788
789fn padding<T>(pos: u64) -> u64 {
790    let size = size_of::<T>() as u64;
791    size - (pos % size)
792}
793
794#[cfg(test)]
795mod tests {
796    use std::io::{Cursor, Read, Seek, SeekFrom};
797
798    use byteorder::{LittleEndian, ReadBytesExt};
799    use ndarray::Array2;
800    use reductive::pq::PQ;
801
802    use crate::io::private::{ReadChunk, WriteChunk};
803    use crate::storage::{NdArray, Quantize, QuantizedArray, StorageView};
804
805    const N_ROWS: usize = 100;
806    const N_COLS: usize = 100;
807
808    fn test_ndarray() -> NdArray {
809        let test_data = Array2::from_shape_fn((N_ROWS, N_COLS), |(r, c)| {
810            r as f32 * N_COLS as f32 + c as f32
811        });
812
813        NdArray(test_data)
814    }
815
816    fn test_quantized_array(norms: bool) -> QuantizedArray {
817        let ndarray = test_ndarray();
818        ndarray.quantize::<PQ<f32>>(10, 4, 5, 1, norms)
819    }
820
821    fn read_chunk_size(read: &mut impl Read) -> u64 {
822        // Skip identifier.
823        read.read_u32::<LittleEndian>().unwrap();
824
825        // Return chunk length.
826        read.read_u64::<LittleEndian>().unwrap()
827    }
828
829    #[test]
830    fn ndarray_correct_chunk_size() {
831        let check_arr = test_ndarray();
832        let mut cursor = Cursor::new(Vec::new());
833        check_arr.write_chunk(&mut cursor).unwrap();
834        cursor.seek(SeekFrom::Start(0)).unwrap();
835
836        let chunk_size = read_chunk_size(&mut cursor);
837        assert_eq!(
838            cursor.read_to_end(&mut Vec::new()).unwrap(),
839            chunk_size as usize
840        );
841    }
842
843    #[test]
844    fn ndarray_write_read_roundtrip() {
845        let check_arr = test_ndarray();
846        let mut cursor = Cursor::new(Vec::new());
847        check_arr.write_chunk(&mut cursor).unwrap();
848        cursor.seek(SeekFrom::Start(0)).unwrap();
849        let arr = NdArray::read_chunk(&mut cursor).unwrap();
850        assert_eq!(arr.view(), check_arr.view());
851    }
852
853    #[test]
854    fn quantized_array_correct_chunk_size() {
855        let check_arr = test_quantized_array(false);
856        let mut cursor = Cursor::new(Vec::new());
857        check_arr.write_chunk(&mut cursor).unwrap();
858        cursor.seek(SeekFrom::Start(0)).unwrap();
859
860        let chunk_size = read_chunk_size(&mut cursor);
861        assert_eq!(
862            cursor.read_to_end(&mut Vec::new()).unwrap(),
863            chunk_size as usize
864        );
865    }
866
867    #[test]
868    fn quantized_array_norms_correct_chunk_size() {
869        let check_arr = test_quantized_array(true);
870        let mut cursor = Cursor::new(Vec::new());
871        check_arr.write_chunk(&mut cursor).unwrap();
872        cursor.seek(SeekFrom::Start(0)).unwrap();
873
874        let chunk_size = read_chunk_size(&mut cursor);
875        assert_eq!(
876            cursor.read_to_end(&mut Vec::new()).unwrap(),
877            chunk_size as usize
878        );
879    }
880
881    #[test]
882    fn quantized_array_read_write_roundtrip() {
883        let check_arr = test_quantized_array(true);
884        let mut cursor = Cursor::new(Vec::new());
885        check_arr.write_chunk(&mut cursor).unwrap();
886        cursor.seek(SeekFrom::Start(0)).unwrap();
887        let arr = QuantizedArray::read_chunk(&mut cursor).unwrap();
888        assert_eq!(arr.quantizer, check_arr.quantizer);
889        assert_eq!(arr.quantized, check_arr.quantized);
890    }
891}