Skip to main content

volt_client_grpc/
relay.rs

1//! Relay protocol implementation for Volt client.
2//!
3//! This module implements the relay protocol used when connecting to a Volt
4//! through a relay server. The relay protocol uses:
5//!
6//! 1. **X25519 key exchange** - For establishing a shared secret
7//! 2. **HKDF key derivation** - For deriving the AES encryption key
8//! 3. **AES-CBC encryption** - For encrypting payloads
9//! 4. **VoltAPI.Invoke** - The bidirectional streaming RPC method
10//!
11//! ## Protocol Flow
12//!
13//! 1. Client opens an `Invoke` stream with the relay
14//! 2. Client sends first `InvokeRequest` with:
15//!    - JWT token for authentication
16//!    - target_did (the Volt DID we want to reach)
17//!    - NO payload (key exchange request)
18//! 3. Relay responds with `InvokeResponse` containing:
19//!    - key_exchange: { encryption_key, nonce, signature }
20//! 4. Client verifies the signature using target's public key
21//! 5. Client derives shared AES key from X25519 ECDH + HKDF
22//! 6. All subsequent requests/responses are AES-CBC encrypted
23
24use prost::Message;
25use std::sync::atomic::{AtomicU64, Ordering};
26use std::sync::{Arc, OnceLock};
27use tokio::sync::{mpsc, Mutex};
28use tonic::Streaming;
29
30use crate::crypto::{
31    aes_cbc_decrypt, aes_cbc_encrypt, random_iv, to_base64, StaticKeyExchange, AES_CBC_IV_LENGTH,
32    AES_KEY_LENGTH,
33};
34use crate::error::{Result, VoltError};
35use crate::proto::volt::{
36    invoke_request, invoke_response, method_invoke, method_payload, remote_request,
37    remote_response, InvokeRequest, InvokeRequestKeyExchange, InvokeResponse, MethodEnd,
38    MethodInvoke, MethodPayload, MethodType, RemoteRequest, RemoteResponse,
39};
40
41fn relay_debug_enabled() -> bool {
42    static RELAY_DEBUG: OnceLock<bool> = OnceLock::new();
43    *RELAY_DEBUG.get_or_init(|| match std::env::var("TDX_RELAY_DEBUG") {
44        Ok(val) => val != "0" && !val.is_empty(),
45        Err(_) => false,
46    })
47}
48
49fn log_remote_response(
50    stage: &str,
51    invoke_id: u64,
52    remote_response: &RemoteResponse,
53    plaintext_b64: &str,
54    plaintext_len: usize,
55) {
56    if !relay_debug_enabled() {
57        return;
58    }
59    match serde_json::to_string(remote_response) {
60        Ok(json) => {
61            tracing::info!(
62                target = "relay_debug",
63                "remote_response[{}]: invoke_id={} plaintext_len={} plaintext_b64={} json={}",
64                stage,
65                invoke_id,
66                plaintext_len,
67                plaintext_b64,
68                json
69            );
70        }
71        Err(err) => {
72            tracing::warn!(
73                target = "relay_debug",
74                "remote_response[{}]: invoke_id={} failed_to_serialize={} plaintext_len={} plaintext_b64={}",
75                stage,
76                invoke_id,
77                err,
78                plaintext_len,
79                plaintext_b64
80            );
81        }
82    }
83}
84
85fn log_invoke_request(stage: &str, request: &InvokeRequest, plaintext_b64: Option<&str>) {
86    if !relay_debug_enabled() {
87        return;
88    }
89
90    let wire_b64 = to_base64(&request.encode_to_vec());
91
92    let iv_b64 = if request.iv.is_empty() {
93        String::new()
94    } else {
95        to_base64(&request.iv)
96    };
97
98    let (payload_type, payload_len, payload_b64) = match &request.request_payload {
99        Some(invoke_request::RequestPayload::Payload(bytes)) => {
100            ("payload", bytes.len(), Some(to_base64(bytes)))
101        }
102        Some(invoke_request::RequestPayload::JsonPayload(bytes)) => {
103            ("json_payload", bytes.len(), Some(to_base64(bytes)))
104        }
105        None => ("none", 0, None),
106    };
107
108    tracing::info!(
109        target = "relay_debug",
110        "invoke_request[{}]: id={} token={} target_did={:?} iv_b64={} client_end={} hop_index={} target_service_id={} payload_type={} payload_len={} payload_b64={} plaintext_b64={} wire_b64={}",
111        stage,
112        request.invoke_id,
113        request.token,
114        request.target_did,
115        iv_b64,
116        request.client_end,
117        request.hop_index,
118        request.target_service_id,
119        payload_type,
120        payload_len,
121        payload_b64.as_deref().unwrap_or(""),
122        plaintext_b64.unwrap_or(""),
123        wire_b64,
124    );
125}
126
127/// State of the relay connection
128#[derive(Debug, Clone, PartialEq)]
129pub enum RelayState {
130    /// Not connected
131    Disconnected,
132    /// Waiting for key exchange response
133    KeyExchangePending,
134    /// Connected and ready for encrypted communication
135    Connected,
136    /// Connection closed
137    Closed,
138}
139
140/// Relay connection context
141///
142/// This manages the state for a single relay connection, including:
143/// - Key exchange state
144/// - Encryption key
145/// - Invoke ID counter
146pub struct RelayContext {
147    /// Current state
148    state: RelayState,
149    /// Our X25519 key exchange
150    key_exchange: StaticKeyExchange,
151    /// Derived AES encryption key (after key exchange)
152    encryption_key: Option<[u8; AES_KEY_LENGTH]>,
153    /// Invoke ID counter
154    invoke_id_counter: AtomicU64,
155    /// Target Volt DID
156    target_did: String,
157    /// Target public key (for signature verification)
158    target_public_key: Option<Vec<u8>>,
159}
160
161impl std::fmt::Debug for RelayContext {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        f.debug_struct("RelayContext")
164            .field("state", &self.state)
165            .field("target_did", &self.target_did)
166            .field("has_encryption_key", &self.encryption_key.is_some())
167            .finish()
168    }
169}
170
171impl RelayContext {
172    /// Create a new relay context
173    pub fn new(target_did: String) -> Self {
174        Self {
175            state: RelayState::Disconnected,
176            key_exchange: StaticKeyExchange::new(),
177            encryption_key: None,
178            invoke_id_counter: AtomicU64::new(1),
179            target_did,
180            target_public_key: None,
181        }
182    }
183
184    /// Set the target's public key for signature verification
185    pub fn set_target_public_key(&mut self, public_key: Vec<u8>) {
186        self.target_public_key = Some(public_key);
187    }
188
189    /// Get the next invoke ID
190    pub fn next_invoke_id(&self) -> u64 {
191        self.invoke_id_counter.fetch_add(1, Ordering::SeqCst)
192    }
193
194    /// Get current state
195    pub fn state(&self) -> &RelayState {
196        &self.state
197    }
198
199    /// Get our public key bytes
200    pub fn public_key_bytes(&self) -> [u8; 32] {
201        self.key_exchange.public_key_bytes()
202    }
203
204    /// Get our public key as base64 (SubjectPublicKeyInfo DER, matching JS client)
205    pub fn public_key_base64(&self) -> String {
206        const DER_PREFIX: [u8; 12] = [
207            0x30, 0x2a, // SEQUENCE, len 42
208            0x30, 0x05, // SEQUENCE, len 5
209            0x06, 0x03, 0x2b, 0x65, 0x6e, // OID 1.3.101.110 (X25519)
210            0x03, 0x21, 0x00, // BIT STRING, len 33, 0 unused bits
211        ];
212
213        let mut der = Vec::with_capacity(DER_PREFIX.len() + 32);
214        der.extend_from_slice(&DER_PREFIX);
215        der.extend_from_slice(&self.key_exchange.public_key_bytes());
216        to_base64(&der)
217    }
218
219    /// Create the initial key exchange request
220    ///
221    /// This is the first message sent on the Invoke stream.
222    pub fn create_key_exchange_request(&mut self, token: &str) -> InvokeRequest {
223        self.state = RelayState::KeyExchangePending;
224
225        let request = InvokeRequest {
226            invoke_id: self.next_invoke_id(),
227            token: token.to_string(),
228            target_did: vec![self.target_did.clone()],
229            iv: Vec::new(), // No IV for key exchange request
230            client_end: false,
231            hop_index: 0,
232            target_service_id: String::new(),
233            request_payload: None, // No payload for key exchange
234        };
235
236        log_invoke_request("key_exchange", &request, None);
237        request
238    }
239
240    /// Process the key exchange response
241    ///
242    /// This verifies the signature and derives the shared encryption key.
243    pub fn process_key_exchange_response(
244        &mut self,
245        key_exchange: &InvokeRequestKeyExchange,
246    ) -> Result<()> {
247        // Extract the server's public key
248        let server_public_key: [u8; 32] = key_exchange
249            .encryption_key
250            .as_slice()
251            .try_into()
252            .map_err(|_| VoltError::crypto("Invalid encryption key length"))?;
253
254        // Verify the signature if we have the target's public key
255        if let Some(target_pk) = &self.target_public_key {
256            // The signed message is encryption_key || nonce
257            let mut message =
258                Vec::with_capacity(key_exchange.encryption_key.len() + key_exchange.nonce.len());
259            message.extend_from_slice(&key_exchange.encryption_key);
260            message.extend_from_slice(&key_exchange.nonce);
261
262            // Verify using Ed25519
263            let target_pk_array: [u8; 32] = target_pk
264                .as_slice()
265                .try_into()
266                .map_err(|_| VoltError::crypto("Invalid target public key length"))?;
267
268            let verifying_key = ed25519_dalek::VerifyingKey::from_bytes(&target_pk_array)
269                .map_err(|e| VoltError::crypto(format!("Invalid public key: {}", e)))?;
270
271            let signature_array: [u8; 64] = key_exchange
272                .signature
273                .as_slice()
274                .try_into()
275                .map_err(|_| VoltError::crypto("Invalid signature length"))?;
276
277            let signature = ed25519_dalek::Signature::from_bytes(&signature_array);
278
279            verifying_key
280                .verify_strict(&message, &signature)
281                .map_err(|e| VoltError::crypto(format!("Signature verification failed: {}", e)))?;
282
283            tracing::debug!("Key exchange signature verified successfully");
284        } else {
285            tracing::warn!("No target public key set, skipping signature verification");
286        }
287
288        // Derive the shared encryption key
289        let encryption_key = self
290            .key_exchange
291            .derive_relay_encryption_key(&server_public_key)?;
292        self.encryption_key = Some(encryption_key);
293        self.state = RelayState::Connected;
294
295        tracing::info!("Relay key exchange completed successfully");
296        Ok(())
297    }
298
299    /// Encrypt a payload for sending
300    pub fn encrypt(&self, plaintext: &[u8]) -> Result<(Vec<u8>, [u8; AES_CBC_IV_LENGTH])> {
301        let key = self
302            .encryption_key
303            .ok_or_else(|| VoltError::crypto("No encryption key established"))?;
304
305        let iv = random_iv();
306        let ciphertext = aes_cbc_encrypt(&key, &iv, plaintext)?;
307
308        Ok((ciphertext, iv))
309    }
310
311    /// Decrypt a received payload
312    pub fn decrypt(&self, ciphertext: &[u8], iv: &[u8]) -> Result<Vec<u8>> {
313        let key = self
314            .encryption_key
315            .ok_or_else(|| VoltError::crypto("No encryption key established"))?;
316
317        let iv_array: [u8; AES_CBC_IV_LENGTH] = iv
318            .try_into()
319            .map_err(|_| VoltError::crypto("Invalid IV length"))?;
320
321        aes_cbc_decrypt(&key, &iv_array, ciphertext)
322    }
323
324    /// Create an encrypted InvokeRequest wrapping a RemoteResponse
325    ///
326    /// Note: In the relay protocol, the client sends "responses" to the relay
327    /// which forwards them to the target Volt (which treats them as requests).
328    pub fn create_encrypted_request(
329        &self,
330        invoke_id: u64,
331        remote_response: &RemoteResponse,
332    ) -> Result<InvokeRequest> {
333        // Serialize the RemoteResponse
334        let mut buf = Vec::new();
335        remote_response.encode(&mut buf).map_err(|e| {
336            VoltError::serialization(format!("Failed to encode RemoteResponse: {}", e))
337        })?;
338
339        let plaintext_b64 = if relay_debug_enabled() {
340            let b64 = to_base64(&buf);
341            log_remote_response("request", invoke_id, remote_response, &b64, buf.len());
342            Some(b64)
343        } else {
344            None
345        };
346
347        // Encrypt the serialized payload
348        let (encrypted, iv) = self.encrypt(&buf)?;
349
350        let request = InvokeRequest {
351            invoke_id,
352            token: String::new(), // Token only needed on first request
353            target_did: vec![self.target_did.clone()],
354            iv: iv.to_vec(),
355            client_end: false,
356            hop_index: 0,
357            target_service_id: String::new(),
358            request_payload: Some(invoke_request::RequestPayload::Payload(encrypted)),
359        };
360
361        log_invoke_request("encrypted", &request, plaintext_b64.as_deref());
362
363        Ok(request)
364    }
365
366    /// Parse an InvokeResponse and extract the RemoteRequest
367    ///
368    /// Handles both key exchange responses and encrypted payload responses.
369    pub fn parse_response(&mut self, response: &InvokeResponse) -> Result<Option<RemoteRequest>> {
370        // Check for key exchange response
371        if let Some(ref key_exchange) = response.key_exchange {
372            self.process_key_exchange_response(key_exchange)?;
373            return Ok(None);
374        }
375
376        // Check for error status
377        if let Some(invoke_response::ResponsePayload::Status(ref status)) =
378            response.response_payload
379        {
380            if status.code != 0 {
381                return Err(VoltError::server(status.code, &status.message));
382            }
383            return Ok(None);
384        }
385
386        // Extract and decrypt the payload
387        let encrypted_payload = match &response.response_payload {
388            Some(invoke_response::ResponsePayload::Payload(data)) => data.as_slice(),
389            Some(invoke_response::ResponsePayload::JsonPayload(data)) => data.as_slice(),
390            _ => return Ok(None),
391        };
392
393        // Decrypt the payload
394        let decrypted = self.decrypt(encrypted_payload, &response.iv)?;
395
396        // Deserialize as RemoteRequest
397        let remote_request = RemoteRequest::decode(decrypted.as_slice()).map_err(|e| {
398            VoltError::serialization(format!("Failed to decode RemoteRequest: {}", e))
399        })?;
400
401        Ok(Some(remote_request))
402    }
403
404    /// Create a MethodInvoke payload for calling a service method
405    pub fn create_method_invoke(
406        &self,
407        invoke_id: u64,
408        service_id: &str,
409        method_name: &str,
410        method_type: MethodType,
411        request_data: Vec<u8>,
412    ) -> RemoteResponse {
413        let method_invoke = MethodInvoke {
414            id: invoke_id,
415            service_id: service_id.to_string(),
416            method_name: method_name.to_string(),
417            method_type: method_type as i32,
418            invoke_request: Some(method_invoke::InvokeRequest::Request(request_data)),
419        };
420
421        RemoteResponse {
422            payload: Some(remote_response::Payload::MethodInvoke(method_invoke)),
423        }
424    }
425
426    /// Create a MethodPayload for streaming data
427    pub fn create_method_payload(&self, invoke_id: u64, payload: Vec<u8>) -> RemoteResponse {
428        let method_payload = MethodPayload {
429            id: invoke_id,
430            method_payload: Some(method_payload::MethodPayload::Payload(payload)),
431        };
432
433        RemoteResponse {
434            payload: Some(remote_response::Payload::MethodPayload(method_payload)),
435        }
436    }
437
438    /// Create a MethodEnd to signal end of stream
439    pub fn create_method_end(&self, invoke_id: u64) -> RemoteResponse {
440        let method_end = MethodEnd {
441            id: invoke_id,
442            ended: true,
443            error: String::new(),
444            error_code: 0,
445        };
446
447        RemoteResponse {
448            payload: Some(remote_response::Payload::MethodEnd(method_end)),
449        }
450    }
451}
452
453/// A relay-wrapped gRPC call
454///
455/// This wraps the normal gRPC call logic to work through the relay protocol.
456pub struct RelayCall {
457    /// The relay context for this call
458    context: Arc<Mutex<RelayContext>>,
459    /// The invoke ID for this call
460    invoke_id: u64,
461    /// Service ID
462    service_id: String,
463}
464
465impl RelayCall {
466    /// Create a new relay call
467    pub fn new(context: Arc<Mutex<RelayContext>>, service_id: String) -> Self {
468        // We'll get the invoke_id when we make the call
469        Self {
470            context,
471            invoke_id: 0,
472            service_id,
473        }
474    }
475
476    /// Make a unary call through the relay
477    pub async fn unary<Req, Resp>(
478        &mut self,
479        method_name: &str,
480        request: Req,
481        invoke_stream_sender: &mpsc::Sender<InvokeRequest>,
482        invoke_stream_receiver: &mut Streaming<InvokeResponse>,
483    ) -> Result<Resp>
484    where
485        Req: Message,
486        Resp: Message + Default,
487    {
488        let ctx = self.context.lock().await;
489
490        // Get invoke ID
491        self.invoke_id = ctx.next_invoke_id();
492
493        // Serialize the request
494        let mut request_bytes = Vec::new();
495        request
496            .encode(&mut request_bytes)
497            .map_err(|e| VoltError::serialization(format!("Failed to encode request: {}", e)))?;
498
499        // Create the method invoke
500        let remote_response = ctx.create_method_invoke(
501            self.invoke_id,
502            &self.service_id,
503            method_name,
504            MethodType::Unary,
505            request_bytes,
506        );
507
508        // Encrypt and send
509        let invoke_request = ctx.create_encrypted_request(self.invoke_id, &remote_response)?;
510        drop(ctx); // Release lock before awaiting
511
512        invoke_stream_sender
513            .send(invoke_request)
514            .await
515            .map_err(|e| VoltError::connection(format!("Failed to send request: {}", e)))?;
516
517        // Receive response
518        let response = invoke_stream_receiver
519            .message()
520            .await
521            .map_err(|e| VoltError::grpc(e.code(), e.message()))?
522            .ok_or_else(|| VoltError::connection("Stream closed unexpectedly"))?;
523
524        // Parse and decrypt the response
525        let mut ctx = self.context.lock().await;
526        let remote_request = ctx
527            .parse_response(&response)?
528            .ok_or_else(|| VoltError::protocol("Expected payload response"))?;
529
530        // Extract the method payload
531        match remote_request.payload {
532            Some(remote_request::Payload::MethodPayload(mp)) => {
533                let payload = match mp.method_payload {
534                    Some(method_payload::MethodPayload::Payload(p)) => p,
535                    Some(method_payload::MethodPayload::JsonPayload(j)) => j.into_bytes(),
536                    None => return Err(VoltError::protocol("Empty method payload")),
537                };
538
539                Resp::decode(payload.as_slice()).map_err(|e| {
540                    VoltError::serialization(format!("Failed to decode response: {}", e))
541                })
542            }
543            _ => Err(VoltError::protocol("Expected MethodPayload response")),
544        }
545    }
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551
552    #[test]
553    fn test_relay_context_creation() {
554        let ctx = RelayContext::new("did:tdx:test-volt".to_string());
555        assert_eq!(ctx.state(), &RelayState::Disconnected);
556        assert!(ctx.encryption_key.is_none());
557    }
558
559    #[test]
560    fn test_invoke_id_counter() {
561        let ctx = RelayContext::new("did:tdx:test-volt".to_string());
562        assert_eq!(ctx.next_invoke_id(), 1);
563        assert_eq!(ctx.next_invoke_id(), 2);
564        assert_eq!(ctx.next_invoke_id(), 3);
565    }
566
567    #[test]
568    fn test_key_exchange_request_creation() {
569        let mut ctx = RelayContext::new("did:tdx:test-volt".to_string());
570        let request = ctx.create_key_exchange_request("test-jwt-token");
571
572        assert_eq!(ctx.state(), &RelayState::KeyExchangePending);
573        assert_eq!(request.token, "test-jwt-token");
574        assert_eq!(request.target_did, vec!["did:tdx:test-volt"]);
575        assert!(request.request_payload.is_none());
576    }
577
578    #[test]
579    fn test_encrypt_decrypt_roundtrip() {
580        let mut ctx = RelayContext::new("did:tdx:test-volt".to_string());
581
582        // Manually set an encryption key for testing
583        ctx.encryption_key = Some([0u8; AES_KEY_LENGTH]);
584        ctx.state = RelayState::Connected;
585
586        let plaintext = b"Hello, relay world!";
587        let (ciphertext, iv) = ctx.encrypt(plaintext).unwrap();
588        let decrypted = ctx.decrypt(&ciphertext, &iv).unwrap();
589
590        assert_eq!(plaintext.as_slice(), decrypted.as_slice());
591    }
592}