tiny_varint/
value.rs

1use crate::{encode, decode, encode_zigzag, decode_zigzag, Error};
2
3/// Enum representing different integer types that can be encoded as varints.
4/// Each variant wraps a specific Rust integer type.
5#[derive(Debug, Clone, Copy, PartialEq)]
6pub enum VarintValue {
7    /// Unsigned 8-bit integer
8    U8(u8),
9    /// Unsigned 16-bit integer
10    U16(u16),
11    /// Unsigned 32-bit integer
12    U32(u32),
13    /// Unsigned 64-bit integer
14    U64(u64),
15    /// Unsigned 128-bit integer
16    U128(u128),
17    /// Signed 8-bit integer
18    I8(i8),
19    /// Signed 16-bit integer
20    I16(i16),
21    /// Signed 32-bit integer
22    I32(i32),
23    /// Signed 64-bit integer
24    I64(i64),
25    /// Signed 128-bit integer
26    I128(i128),
27}
28
29// Type encoding bits:
30// First 3 bits: Type info
31// Last 5 bits: Value type info
32const TYPE_BITS_UNSIGNED: u8 = 0b000_00000;
33const TYPE_BITS_SIGNED: u8   = 0b001_00000;
34
35// Size bits
36const SIZE_BITS_8: u8    = 0b000_00000;
37const SIZE_BITS_16: u8   = 0b000_00001;
38const SIZE_BITS_32: u8   = 0b000_00010;
39const SIZE_BITS_64: u8   = 0b000_00011;
40const SIZE_BITS_128: u8  = 0b000_00100;
41
42// Optimization: Macro for handling all types in a match statement
43macro_rules! for_all_types {
44    ($value:expr, $unsigned_op:expr, $signed_op:expr) => {
45        match $value {
46            VarintValue::U8(val) => ($unsigned_op)(*val, TYPE_BITS_UNSIGNED | SIZE_BITS_8),
47            VarintValue::U16(val) => ($unsigned_op)(*val, TYPE_BITS_UNSIGNED | SIZE_BITS_16),
48            VarintValue::U32(val) => ($unsigned_op)(*val, TYPE_BITS_UNSIGNED | SIZE_BITS_32),
49            VarintValue::U64(val) => ($unsigned_op)(*val, TYPE_BITS_UNSIGNED | SIZE_BITS_64),
50            VarintValue::U128(val) => ($unsigned_op)(*val, TYPE_BITS_UNSIGNED | SIZE_BITS_128),
51            VarintValue::I8(val) => ($signed_op)(*val, TYPE_BITS_SIGNED | SIZE_BITS_8),
52            VarintValue::I16(val) => ($signed_op)(*val, TYPE_BITS_SIGNED | SIZE_BITS_16),
53            VarintValue::I32(val) => ($signed_op)(*val, TYPE_BITS_SIGNED | SIZE_BITS_32),
54            VarintValue::I64(val) => ($signed_op)(*val, TYPE_BITS_SIGNED | SIZE_BITS_64),
55            VarintValue::I128(val) => ($signed_op)(*val, TYPE_BITS_SIGNED | SIZE_BITS_128),
56        }
57    };
58}
59
60impl VarintValue {
61    /// Returns the type identifier byte for this value
62    #[inline]
63    pub fn get_type_id(&self) -> u8 {
64        for_all_types!(self, 
65            |_, type_id| type_id, 
66            |_, type_id| type_id
67        )
68    }
69    
70    /// Directly calculate the number of bytes needed to encode this value
71    #[inline]
72    fn direct_size_calculation(&self) -> usize {
73        let type_byte_size = 1; // 类型标识字节
74        
75        // 对于值为0的情况优化
76        match self {
77            VarintValue::U8(0) | VarintValue::U16(0) | VarintValue::U32(0) | 
78            VarintValue::U64(0) | VarintValue::U128(0) | VarintValue::I8(0) | 
79            VarintValue::I16(0) | VarintValue::I32(0) | VarintValue::I64(0) | 
80            VarintValue::I128(0) => return type_byte_size,
81            _ => {}
82        }
83        
84        // 计算值所需的字节数
85        let value_size = match self {
86            // 无符号类型 - 计算所需比特数然后转换为字节
87            VarintValue::U8(val) => {
88                if *val == 0 { 1 } else {
89                    let bits = 8 - val.leading_zeros() as usize;
90                    (bits + 6) / 7
91                }
92            },
93            VarintValue::U16(val) => {
94                if *val == 0 { 1 } else {
95                    let bits = 16 - val.leading_zeros() as usize;
96                    (bits + 6) / 7
97                }
98            },
99            VarintValue::U32(val) => {
100                if *val == 0 { 1 } else {
101                    let bits = 32 - val.leading_zeros() as usize;
102                    (bits + 6) / 7
103                }
104            },
105            VarintValue::U64(val) => {
106                if *val == 0 { 1 } else {
107                    let bits = 64 - val.leading_zeros() as usize;
108                    (bits + 6) / 7
109                }
110            },
111            VarintValue::U128(val) => {
112                if *val == 0 { 1 } else {
113                    let bits = 128 - val.leading_zeros() as usize;
114                    (bits + 6) / 7
115                }
116            },
117            
118            // 有符号类型 - 使用ZigZag编码计算
119            VarintValue::I8(val) => {
120                let zigzag_val = ((val << 1) ^ (val >> 7)) as u8;
121                if zigzag_val == 0 { 1 } else {
122                    let bits = 8 - zigzag_val.leading_zeros() as usize;
123                    (bits + 6) / 7
124                }
125            },
126            VarintValue::I16(val) => {
127                let zigzag_val = ((val << 1) ^ (val >> 15)) as u16;
128                if zigzag_val == 0 { 1 } else {
129                    let bits = 16 - zigzag_val.leading_zeros() as usize;
130                    (bits + 6) / 7
131                }
132            },
133            VarintValue::I32(val) => {
134                let zigzag_val = ((val << 1) ^ (val >> 31)) as u32;
135                if zigzag_val == 0 { 1 } else {
136                    let bits = 32 - zigzag_val.leading_zeros() as usize;
137                    (bits + 6) / 7
138                }
139            },
140            VarintValue::I64(val) => {
141                let zigzag_val = ((val << 1) ^ (val >> 63)) as u64;
142                if zigzag_val == 0 { 1 } else {
143                    let bits = 64 - zigzag_val.leading_zeros() as usize;
144                    (bits + 6) / 7
145                }
146            },
147            VarintValue::I128(val) => {
148                let zigzag_val = ((val << 1) ^ (val >> 127)) as u128;
149                if zigzag_val == 0 { 1 } else {
150                    let bits = 128 - zigzag_val.leading_zeros() as usize;
151                    (bits + 6) / 7
152                }
153            },
154        };
155        
156        type_byte_size + value_size
157    }
158    
159    /// Returns the number of bytes needed to serialize this value
160    #[inline]
161    pub fn serialized_size(&self) -> usize {
162        self.direct_size_calculation()
163    }
164    
165    /// Serializes the value into a byte buffer.
166    /// 
167    /// The first byte contains the type identifier, followed by the encoded integer value.
168    /// Unsigned integers use standard varint encoding, while signed integers use zigzag encoding.
169    ///
170    /// # Arguments
171    /// * `buffer` - The buffer to write into
172    ///
173    /// # Returns
174    /// * `Ok(size)` - The number of bytes written
175    /// * `Err(...)` - If encoding fails or buffer is too small
176    #[inline]
177    pub fn to_bytes(&self, buffer: &mut [u8]) -> Result<usize, Error> {
178        if buffer.is_empty() {
179            return Err(Error::BufferTooSmall { 
180                needed: 1,
181                actual: 0
182            });
183        }
184        
185        // 优化:处理零值的特殊情况
186        match self {
187            VarintValue::U8(0) | VarintValue::U16(0) | VarintValue::U32(0) | 
188            VarintValue::U64(0) | VarintValue::U128(0) | VarintValue::I8(0) | 
189            VarintValue::I16(0) | VarintValue::I32(0) | VarintValue::I64(0) | 
190            VarintValue::I128(0) => {
191                // 零值的特殊情况 - 只需要一个类型字节
192                buffer[0] = self.get_type_id();
193                return Ok(1);
194            },
195            _ => { /* 继续正常编码 */ }
196        }
197        
198        // 一般情况编码
199        buffer[0] = self.get_type_id();
200        
201        // 直接编码到缓冲区,避免临时缓冲区
202        match self {
203            // 无符号类型使用标准编码
204            VarintValue::U8(val) => {
205                let result = encode(*val, &mut buffer[1..]);
206                match result {
207                    Ok(bytes_written) => Ok(bytes_written + 1), // +1 表示类型字节
208                    Err(e) => Err(e),
209                }
210            },
211            VarintValue::U16(val) => {
212                let result = encode(*val, &mut buffer[1..]);
213                match result {
214                    Ok(bytes_written) => Ok(bytes_written + 1),
215                    Err(e) => Err(e),
216                }
217            },
218            VarintValue::U32(val) => {
219                let result = encode(*val, &mut buffer[1..]);
220                match result {
221                    Ok(bytes_written) => Ok(bytes_written + 1),
222                    Err(e) => Err(e),
223                }
224            },
225            VarintValue::U64(val) => {
226                let result = encode(*val, &mut buffer[1..]);
227                match result {
228                    Ok(bytes_written) => Ok(bytes_written + 1),
229                    Err(e) => Err(e),
230                }
231            },
232            VarintValue::U128(val) => {
233                let result = encode(*val, &mut buffer[1..]);
234                match result {
235                    Ok(bytes_written) => Ok(bytes_written + 1),
236                    Err(e) => Err(e),
237                }
238            },
239            
240            // 有符号类型使用zigzag编码
241            VarintValue::I8(val) => {
242                let result = encode_zigzag(*val, &mut buffer[1..]);
243                match result {
244                    Ok(bytes_written) => Ok(bytes_written + 1),
245                    Err(e) => Err(e),
246                }
247            },
248            VarintValue::I16(val) => {
249                let result = encode_zigzag(*val, &mut buffer[1..]);
250                match result {
251                    Ok(bytes_written) => Ok(bytes_written + 1),
252                    Err(e) => Err(e),
253                }
254            },
255            VarintValue::I32(val) => {
256                let result = encode_zigzag(*val, &mut buffer[1..]);
257                match result {
258                    Ok(bytes_written) => Ok(bytes_written + 1),
259                    Err(e) => Err(e),
260                }
261            },
262            VarintValue::I64(val) => {
263                let result = encode_zigzag(*val, &mut buffer[1..]);
264                match result {
265                    Ok(bytes_written) => Ok(bytes_written + 1),
266                    Err(e) => Err(e),
267                }
268            },
269            VarintValue::I128(val) => {
270                let result = encode_zigzag(*val, &mut buffer[1..]);
271                match result {
272                    Ok(bytes_written) => Ok(bytes_written + 1),
273                    Err(e) => Err(e),
274                }
275            },
276        }
277    }
278    
279    /// Deserializes a value from a byte buffer.
280    ///
281    /// # Arguments
282    /// * `bytes` - The byte buffer to read from
283    ///
284    /// # Returns
285    /// * `Ok((value, size))` - The deserialized value and number of bytes read
286    /// * `Err(...)` - If decoding fails
287    #[inline]
288    pub fn from_bytes(bytes: &[u8]) -> Result<(Self, usize), Error> {
289        if bytes.is_empty() {
290            return Err(Error::InputTooShort);
291        }
292        
293        let type_byte = bytes[0];
294        let type_bits = type_byte & 0b111_00000; // Get high 3 bits
295        let size_bits = type_byte & 0b000_11111; // Get low 5 bits
296        
297        let data = &bytes[1..];
298        
299        // Check if it's the special case for zero
300        if data.is_empty() && (type_bits == TYPE_BITS_UNSIGNED || type_bits == TYPE_BITS_SIGNED) {
301            // This might be a zero value in compact form
302            match (type_bits, size_bits) {
303                (TYPE_BITS_UNSIGNED, SIZE_BITS_8) => return Ok((VarintValue::U8(0), 1)),
304                (TYPE_BITS_UNSIGNED, SIZE_BITS_16) => return Ok((VarintValue::U16(0), 1)),
305                (TYPE_BITS_UNSIGNED, SIZE_BITS_32) => return Ok((VarintValue::U32(0), 1)),
306                (TYPE_BITS_UNSIGNED, SIZE_BITS_64) => return Ok((VarintValue::U64(0), 1)),
307                (TYPE_BITS_UNSIGNED, SIZE_BITS_128) => return Ok((VarintValue::U128(0), 1)),
308                (TYPE_BITS_SIGNED, SIZE_BITS_8) => return Ok((VarintValue::I8(0), 1)),
309                (TYPE_BITS_SIGNED, SIZE_BITS_16) => return Ok((VarintValue::I16(0), 1)),
310                (TYPE_BITS_SIGNED, SIZE_BITS_32) => return Ok((VarintValue::I32(0), 1)),
311                (TYPE_BITS_SIGNED, SIZE_BITS_64) => return Ok((VarintValue::I64(0), 1)),
312                (TYPE_BITS_SIGNED, SIZE_BITS_128) => return Ok((VarintValue::I128(0), 1)),
313                _ => return Err(Error::InvalidEncoding),
314            }
315        }
316        
317        // Regular decoding based on type
318        match (type_bits, size_bits) {
319            (TYPE_BITS_UNSIGNED, SIZE_BITS_8) => {
320                let (val, bytes_read) = decode::<u8>(data)?;
321                Ok((VarintValue::U8(val), bytes_read + 1))
322            },
323            (TYPE_BITS_UNSIGNED, SIZE_BITS_16) => {
324                let (val, bytes_read) = decode::<u16>(data)?;
325                Ok((VarintValue::U16(val), bytes_read + 1))
326            },
327            (TYPE_BITS_UNSIGNED, SIZE_BITS_32) => {
328                let (val, bytes_read) = decode::<u32>(data)?;
329                Ok((VarintValue::U32(val), bytes_read + 1))
330            },
331            (TYPE_BITS_UNSIGNED, SIZE_BITS_64) => {
332                let (val, bytes_read) = decode::<u64>(data)?;
333                Ok((VarintValue::U64(val), bytes_read + 1))
334            },
335            (TYPE_BITS_UNSIGNED, SIZE_BITS_128) => {
336                let (val, bytes_read) = decode::<u128>(data)?;
337                Ok((VarintValue::U128(val), bytes_read + 1))
338            },
339            (TYPE_BITS_SIGNED, SIZE_BITS_8) => {
340                let (val, bytes_read) = decode_zigzag::<i8>(data)?;
341                Ok((VarintValue::I8(val), bytes_read + 1))
342            },
343            (TYPE_BITS_SIGNED, SIZE_BITS_16) => {
344                let (val, bytes_read) = decode_zigzag::<i16>(data)?;
345                Ok((VarintValue::I16(val), bytes_read + 1))
346            },
347            (TYPE_BITS_SIGNED, SIZE_BITS_32) => {
348                let (val, bytes_read) = decode_zigzag::<i32>(data)?;
349                Ok((VarintValue::I32(val), bytes_read + 1))
350            },
351            (TYPE_BITS_SIGNED, SIZE_BITS_64) => {
352                let (val, bytes_read) = decode_zigzag::<i64>(data)?;
353                Ok((VarintValue::I64(val), bytes_read + 1))
354            },
355            (TYPE_BITS_SIGNED, SIZE_BITS_128) => {
356                let (val, bytes_read) = decode_zigzag::<i128>(data)?;
357                Ok((VarintValue::I128(val), bytes_read + 1))
358            },
359            _ => Err(Error::InvalidEncoding),
360        }
361    }
362}
363
364/// Macro for creating VarintValue instances in a concise way
365#[macro_export]
366macro_rules! varint {
367    (u8: $val:expr) => { $crate::VarintValue::U8($val) };
368    (u16: $val:expr) => { $crate::VarintValue::U16($val) };
369    (u32: $val:expr) => { $crate::VarintValue::U32($val) };
370    (u64: $val:expr) => { $crate::VarintValue::U64($val) };
371    (u128: $val:expr) => { $crate::VarintValue::U128($val) };
372    (i8: $val:expr) => { $crate::VarintValue::I8($val) };
373    (i16: $val:expr) => { $crate::VarintValue::I16($val) };
374    (i32: $val:expr) => { $crate::VarintValue::I32($val) };
375    (i64: $val:expr) => { $crate::VarintValue::I64($val) };
376    (i128: $val:expr) => { $crate::VarintValue::I128($val) };
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    
383    #[test]
384    fn test_type_id_encoding() {
385        // Test that type IDs are correctly encoded and decoded
386        let values = [
387            VarintValue::U8(42),
388            VarintValue::U16(1000),
389            VarintValue::U32(100000),
390            VarintValue::I8(-42),
391            VarintValue::I16(-1000),
392            VarintValue::I32(-100000),
393        ];
394        
395        for value in &values {
396            let type_id = value.get_type_id();
397            
398            // Check if we can reconstruct the value's type from just the type ID
399            let mut dummy_buffer = [type_id, 0, 0, 0, 0]; // Just need the type byte
400            let (decoded, _) = VarintValue::from_bytes(&dummy_buffer).unwrap_err_or_else(|_| {
401                // Only testing that the type is correctly identified
402                match value {
403                    VarintValue::U8(_) => (VarintValue::U8(0), 1),
404                    VarintValue::U16(_) => (VarintValue::U16(0), 1),
405                    VarintValue::U32(_) => (VarintValue::U32(0), 1),
406                    VarintValue::U64(_) => (VarintValue::U64(0), 1),
407                    VarintValue::U128(_) => (VarintValue::U128(0), 1),
408                    VarintValue::I8(_) => (VarintValue::I8(0), 1),
409                    VarintValue::I16(_) => (VarintValue::I16(0), 1),
410                    VarintValue::I32(_) => (VarintValue::I32(0), 1),
411                    VarintValue::I64(_) => (VarintValue::I64(0), 1),
412                    VarintValue::I128(_) => (VarintValue::I128(0), 1),
413                }
414            });
415            
416            match (value, decoded) {
417                (VarintValue::U8(_), VarintValue::U8(_)) => {},
418                (VarintValue::U16(_), VarintValue::U16(_)) => {},
419                (VarintValue::U32(_), VarintValue::U32(_)) => {},
420                (VarintValue::U64(_), VarintValue::U64(_)) => {},
421                (VarintValue::U128(_), VarintValue::U128(_)) => {},
422                (VarintValue::I8(_), VarintValue::I8(_)) => {},
423                (VarintValue::I16(_), VarintValue::I16(_)) => {},
424                (VarintValue::I32(_), VarintValue::I32(_)) => {},
425                (VarintValue::I64(_), VarintValue::I64(_)) => {},
426                (VarintValue::I128(_), VarintValue::I128(_)) => {},
427                _ => panic!("Type mismatch: original {:?}, decoded {:?}", value, decoded),
428            }
429        }
430    }
431    
432    #[test]
433    fn test_varint_value_serialization() {
434        // Test unsigned types
435        let values = [
436            VarintValue::U8(42),
437            VarintValue::U16(1000),
438            VarintValue::U32(1000000),
439            VarintValue::U64(1000000000),
440            VarintValue::U128(u128::MAX / 2),
441        ];
442        
443        for value in &values {
444            let mut buffer = [0u8; 30];
445            let bytes_written = value.to_bytes(&mut buffer).unwrap();
446            let (decoded, bytes_read) = VarintValue::from_bytes(&buffer[..bytes_written]).unwrap();
447            
448            assert_eq!(*value, decoded);
449            assert_eq!(bytes_written, bytes_read);
450        }
451        
452        // Test signed types
453        let values = [
454            VarintValue::I8(-42),
455            VarintValue::I16(-1000),
456            VarintValue::I32(-1000000),
457            VarintValue::I64(-1000000000),
458            VarintValue::I128(i128::MIN / 2),
459        ];
460        
461        for value in &values {
462            let mut buffer = [0u8; 30];
463            let bytes_written = value.to_bytes(&mut buffer).unwrap();
464            let (decoded, bytes_read) = VarintValue::from_bytes(&buffer[..bytes_written]).unwrap();
465            
466            assert_eq!(*value, decoded);
467            assert_eq!(bytes_written, bytes_read);
468        }
469    }
470    
471    #[test]
472    fn test_zero_optimization() {
473        // Test zero values special encoding
474        let zero_values = [
475            VarintValue::U8(0),
476            VarintValue::U16(0),
477            VarintValue::U32(0),
478            VarintValue::U64(0),
479            VarintValue::U128(0),
480            VarintValue::I8(0),
481            VarintValue::I16(0),
482            VarintValue::I32(0),
483            VarintValue::I64(0),
484            VarintValue::I128(0),
485        ];
486        
487        for value in &zero_values {
488            let mut buffer = [0u8; 30];
489            let bytes_written = value.to_bytes(&mut buffer).unwrap();
490            
491            // Zero values should be encoded in 1 byte (just the type)
492            assert_eq!(bytes_written, 1, "Zero value {:?} should be encoded in 1 byte", value);
493            
494            let (decoded, bytes_read) = VarintValue::from_bytes(&buffer[..bytes_written]).unwrap();
495            assert_eq!(*value, decoded);
496            assert_eq!(bytes_written, bytes_read);
497        }
498    }
499    
500    #[test]
501    fn test_varint_macro() {
502        assert_eq!(varint!(u8: 42), VarintValue::U8(42));
503        assert_eq!(varint!(i16: -1000), VarintValue::I16(-1000));
504        assert_eq!(varint!(u32: 1000000), VarintValue::U32(1000000));
505        assert_eq!(varint!(i64: -1000000000), VarintValue::I64(-1000000000));
506    }
507    
508    #[test]
509    fn test_serialized_size() {
510        let value = VarintValue::U64(128);
511        assert_eq!(value.serialized_size(), 3); // 1 byte type + 2 bytes value
512        
513        let value = VarintValue::I32(-1);
514        assert_eq!(value.serialized_size(), 2); // 1 byte type + 1 byte zigzag value
515        
516        // Test zero value optimization
517        let value = VarintValue::U32(0);
518        assert_eq!(value.serialized_size(), 1); // Just 1 byte for type + 0
519    }
520    
521    #[test]
522    fn test_error_handling() {
523        let value = VarintValue::U64(1000000);
524        let mut small_buffer = [0u8; 2];
525        
526        // Buffer too small
527        assert!(value.to_bytes(&mut small_buffer).is_err());
528        
529        // Empty input
530        let empty: [u8; 0] = [];
531        assert!(VarintValue::from_bytes(&empty).is_err());
532        
533        // Invalid type ID
534        let invalid = [0xFF, 0x00];
535        assert!(VarintValue::from_bytes(&invalid).is_err());
536    }
537}
538
539// Extension trait for Result to help with unwrap_err_or_else in tests
540#[cfg(test)]
541trait ResultExt<T, E> {
542    fn unwrap_err_or_else<F>(self, f: F) -> T
543    where
544        F: FnOnce(&E) -> T;
545}
546
547#[cfg(test)]
548impl<T, E> ResultExt<T, E> for Result<T, E> {
549    fn unwrap_err_or_else<F>(self, f: F) -> T
550    where
551        F: FnOnce(&E) -> T,
552    {
553        match self {
554            Ok(t) => t,
555            Err(ref e) => f(e),
556        }
557    }
558}