1pub mod common_proto {
8 #![allow(clippy::all, clippy::absolute_paths, unused_qualifications)]
9 tonic::include_proto!("smg.grpc.common");
10}
11pub mod abort_on_drop;
12pub mod channel;
13pub mod mlx_engine;
14pub mod sglang_scheduler;
15pub mod tokenizer_bundle;
16pub mod tokenspeed_scheduler;
17pub mod trtllm_service;
18pub mod vllm_engine;
19
20use std::sync::Arc;
22
23pub use abort_on_drop::{AbortOnDropClient, AbortOnDropStream};
24pub use channel::{connect_channel, normalize_grpc_endpoint};
25pub use mlx_engine::{proto as mlx_proto, MlxEngineClient};
26pub use sglang_scheduler::{proto as sglang_proto, SglangSchedulerClient};
27pub use tokenspeed_scheduler::{tokenspeed_proto, TokenSpeedSchedulerClient};
28use tonic::metadata::MetadataMap;
29pub use trtllm_service::{proto as trtllm_proto, TrtllmServiceClient};
30pub use vllm_engine::{proto as vllm_proto, VllmEngineClient};
31
32macro_rules! impl_get_tokenizer {
39 () => {
40 pub async fn get_tokenizer(
41 &self,
42 ) -> Result<
43 $crate::tokenizer_bundle::StreamBundle,
44 Box<dyn std::error::Error + Send + Sync>,
45 > {
46 use $crate::common_proto::GetTokenizerRequest;
47 let request = tonic::Request::new(GetTokenizerRequest {});
48 let mut client = self.client.clone();
49 $crate::tokenizer_bundle::collect_bundle_from_rpc(
50 client.get_tokenizer(request),
51 |chunk| (chunk.data, chunk.sha256),
52 std::time::Duration::from_secs(120),
53 )
54 .await
55 }
56 };
57}
58pub(crate) use impl_get_tokenizer;
59
60pub const FLUSH_RPC_DEADLINE_MARGIN: std::time::Duration = std::time::Duration::from_secs(45);
65
66pub const PROFILE_RPC_DEADLINE: std::time::Duration = std::time::Duration::from_secs(630);
69
70macro_rules! impl_admin_ops {
77 () => {
78 pub async fn flush_cache(
84 &self,
85 timeout_s: f32,
86 ) -> Result<$crate::common_proto::FlushCacheResponse, tonic::Status> {
87 tracing::debug!("Requesting cache flush (timeout_s={timeout_s})");
88 let mut request =
89 tonic::Request::new($crate::common_proto::FlushCacheRequest { timeout_s });
90 if let Err(e) = self.trace_injector.inject(request.metadata_mut()) {
91 tracing::warn!("Failed to inject trace context: {}", e);
92 }
93 let deadline = std::time::Duration::from_secs_f32(timeout_s.max(0.0))
94 + $crate::FLUSH_RPC_DEADLINE_MARGIN;
95 let mut client = self.client.clone();
96 let response = tokio::time::timeout(deadline, client.flush_cache(request))
97 .await
98 .map_err(|_| {
99 tonic::Status::deadline_exceeded(format!(
100 "FlushCache did not complete within {deadline:?}"
101 ))
102 })??;
103 Ok(response.into_inner())
104 }
105
106 pub async fn start_profile(
108 &self,
109 req: $crate::common_proto::StartProfileRequest,
110 ) -> Result<$crate::common_proto::ProfileResponse, tonic::Status> {
111 tracing::debug!("Requesting profile start");
112 let mut request = tonic::Request::new(req);
113 if let Err(e) = self.trace_injector.inject(request.metadata_mut()) {
114 tracing::warn!("Failed to inject trace context: {}", e);
115 }
116 let mut client = self.client.clone();
117 let response =
118 tokio::time::timeout($crate::PROFILE_RPC_DEADLINE, client.start_profile(request))
119 .await
120 .map_err(|_| {
121 tonic::Status::deadline_exceeded(format!(
122 "StartProfile did not complete within {:?}",
123 $crate::PROFILE_RPC_DEADLINE
124 ))
125 })??;
126 Ok(response.into_inner())
127 }
128
129 pub async fn stop_profile(
131 &self,
132 ) -> Result<$crate::common_proto::ProfileResponse, tonic::Status> {
133 tracing::debug!("Requesting profile stop");
134 let mut request = tonic::Request::new($crate::common_proto::StopProfileRequest {});
135 if let Err(e) = self.trace_injector.inject(request.metadata_mut()) {
136 tracing::warn!("Failed to inject trace context: {}", e);
137 }
138 let mut client = self.client.clone();
139 let response =
140 tokio::time::timeout($crate::PROFILE_RPC_DEADLINE, client.stop_profile(request))
141 .await
142 .map_err(|_| {
143 tonic::Status::deadline_exceeded(format!(
144 "StopProfile did not complete within {:?}",
145 $crate::PROFILE_RPC_DEADLINE
146 ))
147 })??;
148 Ok(response.into_inner())
149 }
150 };
151}
152pub(crate) use impl_admin_ops;
153
154macro_rules! impl_subscribe_kv_events {
160 () => {
161 pub async fn subscribe_kv_events(
164 &self,
165 start_sequence_number: u64,
166 ) -> Result<tonic::Streaming<$crate::common_proto::KvEventBatch>, tonic::Status> {
167 let request = tonic::Request::new($crate::common_proto::SubscribeKvEventsRequest {
168 start_sequence_number,
169 });
170 let mut client = self.client.clone();
171 let response = client.subscribe_kv_events(request).await?;
172 Ok(response.into_inner())
173 }
174 };
175}
176pub(crate) use impl_subscribe_kv_events;
177
178pub trait TraceInjector: Send + Sync {
183 fn inject(
187 &self,
188 metadata: &mut MetadataMap,
189 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
190}
191
192#[derive(Clone, Default)]
194pub struct NoopTraceInjector;
195
196impl TraceInjector for NoopTraceInjector {
197 fn inject(
198 &self,
199 _metadata: &mut MetadataMap,
200 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
201 Ok(())
202 }
203}
204
205pub type BoxedTraceInjector = Arc<dyn TraceInjector>;
207
208macro_rules! impl_engine_client_basics {
220 ($proto_client:path, $display_name:literal) => {
221 pub async fn connect(
223 endpoint: &str,
224 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
225 Self::connect_with_trace_injector(
226 endpoint,
227 std::sync::Arc::new($crate::NoopTraceInjector),
228 )
229 .await
230 }
231
232 pub async fn connect_with_trace_injector(
234 endpoint: &str,
235 trace_injector: $crate::BoxedTraceInjector,
236 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
237 tracing::debug!(
238 "Connecting to {} gRPC server at {}",
239 $display_name,
240 endpoint
241 );
242 let channel = $crate::channel::connect_channel(endpoint).await?;
243 let client = <$proto_client>::new(channel);
244 Ok(Self {
245 client,
246 trace_injector,
247 })
248 }
249
250 #[must_use]
252 pub fn with_trace_injector(mut self, trace_injector: $crate::BoxedTraceInjector) -> Self {
253 self.trace_injector = trace_injector;
254 self
255 }
256
257 pub async fn health_check(&self) -> Result<proto::HealthCheckResponse, tonic::Status> {
259 tracing::debug!("Sending health check request");
260 let request = tonic::Request::new(proto::HealthCheckRequest {});
261 let mut client = self.client.clone();
262 let response = client.health_check(request).await?;
263 tracing::debug!("Health check response received");
264 Ok(response.into_inner())
265 }
266
267 pub async fn get_model_info(&self) -> Result<proto::GetModelInfoResponse, tonic::Status> {
269 tracing::debug!("Requesting model info");
270 let request = tonic::Request::new(proto::GetModelInfoRequest {});
271 let mut client = self.client.clone();
272 let response = client.get_model_info(request).await?;
273 tracing::debug!("Model info response received");
274 Ok(response.into_inner())
275 }
276
277 pub async fn get_server_info(&self) -> Result<proto::GetServerInfoResponse, tonic::Status> {
279 tracing::debug!("Requesting server info");
280 let request = tonic::Request::new(proto::GetServerInfoRequest {});
281 let mut client = self.client.clone();
282 let response = client.get_server_info(request).await?;
283 tracing::debug!("Server info response received");
284 Ok(response.into_inner())
285 }
286 };
287}
288pub(crate) use impl_engine_client_basics;