Skip to main content

vectorizer_sdk/rpc/
client.rs

1//! `RpcClient`: connect, hello, call, ping, close.
2//!
3//! The client owns one TCP connection to the server. It runs a single
4//! background reader task that demultiplexes responses by `Request.id`
5//! into per-call `oneshot` channels, so concurrent in-flight calls
6//! on the same connection don't block each other.
7//!
8//! Auth is **per-connection sticky** per wire spec § 4: the first
9//! frame on a connection MUST be `HELLO`; every subsequent call
10//! inherits the auth state. The client tracks the authenticated /
11//! admin flags from the HELLO response so callers can introspect
12//! after the handshake.
13
14use std::collections::HashMap;
15use std::io;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU32, Ordering};
18
19use parking_lot::Mutex;
20use tokio::io::BufReader;
21use tokio::net::TcpStream;
22use tokio::sync::{Notify, oneshot};
23use tokio::task::JoinHandle;
24use tracing::{debug, warn};
25
26use super::codec::{read_response, write_request};
27use super::types::{Request, Response, VectorizerValue};
28
29/// Errors the [`RpcClient`] can return.
30#[derive(Debug, thiserror::Error)]
31pub enum RpcClientError {
32    /// Network-level I/O failure.
33    #[error("network I/O error: {0}")]
34    Io(#[from] io::Error),
35
36    /// MessagePack encode failure (should be unreachable for the v1
37    /// shapes — every type derives `Serialize`).
38    #[error("encode failed: {0}")]
39    Encode(#[from] rmp_serde::encode::Error),
40
41    /// Server returned `Result::Err(message)` for the call.
42    #[error("server error: {0}")]
43    Server(String),
44
45    /// The connection's reader task died before the response arrived.
46    #[error("connection closed before response (reader task ended)")]
47    ConnectionClosed,
48
49    /// Caller invoked a data-plane command before HELLO succeeded.
50    /// The server would reject this; the client surfaces it locally
51    /// so the offending caller sees a clear panic-free error.
52    #[error("HELLO must succeed before any data-plane command can be issued")]
53    NotAuthenticated,
54}
55
56/// Result type alias.
57pub type Result<T> = std::result::Result<T, RpcClientError>;
58
59/// HELLO request payload — sent as the FIRST frame on a connection.
60///
61/// At least one of `token` / `api_key` should be populated when the
62/// server has auth enabled. When the server runs in single-user mode
63/// (`auth.enabled: false`), credentials are accepted-but-ignored and
64/// the connection runs as the implicit local admin.
65#[derive(Debug, Clone, Default)]
66pub struct HelloPayload {
67    /// Bearer JWT (same shape REST `/auth/login` returns).
68    pub token: Option<String>,
69    /// API key.
70    pub api_key: Option<String>,
71    /// User-Agent-style identifier surfaced in server-side tracing.
72    pub client_name: Option<String>,
73    /// Wire spec protocol version; defaults to 1.
74    pub version: i64,
75}
76
77impl HelloPayload {
78    /// Build a minimal HELLO payload identifying the client by name.
79    /// No credentials — works against a server running in single-user
80    /// mode (`auth.enabled: false`).
81    pub fn new(client_name: impl Into<String>) -> Self {
82        Self {
83            client_name: Some(client_name.into()),
84            version: 1,
85            ..Default::default()
86        }
87    }
88
89    /// Attach a JWT bearer token. Replaces any previously set
90    /// token/api_key.
91    pub fn with_token(mut self, token: impl Into<String>) -> Self {
92        self.token = Some(token.into());
93        self.api_key = None;
94        self
95    }
96
97    /// Attach an API key. Replaces any previously set token/api_key.
98    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
99        self.api_key = Some(api_key.into());
100        self.token = None;
101        self
102    }
103
104    fn into_value(self) -> VectorizerValue {
105        let mut pairs = vec![(
106            VectorizerValue::Str("version".into()),
107            VectorizerValue::Int(self.version),
108        )];
109        if let Some(token) = self.token {
110            pairs.push((
111                VectorizerValue::Str("token".into()),
112                VectorizerValue::Str(token),
113            ));
114        }
115        if let Some(api_key) = self.api_key {
116            pairs.push((
117                VectorizerValue::Str("api_key".into()),
118                VectorizerValue::Str(api_key),
119            ));
120        }
121        if let Some(name) = self.client_name {
122            pairs.push((
123                VectorizerValue::Str("client_name".into()),
124                VectorizerValue::Str(name),
125            ));
126        }
127        VectorizerValue::Map(pairs)
128    }
129}
130
131/// What the server returns for a successful `HELLO`.
132#[derive(Debug, Clone)]
133pub struct HelloResponse {
134    /// Server crate version, e.g. `"3.0.0"`.
135    pub server_version: String,
136    /// Wire spec protocol version, currently always `1`.
137    pub protocol_version: i64,
138    /// `true` when the server accepted the supplied credentials (or
139    /// when auth is globally disabled).
140    pub authenticated: bool,
141    /// `true` when the authenticated principal carries `Role::Admin`.
142    pub admin: bool,
143    /// Capability names this connection can call.
144    pub capabilities: Vec<String>,
145}
146
147impl HelloResponse {
148    fn parse(value: &VectorizerValue) -> Self {
149        let server_version = value
150            .map_get("server_version")
151            .and_then(|v| v.as_str())
152            .map(str::to_owned)
153            .unwrap_or_default();
154        let protocol_version = value
155            .map_get("protocol_version")
156            .and_then(|v| v.as_int())
157            .unwrap_or(0);
158        let authenticated = value
159            .map_get("authenticated")
160            .and_then(|v| v.as_bool())
161            .unwrap_or(false);
162        let admin = value
163            .map_get("admin")
164            .and_then(|v| v.as_bool())
165            .unwrap_or(false);
166        let capabilities = value
167            .map_get("capabilities")
168            .and_then(|v| v.as_array())
169            .map(|arr| {
170                arr.iter()
171                    .filter_map(|v| v.as_str().map(str::to_owned))
172                    .collect()
173            })
174            .unwrap_or_default();
175        Self {
176            server_version,
177            protocol_version,
178            authenticated,
179            admin,
180            capabilities,
181        }
182    }
183}
184
185/// One connection to a Vectorizer RPC server.
186pub struct RpcClient {
187    /// Owned write half of the TCP socket. Wrapped in a mutex because
188    /// every `call` writes serially; the writer is the only one that
189    /// touches this half.
190    writer: Arc<tokio::sync::Mutex<tokio::net::tcp::OwnedWriteHalf>>,
191    /// Map from request id → oneshot sender for the matching response.
192    pending: Arc<Mutex<HashMap<u32, oneshot::Sender<Response>>>>,
193    /// Monotonic id allocator.
194    next_id: AtomicU32,
195    /// Notified when the reader task exits, so pending calls fail
196    /// fast instead of hanging forever.
197    reader_done: Arc<Notify>,
198    /// Handle to the spawned reader task; aborted on `Drop`.
199    reader_task: Option<JoinHandle<()>>,
200    /// `true` once HELLO succeeded.
201    authenticated: Arc<Mutex<bool>>,
202}
203
204impl RpcClient {
205    /// Convenience: parse a `vectorizer://host[:port]` URL and dial.
206    ///
207    /// Accepts every form documented at
208    /// [`crate::rpc::endpoint::parse_endpoint`]:
209    ///
210    /// - `vectorizer://host:port` → RPC on the given port.
211    /// - `vectorizer://host` → RPC on the default port 15503.
212    /// - `host:port` (no scheme) → RPC.
213    /// - `http(s)://...` → returns [`RpcClientError::Server`] with a
214    ///   clear message asking the caller to use the HTTP client
215    ///   instead. The SDK ships the `http` Cargo feature for that
216    ///   path; an `http://` URL is not a transport an RPC client can
217    ///   speak.
218    pub async fn connect_url(url: &str) -> Result<Self> {
219        use super::endpoint::{Endpoint, parse_endpoint};
220        match parse_endpoint(url).map_err(|e| RpcClientError::Server(e.to_string()))? {
221            Endpoint::Rpc { host, port } => Self::connect(format!("{host}:{port}")).await,
222            Endpoint::Rest { url } => Err(RpcClientError::Server(format!(
223                "RpcClient cannot dial REST URL '{url}'; \
224                 use the HTTP client (`vectorizer_sdk::VectorizerClient`) instead, \
225                 or pass a `vectorizer://` URL"
226            ))),
227        }
228    }
229
230    /// Open a TCP connection to `addr` (which must be `host:port`)
231    /// and start the background reader task. Does NOT send HELLO —
232    /// callers MUST call [`Self::hello`] before any data-plane
233    /// command, or the server will reject it.
234    pub async fn connect(addr: impl tokio::net::ToSocketAddrs) -> Result<Self> {
235        let stream = TcpStream::connect(addr).await?;
236        let (read_half, write_half) = stream.into_split();
237        let mut reader = BufReader::new(read_half);
238
239        let pending: Arc<Mutex<HashMap<u32, oneshot::Sender<Response>>>> =
240            Arc::new(Mutex::new(HashMap::new()));
241        let reader_done = Arc::new(Notify::new());
242
243        // Spawn the reader: read frames forever, dispatch to pending
244        // by id, close down on EOF.
245        let pending_for_reader = Arc::clone(&pending);
246        let done_for_reader = Arc::clone(&reader_done);
247        let reader_task = tokio::spawn(async move {
248            loop {
249                match read_response(&mut reader).await {
250                    Ok(resp) => {
251                        let sender = {
252                            let mut p = pending_for_reader.lock();
253                            p.remove(&resp.id)
254                        };
255                        match sender {
256                            Some(tx) => {
257                                let _ = tx.send(resp);
258                            }
259                            None => {
260                                warn!(
261                                    id = resp.id,
262                                    "RpcClient received response with no pending caller — dropping"
263                                );
264                            }
265                        }
266                    }
267                    Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
268                        debug!("RpcClient reader: clean EOF");
269                        break;
270                    }
271                    Err(e) => {
272                        warn!(error = %e, "RpcClient reader error — connection closed");
273                        break;
274                    }
275                }
276            }
277            // Drain pending — every waiting call gets ConnectionClosed.
278            let mut p = pending_for_reader.lock();
279            p.clear();
280            done_for_reader.notify_waiters();
281        });
282
283        Ok(Self {
284            writer: Arc::new(tokio::sync::Mutex::new(write_half)),
285            pending,
286            next_id: AtomicU32::new(1),
287            reader_done,
288            reader_task: Some(reader_task),
289            authenticated: Arc::new(Mutex::new(false)),
290        })
291    }
292
293    /// Issue the `HELLO` handshake. Must be the first call on a fresh
294    /// connection. Returns the server's capability list and auth flags.
295    pub async fn hello(&self, payload: HelloPayload) -> Result<HelloResponse> {
296        let value = payload.into_value();
297        let result = self.raw_call("HELLO", vec![value]).await?;
298        let parsed = HelloResponse::parse(&result);
299        if parsed.authenticated {
300            *self.authenticated.lock() = true;
301        }
302        Ok(parsed)
303    }
304
305    /// Health check. The server treats `PING` as auth-exempt so this
306    /// works even before HELLO; the typed wrapper still validates the
307    /// response shape.
308    pub async fn ping(&self) -> Result<String> {
309        let result = self.raw_call("PING", vec![]).await?;
310        result
311            .as_str()
312            .map(str::to_owned)
313            .ok_or_else(|| RpcClientError::Server("PING returned non-string payload".into()))
314    }
315
316    /// Generic call dispatcher. Most callers should use a typed
317    /// wrapper from [`crate::rpc::commands`] instead.
318    pub async fn call(
319        &self,
320        command: impl Into<String>,
321        args: Vec<VectorizerValue>,
322    ) -> Result<VectorizerValue> {
323        let cmd = command.into();
324        // Auth-exempt commands per wire spec § 4.
325        let exempt = matches!(cmd.as_str(), "HELLO" | "PING");
326        if !exempt && !*self.authenticated.lock() {
327            return Err(RpcClientError::NotAuthenticated);
328        }
329        self.raw_call(cmd, args).await
330    }
331
332    /// Skip the local auth check — used by the HELLO + PING paths so
333    /// the auth gate doesn't block the auth handshake itself.
334    async fn raw_call(
335        &self,
336        command: impl Into<String>,
337        args: Vec<VectorizerValue>,
338    ) -> Result<VectorizerValue> {
339        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
340        let (tx, rx) = oneshot::channel::<Response>();
341        {
342            let mut pending = self.pending.lock();
343            pending.insert(id, tx);
344        }
345
346        let req = Request {
347            id,
348            command: command.into(),
349            args,
350        };
351
352        // Write the frame under the writer mutex so concurrent calls
353        // don't interleave bytes.
354        {
355            let mut writer = self.writer.lock().await;
356            if let Err(e) = write_request(&mut *writer, &req).await {
357                self.pending.lock().remove(&id);
358                return Err(RpcClientError::from(e));
359            }
360        }
361
362        // Race the response against the reader-task-exited notifier so
363        // a torn connection fails fast instead of hanging.
364        let resp = tokio::select! {
365            recv = rx => match recv {
366                Ok(resp) => resp,
367                Err(_) => return Err(RpcClientError::ConnectionClosed),
368            },
369            _ = self.reader_done.notified() => {
370                self.pending.lock().remove(&id);
371                return Err(RpcClientError::ConnectionClosed);
372            }
373        };
374
375        match resp.result {
376            Ok(value) => Ok(value),
377            Err(message) => Err(RpcClientError::Server(message)),
378        }
379    }
380
381    /// Returns `true` once HELLO has succeeded on this connection.
382    pub fn is_authenticated(&self) -> bool {
383        *self.authenticated.lock()
384    }
385
386    /// Close the connection. Aborts the reader task; in-flight calls
387    /// receive `ConnectionClosed`.
388    pub fn close(mut self) {
389        if let Some(handle) = self.reader_task.take() {
390            handle.abort();
391        }
392    }
393}
394
395impl Drop for RpcClient {
396    fn drop(&mut self) {
397        if let Some(handle) = self.reader_task.take() {
398            handle.abort();
399        }
400    }
401}