1use std::collections::HashMap;
15use std::io;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU32, Ordering};
18
19use parking_lot::Mutex;
20use tokio::io::BufReader;
21use tokio::net::TcpStream;
22use tokio::sync::{Notify, oneshot};
23use tokio::task::JoinHandle;
24use tracing::{debug, warn};
25
26use super::codec::{read_response, write_request};
27use super::types::{Request, Response, VectorizerValue};
28
29#[derive(Debug, thiserror::Error)]
31pub enum RpcClientError {
32 #[error("network I/O error: {0}")]
34 Io(#[from] io::Error),
35
36 #[error("encode failed: {0}")]
39 Encode(#[from] rmp_serde::encode::Error),
40
41 #[error("server error: {0}")]
43 Server(String),
44
45 #[error("connection closed before response (reader task ended)")]
47 ConnectionClosed,
48
49 #[error("HELLO must succeed before any data-plane command can be issued")]
53 NotAuthenticated,
54}
55
56pub type Result<T> = std::result::Result<T, RpcClientError>;
58
59#[derive(Debug, Clone, Default)]
66pub struct HelloPayload {
67 pub token: Option<String>,
69 pub api_key: Option<String>,
71 pub client_name: Option<String>,
73 pub version: i64,
75}
76
77impl HelloPayload {
78 pub fn new(client_name: impl Into<String>) -> Self {
82 Self {
83 client_name: Some(client_name.into()),
84 version: 1,
85 ..Default::default()
86 }
87 }
88
89 pub fn with_token(mut self, token: impl Into<String>) -> Self {
92 self.token = Some(token.into());
93 self.api_key = None;
94 self
95 }
96
97 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
99 self.api_key = Some(api_key.into());
100 self.token = None;
101 self
102 }
103
104 fn into_value(self) -> VectorizerValue {
105 let mut pairs = vec![(
106 VectorizerValue::Str("version".into()),
107 VectorizerValue::Int(self.version),
108 )];
109 if let Some(token) = self.token {
110 pairs.push((
111 VectorizerValue::Str("token".into()),
112 VectorizerValue::Str(token),
113 ));
114 }
115 if let Some(api_key) = self.api_key {
116 pairs.push((
117 VectorizerValue::Str("api_key".into()),
118 VectorizerValue::Str(api_key),
119 ));
120 }
121 if let Some(name) = self.client_name {
122 pairs.push((
123 VectorizerValue::Str("client_name".into()),
124 VectorizerValue::Str(name),
125 ));
126 }
127 VectorizerValue::Map(pairs)
128 }
129}
130
131#[derive(Debug, Clone)]
133pub struct HelloResponse {
134 pub server_version: String,
136 pub protocol_version: i64,
138 pub authenticated: bool,
141 pub admin: bool,
143 pub capabilities: Vec<String>,
145}
146
147impl HelloResponse {
148 fn parse(value: &VectorizerValue) -> Self {
149 let server_version = value
150 .map_get("server_version")
151 .and_then(|v| v.as_str())
152 .map(str::to_owned)
153 .unwrap_or_default();
154 let protocol_version = value
155 .map_get("protocol_version")
156 .and_then(|v| v.as_int())
157 .unwrap_or(0);
158 let authenticated = value
159 .map_get("authenticated")
160 .and_then(|v| v.as_bool())
161 .unwrap_or(false);
162 let admin = value
163 .map_get("admin")
164 .and_then(|v| v.as_bool())
165 .unwrap_or(false);
166 let capabilities = value
167 .map_get("capabilities")
168 .and_then(|v| v.as_array())
169 .map(|arr| {
170 arr.iter()
171 .filter_map(|v| v.as_str().map(str::to_owned))
172 .collect()
173 })
174 .unwrap_or_default();
175 Self {
176 server_version,
177 protocol_version,
178 authenticated,
179 admin,
180 capabilities,
181 }
182 }
183}
184
185pub struct RpcClient {
187 writer: Arc<tokio::sync::Mutex<tokio::net::tcp::OwnedWriteHalf>>,
191 pending: Arc<Mutex<HashMap<u32, oneshot::Sender<Response>>>>,
193 next_id: AtomicU32,
195 reader_done: Arc<Notify>,
198 reader_task: Option<JoinHandle<()>>,
200 authenticated: Arc<Mutex<bool>>,
202}
203
204impl RpcClient {
205 pub async fn connect_url(url: &str) -> Result<Self> {
219 use super::endpoint::{Endpoint, parse_endpoint};
220 match parse_endpoint(url).map_err(|e| RpcClientError::Server(e.to_string()))? {
221 Endpoint::Rpc { host, port } => Self::connect(format!("{host}:{port}")).await,
222 Endpoint::Rest { url } => Err(RpcClientError::Server(format!(
223 "RpcClient cannot dial REST URL '{url}'; \
224 use the HTTP client (`vectorizer_sdk::VectorizerClient`) instead, \
225 or pass a `vectorizer://` URL"
226 ))),
227 }
228 }
229
230 pub async fn connect(addr: impl tokio::net::ToSocketAddrs) -> Result<Self> {
235 let stream = TcpStream::connect(addr).await?;
236 let (read_half, write_half) = stream.into_split();
237 let mut reader = BufReader::new(read_half);
238
239 let pending: Arc<Mutex<HashMap<u32, oneshot::Sender<Response>>>> =
240 Arc::new(Mutex::new(HashMap::new()));
241 let reader_done = Arc::new(Notify::new());
242
243 let pending_for_reader = Arc::clone(&pending);
246 let done_for_reader = Arc::clone(&reader_done);
247 let reader_task = tokio::spawn(async move {
248 loop {
249 match read_response(&mut reader).await {
250 Ok(resp) => {
251 let sender = {
252 let mut p = pending_for_reader.lock();
253 p.remove(&resp.id)
254 };
255 match sender {
256 Some(tx) => {
257 let _ = tx.send(resp);
258 }
259 None => {
260 warn!(
261 id = resp.id,
262 "RpcClient received response with no pending caller — dropping"
263 );
264 }
265 }
266 }
267 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
268 debug!("RpcClient reader: clean EOF");
269 break;
270 }
271 Err(e) => {
272 warn!(error = %e, "RpcClient reader error — connection closed");
273 break;
274 }
275 }
276 }
277 let mut p = pending_for_reader.lock();
279 p.clear();
280 done_for_reader.notify_waiters();
281 });
282
283 Ok(Self {
284 writer: Arc::new(tokio::sync::Mutex::new(write_half)),
285 pending,
286 next_id: AtomicU32::new(1),
287 reader_done,
288 reader_task: Some(reader_task),
289 authenticated: Arc::new(Mutex::new(false)),
290 })
291 }
292
293 pub async fn hello(&self, payload: HelloPayload) -> Result<HelloResponse> {
296 let value = payload.into_value();
297 let result = self.raw_call("HELLO", vec![value]).await?;
298 let parsed = HelloResponse::parse(&result);
299 if parsed.authenticated {
300 *self.authenticated.lock() = true;
301 }
302 Ok(parsed)
303 }
304
305 pub async fn ping(&self) -> Result<String> {
309 let result = self.raw_call("PING", vec![]).await?;
310 result
311 .as_str()
312 .map(str::to_owned)
313 .ok_or_else(|| RpcClientError::Server("PING returned non-string payload".into()))
314 }
315
316 pub async fn call(
319 &self,
320 command: impl Into<String>,
321 args: Vec<VectorizerValue>,
322 ) -> Result<VectorizerValue> {
323 let cmd = command.into();
324 let exempt = matches!(cmd.as_str(), "HELLO" | "PING");
326 if !exempt && !*self.authenticated.lock() {
327 return Err(RpcClientError::NotAuthenticated);
328 }
329 self.raw_call(cmd, args).await
330 }
331
332 async fn raw_call(
335 &self,
336 command: impl Into<String>,
337 args: Vec<VectorizerValue>,
338 ) -> Result<VectorizerValue> {
339 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
340 let (tx, rx) = oneshot::channel::<Response>();
341 {
342 let mut pending = self.pending.lock();
343 pending.insert(id, tx);
344 }
345
346 let req = Request {
347 id,
348 command: command.into(),
349 args,
350 };
351
352 {
355 let mut writer = self.writer.lock().await;
356 if let Err(e) = write_request(&mut *writer, &req).await {
357 self.pending.lock().remove(&id);
358 return Err(RpcClientError::from(e));
359 }
360 }
361
362 let resp = tokio::select! {
365 recv = rx => match recv {
366 Ok(resp) => resp,
367 Err(_) => return Err(RpcClientError::ConnectionClosed),
368 },
369 _ = self.reader_done.notified() => {
370 self.pending.lock().remove(&id);
371 return Err(RpcClientError::ConnectionClosed);
372 }
373 };
374
375 match resp.result {
376 Ok(value) => Ok(value),
377 Err(message) => Err(RpcClientError::Server(message)),
378 }
379 }
380
381 pub fn is_authenticated(&self) -> bool {
383 *self.authenticated.lock()
384 }
385
386 pub fn close(mut self) {
389 if let Some(handle) = self.reader_task.take() {
390 handle.abort();
391 }
392 }
393}
394
395impl Drop for RpcClient {
396 fn drop(&mut self) {
397 if let Some(handle) = self.reader_task.take() {
398 handle.abort();
399 }
400 }
401}