1use super::{RefreshAction, VtpmContext, VtpmError};
6use crate::{
7 crypto::{crypto_digest, crypto_hash_size, crypto_hmac, crypto_kdfa},
8 device::{Device, DeviceError},
9 key::Tpm2shAlgId,
10 write_object,
11};
12use std::{any::Any, fs, path::Path};
13use tpm2_protocol::{
14 basic::TpmBuffer,
15 data::{
16 Tpm2bAuth, Tpm2bName, Tpm2bNonce, TpmAlgId, TpmCc, TpmHt, TpmRcBase, TpmRh, TpmaSession,
17 TpmsAuthCommand, TpmsContext,
18 },
19 message::TpmStartAuthSessionResponse,
20 TpmBuild, TpmError, TpmParse, TpmSized, TpmWriter,
21};
22
23#[derive(Debug, Clone)]
25pub struct VtpmSession {
26 pub context: TpmsContext,
27 pub nonce_tpm: Tpm2bNonce,
28 pub attributes: TpmaSession,
29 pub hmac_key: Tpm2bAuth,
30 pub auth_hash: TpmAlgId,
31}
32
33impl VtpmSession {
34 pub fn new(
40 auth_hash: TpmAlgId,
41 nonce_caller: Tpm2bNonce,
42 resp: &TpmStartAuthSessionResponse,
43 auth_value: &[u8],
44 ) -> Result<Self, VtpmError> {
45 let digest_len = crypto_hash_size(auth_hash)
46 .ok_or(VtpmError::UnsupportedNameAlgorithm(Tpm2shAlgId(auth_hash)))?;
47
48 let hmac_key_bytes = if (resp.session_handle.0 >> 24) as u8 == TpmHt::HmacSession as u8 {
49 if auth_value.is_empty() {
50 Vec::new()
51 } else {
52 let key_bits = u16::try_from(digest_len * 8)
53 .map_err(|_| VtpmError::InvalidKeyBits(digest_len.to_string()))?;
54 crypto_kdfa(
55 auth_hash,
56 auth_value,
57 "ATH",
58 &resp.nonce_tpm,
59 &nonce_caller,
60 key_bits,
61 )?
62 }
63 } else {
64 Vec::new()
65 };
66
67 Ok(Self {
68 context: TpmsContext {
69 sequence: 0,
70 saved_handle: resp.session_handle.0.into(),
71 hierarchy: TpmRh::Null,
72 context_blob: TpmBuffer::default(),
73 },
74 nonce_tpm: resp.nonce_tpm,
75 attributes: TpmaSession::CONTINUE_SESSION,
76 hmac_key: Tpm2bAuth::try_from(hmac_key_bytes.as_slice())?,
77 auth_hash,
78 })
79 }
80
81 pub(super) fn load_from_path(path: &Path) -> Result<Self, VtpmError> {
83 let session_bytes = fs::read(path)?;
84 let (context, remainder) = TpmsContext::parse(&session_bytes)?;
85 let (nonce_tpm, remainder) = Tpm2bNonce::parse(remainder)?;
86 let (attributes, remainder) = TpmaSession::parse(remainder)?;
87 let (hmac_key, remainder) = Tpm2bAuth::parse(remainder)?;
88 let (auth_hash, _) = TpmAlgId::parse(remainder)?;
89
90 Ok(Self {
91 context,
92 nonce_tpm,
93 attributes,
94 hmac_key,
95 auth_hash,
96 })
97 }
98}
99
100impl VtpmContext for VtpmSession {
101 fn as_any(&self) -> &dyn Any {
102 self
103 }
104
105 fn as_any_mut(&mut self) -> &mut dyn Any {
106 self
107 }
108
109 fn handle(&self) -> u32 {
110 self.context.saved_handle.0
111 }
112
113 fn class(&self) -> &'static str {
114 if (self.handle() >> 24) as u8 == TpmHt::PolicySession as u8 {
115 "policy"
116 } else {
117 "hmac"
118 }
119 }
120
121 fn details(&self) -> String {
122 String::new()
123 }
124
125 fn save(&self, path: &Path) -> Result<(), VtpmError> {
126 let bytes = write_object(self)?;
127 fs::write(path, bytes)?;
128 Ok(())
129 }
130
131 fn delete(&self, device: &mut Device, cache_dir: &Path, vhandle: u32) -> Result<(), VtpmError> {
132 let path = cache_dir.join(format!("{vhandle:08x}.bin"));
133 if let Err(e) = fs::remove_file(path) {
134 if e.kind() != std::io::ErrorKind::NotFound {
135 return Err(e.into());
136 }
137 }
138 match device.flush_session(self.context.clone()) {
139 Ok(()) => {}
140 Err(DeviceError::TpmRc(rc)) if rc.base() == TpmRcBase::ReferenceH0 => {
141 log::debug!("vtpm session:{vhandle:08x} stale");
142 }
143 Err(e) => return Err(e.into()),
144 }
145 Ok(())
146 }
147
148 fn refresh(&mut self, device: &mut Device) -> Result<RefreshAction, VtpmError> {
149 let vhandle = self.handle();
150 match device.load_context(self.context.clone()) {
151 Ok(phandle) => match device.save_context(phandle) {
152 Ok(context) => match device.flush_context(phandle) {
153 Ok(()) => Ok(RefreshAction::Updated(Box::new(context))),
154 Err(e) => {
155 log::warn!("vtpm:{vhandle:08x}: {e}");
156 Ok(RefreshAction::Stale)
157 }
158 },
159 Err(e) => {
160 log::warn!("vtpm:{vhandle:08x}: {e}");
161 if let Err(e) = device.flush_context(phandle) {
162 log::warn!("vtpm:{vhandle:08x}: {e}");
163 }
164 if matches!(&e, DeviceError::TpmRc(rc) if rc.base() == TpmRcBase::ReferenceH0) {
165 Ok(RefreshAction::Stale)
166 } else {
167 Err(e.into())
168 }
169 }
170 },
171 Err(DeviceError::TpmRc(rc)) if rc.base() == TpmRcBase::ReferenceH0 => {
172 Ok(RefreshAction::Stale)
173 }
174 Err(e) => Err(e.into()),
175 }
176 }
177}
178
179impl TpmSized for VtpmSession {
180 const SIZE: usize = 0;
181 fn len(&self) -> usize {
182 self.context.len()
183 + self.nonce_tpm.len()
184 + self.attributes.len()
185 + self.hmac_key.len()
186 + self.auth_hash.len()
187 }
188}
189
190impl TpmBuild for VtpmSession {
191 fn build(&self, writer: &mut TpmWriter) -> Result<(), TpmError> {
192 self.context.build(writer)?;
193 self.nonce_tpm.build(writer)?;
194 self.attributes.build(writer)?;
195 self.hmac_key.build(writer)?;
196 self.auth_hash.build(writer)
197 }
198}
199
200pub fn build_password_session(password: &[u8]) -> Result<TpmsAuthCommand, VtpmError> {
206 Ok(TpmsAuthCommand {
207 session_handle: (tpm2_protocol::data::TpmRh::Pw as u32).into(),
208 nonce: Tpm2bNonce::default(),
209 session_attributes: TpmaSession::empty(),
210 hmac: Tpm2bAuth::try_from(password)?,
211 })
212}
213
214#[allow(clippy::too_many_arguments)]
221pub fn create_auth(
222 device: &mut Device,
223 session: &VtpmSession,
224 nonce_caller: &Tpm2bNonce,
225 auth_value: &[u8],
226 command_code: TpmCc,
227 handles: &[u32],
228 parameters: &[u8],
229 nonce_decrypt: Option<&Tpm2bNonce>,
230 nonce_encrypt: Option<&Tpm2bNonce>,
231) -> Result<TpmsAuthCommand, VtpmError> {
232 let handle_names: Vec<Tpm2bName> = handles
233 .iter()
234 .map(|&handle| device.read_public(handle.into()).map(|(_, name)| name))
235 .collect::<Result<_, DeviceError>>()?;
236
237 let command_code_bytes = (command_code as u32).to_be_bytes();
238
239 let mut cp_hash_chunks: Vec<&[u8]> = Vec::with_capacity(2 + handle_names.len());
240 cp_hash_chunks.push(&command_code_bytes);
241 for name in &handle_names {
242 cp_hash_chunks.push(name.as_ref());
243 }
244 cp_hash_chunks.push(parameters);
245
246 let cp_hash = crypto_digest(session.auth_hash, &cp_hash_chunks)?;
247
248 let hmac_bytes = if (session.context.saved_handle.0 >> 24) as u8 == TpmHt::HmacSession as u8 {
249 let hmac_key = [session.hmac_key.as_ref(), auth_value].concat();
250
251 let mut hmac_payload: Vec<&[u8]> = Vec::with_capacity(8);
252 hmac_payload.push(&cp_hash);
253 hmac_payload.push(nonce_caller.as_ref());
254 hmac_payload.push(session.nonce_tpm.as_ref());
255
256 if let Some(nonce) = nonce_decrypt {
257 hmac_payload.push(nonce.as_ref());
258 }
259 if let Some(nonce) = nonce_encrypt {
260 if nonce_decrypt.map_or(true, |d| d.as_ref() != nonce.as_ref()) {
261 hmac_payload.push(nonce.as_ref());
262 }
263 }
264
265 let attribute_bits = [session.attributes.bits()];
266 hmac_payload.push(&attribute_bits);
267
268 crypto_hmac(session.auth_hash, &hmac_key, &hmac_payload)?
269 } else {
270 Vec::new()
271 };
272
273 Ok(TpmsAuthCommand {
274 session_handle: session.context.saved_handle,
275 nonce: *nonce_caller,
276 session_attributes: session.attributes,
277 hmac: Tpm2bAuth::try_from(hmac_bytes.as_slice())?,
278 })
279}