Skip to main content

walrus_core/protocol/api/
client.rs

1//! Client trait — transport primitives plus typed provided methods.
2
3use crate::protocol::message::{
4    DownloadEvent, DownloadRequest, HubRequest, MemoryOp, MemoryResult, SendRequest, SendResponse,
5    StreamEvent, StreamRequest, TaskEvent, client::ClientMessage, server::ServerMessage,
6};
7use anyhow::Result;
8use futures_core::Stream;
9use futures_util::StreamExt;
10
11/// Client-side protocol interface.
12///
13/// Implementors provide two transport primitives — [`request`](Client::request)
14/// for request-response and [`request_stream`](Client::request_stream) for
15/// streaming operations. All typed methods are provided defaults that delegate
16/// to these primitives.
17pub trait Client: Send {
18    /// Send a `ClientMessage` and receive a single `ServerMessage`.
19    fn request(
20        &mut self,
21        msg: ClientMessage,
22    ) -> impl std::future::Future<Output = Result<ServerMessage>> + Send;
23
24    /// Send a `ClientMessage` and receive a stream of `ServerMessage`s.
25    ///
26    /// This is a raw transport primitive — the stream reads indefinitely.
27    /// Callers must detect the terminal sentinel (e.g. `StreamEnd`,
28    /// `DownloadEnd`) and stop consuming. The typed streaming methods
29    /// handle this automatically.
30    fn request_stream(
31        &mut self,
32        msg: ClientMessage,
33    ) -> impl Stream<Item = Result<ServerMessage>> + Send + '_;
34
35    /// Send a message to an agent and receive a complete response.
36    fn send(
37        &mut self,
38        req: SendRequest,
39    ) -> impl std::future::Future<Output = Result<SendResponse>> + Send {
40        async move { SendResponse::try_from(self.request(req.into()).await?) }
41    }
42
43    /// Send a message to an agent and receive a streamed response.
44    fn stream(
45        &mut self,
46        req: StreamRequest,
47    ) -> impl Stream<Item = Result<StreamEvent>> + Send + '_ {
48        self.request_stream(req.into())
49            .take_while(|r| {
50                std::future::ready(!matches!(
51                    r,
52                    Ok(ServerMessage::Stream(StreamEvent::End { .. }))
53                ))
54            })
55            .map(|r| r.and_then(StreamEvent::try_from))
56    }
57
58    /// Download a model's files with progress reporting.
59    fn download(
60        &mut self,
61        req: DownloadRequest,
62    ) -> impl Stream<Item = Result<DownloadEvent>> + Send + '_ {
63        self.request_stream(req.into())
64            .take_while(|r| {
65                std::future::ready(!matches!(
66                    r,
67                    Ok(ServerMessage::Download(DownloadEvent::Completed { .. }))
68                ))
69            })
70            .map(|r| r.and_then(DownloadEvent::try_from))
71    }
72
73    /// Install or uninstall a hub package, streaming download events.
74    fn hub(&mut self, req: HubRequest) -> impl Stream<Item = Result<DownloadEvent>> + Send + '_ {
75        self.request_stream(req.into())
76            .take_while(|r| {
77                std::future::ready(!matches!(
78                    r,
79                    Ok(ServerMessage::Download(DownloadEvent::Completed { .. }))
80                ))
81            })
82            .map(|r| r.and_then(DownloadEvent::try_from))
83    }
84
85    /// Ping the server (keepalive).
86    fn ping(&mut self) -> impl std::future::Future<Output = Result<()>> + Send {
87        async move {
88            match self.request(ClientMessage::Ping).await? {
89                ServerMessage::Pong => Ok(()),
90                ServerMessage::Error { code, message } => {
91                    anyhow::bail!("server error ({code}): {message}")
92                }
93                other => anyhow::bail!("unexpected response: {other:?}"),
94            }
95        }
96    }
97
98    /// Subscribe to task lifecycle events.
99    ///
100    /// Streams `TaskEvent`s indefinitely until the connection closes.
101    fn subscribe_tasks(&mut self) -> impl Stream<Item = Result<TaskEvent>> + Send + '_ {
102        self.request_stream(ClientMessage::SubscribeTasks)
103            .map(|r| r.and_then(TaskEvent::try_from))
104    }
105
106    /// Subscribe to download lifecycle events.
107    ///
108    /// Streams `DownloadEvent`s indefinitely until the connection closes.
109    fn subscribe_downloads(&mut self) -> impl Stream<Item = Result<DownloadEvent>> + Send + '_ {
110        self.request_stream(ClientMessage::SubscribeDownloads)
111            .map(|r| r.and_then(DownloadEvent::try_from))
112    }
113
114    /// Get the full daemon config as JSON.
115    fn get_config(&mut self) -> impl std::future::Future<Output = Result<String>> + Send {
116        async move {
117            match self.request(ClientMessage::GetConfig).await? {
118                ServerMessage::Config { config } => Ok(config),
119                ServerMessage::Error { code, message } => {
120                    anyhow::bail!("server error ({code}): {message}")
121                }
122                other => anyhow::bail!("unexpected response: {other:?}"),
123            }
124        }
125    }
126
127    /// Replace the full daemon config from JSON.
128    fn set_config(
129        &mut self,
130        config: String,
131    ) -> impl std::future::Future<Output = Result<()>> + Send {
132        async move {
133            match self.request(ClientMessage::SetConfig { config }).await? {
134                ServerMessage::Pong => Ok(()),
135                ServerMessage::Error { code, message } => {
136                    anyhow::bail!("server error ({code}): {message}")
137                }
138                other => anyhow::bail!("unexpected response: {other:?}"),
139            }
140        }
141    }
142
143    /// Query the memory graph.
144    fn memory_query(
145        &mut self,
146        query: MemoryOp,
147    ) -> impl std::future::Future<Output = Result<MemoryResult>> + Send {
148        async move { MemoryResult::try_from(self.request(ClientMessage::MemoryQuery { query }).await?) }
149    }
150}