Skip to main content

wave_emu/
core.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Core simulation engine. Manages waves within a workgroup, coordinates barrier
5//!
6//! synchronization, schedules wave execution, and drives the instruction executor.
7//! A single Core instance handles one workgroup's complete execution.
8
9use crate::barrier::BarrierManager;
10use crate::executor::{Executor, StepResult};
11use crate::memory::{DeviceMemory, LocalMemory};
12use crate::scheduler::Scheduler;
13use crate::stats::ExecutionStats;
14use crate::wave::Wave;
15use crate::EmulatorConfig;
16use crate::EmulatorError;
17
18pub struct Core<'a> {
19    waves: Vec<Wave>,
20    local_memory: LocalMemory,
21    device_memory: &'a mut DeviceMemory,
22    scheduler: Scheduler,
23    barrier_manager: BarrierManager,
24    executor: Executor<'a>,
25    stats: ExecutionStats,
26    workgroup_id: [u32; 3],
27    max_instructions: u64,
28    instructions_executed: u64,
29}
30
31impl<'a> Core<'a> {
32    pub fn new(
33        config: &EmulatorConfig,
34        code: &'a [u8],
35        device_memory: &'a mut DeviceMemory,
36        workgroup_id: [u32; 3],
37    ) -> Self {
38        let total_threads =
39            config.workgroup_dim[0] * config.workgroup_dim[1] * config.workgroup_dim[2];
40        let num_waves = total_threads.div_ceil(config.wave_width);
41
42        let mut waves = Vec::with_capacity(num_waves as usize);
43        for wave_id in 0..num_waves {
44            let base_thread_index = wave_id * config.wave_width;
45            let wave = Wave::new(
46                config.wave_width,
47                config.register_count,
48                wave_id,
49                workgroup_id,
50                config.workgroup_dim,
51                config.grid_dim,
52                base_thread_index,
53                total_threads,
54                num_waves,
55            );
56            waves.push(wave);
57        }
58
59        // Apply initial register values to all threads in all waves
60        for wave in &mut waves {
61            for thread in &mut wave.threads {
62                for &(reg, val) in &config.initial_registers {
63                    thread.write_register(reg, val);
64                }
65            }
66        }
67
68        let local_memory = LocalMemory::new(config.local_memory_size);
69        let scheduler = Scheduler::new(num_waves as usize);
70        let barrier_manager = BarrierManager::new(num_waves);
71        let executor = Executor::new(code, config.trace_enabled, workgroup_id);
72
73        Self {
74            waves,
75            local_memory,
76            device_memory,
77            scheduler,
78            barrier_manager,
79            executor,
80            stats: ExecutionStats::new(),
81            workgroup_id,
82            max_instructions: config.max_instructions,
83            instructions_executed: 0,
84        }
85    }
86
87    pub fn run(&mut self) -> Result<ExecutionStats, EmulatorError> {
88        for wave in &self.waves {
89            self.stats.record_wave();
90            let _ = wave;
91        }
92
93        loop {
94            if self.scheduler.all_halted(&self.waves) {
95                break;
96            }
97
98            if self.barrier_manager.all_at_barrier() {
99                self.barrier_manager.check_and_release(&mut self.waves);
100                self.stats.record_barrier();
101                continue;
102            }
103
104            let wave_index = match self.scheduler.pick_next_ready(&self.waves) {
105                Some(idx) => idx,
106                None => {
107                    if self.scheduler.is_deadlocked(&self.waves) {
108                        return Err(EmulatorError::Deadlock {
109                            message: format!(
110                                "workgroup ({},{},{}) deadlocked - waves waiting at different barriers",
111                                self.workgroup_id[0], self.workgroup_id[1], self.workgroup_id[2]
112                            ),
113                        });
114                    }
115                    break;
116                }
117            };
118
119            let wave = &mut self.waves[wave_index];
120            let pc_before_step = wave.pc;
121
122            let result = self.executor.step(
123                wave,
124                &mut self.local_memory,
125                self.device_memory,
126                &mut self.stats,
127            )?;
128
129            self.instructions_executed += 1;
130
131            if self.max_instructions > 0 && self.instructions_executed > self.max_instructions {
132                return Err(EmulatorError::InstructionLimitExceeded {
133                    limit: self.max_instructions,
134                    executed: self.instructions_executed,
135                    pc: pc_before_step,
136                });
137            }
138
139            match result {
140                StepResult::Continue => {}
141                StepResult::Halted => {}
142                StepResult::Barrier => {
143                    self.barrier_manager.handle_barrier(wave);
144                }
145            }
146        }
147
148        Ok(self.stats.clone())
149    }
150
151    pub fn waves(&self) -> &[Wave] {
152        &self.waves
153    }
154
155    pub fn local_memory(&self) -> &LocalMemory {
156        &self.local_memory
157    }
158
159    pub fn stats(&self) -> &ExecutionStats {
160        &self.stats
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::decoder::SYNC_OP_FLAG;
168
169    fn encode_halt() -> Vec<u8> {
170        let word = ((0x3Fu32) << 26) | ((1u32) << 7) | u32::from(SYNC_OP_FLAG);
171        word.to_le_bytes().to_vec()
172    }
173
174    fn encode_mov_imm(rd: u8, imm: u32) -> Vec<u8> {
175        let word0 = ((0x3Fu32) << 26) | ((u32::from(rd) & 0x1F) << 21) | ((1u32) << 7) | 0x02;
176        let mut code = word0.to_le_bytes().to_vec();
177        code.extend_from_slice(&imm.to_le_bytes());
178        code
179    }
180
181    #[test]
182    fn test_core_single_wave_halt() {
183        let code = encode_halt();
184
185        let config = EmulatorConfig {
186            grid_dim: [1, 1, 1],
187            workgroup_dim: [32, 1, 1],
188            register_count: 32,
189            local_memory_size: 1024,
190            device_memory_size: 1024,
191            wave_width: 32,
192            ..Default::default()
193        };
194
195        let mut device_memory = DeviceMemory::new(1024);
196        let mut core = Core::new(&config, &code, &mut device_memory, [0, 0, 0]);
197
198        let stats = core.run().unwrap();
199        assert_eq!(stats.waves_executed, 1);
200        assert!(core.waves[0].is_halted());
201    }
202
203    #[test]
204    fn test_core_two_waves() {
205        let code = encode_halt();
206
207        let config = EmulatorConfig {
208            grid_dim: [1, 1, 1],
209            workgroup_dim: [64, 1, 1],
210            register_count: 32,
211            local_memory_size: 1024,
212            device_memory_size: 1024,
213            wave_width: 32,
214            ..Default::default()
215        };
216
217        let mut device_memory = DeviceMemory::new(1024);
218        let mut core = Core::new(&config, &code, &mut device_memory, [0, 0, 0]);
219
220        let stats = core.run().unwrap();
221        assert_eq!(stats.waves_executed, 2);
222        assert!(core.waves[0].is_halted());
223        assert!(core.waves[1].is_halted());
224    }
225
226    #[test]
227    fn test_core_mov_imm_then_halt() {
228        let mut code = encode_mov_imm(5, 0x12345678);
229        code.extend_from_slice(&encode_halt());
230
231        let config = EmulatorConfig {
232            grid_dim: [1, 1, 1],
233            workgroup_dim: [4, 1, 1],
234            register_count: 32,
235            local_memory_size: 1024,
236            device_memory_size: 1024,
237            wave_width: 4,
238            ..Default::default()
239        };
240
241        let mut device_memory = DeviceMemory::new(1024);
242        let mut core = Core::new(&config, &code, &mut device_memory, [0, 0, 0]);
243
244        core.run().unwrap();
245
246        for thread in &core.waves[0].threads {
247            assert_eq!(thread.read_register(5), 0x12345678);
248        }
249    }
250
251    #[test]
252    fn test_core_thread_ids() {
253        let code = encode_halt();
254
255        let config = EmulatorConfig {
256            grid_dim: [2, 2, 1],
257            workgroup_dim: [8, 4, 2],
258            register_count: 32,
259            local_memory_size: 1024,
260            device_memory_size: 1024,
261            wave_width: 32,
262            ..Default::default()
263        };
264
265        let mut device_memory = DeviceMemory::new(1024);
266        let core = Core::new(&config, &code, &mut device_memory, [1, 0, 0]);
267
268        assert_eq!(
269            core.waves[0].threads[0].special_registers.workgroup_id,
270            [1, 0, 0]
271        );
272        assert_eq!(
273            core.waves[0].threads[0].special_registers.thread_id,
274            [0, 0, 0]
275        );
276        assert_eq!(
277            core.waves[0].threads[8].special_registers.thread_id,
278            [0, 1, 0]
279        );
280    }
281
282    fn encode_mov_sr(rd: u8, sr_index: u8) -> Vec<u8> {
283        let word = ((0x3Fu32) << 26)
284            | ((u32::from(rd) & 0x1F) << 21)
285            | ((u32::from(sr_index) & 0x1F) << 16)
286            | ((2u32) << 7)
287            | 0x02;
288        word.to_le_bytes().to_vec()
289    }
290
291    #[test]
292    fn test_core_mov_sr_lane_id() {
293        let mut code = encode_mov_sr(0, 4);
294        code.extend_from_slice(&encode_halt());
295
296        let config = EmulatorConfig {
297            grid_dim: [1, 1, 1],
298            workgroup_dim: [4, 1, 1],
299            register_count: 32,
300            local_memory_size: 1024,
301            device_memory_size: 1024,
302            wave_width: 4,
303            ..Default::default()
304        };
305
306        let mut device_memory = DeviceMemory::new(1024);
307        let mut core = Core::new(&config, &code, &mut device_memory, [0, 0, 0]);
308
309        core.run().unwrap();
310
311        for (i, thread) in core.waves[0].threads.iter().enumerate() {
312            assert_eq!(
313                thread.read_register(0),
314                i as u32,
315                "Thread {} should have lane_id {} in r0",
316                i,
317                i
318            );
319        }
320    }
321
322    fn encode_device_store_u32(addr_reg: u8, value_reg: u8) -> Vec<u8> {
323        let word = ((0x39u32) << 26)
324            | ((u32::from(addr_reg) & 0x1F) << 16)
325            | ((u32::from(value_reg) & 0x1F) << 11)
326            | ((2u32) << 7);
327        word.to_le_bytes().to_vec()
328    }
329
330    fn encode_shl(rd: u8, rs1: u8, rs2: u8) -> Vec<u8> {
331        let word = ((0x24u32) << 26)
332            | ((u32::from(rd) & 0x1F) << 21)
333            | ((u32::from(rs1) & 0x1F) << 16)
334            | ((u32::from(rs2) & 0x1F) << 11);
335        word.to_le_bytes().to_vec()
336    }
337
338    #[test]
339    fn test_core_mov_sr_and_device_store() {
340        let mut code = encode_mov_sr(0, 4); // r0 = lane_id
341        code.extend_from_slice(&encode_mov_imm(1, 2)); // r1 = 2
342        code.extend_from_slice(&encode_shl(2, 0, 1)); // r2 = r0 << r1 = lane_id * 4
343        code.extend_from_slice(&encode_device_store_u32(2, 0)); // store r0 at addr r2
344        code.extend_from_slice(&encode_halt());
345
346        let config = EmulatorConfig {
347            grid_dim: [1, 1, 1],
348            workgroup_dim: [4, 1, 1],
349            register_count: 32,
350            local_memory_size: 1024,
351            device_memory_size: 1024,
352            wave_width: 4,
353            ..Default::default()
354        };
355
356        let mut device_memory = DeviceMemory::new(1024);
357        let mut core = Core::new(&config, &code, &mut device_memory, [0, 0, 0]);
358
359        core.run().unwrap();
360
361        for i in 0..4 {
362            let addr = i * 4;
363            let value = device_memory.read_u32(addr).unwrap();
364            assert_eq!(value, i as u32, "Address {} should contain {}", addr, i);
365        }
366    }
367}