vortex_gpu_kernels/
bit_unpack.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fs::File;
5use std::io::Write;
6use std::path::Path;
7
8use fastlanes::FastLanes;
9
10use crate::indent::IndentedWriter;
11
12fn generate_lane_decoder<T: FastLanes, W: Write>(
13    output: &mut IndentedWriter<W>,
14    bit_width: usize,
15) -> anyhow::Result<()> {
16    let bits = <T>::T;
17    let lanes = T::LANES;
18
19    let func_name = format!("fls_unpack_{bit_width}bw_{bits}ow_lane");
20
21    writeln!(
22        output,
23        "__device__ void _{func_name}(const uint{bits}_t *__restrict in, uint{bits}_t *__restrict out, unsigned int lane) {{"
24    )?;
25
26    output.indent(|output| {
27        writeln!(output, "unsigned int LANE_COUNT = {lanes};")?;
28        if bit_width == 0 {
29            writeln!(output, "uint{bits}_t zero = 0ULL;")?;
30            writeln!(output)?;
31            for row in 0..bits {
32                writeln!(output, "out[INDEX({row}, lane)] = zero;")?;
33            }
34        } else if bit_width == bits {
35            writeln!(output)?;
36            for row in 0..bits {
37                writeln!(
38                    output,
39                    "out[INDEX({row}, lane)] = in[LANE_COUNT * {row} + lane];",
40                )?;
41            }
42        } else {
43            writeln!(output, "uint{bits}_t src;")?;
44            writeln!(output, "uint{bits}_t tmp;")?;
45
46            writeln!(output)?;
47            writeln!(output, "src = in[lane];")?;
48            for row in 0..bits {
49                let curr_word = (row * bit_width) / bits;
50                let next_word = ((row + 1) * bit_width) / bits;
51                let shift = (row * bit_width) % bits;
52
53                if next_word > curr_word {
54                    let remaining_bits = ((row + 1) * bit_width) % bits;
55                    let current_bits = bit_width - remaining_bits;
56                    writeln!(
57                        output,
58                        "tmp = (src >> {shift}) & MASK(uint{bits}_t, {current_bits});"
59                    )?;
60
61                    if next_word < bit_width {
62                        writeln!(output, "src = in[lane + LANE_COUNT * {next_word}];")?;
63                        writeln!(
64                            output,
65                            "tmp |= (src & MASK(uint{bits}_t, {remaining_bits})) << {current_bits};"
66                        )?;
67                    }
68                } else {
69                    writeln!(
70                        output,
71                        "tmp = (src >> {shift}) & MASK(uint{bits}_t, {bit_width});"
72                    )?;
73                }
74
75                writeln!(output, "out[INDEX({row}, lane)] = tmp;")?;
76            }
77        }
78        Ok(())
79    })?;
80
81    Ok(writeln!(output, "}}")?)
82}
83
84fn generate_device_kernel_for_width<T: FastLanes, W: Write>(
85    output: &mut IndentedWriter<W>,
86    bit_width: usize,
87    thread_count: usize,
88) -> anyhow::Result<()> {
89    let bits = <T>::T;
90    let lanes = T::LANES;
91    let per_thread_loop_count = lanes / thread_count;
92
93    let func_name = format!("fls_unpack_{bit_width}bw_{bits}ow_{thread_count}t");
94
95    let local_func_params = format!(
96        "(const uint{bits}_t *__restrict in, uint{bits}_t *__restrict out, int thread_idx)"
97    );
98
99    writeln!(output, "__device__ void _{func_name}{local_func_params} {{")?;
100
101    output.indent(|output| {
102        for thread_lane in 0..per_thread_loop_count {
103            writeln!(output, "_fls_unpack_{bit_width}bw_{bits}ow_lane(in, out, thread_idx * {per_thread_loop_count} + {thread_lane});")?;
104        }
105        Ok(())
106    })?;
107
108    Ok(writeln!(output, "}}")?)
109}
110
111fn generate_global_kernel_for_width<T: FastLanes, W: Write>(
112    output: &mut IndentedWriter<W>,
113    bit_width: usize,
114    thread_count: usize,
115) -> anyhow::Result<()> {
116    let bits = <T>::T;
117
118    let func_name = format!("fls_unpack_{bit_width}bw_{bits}ow_{thread_count}t");
119    let func_params =
120        format!("(const uint{bits}_t *__restrict full_in, uint{bits}_t *__restrict full_out)");
121
122    writeln!(
123        output,
124        "extern \"C\" __global__ void {func_name}{func_params} {{"
125    )?;
126
127    output.indent(|output| {
128        writeln!(output, "int thread_idx = threadIdx.x;")?;
129        writeln!(
130            output,
131            "auto in = full_in + (blockIdx.x * (128 * {bit_width} / sizeof(uint{bits}_t)));"
132        )?;
133        writeln!(output, "auto out = full_out + (blockIdx.x * 1024);")?;
134
135        writeln!(output, "_{func_name}(in, out, thread_idx);")
136    })?;
137
138    Ok(writeln!(output, "}}")?)
139}
140
141fn generate_unpack_for_width<T: FastLanes, W: Write>(
142    output: &mut IndentedWriter<W>,
143    thread_count: usize,
144) -> anyhow::Result<()> {
145    writeln!(
146        output,
147        "// Auto-generated by vortex-gpu-kernels. Do not edit by hand!"
148    )?;
149    writeln!(output, "#include <cuda.h>")?;
150    writeln!(output, "#include <cuda_runtime.h>")?;
151    writeln!(output, "#include <stdint.h>")?;
152    writeln!(output, "#include \"fastlanes_common.cuh\"")?;
153    writeln!(output)?;
154
155    for bit_width in 0..=<T>::T {
156        generate_lane_decoder::<T, _>(output, bit_width)?;
157        writeln!(output)?;
158        generate_device_kernel_for_width::<T, _>(output, bit_width, thread_count)?;
159        writeln!(output)?;
160
161        generate_global_kernel_for_width::<T, _>(output, bit_width, thread_count)?;
162        writeln!(output)?;
163    }
164
165    Ok(())
166}
167
168pub fn generate_unpack<T: FastLanes>(output_dir: &Path, thread_count: usize) -> anyhow::Result<()> {
169    let cu_filename = format!("gen/fls_{}_bit_unpack.cu", T::T);
170    let cu_path = output_dir.join(&cu_filename);
171    let mut cu_file = File::create(&cu_path)?;
172    let mut cu_writer = IndentedWriter::new(&mut cu_file);
173    generate_unpack_for_width::<T, _>(&mut cu_writer, thread_count)?;
174    Ok(())
175}