1#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7
8use nix::{
9 fcntl,
10 poll::{poll, PollFd, PollFlags},
11};
12use std::{
13 cell::RefCell,
14 collections::HashMap,
15 fs::{File, OpenOptions},
16 io::{Read, Write},
17 os::fd::{AsFd, AsRawFd},
18 path::{Path, PathBuf},
19 rc::Rc,
20 time::{Duration, Instant},
21};
22
23use thiserror::Error;
24use tpm2_protocol::{
25 constant::{MAX_HANDLES, TPM_MAX_COMMAND_SIZE},
26 data::{
27 Tpm2bName, TpmAlgId, TpmCap, TpmCc, TpmEccCurve, TpmHt, TpmPt, TpmRc, TpmRcBase, TpmSt,
28 TpmsAlgProperty, TpmsAuthCommand, TpmsCapabilityData, TpmsContext, TpmsPcrSelection,
29 TpmtPublic, TpmuCapabilities,
30 },
31 frame::{
32 tpm_marshal_command, tpm_unmarshal_response, TpmAuthResponses, TpmContextLoadCommand,
33 TpmContextSaveCommand, TpmFlushContextCommand, TpmFrame, TpmGetCapabilityCommand,
34 TpmGetCapabilityResponse, TpmReadPublicCommand, TpmResponse,
35 },
36 TpmHandle, TpmWriter,
37};
38use tracing::trace;
39
40#[derive(Debug, Error)]
42pub enum TpmDeviceError {
43 #[error("device is already borrowed")]
44 AlreadyBorrowed,
45 #[error("capability not found: {0}")]
46 CapabilityMissing(TpmCap),
47 #[error("operation interrupted by user")]
48 Interrupted,
49 #[error("invalid response")]
50 InvalidResponse,
51
52 #[error("I/O: {0}")]
53 Io(#[from] std::io::Error),
54
55 #[error("marshal: {0}")]
57 Marshal(tpm2_protocol::TpmProtocolError),
58
59 #[error("device not available")]
60 NotAvailable,
61 #[error("operation failed")]
62 OperationFailed,
63 #[error("PCR banks not available")]
64 PcrBanksNotAvailable,
65 #[error("PCR bank size mismatch")]
66 PcrBankSizeMismatch,
67
68 #[error("response mismatch: {0}")]
70 ResponseMismatch(TpmCc),
71
72 #[error("TPM command timed out")]
73 Timeout,
74 #[error("TPM return code: {0}")]
75 TpmRc(TpmRc),
76
77 #[error("unmarshal: {0}")]
79 Unmarshal(tpm2_protocol::TpmProtocolError),
80
81 #[error("unexpected EOF")]
82 UnexpectedEof,
83}
84
85impl From<TpmRc> for TpmDeviceError {
86 fn from(rc: TpmRc) -> Self {
87 Self::TpmRc(rc)
88 }
89}
90
91impl From<nix::Error> for TpmDeviceError {
92 fn from(err: nix::Error) -> Self {
93 Self::Io(std::io::Error::from_raw_os_error(err as i32))
94 }
95}
96
97pub fn with_device<F, T, E>(device: Option<Rc<RefCell<TpmDevice>>>, f: F) -> Result<T, E>
109where
110 F: FnOnce(&mut TpmDevice) -> Result<T, E>,
111 E: From<TpmDeviceError>,
112{
113 let device_rc = device.ok_or(TpmDeviceError::NotAvailable)?;
114 let mut device_guard = device_rc
115 .try_borrow_mut()
116 .map_err(|_| TpmDeviceError::AlreadyBorrowed)?;
117 f(&mut device_guard)
118}
119
120pub struct TpmDeviceBuilder {
122 path: PathBuf,
123 timeout: Duration,
124 interrupted: Box<dyn Fn() -> bool>,
125}
126
127impl Default for TpmDeviceBuilder {
128 fn default() -> Self {
129 Self {
130 path: PathBuf::from("/dev/tpmrm0"),
131 timeout: Duration::from_secs(120),
132 interrupted: Box::new(|| false),
133 }
134 }
135}
136
137impl TpmDeviceBuilder {
138 #[must_use]
140 pub fn with_path<P: AsRef<Path>>(mut self, path: P) -> Self {
141 self.path = path.as_ref().to_path_buf();
142 self
143 }
144
145 #[must_use]
147 pub fn with_timeout(mut self, timeout: Duration) -> Self {
148 self.timeout = timeout;
149 self
150 }
151
152 #[must_use]
154 pub fn with_interrupted<F>(mut self, handler: F) -> Self
155 where
156 F: Fn() -> bool + 'static,
157 {
158 self.interrupted = Box::new(handler);
159 self
160 }
161
162 pub fn build(self) -> Result<TpmDevice, TpmDeviceError> {
169 let file = OpenOptions::new()
170 .read(true)
171 .write(true)
172 .open(&self.path)
173 .map_err(TpmDeviceError::Io)?;
174
175 let fd = file.as_raw_fd();
176 let flags = fcntl::fcntl(fd, fcntl::FcntlArg::F_GETFL)?;
177 let mut oflags = fcntl::OFlag::from_bits_truncate(flags);
178 oflags.insert(fcntl::OFlag::O_NONBLOCK);
179 fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(oflags))?;
180
181 Ok(TpmDevice {
182 file,
183 name_cache: HashMap::new(),
184 interrupted: self.interrupted,
185 timeout: self.timeout,
186 command: Vec::with_capacity(TPM_MAX_COMMAND_SIZE as usize),
187 response: Vec::with_capacity(TPM_MAX_COMMAND_SIZE as usize),
188 })
189 }
190}
191
192pub struct TpmDevice {
193 file: File,
194 name_cache: HashMap<u32, (TpmtPublic, Tpm2bName)>,
195 interrupted: Box<dyn Fn() -> bool>,
196 timeout: Duration,
197 command: Vec<u8>,
198 response: Vec<u8>,
199}
200
201impl std::fmt::Debug for TpmDevice {
202 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203 f.debug_struct("Device")
204 .field("file", &self.file)
205 .field("name_cache", &self.name_cache)
206 .field("timeout", &self.timeout)
207 .finish_non_exhaustive()
208 }
209}
210
211impl TpmDevice {
212 const NO_SESSIONS: &'static [TpmsAuthCommand] = &[];
213
214 #[must_use]
216 pub fn builder() -> TpmDeviceBuilder {
217 TpmDeviceBuilder::default()
218 }
219
220 fn receive(&mut self, buf: &mut [u8]) -> Result<usize, TpmDeviceError> {
221 let fd = self.file.as_fd();
222 let mut fds = [PollFd::new(fd, PollFlags::POLLIN)];
223
224 let num_events = match poll(&mut fds, 100u16) {
225 Ok(num) => num,
226 Err(nix::Error::EINTR) => return Ok(0),
227 Err(e) => return Err(e.into()),
228 };
229
230 if num_events == 0 {
231 return Ok(0);
232 }
233
234 let revents = fds[0].revents().unwrap_or(PollFlags::empty());
235
236 if revents.intersects(PollFlags::POLLERR | PollFlags::POLLNVAL) {
237 return Err(TpmDeviceError::UnexpectedEof);
238 }
239
240 if revents.contains(PollFlags::POLLIN) {
241 match self.file.read(buf) {
242 Ok(0) => Err(TpmDeviceError::UnexpectedEof),
243 Ok(n) => Ok(n),
244 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(0),
245 Err(e) if e.kind() == std::io::ErrorKind::Interrupted => Ok(0),
246 Err(e) => Err(e.into()),
247 }
248 } else if revents.contains(PollFlags::POLLHUP) {
249 Err(TpmDeviceError::UnexpectedEof)
250 } else {
251 Ok(0)
252 }
253 }
254
255 pub fn transmit<C: TpmFrame>(
275 &mut self,
276 command: &C,
277 sessions: &[TpmsAuthCommand],
278 ) -> Result<(TpmResponse, TpmAuthResponses), TpmDeviceError> {
279 self.prepare_command(command, sessions)?;
280 let cc = command.cc();
281
282 self.file.write_all(&self.command)?;
283 self.file.flush()?;
284
285 let start_time = Instant::now();
286 self.response.clear();
287 let mut total_size: Option<usize> = None;
288 let mut temp_buf = [0u8; 1024];
289
290 loop {
291 if (self.interrupted)() {
292 return Err(TpmDeviceError::Interrupted);
293 }
294 if start_time.elapsed() > self.timeout {
295 return Err(TpmDeviceError::Timeout);
296 }
297
298 let n = self.receive(&mut temp_buf)?;
299 if n > 0 {
300 self.response.extend_from_slice(&temp_buf[..n]);
301 }
302
303 if total_size.is_none() && self.response.len() >= 10 {
304 let Ok(size_bytes): Result<[u8; 4], _> = self.response[2..6].try_into() else {
305 return Err(TpmDeviceError::InvalidResponse);
306 };
307 let size = u32::from_be_bytes(size_bytes) as usize;
308 if !(10..=TPM_MAX_COMMAND_SIZE as usize).contains(&size) {
309 return Err(TpmDeviceError::InvalidResponse);
310 }
311 total_size = Some(size);
312 }
313
314 if let Some(size) = total_size {
315 if self.response.len() == size {
316 break;
317 }
318 if self.response.len() > size {
319 return Err(TpmDeviceError::InvalidResponse);
320 }
321 }
322 }
323
324 let result = tpm_unmarshal_response(cc, &self.response).map_err(TpmDeviceError::Unmarshal);
325 trace!("{} R: {}", cc, hex::encode(&self.response));
326 Ok(result??)
327 }
328
329 fn prepare_command<C: TpmFrame>(
330 &mut self,
331 command: &C,
332 sessions: &[TpmsAuthCommand],
333 ) -> Result<(), TpmDeviceError> {
334 let cc = command.cc();
335 let tag = if sessions.is_empty() {
336 TpmSt::NoSessions
337 } else {
338 TpmSt::Sessions
339 };
340
341 self.command.resize(TPM_MAX_COMMAND_SIZE as usize, 0);
342
343 let len = {
344 let mut writer = TpmWriter::new(&mut self.command);
345 tpm_marshal_command(command, tag, sessions, &mut writer)
346 .map_err(TpmDeviceError::Marshal)?;
347 writer.len()
348 };
349 self.command.truncate(len);
350
351 trace!("{} C: {}", cc, hex::encode(&self.command));
352 Ok(())
353 }
354
355 fn get_capability<T, F, N>(
364 &mut self,
365 cap: TpmCap,
366 property_start: u32,
367 count: u32,
368 mut extract: F,
369 next_prop: N,
370 ) -> Result<Vec<T>, TpmDeviceError>
371 where
372 T: Copy,
373 F: for<'a> FnMut(&'a TpmuCapabilities) -> Result<&'a [T], TpmDeviceError>,
374 N: Fn(&T) -> u32,
375 {
376 let mut results = Vec::new();
377 let mut prop = property_start;
378 loop {
379 let (more_data, cap_data) = self.get_capability_page(cap, prop, count)?;
380 let items: &[T] = extract(&cap_data.data)?;
381 results.extend_from_slice(items);
382
383 if more_data {
384 if let Some(last) = items.last() {
385 prop = next_prop(last);
386 } else {
387 break;
388 }
389 } else {
390 break;
391 }
392 }
393 Ok(results)
394 }
395
396 pub fn fetch_algorithm_properties(&mut self) -> Result<Vec<TpmsAlgProperty>, TpmDeviceError> {
407 self.get_capability(
408 TpmCap::Algs,
409 0,
410 u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
411 |caps| match caps {
412 TpmuCapabilities::Algs(algs) => Ok(algs),
413 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Algs)),
414 },
415 |last| last.alg as u32 + 1,
416 )
417 }
418
419 pub fn fetch_handles(&mut self, class: TpmHt) -> Result<Vec<TpmHandle>, TpmDeviceError> {
430 self.get_capability(
431 TpmCap::Handles,
432 (class as u32) << 24,
433 u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
434 |caps| match caps {
435 TpmuCapabilities::Handles(handles) => Ok(handles),
436 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Handles)),
437 },
438 |last| *last + 1,
439 )
440 .map(|handles| handles.into_iter().map(TpmHandle).collect())
441 }
442
443 pub fn fetch_ecc_curves(&mut self) -> Result<Vec<TpmEccCurve>, TpmDeviceError> {
454 self.get_capability(
455 TpmCap::EccCurves,
456 0,
457 u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
458 |caps| match caps {
459 TpmuCapabilities::EccCurves(curves) => Ok(curves),
460 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::EccCurves)),
461 },
462 |last| *last as u32 + 1,
463 )
464 }
465
466 pub fn fetch_pcr_bank_list(&mut self) -> Result<(usize, Vec<TpmAlgId>), TpmDeviceError> {
481 let pcrs: Vec<TpmsPcrSelection> = self.get_capability(
482 TpmCap::Pcrs,
483 0,
484 u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
485 |caps| match caps {
486 TpmuCapabilities::Pcrs(pcrs) => Ok(pcrs),
487 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Pcrs)),
488 },
489 |last| last.hash as u32 + 1,
490 )?;
491
492 if pcrs.is_empty() {
493 return Err(TpmDeviceError::PcrBanksNotAvailable);
494 }
495
496 let mut count = 0;
497 let mut algs = Vec::with_capacity(pcrs.len());
498
499 for bank in pcrs {
500 let next_count = bank.pcr_select.len();
501 if count == 0 {
502 count = next_count;
503 }
504 if next_count != count {
505 return Err(TpmDeviceError::PcrBankSizeMismatch);
506 }
507 algs.push(bank.hash);
508 }
509
510 algs.sort();
511 Ok((count, algs))
512 }
513
514 fn get_capability_page(
524 &mut self,
525 cap: TpmCap,
526 property: u32,
527 property_count: u32,
528 ) -> Result<(bool, TpmsCapabilityData), TpmDeviceError> {
529 let cmd = TpmGetCapabilityCommand {
530 cap,
531 property,
532 property_count,
533 handles: [],
534 };
535
536 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
537 let TpmGetCapabilityResponse {
538 more_data,
539 capability_data,
540 handles: [],
541 } = resp
542 .GetCapability()
543 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::GetCapability))?;
544
545 Ok((more_data.into(), capability_data))
546 }
547
548 pub fn get_tpm_property(&mut self, property: TpmPt) -> Result<u32, TpmDeviceError> {
557 let (_, cap_data) = self.get_capability_page(TpmCap::TpmProperties, property as u32, 1)?;
558
559 let TpmuCapabilities::TpmProperties(props) = &cap_data.data else {
560 return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
561 };
562
563 let Some(prop) = props.iter().find(|prop| prop.property == property) else {
564 return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
565 };
566
567 Ok(prop.value)
568 }
569
570 pub fn read_public(
579 &mut self,
580 handle: TpmHandle,
581 ) -> Result<(TpmtPublic, Tpm2bName), TpmDeviceError> {
582 if let Some(cached) = self.name_cache.get(&handle.0) {
583 return Ok(cached.clone());
584 }
585
586 let cmd = TpmReadPublicCommand { handles: [handle] };
587 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
588
589 let read_public_resp = resp
590 .ReadPublic()
591 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ReadPublic))?;
592
593 let public = read_public_resp.out_public.inner;
594 let name = read_public_resp.name;
595
596 self.name_cache.insert(handle.0, (public.clone(), name));
597 Ok((public, name))
598 }
599
600 pub fn find_persistent(
612 &mut self,
613 target_name: &Tpm2bName,
614 ) -> Result<Option<TpmHandle>, TpmDeviceError> {
615 for handle in self.fetch_handles(TpmHt::Persistent)? {
616 match self.read_public(handle) {
617 Ok((_, name)) => {
618 if name == *target_name {
619 return Ok(Some(handle));
620 }
621 }
622 Err(TpmDeviceError::TpmRc(rc)) => {
623 let base = rc.base();
624 if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
625 continue;
626 }
627 return Err(TpmDeviceError::TpmRc(rc));
628 }
629 Err(e) => return Err(e),
630 }
631 }
632 Ok(None)
633 }
634
635 pub fn save_context(&mut self, save_handle: TpmHandle) -> Result<TpmsContext, TpmDeviceError> {
644 let cmd = TpmContextSaveCommand {
645 handles: [save_handle],
646 };
647 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
648 let save_resp = resp
649 .ContextSave()
650 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextSave))?;
651 Ok(save_resp.context)
652 }
653
654 pub fn load_context(&mut self, context: TpmsContext) -> Result<TpmHandle, TpmDeviceError> {
663 let cmd = TpmContextLoadCommand {
664 context,
665 handles: [],
666 };
667 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
668 let resp_inner = resp
669 .ContextLoad()
670 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextLoad))?;
671 Ok(resp_inner.handles[0])
672 }
673
674 pub fn flush_context(&mut self, handle: TpmHandle) -> Result<(), TpmDeviceError> {
682 self.name_cache.remove(&handle.0);
683 let cmd = TpmFlushContextCommand {
684 flush_handle: handle,
685 handles: [],
686 };
687 self.transmit(&cmd, Self::NO_SESSIONS)?;
688 Ok(())
689 }
690
691 pub fn flush_session(&mut self, context: TpmsContext) -> Result<(), TpmDeviceError> {
703 match self.load_context(context) {
704 Ok(handle) => self.flush_context(handle),
705 Err(TpmDeviceError::TpmRc(rc)) => {
706 let base = rc.base();
707 if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
708 Ok(())
709 } else {
710 Err(TpmDeviceError::TpmRc(rc))
711 }
712 }
713 Err(e) => Err(e),
714 }
715 }
716
717 pub fn refresh_key(&mut self, context: TpmsContext) -> Result<bool, TpmDeviceError> {
729 match self.load_context(context) {
730 Ok(handle) => match self.flush_context(handle) {
731 Ok(()) => Ok(true),
732 Err(e) => Err(e),
733 },
734 Err(TpmDeviceError::TpmRc(rc)) if rc.base() == TpmRcBase::ReferenceH0 => Ok(false),
735 Err(e) => Err(e),
736 }
737 }
738}