1use crate::{
6 cli::LogFormat,
7 crypto::CryptoError,
8 key::{Tpm2shAlgId, Tpm2shEccCurve},
9 print::TpmPrint,
10 transport::{receive_from_stream, FileTransport, Transport},
11 TEARDOWN,
12};
13
14use std::{
15 cell::RefCell,
16 collections::HashMap,
17 io::{IsTerminal, Write},
18 num::TryFromIntError,
19 rc::Rc,
20 sync::atomic::Ordering,
21 time::{Duration, Instant},
22};
23
24use indicatif::{ProgressBar, ProgressStyle};
25use log::trace;
26use polling::{Event, Events, Poller};
27use rand::{thread_rng, RngCore};
28use thiserror::Error;
29use tpm2_protocol::{
30 constant::{MAX_HANDLES, TPM_MAX_COMMAND_SIZE},
31 data::{
32 Tpm2bEncryptedSecret, Tpm2bName, Tpm2bNonce, TpmAlgId, TpmCap, TpmCc, TpmHt, TpmPt, TpmRc,
33 TpmRcBase, TpmRh, TpmSe, TpmSt, TpmsAuthCommand, TpmsCapabilityData, TpmsContext,
34 TpmsRsaParms, TpmtPublic, TpmtPublicParms, TpmtSymDefObject, TpmuCapabilities,
35 TpmuPublicParms,
36 },
37 message::{
38 tpm_build_command, tpm_parse_response, TpmAuthResponses, TpmBodyBuild,
39 TpmContextLoadCommand, TpmContextSaveCommand, TpmFlushContextCommand,
40 TpmGetCapabilityCommand, TpmGetCapabilityResponse, TpmHeader, TpmReadPublicCommand,
41 TpmResponseBody, TpmStartAuthSessionCommand, TpmStartAuthSessionResponse,
42 TpmTestParmsCommand,
43 },
44 tpm_hash_size, TpmErrorKind, TpmHandle, TpmWriter,
45};
46
47pub const TPM_CAP_PROPERTY_MAX: u32 = 128;
48
49pub trait TpmCommandObject: TpmPrint + TpmHeader + TpmBodyBuild {}
51impl<T> TpmCommandObject for T where T: TpmHeader + TpmBodyBuild + TpmPrint {}
52
53#[derive(Debug, Clone)]
55pub enum Auth {
56 Tracked(u32),
58 Password(Vec<u8>),
60}
61
62pub type AuthList = Vec<Auth>;
64
65#[derive(Debug, Error)]
66pub enum DeviceError {
67 #[error("I/O: {0}")]
68 Io(#[from] std::io::Error),
69 #[error("syscall: {0}")]
70 Nix(#[from] nix::Error),
71 #[error("response corrupted")]
72 ResponseCorrupted,
73 #[error("response mismatch: {0}")]
74 ResponseMismatch(TpmCc),
75 #[error("operation interrupted by user")]
76 Interrupted,
77 #[error("crypto: {0}")]
78 Crypto(#[from] CryptoError),
79 #[error("unknown handle name: {0:08x}")]
80 UnknownHandleName(u32),
81 #[error("TPM: {0}")]
82 Tpm(TpmErrorKind),
83 #[error("TPM RC: {0}")]
84 TpmRc(TpmRc),
85 #[error("TPM command timed out")]
86 Timeout,
87 #[error("device not available")]
88 NotAvailable,
89 #[error("device is already borrowed")]
90 AlreadyBorrowed,
91 #[error("capability not found: {0}")]
92 CapabilityMissing(TpmCap),
93}
94
95impl From<TpmErrorKind> for DeviceError {
96 fn from(err: TpmErrorKind) -> Self {
97 Self::Tpm(err)
98 }
99}
100
101impl From<TpmRc> for DeviceError {
102 fn from(rc: TpmRc) -> Self {
103 Self::TpmRc(rc)
104 }
105}
106
107impl From<TryFromIntError> for DeviceError {
108 fn from(_err: TryFromIntError) -> Self {
109 Self::Tpm(TpmErrorKind::InvalidValue)
110 }
111}
112
113pub fn with_device<F, T, E>(device: Option<Rc<RefCell<Device>>>, f: F) -> Result<T, E>
123where
124 F: FnOnce(&mut Device) -> Result<T, E>,
125 E: From<DeviceError>,
126{
127 let device_rc = device.ok_or(DeviceError::NotAvailable)?;
128 let mut device_guard = device_rc
129 .try_borrow_mut()
130 .map_err(|_| DeviceError::AlreadyBorrowed)?;
131 f(&mut device_guard)
132}
133
134#[derive(Debug)]
135pub struct Device {
136 transport: Box<dyn Transport>,
137 poller: Poller,
138 log_format: LogFormat,
139 name_cache: HashMap<u32, Tpm2bName>,
140}
141
142fn test_rsa_parms(device: &mut Device, key_bits: u16) -> Result<(), DeviceError> {
144 let cmd = TpmTestParmsCommand {
145 parameters: TpmtPublicParms {
146 object_type: TpmAlgId::Rsa,
147 parameters: TpmuPublicParms::Rsa(TpmsRsaParms {
148 key_bits,
149 ..Default::default()
150 }),
151 },
152 };
153 let sessions = vec![];
154 device.execute(&cmd, &sessions).map(|(_, _)| ())
155}
156
157impl Device {
158 pub fn new(
164 transport: impl Transport + 'static,
165 log_format: LogFormat,
166 ) -> Result<Self, DeviceError> {
167 let poller = Poller::new()?;
168 Ok(Self {
169 transport: Box::new(transport),
170 poller,
171 log_format,
172 name_cache: HashMap::new(),
173 })
174 }
175
176 pub fn add_name_to_cache(&mut self, handle: u32, name: Tpm2bName) {
178 self.name_cache.insert(handle, name);
179 }
180
181 pub(crate) fn get_handle_name(&mut self, handle: u32) -> Result<Tpm2bName, DeviceError> {
183 if let Some(name) = self.name_cache.get(&handle) {
184 return Ok(*name);
185 }
186
187 let mso = (handle >> 24) as u8;
188 if mso == TpmHt::Transient as u8 || mso == TpmHt::Persistent as u8 {
189 let (_, name) = self.read_public(handle.into())?;
190 Ok(name)
191 } else {
192 Tpm2bName::try_from(handle.to_be_bytes().as_slice()).map_err(Into::into)
193 }
194 }
195
196 fn receive_with_progress(&mut self) -> Result<Vec<u8>, DeviceError> {
197 if let Some(ft) = self.transport.as_any_mut().downcast_mut::<FileTransport>() {
198 let spinner = ProgressBar::new_spinner();
199 spinner.enable_steady_tick(Duration::from_millis(100));
200 spinner.set_style(
201 ProgressStyle::with_template("{spinner:.green} {msg}")
202 .expect("Invalid progress spinner template")
203 .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏ "),
204 );
205 spinner.set_message("Waiting for TPM...");
206
207 let mut events = Events::new();
208 let file = &mut ft.0;
209 unsafe { self.poller.add(&*file, Event::readable(0))? };
210
211 let start_time = Instant::now();
212 let result = loop {
213 if TEARDOWN.load(Ordering::Relaxed) {
214 break Err(DeviceError::Interrupted);
215 }
216 if start_time.elapsed() > Duration::from_secs(60) {
217 break Err(DeviceError::Timeout);
218 }
219
220 self.poller
221 .wait(&mut events, Some(Duration::from_millis(100)))?;
222 if !events.is_empty() {
223 break receive_from_stream(file);
224 }
225 };
226
227 spinner.finish_and_clear();
228 self.poller.delete(&*file)?;
229 result
230 } else {
231 self.transport.receive()
232 }
233 }
234
235 pub fn execute<C: TpmCommandObject>(
242 &mut self,
243 command: &C,
244 sessions: &[TpmsAuthCommand],
245 ) -> Result<(TpmResponseBody, TpmAuthResponses), DeviceError> {
246 let command_vec = self.build_command_buffer(command, sessions)?;
247 let cc = command.cc();
248 self.transport.send(&command_vec)?;
249 let resp_buf = if std::io::stderr().is_terminal() {
250 self.receive_with_progress()?
251 } else {
252 self.transport.receive()?
253 };
254 let result = tpm_parse_response(cc, &resp_buf);
255 if self.log_format == LogFormat::Pretty {
256 let mut buf = Vec::new();
257 match &result {
258 Ok(Ok((response, _))) => {
259 response.print(&mut buf, "Response", 1)?;
260 for line in String::from_utf8_lossy(&buf).lines() {
261 trace!(target: "cli::device", "{line}");
262 }
263 }
264 Ok(Err(_)) | Err(_) => {
265 trace!(
266 target: "cli::device",
267 "Response: {}",
268 hex::encode(&resp_buf)
269 );
270 }
271 }
272 } else {
273 trace!(
274 target: "cli::device",
275 "Response: {}",
276 hex::encode(&resp_buf)
277 );
278 }
279 Ok(result??)
280 }
281
282 fn build_command_buffer<C: TpmCommandObject>(
283 &self,
284 command: &C,
285 sessions: &[TpmsAuthCommand],
286 ) -> Result<Vec<u8>, DeviceError> {
287 let cc = command.cc();
288 let tag = if sessions.is_empty() {
289 TpmSt::NoSessions
290 } else {
291 TpmSt::Sessions
292 };
293 let mut buf = vec![0u8; TPM_MAX_COMMAND_SIZE];
294 let len = {
295 let mut writer = TpmWriter::new(&mut buf);
296 tpm_build_command(command, tag, sessions, &mut writer)?;
297 writer.len()
298 };
299 buf.truncate(len);
300
301 if self.log_format == LogFormat::Pretty {
302 let mut print_buf = Vec::new();
303 writeln!(&mut print_buf, "{cc}")?;
304 command.print(&mut print_buf, "Command", 1)?;
305 for line in String::from_utf8_lossy(&print_buf).lines() {
306 trace!(target: "cli::device", "{line}");
307 }
308 } else {
309 trace!(
310 target: "cli::device",
311 "Command: {}",
312 hex::encode(&buf)
313 );
314 }
315 Ok(buf)
316 }
317
318 pub fn get_all_algorithms(&mut self) -> Result<Vec<(TpmAlgId, String)>, DeviceError> {
324 let mut supported_algs = Vec::new();
325 let mut all_algs = Vec::new();
326 let mut prop = 0;
327 loop {
328 let (more_data, cap_data) =
329 self.get_capability(TpmCap::Algs, prop, u32::try_from(MAX_HANDLES)?)?;
330
331 if let TpmuCapabilities::Algs(p) = cap_data.data {
332 all_algs.extend(p.iter().map(|prop| prop.alg));
333 } else {
334 return Err(DeviceError::CapabilityMissing(TpmCap::Algs));
335 }
336
337 if more_data {
338 if let TpmuCapabilities::Algs(algs) = cap_data.data {
339 prop = algs.last().map_or(prop, |p| p.alg as u32 + 1);
340 }
341 } else {
342 break;
343 }
344 }
345 let all_algs: std::collections::HashSet<TpmAlgId> = all_algs.into_iter().collect();
346
347 let name_algs: Vec<TpmAlgId> = [TpmAlgId::Sha256, TpmAlgId::Sha384, TpmAlgId::Sha512]
348 .into_iter()
349 .filter(|alg| all_algs.contains(alg))
350 .collect();
351
352 if all_algs.contains(&TpmAlgId::Rsa) {
353 let rsa_key_sizes = [2048, 3072, 4096];
354 for key_bits in rsa_key_sizes {
355 match test_rsa_parms(self, key_bits) {
356 Ok(()) => {
357 for &name_alg in &name_algs {
358 supported_algs.push((
359 TpmAlgId::Rsa,
360 format!("rsa-{}:{}", key_bits, Tpm2shAlgId(name_alg)),
361 ));
362 }
363 }
364 Err(DeviceError::TpmRc(rc)) => {
365 if rc.base() != TpmRcBase::Value {
366 return Err(DeviceError::TpmRc(rc));
367 }
368 }
369 Err(e) => return Err(e),
370 }
371 }
372 }
373
374 if all_algs.contains(&TpmAlgId::Ecc) {
375 let mut supported_curves = Vec::new();
376 let mut prop = 0;
377 loop {
378 let (more_data, cap_data) =
379 self.get_capability(TpmCap::EccCurves, prop, u32::try_from(MAX_HANDLES)?)?;
380 if let TpmuCapabilities::EccCurves(curves) = &cap_data.data {
381 supported_curves.extend(curves.iter().copied());
382 } else {
383 return Err(DeviceError::CapabilityMissing(TpmCap::EccCurves));
384 }
385 if more_data {
386 if let TpmuCapabilities::EccCurves(curves) = cap_data.data {
387 prop = curves.last().map_or(prop, |&c| c as u32 + 1);
388 }
389 } else {
390 break;
391 }
392 }
393 for curve_id in supported_curves {
394 for &name_alg in &name_algs {
395 supported_algs.push((
396 TpmAlgId::Ecc,
397 format!(
398 "ecc-{}:{}",
399 Tpm2shEccCurve::from(curve_id),
400 Tpm2shAlgId(name_alg)
401 ),
402 ));
403 }
404 }
405 }
406
407 if all_algs.contains(&TpmAlgId::KeyedHash) {
408 for &name_alg in &name_algs {
409 supported_algs.push((
410 TpmAlgId::KeyedHash,
411 format!("keyedhash:{}", Tpm2shAlgId(name_alg)),
412 ));
413 }
414 }
415
416 Ok(supported_algs)
417 }
418
419 pub fn get_all_hashes(&mut self) -> Result<Vec<String>, DeviceError> {
425 let mut all_algs = Vec::new();
426 let mut prop = 0;
427
428 loop {
429 let (more_data, cap_data) =
430 self.get_capability(TpmCap::Algs, prop, u32::try_from(MAX_HANDLES)?)?;
431
432 if let TpmuCapabilities::Algs(p) = &cap_data.data {
433 all_algs.extend(p.iter().map(|prop| prop.alg));
434 } else {
435 return Err(DeviceError::CapabilityMissing(TpmCap::Algs));
436 }
437
438 if more_data {
439 if let TpmuCapabilities::Algs(algs) = cap_data.data {
440 prop = algs.last().map_or(prop, |p| p.alg as u32 + 1);
441 }
442 } else {
443 break;
444 }
445 }
446
447 let hashes: Vec<String> = all_algs
448 .iter()
449 .filter(|p| tpm_hash_size(p).is_some())
450 .map(|p| Tpm2shAlgId(*p).to_string())
451 .collect();
452 Ok(hashes)
453 }
454
455 pub fn get_all_handles(&mut self, handle_type: u32) -> Result<Vec<u32>, DeviceError> {
461 let mut all_handles = Vec::new();
462 let mut prop = handle_type;
463
464 loop {
465 let (more_data, cap_data) =
466 self.get_capability(TpmCap::Handles, prop, TPM_CAP_PROPERTY_MAX)?;
467
468 if let TpmuCapabilities::Handles(handles) = cap_data.data {
469 all_handles.extend(handles.iter().copied());
470 } else {
471 return Err(DeviceError::CapabilityMissing(TpmCap::Handles));
472 }
473
474 if more_data {
475 if let TpmuCapabilities::Handles(handles) = cap_data.data {
476 prop = handles.last().map_or(prop, |&h| h + 1);
477 }
478 } else {
479 break;
480 }
481 }
482
483 Ok(all_handles)
484 }
485
486 pub fn get_capability(
493 &mut self,
494 cap: TpmCap,
495 property: u32,
496 count: u32,
497 ) -> Result<(bool, TpmsCapabilityData), DeviceError> {
498 let cmd = TpmGetCapabilityCommand {
499 cap,
500 property,
501 property_count: count,
502 };
503 let sessions = vec![];
504
505 let (resp, _) = self.execute(&cmd, &sessions)?;
506 let TpmGetCapabilityResponse {
507 more_data,
508 capability_data,
509 } = resp
510 .GetCapability()
511 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::GetCapability))?;
512
513 Ok((more_data.into(), capability_data))
514 }
515
516 pub fn get_tpm_property(&mut self, property: TpmPt) -> Result<u32, DeviceError> {
523 let (_, cap_data) = self.get_capability(TpmCap::TpmProperties, property as u32, 1)?;
524
525 let TpmuCapabilities::TpmProperties(props) = &cap_data.data else {
526 return Err(DeviceError::CapabilityMissing(TpmCap::TpmProperties));
527 };
528
529 let Some(prop) = props.first() else {
530 return Err(DeviceError::CapabilityMissing(TpmCap::TpmProperties));
531 };
532
533 Ok(prop.value)
534 }
535
536 pub fn read_public(
543 &mut self,
544 handle: TpmHandle,
545 ) -> Result<(TpmtPublic, Tpm2bName), DeviceError> {
546 let cmd = TpmReadPublicCommand {
547 object_handle: handle,
548 };
549 let sessions = vec![];
550 let (resp, _) = self.execute(&cmd, &sessions)?;
551 let read_public_resp = resp
552 .ReadPublic()
553 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::ReadPublic))?;
554 let name = read_public_resp.name;
555 self.add_name_to_cache(handle.0, name);
556 Ok((read_public_resp.out_public.inner, name))
557 }
558
559 pub fn save_context(&mut self, handle: u32) -> Result<TpmsContext, DeviceError> {
566 let cmd = TpmContextSaveCommand {
567 save_handle: handle.into(),
568 };
569 let sessions = vec![];
570 let (resp, _) = self.execute(&cmd, &sessions)?;
571 let save_resp = resp
572 .ContextSave()
573 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::ContextSave))?;
574 Ok(save_resp.context)
575 }
576
577 pub fn load_context(&mut self, context: TpmsContext) -> Result<u32, DeviceError> {
583 let cmd = TpmContextLoadCommand { context };
584 let sessions = vec![];
585 let (resp, _) = self.execute(&cmd, &sessions)?;
586 let resp_inner = resp
587 .ContextLoad()
588 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::ContextLoad))?;
589 Ok(resp_inner.loaded_handle.0)
590 }
591
592 pub fn flush_context(&mut self, handle: u32) -> Result<(), DeviceError> {
599 let cmd = TpmFlushContextCommand {
600 flush_handle: handle.into(),
601 };
602 let sessions = vec![];
603 self.execute(&cmd, &sessions)?;
604 Ok(())
605 }
606
607 pub fn flush_session(&mut self, context: TpmsContext) -> Result<(), DeviceError> {
613 match self.load_context(context) {
614 Ok(live_handle) => self.flush_context(live_handle),
615 Err(DeviceError::TpmRc(rc)) if rc.base() == TpmRcBase::ReferenceH0 => Ok(()),
616 Err(e) => Err(e),
617 }
618 }
619
620 pub fn start_session(
630 &mut self,
631 session_type: TpmSe,
632 auth_hash: TpmAlgId,
633 ) -> Result<(TpmStartAuthSessionResponse, Tpm2bNonce), DeviceError> {
634 let digest_len =
635 tpm_hash_size(&auth_hash).ok_or(DeviceError::Tpm(TpmErrorKind::InvalidValue))?;
636 let mut nonce_bytes = vec![0; digest_len];
637 thread_rng().fill_bytes(&mut nonce_bytes);
638 let nonce_caller = Tpm2bNonce::try_from(nonce_bytes.as_slice())?;
639
640 let cmd = TpmStartAuthSessionCommand {
641 tpm_key: (TpmRh::Null as u32).into(),
642 bind: (TpmRh::Null as u32).into(),
643 nonce_caller,
644 encrypted_salt: Tpm2bEncryptedSecret::default(),
645 session_type,
646 symmetric: TpmtSymDefObject::default(),
647 auth_hash,
648 };
649 let sessions = vec![];
650
651 let (response_body, _) = self.execute(&cmd, &sessions)?;
652
653 let resp = response_body
654 .StartAuthSession()
655 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::StartAuthSession))?;
656
657 Ok((resp, nonce_caller))
658 }
659}