svod_codegen/llvm/common/
ctx.rs1use std::collections::HashMap;
7use std::sync::Arc;
8
9use svod_ir::{ConstValue, Op, prelude::*};
10
11use super::types::{lconst, ldt};
12
13pub struct PendingReduce {
15 pub acc_ptr: String,
16 pub dtype: String,
17}
18
19pub struct RenderContext {
21 names: HashMap<u64, String>,
22 range_values: HashMap<usize, String>,
23 counter: usize,
24 pending_reduces: HashMap<u64, PendingReduce>,
26 range_stack: Vec<usize>,
29 pending_error: Option<crate::Error>,
33}
34
35impl RenderContext {
36 pub fn new() -> Self {
37 Self {
38 names: HashMap::new(),
39 range_values: HashMap::new(),
40 counter: 0,
41 pending_reduces: HashMap::new(),
42 range_stack: Vec::new(),
43 pending_error: None,
44 }
45 }
46
47 pub fn set_invalid_graph(&mut self, reason: impl Into<String>) {
49 if self.pending_error.is_none() {
50 self.pending_error = Some(crate::Error::InvalidGraph { reason: reason.into() });
51 }
52 }
53
54 pub fn take_error(&mut self) -> Option<crate::Error> {
56 self.pending_error.take()
57 }
58
59 pub fn name(&mut self, uop: &Arc<UOp>) -> String {
65 if let Some(name) = self.names.get(&uop.id) {
66 return name.clone();
67 }
68
69 let name = match uop.op() {
70 Op::Const(cv) => lconst(&cv.0, &uop.dtype()),
71 Op::VConst { values } => self.render_vconst(values, uop),
72 Op::Param { slot, device: None, .. } => format!("%data{slot}"),
73 Op::DefineLocal(id) => format!("%local{id}"),
74 Op::DefineVar { name, .. } => format!("%{name}"),
75 Op::DefineReg { .. } => {
76 let n = format!("%reg{}", self.counter);
77 self.counter += 1;
78 n
79 }
80 Op::Range { axis_id, .. } => {
81 format!("%r{}", axis_id.value())
83 }
84 _ => {
85 let n = format!("%v{}", self.counter);
86 self.counter += 1;
87 n
88 }
89 };
90
91 self.names.insert(uop.id, name.clone());
92 name
93 }
94
95 fn render_vconst(&self, values: &[ConstValue], uop: &Arc<UOp>) -> String {
97 let scalar_type = ldt(&uop.dtype().scalar_dtype());
98
99 let elements: Vec<String> = values
101 .iter()
102 .map(|v| {
103 let val = lconst(v, &uop.dtype());
104 format!("{scalar_type} {val}")
105 })
106 .collect();
107
108 format!("<{}>", elements.join(", "))
109 }
110
111 pub fn get(&self, uop: &Arc<UOp>) -> &str {
113 self.names
114 .get(&uop.id)
115 .map(|s| s.as_str())
116 .unwrap_or_else(|| panic!("UOp {} ({:?}) not in context", uop.id, uop.op()))
117 }
118
119 pub fn try_get(&self, uop: &Arc<UOp>) -> Option<&str> {
121 self.names.get(&uop.id).map(|s| s.as_str())
122 }
123
124 pub fn contains(&self, id: u64) -> bool {
126 self.names.contains_key(&id)
127 }
128
129 pub fn alias(&mut self, id: u64, name: String) {
131 self.names.insert(id, name);
132 }
133
134 pub fn register(&mut self, id: u64, name: String) {
136 self.names.insert(id, name);
137 }
138
139 pub fn counter(&self) -> usize {
141 self.counter
142 }
143
144 pub fn register_range(&mut self, axis_id: usize, name: String) {
146 self.range_values.insert(axis_id, name);
147 }
148
149 pub fn get_range(&self, axis_id: usize) -> Option<&str> {
151 self.range_values.get(&axis_id).map(|s| s.as_str())
152 }
153
154 pub fn push_range(&mut self, axis_id: usize) {
156 self.range_stack.push(axis_id);
157 }
158
159 pub fn pop_range(&mut self) -> Option<usize> {
161 self.range_stack.pop()
162 }
163
164 pub fn register_reduce_pending(&mut self, reduce_id: u64, acc_ptr: String, dtype: String) {
166 self.pending_reduces.insert(reduce_id, PendingReduce { acc_ptr, dtype });
167 }
168
169 pub fn take_pending_reduces(&mut self) -> HashMap<u64, PendingReduce> {
171 std::mem::take(&mut self.pending_reduces)
172 }
173
174 pub fn has_pending_reduces(&self) -> bool {
176 !self.pending_reduces.is_empty()
177 }
178}
179
180impl Default for RenderContext {
181 fn default() -> Self {
182 Self::new()
183 }
184}