1use crate::decoder::{
11 AtomicOp, BitOpType, CmpOp, ControlOp, CvtType, DecodedInstruction, Decoder, F16Op,
12 F16PackedOp, F64DivSqrtOp, F64Op, FUnaryOp, MemWidth, MiscOp, Opcode, SyncOp, WaveOpType,
13 WaveReduceType,
14};
15use crate::memory::{DeviceMemory, LocalMemory};
16use crate::shuffle;
17use crate::stats::{ExecutionStats, InstructionCategory, TraceWriter};
18use crate::wave::Wave;
19use crate::EmulatorError;
20use half::f16;
21
22pub struct Executor<'a> {
23 decoder: Decoder<'a>,
24 trace: TraceWriter,
25 workgroup_id: [u32; 3],
26}
27
28impl<'a> Executor<'a> {
29 pub fn new(code: &'a [u8], trace_enabled: bool, workgroup_id: [u32; 3]) -> Self {
30 Self {
31 decoder: Decoder::new(code),
32 trace: TraceWriter::new(trace_enabled),
33 workgroup_id,
34 }
35 }
36
37 pub fn step(
38 &mut self,
39 wave: &mut Wave,
40 local_memory: &mut LocalMemory,
41 device_memory: &mut DeviceMemory,
42 stats: &mut ExecutionStats,
43 ) -> Result<StepResult, EmulatorError> {
44 if wave.is_halted() {
45 return Ok(StepResult::Halted);
46 }
47
48 if wave.active_mask == 0 && wave.control_flow.is_empty() {
49 wave.halt();
50 return Ok(StepResult::Halted);
51 }
52
53 let inst = self.decoder.decode_at(wave.pc)?;
54
55 if self.trace.is_enabled() {
56 let disasm = self.decoder.disassemble(&inst);
57 self.trace
58 .trace_instruction(self.workgroup_id, wave.wave_id, wave.pc, &disasm);
59 }
60
61 let result = self.execute_instruction(wave, &inst, local_memory, device_memory, stats)?;
62
63 match result {
64 ExecuteResult::Continue => {
65 wave.advance_pc(inst.size);
66 Ok(StepResult::Continue)
67 }
68 ExecuteResult::Jump(target) => {
69 wave.set_pc(target);
70 Ok(StepResult::Continue)
71 }
72 ExecuteResult::Halt => {
73 wave.halt();
74 Ok(StepResult::Halted)
75 }
76 ExecuteResult::Barrier => Ok(StepResult::Barrier),
77 }
78 }
79
80 fn execute_instruction(
81 &mut self,
82 wave: &mut Wave,
83 inst: &DecodedInstruction,
84 local_memory: &mut LocalMemory,
85 device_memory: &mut DeviceMemory,
86 stats: &mut ExecutionStats,
87 ) -> Result<ExecuteResult, EmulatorError> {
88 let original_mask = wave.active_mask;
89 let is_control_sync = inst.opcode == Opcode::Control && inst.is_sync_op();
90 let is_halt = is_control_sync && inst.modifier == SyncOp::Halt as u8;
91
92 if inst.is_predicated() {
93 let pred_mask = self.compute_predicate_mask(wave, inst.pred_reg, inst.pred_neg);
94
95 if is_halt {
96 wave.active_mask &= !pred_mask;
97 if wave.active_mask == 0 {
98 return Ok(ExecuteResult::Halt);
99 }
100 return Ok(ExecuteResult::Continue);
101 } else if is_control_sync {
102 if pred_mask == 0 {
103 return Ok(ExecuteResult::Continue);
104 }
105 } else {
106 wave.active_mask &= pred_mask;
107 }
108 }
109
110 let result = match inst.opcode {
111 Opcode::Iadd
112 | Opcode::Isub
113 | Opcode::Imul
114 | Opcode::ImulHi
115 | Opcode::Idiv
116 | Opcode::Imod
117 | Opcode::Ineg
118 | Opcode::Iabs
119 | Opcode::Imin
120 | Opcode::Imax => {
121 self.execute_integer_op(wave, inst);
122 stats.record_instruction(InstructionCategory::Integer);
123 Ok(ExecuteResult::Continue)
124 }
125 Opcode::Imad | Opcode::Iclamp => {
126 self.execute_integer_extended(wave, inst);
127 stats.record_instruction(InstructionCategory::Integer);
128 Ok(ExecuteResult::Continue)
129 }
130 Opcode::And
131 | Opcode::Or
132 | Opcode::Xor
133 | Opcode::Not
134 | Opcode::Shl
135 | Opcode::Shr
136 | Opcode::Sar => {
137 self.execute_bitwise_op(wave, inst);
138 stats.record_instruction(InstructionCategory::Integer);
139 Ok(ExecuteResult::Continue)
140 }
141 Opcode::BitOps => {
142 self.execute_bit_ops(wave, inst);
143 stats.record_instruction(InstructionCategory::Integer);
144 Ok(ExecuteResult::Continue)
145 }
146 Opcode::Fadd
147 | Opcode::Fsub
148 | Opcode::Fmul
149 | Opcode::Fdiv
150 | Opcode::Fneg
151 | Opcode::Fabs
152 | Opcode::Fmin
153 | Opcode::Fmax
154 | Opcode::Fsqrt => {
155 self.execute_float_op(wave, inst);
156 stats.record_instruction(InstructionCategory::Float);
157 Ok(ExecuteResult::Continue)
158 }
159 Opcode::Fma | Opcode::Fclamp => {
160 self.execute_float_extended(wave, inst);
161 stats.record_instruction(InstructionCategory::Float);
162 Ok(ExecuteResult::Continue)
163 }
164 Opcode::FUnaryOps => {
165 self.execute_float_unary(wave, inst);
166 stats.record_instruction(InstructionCategory::Float);
167 Ok(ExecuteResult::Continue)
168 }
169 Opcode::F16Ops | Opcode::F16PackedOps => {
170 self.execute_f16_op(wave, inst);
171 stats.record_instruction(InstructionCategory::Float);
172 Ok(ExecuteResult::Continue)
173 }
174 Opcode::F64Ops | Opcode::F64DivSqrt => {
175 self.execute_f64_op(wave, inst);
176 stats.record_instruction(InstructionCategory::Float);
177 Ok(ExecuteResult::Continue)
178 }
179 Opcode::Icmp | Opcode::Ucmp | Opcode::Fcmp => {
180 self.execute_compare(wave, inst);
181 stats.record_instruction(InstructionCategory::Integer);
182 Ok(ExecuteResult::Continue)
183 }
184 Opcode::Select => {
185 self.execute_select(wave, inst);
186 stats.record_instruction(InstructionCategory::Integer);
187 Ok(ExecuteResult::Continue)
188 }
189 Opcode::Cvt => {
190 self.execute_convert(wave, inst);
191 stats.record_instruction(InstructionCategory::Float);
192 Ok(ExecuteResult::Continue)
193 }
194 Opcode::LocalLoad => {
195 self.execute_local_load(wave, inst, local_memory, stats)?;
196 Ok(ExecuteResult::Continue)
197 }
198 Opcode::LocalStore => {
199 self.execute_local_store(wave, inst, local_memory, stats)?;
200 Ok(ExecuteResult::Continue)
201 }
202 Opcode::DeviceLoad => {
203 self.execute_device_load(wave, inst, device_memory, stats)?;
204 Ok(ExecuteResult::Continue)
205 }
206 Opcode::DeviceStore => {
207 self.execute_device_store(wave, inst, device_memory, stats)?;
208 Ok(ExecuteResult::Continue)
209 }
210 Opcode::LocalAtomic => {
211 self.execute_local_atomic(wave, inst, local_memory, stats)?;
212 Ok(ExecuteResult::Continue)
213 }
214 Opcode::DeviceAtomic => {
215 self.execute_device_atomic(wave, inst, device_memory, stats)?;
216 Ok(ExecuteResult::Continue)
217 }
218 Opcode::WaveOp => {
219 self.execute_wave_op(wave, inst);
220 stats.record_instruction(InstructionCategory::WaveOp);
221 Ok(ExecuteResult::Continue)
222 }
223 Opcode::Control => self.execute_control(wave, inst, stats),
224 };
225
226 if inst.is_predicated() && !is_control_sync {
227 wave.active_mask = original_mask;
228 }
229
230 result
231 }
232
233 fn compute_predicate_mask(&self, wave: &Wave, pred_reg: u8, negated: bool) -> u64 {
234 let mut mask: u64 = 0;
235 for lane in 0..wave.wave_width {
236 if wave.is_thread_active(lane) {
237 let pred = wave.threads[lane as usize].read_predicate(pred_reg);
238 let value = if negated { !pred } else { pred };
239 if value {
240 mask |= 1u64 << lane;
241 }
242 }
243 }
244 mask
245 }
246
247 fn execute_integer_op(&self, wave: &mut Wave, inst: &DecodedInstruction) {
248 for lane in 0..wave.wave_width {
249 if !wave.is_thread_active(lane) {
250 continue;
251 }
252
253 let thread = &mut wave.threads[lane as usize];
254 let rs1 = thread.read_register(inst.rs1);
255 let rs2 = thread.read_register(inst.rs2);
256
257 let result = match inst.opcode {
258 Opcode::Iadd => rs1.wrapping_add(rs2),
259 Opcode::Isub => rs1.wrapping_sub(rs2),
260 Opcode::Imul => rs1.wrapping_mul(rs2),
261 Opcode::ImulHi => {
262 let wide = (rs1 as i64).wrapping_mul(rs2 as i64);
263 (wide >> 32) as u32
264 }
265 Opcode::Idiv => {
266 if rs2 == 0 {
267 0
268 } else {
269 (rs1 as i32).wrapping_div(rs2 as i32) as u32
270 }
271 }
272 Opcode::Imod => {
273 if rs2 == 0 {
274 0
275 } else {
276 (rs1 as i32).wrapping_rem(rs2 as i32) as u32
277 }
278 }
279 Opcode::Ineg => (-(rs1 as i32)) as u32,
280 Opcode::Iabs => (rs1 as i32).unsigned_abs(),
281 Opcode::Imin => (rs1 as i32).min(rs2 as i32) as u32,
282 Opcode::Imax => (rs1 as i32).max(rs2 as i32) as u32,
283 _ => 0,
284 };
285
286 thread.write_register(inst.rd, result);
287 }
288 }
289
290 fn execute_integer_extended(&self, wave: &mut Wave, inst: &DecodedInstruction) {
291 for lane in 0..wave.wave_width {
292 if !wave.is_thread_active(lane) {
293 continue;
294 }
295
296 let thread = &mut wave.threads[lane as usize];
297 let rs1 = thread.read_register(inst.rs1);
298 let rs2 = thread.read_register(inst.rs2);
299 let rs3 = thread.read_register(inst.rs3);
300
301 let result = match inst.opcode {
302 Opcode::Imad => rs1.wrapping_mul(rs2).wrapping_add(rs3),
303 Opcode::Iclamp => {
304 let val = rs1 as i32;
305 let lo = rs2 as i32;
306 let hi = rs3 as i32;
307 val.clamp(lo, hi) as u32
308 }
309 _ => 0,
310 };
311
312 thread.write_register(inst.rd, result);
313 }
314 }
315
316 fn execute_bitwise_op(&self, wave: &mut Wave, inst: &DecodedInstruction) {
317 for lane in 0..wave.wave_width {
318 if !wave.is_thread_active(lane) {
319 continue;
320 }
321
322 let thread = &mut wave.threads[lane as usize];
323 let rs1 = thread.read_register(inst.rs1);
324 let rs2 = thread.read_register(inst.rs2);
325
326 let result = match inst.opcode {
327 Opcode::And => rs1 & rs2,
328 Opcode::Or => rs1 | rs2,
329 Opcode::Xor => rs1 ^ rs2,
330 Opcode::Not => !rs1,
331 Opcode::Shl => rs1.wrapping_shl(rs2 & 0x1F),
332 Opcode::Shr => rs1.wrapping_shr(rs2 & 0x1F),
333 Opcode::Sar => ((rs1 as i32).wrapping_shr(rs2 & 0x1F)) as u32,
334 _ => 0,
335 };
336
337 thread.write_register(inst.rd, result);
338 }
339 }
340
341 fn execute_bit_ops(&self, wave: &mut Wave, inst: &DecodedInstruction) {
342 for lane in 0..wave.wave_width {
343 if !wave.is_thread_active(lane) {
344 continue;
345 }
346
347 let thread = &mut wave.threads[lane as usize];
348 let rs1 = thread.read_register(inst.rs1);
349 let rs2 = thread.read_register(inst.rs2);
350 let rs3 = thread.read_register(inst.rs3);
351 let rs4 = thread.read_register(inst.rs4);
352
353 let result = match inst.modifier {
354 m if m == BitOpType::Bitcount as u8 => rs1.count_ones(),
355 m if m == BitOpType::Bitfind as u8 => {
356 if rs1 == 0 {
357 u32::MAX
358 } else {
359 rs1.leading_zeros()
360 }
361 }
362 m if m == BitOpType::Bitrev as u8 => rs1.reverse_bits(),
363 m if m == BitOpType::Bfe as u8 => {
364 let offset = rs2 & 0x1F;
365 let width = rs3 & 0x1F;
366 if width == 0 {
367 0
368 } else {
369 (rs1 >> offset) & ((1 << width) - 1)
370 }
371 }
372 m if m == BitOpType::Bfi as u8 => {
373 let offset = rs3 & 0x1F;
374 let width = rs4 & 0x1F;
375 if width == 0 {
376 rs1
377 } else {
378 let mask = ((1u32 << width) - 1) << offset;
379 (rs1 & !mask) | ((rs2 << offset) & mask)
380 }
381 }
382 _ => 0,
383 };
384
385 thread.write_register(inst.rd, result);
386 }
387 }
388
389 fn execute_float_op(&self, wave: &mut Wave, inst: &DecodedInstruction) {
390 for lane in 0..wave.wave_width {
391 if !wave.is_thread_active(lane) {
392 continue;
393 }
394
395 let thread = &mut wave.threads[lane as usize];
396 let rs1 = f32::from_bits(thread.read_register(inst.rs1));
397 let rs2 = f32::from_bits(thread.read_register(inst.rs2));
398
399 let result = match inst.opcode {
400 Opcode::Fadd => rs1 + rs2,
401 Opcode::Fsub => rs1 - rs2,
402 Opcode::Fmul => rs1 * rs2,
403 Opcode::Fdiv => {
404 if rs2 == 0.0 {
405 f32::INFINITY
406 } else {
407 rs1 / rs2
408 }
409 }
410 Opcode::Fneg => -rs1,
411 Opcode::Fabs => rs1.abs(),
412 Opcode::Fmin => rs1.min(rs2),
413 Opcode::Fmax => rs1.max(rs2),
414 Opcode::Fsqrt => rs1.sqrt(),
415 _ => 0.0,
416 };
417
418 thread.write_register(inst.rd, result.to_bits());
419 }
420 }
421
422 fn execute_float_extended(&self, wave: &mut Wave, inst: &DecodedInstruction) {
423 for lane in 0..wave.wave_width {
424 if !wave.is_thread_active(lane) {
425 continue;
426 }
427
428 let thread = &mut wave.threads[lane as usize];
429 let rs1 = f32::from_bits(thread.read_register(inst.rs1));
430 let rs2 = f32::from_bits(thread.read_register(inst.rs2));
431 let rs3 = f32::from_bits(thread.read_register(inst.rs3));
432
433 let result = match inst.opcode {
434 Opcode::Fma => rs1.mul_add(rs2, rs3),
435 Opcode::Fclamp => rs1.clamp(rs2, rs3),
436 _ => 0.0,
437 };
438
439 thread.write_register(inst.rd, result.to_bits());
440 }
441 }
442
443 fn execute_float_unary(&self, wave: &mut Wave, inst: &DecodedInstruction) {
444 for lane in 0..wave.wave_width {
445 if !wave.is_thread_active(lane) {
446 continue;
447 }
448
449 let thread = &mut wave.threads[lane as usize];
450 let rs1 = f32::from_bits(thread.read_register(inst.rs1));
451
452 let result = match inst.modifier {
453 m if m == FUnaryOp::Frsqrt as u8 => 1.0 / rs1.sqrt(),
454 m if m == FUnaryOp::Frcp as u8 => 1.0 / rs1,
455 m if m == FUnaryOp::Ffloor as u8 => rs1.floor(),
456 m if m == FUnaryOp::Fceil as u8 => rs1.ceil(),
457 m if m == FUnaryOp::Fround as u8 => rs1.round(),
458 m if m == FUnaryOp::Ftrunc as u8 => rs1.trunc(),
459 m if m == FUnaryOp::Ffract as u8 => rs1.fract(),
460 m if m == FUnaryOp::Fsat as u8 => rs1.clamp(0.0, 1.0),
461 m if m == FUnaryOp::Fsin as u8 => rs1.sin(),
462 m if m == FUnaryOp::Fcos as u8 => rs1.cos(),
463 m if m == FUnaryOp::Fexp2 as u8 => rs1.exp2(),
464 m if m == FUnaryOp::Flog2 as u8 => rs1.log2(),
465 _ => 0.0,
466 };
467
468 thread.write_register(inst.rd, result.to_bits());
469 }
470 }
471
472 fn execute_f16_op(&self, wave: &mut Wave, inst: &DecodedInstruction) {
473 for lane in 0..wave.wave_width {
474 if !wave.is_thread_active(lane) {
475 continue;
476 }
477
478 let thread = &mut wave.threads[lane as usize];
479 let rs1_bits = thread.read_register(inst.rs1);
480 let rs2_bits = thread.read_register(inst.rs2);
481 let rs3_bits = thread.read_register(inst.rs3);
482
483 let result = if inst.opcode == Opcode::F16Ops {
484 let a = f16::from_bits(rs1_bits as u16);
485 let b = f16::from_bits(rs2_bits as u16);
486 let c = f16::from_bits(rs3_bits as u16);
487
488 let r = match inst.modifier {
489 m if m == F16Op::Hadd as u8 => f16::from_f32(a.to_f32() + b.to_f32()),
490 m if m == F16Op::Hsub as u8 => f16::from_f32(a.to_f32() - b.to_f32()),
491 m if m == F16Op::Hmul as u8 => f16::from_f32(a.to_f32() * b.to_f32()),
492 m if m == F16Op::Hma as u8 => {
493 f16::from_f32(a.to_f32().mul_add(b.to_f32(), c.to_f32()))
494 }
495 _ => f16::ZERO,
496 };
497 u32::from(r.to_bits())
498 } else {
499 let a_lo = f16::from_bits(rs1_bits as u16);
500 let a_hi = f16::from_bits((rs1_bits >> 16) as u16);
501 let b_lo = f16::from_bits(rs2_bits as u16);
502 let b_hi = f16::from_bits((rs2_bits >> 16) as u16);
503 let c_lo = f16::from_bits(rs3_bits as u16);
504 let c_hi = f16::from_bits((rs3_bits >> 16) as u16);
505
506 let (r_lo, r_hi) = match inst.modifier {
507 m if m == F16PackedOp::Hadd2 as u8 => (
508 f16::from_f32(a_lo.to_f32() + b_lo.to_f32()),
509 f16::from_f32(a_hi.to_f32() + b_hi.to_f32()),
510 ),
511 m if m == F16PackedOp::Hmul2 as u8 => (
512 f16::from_f32(a_lo.to_f32() * b_lo.to_f32()),
513 f16::from_f32(a_hi.to_f32() * b_hi.to_f32()),
514 ),
515 m if m == F16PackedOp::Hma2 as u8 => (
516 f16::from_f32(a_lo.to_f32().mul_add(b_lo.to_f32(), c_lo.to_f32())),
517 f16::from_f32(a_hi.to_f32().mul_add(b_hi.to_f32(), c_hi.to_f32())),
518 ),
519 _ => (f16::ZERO, f16::ZERO),
520 };
521 u32::from(r_lo.to_bits()) | (u32::from(r_hi.to_bits()) << 16)
522 };
523
524 thread.write_register(inst.rd, result);
525 }
526 }
527
528 fn execute_f64_op(&self, wave: &mut Wave, inst: &DecodedInstruction) {
529 for lane in 0..wave.wave_width {
530 if !wave.is_thread_active(lane) {
531 continue;
532 }
533
534 let thread = &mut wave.threads[lane as usize];
535
536 let rs1_lo = thread.read_register(inst.rs1);
537 let rs1_hi = thread.read_register(inst.rs1 + 1);
538 let a = f64::from_bits((u64::from(rs1_hi) << 32) | u64::from(rs1_lo));
539
540 let rs2_lo = thread.read_register(inst.rs2);
541 let rs2_hi = thread.read_register(inst.rs2 + 1);
542 let b = f64::from_bits((u64::from(rs2_hi) << 32) | u64::from(rs2_lo));
543
544 let result = if inst.opcode == Opcode::F64Ops {
545 let rs3_lo = thread.read_register(inst.rs3);
546 let rs3_hi = thread.read_register(inst.rs3 + 1);
547 let c = f64::from_bits((u64::from(rs3_hi) << 32) | u64::from(rs3_lo));
548
549 match inst.modifier {
550 m if m == F64Op::Dadd as u8 => a + b,
551 m if m == F64Op::Dsub as u8 => a - b,
552 m if m == F64Op::Dmul as u8 => a * b,
553 m if m == F64Op::Dma as u8 => a.mul_add(b, c),
554 _ => 0.0,
555 }
556 } else {
557 match inst.modifier {
558 m if m == F64DivSqrtOp::Ddiv as u8 => a / b,
559 m if m == F64DivSqrtOp::Dsqrt as u8 => a.sqrt(),
560 _ => 0.0,
561 }
562 };
563
564 let bits = result.to_bits();
565 thread.write_register(inst.rd, bits as u32);
566 thread.write_register(inst.rd + 1, (bits >> 32) as u32);
567 }
568 }
569
570 fn execute_compare(&self, wave: &mut Wave, inst: &DecodedInstruction) {
571 for lane in 0..wave.wave_width {
572 if !wave.is_thread_active(lane) {
573 continue;
574 }
575
576 let thread = &mut wave.threads[lane as usize];
577 let rs1 = thread.read_register(inst.rs1);
578 let rs2 = thread.read_register(inst.rs2);
579
580 let result = match inst.opcode {
581 Opcode::Icmp => {
582 let a = rs1 as i32;
583 let b = rs2 as i32;
584 match inst.modifier {
585 m if m == CmpOp::Eq as u8 => a == b,
586 m if m == CmpOp::Ne as u8 => a != b,
587 m if m == CmpOp::Lt as u8 => a < b,
588 m if m == CmpOp::Le as u8 => a <= b,
589 m if m == CmpOp::Gt as u8 => a > b,
590 m if m == CmpOp::Ge as u8 => a >= b,
591 _ => false,
592 }
593 }
594 Opcode::Ucmp => match inst.modifier {
595 m if m == CmpOp::Lt as u8 => rs1 < rs2,
596 m if m == CmpOp::Le as u8 => rs1 <= rs2,
597 _ => false,
598 },
599 Opcode::Fcmp => {
600 let a = f32::from_bits(rs1);
601 let b = f32::from_bits(rs2);
602 match inst.modifier {
603 m if m == CmpOp::Eq as u8 => a == b,
604 m if m == CmpOp::Ne as u8 => a != b,
605 m if m == CmpOp::Lt as u8 => a < b,
606 m if m == CmpOp::Le as u8 => a <= b,
607 m if m == CmpOp::Gt as u8 => a > b,
608 m if m == CmpOp::Ord as u8 => !a.is_nan() && !b.is_nan(),
609 m if m == CmpOp::Unord as u8 => a.is_nan() || b.is_nan(),
610 _ => false,
611 }
612 }
613 _ => false,
614 };
615
616 thread.write_predicate(inst.rd, result);
617 }
618 }
619
620 fn execute_select(&self, wave: &mut Wave, inst: &DecodedInstruction) {
621 for lane in 0..wave.wave_width {
622 if !wave.is_thread_active(lane) {
623 continue;
624 }
625
626 let thread = &mut wave.threads[lane as usize];
627 let pred = thread.read_predicate(inst.modifier);
628 let rs1 = thread.read_register(inst.rs1);
629 let rs2 = thread.read_register(inst.rs2);
630
631 let result = if pred { rs1 } else { rs2 };
632 thread.write_register(inst.rd, result);
633 }
634 }
635
636 fn execute_convert(&self, wave: &mut Wave, inst: &DecodedInstruction) {
637 for lane in 0..wave.wave_width {
638 if !wave.is_thread_active(lane) {
639 continue;
640 }
641
642 let thread = &mut wave.threads[lane as usize];
643 let rs1 = thread.read_register(inst.rs1);
644
645 let result = match inst.modifier {
646 m if m == CvtType::F32I32 as u8 => f32::from_bits(rs1) as i32 as u32, m if m == CvtType::F32U32 as u8 => f32::from_bits(rs1) as u32, m if m == CvtType::I32F32 as u8 => ((rs1 as i32) as f32).to_bits(), m if m == CvtType::U32F32 as u8 => (rs1 as f32).to_bits(), m if m == CvtType::F32F16 as u8 => f16::from_bits(rs1 as u16).to_f32().to_bits(),
651 m if m == CvtType::F16F32 as u8 => {
652 u32::from(f16::from_f32(f32::from_bits(rs1)).to_bits())
653 }
654 m if m == CvtType::F32F64 as u8 => {
655 let rs1_hi = thread.read_register(inst.rs1 + 1);
656 let d = f64::from_bits((u64::from(rs1_hi) << 32) | u64::from(rs1));
657 (d as f32).to_bits()
658 }
659 m if m == CvtType::F64F32 as u8 => {
660 let f = f32::from_bits(rs1);
661 let d = f64::from(f);
662 let bits = d.to_bits();
663 thread.write_register(inst.rd + 1, (bits >> 32) as u32);
664 bits as u32
665 }
666 _ => 0,
667 };
668
669 thread.write_register(inst.rd, result);
670 }
671 }
672
673 fn execute_local_load(
674 &self,
675 wave: &mut Wave,
676 inst: &DecodedInstruction,
677 local_memory: &mut LocalMemory,
678 stats: &mut ExecutionStats,
679 ) -> Result<(), EmulatorError> {
680 let width = match inst.modifier {
681 m if m == MemWidth::U8 as u8 => 1,
682 m if m == MemWidth::U16 as u8 => 2,
683 m if m == MemWidth::U32 as u8 => 4,
684 m if m == MemWidth::U64 as u8 => 8,
685 _ => 4,
686 };
687
688 for lane in 0..wave.wave_width {
689 if !wave.is_thread_active(lane) {
690 continue;
691 }
692
693 let thread = &mut wave.threads[lane as usize];
694 let addr = thread.read_register(inst.rs1);
695
696 let value = match width {
697 1 => u32::from(local_memory.read_u8(addr)?),
698 2 => u32::from(local_memory.read_u16(addr)?),
699 4 => local_memory.read_u32(addr)?,
700 8 => {
701 let val = local_memory.read_u64(addr)?;
702 thread.write_register(inst.rd + 1, (val >> 32) as u32);
703 val as u32
704 }
705 _ => 0,
706 };
707
708 thread.write_register(inst.rd, value);
709 stats.record_local_load(width as u64);
710 }
711
712 stats.record_instruction(InstructionCategory::Memory);
713 Ok(())
714 }
715
716 fn execute_local_store(
717 &self,
718 wave: &mut Wave,
719 inst: &DecodedInstruction,
720 local_memory: &mut LocalMemory,
721 stats: &mut ExecutionStats,
722 ) -> Result<(), EmulatorError> {
723 let width = match inst.modifier {
724 m if m == MemWidth::U8 as u8 => 1,
725 m if m == MemWidth::U16 as u8 => 2,
726 m if m == MemWidth::U32 as u8 => 4,
727 m if m == MemWidth::U64 as u8 => 8,
728 _ => 4,
729 };
730
731 for lane in 0..wave.wave_width {
732 if !wave.is_thread_active(lane) {
733 continue;
734 }
735
736 let thread = &wave.threads[lane as usize];
737 let addr = thread.read_register(inst.rs1);
738 let value = thread.read_register(inst.rs2);
739
740 match width {
741 1 => local_memory.write_u8(addr, value as u8)?,
742 2 => local_memory.write_u16(addr, value as u16)?,
743 4 => local_memory.write_u32(addr, value)?,
744 8 => {
745 let hi = thread.read_register(inst.rs2 + 1);
746 let val = (u64::from(hi) << 32) | u64::from(value);
747 local_memory.write_u64(addr, val)?;
748 }
749 _ => {}
750 }
751
752 stats.record_local_store(width as u64);
753 }
754
755 stats.record_instruction(InstructionCategory::Memory);
756 Ok(())
757 }
758
759 fn execute_device_load(
760 &self,
761 wave: &mut Wave,
762 inst: &DecodedInstruction,
763 device_memory: &mut DeviceMemory,
764 stats: &mut ExecutionStats,
765 ) -> Result<(), EmulatorError> {
766 let width = match inst.modifier {
767 m if m == MemWidth::U8 as u8 => 1,
768 m if m == MemWidth::U16 as u8 => 2,
769 m if m == MemWidth::U32 as u8 => 4,
770 m if m == MemWidth::U64 as u8 => 8,
771 m if m == MemWidth::U128 as u8 => 16,
772 _ => 4,
773 };
774
775 for lane in 0..wave.wave_width {
776 if !wave.is_thread_active(lane) {
777 continue;
778 }
779
780 let thread = &mut wave.threads[lane as usize];
781 let addr = u64::from(thread.read_register(inst.rs1));
782
783 match width {
784 1 => {
785 let val = device_memory.read_u8(addr)?;
786 thread.write_register(inst.rd, u32::from(val));
787 }
788 2 => {
789 let val = device_memory.read_u16(addr)?;
790 thread.write_register(inst.rd, u32::from(val));
791 }
792 4 => {
793 let val = device_memory.read_u32(addr)?;
794 thread.write_register(inst.rd, val);
795 }
796 8 => {
797 let val = device_memory.read_u64(addr)?;
798 thread.write_register(inst.rd, val as u32);
799 thread.write_register(inst.rd + 1, (val >> 32) as u32);
800 }
801 16 => {
802 let val = device_memory.read_u128(addr)?;
803 thread.write_register(inst.rd, val as u32);
804 thread.write_register(inst.rd + 1, (val >> 32) as u32);
805 thread.write_register(inst.rd + 2, (val >> 64) as u32);
806 thread.write_register(inst.rd + 3, (val >> 96) as u32);
807 }
808 _ => {}
809 }
810
811 stats.record_device_load(width as u64);
812 }
813
814 stats.record_instruction(InstructionCategory::Memory);
815 Ok(())
816 }
817
818 fn execute_device_store(
819 &self,
820 wave: &mut Wave,
821 inst: &DecodedInstruction,
822 device_memory: &mut DeviceMemory,
823 stats: &mut ExecutionStats,
824 ) -> Result<(), EmulatorError> {
825 let width = match inst.modifier {
826 m if m == MemWidth::U8 as u8 => 1,
827 m if m == MemWidth::U16 as u8 => 2,
828 m if m == MemWidth::U32 as u8 => 4,
829 m if m == MemWidth::U64 as u8 => 8,
830 m if m == MemWidth::U128 as u8 => 16,
831 _ => 4,
832 };
833
834 for lane in 0..wave.wave_width {
835 if !wave.is_thread_active(lane) {
836 continue;
837 }
838
839 let thread = &wave.threads[lane as usize];
840 let addr = u64::from(thread.read_register(inst.rs1));
841 let value = thread.read_register(inst.rs2);
842
843 match width {
844 1 => device_memory.write_u8(addr, value as u8)?,
845 2 => device_memory.write_u16(addr, value as u16)?,
846 4 => device_memory.write_u32(addr, value)?,
847 8 => {
848 let hi = thread.read_register(inst.rs2 + 1);
849 let val = (u64::from(hi) << 32) | u64::from(value);
850 device_memory.write_u64(addr, val)?;
851 }
852 16 => {
853 let w0 = value;
854 let w1 = thread.read_register(inst.rs2 + 1);
855 let w2 = thread.read_register(inst.rs2 + 2);
856 let w3 = thread.read_register(inst.rs2 + 3);
857 let val = u128::from(w0)
858 | (u128::from(w1) << 32)
859 | (u128::from(w2) << 64)
860 | (u128::from(w3) << 96);
861 device_memory.write_u128(addr, val)?;
862 }
863 _ => {}
864 }
865
866 stats.record_device_store(width as u64);
867 }
868
869 stats.record_instruction(InstructionCategory::Memory);
870 Ok(())
871 }
872
873 fn execute_local_atomic(
874 &self,
875 wave: &mut Wave,
876 inst: &DecodedInstruction,
877 local_memory: &mut LocalMemory,
878 stats: &mut ExecutionStats,
879 ) -> Result<(), EmulatorError> {
880 let non_returning = inst.is_non_returning_atomic();
881
882 for lane in 0..wave.wave_width {
883 if !wave.is_thread_active(lane) {
884 continue;
885 }
886
887 let thread = &mut wave.threads[lane as usize];
888 let addr = thread.read_register(inst.rs1);
889 let value = thread.read_register(inst.rs2);
890
891 let old = match inst.modifier {
892 m if m == AtomicOp::Add as u8 => local_memory.atomic_add(addr, value)?,
893 m if m == AtomicOp::Sub as u8 => local_memory.atomic_sub(addr, value)?,
894 m if m == AtomicOp::Min as u8 => local_memory.atomic_min(addr, value)?,
895 m if m == AtomicOp::Max as u8 => local_memory.atomic_max(addr, value)?,
896 m if m == AtomicOp::And as u8 => local_memory.atomic_and(addr, value)?,
897 m if m == AtomicOp::Or as u8 => local_memory.atomic_or(addr, value)?,
898 m if m == AtomicOp::Xor as u8 => local_memory.atomic_xor(addr, value)?,
899 m if m == AtomicOp::Exchange as u8 => local_memory.atomic_exchange(addr, value)?,
900 _ => {
901 let expected = thread.read_register(inst.rs3);
902 local_memory.atomic_cas(addr, expected, value)?
903 }
904 };
905
906 if !non_returning {
907 thread.write_register(inst.rd, old);
908 }
909
910 stats.atomic_ops += 1;
911 }
912
913 stats.record_instruction(InstructionCategory::Atomic);
914 Ok(())
915 }
916
917 fn execute_device_atomic(
918 &self,
919 wave: &mut Wave,
920 inst: &DecodedInstruction,
921 device_memory: &mut DeviceMemory,
922 stats: &mut ExecutionStats,
923 ) -> Result<(), EmulatorError> {
924 let non_returning = inst.is_non_returning_atomic();
925
926 for lane in 0..wave.wave_width {
927 if !wave.is_thread_active(lane) {
928 continue;
929 }
930
931 let thread = &mut wave.threads[lane as usize];
932 let addr = u64::from(thread.read_register(inst.rs1));
933 let value = thread.read_register(inst.rs2);
934
935 let old = match inst.modifier {
936 m if m == AtomicOp::Add as u8 => device_memory.atomic_add(addr, value)?,
937 m if m == AtomicOp::Sub as u8 => device_memory.atomic_sub(addr, value)?,
938 m if m == AtomicOp::Min as u8 => device_memory.atomic_min(addr, value)?,
939 m if m == AtomicOp::Max as u8 => device_memory.atomic_max(addr, value)?,
940 m if m == AtomicOp::And as u8 => device_memory.atomic_and(addr, value)?,
941 m if m == AtomicOp::Or as u8 => device_memory.atomic_or(addr, value)?,
942 m if m == AtomicOp::Xor as u8 => device_memory.atomic_xor(addr, value)?,
943 m if m == AtomicOp::Exchange as u8 => device_memory.atomic_exchange(addr, value)?,
944 _ => {
945 let expected = thread.read_register(inst.rs3);
946 device_memory.atomic_cas(addr, expected, value)?
947 }
948 };
949
950 if !non_returning {
951 thread.write_register(inst.rd, old);
952 }
953
954 stats.atomic_ops += 1;
955 }
956
957 stats.record_instruction(InstructionCategory::Atomic);
958 Ok(())
959 }
960
961 fn execute_wave_op(&self, wave: &mut Wave, inst: &DecodedInstruction) {
962 if inst.is_wave_reduce() {
963 let reduce_mod = inst.modifier - 8;
964 match reduce_mod {
965 m if m == WaveReduceType::PrefixSum as u8 => {
966 shuffle::wave_prefix_sum(wave, inst.rd, inst.rs1);
967 }
968 m if m == WaveReduceType::ReduceAdd as u8 => {
969 shuffle::wave_reduce_add(wave, inst.rd, inst.rs1);
970 }
971 m if m == WaveReduceType::ReduceMin as u8 => {
972 shuffle::wave_reduce_min(wave, inst.rd, inst.rs1);
973 }
974 m if m == WaveReduceType::ReduceMax as u8 => {
975 shuffle::wave_reduce_max(wave, inst.rd, inst.rs1);
976 }
977 _ => {}
978 }
979 } else {
980 match inst.modifier {
981 m if m == WaveOpType::Shuffle as u8 => {
982 shuffle::wave_shuffle(wave, inst.rd, inst.rs1, inst.rs2);
983 }
984 m if m == WaveOpType::ShuffleUp as u8 => {
985 shuffle::wave_shuffle_up(wave, inst.rd, inst.rs1, inst.rs2);
986 }
987 m if m == WaveOpType::ShuffleDown as u8 => {
988 shuffle::wave_shuffle_down(wave, inst.rd, inst.rs1, inst.rs2);
989 }
990 m if m == WaveOpType::ShuffleXor as u8 => {
991 shuffle::wave_shuffle_xor(wave, inst.rd, inst.rs1, inst.rs2);
992 }
993 m if m == WaveOpType::Broadcast as u8 => {
994 shuffle::wave_broadcast(wave, inst.rd, inst.rs1, inst.rs2);
995 }
996 m if m == WaveOpType::Ballot as u8 => {
997 shuffle::wave_ballot(wave, inst.rd, inst.rs1);
998 }
999 m if m == WaveOpType::Any as u8 => {
1000 shuffle::wave_any(wave, inst.rd, inst.rs1);
1001 }
1002 m if m == WaveOpType::All as u8 => {
1003 shuffle::wave_all(wave, inst.rd, inst.rs1);
1004 }
1005 _ => {}
1006 }
1007 }
1008 }
1009
1010 fn execute_control(
1011 &mut self,
1012 wave: &mut Wave,
1013 inst: &DecodedInstruction,
1014 stats: &mut ExecutionStats,
1015 ) -> Result<ExecuteResult, EmulatorError> {
1016 stats.record_instruction(InstructionCategory::Control);
1017
1018 if inst.is_sync_op() {
1019 return self.execute_sync_op(wave, inst);
1020 }
1021
1022 if inst.is_misc_op() {
1023 return self.execute_misc_op(wave, inst);
1024 }
1025
1026 self.execute_control_flow(wave, inst, stats)
1027 }
1028
1029 fn execute_sync_op(
1030 &self,
1031 wave: &mut Wave,
1032 inst: &DecodedInstruction,
1033 ) -> Result<ExecuteResult, EmulatorError> {
1034 match inst.modifier {
1035 m if m == SyncOp::Return as u8 => {
1036 if let Some(return_pc) = wave.pop_call() {
1037 Ok(ExecuteResult::Jump(return_pc))
1038 } else {
1039 Ok(ExecuteResult::Halt)
1040 }
1041 }
1042 m if m == SyncOp::Halt as u8 => Ok(ExecuteResult::Halt),
1043 m if m == SyncOp::Barrier as u8 => Ok(ExecuteResult::Barrier),
1044 m if m == SyncOp::Nop as u8 || m == SyncOp::Wait as u8 => Ok(ExecuteResult::Continue),
1045 _ => Ok(ExecuteResult::Continue),
1046 }
1047 }
1048
1049 fn execute_misc_op(
1050 &self,
1051 wave: &mut Wave,
1052 inst: &DecodedInstruction,
1053 ) -> Result<ExecuteResult, EmulatorError> {
1054 match inst.modifier {
1055 m if m == MiscOp::Mov as u8 => {
1056 for lane in 0..wave.wave_width {
1057 if wave.is_thread_active(lane) {
1058 let thread = &mut wave.threads[lane as usize];
1059 let value = thread.read_register(inst.rs1);
1060 thread.write_register(inst.rd, value);
1061 }
1062 }
1063 }
1064 m if m == MiscOp::MovImm as u8 => {
1065 for lane in 0..wave.wave_width {
1066 if wave.is_thread_active(lane) {
1067 wave.threads[lane as usize].write_register(inst.rd, inst.immediate);
1068 }
1069 }
1070 }
1071 m if m == MiscOp::MovSr as u8 => {
1072 for lane in 0..wave.wave_width {
1073 if wave.is_thread_active(lane) {
1074 let thread = &mut wave.threads[lane as usize];
1075 let value = thread.read_special(inst.rs1);
1076 thread.write_register(inst.rd, value);
1077 }
1078 }
1079 }
1080 _ => {}
1081 }
1082 Ok(ExecuteResult::Continue)
1083 }
1084
1085 fn execute_control_flow(
1086 &mut self,
1087 wave: &mut Wave,
1088 inst: &DecodedInstruction,
1089 stats: &mut ExecutionStats,
1090 ) -> Result<ExecuteResult, EmulatorError> {
1091 match inst.modifier {
1092 m if m == ControlOp::If as u8 => {
1093 let mut pred_mask: u64 = 0;
1094 for lane in 0..wave.wave_width {
1095 if wave.is_thread_active(lane) {
1096 if wave.threads[lane as usize].read_predicate(inst.rs1) {
1097 pred_mask |= 1u64 << lane;
1098 }
1099 }
1100 }
1101
1102 let then_mask = wave.active_mask & pred_mask;
1103 let else_mask = wave.active_mask & !pred_mask;
1104
1105 if then_mask != wave.active_mask && else_mask != 0 {
1106 stats.record_divergent_branch();
1107 }
1108
1109 let (new_mask, _) = wave.control_flow.handle_if(wave.active_mask, pred_mask)?;
1110 wave.active_mask = new_mask;
1111 Ok(ExecuteResult::Continue)
1112 }
1113 m if m == ControlOp::Else as u8 => {
1114 let (new_mask, _) = wave.control_flow.handle_else(wave.active_mask)?;
1115 wave.active_mask = new_mask;
1116 Ok(ExecuteResult::Continue)
1117 }
1118 m if m == ControlOp::Endif as u8 => {
1119 let new_mask = wave.control_flow.handle_endif()?;
1120 wave.active_mask = new_mask;
1121 Ok(ExecuteResult::Continue)
1122 }
1123 m if m == ControlOp::Loop as u8 => {
1124 let body_start = wave.pc + inst.size;
1125 let new_mask = wave
1126 .control_flow
1127 .handle_loop(wave.active_mask, body_start)?;
1128 wave.active_mask = new_mask;
1129 Ok(ExecuteResult::Continue)
1130 }
1131 m if m == ControlOp::Break as u8 => {
1132 let mut pred_mask: u64 = 0;
1133 for lane in 0..wave.wave_width {
1134 if wave.is_thread_active(lane) {
1135 if wave.threads[lane as usize].read_predicate(inst.rs1) {
1136 pred_mask |= 1u64 << lane;
1137 }
1138 }
1139 }
1140
1141 if self.trace.is_enabled() {
1142 eprintln!(
1143 " BREAK: active_mask=0x{:x}, pred_mask=0x{:x}, pred_reg=p{}",
1144 wave.active_mask, pred_mask, inst.rs1
1145 );
1146 }
1147
1148 let (new_mask, jump) = wave
1149 .control_flow
1150 .handle_break(wave.active_mask, pred_mask)?;
1151 wave.active_mask = new_mask;
1152
1153 if self.trace.is_enabled() {
1154 eprintln!(" BREAK: new_active_mask=0x{new_mask:x}, jump={jump:?}");
1155 }
1156
1157 Ok(ExecuteResult::Continue)
1158 }
1159 m if m == ControlOp::Continue as u8 => {
1160 let mut pred_mask: u64 = 0;
1161 for lane in 0..wave.wave_width {
1162 if wave.is_thread_active(lane) {
1163 if wave.threads[lane as usize].read_predicate(inst.rs1) {
1164 pred_mask |= 1u64 << lane;
1165 }
1166 }
1167 }
1168
1169 let (new_mask, jump) = wave
1170 .control_flow
1171 .handle_continue(wave.active_mask, pred_mask)?;
1172 wave.active_mask = new_mask;
1173 if let Some(target) = jump {
1174 Ok(ExecuteResult::Jump(target))
1175 } else {
1176 Ok(ExecuteResult::Continue)
1177 }
1178 }
1179 m if m == ControlOp::Endloop as u8 => {
1180 if self.trace.is_enabled() {
1181 eprintln!(" ENDLOOP: active_mask=0x{:x}", wave.active_mask);
1182 }
1183
1184 let (new_mask, jump) = wave.control_flow.handle_endloop(wave.active_mask)?;
1185 wave.active_mask = new_mask;
1186
1187 if self.trace.is_enabled() {
1188 eprintln!(" ENDLOOP: new_active_mask=0x{new_mask:x}, jump={jump:?}");
1189 }
1190
1191 if let Some(target) = jump {
1192 Ok(ExecuteResult::Jump(target))
1193 } else {
1194 Ok(ExecuteResult::Continue)
1195 }
1196 }
1197 m if m == ControlOp::Call as u8 => {
1198 let return_pc = wave.pc + inst.size;
1199 wave.push_call(return_pc)
1200 .map_err(|_| EmulatorError::StackOverflow {
1201 kind: "call".into(),
1202 })?;
1203 Ok(ExecuteResult::Jump(inst.immediate))
1204 }
1205 _ => Ok(ExecuteResult::Continue),
1206 }
1207 }
1208}
1209
1210#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1211pub enum StepResult {
1212 Continue,
1213 Halted,
1214 Barrier,
1215}
1216
1217#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1218enum ExecuteResult {
1219 Continue,
1220 Jump(u32),
1221 Halt,
1222 Barrier,
1223}
1224
1225#[cfg(test)]
1226mod tests {
1227 use super::*;
1228 use crate::decoder::MISC_OP_FLAG;
1229
1230 fn encode_base(opcode: u8, rd: u8, rs1: u8, rs2: u8, modifier: u8, flags: u8) -> Vec<u8> {
1231 let word = ((u32::from(opcode) & 0x3F) << 26)
1232 | ((u32::from(rd) & 0x1F) << 21)
1233 | ((u32::from(rs1) & 0x1F) << 16)
1234 | ((u32::from(rs2) & 0x1F) << 11)
1235 | ((u32::from(modifier) & 0x0F) << 7)
1236 | (u32::from(flags) & 0x03);
1237 word.to_le_bytes().to_vec()
1238 }
1239
1240 fn encode_extended(
1241 opcode: u8,
1242 rd: u8,
1243 rs1: u8,
1244 rs2: u8,
1245 modifier: u8,
1246 flags: u8,
1247 imm: u32,
1248 ) -> Vec<u8> {
1249 let word0 = ((u32::from(opcode) & 0x3F) << 26)
1250 | ((u32::from(rd) & 0x1F) << 21)
1251 | ((u32::from(rs1) & 0x1F) << 16)
1252 | ((u32::from(rs2) & 0x1F) << 11)
1253 | ((u32::from(modifier) & 0x0F) << 7)
1254 | (u32::from(flags) & 0x03);
1255 let mut code = word0.to_le_bytes().to_vec();
1256 code.extend_from_slice(&imm.to_le_bytes());
1257 code
1258 }
1259
1260 #[test]
1261 fn test_executor_iadd() {
1262 let code = encode_base(0x00, 3, 1, 2, 0, 0);
1263 let mut wave = Wave::new(4, 32, 0, [0, 0, 0], [4, 1, 1], [1, 1, 1], 0, 4, 1);
1264
1265 for i in 0..4 {
1266 wave.threads[i].write_register(1, 10);
1267 wave.threads[i].write_register(2, 20);
1268 }
1269
1270 let mut executor = Executor::new(&code, false, [0, 0, 0]);
1271 let mut local_memory = LocalMemory::new(1024);
1272 let mut device_memory = DeviceMemory::new(1024);
1273 let mut stats = ExecutionStats::new();
1274
1275 executor
1276 .step(&mut wave, &mut local_memory, &mut device_memory, &mut stats)
1277 .unwrap();
1278
1279 for i in 0..4 {
1280 assert_eq!(wave.threads[i].read_register(3), 30);
1281 }
1282 }
1283
1284 #[test]
1285 fn test_executor_mov_imm() {
1286 let code = encode_extended(0x3F, 5, 0, 0, 1, MISC_OP_FLAG as u8, 0xDEADBEEF);
1287 let mut wave = Wave::new(4, 32, 0, [0, 0, 0], [4, 1, 1], [1, 1, 1], 0, 4, 1);
1288
1289 let mut executor = Executor::new(&code, false, [0, 0, 0]);
1290 let mut local_memory = LocalMemory::new(1024);
1291 let mut device_memory = DeviceMemory::new(1024);
1292 let mut stats = ExecutionStats::new();
1293
1294 executor
1295 .step(&mut wave, &mut local_memory, &mut device_memory, &mut stats)
1296 .unwrap();
1297
1298 for i in 0..4 {
1299 assert_eq!(wave.threads[i].read_register(5), 0xDEADBEEF);
1300 }
1301 }
1302
1303 #[test]
1304 fn test_executor_respects_active_mask() {
1305 let code = encode_base(0x00, 3, 1, 2, 0, 0);
1306 let mut wave = Wave::new(4, 32, 0, [0, 0, 0], [4, 1, 1], [1, 1, 1], 0, 4, 1);
1307
1308 wave.active_mask = 0b0101;
1309
1310 for i in 0..4 {
1311 wave.threads[i].write_register(1, 10);
1312 wave.threads[i].write_register(2, 20);
1313 wave.threads[i].write_register(3, 0);
1314 }
1315
1316 let mut executor = Executor::new(&code, false, [0, 0, 0]);
1317 let mut local_memory = LocalMemory::new(1024);
1318 let mut device_memory = DeviceMemory::new(1024);
1319 let mut stats = ExecutionStats::new();
1320
1321 executor
1322 .step(&mut wave, &mut local_memory, &mut device_memory, &mut stats)
1323 .unwrap();
1324
1325 assert_eq!(wave.threads[0].read_register(3), 30);
1326 assert_eq!(wave.threads[1].read_register(3), 0);
1327 assert_eq!(wave.threads[2].read_register(3), 30);
1328 assert_eq!(wave.threads[3].read_register(3), 0);
1329 }
1330}