1#[derive(Debug, Clone)]
10pub struct Thread {
11 pub registers: Vec<u32>,
12 pub predicates: [bool; 4],
13 pub special_registers: SpecialRegisters,
14}
15
16#[derive(Debug, Clone, Copy, Default)]
17pub struct SpecialRegisters {
18 pub thread_id: [u32; 3],
19 pub wave_id: u32,
20 pub lane_id: u32,
21 pub workgroup_id: [u32; 3],
22 pub workgroup_size: [u32; 3],
23 pub grid_size: [u32; 3],
24 pub wave_width: u32,
25 pub num_waves: u32,
26}
27
28impl SpecialRegisters {
29 pub fn get(&self, index: u8) -> u32 {
30 match index {
31 0 => self.thread_id[0],
32 1 => self.thread_id[1],
33 2 => self.thread_id[2],
34 3 => self.wave_id,
35 4 => self.lane_id,
36 5 => self.workgroup_id[0],
37 6 => self.workgroup_id[1],
38 7 => self.workgroup_id[2],
39 8 => self.workgroup_size[0],
40 9 => self.workgroup_size[1],
41 10 => self.workgroup_size[2],
42 11 => self.grid_size[0],
43 12 => self.grid_size[1],
44 13 => self.grid_size[2],
45 14 => self.wave_width,
46 15 => self.num_waves,
47 _ => 0,
48 }
49 }
50}
51
52impl Thread {
53 pub fn new(register_count: u32) -> Self {
54 Self {
55 registers: vec![0; register_count as usize],
56 predicates: [false; 4],
57 special_registers: SpecialRegisters::default(),
58 }
59 }
60
61 pub fn with_special_registers(register_count: u32, special: SpecialRegisters) -> Self {
62 Self {
63 registers: vec![0; register_count as usize],
64 predicates: [false; 4],
65 special_registers: special,
66 }
67 }
68
69 pub fn read_register(&self, index: u8) -> u32 {
70 self.registers.get(index as usize).copied().unwrap_or(0)
71 }
72
73 pub fn write_register(&mut self, index: u8, value: u32) {
74 if (index as usize) < self.registers.len() {
75 self.registers[index as usize] = value;
76 }
77 }
78
79 pub fn read_predicate(&self, index: u8) -> bool {
80 self.predicates
81 .get(index as usize)
82 .copied()
83 .unwrap_or(false)
84 }
85
86 pub fn write_predicate(&mut self, index: u8, value: bool) {
87 if (index as usize) < self.predicates.len() {
88 self.predicates[index as usize] = value;
89 }
90 }
91
92 pub fn read_special(&self, index: u8) -> u32 {
93 self.special_registers.get(index)
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100
101 #[test]
102 fn test_thread_new() {
103 let thread = Thread::new(32);
104 assert_eq!(thread.registers.len(), 32);
105 assert_eq!(thread.predicates, [false; 4]);
106 }
107
108 #[test]
109 fn test_thread_register_read_write() {
110 let mut thread = Thread::new(32);
111 thread.write_register(5, 0x12345678);
112 assert_eq!(thread.read_register(5), 0x12345678);
113 }
114
115 #[test]
116 fn test_thread_predicate_read_write() {
117 let mut thread = Thread::new(32);
118 thread.write_predicate(2, true);
119 assert!(thread.read_predicate(2));
120 assert!(!thread.read_predicate(0));
121 }
122
123 #[test]
124 fn test_thread_special_registers() {
125 let special = SpecialRegisters {
126 thread_id: [10, 20, 30],
127 wave_id: 2,
128 lane_id: 15,
129 workgroup_id: [1, 2, 3],
130 workgroup_size: [64, 1, 1],
131 grid_size: [4, 4, 1],
132 wave_width: 32,
133 num_waves: 2,
134 };
135 let thread = Thread::with_special_registers(32, special);
136
137 assert_eq!(thread.read_special(0), 10);
138 assert_eq!(thread.read_special(1), 20);
139 assert_eq!(thread.read_special(2), 30);
140 assert_eq!(thread.read_special(3), 2);
141 assert_eq!(thread.read_special(4), 15);
142 assert_eq!(thread.read_special(5), 1);
143 assert_eq!(thread.read_special(14), 32);
144 }
145
146 #[test]
147 fn test_thread_out_of_bounds_register() {
148 let thread = Thread::new(8);
149 assert_eq!(thread.read_register(100), 0);
150 }
151}