1use core::convert::TryFrom;
16use core::convert::TryInto;
17use core::sync::atomic::Ordering;
18use ctaphid_dispatch::{app::Command, Requester};
22use heapless_bytes::Bytes;
23use ref_swap::OptionRefSwap;
24use trussed_core::InterruptFlag;
25use usb_device::{
27 bus::UsbBus,
28 endpoint::{EndpointAddress, EndpointIn, EndpointOut},
29 UsbError,
30 };
32
33use crate::{constants::PACKET_SIZE, types::KeepaliveStatus};
34
35enum AuthenticatorError {
38 ChannelBusy,
39 InvalidChannel,
40 InvalidCommand,
41 InvalidLength,
42 InvalidSeq,
43 Timeout,
44}
45
46impl From<AuthenticatorError> for u8 {
47 fn from(error: AuthenticatorError) -> Self {
48 match error {
49 AuthenticatorError::InvalidCommand => 0x01,
50 AuthenticatorError::InvalidLength => 0x03,
51 AuthenticatorError::InvalidSeq => 0x04,
52 AuthenticatorError::Timeout => 0x05,
53 AuthenticatorError::ChannelBusy => 0x06,
54 AuthenticatorError::InvalidChannel => 0x0B,
55 }
56 }
57}
58
59#[derive(Copy, Clone, Debug, Eq, PartialEq)]
61pub struct Request {
62 channel: u32,
63 command: Command,
64 length: u16,
65 timestamp: u32,
66}
67
68#[derive(Copy, Clone, Debug, Eq, PartialEq)]
70pub struct Response {
71 channel: u32,
72 command: Command,
73 length: u16,
74}
75
76impl Response {
77 pub fn from_request_and_size(request: Request, size: usize) -> Self {
78 Self {
79 channel: request.channel,
80 command: request.command,
81 length: size as u16,
82 }
83 }
84
85 pub fn error_from_request(request: Request) -> Self {
86 Self::error_on_channel(request.channel)
87 }
88
89 pub fn error_on_channel(channel: u32) -> Self {
90 Self {
91 channel,
92 command: ctaphid_dispatch::app::Command::Error,
93 length: 1,
94 }
95 }
96}
97
98#[derive(Copy, Clone, Debug, Eq, PartialEq)]
99pub struct MessageState {
100 next_sequence: u8,
102 transmitted: usize,
104}
105
106impl Default for MessageState {
107 fn default() -> Self {
108 Self {
109 next_sequence: 0,
110 transmitted: PACKET_SIZE - 7,
111 }
112 }
113}
114
115impl MessageState {
116 pub fn absorb_packet(&mut self) {
118 self.next_sequence += 1;
119 self.transmitted += PACKET_SIZE - 5;
120 }
121}
122
123#[derive(Clone, Debug, Eq, PartialEq)]
124#[allow(unused)]
125pub enum State {
126 Idle,
127
128 Receiving((Request, MessageState)),
130
131 WaitingOnAuthenticator(Request),
138
139 WaitingToSend(Response),
140
141 Sending((Response, MessageState)),
142}
143
144pub struct Pipe<'alloc, 'pipe, 'interrupt, Bus: UsbBus, const N: usize> {
145 read_endpoint: EndpointOut<'alloc, Bus>,
146 write_endpoint: EndpointIn<'alloc, Bus>,
147 state: State,
148
149 interchange: Requester<'pipe, N>,
150 interrupt: Option<&'interrupt OptionRefSwap<'interrupt, InterruptFlag>>,
151
152 buffer: [u8; N],
154
155 last_channel: u32,
158
159 pub(crate) implements: u8,
161
162 pub(crate) last_milliseconds: u32,
164
165 started_processing: bool,
167
168 needs_keepalive: bool,
169
170 pub(crate) version: crate::Version,
171}
172
173impl<'alloc, 'pipe, Bus: UsbBus, const N: usize> Pipe<'alloc, 'pipe, '_, Bus, N> {
174 pub(crate) fn new(
175 read_endpoint: EndpointOut<'alloc, Bus>,
176 write_endpoint: EndpointIn<'alloc, Bus>,
177 interchange: Requester<'pipe, N>,
178 initial_milliseconds: u32,
179 ) -> Self {
180 Self {
181 read_endpoint,
182 write_endpoint,
183 state: State::Idle,
184 interchange,
185 buffer: [0u8; N],
186 last_channel: 0,
187 interrupt: None,
188 implements: 0x80,
190 last_milliseconds: initial_milliseconds,
191 started_processing: false,
192 needs_keepalive: false,
193 version: Default::default(),
194 }
195 }
196}
197
198impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus, const N: usize>
199 Pipe<'alloc, 'pipe, 'interrupt, Bus, N>
200{
201 pub(crate) fn with_interrupt(
206 read_endpoint: EndpointOut<'alloc, Bus>,
207 write_endpoint: EndpointIn<'alloc, Bus>,
208 interchange: Requester<'pipe, N>,
209 interrupt: Option<&'interrupt OptionRefSwap<'interrupt, InterruptFlag>>,
210 initial_milliseconds: u32,
211 ) -> Self {
212 Self {
213 read_endpoint,
214 write_endpoint,
215 state: State::Idle,
216 interchange,
217 buffer: [0u8; N],
218 last_channel: 0,
219 interrupt,
220 implements: 0x80,
222 last_milliseconds: initial_milliseconds,
223 started_processing: false,
224 needs_keepalive: false,
225 version: Default::default(),
226 }
227 }
228
229 pub(crate) fn set_version(&mut self, version: crate::Version) {
230 self.version = version;
231 }
232
233 pub fn read_address(&self) -> EndpointAddress {
234 self.read_endpoint.address()
235 }
236
237 pub fn write_address(&self) -> EndpointAddress {
238 self.write_endpoint.address()
239 }
240
241 pub(crate) fn read_endpoint(&self) -> &EndpointOut<'alloc, Bus> {
243 &self.read_endpoint
244 }
245
246 pub(crate) fn write_endpoint(&self) -> &EndpointIn<'alloc, Bus> {
248 &self.write_endpoint
249 }
250
251 fn cancel_ongoing_activity(&mut self) {
252 if matches!(self.state, State::WaitingOnAuthenticator(_)) {
253 info_now!("Interrupting request");
254 if let Some(Some(i)) = self.interrupt.map(|i| i.load(Ordering::Relaxed)) {
255 info_now!("Loaded some interrupter");
256 i.interrupt();
257 }
258 }
259 }
260
261 pub(crate) fn read_and_handle_packet(&mut self) {
266 let mut packet = [0u8; PACKET_SIZE];
268 match self.read_endpoint.read(&mut packet) {
269 Ok(PACKET_SIZE) => {}
270 Ok(_size) => {
271 info!("error unexpected size {}", _size);
278 return;
279 }
280 Err(_error) => {
285 info!("error no {}", _error as i32);
286 return;
287 }
288 };
289 info!(">> ");
290 info!("{}", hex_str!(&packet[..16]));
291
292 let channel = u32::from_be_bytes(packet[..4].try_into().unwrap());
294 let is_initialization = (packet[4] >> 7) != 0;
297 if is_initialization {
300 info!("init");
302
303 let command_number = packet[4] & !0x80;
304 let command = match Command::try_from(command_number) {
307 Ok(command) => command,
308 Err(_) => {
310 info!("Received invalid command.");
311 self.start_sending_error_on_channel(
312 channel,
313 AuthenticatorError::InvalidCommand,
314 );
315 return;
316 }
317 };
318
319 let length = u16::from_be_bytes(packet[5..][..2].try_into().unwrap());
321
322 let timestamp = self.last_milliseconds;
323 let current_request = Request {
324 channel,
325 command,
326 length,
327 timestamp,
328 };
329
330 if !(self.state == State::Idle) {
331 let request = match self.state {
332 State::WaitingOnAuthenticator(request) => request,
333 State::Receiving((request, _message_state)) => request,
334 _ => {
335 info_now!("Ignoring transaction as we're already transmitting.");
336 return;
337 }
338 };
339 if packet[4] == 0x86 {
340 info_now!("Resyncing!");
341 self.cancel_ongoing_activity();
342 } else {
343 if channel == request.channel {
344 if command == Command::Cancel {
345 info_now!("Cancelling");
346 self.cancel_ongoing_activity();
347 } else {
348 info_now!("Expected seq, {:?}", request.command);
349 self.start_sending_error(request, AuthenticatorError::InvalidSeq);
350 }
351 } else {
352 info_now!("busy.");
353 self.send_error_now(current_request, AuthenticatorError::ChannelBusy);
354 }
355
356 return;
357 }
358 }
359
360 if length > N as u16 {
361 info!("Error message too big.");
362 self.send_error_now(current_request, AuthenticatorError::InvalidLength);
363 return;
364 }
365
366 if length > PACKET_SIZE as u16 - 7 {
367 self.buffer[..PACKET_SIZE - 7].copy_from_slice(&packet[7..]);
370 self.state = State::Receiving((current_request, { MessageState::default() }));
371 } else {
373 self.buffer[..length as usize].copy_from_slice(&packet[7..][..length as usize]);
375 self.dispatch_request(current_request);
376 }
377 } else {
378 match self.state {
380 State::Receiving((request, mut message_state)) => {
381 let sequence = packet[4];
382 if sequence != message_state.next_sequence {
384 info!("Error invalid cont pkt");
388 self.start_sending_error(request, AuthenticatorError::InvalidSeq);
389 return;
390 }
391 if channel != request.channel {
392 info!("Ignore invalid channel");
396 return;
397 }
398
399 let payload_length = request.length as usize;
400 if message_state.transmitted + (PACKET_SIZE - 5) < payload_length {
401 self.buffer[message_state.transmitted..][..PACKET_SIZE - 5]
405 .copy_from_slice(&packet[5..]);
406 message_state.absorb_packet();
407 self.state = State::Receiving((request, message_state));
408 } else {
410 let missing = request.length as usize - message_state.transmitted;
411 self.buffer[message_state.transmitted..payload_length]
412 .copy_from_slice(&packet[5..][..missing]);
413 self.dispatch_request(request);
414 }
415 }
416 _ => {
417 info!("Ignore unexpected cont pkt");
419 }
420 }
421 }
422 }
423
424 pub fn check_timeout(&mut self, milliseconds: u32) {
425 let last = self.last_milliseconds;
428 self.last_milliseconds = milliseconds;
429 if let State::Receiving((request, _message_state)) = &mut self.state {
430 if (milliseconds - last) > 200 {
431 debug!(
435 "lapse in hid check.. {} {} {}",
436 request.timestamp, milliseconds, last
437 );
438 request.timestamp = milliseconds;
439 }
440 else if (milliseconds > request.timestamp && (milliseconds - request.timestamp) > 550)
442 || (milliseconds < request.timestamp && milliseconds > 550)
443 {
444 debug!(
445 "Channel timeout. {}, {}, {}",
446 request.timestamp, milliseconds, last
447 );
448 let req = *request;
449 self.start_sending_error(req, AuthenticatorError::Timeout);
450 }
451 }
452 }
453
454 fn dispatch_request(&mut self, request: Request) {
455 info!("Got request: {:?}", request.command);
456 match request.command {
457 Command::Init => {}
458 _ => {
459 if request.channel == 0xffffffff {
460 self.start_sending_error(request, AuthenticatorError::InvalidChannel);
461 return;
462 }
463 }
464 }
465 match request.command {
467 Command::Init => {
468 match request.channel {
471 0 => {
472 self.start_sending_error(request, AuthenticatorError::InvalidChannel);
474 }
475
476 cid => {
478 if request.length != 8 {
479 info!("Invalid length for init. ignore.");
481 } else {
482 self.last_channel += 1;
483 let _nonce = &self.buffer[..8];
486 let response = Response {
487 channel: cid,
488 command: request.command,
489 length: 17,
490 };
491
492 self.buffer[8..12].copy_from_slice(&self.last_channel.to_be_bytes());
493 self.buffer[12] = 2;
495 self.buffer[13] = self.version.major;
497 self.buffer[14] = self.version.minor;
499 self.buffer[15] = self.version.build;
501 self.buffer[16] = self.implements;
507 self.start_sending(response);
508 }
509 }
510 }
511 }
512
513 Command::Ping => {
514 let response = Response::from_request_and_size(request, request.length as usize);
515 self.start_sending(response);
516 }
517
518 Command::Cancel => {
519 info!("CTAPHID_CANCEL");
520 self.cancel_ongoing_activity();
521 }
522
523 _ => {
524 self.needs_keepalive = request.command == Command::Cbor;
525 if self.interchange.state() == interchange::State::Responded {
526 info!("dumping stale response");
527 self.interchange.take_response();
528 }
529 match self.interchange.request((
530 request.command,
531 Bytes::try_from(&self.buffer[..request.length as usize]).unwrap(),
532 )) {
533 Ok(_) => {
534 self.state = State::WaitingOnAuthenticator(request);
535 self.started_processing = true;
536 }
537 Err(_) => {
538 info_now!("STATE: {:?}", self.interchange.state());
540 info!("can't handle more than one authenticator request at a time.");
541 self.send_error_now(request, AuthenticatorError::ChannelBusy);
542 }
543 }
544 }
545 }
546 }
547
548 pub fn did_start_processing(&mut self) -> bool {
549 if self.started_processing {
550 self.started_processing = false;
551 true
552 } else {
553 false
554 }
555 }
556
557 pub fn send_keepalive(&mut self, is_waiting_for_user_presence: bool) -> bool {
558 if let State::WaitingOnAuthenticator(request) = &self.state {
559 if !self.needs_keepalive {
560 info!("cmd does not need keepalive messages");
562 false
563 } else {
564 info!("keepalive");
565
566 let mut packet = [0u8; PACKET_SIZE];
567
568 packet[..4].copy_from_slice(&request.channel.to_be_bytes());
569 packet[4] = 0x80 | 0x3B;
570 packet[5..7].copy_from_slice(&1u16.to_be_bytes());
571
572 if is_waiting_for_user_presence {
573 packet[7] = KeepaliveStatus::UpNeeded as u8;
574 } else {
575 packet[7] = KeepaliveStatus::Processing as u8;
576 }
577
578 self.write_endpoint.write(&packet).ok();
579
580 true
581 }
582 } else {
583 info!("keepalive done");
584 false
585 }
586 }
587
588 #[inline(never)]
589 pub fn handle_response(&mut self) {
590 if let State::WaitingOnAuthenticator(request) = self.state {
591 if let Ok(response) = self.interchange.response() {
592 match &response.0 {
593 Err(ctaphid_dispatch::app::Error::InvalidCommand) => {
594 info!("Got waiting reply from authenticator??");
595 self.start_sending_error(request, AuthenticatorError::InvalidCommand);
596 }
597 Err(ctaphid_dispatch::app::Error::InvalidLength) => {
598 info!("Error, payload needed app command.");
599 self.start_sending_error(request, AuthenticatorError::InvalidLength);
600 }
601 Err(ctaphid_dispatch::app::Error::NoResponse) => {
602 info!("Got waiting noresponse from authenticator??");
603 }
604
605 Ok(message) => {
606 if message.len() > self.buffer.len() {
607 error!(
608 "Message is longer than buffer ({} > {})",
609 message.len(),
610 self.buffer.len(),
611 );
612 self.start_sending_error(request, AuthenticatorError::InvalidLength);
613 } else {
614 info!(
615 "Got {} bytes response from authenticator, starting send",
616 message.len()
617 );
618 let response = Response::from_request_and_size(request, message.len());
619 self.buffer[..message.len()].copy_from_slice(message);
620 self.start_sending(response);
621 }
622 }
623 }
624 }
625 }
626 }
627
628 fn start_sending(&mut self, response: Response) {
629 self.state = State::WaitingToSend(response);
630 self.maybe_write_packet();
631 }
632
633 fn start_sending_error(&mut self, request: Request, error: AuthenticatorError) {
634 self.start_sending_error_on_channel(request.channel, error);
635 }
636
637 fn start_sending_error_on_channel(&mut self, channel: u32, error: AuthenticatorError) {
638 self.buffer[0] = error.into();
639 let response = Response::error_on_channel(channel);
640 self.start_sending(response);
641 }
642
643 fn send_error_now(&mut self, request: Request, error: AuthenticatorError) {
644 let last_state = core::mem::replace(&mut self.state, State::Idle);
645 let last_first_byte = self.buffer[0];
646
647 self.buffer[0] = error as u8;
648 let response = Response::error_from_request(request);
649 self.start_sending(response);
650 self.maybe_write_packet();
651
652 self.state = last_state;
653 self.buffer[0] = last_first_byte;
654 }
655
656 #[inline(never)]
658 pub(crate) fn maybe_write_packet(&mut self) {
659 match self.state {
660 State::WaitingToSend(response) => {
661 let mut packet = [0u8; PACKET_SIZE];
663 packet[..4].copy_from_slice(&response.channel.to_be_bytes());
664 packet[4] = response.command.into_u8() | 0x80;
666 packet[5..7].copy_from_slice(&response.length.to_be_bytes());
667
668 let fits_in_one_packet = 7 + response.length as usize <= PACKET_SIZE;
669 if fits_in_one_packet {
670 packet[7..][..response.length as usize]
671 .copy_from_slice(&self.buffer[..response.length as usize]);
672 self.state = State::Idle;
673 } else {
674 packet[7..].copy_from_slice(&self.buffer[..PACKET_SIZE - 7]);
675 }
676
677 let result = self.write_endpoint.write(&packet);
681
682 match result {
683 Err(UsbError::WouldBlock) => {
684 info!("hid usb WouldBlock");
687 }
688 Err(_) => {
689 panic!("unexpected error writing packet!");
691 }
692 Ok(PACKET_SIZE) => {
693 if fits_in_one_packet {
695 self.state = State::Idle;
696 } else {
699 self.state = State::Sending((response, MessageState::default()));
700 }
705 }
706 Ok(_) => {
707 panic!("unexpected size writing packet!");
709 }
710 };
711 }
712
713 State::Sending((response, mut message_state)) => {
714 let mut packet = [0u8; PACKET_SIZE];
716 packet[..4].copy_from_slice(&response.channel.to_be_bytes());
717 packet[4] = message_state.next_sequence;
718
719 let sent = message_state.transmitted;
720 let remaining = response.length as usize - sent;
721 let last_packet = 5 + remaining <= PACKET_SIZE;
722 if last_packet {
723 packet[5..][..remaining]
724 .copy_from_slice(&self.buffer[message_state.transmitted..][..remaining]);
725 } else {
726 packet[5..].copy_from_slice(
727 &self.buffer[message_state.transmitted..][..PACKET_SIZE - 5],
728 );
729 }
730
731 let result = self.write_endpoint.write(&packet);
735
736 match result {
737 Err(UsbError::WouldBlock) => {
738 }
743 Err(_) => {
744 panic!("unexpected error writing packet!");
746 }
747 Ok(PACKET_SIZE) => {
748 if last_packet {
750 self.state = State::Idle;
751 } else {
753 message_state.absorb_packet();
754 self.state = State::Sending((response, message_state));
758 }
759 }
760 Ok(_) => {
761 debug!("short write");
762 panic!("unexpected size writing packet!");
763 }
764 };
765 }
766
767 _ => {}
769 }
770 }
771}