Skip to main content

svod_codegen/llvm/common/
types.rs

1//! LLVM type and constant string generation.
2//!
3//! Provides functions for converting Svod types to LLVM IR text.
4//! Shared between CPU and GPU backends.
5
6use svod_dtype::{AddrSpace, DType, ScalarDType};
7use svod_ir::ConstValue;
8
9/// Convert a DType to LLVM type string.
10///
11/// Uses LLVM opaque pointer mode: all pointers are `ptr`, vectors of
12/// pointers are `<N x ptr>`. Typed pointer syntax (`float*`) is not emitted.
13pub fn ldt(dtype: &DType) -> String {
14    match dtype {
15        DType::Vector { scalar, count } => {
16            format!("<{} x {}>", count, ldt_scalar(*scalar))
17        }
18        DType::Ptr { vcount, .. } if *vcount > 1 => {
19            format!("<{} x ptr>", vcount)
20        }
21        DType::Ptr { .. } | DType::Image { .. } => "ptr".to_string(),
22        DType::Scalar(s) => ldt_scalar(*s).to_string(),
23    }
24}
25
26/// Convert a ScalarDType to LLVM type string.
27fn ldt_scalar(s: ScalarDType) -> &'static str {
28    match s {
29        ScalarDType::Bool => "i1",
30        ScalarDType::Int8 | ScalarDType::UInt8 => "i8",
31        ScalarDType::Int16 | ScalarDType::UInt16 => "i16",
32        ScalarDType::Int32 | ScalarDType::UInt32 => "i32",
33        ScalarDType::Int64 | ScalarDType::UInt64 | ScalarDType::Index => "i64",
34        ScalarDType::Float16 => "half",
35        ScalarDType::BFloat16 => "bfloat",
36        ScalarDType::Float32 => "float",
37        ScalarDType::Float64 => "double",
38        ScalarDType::Void => "void",
39        ScalarDType::FP8E4M3 | ScalarDType::FP8E5M2 => "i8",
40    }
41}
42
43/// Convert a constant value to LLVM literal string.
44pub fn lconst(val: &ConstValue, dtype: &DType) -> String {
45    match val {
46        ConstValue::Int(i) => i.to_string(),
47        ConstValue::UInt(u) => (*u as i64).to_string(),
48        ConstValue::Float(f) => format_float(*f, dtype),
49        ConstValue::Bool(b) => if *b { "1" } else { "0" }.to_string(),
50    }
51}
52
53/// Format a float value for LLVM IR.
54fn format_float(f: f64, dtype: &DType) -> String {
55    let scalar = dtype.base();
56
57    if f.is_nan() {
58        // LLVM expects NaN in double-precision hex format for all float types
59        return match scalar {
60            ScalarDType::Float64 | ScalarDType::Float32 => "0x7FF8000000000000".to_string(),
61            ScalarDType::Float16 => "0xH7E00".to_string(),
62            ScalarDType::BFloat16 => "0xR7FC0".to_string(),
63            _ => "nan".to_string(),
64        };
65    }
66
67    if f.is_infinite() {
68        // LLVM expects infinity in hex format with sign encoded in bits
69        return match scalar {
70            ScalarDType::Float64 | ScalarDType::Float32 => {
71                // Use bit representation (sign is encoded in the high bit)
72                // +inf = 0x7FF0000000000000, -inf = 0xFFF0000000000000
73                format!("0x{:016X}", f.to_bits())
74            }
75            ScalarDType::Float16 => {
76                // Half precision: +inf = 0x7C00, -inf = 0xFC00
77                if f.is_sign_positive() { "0xH7C00".to_string() } else { "0xHFC00".to_string() }
78            }
79            ScalarDType::BFloat16 => {
80                // BFloat16: +inf = 0x7F80, -inf = 0xFF80
81                if f.is_sign_positive() { "0xR7F80".to_string() } else { "0xRFF80".to_string() }
82            }
83            _ => {
84                if f.is_sign_positive() {
85                    "inf".to_string()
86                } else {
87                    "-inf".to_string()
88                }
89            }
90        };
91    }
92
93    match scalar {
94        ScalarDType::Float64 => {
95            format!("0x{:016X}", f.to_bits())
96        }
97        ScalarDType::Float32 => {
98            // LLVM expects float32 constants in double-precision hex format
99            // Convert to f32 for precision, then back to f64 for LLVM encoding
100            let f32_val = f as f32;
101            let f64_val = f32_val as f64;
102            format!("0x{:016X}", f64_val.to_bits())
103        }
104        ScalarDType::Float16 => {
105            let f32_val = f as f32;
106            let half_bits = f32_to_f16_bits(f32_val);
107            format!("0xH{:04X}", half_bits)
108        }
109        ScalarDType::BFloat16 => {
110            let f32_val = f as f32;
111            let bf16_bits = (f32_val.to_bits() >> 16) as u16;
112            format!("0xR{:04X}", bf16_bits)
113        }
114        _ => format!("{:e}", f),
115    }
116}
117
118/// Convert f32 to f16 bits (IEEE 754 half precision).
119fn f32_to_f16_bits(f: f32) -> u16 {
120    let bits = f.to_bits();
121    let sign = ((bits >> 16) & 0x8000) as u16;
122    let exp = ((bits >> 23) & 0xFF) as i32;
123    let mant = bits & 0x007FFFFF;
124
125    if exp == 255 {
126        if mant == 0 { sign | 0x7C00 } else { sign | 0x7E00 }
127    } else if exp > 142 {
128        sign | 0x7C00
129    } else if exp < 113 {
130        if exp < 103 {
131            sign
132        } else {
133            let mant = mant | 0x00800000;
134            let shift = 126 - exp;
135            sign | ((mant >> shift) as u16)
136        }
137    } else {
138        let new_exp = ((exp - 127 + 15) as u16) << 10;
139        let new_mant = (mant >> 13) as u16;
140        sign | new_exp | new_mant
141    }
142}
143
144/// Get LLVM cast instruction name for a type conversion.
145///
146/// FP8 (E4M3/E5M2) types are mapped to `i8` in LLVM and cannot use `fpext`/`fptrunc`;
147/// FP8↔Float must be decomposed via the devectorize fp8 patterns before reaching LLVM,
148/// matching tinygrad's dedicated `f32_to_fp8` / `cvt.f32.fp8` intrinsics (`llvmir.py:226-230`).
149pub fn lcast(from: &DType, to: &DType) -> &'static str {
150    let from_scalar = from.base();
151    let to_scalar = to.base();
152
153    debug_assert!(
154        !(from_scalar.is_fp8() || to_scalar.is_fp8()),
155        "lcast does not support FP8 (mapped to i8); decompose via devectorize fp8 patterns first"
156    );
157
158    if matches!(from, DType::Ptr { .. }) || matches!(to, DType::Ptr { .. }) {
159        return if matches!(from, DType::Ptr { .. }) && matches!(to, DType::Ptr { .. }) {
160            "bitcast"
161        } else if matches!(from, DType::Ptr { .. }) {
162            "ptrtoint"
163        } else {
164            "inttoptr"
165        };
166    }
167
168    if from_scalar.is_float() && to_scalar.is_float() {
169        return if to_scalar.bytes() > from_scalar.bytes() { "fpext" } else { "fptrunc" };
170    }
171
172    if (from_scalar.is_unsigned() || from_scalar.is_bool()) && to_scalar.is_float() {
173        return "uitofp";
174    }
175    if (from_scalar.is_signed() || from_scalar == ScalarDType::Index) && to_scalar.is_float() {
176        return "sitofp";
177    }
178
179    if from_scalar.is_float() && to_scalar.is_unsigned() {
180        return "fptoui";
181    }
182    if from_scalar.is_float() && (to_scalar.is_signed() || to_scalar == ScalarDType::Index) {
183        return "fptosi";
184    }
185
186    // Integer-to-integer casts
187    let from_bytes = from_scalar.bytes();
188    let to_bytes = to_scalar.bytes();
189
190    // Bool (i1) to any integer type needs zext - i1 is always smaller than i8+
191    // Note: Bool.bytes() returns 1 (storage size) but LLVM i1 is 1 bit, not 1 byte
192    if from_scalar.is_bool() && !to_scalar.is_bool() {
193        return "zext";
194    }
195
196    // Any integer to Bool needs trunc - truncate to 1 bit
197    if !from_scalar.is_bool() && to_scalar.is_bool() {
198        return "trunc";
199    }
200
201    // Same size: bitcast (handles signed↔unsigned same-size casts)
202    if from_bytes == to_bytes {
203        return "bitcast";
204    }
205
206    // Narrowing: always trunc
207    if to_bytes < from_bytes {
208        return "trunc";
209    }
210
211    // Widening: use zext for unsigned/bool, sext for signed/Index
212    if from_scalar.is_unsigned() || from_scalar.is_bool() {
213        return "zext";
214    }
215
216    // Index type is treated as signed integer for casting purposes
217    if from_scalar.is_signed() || from_scalar == ScalarDType::Index {
218        return "sext";
219    }
220
221    "bitcast"
222}
223
224/// Get LLVM address space number.
225pub fn addr_space_num(addrspace: AddrSpace) -> u32 {
226    match addrspace {
227        AddrSpace::Global => 0,
228        AddrSpace::Local => 3,
229        AddrSpace::Reg => 5,
230    }
231}
232
233#[cfg(test)]
234#[path = "../../test/unit/llvm_common_types.rs"]
235mod tests;