Skip to main content

svod_codegen/c/
types.rs

1//! C type mapping and constant rendering for the C codegen backend.
2
3use std::collections::BTreeSet;
4use std::sync::Arc;
5
6use svod_dtype::{DType, ScalarDType};
7use svod_ir::{ConstValue, UOp};
8
9/// Convert a DType to its C scalar type string.
10pub fn c_scalar(s: ScalarDType) -> &'static str {
11    match s {
12        ScalarDType::Bool => "_Bool",
13        ScalarDType::Int8 => "signed char",
14        ScalarDType::UInt8 => "unsigned char",
15        ScalarDType::Int16 => "short",
16        ScalarDType::UInt16 => "unsigned short",
17        ScalarDType::Int32 => "int",
18        ScalarDType::UInt32 => "unsigned int",
19        ScalarDType::Int64 | ScalarDType::Index => "long long",
20        ScalarDType::UInt64 => "unsigned long long",
21        ScalarDType::Float16 => "_Float16",
22        ScalarDType::BFloat16 => "__bf16",
23        ScalarDType::Float32 => "float",
24        ScalarDType::Float64 => "double",
25        ScalarDType::Void => "void",
26        ScalarDType::FP8E4M3 | ScalarDType::FP8E5M2 => "unsigned char",
27    }
28}
29
30/// Space-free identifier base for vector typedef names (e.g. `uchar4`, `llong2`).
31fn c_vector_base(s: ScalarDType) -> &'static str {
32    match s {
33        ScalarDType::Bool => "bool",
34        ScalarDType::Int8 => "schar",
35        ScalarDType::UInt8 | ScalarDType::FP8E4M3 | ScalarDType::FP8E5M2 => "uchar",
36        ScalarDType::Int16 => "short",
37        ScalarDType::UInt16 => "ushort",
38        ScalarDType::Int32 => "int",
39        ScalarDType::UInt32 => "uint",
40        ScalarDType::Int64 | ScalarDType::Index => "llong",
41        ScalarDType::UInt64 => "ullong",
42        ScalarDType::Float16 => "half",
43        ScalarDType::BFloat16 => "bhalf",
44        ScalarDType::Float32 => "float",
45        ScalarDType::Float64 => "double",
46        ScalarDType::Void => "void",
47    }
48}
49
50/// Convert a DType to its C type string.
51///
52/// For vectors, returns the typedef name (e.g. `float4`).
53/// For pointers, returns `T*`.
54pub fn c_dtype(dtype: &DType) -> String {
55    match dtype {
56        DType::Scalar(s) => c_scalar(*s).to_string(),
57        DType::Vector { scalar, count } => {
58            format!("{}{}", c_vector_base(*scalar), count)
59        }
60        DType::Ptr { base, .. } => format!("{}*", c_dtype(base)),
61        DType::Image { .. } => "void*".to_string(),
62    }
63}
64
65/// Render a constant value as a C literal.
66pub fn c_const(val: &ConstValue, dtype: &DType) -> String {
67    match val {
68        ConstValue::Bool(b) => if *b { "1" } else { "0" }.to_string(),
69        ConstValue::Int(i) => {
70            let base = dtype.base();
71            match base {
72                ScalarDType::Int64 | ScalarDType::Index => format!("{i}LL"),
73                ScalarDType::UInt64 => format!("{}ULL", *i as u64),
74                _ => i.to_string(),
75            }
76        }
77        ConstValue::UInt(u) => {
78            let base = dtype.base();
79            match base {
80                ScalarDType::UInt64 => format!("{u}ULL"),
81                ScalarDType::UInt32 => format!("{u}u"),
82                _ => u.to_string(),
83            }
84        }
85        ConstValue::Float(f) => c_float(*f, dtype),
86    }
87}
88
89/// Render a float constant as a C literal.
90fn c_float(f: f64, dtype: &DType) -> String {
91    let base = dtype.base();
92
93    if f.is_nan() {
94        return match base {
95            ScalarDType::Float32 => "__builtin_nanf(\"\")".to_string(),
96            ScalarDType::Float64 => "__builtin_nan(\"\")".to_string(),
97            ScalarDType::Float16 => "((_Float16)__builtin_nanf(\"\"))".to_string(),
98            _ => "__builtin_nanf(\"\")".to_string(),
99        };
100    }
101
102    if f.is_infinite() {
103        let sign = if f.is_sign_negative() { "-" } else { "" };
104        return match base {
105            ScalarDType::Float32 => format!("{sign}__builtin_inff()"),
106            ScalarDType::Float64 => format!("{sign}__builtin_inf()"),
107            ScalarDType::Float16 => format!("((_Float16){sign}__builtin_inff())"),
108            _ => format!("{sign}__builtin_inff()"),
109        };
110    }
111
112    match base {
113        ScalarDType::Float32 => {
114            let f32_val = f as f32;
115            if f32_val == 0.0 && f.is_sign_negative() {
116                "-0.0f".to_string()
117            } else if f32_val.fract() == 0.0 && f32_val.abs() < 1e15 {
118                format!("{:.1}f", f32_val)
119            } else {
120                format!("{:e}f", f32_val)
121            }
122        }
123        ScalarDType::Float64 => {
124            if f == 0.0 && f.is_sign_negative() {
125                "-0.0".to_string()
126            } else if f.fract() == 0.0 && f.abs() < 1e15 {
127                format!("{:.1}", f)
128            } else {
129                format!("{:e}", f)
130            }
131        }
132        ScalarDType::Float16 => {
133            let f32_val = f as f32;
134            format!("((_Float16){}f)", format_f32_literal(f32_val))
135        }
136        ScalarDType::BFloat16 => {
137            let f32_val = f as f32;
138            format!("((__bf16){}f)", format_f32_literal(f32_val))
139        }
140        _ => format!("{:e}f", f as f32),
141    }
142}
143
144/// Format an f32 value as a simple literal.
145fn format_f32_literal(f: f32) -> String {
146    if f.fract() == 0.0 && f.abs() < 1e15 { format!("{:.1}", f) } else { format!("{:e}", f) }
147}
148
149/// Render a vector constant as a C initializer.
150pub fn c_vconst(values: &[ConstValue], dtype: &DType) -> String {
151    let scalar_dtype = dtype.scalar_dtype();
152    let elements: Vec<String> = values.iter().map(|v| c_const(v, &scalar_dtype)).collect();
153    format!("({}){{{}}}", c_dtype(dtype), elements.join(", "))
154}
155
156/// Collect all vector types used in the linearized instruction stream
157/// and return the necessary typedef declarations.
158pub fn collect_vector_typedefs(nodes: &[Arc<UOp>]) -> Vec<String> {
159    let mut seen = BTreeSet::new();
160
161    for node in nodes {
162        collect_vec_dtype(&node.dtype(), &mut seen);
163        // Also check child dtypes for cases where vectors appear as operands
164        for child in node.op().children() {
165            collect_vec_dtype(&child.dtype(), &mut seen);
166        }
167    }
168
169    seen.into_iter()
170        .map(|(scalar, count)| {
171            // Bool can't be used as ext_vector_type base; store as unsigned char
172            let storage_scalar = if scalar == ScalarDType::Bool { "unsigned char" } else { c_scalar(scalar) };
173            let vec_name = format!("{}{}", c_vector_base(scalar), count);
174            let alignment = scalar.bytes() * count;
175            let alignment = alignment.next_power_of_two();
176            format!(
177                "typedef {storage_scalar} {vec_name} __attribute__((aligned({alignment}),ext_vector_type({count})));",
178            )
179        })
180        .collect()
181}
182
183fn collect_vec_dtype(dtype: &DType, seen: &mut BTreeSet<(ScalarDType, usize)>) {
184    match dtype {
185        DType::Vector { scalar, count } => {
186            seen.insert((*scalar, *count));
187        }
188        DType::Ptr { base, .. } => collect_vec_dtype(base, seen),
189        _ => {}
190    }
191}
192
193/// Get the C math function name for the given unary op suffix and dtype.
194/// Returns function name with type suffix (e.g. `sqrtf` for float, `sqrt` for double).
195pub fn c_math_fn(name: &str, dtype: &DType) -> String {
196    let base = dtype.base();
197    match base {
198        ScalarDType::Float32 => format!("{name}f"),
199        ScalarDType::Float64 => name.to_string(),
200        // For half/bfloat, cast through float
201        _ => format!("{name}f"),
202    }
203}
204
205/// Get the identity element for a reduce operation as a C literal.
206pub fn c_reduce_identity(op: svod_ir::ReduceOp, dtype: &DType) -> String {
207    use svod_ir::ReduceOp;
208    let is_f64 = matches!(dtype.base(), ScalarDType::Float64);
209    match op {
210        ReduceOp::Add => {
211            if dtype.is_float() {
212                if is_f64 { "0.0" } else { "0.0f" }.to_string()
213            } else {
214                "0".to_string()
215            }
216        }
217        ReduceOp::Mul => {
218            if dtype.is_float() {
219                if is_f64 { "1.0" } else { "1.0f" }.to_string()
220            } else {
221                "1".to_string()
222            }
223        }
224        ReduceOp::Max => {
225            if dtype.is_float() {
226                format!("-{}", c_math_fn("__builtin_inf", dtype))
227            } else if dtype.is_signed() {
228                match dtype.base() {
229                    ScalarDType::Int64 | ScalarDType::Index => format!("{}LL", i64::MIN),
230                    ScalarDType::Int32 => format!("{}", i32::MIN),
231                    ScalarDType::Int16 => format!("{}", i16::MIN),
232                    ScalarDType::Int8 => format!("{}", i8::MIN),
233                    _ => "0".to_string(),
234                }
235            } else {
236                "0".to_string()
237            }
238        }
239        ReduceOp::Min => {
240            if dtype.is_float() {
241                c_math_fn("__builtin_inf", dtype)
242            } else if dtype.is_signed() {
243                match dtype.base() {
244                    ScalarDType::Int64 | ScalarDType::Index => format!("{}LL", i64::MAX),
245                    ScalarDType::Int32 => format!("{}", i32::MAX),
246                    ScalarDType::Int16 => format!("{}", i16::MAX),
247                    ScalarDType::Int8 => format!("{}", i8::MAX),
248                    _ => "0".to_string(),
249                }
250            } else {
251                match dtype.base() {
252                    ScalarDType::UInt64 => format!("{}ULL", u64::MAX),
253                    ScalarDType::UInt32 => format!("{}u", u32::MAX),
254                    ScalarDType::UInt16 => format!("{}", u16::MAX),
255                    ScalarDType::UInt8 => format!("{}", u8::MAX),
256                    _ => "0".to_string(),
257                }
258            }
259        }
260    }
261}
262
263/// Get the C cast expression for converting between types.
264pub fn c_cast(val: &str, from: &DType, to: &DType) -> String {
265    let to_str = c_dtype(to);
266    // For pointer casts, use void* intermediate
267    if matches!(from, DType::Ptr { .. }) && !matches!(to, DType::Ptr { .. }) {
268        return format!("({})(long long){}", to_str, val);
269    }
270    format!("({}){}", to_str, val)
271}