1use std::{
2 mem,
3 os::windows::prelude::AsRawHandle,
4 sync::{
5 atomic::{self, AtomicBool},
6 mpsc, Arc,
7 },
8 thread::{self, JoinHandle},
9};
10
11use parking_lot::Mutex;
12use windows::{
13 Foundation::AsyncActionCompletedHandler,
14 Graphics::Capture::GraphicsCaptureItem,
15 Win32::{
16 Foundation::{HANDLE, LPARAM, WPARAM},
17 Graphics::Direct3D11::{ID3D11Device, ID3D11DeviceContext},
18 System::{
19 Threading::{GetCurrentThreadId, GetThreadId},
20 WinRT::{
21 CreateDispatcherQueueController, DispatcherQueueOptions, RoInitialize,
22 RoUninitialize, DQTAT_COM_NONE, DQTYPE_THREAD_CURRENT, RO_INIT_MULTITHREADED,
23 },
24 },
25 UI::WindowsAndMessaging::{
26 DispatchMessageW, GetMessageW, PostQuitMessage, PostThreadMessageW, TranslateMessage,
27 MSG, WM_QUIT,
28 },
29 },
30};
31
32use crate::{
33 d3d11::{self, create_d3d_device},
34 frame::Frame,
35 graphics_capture_api::{self, GraphicsCaptureApi, InternalCaptureControl},
36 settings::Settings,
37};
38
39#[derive(thiserror::Error, Debug)]
40pub enum CaptureControlError<E> {
41 #[error("Failed to join thread")]
42 FailedToJoinThread,
43 #[error("Thread handle is taken out of struct")]
44 ThreadHandleIsTaken,
45 #[error("Failed to post thread message")]
46 FailedToPostThreadMessage,
47 #[error("Stopped handler error: {0}")]
48 StoppedHandlerError(E),
49 #[error("Windows capture error: {0}")]
50 GraphicsCaptureApiError(#[from] GraphicsCaptureApiError<E>),
51}
52
53pub struct CaptureControl<T: GraphicsCaptureApiHandler + Send + 'static, E> {
55 thread_handle: Option<JoinHandle<Result<(), GraphicsCaptureApiError<E>>>>,
56 halt_handle: Arc<AtomicBool>,
57 callback: Arc<Mutex<T>>,
58}
59
60impl<T: GraphicsCaptureApiHandler + Send + 'static, E> CaptureControl<T, E> {
61 #[must_use]
73 #[inline]
74 pub const fn new(
75 thread_handle: JoinHandle<Result<(), GraphicsCaptureApiError<E>>>,
76 halt_handle: Arc<AtomicBool>,
77 callback: Arc<Mutex<T>>,
78 ) -> Self {
79 Self {
80 thread_handle: Some(thread_handle),
81 halt_handle,
82 callback,
83 }
84 }
85
86 #[must_use]
92 #[inline]
93 pub fn is_finished(&self) -> bool {
94 self.thread_handle
95 .as_ref()
96 .map_or(true, std::thread::JoinHandle::is_finished)
97 }
98
99 #[must_use]
105 #[inline]
106 pub fn into_thread_handle(self) -> JoinHandle<Result<(), GraphicsCaptureApiError<E>>> {
107 self.thread_handle.unwrap()
108 }
109
110 #[must_use]
116 #[inline]
117 pub fn halt_handle(&self) -> Arc<AtomicBool> {
118 self.halt_handle.clone()
119 }
120
121 #[must_use]
127 #[inline]
128 pub fn callback(&self) -> Arc<Mutex<T>> {
129 self.callback.clone()
130 }
131
132 #[inline]
138 pub fn wait(mut self) -> Result<(), CaptureControlError<E>> {
139 if let Some(thread_handle) = self.thread_handle.take() {
140 match thread_handle.join() {
141 Ok(result) => result?,
142 Err(_) => {
143 return Err(CaptureControlError::FailedToJoinThread);
144 }
145 }
146 } else {
147 return Err(CaptureControlError::ThreadHandleIsTaken);
148 }
149
150 Ok(())
151 }
152
153 #[inline]
159 pub fn stop(mut self) -> Result<(), CaptureControlError<E>> {
160 self.halt_handle.store(true, atomic::Ordering::Relaxed);
161
162 if let Some(thread_handle) = self.thread_handle.take() {
163 let handle = thread_handle.as_raw_handle();
164 let handle = HANDLE(handle);
165 let therad_id = unsafe { GetThreadId(handle) };
166
167 loop {
168 match unsafe {
169 PostThreadMessageW(therad_id, WM_QUIT, WPARAM::default(), LPARAM::default())
170 } {
171 Ok(()) => break,
172 Err(e) => {
173 if thread_handle.is_finished() {
174 break;
175 }
176
177 if e.code().0 != -2_147_023_452 {
178 Err(e).map_err(|_| CaptureControlError::FailedToPostThreadMessage)?;
179 }
180 }
181 }
182 }
183
184 match thread_handle.join() {
185 Ok(result) => result?,
186 Err(_) => {
187 return Err(CaptureControlError::FailedToJoinThread);
188 }
189 }
190 } else {
191 return Err(CaptureControlError::ThreadHandleIsTaken);
192 }
193
194 Ok(())
195 }
196}
197
198#[derive(thiserror::Error, Eq, PartialEq, Clone, Debug)]
199pub enum GraphicsCaptureApiError<E> {
200 #[error("Failed to join thread")]
201 FailedToJoinThread,
202 #[error("Failed to initialize WinRT")]
203 FailedToInitWinRT,
204 #[error("Failed to create dispatcher queue controller")]
205 FailedToCreateDispatcherQueueController,
206 #[error("Failed to shutdown dispatcher queue")]
207 FailedToShutdownDispatcherQueue,
208 #[error("Failed to set dispatcher queue completed handler")]
209 FailedToSetDispatcherQueueCompletedHandler,
210 #[error("Failed to convert item to GraphicsCaptureItem")]
211 ItemConvertFailed,
212 #[error("DirectX error: {0}")]
213 DirectXError(#[from] d3d11::Error),
214 #[error("Graphics capture error: {0}")]
215 GraphicsCaptureApiError(graphics_capture_api::Error),
216 #[error("New handler error: {0}")]
217 NewHandlerError(E),
218 #[error("Frame handler error: {0}")]
219 FrameHandlerError(E),
220}
221
222pub struct Context<Flags> {
224 pub flags: Flags,
226 pub device: ID3D11Device,
228 pub device_context: ID3D11DeviceContext,
230}
231
232pub trait GraphicsCaptureApiHandler: Sized {
234 type Flags;
236
237 type Error: Send + Sync;
239
240 #[inline]
250 fn start<T: TryInto<GraphicsCaptureItem>>(
251 settings: Settings<Self::Flags, T>,
252 ) -> Result<(), GraphicsCaptureApiError<Self::Error>>
253 where
254 Self: Send + 'static,
255 <Self as GraphicsCaptureApiHandler>::Flags: Send,
256 {
257 unsafe {
259 RoInitialize(RO_INIT_MULTITHREADED)
260 .map_err(|_| GraphicsCaptureApiError::FailedToInitWinRT)?;
261 };
262
263 let options = DispatcherQueueOptions {
265 dwSize: u32::try_from(mem::size_of::<DispatcherQueueOptions>()).unwrap(),
266 threadType: DQTYPE_THREAD_CURRENT,
267 apartmentType: DQTAT_COM_NONE,
268 };
269 let controller = unsafe {
270 CreateDispatcherQueueController(options)
271 .map_err(|_| GraphicsCaptureApiError::FailedToCreateDispatcherQueueController)?
272 };
273
274 let thread_id = unsafe { GetCurrentThreadId() };
276
277 let (d3d_device, d3d_device_context) = create_d3d_device()?;
279
280 let result = Arc::new(Mutex::new(None));
282
283 let ctx = Context {
284 flags: settings.flags,
285 device: d3d_device.clone(),
286 device_context: d3d_device_context.clone(),
287 };
288
289 let callback = Arc::new(Mutex::new(
290 Self::new(ctx).map_err(GraphicsCaptureApiError::NewHandlerError)?,
291 ));
292
293 let item = settings
294 .item
295 .try_into()
296 .map_err(|_| GraphicsCaptureApiError::ItemConvertFailed)?;
297
298 let mut capture = GraphicsCaptureApi::new(
299 d3d_device,
300 d3d_device_context,
301 item,
302 callback,
303 settings.cursor_capture,
304 settings.draw_border,
305 settings.color_format,
306 thread_id,
307 result.clone(),
308 )
309 .map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
310 capture
311 .start_capture()
312 .map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
313
314 let mut message = MSG::default();
316 unsafe {
317 while GetMessageW(&mut message, None, 0, 0).as_bool() {
318 let _ = TranslateMessage(&message);
319 DispatchMessageW(&message);
320 }
321 }
322
323 let async_action = controller
325 .ShutdownQueueAsync()
326 .map_err(|_| GraphicsCaptureApiError::FailedToShutdownDispatcherQueue)?;
327 async_action
328 .SetCompleted(&AsyncActionCompletedHandler::new(
329 move |_, _| -> Result<(), windows::core::Error> {
330 unsafe { PostQuitMessage(0) };
331 Ok(())
332 },
333 ))
334 .map_err(|_| GraphicsCaptureApiError::FailedToSetDispatcherQueueCompletedHandler)?;
335
336 let mut message = MSG::default();
338 unsafe {
339 while GetMessageW(&mut message, None, 0, 0).as_bool() {
340 let _ = TranslateMessage(&message);
341 DispatchMessageW(&message);
342 }
343 }
344
345 capture.stop_capture();
347
348 unsafe { RoUninitialize() };
350
351 let result = result.lock().take();
353 if let Some(e) = result {
354 return Err(GraphicsCaptureApiError::FrameHandlerError(e));
355 }
356
357 Ok(())
358 }
359
360 #[inline]
370 fn start_free_threaded<T: TryInto<GraphicsCaptureItem> + Send + 'static>(
371 settings: Settings<Self::Flags, T>,
372 ) -> Result<CaptureControl<Self, Self::Error>, GraphicsCaptureApiError<Self::Error>>
373 where
374 Self: Send + 'static,
375 <Self as GraphicsCaptureApiHandler>::Flags: Send,
376 {
377 let (halt_sender, halt_receiver) = mpsc::channel::<Arc<AtomicBool>>();
378 let (callback_sender, callback_receiver) = mpsc::channel::<Arc<Mutex<Self>>>();
379
380 let thread_handle = thread::spawn(
381 move || -> Result<(), GraphicsCaptureApiError<Self::Error>> {
382 unsafe {
384 RoInitialize(RO_INIT_MULTITHREADED)
385 .map_err(|_| GraphicsCaptureApiError::FailedToInitWinRT)?;
386 };
387
388 let options = DispatcherQueueOptions {
390 dwSize: u32::try_from(mem::size_of::<DispatcherQueueOptions>()).unwrap(),
391 threadType: DQTYPE_THREAD_CURRENT,
392 apartmentType: DQTAT_COM_NONE,
393 };
394 let controller = unsafe {
395 CreateDispatcherQueueController(options).map_err(|_| {
396 GraphicsCaptureApiError::FailedToCreateDispatcherQueueController
397 })?
398 };
399
400 let thread_id = unsafe { GetCurrentThreadId() };
402
403 let (d3d_device, d3d_device_context) = create_d3d_device()?;
405
406 let result = Arc::new(Mutex::new(None));
408
409 let ctx = Context {
410 flags: settings.flags,
411 device: d3d_device.clone(),
412 device_context: d3d_device_context.clone(),
413 };
414
415 let callback = Arc::new(Mutex::new(
416 Self::new(ctx).map_err(GraphicsCaptureApiError::NewHandlerError)?,
417 ));
418
419 let item = settings
420 .item
421 .try_into()
422 .map_err(|_| GraphicsCaptureApiError::ItemConvertFailed)?;
423
424 let mut capture = GraphicsCaptureApi::new(
425 d3d_device,
426 d3d_device_context,
427 item,
428 callback.clone(),
429 settings.cursor_capture,
430 settings.draw_border,
431 settings.color_format,
432 thread_id,
433 result.clone(),
434 )
435 .map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
436 capture
437 .start_capture()
438 .map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
439
440 let halt_handle = capture.halt_handle();
442 halt_sender.send(halt_handle).unwrap();
443
444 callback_sender.send(callback).unwrap();
446
447 let mut message = MSG::default();
449 unsafe {
450 while GetMessageW(&mut message, None, 0, 0).as_bool() {
451 let _ = TranslateMessage(&message);
452 DispatchMessageW(&message);
453 }
454 }
455
456 let async_action = controller
458 .ShutdownQueueAsync()
459 .map_err(|_| GraphicsCaptureApiError::FailedToShutdownDispatcherQueue)?;
460
461 async_action
462 .SetCompleted(&AsyncActionCompletedHandler::new(
463 move |_, _| -> Result<(), windows::core::Error> {
464 unsafe { PostQuitMessage(0) };
465 Ok(())
466 },
467 ))
468 .map_err(|_| {
469 GraphicsCaptureApiError::FailedToSetDispatcherQueueCompletedHandler
470 })?;
471
472 let mut message = MSG::default();
474 unsafe {
475 while GetMessageW(&mut message, None, 0, 0).as_bool() {
476 let _ = TranslateMessage(&message);
477 DispatchMessageW(&message);
478 }
479 }
480
481 capture.stop_capture();
483
484 unsafe { RoUninitialize() };
486
487 let result = result.lock().take();
489 if let Some(e) = result {
490 return Err(GraphicsCaptureApiError::FrameHandlerError(e));
491 }
492
493 Ok(())
494 },
495 );
496
497 let Ok(halt_handle) = halt_receiver.recv() else {
498 match thread_handle.join() {
499 Ok(result) => return Err(result.err().unwrap()),
500 Err(_) => {
501 return Err(GraphicsCaptureApiError::FailedToJoinThread);
502 }
503 }
504 };
505
506 let Ok(callback) = callback_receiver.recv() else {
507 match thread_handle.join() {
508 Ok(result) => return Err(result.err().unwrap()),
509 Err(_) => {
510 return Err(GraphicsCaptureApiError::FailedToJoinThread);
511 }
512 }
513 };
514
515 Ok(CaptureControl::new(thread_handle, halt_handle, callback))
516 }
517
518 fn new(ctx: Context<Self::Flags>) -> Result<Self, Self::Error>;
528
529 fn on_frame_arrived(
540 &mut self,
541 frame: &mut Frame,
542 capture_control: InternalCaptureControl,
543 ) -> Result<(), Self::Error>;
544
545 #[inline]
551 fn on_closed(&mut self) -> Result<(), Self::Error> {
552 Ok(())
553 }
554}