Skip to main content

snc_core/
codegen.rs

1use crate::ir::*;
2use std::fmt::Write;
3
4/// Generate complete Rust source from the IR.
5pub fn generate(ir: &SeqIR) -> String {
6    let mut out = String::with_capacity(4096);
7    gen_header(&mut out, ir);
8    gen_constants(&mut out, ir);
9    gen_vars_struct(&mut out, ir);
10    gen_program_vars_impl(&mut out, ir);
11    gen_meta_impl(&mut out, ir);
12    gen_state_sets(&mut out, ir);
13    gen_main(&mut out, ir);
14    out
15}
16
17fn gen_header(out: &mut String, _ir: &SeqIR) {
18    writeln!(out, "//! Generated by snc — do not edit.").unwrap();
19    writeln!(out).unwrap();
20    writeln!(out, "use epics_seq::prelude::*;").unwrap();
21    writeln!(out).unwrap();
22}
23
24fn gen_constants(out: &mut String, ir: &SeqIR) {
25    // Channel ID constants
26    for ch in &ir.channels {
27        let name = ch.var_name.to_uppercase();
28        writeln!(out, "const CH_{name}: usize = {};", ch.id).unwrap();
29    }
30    if !ir.channels.is_empty() {
31        writeln!(out).unwrap();
32    }
33
34    // Event flag ID constants
35    for ef in &ir.event_flags {
36        let name = ef.name.to_uppercase();
37        writeln!(out, "const EF_{name}: usize = {};", ef.id).unwrap();
38    }
39    if !ir.event_flags.is_empty() {
40        writeln!(out).unwrap();
41    }
42
43    // State ID constants per state set
44    for ss in &ir.state_sets {
45        let ss_upper = ss.name.to_uppercase();
46        for state in &ss.states {
47            let st_upper = state.name.to_uppercase();
48            writeln!(out, "const {ss_upper}_{st_upper}: usize = {};", state.id).unwrap();
49        }
50    }
51    if !ir.state_sets.is_empty() {
52        writeln!(out).unwrap();
53    }
54}
55
56fn gen_vars_struct(out: &mut String, ir: &SeqIR) {
57    let name = &ir.program_name;
58    let struct_name = format!("{name}Vars");
59
60    writeln!(out, "#[derive(Clone)]").unwrap();
61    writeln!(out, "struct {struct_name} {{").unwrap();
62    for var in &ir.variables {
63        let rust_type = var.var_type.rust_type();
64        writeln!(out, "    {}: {rust_type},", var.name).unwrap();
65    }
66    writeln!(out, "}}").unwrap();
67    writeln!(out).unwrap();
68}
69
70fn gen_program_vars_impl(out: &mut String, ir: &SeqIR) {
71    let struct_name = format!("{}Vars", ir.program_name);
72
73    writeln!(out, "impl ProgramVars for {struct_name} {{").unwrap();
74
75    // get_channel_value
76    writeln!(
77        out,
78        "    fn get_channel_value(&self, ch_id: usize) -> EpicsValue {{"
79    )
80    .unwrap();
81    writeln!(out, "        match ch_id {{").unwrap();
82    for var in &ir.variables {
83        if let Some(ch_id) = var.channel_id {
84            let expr = var.var_type.to_epics_value_expr(&format!("self.{}", var.name));
85            writeln!(out, "            {ch_id} => {expr},").unwrap();
86        }
87    }
88    writeln!(out, "            _ => EpicsValue::Double(0.0),").unwrap();
89    writeln!(out, "        }}").unwrap();
90    writeln!(out, "    }}").unwrap();
91
92    // set_channel_value
93    writeln!(
94        out,
95        "    fn set_channel_value(&mut self, ch_id: usize, value: &EpicsValue) {{"
96    )
97    .unwrap();
98    writeln!(out, "        match ch_id {{").unwrap();
99    for var in &ir.variables {
100        if let Some(ch_id) = var.channel_id {
101            let expr = var.var_type.from_epics_value_expr("value");
102            writeln!(out, "            {ch_id} => self.{} = {expr},", var.name).unwrap();
103        }
104    }
105    writeln!(out, "            _ => {{}}").unwrap();
106    writeln!(out, "        }}").unwrap();
107    writeln!(out, "    }}").unwrap();
108
109    writeln!(out, "}}").unwrap();
110    writeln!(out).unwrap();
111}
112
113fn gen_meta_impl(out: &mut String, ir: &SeqIR) {
114    let struct_name = format!("{}Meta", ir.program_name);
115    let vars_name = format!("{}Vars", ir.program_name);
116
117    writeln!(out, "struct {struct_name};").unwrap();
118    writeln!(out).unwrap();
119    writeln!(out, "impl ProgramMeta for {struct_name} {{").unwrap();
120    writeln!(
121        out,
122        "    const NUM_CHANNELS: usize = {};",
123        ir.channels.len()
124    )
125    .unwrap();
126    writeln!(
127        out,
128        "    const NUM_EVENT_FLAGS: usize = {};",
129        ir.event_flags.len()
130    )
131    .unwrap();
132    writeln!(
133        out,
134        "    const NUM_STATE_SETS: usize = {};",
135        ir.state_sets.len()
136    )
137    .unwrap();
138    writeln!(out).unwrap();
139
140    // channel_defs()
141    writeln!(out, "    fn channel_defs() -> Vec<ChannelDef> {{").unwrap();
142    writeln!(out, "        vec![").unwrap();
143    for ch in &ir.channels {
144        let monitored = ch.monitored;
145        let sync_ef = match ch.sync_ef {
146            Some(id) => format!("Some({id})"),
147            None => "None".to_string(),
148        };
149        writeln!(out, "            ChannelDef {{").unwrap();
150        writeln!(out, "                var_name: \"{}\".into(),", ch.var_name).unwrap();
151        writeln!(out, "                pv_name: \"{}\".into(),", ch.pv_name).unwrap();
152        writeln!(out, "                monitored: {monitored},").unwrap();
153        writeln!(out, "                sync_ef: {sync_ef},").unwrap();
154        writeln!(out, "            }},").unwrap();
155    }
156    writeln!(out, "        ]").unwrap();
157    writeln!(out, "    }}").unwrap();
158    writeln!(out).unwrap();
159
160    // event_flag_sync_map()
161    writeln!(out, "    fn event_flag_sync_map() -> Vec<Vec<usize>> {{").unwrap();
162    writeln!(out, "        vec![").unwrap();
163    for ef in &ir.event_flags {
164        let chs: Vec<String> = ef.synced_channels.iter().map(|c| c.to_string()).collect();
165        writeln!(out, "            vec![{}],", chs.join(", ")).unwrap();
166    }
167    writeln!(out, "        ]").unwrap();
168    writeln!(out, "    }}").unwrap();
169
170    writeln!(out, "}}").unwrap();
171    writeln!(out).unwrap();
172
173    // Make the vars struct name and meta struct name available
174    // (they're referenced by gen_main)
175    let _ = vars_name;
176}
177
178fn gen_state_sets(out: &mut String, ir: &SeqIR) {
179    for ss in &ir.state_sets {
180        gen_one_state_set(out, ir, ss);
181    }
182}
183
184fn gen_one_state_set(out: &mut String, ir: &SeqIR, ss: &IRStateSet) {
185    let struct_name = format!("{}Vars", ir.program_name);
186    let fn_name = &ss.name;
187    let state_names: Vec<&str> = ss.states.iter().map(|s| s.name.as_str()).collect();
188
189    writeln!(
190        out,
191        "async fn {fn_name}(mut ctx: StateSetContext<{struct_name}>) -> SeqResult<()> {{"
192    )
193    .unwrap();
194
195    // State name array for debugging
196    let names_str: Vec<String> = state_names.iter().map(|n| format!("\"{n}\"")).collect();
197    writeln!(
198        out,
199        "    let _state_names = [{}];",
200        names_str.join(", ")
201    )
202    .unwrap();
203    writeln!(out).unwrap();
204
205    // SS-local variables
206    for var in &ss.local_vars {
207        let default_val = var.var_type.default_value();
208        let init = var
209            .init_value
210            .as_deref()
211            .unwrap_or(&default_val);
212        writeln!(
213            out,
214            "    let mut {}: {} = {init};",
215            var.name,
216            var.var_type.rust_type()
217        )
218        .unwrap();
219    }
220    if !ss.local_vars.is_empty() {
221        writeln!(out).unwrap();
222    }
223
224    writeln!(out, "    ctx.enter_state(0);").unwrap();
225    writeln!(out, "    ctx.wakeup().notify_one();").unwrap();
226    writeln!(out).unwrap();
227    writeln!(out, "    loop {{").unwrap();
228    writeln!(out, "        if ctx.is_shutdown() {{ break; }}").unwrap();
229    writeln!(out).unwrap();
230
231    // Entry block per state
232    writeln!(out, "        if ctx.should_run_entry() {{").unwrap();
233    writeln!(out, "            match ctx.current_state() {{").unwrap();
234    for state in &ss.states {
235        if let Some(entry) = &state.entry {
236            let ss_upper = ss.name.to_uppercase();
237            let st_upper = state.name.to_uppercase();
238            writeln!(out, "                {ss_upper}_{st_upper} => {{").unwrap();
239            for line in entry.code.lines() {
240                writeln!(out, "                    {line}").unwrap();
241            }
242            writeln!(out, "                }}").unwrap();
243        }
244    }
245    writeln!(out, "                _ => {{}}").unwrap();
246    writeln!(out, "            }}").unwrap();
247    writeln!(out, "        }}").unwrap();
248    writeln!(out).unwrap();
249
250    // Inner loop
251    writeln!(out, "        loop {{").unwrap();
252    writeln!(out, "            ctx.wait_for_wakeup().await;").unwrap();
253    writeln!(out, "            if ctx.is_shutdown() {{ return Ok(()); }}").unwrap();
254    writeln!(out, "            ctx.reset_wakeup();").unwrap();
255    writeln!(out, "            ctx.sync_dirty_vars();").unwrap();
256    writeln!(out).unwrap();
257
258    writeln!(out, "            match ctx.current_state() {{").unwrap();
259    for state in &ss.states {
260        let ss_upper = ss.name.to_uppercase();
261        let st_upper = state.name.to_uppercase();
262        writeln!(out, "                {ss_upper}_{st_upper} => {{").unwrap();
263
264        for (i, trans) in state.transitions.iter().enumerate() {
265            let kw = if i == 0 { "if" } else { "} else if" };
266            let cond = trans.condition.as_deref().unwrap_or("true");
267            writeln!(out, "                    {kw} {cond} {{").unwrap();
268
269            // Action code
270            for line in trans.action.code.lines() {
271                writeln!(out, "                        {line}").unwrap();
272            }
273
274            // Transition
275            match trans.target_state {
276                Some(target) => {
277                    // Find target state name
278                    let target_name = ss.states.get(target).map_or("?", |s| &s.name);
279                    let target_upper = target_name.to_uppercase();
280                    writeln!(
281                        out,
282                        "                        ctx.transition_to({ss_upper}_{target_upper});"
283                    )
284                    .unwrap();
285                }
286                None => {
287                    writeln!(out, "                        return Ok(());").unwrap();
288                }
289            }
290        }
291        if !state.transitions.is_empty() {
292            writeln!(out, "                    }}").unwrap();
293        }
294        writeln!(out, "                }}").unwrap();
295    }
296    writeln!(out, "                _ => return Err(SeqError::InvalidStateId(ctx.current_state())),").unwrap();
297    writeln!(out, "            }}").unwrap();
298    writeln!(out).unwrap();
299    writeln!(out, "            if ctx.has_transition() {{ break; }}").unwrap();
300    writeln!(out, "        }}").unwrap();
301    writeln!(out).unwrap();
302
303    // Exit block per state
304    writeln!(out, "        if ctx.should_run_exit() {{").unwrap();
305    writeln!(out, "            match ctx.current_state() {{").unwrap();
306    for state in &ss.states {
307        if let Some(exit) = &state.exit {
308            let ss_upper = ss.name.to_uppercase();
309            let st_upper = state.name.to_uppercase();
310            writeln!(out, "                {ss_upper}_{st_upper} => {{").unwrap();
311            for line in exit.code.lines() {
312                writeln!(out, "                    {line}").unwrap();
313            }
314            writeln!(out, "                }}").unwrap();
315        }
316    }
317    writeln!(out, "                _ => {{}}").unwrap();
318    writeln!(out, "            }}").unwrap();
319    writeln!(out, "        }}").unwrap();
320    writeln!(out).unwrap();
321
322    writeln!(out, "        if let Some(next) = ctx.take_transition() {{").unwrap();
323    writeln!(out, "            ctx.enter_state(next);").unwrap();
324    writeln!(out, "            ctx.wakeup().notify_one();").unwrap();
325    writeln!(out, "        }}").unwrap();
326    writeln!(out, "    }}").unwrap();
327    writeln!(out).unwrap();
328    writeln!(out, "    Ok(())").unwrap();
329    writeln!(out, "}}").unwrap();
330    writeln!(out).unwrap();
331}
332
333fn gen_main(out: &mut String, ir: &SeqIR) {
334    let name = &ir.program_name;
335    let vars_name = format!("{name}Vars");
336    let meta_name = format!("{name}Meta");
337
338    writeln!(out, "#[tokio::main]").unwrap();
339    writeln!(
340        out,
341        "async fn main() -> Result<(), Box<dyn std::error::Error>> {{"
342    )
343    .unwrap();
344    writeln!(out, "    tracing_subscriber::fmt::init();").unwrap();
345    writeln!(out).unwrap();
346    writeln!(
347        out,
348        "    let macro_str = std::env::args().nth(1).unwrap_or_default();"
349    )
350    .unwrap();
351    writeln!(out).unwrap();
352
353    // Initial variable values
354    writeln!(out, "    let initial = {vars_name} {{").unwrap();
355    for var in &ir.variables {
356        let default_val = var.var_type.default_value();
357        let init = var
358            .init_value
359            .as_deref()
360            .unwrap_or(&default_val);
361        writeln!(out, "        {}: {init},", var.name).unwrap();
362    }
363    writeln!(out, "    }};").unwrap();
364    writeln!(out).unwrap();
365
366    writeln!(
367        out,
368        "    ProgramBuilder::<{vars_name}, {meta_name}>::new(\"{name}\", initial)"
369    )
370    .unwrap();
371    writeln!(out, "        .macros(&macro_str)").unwrap();
372    for ss in &ir.state_sets {
373        let fn_name = &ss.name;
374        writeln!(
375            out,
376            "        .add_ss(Box::new(|ctx| Box::pin({fn_name}(ctx))))"
377        )
378        .unwrap();
379    }
380    writeln!(out, "        .run()").unwrap();
381    writeln!(out, "        .await?;").unwrap();
382    writeln!(out).unwrap();
383    writeln!(out, "    Ok(())").unwrap();
384    writeln!(out, "}}").unwrap();
385}