1use crate::{
6 cli::LogFormat,
7 crypto::crypto_hash_size,
8 handle::{Handle, HandleClass},
9 print::TpmPrint,
10 spinner::Spinner,
11 TEARDOWN,
12};
13use log::trace;
14use polling::{Event, Events, Poller};
15use rand::{thread_rng, RngCore};
16use std::{
17 cell::RefCell,
18 collections::HashMap,
19 fs::File,
20 io::{Read, Write},
21 num::TryFromIntError,
22 rc::Rc,
23 sync::atomic::Ordering,
24 time::{Duration, Instant},
25};
26use thiserror::Error;
27use tpm2_protocol::{
28 constant::{MAX_HANDLES, TPM_MAX_COMMAND_SIZE},
29 data::{
30 Tpm2bEncryptedSecret, Tpm2bName, Tpm2bNonce, TpmAlgId, TpmCap, TpmCc, TpmHt, TpmPt, TpmRc,
31 TpmRcBase, TpmRh, TpmSe, TpmSt, TpmsAlgProperty, TpmsAuthCommand, TpmsCapabilityData,
32 TpmsContext, TpmsRsaParms, TpmtPublic, TpmtPublicParms, TpmtSymDefObject, TpmuCapabilities,
33 TpmuPublicParms,
34 },
35 message::{
36 tpm_build_command, tpm_parse_response, TpmAuthResponses, TpmBodyBuild,
37 TpmContextLoadCommand, TpmContextSaveCommand, TpmEvictControlCommand,
38 TpmFlushContextCommand, TpmGetCapabilityCommand, TpmGetCapabilityResponse, TpmHeader,
39 TpmReadPublicCommand, TpmResponseBody, TpmStartAuthSessionCommand,
40 TpmStartAuthSessionResponse, TpmTestParmsCommand,
41 },
42 TpmError, TpmHandle, TpmWriter,
43};
44
45pub trait TpmCommandObject: TpmPrint + TpmHeader + TpmBodyBuild {}
47impl<T> TpmCommandObject for T where T: TpmHeader + TpmBodyBuild + TpmPrint {}
48
49#[derive(Debug, Error)]
50pub enum DeviceError {
51 #[error("device is already borrowed")]
52 AlreadyBorrowed,
53 #[error("capability not found: {0}")]
54 CapabilityMissing(TpmCap),
55 #[error("operation interrupted by user")]
56 Interrupted,
57 #[error("invalid response")]
58 InvalidResponse,
59 #[error("device not available")]
60 NotAvailable,
61 #[error("response mismatch: {0}")]
62 ResponseMismatch(TpmCc),
63 #[error("TPM command timed out")]
64 Timeout,
65 #[error("int decode: {0}")]
66 IntDecode(#[from] TryFromIntError),
67 #[error("I/O: {0}")]
68 Io(#[from] std::io::Error),
69 #[error("syscall: {0}")]
70 Nix(#[from] nix::Error),
71 #[error("protocol: {0}")]
72 TpmProtocol(TpmError),
73 #[error("TPM return code: {0}")]
74 TpmRc(TpmRc),
75}
76
77impl From<TpmError> for DeviceError {
78 fn from(err: TpmError) -> Self {
79 Self::TpmProtocol(err)
80 }
81}
82
83impl From<TpmRc> for DeviceError {
84 fn from(rc: TpmRc) -> Self {
85 Self::TpmRc(rc)
86 }
87}
88
89pub fn with_device<F, T, E>(device: Option<Rc<RefCell<Device>>>, f: F) -> Result<T, E>
99where
100 F: FnOnce(&mut Device) -> Result<T, E>,
101 E: From<DeviceError>,
102{
103 let device_rc = device.ok_or(DeviceError::NotAvailable)?;
104 let mut device_guard = device_rc
105 .try_borrow_mut()
106 .map_err(|_| DeviceError::AlreadyBorrowed)?;
107 f(&mut device_guard)
108}
109
110#[derive(Debug)]
111pub struct Device {
112 file: File,
113 poller: Poller,
114 log_format: LogFormat,
115 name_cache: HashMap<u32, (TpmtPublic, Tpm2bName)>,
116}
117
118pub(crate) fn test_rsa_parms(device: &mut Device, key_bits: u16) -> Result<(), DeviceError> {
120 let cmd = TpmTestParmsCommand {
121 parameters: TpmtPublicParms {
122 object_type: TpmAlgId::Rsa,
123 parameters: TpmuPublicParms::Rsa(TpmsRsaParms {
124 key_bits,
125 ..Default::default()
126 }),
127 },
128 };
129 let sessions = vec![];
130 device.execute(&cmd, &sessions).map(|(_, _)| ())
131}
132
133impl Device {
134 pub fn new(file: File, log_format: LogFormat) -> Result<Self, DeviceError> {
140 let poller = Poller::new()?;
141 Ok(Self {
142 file,
143 poller,
144 log_format,
145 name_cache: HashMap::new(),
146 })
147 }
148
149 fn receive_from_stream(&mut self) -> Result<Vec<u8>, DeviceError> {
150 let mut header = [0u8; 10];
151 self.file.read_exact(&mut header)?;
152 let Ok(size_bytes): Result<[u8; 4], _> = header[2..6].try_into() else {
153 return Err(DeviceError::InvalidResponse);
154 };
155 let size = u32::from_be_bytes(size_bytes) as usize;
156 if size < header.len() || size > TPM_MAX_COMMAND_SIZE {
157 return Err(DeviceError::InvalidResponse);
158 }
159 let mut resp_buf = header.to_vec();
160 resp_buf.resize(size, 0);
161 self.file.read_exact(&mut resp_buf[header.len()..])?;
162 Ok(resp_buf)
163 }
164
165 pub fn execute<C: TpmCommandObject>(
180 &mut self,
181 command: &C,
182 sessions: &[TpmsAuthCommand],
183 ) -> Result<(TpmResponseBody, TpmAuthResponses), DeviceError> {
184 let command_vec = self.build_command_buffer(command, sessions)?;
185 let cc = command.cc();
186
187 let mut spinner = Spinner::new("Waiting for TPM...");
188
189 self.file.write_all(&command_vec)?;
190 self.file.flush()?;
191
192 let mut events = Events::new();
193 unsafe { self.poller.add(&self.file, Event::readable(0))? };
194
195 let start_time = Instant::now();
196 let resp_buf = loop {
197 if TEARDOWN.load(Ordering::Relaxed) {
198 spinner.finish();
199 let _ = self.poller.delete(&self.file);
200 break Err(DeviceError::Interrupted);
201 }
202 if start_time.elapsed() > Duration::from_secs(60) {
203 spinner.finish();
204 let _ = self.poller.delete(&self.file);
205 break Err(DeviceError::Timeout);
206 }
207
208 spinner.tick();
209
210 self.poller
211 .wait(&mut events, Some(Duration::from_millis(100)))?;
212
213 if !events.is_empty() {
214 let _ = self.poller.delete(&self.file);
215 break self.receive_from_stream();
216 }
217 }?;
218
219 let result = tpm_parse_response(cc, &resp_buf);
220 if self.log_format == LogFormat::Pretty {
221 let mut buf = Vec::new();
222 match &result {
223 Ok(Ok((response, _))) => {
224 response.print(&mut buf, "Response", 1)?;
225 for line in String::from_utf8_lossy(&buf).lines() {
226 trace!(target: "cli::device", "{line}");
227 }
228 }
229 Ok(Err(_)) | Err(_) => {
230 trace!(
231 target: "cli::device",
232 "Response: {}",
233 hex::encode(&resp_buf)
234 );
235 }
236 }
237 } else {
238 trace!(
239 target: "cli::device",
240 "Response: {}",
241 hex::encode(&resp_buf)
242 );
243 }
244 Ok(result??)
245 }
246
247 fn build_command_buffer<C: TpmCommandObject>(
248 &self,
249 command: &C,
250 sessions: &[TpmsAuthCommand],
251 ) -> Result<Vec<u8>, DeviceError> {
252 let cc = command.cc();
253 let tag = if sessions.is_empty() {
254 TpmSt::NoSessions
255 } else {
256 TpmSt::Sessions
257 };
258 let mut buf = vec![0u8; TPM_MAX_COMMAND_SIZE];
259 let len = {
260 let mut writer = TpmWriter::new(&mut buf);
261 tpm_build_command(command, tag, sessions, &mut writer)?;
262 writer.len()
263 };
264 buf.truncate(len);
265
266 if self.log_format == LogFormat::Pretty {
267 let mut print_buf = Vec::new();
268 writeln!(&mut print_buf, "{cc}")?;
269 command.print(&mut print_buf, "Command", 1)?;
270 for line in String::from_utf8_lossy(&print_buf).lines() {
271 trace!(target: "cli::device", "{line}");
272 }
273 } else {
274 trace!(
275 target: "cli::device",
276 "Command: {}",
277 hex::encode(&buf)
278 );
279 }
280 Ok(buf)
281 }
282
283 pub fn get_capability<T, F, N>(
290 &mut self,
291 cap: TpmCap,
292 property_start: u32,
293 count: u32,
294 mut extract: F,
295 next_prop: N,
296 ) -> Result<Vec<T>, DeviceError>
297 where
298 T: Copy,
299 F: for<'a> FnMut(&'a TpmuCapabilities) -> Result<&'a [T], DeviceError>,
300 N: Fn(&T) -> u32,
301 {
302 let mut results = Vec::new();
303 let mut prop = property_start;
304 loop {
305 let (more_data, cap_data) = self.get_capability_page(cap, prop, count)?;
306 let items: &[T] = extract(&cap_data.data)?;
307 results.extend_from_slice(items);
308
309 if more_data {
310 if let Some(last) = items.last() {
311 prop = next_prop(last);
312 } else {
313 break;
314 }
315 } else {
316 break;
317 }
318 }
319 Ok(results)
320 }
321
322 pub(crate) fn fetch_algorithm_properties(
324 &mut self,
325 ) -> Result<Vec<TpmsAlgProperty>, DeviceError> {
326 self.get_capability(
327 TpmCap::Algs,
328 0,
329 u32::try_from(MAX_HANDLES)?,
330 |caps| match caps {
331 TpmuCapabilities::Algs(algs) => Ok(algs),
332 _ => Err(DeviceError::CapabilityMissing(TpmCap::Algs)),
333 },
334 |last| last.alg as u32 + 1,
335 )
336 }
337
338 pub fn fetch_handles(&mut self, class: u32) -> Result<Vec<Handle>, DeviceError> {
344 self.get_capability(
345 TpmCap::Handles,
346 class,
347 u32::try_from(MAX_HANDLES)?,
348 |caps| match caps {
349 TpmuCapabilities::Handles(handles) => Ok(handles),
350 _ => Err(DeviceError::CapabilityMissing(TpmCap::Handles)),
351 },
352 |last| *last + 1,
353 )
354 .map(|handles| {
355 handles
356 .into_iter()
357 .map(|h| Handle((HandleClass::Tpm, h)))
358 .collect()
359 })
360 }
361
362 pub fn get_capability_page(
369 &mut self,
370 cap: TpmCap,
371 property: u32,
372 count: u32,
373 ) -> Result<(bool, TpmsCapabilityData), DeviceError> {
374 let cmd = TpmGetCapabilityCommand {
375 cap,
376 property,
377 property_count: count,
378 };
379 let sessions = vec![];
380
381 let (resp, _) = self.execute(&cmd, &sessions)?;
382 let TpmGetCapabilityResponse {
383 more_data,
384 capability_data,
385 } = resp
386 .GetCapability()
387 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::GetCapability))?;
388
389 Ok((more_data.into(), capability_data))
390 }
391
392 pub fn get_tpm_property(&mut self, property: TpmPt) -> Result<u32, DeviceError> {
399 let (_, cap_data) = self.get_capability_page(TpmCap::TpmProperties, property as u32, 1)?;
400
401 let TpmuCapabilities::TpmProperties(props) = &cap_data.data else {
402 return Err(DeviceError::CapabilityMissing(TpmCap::TpmProperties));
403 };
404
405 let Some(prop) = props.first() else {
406 return Err(DeviceError::CapabilityMissing(TpmCap::TpmProperties));
407 };
408
409 Ok(prop.value)
410 }
411
412 pub fn read_public(
419 &mut self,
420 handle: TpmHandle,
421 ) -> Result<(TpmtPublic, Tpm2bName), DeviceError> {
422 if let Some(cached) = self.name_cache.get(&handle.0) {
423 return Ok(cached.clone());
424 }
425
426 let cmd = TpmReadPublicCommand {
427 object_handle: handle,
428 };
429 let sessions = vec![];
430 let (resp, _) = self.execute(&cmd, &sessions)?;
431
432 let read_public_resp = resp
433 .ReadPublic()
434 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::ReadPublic))?;
435
436 let public = read_public_resp.out_public.inner;
437 let name = read_public_resp.name;
438
439 self.name_cache.insert(handle.0, (public.clone(), name));
440 Ok((public, name))
441 }
442
443 pub fn find_persistent(
449 &mut self,
450 target: &TpmtPublic,
451 ) -> Result<Option<(TpmHandle, Tpm2bName)>, DeviceError> {
452 let handles = self.fetch_handles((TpmHt::Persistent as u32) << 24)?;
453 for handle in handles {
454 if let Ok((public, name)) = self.read_public(handle.value().into()) {
455 if public == *target {
456 return Ok(Some((handle.value().into(), name)));
457 }
458 }
459 }
460 Ok(None)
461 }
462
463 pub fn save_context(&mut self, save_handle: TpmHandle) -> Result<TpmsContext, DeviceError> {
470 let cmd = TpmContextSaveCommand { save_handle };
471 let sessions = vec![];
472 let (resp, _) = self.execute(&cmd, &sessions)?;
473 let save_resp = resp
474 .ContextSave()
475 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::ContextSave))?;
476 Ok(save_resp.context)
477 }
478
479 pub fn load_context(&mut self, context: TpmsContext) -> Result<TpmHandle, DeviceError> {
485 let cmd = TpmContextLoadCommand { context };
486 let sessions = vec![];
487 let (resp, _) = self.execute(&cmd, &sessions)?;
488 let resp_inner = resp
489 .ContextLoad()
490 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::ContextLoad))?;
491 Ok(resp_inner.loaded_handle)
492 }
493
494 pub fn flush_context(&mut self, handle: TpmHandle) -> Result<(), DeviceError> {
501 self.name_cache.remove(&handle.0);
502 let cmd = TpmFlushContextCommand {
503 flush_handle: handle,
504 };
505 let sessions = vec![];
506 self.execute(&cmd, &sessions)?;
507 Ok(())
508 }
509
510 pub fn flush_session(&mut self, context: TpmsContext) -> Result<(), DeviceError> {
516 match self.load_context(context) {
517 Ok(handle) => self.flush_context(handle),
518 Err(DeviceError::TpmRc(rc)) => {
519 let base = rc.base();
520 if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
521 Ok(())
522 } else {
523 Err(DeviceError::TpmRc(rc))
524 }
525 }
526 Err(e) => Err(e),
527 }
528 }
529
530 pub fn start_session(
540 &mut self,
541 session_type: TpmSe,
542 auth_hash: TpmAlgId,
543 bind: TpmHandle,
544 ) -> Result<(TpmStartAuthSessionResponse, Tpm2bNonce), DeviceError> {
545 let digest_len =
546 crypto_hash_size(auth_hash).ok_or(DeviceError::TpmProtocol(TpmError::MalformedData))?;
547 let mut nonce_bytes = vec![0; digest_len];
548 thread_rng().fill_bytes(&mut nonce_bytes);
549 let nonce_caller = Tpm2bNonce::try_from(nonce_bytes.as_slice())?;
550
551 let cmd = TpmStartAuthSessionCommand {
552 tpm_key: (TpmRh::Null as u32).into(),
553 bind,
554 nonce_caller,
555 encrypted_salt: Tpm2bEncryptedSecret::default(),
556 session_type,
557 symmetric: TpmtSymDefObject::default(),
558 auth_hash,
559 };
560 let sessions = vec![];
561
562 let (response_body, _) = self.execute(&cmd, &sessions)?;
563
564 let resp = response_body
565 .StartAuthSession()
566 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::StartAuthSession))?;
567
568 Ok((resp, nonce_caller))
569 }
570
571 pub fn evict_control(
577 &mut self,
578 auth: TpmHandle,
579 object_handle: TpmHandle,
580 persistent_handle: TpmHandle,
581 sessions: &[TpmsAuthCommand],
582 ) -> Result<(), DeviceError> {
583 let cmd = TpmEvictControlCommand {
584 auth,
585 object_handle: object_handle.0.into(),
586 persistent_handle,
587 };
588 let (resp, _) = self.execute(&cmd, sessions)?;
589
590 resp.EvictControl()
591 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::EvictControl))?;
592 Ok(())
593 }
594}