use std::{mem, ptr::NonNull, time::Duration, u8};
use libc::{c_int, c_uchar, c_uint};
use libusb1_sys::{constants::*, *};
use crate::{
config_descriptor::ConfigDescriptor,
device::{self, Device},
device_descriptor::DeviceDescriptor,
error::{self, Error},
fields::{request_type, Direction, Recipient, RequestType},
interface_descriptor::InterfaceDescriptor,
language::Language,
UsbContext,
};
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
struct ClaimedInterfaces {
inner: [u128; 2],
}
impl ClaimedInterfaces {
fn new() -> Self {
Self { inner: [0, 0] }
}
fn get_index_and_mask(interface: u8) -> (usize, u128) {
((interface / 128) as usize, 1 << (interface % 128))
}
fn insert(&mut self, interface: u8) {
let (index, mask) = ClaimedInterfaces::get_index_and_mask(interface);
self.inner[index] |= mask;
}
fn remove(&mut self, interface: u8) {
let (index, mask) = ClaimedInterfaces::get_index_and_mask(interface);
self.inner[index] &= !mask;
}
fn contains(&self, interface: u8) -> bool {
let (index, mask) = ClaimedInterfaces::get_index_and_mask(interface);
self.inner[index] & mask != 0
}
fn size(&self) -> usize {
self.inner.iter().map(|v| v.count_ones()).sum::<u32>() as usize
}
fn iter(&self) -> ClaimedInterfacesIter {
ClaimedInterfacesIter::new(&self)
}
}
struct ClaimedInterfacesIter<'a> {
index: u16,
remaining: usize,
source: &'a ClaimedInterfaces,
}
impl<'a> ClaimedInterfacesIter<'a> {
fn new<'source>(source: &'source ClaimedInterfaces) -> ClaimedInterfacesIter<'source> {
ClaimedInterfacesIter {
index: 0,
remaining: source.size(),
source,
}
}
}
impl<'a> Iterator for ClaimedInterfacesIter<'a> {
type Item = u8;
fn next(&mut self) -> Option<u8> {
while self.index <= u8::MAX as u16 {
let index = self.index as u8;
let contains = self.source.contains(index);
self.index += 1;
if contains {
self.remaining -= 1;
return Some(index);
}
}
None
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
#[derive(Eq, PartialEq)]
pub struct DeviceHandle<T: UsbContext> {
context: T,
handle: NonNull<libusb_device_handle>,
interfaces: ClaimedInterfaces,
}
impl<T: UsbContext> Drop for DeviceHandle<T> {
fn drop(&mut self) {
unsafe {
for iface in self.interfaces.iter() {
libusb_release_interface(self.handle.as_ptr(), iface as c_int);
}
libusb_close(self.handle.as_ptr());
}
}
}
unsafe impl<T: UsbContext> Send for DeviceHandle<T> {}
unsafe impl<T: UsbContext> Sync for DeviceHandle<T> {}
impl<T: UsbContext> DeviceHandle<T> {
pub fn as_raw(&self) -> *mut libusb_device_handle {
self.handle.as_ptr()
}
pub fn device(&self) -> Device<T> {
unsafe {
device::Device::from_libusb(
self.context.clone(),
std::ptr::NonNull::new_unchecked(libusb_get_device(self.handle.as_ptr())),
)
}
}
pub unsafe fn from_libusb(
context: T,
handle: NonNull<libusb_device_handle>,
) -> DeviceHandle<T> {
DeviceHandle {
context,
handle,
interfaces: ClaimedInterfaces::new(),
}
}
pub fn active_configuration(&self) -> crate::Result<u8> {
let mut config = mem::MaybeUninit::<c_int>::uninit();
try_unsafe!(libusb_get_configuration(
self.handle.as_ptr(),
config.as_mut_ptr()
));
Ok(unsafe { config.assume_init() } as u8)
}
pub fn set_active_configuration(&mut self, config: u8) -> crate::Result<()> {
try_unsafe!(libusb_set_configuration(
self.handle.as_ptr(),
c_int::from(config)
));
Ok(())
}
pub fn unconfigure(&mut self) -> crate::Result<()> {
try_unsafe!(libusb_set_configuration(self.handle.as_ptr(), -1));
Ok(())
}
pub fn reset(&mut self) -> crate::Result<()> {
try_unsafe!(libusb_reset_device(self.handle.as_ptr()));
Ok(())
}
pub fn clear_halt(&mut self, endpoint: u8) -> crate::Result<()> {
try_unsafe!(libusb_clear_halt(self.handle.as_ptr(), endpoint));
Ok(())
}
pub fn kernel_driver_active(&self, iface: u8) -> crate::Result<bool> {
match unsafe { libusb_kernel_driver_active(self.handle.as_ptr(), c_int::from(iface)) } {
0 => Ok(false),
1 => Ok(true),
err => Err(error::from_libusb(err)),
}
}
pub fn detach_kernel_driver(&mut self, iface: u8) -> crate::Result<()> {
try_unsafe!(libusb_detach_kernel_driver(
self.handle.as_ptr(),
c_int::from(iface)
));
Ok(())
}
pub fn attach_kernel_driver(&mut self, iface: u8) -> crate::Result<()> {
try_unsafe!(libusb_attach_kernel_driver(
self.handle.as_ptr(),
c_int::from(iface)
));
Ok(())
}
pub fn set_auto_detach_kernel_driver(&mut self, auto_detach: bool) -> crate::Result<()> {
try_unsafe!(libusb_set_auto_detach_kernel_driver(
self.handle.as_ptr(),
auto_detach.into()
));
Ok(())
}
pub fn claim_interface(&mut self, iface: u8) -> crate::Result<()> {
try_unsafe!(libusb_claim_interface(
self.handle.as_ptr(),
c_int::from(iface)
));
self.interfaces.insert(iface);
Ok(())
}
pub fn release_interface(&mut self, iface: u8) -> crate::Result<()> {
try_unsafe!(libusb_release_interface(
self.handle.as_ptr(),
c_int::from(iface)
));
self.interfaces.remove(iface);
Ok(())
}
pub fn set_alternate_setting(&mut self, iface: u8, setting: u8) -> crate::Result<()> {
try_unsafe!(libusb_set_interface_alt_setting(
self.handle.as_ptr(),
c_int::from(iface),
c_int::from(setting)
));
Ok(())
}
pub fn read_interrupt(
&self,
endpoint: u8,
buf: &mut [u8],
timeout: Duration,
) -> crate::Result<usize> {
if endpoint & LIBUSB_ENDPOINT_DIR_MASK != LIBUSB_ENDPOINT_IN {
return Err(Error::InvalidParam);
}
let mut transferred = mem::MaybeUninit::<c_int>::uninit();
unsafe {
match libusb_interrupt_transfer(
self.handle.as_ptr(),
endpoint,
buf.as_mut_ptr() as *mut c_uchar,
buf.len() as c_int,
transferred.as_mut_ptr(),
timeout.as_millis() as c_uint,
) {
0 => Ok(transferred.assume_init() as usize),
err if err == LIBUSB_ERROR_INTERRUPTED => {
let transferred = transferred.assume_init();
if transferred > 0 {
Ok(transferred as usize)
} else {
Err(error::from_libusb(err))
}
}
err => Err(error::from_libusb(err)),
}
}
}
pub fn write_interrupt(
&self,
endpoint: u8,
buf: &[u8],
timeout: Duration,
) -> crate::Result<usize> {
if endpoint & LIBUSB_ENDPOINT_DIR_MASK != LIBUSB_ENDPOINT_OUT {
return Err(Error::InvalidParam);
}
let mut transferred = mem::MaybeUninit::<c_int>::uninit();
unsafe {
match libusb_interrupt_transfer(
self.handle.as_ptr(),
endpoint,
buf.as_ptr() as *mut c_uchar,
buf.len() as c_int,
transferred.as_mut_ptr(),
timeout.as_millis() as c_uint,
) {
0 => Ok(transferred.assume_init() as usize),
err if err == LIBUSB_ERROR_INTERRUPTED => {
let transferred = transferred.assume_init();
if transferred > 0 {
Ok(transferred as usize)
} else {
Err(error::from_libusb(err))
}
}
err => Err(error::from_libusb(err)),
}
}
}
pub fn read_bulk(
&self,
endpoint: u8,
buf: &mut [u8],
timeout: Duration,
) -> crate::Result<usize> {
if endpoint & LIBUSB_ENDPOINT_DIR_MASK != LIBUSB_ENDPOINT_IN {
return Err(Error::InvalidParam);
}
let mut transferred = mem::MaybeUninit::<c_int>::uninit();
unsafe {
match libusb_bulk_transfer(
self.handle.as_ptr(),
endpoint,
buf.as_mut_ptr() as *mut c_uchar,
buf.len() as c_int,
transferred.as_mut_ptr(),
timeout.as_millis() as c_uint,
) {
0 => Ok(transferred.assume_init() as usize),
err if err == LIBUSB_ERROR_INTERRUPTED || err == LIBUSB_ERROR_TIMEOUT => {
let transferred = transferred.assume_init();
if transferred > 0 {
Ok(transferred as usize)
} else {
Err(error::from_libusb(err))
}
}
err => Err(error::from_libusb(err)),
}
}
}
pub fn write_bulk(&self, endpoint: u8, buf: &[u8], timeout: Duration) -> crate::Result<usize> {
if endpoint & LIBUSB_ENDPOINT_DIR_MASK != LIBUSB_ENDPOINT_OUT {
return Err(Error::InvalidParam);
}
let mut transferred = mem::MaybeUninit::<c_int>::uninit();
unsafe {
match libusb_bulk_transfer(
self.handle.as_ptr(),
endpoint,
buf.as_ptr() as *mut c_uchar,
buf.len() as c_int,
transferred.as_mut_ptr(),
timeout.as_millis() as c_uint,
) {
0 => Ok(transferred.assume_init() as usize),
err if err == LIBUSB_ERROR_INTERRUPTED || err == LIBUSB_ERROR_TIMEOUT => {
let transferred = transferred.assume_init();
if transferred > 0 {
Ok(transferred as usize)
} else {
Err(error::from_libusb(err))
}
}
err => Err(error::from_libusb(err)),
}
}
}
pub fn read_control(
&self,
request_type: u8,
request: u8,
value: u16,
index: u16,
buf: &mut [u8],
timeout: Duration,
) -> crate::Result<usize> {
if request_type & LIBUSB_ENDPOINT_DIR_MASK != LIBUSB_ENDPOINT_IN {
return Err(Error::InvalidParam);
}
let res = unsafe {
libusb_control_transfer(
self.handle.as_ptr(),
request_type,
request,
value,
index,
buf.as_mut_ptr() as *mut c_uchar,
buf.len() as u16,
timeout.as_millis() as c_uint,
)
};
if res < 0 {
Err(error::from_libusb(res))
} else {
Ok(res as usize)
}
}
pub fn write_control(
&self,
request_type: u8,
request: u8,
value: u16,
index: u16,
buf: &[u8],
timeout: Duration,
) -> crate::Result<usize> {
if request_type & LIBUSB_ENDPOINT_DIR_MASK != LIBUSB_ENDPOINT_OUT {
return Err(Error::InvalidParam);
}
let res = unsafe {
libusb_control_transfer(
self.handle.as_ptr(),
request_type,
request,
value,
index,
buf.as_ptr() as *mut c_uchar,
buf.len() as u16,
timeout.as_millis() as c_uint,
)
};
if res < 0 {
Err(error::from_libusb(res))
} else {
Ok(res as usize)
}
}
pub fn read_languages(&self, timeout: Duration) -> crate::Result<Vec<Language>> {
let mut buf = [0u8; 255];
let len = self.read_control(
request_type(Direction::In, RequestType::Standard, Recipient::Device),
LIBUSB_REQUEST_GET_DESCRIPTOR,
u16::from(LIBUSB_DT_STRING) << 8,
0,
&mut buf,
timeout,
)?;
if len < 2 || buf[0] != len as u8 || len & 0x01 != 0 {
return Err(Error::BadDescriptor);
}
if len == 2 {
return Ok(Vec::new());
}
Ok(buf[0..len]
.chunks(2)
.skip(1)
.map(|chunk| {
let lang_id = u16::from(chunk[0]) | u16::from(chunk[1]) << 8;
crate::language::from_lang_id(lang_id)
})
.collect())
}
pub fn read_string_descriptor_ascii(&self, index: u8) -> crate::Result<String> {
let mut buf = Vec::<u8>::with_capacity(255);
let ptr = buf.as_mut_ptr() as *mut c_uchar;
let capacity = buf.capacity() as i32;
let res = unsafe {
libusb_get_string_descriptor_ascii(self.handle.as_ptr(), index, ptr, capacity)
};
if res < 0 {
return Err(error::from_libusb(res));
}
unsafe {
buf.set_len(res as usize);
}
String::from_utf8(buf).map_err(|_| Error::Other)
}
pub fn read_string_descriptor(
&self,
language: Language,
index: u8,
timeout: Duration,
) -> crate::Result<String> {
let mut buf = [0u8; 255];
let len = self.read_control(
request_type(Direction::In, RequestType::Standard, Recipient::Device),
LIBUSB_REQUEST_GET_DESCRIPTOR,
u16::from(LIBUSB_DT_STRING) << 8 | u16::from(index),
language.lang_id(),
&mut buf,
timeout,
)?;
if len < 2 || buf[0] != len as u8 || len & 0x01 != 0 {
return Err(Error::BadDescriptor);
}
if len == 2 {
return Ok(String::new());
}
let utf16: Vec<u16> = buf[..len]
.chunks(2)
.skip(1)
.map(|chunk| u16::from(chunk[0]) | u16::from(chunk[1]) << 8)
.collect();
String::from_utf16(&utf16).map_err(|_| Error::Other)
}
pub fn read_manufacturer_string_ascii(
&self,
device: &DeviceDescriptor,
) -> crate::Result<String> {
match device.manufacturer_string_index() {
None => Err(Error::InvalidParam),
Some(n) => self.read_string_descriptor_ascii(n),
}
}
pub fn read_manufacturer_string(
&self,
language: Language,
device: &DeviceDescriptor,
timeout: Duration,
) -> crate::Result<String> {
match device.manufacturer_string_index() {
None => Err(Error::InvalidParam),
Some(n) => self.read_string_descriptor(language, n, timeout),
}
}
pub fn read_product_string_ascii(&self, device: &DeviceDescriptor) -> crate::Result<String> {
match device.product_string_index() {
None => Err(Error::InvalidParam),
Some(n) => self.read_string_descriptor_ascii(n),
}
}
pub fn read_product_string(
&self,
language: Language,
device: &DeviceDescriptor,
timeout: Duration,
) -> crate::Result<String> {
match device.product_string_index() {
None => Err(Error::InvalidParam),
Some(n) => self.read_string_descriptor(language, n, timeout),
}
}
pub fn read_serial_number_string_ascii(
&self,
device: &DeviceDescriptor,
) -> crate::Result<String> {
match device.serial_number_string_index() {
None => Err(Error::InvalidParam),
Some(n) => self.read_string_descriptor_ascii(n),
}
}
pub fn read_serial_number_string(
&self,
language: Language,
device: &DeviceDescriptor,
timeout: Duration,
) -> crate::Result<String> {
match device.serial_number_string_index() {
None => Err(Error::InvalidParam),
Some(n) => self.read_string_descriptor(language, n, timeout),
}
}
pub fn read_configuration_string(
&self,
language: Language,
configuration: &ConfigDescriptor,
timeout: Duration,
) -> crate::Result<String> {
match configuration.description_string_index() {
None => Err(Error::InvalidParam),
Some(n) => self.read_string_descriptor(language, n, timeout),
}
}
pub fn read_interface_string(
&self,
language: Language,
interface: &InterfaceDescriptor,
timeout: Duration,
) -> crate::Result<String> {
match interface.description_string_index() {
None => Err(Error::InvalidParam),
Some(n) => self.read_string_descriptor(language, n, timeout),
}
}
}
#[cfg(test)]
mod tests {
use super::ClaimedInterfaces;
use std::u8;
#[test]
fn claimed_interfaces_empty() {
let empty = ClaimedInterfaces::new();
assert_eq!(empty.size(), 0);
for i in 0..=u8::MAX {
assert!(!empty.contains(i), "empty set should not contain {}", i);
}
let mut iter = empty.iter();
assert_eq!(iter.size_hint(), (0, Some(0)));
assert_eq!(iter.next(), None);
}
#[test]
fn claimed_interfaces_one_element() {
let mut interfaces = ClaimedInterfaces::new();
interfaces.insert(94);
assert_eq!(interfaces.size(), 1);
assert!(interfaces.contains(94));
for i in 0..=u8::MAX {
if i == 94 {
continue;
}
assert!(
!interfaces.contains(i),
"interfaces should not contain {}",
i
);
}
let mut iter = interfaces.iter();
assert_eq!(iter.size_hint(), (1, Some(1)));
assert_eq!(iter.next(), Some(94));
assert_eq!(iter.size_hint(), (0, Some(0)));
assert_eq!(iter.next(), None);
}
#[test]
fn claimed_interfaces_many_elements() {
let mut interfaces = ClaimedInterfaces::new();
let elements = vec![94, 0, 255, 17, 183, 6];
for (index, &interface) in elements.iter().enumerate() {
interfaces.insert(interface);
assert_eq!(interfaces.size(), index + 1);
}
for &interface in elements.iter() {
assert!(
interfaces.contains(interface),
"interfaces should contain {}",
interface
);
}
let contents = interfaces.iter().collect::<Vec<_>>().sort();
assert_eq!(contents, elements.clone().sort());
let mut iter = interfaces.iter();
let mut read = 0;
loop {
assert!(
read <= elements.len(),
"read elements {} should not exceed elements size {}",
read,
elements.len()
);
let remaining = elements.len() - read;
assert_eq!(iter.size_hint(), (remaining, Some(remaining)));
match iter.next() {
Some(_) => read += 1,
None => break,
}
}
}
}