1use std::mem;
2use std::os::windows::prelude::AsRawHandle;
3use std::sync::atomic::{self, AtomicBool};
4use std::sync::{Arc, OnceLock, mpsc};
5use std::thread::{self, JoinHandle};
6
7use parking_lot::Mutex;
8use windows::Win32::Foundation::{HANDLE, LPARAM, S_FALSE, WPARAM};
9use windows::Win32::Graphics::Direct3D11::{ID3D11Device, ID3D11DeviceContext};
10use windows::Win32::System::Com::CoIncrementMTAUsage;
11use windows::Win32::System::Threading::{GetCurrentThreadId, GetThreadId};
12use windows::Win32::System::WinRT::{
13 CreateDispatcherQueueController, DQTAT_COM_NONE, DQTYPE_THREAD_CURRENT, DispatcherQueueOptions,
14 RO_INIT_MULTITHREADED, RoInitialize,
15};
16use windows::Win32::UI::WindowsAndMessaging::{
17 DispatchMessageW, GetMessageW, MSG, PostQuitMessage, PostThreadMessageW, TranslateMessage,
18 WM_QUIT,
19};
20use windows::core::Result as WindowsResult;
21use windows_future::AsyncActionCompletedHandler;
22
23use crate::d3d11::{self, create_d3d_device};
24use crate::frame::Frame;
25use crate::graphics_capture_api::{self, GraphicsCaptureApi, InternalCaptureControl};
26use crate::settings::{Settings, TryIntoCaptureItemWithType};
27
28#[derive(thiserror::Error, Debug)]
29pub enum CaptureControlError<E> {
30 #[error("Failed to join thread")]
31 FailedToJoinThread,
32 #[error("Thread handle is taken out of the struct")]
33 ThreadHandleIsTaken,
34 #[error("Failed to post thread message")]
35 FailedToPostThreadMessage,
36 #[error("Stopped handler error: {0}")]
37 StoppedHandlerError(E),
38 #[error("Windows capture error: {0}")]
39 GraphicsCaptureApiError(#[from] GraphicsCaptureApiError<E>),
40}
41
42pub struct CaptureControl<T: GraphicsCaptureApiHandler + Send + 'static, E> {
44 thread_handle: Option<JoinHandle<Result<(), GraphicsCaptureApiError<E>>>>,
45 halt_handle: Arc<AtomicBool>,
46 callback: Arc<Mutex<T>>,
47}
48
49impl<T: GraphicsCaptureApiHandler + Send + 'static, E> CaptureControl<T, E> {
50 #[must_use]
63 #[inline]
64 pub const fn new(
65 thread_handle: JoinHandle<Result<(), GraphicsCaptureApiError<E>>>,
66 halt_handle: Arc<AtomicBool>,
67 callback: Arc<Mutex<T>>,
68 ) -> Self {
69 Self { thread_handle: Some(thread_handle), halt_handle, callback }
70 }
71
72 #[must_use]
78 #[inline]
79 pub fn is_finished(&self) -> bool {
80 self.thread_handle.as_ref().is_none_or(std::thread::JoinHandle::is_finished)
81 }
82
83 #[must_use]
89 #[inline]
90 pub fn into_thread_handle(self) -> JoinHandle<Result<(), GraphicsCaptureApiError<E>>> {
91 self.thread_handle.unwrap()
92 }
93
94 #[must_use]
100 #[inline]
101 pub fn halt_handle(&self) -> Arc<AtomicBool> {
102 self.halt_handle.clone()
103 }
104
105 #[must_use]
111 #[inline]
112 pub fn callback(&self) -> Arc<Mutex<T>> {
113 self.callback.clone()
114 }
115
116 #[inline]
122 pub fn wait(mut self) -> Result<(), CaptureControlError<E>> {
123 if let Some(thread_handle) = self.thread_handle.take() {
124 match thread_handle.join() {
125 Ok(result) => result?,
126 Err(_) => {
127 return Err(CaptureControlError::FailedToJoinThread);
128 }
129 }
130 } else {
131 return Err(CaptureControlError::ThreadHandleIsTaken);
132 }
133
134 Ok(())
135 }
136
137 #[inline]
143 pub fn stop(mut self) -> Result<(), CaptureControlError<E>> {
144 self.halt_handle.store(true, atomic::Ordering::Relaxed);
145
146 if let Some(thread_handle) = self.thread_handle.take() {
147 let handle = thread_handle.as_raw_handle();
148 let handle = HANDLE(handle);
149 let thread_id = unsafe { GetThreadId(handle) };
150
151 loop {
152 match unsafe {
153 PostThreadMessageW(thread_id, WM_QUIT, WPARAM::default(), LPARAM::default())
154 } {
155 Ok(()) => break,
156 Err(e) => {
157 if thread_handle.is_finished() {
158 break;
159 }
160
161 if e.code().0 != -2_147_023_452 {
162 Err(e).map_err(|_| CaptureControlError::FailedToPostThreadMessage)?;
163 }
164 }
165 }
166 }
167
168 match thread_handle.join() {
169 Ok(result) => result?,
170 Err(_) => {
171 return Err(CaptureControlError::FailedToJoinThread);
172 }
173 }
174 } else {
175 return Err(CaptureControlError::ThreadHandleIsTaken);
176 }
177
178 Ok(())
179 }
180}
181
182#[derive(thiserror::Error, Eq, PartialEq, Clone, Debug)]
183pub enum GraphicsCaptureApiError<E> {
184 #[error("Failed to join thread")]
185 FailedToJoinThread,
186 #[error("Failed to initialize WinRT")]
187 FailedToInitWinRT,
188 #[error("Failed to create dispatcher queue controller")]
189 FailedToCreateDispatcherQueueController,
190 #[error("Failed to shut down dispatcher queue")]
191 FailedToShutdownDispatcherQueue,
192 #[error("Failed to set dispatcher queue completed handler")]
193 FailedToSetDispatcherQueueCompletedHandler,
194 #[error("Failed to convert item to `GraphicsCaptureItem`")]
195 ItemConvertFailed,
196 #[error("DirectX error: {0}")]
197 DirectXError(#[from] d3d11::Error),
198 #[error("Graphics capture error: {0}")]
199 GraphicsCaptureApiError(graphics_capture_api::Error),
200 #[error("New handler error: {0}")]
201 NewHandlerError(E),
202 #[error("Frame handler error: {0}")]
203 FrameHandlerError(E),
204}
205
206pub struct Context<Flags> {
208 pub flags: Flags,
210 pub device: ID3D11Device,
212 pub device_context: ID3D11DeviceContext,
214}
215
216pub trait GraphicsCaptureApiHandler: Sized {
218 type Flags;
220
221 type Error: Send + Sync;
223
224 #[inline]
234 fn start<T: TryIntoCaptureItemWithType>(
235 settings: Settings<Self::Flags, T>,
236 ) -> Result<(), GraphicsCaptureApiError<Self::Error>>
237 where
238 Self: Send + 'static,
239 <Self as GraphicsCaptureApiHandler>::Flags: Send,
240 {
241 static INIT_MTA: OnceLock<()> = OnceLock::new();
243 INIT_MTA.get_or_init(|| {
244 unsafe {
245 CoIncrementMTAUsage().expect("Failed to increment MTA usage");
246 };
247 });
248
249 match unsafe { RoInitialize(RO_INIT_MULTITHREADED) } {
250 Ok(_) => (),
251 Err(e) => {
252 if e.code() == S_FALSE {
253 } else {
255 return Err(GraphicsCaptureApiError::FailedToInitWinRT);
256 }
257 }
258 }
259
260 let options = DispatcherQueueOptions {
262 dwSize: u32::try_from(mem::size_of::<DispatcherQueueOptions>()).unwrap(),
263 threadType: DQTYPE_THREAD_CURRENT,
264 apartmentType: DQTAT_COM_NONE,
265 };
266 let controller = unsafe {
267 CreateDispatcherQueueController(options)
268 .map_err(|_| GraphicsCaptureApiError::FailedToCreateDispatcherQueueController)?
269 };
270
271 let thread_id = unsafe { GetCurrentThreadId() };
273
274 let (d3d_device, d3d_device_context) = create_d3d_device()?;
276
277 let result = Arc::new(Mutex::new(None));
279
280 let ctx = Context {
281 flags: settings.flags,
282 device: d3d_device.clone(),
283 device_context: d3d_device_context.clone(),
284 };
285
286 let callback =
287 Arc::new(Mutex::new(Self::new(ctx).map_err(GraphicsCaptureApiError::NewHandlerError)?));
288
289 let (item, item_type) = settings
291 .item
292 .try_into_capture_item()
293 .map_err(|_| GraphicsCaptureApiError::ItemConvertFailed)?;
294
295 let mut capture = GraphicsCaptureApi::new(
296 d3d_device,
297 d3d_device_context,
298 item,
299 item_type,
300 callback,
301 settings.cursor_capture_settings,
302 settings.draw_border_settings,
303 settings.secondary_window_settings,
304 settings.minimum_update_interval_settings,
305 settings.dirty_region_settings,
306 settings.color_format,
307 thread_id,
308 result.clone(),
309 )
310 .map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
311 capture.start_capture().map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
312
313 let mut message = MSG::default();
315 unsafe {
316 while GetMessageW(&mut message, None, 0, 0).as_bool() {
317 let _ = TranslateMessage(&message);
318 DispatchMessageW(&message);
319 }
320 }
321
322 let async_action = controller
324 .ShutdownQueueAsync()
325 .map_err(|_| GraphicsCaptureApiError::FailedToShutdownDispatcherQueue)?;
326
327 async_action
328 .SetCompleted(&AsyncActionCompletedHandler::new(move |_, _| -> WindowsResult<()> {
329 unsafe { PostQuitMessage(0) };
330 Ok(())
331 }))
332 .map_err(|_| GraphicsCaptureApiError::FailedToSetDispatcherQueueCompletedHandler)?;
333
334 let mut message = MSG::default();
336 unsafe {
337 while GetMessageW(&mut message, None, 0, 0).as_bool() {
338 let _ = TranslateMessage(&message);
339 DispatchMessageW(&message);
340 }
341 }
342
343 capture.stop_capture();
345
346 let result = result.lock().take();
351 if let Some(e) = result {
352 return Err(GraphicsCaptureApiError::FrameHandlerError(e));
353 }
354
355 Ok(())
356 }
357
358 #[inline]
368 fn start_free_threaded<T: TryIntoCaptureItemWithType + Send + 'static>(
369 settings: Settings<Self::Flags, T>,
370 ) -> Result<CaptureControl<Self, Self::Error>, GraphicsCaptureApiError<Self::Error>>
371 where
372 Self: Send + 'static,
373 <Self as GraphicsCaptureApiHandler>::Flags: Send,
374 {
375 let (halt_sender, halt_receiver) = mpsc::channel::<Arc<AtomicBool>>();
376 let (callback_sender, callback_receiver) = mpsc::channel::<Arc<Mutex<Self>>>();
377
378 let thread_handle =
379 thread::spawn(move || -> Result<(), GraphicsCaptureApiError<Self::Error>> {
380 static INIT_MTA: OnceLock<()> = OnceLock::new();
382 INIT_MTA.get_or_init(|| {
383 unsafe {
384 CoIncrementMTAUsage().expect("Failed to increment MTA usage");
385 };
386 });
387
388 match unsafe { RoInitialize(RO_INIT_MULTITHREADED) } {
389 Ok(_) => (),
390 Err(e) => {
391 if e.code() == S_FALSE {
392 } else {
394 return Err(GraphicsCaptureApiError::FailedToInitWinRT);
395 }
396 }
397 }
398
399 let options = DispatcherQueueOptions {
401 dwSize: u32::try_from(mem::size_of::<DispatcherQueueOptions>()).unwrap(),
402 threadType: DQTYPE_THREAD_CURRENT,
403 apartmentType: DQTAT_COM_NONE,
404 };
405 let controller = unsafe {
406 CreateDispatcherQueueController(options).map_err(|_| {
407 GraphicsCaptureApiError::FailedToCreateDispatcherQueueController
408 })?
409 };
410
411 let thread_id = unsafe { GetCurrentThreadId() };
413
414 let (d3d_device, d3d_device_context) = create_d3d_device()?;
416
417 let result = Arc::new(Mutex::new(None));
419
420 let ctx = Context {
421 flags: settings.flags,
422 device: d3d_device.clone(),
423 device_context: d3d_device_context.clone(),
424 };
425
426 let callback = Arc::new(Mutex::new(
427 Self::new(ctx).map_err(GraphicsCaptureApiError::NewHandlerError)?,
428 ));
429
430 let (item, item_type) = settings
432 .item
433 .try_into_capture_item()
434 .map_err(|_| GraphicsCaptureApiError::ItemConvertFailed)?;
435
436 let mut capture = GraphicsCaptureApi::new(
437 d3d_device,
438 d3d_device_context,
439 item,
440 item_type,
441 callback.clone(),
442 settings.cursor_capture_settings,
443 settings.draw_border_settings,
444 settings.secondary_window_settings,
445 settings.minimum_update_interval_settings,
446 settings.dirty_region_settings,
447 settings.color_format,
448 thread_id,
449 result.clone(),
450 )
451 .map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
452
453 capture
454 .start_capture()
455 .map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
456
457 let halt_handle = capture.halt_handle();
459 halt_sender.send(halt_handle).unwrap();
460
461 callback_sender.send(callback).unwrap();
463
464 let mut message = MSG::default();
466 unsafe {
467 while GetMessageW(&mut message, None, 0, 0).as_bool() {
468 let _ = TranslateMessage(&message);
469 DispatchMessageW(&message);
470 }
471 }
472
473 let async_action = controller
475 .ShutdownQueueAsync()
476 .map_err(|_| GraphicsCaptureApiError::FailedToShutdownDispatcherQueue)?;
477
478 async_action
479 .SetCompleted(&AsyncActionCompletedHandler::new(
480 move |_, _| -> Result<(), windows::core::Error> {
481 unsafe { PostQuitMessage(0) };
482 Ok(())
483 },
484 ))
485 .map_err(|_| {
486 GraphicsCaptureApiError::FailedToSetDispatcherQueueCompletedHandler
487 })?;
488
489 let mut message = MSG::default();
491 unsafe {
492 while GetMessageW(&mut message, None, 0, 0).as_bool() {
493 let _ = TranslateMessage(&message);
494 DispatchMessageW(&message);
495 }
496 }
497
498 capture.stop_capture();
500
501 let result = result.lock().take();
506 if let Some(e) = result {
507 return Err(GraphicsCaptureApiError::FrameHandlerError(e));
508 }
509
510 Ok(())
511 });
512
513 let Ok(halt_handle) = halt_receiver.recv() else {
514 match thread_handle.join() {
515 Ok(result) => return Err(result.err().unwrap()),
516 Err(_) => {
517 return Err(GraphicsCaptureApiError::FailedToJoinThread);
518 }
519 }
520 };
521
522 let Ok(callback) = callback_receiver.recv() else {
523 match thread_handle.join() {
524 Ok(result) => return Err(result.err().unwrap()),
525 Err(_) => {
526 return Err(GraphicsCaptureApiError::FailedToJoinThread);
527 }
528 }
529 };
530
531 Ok(CaptureControl::new(thread_handle, halt_handle, callback))
532 }
533
534 fn new(ctx: Context<Self::Flags>) -> Result<Self, Self::Error>;
545
546 fn on_frame_arrived(
557 &mut self,
558 frame: &mut Frame,
559 capture_control: InternalCaptureControl,
560 ) -> Result<(), Self::Error>;
561
562 #[inline]
568 fn on_closed(&mut self) -> Result<(), Self::Error> {
569 Ok(())
570 }
571}