web_rwkv/
context.rs

1use std::{borrow::Cow, collections::BTreeMap, sync::Arc};
2
3use futures::Future;
4use thiserror::Error;
5use wasm_bindgen::prelude::wasm_bindgen;
6use web_rwkv_derive::{Deref, DerefMut};
7use wgpu::{
8    util::{BufferInitDescriptor, DeviceExt},
9    Adapter, BindGroup, BindGroupDescriptor, BindGroupEntry, BindGroupLayout,
10    BindGroupLayoutDescriptor, BindGroupLayoutEntry, Buffer, BufferDescriptor, BufferUsages,
11    ComputePipeline, ComputePipelineDescriptor, Device, DeviceDescriptor, ExperimentalFeatures,
12    Features, Instance, Limits, MemoryHints, PipelineLayoutDescriptor, PowerPreference, Queue,
13    RequestAdapterOptions, ShaderModuleDescriptor, Trace,
14};
15
16use crate::tensor::{
17    cache::{ResourceCache, SharedResourceCache},
18    shape::{IntoBytes, Shape},
19    ResourceKey, TensorResource, View,
20};
21
22pub trait InstanceExt {
23    fn adapter(
24        &self,
25        power_preference: PowerPreference,
26    ) -> impl Future<Output = Result<Adapter, ContextError>>;
27}
28
29impl InstanceExt for Instance {
30    async fn adapter(&self, power_preference: PowerPreference) -> Result<Adapter, ContextError> {
31        self.request_adapter(&RequestAdapterOptions {
32            power_preference,
33            force_fallback_adapter: false,
34            compatible_surface: None,
35        })
36        .await
37        .or(Err(ContextError::RequestAdapterFailed))
38    }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
42pub struct ContextId;
43
44#[cfg(not(target_arch = "wasm32"))]
45pub struct ContextEvent {
46    pub buffer: Arc<Buffer>,
47    pub sender: flume::Sender<Box<[u8]>>,
48}
49
50#[derive(Debug, Clone)]
51pub struct Context {
52    pub id: uid::Id<ContextId>,
53    pub adapter: Adapter,
54    pub device: Device,
55    pub queue: Queue,
56
57    pipelines: SharedResourceCache<PipelineKey, CachedPipeline>,
58    shapes: ResourceCache<View, Buffer>,
59    buffers: ResourceCache<BufferKey, Buffer>,
60    bindings: SharedResourceCache<BindGroupKey, BindGroup>,
61
62    #[cfg(not(target_arch = "wasm32"))]
63    event: flume::Sender<ContextEvent>,
64}
65
66#[cfg(not(target_arch = "wasm32"))]
67impl Drop for Context {
68    fn drop(&mut self) {
69        if self.event.sender_count() <= 1 {
70            self.clear_buffers();
71            self.queue.submit(None);
72            _ = self.device.poll(wgpu::PollType::Wait {
73                submission_index: None,
74                timeout: None,
75            });
76        }
77    }
78}
79
80impl PartialEq for Context {
81    fn eq(&self, other: &Self) -> bool {
82        self.id == other.id
83    }
84}
85
86pub struct ContextBuilder {
87    pub adapter: Adapter,
88    pub features: Features,
89    pub limits: Limits,
90}
91
92#[wasm_bindgen]
93#[derive(Debug, Error)]
94pub enum ContextError {
95    #[error("failed to request adaptor")]
96    RequestAdapterFailed,
97    #[error("failed to request device")]
98    RequestDeviceFailed,
99}
100
101impl ContextBuilder {
102    pub fn new(adapter: Adapter) -> Self {
103        let features = Features::empty();
104        #[cfg(feature = "subgroup-ops")]
105        let features = features | Features::SUBGROUP;
106        Self {
107            adapter,
108            features,
109            limits: Default::default(),
110        }
111    }
112
113    pub async fn build(self) -> Result<Context, ContextError> {
114        let Self {
115            adapter,
116            features,
117            limits,
118        } = self;
119
120        let (device, queue) = adapter
121            .request_device(&DeviceDescriptor {
122                label: None,
123                required_features: features,
124                required_limits: limits,
125                memory_hints: MemoryHints::Performance,
126                trace: Trace::Off,
127                experimental_features: ExperimentalFeatures::disabled(),
128            })
129            .await
130            .map_err(|_| ContextError::RequestDeviceFailed)?;
131
132        #[cfg(not(target_arch = "wasm32"))]
133        let (event, receiver) = flume::unbounded();
134
135        let context = Context {
136            id: uid::Id::new(),
137            adapter,
138            device,
139            queue,
140            pipelines: Default::default(),
141            shapes: Default::default(),
142            buffers: ResourceCache::new(4),
143            bindings: SharedResourceCache::new(64),
144            #[cfg(not(target_arch = "wasm32"))]
145            event,
146        };
147
148        // start a thread for reading back buffers
149        #[cfg(not(target_arch = "wasm32"))]
150        {
151            let id = context.id;
152            let device = context.device.clone();
153            std::thread::spawn(move || {
154                while let Ok(ContextEvent { buffer, sender }) = receiver.recv() {
155                    #[cfg(feature = "trace")]
156                    let _span = tracing::trace_span!("device").entered();
157                    let data = read_back_buffer(&device, &buffer);
158                    let _ = sender.send(data);
159                }
160                log::info!("context dropped: {id}");
161            });
162        }
163
164        Ok(context)
165    }
166
167    pub fn limits(mut self, limits: Limits) -> Self {
168        self.limits = limits;
169        self
170    }
171
172    pub fn update_limits(mut self, f: impl FnOnce(&mut Limits)) -> Self {
173        f(&mut self.limits);
174        self
175    }
176
177    pub fn features(mut self, features: Features) -> Self {
178        self.features = features;
179        self
180    }
181
182    pub fn update_features(mut self, f: impl FnOnce(&mut Features)) -> Self {
183        f(&mut self.features);
184        self
185    }
186}
187
188/// A container of macro definitions in shader.
189#[derive(Debug, Default, Clone, Deref, DerefMut, PartialEq, Eq, Hash)]
190pub struct Macros(BTreeMap<String, String>);
191
192impl Macros {
193    pub fn new() -> Self {
194        Default::default()
195    }
196
197    pub fn compile(self) -> Vec<(String, String)> {
198        self.0.into_iter().collect()
199    }
200}
201
202#[derive(Debug, Clone, PartialEq, Eq, Hash)]
203pub struct PipelineKey {
204    name: String,
205    entry_point: String,
206    macros: Vec<(String, String)>,
207}
208
209impl PipelineKey {
210    pub fn new(name: impl AsRef<str>, entry_point: impl AsRef<str>, macros: Macros) -> Self {
211        let name = name.as_ref().into();
212        let entry_point = entry_point.as_ref().into();
213        let macros = macros.compile();
214        Self {
215            name,
216            entry_point,
217            macros,
218        }
219    }
220}
221
222#[derive(Debug, Clone)]
223pub struct CachedPipeline {
224    pub pipeline: ComputePipeline,
225    pub layout: BindGroupLayout,
226}
227
228#[derive(Debug, Clone, PartialEq, Eq, Hash)]
229struct BufferKey {
230    size: usize,
231    usage: BufferUsages,
232}
233
234#[derive(Debug, Clone, PartialEq, Eq, Hash)]
235struct BindGroupKey {
236    pipeline: PipelineKey,
237    bindings: Vec<(u32, ResourceKey)>,
238}
239
240pub struct BindGroupBuilder<'a, 'b> {
241    context: &'b Context,
242    layout: &'b BindGroupLayout,
243    key: BindGroupKey,
244    entries: Vec<BindGroupEntry<'a>>,
245}
246
247impl<'a, 'b> BindGroupBuilder<'a, 'b> {
248    pub fn new(key: &PipelineKey, context: &'b Context, layout: &'b BindGroupLayout) -> Self {
249        Self {
250            context,
251            layout,
252            key: BindGroupKey {
253                pipeline: key.clone(),
254                bindings: vec![],
255            },
256            entries: vec![],
257        }
258    }
259
260    /// Mark a resource as being touched.
261    /// How resources are touched determines whether the bind group can be found in cache.
262    pub fn touch(mut self, binding: u32, tensor: &'a impl TensorResource) -> Self {
263        let key = tensor.resource_key();
264        self.key.bindings.push((binding, key));
265        self
266    }
267
268    /// Insert an entry into the bind group.
269    pub fn bind(mut self, binding: u32, tensor: &'a impl TensorResource) -> Self {
270        let resource = tensor.binding();
271        self.entries.push(BindGroupEntry { binding, resource });
272        self.touch(binding, tensor)
273    }
274
275    /// Insert an entry into the bind group.
276    pub fn bind_meta(mut self, binding: u32, tensor: &'a impl TensorResource) -> Self {
277        let resource = tensor.meta_binding();
278        self.entries.push(BindGroupEntry { binding, resource });
279        // self.touch(binding, tensor)
280        self
281    }
282
283    pub fn build(self) -> Arc<BindGroup> {
284        let name = self.key.pipeline.name.clone();
285        self.context.bindings.checkout(self.key, || {
286            self.context.device.create_bind_group(&BindGroupDescriptor {
287                label: Some(&name),
288                layout: self.layout,
289                entries: &self.entries,
290            })
291        })
292    }
293}
294
295impl Eq for Context {}
296
297impl Context {
298    pub fn checkout_pipeline(
299        &self,
300        key: &PipelineKey,
301        source: impl AsRef<str>,
302        entries: &[BindGroupLayoutEntry],
303    ) -> Arc<CachedPipeline> {
304        self.pipelines.checkout(key.clone(), || {
305            use gpp::{process_str, Context};
306            let mut context = Context::new();
307            context.macros = key.macros.iter().cloned().collect();
308
309            let shader = process_str(source.as_ref(), &mut context).unwrap();
310            let module = &self.device.create_shader_module(ShaderModuleDescriptor {
311                label: Some(&key.name),
312                source: wgpu::ShaderSource::Wgsl(Cow::from(shader)),
313            });
314
315            let layout = self
316                .device
317                .create_bind_group_layout(&BindGroupLayoutDescriptor {
318                    label: Some(&key.name),
319                    entries,
320                });
321            let pipeline_layout = self
322                .device
323                .create_pipeline_layout(&PipelineLayoutDescriptor {
324                    label: Some(&key.name),
325                    bind_group_layouts: &[&layout],
326                    push_constant_ranges: &[],
327                });
328
329            let pipeline = self
330                .device
331                .create_compute_pipeline(&ComputePipelineDescriptor {
332                    label: Some(&key.name),
333                    layout: Some(&pipeline_layout),
334                    module,
335                    entry_point: Some(&key.entry_point),
336                    compilation_options: Default::default(),
337                    cache: None,
338                });
339            CachedPipeline { pipeline, layout }
340        })
341    }
342
343    pub(crate) fn checkout_shape_uniform(&self, shape: Shape) -> Arc<Buffer> {
344        let view = View {
345            shape,
346            stride: shape,
347            offset: Shape::new(0, 0, 0, 0),
348        };
349        let desc = BufferInitDescriptor {
350            label: None,
351            contents: &view.into_bytes(),
352            usage: BufferUsages::UNIFORM,
353        };
354        self.shapes
355            .checkout(view, || self.device.create_buffer_init(&desc))
356    }
357
358    pub(crate) fn checkout_view_uniform(&self, view: View) -> Arc<Buffer> {
359        let desc = BufferInitDescriptor {
360            label: None,
361            contents: &view.into_bytes(),
362            usage: BufferUsages::UNIFORM,
363        };
364        self.shapes
365            .checkout(view, || self.device.create_buffer_init(&desc))
366    }
367
368    pub(crate) fn checkout_buffer_init(&self, contents: &[u8], usage: BufferUsages) -> Arc<Buffer> {
369        let size = std::mem::size_of_val(contents);
370        let _key = BufferKey { size, usage };
371        let desc = BufferInitDescriptor {
372            label: None,
373            contents,
374            usage,
375        };
376        // self.buffer_cache.checkout(
377        //     key,
378        //     || self.device.create_buffer_init(&desc),
379        //     |buffer| self.queue.write_buffer(buffer, 0, contents),
380        // )
381        self.device.create_buffer_init(&desc).into()
382    }
383
384    pub(crate) fn checkout_buffer(&self, size: usize, usage: BufferUsages) -> Arc<Buffer> {
385        let key = BufferKey { size, usage };
386        let desc = BufferDescriptor {
387            label: None,
388            size: size as u64,
389            usage,
390            mapped_at_creation: false,
391        };
392        self.buffers
393            .checkout(key, || self.device.create_buffer(&desc))
394    }
395
396    // pub(crate) fn checkout_buffer_uncached(&self, size: usize, usage: BufferUsages) -> Arc<Buffer> {
397    //     self.device
398    //         .create_buffer(&BufferDescriptor {
399    //             label: None,
400    //             size: size as u64,
401    //             usage,
402    //             mapped_at_creation: false,
403    //         })
404    //         .into()
405    // }
406
407    /// Maintain resource caches.
408    #[inline]
409    pub fn maintain(&self) {
410        self.pipelines.maintain();
411        self.shapes.maintain();
412        self.buffers.maintain();
413        self.bindings.maintain();
414    }
415
416    /// Clear resource caches.
417    #[inline]
418    pub fn clear_buffers(&self) {
419        self.shapes.clear();
420        self.buffers.clear();
421    }
422
423    #[cfg(not(target_arch = "wasm32"))]
424    pub(crate) fn event(&self) -> flume::Sender<ContextEvent> {
425        self.event.clone()
426    }
427
428    #[cfg(feature = "subgroup-ops")]
429    pub fn min_subgroup_size(&self) -> u32 {
430        self.adapter.limits().min_subgroup_size
431    }
432
433    #[cfg(feature = "subgroup-ops")]
434    pub fn max_subgroup_size(&self) -> u32 {
435        self.adapter.limits().max_subgroup_size
436    }
437}
438
439#[cfg(not(target_arch = "wasm32"))]
440fn read_back_buffer(device: &Device, buffer: &Buffer) -> Box<[u8]> {
441    assert!(buffer.usage().contains(BufferUsages::MAP_READ));
442
443    let (sender, receiver) = flume::bounded(1);
444    let slice = buffer.slice(..);
445    slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
446
447    _ = device.poll(wgpu::PollType::Wait {
448        submission_index: None,
449        timeout: None,
450    });
451    receiver
452        .recv()
453        .expect("failed to receive read back buffer")
454        .expect("failed to map buffer");
455
456    let data = {
457        let map = slice.get_mapped_range();
458        let len = map.len();
459        let size = std::mem::size_of::<u32>();
460        let data = vec![0u32; len.div_ceil(size)].into_boxed_slice();
461        unsafe {
462            let data = Box::leak(data);
463            let data: &mut [u8] = bytemuck::cast_slice_mut(data);
464            data.copy_from_slice(&map);
465            Box::from_raw(data)
466        }
467    };
468    buffer.unmap();
469    data
470}