vtcode_core/llm/providers/
stepfun.rs1use async_stream::try_stream;
2use async_trait::async_trait;
3use reqwest::Client as HttpClient;
4use serde_json::{Map, Value};
5
6use crate::config::TimeoutsConfig;
7use crate::config::constants::{env_vars, models, urls};
8use crate::config::core::{AnthropicConfig, ModelConfig, PromptCachingConfig};
9use crate::config::types::ReasoningEffortLevel;
10use crate::llm::error_display;
11use crate::llm::provider::{
12 LLMError, LLMProvider, LLMRequest, LLMResponse, LLMStream, LLMStreamEvent,
13};
14
15use super::common::{
16 ensure_model, impl_llm_client, map_finish_reason_common, override_base_url,
17 parse_json_response, parse_response_openai_format, resolve_model,
18 serialize_messages_openai_format, serialize_tools_openai_format, validate_supported_models,
19};
20use super::error_handling::handle_openai_http_error;
21use super::extract_reasoning_trace;
22
23const PROVIDER_NAME: &str = "StepFun";
24const PROVIDER_KEY: &str = "stepfun";
25const PRIMARY_API_KEY_ENV: &str = "STEPFUN_API_KEY";
26const LEGACY_API_KEY_ENV: &str = "STEP_API_KEY";
27
28pub struct StepFunProvider {
29 api_key: String,
30 http_client: HttpClient,
31 base_url: String,
32 model: String,
33 model_behavior: Option<ModelConfig>,
34}
35
36impl StepFunProvider {
37 pub fn new(api_key: String) -> Self {
38 Self::with_model_internal(
39 api_key,
40 models::stepfun::DEFAULT_MODEL.to_string(),
41 None,
42 None,
43 None,
44 )
45 }
46
47 pub fn with_model(api_key: String, model: String) -> Self {
48 Self::with_model_internal(api_key, model, None, None, None)
49 }
50
51 pub fn new_with_client(
52 api_key: String,
53 model: String,
54 http_client: reqwest::Client,
55 base_url: String,
56 _timeouts: TimeoutsConfig,
57 ) -> Self {
58 Self {
59 api_key,
60 http_client,
61 base_url,
62 model,
63 model_behavior: None,
64 }
65 }
66
67 pub fn from_config(
68 api_key: Option<String>,
69 model: Option<String>,
70 base_url: Option<String>,
71 _prompt_cache: Option<PromptCachingConfig>,
72 timeouts: Option<TimeoutsConfig>,
73 _anthropic: Option<AnthropicConfig>,
74 model_behavior: Option<ModelConfig>,
75 ) -> Self {
76 let api_key_value = api_key
77 .filter(|key| !key.trim().is_empty())
78 .or_else(|| std::env::var(PRIMARY_API_KEY_ENV).ok())
79 .or_else(|| std::env::var(LEGACY_API_KEY_ENV).ok())
80 .unwrap_or_default();
81
82 Self::with_model_internal(
83 api_key_value,
84 resolve_model(model, models::stepfun::DEFAULT_MODEL),
85 base_url,
86 timeouts,
87 model_behavior,
88 )
89 }
90
91 fn with_model_internal(
92 api_key: String,
93 model: String,
94 base_url: Option<String>,
95 timeouts: Option<TimeoutsConfig>,
96 model_behavior: Option<ModelConfig>,
97 ) -> Self {
98 use crate::llm::http_client::HttpClientFactory;
99
100 let timeouts = timeouts.unwrap_or_default();
101
102 Self {
103 api_key,
104 http_client: HttpClientFactory::for_llm(&timeouts),
105 base_url: override_base_url(
106 urls::STEPFUN_API_BASE,
107 base_url,
108 Some(env_vars::STEPFUN_BASE_URL),
109 ),
110 model,
111 model_behavior,
112 }
113 }
114
115 fn float_to_json_number(value: f32) -> Result<serde_json::Number, LLMError> {
116 serde_json::Number::from_f64(value as f64).ok_or_else(|| LLMError::InvalidRequest {
117 message: "invalid numeric parameter value (NaN or infinity)".to_string(),
118 metadata: None,
119 })
120 }
121
122 fn reasoning_effort_value(effort: ReasoningEffortLevel) -> Option<&'static str> {
123 match effort {
124 ReasoningEffortLevel::None => None,
125 ReasoningEffortLevel::Minimal | ReasoningEffortLevel::Low => Some("low"),
126 ReasoningEffortLevel::Medium => Some("medium"),
127 ReasoningEffortLevel::High
128 | ReasoningEffortLevel::XHigh
129 | ReasoningEffortLevel::Max => Some("high"),
130 }
131 }
132
133 fn is_reasoning_enabled(request: &LLMRequest) -> bool {
134 request
135 .reasoning_effort
136 .is_some_and(|effort| effort != ReasoningEffortLevel::None)
137 }
138
139 fn convert_to_stepfun_format(&self, request: &LLMRequest) -> Result<Value, LLMError> {
140 let mut payload = Map::with_capacity(10);
141 payload.insert("model".to_owned(), Value::String(request.model.clone()));
142
143 let mut messages = serialize_messages_openai_format(request, PROVIDER_KEY)?;
144 if let Some(system_prompt) = &request.system_prompt {
145 let trimmed = system_prompt.trim();
146 if !trimmed.is_empty() {
147 messages.insert(
148 0,
149 serde_json::json!({ "role": "system", "content": trimmed }),
150 );
151 }
152 }
153 payload.insert("messages".to_owned(), Value::Array(messages));
154
155 if let Some(max_tokens) = request.max_tokens {
156 payload.insert(
157 "max_tokens".to_owned(),
158 Value::Number(serde_json::Number::from(max_tokens as u64)),
159 );
160 }
161
162 if !Self::is_reasoning_enabled(request) {
163 if let Some(temperature) = request.temperature {
164 payload.insert(
165 "temperature".to_owned(),
166 Value::Number(Self::float_to_json_number(temperature)?),
167 );
168 }
169
170 if let Some(top_p) = request.top_p {
171 payload.insert(
172 "top_p".to_owned(),
173 Value::Number(Self::float_to_json_number(top_p)?),
174 );
175 }
176 }
177
178 if request.stream {
179 payload.insert("stream".to_owned(), Value::Bool(true));
180 }
181
182 if let Some(tools) = &request.tools
183 && let Some(serialized_tools) = serialize_tools_openai_format(tools)
184 {
185 payload.insert("tools".to_owned(), Value::Array(serialized_tools));
186 }
187
188 if let Some(choice) = &request.tool_choice {
189 payload.insert(
190 "tool_choice".to_owned(),
191 choice.to_provider_format(PROVIDER_KEY),
192 );
193 }
194
195 if let Some(effort) = request.reasoning_effort
196 && let Some(mapped) = Self::reasoning_effort_value(effort)
197 {
198 payload.insert(
199 "reasoning_effort".to_owned(),
200 Value::String(mapped.to_string()),
201 );
202 }
203
204 Ok(Value::Object(payload))
205 }
206}
207
208#[async_trait]
209impl LLMProvider for StepFunProvider {
210 fn name(&self) -> &str {
211 PROVIDER_KEY
212 }
213
214 fn supports_streaming(&self) -> bool {
215 true
216 }
217
218 fn supports_tools(&self, _model: &str) -> bool {
219 true
220 }
221
222 fn supports_structured_output(&self, _model: &str) -> bool {
223 true
224 }
225
226 fn supports_vision(&self, _model: &str) -> bool {
227 true
228 }
229
230 fn supports_reasoning(&self, model: &str) -> bool {
231 let requested = if model.trim().is_empty() {
232 &self.model
233 } else {
234 model
235 };
236
237 self.model_behavior
238 .as_ref()
239 .and_then(|behavior| behavior.model_supports_reasoning)
240 .unwrap_or(false)
241 || models::stepfun::REASONING_MODELS.contains(&requested)
242 }
243
244 fn supports_reasoning_effort(&self, model: &str) -> bool {
245 let requested = if model.trim().is_empty() {
246 &self.model
247 } else {
248 model
249 };
250
251 self.model_behavior
252 .as_ref()
253 .and_then(|behavior| behavior.model_supports_reasoning_effort)
254 .unwrap_or(false)
255 || models::stepfun::REASONING_MODELS.contains(&requested)
256 }
257
258 fn effective_context_size(&self, model: &str) -> usize {
259 let requested = if model.trim().is_empty() {
260 &self.model
261 } else {
262 model
263 };
264
265 match requested {
266 models::stepfun::STEP_3_7_FLASH => 262_144,
267 _ => 262_144,
268 }
269 }
270
271 async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
272 let model = ensure_model(&mut request, &self.model);
273
274 let payload = self.convert_to_stepfun_format(&request)?;
275 let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
276
277 let response = self
278 .http_client
279 .post(&url)
280 .bearer_auth(&self.api_key)
281 .json(&payload)
282 .send()
283 .await
284 .map_err(|error| LLMError::Network {
285 message: error_display::format_llm_error(
286 PROVIDER_NAME,
287 &format!("network error: {error}"),
288 ),
289 metadata: None,
290 })?;
291
292 let response =
293 handle_openai_http_error(response, PROVIDER_NAME, PRIMARY_API_KEY_ENV).await?;
294 let response_json = parse_json_response(response, PROVIDER_NAME).await?;
295
296 let reasoning_extractor = |message: &Value, choice: &Value| {
297 message
298 .get("reasoning")
299 .and_then(extract_reasoning_trace)
300 .or_else(|| choice.get("reasoning").and_then(extract_reasoning_trace))
301 };
302
303 parse_response_openai_format(
304 response_json,
305 PROVIDER_NAME,
306 model,
307 false,
308 Some(reasoning_extractor),
309 )
310 }
311
312 async fn stream(&self, mut request: LLMRequest) -> Result<LLMStream, LLMError> {
313 ensure_model(&mut request, &self.model);
314 self.validate_request(&request)?;
315 request.stream = true;
316 let model = request.model.clone();
317
318 let payload = self.convert_to_stepfun_format(&request)?;
319 let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
320
321 let response = self
322 .http_client
323 .post(&url)
324 .bearer_auth(&self.api_key)
325 .json(&payload)
326 .send()
327 .await
328 .map_err(|error| LLMError::Network {
329 message: error_display::format_llm_error(
330 PROVIDER_NAME,
331 &format!("network error: {error}"),
332 ),
333 metadata: None,
334 })?;
335
336 let response =
337 handle_openai_http_error(response, PROVIDER_NAME, PRIMARY_API_KEY_ENV).await?;
338
339 let bytes_stream = response.bytes_stream();
340 let (event_tx, event_rx) =
341 tokio::sync::mpsc::unbounded_channel::<Result<LLMStreamEvent, LLMError>>();
342 let tx = event_tx.clone();
343
344 let model_clone = model.clone();
345 tokio::spawn(async move {
346 let mut aggregator =
347 crate::llm::providers::shared::StreamAggregator::new(model_clone.clone());
348
349 let result = crate::llm::providers::shared::process_openai_stream(
350 bytes_stream,
351 PROVIDER_NAME,
352 model_clone,
353 |value| {
354 if let Some(choices) =
355 value.get("choices").and_then(|choices| choices.as_array())
356 && let Some(choice) = choices.first()
357 {
358 if let Some(delta) = choice.get("delta") {
359 if let Some(reasoning) = delta.get("reasoning").and_then(|v| v.as_str())
360 && let Some(delta) = aggregator.handle_reasoning(reasoning)
361 {
362 let _ = tx.send(Ok(LLMStreamEvent::Reasoning { delta }));
363 }
364
365 if let Some(content) = delta.get("content").and_then(|v| v.as_str()) {
366 for event in aggregator.handle_content(content) {
367 let _ = tx.send(Ok(event));
368 }
369 }
370
371 if let Some(tool_calls) =
372 delta.get("tool_calls").and_then(|calls| calls.as_array())
373 {
374 aggregator.handle_tool_calls(tool_calls);
375 }
376 }
377
378 if let Some(reason) = choice.get("finish_reason").and_then(|v| v.as_str()) {
379 aggregator.set_finish_reason(map_finish_reason_common(reason));
380 }
381 }
382
383 if let Some(_usage_value) = value.get("usage")
384 && let Some(usage) =
385 crate::llm::providers::common::parse_usage_openai_format(&value, false)
386 {
387 aggregator.set_usage(usage);
388 }
389 Ok(())
390 },
391 )
392 .await;
393
394 match result {
395 Ok(_) => {
396 let response = aggregator.finalize();
397 let _ = tx.send(Ok(LLMStreamEvent::Completed {
398 response: Box::new(response),
399 }));
400 }
401 Err(error) => {
402 let _ = tx.send(Err(error));
403 }
404 }
405 });
406
407 let stream = try_stream! {
408 let mut receiver = event_rx;
409 while let Some(event) = receiver.recv().await {
410 yield event?;
411 }
412 };
413
414 Ok(Box::pin(stream))
415 }
416
417 fn supported_models(&self) -> Vec<String> {
418 models::stepfun::SUPPORTED_MODELS
419 .iter()
420 .map(|model| model.to_string())
421 .collect()
422 }
423
424 fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
425 validate_supported_models(
426 request,
427 PROVIDER_NAME,
428 PROVIDER_KEY,
429 models::stepfun::SUPPORTED_MODELS,
430 )
431 }
432}
433
434impl_llm_client!(StepFunProvider);
435
436#[cfg(test)]
437mod tests {
438 use super::StepFunProvider;
439 use crate::config::constants::models;
440 use crate::config::types::ReasoningEffortLevel;
441 use crate::llm::provider::{LLMRequest, Message};
442
443 #[test]
444 fn payload_maps_reasoning_effort() {
445 let provider = StepFunProvider::new("test-key".to_string());
446 let payload = provider
447 .convert_to_stepfun_format(&LLMRequest {
448 model: models::stepfun::STEP_3_7_FLASH.to_string(),
449 messages: vec![Message::user("hello".to_string())],
450 reasoning_effort: Some(ReasoningEffortLevel::XHigh),
451 ..Default::default()
452 })
453 .expect("payload should be valid");
454
455 assert_eq!(
456 payload
457 .get("reasoning_effort")
458 .and_then(|value| value.as_str()),
459 Some("high")
460 );
461 assert!(payload.get("temperature").is_none());
462 assert!(payload.get("top_p").is_none());
463 }
464}