Skip to main content

vtcode_core/llm/providers/
stepfun.rs

1use async_stream::try_stream;
2use async_trait::async_trait;
3use reqwest::Client as HttpClient;
4use serde_json::{Map, Value};
5
6use crate::config::TimeoutsConfig;
7use crate::config::constants::{env_vars, models, urls};
8use crate::config::core::{AnthropicConfig, ModelConfig, PromptCachingConfig};
9use crate::config::types::ReasoningEffortLevel;
10use crate::llm::error_display;
11use crate::llm::provider::{
12    LLMError, LLMProvider, LLMRequest, LLMResponse, LLMStream, LLMStreamEvent,
13};
14
15use super::common::{
16    ensure_model, impl_llm_client, map_finish_reason_common, override_base_url,
17    parse_json_response, parse_response_openai_format, resolve_model,
18    serialize_messages_openai_format, serialize_tools_openai_format, validate_supported_models,
19};
20use super::error_handling::handle_openai_http_error;
21use super::extract_reasoning_trace;
22
23const PROVIDER_NAME: &str = "StepFun";
24const PROVIDER_KEY: &str = "stepfun";
25const PRIMARY_API_KEY_ENV: &str = "STEPFUN_API_KEY";
26const LEGACY_API_KEY_ENV: &str = "STEP_API_KEY";
27
28pub struct StepFunProvider {
29    api_key: String,
30    http_client: HttpClient,
31    base_url: String,
32    model: String,
33    model_behavior: Option<ModelConfig>,
34}
35
36impl StepFunProvider {
37    pub fn new(api_key: String) -> Self {
38        Self::with_model_internal(
39            api_key,
40            models::stepfun::DEFAULT_MODEL.to_string(),
41            None,
42            None,
43            None,
44        )
45    }
46
47    pub fn with_model(api_key: String, model: String) -> Self {
48        Self::with_model_internal(api_key, model, None, None, None)
49    }
50
51    pub fn new_with_client(
52        api_key: String,
53        model: String,
54        http_client: reqwest::Client,
55        base_url: String,
56        _timeouts: TimeoutsConfig,
57    ) -> Self {
58        Self {
59            api_key,
60            http_client,
61            base_url,
62            model,
63            model_behavior: None,
64        }
65    }
66
67    pub fn from_config(
68        api_key: Option<String>,
69        model: Option<String>,
70        base_url: Option<String>,
71        _prompt_cache: Option<PromptCachingConfig>,
72        timeouts: Option<TimeoutsConfig>,
73        _anthropic: Option<AnthropicConfig>,
74        model_behavior: Option<ModelConfig>,
75    ) -> Self {
76        let api_key_value = api_key
77            .filter(|key| !key.trim().is_empty())
78            .or_else(|| std::env::var(PRIMARY_API_KEY_ENV).ok())
79            .or_else(|| std::env::var(LEGACY_API_KEY_ENV).ok())
80            .unwrap_or_default();
81
82        Self::with_model_internal(
83            api_key_value,
84            resolve_model(model, models::stepfun::DEFAULT_MODEL),
85            base_url,
86            timeouts,
87            model_behavior,
88        )
89    }
90
91    fn with_model_internal(
92        api_key: String,
93        model: String,
94        base_url: Option<String>,
95        timeouts: Option<TimeoutsConfig>,
96        model_behavior: Option<ModelConfig>,
97    ) -> Self {
98        use crate::llm::http_client::HttpClientFactory;
99
100        let timeouts = timeouts.unwrap_or_default();
101
102        Self {
103            api_key,
104            http_client: HttpClientFactory::for_llm(&timeouts),
105            base_url: override_base_url(
106                urls::STEPFUN_API_BASE,
107                base_url,
108                Some(env_vars::STEPFUN_BASE_URL),
109            ),
110            model,
111            model_behavior,
112        }
113    }
114
115    fn float_to_json_number(value: f32) -> Result<serde_json::Number, LLMError> {
116        serde_json::Number::from_f64(value as f64).ok_or_else(|| LLMError::InvalidRequest {
117            message: "invalid numeric parameter value (NaN or infinity)".to_string(),
118            metadata: None,
119        })
120    }
121
122    fn reasoning_effort_value(effort: ReasoningEffortLevel) -> Option<&'static str> {
123        match effort {
124            ReasoningEffortLevel::None => None,
125            ReasoningEffortLevel::Minimal | ReasoningEffortLevel::Low => Some("low"),
126            ReasoningEffortLevel::Medium => Some("medium"),
127            ReasoningEffortLevel::High
128            | ReasoningEffortLevel::XHigh
129            | ReasoningEffortLevel::Max => Some("high"),
130        }
131    }
132
133    fn is_reasoning_enabled(request: &LLMRequest) -> bool {
134        request
135            .reasoning_effort
136            .is_some_and(|effort| effort != ReasoningEffortLevel::None)
137    }
138
139    fn convert_to_stepfun_format(&self, request: &LLMRequest) -> Result<Value, LLMError> {
140        let mut payload = Map::with_capacity(10);
141        payload.insert("model".to_owned(), Value::String(request.model.clone()));
142
143        let mut messages = serialize_messages_openai_format(request, PROVIDER_KEY)?;
144        if let Some(system_prompt) = &request.system_prompt {
145            let trimmed = system_prompt.trim();
146            if !trimmed.is_empty() {
147                messages.insert(
148                    0,
149                    serde_json::json!({ "role": "system", "content": trimmed }),
150                );
151            }
152        }
153        payload.insert("messages".to_owned(), Value::Array(messages));
154
155        if let Some(max_tokens) = request.max_tokens {
156            payload.insert(
157                "max_tokens".to_owned(),
158                Value::Number(serde_json::Number::from(max_tokens as u64)),
159            );
160        }
161
162        if !Self::is_reasoning_enabled(request) {
163            if let Some(temperature) = request.temperature {
164                payload.insert(
165                    "temperature".to_owned(),
166                    Value::Number(Self::float_to_json_number(temperature)?),
167                );
168            }
169
170            if let Some(top_p) = request.top_p {
171                payload.insert(
172                    "top_p".to_owned(),
173                    Value::Number(Self::float_to_json_number(top_p)?),
174                );
175            }
176        }
177
178        if request.stream {
179            payload.insert("stream".to_owned(), Value::Bool(true));
180        }
181
182        if let Some(tools) = &request.tools
183            && let Some(serialized_tools) = serialize_tools_openai_format(tools)
184        {
185            payload.insert("tools".to_owned(), Value::Array(serialized_tools));
186        }
187
188        if let Some(choice) = &request.tool_choice {
189            payload.insert(
190                "tool_choice".to_owned(),
191                choice.to_provider_format(PROVIDER_KEY),
192            );
193        }
194
195        if let Some(effort) = request.reasoning_effort
196            && let Some(mapped) = Self::reasoning_effort_value(effort)
197        {
198            payload.insert(
199                "reasoning_effort".to_owned(),
200                Value::String(mapped.to_string()),
201            );
202        }
203
204        Ok(Value::Object(payload))
205    }
206}
207
208#[async_trait]
209impl LLMProvider for StepFunProvider {
210    fn name(&self) -> &str {
211        PROVIDER_KEY
212    }
213
214    fn supports_streaming(&self) -> bool {
215        true
216    }
217
218    fn supports_tools(&self, _model: &str) -> bool {
219        true
220    }
221
222    fn supports_structured_output(&self, _model: &str) -> bool {
223        true
224    }
225
226    fn supports_vision(&self, _model: &str) -> bool {
227        true
228    }
229
230    fn supports_reasoning(&self, model: &str) -> bool {
231        let requested = if model.trim().is_empty() {
232            &self.model
233        } else {
234            model
235        };
236
237        self.model_behavior
238            .as_ref()
239            .and_then(|behavior| behavior.model_supports_reasoning)
240            .unwrap_or(false)
241            || models::stepfun::REASONING_MODELS.contains(&requested)
242    }
243
244    fn supports_reasoning_effort(&self, model: &str) -> bool {
245        let requested = if model.trim().is_empty() {
246            &self.model
247        } else {
248            model
249        };
250
251        self.model_behavior
252            .as_ref()
253            .and_then(|behavior| behavior.model_supports_reasoning_effort)
254            .unwrap_or(false)
255            || models::stepfun::REASONING_MODELS.contains(&requested)
256    }
257
258    fn effective_context_size(&self, model: &str) -> usize {
259        let requested = if model.trim().is_empty() {
260            &self.model
261        } else {
262            model
263        };
264
265        match requested {
266            models::stepfun::STEP_3_7_FLASH => 262_144,
267            _ => 262_144,
268        }
269    }
270
271    async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
272        let model = ensure_model(&mut request, &self.model);
273
274        let payload = self.convert_to_stepfun_format(&request)?;
275        let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
276
277        let response = self
278            .http_client
279            .post(&url)
280            .bearer_auth(&self.api_key)
281            .json(&payload)
282            .send()
283            .await
284            .map_err(|error| LLMError::Network {
285                message: error_display::format_llm_error(
286                    PROVIDER_NAME,
287                    &format!("network error: {error}"),
288                ),
289                metadata: None,
290            })?;
291
292        let response =
293            handle_openai_http_error(response, PROVIDER_NAME, PRIMARY_API_KEY_ENV).await?;
294        let response_json = parse_json_response(response, PROVIDER_NAME).await?;
295
296        let reasoning_extractor = |message: &Value, choice: &Value| {
297            message
298                .get("reasoning")
299                .and_then(extract_reasoning_trace)
300                .or_else(|| choice.get("reasoning").and_then(extract_reasoning_trace))
301        };
302
303        parse_response_openai_format(
304            response_json,
305            PROVIDER_NAME,
306            model,
307            false,
308            Some(reasoning_extractor),
309        )
310    }
311
312    async fn stream(&self, mut request: LLMRequest) -> Result<LLMStream, LLMError> {
313        ensure_model(&mut request, &self.model);
314        self.validate_request(&request)?;
315        request.stream = true;
316        let model = request.model.clone();
317
318        let payload = self.convert_to_stepfun_format(&request)?;
319        let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
320
321        let response = self
322            .http_client
323            .post(&url)
324            .bearer_auth(&self.api_key)
325            .json(&payload)
326            .send()
327            .await
328            .map_err(|error| LLMError::Network {
329                message: error_display::format_llm_error(
330                    PROVIDER_NAME,
331                    &format!("network error: {error}"),
332                ),
333                metadata: None,
334            })?;
335
336        let response =
337            handle_openai_http_error(response, PROVIDER_NAME, PRIMARY_API_KEY_ENV).await?;
338
339        let bytes_stream = response.bytes_stream();
340        let (event_tx, event_rx) =
341            tokio::sync::mpsc::unbounded_channel::<Result<LLMStreamEvent, LLMError>>();
342        let tx = event_tx.clone();
343
344        let model_clone = model.clone();
345        tokio::spawn(async move {
346            let mut aggregator =
347                crate::llm::providers::shared::StreamAggregator::new(model_clone.clone());
348
349            let result = crate::llm::providers::shared::process_openai_stream(
350                bytes_stream,
351                PROVIDER_NAME,
352                model_clone,
353                |value| {
354                    if let Some(choices) =
355                        value.get("choices").and_then(|choices| choices.as_array())
356                        && let Some(choice) = choices.first()
357                    {
358                        if let Some(delta) = choice.get("delta") {
359                            if let Some(reasoning) = delta.get("reasoning").and_then(|v| v.as_str())
360                                && let Some(delta) = aggregator.handle_reasoning(reasoning)
361                            {
362                                let _ = tx.send(Ok(LLMStreamEvent::Reasoning { delta }));
363                            }
364
365                            if let Some(content) = delta.get("content").and_then(|v| v.as_str()) {
366                                for event in aggregator.handle_content(content) {
367                                    let _ = tx.send(Ok(event));
368                                }
369                            }
370
371                            if let Some(tool_calls) =
372                                delta.get("tool_calls").and_then(|calls| calls.as_array())
373                            {
374                                aggregator.handle_tool_calls(tool_calls);
375                            }
376                        }
377
378                        if let Some(reason) = choice.get("finish_reason").and_then(|v| v.as_str()) {
379                            aggregator.set_finish_reason(map_finish_reason_common(reason));
380                        }
381                    }
382
383                    if let Some(_usage_value) = value.get("usage")
384                        && let Some(usage) =
385                            crate::llm::providers::common::parse_usage_openai_format(&value, false)
386                    {
387                        aggregator.set_usage(usage);
388                    }
389                    Ok(())
390                },
391            )
392            .await;
393
394            match result {
395                Ok(_) => {
396                    let response = aggregator.finalize();
397                    let _ = tx.send(Ok(LLMStreamEvent::Completed {
398                        response: Box::new(response),
399                    }));
400                }
401                Err(error) => {
402                    let _ = tx.send(Err(error));
403                }
404            }
405        });
406
407        let stream = try_stream! {
408            let mut receiver = event_rx;
409            while let Some(event) = receiver.recv().await {
410                yield event?;
411            }
412        };
413
414        Ok(Box::pin(stream))
415    }
416
417    fn supported_models(&self) -> Vec<String> {
418        models::stepfun::SUPPORTED_MODELS
419            .iter()
420            .map(|model| model.to_string())
421            .collect()
422    }
423
424    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
425        validate_supported_models(
426            request,
427            PROVIDER_NAME,
428            PROVIDER_KEY,
429            models::stepfun::SUPPORTED_MODELS,
430        )
431    }
432}
433
434impl_llm_client!(StepFunProvider);
435
436#[cfg(test)]
437mod tests {
438    use super::StepFunProvider;
439    use crate::config::constants::models;
440    use crate::config::types::ReasoningEffortLevel;
441    use crate::llm::provider::{LLMRequest, Message};
442
443    #[test]
444    fn payload_maps_reasoning_effort() {
445        let provider = StepFunProvider::new("test-key".to_string());
446        let payload = provider
447            .convert_to_stepfun_format(&LLMRequest {
448                model: models::stepfun::STEP_3_7_FLASH.to_string(),
449                messages: vec![Message::user("hello".to_string())],
450                reasoning_effort: Some(ReasoningEffortLevel::XHigh),
451                ..Default::default()
452            })
453            .expect("payload should be valid");
454
455        assert_eq!(
456            payload
457                .get("reasoning_effort")
458                .and_then(|value| value.as_str()),
459            Some("high")
460        );
461        assert!(payload.get("temperature").is_none());
462        assert!(payload.get("top_p").is_none());
463    }
464}