Skip to main content

svod_codegen/llvm/common/
ctx.rs

1//! Render context for LLVM IR text generation.
2//!
3//! Maps UOp IDs to LLVM variable names and manages naming.
4//! Shared between CPU and GPU backends.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use svod_ir::{ConstValue, Op, prelude::*};
10
11use super::types::{lconst, ldt};
12
13/// Pending reduce load info.
14pub struct PendingReduce {
15    pub acc_ptr: String,
16    pub dtype: String,
17}
18
19/// Maps UOp ID → LLVM variable name.
20pub struct RenderContext {
21    names: HashMap<u64, String>,
22    range_values: HashMap<usize, String>,
23    counter: usize,
24    /// Pending reduce final loads: reduce_id -> (acc_ptr, dtype)
25    pending_reduces: HashMap<u64, PendingReduce>,
26    /// Stack of currently open RANGE axis_ids (for correct END footer ordering).
27    /// Pushed on RANGE emission, popped on END emission.
28    range_stack: Vec<usize>,
29    /// Side-channel error set by `render_uop` when it detects a graph invariant
30    /// violation. The render loop drains this after each call and propagates as
31    /// a typed [`crate::Error`].
32    pending_error: Option<crate::Error>,
33}
34
35impl RenderContext {
36    pub fn new() -> Self {
37        Self {
38            names: HashMap::new(),
39            range_values: HashMap::new(),
40            counter: 0,
41            pending_reduces: HashMap::new(),
42            range_stack: Vec::new(),
43            pending_error: None,
44        }
45    }
46
47    /// Record an `InvalidGraph` error from a renderer op handler.
48    pub fn set_invalid_graph(&mut self, reason: impl Into<String>) {
49        if self.pending_error.is_none() {
50            self.pending_error = Some(crate::Error::InvalidGraph { reason: reason.into() });
51        }
52    }
53
54    /// Drain any error recorded via [`Self::set_invalid_graph`].
55    pub fn take_error(&mut self) -> Option<crate::Error> {
56        self.pending_error.take()
57    }
58
59    /// Get or create variable name for UOp.
60    ///
61    /// For constants, returns literal value.
62    /// For definitions, returns argument name.
63    /// For other ops, returns a generated variable name.
64    pub fn name(&mut self, uop: &Arc<UOp>) -> String {
65        if let Some(name) = self.names.get(&uop.id) {
66            return name.clone();
67        }
68
69        let name = match uop.op() {
70            Op::Const(cv) => lconst(&cv.0, &uop.dtype()),
71            Op::VConst { values } => self.render_vconst(values, uop),
72            Op::Param { slot, device: None, .. } => format!("%data{slot}"),
73            Op::DefineLocal(id) => format!("%local{id}"),
74            Op::DefineVar { name, .. } => format!("%{name}"),
75            Op::DefineReg { .. } => {
76                let n = format!("%reg{}", self.counter);
77                self.counter += 1;
78                n
79            }
80            Op::Range { axis_id, .. } => {
81                // Range variables are named by axis_id
82                format!("%r{}", axis_id.value())
83            }
84            _ => {
85                let n = format!("%v{}", self.counter);
86                self.counter += 1;
87                n
88            }
89        };
90
91        self.names.insert(uop.id, name.clone());
92        name
93    }
94
95    /// Render a vector constant.
96    fn render_vconst(&self, values: &[ConstValue], uop: &Arc<UOp>) -> String {
97        let scalar_type = ldt(&uop.dtype().scalar_dtype());
98
99        // Format as LLVM vector constant: <type val, type val, ...>
100        let elements: Vec<String> = values
101            .iter()
102            .map(|v| {
103                let val = lconst(v, &uop.dtype());
104                format!("{scalar_type} {val}")
105            })
106            .collect();
107
108        format!("<{}>", elements.join(", "))
109    }
110
111    /// Get existing name (panics if not found).
112    pub fn get(&self, uop: &Arc<UOp>) -> &str {
113        self.names
114            .get(&uop.id)
115            .map(|s| s.as_str())
116            .unwrap_or_else(|| panic!("UOp {} ({:?}) not in context", uop.id, uop.op()))
117    }
118
119    /// Try to get existing name.
120    pub fn try_get(&self, uop: &Arc<UOp>) -> Option<&str> {
121        self.names.get(&uop.id).map(|s| s.as_str())
122    }
123
124    /// Check if a UOp is already registered.
125    pub fn contains(&self, id: u64) -> bool {
126        self.names.contains_key(&id)
127    }
128
129    /// Alias one ID to another's name.
130    pub fn alias(&mut self, id: u64, name: String) {
131        self.names.insert(id, name);
132    }
133
134    /// Pre-register a name for a UOp ID.
135    pub fn register(&mut self, id: u64, name: String) {
136        self.names.insert(id, name);
137    }
138
139    /// Get current variable counter.
140    pub fn counter(&self) -> usize {
141        self.counter
142    }
143
144    /// Register a range value by axis_id.
145    pub fn register_range(&mut self, axis_id: usize, name: String) {
146        self.range_values.insert(axis_id, name);
147    }
148
149    /// Get a range value by axis_id.
150    pub fn get_range(&self, axis_id: usize) -> Option<&str> {
151        self.range_values.get(&axis_id).map(|s| s.as_str())
152    }
153
154    /// Push a range axis_id onto the open-range stack (called during RANGE codegen).
155    pub fn push_range(&mut self, axis_id: usize) {
156        self.range_stack.push(axis_id);
157    }
158
159    /// Pop the innermost open range axis_id (called during END codegen).
160    pub fn pop_range(&mut self) -> Option<usize> {
161        self.range_stack.pop()
162    }
163
164    /// Register a pending reduce final load.
165    pub fn register_reduce_pending(&mut self, reduce_id: u64, acc_ptr: String, dtype: String) {
166        self.pending_reduces.insert(reduce_id, PendingReduce { acc_ptr, dtype });
167    }
168
169    /// Take all pending reduces (empties map).
170    pub fn take_pending_reduces(&mut self) -> HashMap<u64, PendingReduce> {
171        std::mem::take(&mut self.pending_reduces)
172    }
173
174    /// Check if there are pending reduces.
175    pub fn has_pending_reduces(&self) -> bool {
176        !self.pending_reduces.is_empty()
177    }
178}
179
180impl Default for RenderContext {
181    fn default() -> Self {
182        Self::new()
183    }
184}