volt_client_grpc/
grpc_call.rs1use crate::constants::MethodType;
7use crate::crypto::{aes_decrypt, aes_encrypt, AesKey};
8use crate::error::{Result, VoltError};
9use futures::stream::Stream;
10use std::sync::Arc;
11use tokio::sync::mpsc;
12use tokio::sync::Mutex;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum CallState {
17 NotStarted,
19 KeyExchangePending,
21 Active,
23 Ended,
25}
26
27pub struct GrpcCall<Req, Resp> {
29 method_name: String,
30 method_type: MethodType,
31 state: Arc<Mutex<CallState>>,
32 encryption_key: Arc<Mutex<Option<[u8; 32]>>>,
33 pending_requests: Arc<Mutex<Vec<Req>>>,
34 _sender: Option<mpsc::Sender<Req>>,
35 _phantom: std::marker::PhantomData<Resp>,
36}
37
38impl<Req, Resp> GrpcCall<Req, Resp>
39where
40 Req: Send + 'static,
41 Resp: Send + 'static,
42{
43 pub fn new(method_name: impl Into<String>, method_type: MethodType) -> Self {
45 Self {
46 method_name: method_name.into(),
47 method_type,
48 state: Arc::new(Mutex::new(CallState::NotStarted)),
49 encryption_key: Arc::new(Mutex::new(None)),
50 pending_requests: Arc::new(Mutex::new(Vec::new())),
51 _sender: None,
52 _phantom: std::marker::PhantomData,
53 }
54 }
55
56 pub fn method_name(&self) -> &str {
58 &self.method_name
59 }
60
61 pub fn method_type(&self) -> MethodType {
63 self.method_type
64 }
65
66 pub async fn state(&self) -> CallState {
68 *self.state.lock().await
69 }
70
71 pub async fn is_writable(&self) -> bool {
73 let state = self.state.lock().await;
74 *state == CallState::Active || *state == CallState::KeyExchangePending
75 }
76
77 pub async fn set_encryption_key(&self, key: [u8; 32]) {
79 *self.encryption_key.lock().await = Some(key);
80 }
81
82 pub async fn encryption_key(&self) -> Option<[u8; 32]> {
84 *self.encryption_key.lock().await
85 }
86
87 pub async fn queue_request(&self, request: Req) {
89 self.pending_requests.lock().await.push(request);
90 }
91
92 pub async fn take_pending_requests(&self) -> Vec<Req> {
94 std::mem::take(&mut *self.pending_requests.lock().await)
95 }
96
97 pub async fn set_state(&self, state: CallState) {
99 *self.state.lock().await = state;
100 }
101}
102
103#[allow(dead_code)]
105pub struct GrpcCallBuilder {
106 method_name: String,
107 _method_type: MethodType,
108 is_relayed: bool,
109 service_relayed: bool,
110}
111
112impl GrpcCallBuilder {
113 pub fn new(method_name: impl Into<String>, _method_type: MethodType) -> Self {
114 Self {
115 method_name: method_name.into(),
116 _method_type,
117 is_relayed: false,
118 service_relayed: false,
119 }
120 }
121
122 pub fn relayed(mut self, is_relayed: bool) -> Self {
123 self.is_relayed = is_relayed;
124 self
125 }
126
127 pub fn service_relayed(mut self, service_relayed: bool) -> Self {
128 self.service_relayed = service_relayed;
129 self
130 }
131
132 pub fn needs_encryption(&self) -> bool {
133 self.is_relayed || self.service_relayed
134 }
135}
136
137pub fn encrypt_payload(key: &[u8; 32], payload: &[u8]) -> Result<(Vec<u8>, [u8; 12])> {
139 let aes_key = AesKey::generate();
140 let iv = *aes_key.iv();
141 let encrypted = aes_encrypt(key, &iv, payload)?;
142 Ok((encrypted, iv))
143}
144
145pub fn decrypt_payload(key: &[u8; 32], iv: &[u8; 12], ciphertext: &[u8]) -> Result<Vec<u8>> {
147 aes_decrypt(key, iv, ciphertext)
148}
149
150#[allow(dead_code)]
152pub struct EncryptedStream<S, T> {
153 inner: S,
154 encryption_key: Option<[u8; 32]>,
155 _phantom: std::marker::PhantomData<T>,
156}
157
158impl<S, T> EncryptedStream<S, T>
159where
160 S: Stream<Item = Result<T>> + Unpin,
161{
162 pub fn new(inner: S, encryption_key: Option<[u8; 32]>) -> Self {
163 Self {
164 inner,
165 encryption_key,
166 _phantom: std::marker::PhantomData,
167 }
168 }
169}
170
171pub type CallEventSender<T> = mpsc::Sender<CallEvent<T>>;
173pub type CallEventReceiver<T> = mpsc::Receiver<CallEvent<T>>;
174
175#[derive(Debug)]
177pub enum CallEvent<T> {
178 Data(T),
180 Error(VoltError),
182 End,
184}
185
186pub fn create_event_channel<T>(buffer_size: usize) -> (CallEventSender<T>, CallEventReceiver<T>) {
188 mpsc::channel(buffer_size)
189}