1#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7
8use nix::{
9 fcntl,
10 poll::{poll, PollFd, PollFlags},
11};
12use std::{
13 cell::RefCell,
14 collections::HashMap,
15 fs::{File, OpenOptions},
16 io::{Read, Write},
17 num::TryFromIntError,
18 os::fd::{AsFd, AsRawFd},
19 path::{Path, PathBuf},
20 rc::Rc,
21 time::{Duration, Instant},
22};
23
24use thiserror::Error;
25use tpm2_protocol::{
26 constant::{MAX_HANDLES, TPM_MAX_COMMAND_SIZE},
27 data::{
28 Tpm2bName, TpmCap, TpmCc, TpmEccCurve, TpmHt, TpmPt, TpmRc, TpmRcBase, TpmSt,
29 TpmsAlgProperty, TpmsAuthCommand, TpmsCapabilityData, TpmsContext, TpmsPcrSelection,
30 TpmtPublic, TpmuCapabilities,
31 },
32 frame::{
33 tpm_marshal_command, tpm_unmarshal_response, TpmAuthResponses, TpmContextLoadCommand,
34 TpmContextSaveCommand, TpmEvictControlCommand, TpmFlushContextCommand, TpmFrame,
35 TpmGetCapabilityCommand, TpmGetCapabilityResponse, TpmReadPublicCommand, TpmResponse,
36 },
37 TpmHandle, TpmWriter,
38};
39use tracing::trace;
40
41#[derive(Debug, Error)]
43pub enum TpmDeviceError {
44 #[error("device is already borrowed")]
45 AlreadyBorrowed,
46 #[error("capability not found: {0}")]
47 CapabilityMissing(TpmCap),
48 #[error("operation interrupted by user")]
49 Interrupted,
50 #[error("invalid response")]
51 InvalidResponse,
52 #[error("device not available")]
53 NotAvailable,
54
55 #[error("marshal: {0}")]
57 Marshal(tpm2_protocol::TpmProtocolError),
58
59 #[error("unmarshal: {0}")]
61 Unmarshal(tpm2_protocol::TpmProtocolError),
62
63 #[error("response mismatch: {0}")]
64 ResponseMismatch(TpmCc),
65 #[error("TPM command timed out")]
66 Timeout,
67 #[error("unexpected EOF")]
68 UnexpectedEof,
69 #[error("int decode: {0}")]
70 IntDecode(#[from] TryFromIntError),
71 #[error("I/O: {0}")]
72 Io(#[from] std::io::Error),
73 #[error("syscall: {0}")]
74 Nix(#[from] nix::Error),
75 #[error("TPM return code: {0}")]
76 TpmRc(TpmRc),
77}
78
79impl From<TpmRc> for TpmDeviceError {
80 fn from(rc: TpmRc) -> Self {
81 Self::TpmRc(rc)
82 }
83}
84
85pub fn with_device<F, T, E>(device: Option<Rc<RefCell<TpmDevice>>>, f: F) -> Result<T, E>
97where
98 F: FnOnce(&mut TpmDevice) -> Result<T, E>,
99 E: From<TpmDeviceError>,
100{
101 let device_rc = device.ok_or(TpmDeviceError::NotAvailable)?;
102 let mut device_guard = device_rc
103 .try_borrow_mut()
104 .map_err(|_| TpmDeviceError::AlreadyBorrowed)?;
105 f(&mut device_guard)
106}
107
108pub struct TpmDeviceBuilder {
110 path: PathBuf,
111 timeout: Duration,
112 interrupted: Box<dyn Fn() -> bool>,
113}
114
115impl Default for TpmDeviceBuilder {
116 fn default() -> Self {
117 Self {
118 path: PathBuf::from("/dev/tpmrm0"),
119 timeout: Duration::from_secs(120),
120 interrupted: Box::new(|| false),
121 }
122 }
123}
124
125impl TpmDeviceBuilder {
126 #[must_use]
128 pub fn with_path<P: AsRef<Path>>(mut self, path: P) -> Self {
129 self.path = path.as_ref().to_path_buf();
130 self
131 }
132
133 #[must_use]
135 pub fn with_timeout(mut self, timeout: Duration) -> Self {
136 self.timeout = timeout;
137 self
138 }
139
140 #[must_use]
142 pub fn with_interrupted<F>(mut self, handler: F) -> Self
143 where
144 F: Fn() -> bool + 'static,
145 {
146 self.interrupted = Box::new(handler);
147 self
148 }
149
150 pub fn build(self) -> Result<TpmDevice, TpmDeviceError> {
158 let file = OpenOptions::new()
159 .read(true)
160 .write(true)
161 .open(&self.path)
162 .map_err(TpmDeviceError::Io)?;
163
164 let fd = file.as_raw_fd();
165 let flags = fcntl::fcntl(fd, fcntl::FcntlArg::F_GETFL)?;
166 let mut oflags = fcntl::OFlag::from_bits_truncate(flags);
167 oflags.insert(fcntl::OFlag::O_NONBLOCK);
168 fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(oflags))?;
169
170 Ok(TpmDevice {
171 file,
172 name_cache: HashMap::new(),
173 interrupted: self.interrupted,
174 timeout: self.timeout,
175 command: Vec::with_capacity(TPM_MAX_COMMAND_SIZE as usize),
176 response: Vec::with_capacity(TPM_MAX_COMMAND_SIZE as usize),
177 })
178 }
179}
180
181pub struct TpmDevice {
182 file: File,
183 name_cache: HashMap<u32, (TpmtPublic, Tpm2bName)>,
184 interrupted: Box<dyn Fn() -> bool>,
185 timeout: Duration,
186 command: Vec<u8>,
187 response: Vec<u8>,
188}
189
190impl std::fmt::Debug for TpmDevice {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 f.debug_struct("Device")
193 .field("file", &self.file)
194 .field("name_cache", &self.name_cache)
195 .field("timeout", &self.timeout)
196 .finish_non_exhaustive()
197 }
198}
199
200impl TpmDevice {
201 const NO_SESSIONS: &'static [TpmsAuthCommand] = &[];
202
203 #[must_use]
205 pub fn builder() -> TpmDeviceBuilder {
206 TpmDeviceBuilder::default()
207 }
208
209 fn receive(&mut self, buf: &mut [u8]) -> Result<usize, TpmDeviceError> {
210 let fd = self.file.as_fd();
211 let mut fds = [PollFd::new(fd, PollFlags::POLLIN)];
212
213 let num_events = match poll(&mut fds, 100u16) {
214 Ok(num) => num,
215 Err(nix::Error::EINTR) => return Ok(0),
216 Err(e) => return Err(e.into()),
217 };
218
219 if num_events == 0 {
220 return Ok(0);
221 }
222
223 let revents = fds[0].revents().unwrap_or(PollFlags::empty());
224
225 if revents.intersects(PollFlags::POLLERR | PollFlags::POLLNVAL) {
226 return Err(TpmDeviceError::UnexpectedEof);
227 }
228
229 if revents.contains(PollFlags::POLLIN) {
230 match self.file.read(buf) {
231 Ok(0) => Err(TpmDeviceError::UnexpectedEof),
232 Ok(n) => Ok(n),
233 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(0),
234 Err(e) if e.kind() == std::io::ErrorKind::Interrupted => Ok(0),
235 Err(e) => Err(e.into()),
236 }
237 } else if revents.contains(PollFlags::POLLHUP) {
238 Err(TpmDeviceError::UnexpectedEof)
239 } else {
240 Ok(0)
241 }
242 }
243
244 pub fn transmit<C: TpmFrame>(
265 &mut self,
266 command: &C,
267 sessions: &[TpmsAuthCommand],
268 ) -> Result<(TpmResponse, TpmAuthResponses), TpmDeviceError> {
269 self.prepare_command(command, sessions)?;
270 let cc = command.cc();
271
272 self.file.write_all(&self.command)?;
273 self.file.flush()?;
274
275 let start_time = Instant::now();
276 self.response.clear();
277 let mut total_size: Option<usize> = None;
278 let mut temp_buf = [0u8; 1024];
279
280 loop {
281 if (self.interrupted)() {
282 return Err(TpmDeviceError::Interrupted);
283 }
284 if start_time.elapsed() > self.timeout {
285 return Err(TpmDeviceError::Timeout);
286 }
287
288 let n = self.receive(&mut temp_buf)?;
289 if n > 0 {
290 self.response.extend_from_slice(&temp_buf[..n]);
291 }
292
293 if total_size.is_none() && self.response.len() >= 10 {
294 let Ok(size_bytes): Result<[u8; 4], _> = self.response[2..6].try_into() else {
295 return Err(TpmDeviceError::InvalidResponse);
296 };
297 let size = u32::from_be_bytes(size_bytes) as usize;
298 if !(10..=TPM_MAX_COMMAND_SIZE as usize).contains(&size) {
299 return Err(TpmDeviceError::InvalidResponse);
300 }
301 total_size = Some(size);
302 }
303
304 if let Some(size) = total_size {
305 if self.response.len() == size {
306 break;
307 }
308 if self.response.len() > size {
309 return Err(TpmDeviceError::InvalidResponse);
310 }
311 }
312 }
313
314 let result = tpm_unmarshal_response(cc, &self.response).map_err(TpmDeviceError::Unmarshal);
315 trace!("{} R: {}", cc, hex::encode(&self.response));
316 Ok(result??)
317 }
318
319 fn prepare_command<C: TpmFrame>(
320 &mut self,
321 command: &C,
322 sessions: &[TpmsAuthCommand],
323 ) -> Result<(), TpmDeviceError> {
324 let cc = command.cc();
325 let tag = if sessions.is_empty() {
326 TpmSt::NoSessions
327 } else {
328 TpmSt::Sessions
329 };
330
331 self.command.resize(TPM_MAX_COMMAND_SIZE as usize, 0);
332
333 let len = {
334 let mut writer = TpmWriter::new(&mut self.command);
335 tpm_marshal_command(command, tag, sessions, &mut writer)
336 .map_err(TpmDeviceError::Marshal)?;
337 writer.len()
338 };
339 self.command.truncate(len);
340
341 trace!("{} C: {}", cc, hex::encode(&self.command));
342 Ok(())
343 }
344
345 fn get_capability<T, F, N>(
354 &mut self,
355 cap: TpmCap,
356 property_start: u32,
357 count: u32,
358 mut extract: F,
359 next_prop: N,
360 ) -> Result<Vec<T>, TpmDeviceError>
361 where
362 T: Copy,
363 F: for<'a> FnMut(&'a TpmuCapabilities) -> Result<&'a [T], TpmDeviceError>,
364 N: Fn(&T) -> u32,
365 {
366 let mut results = Vec::new();
367 let mut prop = property_start;
368 loop {
369 let (more_data, cap_data) = self.get_capability_page(cap, prop, count)?;
370 let items: &[T] = extract(&cap_data.data)?;
371 results.extend_from_slice(items);
372
373 if more_data {
374 if let Some(last) = items.last() {
375 prop = next_prop(last);
376 } else {
377 break;
378 }
379 } else {
380 break;
381 }
382 }
383 Ok(results)
384 }
385
386 pub fn fetch_algorithm_properties(&mut self) -> Result<Vec<TpmsAlgProperty>, TpmDeviceError> {
397 self.get_capability(
398 TpmCap::Algs,
399 0,
400 u32::try_from(MAX_HANDLES)?,
401 |caps| match caps {
402 TpmuCapabilities::Algs(algs) => Ok(algs),
403 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Algs)),
404 },
405 |last| last.alg as u32 + 1,
406 )
407 }
408
409 pub fn fetch_handles(&mut self, class: TpmHt) -> Result<Vec<TpmHandle>, TpmDeviceError> {
420 self.get_capability(
421 TpmCap::Handles,
422 (class as u32) << 24,
423 u32::try_from(MAX_HANDLES)?,
424 |caps| match caps {
425 TpmuCapabilities::Handles(handles) => Ok(handles),
426 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Handles)),
427 },
428 |last| *last + 1,
429 )
430 .map(|handles| handles.into_iter().map(TpmHandle).collect())
431 }
432
433 pub fn fetch_ecc_curves(&mut self) -> Result<Vec<TpmEccCurve>, TpmDeviceError> {
444 self.get_capability(
445 TpmCap::EccCurves,
446 0,
447 u32::try_from(MAX_HANDLES)?,
448 |caps| match caps {
449 TpmuCapabilities::EccCurves(curves) => Ok(curves),
450 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::EccCurves)),
451 },
452 |last| *last as u32 + 1,
453 )
454 }
455
456 pub fn fetch_pcr_banks(&mut self) -> Result<Vec<TpmsPcrSelection>, TpmDeviceError> {
467 self.get_capability(
468 TpmCap::Pcrs,
469 0,
470 u32::try_from(MAX_HANDLES)?,
471 |caps| match caps {
472 TpmuCapabilities::Pcrs(pcrs) => Ok(pcrs),
473 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Pcrs)),
474 },
475 |last| last.hash as u32 + 1,
476 )
477 }
478
479 fn get_capability_page(
489 &mut self,
490 cap: TpmCap,
491 property: u32,
492 count: u32,
493 ) -> Result<(bool, TpmsCapabilityData), TpmDeviceError> {
494 let cmd = TpmGetCapabilityCommand {
495 cap,
496 property,
497 property_count: count,
498 handles: [],
499 };
500
501 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
502 let TpmGetCapabilityResponse {
503 more_data,
504 capability_data,
505 handles: [],
506 } = resp
507 .GetCapability()
508 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::GetCapability))?;
509
510 Ok((more_data.into(), capability_data))
511 }
512
513 pub fn get_tpm_property(&mut self, property: TpmPt) -> Result<u32, TpmDeviceError> {
522 let (_, cap_data) = self.get_capability_page(TpmCap::TpmProperties, property as u32, 1)?;
523
524 let TpmuCapabilities::TpmProperties(props) = &cap_data.data else {
525 return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
526 };
527
528 let Some(prop) = props.first() else {
529 return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
530 };
531
532 Ok(prop.value)
533 }
534
535 pub fn read_public(
544 &mut self,
545 handle: TpmHandle,
546 ) -> Result<(TpmtPublic, Tpm2bName), TpmDeviceError> {
547 if let Some(cached) = self.name_cache.get(&handle.0) {
548 return Ok(cached.clone());
549 }
550
551 let cmd = TpmReadPublicCommand { handles: [handle] };
552 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
553
554 let read_public_resp = resp
555 .ReadPublic()
556 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ReadPublic))?;
557
558 let public = read_public_resp.out_public.inner;
559 let name = read_public_resp.name;
560
561 self.name_cache.insert(handle.0, (public.clone(), name));
562 Ok((public, name))
563 }
564
565 pub fn find_persistent(
577 &mut self,
578 target_name: &Tpm2bName,
579 ) -> Result<Option<TpmHandle>, TpmDeviceError> {
580 for handle in self.fetch_handles(TpmHt::Persistent)? {
581 match self.read_public(handle) {
582 Ok((_, name)) => {
583 if name == *target_name {
584 return Ok(Some(handle));
585 }
586 }
587 Err(TpmDeviceError::TpmRc(rc)) => {
588 let base = rc.base();
589 if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
590 continue;
591 }
592 return Err(TpmDeviceError::TpmRc(rc));
593 }
594 Err(e) => return Err(e),
595 }
596 }
597 Ok(None)
598 }
599
600 pub fn save_context(&mut self, save_handle: TpmHandle) -> Result<TpmsContext, TpmDeviceError> {
609 let cmd = TpmContextSaveCommand {
610 handles: [save_handle],
611 };
612 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
613 let save_resp = resp
614 .ContextSave()
615 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextSave))?;
616 Ok(save_resp.context)
617 }
618
619 pub fn load_context(&mut self, context: TpmsContext) -> Result<TpmHandle, TpmDeviceError> {
628 let cmd = TpmContextLoadCommand {
629 context,
630 handles: [],
631 };
632 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
633 let resp_inner = resp
634 .ContextLoad()
635 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextLoad))?;
636 Ok(resp_inner.handles[0])
637 }
638
639 pub fn flush_context(&mut self, handle: TpmHandle) -> Result<(), TpmDeviceError> {
647 self.name_cache.remove(&handle.0);
648 let cmd = TpmFlushContextCommand {
649 flush_handle: handle,
650 handles: [],
651 };
652 self.transmit(&cmd, Self::NO_SESSIONS)?;
653 Ok(())
654 }
655
656 pub fn flush_session(&mut self, context: TpmsContext) -> Result<(), TpmDeviceError> {
668 match self.load_context(context) {
669 Ok(handle) => self.flush_context(handle),
670 Err(TpmDeviceError::TpmRc(rc)) => {
671 let base = rc.base();
672 if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
673 Ok(())
674 } else {
675 Err(TpmDeviceError::TpmRc(rc))
676 }
677 }
678 Err(e) => Err(e),
679 }
680 }
681
682 pub fn evict_control(
691 &mut self,
692 auth: TpmHandle,
693 object_handle: TpmHandle,
694 persistent_handle: TpmHandle,
695 sessions: &[TpmsAuthCommand],
696 ) -> Result<(), TpmDeviceError> {
697 let cmd = TpmEvictControlCommand {
698 handles: [auth, object_handle],
699 persistent_handle,
700 };
701
702 let (resp, _) = self.transmit(&cmd, sessions)?;
703
704 resp.EvictControl()
705 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::EvictControl))?;
706
707 Ok(())
708 }
709
710 pub fn refresh_key(&mut self, context: TpmsContext) -> Result<bool, TpmDeviceError> {
722 match self.load_context(context) {
723 Ok(handle) => match self.flush_context(handle) {
724 Ok(()) => Ok(true),
725 Err(e) => Err(e),
726 },
727 Err(TpmDeviceError::TpmRc(rc)) if rc.base() == TpmRcBase::ReferenceH0 => Ok(false),
728 Err(e) => Err(e),
729 }
730 }
731}