cli/
session.rs

1// SPDX-License-Identifier: GPL-3.0-or-later
2// Copyright (c) 2025 Opinsys Oy
3// Copyright (c) 2024-2025 Jarkko Sakkinen
4
5//! For sessions, context managements works as follows:
6//!
7//! 1. `TPM2_ContextSave` transforms an active session into saved session, which
8//!    gets a handle from the same address range as HMAC sessions.
9//! 2. `TPM2_ContextLoad` transforms a saved session into active session, and
10//!    session regains the handle assigned upon creation.
11//!
12//! Consequences:
13//!
14//! 1. `TPM2_FlushContext` must not be applied after `TPM2_ContextSave`.
15//! 2. `TPM2_FlushContext` can be only applied to a loaded session.
16
17use crate::{
18    convert::from_tpm_object_to_vec,
19    crypto::{crypto_digest, crypto_hmac, crypto_kdfa, CryptoError},
20    device::{Auth, Device, DeviceError, TpmCommandObject},
21    uri::Uri,
22};
23use rand::{thread_rng, RngCore};
24use std::{
25    collections::{hash_map, HashMap, HashSet},
26    num::TryFromIntError,
27    path::{Path, PathBuf},
28    str::FromStr,
29};
30use thiserror::Error;
31use tpm2_protocol::{
32    constant::TPM_MAX_COMMAND_SIZE,
33    data::{
34        Tpm2bAuth, Tpm2bName, Tpm2bNonce, TpmAlgId, TpmCc, TpmRcBase, TpmRh, TpmSe, TpmaSession,
35        TpmsAuthCommand, TpmsAuthResponse, TpmsContext,
36    },
37    message::{TpmAuthResponses, TpmStartAuthSessionResponse},
38    tpm_hash_size, TpmBuffer, TpmBuild, TpmErrorKind, TpmHandle, TpmParse, TpmWriter,
39};
40
41#[derive(Debug, Error)]
42pub enum SessionError {
43    #[error("crypto: {0}")]
44    Crypto(#[from] CryptoError),
45    #[error("device: {0}")]
46    Device(#[from] DeviceError),
47    #[error("I/O: {0}")]
48    Io(#[from] std::io::Error),
49    #[error("session '{0}' not found")]
50    NotFound(String),
51    #[error("session file has trailing data")]
52    TrailingData,
53    #[error("authorization requires both a session and a password")]
54    AuthPairingError,
55}
56
57impl From<TryFromIntError> for SessionError {
58    fn from(_err: TryFromIntError) -> Self {
59        Self::Device(DeviceError::Tpm(TpmErrorKind::InvalidValue))
60    }
61}
62
63/// Manages the state of an active authorization session.
64#[derive(Debug, Clone)]
65pub struct Session {
66    pub handle: TpmHandle,
67    pub context: TpmsContext,
68    pub nonce_tpm: Tpm2bNonce,
69    pub attributes: TpmaSession,
70    pub hmac_key: Tpm2bAuth,
71    pub auth_hash: TpmAlgId,
72    pub session_type: TpmSe,
73}
74
75impl Session {
76    /// Creates a new session from a `StartAuthSession` response.
77    pub(crate) fn new(
78        session_type: TpmSe,
79        auth_hash: TpmAlgId,
80        nonce_caller: Tpm2bNonce,
81        resp: &TpmStartAuthSessionResponse,
82    ) -> Result<Self, SessionError> {
83        let digest_len =
84            tpm_hash_size(&auth_hash).ok_or(DeviceError::Tpm(TpmErrorKind::InvalidValue))?;
85
86        let hmac_key_bytes = if session_type == TpmSe::Hmac {
87            let key_bits = u16::try_from(digest_len * 8)?;
88            crypto_kdfa(
89                auth_hash,
90                &[],
91                "ATH",
92                &resp.nonce_tpm,
93                &nonce_caller,
94                key_bits,
95            )
96            .map_err(SessionError::Crypto)?
97        } else {
98            Vec::new()
99        };
100
101        Ok(Self {
102            handle: resp.session_handle,
103            context: TpmsContext {
104                sequence: 0,
105                saved_handle: resp.session_handle.0.into(),
106                hierarchy: TpmRh::Null,
107                context_blob: TpmBuffer::default(),
108            },
109            nonce_tpm: resp.nonce_tpm,
110            attributes: TpmaSession::CONTINUE_SESSION,
111            hmac_key: Tpm2bAuth::try_from(hmac_key_bytes.as_slice()).map_err(DeviceError::Tpm)?,
112            auth_hash,
113            session_type,
114        })
115    }
116    /// Saves a session's state to a binary file.
117    ///
118    /// # Errors
119    ///
120    /// Returns `SessionError::Io` on I/O failure.
121    pub fn save_to_path(&self, path: &Path) -> Result<(), SessionError> {
122        let mut buf = vec![0u8; TPM_MAX_COMMAND_SIZE];
123        let len = {
124            let mut writer = TpmWriter::new(&mut buf);
125            self.session_type
126                .build(&mut writer)
127                .map_err(DeviceError::Tpm)?;
128            self.context.build(&mut writer).map_err(DeviceError::Tpm)?;
129            self.nonce_tpm
130                .build(&mut writer)
131                .map_err(DeviceError::Tpm)?;
132            self.attributes
133                .build(&mut writer)
134                .map_err(DeviceError::Tpm)?;
135            self.hmac_key.build(&mut writer).map_err(DeviceError::Tpm)?;
136            self.auth_hash
137                .build(&mut writer)
138                .map_err(DeviceError::Tpm)?;
139            writer.len()
140        };
141        buf.truncate(len);
142
143        std::fs::write(path, &buf)?;
144        Ok(())
145    }
146
147    /// Loads a session from a binary file.
148    ///
149    /// # Errors
150    ///
151    /// Returns `SessionError::Io` on I/O failure.
152    pub fn load_from_path(path: &Path) -> Result<Self, SessionError> {
153        let session_bytes = std::fs::read(path)?;
154
155        let (session_type, remainder) = TpmSe::parse(&session_bytes).map_err(DeviceError::Tpm)?;
156        let (context, remainder) = TpmsContext::parse(remainder).map_err(DeviceError::Tpm)?;
157        let (nonce_tpm, remainder) = Tpm2bNonce::parse(remainder).map_err(DeviceError::Tpm)?;
158        let (attributes, remainder) = TpmaSession::parse(remainder).map_err(DeviceError::Tpm)?;
159        let (hmac_key, remainder) = Tpm2bAuth::parse(remainder).map_err(DeviceError::Tpm)?;
160        let (auth_hash, remainder) = TpmAlgId::parse(remainder).map_err(DeviceError::Tpm)?;
161
162        if !remainder.is_empty() {
163            return Err(SessionError::TrailingData);
164        }
165
166        Ok(Self {
167            handle: TpmHandle(0),
168            context,
169            nonce_tpm,
170            attributes,
171            hmac_key,
172            auth_hash,
173            session_type,
174        })
175    }
176}
177
178#[derive(Debug)]
179pub struct SessionCache {
180    pub sessions: HashMap<String, Session>,
181    pub dirty: HashSet<String>,
182    pub sessions_dir: PathBuf,
183}
184
185impl<'a> IntoIterator for &'a SessionCache {
186    type Item = (&'a String, &'a Session);
187    type IntoIter = hash_map::Iter<'a, String, Session>;
188
189    fn into_iter(self) -> Self::IntoIter {
190        self.iter()
191    }
192}
193
194impl SessionCache {
195    /// Creates a new, empty `SessionMap`.
196    #[must_use]
197    pub fn new(cache_dir: &Path) -> Self {
198        Self {
199            sessions: HashMap::new(),
200            dirty: HashSet::new(),
201            sessions_dir: cache_dir.join("sessions"),
202        }
203    }
204
205    /// Populates the session map by reading all available session files from the
206    /// cache directory.
207    ///
208    /// # Errors
209    ///
210    /// Returns `SessionError` on I/O failure.
211    pub fn load_sessions(&mut self) -> Result<(), SessionError> {
212        std::fs::create_dir_all(&self.sessions_dir)?;
213
214        let entries = match std::fs::read_dir(&self.sessions_dir) {
215            Ok(entries) => entries.filter_map(Result::ok),
216            Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
217            Err(e) => return Err(e.into()),
218        };
219
220        for entry in entries {
221            let path = entry.path();
222            if path.extension().and_then(|s| s.to_str()) != Some("session") {
223                continue;
224            }
225
226            let Some(file_stem) = path.file_stem().and_then(|s| s.to_str()) else {
227                continue;
228            };
229
230            let Ok(handle) = u32::from_str_radix(file_stem, 16) else {
231                log::warn!("Invalid session filename format: '{}'", path.display());
232                continue;
233            };
234
235            let session = match Session::load_from_path(&path) {
236                Ok(s) => s,
237                Err(e) => {
238                    log::warn!(
239                        "Failed to load session file '{}': {e}. Deleting.",
240                        path.display()
241                    );
242                    let _ = std::fs::remove_file(path);
243                    continue;
244                }
245            };
246
247            if session.context.saved_handle.0 != handle {
248                log::warn!(
249                    "Session file '{}' has mismatched handle in its content. Deleting.",
250                    path.display()
251                );
252                let _ = std::fs::remove_file(path);
253                continue;
254            }
255
256            let uri = Uri::Session(handle).to_string();
257            self.sessions.insert(uri, session);
258        }
259        Ok(())
260    }
261
262    /// Validates all sessions at startup, and removes expired sessions from the
263    /// previous power cycle.
264    ///
265    /// # Errors
266    ///
267    /// Returns an aggregate `DeviceError` if any non-recoverable errors occur.
268    /// Individual session failures are logged as warnings.
269    pub fn refresh_sessions(&mut self, device: &mut Device) -> Result<(), SessionError> {
270        let uris_to_refresh: Vec<String> = self.sessions.keys().cloned().collect();
271        for uri in uris_to_refresh {
272            let session = match self.get(&uri) {
273                Ok(s) => s.clone(),
274                Err(_) => continue,
275            };
276
277            match device.load_context(session.context.clone()) {
278                Ok(live_handle) => match device.save_context(live_handle) {
279                    Ok(new_context) => {
280                        if let Ok(s) = self.get_mut(&uri) {
281                            s.context = new_context;
282                        }
283                    }
284                    Err(e) => log::warn!("{uri}: {e}"),
285                },
286                Err(DeviceError::TpmRc(rc))
287                    if matches!(rc.base(), TpmRcBase::Handle | TpmRcBase::ReferenceH0) =>
288                {
289                    log::debug!("Removing stale session file for {uri}");
290                    if self.remove(&uri).is_err() {
291                        log::warn!("Failed to remove stale session for {uri}");
292                    }
293                }
294                Err(e) => log::warn!("{uri}: {e}"),
295            }
296        }
297
298        Ok(())
299    }
300
301    /// Saves all sessions marked as dirty to their respective files.
302    ///
303    /// # Errors
304    ///
305    /// Returns an error on serialization or I/O failure.
306    pub fn save(&mut self) -> Result<(), SessionError> {
307        if self.dirty.is_empty() {
308            return Ok(());
309        }
310
311        std::fs::create_dir_all(&self.sessions_dir)?;
312
313        for uri in self.dirty.drain() {
314            if let Some(session) = self.sessions.get(&uri) {
315                let handle = session.context.saved_handle.0;
316                let path = self.sessions_dir.join(format!("{handle:x}.session"));
317                session.save_to_path(&path)?;
318            }
319        }
320        Ok(())
321    }
322
323    /// Adds a new session and returns its URI.
324    pub fn add(&mut self, session: Session) -> String {
325        let handle = session.context.saved_handle.0;
326        let uri = Uri::Session(handle).to_string();
327        self.sessions.insert(uri.clone(), session);
328        self.dirty.insert(uri.clone());
329        uri
330    }
331
332    /// Removes a session from the map and deletes its file from disk. This
333    /// operation is idempotent.
334    ///
335    /// # Errors
336    ///
337    /// Returns `SessionError` on I/O failure (e.g. permissions).
338    pub fn remove(&mut self, uri: &str) -> Result<Option<Session>, SessionError> {
339        let session = self.sessions.remove(uri);
340        self.dirty.remove(uri);
341
342        if let Ok(parsed_uri) = Uri::from_str(uri) {
343            if let Ok(handle) = parsed_uri.to_handle() {
344                let path = self.sessions_dir.join(format!("{handle:x}.session"));
345                if let Err(e) = std::fs::remove_file(path) {
346                    if e.kind() != std::io::ErrorKind::NotFound {
347                        return Err(e.into());
348                    }
349                }
350            }
351        }
352        Ok(session)
353    }
354
355    /// Removes all sessions from the map and deletes their files from disk.
356    ///
357    /// # Errors
358    ///
359    /// Returns `SessionError::Io` on I/O failure.
360    pub fn reset(&mut self) -> Result<(), SessionError> {
361        for session in self.sessions.values() {
362            let handle = session.context.saved_handle.0;
363            let path = self.sessions_dir.join(format!("{handle:x}.session"));
364            if path.exists() {
365                std::fs::remove_file(path)?;
366            }
367        }
368        self.sessions.clear();
369        self.dirty.clear();
370        Ok(())
371    }
372
373    /// Gets an immutable reference to a session.
374    ///
375    /// # Errors
376    ///
377    /// Returns `SessionError::NotFound` if no session is found.
378    pub fn get(&self, uri: &str) -> Result<&Session, SessionError> {
379        self.sessions
380            .get(uri)
381            .ok_or_else(|| SessionError::NotFound(uri.to_string()))
382    }
383
384    /// Gets a mutable reference to a session, marks it as dirty.
385    ///
386    /// # Errors
387    ///
388    /// Returns `SessionError::NotFound` if no session is found.
389    pub fn get_mut(&mut self, uri: &str) -> Result<&mut Session, SessionError> {
390        self.dirty.insert(uri.to_string());
391        self.sessions
392            .get_mut(uri)
393            .ok_or_else(|| SessionError::NotFound(uri.to_string()))
394    }
395
396    /// Returns an iterator over the sessions.
397    #[must_use]
398    pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Session> {
399        self.sessions.iter()
400    }
401
402    /// Loads session contexts if they are not already active.
403    ///
404    /// # Errors
405    ///
406    /// Returns a `SessionError` if a session URI is invalid or if loading a TPM
407    /// context fails.
408    pub fn prepare_sessions(
409        &mut self,
410        device: &mut Device,
411        auth_list: &[Auth],
412    ) -> Result<Vec<u32>, SessionError> {
413        let mut activated_handles = Vec::new();
414        for auth in auth_list {
415            if let Auth::Tracked(handle) = auth {
416                let uri = Uri::Session(*handle).to_string();
417                let session_is_loaded = {
418                    let session = self.get(&uri)?;
419                    session.handle.0 != 0
420                };
421                if !session_is_loaded {
422                    let new_handle = {
423                        let session = self.get(&uri)?;
424                        device.load_context(session.context.clone())?
425                    };
426                    let session = self.get_mut(&uri)?;
427                    session.handle = TpmHandle(new_handle);
428                    activated_handles.push(new_handle);
429                }
430            }
431        }
432        Ok(activated_handles)
433    }
434
435    /// Builds the authorization area for a command.
436    ///
437    /// # Errors
438    ///
439    /// Returns a `SessionError` if a session URI is not found, or if building
440    /// any part of the authorization command fails.
441    pub fn build_auth_area<C: TpmCommandObject>(
442        &self,
443        device: &mut Device,
444        command: &C,
445        handles: &[u32],
446        auth_list: &[Auth],
447    ) -> Result<(Vec<TpmsAuthCommand>, Vec<u32>), SessionError> {
448        build_auth_area(self, device, command, handles, auth_list)
449    }
450
451    /// Finalizes sessions after a command executes.
452    ///
453    /// # Errors
454    ///
455    /// Returns a `SessionError` if a session URI is not found, or if
456    /// saving/flushing the updated TPM context fails.
457    pub fn teardown_sessions(
458        &mut self,
459        device: &mut Device,
460        session_handles: &[u32],
461        auth_responses: &TpmAuthResponses,
462    ) -> Result<(), SessionError> {
463        for (i, handle) in session_handles.iter().enumerate() {
464            let uri = Uri::Session(*handle).to_string();
465            let session_handle = self.get(&uri)?.handle;
466            if session_handle.0 == 0 {
467                continue;
468            }
469
470            match device.save_context(session_handle.0) {
471                Ok(new_context) => {
472                    let session = self.get_mut(&uri)?;
473                    session.context = new_context;
474                    let auth: TpmsAuthResponse = auth_responses[i];
475                    session.nonce_tpm = auth.nonce;
476                    session.attributes = auth.session_attributes;
477                }
478                Err(e) => {
479                    log::warn!("Failed to save session context for {uri}: {e}. Flushing handle.");
480
481                    if device.flush_context(session_handle.0).is_err() {
482                        log::warn!("Failed to flush orphaned session handle {session_handle}.");
483                    }
484
485                    if let Ok(session) = self.get_mut(&uri) {
486                        session.handle = TpmHandle(0);
487                    } else {
488                        log::warn!("Session '{uri}' not found during error handling.");
489                    }
490
491                    return Err(e.into());
492                }
493            }
494        }
495        Ok(())
496    }
497}
498
499pub(crate) fn build_password_session(password: &[u8]) -> Result<TpmsAuthCommand, SessionError> {
500    Ok(TpmsAuthCommand {
501        session_handle: (tpm2_protocol::data::TpmRh::Pw as u32).into(),
502        nonce: Tpm2bNonce::default(),
503        session_attributes: TpmaSession::empty(),
504        hmac: Tpm2bAuth::try_from(password).map_err(DeviceError::Tpm)?,
505    })
506}
507
508fn build_auth_area<C: TpmCommandObject>(
509    session_map: &SessionCache,
510    device: &mut Device,
511    command: &C,
512    handles: &[u32],
513    auth_list: &[Auth],
514) -> Result<(Vec<TpmsAuthCommand>, Vec<u32>), SessionError> {
515    let mut built_auths = Vec::new();
516    let mut tracked_session_handles = Vec::new();
517
518    let mut auth_iter = auth_list.iter();
519    let params = from_tpm_object_to_vec(command).map_err(DeviceError::Tpm)?;
520
521    for handle in handles {
522        let auth = auth_iter.next().ok_or(SessionError::AuthPairingError)?;
523
524        match auth {
525            Auth::Password(password) => {
526                built_auths.push(build_password_session(password)?);
527            }
528            Auth::Tracked(session_handle) => {
529                tracked_session_handles.push(*session_handle);
530                let uri = Uri::Session(*session_handle).to_string();
531                let session = session_map.get(&uri)?;
532
533                let nonce_caller = new_nonce(session.auth_hash)?;
534
535                let result = create_auth(
536                    device,
537                    session,
538                    &nonce_caller,
539                    &[],
540                    C::CC,
541                    &[*handle],
542                    &params,
543                )?;
544                built_auths.push(result);
545            }
546        }
547    }
548    Ok((built_auths, tracked_session_handles))
549}
550
551fn new_nonce(hash_alg: TpmAlgId) -> Result<Tpm2bNonce, DeviceError> {
552    let nonce_size =
553        tpm_hash_size(&hash_alg).ok_or(DeviceError::Tpm(TpmErrorKind::InvalidValue))?;
554    let mut nonce_bytes = vec![0; nonce_size];
555    thread_rng().fill_bytes(&mut nonce_bytes);
556    Tpm2bNonce::try_from(nonce_bytes.as_slice()).map_err(Into::into)
557}
558
559#[allow(clippy::too_many_arguments)]
560fn create_auth(
561    device: &mut Device,
562    session: &Session,
563    nonce_caller: &Tpm2bNonce,
564    auth_value: &[u8],
565    command_code: TpmCc,
566    handles: &[u32],
567    parameters: &[u8],
568) -> Result<TpmsAuthCommand, SessionError> {
569    let handle_names: Vec<Tpm2bName> = handles
570        .iter()
571        .map(|&handle| device.get_handle_name(handle))
572        .collect::<Result<_, _>>()?;
573
574    let command_code_bytes = (command_code as u32).to_be_bytes();
575
576    let mut cp_hash_chunks: Vec<&[u8]> = Vec::with_capacity(2 + handle_names.len());
577    cp_hash_chunks.push(&command_code_bytes);
578    for name in &handle_names {
579        cp_hash_chunks.push(name.as_ref());
580    }
581    cp_hash_chunks.push(parameters);
582
583    let cp_hash = crypto_digest(session.auth_hash, &cp_hash_chunks)?;
584
585    let hmac_bytes = if session.session_type == TpmSe::Hmac {
586        let hmac_key = [session.hmac_key.as_ref(), auth_value].concat();
587
588        let mut hmac_payload: Vec<&[u8]> = Vec::with_capacity(8);
589        hmac_payload.push(&cp_hash);
590        hmac_payload.push(nonce_caller.as_ref());
591        hmac_payload.push(session.nonce_tpm.as_ref());
592
593        let attribute_bits = [session.attributes.bits()];
594        hmac_payload.push(&attribute_bits);
595
596        crypto_hmac(session.auth_hash, &hmac_key, &hmac_payload)?
597    } else {
598        Vec::new()
599    };
600
601    Ok(TpmsAuthCommand {
602        session_handle: session.handle,
603        nonce: *nonce_caller,
604        session_attributes: session.attributes,
605        hmac: Tpm2bAuth::try_from(hmac_bytes.as_slice()).map_err(DeviceError::Tpm)?,
606    })
607}