rustic_zen/image/
vulkan.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4
5//! VulkanImage contains the full colour depth framebuffer and a GPU rendering Context, and provides functions to export it.
6use std::sync::Arc;
7use std::thread::{self, JoinHandle};
8
9use crossbeam::channel::{bounded, Receiver, Sender};
10
11use vulkano::buffer::{Buffer, BufferContents, BufferCreateInfo, BufferUsage, Subbuffer};
12use vulkano::command_buffer::{
13    CopyImageToBufferInfo, RenderPassBeginInfo, SubpassBeginInfo, SubpassContents,
14};
15use vulkano::format::Format;
16use vulkano::image::view::ImageView;
17use vulkano::image::Image;
18use vulkano::memory::allocator::{AllocationCreateInfo, MemoryTypeFilter};
19use vulkano::pipeline::graphics::color_blend::{
20    AttachmentBlend, ColorBlendAttachmentState, ColorBlendState,
21};
22use vulkano::pipeline::graphics::input_assembly::{InputAssemblyState, PrimitiveTopology};
23use vulkano::pipeline::graphics::multisample::MultisampleState;
24use vulkano::pipeline::graphics::rasterization::RasterizationState;
25use vulkano::pipeline::graphics::vertex_input::{Vertex, VertexDefinition};
26use vulkano::pipeline::graphics::viewport::{Viewport, ViewportState};
27use vulkano::pipeline::graphics::GraphicsPipelineCreateInfo;
28use vulkano::pipeline::layout::PipelineDescriptorSetLayoutCreateInfo;
29use vulkano::pipeline::{GraphicsPipeline, PipelineLayout, PipelineShaderStageCreateInfo};
30use vulkano::render_pass::{Framebuffer, FramebufferCreateInfo, Subpass};
31
32use super::vulkan_context::VulkanContext;
33use super::{ExportImage, RenderImage};
34use crate::ray::RayResult;
35
36/// Provides a context for rendering an image on the GPU
37pub struct VulkanImage {
38    image_buffer: Subbuffer<[f32]>,
39    width: usize,
40    height: usize,
41    lightpower: f32,
42    sender: Sender<RayMsg>,
43    render_thread: Option<JoinHandle<()>>,
44}
45
46#[derive(BufferContents, Vertex)]
47#[repr(C)]
48struct RayVertex {
49    #[format(R32G32_SFLOAT)]
50    position: [f32; 2],
51    #[format(R32G32B32A32_SFLOAT)]
52    color: [f32; 4],
53}
54
55#[derive(BufferContents)]
56#[repr(C)]
57struct ImageBounds {
58    bounds: [f32; 2],
59}
60
61enum RayMsg {
62    Ray(RayResult),
63    EndOfFrame(Sender<()>),
64    Shutdown,
65}
66
67impl VulkanImage {
68    pub(crate) const BATCH_SIZE: usize = 1_000_000;
69
70    /// Create new Image on the GPU for GPU rendering, with the given size.
71    pub fn new(width: usize, height: usize) -> Self {
72        let ctx = VulkanContext::new();
73
74        let image_buffer = Buffer::new_slice::<f32>(
75            ctx.memory_allocator(),
76            BufferCreateInfo {
77                usage: BufferUsage::TRANSFER_DST,
78                ..Default::default()
79            },
80            AllocationCreateInfo {
81                memory_type_filter: MemoryTypeFilter::PREFER_HOST
82                    | MemoryTypeFilter::HOST_RANDOM_ACCESS,
83                ..Default::default()
84            },
85            (width * height * 4) as u64,
86        )
87        .expect("failed to create buffer");
88
89        let (tx, rx) = bounded(Self::BATCH_SIZE);
90
91        let mut this = Self {
92            image_buffer: image_buffer.clone(),
93            width,
94            height,
95            lightpower: 0.0,
96            sender: tx,
97            render_thread: None,
98        };
99
100        let w = width as u32;
101        let h = height as u32;
102        let render_thread = thread::spawn(move || Self::draw(ctx, rx, w, h, image_buffer));
103
104        this.render_thread = Some(render_thread);
105
106        this
107    }
108
109    fn draw(
110        mut ctx: VulkanContext,
111        recv: Receiver<RayMsg>,
112        width: u32,
113        height: u32,
114        dest: Subbuffer<[f32]>,
115    ) {
116        let image = ctx.new_framebuffer(width, height);
117        let mut result_buffer = Vec::with_capacity(Self::BATCH_SIZE * 2);
118        loop {
119            match recv.recv().unwrap() {
120                RayMsg::Ray(ray) => {
121                    let (r, g, b) = ray.color::<f32>();
122
123                    let start = RayVertex {
124                        position: [ray.origin.x as f32, ray.origin.y as f32],
125                        color: [r, g, b, 1.0],
126                    };
127
128                    let end = RayVertex {
129                        position: [ray.termination.x as f32, ray.termination.y as f32],
130                        color: [r, g, b, 1.0],
131                    };
132
133                    result_buffer.push(start);
134                    result_buffer.push(end);
135
136                    if result_buffer.len() == Self::BATCH_SIZE * 2 {
137                        Self::draw_batch(
138                            &mut ctx,
139                            width,
140                            height,
141                            result_buffer.drain(..),
142                            image.clone(),
143                            dest.clone(),
144                        );
145                    }
146                }
147                RayMsg::EndOfFrame(ack) => {
148                    Self::draw_batch(
149                        &mut ctx,
150                        width,
151                        height,
152                        result_buffer.drain(..),
153                        image.clone(),
154                        dest.clone(),
155                    );
156                    ctx.wait_gpu();
157                    ack.send(()).unwrap();
158                }
159                RayMsg::Shutdown => {
160                    break;
161                }
162            }
163        }
164    }
165
166    fn draw_batch(
167        ctx: &mut VulkanContext,
168        width: u32,
169        height: u32,
170        rays: impl ExactSizeIterator<Item = RayVertex>,
171        image: Arc<Image>,
172        dest: Subbuffer<[f32]>,
173    ) {
174        let rays = rays.into_iter();
175        let vertex_buffer = Buffer::from_iter(
176            ctx.memory_allocator(),
177            BufferCreateInfo {
178                usage: BufferUsage::VERTEX_BUFFER,
179                ..Default::default()
180            },
181            AllocationCreateInfo {
182                memory_type_filter: MemoryTypeFilter::PREFER_DEVICE
183                    | MemoryTypeFilter::HOST_SEQUENTIAL_WRITE,
184                ..Default::default()
185            },
186            rays,
187        )
188        .unwrap();
189
190        let render_pass = vulkano::single_pass_renderpass!(ctx.device(),
191            attachments: {
192                color: {
193                    format: Format::R32G32B32A32_SFLOAT,
194                    samples: 1,
195                    load_op: DontCare,
196                    store_op: Store,
197                },
198            },
199            pass: {
200                color: [color],
201                depth_stencil: {},
202            },
203        )
204        .unwrap();
205
206        let view = ImageView::new_default(image.clone()).unwrap();
207        let frame_buffer = Framebuffer::new(
208            render_pass.clone(),
209            FramebufferCreateInfo {
210                attachments: vec![view],
211                ..Default::default()
212            },
213        )
214        .unwrap();
215
216        let vs = shaders::vertex::load(ctx.device()).expect("failed to create shader module");
217        let fs = shaders::fragment::load(ctx.device()).expect("failed to create shader module");
218
219        let viewport = Viewport {
220            offset: [0.0, 0.0],
221            extent: [width as f32, height as f32],
222            depth_range: 0.0..=1.0,
223        };
224
225        let vs = vs.entry_point("main").unwrap();
226        let fs = fs.entry_point("main").unwrap();
227
228        let vertex_input_state = RayVertex::per_vertex()
229            .definition(&vs.info().input_interface)
230            .unwrap();
231
232        let stages = [
233            PipelineShaderStageCreateInfo::new(vs),
234            PipelineShaderStageCreateInfo::new(fs),
235        ];
236
237        let layout = PipelineLayout::new(
238            ctx.device(),
239            PipelineDescriptorSetLayoutCreateInfo::from_stages(&stages)
240                .into_pipeline_layout_create_info(ctx.device())
241                .unwrap(),
242        )
243        .unwrap();
244
245        let subpass = Subpass::from(render_pass.clone(), 0).unwrap();
246
247        let pipeline = GraphicsPipeline::new(
248            ctx.device(),
249            None,
250            GraphicsPipelineCreateInfo {
251                stages: stages.into_iter().collect(),
252                vertex_input_state: Some(vertex_input_state),
253                input_assembly_state: Some(InputAssemblyState {
254                    topology: PrimitiveTopology::LineList,
255                    primitive_restart_enable: false,
256                    ..Default::default()
257                }),
258                viewport_state: Some(ViewportState {
259                    viewports: [viewport].into_iter().collect(),
260                    ..Default::default()
261                }),
262                rasterization_state: Some(RasterizationState::default()),
263                multisample_state: Some(MultisampleState::default()),
264                color_blend_state: Some(ColorBlendState::with_attachment_states(
265                    subpass.num_color_attachments(),
266                    ColorBlendAttachmentState {
267                        blend: Some(AttachmentBlend::additive()),
268                        ..Default::default()
269                    },
270                )),
271                subpass: Some(subpass.into()),
272                ..GraphicsPipelineCreateInfo::layout(layout.clone())
273            },
274        )
275        .unwrap();
276
277        let num_vertices = vertex_buffer.len();
278
279        let mut builder = ctx.command_builder();
280
281        builder
282            .begin_render_pass(
283                RenderPassBeginInfo {
284                    clear_values: vec![None],
285                    ..RenderPassBeginInfo::framebuffer(frame_buffer)
286                },
287                SubpassBeginInfo {
288                    contents: SubpassContents::Inline,
289                    ..Default::default()
290                },
291            )
292            .unwrap()
293            .bind_pipeline_graphics(pipeline)
294            .unwrap()
295            .bind_vertex_buffers(0, vertex_buffer)
296            .unwrap()
297            .push_constants(
298                layout,
299                0,
300                ImageBounds {
301                    bounds: [width as f32, height as f32],
302                },
303            )
304            .unwrap()
305            .draw(
306                num_vertices as u32,
307                1,
308                0,
309                0, // 3 is the number of vertices, 1 is the number of instances
310            )
311            .unwrap()
312            .end_render_pass(Default::default())
313            .unwrap()
314            .copy_image_to_buffer(CopyImageToBufferInfo::image_buffer(image, dest))
315            .unwrap();
316
317        let command_buffer = builder.build().unwrap();
318
319        ctx.run_command_buffer(command_buffer);
320    }
321}
322
323impl RenderImage for VulkanImage {
324    fn draw_line(&self, ray: RayResult) {
325        self.sender.send(RayMsg::Ray(ray)).unwrap();
326    }
327
328    fn prepare_render(&mut self, lightpower: f32) {
329        self.lightpower = lightpower;
330    }
331
332    fn finish_render(&mut self) {
333        let (tx, rx) = bounded(1);
334        self.sender.send(RayMsg::EndOfFrame(tx)).unwrap();
335        rx.recv().unwrap();
336    }
337}
338
339impl ExportImage for VulkanImage {
340    fn get_size(&self) -> (usize, usize) {
341        (self.width, self.height)
342    }
343
344    fn get_lightpower(&self) -> f32 {
345        self.lightpower
346    }
347
348    fn to_rgbaf32(&self) -> Vec<f32> {
349        self.image_buffer
350            .read()
351            .expect("failed to read frambuffer")
352            .to_vec()
353    }
354}
355
356impl Drop for VulkanImage {
357    fn drop(&mut self) {
358        self.sender.send(RayMsg::Shutdown).unwrap();
359    }
360}
361
362mod shaders {
363    pub mod vertex {
364        vulkano_shaders::shader! {
365            ty: "vertex",
366            src: "
367                #version 460               
368                layout(location = 0) in vec2 position;
369                layout(location = 1) in vec4 color;
370                
371                layout(location=0) out vec4 outcolor;
372
373                layout( push_constant ) uniform Constants
374                {
375                    vec2 image_bounds;
376                } constants;
377
378                void main() {
379                    gl_Position = vec4((2.0 * (position.x / constants.image_bounds.x)) - 1.0, (2.0 * (position.y / constants.image_bounds.y)) - 1.0, 0.0, 1.0);
380                    outcolor = color;
381                }
382            ",
383        }
384    }
385
386    pub mod fragment {
387        vulkano_shaders::shader! {
388            ty: "fragment",
389            src: r"
390                #version 460
391                layout(location=0) in vec4 incolor;
392
393                layout(location = 0) out vec4 f_color;
394
395                void main() {
396                    f_color = incolor;
397                }
398            ",
399        }
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::VulkanImage;
406    use crate::image::{ExportImage, RenderImage};
407    use crate::ray::RayResult;
408
409    use itertools::Itertools as _;
410
411    #[test]
412    fn traced_ray_is_not_black() {
413        let mut i = VulkanImage::new(100, 100);
414        i.prepare_render(0.0);
415        i.draw_line(RayResult::new((10.0, 10.0), (90.0, 90.0), 620.0)); //red
416        i.draw_line(RayResult::new((20.0, 10.0), (90.0, 80.0), 520.0)); //green
417        i.draw_line(RayResult::new((10.0, 20.0), (80.0, 90.0), 470.0)); //blue
418        i.finish_render();
419        let mut r_count = 0.0;
420        let mut g_count = 0.0;
421        let mut b_count = 0.0;
422        for (r, g, b, _) in i.to_rgbaf32().iter().tuples() {
423            r_count += r;
424            g_count += g;
425            b_count += b;
426        }
427        assert_ne!(r_count, 0.0);
428        assert_ne!(g_count, 0.0);
429        assert_ne!(b_count, 0.0);
430    }
431
432    #[test]
433    fn empty_image_is_black() {
434        let i = VulkanImage::new(1920, 1080);
435        let v = i.to_rgbaf32();
436        for i in v.iter() {
437            assert_eq!(*i, 0.0);
438        }
439    }
440
441    #[test]
442    fn output_len_u8() {
443        let i = VulkanImage::new(1920, 1080);
444        let v = i.to_rgba8(0, 1.0, 1.0);
445        assert_eq!(v.len(), 1920 * 1080 * 4);
446    }
447
448    #[test]
449    fn output_len_f32() {
450        let i = VulkanImage::new(1920, 1080);
451        let v = i.to_rgbaf32();
452        assert_eq!(v.len(), 1920 * 1080 * 4);
453    }
454}