Skip to main content

svod_codegen/c/
ops.rs

1//! C source code rendering for individual UOp operations.
2//!
3//! Generates C expressions/statements for each Op variant.
4//! Uses SSA inlining: single-use values are inlined as expressions,
5//! multi-use values get local variable declarations.
6
7use std::collections::{HashMap, HashSet};
8use std::sync::Arc;
9
10use svod_dtype::{DType, ScalarDType};
11use svod_ir::{BinaryOp, Op, ReduceOp, TernaryOp, UnaryOp, prelude::*};
12
13use super::types::{c_cast, c_dtype, c_math_fn};
14use crate::common::format_custom_template_strict;
15
16/// Context for C code generation, tracking variable names and SSA inlining.
17pub struct CContext {
18    /// UOp ID -> C expression or variable name
19    names: HashMap<u64, String>,
20    /// UOp ID -> reference count (how many times used)
21    ref_counts: HashMap<u64, usize>,
22    /// Variable counter for generating unique names
23    counter: usize,
24    /// Current indentation depth
25    depth: usize,
26    /// Pending reduce accumulator info: reduce_id -> (acc_name, dtype)
27    pending_reduces: HashMap<u64, (String, DType)>,
28    /// UOp IDs that escape their declaration scope — need function-scope declaration.
29    scope_escaping: HashSet<u64>,
30    /// Function-scope declarations for hoisted variables (emitted before kernel body).
31    pub hoisted_declarations: Vec<String>,
32    /// Side-channel error set by `render_uop` when it detects a graph invariant
33    /// violation. The render loop drains this after each call and propagates as
34    /// a typed [`crate::Error`].
35    pending_error: Option<crate::Error>,
36}
37
38impl CContext {
39    pub fn new(ref_counts: HashMap<u64, usize>, scope_escaping: HashSet<u64>) -> Self {
40        Self {
41            names: HashMap::new(),
42            ref_counts,
43            counter: 0,
44            depth: 1,
45            pending_reduces: HashMap::new(),
46            scope_escaping,
47            hoisted_declarations: Vec::new(),
48            pending_error: None,
49        }
50    }
51
52    /// Record an `InvalidGraph` error from a renderer op handler.
53    pub fn set_invalid_graph(&mut self, reason: impl Into<String>) {
54        if self.pending_error.is_none() {
55            self.pending_error = Some(crate::Error::InvalidGraph { reason: reason.into() });
56        }
57    }
58
59    /// Drain any error recorded via [`Self::set_invalid_graph`].
60    pub fn take_error(&mut self) -> Option<crate::Error> {
61        self.pending_error.take()
62    }
63
64    /// Get the C expression for a UOp. Panics if not registered.
65    pub fn get(&self, uop: &Arc<UOp>) -> &str {
66        self.names
67            .get(&uop.id)
68            .map(|s| s.as_str())
69            .unwrap_or_else(|| panic!("UOp {} ({}) not in C context", uop.id, uop.op().as_ref()))
70    }
71
72    /// Register a name/expression for a UOp ID.
73    pub fn register(&mut self, id: u64, expr: String) {
74        self.names.insert(id, expr);
75    }
76
77    /// Check if a value should be inlined (single-use, expression-safe).
78    pub fn should_inline(&self, id: u64) -> bool {
79        self.ref_counts.get(&id).copied().unwrap_or(0) <= 1
80    }
81
82    /// Generate a unique variable name with given prefix.
83    pub fn next_name(&mut self, prefix: &str) -> String {
84        let name = format!("{}{}", prefix, self.counter);
85        self.counter += 1;
86        name
87    }
88
89    /// Get current indentation string.
90    pub fn indent(&self) -> String {
91        "  ".repeat(self.depth)
92    }
93
94    /// Increase indentation depth.
95    pub fn push_indent(&mut self) {
96        self.depth += 1;
97    }
98
99    /// Decrease indentation depth.
100    pub fn pop_indent(&mut self) {
101        self.depth = self.depth.saturating_sub(1);
102    }
103
104    /// Register a pending reduce final load.
105    pub fn register_reduce_pending(&mut self, reduce_id: u64, acc_name: String, dtype: DType) {
106        self.pending_reduces.insert(reduce_id, (acc_name, dtype));
107    }
108
109    /// Take all pending reduces.
110    pub fn take_pending_reduces(&mut self) -> HashMap<u64, (String, DType)> {
111        std::mem::take(&mut self.pending_reduces)
112    }
113
114    /// Emit a C expression, either as an inline expression or a variable declaration.
115    /// Returns the name/expression to reference this value.
116    ///
117    /// Variables that escape their declaration scope are hoisted: declared at function
118    /// scope and assigned at current depth. This prevents "use of undeclared identifier"
119    /// errors when the linearizer places a shared node inside a loop but consumers exist
120    /// outside the loop.
121    pub fn emit_expr(&mut self, uop: &Arc<UOp>, expr: String, prefix: &str, kernel: &mut Vec<String>) -> String {
122        if self.should_inline(uop.id) {
123            self.register(uop.id, expr.clone());
124            expr
125        } else {
126            let name = self.next_name(prefix);
127            let dtype = c_dtype(&uop.dtype());
128            let indent = self.indent();
129            if self.scope_escaping.contains(&uop.id) {
130                // Hoist: declare at function scope, assign at current depth
131                self.hoisted_declarations.push(format!("  {dtype} {name};"));
132                kernel.push(format!("{indent}{name} = {expr};"));
133            } else {
134                kernel.push(format!("{indent}{dtype} {name} = {expr};"));
135            }
136            self.register(uop.id, name.clone());
137            name
138        }
139    }
140}
141
142/// Render a single UOp to C source code.
143///
144/// Returns `Some(())` if code was emitted, `None` for meta-ops.
145pub fn render_uop(uop: &Arc<UOp>, ctx: &mut CContext, kernel: &mut Vec<String>) -> Option<()> {
146    match uop.op() {
147        // Meta-ops: no code emitted
148        Op::Const(_)
149        | Op::VConst { .. }
150        | Op::Param { device: None, .. }
151        | Op::DefineLocal(_)
152        | Op::DefineVar { .. }
153        | Op::Noop
154        | Op::Sink { .. }
155        | Op::Group { .. }
156        | Op::Buffer { .. }
157        | Op::Unique(_)
158        | Op::Device(_)
159        | Op::Call { .. }
160        | Op::Barrier { .. } => None,
161
162        Op::DefineReg { .. } => {
163            // Read base type and size from dtype (matching Tinygrad's x.dtype.base/x.dtype.size).
164            // After devectorize's no_vectorized_buf, the dtype is the canonical source of truth:
165            // e.g. Ptr(base=Float32, size=35) instead of the Op's original size field.
166            let (base_dtype, alloc_size) = match uop.dtype() {
167                DType::Ptr { base, size, .. } => (base.as_ref().clone(), size.unwrap_or(1)),
168                other => (other, 1),
169            };
170            let name = ctx.next_name("reg");
171            let indent = ctx.indent();
172            kernel.push(format!("{indent}{} {name}[{alloc_size}];", c_dtype(&base_dtype)));
173            ctx.register(uop.id, name);
174            Some(())
175        }
176
177        Op::Index { buffer, indices, .. } => {
178            let buf = ctx.get(buffer).to_string();
179
180            if indices.is_empty() {
181                // No index - just alias the buffer pointer
182                ctx.register(uop.id, buf);
183            } else {
184                let idx = if indices.len() == 1 {
185                    ctx.get(&indices[0]).to_string()
186                } else {
187                    ctx.set_invalid_graph(format!(
188                        "C renderer requires linearized INDEX (single-axis), found {} indices on uop {}",
189                        indices.len(),
190                        uop.id
191                    ));
192                    return None;
193                };
194                let expr = format!("{buf} + {idx}");
195                ctx.emit_expr(uop, expr, "idx", kernel);
196            }
197            Some(())
198        }
199
200        Op::PointerIndex { ptr, offset } => {
201            let ptr_val = ctx.get(ptr).to_string();
202            let off_val = ctx.get(offset).to_string();
203            let expr = format!("{ptr_val} + {off_val}");
204            ctx.emit_expr(uop, expr, "pidx", kernel);
205            Some(())
206        }
207
208        Op::Load { index, alt, .. } => {
209            let idx = ctx.get(index).to_string();
210            let load_dtype = uop.dtype();
211            // Gated LOAD follows Tinygrad semantics: conditional load with explicit alt value.
212            // The gate is carried by INDEX, possibly behind one CAST wrapper.
213            let actual_index = match index.op() {
214                Op::Cast { src, .. } => src,
215                _ => index,
216            };
217            let gate_expr = if let Op::Index { gate: Some(gate_uop), .. } = actual_index.op() {
218                Some(ctx.get(gate_uop).to_string())
219            } else {
220                None
221            };
222            let deref_expr = if load_dtype.vcount() > 1 {
223                let cast_type = c_dtype(&load_dtype);
224                format!("*(({cast_type}*)({idx}))")
225            } else {
226                format!("*({idx})")
227            };
228            let expr = if let Some(gate) = gate_expr {
229                let Some(alt_uop) = alt.as_ref() else {
230                    ctx.set_invalid_graph(format!(
231                        "gated LOAD on uop {} has no alt value; line_rewrite_cleanups must lift gated LOADs",
232                        uop.id
233                    ));
234                    return None;
235                };
236                let alt_expr = ctx.get(alt_uop).to_string();
237                format!("({gate} ? {deref_expr} : {alt_expr})")
238            } else {
239                deref_expr
240            };
241            ctx.emit_expr(uop, expr, "val", kernel);
242            Some(())
243        }
244
245        Op::Store { index, value, .. } => {
246            let idx = ctx.get(index).to_string();
247            let val = ctx.get(value).to_string();
248            let indent = ctx.indent();
249            let val_dtype = value.dtype();
250            // Buffer pointers are declared as scalar types (e.g., float*) in C,
251            // so vector stores need an explicit pointer cast.
252            if val_dtype.vcount() > 1 {
253                let cast_type = c_dtype(&val_dtype);
254                kernel.push(format!("{indent}*(({cast_type}*)({idx})) = {val};"));
255            } else {
256                kernel.push(format!("{indent}*({idx}) = {val};"));
257            }
258            Some(())
259        }
260
261        Op::Binary(op, lhs, rhs) => {
262            let l = ctx.get(lhs).to_string();
263            let r = ctx.get(rhs).to_string();
264            let expr = render_binary(*op, &l, &r, &lhs.dtype());
265            ctx.emit_expr(uop, expr, "alu", kernel);
266            Some(())
267        }
268
269        Op::Unary(op, src) => {
270            let s = ctx.get(src).to_string();
271            let expr = render_unary(*op, &s, &src.dtype());
272            ctx.emit_expr(uop, expr, "alu", kernel);
273            Some(())
274        }
275
276        Op::Ternary(TernaryOp::Where, cond, t, f) => {
277            let c = ctx.get(cond).to_string();
278            let tv = ctx.get(t).to_string();
279            let fv = ctx.get(f).to_string();
280            let expr = format!("({c} ? {tv} : {fv})");
281            ctx.emit_expr(uop, expr, "alu", kernel);
282            Some(())
283        }
284
285        Op::Ternary(TernaryOp::MulAcc, a, b, c) => {
286            let av = ctx.get(a).to_string();
287            let bv = ctx.get(b).to_string();
288            let cv = ctx.get(c).to_string();
289            let expr = if a.dtype().is_float() {
290                format!("{}({av}, {bv}, {cv})", c_math_fn("__builtin_fma", &a.dtype()))
291            } else {
292                format!("(({av} * {bv}) + {cv})")
293            };
294            ctx.emit_expr(uop, expr, "alu", kernel);
295            Some(())
296        }
297
298        Op::Cast { src, dtype } => {
299            let s = ctx.get(src).to_string();
300
301            // INDEX to Ptr is a no-op in C (INDEX already produces a pointer)
302            if matches!(src.op(), Op::Index { .. }) && matches!(dtype, DType::Ptr { .. }) {
303                ctx.register(uop.id, s);
304                return Some(());
305            }
306
307            // Vector casts use __builtin_convertvector for element-wise conversion
308            // (a plain C cast would reinterpret bits, not convert values)
309            let expr = if dtype.vcount() > 1 && !matches!(dtype, DType::Ptr { .. }) {
310                format!("__builtin_convertvector({s}, {})", c_dtype(dtype))
311            } else {
312                c_cast(&s, &src.dtype(), dtype)
313            };
314            ctx.emit_expr(uop, expr, "cast", kernel);
315            Some(())
316        }
317
318        Op::BitCast { src, dtype } => {
319            let s = ctx.get(src).to_string();
320            let from_type = c_dtype(&src.dtype());
321            let to_type = c_dtype(dtype);
322            if from_type == to_type {
323                ctx.register(uop.id, s);
324            } else {
325                let expr = format!("__builtin_bit_cast({to_type}, ({from_type})({s}))");
326                ctx.emit_expr(uop, expr, "cast", kernel);
327            }
328            Some(())
329        }
330
331        Op::Reshape { src, .. } => {
332            let s = ctx.get(src).to_string();
333            ctx.register(uop.id, s);
334            Some(())
335        }
336
337        Op::Range { end, axis_id, .. } => {
338            let end_val = ctx.get(end).to_string();
339            let id = axis_id.value();
340            let range_dtype = c_dtype(&uop.dtype());
341            let var_name = format!("ridx{id}");
342            let indent = ctx.indent();
343            kernel.push(format!("{indent}for ({range_dtype} {var_name} = 0; {var_name} < {end_val}; {var_name}++) {{"));
344            ctx.register(uop.id, var_name);
345            ctx.push_indent();
346            Some(())
347        }
348
349        Op::End { ranges, .. } => {
350            for range in ranges.iter() {
351                if let Op::Range { .. } = range.op() {
352                    ctx.pop_indent();
353                    let indent = ctx.indent();
354                    kernel.push(format!("{indent}}}"));
355                }
356            }
357
358            // After closing loops, resolve pending reduces.
359            // In C, the accumulator variable already holds the final value
360            // (unlike LLVM where we need to load from alloca).
361            let pending = ctx.take_pending_reduces();
362            for (reduce_id, (acc_name, _dtype)) in pending {
363                // Re-register the reduce with the accumulator name
364                // so downstream users reference the accumulated value.
365                ctx.register(reduce_id, acc_name);
366            }
367            Some(())
368        }
369
370        Op::Reduce { src, ranges, reduce_op } => {
371            let src_val = ctx.get(src).to_string();
372            let dtype = &uop.dtype();
373
374            if ranges.is_empty() {
375                // Passthrough reduce
376                ctx.register(uop.id, src_val);
377            } else {
378                // Accumulator was pre-declared in mod.rs with name acc{uop.id}
379                let acc_name = ctx.get(uop).to_string();
380                let indent = ctx.indent();
381
382                let acc_expr = render_reduce_accumulate(*reduce_op, &acc_name, &src_val, dtype);
383                kernel.push(format!("{indent}{acc_expr}"));
384
385                // Register pending for End to emit the final value
386                ctx.register_reduce_pending(uop.id, acc_name, dtype.clone());
387            }
388            Some(())
389        }
390
391        Op::Gep { vector, indices } => {
392            let vec = ctx.get(vector).to_string();
393            if indices.len() == 1 {
394                // Parenthesize to handle precedence: *((float4*)ptr)[i] → (*((float4*)ptr))[i]
395                let expr = format!("({vec})[{}]", indices[0]);
396                ctx.emit_expr(uop, expr, "gep", kernel);
397            } else {
398                // Multi-element GEP: build a new vector from extracted elements
399                let out_dtype = c_dtype(&uop.dtype());
400                let elements: Vec<String> = indices.iter().map(|&i| format!("({vec})[{i}]")).collect();
401                let expr = format!("({out_dtype}){{{}}}", elements.join(", "));
402                ctx.emit_expr(uop, expr, "gep", kernel);
403            }
404            Some(())
405        }
406
407        Op::Vectorize { elements } => {
408            let vals: Vec<String> = elements.iter().map(|e| ctx.get(e).to_string()).collect();
409            if matches!(uop.dtype(), DType::Ptr { .. }) {
410                // Ptr types can't be vectorized in C (no compound literal for pointers).
411                // All elements should be the same scalar pointer — use the first one.
412                ctx.emit_expr(uop, vals[0].clone(), "vec", kernel);
413            } else {
414                let out_dtype = c_dtype(&uop.dtype());
415                let expr = format!("({out_dtype}){{{}}}", vals.join(", "));
416                ctx.emit_expr(uop, expr, "vec", kernel);
417            }
418            Some(())
419        }
420
421        Op::Cat { sources } => {
422            render_cat(uop, sources, ctx, kernel);
423            Some(())
424        }
425
426        Op::PtrCat { .. } => {
427            panic!(
428                "PtrCat must be eliminated before codegen (devectorize should distribute it into scalar loads/stores)"
429            );
430        }
431
432        Op::Wmma { a, b, c, metadata } => {
433            let a_val = ctx.get(a).to_string();
434            let b_val = ctx.get(b).to_string();
435            let c_val = ctx.get(c).to_string();
436            let expr = format!("__{name}({a_val}, {b_val}, {c_val})", name = metadata.name);
437            ctx.emit_expr(uop, expr, "wmma", kernel);
438            Some(())
439        }
440
441        Op::CustomI { deps, code } => {
442            let args: Vec<String> = deps.iter().map(|dep| ctx.get(dep).to_string()).collect();
443            let expr = match format_custom_template_strict(code, &args) {
444                Ok(s) => s,
445                Err(e) => {
446                    ctx.set_invalid_graph(format!("CUSTOMI template error on uop {}: {e}", uop.id));
447                    return None;
448                }
449            };
450            // CUSTOMI is always inline in Tinygrad's cstyle renderer.
451            ctx.register(uop.id, expr);
452            Some(())
453        }
454
455        Op::Custom { deps, code } => {
456            let args: Vec<String> = deps.iter().map(|dep| ctx.get(dep).to_string()).collect();
457            let rendered = match format_custom_template_strict(code, &args) {
458                Ok(s) => s,
459                Err(e) => {
460                    ctx.set_invalid_graph(format!("CUSTOM template error on uop {}: {e}", uop.id));
461                    return None;
462                }
463            };
464            let indent = ctx.indent();
465
466            if uop.dtype() == DType::Void {
467                let stmt = if rendered.trim_end().ends_with(';') { rendered } else { format!("{rendered};") };
468                kernel.push(format!("{indent}{stmt}"));
469                ctx.register(uop.id, String::new());
470            } else {
471                let name = ctx.next_name("custom");
472                let dtype = c_dtype(&uop.dtype());
473                if ctx.scope_escaping.contains(&uop.id) {
474                    ctx.hoisted_declarations.push(format!("  {dtype} {name};"));
475                    kernel.push(format!("{indent}{name} = {rendered};"));
476                } else {
477                    kernel.push(format!("{indent}{dtype} {name} = {rendered};"));
478                }
479                ctx.register(uop.id, name);
480            }
481            Some(())
482        }
483
484        Op::Contract { src, .. } | Op::Unroll { src, .. } | Op::Detach { src } => {
485            let s = ctx.get(src).to_string();
486            ctx.register(uop.id, s);
487            None
488        }
489
490        Op::After { passthrough, .. } => {
491            assert!(
492                !matches!(passthrough.op(), Op::Group { .. }),
493                "BUG: AFTER passthrough is GROUP (id={}). AFTER tree:\n{}",
494                passthrough.id,
495                uop.tree()
496            );
497            let s = ctx.get(passthrough).to_string();
498            ctx.register(uop.id, s);
499            None
500        }
501
502        Op::Bind { var, value } => {
503            let v = ctx.get(value).to_string();
504            ctx.register(var.id, v);
505            None
506        }
507
508        Op::If { condition, .. } => {
509            let cond = ctx.get(condition).to_string();
510            let indent = ctx.indent();
511            kernel.push(format!("{indent}if ({cond}) {{"));
512            ctx.push_indent();
513            Some(())
514        }
515
516        Op::EndIf { .. } => {
517            ctx.pop_indent();
518            let indent = ctx.indent();
519            kernel.push(format!("{indent}}}"));
520            Some(())
521        }
522
523        _ => {
524            let indent = ctx.indent();
525            kernel.push(format!("{indent}/* UNSUPPORTED: {:?} */", uop.op().as_ref()));
526            None
527        }
528    }
529}
530
531/// Render a binary operation as a C expression.
532fn render_binary(op: BinaryOp, l: &str, r: &str, dtype: &DType) -> String {
533    match op {
534        BinaryOp::Add => format!("({l} + {r})"),
535        BinaryOp::Sub => format!("({l} - {r})"),
536        BinaryOp::Mul => format!("({l} * {r})"),
537        BinaryOp::Fdiv => format!("({l} / {r})"),
538        BinaryOp::Idiv => format!("({l} / {r})"),
539        BinaryOp::Mod => {
540            if dtype.is_float() {
541                format!("{}({l}, {r})", c_math_fn("__builtin_fmod", dtype))
542            } else {
543                format!("({l} % {r})")
544            }
545        }
546        BinaryOp::Max => {
547            if dtype.is_float() {
548                format!("{}({l}, {r})", c_math_fn("__builtin_fmax", dtype))
549            } else {
550                format!("({l} > {r} ? {l} : {r})")
551            }
552        }
553        BinaryOp::Lt => format!("({l} < {r})"),
554        BinaryOp::Le => format!("({l} <= {r})"),
555        BinaryOp::Gt => format!("({l} > {r})"),
556        BinaryOp::Ge => format!("({l} >= {r})"),
557        BinaryOp::Eq => format!("({l} == {r})"),
558        BinaryOp::Ne => format!("({l} != {r})"),
559        BinaryOp::And => format!("({l} & {r})"),
560        BinaryOp::Or => format!("({l} | {r})"),
561        BinaryOp::Xor => format!("({l} ^ {r})"),
562        BinaryOp::Shl => format!("({l} << {r})"),
563        BinaryOp::Shr => format!("({l} >> {r})"),
564        BinaryOp::Pow => {
565            if dtype.is_float() {
566                format!("{}({l}, {r})", c_math_fn("__builtin_pow", dtype))
567            } else {
568                // Integer pow via cast to double
569                format!("(({})__builtin_pow((double){l}, (double){r}))", c_dtype(&DType::Scalar(dtype.base())))
570            }
571        }
572        BinaryOp::Threefry => format!("({l} ^ {r})"),
573    }
574}
575
576/// Render a unary operation as a C expression.
577fn render_unary(op: UnaryOp, s: &str, dtype: &DType) -> String {
578    match op {
579        UnaryOp::Neg => {
580            format!("(-{s})")
581        }
582        UnaryOp::Not => {
583            if dtype.is_bool() {
584                format!("(!{s})")
585            } else {
586                format!("(~{s})")
587            }
588        }
589        UnaryOp::Abs => {
590            if dtype.is_float() {
591                format!("{}({s})", c_math_fn("__builtin_fabs", dtype))
592            } else {
593                format!("({s} < 0 ? -{s} : {s})")
594            }
595        }
596        UnaryOp::Sqrt => format!("{}({s})", c_math_fn("__builtin_sqrt", dtype)),
597        UnaryOp::Rsqrt => {
598            let one = if matches!(dtype.base(), ScalarDType::Float64) { "1.0" } else { "1.0f" };
599            format!("({one} / {}({s}))", c_math_fn("__builtin_sqrt", dtype))
600        }
601        UnaryOp::Reciprocal => {
602            let one = if matches!(dtype.base(), ScalarDType::Float64) { "1.0" } else { "1.0f" };
603            format!("({one} / {s})")
604        }
605        UnaryOp::Exp => format!("{}({s})", c_math_fn("__builtin_exp", dtype)),
606        UnaryOp::Exp2 => format!("{}({s})", c_math_fn("__builtin_exp2", dtype)),
607        UnaryOp::Log => format!("{}({s})", c_math_fn("__builtin_log", dtype)),
608        UnaryOp::Log2 => format!("{}({s})", c_math_fn("__builtin_log2", dtype)),
609        UnaryOp::Sin => format!("{}({s})", c_math_fn("__builtin_sin", dtype)),
610        UnaryOp::Cos => format!("{}({s})", c_math_fn("__builtin_cos", dtype)),
611        UnaryOp::Tan => format!("{}({s})", c_math_fn("__builtin_tan", dtype)),
612        UnaryOp::Floor => format!("{}({s})", c_math_fn("__builtin_floor", dtype)),
613        UnaryOp::Ceil => format!("{}({s})", c_math_fn("__builtin_ceil", dtype)),
614        UnaryOp::Trunc => format!("{}({s})", c_math_fn("__builtin_trunc", dtype)),
615        UnaryOp::Round => format!("{}({s})", c_math_fn("__builtin_rint", dtype)),
616        UnaryOp::Erf => format!("{}({s})", c_math_fn("__builtin_erf", dtype)),
617        UnaryOp::Sign => {
618            if dtype.is_float() {
619                let zero = if matches!(dtype.base(), ScalarDType::Float64) { "0.0" } else { "0.0f" };
620                format!("(({s} > {zero}) - ({s} < {zero}))")
621            } else {
622                format!("(({s} > 0) - ({s} < 0))")
623            }
624        }
625        UnaryOp::Square => format!("({s} * {s})"),
626    }
627}
628
629/// Render a reduce accumulation statement.
630fn render_reduce_accumulate(op: ReduceOp, acc: &str, val: &str, dtype: &DType) -> String {
631    match op {
632        ReduceOp::Add => format!("{acc} += {val};"),
633        ReduceOp::Mul => format!("{acc} *= {val};"),
634        ReduceOp::Max => {
635            if dtype.is_float() {
636                format!("{acc} = {}({acc}, {val});", c_math_fn("__builtin_fmax", dtype))
637            } else {
638                format!("{acc} = ({acc} > {val} ? {acc} : {val});")
639            }
640        }
641        ReduceOp::Min => {
642            if dtype.is_float() {
643                format!("{acc} = {}({acc}, {val});", c_math_fn("__builtin_fmin", dtype))
644            } else {
645                format!("{acc} = ({acc} < {val} ? {acc} : {val});")
646            }
647        }
648    }
649}
650
651/// Render a Cat operation (concatenate vectors).
652fn render_cat(uop: &Arc<UOp>, sources: &[Arc<UOp>], ctx: &mut CContext, kernel: &mut Vec<String>) {
653    let out_dtype = c_dtype(&uop.dtype());
654    let mut elements = Vec::new();
655
656    for src in sources {
657        let src_val = ctx.get(src).to_string();
658        let src_vcount = src.dtype().vcount();
659        if src_vcount == 1 {
660            elements.push(src_val);
661        } else {
662            for i in 0..src_vcount {
663                elements.push(format!("{src_val}[{i}]"));
664            }
665        }
666    }
667
668    let expr = format!("({out_dtype}){{{}}}", elements.join(", "));
669    ctx.emit_expr(uop, expr, "cat", kernel);
670}
671
672/// Count references for each UOp ID in the linearized stream.
673/// Used to determine which values should be inlined vs declared.
674pub fn count_references(nodes: &[Arc<UOp>]) -> HashMap<u64, usize> {
675    let mut counts: HashMap<u64, usize> = HashMap::new();
676    for node in nodes {
677        for child in node.op().children() {
678            *counts.entry(child.id).or_insert(0) += 1;
679        }
680    }
681    counts
682}