1use crate::barrier::BarrierManager;
10use crate::executor::{Executor, StepResult};
11use crate::memory::{DeviceMemory, LocalMemory};
12use crate::scheduler::Scheduler;
13use crate::stats::ExecutionStats;
14use crate::wave::Wave;
15use crate::EmulatorConfig;
16use crate::EmulatorError;
17
18pub struct Core<'a> {
19 waves: Vec<Wave>,
20 local_memory: LocalMemory,
21 device_memory: &'a mut DeviceMemory,
22 scheduler: Scheduler,
23 barrier_manager: BarrierManager,
24 executor: Executor<'a>,
25 stats: ExecutionStats,
26 workgroup_id: [u32; 3],
27 max_instructions: u64,
28 instructions_executed: u64,
29}
30
31impl<'a> Core<'a> {
32 pub fn new(
33 config: &EmulatorConfig,
34 code: &'a [u8],
35 device_memory: &'a mut DeviceMemory,
36 workgroup_id: [u32; 3],
37 ) -> Self {
38 let total_threads =
39 config.workgroup_dim[0] * config.workgroup_dim[1] * config.workgroup_dim[2];
40 let num_waves = total_threads.div_ceil(config.wave_width);
41
42 let mut waves = Vec::with_capacity(num_waves as usize);
43 for wave_id in 0..num_waves {
44 let base_thread_index = wave_id * config.wave_width;
45 let wave = Wave::new(
46 config.wave_width,
47 config.register_count,
48 wave_id,
49 workgroup_id,
50 config.workgroup_dim,
51 config.grid_dim,
52 base_thread_index,
53 total_threads,
54 num_waves,
55 );
56 waves.push(wave);
57 }
58
59 for wave in &mut waves {
61 for thread in &mut wave.threads {
62 for &(reg, val) in &config.initial_registers {
63 thread.write_register(reg, val);
64 }
65 }
66 }
67
68 let local_memory = LocalMemory::new(config.local_memory_size);
69 let scheduler = Scheduler::new(num_waves as usize);
70 let barrier_manager = BarrierManager::new(num_waves);
71 let executor = Executor::new(code, config.trace_enabled, workgroup_id);
72
73 Self {
74 waves,
75 local_memory,
76 device_memory,
77 scheduler,
78 barrier_manager,
79 executor,
80 stats: ExecutionStats::new(),
81 workgroup_id,
82 max_instructions: config.max_instructions,
83 instructions_executed: 0,
84 }
85 }
86
87 pub fn run(&mut self) -> Result<ExecutionStats, EmulatorError> {
88 for wave in &self.waves {
89 self.stats.record_wave();
90 let _ = wave;
91 }
92
93 loop {
94 if self.scheduler.all_halted(&self.waves) {
95 break;
96 }
97
98 if self.barrier_manager.all_at_barrier() {
99 self.barrier_manager.check_and_release(&mut self.waves);
100 self.stats.record_barrier();
101 continue;
102 }
103
104 let wave_index = match self.scheduler.pick_next_ready(&self.waves) {
105 Some(idx) => idx,
106 None => {
107 if self.scheduler.is_deadlocked(&self.waves) {
108 return Err(EmulatorError::Deadlock {
109 message: format!(
110 "workgroup ({},{},{}) deadlocked - waves waiting at different barriers",
111 self.workgroup_id[0], self.workgroup_id[1], self.workgroup_id[2]
112 ),
113 });
114 }
115 break;
116 }
117 };
118
119 let wave = &mut self.waves[wave_index];
120 let pc_before_step = wave.pc;
121
122 let result = self.executor.step(
123 wave,
124 &mut self.local_memory,
125 self.device_memory,
126 &mut self.stats,
127 )?;
128
129 self.instructions_executed += 1;
130
131 if self.max_instructions > 0 && self.instructions_executed > self.max_instructions {
132 return Err(EmulatorError::InstructionLimitExceeded {
133 limit: self.max_instructions,
134 executed: self.instructions_executed,
135 pc: pc_before_step,
136 });
137 }
138
139 match result {
140 StepResult::Continue => {}
141 StepResult::Halted => {}
142 StepResult::Barrier => {
143 self.barrier_manager.handle_barrier(wave);
144 }
145 }
146 }
147
148 Ok(self.stats.clone())
149 }
150
151 pub fn waves(&self) -> &[Wave] {
152 &self.waves
153 }
154
155 pub fn local_memory(&self) -> &LocalMemory {
156 &self.local_memory
157 }
158
159 pub fn stats(&self) -> &ExecutionStats {
160 &self.stats
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::decoder::SYNC_OP_FLAG;
168
169 fn encode_halt() -> Vec<u8> {
170 let word = ((0x3Fu32) << 26) | ((1u32) << 7) | u32::from(SYNC_OP_FLAG);
171 word.to_le_bytes().to_vec()
172 }
173
174 fn encode_mov_imm(rd: u8, imm: u32) -> Vec<u8> {
175 let word0 = ((0x3Fu32) << 26) | ((u32::from(rd) & 0x1F) << 21) | ((1u32) << 7) | 0x02;
176 let mut code = word0.to_le_bytes().to_vec();
177 code.extend_from_slice(&imm.to_le_bytes());
178 code
179 }
180
181 #[test]
182 fn test_core_single_wave_halt() {
183 let code = encode_halt();
184
185 let config = EmulatorConfig {
186 grid_dim: [1, 1, 1],
187 workgroup_dim: [32, 1, 1],
188 register_count: 32,
189 local_memory_size: 1024,
190 device_memory_size: 1024,
191 wave_width: 32,
192 ..Default::default()
193 };
194
195 let mut device_memory = DeviceMemory::new(1024);
196 let mut core = Core::new(&config, &code, &mut device_memory, [0, 0, 0]);
197
198 let stats = core.run().unwrap();
199 assert_eq!(stats.waves_executed, 1);
200 assert!(core.waves[0].is_halted());
201 }
202
203 #[test]
204 fn test_core_two_waves() {
205 let code = encode_halt();
206
207 let config = EmulatorConfig {
208 grid_dim: [1, 1, 1],
209 workgroup_dim: [64, 1, 1],
210 register_count: 32,
211 local_memory_size: 1024,
212 device_memory_size: 1024,
213 wave_width: 32,
214 ..Default::default()
215 };
216
217 let mut device_memory = DeviceMemory::new(1024);
218 let mut core = Core::new(&config, &code, &mut device_memory, [0, 0, 0]);
219
220 let stats = core.run().unwrap();
221 assert_eq!(stats.waves_executed, 2);
222 assert!(core.waves[0].is_halted());
223 assert!(core.waves[1].is_halted());
224 }
225
226 #[test]
227 fn test_core_mov_imm_then_halt() {
228 let mut code = encode_mov_imm(5, 0x12345678);
229 code.extend_from_slice(&encode_halt());
230
231 let config = EmulatorConfig {
232 grid_dim: [1, 1, 1],
233 workgroup_dim: [4, 1, 1],
234 register_count: 32,
235 local_memory_size: 1024,
236 device_memory_size: 1024,
237 wave_width: 4,
238 ..Default::default()
239 };
240
241 let mut device_memory = DeviceMemory::new(1024);
242 let mut core = Core::new(&config, &code, &mut device_memory, [0, 0, 0]);
243
244 core.run().unwrap();
245
246 for thread in &core.waves[0].threads {
247 assert_eq!(thread.read_register(5), 0x12345678);
248 }
249 }
250
251 #[test]
252 fn test_core_thread_ids() {
253 let code = encode_halt();
254
255 let config = EmulatorConfig {
256 grid_dim: [2, 2, 1],
257 workgroup_dim: [8, 4, 2],
258 register_count: 32,
259 local_memory_size: 1024,
260 device_memory_size: 1024,
261 wave_width: 32,
262 ..Default::default()
263 };
264
265 let mut device_memory = DeviceMemory::new(1024);
266 let core = Core::new(&config, &code, &mut device_memory, [1, 0, 0]);
267
268 assert_eq!(
269 core.waves[0].threads[0].special_registers.workgroup_id,
270 [1, 0, 0]
271 );
272 assert_eq!(
273 core.waves[0].threads[0].special_registers.thread_id,
274 [0, 0, 0]
275 );
276 assert_eq!(
277 core.waves[0].threads[8].special_registers.thread_id,
278 [0, 1, 0]
279 );
280 }
281
282 fn encode_mov_sr(rd: u8, sr_index: u8) -> Vec<u8> {
283 let word = ((0x3Fu32) << 26)
284 | ((u32::from(rd) & 0x1F) << 21)
285 | ((u32::from(sr_index) & 0x1F) << 16)
286 | ((2u32) << 7)
287 | 0x02;
288 word.to_le_bytes().to_vec()
289 }
290
291 #[test]
292 fn test_core_mov_sr_lane_id() {
293 let mut code = encode_mov_sr(0, 4);
294 code.extend_from_slice(&encode_halt());
295
296 let config = EmulatorConfig {
297 grid_dim: [1, 1, 1],
298 workgroup_dim: [4, 1, 1],
299 register_count: 32,
300 local_memory_size: 1024,
301 device_memory_size: 1024,
302 wave_width: 4,
303 ..Default::default()
304 };
305
306 let mut device_memory = DeviceMemory::new(1024);
307 let mut core = Core::new(&config, &code, &mut device_memory, [0, 0, 0]);
308
309 core.run().unwrap();
310
311 for (i, thread) in core.waves[0].threads.iter().enumerate() {
312 assert_eq!(
313 thread.read_register(0),
314 i as u32,
315 "Thread {} should have lane_id {} in r0",
316 i,
317 i
318 );
319 }
320 }
321
322 fn encode_device_store_u32(addr_reg: u8, value_reg: u8) -> Vec<u8> {
323 let word = ((0x39u32) << 26)
324 | ((u32::from(addr_reg) & 0x1F) << 16)
325 | ((u32::from(value_reg) & 0x1F) << 11)
326 | ((2u32) << 7);
327 word.to_le_bytes().to_vec()
328 }
329
330 fn encode_shl(rd: u8, rs1: u8, rs2: u8) -> Vec<u8> {
331 let word = ((0x24u32) << 26)
332 | ((u32::from(rd) & 0x1F) << 21)
333 | ((u32::from(rs1) & 0x1F) << 16)
334 | ((u32::from(rs2) & 0x1F) << 11);
335 word.to_le_bytes().to_vec()
336 }
337
338 #[test]
339 fn test_core_mov_sr_and_device_store() {
340 let mut code = encode_mov_sr(0, 4); code.extend_from_slice(&encode_mov_imm(1, 2)); code.extend_from_slice(&encode_shl(2, 0, 1)); code.extend_from_slice(&encode_device_store_u32(2, 0)); code.extend_from_slice(&encode_halt());
345
346 let config = EmulatorConfig {
347 grid_dim: [1, 1, 1],
348 workgroup_dim: [4, 1, 1],
349 register_count: 32,
350 local_memory_size: 1024,
351 device_memory_size: 1024,
352 wave_width: 4,
353 ..Default::default()
354 };
355
356 let mut device_memory = DeviceMemory::new(1024);
357 let mut core = Core::new(&config, &code, &mut device_memory, [0, 0, 0]);
358
359 core.run().unwrap();
360
361 for i in 0..4 {
362 let addr = i * 4;
363 let value = device_memory.read_u32(addr).unwrap();
364 assert_eq!(value, i as u32, "Address {} should contain {}", addr, i);
365 }
366 }
367}