Skip to main content

vtcode_core/llm/providers/
huggingface.rs

1#![allow(clippy::bind_instead_of_map, clippy::collapsible_if)]
2
3use crate::config::TimeoutsConfig;
4use crate::config::constants::{env_vars, models, urls};
5use crate::config::core::{AnthropicConfig, ModelConfig, PromptCachingConfig};
6use crate::llm::error_display::format_llm_error;
7use crate::llm::provider::{
8    LLMError, LLMErrorMetadata, LLMProvider, LLMRequest, LLMResponse, LLMStream, LLMStreamEvent,
9    MessageRole, ToolDefinition,
10};
11use crate::llm::providers::shared::{
12    NoopStreamTelemetry, StreamTelemetry, function_output_value_from_message_content,
13};
14use async_stream::try_stream;
15use async_trait::async_trait;
16use futures::StreamExt;
17use reqwest::{Client as HttpClient, Response, StatusCode};
18use serde_json::{Value, json};
19
20use super::common::{
21    assistant_interleaved_history_text, ensure_model, impl_llm_client, is_minimax_m2_model,
22    map_finish_reason_common, normalize_reasoning_detail_objects, override_base_url,
23    parse_response_openai_format, resolve_model,
24};
25use super::error_handling::{format_network_error, format_parse_error};
26
27const PROVIDER_NAME: &str = "HuggingFace";
28const PROVIDER_KEY: &str = "huggingface";
29const JSON_INSTRUCTION: &str = "Return JSON that matches the provided schema.";
30
31pub struct HuggingFaceProvider {
32    api_key: String,
33    http_client: HttpClient,
34    base_url: String,
35    model: String,
36    _timeouts: TimeoutsConfig,
37    model_behavior: Option<ModelConfig>,
38}
39
40impl HuggingFaceProvider {
41    pub fn new(api_key: String) -> Self {
42        Self::with_model_internal(
43            api_key,
44            models::huggingface::DEFAULT_MODEL.to_string(),
45            None,
46            None,
47            None,
48        )
49    }
50
51    pub fn with_model(api_key: String, model: String) -> Self {
52        Self::with_model_internal(api_key, model, None, None, None)
53    }
54
55    pub fn with_timeouts(api_key: String, timeouts: TimeoutsConfig) -> Self {
56        Self::with_model_internal(
57            api_key,
58            models::huggingface::DEFAULT_MODEL.to_string(),
59            None,
60            Some(timeouts),
61            None,
62        )
63    }
64
65    fn with_model_internal(
66        api_key: String,
67        model: String,
68        base_url: Option<String>,
69        timeouts: Option<TimeoutsConfig>,
70        model_behavior: Option<ModelConfig>,
71    ) -> Self {
72        use crate::llm::http_client::HttpClientFactory;
73
74        let timeouts = timeouts.unwrap_or_default();
75
76        Self {
77            api_key,
78            http_client: HttpClientFactory::for_llm(&timeouts),
79            base_url: override_base_url(
80                urls::HUGGINGFACE_API_BASE,
81                base_url,
82                Some(env_vars::HUGGINGFACE_BASE_URL),
83            ),
84            model,
85            _timeouts: timeouts,
86            model_behavior,
87        }
88    }
89
90    pub fn from_config(
91        api_key: Option<String>,
92        model: Option<String>,
93        base_url: Option<String>,
94        _prompt_cache: Option<PromptCachingConfig>,
95        timeouts: Option<TimeoutsConfig>,
96        _anthropic: Option<AnthropicConfig>,
97        model_behavior: Option<ModelConfig>,
98    ) -> Self {
99        let api_key_value = api_key.unwrap_or_default();
100        let model_value = resolve_model(model, models::huggingface::DEFAULT_MODEL);
101        Self::with_model_internal(
102            api_key_value,
103            model_value,
104            base_url,
105            timeouts,
106            model_behavior,
107        )
108    }
109
110    fn normalize_model_id(&self, model: &str) -> Result<String, LLMError> {
111        let model = model.trim();
112        let lower = model.to_ascii_lowercase();
113
114        if lower.starts_with(&models::huggingface::STEP_3_5_FLASH_BASE.to_ascii_lowercase()) {
115            if !model.contains(':') {
116                return Ok(format!(
117                    "{}:{}",
118                    models::huggingface::STEP_3_5_FLASH_BASE,
119                    models::huggingface::STEP_3_5_FLASH_PROVIDER
120                ));
121            }
122            if let Some((base, provider)) = model.rsplit_once(':')
123                && provider.eq_ignore_ascii_case("fastest")
124            {
125                return Ok(format!(
126                    "{}:{}",
127                    base,
128                    models::huggingface::STEP_3_5_FLASH_PROVIDER
129                ));
130            }
131        }
132
133        if lower.contains("minimax-m2") && !model.contains(':') {
134            return Err(LLMError::Provider {
135                message: format_llm_error(
136                    PROVIDER_NAME,
137                    "MiniMax models require explicit provider selection (:novita suffix). \n                    Use 'MiniMaxAI/MiniMax-M2.5:novita'.",
138                ),
139                metadata: None,
140            });
141        }
142
143        if lower.contains("glm-5") && !model.contains(':') {
144            return Err(LLMError::Provider {
145                message: format_llm_error(
146                    PROVIDER_NAME,
147                    "GLM models require explicit provider selection on HuggingFace.",
148                ),
149                metadata: None,
150            });
151        }
152
153        Ok(model.to_string())
154    }
155
156    fn serialize_tools_huggingface(&self, tools: &[ToolDefinition]) -> Option<Vec<Value>> {
157        crate::llm::providers::common::serialize_tools_openai_format(tools)
158    }
159
160    fn serialize_messages_huggingface_chat(
161        &self,
162        request: &LLMRequest,
163    ) -> Result<Vec<Value>, LLMError> {
164        use serde_json::{Map, json};
165
166        let mut messages = Vec::with_capacity(request.messages.len());
167
168        for message in &request.messages {
169            message
170                .validate_for_provider(PROVIDER_KEY)
171                .map_err(|e| LLMError::InvalidRequest {
172                    message: e,
173                    metadata: None,
174                })?;
175
176            let mut message_map = Map::with_capacity(4);
177            message_map.insert(
178                "role".to_owned(),
179                Value::String(message.role.as_generic_str().to_owned()),
180            );
181
182            if let Some(interleaved_content) =
183                assistant_interleaved_history_text(message, &request.model)
184            {
185                message_map.insert("content".to_owned(), Value::String(interleaved_content));
186            } else {
187                match &message.content {
188                    crate::llm::provider::MessageContent::Text(text) => {
189                        message_map.insert("content".to_owned(), Value::String(text.clone()));
190                    }
191                    crate::llm::provider::MessageContent::Parts(parts) => {
192                        let has_images = parts
193                            .iter()
194                            .any(crate::llm::provider::ContentPart::is_image);
195                        if has_images {
196                            let parts_json: Vec<Value> = parts
197                            .iter()
198                            .map(|part| match part {
199                                crate::llm::provider::ContentPart::Text { text } => {
200                                    json!({ "type": "text", "text": text })
201                                }
202                                crate::llm::provider::ContentPart::Image {
203                                    data,
204                                    mime_type,
205                                    ..
206                                } => {
207                                    json!({
208                                        "type": "image_url",
209                                        "image_url": {
210                                            "url": format!("data:{};base64,{}", mime_type, data)
211                                        }
212                                    })
213                                }
214                                crate::llm::provider::ContentPart::File {
215                                    filename,
216                                    file_id,
217                                    file_url,
218                                    ..
219                                } => {
220                                    let fallback = filename
221                                        .clone()
222                                        .or_else(|| file_id.clone())
223                                        .or_else(|| file_url.clone())
224                                        .unwrap_or_else(|| "attached file".to_string());
225                                    json!({ "type": "text", "text": format!("[File input not directly supported: {}]", fallback) })
226                                }
227                            })
228                            .collect();
229                            message_map.insert("content".to_owned(), Value::Array(parts_json));
230                        } else {
231                            let text = message.content.as_text().into_owned();
232                            message_map.insert("content".to_owned(), Value::String(text));
233                        }
234                    }
235                }
236            }
237
238            if let Some(tool_calls) = &message.tool_calls {
239                let serialized_calls = tool_calls
240                    .iter()
241                    .filter_map(|call| {
242                        call.function.as_ref().map(|func| {
243                            json!({
244                                "id": &call.id,
245                                "type": "function",
246                                "function": {
247                                    "name": &func.name,
248                                    "arguments": &func.arguments
249                                }
250                            })
251                        })
252                    })
253                    .collect::<Vec<_>>();
254                message_map.insert("tool_calls".to_owned(), Value::Array(serialized_calls));
255            }
256
257            if let Some(tool_call_id) = &message.tool_call_id {
258                message_map.insert(
259                    "tool_call_id".to_owned(),
260                    Value::String(tool_call_id.clone()),
261                );
262            }
263
264            if message.role == MessageRole::Assistant
265                && is_minimax_m2_model(&request.model)
266                && let Some(reasoning_details) = &message.reasoning_details
267                && !reasoning_details.is_empty()
268            {
269                let normalized_details = normalize_reasoning_detail_objects(reasoning_details);
270                if !normalized_details.is_empty() {
271                    message_map.insert(
272                        "reasoning_details".to_owned(),
273                        Value::Array(normalized_details),
274                    );
275                }
276            }
277
278            messages.push(Value::Object(message_map));
279        }
280
281        Ok(messages)
282    }
283
284    fn format_for_chat_completions(&self, request: &LLMRequest) -> Result<Value, LLMError> {
285        let mut messages = self.serialize_messages_huggingface_chat(request)?;
286        let is_glm = self.is_glm_model(&request.model);
287
288        if let Some(system) = &request.system_prompt {
289            let has_system = messages
290                .first()
291                .and_then(|m| m.get("role"))
292                .and_then(|r| r.as_str())
293                == Some("system");
294            if !has_system {
295                messages.insert(
296                    0,
297                    json!({
298                        "role": "system",
299                        "content": system
300                    }),
301                );
302            }
303        }
304
305        let mut payload = json!({
306            "model": request.model,
307            "messages": messages,
308            "stream": request.stream,
309        });
310
311        if request.stream && request.tools.is_some() && is_glm {
312            payload["tool_stream"] = json!(true);
313        }
314
315        if let Some(max_tokens) = request.max_tokens {
316            payload["max_tokens"] = json!(max_tokens);
317        }
318
319        if let Some(tools) = &request.tools {
320            if let Some(serialized) = self.serialize_tools_huggingface(tools) {
321                payload["tools"] = json!(serialized);
322
323                if let Some(choice) = &request.tool_choice {
324                    payload["tool_choice"] = choice.to_provider_format("openai");
325                }
326            }
327        }
328
329        if let Some(temperature) = request.temperature {
330            payload["temperature"] = json!(temperature);
331        }
332
333        if let Some(top_p) = request.top_p {
334            payload["top_p"] = json!(top_p);
335        }
336
337        if let Some(top_k) = request.top_k {
338            payload["top_k"] = json!(top_k);
339        }
340
341        if let Some(effort) = request.reasoning_effort {
342            use crate::config::models::Provider;
343            use crate::llm::rig_adapter::RigProviderCapabilities;
344            if let Some(reasoning_params) =
345                RigProviderCapabilities::new(Provider::HuggingFace, &request.model)
346                    .reasoning_parameters(effort)
347            {
348                if let Some(params_obj) = reasoning_params.as_object() {
349                    for (k, v) in params_obj {
350                        payload[k] = v.clone();
351                    }
352                }
353            }
354        }
355
356        if request.output_format.is_some() && !is_glm {
357            payload["response_format"] = json!({ "type": "json_object" });
358        }
359
360        Ok(payload)
361    }
362
363    fn is_glm_model(&self, model: &str) -> bool {
364        let lower = model.to_ascii_lowercase();
365        lower.contains("glm")
366    }
367
368    fn is_deepseek_model(&self, model: &str) -> bool {
369        let lower = model.to_ascii_lowercase();
370        lower.contains("deepseek")
371    }
372
373    fn is_minimax_model(&self, model: &str) -> bool {
374        let lower = model.to_ascii_lowercase();
375        lower.contains("minimax")
376    }
377
378    fn apply_model_defaults(&self, request: &mut LLMRequest) {
379        if self.is_minimax_model(&request.model) {
380            if request.temperature.is_none() {
381                request.temperature = Some(1.0);
382            }
383            if request.top_p.is_none() {
384                request.top_p = Some(0.95);
385            }
386            if request.top_k.is_none() {
387                request.top_k = Some(40);
388            }
389        }
390    }
391
392    fn add_json_instruction(&self, payload: &mut Value) -> Result<(), LLMError> {
393        if let Some(instructions) = payload.get_mut("instructions") {
394            if let Some(text) = instructions.as_str() {
395                if !text.contains("Return JSON") {
396                    *instructions = json!(format!("{}\n\n{}", text, JSON_INSTRUCTION));
397                }
398            }
399        } else {
400            payload["instructions"] = json!(JSON_INSTRUCTION);
401        }
402
403        Ok(())
404    }
405
406    fn format_for_responses_api(&self, request: &LLMRequest) -> Result<Value, LLMError> {
407        let mut input = Vec::new();
408
409        for msg in &request.messages {
410            let convert_parts = |parts: &[crate::llm::provider::ContentPart]| -> Value {
411                let parts_json: Vec<Value> = parts
412                    .iter()
413                    .map(|part| match part {
414                        crate::llm::provider::ContentPart::Text { text } => {
415                            json!({ "type": "input_text", "text": text })
416                        }
417                        crate::llm::provider::ContentPart::Image {
418                            data, mime_type, ..
419                        } => {
420                            json!({
421                                "type": "input_image",
422                                "image_url": format!("data:{};base64,{}", mime_type, data)
423                            })
424                        }
425                        crate::llm::provider::ContentPart::File {
426                            filename,
427                            file_id,
428                            file_url,
429                            ..
430                        } => {
431                            let fallback = filename
432                                .clone()
433                                .or_else(|| file_id.clone())
434                                .or_else(|| file_url.clone())
435                                .unwrap_or_else(|| "attached file".to_string());
436                            json!({
437                                "type": "input_text",
438                                "text": format!("[File input not directly supported: {}]", fallback)
439                            })
440                        }
441                    })
442                    .collect();
443                json!(parts_json)
444            };
445
446            match msg.role {
447                MessageRole::System | MessageRole::User => {
448                    if msg.role == MessageRole::System && request.system_prompt.is_some() {
449                        if let crate::llm::provider::MessageContent::Text(text) = &msg.content {
450                            if request.system_prompt.as_ref().map(|s| s.as_str())
451                                == Some(text.as_str())
452                            {
453                                continue;
454                            }
455                        }
456                    }
457
458                    let role = if msg.role == MessageRole::System {
459                        "system"
460                    } else {
461                        "user"
462                    };
463
464                    let mut message_obj = json!({
465                        "type": "message",
466                        "role": role,
467                    });
468
469                    match &msg.content {
470                        crate::llm::provider::MessageContent::Text(text) => {
471                            message_obj["content"] = json!(text);
472                        }
473                        crate::llm::provider::MessageContent::Parts(parts) => {
474                            message_obj["content"] = convert_parts(parts);
475                        }
476                    }
477
478                    input.push(message_obj);
479                }
480                MessageRole::Assistant => {
481                    let has_content = match &msg.content {
482                        crate::llm::provider::MessageContent::Text(text) => !text.is_empty(),
483                        crate::llm::provider::MessageContent::Parts(parts) => !parts.is_empty(),
484                    };
485
486                    if has_content {
487                        let mut message_obj = json!({
488                            "type": "message",
489                            "role": "assistant",
490                        });
491
492                        match &msg.content {
493                            crate::llm::provider::MessageContent::Text(text) => {
494                                message_obj["content"] = json!(text);
495                            }
496                            crate::llm::provider::MessageContent::Parts(parts) => {
497                                message_obj["content"] = convert_parts(parts);
498                            }
499                        }
500
501                        input.push(message_obj);
502                    }
503
504                    if let Some(tool_calls) = &msg.tool_calls {
505                        for tc in tool_calls {
506                            if let Some(func) = &tc.function {
507                                input.push(json!({
508                                    "type": "function_call",
509                                    "call_id": tc.id,
510                                    "name": func.name,
511                                    "arguments": func.arguments
512                                }));
513                            }
514                        }
515                    }
516                }
517                MessageRole::Tool => {
518                    input.push(json!({
519                        "type": "function_call_output",
520                        "call_id": msg.tool_call_id.clone().unwrap_or_default(),
521                        "output": function_output_value_from_message_content(&msg.content)
522                    }));
523                }
524            }
525        }
526
527        let mut payload = json!({
528            "model": request.model,
529            "input": input,
530            "stream": request.stream,
531        });
532
533        if let Some(system_prompt) = &request.system_prompt {
534            payload["instructions"] = json!(system_prompt);
535        }
536
537        if let Some(effort) = request.reasoning_effort {
538            use crate::config::types::ReasoningEffortLevel;
539            if effort != ReasoningEffortLevel::None {
540                payload["reasoning"] = json!({ "effort": effort.as_str() });
541            }
542        }
543
544        if let Some(max_tokens) = request.max_tokens {
545            payload["max_tokens"] = json!(max_tokens);
546        }
547        if let Some(temperature) = request.temperature {
548            payload["temperature"] = json!(temperature);
549        }
550        if let Some(top_p) = request.top_p {
551            payload["top_p"] = json!(top_p);
552        }
553        if let Some(top_k) = request.top_k {
554            payload["top_k"] = json!(top_k);
555        }
556
557        if let Some(tools) = &request.tools {
558            if let Some(serialized) = self.serialize_tools_huggingface(tools) {
559                payload["tools"] = json!(serialized);
560
561                if let Some(choice) = &request.tool_choice {
562                    payload["tool_choice"] = choice.to_provider_format("openai");
563                }
564            }
565        }
566
567        if request.output_format.is_some() || request.tools.is_some() {
568            self.add_json_instruction(&mut payload)?;
569        }
570
571        if request.output_format.is_some() && !self.is_glm_model(&request.model) {
572            payload["response_format"] = json!({ "type": "json_object" });
573        }
574
575        Ok(payload)
576    }
577
578    fn should_use_responses_api(&self, _request: &LLMRequest) -> bool {
579        false
580    }
581
582    fn format_error(&self, status: StatusCode, body: &str) -> LLMError {
583        let message = if body.contains("\"code\":\"model_not_supported\"")
584            && body.contains(models::huggingface::STEP_3_5_FLASH_BASE)
585        {
586            format!(
587                "HuggingFace API error ({}): Step 3.5 Flash requires the '{}' provider. \
588Enable that provider in your HuggingFace Inference Providers settings, or switch to another model.",
589                status,
590                models::huggingface::STEP_3_5_FLASH_PROVIDER
591            )
592        } else {
593            format!("HuggingFace API error ({}): {}", status, body)
594        };
595
596        LLMError::Provider {
597            message: format_llm_error(PROVIDER_NAME, &message),
598            metadata: Some(LLMErrorMetadata::new(
599                PROVIDER_NAME,
600                Some(status.as_u16()),
601                None,
602                None,
603                None,
604                None,
605                Some(body.to_string()),
606            )),
607        }
608    }
609
610    fn parse_responses_api_format(json: &Value, model: String) -> Result<LLMResponse, LLMError> {
611        let convenience_text = json.get("output_text").and_then(|t| t.as_str());
612
613        let json_obj = json.get("response").unwrap_or(json);
614
615        let output = json_obj.get("output").and_then(|v| v.as_array());
616
617        let output_arr = match output {
618            Some(arr) => arr,
619            None => {
620                if let Some(text) = convenience_text {
621                    return Ok(LLMResponse {
622                        content: Some(text.to_string()),
623                        tool_calls: None,
624                        model,
625                        usage: None,
626                        finish_reason: crate::llm::provider::FinishReason::Stop,
627                        reasoning: None,
628                        reasoning_details: None,
629                        tool_references: Vec::new(),
630                        request_id: None,
631                        organization_id: None,
632                        compaction: None,
633                    });
634                }
635
636                return Err(LLMError::Provider {
637                    message: format_llm_error(PROVIDER_NAME, "Not a Responses API format"),
638                    metadata: None,
639                });
640            }
641        };
642
643        let mut content_fragments: Vec<String> = Vec::new();
644        let mut reasoning_fragments: Vec<String> = Vec::new();
645        let mut tool_calls: Vec<crate::llm::provider::ToolCall> = Vec::new();
646
647        for item in output_arr {
648            let item_type = item.get("type").and_then(|t| t.as_str()).unwrap_or("");
649
650            match item_type {
651                "message" => {
652                    if let Some(content_arr) = item.get("content").and_then(|c| c.as_array()) {
653                        for entry in content_arr {
654                            let entry_type =
655                                entry.get("type").and_then(|t| t.as_str()).unwrap_or("");
656                            match entry_type {
657                                "text" | "output_text" => {
658                                    if let Some(text) = entry.get("text").and_then(|t| t.as_str()) {
659                                        if !text.is_empty() {
660                                            content_fragments.push(text.to_string());
661                                        }
662                                    }
663                                }
664                                "reasoning" => {
665                                    if let Some(text) = entry.get("text").and_then(|t| t.as_str()) {
666                                        if !text.is_empty() {
667                                            reasoning_fragments.push(text.to_string());
668                                        }
669                                    }
670                                }
671                                "function_call" | "tool_call" => {
672                                    if let Some(call) = Self::parse_responses_tool_call(entry) {
673                                        tool_calls.push(call);
674                                    }
675                                }
676                                _ => {}
677                            }
678                        }
679                    }
680                }
681                "function_call" | "tool_call" => {
682                    if let Some(call) = Self::parse_responses_tool_call(item) {
683                        tool_calls.push(call);
684                    }
685                }
686                "reasoning" => {
687                    if let Some(summary_arr) = item.get("summary").and_then(|s| s.as_array()) {
688                        for summary in summary_arr {
689                            if let Some(text) = summary.get("text").and_then(|t| t.as_str()) {
690                                if !text.is_empty() {
691                                    reasoning_fragments.push(text.to_string());
692                                }
693                            }
694                        }
695                    } else if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
696                        reasoning_fragments.push(text.to_string());
697                    }
698                }
699                _ => {}
700            }
701        }
702
703        let content = if content_fragments.is_empty() {
704            convenience_text.map(|t| t.to_string())
705        } else {
706            Some(content_fragments.join(""))
707        };
708
709        let reasoning = if reasoning_fragments.is_empty() {
710            None
711        } else {
712            Some(reasoning_fragments.join("\n\n"))
713        };
714
715        let finish_reason = if !tool_calls.is_empty() {
716            crate::llm::provider::FinishReason::ToolCalls
717        } else {
718            crate::llm::provider::FinishReason::Stop
719        };
720
721        let usage_value = json.get("usage").or_else(|| json_obj.get("usage"));
722        let usage = usage_value.map(|usage_value| crate::llm::provider::Usage {
723            prompt_tokens: usage_value
724                .get("input_tokens")
725                .or_else(|| usage_value.get("prompt_tokens"))
726                .and_then(|pt| pt.as_u64())
727                .unwrap_or(0) as u32,
728            completion_tokens: usage_value
729                .get("output_tokens")
730                .or_else(|| usage_value.get("completion_tokens"))
731                .and_then(|ct| ct.as_u64())
732                .unwrap_or(0) as u32,
733            total_tokens: usage_value
734                .get("total_tokens")
735                .and_then(|tt| tt.as_u64())
736                .unwrap_or(0) as u32,
737            cached_prompt_tokens: None,
738            cache_creation_tokens: None,
739            cache_read_tokens: None,
740            iterations: None,
741        });
742
743        Ok(LLMResponse {
744            content,
745            tool_calls: if tool_calls.is_empty() {
746                None
747            } else {
748                Some(tool_calls)
749            },
750            model,
751            usage,
752            finish_reason,
753            reasoning,
754            reasoning_details: None,
755            tool_references: Vec::new(),
756            request_id: None,
757            organization_id: None,
758            compaction: None,
759        })
760    }
761
762    fn parse_responses_tool_call(item: &Value) -> Option<crate::llm::provider::ToolCall> {
763        let call_id = item.get("id").and_then(|v| v.as_str()).unwrap_or("");
764        let function_obj = item.get("function").and_then(|v| v.as_object());
765        let name = function_obj.and_then(|f| f.get("name").and_then(|n| n.as_str()))?;
766        let arguments = function_obj.and_then(|f| f.get("arguments"));
767
768        let serialized = arguments.map_or("{}".to_owned(), |args| {
769            if args.is_string() {
770                args.as_str().unwrap_or("{}").to_string()
771            } else {
772                args.to_string()
773            }
774        });
775
776        Some(crate::llm::provider::ToolCall::function(
777            call_id.to_string(),
778            name.to_string(),
779            serialized,
780        ))
781    }
782
783    async fn parse_response(
784        &self,
785        response: Response,
786        model: String,
787        use_responses_api: bool,
788    ) -> Result<LLMResponse, LLMError> {
789        let status = response.status();
790
791        if !status.is_success() {
792            let body = response.text().await.unwrap_or_default();
793            return Err(self.format_error(status, &body));
794        }
795
796        let json: Value = response
797            .json()
798            .await
799            .map_err(|err| format_parse_error(PROVIDER_NAME, &err))?;
800
801        if use_responses_api {
802            if json.get("output").is_some() {
803                return Self::parse_responses_api_format(&json, model);
804            }
805        }
806
807        parse_response_openai_format::<fn(&Value, &Value) -> Option<String>>(
808            json,
809            PROVIDER_NAME,
810            model,
811            false,
812            None,
813        )
814    }
815
816    pub fn available_models() -> Vec<String> {
817        models::huggingface::SUPPORTED_MODELS
818            .iter()
819            .map(|s| s.to_string())
820            .collect()
821    }
822
823    fn get_endpoint(&self, use_responses_api: bool) -> String {
824        let base = self.base_url.trim_end_matches('/');
825        if use_responses_api {
826            format!("{}/responses", base)
827        } else {
828            format!("{}/chat/completions", base)
829        }
830    }
831}
832
833#[async_trait]
834impl LLMProvider for HuggingFaceProvider {
835    fn name(&self) -> &str {
836        PROVIDER_KEY
837    }
838
839    fn supports_streaming(&self) -> bool {
840        true
841    }
842
843    fn supports_reasoning(&self, model: &str) -> bool {
844        // Codex-inspired robustness: Setting model_supports_reasoning to false
845        // does NOT disable it for known reasoning models.
846        models::huggingface::REASONING_MODELS.contains(&model)
847            || self
848                .model_behavior
849                .as_ref()
850                .and_then(|b| b.model_supports_reasoning)
851                .unwrap_or(false)
852    }
853
854    fn supports_reasoning_effort(&self, model: &str) -> bool {
855        // Same robustness logic for reasoning effort
856        self.is_glm_model(model)
857            || self.is_deepseek_model(model)
858            || self
859                .model_behavior
860                .as_ref()
861                .and_then(|b| b.model_supports_reasoning_effort)
862                .unwrap_or(false)
863    }
864
865    fn supports_tools(&self, _model: &str) -> bool {
866        true
867    }
868
869    fn supports_parallel_tool_config(&self, _model: &str) -> bool {
870        false
871    }
872
873    fn supports_structured_output(&self, _model: &str) -> bool {
874        true
875    }
876
877    fn supports_context_caching(&self, _model: &str) -> bool {
878        false
879    }
880
881    fn effective_context_size(&self, _model: &str) -> usize {
882        128_000
883    }
884
885    async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
886        let model = ensure_model(&mut request, &self.model);
887
888        self.apply_model_defaults(&mut request);
889        self.validate_request(&request)?;
890
891        let model_id = self.normalize_model_id(&request.model)?;
892        request.model = model_id;
893
894        let use_responses_api = self.should_use_responses_api(&request);
895        let payload = if use_responses_api {
896            self.format_for_responses_api(&request)?
897        } else {
898            self.format_for_chat_completions(&request)?
899        };
900
901        let endpoint = self.get_endpoint(use_responses_api);
902
903        let response = self
904            .http_client
905            .post(&endpoint)
906            .header("Authorization", format!("Bearer {}", self.api_key))
907            .json(&payload)
908            .send()
909            .await
910            .map_err(|err| format_network_error(PROVIDER_NAME, &err))?;
911
912        self.parse_response(response, model, use_responses_api)
913            .await
914    }
915
916    async fn stream(&self, mut request: LLMRequest) -> Result<LLMStream, LLMError> {
917        let model = ensure_model(&mut request, &self.model);
918
919        self.apply_model_defaults(&mut request);
920        self.validate_request(&request)?;
921        request.stream = true;
922
923        let model_id = self.normalize_model_id(&request.model)?;
924        request.model = model_id;
925
926        let use_responses_api = self.should_use_responses_api(&request);
927        let payload = if use_responses_api {
928            self.format_for_responses_api(&request)?
929        } else {
930            self.format_for_chat_completions(&request)?
931        };
932
933        let endpoint = self.get_endpoint(use_responses_api);
934
935        let response = self
936            .http_client
937            .post(&endpoint)
938            .header("Authorization", format!("Bearer {}", self.api_key))
939            .json(&payload)
940            .send()
941            .await
942            .map_err(|err| format_network_error(PROVIDER_NAME, &err))?;
943
944        if !response.status().is_success() {
945            let status = response.status();
946            let body = response.text().await.unwrap_or_default();
947            return Err(self.format_error(status, &body));
948        }
949
950        self.create_stream(response, model, use_responses_api).await
951    }
952
953    fn supported_models(&self) -> Vec<String> {
954        Self::available_models()
955    }
956
957    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
958        if request.messages.is_empty() {
959            return Err(LLMError::InvalidRequest {
960                message: format_llm_error(PROVIDER_NAME, "Messages cannot be empty"),
961                metadata: None,
962            });
963        }
964
965        if request.model.trim().is_empty() {
966            return Err(LLMError::InvalidRequest {
967                message: format_llm_error(PROVIDER_NAME, "Model identifier cannot be empty"),
968                metadata: None,
969            });
970        }
971
972        Ok(())
973    }
974}
975
976impl HuggingFaceProvider {
977    async fn create_stream(
978        &self,
979        response: Response,
980        model: String,
981        use_responses_api: bool,
982    ) -> Result<LLMStream, LLMError> {
983        let mut bytes_stream = response.bytes_stream();
984        let mut buffer = String::with_capacity(4096);
985        let mut aggregator = crate::llm::providers::shared::StreamAggregator::new(model.clone());
986        let telemetry = NoopStreamTelemetry;
987
988        let stream = try_stream! {
989            'outer: while let Some(chunk_result) = bytes_stream.next().await {
990                let chunk = chunk_result.map_err(|err| format_network_error(PROVIDER_NAME, &err))?;
991                let text = String::from_utf8_lossy(&chunk);
992                buffer.push_str(&text);
993
994                if buffer.len() > 128_000 {
995                    Err(LLMError::Provider {
996                        message: format_llm_error(PROVIDER_NAME, "Stream buffer exceeded maximum size (128KB)"),
997                        metadata: None,
998                    })?;
999                }
1000
1001                while let Some(newline_pos) = buffer.find('\n') {
1002                    let line = buffer[..newline_pos].trim().to_string();
1003                    buffer.drain(..=newline_pos);
1004
1005                    if line.is_empty() || line.starts_with(':') {
1006                        continue;
1007                    }
1008
1009                    let data = if let Some(stripped) = line.strip_prefix("data: ") {
1010                        stripped
1011                    } else {
1012                        continue;
1013                    };
1014
1015                    if data == "[DONE]" {
1016                        break 'outer;
1017                    }
1018
1019                    let event: Value = match serde_json::from_str(data) {
1020                        Ok(v) => v,
1021                        Err(_) => continue,
1022                    };
1023
1024                    if use_responses_api {
1025                        let event_type = event.get("type").and_then(|t| t.as_str()).unwrap_or("");
1026
1027                        match event_type {
1028                            "response.output_text.delta" | "output_text.delta" => {
1029                                if let Some(delta) = event.get("delta").and_then(|d| d.as_str()) {
1030                                    telemetry.on_content_delta(delta);
1031                                    for ev in aggregator.handle_content(delta) {
1032                                        yield ev;
1033                                    }
1034                                }
1035                                continue;
1036                            }
1037                            "response.reasoning.delta" | "reasoning.delta" => {
1038                                if let Some(delta) = event.get("delta").and_then(|d| d.as_str()) {
1039                                    if let Some(d) = aggregator.handle_reasoning(delta) {
1040                                        telemetry.on_reasoning_delta(&d);
1041                                        yield LLMStreamEvent::Reasoning { delta: d };
1042                                    }
1043                                }
1044                                continue;
1045                            }
1046                            "response.function_call_arguments.delta" | "tool_call.delta" => {
1047                                telemetry.on_tool_call_delta();
1048                                continue;
1049                            }
1050                            "response.completed" => {
1051                                if let Some(response_obj) = event.get("response") {
1052                                    if let Ok(response) = Self::parse_responses_api_format(response_obj, model.clone()) {
1053                                        let final_agg_response = aggregator.finalize();
1054                                        let mut merged_response = response;
1055                                        if merged_response.content.is_none() {
1056                                            merged_response.content = final_agg_response.content;
1057                                        }
1058                                        if merged_response.reasoning.is_none() {
1059                                            merged_response.reasoning = final_agg_response.reasoning;
1060                                        }
1061                                        if merged_response.tool_calls.is_none() {
1062                                            merged_response.tool_calls = final_agg_response.tool_calls;
1063                                        }
1064                                        if merged_response.usage.is_none() {
1065                                            merged_response.usage = final_agg_response.usage;
1066                                        }
1067                                        yield LLMStreamEvent::Completed { response: Box::new(merged_response) };
1068                                        return;
1069                                    }
1070                                }
1071                                break 'outer;
1072                            }
1073                            "response.done" => {
1074                                break 'outer;
1075                            }
1076                            _ => {}
1077                        }
1078                    }
1079
1080                    if let Some(choices_arr) = event.get("choices").and_then(|c| c.as_array()) {
1081                        if let Some(choice) = choices_arr.first() {
1082                            if let Some(delta_obj) = choice.get("delta") {
1083                                if let Some(content) = delta_obj.get("content").and_then(|c| c.as_str()) {
1084                                    telemetry.on_content_delta(content);
1085                                    for ev in aggregator.handle_content(content) {
1086                                        yield ev;
1087                                    }
1088                                }
1089
1090                                if let Some(reason) = delta_obj.get("reasoning_content").and_then(|r| r.as_str()) {
1091                                    if let Some(d) = aggregator.handle_reasoning(reason) {
1092                                        telemetry.on_reasoning_delta(&d);
1093                                        yield LLMStreamEvent::Reasoning { delta: d };
1094                                    }
1095                                }
1096
1097                                if let Some(reasoning_details) = delta_obj
1098                                    .get("reasoning_details")
1099                                    .and_then(|details| details.as_array())
1100                                {
1101                                    aggregator.set_reasoning_details(reasoning_details);
1102                                }
1103
1104                                if let Some(tool_calls_arr) = delta_obj.get("tool_calls").and_then(|tc| tc.as_array()) {
1105                                    aggregator.handle_tool_calls(tool_calls_arr);
1106                                    telemetry.on_tool_call_delta();
1107                                }
1108                            }
1109
1110                            if let Some(finish_reason_str) = choice.get("finish_reason").and_then(|fr| fr.as_str()) {
1111                                aggregator.set_finish_reason(map_finish_reason_common(finish_reason_str));
1112                                if let Some(usage_value) = event.get("usage") {
1113                                    aggregator.set_usage(crate::llm::provider::Usage {
1114                                        prompt_tokens: usage_value.get("prompt_tokens").and_then(|pt| pt.as_u64()).unwrap_or(0) as u32,
1115                                        completion_tokens: usage_value.get("completion_tokens").and_then(|ct| ct.as_u64()).unwrap_or(0) as u32,
1116                                        total_tokens: usage_value.get("total_tokens").and_then(|tt| tt.as_u64()).unwrap_or(0) as u32,
1117                                        cached_prompt_tokens: None,
1118                                        cache_creation_tokens: None,
1119                                        cache_read_tokens: None,
1120                                        iterations: None,
1121                                    });
1122                                }
1123
1124                                break 'outer;
1125                            }
1126                        }
1127                    }
1128                }
1129            }
1130
1131            yield LLMStreamEvent::Completed { response: Box::new(aggregator.finalize()) };
1132        };
1133
1134        Ok(Box::pin(stream))
1135    }
1136}
1137
1138impl_llm_client!(HuggingFaceProvider);
1139
1140#[cfg(test)]
1141mod tests {
1142    use super::HuggingFaceProvider;
1143    use crate::llm::provider::{LLMRequest, Message, ToolDefinition};
1144    use crate::llm::providers::common::{is_minimax_m2_model, normalize_reasoning_detail_object};
1145    use serde_json::json;
1146    use std::sync::Arc;
1147
1148    #[test]
1149    fn minimax_model_detection_handles_variants() {
1150        assert!(is_minimax_m2_model("MiniMaxAI/MiniMax-M2.5:novita"));
1151        assert!(is_minimax_m2_model("minimax-m2.5"));
1152        assert!(!is_minimax_m2_model("deepseek-r1"));
1153    }
1154
1155    #[test]
1156    fn normalize_reasoning_detail_decodes_stringified_object() {
1157        let parsed = normalize_reasoning_detail_object(&json!(
1158            "{\"type\":\"reasoning.text\",\"text\":\"step\"}"
1159        ))
1160        .expect("expected a parsed reasoning detail object");
1161        assert!(parsed.is_object());
1162        assert_eq!(parsed["type"], "reasoning.text");
1163    }
1164
1165    #[test]
1166    fn serialize_messages_normalizes_minimax_reasoning_details() {
1167        let provider = HuggingFaceProvider::with_model(
1168            "test-key".to_string(),
1169            "MiniMaxAI/MiniMax-M2.5:novita".to_string(),
1170        );
1171        let request = LLMRequest {
1172            model: "MiniMaxAI/MiniMax-M2.5:novita".to_string(),
1173            messages: vec![
1174                Message::assistant("answer".to_string()).with_reasoning_details(Some(vec![json!(
1175                    "{\"type\":\"reasoning.text\",\"text\":\"chain\"}"
1176                )])),
1177            ],
1178            ..Default::default()
1179        };
1180
1181        let messages = provider
1182            .serialize_messages_huggingface_chat(&request)
1183            .expect("message serialization should succeed");
1184        assert!(messages[0]["reasoning_details"].is_array());
1185        assert!(messages[0]["reasoning_details"][0].is_object());
1186    }
1187
1188    #[test]
1189    fn serialize_messages_rehydrates_glm_interleaved_history_into_content() {
1190        let provider =
1191            HuggingFaceProvider::with_model("test-key".to_string(), "zai-org/GLM-5:novita".into());
1192        let request = LLMRequest {
1193            model: "zai-org/GLM-5:novita".to_string(),
1194            messages: vec![
1195                Message::assistant("done".to_string()).with_reasoning(Some("trace".to_string())),
1196            ],
1197            ..Default::default()
1198        };
1199
1200        let messages = provider
1201            .serialize_messages_huggingface_chat(&request)
1202            .expect("message serialization should succeed");
1203
1204        assert_eq!(messages[0]["content"], json!("<think>trace</think>done"));
1205    }
1206
1207    #[test]
1208    fn normalize_step35_flash_provider_suffix() {
1209        let provider = HuggingFaceProvider::with_model(
1210            "test-key".to_string(),
1211            "stepfun-ai/Step-3.5-Flash".to_string(),
1212        );
1213
1214        let normalized = provider
1215            .normalize_model_id("stepfun-ai/Step-3.5-Flash")
1216            .expect("normalization should succeed");
1217        assert_eq!(
1218            normalized,
1219            "stepfun-ai/Step-3.5-Flash:featherless-ai".to_string()
1220        );
1221
1222        let normalized_legacy = provider
1223            .normalize_model_id("stepfun-ai/Step-3.5-Flash:fastest")
1224            .expect("legacy suffix normalization should succeed");
1225        assert_eq!(
1226            normalized_legacy,
1227            "stepfun-ai/Step-3.5-Flash:featherless-ai".to_string()
1228        );
1229    }
1230
1231    #[test]
1232    fn format_for_chat_completions_keeps_apply_patch_as_function_tool() {
1233        let provider = HuggingFaceProvider::with_model(
1234            "test-key".to_string(),
1235            "Qwen/Qwen3-Coder-480B-A35B-Instruct".to_string(),
1236        );
1237        let request = LLMRequest {
1238            model: "Qwen/Qwen3-Coder-480B-A35B-Instruct".to_string(),
1239            messages: vec![Message::user("apply a patch".to_string())],
1240            tools: Some(Arc::new(vec![ToolDefinition::apply_patch(
1241                "Apply patches".to_string(),
1242            )])),
1243            ..Default::default()
1244        };
1245
1246        let payload = provider
1247            .format_for_chat_completions(&request)
1248            .expect("payload should serialize");
1249
1250        assert_eq!(payload["tools"][0]["type"], "function");
1251        assert_eq!(payload["tools"][0]["function"]["name"], "apply_patch");
1252    }
1253}