Skip to main content

stepflow_client/
client.rs

1// Copyright 2025 DataStax Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4// in compliance with the License. You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software distributed under the License
9// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10// or implied. See the License for the specific language governing permissions and limitations under
11// the License.
12
13//! High-level gRPC client for the Stepflow orchestrator.
14
15use std::collections::HashMap;
16
17use tonic::transport::Channel;
18
19use crate::error::{ClientError, ClientResult};
20use stepflow_flow::workflow::Flow;
21
22use stepflow_proto::{
23    CreateRunRequest, GetRunEventsRequest, GetRunItemsRequest, GetRunRequest, HealthCheckRequest,
24    ListRegisteredComponentsRequest, StoreFlowRequest,
25    components_service_client::ComponentsServiceClient, flows_service_client::FlowsServiceClient,
26    health_service_client::HealthServiceClient, runs_service_client::RunsServiceClient,
27};
28
29// ---------------------------------------------------------------------------
30// Public return types
31// ---------------------------------------------------------------------------
32
33/// Run status returned by [`StepflowClient::get_run`] and [`StepflowClient::run`].
34#[derive(Debug, Clone)]
35pub struct RunStatus {
36    /// The run's unique identifier.
37    pub run_id: String,
38    /// Numeric execution status (see `ExecutionStatus` proto enum).
39    pub status: i32,
40    /// Outputs for each item in the run.
41    ///
42    /// Populated by [`StepflowClient::run`] (synchronous execution).
43    /// Empty when returned by [`StepflowClient::get_run`] — use
44    /// [`StepflowClient::get_run_items`] to fetch outputs for a completed run.
45    pub outputs: Vec<serde_json::Value>,
46}
47
48/// A single registered component returned by [`StepflowClient::list_components`].
49#[derive(Debug, Clone)]
50pub struct ComponentInfo {
51    /// Component path (e.g. `/builtin/openai`, `/python/my_func`).
52    pub component: String,
53    /// Optional human-readable description.
54    pub description: Option<String>,
55    /// JSON Schema for the component's input, if schemas were requested.
56    pub input_schema: Option<serde_json::Value>,
57    /// JSON Schema for the component's output, if schemas were requested.
58    pub output_schema: Option<serde_json::Value>,
59}
60
61/// Result of [`StepflowClient::list_components`].
62#[derive(Debug, Clone)]
63pub struct ListComponentsResult {
64    /// All discovered components, sorted by path.
65    pub components: Vec<ComponentInfo>,
66    /// `true` if all plugins responded successfully.
67    ///
68    /// When `false`, check `failed_plugins` for plugins that could not be
69    /// reached during discovery.
70    pub complete: bool,
71    /// `(plugin_name, error_message)` pairs for plugins that failed discovery.
72    pub failed_plugins: Vec<(String, String)>,
73}
74
75/// A variable definition returned by [`StepflowClient::get_flow_variables`].
76#[derive(Debug, Clone)]
77pub struct FlowVariable {
78    /// Optional human-readable description.
79    pub description: Option<String>,
80    /// Default value for the variable.
81    pub default_value: Option<serde_json::Value>,
82    /// Whether the variable must be provided at run time.
83    pub required: bool,
84    /// JSON Schema for the variable's expected value.
85    pub schema: Option<serde_json::Value>,
86    /// Environment variable that populates this variable, if any.
87    pub env_var: Option<String>,
88}
89
90// ---------------------------------------------------------------------------
91// Type alias for the status event stream
92// ---------------------------------------------------------------------------
93
94/// A streaming response of [`stepflow_proto::StatusEvent`]s from [`StepflowClient::status_events`].
95///
96/// Drive the stream with [`futures::StreamExt::next`] or `while let Some(event) = stream.message().await`.
97pub type StatusEventStream = tonic::codec::Streaming<stepflow_proto::StatusEvent>;
98
99// ---------------------------------------------------------------------------
100// StepflowClient
101// ---------------------------------------------------------------------------
102
103/// High-level client for interacting with the Stepflow orchestrator.
104///
105/// Wraps the gRPC service clients for flows, runs, health, and component
106/// discovery, providing a convenient API for common operations.
107///
108/// # Example
109///
110/// ```rust,no_run
111/// use stepflow_client::{StepflowClient, FlowBuilder, ValueExpr};
112///
113/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
114/// let mut client = StepflowClient::connect("http://localhost:7840").await?;
115///
116/// let mut builder = FlowBuilder::new();
117/// builder.add_step("hello", "/builtin/eval", ValueExpr::null());
118/// let flow = builder.output(ValueExpr::step_output("hello")).build()?;
119///
120/// let flow_id = client.store_flow(&flow).await?;
121/// let output = client.run(&flow_id, serde_json::json!({"name": "world"})).await?;
122/// println!("{output}");
123/// # Ok(())
124/// # }
125/// ```
126pub struct StepflowClient {
127    flows: FlowsServiceClient<Channel>,
128    runs: RunsServiceClient<Channel>,
129    health: HealthServiceClient<Channel>,
130    components: ComponentsServiceClient<Channel>,
131}
132
133impl StepflowClient {
134    /// Connect to the Stepflow orchestrator at the given URL.
135    ///
136    /// The URL should be in the form `http://host:port` (or `https://...` for TLS).
137    pub async fn connect(url: impl Into<String>) -> ClientResult<Self> {
138        let url = url.into();
139        let channel = Channel::from_shared(url.clone())
140            .map_err(|e| ClientError::Connection {
141                url: url.clone(),
142                source: Box::new(e),
143            })?
144            .connect()
145            .await
146            .map_err(|e| ClientError::Connection {
147                url,
148                source: Box::new(e),
149            })?;
150
151        Ok(Self {
152            flows: FlowsServiceClient::new(channel.clone()),
153            runs: RunsServiceClient::new(channel.clone()),
154            health: HealthServiceClient::new(channel.clone()),
155            components: ComponentsServiceClient::new(channel),
156        })
157    }
158
159    /// Store a flow definition in the orchestrator, returning its flow ID.
160    ///
161    /// The returned flow ID can be passed to [`run`](Self::run) or
162    /// [`submit`](Self::submit).
163    pub async fn store_flow(&mut self, flow: &Flow) -> ClientResult<String> {
164        let flow_json = serde_json::to_value(flow)?;
165        let flow_value = json_to_proto_value(flow_json);
166        let flow_struct = match flow_value.kind {
167            Some(prost_wkt_types::value::Kind::StructValue(s)) => s,
168            _ => {
169                return Err(ClientError::InvalidResponse(
170                    "Flow JSON must be an object".to_string(),
171                ));
172            }
173        };
174
175        let request = StoreFlowRequest {
176            flow: Some(flow_struct),
177            dry_run: false,
178        };
179        let response = self.flows.store_flow(request).await?.into_inner();
180        Ok(response.flow_id)
181    }
182
183    /// Execute a flow synchronously, blocking until it completes, and return the output.
184    ///
185    /// This is equivalent to `submit` + waiting for the run to complete.
186    pub async fn run(
187        &mut self,
188        flow_id: &str,
189        input: serde_json::Value,
190    ) -> ClientResult<serde_json::Value> {
191        let input_proto = json_to_proto_value(input);
192
193        let request = CreateRunRequest {
194            flow_id: flow_id.to_string(),
195            input: vec![input_proto],
196            wait: true,
197            ..Default::default()
198        };
199        let response = self.runs.create_run(request).await?.into_inner();
200
201        // Extract first item's output from the run results
202        if let Some(item) = response.results.first() {
203            if let Some(output) = &item.output {
204                return Ok(proto_value_to_json(output));
205            }
206            if let Some(msg) = &item.error_message {
207                return Err(ClientError::InvalidResponse(format!("Run failed: {msg}")));
208            }
209        }
210
211        Err(ClientError::InvalidResponse(
212            "Run completed but returned no output".to_string(),
213        ))
214    }
215
216    /// Submit a flow for asynchronous execution, returning the run ID.
217    ///
218    /// Use [`get_run`](Self::get_run) to poll for completion and
219    /// [`get_run_items`](Self::get_run_items) to fetch outputs.
220    pub async fn submit(
221        &mut self,
222        flow_id: &str,
223        input: serde_json::Value,
224    ) -> ClientResult<String> {
225        let input_proto = json_to_proto_value(input);
226
227        let request = CreateRunRequest {
228            flow_id: flow_id.to_string(),
229            input: vec![input_proto],
230            wait: false,
231            ..Default::default()
232        };
233        let response = self.runs.create_run(request).await?.into_inner();
234        Ok(response.summary.map(|s| s.run_id).unwrap_or_default())
235    }
236
237    /// Get the status of a run.
238    ///
239    /// If `wait` is true, the request will block until the run completes (or fails).
240    ///
241    /// Note: outputs are not included in the response — use
242    /// [`get_run_items`](Self::get_run_items) to fetch them after the run completes.
243    pub async fn get_run(&mut self, run_id: &str, wait: bool) -> ClientResult<RunStatus> {
244        let request = GetRunRequest {
245            run_id: run_id.to_string(),
246            wait,
247            timeout_secs: None,
248        };
249
250        let response = self.runs.get_run(request).await?.into_inner();
251        let summary = response.summary.unwrap_or_default();
252
253        Ok(RunStatus {
254            run_id: summary.run_id,
255            status: summary.status,
256            outputs: vec![],
257        })
258    }
259
260    /// Get the output of each item in a completed run.
261    ///
262    /// Returns one `serde_json::Value` per input item, in submission order.
263    /// Errors for individual items are currently surfaced as
264    /// [`ClientError::InvalidResponse`].
265    pub async fn get_run_items(&mut self, run_id: &str) -> ClientResult<Vec<serde_json::Value>> {
266        let request = GetRunItemsRequest {
267            run_id: run_id.to_string(),
268            result_order: 0, // RESULT_ORDER_UNSPECIFIED
269        };
270        let response = self.runs.get_run_items(request).await?.into_inner();
271
272        let mut outputs = Vec::with_capacity(response.results.len());
273        for item in &response.results {
274            if let Some(output) = &item.output {
275                outputs.push(proto_value_to_json(output));
276            } else if let Some(msg) = &item.error_message {
277                return Err(ClientError::InvalidResponse(format!(
278                    "Run item failed: {msg}"
279                )));
280            } else {
281                outputs.push(serde_json::Value::Null);
282            }
283        }
284        Ok(outputs)
285    }
286
287    /// List all components registered across all plugins.
288    ///
289    /// Set `exclude_schemas` to `true` to omit JSON Schemas from the response
290    /// (faster when you only need component paths and descriptions).
291    ///
292    /// Note: this triggers on-demand component discovery from all plugins and
293    /// may take a moment if workers haven't connected yet.
294    pub async fn list_components(
295        &mut self,
296        exclude_schemas: bool,
297    ) -> ClientResult<ListComponentsResult> {
298        let request = ListRegisteredComponentsRequest { exclude_schemas };
299        let response = self
300            .components
301            .list_registered_components(request)
302            .await?
303            .into_inner();
304
305        let components = response
306            .components
307            .into_iter()
308            .map(|c| ComponentInfo {
309                component: c.component_id,
310                description: c.description,
311                input_schema: c.input_schema.map(proto_struct_to_json),
312                output_schema: c.output_schema.map(proto_struct_to_json),
313            })
314            .collect();
315
316        let failed_plugins = response
317            .failed_plugins
318            .into_iter()
319            .map(|e| (e.plugin, e.error))
320            .collect();
321
322        Ok(ListComponentsResult {
323            components,
324            complete: response.complete,
325            failed_plugins,
326        })
327    }
328
329    /// Stream execution events for a run.
330    ///
331    /// Returns a server-streaming response that emits [`stepflow_proto::StatusEvent`]s as the
332    /// run progresses.  Drive the stream with `stream.message().await`.
333    ///
334    /// Set `include_sub_runs` to also receive events from nested sub-flows.
335    /// Set `include_results` to include step outputs in completion events.
336    ///
337    /// # Example
338    ///
339    /// ```rust,no_run
340    /// # async fn example(mut client: stepflow_client::StepflowClient, run_id: &str) -> Result<(), Box<dyn std::error::Error>> {
341    /// let mut stream = client.status_events(run_id, false, false).await?;
342    /// while let Some(event) = stream.message().await? {
343    ///     println!("{event:?}");
344    /// }
345    /// # Ok(())
346    /// # }
347    /// ```
348    pub async fn status_events(
349        &mut self,
350        run_id: &str,
351        include_sub_runs: bool,
352        include_results: bool,
353    ) -> ClientResult<StatusEventStream> {
354        let request = GetRunEventsRequest {
355            run_id: run_id.to_string(),
356            since: None,
357            event_types: vec![],
358            include_sub_runs,
359            include_results,
360        };
361        let stream = self.runs.get_run_events(request).await?.into_inner();
362        Ok(stream)
363    }
364
365    /// Get the variable definitions declared in a flow.
366    ///
367    /// Returns a map of variable name → [`FlowVariable`] describing the schema,
368    /// default value, and optional environment variable mapping for each variable.
369    pub async fn get_flow_variables(
370        &mut self,
371        flow_id: &str,
372    ) -> ClientResult<HashMap<String, FlowVariable>> {
373        use stepflow_proto::GetFlowVariablesRequest;
374
375        let request = GetFlowVariablesRequest {
376            flow_id: flow_id.to_string(),
377        };
378        let response = self.flows.get_flow_variables(request).await?.into_inner();
379
380        let variables = response
381            .variables
382            .into_iter()
383            .map(|(name, v)| {
384                (
385                    name,
386                    FlowVariable {
387                        description: v.description,
388                        default_value: v.default_value.as_ref().map(proto_value_to_json),
389                        required: v.required,
390                        schema: v.schema.map(proto_struct_to_json),
391                        env_var: v.env_var,
392                    },
393                )
394            })
395            .collect();
396
397        Ok(variables)
398    }
399
400    /// Check whether the orchestrator is healthy.
401    pub async fn is_healthy(&mut self) -> bool {
402        self.health
403            .health_check(HealthCheckRequest {})
404            .await
405            .is_ok()
406    }
407}
408
409// ---------------------------------------------------------------------------
410// Proto ↔ JSON conversion helpers
411// ---------------------------------------------------------------------------
412
413/// Convert a `serde_json::Value` to `prost_wkt_types::Value`.
414pub(crate) fn json_to_proto_value(value: serde_json::Value) -> prost_wkt_types::Value {
415    use prost_wkt_types::value::Kind;
416    prost_wkt_types::Value {
417        kind: Some(match value {
418            serde_json::Value::Null => Kind::NullValue(0),
419            serde_json::Value::Bool(b) => Kind::BoolValue(b),
420            serde_json::Value::Number(n) => Kind::NumberValue(n.as_f64().unwrap_or(0.0)),
421            serde_json::Value::String(s) => Kind::StringValue(s),
422            serde_json::Value::Array(arr) => Kind::ListValue(prost_wkt_types::ListValue {
423                values: arr.into_iter().map(json_to_proto_value).collect(),
424            }),
425            serde_json::Value::Object(obj) => Kind::StructValue(prost_wkt_types::Struct {
426                fields: obj
427                    .into_iter()
428                    .map(|(k, v)| (k, json_to_proto_value(v)))
429                    .collect(),
430            }),
431        }),
432    }
433}
434
435/// Convert a `prost_wkt_types::Value` to `serde_json::Value`, preserving integer
436/// types for whole-number floats (protobuf always uses f64 for numbers).
437pub(crate) fn proto_value_to_json(value: &prost_wkt_types::Value) -> serde_json::Value {
438    use prost_wkt_types::value::Kind;
439    match &value.kind {
440        Some(Kind::NullValue(_)) | None => serde_json::Value::Null,
441        Some(Kind::BoolValue(b)) => serde_json::Value::Bool(*b),
442        Some(Kind::NumberValue(n)) => {
443            let n = *n;
444            if n.is_finite() && n.fract() == 0.0 {
445                let i = n as i64;
446                if i as f64 == n {
447                    return serde_json::Value::Number(i.into());
448                }
449            }
450            serde_json::Number::from_f64(n)
451                .map(serde_json::Value::Number)
452                .unwrap_or(serde_json::Value::Null)
453        }
454        Some(Kind::StringValue(s)) => serde_json::Value::String(s.clone()),
455        Some(Kind::StructValue(s)) => {
456            let map = s
457                .fields
458                .iter()
459                .map(|(k, v)| (k.clone(), proto_value_to_json(v)))
460                .collect();
461            serde_json::Value::Object(map)
462        }
463        Some(Kind::ListValue(l)) => {
464            serde_json::Value::Array(l.values.iter().map(proto_value_to_json).collect())
465        }
466    }
467}
468
469/// Convert a `prost_wkt_types::Struct` to a `serde_json::Value::Object`.
470fn proto_struct_to_json(s: prost_wkt_types::Struct) -> serde_json::Value {
471    let map = s
472        .fields
473        .into_iter()
474        .map(|(k, v)| (k, proto_value_to_json(&v)))
475        .collect();
476    serde_json::Value::Object(map)
477}