1use std::{fmt, time::SystemTime};
19
20use base64::{Engine as _, engine::general_purpose};
21use k256::{PublicKey, elliptic_curve::sec1::ToEncodedPoint};
22use rand_core::{OsRng, RngCore};
23use serde_json::Value;
24use sha2::{Digest as Sha2Digest, Sha256};
25use sha3::Keccak256;
26use thiserror::Error;
27
28use crate::{
29 config::{AttestationConfig, NvidiaRequirement, ProxyConfig},
30 util::json_kind,
31 venice::{VeniceClient, VeniceClientError},
32};
33
34const ATTESTATION_NONCE_BYTES: usize = 32;
35const ATTESTATION_NONCE_HEX_CHARS: usize = ATTESTATION_NONCE_BYTES * 2;
36const TDX_TEE_TYPE: u32 = 0x81;
37const TDX_QUOTE_HEADER_LEN: usize = 48;
38const TDX_QUOTE_TEE_TYPE_OFFSET: usize = 4;
39const TDX_QUOTE_TEE_TYPE_END: usize = TDX_QUOTE_TEE_TYPE_OFFSET + 4;
40const TDX_REPORT_BODY_OFFSET: usize = TDX_QUOTE_HEADER_LEN;
41const TDX_REPORT_TD_ATTRIBUTES_OFFSET: usize = TDX_REPORT_BODY_OFFSET + 120;
42const TDX_REPORT_TD_ATTRIBUTES_END: usize = TDX_REPORT_TD_ATTRIBUTES_OFFSET + 8;
43const TDX_REPORT_DATA_OFFSET: usize = TDX_REPORT_BODY_OFFSET + 520;
44const TDX_REPORT_DATA_LEN: usize = 64;
45const TDX_REPORT_DATA_END: usize = TDX_REPORT_DATA_OFFSET + TDX_REPORT_DATA_LEN;
46
47#[derive(Clone, Debug)]
49pub struct AttestationVerifier {
50 policy: AttestationConfig,
51 venice_client: VeniceClient,
52}
53
54impl AttestationVerifier {
55 pub fn from_config(config: &ProxyConfig, venice_client: VeniceClient) -> Self {
57 Self::new(config.attestation.clone(), venice_client)
58 }
59
60 pub fn new(policy: AttestationConfig, venice_client: VeniceClient) -> Self {
62 Self {
63 policy,
64 venice_client,
65 }
66 }
67
68 pub fn policy(&self) -> &AttestationConfig {
70 &self.policy
71 }
72
73 pub async fn verify_model_attestation(
76 &self,
77 model_id: &str,
78 ) -> Result<VerifiedAttestation, AttestationError> {
79 if model_id.trim().is_empty() {
80 return Err(AttestationError::InvalidRequest {
81 message: "model id must not be empty".to_owned(),
82 });
83 }
84
85 let nonce = AttestationNonce::generate();
86 let evidence = self
87 .venice_client
88 .fetch_attestation_evidence(model_id, nonce.as_str())
89 .await
90 .map_err(AttestationError::Fetch)?;
91
92 self.verify_evidence(model_id, nonce.as_str(), evidence)
93 }
94
95 pub fn verify_evidence(
97 &self,
98 requested_model_id: &str,
99 client_nonce: &str,
100 upstream_response: Value,
101 ) -> Result<VerifiedAttestation, AttestationError> {
102 verify_attestation_evidence(
103 &self.policy,
104 requested_model_id,
105 client_nonce,
106 upstream_response,
107 )
108 }
109}
110
111#[derive(Clone, PartialEq, Eq)]
113pub struct AttestationNonce(String);
114
115impl AttestationNonce {
116 pub fn generate() -> Self {
118 let mut bytes = [0_u8; ATTESTATION_NONCE_BYTES];
119 OsRng.fill_bytes(&mut bytes);
120 Self(hex::encode(bytes))
121 }
122
123 pub fn as_str(&self) -> &str {
125 &self.0
126 }
127}
128
129impl fmt::Debug for AttestationNonce {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 f.debug_tuple("AttestationNonce").field(&self.0).finish()
133 }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq)]
138pub struct VerifiedAttestation {
139 pub model_id: String,
140 pub model_public_key: String,
141 pub signing_address: Option<String>,
142 pub tee_provider: Option<String>,
143 pub tdx: TdxVerificationSummary,
144 pub nvidia: NvidiaVerificationSummary,
145 pub verified_at: SystemTime,
146 pub attestation_report: Value,
147}
148
149#[derive(Debug, Clone, Copy, PartialEq, Eq)]
151pub struct TdxVerificationSummary {
152 pub present: bool,
153 pub verified: bool,
154 pub debug: Option<bool>,
155 pub tee_type: Option<u32>,
156}
157
158impl TdxVerificationSummary {
159 fn not_present() -> Self {
161 Self {
162 present: false,
163 verified: false,
164 debug: None,
165 tee_type: None,
166 }
167 }
168}
169
170#[derive(Debug, Clone, PartialEq, Eq)]
172pub struct NvidiaVerificationSummary {
173 pub present: bool,
174 pub verified: NvidiaVerificationStatus,
175}
176
177impl NvidiaVerificationSummary {
178 fn not_present() -> Self {
180 Self {
181 present: false,
182 verified: NvidiaVerificationStatus::NotPresent,
183 }
184 }
185}
186
187#[derive(Debug, Clone, Copy, PartialEq, Eq)]
189pub enum NvidiaVerificationStatus {
190 NotPresent,
191 IgnoredByPolicy,
192 PresentVerifierUnavailable,
193}
194
195impl NvidiaVerificationStatus {
196 pub fn as_header_value(self) -> &'static str {
198 match self {
199 Self::NotPresent => "not-present",
200 Self::IgnoredByPolicy => "ignored",
201 Self::PresentVerifierUnavailable => "verifier-unavailable",
202 }
203 }
204}
205
206#[derive(Debug, Error)]
208pub enum AttestationError {
209 #[error("invalid attestation request: {message}")]
210 InvalidRequest { message: String },
211 #[error("TEE attestation fetch failed: {0}")]
212 Fetch(#[from] VeniceClientError),
213 #[error("TEE attestation response is malformed: {message}")]
214 MalformedResponse { message: String },
215 #[error("TEE attestation evidence is missing required field {field}")]
216 MissingField { field: &'static str },
217 #[error("TEE attestation verification failed: {message}")]
218 PolicyViolation {
219 code: AttestationFailureCode,
220 message: String,
221 },
222 #[error("TEE attestation verifier unavailable: {message}")]
223 ExternalVerifierUnavailable {
224 verifier: &'static str,
225 message: String,
226 },
227}
228
229impl AttestationError {
230 pub fn api_error_type(&self) -> &'static str {
232 match self {
233 Self::InvalidRequest { .. } => "invalid_request_error",
234 Self::ExternalVerifierUnavailable { .. } => "proxy_attestation_verifier_unavailable",
235 Self::Fetch(_)
236 | Self::MalformedResponse { .. }
237 | Self::MissingField { .. }
238 | Self::PolicyViolation { .. } => "proxy_attestation_error",
239 }
240 }
241
242 pub fn api_error_code(&self) -> &'static str {
244 match self {
245 Self::InvalidRequest { .. } => "invalid_attestation_request",
246 Self::Fetch(_) => "attestation_fetch_failed",
247 Self::MalformedResponse { .. } => "attestation_malformed_response",
248 Self::MissingField { .. } => "attestation_missing_required_field",
249 Self::PolicyViolation { code, .. } => code.as_str(),
250 Self::ExternalVerifierUnavailable { .. } => "attestation_verifier_unavailable",
251 }
252 }
253
254 pub fn verifier_unavailable(&self) -> bool {
256 matches!(self, Self::ExternalVerifierUnavailable { .. })
257 }
258}
259
260#[derive(Debug, Clone, Copy, PartialEq, Eq)]
262pub enum AttestationFailureCode {
263 UpstreamNotVerified,
264 NonceMismatch,
265 ModelMismatch,
266 InvalidSigningKey,
267 SigningAddressMismatch,
268 DebugModeDetected,
269 MissingTdxEvidence,
270 InvalidTdxEvidence,
271 MissingNvidiaEvidence,
272 InvalidNvidiaEvidence,
273}
274
275impl AttestationFailureCode {
276 pub fn as_str(self) -> &'static str {
278 match self {
279 Self::UpstreamNotVerified => "attestation_upstream_not_verified",
280 Self::NonceMismatch => "attestation_nonce_mismatch",
281 Self::ModelMismatch => "attestation_model_mismatch",
282 Self::InvalidSigningKey => "attestation_invalid_signing_key",
283 Self::SigningAddressMismatch => "attestation_signing_address_mismatch",
284 Self::DebugModeDetected => "attestation_debug_mode_detected",
285 Self::MissingTdxEvidence => "attestation_missing_tdx_evidence",
286 Self::InvalidTdxEvidence => "attestation_invalid_tdx_evidence",
287 Self::MissingNvidiaEvidence => "attestation_missing_nvidia_evidence",
288 Self::InvalidNvidiaEvidence => "attestation_invalid_nvidia_evidence",
289 }
290 }
291}
292
293fn verify_attestation_evidence(
295 policy: &AttestationConfig,
296 requested_model_id: &str,
297 client_nonce: &str,
298 upstream_response: Value,
299) -> Result<VerifiedAttestation, AttestationError> {
300 validate_nonce_hex(client_nonce)?;
301
302 let evidence = evidence_object(&upstream_response)?;
303 let verified = required_bool(evidence, "verified")?;
304
305 if !verified {
306 return policy_error(
307 AttestationFailureCode::UpstreamNotVerified,
308 "Venice did not mark the attestation evidence as verified",
309 );
310 }
311
312 let nonce = required_string(evidence, "nonce")?;
313
314 if nonce != client_nonce {
315 return policy_error(
316 AttestationFailureCode::NonceMismatch,
317 "attestation nonce does not match the client nonce; evidence may be stale or replayed",
318 );
319 }
320
321 let model = required_string(evidence, "model")?;
322
323 if model != requested_model_id {
324 return policy_error(
325 AttestationFailureCode::ModelMismatch,
326 format!(
327 "attestation model {model:?} does not match requested model {requested_model_id:?}"
328 ),
329 );
330 }
331
332 let signing_key = optional_non_empty_string(evidence, "signing_key")
333 .or_else(|| optional_non_empty_string(evidence, "signing_public_key"))
334 .ok_or(AttestationError::MissingField {
335 field: "signing_key|signing_public_key",
336 })?;
337 let normalized_signing_key = normalize_public_key_hex(signing_key)?;
338 let derived_address = ethereum_address_from_uncompressed_key_hex(&normalized_signing_key)?;
339 let signing_address = optional_non_empty_string(evidence, "signing_address")
340 .map(normalize_ethereum_address)
341 .transpose()?;
342
343 if let Some(signing_address) = &signing_address
344 && signing_address != &derived_address
345 {
346 return policy_error(
347 AttestationFailureCode::SigningAddressMismatch,
348 format!(
349 "signing_address {signing_address} does not match address {derived_address} derived from signing key"
350 ),
351 );
352 }
353
354 if top_level_debug(evidence) == Some(true) && !policy.allow_debug {
355 return policy_error(
356 AttestationFailureCode::DebugModeDetected,
357 "attestation evidence reports debug mode and attestation.allow_debug=false",
358 );
359 }
360
361 let tdx = evaluate_tdx_policy(
362 policy,
363 evidence,
364 &normalized_signing_key,
365 signing_address.as_deref(),
366 )?;
367 let nvidia = evaluate_nvidia_policy(policy, evidence)?;
368
369 Ok(VerifiedAttestation {
370 model_id: requested_model_id.to_owned(),
371 model_public_key: normalized_signing_key,
372 signing_address,
373 tee_provider: optional_non_empty_string(evidence, "tee_provider").map(ToOwned::to_owned),
374 tdx,
375 nvidia,
376 verified_at: SystemTime::now(),
377 attestation_report: upstream_response,
378 })
379}
380
381fn evaluate_tdx_policy(
383 policy: &AttestationConfig,
384 evidence: &serde_json::Map<String, Value>,
385 signing_key: &str,
386 signing_address: Option<&str>,
387) -> Result<TdxVerificationSummary, AttestationError> {
388 let Some(intel_quote) = optional_non_empty_string(evidence, "intel_quote") else {
389 return if policy.require_tdx {
390 policy_error(
391 AttestationFailureCode::MissingTdxEvidence,
392 "attestation.require_tdx=true but intel_quote is absent",
393 )
394 } else {
395 Ok(TdxVerificationSummary::not_present())
396 };
397 };
398
399 let parsed = parse_tdx_quote(intel_quote)?;
400
401 if parsed.tee_type != TDX_TEE_TYPE {
402 return policy_error(
403 AttestationFailureCode::InvalidTdxEvidence,
404 format!(
405 "Intel quote teeType 0x{:x} is not TDX teeType 0x{TDX_TEE_TYPE:x}",
406 parsed.tee_type
407 ),
408 );
409 }
410
411 if parsed.debug && !policy.allow_debug {
412 return policy_error(
413 AttestationFailureCode::DebugModeDetected,
414 "Intel TDX quote reports debug mode and attestation.allow_debug=false",
415 );
416 }
417
418 if let Some(reportdata) = optional_non_empty_string(evidence, "tdx_reportdata") {
419 verify_reportdata_binding(reportdata, signing_key, signing_address)?;
420 }
421
422 if policy.require_tdx {
423 let message = if policy.pccs_url.trim().is_empty() {
424 "attestation.require_tdx=true requires independent DCAP/QVL quote verification, but no DCAP verifier is linked and attestation.pccs_url is empty".to_owned()
425 } else {
426 "attestation.require_tdx=true requires independent DCAP/QVL quote verification; PCCS URL is configured but this v0.1 verifier has no DCAP/QVL backend linked".to_owned()
427 };
428
429 return Err(AttestationError::ExternalVerifierUnavailable {
430 verifier: "tdx-dcap-qvl",
431 message,
432 });
433 }
434
435 Ok(TdxVerificationSummary {
436 present: true,
437 verified: false,
438 debug: Some(parsed.debug),
439 tee_type: Some(parsed.tee_type),
440 })
441}
442
443fn evaluate_nvidia_policy(
445 policy: &AttestationConfig,
446 evidence: &serde_json::Map<String, Value>,
447) -> Result<NvidiaVerificationSummary, AttestationError> {
448 let nvidia_payload = evidence
449 .get("nvidia_payload")
450 .filter(|value| !value.is_null());
451
452 match (policy.require_nvidia, nvidia_payload) {
453 (NvidiaRequirement::Required, None) => policy_error(
454 AttestationFailureCode::MissingNvidiaEvidence,
455 "attestation.require_nvidia=required but nvidia_payload is absent",
456 ),
457 (NvidiaRequirement::Never, None) => Ok(NvidiaVerificationSummary::not_present()),
458 (NvidiaRequirement::Never, Some(_)) => Ok(NvidiaVerificationSummary {
459 present: true,
460 verified: NvidiaVerificationStatus::IgnoredByPolicy,
461 }),
462 (_, Some(Value::Object(_))) | (_, Some(Value::String(_))) => {
463 Err(AttestationError::ExternalVerifierUnavailable {
464 verifier: "nvidia-nras",
465 message: "NVIDIA attestation payload is present and policy requires verification, but this v0.1 verifier has no NRAS/local NVIDIA verifier backend linked".to_owned(),
466 })
467 }
468 (_, Some(_)) => policy_error(
469 AttestationFailureCode::InvalidNvidiaEvidence,
470 "nvidia_payload is present but is not an object or encoded string",
471 ),
472 (NvidiaRequirement::WhenPresent, None) => Ok(NvidiaVerificationSummary::not_present()),
473 }
474}
475
476#[derive(Debug, Clone, Copy, PartialEq, Eq)]
478struct ParsedTdxQuote {
479 tee_type: u32,
480 debug: bool,
481}
482
483fn parse_tdx_quote(value: &str) -> Result<ParsedTdxQuote, AttestationError> {
485 let bytes = decode_tdx_quote(value)?;
486
487 if bytes.len() < TDX_REPORT_DATA_END {
488 return policy_error(
489 AttestationFailureCode::InvalidTdxEvidence,
490 format!(
491 "Intel TDX quote is too short: got {} bytes, need at least {TDX_REPORT_DATA_END}",
492 bytes.len()
493 ),
494 );
495 }
496
497 let tee_type = u32::from_le_bytes(
498 bytes[TDX_QUOTE_TEE_TYPE_OFFSET..TDX_QUOTE_TEE_TYPE_END]
499 .try_into()
500 .expect("TDX tee_type slice length is fixed"),
501 );
502 let td_attributes = u64::from_le_bytes(
503 bytes[TDX_REPORT_TD_ATTRIBUTES_OFFSET..TDX_REPORT_TD_ATTRIBUTES_END]
504 .try_into()
505 .expect("TDX attributes slice length is fixed"),
506 );
507 let debug = td_attributes & 1 == 1;
508
509 Ok(ParsedTdxQuote { tee_type, debug })
510}
511
512fn decode_tdx_quote(value: &str) -> Result<Vec<u8>, AttestationError> {
514 let value = value.trim();
515 let hex = value.strip_prefix("0x").unwrap_or(value);
516 if !hex.is_empty()
518 && let Ok(bytes) = hex::decode(hex)
519 {
520 return Ok(bytes);
521 }
522
523 general_purpose::STANDARD
524 .decode(value)
525 .map_err(|source| AttestationError::PolicyViolation {
526 code: AttestationFailureCode::InvalidTdxEvidence,
527 message: format!("intel_quote is neither hex nor valid base64: {source}"),
528 })
529}
530
531fn verify_reportdata_binding(
533 reportdata_hex: &str,
534 signing_key: &str,
535 signing_address: Option<&str>,
536) -> Result<(), AttestationError> {
537 let reportdata =
538 hex::decode(reportdata_hex).map_err(|error| AttestationError::PolicyViolation {
539 code: AttestationFailureCode::InvalidTdxEvidence,
540 message: format!("tdx_reportdata is not valid hex: {error}"),
541 })?;
542 if reportdata.len() != TDX_REPORT_DATA_LEN {
543 return policy_error(
544 AttestationFailureCode::InvalidTdxEvidence,
545 format!(
546 "tdx_reportdata has {} bytes, expected {TDX_REPORT_DATA_LEN}",
547 reportdata.len()
548 ),
549 );
550 }
551
552 let signing_key_bytes =
553 hex::decode(signing_key).map_err(|error| AttestationError::PolicyViolation {
554 code: AttestationFailureCode::InvalidSigningKey,
555 message: format!("normalized signing key is not valid hex: {error}"),
556 })?;
557 let signing_key_hash = Sha256::digest(&signing_key_bytes);
558 if reportdata.starts_with(&signing_key_hash[..]) {
559 return Ok(());
560 }
561
562 if let Some(signing_address) = signing_address {
563 let signing_address_hash = Sha256::digest(signing_address.as_bytes());
564 if reportdata.starts_with(&signing_address_hash[..]) {
565 return Ok(());
566 }
567 }
568
569 policy_error(
570 AttestationFailureCode::InvalidTdxEvidence,
571 "TDX REPORTDATA does not bind the attested signing key or signing address",
572 )
573}
574
575fn evidence_object(response: &Value) -> Result<&serde_json::Map<String, Value>, AttestationError> {
577 if let Value::Object(root) = response {
578 if let Some(Value::Object(attestation)) = root.get("attestation") {
579 return Ok(attestation);
580 }
581 return Ok(root);
582 }
583
584 Err(AttestationError::MalformedResponse {
585 message: format!(
586 "expected attestation response object, got {}",
587 json_kind(response)
588 ),
589 })
590}
591
592fn required_bool(
594 object: &serde_json::Map<String, Value>,
595 field: &'static str,
596) -> Result<bool, AttestationError> {
597 match object.get(field) {
598 Some(Value::Bool(value)) => Ok(*value),
599 Some(other) => Err(AttestationError::MalformedResponse {
600 message: format!("field {field} must be a boolean, got {}", json_kind(other)),
601 }),
602 None => Err(AttestationError::MissingField { field }),
603 }
604}
605
606fn required_string<'a>(
608 object: &'a serde_json::Map<String, Value>,
609 field: &'static str,
610) -> Result<&'a str, AttestationError> {
611 match object.get(field) {
612 Some(Value::String(value)) if !value.trim().is_empty() => Ok(value),
613 Some(Value::String(_)) => Err(AttestationError::MalformedResponse {
614 message: format!("field {field} must not be empty"),
615 }),
616 Some(other) => Err(AttestationError::MalformedResponse {
617 message: format!("field {field} must be a string, got {}", json_kind(other)),
618 }),
619 None => Err(AttestationError::MissingField { field }),
620 }
621}
622
623fn optional_non_empty_string<'a>(
625 object: &'a serde_json::Map<String, Value>,
626 field: &'static str,
627) -> Option<&'a str> {
628 match object.get(field) {
629 Some(Value::String(value)) if !value.trim().is_empty() => Some(value.as_str()),
630 _ => None,
631 }
632}
633
634fn top_level_debug(object: &serde_json::Map<String, Value>) -> Option<bool> {
636 object
637 .get("debug")
638 .or_else(|| object.get("tdx_debug"))
639 .and_then(Value::as_bool)
640}
641
642fn normalize_public_key_hex(value: &str) -> Result<String, AttestationError> {
644 let value = value.trim().strip_prefix("0x").unwrap_or(value.trim());
645 let mut bytes = hex::decode(value).map_err(|error| AttestationError::PolicyViolation {
646 code: AttestationFailureCode::InvalidSigningKey,
647 message: error.to_string(),
648 })?;
649
650 if bytes.len() == 64 {
651 let mut uncompressed = Vec::with_capacity(65);
652 uncompressed.push(0x04);
653 uncompressed.extend_from_slice(&bytes);
654 bytes = uncompressed;
655 }
656
657 if !matches!(bytes.len(), 33 | 65) {
658 return policy_error(
659 AttestationFailureCode::InvalidSigningKey,
660 format!(
661 "signing key must be 33-byte compressed, 64-byte x/y, or 65-byte uncompressed SEC1 public key; got {} bytes",
662 bytes.len()
663 ),
664 );
665 }
666
667 let public_key =
668 PublicKey::from_sec1_bytes(&bytes).map_err(|_| AttestationError::PolicyViolation {
669 code: AttestationFailureCode::InvalidSigningKey,
670 message: "signing key is not a valid secp256k1 public key".to_owned(),
671 })?;
672 Ok(hex::encode(public_key.to_encoded_point(false).as_bytes()))
673}
674
675fn ethereum_address_from_uncompressed_key_hex(value: &str) -> Result<String, AttestationError> {
677 let bytes = hex::decode(value).map_err(|error| AttestationError::PolicyViolation {
678 code: AttestationFailureCode::InvalidSigningKey,
679 message: error.to_string(),
680 })?;
681 if bytes.len() != 65 || bytes.first() != Some(&0x04) {
682 return policy_error(
683 AttestationFailureCode::InvalidSigningKey,
684 "normalized signing key is not an uncompressed 65-byte SEC1 key",
685 );
686 }
687
688 let hash = Keccak256::digest(&bytes[1..]);
689 Ok(format!("0x{}", hex::encode(&hash[12..])))
690}
691
692fn normalize_ethereum_address(value: &str) -> Result<String, AttestationError> {
694 let value = value.trim();
695 let stripped = value.strip_prefix("0x").unwrap_or(value);
696 if stripped.len() != 40 || stripped.chars().any(|ch| !ch.is_ascii_hexdigit()) {
697 return policy_error(
698 AttestationFailureCode::SigningAddressMismatch,
699 "signing_address must be a 20-byte Ethereum address encoded as hex",
700 );
701 }
702 Ok(format!("0x{}", stripped.to_ascii_lowercase()))
703}
704
705fn validate_nonce_hex(value: &str) -> Result<(), AttestationError> {
707 if value.len() != ATTESTATION_NONCE_HEX_CHARS {
708 return Err(AttestationError::InvalidRequest {
709 message: format!(
710 "attestation nonce must be {ATTESTATION_NONCE_HEX_CHARS} hex characters"
711 ),
712 });
713 }
714 if value.chars().any(|ch| !ch.is_ascii_hexdigit()) {
715 return Err(AttestationError::InvalidRequest {
716 message: "attestation nonce must contain only hex characters".to_owned(),
717 });
718 }
719 Ok(())
720}
721
722fn policy_error<T>(
724 code: AttestationFailureCode,
725 message: impl Into<String>,
726) -> Result<T, AttestationError> {
727 Err(AttestationError::PolicyViolation {
728 code,
729 message: message.into(),
730 })
731}
732
733#[cfg(test)]
734mod tests {
735 use super::*;
736 use std::{collections::HashMap, net::SocketAddr, time::Duration};
737
738 use axum::{
739 Router,
740 body::Body,
741 extract::Query,
742 http::{Response, StatusCode},
743 response::IntoResponse,
744 routing::get,
745 };
746 use k256::SecretKey;
747 use serde_json::json;
748 use tokio::net::TcpListener;
749
750 const MODEL: &str = "e2ee-qwen3-5-122b-a10b";
751 const NONCE: &str = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
752
753 fn policy_for_basic_success() -> AttestationConfig {
754 AttestationConfig {
755 require_tdx: false,
756 require_nvidia: NvidiaRequirement::WhenPresent,
757 ..AttestationConfig::default()
758 }
759 }
760
761 fn verifier(policy: AttestationConfig) -> AttestationVerifier {
762 AttestationVerifier::new(policy, test_venice_client("http://127.0.0.1:1/api/v1"))
763 }
764
765 fn test_venice_client(base_url: &str) -> VeniceClient {
766 VeniceClient::new(base_url, "test-api-key", Duration::from_secs(1))
767 .expect("test Venice client should build")
768 }
769
770 fn key_material() -> (String, String) {
771 let secret_key = SecretKey::from_slice(&[7_u8; 32]).expect("fixed secret key is valid");
772 let public_key = secret_key.public_key();
773 let public_key_hex = hex::encode(public_key.to_encoded_point(false).as_bytes());
774 let address = ethereum_address_from_uncompressed_key_hex(&public_key_hex)
775 .expect("test public key should derive address");
776 (public_key_hex, address)
777 }
778
779 fn valid_evidence() -> Value {
780 let (signing_key, signing_address) = key_material();
781 json!({
782 "verified": true,
783 "nonce": NONCE,
784 "model": MODEL,
785 "tee_provider": "tdx",
786 "signing_key": signing_key,
787 "signing_address": signing_address
788 })
789 }
790
791 #[test]
792 fn generated_nonce_is_32_bytes_lower_hex() {
793 let nonce = AttestationNonce::generate();
794
795 assert_eq!(nonce.as_str().len(), 64);
796 assert!(nonce.as_str().chars().all(|ch| ch.is_ascii_hexdigit()));
797 assert!(!nonce.as_str().chars().any(|ch| ch.is_ascii_uppercase()));
798 }
799
800 #[test]
801 fn valid_basic_evidence_passes_without_optional_hardware_requirements() {
802 let result = verifier(policy_for_basic_success())
803 .verify_evidence(MODEL, NONCE, valid_evidence())
804 .expect("valid basic attestation should pass");
805
806 let (expected_key, expected_address) = key_material();
807 assert_eq!(result.model_id, MODEL);
808 assert_eq!(result.model_public_key, expected_key);
809 assert_eq!(
810 result.signing_address.as_deref(),
811 Some(expected_address.as_str())
812 );
813 assert_eq!(result.tee_provider.as_deref(), Some("tdx"));
814 assert!(!result.tdx.present);
815 assert_eq!(result.nvidia.verified, NvidiaVerificationStatus::NotPresent);
816 }
817
818 #[test]
819 fn missing_required_fields_fail_closed() {
820 let mut evidence = valid_evidence();
821 evidence.as_object_mut().unwrap().remove("verified");
822
823 let error = verifier(policy_for_basic_success())
824 .verify_evidence(MODEL, NONCE, evidence)
825 .expect_err("missing verified field must fail");
826
827 assert!(matches!(
828 error,
829 AttestationError::MissingField { field: "verified" }
830 ));
831 assert_eq!(error.api_error_code(), "attestation_missing_required_field");
832 }
833
834 #[test]
835 fn debug_evidence_fails_when_debug_is_not_allowed() {
836 let mut evidence = valid_evidence();
837 evidence
838 .as_object_mut()
839 .unwrap()
840 .insert("debug".to_owned(), json!(true));
841
842 let error = verifier(policy_for_basic_success())
843 .verify_evidence(MODEL, NONCE, evidence)
844 .expect_err("debug attestation must fail");
845
846 assert!(matches!(
847 error,
848 AttestationError::PolicyViolation {
849 code: AttestationFailureCode::DebugModeDetected,
850 ..
851 }
852 ));
853 }
854
855 #[test]
856 fn tdx_required_mode_fails_on_missing_tdx_evidence() {
857 let error = verifier(AttestationConfig {
858 require_tdx: true,
859 require_nvidia: NvidiaRequirement::Never,
860 ..AttestationConfig::default()
861 })
862 .verify_evidence(MODEL, NONCE, valid_evidence())
863 .expect_err("missing required TDX evidence must fail");
864
865 assert!(matches!(
866 error,
867 AttestationError::PolicyViolation {
868 code: AttestationFailureCode::MissingTdxEvidence,
869 ..
870 }
871 ));
872 }
873
874 #[test]
875 fn tdx_required_mode_fails_on_invalid_tdx_evidence() {
876 let mut evidence = valid_evidence();
877 evidence
878 .as_object_mut()
879 .unwrap()
880 .insert("intel_quote".to_owned(), json!("not quote encoding"));
881
882 let error = verifier(AttestationConfig {
883 require_tdx: true,
884 require_nvidia: NvidiaRequirement::Never,
885 ..AttestationConfig::default()
886 })
887 .verify_evidence(MODEL, NONCE, evidence)
888 .expect_err("invalid TDX evidence must fail");
889
890 assert!(matches!(
891 error,
892 AttestationError::PolicyViolation {
893 code: AttestationFailureCode::InvalidTdxEvidence,
894 ..
895 }
896 ));
897 }
898
899 #[test]
900 fn tdx_debug_quote_fails_when_debug_is_not_allowed() {
901 let mut evidence = valid_evidence();
902 evidence.as_object_mut().unwrap().insert(
903 "intel_quote".to_owned(),
904 json!(tdx_quote_hex(true, TDX_TEE_TYPE)),
905 );
906
907 let error = verifier(AttestationConfig {
908 require_tdx: false,
909 require_nvidia: NvidiaRequirement::Never,
910 allow_debug: false,
911 ..AttestationConfig::default()
912 })
913 .verify_evidence(MODEL, NONCE, evidence)
914 .expect_err("debug quote must fail");
915
916 assert!(matches!(
917 error,
918 AttestationError::PolicyViolation {
919 code: AttestationFailureCode::DebugModeDetected,
920 ..
921 }
922 ));
923 }
924
925 #[test]
926 fn tdx_optional_mode_accepts_legacy_base64_quote_encoding() {
927 let mut evidence = valid_evidence();
928 evidence.as_object_mut().unwrap().insert(
929 "intel_quote".to_owned(),
930 json!(tdx_quote_base64(false, TDX_TEE_TYPE)),
931 );
932
933 let result = verifier(AttestationConfig {
934 require_tdx: false,
935 require_nvidia: NvidiaRequirement::Never,
936 ..AttestationConfig::default()
937 })
938 .verify_evidence(MODEL, NONCE, evidence)
939 .expect("legacy base64-encoded TDX quote should parse when TDX is optional");
940
941 assert!(result.tdx.present);
942 assert_eq!(result.tdx.tee_type, Some(TDX_TEE_TYPE));
943 }
944
945 #[test]
946 fn tdx_required_mode_fails_closed_when_dcap_verifier_is_unavailable() {
947 let mut evidence = valid_evidence();
948 evidence.as_object_mut().unwrap().insert(
949 "intel_quote".to_owned(),
950 json!(tdx_quote_hex(false, TDX_TEE_TYPE)),
951 );
952
953 let error = verifier(AttestationConfig {
954 require_tdx: true,
955 require_nvidia: NvidiaRequirement::Never,
956 ..AttestationConfig::default()
957 })
958 .verify_evidence(MODEL, NONCE, evidence)
959 .expect_err("strict TDX should fail without DCAP verifier");
960
961 assert!(matches!(
962 error,
963 AttestationError::ExternalVerifierUnavailable {
964 verifier: "tdx-dcap-qvl",
965 ..
966 }
967 ));
968 assert_eq!(error.api_error_code(), "attestation_verifier_unavailable");
969 }
970
971 #[test]
972 fn nvidia_required_mode_fails_on_missing_nvidia_evidence() {
973 let error = verifier(AttestationConfig {
974 require_tdx: false,
975 require_nvidia: NvidiaRequirement::Required,
976 ..AttestationConfig::default()
977 })
978 .verify_evidence(MODEL, NONCE, valid_evidence())
979 .expect_err("missing required NVIDIA evidence must fail");
980
981 assert!(matches!(
982 error,
983 AttestationError::PolicyViolation {
984 code: AttestationFailureCode::MissingNvidiaEvidence,
985 ..
986 }
987 ));
988 }
989
990 #[test]
991 fn nvidia_required_mode_fails_on_invalid_nvidia_evidence() {
992 let mut evidence = valid_evidence();
993 evidence
994 .as_object_mut()
995 .unwrap()
996 .insert("nvidia_payload".to_owned(), json!(42));
997
998 let error = verifier(AttestationConfig {
999 require_tdx: false,
1000 require_nvidia: NvidiaRequirement::Required,
1001 ..AttestationConfig::default()
1002 })
1003 .verify_evidence(MODEL, NONCE, evidence)
1004 .expect_err("invalid NVIDIA evidence must fail");
1005
1006 assert!(matches!(
1007 error,
1008 AttestationError::PolicyViolation {
1009 code: AttestationFailureCode::InvalidNvidiaEvidence,
1010 ..
1011 }
1012 ));
1013 }
1014
1015 #[test]
1016 fn nvidia_payload_when_present_fails_closed_without_nras_verifier() {
1017 let mut evidence = valid_evidence();
1018 evidence
1019 .as_object_mut()
1020 .unwrap()
1021 .insert("nvidia_payload".to_owned(), json!({ "nonce": NONCE }));
1022
1023 let error = verifier(policy_for_basic_success())
1024 .verify_evidence(MODEL, NONCE, evidence)
1025 .expect_err("present NVIDIA evidence must be verified");
1026
1027 assert!(matches!(
1028 error,
1029 AttestationError::ExternalVerifierUnavailable {
1030 verifier: "nvidia-nras",
1031 ..
1032 }
1033 ));
1034 }
1035
1036 #[test]
1037 fn nonce_mismatch_fails_closed_as_stale_or_replayed_evidence() {
1038 let mut evidence = valid_evidence();
1039 evidence.as_object_mut().unwrap().insert(
1040 "nonce".to_owned(),
1041 json!("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"),
1042 );
1043
1044 let error = verifier(policy_for_basic_success())
1045 .verify_evidence(MODEL, NONCE, evidence)
1046 .expect_err("nonce mismatch must fail");
1047
1048 assert!(matches!(
1049 error,
1050 AttestationError::PolicyViolation {
1051 code: AttestationFailureCode::NonceMismatch,
1052 ..
1053 }
1054 ));
1055 }
1056
1057 #[test]
1058 fn signing_address_mismatch_fails_closed() {
1059 let mut evidence = valid_evidence();
1060 evidence.as_object_mut().unwrap().insert(
1061 "signing_address".to_owned(),
1062 json!("0x0000000000000000000000000000000000000000"),
1063 );
1064
1065 let error = verifier(policy_for_basic_success())
1066 .verify_evidence(MODEL, NONCE, evidence)
1067 .expect_err("address mismatch must fail");
1068
1069 assert!(matches!(
1070 error,
1071 AttestationError::PolicyViolation {
1072 code: AttestationFailureCode::SigningAddressMismatch,
1073 ..
1074 }
1075 ));
1076 }
1077
1078 #[test]
1079 fn malformed_upstream_response_shape_fails_closed() {
1080 let error = verifier(policy_for_basic_success())
1081 .verify_evidence(MODEL, NONCE, json!([]))
1082 .expect_err("array response must fail");
1083
1084 assert!(matches!(error, AttestationError::MalformedResponse { .. }));
1085 }
1086
1087 #[tokio::test]
1088 async fn fetches_attestation_with_model_and_nonce_then_verifies() {
1089 let base_url = spawn_attestation_server(|query| {
1090 assert_eq!(query.get("model").map(String::as_str), Some(MODEL));
1091 let nonce = query
1092 .get("nonce")
1093 .expect("nonce query parameter should be present");
1094 assert_eq!(nonce.len(), 64);
1095 assert!(nonce.chars().all(|ch| ch.is_ascii_hexdigit()));
1096
1097 let (signing_key, signing_address) = key_material();
1098 (
1099 StatusCode::OK,
1100 serde_json::to_vec(&json!({
1101 "verified": true,
1102 "nonce": nonce,
1103 "model": MODEL,
1104 "signing_key": signing_key,
1105 "signing_address": signing_address
1106 }))
1107 .expect("response should serialize"),
1108 )
1109 })
1110 .await;
1111 let verifier =
1112 AttestationVerifier::new(policy_for_basic_success(), test_venice_client(&base_url));
1113
1114 let result = verifier
1115 .verify_model_attestation(MODEL)
1116 .await
1117 .expect("mock attestation should verify");
1118
1119 assert_eq!(result.model_id, MODEL);
1120 assert_eq!(result.model_public_key, key_material().0);
1121 }
1122
1123 #[tokio::test]
1124 async fn malformed_upstream_json_fails_closed() {
1125 let base_url = spawn_raw_attestation_server(StatusCode::OK, b"{".to_vec()).await;
1126 let verifier =
1127 AttestationVerifier::new(policy_for_basic_success(), test_venice_client(&base_url));
1128
1129 let error = verifier
1130 .verify_model_attestation(MODEL)
1131 .await
1132 .expect_err("malformed upstream JSON must fail");
1133
1134 assert!(matches!(
1135 error,
1136 AttestationError::Fetch(VeniceClientError::MalformedAttestationPayload { .. })
1137 ));
1138 assert_eq!(error.api_error_code(), "attestation_fetch_failed");
1139 }
1140
1141 #[tokio::test]
1142 async fn upstream_fetch_errors_fail_closed() {
1143 let verifier = AttestationVerifier::new(
1144 policy_for_basic_success(),
1145 test_venice_client("http://127.0.0.1:1/api/v1"),
1146 );
1147
1148 let error = verifier
1149 .verify_model_attestation(MODEL)
1150 .await
1151 .expect_err("connection failure must fail closed");
1152
1153 assert!(matches!(error, AttestationError::Fetch(_)));
1154 assert_eq!(error.api_error_code(), "attestation_fetch_failed");
1155 }
1156
1157 fn tdx_quote_hex(debug: bool, tee_type: u32) -> String {
1158 hex::encode(tdx_quote_bytes(debug, tee_type))
1159 }
1160
1161 fn tdx_quote_base64(debug: bool, tee_type: u32) -> String {
1162 general_purpose::STANDARD.encode(tdx_quote_bytes(debug, tee_type))
1163 }
1164
1165 fn tdx_quote_bytes(debug: bool, tee_type: u32) -> Vec<u8> {
1166 let mut bytes = vec![0_u8; TDX_REPORT_DATA_END];
1167 bytes[TDX_QUOTE_TEE_TYPE_OFFSET..TDX_QUOTE_TEE_TYPE_END]
1168 .copy_from_slice(&tee_type.to_le_bytes());
1169 let td_attributes = if debug { 1_u64 } else { 0_u64 };
1170 bytes[TDX_REPORT_TD_ATTRIBUTES_OFFSET..TDX_REPORT_TD_ATTRIBUTES_END]
1171 .copy_from_slice(&td_attributes.to_le_bytes());
1172 bytes
1173 }
1174
1175 async fn spawn_attestation_server<F>(handler: F) -> String
1176 where
1177 F: Fn(HashMap<String, String>) -> (StatusCode, Vec<u8>) + Clone + Send + Sync + 'static,
1178 {
1179 async fn route<F>(
1180 Query(query): Query<HashMap<String, String>>,
1181 handler: F,
1182 ) -> Response<Body>
1183 where
1184 F: Fn(HashMap<String, String>) -> (StatusCode, Vec<u8>) + Clone + Send + Sync + 'static,
1185 {
1186 let (status, body) = handler(query);
1187 (status, body).into_response()
1188 }
1189
1190 let app = Router::new().route(
1191 "/api/v1/tee/attestation",
1192 get({
1193 let handler = handler.clone();
1194 move |query| route(query, handler.clone())
1195 }),
1196 );
1197 spawn_router(app).await
1198 }
1199
1200 async fn spawn_raw_attestation_server(status: StatusCode, body: Vec<u8>) -> String {
1201 let app = Router::new().route(
1202 "/api/v1/tee/attestation",
1203 get(move || async move { (status, body.clone()) }),
1204 );
1205 spawn_router(app).await
1206 }
1207
1208 async fn spawn_router(app: Router) -> String {
1209 let listener = TcpListener::bind(("127.0.0.1", 0))
1210 .await
1211 .expect("test listener should bind");
1212 let addr: SocketAddr = listener.local_addr().expect("listener should have address");
1213 tokio::spawn(async move {
1214 axum::serve(listener, app)
1215 .await
1216 .expect("test server should run");
1217 });
1218 format!("http://{addr}/api/v1")
1219 }
1220}