1#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7
8use nix::{
9 fcntl,
10 poll::{poll, PollFd, PollFlags},
11};
12use std::{
13 cell::RefCell,
14 collections::HashMap,
15 fs::{File, OpenOptions},
16 io::{Read, Write},
17 os::fd::{AsFd, AsRawFd},
18 path::{Path, PathBuf},
19 rc::Rc,
20 time::{Duration, Instant},
21};
22
23use thiserror::Error;
24use tpm2_protocol::{
25 basic::{TpmHandle, TpmUint32},
26 constant::{MAX_HANDLES, TPM_MAX_COMMAND_SIZE},
27 data::{
28 Tpm2bName, TpmAlgId, TpmCap, TpmCc, TpmEccCurve, TpmHt, TpmPt, TpmRc, TpmRcBase, TpmSt,
29 TpmsAlgProperty, TpmsAuthCommand, TpmsCapabilityData, TpmsContext, TpmsPcrSelect,
30 TpmsPcrSelection, TpmtPublic, TpmuCapabilities,
31 },
32 frame::{
33 tpm_marshal_command, tpm_unmarshal_response, TpmAuthResponses, TpmContextLoadCommand,
34 TpmContextSaveCommand, TpmFlushContextCommand, TpmFrame, TpmGetCapabilityCommand,
35 TpmGetCapabilityResponse, TpmReadPublicCommand, TpmResponse,
36 },
37 TpmWriter,
38};
39use tracing::{debug, trace};
40
41#[derive(Debug, Error)]
43pub enum TpmDeviceError {
44 #[error("device is already borrowed")]
45 AlreadyBorrowed,
46 #[error("capability not found: {0}")]
47 CapabilityMissing(TpmCap),
48 #[error("operation interrupted by user")]
49 Interrupted,
50 #[error("invalid response")]
51 InvalidResponse,
52
53 #[error("I/O: {0}")]
54 Io(#[from] std::io::Error),
55
56 #[error("marshal: {0}")]
58 Marshal(tpm2_protocol::TpmProtocolError),
59
60 #[error("device not available")]
61 NotAvailable,
62 #[error("operation failed")]
63 OperationFailed,
64 #[error("PCR banks not available")]
65 PcrBanksNotAvailable,
66 #[error("PCR bank selection mismatch")]
67 PcrBankSelectionMismatch,
68
69 #[error("response mismatch: {0}")]
71 ResponseMismatch(TpmCc),
72
73 #[error("TPM command timed out")]
74 Timeout,
75 #[error("TPM return code: {0}")]
76 TpmRc(TpmRc),
77
78 #[error("unmarshal: {0}")]
80 Unmarshal(tpm2_protocol::TpmProtocolError),
81
82 #[error("unexpected EOF")]
83 UnexpectedEof,
84}
85
86impl From<TpmRc> for TpmDeviceError {
87 fn from(rc: TpmRc) -> Self {
88 Self::TpmRc(rc)
89 }
90}
91
92impl From<nix::Error> for TpmDeviceError {
93 fn from(err: nix::Error) -> Self {
94 Self::Io(std::io::Error::from_raw_os_error(err as i32))
95 }
96}
97
98pub fn with_device<F, T, E>(device: Option<Rc<RefCell<TpmDevice>>>, f: F) -> Result<T, E>
110where
111 F: FnOnce(&mut TpmDevice) -> Result<T, E>,
112 E: From<TpmDeviceError>,
113{
114 let device_rc = device.ok_or(TpmDeviceError::NotAvailable)?;
115 let mut device_guard = device_rc
116 .try_borrow_mut()
117 .map_err(|_| TpmDeviceError::AlreadyBorrowed)?;
118 f(&mut device_guard)
119}
120
121pub struct TpmDeviceBuilder {
123 path: PathBuf,
124 timeout: Duration,
125 interrupted: Box<dyn Fn() -> bool>,
126}
127
128impl Default for TpmDeviceBuilder {
129 fn default() -> Self {
130 Self {
131 path: PathBuf::from("/dev/tpmrm0"),
132 timeout: Duration::from_secs(120),
133 interrupted: Box::new(|| false),
134 }
135 }
136}
137
138impl TpmDeviceBuilder {
139 #[must_use]
141 pub fn with_path<P: AsRef<Path>>(mut self, path: P) -> Self {
142 self.path = path.as_ref().to_path_buf();
143 self
144 }
145
146 #[must_use]
148 pub fn with_timeout(mut self, timeout: Duration) -> Self {
149 self.timeout = timeout;
150 self
151 }
152
153 #[must_use]
155 pub fn with_interrupted<F>(mut self, handler: F) -> Self
156 where
157 F: Fn() -> bool + 'static,
158 {
159 self.interrupted = Box::new(handler);
160 self
161 }
162
163 pub fn build(self) -> Result<TpmDevice, TpmDeviceError> {
170 let file = OpenOptions::new()
171 .read(true)
172 .write(true)
173 .open(&self.path)
174 .map_err(TpmDeviceError::Io)?;
175
176 let fd = file.as_raw_fd();
177 let flags = fcntl::fcntl(fd, fcntl::FcntlArg::F_GETFL)?;
178 let mut oflags = fcntl::OFlag::from_bits_truncate(flags);
179 oflags.insert(fcntl::OFlag::O_NONBLOCK);
180 fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(oflags))?;
181
182 Ok(TpmDevice {
183 file,
184 name_cache: HashMap::new(),
185 interrupted: self.interrupted,
186 timeout: self.timeout,
187 command: Vec::with_capacity(TPM_MAX_COMMAND_SIZE),
188 response: Vec::with_capacity(TPM_MAX_COMMAND_SIZE),
189 })
190 }
191}
192
193pub struct TpmDevice {
194 file: File,
195 name_cache: HashMap<u32, (TpmtPublic, Tpm2bName)>,
196 interrupted: Box<dyn Fn() -> bool>,
197 timeout: Duration,
198 command: Vec<u8>,
199 response: Vec<u8>,
200}
201
202impl std::fmt::Debug for TpmDevice {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 f.debug_struct("Device")
205 .field("file", &self.file)
206 .field("name_cache", &self.name_cache)
207 .field("timeout", &self.timeout)
208 .finish_non_exhaustive()
209 }
210}
211
212impl TpmDevice {
213 const NO_SESSIONS: &'static [TpmsAuthCommand] = &[];
214
215 #[must_use]
217 pub fn builder() -> TpmDeviceBuilder {
218 TpmDeviceBuilder::default()
219 }
220
221 fn receive(&mut self, buf: &mut [u8]) -> Result<usize, TpmDeviceError> {
222 let fd = self.file.as_fd();
223 let mut fds = [PollFd::new(fd, PollFlags::POLLIN)];
224
225 let num_events = match poll(&mut fds, 100u16) {
226 Ok(num) => num,
227 Err(nix::Error::EINTR) => return Ok(0),
228 Err(e) => return Err(e.into()),
229 };
230
231 if num_events == 0 {
232 return Ok(0);
233 }
234
235 let revents = fds[0].revents().unwrap_or(PollFlags::empty());
236
237 if revents.intersects(PollFlags::POLLERR | PollFlags::POLLNVAL) {
238 return Err(TpmDeviceError::UnexpectedEof);
239 }
240
241 if revents.contains(PollFlags::POLLIN) {
242 match self.file.read(buf) {
243 Ok(0) => Err(TpmDeviceError::UnexpectedEof),
244 Ok(n) => Ok(n),
245 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(0),
246 Err(e) if e.kind() == std::io::ErrorKind::Interrupted => Ok(0),
247 Err(e) => Err(e.into()),
248 }
249 } else if revents.contains(PollFlags::POLLHUP) {
250 Err(TpmDeviceError::UnexpectedEof)
251 } else {
252 Ok(0)
253 }
254 }
255
256 pub fn transmit<C: TpmFrame>(
276 &mut self,
277 command: &C,
278 sessions: &[TpmsAuthCommand],
279 ) -> Result<(TpmResponse, TpmAuthResponses), TpmDeviceError> {
280 self.prepare_command(command, sessions)?;
281 let cc = command.cc();
282
283 self.file.write_all(&self.command)?;
284 self.file.flush()?;
285
286 let start_time = Instant::now();
287 self.response.clear();
288 let mut total_size: Option<usize> = None;
289 let mut temp_buf = [0u8; 1024];
290
291 loop {
292 if (self.interrupted)() {
293 return Err(TpmDeviceError::Interrupted);
294 }
295 if start_time.elapsed() > self.timeout {
296 return Err(TpmDeviceError::Timeout);
297 }
298
299 let n = self.receive(&mut temp_buf)?;
300 if n > 0 {
301 self.response.extend_from_slice(&temp_buf[..n]);
302 }
303
304 if total_size.is_none() && self.response.len() >= 10 {
305 let Ok(size_bytes): Result<[u8; 4], _> = self.response[2..6].try_into() else {
306 return Err(TpmDeviceError::InvalidResponse);
307 };
308 let size = u32::from_be_bytes(size_bytes) as usize;
309 if !(10..={ TPM_MAX_COMMAND_SIZE }).contains(&size) {
310 return Err(TpmDeviceError::InvalidResponse);
311 }
312 total_size = Some(size);
313 }
314
315 if let Some(size) = total_size {
316 if self.response.len() == size {
317 break;
318 }
319 if self.response.len() > size {
320 return Err(TpmDeviceError::InvalidResponse);
321 }
322 }
323 }
324
325 let result = tpm_unmarshal_response(cc, &self.response).map_err(TpmDeviceError::Unmarshal);
326 trace!("{} R: {}", cc, hex::encode(&self.response));
327 Ok(result??)
328 }
329
330 fn prepare_command<C: TpmFrame>(
331 &mut self,
332 command: &C,
333 sessions: &[TpmsAuthCommand],
334 ) -> Result<(), TpmDeviceError> {
335 let cc = command.cc();
336 let tag = if sessions.is_empty() {
337 TpmSt::NoSessions
338 } else {
339 TpmSt::Sessions
340 };
341
342 self.command.resize(TPM_MAX_COMMAND_SIZE, 0);
343
344 let len = {
345 let mut writer = TpmWriter::new(&mut self.command);
346 tpm_marshal_command(command, tag, sessions, &mut writer)
347 .map_err(TpmDeviceError::Marshal)?;
348 writer.len()
349 };
350 self.command.truncate(len);
351
352 trace!("{} C: {}", cc, hex::encode(&self.command));
353 Ok(())
354 }
355
356 fn get_capability<T, F, N>(
365 &mut self,
366 cap: TpmCap,
367 property_start: u32,
368 count: u32,
369 mut extract: F,
370 next_prop: N,
371 ) -> Result<Vec<T>, TpmDeviceError>
372 where
373 T: Copy,
374 F: for<'a> FnMut(&'a TpmuCapabilities) -> Result<&'a [T], TpmDeviceError>,
375 N: Fn(&T) -> u32,
376 {
377 let mut results = Vec::new();
378 let mut prop = property_start;
379 loop {
380 let (more_data, cap_data) =
381 self.get_capability_page(cap, TpmUint32(prop), TpmUint32(count))?;
382 let items: &[T] = extract(&cap_data.data)?;
383 results.extend_from_slice(items);
384
385 if more_data {
386 if let Some(last) = items.last() {
387 prop = next_prop(last);
388 } else {
389 break;
390 }
391 } else {
392 break;
393 }
394 }
395 Ok(results)
396 }
397
398 pub fn fetch_algorithm_properties(&mut self) -> Result<Vec<TpmsAlgProperty>, TpmDeviceError> {
409 self.get_capability(
410 TpmCap::Algs,
411 0,
412 u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
413 |caps| match caps {
414 TpmuCapabilities::Algs(algs) => Ok(algs),
415 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Algs)),
416 },
417 |last| last.alg as u32 + 1,
418 )
419 }
420
421 pub fn fetch_handles(&mut self, class: TpmHt) -> Result<Vec<TpmHandle>, TpmDeviceError> {
432 self.get_capability(
433 TpmCap::Handles,
434 (class as u32) << 24,
435 u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
436 |caps| match caps {
437 TpmuCapabilities::Handles(handles) => Ok(handles),
438 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Handles)),
439 },
440 |last| last.value() + 1,
441 )
442 .map(|handles| handles.into_iter().collect())
443 }
444
445 pub fn fetch_ecc_curves(&mut self) -> Result<Vec<TpmEccCurve>, TpmDeviceError> {
456 self.get_capability(
457 TpmCap::EccCurves,
458 0,
459 u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
460 |caps| match caps {
461 TpmuCapabilities::EccCurves(curves) => Ok(curves),
462 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::EccCurves)),
463 },
464 |last| *last as u32 + 1,
465 )
466 }
467
468 pub fn fetch_pcr_bank_list(
483 &mut self,
484 ) -> Result<(Vec<TpmAlgId>, TpmsPcrSelect), TpmDeviceError> {
485 let pcrs: Vec<TpmsPcrSelection> = self.get_capability(
486 TpmCap::Pcrs,
487 0,
488 u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
489 |caps| match caps {
490 TpmuCapabilities::Pcrs(pcrs) => Ok(pcrs),
491 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Pcrs)),
492 },
493 |last| last.hash as u32 + 1,
494 )?;
495
496 if pcrs.is_empty() {
497 return Err(TpmDeviceError::PcrBanksNotAvailable);
498 }
499
500 let mut common_select: Option<TpmsPcrSelect> = None;
501 let mut algs = Vec::with_capacity(pcrs.len());
502
503 for bank in pcrs {
504 if bank.pcr_select.iter().all(|&b| b == 0) {
505 debug!(
506 "skipping unallocated bank {:?} (mask: {})",
507 bank.hash,
508 hex::encode(&*bank.pcr_select)
509 );
510 continue;
511 }
512
513 if let Some(ref select) = common_select {
514 if bank.pcr_select != *select {
515 return Err(TpmDeviceError::PcrBankSelectionMismatch);
516 }
517 } else {
518 common_select = Some(bank.pcr_select);
519 }
520 algs.push(bank.hash);
521 }
522
523 let select = common_select.ok_or(TpmDeviceError::PcrBanksNotAvailable)?;
524
525 algs.sort();
526 Ok((algs, select))
527 }
528
529 fn get_capability_page(
539 &mut self,
540 cap: TpmCap,
541 property: TpmUint32,
542 property_count: TpmUint32,
543 ) -> Result<(bool, TpmsCapabilityData), TpmDeviceError> {
544 let cmd = TpmGetCapabilityCommand {
545 cap,
546 property,
547 property_count,
548 handles: [],
549 };
550
551 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
552 let TpmGetCapabilityResponse {
553 more_data,
554 capability_data,
555 handles: [],
556 } = resp
557 .GetCapability()
558 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::GetCapability))?;
559
560 Ok((more_data.into(), capability_data))
561 }
562
563 pub fn get_tpm_property(&mut self, property: TpmPt) -> Result<TpmUint32, TpmDeviceError> {
572 let (_, cap_data) = self.get_capability_page(
573 TpmCap::TpmProperties,
574 TpmUint32(property as u32),
575 TpmUint32(1),
576 )?;
577
578 let TpmuCapabilities::TpmProperties(props) = &cap_data.data else {
579 return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
580 };
581
582 let Some(prop) = props.iter().find(|prop| prop.property == property) else {
583 return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
584 };
585
586 Ok(prop.value)
587 }
588
589 pub fn read_public(
598 &mut self,
599 handle: TpmHandle,
600 ) -> Result<(TpmtPublic, Tpm2bName), TpmDeviceError> {
601 if let Some(cached) = self.name_cache.get(&handle.0) {
602 return Ok(cached.clone());
603 }
604
605 let cmd = TpmReadPublicCommand { handles: [handle] };
606 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
607
608 let read_public_resp = resp
609 .ReadPublic()
610 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ReadPublic))?;
611
612 let public = read_public_resp.out_public.inner;
613 let name = read_public_resp.name;
614
615 self.name_cache.insert(handle.0, (public.clone(), name));
616 Ok((public, name))
617 }
618
619 pub fn find_persistent(
631 &mut self,
632 target_name: &Tpm2bName,
633 ) -> Result<Option<TpmHandle>, TpmDeviceError> {
634 for handle in self.fetch_handles(TpmHt::Persistent)? {
635 match self.read_public(handle) {
636 Ok((_, name)) => {
637 if name == *target_name {
638 return Ok(Some(handle));
639 }
640 }
641 Err(TpmDeviceError::TpmRc(rc)) => {
642 let base = rc.base();
643 if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
644 continue;
645 }
646 return Err(TpmDeviceError::TpmRc(rc));
647 }
648 Err(e) => return Err(e),
649 }
650 }
651 Ok(None)
652 }
653
654 pub fn save_context(&mut self, save_handle: TpmHandle) -> Result<TpmsContext, TpmDeviceError> {
663 let cmd = TpmContextSaveCommand {
664 handles: [save_handle],
665 };
666 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
667 let save_resp = resp
668 .ContextSave()
669 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextSave))?;
670 Ok(save_resp.context)
671 }
672
673 pub fn load_context(&mut self, context: TpmsContext) -> Result<TpmHandle, TpmDeviceError> {
682 let cmd = TpmContextLoadCommand {
683 context,
684 handles: [],
685 };
686 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
687 let resp_inner = resp
688 .ContextLoad()
689 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextLoad))?;
690 Ok(resp_inner.handles[0])
691 }
692
693 pub fn flush_context(&mut self, handle: TpmHandle) -> Result<(), TpmDeviceError> {
701 self.name_cache.remove(&handle.0);
702 let cmd = TpmFlushContextCommand {
703 flush_handle: handle,
704 handles: [],
705 };
706 self.transmit(&cmd, Self::NO_SESSIONS)?;
707 Ok(())
708 }
709
710 pub fn flush_session(&mut self, context: TpmsContext) -> Result<(), TpmDeviceError> {
722 match self.load_context(context) {
723 Ok(handle) => self.flush_context(handle),
724 Err(TpmDeviceError::TpmRc(rc)) => {
725 let base = rc.base();
726 if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
727 Ok(())
728 } else {
729 Err(TpmDeviceError::TpmRc(rc))
730 }
731 }
732 Err(e) => Err(e),
733 }
734 }
735
736 pub fn refresh_key(&mut self, context: TpmsContext) -> Result<bool, TpmDeviceError> {
748 match self.load_context(context) {
749 Ok(handle) => match self.flush_context(handle) {
750 Ok(()) => Ok(true),
751 Err(e) => Err(e),
752 },
753 Err(TpmDeviceError::TpmRc(rc)) if rc.base() == TpmRcBase::ReferenceH0 => Ok(false),
754 Err(e) => Err(e),
755 }
756 }
757}