1use crate::device::GpuDevice;
4use std::sync::Arc;
5use tl_ai::TlTensor;
6use wgpu;
7
8#[derive(Debug, Clone, Copy, PartialEq)]
10pub enum DType {
11 F32,
12 F64,
13}
14
15impl std::fmt::Display for DType {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 match self {
18 DType::F32 => write!(f, "f32"),
19 DType::F64 => write!(f, "f64"),
20 }
21 }
22}
23
24pub struct GpuTensor {
26 pub buffer: wgpu::Buffer,
27 pub shape: Vec<usize>,
28 pub dtype: DType,
29 pub numel: usize,
30 pub device: Arc<GpuDevice>,
31}
32
33impl GpuTensor {
34 pub fn from_cpu(tensor: &TlTensor, device: Arc<GpuDevice>) -> Self {
36 let data_f32: Vec<f32> = tensor.data.iter().map(|&v| v as f32).collect();
37 Self::from_f32(&data_f32, tensor.data.shape().to_vec(), device)
38 }
39
40 pub fn from_f32(data: &[f32], shape: Vec<usize>, device: Arc<GpuDevice>) -> Self {
42 let bytes = bytemuck::cast_slice(data);
43 let buffer = device
44 .device
45 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
46 label: Some("gpu_tensor_data"),
47 contents: bytes,
48 usage: wgpu::BufferUsages::STORAGE
49 | wgpu::BufferUsages::COPY_SRC
50 | wgpu::BufferUsages::COPY_DST,
51 });
52
53 let numel = data.len();
54 GpuTensor {
55 buffer,
56 shape,
57 dtype: DType::F32,
58 numel,
59 device,
60 }
61 }
62
63 pub fn to_cpu(&self) -> Result<TlTensor, String> {
65 let f32_data = self.read_f32()?;
66 let f64_data: Vec<f64> = f32_data.iter().map(|&v| v as f64).collect();
67 let shape = ndarray::IxDyn(&self.shape);
68 let array = ndarray::ArrayD::from_shape_vec(shape, f64_data)
69 .map_err(|e| format!("Shape mismatch: {e}"))?;
70 Ok(TlTensor {
71 data: array,
72 name: None,
73 })
74 }
75
76 pub fn read_f32(&self) -> Result<Vec<f32>, String> {
78 let size = (self.numel * std::mem::size_of::<f32>()) as u64;
79 let staging = self.device.device.create_buffer(&wgpu::BufferDescriptor {
80 label: Some("staging_read"),
81 size,
82 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
83 mapped_at_creation: false,
84 });
85
86 let mut encoder =
87 self.device
88 .device
89 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
90 label: Some("readback"),
91 });
92 encoder.copy_buffer_to_buffer(&self.buffer, 0, &staging, 0, size);
93 self.device.queue.submit(std::iter::once(encoder.finish()));
94
95 let slice = staging.slice(..);
96 let (tx, rx) = std::sync::mpsc::channel();
97 slice.map_async(wgpu::MapMode::Read, move |result| {
98 let _ = tx.send(result);
99 });
100 self.device.device.poll(wgpu::Maintain::Wait);
101 rx.recv()
102 .map_err(|e| format!("GPU readback channel error: {e}"))?
103 .map_err(|e| format!("GPU readback error: {e}"))?;
104
105 let data = slice.get_mapped_range();
106 let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
107 drop(data);
108 staging.unmap();
109
110 Ok(result)
111 }
112
113 pub fn byte_size(&self) -> u64 {
115 (self.numel * std::mem::size_of::<f32>()) as u64
116 }
117}
118
119impl Clone for GpuTensor {
120 fn clone(&self) -> Self {
121 let size = self.byte_size();
122 let new_buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
123 label: Some("gpu_tensor_clone"),
124 size,
125 usage: wgpu::BufferUsages::STORAGE
126 | wgpu::BufferUsages::COPY_SRC
127 | wgpu::BufferUsages::COPY_DST,
128 mapped_at_creation: false,
129 });
130
131 let mut encoder =
132 self.device
133 .device
134 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
135 label: Some("clone"),
136 });
137 encoder.copy_buffer_to_buffer(&self.buffer, 0, &new_buffer, 0, size);
138 self.device.queue.submit(std::iter::once(encoder.finish()));
139
140 GpuTensor {
141 buffer: new_buffer,
142 shape: self.shape.clone(),
143 dtype: self.dtype,
144 numel: self.numel,
145 device: self.device.clone(),
146 }
147 }
148}
149
150impl std::fmt::Debug for GpuTensor {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 write!(
153 f,
154 "GpuTensor(shape={:?}, dtype={}, device={})",
155 self.shape, self.dtype, self.device.adapter_name
156 )
157 }
158}
159
160impl std::fmt::Display for GpuTensor {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 write!(
163 f,
164 "<gpu_tensor shape={:?} dtype={}>",
165 self.shape, self.dtype
166 )
167 }
168}
169
170use wgpu::util::DeviceExt;
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn test_roundtrip_cpu_gpu_cpu() {
179 let Some(device) = GpuDevice::get() else {
180 return;
181 };
182
183 let cpu_tensor = TlTensor {
184 data: ndarray::arr1(&[1.0, 2.0, 3.0, 4.0]).into_dyn(),
185 name: None,
186 };
187
188 let gpu = GpuTensor::from_cpu(&cpu_tensor, device);
189 let back = gpu.to_cpu().unwrap();
190
191 for (a, b) in cpu_tensor.data.iter().zip(back.data.iter()) {
193 assert!((a - b).abs() < 1e-6, "mismatch: {a} vs {b}");
194 }
195 }
196}