Skip to main content

shape_ast/
int_width.rs

1//! IntWidth: shared width-semantics specification for first-class integer width types.
2//!
3//! This module is the single source of truth for integer width metadata: bit counts,
4//! signedness, masks, truncation, and width-joining rules. It lives in shape-ast
5//! (bottom of the dependency chain) so every crate can import it.
6//!
7//! `IntWidth` covers the sub-i64 and u64 widths. Plain `int` (i64) is NOT represented
8//! here — it remains the default integer type handled by existing codepaths.
9
10use serde::{Deserialize, Serialize};
11
12/// Integer width types with real width semantics.
13///
14/// Does NOT include i64 — that stays as the default `int` type.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
16pub enum IntWidth {
17    I8,
18    U8,
19    I16,
20    U16,
21    I32,
22    U32,
23    U64,
24}
25
26macro_rules! define_int_width_spec {
27    ($($variant:ident => {
28        bits: $bits:expr,
29        signed: $signed:expr,
30        mask: $mask:expr,
31        sign_shift: $sign_shift:expr,
32        min_i64: $min_i64:expr,
33        max_i64: $max_i64:expr,
34        max_u64: $max_u64:expr,
35        name: $name:expr,
36    };)*) => {
37        impl IntWidth {
38            /// All 7 width variants.
39            pub const ALL: [IntWidth; 7] = [
40                $(IntWidth::$variant,)*
41            ];
42
43            /// Number of bits in this width.
44            #[inline]
45            pub const fn bits(self) -> u32 {
46                match self {
47                    $(IntWidth::$variant => $bits,)*
48                }
49            }
50
51            /// Whether this is a signed integer type.
52            #[inline]
53            pub const fn is_signed(self) -> bool {
54                match self {
55                    $(IntWidth::$variant => $signed,)*
56                }
57            }
58
59            /// Whether this is an unsigned integer type.
60            #[inline]
61            pub const fn is_unsigned(self) -> bool {
62                !self.is_signed()
63            }
64
65            /// Bit mask for the value range (e.g., 0xFF for 8-bit).
66            #[inline]
67            pub const fn mask(self) -> u64 {
68                match self {
69                    $(IntWidth::$variant => $mask,)*
70                }
71            }
72
73            /// Bit position of the sign bit (e.g., 7 for i8).
74            #[inline]
75            pub const fn sign_shift(self) -> u32 {
76                match self {
77                    $(IntWidth::$variant => $sign_shift,)*
78                }
79            }
80
81            /// Minimum value representable as i64.
82            /// For unsigned types, this is 0.
83            #[inline]
84            pub const fn min_value(self) -> i64 {
85                match self {
86                    $(IntWidth::$variant => $min_i64,)*
87                }
88            }
89
90            /// Maximum value representable as i64.
91            /// For U64, this returns i64::MAX (the max *signed* portion).
92            #[inline]
93            pub const fn max_value(self) -> i64 {
94                match self {
95                    $(IntWidth::$variant => $max_i64,)*
96                }
97            }
98
99            /// Maximum value representable as u64 (meaningful for unsigned types).
100            #[inline]
101            pub const fn max_unsigned(self) -> u64 {
102                match self {
103                    $(IntWidth::$variant => $max_u64,)*
104                }
105            }
106
107            /// Human-readable type name (e.g., "i8", "u64").
108            #[inline]
109            pub const fn type_name(self) -> &'static str {
110                match self {
111                    $(IntWidth::$variant => $name,)*
112                }
113            }
114
115            /// Canonical truncation: wraps an i64 value to this width using
116            /// two's complement semantics.
117            ///
118            /// For signed types: mask then sign-extend.
119            /// For U64: identity (no truncation needed for i64→u64 bit reinterpret).
120            /// For other unsigned: just mask.
121            #[inline]
122            pub const fn truncate(self, value: i64) -> i64 {
123                match self {
124                    $(IntWidth::$variant => {
125                        if $signed {
126                            // Mask to width, then sign-extend
127                            let masked = (value as u64) & $mask;
128                            // Sign-extend: if sign bit set, fill upper bits
129                            if masked & (1u64 << $sign_shift) != 0 {
130                                (masked | !$mask) as i64
131                            } else {
132                                masked as i64
133                            }
134                        } else if $bits == 64 {
135                            // U64: no truncation, value is reinterpreted
136                            value
137                        } else {
138                            // Unsigned sub-64: just mask (always positive in i64)
139                            ((value as u64) & $mask) as i64
140                        }
141                    })*
142                }
143            }
144
145            /// Unsigned-safe truncation: wraps a u64 value to this width.
146            ///
147            /// For signed types: mask then sign-extend (returned as u64 bit pattern).
148            /// For unsigned types: just mask.
149            #[inline]
150            pub const fn truncate_u64(self, value: u64) -> u64 {
151                match self {
152                    $(IntWidth::$variant => {
153                        if $bits == 64 {
154                            value // U64 or I64-width: identity
155                        } else if $signed {
156                            let masked = value & $mask;
157                            if masked & (1u64 << $sign_shift) != 0 {
158                                masked | !$mask
159                            } else {
160                                masked
161                            }
162                        } else {
163                            value & $mask
164                        }
165                    })*
166                }
167            }
168
169            /// Parse a width name (e.g., "i8", "u64") to an IntWidth.
170            pub fn from_name(name: &str) -> Option<IntWidth> {
171                match name {
172                    $($name => Some(IntWidth::$variant),)*
173                    _ => None,
174                }
175            }
176        }
177    };
178}
179
180define_int_width_spec! {
181    I8 => {
182        bits: 8,
183        signed: true,
184        mask: 0xFF_u64,
185        sign_shift: 7,
186        min_i64: -128_i64,
187        max_i64: 127_i64,
188        max_u64: 127_u64,
189        name: "i8",
190    };
191    U8 => {
192        bits: 8,
193        signed: false,
194        mask: 0xFF_u64,
195        sign_shift: 7,
196        min_i64: 0_i64,
197        max_i64: 255_i64,
198        max_u64: 255_u64,
199        name: "u8",
200    };
201    I16 => {
202        bits: 16,
203        signed: true,
204        mask: 0xFFFF_u64,
205        sign_shift: 15,
206        min_i64: -32768_i64,
207        max_i64: 32767_i64,
208        max_u64: 32767_u64,
209        name: "i16",
210    };
211    U16 => {
212        bits: 16,
213        signed: false,
214        mask: 0xFFFF_u64,
215        sign_shift: 15,
216        min_i64: 0_i64,
217        max_i64: 65535_i64,
218        max_u64: 65535_u64,
219        name: "u16",
220    };
221    I32 => {
222        bits: 32,
223        signed: true,
224        mask: 0xFFFF_FFFF_u64,
225        sign_shift: 31,
226        min_i64: -2147483648_i64,
227        max_i64: 2147483647_i64,
228        max_u64: 2147483647_u64,
229        name: "i32",
230    };
231    U32 => {
232        bits: 32,
233        signed: false,
234        mask: 0xFFFF_FFFF_u64,
235        sign_shift: 31,
236        min_i64: 0_i64,
237        max_i64: 4294967295_i64,
238        max_u64: 4294967295_u64,
239        name: "u32",
240    };
241    U64 => {
242        bits: 64,
243        signed: false,
244        mask: u64::MAX,
245        sign_shift: 63,
246        min_i64: 0_i64,
247        max_i64: i64::MAX,
248        max_u64: u64::MAX,
249        name: "u64",
250    };
251}
252
253impl IntWidth {
254    /// Join two widths for mixed-width arithmetic.
255    ///
256    /// Rules:
257    /// - Same width → Ok(same)
258    /// - Different widths, same signedness → Ok(wider)
259    /// - Mixed sign: u8+i8→I16, u16+i16→I32, u32+i32→I64 (widen to next signed)
260    /// - **u64 + any signed → Err(())** (compile error — no safe widening)
261    pub fn join(a: IntWidth, b: IntWidth) -> Result<IntWidth, ()> {
262        if a == b {
263            return Ok(a);
264        }
265
266        // Same signedness: pick wider
267        if a.is_signed() == b.is_signed() {
268            return Ok(if a.bits() >= b.bits() { a } else { b });
269        }
270
271        // Mixed sign: identify unsigned and signed
272        let (unsigned, signed) = if a.is_unsigned() { (a, b) } else { (b, a) };
273
274        // u64 + any signed → error
275        if unsigned == IntWidth::U64 {
276            return Err(());
277        }
278
279        // Widen to next signed width that fits both
280        match (unsigned, signed) {
281            // u8 (0..255) + i8 (-128..127) → i16 (-32768..32767)
282            (IntWidth::U8, IntWidth::I8) => Ok(IntWidth::I16),
283            // u8 + i16/i32 → the signed type is already wide enough
284            (IntWidth::U8, s) => Ok(s),
285
286            // u16 (0..65535) + i8/i16 → i32
287            (IntWidth::U16, IntWidth::I8 | IntWidth::I16) => Ok(IntWidth::I32),
288            // u16 + i32 → i32 is wide enough
289            (IntWidth::U16, IntWidth::I32) => Ok(IntWidth::I32),
290
291            // u32 (0..4B) + i8/i16/i32 → need i64 (default int)
292            // Return Err to signal "promote to i64" since IntWidth doesn't include i64
293            (IntWidth::U32, _) => Err(()),
294
295            _ => Err(()),
296        }
297    }
298
299    /// Check if a given i64 value is in range for this width.
300    #[inline]
301    pub const fn in_range_i64(self, value: i64) -> bool {
302        value >= self.min_value() && value <= self.max_value()
303    }
304
305    /// Check if a given u64 value is in range for this width.
306    #[inline]
307    pub const fn in_range_u64(self, value: u64) -> bool {
308        value <= self.max_unsigned()
309    }
310}
311
312impl std::fmt::Display for IntWidth {
313    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314        f.write_str(self.type_name())
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn truncate_i8_boundaries() {
324        assert_eq!(IntWidth::I8.truncate(127), 127);
325        assert_eq!(IntWidth::I8.truncate(128), -128);
326        assert_eq!(IntWidth::I8.truncate(-128), -128);
327        assert_eq!(IntWidth::I8.truncate(-129), 127);
328        assert_eq!(IntWidth::I8.truncate(255), -1);
329        assert_eq!(IntWidth::I8.truncate(256), 0);
330    }
331
332    #[test]
333    fn truncate_u8_boundaries() {
334        assert_eq!(IntWidth::U8.truncate(0), 0);
335        assert_eq!(IntWidth::U8.truncate(255), 255);
336        assert_eq!(IntWidth::U8.truncate(256), 0);
337        assert_eq!(IntWidth::U8.truncate(-1), 255);
338    }
339
340    #[test]
341    fn truncate_i16_boundaries() {
342        assert_eq!(IntWidth::I16.truncate(32767), 32767);
343        assert_eq!(IntWidth::I16.truncate(32768), -32768);
344        assert_eq!(IntWidth::I16.truncate(-32768), -32768);
345        assert_eq!(IntWidth::I16.truncate(-32769), 32767);
346    }
347
348    #[test]
349    fn truncate_u16_boundaries() {
350        assert_eq!(IntWidth::U16.truncate(0), 0);
351        assert_eq!(IntWidth::U16.truncate(65535), 65535);
352        assert_eq!(IntWidth::U16.truncate(65536), 0);
353        assert_eq!(IntWidth::U16.truncate(-1), 65535);
354    }
355
356    #[test]
357    fn truncate_i32_boundaries() {
358        assert_eq!(IntWidth::I32.truncate(2147483647), 2147483647);
359        assert_eq!(IntWidth::I32.truncate(2147483648), -2147483648);
360        assert_eq!(IntWidth::I32.truncate(-2147483648), -2147483648);
361    }
362
363    #[test]
364    fn truncate_u32_boundaries() {
365        assert_eq!(IntWidth::U32.truncate(0), 0);
366        assert_eq!(IntWidth::U32.truncate(4294967295), 4294967295);
367        assert_eq!(IntWidth::U32.truncate(4294967296), 0);
368        assert_eq!(IntWidth::U32.truncate(-1), 4294967295);
369    }
370
371    #[test]
372    fn truncate_u64_identity() {
373        assert_eq!(IntWidth::U64.truncate(0), 0);
374        assert_eq!(IntWidth::U64.truncate(i64::MAX), i64::MAX);
375        assert_eq!(IntWidth::U64.truncate(-1), -1); // bit pattern preserved
376    }
377
378    #[test]
379    fn truncate_u64_unsigned() {
380        assert_eq!(IntWidth::U64.truncate_u64(0), 0);
381        assert_eq!(IntWidth::U64.truncate_u64(u64::MAX), u64::MAX);
382        assert_eq!(IntWidth::U64.truncate_u64(u64::MAX - 1), u64::MAX - 1);
383    }
384
385    #[test]
386    fn join_same_width() {
387        assert_eq!(IntWidth::join(IntWidth::I8, IntWidth::I8), Ok(IntWidth::I8));
388        assert_eq!(
389            IntWidth::join(IntWidth::U64, IntWidth::U64),
390            Ok(IntWidth::U64)
391        );
392    }
393
394    #[test]
395    fn join_same_sign_different_width() {
396        assert_eq!(
397            IntWidth::join(IntWidth::I8, IntWidth::I16),
398            Ok(IntWidth::I16)
399        );
400        assert_eq!(
401            IntWidth::join(IntWidth::I16, IntWidth::I32),
402            Ok(IntWidth::I32)
403        );
404        assert_eq!(
405            IntWidth::join(IntWidth::U8, IntWidth::U16),
406            Ok(IntWidth::U16)
407        );
408        assert_eq!(
409            IntWidth::join(IntWidth::U16, IntWidth::U32),
410            Ok(IntWidth::U32)
411        );
412    }
413
414    #[test]
415    fn join_mixed_sign_widens() {
416        assert_eq!(
417            IntWidth::join(IntWidth::U8, IntWidth::I8),
418            Ok(IntWidth::I16)
419        );
420        assert_eq!(
421            IntWidth::join(IntWidth::I8, IntWidth::U8),
422            Ok(IntWidth::I16)
423        );
424        assert_eq!(
425            IntWidth::join(IntWidth::U16, IntWidth::I16),
426            Ok(IntWidth::I32)
427        );
428        assert_eq!(
429            IntWidth::join(IntWidth::U8, IntWidth::I16),
430            Ok(IntWidth::I16)
431        );
432        assert_eq!(
433            IntWidth::join(IntWidth::U8, IntWidth::I32),
434            Ok(IntWidth::I32)
435        );
436        assert_eq!(
437            IntWidth::join(IntWidth::U16, IntWidth::I32),
438            Ok(IntWidth::I32)
439        );
440    }
441
442    #[test]
443    fn join_u64_signed_error() {
444        assert_eq!(IntWidth::join(IntWidth::U64, IntWidth::I8), Err(()));
445        assert_eq!(IntWidth::join(IntWidth::U64, IntWidth::I16), Err(()));
446        assert_eq!(IntWidth::join(IntWidth::U64, IntWidth::I32), Err(()));
447        assert_eq!(IntWidth::join(IntWidth::I8, IntWidth::U64), Err(()));
448    }
449
450    #[test]
451    fn join_u32_signed_promotes_to_i64() {
452        // u32 + any signed → Err (needs i64, which is outside IntWidth)
453        assert_eq!(IntWidth::join(IntWidth::U32, IntWidth::I8), Err(()));
454        assert_eq!(IntWidth::join(IntWidth::U32, IntWidth::I32), Err(()));
455    }
456
457    #[test]
458    fn from_name_roundtrip() {
459        for w in IntWidth::ALL {
460            assert_eq!(IntWidth::from_name(w.type_name()), Some(w));
461        }
462        assert_eq!(IntWidth::from_name("i64"), None);
463        assert_eq!(IntWidth::from_name("float"), None);
464    }
465
466    #[test]
467    fn in_range_checks() {
468        assert!(IntWidth::I8.in_range_i64(0));
469        assert!(IntWidth::I8.in_range_i64(127));
470        assert!(IntWidth::I8.in_range_i64(-128));
471        assert!(!IntWidth::I8.in_range_i64(128));
472        assert!(!IntWidth::I8.in_range_i64(-129));
473
474        assert!(IntWidth::U8.in_range_i64(0));
475        assert!(IntWidth::U8.in_range_i64(255));
476        assert!(!IntWidth::U8.in_range_i64(-1));
477        assert!(!IntWidth::U8.in_range_i64(256));
478
479        assert!(IntWidth::U64.in_range_u64(u64::MAX));
480        assert!(IntWidth::U64.in_range_u64(0));
481    }
482
483    #[test]
484    fn display_impl() {
485        assert_eq!(format!("{}", IntWidth::I8), "i8");
486        assert_eq!(format!("{}", IntWidth::U64), "u64");
487    }
488}