1#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7
8use nix::{
9 fcntl,
10 poll::{poll, PollFd, PollFlags},
11};
12use std::{
13 cell::RefCell,
14 collections::HashMap,
15 fs::{File, OpenOptions},
16 io::{Read, Write},
17 num::TryFromIntError,
18 os::fd::{AsFd, AsRawFd},
19 path::{Path, PathBuf},
20 rc::Rc,
21 time::{Duration, Instant},
22};
23
24use thiserror::Error;
25use tpm2_crypto::{tpm_make_name, TpmCryptoError};
26use tpm2_protocol::{
27 constant::{MAX_HANDLES, TPM_MAX_COMMAND_SIZE},
28 data::{
29 Tpm2bName, TpmCap, TpmCc, TpmEccCurve, TpmHt, TpmPt, TpmRc, TpmRcBase, TpmSt,
30 TpmsAlgProperty, TpmsAuthCommand, TpmsCapabilityData, TpmsContext, TpmsPcrSelection,
31 TpmtPublic, TpmuCapabilities,
32 },
33 frame::{
34 tpm_marshal_command, tpm_unmarshal_response, TpmAuthResponses, TpmContextLoadCommand,
35 TpmContextSaveCommand, TpmEvictControlCommand, TpmFlushContextCommand, TpmFrame,
36 TpmGetCapabilityCommand, TpmGetCapabilityResponse, TpmReadPublicCommand, TpmResponse,
37 },
38 TpmHandle, TpmWriter,
39};
40use tracing::trace;
41
42pub trait TpmCommandObject: TpmFrame {}
44impl<T> TpmCommandObject for T where T: TpmFrame {}
45
46#[derive(Debug, Error)]
48pub enum TpmDeviceError {
49 #[error("device is already borrowed")]
50 AlreadyBorrowed,
51 #[error("capability not found: {0}")]
52 CapabilityMissing(TpmCap),
53
54 #[error("crypto: {0}")]
56 Crypto(#[from] TpmCryptoError),
57
58 #[error("operation interrupted by user")]
59 Interrupted,
60 #[error("invalid response")]
61 InvalidResponse,
62 #[error("device not available")]
63 NotAvailable,
64
65 #[error("marshal: {0}")]
67 Marshal(tpm2_protocol::TpmProtocolError),
68
69 #[error("unmarshal: {0}")]
71 Unmarshal(tpm2_protocol::TpmProtocolError),
72
73 #[error("response mismatch: {0}")]
74 ResponseMismatch(TpmCc),
75 #[error("TPM command timed out")]
76 Timeout,
77 #[error("unexpected EOF")]
78 UnexpectedEof,
79 #[error("int decode: {0}")]
80 IntDecode(#[from] TryFromIntError),
81 #[error("I/O: {0}")]
82 Io(#[from] std::io::Error),
83 #[error("syscall: {0}")]
84 Nix(#[from] nix::Error),
85 #[error("TPM return code: {0}")]
86 TpmRc(TpmRc),
87}
88
89impl From<TpmRc> for TpmDeviceError {
90 fn from(rc: TpmRc) -> Self {
91 Self::TpmRc(rc)
92 }
93}
94
95pub fn with_device<F, T, E>(device: Option<Rc<RefCell<TpmDevice>>>, f: F) -> Result<T, E>
107where
108 F: FnOnce(&mut TpmDevice) -> Result<T, E>,
109 E: From<TpmDeviceError>,
110{
111 let device_rc = device.ok_or(TpmDeviceError::NotAvailable)?;
112 let mut device_guard = device_rc
113 .try_borrow_mut()
114 .map_err(|_| TpmDeviceError::AlreadyBorrowed)?;
115 f(&mut device_guard)
116}
117
118pub struct TpmDeviceBuilder {
120 path: PathBuf,
121 timeout: Duration,
122 interrupted: Box<dyn Fn() -> bool>,
123}
124
125impl Default for TpmDeviceBuilder {
126 fn default() -> Self {
127 Self {
128 path: PathBuf::from("/dev/tpmrm0"),
129 timeout: Duration::from_secs(120),
130 interrupted: Box::new(|| false),
131 }
132 }
133}
134
135impl TpmDeviceBuilder {
136 #[must_use]
138 pub fn with_path<P: AsRef<Path>>(mut self, path: P) -> Self {
139 self.path = path.as_ref().to_path_buf();
140 self
141 }
142
143 #[must_use]
145 pub fn with_timeout(mut self, timeout: Duration) -> Self {
146 self.timeout = timeout;
147 self
148 }
149
150 #[must_use]
152 pub fn with_interrupted<F>(mut self, handler: F) -> Self
153 where
154 F: Fn() -> bool + 'static,
155 {
156 self.interrupted = Box::new(handler);
157 self
158 }
159
160 pub fn build(self) -> Result<TpmDevice, TpmDeviceError> {
168 let file = OpenOptions::new()
169 .read(true)
170 .write(true)
171 .open(&self.path)
172 .map_err(TpmDeviceError::Io)?;
173
174 let fd = file.as_raw_fd();
175 let flags = fcntl::fcntl(fd, fcntl::FcntlArg::F_GETFL)?;
176 let mut oflags = fcntl::OFlag::from_bits_truncate(flags);
177 oflags.insert(fcntl::OFlag::O_NONBLOCK);
178 fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(oflags))?;
179
180 Ok(TpmDevice {
181 file,
182 name_cache: HashMap::new(),
183 interrupted: self.interrupted,
184 timeout: self.timeout,
185 command: Vec::with_capacity(TPM_MAX_COMMAND_SIZE as usize),
186 response: Vec::with_capacity(TPM_MAX_COMMAND_SIZE as usize),
187 })
188 }
189}
190
191pub struct TpmDevice {
192 file: File,
193 name_cache: HashMap<u32, (TpmtPublic, Tpm2bName)>,
194 interrupted: Box<dyn Fn() -> bool>,
195 timeout: Duration,
196 command: Vec<u8>,
197 response: Vec<u8>,
198}
199
200impl std::fmt::Debug for TpmDevice {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 f.debug_struct("Device")
203 .field("file", &self.file)
204 .field("name_cache", &self.name_cache)
205 .field("timeout", &self.timeout)
206 .finish_non_exhaustive()
207 }
208}
209
210impl TpmDevice {
211 const NO_SESSIONS: &'static [TpmsAuthCommand] = &[];
212
213 #[must_use]
215 pub fn builder() -> TpmDeviceBuilder {
216 TpmDeviceBuilder::default()
217 }
218
219 fn receive(&mut self, buf: &mut [u8]) -> Result<usize, TpmDeviceError> {
220 let fd = self.file.as_fd();
221 let mut fds = [PollFd::new(fd, PollFlags::POLLIN)];
222
223 let num_events = match poll(&mut fds, 100u16) {
224 Ok(num) => num,
225 Err(nix::Error::EINTR) => return Ok(0),
226 Err(e) => return Err(e.into()),
227 };
228
229 if num_events == 0 {
230 return Ok(0);
231 }
232
233 let revents = fds[0].revents().unwrap_or(PollFlags::empty());
234
235 if revents.intersects(PollFlags::POLLERR | PollFlags::POLLNVAL) {
236 return Err(TpmDeviceError::UnexpectedEof);
237 }
238
239 if revents.contains(PollFlags::POLLIN) {
240 match self.file.read(buf) {
241 Ok(0) => Err(TpmDeviceError::UnexpectedEof),
242 Ok(n) => Ok(n),
243 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(0),
244 Err(e) if e.kind() == std::io::ErrorKind::Interrupted => Ok(0),
245 Err(e) => Err(e.into()),
246 }
247 } else if revents.contains(PollFlags::POLLHUP) {
248 Err(TpmDeviceError::UnexpectedEof)
249 } else {
250 Ok(0)
251 }
252 }
253
254 pub fn transmit<C: TpmCommandObject>(
275 &mut self,
276 command: &C,
277 sessions: &[TpmsAuthCommand],
278 ) -> Result<(TpmResponse, TpmAuthResponses), TpmDeviceError> {
279 self.prepare_command(command, sessions)?;
280 let cc = command.cc();
281
282 self.file.write_all(&self.command)?;
283 self.file.flush()?;
284
285 let start_time = Instant::now();
286 self.response.clear();
287 let mut total_size: Option<usize> = None;
288 let mut temp_buf = [0u8; 1024];
289
290 loop {
291 if (self.interrupted)() {
292 return Err(TpmDeviceError::Interrupted);
293 }
294 if start_time.elapsed() > self.timeout {
295 return Err(TpmDeviceError::Timeout);
296 }
297
298 let n = self.receive(&mut temp_buf)?;
299 if n > 0 {
300 self.response.extend_from_slice(&temp_buf[..n]);
301 }
302
303 if total_size.is_none() && self.response.len() >= 10 {
304 let Ok(size_bytes): Result<[u8; 4], _> = self.response[2..6].try_into() else {
305 return Err(TpmDeviceError::InvalidResponse);
306 };
307 let size = u32::from_be_bytes(size_bytes) as usize;
308 if !(10..=TPM_MAX_COMMAND_SIZE as usize).contains(&size) {
309 return Err(TpmDeviceError::InvalidResponse);
310 }
311 total_size = Some(size);
312 }
313
314 if let Some(size) = total_size {
315 if self.response.len() == size {
316 break;
317 }
318 if self.response.len() > size {
319 return Err(TpmDeviceError::InvalidResponse);
320 }
321 }
322 }
323
324 let result = tpm_unmarshal_response(cc, &self.response).map_err(TpmDeviceError::Unmarshal);
325 trace!("{} R: {}", cc, hex::encode(&self.response));
326 Ok(result??)
327 }
328
329 fn prepare_command<C: TpmCommandObject>(
330 &mut self,
331 command: &C,
332 sessions: &[TpmsAuthCommand],
333 ) -> Result<(), TpmDeviceError> {
334 let cc = command.cc();
335 let tag = if sessions.is_empty() {
336 TpmSt::NoSessions
337 } else {
338 TpmSt::Sessions
339 };
340
341 self.command.resize(TPM_MAX_COMMAND_SIZE as usize, 0);
342
343 let len = {
344 let mut writer = TpmWriter::new(&mut self.command);
345 tpm_marshal_command(command, tag, sessions, &mut writer)
346 .map_err(TpmDeviceError::Marshal)?;
347 writer.len()
348 };
349 self.command.truncate(len);
350
351 trace!("{} C: {}", cc, hex::encode(&self.command));
352 Ok(())
353 }
354
355 fn get_capability<T, F, N>(
364 &mut self,
365 cap: TpmCap,
366 property_start: u32,
367 count: u32,
368 mut extract: F,
369 next_prop: N,
370 ) -> Result<Vec<T>, TpmDeviceError>
371 where
372 T: Copy,
373 F: for<'a> FnMut(&'a TpmuCapabilities) -> Result<&'a [T], TpmDeviceError>,
374 N: Fn(&T) -> u32,
375 {
376 let mut results = Vec::new();
377 let mut prop = property_start;
378 loop {
379 let (more_data, cap_data) = self.get_capability_page(cap, prop, count)?;
380 let items: &[T] = extract(&cap_data.data)?;
381 results.extend_from_slice(items);
382
383 if more_data {
384 if let Some(last) = items.last() {
385 prop = next_prop(last);
386 } else {
387 break;
388 }
389 } else {
390 break;
391 }
392 }
393 Ok(results)
394 }
395
396 pub fn fetch_algorithm_properties(&mut self) -> Result<Vec<TpmsAlgProperty>, TpmDeviceError> {
407 self.get_capability(
408 TpmCap::Algs,
409 0,
410 u32::try_from(MAX_HANDLES)?,
411 |caps| match caps {
412 TpmuCapabilities::Algs(algs) => Ok(algs),
413 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Algs)),
414 },
415 |last| last.alg as u32 + 1,
416 )
417 }
418
419 pub fn fetch_handles(&mut self, class: TpmHt) -> Result<Vec<TpmHandle>, TpmDeviceError> {
430 self.get_capability(
431 TpmCap::Handles,
432 (class as u32) << 24,
433 u32::try_from(MAX_HANDLES)?,
434 |caps| match caps {
435 TpmuCapabilities::Handles(handles) => Ok(handles),
436 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Handles)),
437 },
438 |last| *last + 1,
439 )
440 .map(|handles| handles.into_iter().map(TpmHandle).collect())
441 }
442
443 pub fn fetch_ecc_curves(&mut self) -> Result<Vec<TpmEccCurve>, TpmDeviceError> {
454 self.get_capability(
455 TpmCap::EccCurves,
456 0,
457 u32::try_from(MAX_HANDLES)?,
458 |caps| match caps {
459 TpmuCapabilities::EccCurves(curves) => Ok(curves),
460 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::EccCurves)),
461 },
462 |last| *last as u32 + 1,
463 )
464 }
465
466 pub fn fetch_pcr_banks(&mut self) -> Result<Vec<TpmsPcrSelection>, TpmDeviceError> {
477 self.get_capability(
478 TpmCap::Pcrs,
479 0,
480 u32::try_from(MAX_HANDLES)?,
481 |caps| match caps {
482 TpmuCapabilities::Pcrs(pcrs) => Ok(pcrs),
483 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Pcrs)),
484 },
485 |last| last.hash as u32 + 1,
486 )
487 }
488
489 fn get_capability_page(
499 &mut self,
500 cap: TpmCap,
501 property: u32,
502 count: u32,
503 ) -> Result<(bool, TpmsCapabilityData), TpmDeviceError> {
504 let cmd = TpmGetCapabilityCommand {
505 cap,
506 property,
507 property_count: count,
508 handles: [],
509 };
510
511 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
512 let TpmGetCapabilityResponse {
513 more_data,
514 capability_data,
515 handles: [],
516 } = resp
517 .GetCapability()
518 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::GetCapability))?;
519
520 Ok((more_data.into(), capability_data))
521 }
522
523 pub fn get_tpm_property(&mut self, property: TpmPt) -> Result<u32, TpmDeviceError> {
532 let (_, cap_data) = self.get_capability_page(TpmCap::TpmProperties, property as u32, 1)?;
533
534 let TpmuCapabilities::TpmProperties(props) = &cap_data.data else {
535 return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
536 };
537
538 let Some(prop) = props.first() else {
539 return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
540 };
541
542 Ok(prop.value)
543 }
544
545 pub fn read_public(
554 &mut self,
555 handle: TpmHandle,
556 ) -> Result<(TpmtPublic, Tpm2bName), TpmDeviceError> {
557 if let Some(cached) = self.name_cache.get(&handle.0) {
558 return Ok(cached.clone());
559 }
560
561 let cmd = TpmReadPublicCommand { handles: [handle] };
562 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
563
564 let read_public_resp = resp
565 .ReadPublic()
566 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ReadPublic))?;
567
568 let public = read_public_resp.out_public.inner;
569 let name = read_public_resp.name;
570
571 self.name_cache.insert(handle.0, (public.clone(), name));
572 Ok((public, name))
573 }
574
575 pub fn find_persistent(
587 &mut self,
588 target: &TpmtPublic,
589 ) -> Result<Option<(TpmHandle, Tpm2bName)>, TpmDeviceError> {
590 for handle in self.fetch_handles(TpmHt::Persistent)? {
591 match self.read_public(handle) {
592 Ok((public, name)) => {
593 if public == *target {
594 return Ok(Some((handle, name)));
595 }
596 }
597 Err(TpmDeviceError::TpmRc(rc)) => {
598 let base = rc.base();
599 if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
600 continue;
601 }
602 return Err(TpmDeviceError::TpmRc(rc));
603 }
604 Err(e) => return Err(e),
605 }
606 }
607 Ok(None)
608 }
609
610 pub fn find_persistent_by_name(
624 &mut self,
625 target_name: &Tpm2bName,
626 ) -> Result<Option<TpmHandle>, TpmDeviceError> {
627 for handle in self.fetch_handles(TpmHt::Persistent)? {
628 match self.read_public(handle) {
629 Ok((public, name)) => {
630 if name == *target_name {
631 return Ok(Some(handle));
632 }
633 let calculated_name = tpm_make_name(&public)?;
634 if calculated_name == *target_name {
635 return Ok(Some(handle));
636 }
637 }
638 Err(TpmDeviceError::TpmRc(rc)) => {
639 let base = rc.base();
640 if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
641 continue;
642 }
643 return Err(TpmDeviceError::TpmRc(rc));
644 }
645 Err(e) => return Err(e),
646 }
647 }
648 Ok(None)
649 }
650
651 pub fn save_context(&mut self, save_handle: TpmHandle) -> Result<TpmsContext, TpmDeviceError> {
660 let cmd = TpmContextSaveCommand {
661 handles: [save_handle],
662 };
663 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
664 let save_resp = resp
665 .ContextSave()
666 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextSave))?;
667 Ok(save_resp.context)
668 }
669
670 pub fn load_context(&mut self, context: TpmsContext) -> Result<TpmHandle, TpmDeviceError> {
679 let cmd = TpmContextLoadCommand {
680 context,
681 handles: [],
682 };
683 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
684 let resp_inner = resp
685 .ContextLoad()
686 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextLoad))?;
687 Ok(resp_inner.handles[0])
688 }
689
690 pub fn flush_context(&mut self, handle: TpmHandle) -> Result<(), TpmDeviceError> {
698 self.name_cache.remove(&handle.0);
699 let cmd = TpmFlushContextCommand {
700 flush_handle: handle,
701 handles: [],
702 };
703 self.transmit(&cmd, Self::NO_SESSIONS)?;
704 Ok(())
705 }
706
707 pub fn flush_session(&mut self, context: TpmsContext) -> Result<(), TpmDeviceError> {
719 match self.load_context(context) {
720 Ok(handle) => self.flush_context(handle),
721 Err(TpmDeviceError::TpmRc(rc)) => {
722 let base = rc.base();
723 if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
724 Ok(())
725 } else {
726 Err(TpmDeviceError::TpmRc(rc))
727 }
728 }
729 Err(e) => Err(e),
730 }
731 }
732
733 pub fn evict_control(
742 &mut self,
743 auth: TpmHandle,
744 object_handle: TpmHandle,
745 persistent_handle: TpmHandle,
746 sessions: &[TpmsAuthCommand],
747 ) -> Result<(), TpmDeviceError> {
748 let cmd = TpmEvictControlCommand {
749 handles: [auth, object_handle],
750 persistent_handle,
751 };
752
753 let (resp, _) = self.transmit(&cmd, sessions)?;
754
755 resp.EvictControl()
756 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::EvictControl))?;
757
758 Ok(())
759 }
760
761 pub fn refresh_key(&mut self, context: TpmsContext) -> Result<bool, TpmDeviceError> {
773 match self.load_context(context) {
774 Ok(handle) => match self.flush_context(handle) {
775 Ok(()) => Ok(true),
776 Err(e) => Err(e),
777 },
778 Err(TpmDeviceError::TpmRc(rc)) if rc.base() == TpmRcBase::ReferenceH0 => Ok(false),
779 Err(e) => Err(e),
780 }
781 }
782}