Skip to main content

trit_vsa/
word.rs

1//! Word type representing 6 trits (values -364 to +364).
2//!
3//! A Word6 consists of 6 trits, capable of representing 729 distinct values (3^6).
4//! This is useful for balanced ternary arithmetic at the word level.
5
6use serde::{Deserialize, Serialize};
7use std::fmt;
8use std::ops::{Add, Mul, Neg};
9
10use crate::error::{Result, TernaryError};
11use crate::trit::Trit;
12use crate::tryte::Tryte3;
13
14/// Minimum value representable by a Word6 (-364).
15pub const WORD6_MIN: i32 = -364;
16/// Maximum value representable by a Word6 (+364).
17pub const WORD6_MAX: i32 = 364;
18
19/// A balanced ternary word consisting of 6 trits.
20///
21/// # Value Range
22///
23/// A Word6 can represent values from -364 to +364:
24/// ```text
25/// Value = t0*1 + t1*3 + t2*9 + t3*27 + t4*81 + t5*243
26/// Min: -1 - 3 - 9 - 27 - 81 - 243 = -364
27/// Max: +1 + 3 + 9 + 27 + 81 + 243 = +364
28/// ```
29///
30/// # Internal Representation
31///
32/// Stored as a u16 with 2 bits per trit:
33/// - Bits 0-1: trit 0 (least significant)
34/// - Bits 2-3: trit 1
35/// - ...
36/// - Bits 10-11: trit 5 (most significant)
37///
38/// # Examples
39///
40/// ```
41/// use trit_vsa::Word6;
42///
43/// let w = Word6::from_value(100).unwrap();
44/// assert_eq!(w.value(), 100);
45///
46/// let w_neg = -w;
47/// assert_eq!(w_neg.value(), -100);
48/// ```
49#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
50pub struct Word6(u16);
51
52impl Word6 {
53    /// Create a word from an integer value.
54    ///
55    /// # Arguments
56    ///
57    /// * `value` - Integer value (-364 to +364)
58    ///
59    /// # Errors
60    ///
61    /// Returns `TernaryError::InvalidWordValue` if value is outside range.
62    ///
63    /// # Examples
64    ///
65    /// ```
66    /// use trit_vsa::Word6;
67    ///
68    /// let w = Word6::from_value(123).unwrap();
69    /// assert_eq!(w.value(), 123);
70    ///
71    /// assert!(Word6::from_value(365).is_err());
72    /// assert!(Word6::from_value(-365).is_err());
73    /// ```
74    pub fn from_value(value: i32) -> Result<Self> {
75        if !(WORD6_MIN..=WORD6_MAX).contains(&value) {
76            return Err(TernaryError::InvalidWordValue(value));
77        }
78
79        let trits = Self::value_to_trits(value);
80        Ok(Self::from_trits(trits))
81    }
82
83    /// Create a word from six trits.
84    ///
85    /// # Arguments
86    ///
87    /// * `trits` - Array of 6 trits [t0, t1, t2, t3, t4, t5] where t0 is least significant
88    #[must_use]
89    pub fn from_trits(trits: [Trit; 6]) -> Self {
90        let mut encoded: u16 = 0;
91        for (i, &trit) in trits.iter().enumerate() {
92            encoded |= (Self::encode_trit(trit) as u16) << (i * 2);
93        }
94        Self(encoded)
95    }
96
97    /// Create a word from two trytes.
98    ///
99    /// # Arguments
100    ///
101    /// * `low` - Low tryte (trits 0-2)
102    /// * `high` - High tryte (trits 3-5)
103    #[must_use]
104    pub fn from_trytes(low: Tryte3, high: Tryte3) -> Self {
105        let low_trits = low.to_trits();
106        let high_trits = high.to_trits();
107        Self::from_trits([
108            low_trits[0],
109            low_trits[1],
110            low_trits[2],
111            high_trits[0],
112            high_trits[1],
113            high_trits[2],
114        ])
115    }
116
117    /// Get the integer value of the word.
118    #[must_use]
119    pub fn value(self) -> i32 {
120        let trits = self.to_trits();
121        let mut result: i32 = 0;
122        let mut power: i32 = 1;
123        for trit in trits {
124            result += trit.value() as i32 * power;
125            power *= 3;
126        }
127        result
128    }
129
130    /// Extract the six trits.
131    ///
132    /// # Returns
133    ///
134    /// Array [t0, ..., t5] where t0 is least significant.
135    #[must_use]
136    pub fn to_trits(self) -> [Trit; 6] {
137        let mut trits = [Trit::Z; 6];
138        for (i, trit) in trits.iter_mut().enumerate() {
139            *trit = Self::decode_trit(((self.0 >> (i * 2)) & 0b11) as u8);
140        }
141        trits
142    }
143
144    /// Split into two trytes.
145    ///
146    /// # Returns
147    ///
148    /// Tuple `(low, high)` where low contains trits 0-2 and high contains trits 3-5.
149    #[must_use]
150    pub fn to_trytes(self) -> (Tryte3, Tryte3) {
151        let trits = self.to_trits();
152        (
153            Tryte3::from_trits([trits[0], trits[1], trits[2]]),
154            Tryte3::from_trits([trits[3], trits[4], trits[5]]),
155        )
156    }
157
158    /// Get a specific trit by index.
159    ///
160    /// # Arguments
161    ///
162    /// * `index` - Trit index (0-5)
163    ///
164    /// # Panics
165    ///
166    /// Panics if index >= 6.
167    #[must_use]
168    pub fn get_trit(self, index: usize) -> Trit {
169        assert!(index < 6, "trit index out of bounds");
170        Self::decode_trit(((self.0 >> (index * 2)) & 0b11) as u8)
171    }
172
173    /// Create a zero word.
174    #[must_use]
175    pub const fn zero() -> Self {
176        // All trits are Z (encoded as 1): 01|01|01|01|01|01
177        Self(0b01_01_01_01_01_01)
178    }
179
180    /// Check if the word is zero.
181    #[must_use]
182    pub fn is_zero(self) -> bool {
183        self.value() == 0
184    }
185
186    /// Get the raw packed representation.
187    #[must_use]
188    pub const fn raw(self) -> u16 {
189        self.0
190    }
191
192    // Internal: convert value to trit array using balanced ternary conversion
193    fn value_to_trits(mut value: i32) -> [Trit; 6] {
194        let mut trits = [Trit::Z; 6];
195
196        for trit in &mut trits {
197            if value == 0 {
198                *trit = Trit::Z;
199                continue;
200            }
201
202            let mut rem = value % 3;
203            value /= 3;
204
205            // Adjust for balanced ternary
206            if rem == 2 {
207                rem = -1;
208                value += 1;
209            } else if rem == -2 {
210                rem = 1;
211                value -= 1;
212            }
213
214            *trit = match rem {
215                -1 => Trit::N,
216                0 => Trit::Z,
217                1 => Trit::P,
218                _ => unreachable!(),
219            };
220        }
221
222        trits
223    }
224
225    fn encode_trit(trit: Trit) -> u8 {
226        match trit {
227            Trit::N => 0,
228            Trit::Z => 1,
229            Trit::P => 2,
230        }
231    }
232
233    fn decode_trit(bits: u8) -> Trit {
234        match bits & 0b11 {
235            0 => Trit::N,
236            1 | 3 => Trit::Z, // 3 is invalid, treat as zero
237            2 => Trit::P,
238            _ => unreachable!(),
239        }
240    }
241}
242
243impl Default for Word6 {
244    fn default() -> Self {
245        Self::zero()
246    }
247}
248
249impl Neg for Word6 {
250    type Output = Self;
251
252    fn neg(self) -> Self::Output {
253        let trits = self.to_trits();
254        Self::from_trits([
255            -trits[0], -trits[1], -trits[2], -trits[3], -trits[4], -trits[5],
256        ])
257    }
258}
259
260impl Add for Word6 {
261    type Output = (Self, Trit);
262
263    /// Add two words, returning (result, carry).
264    fn add(self, other: Self) -> Self::Output {
265        let a = self.to_trits();
266        let b = other.to_trits();
267        let mut result = [Trit::Z; 6];
268        let mut carry = Trit::Z;
269
270        for i in 0..6 {
271            let (sum1, carry1) = a[i].add_with_carry(b[i]);
272            let (sum2, carry2) = sum1.add_with_carry(carry);
273
274            result[i] = sum2;
275            let (carry_sum, _) = carry1.add_with_carry(carry2);
276            carry = carry_sum;
277        }
278
279        (Self::from_trits(result), carry)
280    }
281}
282
283impl Mul for Word6 {
284    type Output = (Self, Self);
285
286    /// Multiply two words, returning (low, high) result.
287    ///
288    /// The full result is `low + high * 729`.
289    fn mul(self, other: Self) -> Self::Output {
290        let product = self.value() as i64 * other.value() as i64;
291
292        // Split into low and high words
293        let low_val = ((product % 729) + 729 + 364) % 729 - 364;
294        let high_val = (product - low_val) / 729;
295
296        (
297            Self::from_value(low_val as i32).unwrap_or_else(|_| Self::zero()),
298            Self::from_value(high_val.clamp(WORD6_MIN as i64, WORD6_MAX as i64) as i32)
299                .unwrap_or_else(|_| Self::zero()),
300        )
301    }
302}
303
304impl fmt::Debug for Word6 {
305    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
306        let trits = self.to_trits();
307        write!(
308            f,
309            "Word6({}{}{}{}{}{} = {})",
310            trits[5],
311            trits[4],
312            trits[3],
313            trits[2],
314            trits[1],
315            trits[0],
316            self.value()
317        )
318    }
319}
320
321impl fmt::Display for Word6 {
322    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
323        write!(f, "{}", self.value())
324    }
325}
326
327impl TryFrom<i32> for Word6 {
328    type Error = TernaryError;
329
330    fn try_from(value: i32) -> Result<Self> {
331        Self::from_value(value)
332    }
333}
334
335impl From<Word6> for i32 {
336    fn from(word: Word6) -> Self {
337        word.value()
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn test_word_range() {
347        // Test boundary values
348        assert!(Word6::from_value(WORD6_MIN).is_ok());
349        assert!(Word6::from_value(WORD6_MAX).is_ok());
350        assert!(Word6::from_value(WORD6_MIN - 1).is_err());
351        assert!(Word6::from_value(WORD6_MAX + 1).is_err());
352
353        // Test roundtrip for all values (this is a larger test)
354        for v in WORD6_MIN..=WORD6_MAX {
355            let w = Word6::from_value(v).expect("valid value");
356            assert_eq!(w.value(), v, "failed for value {v}");
357        }
358    }
359
360    #[test]
361    fn test_word_zero() {
362        let z = Word6::zero();
363        assert_eq!(z.value(), 0);
364        assert!(z.is_zero());
365    }
366
367    #[test]
368    fn test_word_negation() {
369        let test_values = [0, 1, -1, 100, -100, 364, -364];
370        for v in test_values {
371            let w = Word6::from_value(v).unwrap();
372            let neg = -w;
373            assert_eq!(neg.value(), -v);
374        }
375    }
376
377    #[test]
378    fn test_word_addition() {
379        // Simple addition
380        let a = Word6::from_value(100).unwrap();
381        let b = Word6::from_value(50).unwrap();
382        let (result, carry) = a + b;
383        assert_eq!(result.value() + carry.value() as i32 * 729, 150);
384
385        // Addition with overflow
386        let a = Word6::from_value(300).unwrap();
387        let b = Word6::from_value(200).unwrap();
388        let (result, carry) = a + b;
389        let total = result.value() + carry.value() as i32 * 729;
390        assert_eq!(total, 500);
391    }
392
393    #[test]
394    fn test_word_multiplication() {
395        let a = Word6::from_value(10).unwrap();
396        let b = Word6::from_value(20).unwrap();
397        let (low, high) = a * b;
398        let total = low.value() + high.value() * 729;
399        assert_eq!(total, 200);
400
401        // Larger multiplication
402        let a = Word6::from_value(50).unwrap();
403        let b = Word6::from_value(50).unwrap();
404        let (low, high) = a * b;
405        let total = low.value() + high.value() * 729;
406        assert_eq!(total, 2500);
407    }
408
409    #[test]
410    fn test_word_tryte_conversion() {
411        let w = Word6::from_value(100).unwrap();
412        let (low, high) = w.to_trytes();
413
414        // Reconstruct
415        let reconstructed = Word6::from_trytes(low, high);
416        assert_eq!(reconstructed.value(), 100);
417    }
418
419    #[test]
420    fn test_word_get_trit() {
421        let w = Word6::from_trits([Trit::N, Trit::Z, Trit::P, Trit::N, Trit::Z, Trit::P]);
422        assert_eq!(w.get_trit(0), Trit::N);
423        assert_eq!(w.get_trit(1), Trit::Z);
424        assert_eq!(w.get_trit(2), Trit::P);
425        assert_eq!(w.get_trit(3), Trit::N);
426        assert_eq!(w.get_trit(4), Trit::Z);
427        assert_eq!(w.get_trit(5), Trit::P);
428    }
429}