Skip to main content

volt_client_grpc/
grpc_call.rs

1//! gRPC call wrapper for handling Volt API calls.
2//!
3//! This module provides abstractions for making gRPC calls to the Volt,
4//! handling encryption, and managing streaming calls.
5
6use crate::constants::MethodType;
7use crate::crypto::{aes_decrypt, aes_encrypt, AesKey};
8use crate::error::{Result, VoltError};
9use futures::stream::Stream;
10use std::sync::Arc;
11use tokio::sync::mpsc;
12use tokio::sync::Mutex;
13
14/// State of a gRPC call
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum CallState {
17    /// Call has not started
18    NotStarted,
19    /// Waiting for key exchange
20    KeyExchangePending,
21    /// Call is active
22    Active,
23    /// Call has ended
24    Ended,
25}
26
27/// A wrapper around gRPC streaming calls
28pub struct GrpcCall<Req, Resp> {
29    method_name: String,
30    method_type: MethodType,
31    state: Arc<Mutex<CallState>>,
32    encryption_key: Arc<Mutex<Option<[u8; 32]>>>,
33    pending_requests: Arc<Mutex<Vec<Req>>>,
34    _sender: Option<mpsc::Sender<Req>>,
35    _phantom: std::marker::PhantomData<Resp>,
36}
37
38impl<Req, Resp> GrpcCall<Req, Resp>
39where
40    Req: Send + 'static,
41    Resp: Send + 'static,
42{
43    /// Create a new gRPC call wrapper
44    pub fn new(method_name: impl Into<String>, method_type: MethodType) -> Self {
45        Self {
46            method_name: method_name.into(),
47            method_type,
48            state: Arc::new(Mutex::new(CallState::NotStarted)),
49            encryption_key: Arc::new(Mutex::new(None)),
50            pending_requests: Arc::new(Mutex::new(Vec::new())),
51            _sender: None,
52            _phantom: std::marker::PhantomData,
53        }
54    }
55
56    /// Get the method name
57    pub fn method_name(&self) -> &str {
58        &self.method_name
59    }
60
61    /// Get the method type
62    pub fn method_type(&self) -> MethodType {
63        self.method_type
64    }
65
66    /// Get the current state
67    pub async fn state(&self) -> CallState {
68        *self.state.lock().await
69    }
70
71    /// Check if the call is still writable
72    pub async fn is_writable(&self) -> bool {
73        let state = self.state.lock().await;
74        *state == CallState::Active || *state == CallState::KeyExchangePending
75    }
76
77    /// Set the encryption key
78    pub async fn set_encryption_key(&self, key: [u8; 32]) {
79        *self.encryption_key.lock().await = Some(key);
80    }
81
82    /// Get the encryption key
83    pub async fn encryption_key(&self) -> Option<[u8; 32]> {
84        *self.encryption_key.lock().await
85    }
86
87    /// Queue a request for later sending
88    pub async fn queue_request(&self, request: Req) {
89        self.pending_requests.lock().await.push(request);
90    }
91
92    /// Take all pending requests
93    pub async fn take_pending_requests(&self) -> Vec<Req> {
94        std::mem::take(&mut *self.pending_requests.lock().await)
95    }
96
97    /// Set the call state
98    pub async fn set_state(&self, state: CallState) {
99        *self.state.lock().await = state;
100    }
101}
102
103/// Builder for constructing gRPC calls with proper setup
104#[allow(dead_code)]
105pub struct GrpcCallBuilder {
106    method_name: String,
107    _method_type: MethodType,
108    is_relayed: bool,
109    service_relayed: bool,
110}
111
112impl GrpcCallBuilder {
113    pub fn new(method_name: impl Into<String>, _method_type: MethodType) -> Self {
114        Self {
115            method_name: method_name.into(),
116            _method_type,
117            is_relayed: false,
118            service_relayed: false,
119        }
120    }
121
122    pub fn relayed(mut self, is_relayed: bool) -> Self {
123        self.is_relayed = is_relayed;
124        self
125    }
126
127    pub fn service_relayed(mut self, service_relayed: bool) -> Self {
128        self.service_relayed = service_relayed;
129        self
130    }
131
132    pub fn needs_encryption(&self) -> bool {
133        self.is_relayed || self.service_relayed
134    }
135}
136
137/// Handle encryption for relayed calls
138pub fn encrypt_payload(key: &[u8; 32], payload: &[u8]) -> Result<(Vec<u8>, [u8; 12])> {
139    let aes_key = AesKey::generate();
140    let iv = *aes_key.iv();
141    let encrypted = aes_encrypt(key, &iv, payload)?;
142    Ok((encrypted, iv))
143}
144
145/// Handle decryption for relayed calls
146pub fn decrypt_payload(key: &[u8; 32], iv: &[u8; 12], ciphertext: &[u8]) -> Result<Vec<u8>> {
147    aes_decrypt(key, iv, ciphertext)
148}
149
150/// A stream adapter that handles encryption/decryption
151#[allow(dead_code)] 
152pub struct EncryptedStream<S, T> {
153    inner: S,
154    encryption_key: Option<[u8; 32]>,
155    _phantom: std::marker::PhantomData<T>,
156}
157
158impl<S, T> EncryptedStream<S, T>
159where
160    S: Stream<Item = Result<T>> + Unpin,
161{
162    pub fn new(inner: S, encryption_key: Option<[u8; 32]>) -> Self {
163        Self {
164            inner,
165            encryption_key,
166            _phantom: std::marker::PhantomData,
167        }
168    }
169}
170
171/// Event emitter for call events
172pub type CallEventSender<T> = mpsc::Sender<CallEvent<T>>;
173pub type CallEventReceiver<T> = mpsc::Receiver<CallEvent<T>>;
174
175/// Events that can be emitted by a call
176#[derive(Debug)]
177pub enum CallEvent<T> {
178    /// Data received
179    Data(T),
180    /// Error occurred
181    Error(VoltError),
182    /// Call ended
183    End,
184}
185
186/// Create an event channel for a call
187pub fn create_event_channel<T>(buffer_size: usize) -> (CallEventSender<T>, CallEventReceiver<T>) {
188    mpsc::channel(buffer_size)
189}