1use std::collections::{HashMap, HashSet};
8use std::sync::Arc;
9
10use svod_dtype::{DType, ScalarDType};
11use svod_ir::{BinaryOp, Op, ReduceOp, TernaryOp, UnaryOp, prelude::*};
12
13use super::types::{c_cast, c_dtype, c_math_fn};
14use crate::common::format_custom_template_strict;
15
16pub struct CContext {
18 names: HashMap<u64, String>,
20 ref_counts: HashMap<u64, usize>,
22 counter: usize,
24 depth: usize,
26 pending_reduces: HashMap<u64, (String, DType)>,
28 scope_escaping: HashSet<u64>,
30 pub hoisted_declarations: Vec<String>,
32 pending_error: Option<crate::Error>,
36}
37
38impl CContext {
39 pub fn new(ref_counts: HashMap<u64, usize>, scope_escaping: HashSet<u64>) -> Self {
40 Self {
41 names: HashMap::new(),
42 ref_counts,
43 counter: 0,
44 depth: 1,
45 pending_reduces: HashMap::new(),
46 scope_escaping,
47 hoisted_declarations: Vec::new(),
48 pending_error: None,
49 }
50 }
51
52 pub fn set_invalid_graph(&mut self, reason: impl Into<String>) {
54 if self.pending_error.is_none() {
55 self.pending_error = Some(crate::Error::InvalidGraph { reason: reason.into() });
56 }
57 }
58
59 pub fn take_error(&mut self) -> Option<crate::Error> {
61 self.pending_error.take()
62 }
63
64 pub fn get(&self, uop: &Arc<UOp>) -> &str {
66 self.names
67 .get(&uop.id)
68 .map(|s| s.as_str())
69 .unwrap_or_else(|| panic!("UOp {} ({}) not in C context", uop.id, uop.op().as_ref()))
70 }
71
72 pub fn register(&mut self, id: u64, expr: String) {
74 self.names.insert(id, expr);
75 }
76
77 pub fn should_inline(&self, id: u64) -> bool {
79 self.ref_counts.get(&id).copied().unwrap_or(0) <= 1
80 }
81
82 pub fn next_name(&mut self, prefix: &str) -> String {
84 let name = format!("{}{}", prefix, self.counter);
85 self.counter += 1;
86 name
87 }
88
89 pub fn indent(&self) -> String {
91 " ".repeat(self.depth)
92 }
93
94 pub fn push_indent(&mut self) {
96 self.depth += 1;
97 }
98
99 pub fn pop_indent(&mut self) {
101 self.depth = self.depth.saturating_sub(1);
102 }
103
104 pub fn register_reduce_pending(&mut self, reduce_id: u64, acc_name: String, dtype: DType) {
106 self.pending_reduces.insert(reduce_id, (acc_name, dtype));
107 }
108
109 pub fn take_pending_reduces(&mut self) -> HashMap<u64, (String, DType)> {
111 std::mem::take(&mut self.pending_reduces)
112 }
113
114 pub fn emit_expr(&mut self, uop: &Arc<UOp>, expr: String, prefix: &str, kernel: &mut Vec<String>) -> String {
122 if self.should_inline(uop.id) {
123 self.register(uop.id, expr.clone());
124 expr
125 } else {
126 let name = self.next_name(prefix);
127 let dtype = c_dtype(&uop.dtype());
128 let indent = self.indent();
129 if self.scope_escaping.contains(&uop.id) {
130 self.hoisted_declarations.push(format!(" {dtype} {name};"));
132 kernel.push(format!("{indent}{name} = {expr};"));
133 } else {
134 kernel.push(format!("{indent}{dtype} {name} = {expr};"));
135 }
136 self.register(uop.id, name.clone());
137 name
138 }
139 }
140}
141
142pub fn render_uop(uop: &Arc<UOp>, ctx: &mut CContext, kernel: &mut Vec<String>) -> Option<()> {
146 match uop.op() {
147 Op::Const(_)
149 | Op::VConst { .. }
150 | Op::Param { device: None, .. }
151 | Op::DefineLocal(_)
152 | Op::DefineVar { .. }
153 | Op::Noop
154 | Op::Sink { .. }
155 | Op::Group { .. }
156 | Op::Buffer { .. }
157 | Op::Unique(_)
158 | Op::Device(_)
159 | Op::Call { .. }
160 | Op::Barrier { .. } => None,
161
162 Op::DefineReg { .. } => {
163 let (base_dtype, alloc_size) = match uop.dtype() {
167 DType::Ptr { base, size, .. } => (base.as_ref().clone(), size.unwrap_or(1)),
168 other => (other, 1),
169 };
170 let name = ctx.next_name("reg");
171 let indent = ctx.indent();
172 kernel.push(format!("{indent}{} {name}[{alloc_size}];", c_dtype(&base_dtype)));
173 ctx.register(uop.id, name);
174 Some(())
175 }
176
177 Op::Index { buffer, indices, .. } => {
178 let buf = ctx.get(buffer).to_string();
179
180 if indices.is_empty() {
181 ctx.register(uop.id, buf);
183 } else {
184 let idx = if indices.len() == 1 {
185 ctx.get(&indices[0]).to_string()
186 } else {
187 ctx.set_invalid_graph(format!(
188 "C renderer requires linearized INDEX (single-axis), found {} indices on uop {}",
189 indices.len(),
190 uop.id
191 ));
192 return None;
193 };
194 let expr = format!("{buf} + {idx}");
195 ctx.emit_expr(uop, expr, "idx", kernel);
196 }
197 Some(())
198 }
199
200 Op::PointerIndex { ptr, offset } => {
201 let ptr_val = ctx.get(ptr).to_string();
202 let off_val = ctx.get(offset).to_string();
203 let expr = format!("{ptr_val} + {off_val}");
204 ctx.emit_expr(uop, expr, "pidx", kernel);
205 Some(())
206 }
207
208 Op::Load { index, alt, .. } => {
209 let idx = ctx.get(index).to_string();
210 let load_dtype = uop.dtype();
211 let actual_index = match index.op() {
214 Op::Cast { src, .. } => src,
215 _ => index,
216 };
217 let gate_expr = if let Op::Index { gate: Some(gate_uop), .. } = actual_index.op() {
218 Some(ctx.get(gate_uop).to_string())
219 } else {
220 None
221 };
222 let deref_expr = if load_dtype.vcount() > 1 {
223 let cast_type = c_dtype(&load_dtype);
224 format!("*(({cast_type}*)({idx}))")
225 } else {
226 format!("*({idx})")
227 };
228 let expr = if let Some(gate) = gate_expr {
229 let Some(alt_uop) = alt.as_ref() else {
230 ctx.set_invalid_graph(format!(
231 "gated LOAD on uop {} has no alt value; line_rewrite_cleanups must lift gated LOADs",
232 uop.id
233 ));
234 return None;
235 };
236 let alt_expr = ctx.get(alt_uop).to_string();
237 format!("({gate} ? {deref_expr} : {alt_expr})")
238 } else {
239 deref_expr
240 };
241 ctx.emit_expr(uop, expr, "val", kernel);
242 Some(())
243 }
244
245 Op::Store { index, value, .. } => {
246 let idx = ctx.get(index).to_string();
247 let val = ctx.get(value).to_string();
248 let indent = ctx.indent();
249 let val_dtype = value.dtype();
250 if val_dtype.vcount() > 1 {
253 let cast_type = c_dtype(&val_dtype);
254 kernel.push(format!("{indent}*(({cast_type}*)({idx})) = {val};"));
255 } else {
256 kernel.push(format!("{indent}*({idx}) = {val};"));
257 }
258 Some(())
259 }
260
261 Op::Binary(op, lhs, rhs) => {
262 let l = ctx.get(lhs).to_string();
263 let r = ctx.get(rhs).to_string();
264 let expr = render_binary(*op, &l, &r, &lhs.dtype());
265 ctx.emit_expr(uop, expr, "alu", kernel);
266 Some(())
267 }
268
269 Op::Unary(op, src) => {
270 let s = ctx.get(src).to_string();
271 let expr = render_unary(*op, &s, &src.dtype());
272 ctx.emit_expr(uop, expr, "alu", kernel);
273 Some(())
274 }
275
276 Op::Ternary(TernaryOp::Where, cond, t, f) => {
277 let c = ctx.get(cond).to_string();
278 let tv = ctx.get(t).to_string();
279 let fv = ctx.get(f).to_string();
280 let expr = format!("({c} ? {tv} : {fv})");
281 ctx.emit_expr(uop, expr, "alu", kernel);
282 Some(())
283 }
284
285 Op::Ternary(TernaryOp::MulAcc, a, b, c) => {
286 let av = ctx.get(a).to_string();
287 let bv = ctx.get(b).to_string();
288 let cv = ctx.get(c).to_string();
289 let expr = if a.dtype().is_float() {
290 format!("{}({av}, {bv}, {cv})", c_math_fn("__builtin_fma", &a.dtype()))
291 } else {
292 format!("(({av} * {bv}) + {cv})")
293 };
294 ctx.emit_expr(uop, expr, "alu", kernel);
295 Some(())
296 }
297
298 Op::Cast { src, dtype } => {
299 let s = ctx.get(src).to_string();
300
301 if matches!(src.op(), Op::Index { .. }) && matches!(dtype, DType::Ptr { .. }) {
303 ctx.register(uop.id, s);
304 return Some(());
305 }
306
307 let expr = if dtype.vcount() > 1 && !matches!(dtype, DType::Ptr { .. }) {
310 format!("__builtin_convertvector({s}, {})", c_dtype(dtype))
311 } else {
312 c_cast(&s, &src.dtype(), dtype)
313 };
314 ctx.emit_expr(uop, expr, "cast", kernel);
315 Some(())
316 }
317
318 Op::BitCast { src, dtype } => {
319 let s = ctx.get(src).to_string();
320 let from_type = c_dtype(&src.dtype());
321 let to_type = c_dtype(dtype);
322 if from_type == to_type {
323 ctx.register(uop.id, s);
324 } else {
325 let expr = format!("__builtin_bit_cast({to_type}, ({from_type})({s}))");
326 ctx.emit_expr(uop, expr, "cast", kernel);
327 }
328 Some(())
329 }
330
331 Op::Reshape { src, .. } => {
332 let s = ctx.get(src).to_string();
333 ctx.register(uop.id, s);
334 Some(())
335 }
336
337 Op::Range { end, axis_id, .. } => {
338 let end_val = ctx.get(end).to_string();
339 let id = axis_id.value();
340 let range_dtype = c_dtype(&uop.dtype());
341 let var_name = format!("ridx{id}");
342 let indent = ctx.indent();
343 kernel.push(format!("{indent}for ({range_dtype} {var_name} = 0; {var_name} < {end_val}; {var_name}++) {{"));
344 ctx.register(uop.id, var_name);
345 ctx.push_indent();
346 Some(())
347 }
348
349 Op::End { ranges, .. } => {
350 for range in ranges.iter() {
351 if let Op::Range { .. } = range.op() {
352 ctx.pop_indent();
353 let indent = ctx.indent();
354 kernel.push(format!("{indent}}}"));
355 }
356 }
357
358 let pending = ctx.take_pending_reduces();
362 for (reduce_id, (acc_name, _dtype)) in pending {
363 ctx.register(reduce_id, acc_name);
366 }
367 Some(())
368 }
369
370 Op::Reduce { src, ranges, reduce_op } => {
371 let src_val = ctx.get(src).to_string();
372 let dtype = &uop.dtype();
373
374 if ranges.is_empty() {
375 ctx.register(uop.id, src_val);
377 } else {
378 let acc_name = ctx.get(uop).to_string();
380 let indent = ctx.indent();
381
382 let acc_expr = render_reduce_accumulate(*reduce_op, &acc_name, &src_val, dtype);
383 kernel.push(format!("{indent}{acc_expr}"));
384
385 ctx.register_reduce_pending(uop.id, acc_name, dtype.clone());
387 }
388 Some(())
389 }
390
391 Op::Gep { vector, indices } => {
392 let vec = ctx.get(vector).to_string();
393 if indices.len() == 1 {
394 let expr = format!("({vec})[{}]", indices[0]);
396 ctx.emit_expr(uop, expr, "gep", kernel);
397 } else {
398 let out_dtype = c_dtype(&uop.dtype());
400 let elements: Vec<String> = indices.iter().map(|&i| format!("({vec})[{i}]")).collect();
401 let expr = format!("({out_dtype}){{{}}}", elements.join(", "));
402 ctx.emit_expr(uop, expr, "gep", kernel);
403 }
404 Some(())
405 }
406
407 Op::Vectorize { elements } => {
408 let vals: Vec<String> = elements.iter().map(|e| ctx.get(e).to_string()).collect();
409 if matches!(uop.dtype(), DType::Ptr { .. }) {
410 ctx.emit_expr(uop, vals[0].clone(), "vec", kernel);
413 } else {
414 let out_dtype = c_dtype(&uop.dtype());
415 let expr = format!("({out_dtype}){{{}}}", vals.join(", "));
416 ctx.emit_expr(uop, expr, "vec", kernel);
417 }
418 Some(())
419 }
420
421 Op::Cat { sources } => {
422 render_cat(uop, sources, ctx, kernel);
423 Some(())
424 }
425
426 Op::PtrCat { .. } => {
427 panic!(
428 "PtrCat must be eliminated before codegen (devectorize should distribute it into scalar loads/stores)"
429 );
430 }
431
432 Op::Wmma { a, b, c, metadata } => {
433 let a_val = ctx.get(a).to_string();
434 let b_val = ctx.get(b).to_string();
435 let c_val = ctx.get(c).to_string();
436 let expr = format!("__{name}({a_val}, {b_val}, {c_val})", name = metadata.name);
437 ctx.emit_expr(uop, expr, "wmma", kernel);
438 Some(())
439 }
440
441 Op::CustomI { deps, code } => {
442 let args: Vec<String> = deps.iter().map(|dep| ctx.get(dep).to_string()).collect();
443 let expr = match format_custom_template_strict(code, &args) {
444 Ok(s) => s,
445 Err(e) => {
446 ctx.set_invalid_graph(format!("CUSTOMI template error on uop {}: {e}", uop.id));
447 return None;
448 }
449 };
450 ctx.register(uop.id, expr);
452 Some(())
453 }
454
455 Op::Custom { deps, code } => {
456 let args: Vec<String> = deps.iter().map(|dep| ctx.get(dep).to_string()).collect();
457 let rendered = match format_custom_template_strict(code, &args) {
458 Ok(s) => s,
459 Err(e) => {
460 ctx.set_invalid_graph(format!("CUSTOM template error on uop {}: {e}", uop.id));
461 return None;
462 }
463 };
464 let indent = ctx.indent();
465
466 if uop.dtype() == DType::Void {
467 let stmt = if rendered.trim_end().ends_with(';') { rendered } else { format!("{rendered};") };
468 kernel.push(format!("{indent}{stmt}"));
469 ctx.register(uop.id, String::new());
470 } else {
471 let name = ctx.next_name("custom");
472 let dtype = c_dtype(&uop.dtype());
473 if ctx.scope_escaping.contains(&uop.id) {
474 ctx.hoisted_declarations.push(format!(" {dtype} {name};"));
475 kernel.push(format!("{indent}{name} = {rendered};"));
476 } else {
477 kernel.push(format!("{indent}{dtype} {name} = {rendered};"));
478 }
479 ctx.register(uop.id, name);
480 }
481 Some(())
482 }
483
484 Op::Contract { src, .. } | Op::Unroll { src, .. } | Op::Detach { src } => {
485 let s = ctx.get(src).to_string();
486 ctx.register(uop.id, s);
487 None
488 }
489
490 Op::After { passthrough, .. } => {
491 assert!(
492 !matches!(passthrough.op(), Op::Group { .. }),
493 "BUG: AFTER passthrough is GROUP (id={}). AFTER tree:\n{}",
494 passthrough.id,
495 uop.tree()
496 );
497 let s = ctx.get(passthrough).to_string();
498 ctx.register(uop.id, s);
499 None
500 }
501
502 Op::Bind { var, value } => {
503 let v = ctx.get(value).to_string();
504 ctx.register(var.id, v);
505 None
506 }
507
508 Op::If { condition, .. } => {
509 let cond = ctx.get(condition).to_string();
510 let indent = ctx.indent();
511 kernel.push(format!("{indent}if ({cond}) {{"));
512 ctx.push_indent();
513 Some(())
514 }
515
516 Op::EndIf { .. } => {
517 ctx.pop_indent();
518 let indent = ctx.indent();
519 kernel.push(format!("{indent}}}"));
520 Some(())
521 }
522
523 _ => {
524 let indent = ctx.indent();
525 kernel.push(format!("{indent}/* UNSUPPORTED: {:?} */", uop.op().as_ref()));
526 None
527 }
528 }
529}
530
531fn render_binary(op: BinaryOp, l: &str, r: &str, dtype: &DType) -> String {
533 match op {
534 BinaryOp::Add => format!("({l} + {r})"),
535 BinaryOp::Sub => format!("({l} - {r})"),
536 BinaryOp::Mul => format!("({l} * {r})"),
537 BinaryOp::Fdiv => format!("({l} / {r})"),
538 BinaryOp::Idiv => format!("({l} / {r})"),
539 BinaryOp::Mod => {
540 if dtype.is_float() {
541 format!("{}({l}, {r})", c_math_fn("__builtin_fmod", dtype))
542 } else {
543 format!("({l} % {r})")
544 }
545 }
546 BinaryOp::Max => {
547 if dtype.is_float() {
548 format!("{}({l}, {r})", c_math_fn("__builtin_fmax", dtype))
549 } else {
550 format!("({l} > {r} ? {l} : {r})")
551 }
552 }
553 BinaryOp::Lt => format!("({l} < {r})"),
554 BinaryOp::Le => format!("({l} <= {r})"),
555 BinaryOp::Gt => format!("({l} > {r})"),
556 BinaryOp::Ge => format!("({l} >= {r})"),
557 BinaryOp::Eq => format!("({l} == {r})"),
558 BinaryOp::Ne => format!("({l} != {r})"),
559 BinaryOp::And => format!("({l} & {r})"),
560 BinaryOp::Or => format!("({l} | {r})"),
561 BinaryOp::Xor => format!("({l} ^ {r})"),
562 BinaryOp::Shl => format!("({l} << {r})"),
563 BinaryOp::Shr => format!("({l} >> {r})"),
564 BinaryOp::Pow => {
565 if dtype.is_float() {
566 format!("{}({l}, {r})", c_math_fn("__builtin_pow", dtype))
567 } else {
568 format!("(({})__builtin_pow((double){l}, (double){r}))", c_dtype(&DType::Scalar(dtype.base())))
570 }
571 }
572 BinaryOp::Threefry => format!("({l} ^ {r})"),
573 }
574}
575
576fn render_unary(op: UnaryOp, s: &str, dtype: &DType) -> String {
578 match op {
579 UnaryOp::Neg => {
580 format!("(-{s})")
581 }
582 UnaryOp::Not => {
583 if dtype.is_bool() {
584 format!("(!{s})")
585 } else {
586 format!("(~{s})")
587 }
588 }
589 UnaryOp::Abs => {
590 if dtype.is_float() {
591 format!("{}({s})", c_math_fn("__builtin_fabs", dtype))
592 } else {
593 format!("({s} < 0 ? -{s} : {s})")
594 }
595 }
596 UnaryOp::Sqrt => format!("{}({s})", c_math_fn("__builtin_sqrt", dtype)),
597 UnaryOp::Rsqrt => {
598 let one = if matches!(dtype.base(), ScalarDType::Float64) { "1.0" } else { "1.0f" };
599 format!("({one} / {}({s}))", c_math_fn("__builtin_sqrt", dtype))
600 }
601 UnaryOp::Reciprocal => {
602 let one = if matches!(dtype.base(), ScalarDType::Float64) { "1.0" } else { "1.0f" };
603 format!("({one} / {s})")
604 }
605 UnaryOp::Exp => format!("{}({s})", c_math_fn("__builtin_exp", dtype)),
606 UnaryOp::Exp2 => format!("{}({s})", c_math_fn("__builtin_exp2", dtype)),
607 UnaryOp::Log => format!("{}({s})", c_math_fn("__builtin_log", dtype)),
608 UnaryOp::Log2 => format!("{}({s})", c_math_fn("__builtin_log2", dtype)),
609 UnaryOp::Sin => format!("{}({s})", c_math_fn("__builtin_sin", dtype)),
610 UnaryOp::Cos => format!("{}({s})", c_math_fn("__builtin_cos", dtype)),
611 UnaryOp::Tan => format!("{}({s})", c_math_fn("__builtin_tan", dtype)),
612 UnaryOp::Floor => format!("{}({s})", c_math_fn("__builtin_floor", dtype)),
613 UnaryOp::Ceil => format!("{}({s})", c_math_fn("__builtin_ceil", dtype)),
614 UnaryOp::Trunc => format!("{}({s})", c_math_fn("__builtin_trunc", dtype)),
615 UnaryOp::Round => format!("{}({s})", c_math_fn("__builtin_rint", dtype)),
616 UnaryOp::Erf => format!("{}({s})", c_math_fn("__builtin_erf", dtype)),
617 UnaryOp::Sign => {
618 if dtype.is_float() {
619 let zero = if matches!(dtype.base(), ScalarDType::Float64) { "0.0" } else { "0.0f" };
620 format!("(({s} > {zero}) - ({s} < {zero}))")
621 } else {
622 format!("(({s} > 0) - ({s} < 0))")
623 }
624 }
625 UnaryOp::Square => format!("({s} * {s})"),
626 }
627}
628
629fn render_reduce_accumulate(op: ReduceOp, acc: &str, val: &str, dtype: &DType) -> String {
631 match op {
632 ReduceOp::Add => format!("{acc} += {val};"),
633 ReduceOp::Mul => format!("{acc} *= {val};"),
634 ReduceOp::Max => {
635 if dtype.is_float() {
636 format!("{acc} = {}({acc}, {val});", c_math_fn("__builtin_fmax", dtype))
637 } else {
638 format!("{acc} = ({acc} > {val} ? {acc} : {val});")
639 }
640 }
641 ReduceOp::Min => {
642 if dtype.is_float() {
643 format!("{acc} = {}({acc}, {val});", c_math_fn("__builtin_fmin", dtype))
644 } else {
645 format!("{acc} = ({acc} < {val} ? {acc} : {val});")
646 }
647 }
648 }
649}
650
651fn render_cat(uop: &Arc<UOp>, sources: &[Arc<UOp>], ctx: &mut CContext, kernel: &mut Vec<String>) {
653 let out_dtype = c_dtype(&uop.dtype());
654 let mut elements = Vec::new();
655
656 for src in sources {
657 let src_val = ctx.get(src).to_string();
658 let src_vcount = src.dtype().vcount();
659 if src_vcount == 1 {
660 elements.push(src_val);
661 } else {
662 for i in 0..src_vcount {
663 elements.push(format!("{src_val}[{i}]"));
664 }
665 }
666 }
667
668 let expr = format!("({out_dtype}){{{}}}", elements.join(", "));
669 ctx.emit_expr(uop, expr, "cat", kernel);
670}
671
672pub fn count_references(nodes: &[Arc<UOp>]) -> HashMap<u64, usize> {
675 let mut counts: HashMap<u64, usize> = HashMap::new();
676 for node in nodes {
677 for child in node.op().children() {
678 *counts.entry(child.id).or_insert(0) += 1;
679 }
680 }
681 counts
682}