1use std::collections::BTreeSet;
4use std::sync::Arc;
5
6use svod_dtype::{DType, ScalarDType};
7use svod_ir::{ConstValue, UOp};
8
9pub fn c_scalar(s: ScalarDType) -> &'static str {
11 match s {
12 ScalarDType::Bool => "_Bool",
13 ScalarDType::Int8 => "signed char",
14 ScalarDType::UInt8 => "unsigned char",
15 ScalarDType::Int16 => "short",
16 ScalarDType::UInt16 => "unsigned short",
17 ScalarDType::Int32 => "int",
18 ScalarDType::UInt32 => "unsigned int",
19 ScalarDType::Int64 | ScalarDType::Index => "long long",
20 ScalarDType::UInt64 => "unsigned long long",
21 ScalarDType::Float16 => "_Float16",
22 ScalarDType::BFloat16 => "__bf16",
23 ScalarDType::Float32 => "float",
24 ScalarDType::Float64 => "double",
25 ScalarDType::Void => "void",
26 ScalarDType::FP8E4M3 | ScalarDType::FP8E5M2 => "unsigned char",
27 }
28}
29
30fn c_vector_base(s: ScalarDType) -> &'static str {
32 match s {
33 ScalarDType::Bool => "bool",
34 ScalarDType::Int8 => "schar",
35 ScalarDType::UInt8 | ScalarDType::FP8E4M3 | ScalarDType::FP8E5M2 => "uchar",
36 ScalarDType::Int16 => "short",
37 ScalarDType::UInt16 => "ushort",
38 ScalarDType::Int32 => "int",
39 ScalarDType::UInt32 => "uint",
40 ScalarDType::Int64 | ScalarDType::Index => "llong",
41 ScalarDType::UInt64 => "ullong",
42 ScalarDType::Float16 => "half",
43 ScalarDType::BFloat16 => "bhalf",
44 ScalarDType::Float32 => "float",
45 ScalarDType::Float64 => "double",
46 ScalarDType::Void => "void",
47 }
48}
49
50pub fn c_dtype(dtype: &DType) -> String {
55 match dtype {
56 DType::Scalar(s) => c_scalar(*s).to_string(),
57 DType::Vector { scalar, count } => {
58 format!("{}{}", c_vector_base(*scalar), count)
59 }
60 DType::Ptr { base, .. } => format!("{}*", c_dtype(base)),
61 DType::Image { .. } => "void*".to_string(),
62 }
63}
64
65pub fn c_const(val: &ConstValue, dtype: &DType) -> String {
67 match val {
68 ConstValue::Bool(b) => if *b { "1" } else { "0" }.to_string(),
69 ConstValue::Int(i) => {
70 let base = dtype.base();
71 match base {
72 ScalarDType::Int64 | ScalarDType::Index => format!("{i}LL"),
73 ScalarDType::UInt64 => format!("{}ULL", *i as u64),
74 _ => i.to_string(),
75 }
76 }
77 ConstValue::UInt(u) => {
78 let base = dtype.base();
79 match base {
80 ScalarDType::UInt64 => format!("{u}ULL"),
81 ScalarDType::UInt32 => format!("{u}u"),
82 _ => u.to_string(),
83 }
84 }
85 ConstValue::Float(f) => c_float(*f, dtype),
86 }
87}
88
89fn c_float(f: f64, dtype: &DType) -> String {
91 let base = dtype.base();
92
93 if f.is_nan() {
94 return match base {
95 ScalarDType::Float32 => "__builtin_nanf(\"\")".to_string(),
96 ScalarDType::Float64 => "__builtin_nan(\"\")".to_string(),
97 ScalarDType::Float16 => "((_Float16)__builtin_nanf(\"\"))".to_string(),
98 _ => "__builtin_nanf(\"\")".to_string(),
99 };
100 }
101
102 if f.is_infinite() {
103 let sign = if f.is_sign_negative() { "-" } else { "" };
104 return match base {
105 ScalarDType::Float32 => format!("{sign}__builtin_inff()"),
106 ScalarDType::Float64 => format!("{sign}__builtin_inf()"),
107 ScalarDType::Float16 => format!("((_Float16){sign}__builtin_inff())"),
108 _ => format!("{sign}__builtin_inff()"),
109 };
110 }
111
112 match base {
113 ScalarDType::Float32 => {
114 let f32_val = f as f32;
115 if f32_val == 0.0 && f.is_sign_negative() {
116 "-0.0f".to_string()
117 } else if f32_val.fract() == 0.0 && f32_val.abs() < 1e15 {
118 format!("{:.1}f", f32_val)
119 } else {
120 format!("{:e}f", f32_val)
121 }
122 }
123 ScalarDType::Float64 => {
124 if f == 0.0 && f.is_sign_negative() {
125 "-0.0".to_string()
126 } else if f.fract() == 0.0 && f.abs() < 1e15 {
127 format!("{:.1}", f)
128 } else {
129 format!("{:e}", f)
130 }
131 }
132 ScalarDType::Float16 => {
133 let f32_val = f as f32;
134 format!("((_Float16){}f)", format_f32_literal(f32_val))
135 }
136 ScalarDType::BFloat16 => {
137 let f32_val = f as f32;
138 format!("((__bf16){}f)", format_f32_literal(f32_val))
139 }
140 _ => format!("{:e}f", f as f32),
141 }
142}
143
144fn format_f32_literal(f: f32) -> String {
146 if f.fract() == 0.0 && f.abs() < 1e15 { format!("{:.1}", f) } else { format!("{:e}", f) }
147}
148
149pub fn c_vconst(values: &[ConstValue], dtype: &DType) -> String {
151 let scalar_dtype = dtype.scalar_dtype();
152 let elements: Vec<String> = values.iter().map(|v| c_const(v, &scalar_dtype)).collect();
153 format!("({}){{{}}}", c_dtype(dtype), elements.join(", "))
154}
155
156pub fn collect_vector_typedefs(nodes: &[Arc<UOp>]) -> Vec<String> {
159 let mut seen = BTreeSet::new();
160
161 for node in nodes {
162 collect_vec_dtype(&node.dtype(), &mut seen);
163 for child in node.op().children() {
165 collect_vec_dtype(&child.dtype(), &mut seen);
166 }
167 }
168
169 seen.into_iter()
170 .map(|(scalar, count)| {
171 let storage_scalar = if scalar == ScalarDType::Bool { "unsigned char" } else { c_scalar(scalar) };
173 let vec_name = format!("{}{}", c_vector_base(scalar), count);
174 let alignment = scalar.bytes() * count;
175 let alignment = alignment.next_power_of_two();
176 format!(
177 "typedef {storage_scalar} {vec_name} __attribute__((aligned({alignment}),ext_vector_type({count})));",
178 )
179 })
180 .collect()
181}
182
183fn collect_vec_dtype(dtype: &DType, seen: &mut BTreeSet<(ScalarDType, usize)>) {
184 match dtype {
185 DType::Vector { scalar, count } => {
186 seen.insert((*scalar, *count));
187 }
188 DType::Ptr { base, .. } => collect_vec_dtype(base, seen),
189 _ => {}
190 }
191}
192
193pub fn c_math_fn(name: &str, dtype: &DType) -> String {
196 let base = dtype.base();
197 match base {
198 ScalarDType::Float32 => format!("{name}f"),
199 ScalarDType::Float64 => name.to_string(),
200 _ => format!("{name}f"),
202 }
203}
204
205pub fn c_reduce_identity(op: svod_ir::ReduceOp, dtype: &DType) -> String {
207 use svod_ir::ReduceOp;
208 let is_f64 = matches!(dtype.base(), ScalarDType::Float64);
209 match op {
210 ReduceOp::Add => {
211 if dtype.is_float() {
212 if is_f64 { "0.0" } else { "0.0f" }.to_string()
213 } else {
214 "0".to_string()
215 }
216 }
217 ReduceOp::Mul => {
218 if dtype.is_float() {
219 if is_f64 { "1.0" } else { "1.0f" }.to_string()
220 } else {
221 "1".to_string()
222 }
223 }
224 ReduceOp::Max => {
225 if dtype.is_float() {
226 format!("-{}", c_math_fn("__builtin_inf", dtype))
227 } else if dtype.is_signed() {
228 match dtype.base() {
229 ScalarDType::Int64 | ScalarDType::Index => format!("{}LL", i64::MIN),
230 ScalarDType::Int32 => format!("{}", i32::MIN),
231 ScalarDType::Int16 => format!("{}", i16::MIN),
232 ScalarDType::Int8 => format!("{}", i8::MIN),
233 _ => "0".to_string(),
234 }
235 } else {
236 "0".to_string()
237 }
238 }
239 ReduceOp::Min => {
240 if dtype.is_float() {
241 c_math_fn("__builtin_inf", dtype)
242 } else if dtype.is_signed() {
243 match dtype.base() {
244 ScalarDType::Int64 | ScalarDType::Index => format!("{}LL", i64::MAX),
245 ScalarDType::Int32 => format!("{}", i32::MAX),
246 ScalarDType::Int16 => format!("{}", i16::MAX),
247 ScalarDType::Int8 => format!("{}", i8::MAX),
248 _ => "0".to_string(),
249 }
250 } else {
251 match dtype.base() {
252 ScalarDType::UInt64 => format!("{}ULL", u64::MAX),
253 ScalarDType::UInt32 => format!("{}u", u32::MAX),
254 ScalarDType::UInt16 => format!("{}", u16::MAX),
255 ScalarDType::UInt8 => format!("{}", u8::MAX),
256 _ => "0".to_string(),
257 }
258 }
259 }
260 }
261}
262
263pub fn c_cast(val: &str, from: &DType, to: &DType) -> String {
265 let to_str = c_dtype(to);
266 if matches!(from, DType::Ptr { .. }) && !matches!(to, DType::Ptr { .. }) {
268 return format!("({})(long long){}", to_str, val);
269 }
270 format!("({}){}", to_str, val)
271}