1use crate::{
18 convert::from_tpm_object_to_vec,
19 crypto::{crypto_digest, crypto_hmac, crypto_kdfa, CryptoError},
20 device::{Auth, Device, DeviceError, TpmCommandObject},
21 uri::Uri,
22};
23use rand::{thread_rng, RngCore};
24use std::{
25 collections::{hash_map, HashMap, HashSet},
26 num::TryFromIntError,
27 path::{Path, PathBuf},
28 str::FromStr,
29};
30use thiserror::Error;
31use tpm2_protocol::{
32 constant::TPM_MAX_COMMAND_SIZE,
33 data::{
34 Tpm2bAuth, Tpm2bName, Tpm2bNonce, TpmAlgId, TpmCc, TpmRcBase, TpmRh, TpmSe, TpmaSession,
35 TpmsAuthCommand, TpmsAuthResponse, TpmsContext,
36 },
37 message::{TpmAuthResponses, TpmStartAuthSessionResponse},
38 tpm_hash_size, TpmBuffer, TpmBuild, TpmErrorKind, TpmHandle, TpmParse, TpmWriter,
39};
40
41#[derive(Debug, Error)]
42pub enum SessionError {
43 #[error("crypto: {0}")]
44 Crypto(#[from] CryptoError),
45 #[error("device: {0}")]
46 Device(#[from] DeviceError),
47 #[error("I/O: {0}")]
48 Io(#[from] std::io::Error),
49 #[error("session '{0}' not found")]
50 NotFound(String),
51 #[error("session file has trailing data")]
52 TrailingData,
53 #[error("authorization requires both a session and a password")]
54 AuthPairingError,
55}
56
57impl From<TryFromIntError> for SessionError {
58 fn from(_err: TryFromIntError) -> Self {
59 Self::Device(DeviceError::Tpm(TpmErrorKind::InvalidValue))
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct Session {
66 pub handle: TpmHandle,
67 pub context: TpmsContext,
68 pub nonce_tpm: Tpm2bNonce,
69 pub attributes: TpmaSession,
70 pub hmac_key: Tpm2bAuth,
71 pub auth_hash: TpmAlgId,
72 pub session_type: TpmSe,
73}
74
75impl Session {
76 pub(crate) fn new(
78 session_type: TpmSe,
79 auth_hash: TpmAlgId,
80 nonce_caller: Tpm2bNonce,
81 resp: &TpmStartAuthSessionResponse,
82 ) -> Result<Self, SessionError> {
83 let digest_len =
84 tpm_hash_size(&auth_hash).ok_or(DeviceError::Tpm(TpmErrorKind::InvalidValue))?;
85
86 let hmac_key_bytes = if session_type == TpmSe::Hmac {
87 let key_bits = u16::try_from(digest_len * 8)?;
88 crypto_kdfa(
89 auth_hash,
90 &[],
91 "ATH",
92 &resp.nonce_tpm,
93 &nonce_caller,
94 key_bits,
95 )
96 .map_err(SessionError::Crypto)?
97 } else {
98 Vec::new()
99 };
100
101 Ok(Self {
102 handle: resp.session_handle,
103 context: TpmsContext {
104 sequence: 0,
105 saved_handle: resp.session_handle.0.into(),
106 hierarchy: TpmRh::Null,
107 context_blob: TpmBuffer::default(),
108 },
109 nonce_tpm: resp.nonce_tpm,
110 attributes: TpmaSession::CONTINUE_SESSION,
111 hmac_key: Tpm2bAuth::try_from(hmac_key_bytes.as_slice()).map_err(DeviceError::Tpm)?,
112 auth_hash,
113 session_type,
114 })
115 }
116 pub fn save_to_path(&self, path: &Path) -> Result<(), SessionError> {
122 let mut buf = vec![0u8; TPM_MAX_COMMAND_SIZE];
123 let len = {
124 let mut writer = TpmWriter::new(&mut buf);
125 self.session_type
126 .build(&mut writer)
127 .map_err(DeviceError::Tpm)?;
128 self.context.build(&mut writer).map_err(DeviceError::Tpm)?;
129 self.nonce_tpm
130 .build(&mut writer)
131 .map_err(DeviceError::Tpm)?;
132 self.attributes
133 .build(&mut writer)
134 .map_err(DeviceError::Tpm)?;
135 self.hmac_key.build(&mut writer).map_err(DeviceError::Tpm)?;
136 self.auth_hash
137 .build(&mut writer)
138 .map_err(DeviceError::Tpm)?;
139 writer.len()
140 };
141 buf.truncate(len);
142
143 std::fs::write(path, &buf)?;
144 Ok(())
145 }
146
147 pub fn load_from_path(path: &Path) -> Result<Self, SessionError> {
153 let session_bytes = std::fs::read(path)?;
154
155 let (session_type, remainder) = TpmSe::parse(&session_bytes).map_err(DeviceError::Tpm)?;
156 let (context, remainder) = TpmsContext::parse(remainder).map_err(DeviceError::Tpm)?;
157 let (nonce_tpm, remainder) = Tpm2bNonce::parse(remainder).map_err(DeviceError::Tpm)?;
158 let (attributes, remainder) = TpmaSession::parse(remainder).map_err(DeviceError::Tpm)?;
159 let (hmac_key, remainder) = Tpm2bAuth::parse(remainder).map_err(DeviceError::Tpm)?;
160 let (auth_hash, remainder) = TpmAlgId::parse(remainder).map_err(DeviceError::Tpm)?;
161
162 if !remainder.is_empty() {
163 return Err(SessionError::TrailingData);
164 }
165
166 Ok(Self {
167 handle: TpmHandle(0),
168 context,
169 nonce_tpm,
170 attributes,
171 hmac_key,
172 auth_hash,
173 session_type,
174 })
175 }
176}
177
178#[derive(Debug)]
179pub struct SessionCache {
180 pub sessions: HashMap<String, Session>,
181 pub dirty: HashSet<String>,
182 pub sessions_dir: PathBuf,
183}
184
185impl<'a> IntoIterator for &'a SessionCache {
186 type Item = (&'a String, &'a Session);
187 type IntoIter = hash_map::Iter<'a, String, Session>;
188
189 fn into_iter(self) -> Self::IntoIter {
190 self.iter()
191 }
192}
193
194impl SessionCache {
195 #[must_use]
197 pub fn new(cache_dir: &Path) -> Self {
198 Self {
199 sessions: HashMap::new(),
200 dirty: HashSet::new(),
201 sessions_dir: cache_dir.join("sessions"),
202 }
203 }
204
205 pub fn load_sessions(&mut self) -> Result<(), SessionError> {
212 std::fs::create_dir_all(&self.sessions_dir)?;
213
214 let entries = match std::fs::read_dir(&self.sessions_dir) {
215 Ok(entries) => entries.filter_map(Result::ok),
216 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
217 Err(e) => return Err(e.into()),
218 };
219
220 for entry in entries {
221 let path = entry.path();
222 if path.extension().and_then(|s| s.to_str()) != Some("session") {
223 continue;
224 }
225
226 let Some(file_stem) = path.file_stem().and_then(|s| s.to_str()) else {
227 continue;
228 };
229
230 let Ok(handle) = u32::from_str_radix(file_stem, 16) else {
231 log::warn!("Invalid session filename format: '{}'", path.display());
232 continue;
233 };
234
235 let session = match Session::load_from_path(&path) {
236 Ok(s) => s,
237 Err(e) => {
238 log::warn!(
239 "Failed to load session file '{}': {e}. Deleting.",
240 path.display()
241 );
242 let _ = std::fs::remove_file(path);
243 continue;
244 }
245 };
246
247 if session.context.saved_handle.0 != handle {
248 log::warn!(
249 "Session file '{}' has mismatched handle in its content. Deleting.",
250 path.display()
251 );
252 let _ = std::fs::remove_file(path);
253 continue;
254 }
255
256 let uri = Uri::Session(handle).to_string();
257 self.sessions.insert(uri, session);
258 }
259 Ok(())
260 }
261
262 pub fn refresh_sessions(&mut self, device: &mut Device) -> Result<(), SessionError> {
270 let uris_to_refresh: Vec<String> = self.sessions.keys().cloned().collect();
271 for uri in uris_to_refresh {
272 let session = match self.get(&uri) {
273 Ok(s) => s.clone(),
274 Err(_) => continue,
275 };
276
277 match device.load_context(session.context.clone()) {
278 Ok(live_handle) => match device.save_context(live_handle) {
279 Ok(new_context) => {
280 if let Ok(s) = self.get_mut(&uri) {
281 s.context = new_context;
282 }
283 }
284 Err(e) => log::warn!("{uri}: {e}"),
285 },
286 Err(DeviceError::TpmRc(rc))
287 if matches!(rc.base(), TpmRcBase::Handle | TpmRcBase::ReferenceH0) =>
288 {
289 log::debug!("Removing stale session file for {uri}");
290 if self.remove(&uri).is_err() {
291 log::warn!("Failed to remove stale session for {uri}");
292 }
293 }
294 Err(e) => log::warn!("{uri}: {e}"),
295 }
296 }
297
298 Ok(())
299 }
300
301 pub fn save(&mut self) -> Result<(), SessionError> {
307 if self.dirty.is_empty() {
308 return Ok(());
309 }
310
311 std::fs::create_dir_all(&self.sessions_dir)?;
312
313 for uri in self.dirty.drain() {
314 if let Some(session) = self.sessions.get(&uri) {
315 let handle = session.context.saved_handle.0;
316 let path = self.sessions_dir.join(format!("{handle:x}.session"));
317 session.save_to_path(&path)?;
318 }
319 }
320 Ok(())
321 }
322
323 pub fn add(&mut self, session: Session) -> String {
325 let handle = session.context.saved_handle.0;
326 let uri = Uri::Session(handle).to_string();
327 self.sessions.insert(uri.clone(), session);
328 self.dirty.insert(uri.clone());
329 uri
330 }
331
332 pub fn remove(&mut self, uri: &str) -> Result<Option<Session>, SessionError> {
339 let session = self.sessions.remove(uri);
340 self.dirty.remove(uri);
341
342 if let Ok(parsed_uri) = Uri::from_str(uri) {
343 if let Ok(handle) = parsed_uri.to_handle() {
344 let path = self.sessions_dir.join(format!("{handle:x}.session"));
345 if let Err(e) = std::fs::remove_file(path) {
346 if e.kind() != std::io::ErrorKind::NotFound {
347 return Err(e.into());
348 }
349 }
350 }
351 }
352 Ok(session)
353 }
354
355 pub fn reset(&mut self) -> Result<(), SessionError> {
361 for session in self.sessions.values() {
362 let handle = session.context.saved_handle.0;
363 let path = self.sessions_dir.join(format!("{handle:x}.session"));
364 if path.exists() {
365 std::fs::remove_file(path)?;
366 }
367 }
368 self.sessions.clear();
369 self.dirty.clear();
370 Ok(())
371 }
372
373 pub fn get(&self, uri: &str) -> Result<&Session, SessionError> {
379 self.sessions
380 .get(uri)
381 .ok_or_else(|| SessionError::NotFound(uri.to_string()))
382 }
383
384 pub fn get_mut(&mut self, uri: &str) -> Result<&mut Session, SessionError> {
390 self.dirty.insert(uri.to_string());
391 self.sessions
392 .get_mut(uri)
393 .ok_or_else(|| SessionError::NotFound(uri.to_string()))
394 }
395
396 #[must_use]
398 pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Session> {
399 self.sessions.iter()
400 }
401
402 pub fn prepare_sessions(
409 &mut self,
410 device: &mut Device,
411 auth_list: &[Auth],
412 ) -> Result<Vec<u32>, SessionError> {
413 let mut activated_handles = Vec::new();
414 for auth in auth_list {
415 if let Auth::Tracked(handle) = auth {
416 let uri = Uri::Session(*handle).to_string();
417 let session_is_loaded = {
418 let session = self.get(&uri)?;
419 session.handle.0 != 0
420 };
421 if !session_is_loaded {
422 let new_handle = {
423 let session = self.get(&uri)?;
424 device.load_context(session.context.clone())?
425 };
426 let session = self.get_mut(&uri)?;
427 session.handle = TpmHandle(new_handle);
428 activated_handles.push(new_handle);
429 }
430 }
431 }
432 Ok(activated_handles)
433 }
434
435 pub fn build_auth_area<C: TpmCommandObject>(
442 &self,
443 device: &mut Device,
444 command: &C,
445 handles: &[u32],
446 auth_list: &[Auth],
447 ) -> Result<(Vec<TpmsAuthCommand>, Vec<u32>), SessionError> {
448 build_auth_area(self, device, command, handles, auth_list)
449 }
450
451 pub fn teardown_sessions(
458 &mut self,
459 device: &mut Device,
460 session_handles: &[u32],
461 auth_responses: &TpmAuthResponses,
462 ) -> Result<(), SessionError> {
463 for (i, handle) in session_handles.iter().enumerate() {
464 let uri = Uri::Session(*handle).to_string();
465 let session_handle = self.get(&uri)?.handle;
466 if session_handle.0 == 0 {
467 continue;
468 }
469
470 match device.save_context(session_handle.0) {
471 Ok(new_context) => {
472 let session = self.get_mut(&uri)?;
473 session.context = new_context;
474 let auth: TpmsAuthResponse = auth_responses[i];
475 session.nonce_tpm = auth.nonce;
476 session.attributes = auth.session_attributes;
477 }
478 Err(e) => {
479 log::warn!("Failed to save session context for {uri}: {e}. Flushing handle.");
480
481 if device.flush_context(session_handle.0).is_err() {
482 log::warn!("Failed to flush orphaned session handle {session_handle}.");
483 }
484
485 if let Ok(session) = self.get_mut(&uri) {
486 session.handle = TpmHandle(0);
487 } else {
488 log::warn!("Session '{uri}' not found during error handling.");
489 }
490
491 return Err(e.into());
492 }
493 }
494 }
495 Ok(())
496 }
497}
498
499pub(crate) fn build_password_session(password: &[u8]) -> Result<TpmsAuthCommand, SessionError> {
500 Ok(TpmsAuthCommand {
501 session_handle: (tpm2_protocol::data::TpmRh::Pw as u32).into(),
502 nonce: Tpm2bNonce::default(),
503 session_attributes: TpmaSession::empty(),
504 hmac: Tpm2bAuth::try_from(password).map_err(DeviceError::Tpm)?,
505 })
506}
507
508fn build_auth_area<C: TpmCommandObject>(
509 session_map: &SessionCache,
510 device: &mut Device,
511 command: &C,
512 handles: &[u32],
513 auth_list: &[Auth],
514) -> Result<(Vec<TpmsAuthCommand>, Vec<u32>), SessionError> {
515 let mut built_auths = Vec::new();
516 let mut tracked_session_handles = Vec::new();
517
518 let mut auth_iter = auth_list.iter();
519 let params = from_tpm_object_to_vec(command).map_err(DeviceError::Tpm)?;
520
521 for handle in handles {
522 let auth = auth_iter.next().ok_or(SessionError::AuthPairingError)?;
523
524 match auth {
525 Auth::Password(password) => {
526 built_auths.push(build_password_session(password)?);
527 }
528 Auth::Tracked(session_handle) => {
529 tracked_session_handles.push(*session_handle);
530 let uri = Uri::Session(*session_handle).to_string();
531 let session = session_map.get(&uri)?;
532
533 let nonce_caller = new_nonce(session.auth_hash)?;
534
535 let result = create_auth(
536 device,
537 session,
538 &nonce_caller,
539 &[],
540 C::CC,
541 &[*handle],
542 ¶ms,
543 )?;
544 built_auths.push(result);
545 }
546 }
547 }
548 Ok((built_auths, tracked_session_handles))
549}
550
551fn new_nonce(hash_alg: TpmAlgId) -> Result<Tpm2bNonce, DeviceError> {
552 let nonce_size =
553 tpm_hash_size(&hash_alg).ok_or(DeviceError::Tpm(TpmErrorKind::InvalidValue))?;
554 let mut nonce_bytes = vec![0; nonce_size];
555 thread_rng().fill_bytes(&mut nonce_bytes);
556 Tpm2bNonce::try_from(nonce_bytes.as_slice()).map_err(Into::into)
557}
558
559#[allow(clippy::too_many_arguments)]
560fn create_auth(
561 device: &mut Device,
562 session: &Session,
563 nonce_caller: &Tpm2bNonce,
564 auth_value: &[u8],
565 command_code: TpmCc,
566 handles: &[u32],
567 parameters: &[u8],
568) -> Result<TpmsAuthCommand, SessionError> {
569 let handle_names: Vec<Tpm2bName> = handles
570 .iter()
571 .map(|&handle| device.get_handle_name(handle))
572 .collect::<Result<_, _>>()?;
573
574 let command_code_bytes = (command_code as u32).to_be_bytes();
575
576 let mut cp_hash_chunks: Vec<&[u8]> = Vec::with_capacity(2 + handle_names.len());
577 cp_hash_chunks.push(&command_code_bytes);
578 for name in &handle_names {
579 cp_hash_chunks.push(name.as_ref());
580 }
581 cp_hash_chunks.push(parameters);
582
583 let cp_hash = crypto_digest(session.auth_hash, &cp_hash_chunks)?;
584
585 let hmac_bytes = if session.session_type == TpmSe::Hmac {
586 let hmac_key = [session.hmac_key.as_ref(), auth_value].concat();
587
588 let mut hmac_payload: Vec<&[u8]> = Vec::with_capacity(8);
589 hmac_payload.push(&cp_hash);
590 hmac_payload.push(nonce_caller.as_ref());
591 hmac_payload.push(session.nonce_tpm.as_ref());
592
593 let attribute_bits = [session.attributes.bits()];
594 hmac_payload.push(&attribute_bits);
595
596 crypto_hmac(session.auth_hash, &hmac_key, &hmac_payload)?
597 } else {
598 Vec::new()
599 };
600
601 Ok(TpmsAuthCommand {
602 session_handle: session.handle,
603 nonce: *nonce_caller,
604 session_attributes: session.attributes,
605 hmac: Tpm2bAuth::try_from(hmac_bytes.as_slice()).map_err(DeviceError::Tpm)?,
606 })
607}