1use std::ffi::CString;
2use std::process::{ChildStderr, ChildStdin, ChildStdout, ExitStatus};
3use std::sync::mpsc::{self, Receiver, Sender};
4use std::thread::{self, JoinHandle};
5
6use bitflags::bitflags;
7use uuid::Uuid;
8use windows::core::{IUnknown, Interface, GUID, PCSTR, PCWSTR};
9use windows::Win32::Foundation::{CloseHandle, GetLastError, HANDLE};
10use windows::Win32::Networking::WinSock::WSAStartup;
11use windows::Win32::Storage::FileSystem::{
12 GetFileType, FILE_TYPE_CHAR, FILE_TYPE_DISK, FILE_TYPE_PIPE, FILE_TYPE_REMOTE,
13 FILE_TYPE_UNKNOWN,
14};
15use windows::Win32::System::Com::{
16 CoInitializeEx, CoInitializeSecurity, CoTaskMemFree, CoUninitialize, IClientSecurity,
17 COINIT_MULTITHREADED, EOAC_DYNAMIC_CLOAKING, EOAC_STATIC_CLOAKING,
18 EOLE_AUTHENTICATION_CAPABILITIES, RPC_C_AUTHN_LEVEL, RPC_C_AUTHN_LEVEL_CONNECT,
19 RPC_C_IMP_LEVEL, RPC_C_IMP_LEVEL_IDENTIFY, RPC_C_IMP_LEVEL_IMPERSONATE,
20};
21use windows::Win32::System::IO::DeviceIoControl;
22use wsl_com_api_sys::{
23 constants::*, get_lxss_user_session, ILxssUserSession, LxssHandleType, LXSS_ENUMERATE_INFO,
24 LXSS_HANDLE, LXSS_STD_HANDLES,
25};
26
27use wsl_com_api_sys::interop::LXBUS_IPC_LX_PROCESS_WAIT_FOR_TERMINATION_PARAMETERS;
28
29mod error;
30pub use error::*;
31mod interop;
32
33#[cfg(windows)]
36use std::os::windows::io::AsRawHandle;
37#[cfg(windows)]
38fn to_handle(handle: &impl AsRawHandle) -> HANDLE {
39 HANDLE(handle.as_raw_handle() as isize)
40}
41#[cfg(windows)]
42fn from_handle<T: From<std::os::windows::io::OwnedHandle>>(handle: HANDLE) -> T {
43 use std::os::windows::io::FromRawHandle;
44 unsafe {
45 T::from(std::os::windows::io::OwnedHandle::from_raw_handle(
46 handle.0 as _,
47 ))
48 }
49}
50
51#[cfg(unix)]
52use std::os::fd::AsRawFd as AsRawHandle;
53
54use crate::interop::Interop;
55#[cfg(unix)]
56fn to_handle(_: &impl AsRawHandle) -> HANDLE {
57 unreachable!("This should never be called on Unix: we only support Windows");
58}
59#[cfg(unix)]
60fn from_handle<T>(_: HANDLE) -> T {
61 unreachable!("This should never be called on Unix: we only support Windows");
62}
63
64unsafe fn wait_for_wsl_process(process_handle: HANDLE, timeout_ms: u32) -> Result<u32, WslError> {
66 let mut parameters = LXBUS_IPC_LX_PROCESS_WAIT_FOR_TERMINATION_PARAMETERS {
67 Input: wsl_com_api_sys::interop::LXBUS_IPC_LX_PROCESS_WAIT_FOR_TERMINATION_INPUT {
68 TimeoutMs: timeout_ms,
69 },
70 };
71
72 DeviceIoControl(
73 process_handle,
74 LXBUS_IPC_LX_PROCESS_IOCTL_WAIT_FOR_TERMINATION,
75 Some(¶meters.Input as *const _ as *const _),
76 std::mem::size_of::<wsl_com_api_sys::interop::LXBUS_IPC_LX_PROCESS_WAIT_FOR_TERMINATION_INPUT>(
77 ) as u32,
78 Some(&mut parameters.Output as *mut _ as *mut _),
79 std::mem::size_of::<
80 wsl_com_api_sys::interop::LXBUS_IPC_LX_PROCESS_WAIT_FOR_TERMINATION_OUTPUT,
81 >() as u32,
82 None,
83 None,
84 )?;
85
86 Ok((parameters.Output.ExitStatus as u32) >> 8)
88}
89
90fn validate_file_handle(
92 name: &str,
93 handle: HANDLE,
94 expected_type: windows::Win32::Storage::FileSystem::FILE_TYPE,
95) -> Result<(), WslError> {
96 let file_type = unsafe { GetFileType(handle) };
97 if file_type == FILE_TYPE_UNKNOWN {
98 return Err(windows::core::Error::new(
99 wsl_com_api_sys::error::WSL_E_INVALID_USAGE,
100 format!(
101 "{} ({:x}) is not a valid file handle: {:?}",
102 name,
103 handle.0,
104 unsafe { GetLastError() }
105 ),
106 )
107 .into());
108 }
109 let type_to_string = |file_type: windows::Win32::Storage::FileSystem::FILE_TYPE| match file_type
110 {
111 FILE_TYPE_DISK => "file",
112 FILE_TYPE_PIPE => "pipe",
113 FILE_TYPE_CHAR => "character device",
114 FILE_TYPE_REMOTE => "remote file",
115 FILE_TYPE_UNKNOWN => "unknown type",
116 _ => "invalid type",
117 };
118
119 if file_type != expected_type {
120 let expected_type_name = type_to_string(expected_type);
121 return Err(windows::core::Error::new(
122 wsl_com_api_sys::error::WSL_E_INVALID_USAGE,
123 format!(
124 "{} ({:x}) must be a {} (got a {})",
125 name,
126 handle.0,
127 expected_type_name,
128 type_to_string(file_type)
129 ),
130 )
131 .into());
132 }
133 Ok(())
134}
135
136struct CoMultithreadedInterface<T: Interface>(T);
137
138unsafe impl<T: Interface> Send for CoMultithreadedInterface<T> {}
139
140pub struct Wsl2 {
142 sender: Sender<Box<dyn FnOnce(&ILxssUserSession) + Send>>,
144 session: CoMultithreadedInterface<ILxssUserSession>,
146 _background_thread: JoinHandle<()>,
148}
149
150impl Wsl2 {
151 pub fn new() -> Result<Self, WslError> {
153 let (sender, receiver) = mpsc::channel();
154 let (tx_init, rx_init) = mpsc::channel();
155
156 let background_thread = thread::spawn(move || {
157 Self::com_thread_worker(receiver, tx_init);
158 });
159
160 let session = rx_init
161 .recv()
162 .expect("thread died (init)?")
163 .map_err(WslError::from)?;
164
165 Ok(Wsl2 {
166 sender,
167 session,
168 _background_thread: background_thread,
169 })
170 }
171
172 fn com_thread_worker(
174 receiver: Receiver<Box<dyn FnOnce(&ILxssUserSession) + Send>>,
175 initialized: Sender<windows::core::Result<CoMultithreadedInterface<ILxssUserSession>>>,
176 ) {
177 unsafe {
178 let mut wsa_data = std::mem::zeroed();
182 let result = WSAStartup(0x0202, &mut wsa_data);
183 if result != 0 {
184 initialized
185 .send(Err(windows::core::Error::new(
186 wsl_com_api_sys::error::WSL_E_INVALID_USAGE,
187 format!("WSAStartup failed: 0x{:x}", result),
188 )
189 .into()))
190 .expect("thread died (init tx)?");
191 return;
192 }
193
194 let result = CoInitializeEx(None, COINIT_MULTITHREADED);
196 if result.is_err() {
197 initialized
198 .send(Err(result.into()))
199 .expect("thread died (init tx)?");
200 return;
201 }
202
203 let result = CoInitializeSecurity(
204 None,
205 -1,
206 None,
207 None,
208 RPC_C_AUTHN_LEVEL_CONNECT,
209 RPC_C_IMP_LEVEL_IDENTIFY,
210 None,
211 EOAC_STATIC_CLOAKING,
212 None,
213 );
214 if let Err(e) = result {
215 CoUninitialize();
216 initialized.send(Err(e)).expect("thread died (init tx)?");
217 return;
218 }
219 }
220
221 let session = match unsafe { get_lxss_user_session() } {
223 Ok(session) => session,
224 Err(e) => {
225 unsafe {
226 CoUninitialize();
227 }
228 initialized
229 .send(Err(e.into()))
230 .expect("thread died (init tx)?");
231 return;
232 }
233 };
234
235 let result = Self::set_session_blanket(&session);
236 if let Err(e) = result {
237 unsafe {
238 CoUninitialize();
239 }
240 initialized.send(Err(e)).expect("thread died (init tx)?");
241 return;
242 }
243
244 initialized
245 .send(Ok(CoMultithreadedInterface(session.clone())))
246 .expect("thread died (init tx)?");
247
248 for request in receiver {
250 request(&session);
251 }
252
253 unsafe {
255 CoUninitialize();
256 }
257 }
258
259 fn set_session_blanket(session: &ILxssUserSession) -> windows::core::Result<()> {
268 let client_security: IClientSecurity = session.cast()?;
269
270 let mut authn_svc = 0;
271 let mut authz_svc = 0;
272 let mut authn_lvl: RPC_C_AUTHN_LEVEL = RPC_C_AUTHN_LEVEL(0);
273 let mut imp_lvl: RPC_C_IMP_LEVEL = RPC_C_IMP_LEVEL(0);
274 let mut capabilities: EOLE_AUTHENTICATION_CAPABILITIES =
275 EOLE_AUTHENTICATION_CAPABILITIES(0);
276 unsafe {
277 client_security.QueryBlanket::<&IUnknown>(
278 &session.0,
279 std::ptr::from_mut(&mut authn_svc),
280 Some(std::ptr::from_mut(&mut authz_svc)),
281 std::ptr::null_mut(),
282 Some(std::ptr::from_mut(&mut authn_lvl)),
283 Some(std::ptr::from_mut(&mut imp_lvl)),
284 std::ptr::null_mut(),
285 Some(std::ptr::from_mut(&mut capabilities.0) as _),
286 )?;
287 }
288
289 capabilities.0 &= !EOAC_STATIC_CLOAKING.0;
290 capabilities.0 |= EOAC_DYNAMIC_CLOAKING.0;
291
292 unsafe {
293 client_security.SetBlanket::<&IUnknown, PCWSTR>(
294 &session.0,
295 authn_svc,
296 authz_svc,
297 PCWSTR::null(),
298 authn_lvl,
299 RPC_C_IMP_LEVEL_IMPERSONATE,
300 None,
301 capabilities,
302 )?;
303 }
304
305 Ok(())
306 }
307
308 fn execute_thread<F, T>(&self, f: F) -> Result<T, WslError>
310 where
311 F: FnOnce(&ILxssUserSession) -> Result<T, WslError> + Send + 'static,
312 T: Send + 'static,
313 {
314 let (tx, rx) = mpsc::channel();
315 self.sender
316 .send(Box::new(move |session| {
317 let result = f(session);
318 _ = tx.send(result);
319 }))
320 .expect("thread died (tx)?");
321 rx.recv().expect("thread died (rx)?")
322 }
323
324 fn execute<F, T>(&self, f: F) -> Result<T, WslError>
326 where
327 F: FnOnce(&ILxssUserSession) -> Result<T, WslError> + Send + 'static,
328 T: Send + 'static,
329 {
330 f(&self.session.0)
331 }
332
333 pub fn shutdown(self, force: bool) -> Result<(), WslError> {
335 self.execute_thread(move |session| unsafe {
336 session.Shutdown(force as i32)?;
337 Ok(())
338 })
339 }
340
341 pub fn get_default_distribution(&self) -> Result<Uuid, WslError> {
343 self.execute(|session| unsafe {
344 Ok(session
345 .GetDefaultDistribution()
346 .map(|guid| Uuid::from_u128(guid.to_u128()))?)
347 })
348 }
349
350 pub fn launch(
354 &self,
355 distro_guid: Uuid,
356 command: &str,
357 args: &[&str],
358 cwd: Option<&str>,
359 username: &str,
360 ) -> Result<WslProcess, WslError> {
361 let username = widestring::U16CString::from_str_truncate(username);
362 let command = CString::new(command).unwrap();
363 let cwd = cwd.map(|cwd| widestring::U16CString::from_str_truncate(cwd));
364 let nt_path = widestring::U16CString::from_str_truncate(
365 std::env::current_dir()
366 .unwrap_or_default()
367 .to_str()
368 .unwrap(),
369 );
370 let args = args
371 .iter()
372 .map(|arg| CString::new(*arg).unwrap())
373 .collect::<Vec<_>>();
374
375 let (stdin_r, stdin_w) = std::io::pipe().unwrap();
376 let (stdout_r, stdout_w) = std::io::pipe().unwrap();
377 let (stderr_r, stderr_w) = std::io::pipe().unwrap();
378
379 let pipe = (
380 to_handle(&stdin_r),
381 to_handle(&stdout_w),
382 to_handle(&stderr_w),
383 );
384
385 let handles = LXSS_STD_HANDLES {
386 StdIn: LXSS_HANDLE {
387 Handle: pipe.0 .0 as _,
388 HandleType: LxssHandleType::LxssHandleInput,
389 },
390 StdOut: LXSS_HANDLE {
391 Handle: pipe.1 .0 as _,
392 HandleType: LxssHandleType::LxssHandleOutput,
393 },
394 StdErr: LXSS_HANDLE {
395 Handle: pipe.2 .0 as _,
396 HandleType: LxssHandleType::LxssHandleOutput,
397 },
398 };
399
400 std::mem::forget(stderr_w);
401 std::mem::forget(stdout_w);
402 std::mem::forget(stdin_r);
403
404 self.execute(move |session| unsafe {
405 let arg_ptrs = args
406 .iter()
407 .map(|arg| arg.to_bytes_with_nul().as_ptr())
408 .collect::<Vec<_>>();
409 let result = session.CreateLxProcess(
410 GUID::from_u128(distro_guid.as_u128()),
411 PCSTR::from_raw(command.as_ptr() as *const u8),
412 args.len() as u32,
413 arg_ptrs.as_ptr() as *const PCSTR,
414 PCWSTR::from_raw(cwd.map(|cwd| cwd.as_ptr()).unwrap_or(std::ptr::null())),
415 PCWSTR::from_raw(nt_path.as_ptr()),
416 std::ptr::null_mut(), 0, PCWSTR::from_raw(username.as_ptr()),
419 80,
420 25,
421 0,
422 std::ptr::from_ref(&handles),
423 CreateInstanceFlags::empty().bits(),
424 )?;
425
426 #[allow(unreachable_code)]
427 let process = if result.ProcessHandle.is_invalid() {
428 #[cfg(unix)]
430 #[allow(unused)]
431 let tcp = { unreachable!("Unsupported platform") };
432
433 #[cfg(windows)]
434 let tcp = {
435 use std::net::TcpStream;
436 use std::os::windows::io::FromRawSocket;
437 TcpStream::from_raw_socket(result.InteropSocket.0 as _)
438 };
439
440 WslProcess {
441 stdin: Some(from_handle(result.StandardIn)),
442 stdout: Some(from_handle(result.StandardOut)),
443 stderr: Some(from_handle(result.StandardErr)),
444 pipe,
445 handle: WslProcessInner::WSL2(Interop::new(tcp), result.CommunicationChannel),
446 }
447 } else {
448 let process = WslProcess {
449 stdin: Some(from_handle(to_handle(&stdin_w))),
450 stdout: Some(from_handle(to_handle(&stdout_r))),
451 stderr: Some(from_handle(to_handle(&stderr_r))),
452 pipe,
453 handle: WslProcessInner::WSL1(result.ProcessHandle),
454 };
455
456 _ = CloseHandle(result.ServerHandle);
458
459 std::mem::forget(stdin_w);
460 std::mem::forget(stdout_r);
461 std::mem::forget(stderr_r);
462
463 process
464 };
465
466 Ok(process)
467 })
468 }
469
470 pub fn enumerate_distributions(&self) -> Result<Vec<Distribution>, WslError> {
472 self.execute(|session| unsafe {
473 let (count, distros) = session.EnumerateDistributions()?;
474 let distros_copy = {
475 let slice = std::slice::from_raw_parts(distros, count as usize);
476 slice
477 .iter()
478 .map(|distro| Distribution::from(distro))
479 .collect()
480 };
481 CoTaskMemFree(Some(distros as _));
482 Ok(distros_copy)
483 })
484 }
485
486 pub fn export_distribution(
488 &self,
489 distro_guid: Uuid,
490 file: impl AsRawHandle,
491 stderr: impl AsRawHandle,
492 flags: ExportFlags,
493 ) -> Result<(), WslError> {
494 let file_handle = to_handle(&file);
495 let stderr_handle = to_handle(&stderr);
496
497 let res = self.execute(move |session| unsafe {
498 validate_file_handle("stderr_handle", stderr_handle, FILE_TYPE_PIPE)?;
500 validate_file_handle("file_handle", file_handle, FILE_TYPE_DISK)?;
501
502 session.ExportDistribution(
503 GUID::from_u128(distro_guid.as_u128()),
504 file_handle,
505 stderr_handle,
506 flags.bits(),
507 )?;
508 Ok(())
509 });
510
511 drop(file);
512 drop(stderr);
513 res
514 }
515
516 pub fn register_distribution(
519 &self,
520 name: &str,
521 version: Version,
522 file: impl AsRawHandle,
523 stderr: impl AsRawHandle,
524 flags: ImportFlags,
525 ) -> Result<(Uuid, String), WslError> {
526 let file_handle = to_handle(&file);
527 let stderr_handle = to_handle(&stderr);
528 let wide_name = widestring::U16CString::from_str_truncate(name);
529
530 let res = self.execute(move |session| unsafe {
531 validate_file_handle("stderr_handle", stderr_handle, FILE_TYPE_PIPE)?;
533 validate_file_handle("file_handle", file_handle, FILE_TYPE_DISK)?;
534
535 let result = session.RegisterDistribution(
536 PCWSTR::from_raw(wide_name.as_ptr()),
537 version.into(),
538 file_handle,
539 stderr_handle,
540 PCWSTR::null(),
541 flags.bits(),
542 0,
543 PCWSTR::null(),
544 )?;
545 let name = result.InstalledName.to_string().unwrap_or_default();
546 CoTaskMemFree(Some(result.InstalledName.0 as _));
547 Ok((Uuid::from_u128(result.Guid.to_u128()), name))
548 });
549
550 drop(file);
551 drop(stderr);
552 res
553 }
554
555 pub fn set_version(
556 &self,
557 distribution: Uuid,
558 version: Version,
559 stderr: impl AsRawHandle,
560 ) -> Result<(), WslError> {
561 let handle = to_handle(&stderr);
562 let res = self.execute(move |session| unsafe {
563 session.SetVersion(
564 GUID::from_u128(distribution.as_u128()),
565 version.into(),
566 handle,
567 )?;
568 Ok(())
569 });
570
571 drop(stderr);
572 res
573 }
574}
575
576impl Drop for Wsl2 {
577 fn drop(&mut self) {
578 }
581}
582
583#[derive(Debug)]
584pub struct Distribution {
585 pub name: String,
586 pub uuid: Uuid,
587 pub version: Version,
588}
589
590impl From<&LXSS_ENUMERATE_INFO> for Distribution {
591 fn from(info: &LXSS_ENUMERATE_INFO) -> Self {
592 let name = unsafe {
593 PCWSTR::from_raw(info.DistroName.as_ptr())
594 .to_string()
595 .unwrap()
596 };
597 Self {
598 name: name.to_string(),
599 uuid: Uuid::from_u128(info.DistroGuid.to_u128()),
600 version: match info.Version {
601 1 => Version::WSL1,
602 2 => Version::WSL2,
603 _ => Version::Unknown(info.Version),
604 },
605 }
606 }
607}
608
609#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
610pub enum Version {
611 Legacy,
612 WSL1,
613 WSL2,
614 Unknown(u32),
615}
616
617impl Into<u32> for Version {
618 fn into(self) -> u32 {
619 match self {
620 Version::Legacy => 0,
621 Version::WSL1 => 1,
622 Version::WSL2 => 2,
623 Version::Unknown(v) => v,
624 }
625 }
626}
627
628bitflags! {
629 #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
630 pub struct ExportFlags: u32 {
631 const VHD = LXSS_EXPORT_DISTRO_FLAGS_VHD;
632 const GZIP = LXSS_EXPORT_DISTRO_FLAGS_GZIP;
633 const XZIP = LXSS_EXPORT_DISTRO_FLAGS_XZIP;
634 const VERBOSE = LXSS_EXPORT_DISTRO_FLAGS_VERBOSE;
635 }
636
637 #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
638 pub struct ImportFlags: u32 {
639 const VHD = LXSS_IMPORT_DISTRO_FLAGS_VHD;
640 const CREATE_SHORTCUT = LXSS_IMPORT_DISTRO_FLAGS_CREATE_SHORTCUT;
641 const NO_OOBE = LXSS_IMPORT_DISTRO_FLAGS_NO_OOBE;
643 const FIXED_VHD = LXSS_IMPORT_DISTRO_FLAGS_FIXED_VHD;
644 }
645
646 #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
647 pub struct CreateInstanceFlags: u32 {
648 const ALLOW_FS_UPGRADE = LXSS_CREATE_INSTANCE_FLAGS_ALLOW_FS_UPGRADE;
649 const OPEN_EXISTING = LXSS_CREATE_INSTANCE_FLAGS_OPEN_EXISTING;
650 const IGNORE_CLIENT = LXSS_CREATE_INSTANCE_FLAGS_IGNORE_CLIENT;
651 const USE_SYSTEM_DISTRO = LXSS_CREATE_INSTANCE_FLAGS_USE_SYSTEM_DISTRO;
652 const SHELL_LOGIN = LXSS_CREATE_INSTANCE_FLAGS_SHELL_LOGIN;
653 }
654}
655
656#[derive(Debug)]
657pub struct WslProcess {
658 pub stdin: Option<ChildStdin>,
659 pub stdout: Option<ChildStdout>,
660 pub stderr: Option<ChildStderr>,
661 pipe: (HANDLE, HANDLE, HANDLE),
662 handle: WslProcessInner,
663}
664
665fn u32_to_exit_status(exit_code: u32) -> ExitStatus {
666 #[cfg(unix)]
668 use std::os::unix::process::ExitStatusExt;
669 #[cfg(windows)]
670 use std::os::windows::process::ExitStatusExt;
671
672 ExitStatusExt::from_raw(exit_code as _)
673}
674
675impl WslProcess {
676 pub fn wait(self) -> Result<ExitStatus, WslError> {
677 match &self.handle {
678 WslProcessInner::WSL1(handle) => {
679 let exit_code = unsafe { wait_for_wsl_process(*handle, u32::MAX)? };
681 Ok(u32_to_exit_status(exit_code))
682 }
683 WslProcessInner::WSL2(interop, _) => {
684 let exit = interop.recv_exit_code();
685 Ok(exit.map(u32_to_exit_status).unwrap_or_default())
686 }
687 }
688 }
689}
690
691impl Drop for WslProcess {
692 fn drop(&mut self) {
693 match self.handle {
694 WslProcessInner::WSL2(_, handle) => unsafe { _ = CloseHandle(handle) },
695 WslProcessInner::WSL1(handle) => unsafe {
696 _ = CloseHandle(handle);
697 },
698 }
699
700 unsafe {
701 _ = CloseHandle(self.pipe.0);
702 _ = CloseHandle(self.pipe.1);
703 _ = CloseHandle(self.pipe.2);
704 }
705 }
706}
707
708#[derive(Debug)]
709enum WslProcessInner {
710 WSL1(HANDLE),
711 WSL2(Interop, HANDLE),
712}