1use prost::Message;
25use std::sync::atomic::{AtomicU64, Ordering};
26use std::sync::{Arc, OnceLock};
27use tokio::sync::{mpsc, Mutex};
28use tonic::Streaming;
29
30use crate::crypto::{
31 aes_cbc_decrypt, aes_cbc_encrypt, random_iv, to_base64, StaticKeyExchange, AES_CBC_IV_LENGTH,
32 AES_KEY_LENGTH,
33};
34use crate::error::{Result, VoltError};
35use crate::proto::volt::{
36 invoke_request, invoke_response, method_invoke, method_payload, remote_request,
37 remote_response, InvokeRequest, InvokeRequestKeyExchange, InvokeResponse, MethodEnd,
38 MethodInvoke, MethodPayload, MethodType, RemoteRequest, RemoteResponse,
39};
40
41fn relay_debug_enabled() -> bool {
42 static RELAY_DEBUG: OnceLock<bool> = OnceLock::new();
43 *RELAY_DEBUG.get_or_init(|| match std::env::var("TDX_RELAY_DEBUG") {
44 Ok(val) => val != "0" && !val.is_empty(),
45 Err(_) => false,
46 })
47}
48
49fn log_remote_response(
50 stage: &str,
51 invoke_id: u64,
52 remote_response: &RemoteResponse,
53 plaintext_b64: &str,
54 plaintext_len: usize,
55) {
56 if !relay_debug_enabled() {
57 return;
58 }
59 match serde_json::to_string(remote_response) {
60 Ok(json) => {
61 tracing::info!(
62 target = "relay_debug",
63 "remote_response[{}]: invoke_id={} plaintext_len={} plaintext_b64={} json={}",
64 stage,
65 invoke_id,
66 plaintext_len,
67 plaintext_b64,
68 json
69 );
70 }
71 Err(err) => {
72 tracing::warn!(
73 target = "relay_debug",
74 "remote_response[{}]: invoke_id={} failed_to_serialize={} plaintext_len={} plaintext_b64={}",
75 stage,
76 invoke_id,
77 err,
78 plaintext_len,
79 plaintext_b64
80 );
81 }
82 }
83}
84
85fn log_invoke_request(stage: &str, request: &InvokeRequest, plaintext_b64: Option<&str>) {
86 if !relay_debug_enabled() {
87 return;
88 }
89
90 let wire_b64 = to_base64(&request.encode_to_vec());
91
92 let iv_b64 = if request.iv.is_empty() {
93 String::new()
94 } else {
95 to_base64(&request.iv)
96 };
97
98 let (payload_type, payload_len, payload_b64) = match &request.request_payload {
99 Some(invoke_request::RequestPayload::Payload(bytes)) => {
100 ("payload", bytes.len(), Some(to_base64(bytes)))
101 }
102 Some(invoke_request::RequestPayload::JsonPayload(bytes)) => {
103 ("json_payload", bytes.len(), Some(to_base64(bytes)))
104 }
105 None => ("none", 0, None),
106 };
107
108 tracing::info!(
109 target = "relay_debug",
110 "invoke_request[{}]: id={} token={} target_did={:?} iv_b64={} client_end={} hop_index={} target_service_id={} payload_type={} payload_len={} payload_b64={} plaintext_b64={} wire_b64={}",
111 stage,
112 request.invoke_id,
113 request.token,
114 request.target_did,
115 iv_b64,
116 request.client_end,
117 request.hop_index,
118 request.target_service_id,
119 payload_type,
120 payload_len,
121 payload_b64.as_deref().unwrap_or(""),
122 plaintext_b64.unwrap_or(""),
123 wire_b64,
124 );
125}
126
127#[derive(Debug, Clone, PartialEq)]
129pub enum RelayState {
130 Disconnected,
132 KeyExchangePending,
134 Connected,
136 Closed,
138}
139
140pub struct RelayContext {
147 state: RelayState,
149 key_exchange: StaticKeyExchange,
151 encryption_key: Option<[u8; AES_KEY_LENGTH]>,
153 invoke_id_counter: AtomicU64,
155 target_did: String,
157 target_public_key: Option<Vec<u8>>,
159}
160
161impl std::fmt::Debug for RelayContext {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("RelayContext")
164 .field("state", &self.state)
165 .field("target_did", &self.target_did)
166 .field("has_encryption_key", &self.encryption_key.is_some())
167 .finish()
168 }
169}
170
171impl RelayContext {
172 pub fn new(target_did: String) -> Self {
174 Self {
175 state: RelayState::Disconnected,
176 key_exchange: StaticKeyExchange::new(),
177 encryption_key: None,
178 invoke_id_counter: AtomicU64::new(1),
179 target_did,
180 target_public_key: None,
181 }
182 }
183
184 pub fn set_target_public_key(&mut self, public_key: Vec<u8>) {
186 self.target_public_key = Some(public_key);
187 }
188
189 pub fn next_invoke_id(&self) -> u64 {
191 self.invoke_id_counter.fetch_add(1, Ordering::SeqCst)
192 }
193
194 pub fn state(&self) -> &RelayState {
196 &self.state
197 }
198
199 pub fn public_key_bytes(&self) -> [u8; 32] {
201 self.key_exchange.public_key_bytes()
202 }
203
204 pub fn public_key_base64(&self) -> String {
206 const DER_PREFIX: [u8; 12] = [
207 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x6e, 0x03, 0x21, 0x00, ];
212
213 let mut der = Vec::with_capacity(DER_PREFIX.len() + 32);
214 der.extend_from_slice(&DER_PREFIX);
215 der.extend_from_slice(&self.key_exchange.public_key_bytes());
216 to_base64(&der)
217 }
218
219 pub fn create_key_exchange_request(&mut self, token: &str) -> InvokeRequest {
223 self.state = RelayState::KeyExchangePending;
224
225 let request = InvokeRequest {
226 invoke_id: self.next_invoke_id(),
227 token: token.to_string(),
228 target_did: vec![self.target_did.clone()],
229 iv: Vec::new(), client_end: false,
231 hop_index: 0,
232 target_service_id: String::new(),
233 request_payload: None, };
235
236 log_invoke_request("key_exchange", &request, None);
237 request
238 }
239
240 pub fn process_key_exchange_response(
244 &mut self,
245 key_exchange: &InvokeRequestKeyExchange,
246 ) -> Result<()> {
247 let server_public_key: [u8; 32] = key_exchange
249 .encryption_key
250 .as_slice()
251 .try_into()
252 .map_err(|_| VoltError::crypto("Invalid encryption key length"))?;
253
254 if let Some(target_pk) = &self.target_public_key {
256 let mut message =
258 Vec::with_capacity(key_exchange.encryption_key.len() + key_exchange.nonce.len());
259 message.extend_from_slice(&key_exchange.encryption_key);
260 message.extend_from_slice(&key_exchange.nonce);
261
262 let target_pk_array: [u8; 32] = target_pk
264 .as_slice()
265 .try_into()
266 .map_err(|_| VoltError::crypto("Invalid target public key length"))?;
267
268 let verifying_key = ed25519_dalek::VerifyingKey::from_bytes(&target_pk_array)
269 .map_err(|e| VoltError::crypto(format!("Invalid public key: {}", e)))?;
270
271 let signature_array: [u8; 64] = key_exchange
272 .signature
273 .as_slice()
274 .try_into()
275 .map_err(|_| VoltError::crypto("Invalid signature length"))?;
276
277 let signature = ed25519_dalek::Signature::from_bytes(&signature_array);
278
279 verifying_key
280 .verify_strict(&message, &signature)
281 .map_err(|e| VoltError::crypto(format!("Signature verification failed: {}", e)))?;
282
283 tracing::debug!("Key exchange signature verified successfully");
284 } else {
285 tracing::warn!("No target public key set, skipping signature verification");
286 }
287
288 let encryption_key = self
290 .key_exchange
291 .derive_relay_encryption_key(&server_public_key)?;
292 self.encryption_key = Some(encryption_key);
293 self.state = RelayState::Connected;
294
295 tracing::info!("Relay key exchange completed successfully");
296 Ok(())
297 }
298
299 pub fn encrypt(&self, plaintext: &[u8]) -> Result<(Vec<u8>, [u8; AES_CBC_IV_LENGTH])> {
301 let key = self
302 .encryption_key
303 .ok_or_else(|| VoltError::crypto("No encryption key established"))?;
304
305 let iv = random_iv();
306 let ciphertext = aes_cbc_encrypt(&key, &iv, plaintext)?;
307
308 Ok((ciphertext, iv))
309 }
310
311 pub fn decrypt(&self, ciphertext: &[u8], iv: &[u8]) -> Result<Vec<u8>> {
313 let key = self
314 .encryption_key
315 .ok_or_else(|| VoltError::crypto("No encryption key established"))?;
316
317 let iv_array: [u8; AES_CBC_IV_LENGTH] = iv
318 .try_into()
319 .map_err(|_| VoltError::crypto("Invalid IV length"))?;
320
321 aes_cbc_decrypt(&key, &iv_array, ciphertext)
322 }
323
324 pub fn create_encrypted_request(
329 &self,
330 invoke_id: u64,
331 remote_response: &RemoteResponse,
332 ) -> Result<InvokeRequest> {
333 let mut buf = Vec::new();
335 remote_response.encode(&mut buf).map_err(|e| {
336 VoltError::serialization(format!("Failed to encode RemoteResponse: {}", e))
337 })?;
338
339 let plaintext_b64 = if relay_debug_enabled() {
340 let b64 = to_base64(&buf);
341 log_remote_response("request", invoke_id, remote_response, &b64, buf.len());
342 Some(b64)
343 } else {
344 None
345 };
346
347 let (encrypted, iv) = self.encrypt(&buf)?;
349
350 let request = InvokeRequest {
351 invoke_id,
352 token: String::new(), target_did: vec![self.target_did.clone()],
354 iv: iv.to_vec(),
355 client_end: false,
356 hop_index: 0,
357 target_service_id: String::new(),
358 request_payload: Some(invoke_request::RequestPayload::Payload(encrypted)),
359 };
360
361 log_invoke_request("encrypted", &request, plaintext_b64.as_deref());
362
363 Ok(request)
364 }
365
366 pub fn parse_response(&mut self, response: &InvokeResponse) -> Result<Option<RemoteRequest>> {
370 if let Some(ref key_exchange) = response.key_exchange {
372 self.process_key_exchange_response(key_exchange)?;
373 return Ok(None);
374 }
375
376 if let Some(invoke_response::ResponsePayload::Status(ref status)) =
378 response.response_payload
379 {
380 if status.code != 0 {
381 return Err(VoltError::server(status.code, &status.message));
382 }
383 return Ok(None);
384 }
385
386 let encrypted_payload = match &response.response_payload {
388 Some(invoke_response::ResponsePayload::Payload(data)) => data.as_slice(),
389 Some(invoke_response::ResponsePayload::JsonPayload(data)) => data.as_slice(),
390 _ => return Ok(None),
391 };
392
393 let decrypted = self.decrypt(encrypted_payload, &response.iv)?;
395
396 let remote_request = RemoteRequest::decode(decrypted.as_slice()).map_err(|e| {
398 VoltError::serialization(format!("Failed to decode RemoteRequest: {}", e))
399 })?;
400
401 Ok(Some(remote_request))
402 }
403
404 pub fn create_method_invoke(
406 &self,
407 invoke_id: u64,
408 service_id: &str,
409 method_name: &str,
410 method_type: MethodType,
411 request_data: Vec<u8>,
412 ) -> RemoteResponse {
413 let method_invoke = MethodInvoke {
414 id: invoke_id,
415 service_id: service_id.to_string(),
416 method_name: method_name.to_string(),
417 method_type: method_type as i32,
418 invoke_request: Some(method_invoke::InvokeRequest::Request(request_data)),
419 };
420
421 RemoteResponse {
422 payload: Some(remote_response::Payload::MethodInvoke(method_invoke)),
423 }
424 }
425
426 pub fn create_method_payload(&self, invoke_id: u64, payload: Vec<u8>) -> RemoteResponse {
428 let method_payload = MethodPayload {
429 id: invoke_id,
430 method_payload: Some(method_payload::MethodPayload::Payload(payload)),
431 };
432
433 RemoteResponse {
434 payload: Some(remote_response::Payload::MethodPayload(method_payload)),
435 }
436 }
437
438 pub fn create_method_end(&self, invoke_id: u64) -> RemoteResponse {
440 let method_end = MethodEnd {
441 id: invoke_id,
442 ended: true,
443 error: String::new(),
444 error_code: 0,
445 };
446
447 RemoteResponse {
448 payload: Some(remote_response::Payload::MethodEnd(method_end)),
449 }
450 }
451}
452
453pub struct RelayCall {
457 context: Arc<Mutex<RelayContext>>,
459 invoke_id: u64,
461 service_id: String,
463}
464
465impl RelayCall {
466 pub fn new(context: Arc<Mutex<RelayContext>>, service_id: String) -> Self {
468 Self {
470 context,
471 invoke_id: 0,
472 service_id,
473 }
474 }
475
476 pub async fn unary<Req, Resp>(
478 &mut self,
479 method_name: &str,
480 request: Req,
481 invoke_stream_sender: &mpsc::Sender<InvokeRequest>,
482 invoke_stream_receiver: &mut Streaming<InvokeResponse>,
483 ) -> Result<Resp>
484 where
485 Req: Message,
486 Resp: Message + Default,
487 {
488 let ctx = self.context.lock().await;
489
490 self.invoke_id = ctx.next_invoke_id();
492
493 let mut request_bytes = Vec::new();
495 request
496 .encode(&mut request_bytes)
497 .map_err(|e| VoltError::serialization(format!("Failed to encode request: {}", e)))?;
498
499 let remote_response = ctx.create_method_invoke(
501 self.invoke_id,
502 &self.service_id,
503 method_name,
504 MethodType::Unary,
505 request_bytes,
506 );
507
508 let invoke_request = ctx.create_encrypted_request(self.invoke_id, &remote_response)?;
510 drop(ctx); invoke_stream_sender
513 .send(invoke_request)
514 .await
515 .map_err(|e| VoltError::connection(format!("Failed to send request: {}", e)))?;
516
517 let response = invoke_stream_receiver
519 .message()
520 .await
521 .map_err(|e| VoltError::grpc(e.code(), e.message()))?
522 .ok_or_else(|| VoltError::connection("Stream closed unexpectedly"))?;
523
524 let mut ctx = self.context.lock().await;
526 let remote_request = ctx
527 .parse_response(&response)?
528 .ok_or_else(|| VoltError::protocol("Expected payload response"))?;
529
530 match remote_request.payload {
532 Some(remote_request::Payload::MethodPayload(mp)) => {
533 let payload = match mp.method_payload {
534 Some(method_payload::MethodPayload::Payload(p)) => p,
535 Some(method_payload::MethodPayload::JsonPayload(j)) => j.into_bytes(),
536 None => return Err(VoltError::protocol("Empty method payload")),
537 };
538
539 Resp::decode(payload.as_slice()).map_err(|e| {
540 VoltError::serialization(format!("Failed to decode response: {}", e))
541 })
542 }
543 _ => Err(VoltError::protocol("Expected MethodPayload response")),
544 }
545 }
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551
552 #[test]
553 fn test_relay_context_creation() {
554 let ctx = RelayContext::new("did:tdx:test-volt".to_string());
555 assert_eq!(ctx.state(), &RelayState::Disconnected);
556 assert!(ctx.encryption_key.is_none());
557 }
558
559 #[test]
560 fn test_invoke_id_counter() {
561 let ctx = RelayContext::new("did:tdx:test-volt".to_string());
562 assert_eq!(ctx.next_invoke_id(), 1);
563 assert_eq!(ctx.next_invoke_id(), 2);
564 assert_eq!(ctx.next_invoke_id(), 3);
565 }
566
567 #[test]
568 fn test_key_exchange_request_creation() {
569 let mut ctx = RelayContext::new("did:tdx:test-volt".to_string());
570 let request = ctx.create_key_exchange_request("test-jwt-token");
571
572 assert_eq!(ctx.state(), &RelayState::KeyExchangePending);
573 assert_eq!(request.token, "test-jwt-token");
574 assert_eq!(request.target_did, vec!["did:tdx:test-volt"]);
575 assert!(request.request_payload.is_none());
576 }
577
578 #[test]
579 fn test_encrypt_decrypt_roundtrip() {
580 let mut ctx = RelayContext::new("did:tdx:test-volt".to_string());
581
582 ctx.encryption_key = Some([0u8; AES_KEY_LENGTH]);
584 ctx.state = RelayState::Connected;
585
586 let plaintext = b"Hello, relay world!";
587 let (ciphertext, iv) = ctx.encrypt(plaintext).unwrap();
588 let decrypted = ctx.decrypt(&ciphertext, &iv).unwrap();
589
590 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
591 }
592}