stepflow_client/client.rs
1// Copyright 2025 DataStax Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4// in compliance with the License. You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software distributed under the License
9// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10// or implied. See the License for the specific language governing permissions and limitations under
11// the License.
12
13//! High-level gRPC client for the Stepflow orchestrator.
14
15use std::collections::HashMap;
16
17use tonic::transport::Channel;
18
19use crate::error::{ClientError, ClientResult};
20use stepflow_flow::workflow::Flow;
21
22use stepflow_proto::{
23 CreateRunRequest, GetRunEventsRequest, GetRunItemsRequest, GetRunRequest, HealthCheckRequest,
24 ListRegisteredComponentsRequest, StoreFlowRequest,
25 components_service_client::ComponentsServiceClient, flows_service_client::FlowsServiceClient,
26 health_service_client::HealthServiceClient, runs_service_client::RunsServiceClient,
27};
28
29// ---------------------------------------------------------------------------
30// Public return types
31// ---------------------------------------------------------------------------
32
33/// Run status returned by [`StepflowClient::get_run`] and [`StepflowClient::run`].
34#[derive(Debug, Clone)]
35pub struct RunStatus {
36 /// The run's unique identifier.
37 pub run_id: String,
38 /// Numeric execution status (see `ExecutionStatus` proto enum).
39 pub status: i32,
40 /// Outputs for each item in the run.
41 ///
42 /// Populated by [`StepflowClient::run`] (synchronous execution).
43 /// Empty when returned by [`StepflowClient::get_run`] — use
44 /// [`StepflowClient::get_run_items`] to fetch outputs for a completed run.
45 pub outputs: Vec<serde_json::Value>,
46}
47
48/// A single registered component returned by [`StepflowClient::list_components`].
49#[derive(Debug, Clone)]
50pub struct ComponentInfo {
51 /// Component path (e.g. `/builtin/openai`, `/python/my_func`).
52 pub component: String,
53 /// Optional human-readable description.
54 pub description: Option<String>,
55 /// JSON Schema for the component's input, if schemas were requested.
56 pub input_schema: Option<serde_json::Value>,
57 /// JSON Schema for the component's output, if schemas were requested.
58 pub output_schema: Option<serde_json::Value>,
59}
60
61/// Result of [`StepflowClient::list_components`].
62#[derive(Debug, Clone)]
63pub struct ListComponentsResult {
64 /// All discovered components, sorted by path.
65 pub components: Vec<ComponentInfo>,
66 /// `true` if all plugins responded successfully.
67 ///
68 /// When `false`, check `failed_plugins` for plugins that could not be
69 /// reached during discovery.
70 pub complete: bool,
71 /// `(plugin_name, error_message)` pairs for plugins that failed discovery.
72 pub failed_plugins: Vec<(String, String)>,
73}
74
75/// A variable definition returned by [`StepflowClient::get_flow_variables`].
76#[derive(Debug, Clone)]
77pub struct FlowVariable {
78 /// Optional human-readable description.
79 pub description: Option<String>,
80 /// Default value for the variable.
81 pub default_value: Option<serde_json::Value>,
82 /// Whether the variable must be provided at run time.
83 pub required: bool,
84 /// JSON Schema for the variable's expected value.
85 pub schema: Option<serde_json::Value>,
86 /// Environment variable that populates this variable, if any.
87 pub env_var: Option<String>,
88}
89
90// ---------------------------------------------------------------------------
91// Type alias for the status event stream
92// ---------------------------------------------------------------------------
93
94/// A streaming response of [`stepflow_proto::StatusEvent`]s from [`StepflowClient::status_events`].
95///
96/// Drive the stream with [`futures::StreamExt::next`] or `while let Some(event) = stream.message().await`.
97pub type StatusEventStream = tonic::codec::Streaming<stepflow_proto::StatusEvent>;
98
99// ---------------------------------------------------------------------------
100// StepflowClient
101// ---------------------------------------------------------------------------
102
103/// High-level client for interacting with the Stepflow orchestrator.
104///
105/// Wraps the gRPC service clients for flows, runs, health, and component
106/// discovery, providing a convenient API for common operations.
107///
108/// # Example
109///
110/// ```rust,no_run
111/// use stepflow_client::{StepflowClient, FlowBuilder, ValueExpr};
112///
113/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
114/// let mut client = StepflowClient::connect("http://localhost:7840").await?;
115///
116/// let mut builder = FlowBuilder::new();
117/// builder.add_step("hello", "/builtin/eval", ValueExpr::null());
118/// let flow = builder.output(ValueExpr::step_output("hello")).build()?;
119///
120/// let flow_id = client.store_flow(&flow).await?;
121/// let output = client.run(&flow_id, serde_json::json!({"name": "world"})).await?;
122/// println!("{output}");
123/// # Ok(())
124/// # }
125/// ```
126pub struct StepflowClient {
127 flows: FlowsServiceClient<Channel>,
128 runs: RunsServiceClient<Channel>,
129 health: HealthServiceClient<Channel>,
130 components: ComponentsServiceClient<Channel>,
131}
132
133impl StepflowClient {
134 /// Connect to the Stepflow orchestrator at the given URL.
135 ///
136 /// The URL should be in the form `http://host:port` (or `https://...` for TLS).
137 pub async fn connect(url: impl Into<String>) -> ClientResult<Self> {
138 let url = url.into();
139 let channel = Channel::from_shared(url.clone())
140 .map_err(|e| ClientError::Connection {
141 url: url.clone(),
142 source: Box::new(e),
143 })?
144 .connect()
145 .await
146 .map_err(|e| ClientError::Connection {
147 url,
148 source: Box::new(e),
149 })?;
150
151 Ok(Self {
152 flows: FlowsServiceClient::new(channel.clone()),
153 runs: RunsServiceClient::new(channel.clone()),
154 health: HealthServiceClient::new(channel.clone()),
155 components: ComponentsServiceClient::new(channel),
156 })
157 }
158
159 /// Store a flow definition in the orchestrator, returning its flow ID.
160 ///
161 /// The returned flow ID can be passed to [`run`](Self::run) or
162 /// [`submit`](Self::submit).
163 pub async fn store_flow(&mut self, flow: &Flow) -> ClientResult<String> {
164 let flow_json = serde_json::to_value(flow)?;
165 let flow_value = json_to_proto_value(flow_json);
166 let flow_struct = match flow_value.kind {
167 Some(prost_wkt_types::value::Kind::StructValue(s)) => s,
168 _ => {
169 return Err(ClientError::InvalidResponse(
170 "Flow JSON must be an object".to_string(),
171 ));
172 }
173 };
174
175 let request = StoreFlowRequest {
176 flow: Some(flow_struct),
177 dry_run: false,
178 };
179 let response = self.flows.store_flow(request).await?.into_inner();
180 Ok(response.flow_id)
181 }
182
183 /// Execute a flow synchronously, blocking until it completes, and return the output.
184 ///
185 /// This is equivalent to `submit` + waiting for the run to complete.
186 pub async fn run(
187 &mut self,
188 flow_id: &str,
189 input: serde_json::Value,
190 ) -> ClientResult<serde_json::Value> {
191 let input_proto = json_to_proto_value(input);
192
193 let request = CreateRunRequest {
194 flow_id: flow_id.to_string(),
195 input: vec![input_proto],
196 wait: true,
197 ..Default::default()
198 };
199 let response = self.runs.create_run(request).await?.into_inner();
200
201 // Extract first item's output from the run results
202 if let Some(item) = response.results.first() {
203 if let Some(output) = &item.output {
204 return Ok(proto_value_to_json(output));
205 }
206 if let Some(msg) = &item.error_message {
207 return Err(ClientError::InvalidResponse(format!("Run failed: {msg}")));
208 }
209 }
210
211 Err(ClientError::InvalidResponse(
212 "Run completed but returned no output".to_string(),
213 ))
214 }
215
216 /// Submit a flow for asynchronous execution, returning the run ID.
217 ///
218 /// Use [`get_run`](Self::get_run) to poll for completion and
219 /// [`get_run_items`](Self::get_run_items) to fetch outputs.
220 pub async fn submit(
221 &mut self,
222 flow_id: &str,
223 input: serde_json::Value,
224 ) -> ClientResult<String> {
225 let input_proto = json_to_proto_value(input);
226
227 let request = CreateRunRequest {
228 flow_id: flow_id.to_string(),
229 input: vec![input_proto],
230 wait: false,
231 ..Default::default()
232 };
233 let response = self.runs.create_run(request).await?.into_inner();
234 Ok(response.summary.map(|s| s.run_id).unwrap_or_default())
235 }
236
237 /// Get the status of a run.
238 ///
239 /// If `wait` is true, the request will block until the run completes (or fails).
240 ///
241 /// Note: outputs are not included in the response — use
242 /// [`get_run_items`](Self::get_run_items) to fetch them after the run completes.
243 pub async fn get_run(&mut self, run_id: &str, wait: bool) -> ClientResult<RunStatus> {
244 let request = GetRunRequest {
245 run_id: run_id.to_string(),
246 wait,
247 timeout_secs: None,
248 };
249
250 let response = self.runs.get_run(request).await?.into_inner();
251 let summary = response.summary.unwrap_or_default();
252
253 Ok(RunStatus {
254 run_id: summary.run_id,
255 status: summary.status,
256 outputs: vec![],
257 })
258 }
259
260 /// Get the output of each item in a completed run.
261 ///
262 /// Returns one `serde_json::Value` per input item, in submission order.
263 /// Errors for individual items are currently surfaced as
264 /// [`ClientError::InvalidResponse`].
265 pub async fn get_run_items(&mut self, run_id: &str) -> ClientResult<Vec<serde_json::Value>> {
266 let request = GetRunItemsRequest {
267 run_id: run_id.to_string(),
268 result_order: 0, // RESULT_ORDER_UNSPECIFIED
269 };
270 let response = self.runs.get_run_items(request).await?.into_inner();
271
272 let mut outputs = Vec::with_capacity(response.results.len());
273 for item in &response.results {
274 if let Some(output) = &item.output {
275 outputs.push(proto_value_to_json(output));
276 } else if let Some(msg) = &item.error_message {
277 return Err(ClientError::InvalidResponse(format!(
278 "Run item failed: {msg}"
279 )));
280 } else {
281 outputs.push(serde_json::Value::Null);
282 }
283 }
284 Ok(outputs)
285 }
286
287 /// List all components registered across all plugins.
288 ///
289 /// Set `exclude_schemas` to `true` to omit JSON Schemas from the response
290 /// (faster when you only need component paths and descriptions).
291 ///
292 /// Note: this triggers on-demand component discovery from all plugins and
293 /// may take a moment if workers haven't connected yet.
294 pub async fn list_components(
295 &mut self,
296 exclude_schemas: bool,
297 ) -> ClientResult<ListComponentsResult> {
298 let request = ListRegisteredComponentsRequest { exclude_schemas };
299 let response = self
300 .components
301 .list_registered_components(request)
302 .await?
303 .into_inner();
304
305 let components = response
306 .components
307 .into_iter()
308 .map(|c| ComponentInfo {
309 component: c.component_id,
310 description: c.description,
311 input_schema: c.input_schema.map(proto_struct_to_json),
312 output_schema: c.output_schema.map(proto_struct_to_json),
313 })
314 .collect();
315
316 let failed_plugins = response
317 .failed_plugins
318 .into_iter()
319 .map(|e| (e.plugin, e.error))
320 .collect();
321
322 Ok(ListComponentsResult {
323 components,
324 complete: response.complete,
325 failed_plugins,
326 })
327 }
328
329 /// Stream execution events for a run.
330 ///
331 /// Returns a server-streaming response that emits [`stepflow_proto::StatusEvent`]s as the
332 /// run progresses. Drive the stream with `stream.message().await`.
333 ///
334 /// Set `include_sub_runs` to also receive events from nested sub-flows.
335 /// Set `include_results` to include step outputs in completion events.
336 ///
337 /// # Example
338 ///
339 /// ```rust,no_run
340 /// # async fn example(mut client: stepflow_client::StepflowClient, run_id: &str) -> Result<(), Box<dyn std::error::Error>> {
341 /// let mut stream = client.status_events(run_id, false, false).await?;
342 /// while let Some(event) = stream.message().await? {
343 /// println!("{event:?}");
344 /// }
345 /// # Ok(())
346 /// # }
347 /// ```
348 pub async fn status_events(
349 &mut self,
350 run_id: &str,
351 include_sub_runs: bool,
352 include_results: bool,
353 ) -> ClientResult<StatusEventStream> {
354 let request = GetRunEventsRequest {
355 run_id: run_id.to_string(),
356 since: None,
357 event_types: vec![],
358 include_sub_runs,
359 include_results,
360 };
361 let stream = self.runs.get_run_events(request).await?.into_inner();
362 Ok(stream)
363 }
364
365 /// Get the variable definitions declared in a flow.
366 ///
367 /// Returns a map of variable name → [`FlowVariable`] describing the schema,
368 /// default value, and optional environment variable mapping for each variable.
369 pub async fn get_flow_variables(
370 &mut self,
371 flow_id: &str,
372 ) -> ClientResult<HashMap<String, FlowVariable>> {
373 use stepflow_proto::GetFlowVariablesRequest;
374
375 let request = GetFlowVariablesRequest {
376 flow_id: flow_id.to_string(),
377 };
378 let response = self.flows.get_flow_variables(request).await?.into_inner();
379
380 let variables = response
381 .variables
382 .into_iter()
383 .map(|(name, v)| {
384 (
385 name,
386 FlowVariable {
387 description: v.description,
388 default_value: v.default_value.as_ref().map(proto_value_to_json),
389 required: v.required,
390 schema: v.schema.map(proto_struct_to_json),
391 env_var: v.env_var,
392 },
393 )
394 })
395 .collect();
396
397 Ok(variables)
398 }
399
400 /// Check whether the orchestrator is healthy.
401 pub async fn is_healthy(&mut self) -> bool {
402 self.health
403 .health_check(HealthCheckRequest {})
404 .await
405 .is_ok()
406 }
407}
408
409// ---------------------------------------------------------------------------
410// Proto ↔ JSON conversion helpers
411// ---------------------------------------------------------------------------
412
413/// Convert a `serde_json::Value` to `prost_wkt_types::Value`.
414pub(crate) fn json_to_proto_value(value: serde_json::Value) -> prost_wkt_types::Value {
415 use prost_wkt_types::value::Kind;
416 prost_wkt_types::Value {
417 kind: Some(match value {
418 serde_json::Value::Null => Kind::NullValue(0),
419 serde_json::Value::Bool(b) => Kind::BoolValue(b),
420 serde_json::Value::Number(n) => Kind::NumberValue(n.as_f64().unwrap_or(0.0)),
421 serde_json::Value::String(s) => Kind::StringValue(s),
422 serde_json::Value::Array(arr) => Kind::ListValue(prost_wkt_types::ListValue {
423 values: arr.into_iter().map(json_to_proto_value).collect(),
424 }),
425 serde_json::Value::Object(obj) => Kind::StructValue(prost_wkt_types::Struct {
426 fields: obj
427 .into_iter()
428 .map(|(k, v)| (k, json_to_proto_value(v)))
429 .collect(),
430 }),
431 }),
432 }
433}
434
435/// Convert a `prost_wkt_types::Value` to `serde_json::Value`, preserving integer
436/// types for whole-number floats (protobuf always uses f64 for numbers).
437pub(crate) fn proto_value_to_json(value: &prost_wkt_types::Value) -> serde_json::Value {
438 use prost_wkt_types::value::Kind;
439 match &value.kind {
440 Some(Kind::NullValue(_)) | None => serde_json::Value::Null,
441 Some(Kind::BoolValue(b)) => serde_json::Value::Bool(*b),
442 Some(Kind::NumberValue(n)) => {
443 let n = *n;
444 if n.is_finite() && n.fract() == 0.0 {
445 let i = n as i64;
446 if i as f64 == n {
447 return serde_json::Value::Number(i.into());
448 }
449 }
450 serde_json::Number::from_f64(n)
451 .map(serde_json::Value::Number)
452 .unwrap_or(serde_json::Value::Null)
453 }
454 Some(Kind::StringValue(s)) => serde_json::Value::String(s.clone()),
455 Some(Kind::StructValue(s)) => {
456 let map = s
457 .fields
458 .iter()
459 .map(|(k, v)| (k.clone(), proto_value_to_json(v)))
460 .collect();
461 serde_json::Value::Object(map)
462 }
463 Some(Kind::ListValue(l)) => {
464 serde_json::Value::Array(l.values.iter().map(proto_value_to_json).collect())
465 }
466 }
467}
468
469/// Convert a `prost_wkt_types::Struct` to a `serde_json::Value::Object`.
470fn proto_struct_to_json(s: prost_wkt_types::Struct) -> serde_json::Value {
471 let map = s
472 .fields
473 .into_iter()
474 .map(|(k, v)| (k, proto_value_to_json(&v)))
475 .collect();
476 serde_json::Value::Object(map)
477}