Skip to main content

wave_emu/
thread.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Per-thread execution state. Each thread has a register file (32 x u32), four
5//!
6//! predicate registers (p0-p3), and read-only special registers populated at
7//! dispatch time (thread/wave/workgroup IDs, dimensions, etc).
8
9#[derive(Debug, Clone)]
10pub struct Thread {
11    pub registers: Vec<u32>,
12    pub predicates: [bool; 4],
13    pub special_registers: SpecialRegisters,
14}
15
16#[derive(Debug, Clone, Copy, Default)]
17pub struct SpecialRegisters {
18    pub thread_id: [u32; 3],
19    pub wave_id: u32,
20    pub lane_id: u32,
21    pub workgroup_id: [u32; 3],
22    pub workgroup_size: [u32; 3],
23    pub grid_size: [u32; 3],
24    pub wave_width: u32,
25    pub num_waves: u32,
26}
27
28impl SpecialRegisters {
29    pub fn get(&self, index: u8) -> u32 {
30        match index {
31            0 => self.thread_id[0],
32            1 => self.thread_id[1],
33            2 => self.thread_id[2],
34            3 => self.wave_id,
35            4 => self.lane_id,
36            5 => self.workgroup_id[0],
37            6 => self.workgroup_id[1],
38            7 => self.workgroup_id[2],
39            8 => self.workgroup_size[0],
40            9 => self.workgroup_size[1],
41            10 => self.workgroup_size[2],
42            11 => self.grid_size[0],
43            12 => self.grid_size[1],
44            13 => self.grid_size[2],
45            14 => self.wave_width,
46            15 => self.num_waves,
47            _ => 0,
48        }
49    }
50}
51
52impl Thread {
53    pub fn new(register_count: u32) -> Self {
54        Self {
55            registers: vec![0; register_count as usize],
56            predicates: [false; 4],
57            special_registers: SpecialRegisters::default(),
58        }
59    }
60
61    pub fn with_special_registers(register_count: u32, special: SpecialRegisters) -> Self {
62        Self {
63            registers: vec![0; register_count as usize],
64            predicates: [false; 4],
65            special_registers: special,
66        }
67    }
68
69    pub fn read_register(&self, index: u8) -> u32 {
70        self.registers.get(index as usize).copied().unwrap_or(0)
71    }
72
73    pub fn write_register(&mut self, index: u8, value: u32) {
74        if (index as usize) < self.registers.len() {
75            self.registers[index as usize] = value;
76        }
77    }
78
79    pub fn read_predicate(&self, index: u8) -> bool {
80        self.predicates
81            .get(index as usize)
82            .copied()
83            .unwrap_or(false)
84    }
85
86    pub fn write_predicate(&mut self, index: u8, value: bool) {
87        if (index as usize) < self.predicates.len() {
88            self.predicates[index as usize] = value;
89        }
90    }
91
92    pub fn read_special(&self, index: u8) -> u32 {
93        self.special_registers.get(index)
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    #[test]
102    fn test_thread_new() {
103        let thread = Thread::new(32);
104        assert_eq!(thread.registers.len(), 32);
105        assert_eq!(thread.predicates, [false; 4]);
106    }
107
108    #[test]
109    fn test_thread_register_read_write() {
110        let mut thread = Thread::new(32);
111        thread.write_register(5, 0x12345678);
112        assert_eq!(thread.read_register(5), 0x12345678);
113    }
114
115    #[test]
116    fn test_thread_predicate_read_write() {
117        let mut thread = Thread::new(32);
118        thread.write_predicate(2, true);
119        assert!(thread.read_predicate(2));
120        assert!(!thread.read_predicate(0));
121    }
122
123    #[test]
124    fn test_thread_special_registers() {
125        let special = SpecialRegisters {
126            thread_id: [10, 20, 30],
127            wave_id: 2,
128            lane_id: 15,
129            workgroup_id: [1, 2, 3],
130            workgroup_size: [64, 1, 1],
131            grid_size: [4, 4, 1],
132            wave_width: 32,
133            num_waves: 2,
134        };
135        let thread = Thread::with_special_registers(32, special);
136
137        assert_eq!(thread.read_special(0), 10);
138        assert_eq!(thread.read_special(1), 20);
139        assert_eq!(thread.read_special(2), 30);
140        assert_eq!(thread.read_special(3), 2);
141        assert_eq!(thread.read_special(4), 15);
142        assert_eq!(thread.read_special(5), 1);
143        assert_eq!(thread.read_special(14), 32);
144    }
145
146    #[test]
147    fn test_thread_out_of_bounds_register() {
148        let thread = Thread::new(8);
149        assert_eq!(thread.read_register(100), 0);
150    }
151}