Skip to main content

vtcode_core/llm/providers/
huggingface.rs

1#![allow(clippy::bind_instead_of_map, clippy::collapsible_if)]
2
3use crate::config::TimeoutsConfig;
4use crate::config::constants::{env_vars, models, urls};
5use crate::config::core::{AnthropicConfig, ModelConfig, PromptCachingConfig};
6use crate::llm::error_display::format_llm_error;
7use crate::llm::provider::{
8    LLMError, LLMErrorMetadata, LLMProvider, LLMRequest, LLMResponse, LLMStream, LLMStreamEvent,
9    MessageRole, ToolDefinition,
10};
11use crate::llm::providers::shared::{
12    NoopStreamTelemetry, StreamTelemetry, function_output_value_from_message_content,
13};
14use async_stream::try_stream;
15use async_trait::async_trait;
16use futures::StreamExt;
17use reqwest::{Client as HttpClient, Response, StatusCode};
18use serde_json::{Value, json};
19
20use super::common::{
21    assistant_interleaved_history_text, ensure_model, impl_llm_client, is_minimax_m2_model,
22    map_finish_reason_common, normalize_reasoning_detail_objects, override_base_url,
23    parse_response_openai_format, resolve_model,
24};
25use super::error_handling::{format_network_error, format_parse_error};
26
27const PROVIDER_NAME: &str = "HuggingFace";
28const PROVIDER_KEY: &str = "huggingface";
29const JSON_INSTRUCTION: &str = "Return JSON that matches the provided schema.";
30
31pub struct HuggingFaceProvider {
32    api_key: String,
33    http_client: HttpClient,
34    base_url: String,
35    model: String,
36    _timeouts: TimeoutsConfig,
37    model_behavior: Option<ModelConfig>,
38}
39
40impl HuggingFaceProvider {
41    pub fn new(api_key: String) -> Self {
42        Self::with_model_internal(
43            api_key,
44            models::huggingface::DEFAULT_MODEL.to_string(),
45            None,
46            None,
47            None,
48        )
49    }
50
51    pub fn with_model(api_key: String, model: String) -> Self {
52        Self::with_model_internal(api_key, model, None, None, None)
53    }
54
55    pub fn with_timeouts(api_key: String, timeouts: TimeoutsConfig) -> Self {
56        Self::with_model_internal(
57            api_key,
58            models::huggingface::DEFAULT_MODEL.to_string(),
59            None,
60            Some(timeouts),
61            None,
62        )
63    }
64
65    fn with_model_internal(
66        api_key: String,
67        model: String,
68        base_url: Option<String>,
69        timeouts: Option<TimeoutsConfig>,
70        model_behavior: Option<ModelConfig>,
71    ) -> Self {
72        use crate::llm::http_client::HttpClientFactory;
73
74        let timeouts = timeouts.unwrap_or_default();
75
76        Self {
77            api_key,
78            http_client: HttpClientFactory::for_llm(&timeouts),
79            base_url: override_base_url(
80                urls::HUGGINGFACE_API_BASE,
81                base_url,
82                Some(env_vars::HUGGINGFACE_BASE_URL),
83            ),
84            model,
85            _timeouts: timeouts,
86            model_behavior,
87        }
88    }
89
90    pub fn from_config(
91        api_key: Option<String>,
92        model: Option<String>,
93        base_url: Option<String>,
94        _prompt_cache: Option<PromptCachingConfig>,
95        timeouts: Option<TimeoutsConfig>,
96        _anthropic: Option<AnthropicConfig>,
97        model_behavior: Option<ModelConfig>,
98    ) -> Self {
99        let api_key_value = api_key.unwrap_or_default();
100        let model_value = resolve_model(model, models::huggingface::DEFAULT_MODEL);
101        Self::with_model_internal(
102            api_key_value,
103            model_value,
104            base_url,
105            timeouts,
106            model_behavior,
107        )
108    }
109
110    fn normalize_model_id(&self, model: &str) -> Result<String, LLMError> {
111        let model = model.trim();
112        let lower = model.to_ascii_lowercase();
113
114        if lower.starts_with(&models::huggingface::STEP_3_5_FLASH_BASE.to_ascii_lowercase()) {
115            if !model.contains(':') {
116                return Ok(format!(
117                    "{}:{}",
118                    models::huggingface::STEP_3_5_FLASH_BASE,
119                    models::huggingface::STEP_3_5_FLASH_PROVIDER
120                ));
121            }
122            if let Some((base, provider)) = model.rsplit_once(':')
123                && provider.eq_ignore_ascii_case("fastest")
124            {
125                return Ok(format!(
126                    "{}:{}",
127                    base,
128                    models::huggingface::STEP_3_5_FLASH_PROVIDER
129                ));
130            }
131        }
132
133        if lower.contains("minimax-m2") && !model.contains(':') {
134            return Err(LLMError::Provider {
135                message: format_llm_error(
136                    PROVIDER_NAME,
137                    "MiniMax models require explicit provider selection (:novita suffix). \n                    Use 'MiniMaxAI/MiniMax-M2.5:novita'.",
138                ),
139                metadata: None,
140            });
141        }
142
143        if lower.contains("glm-5") && !model.contains(':') {
144            return Err(LLMError::Provider {
145                message: format_llm_error(
146                    PROVIDER_NAME,
147                    "GLM models require explicit provider selection on HuggingFace.",
148                ),
149                metadata: None,
150            });
151        }
152
153        Ok(model.to_string())
154    }
155
156    fn serialize_tools_huggingface(&self, tools: &[ToolDefinition]) -> Option<Vec<Value>> {
157        crate::llm::providers::common::serialize_tools_openai_format(tools)
158    }
159
160    fn serialize_messages_huggingface_chat(
161        &self,
162        request: &LLMRequest,
163    ) -> Result<Vec<Value>, LLMError> {
164        use serde_json::{Map, json};
165
166        let mut messages = Vec::with_capacity(request.messages.len());
167
168        for message in &request.messages {
169            message
170                .validate_for_provider(PROVIDER_KEY)
171                .map_err(|e| LLMError::InvalidRequest {
172                    message: e,
173                    metadata: None,
174                })?;
175
176            let mut message_map = Map::with_capacity(4);
177            message_map.insert(
178                "role".to_owned(),
179                Value::String(message.role.as_generic_str().to_owned()),
180            );
181
182            if let Some(interleaved_content) =
183                assistant_interleaved_history_text(message, &request.model)
184            {
185                message_map.insert("content".to_owned(), Value::String(interleaved_content));
186            } else {
187                match &message.content {
188                    crate::llm::provider::MessageContent::Text(text) => {
189                        message_map.insert("content".to_owned(), Value::String(text.clone()));
190                    }
191                    crate::llm::provider::MessageContent::Parts(parts) => {
192                        let has_images = parts
193                            .iter()
194                            .any(crate::llm::provider::ContentPart::is_image);
195                        if has_images {
196                            let parts_json: Vec<Value> = parts
197                            .iter()
198                            .map(|part| match part {
199                                crate::llm::provider::ContentPart::Text { text } => {
200                                    json!({ "type": "text", "text": text })
201                                }
202                                crate::llm::provider::ContentPart::Image {
203                                    data,
204                                    mime_type,
205                                    ..
206                                } => {
207                                    json!({
208                                        "type": "image_url",
209                                        "image_url": {
210                                            "url": format!("data:{};base64,{}", mime_type, data)
211                                        }
212                                    })
213                                }
214                                crate::llm::provider::ContentPart::File {
215                                    filename,
216                                    file_id,
217                                    file_url,
218                                    ..
219                                } => {
220                                    let fallback = filename
221                                        .clone()
222                                        .or_else(|| file_id.clone())
223                                        .or_else(|| file_url.clone())
224                                        .unwrap_or_else(|| "attached file".to_string());
225                                    json!({ "type": "text", "text": format!("[File input not directly supported: {}]", fallback) })
226                                }
227                            })
228                            .collect();
229                            message_map.insert("content".to_owned(), Value::Array(parts_json));
230                        } else {
231                            let text = message.content.as_text().into_owned();
232                            message_map.insert("content".to_owned(), Value::String(text));
233                        }
234                    }
235                }
236            }
237
238            if let Some(tool_calls) = &message.tool_calls {
239                let serialized_calls = tool_calls
240                    .iter()
241                    .filter_map(|call| {
242                        call.function.as_ref().map(|func| {
243                            json!({
244                                "id": &call.id,
245                                "type": "function",
246                                "function": {
247                                    "name": &func.name,
248                                    "arguments": &func.arguments
249                                }
250                            })
251                        })
252                    })
253                    .collect::<Vec<_>>();
254                message_map.insert("tool_calls".to_owned(), Value::Array(serialized_calls));
255            }
256
257            if let Some(tool_call_id) = &message.tool_call_id {
258                message_map.insert(
259                    "tool_call_id".to_owned(),
260                    Value::String(tool_call_id.clone()),
261                );
262            }
263
264            if message.role == MessageRole::Assistant
265                && is_minimax_m2_model(&request.model)
266                && let Some(reasoning_details) = &message.reasoning_details
267                && !reasoning_details.is_empty()
268            {
269                let normalized_details = normalize_reasoning_detail_objects(reasoning_details);
270                if !normalized_details.is_empty() {
271                    message_map.insert(
272                        "reasoning_details".to_owned(),
273                        Value::Array(normalized_details),
274                    );
275                }
276            }
277
278            messages.push(Value::Object(message_map));
279        }
280
281        Ok(messages)
282    }
283
284    fn format_for_chat_completions(&self, request: &LLMRequest) -> Result<Value, LLMError> {
285        let mut messages = self.serialize_messages_huggingface_chat(request)?;
286        let is_glm = self.is_glm_model(&request.model);
287
288        if let Some(system) = &request.system_prompt {
289            let has_system = messages
290                .first()
291                .and_then(|m| m.get("role"))
292                .and_then(|r| r.as_str())
293                == Some("system");
294            if !has_system {
295                messages.insert(
296                    0,
297                    json!({
298                        "role": "system",
299                        "content": system
300                    }),
301                );
302            }
303        }
304
305        let mut payload = json!({
306            "model": request.model,
307            "messages": messages,
308            "stream": request.stream,
309        });
310
311        if request.stream && request.tools.is_some() && is_glm {
312            payload["tool_stream"] = json!(true);
313        }
314
315        if let Some(max_tokens) = request.max_tokens {
316            payload["max_tokens"] = json!(max_tokens);
317        }
318
319        if let Some(tools) = &request.tools {
320            if let Some(serialized) = self.serialize_tools_huggingface(tools) {
321                payload["tools"] = json!(serialized);
322
323                if let Some(choice) = &request.tool_choice {
324                    payload["tool_choice"] = choice.to_provider_format("openai");
325                }
326            }
327        }
328
329        if let Some(temperature) = request.temperature {
330            payload["temperature"] = json!(temperature);
331        }
332
333        if let Some(top_p) = request.top_p {
334            payload["top_p"] = json!(top_p);
335        }
336
337        if let Some(top_k) = request.top_k {
338            payload["top_k"] = json!(top_k);
339        }
340
341        if let Some(effort) = request.reasoning_effort {
342            use crate::config::models::Provider;
343            use crate::llm::rig_adapter::RigProviderCapabilities;
344            if let Some(reasoning_params) =
345                RigProviderCapabilities::new(Provider::HuggingFace, &request.model)
346                    .reasoning_parameters(effort)
347            {
348                if let Some(params_obj) = reasoning_params.as_object() {
349                    for (k, v) in params_obj {
350                        payload[k] = v.clone();
351                    }
352                }
353            }
354        }
355
356        if request.output_format.is_some() && !is_glm {
357            payload["response_format"] = json!({ "type": "json_object" });
358        }
359
360        Ok(payload)
361    }
362
363    fn is_glm_model(&self, model: &str) -> bool {
364        let lower = model.to_ascii_lowercase();
365        lower.contains("glm")
366    }
367
368    fn is_deepseek_model(&self, model: &str) -> bool {
369        let lower = model.to_ascii_lowercase();
370        lower.contains("deepseek")
371    }
372
373    fn is_minimax_model(&self, model: &str) -> bool {
374        let lower = model.to_ascii_lowercase();
375        lower.contains("minimax")
376    }
377
378    fn apply_model_defaults(&self, request: &mut LLMRequest) {
379        if self.is_minimax_model(&request.model) {
380            if request.temperature.is_none() {
381                request.temperature = Some(1.0);
382            }
383            if request.top_p.is_none() {
384                request.top_p = Some(0.95);
385            }
386            if request.top_k.is_none() {
387                request.top_k = Some(40);
388            }
389        }
390    }
391
392    fn add_json_instruction(&self, payload: &mut Value) -> Result<(), LLMError> {
393        if let Some(instructions) = payload.get_mut("instructions") {
394            if let Some(text) = instructions.as_str() {
395                if !text.contains("Return JSON") {
396                    *instructions = json!(format!("{}\n\n{}", text, JSON_INSTRUCTION));
397                }
398            }
399        } else {
400            payload["instructions"] = json!(JSON_INSTRUCTION);
401        }
402
403        Ok(())
404    }
405
406    fn format_for_responses_api(&self, request: &LLMRequest) -> Result<Value, LLMError> {
407        let mut input = Vec::new();
408
409        for msg in &request.messages {
410            let convert_parts = |parts: &[crate::llm::provider::ContentPart]| -> Value {
411                let parts_json: Vec<Value> = parts
412                    .iter()
413                    .map(|part| match part {
414                        crate::llm::provider::ContentPart::Text { text } => {
415                            json!({ "type": "input_text", "text": text })
416                        }
417                        crate::llm::provider::ContentPart::Image {
418                            data, mime_type, ..
419                        } => {
420                            json!({
421                                "type": "input_image",
422                                "image_url": format!("data:{};base64,{}", mime_type, data)
423                            })
424                        }
425                        crate::llm::provider::ContentPart::File {
426                            filename,
427                            file_id,
428                            file_url,
429                            ..
430                        } => {
431                            let fallback = filename
432                                .clone()
433                                .or_else(|| file_id.clone())
434                                .or_else(|| file_url.clone())
435                                .unwrap_or_else(|| "attached file".to_string());
436                            json!({
437                                "type": "input_text",
438                                "text": format!("[File input not directly supported: {}]", fallback)
439                            })
440                        }
441                    })
442                    .collect();
443                json!(parts_json)
444            };
445
446            match msg.role {
447                MessageRole::System | MessageRole::User => {
448                    if msg.role == MessageRole::System && request.system_prompt.is_some() {
449                        if let crate::llm::provider::MessageContent::Text(text) = &msg.content {
450                            if request.system_prompt.as_ref().map(|s| s.as_str())
451                                == Some(text.as_str())
452                            {
453                                continue;
454                            }
455                        }
456                    }
457
458                    let role = if msg.role == MessageRole::System {
459                        "system"
460                    } else {
461                        "user"
462                    };
463
464                    let mut message_obj = json!({
465                        "type": "message",
466                        "role": role,
467                    });
468
469                    match &msg.content {
470                        crate::llm::provider::MessageContent::Text(text) => {
471                            message_obj["content"] = json!(text);
472                        }
473                        crate::llm::provider::MessageContent::Parts(parts) => {
474                            message_obj["content"] = convert_parts(parts);
475                        }
476                    }
477
478                    input.push(message_obj);
479                }
480                MessageRole::Assistant => {
481                    let has_content = match &msg.content {
482                        crate::llm::provider::MessageContent::Text(text) => !text.is_empty(),
483                        crate::llm::provider::MessageContent::Parts(parts) => !parts.is_empty(),
484                    };
485
486                    if has_content {
487                        let mut message_obj = json!({
488                            "type": "message",
489                            "role": "assistant",
490                        });
491
492                        match &msg.content {
493                            crate::llm::provider::MessageContent::Text(text) => {
494                                message_obj["content"] = json!(text);
495                            }
496                            crate::llm::provider::MessageContent::Parts(parts) => {
497                                message_obj["content"] = convert_parts(parts);
498                            }
499                        }
500
501                        input.push(message_obj);
502                    }
503
504                    if let Some(tool_calls) = &msg.tool_calls {
505                        for tc in tool_calls {
506                            if let Some(func) = &tc.function {
507                                input.push(json!({
508                                    "type": "function_call",
509                                    "call_id": tc.id,
510                                    "name": func.name,
511                                    "arguments": func.arguments
512                                }));
513                            }
514                        }
515                    }
516                }
517                MessageRole::Tool => {
518                    input.push(json!({
519                        "type": "function_call_output",
520                        "call_id": msg.tool_call_id.clone().unwrap_or_default(),
521                        "output": function_output_value_from_message_content(&msg.content)
522                    }));
523                }
524            }
525        }
526
527        let mut payload = json!({
528            "model": request.model,
529            "input": input,
530            "stream": request.stream,
531        });
532
533        if let Some(system_prompt) = &request.system_prompt {
534            payload["instructions"] = json!(system_prompt);
535        }
536
537        if let Some(effort) = request.reasoning_effort {
538            use crate::config::types::ReasoningEffortLevel;
539            if effort != ReasoningEffortLevel::None {
540                payload["reasoning"] = json!({ "effort": effort.as_str() });
541            }
542        }
543
544        if let Some(max_tokens) = request.max_tokens {
545            payload["max_tokens"] = json!(max_tokens);
546        }
547        if let Some(temperature) = request.temperature {
548            payload["temperature"] = json!(temperature);
549        }
550        if let Some(top_p) = request.top_p {
551            payload["top_p"] = json!(top_p);
552        }
553        if let Some(top_k) = request.top_k {
554            payload["top_k"] = json!(top_k);
555        }
556
557        if let Some(tools) = &request.tools {
558            if let Some(serialized) = self.serialize_tools_huggingface(tools) {
559                payload["tools"] = json!(serialized);
560
561                if let Some(choice) = &request.tool_choice {
562                    payload["tool_choice"] = choice.to_provider_format("openai");
563                }
564            }
565        }
566
567        if request.output_format.is_some() || request.tools.is_some() {
568            self.add_json_instruction(&mut payload)?;
569        }
570
571        if request.output_format.is_some() && !self.is_glm_model(&request.model) {
572            payload["response_format"] = json!({ "type": "json_object" });
573        }
574
575        Ok(payload)
576    }
577
578    fn should_use_responses_api(&self, _request: &LLMRequest) -> bool {
579        false
580    }
581
582    fn format_error(&self, status: StatusCode, body: &str) -> LLMError {
583        let message = if body.contains("\"code\":\"model_not_supported\"")
584            && body.contains(models::huggingface::STEP_3_5_FLASH_BASE)
585        {
586            format!(
587                "HuggingFace API error ({}): Step 3.5 Flash requires the '{}' provider. \
588Enable that provider in your HuggingFace Inference Providers settings, or switch to another model.",
589                status,
590                models::huggingface::STEP_3_5_FLASH_PROVIDER
591            )
592        } else {
593            format!("HuggingFace API error ({}): {}", status, body)
594        };
595
596        LLMError::Provider {
597            message: format_llm_error(PROVIDER_NAME, &message),
598            metadata: Some(LLMErrorMetadata::new(
599                PROVIDER_NAME,
600                Some(status.as_u16()),
601                None,
602                None,
603                None,
604                None,
605                Some(body.to_string()),
606            )),
607        }
608    }
609
610    fn parse_responses_api_format(json: &Value, model: String) -> Result<LLMResponse, LLMError> {
611        let convenience_text = json.get("output_text").and_then(|t| t.as_str());
612
613        let json_obj = json.get("response").unwrap_or(json);
614
615        let output = json_obj.get("output").and_then(|v| v.as_array());
616
617        let output_arr = match output {
618            Some(arr) => arr,
619            None => {
620                if let Some(text) = convenience_text {
621                    return Ok(LLMResponse {
622                        content: Some(text.to_string()),
623                        tool_calls: None,
624                        model,
625                        usage: None,
626                        finish_reason: crate::llm::provider::FinishReason::Stop,
627                        reasoning: None,
628                        reasoning_details: None,
629                        tool_references: Vec::new(),
630                        request_id: None,
631                        organization_id: None,
632                        compaction: None,
633                    });
634                }
635
636                return Err(LLMError::Provider {
637                    message: format_llm_error(PROVIDER_NAME, "Not a Responses API format"),
638                    metadata: None,
639                });
640            }
641        };
642
643        let mut content_fragments: Vec<String> = Vec::new();
644        let mut reasoning_fragments: Vec<String> = Vec::new();
645        let mut tool_calls: Vec<crate::llm::provider::ToolCall> = Vec::new();
646
647        for item in output_arr {
648            let item_type = item.get("type").and_then(|t| t.as_str()).unwrap_or("");
649
650            match item_type {
651                "message" => {
652                    if let Some(content_arr) = item.get("content").and_then(|c| c.as_array()) {
653                        for entry in content_arr {
654                            let entry_type =
655                                entry.get("type").and_then(|t| t.as_str()).unwrap_or("");
656                            match entry_type {
657                                "text" | "output_text" => {
658                                    if let Some(text) = entry.get("text").and_then(|t| t.as_str()) {
659                                        if !text.is_empty() {
660                                            content_fragments.push(text.to_string());
661                                        }
662                                    }
663                                }
664                                "reasoning" => {
665                                    if let Some(text) = entry.get("text").and_then(|t| t.as_str()) {
666                                        if !text.is_empty() {
667                                            reasoning_fragments.push(text.to_string());
668                                        }
669                                    }
670                                }
671                                "function_call" | "tool_call" => {
672                                    if let Some(call) = Self::parse_responses_tool_call(entry) {
673                                        tool_calls.push(call);
674                                    }
675                                }
676                                _ => {}
677                            }
678                        }
679                    }
680                }
681                "function_call" | "tool_call" => {
682                    if let Some(call) = Self::parse_responses_tool_call(item) {
683                        tool_calls.push(call);
684                    }
685                }
686                "reasoning" => {
687                    if let Some(summary_arr) = item.get("summary").and_then(|s| s.as_array()) {
688                        for summary in summary_arr {
689                            if let Some(text) = summary.get("text").and_then(|t| t.as_str()) {
690                                if !text.is_empty() {
691                                    reasoning_fragments.push(text.to_string());
692                                }
693                            }
694                        }
695                    } else if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
696                        reasoning_fragments.push(text.to_string());
697                    }
698                }
699                _ => {}
700            }
701        }
702
703        let content = if content_fragments.is_empty() {
704            convenience_text.map(|t| t.to_string())
705        } else {
706            Some(content_fragments.join(""))
707        };
708
709        let reasoning = if reasoning_fragments.is_empty() {
710            None
711        } else {
712            Some(reasoning_fragments.join("\n\n"))
713        };
714
715        let finish_reason = if !tool_calls.is_empty() {
716            crate::llm::provider::FinishReason::ToolCalls
717        } else {
718            crate::llm::provider::FinishReason::Stop
719        };
720
721        let usage_value = json.get("usage").or_else(|| json_obj.get("usage"));
722        let usage = usage_value.map(|usage_value| crate::llm::provider::Usage {
723            prompt_tokens: usage_value
724                .get("input_tokens")
725                .or_else(|| usage_value.get("prompt_tokens"))
726                .and_then(|pt| pt.as_u64())
727                .unwrap_or(0) as u32,
728            completion_tokens: usage_value
729                .get("output_tokens")
730                .or_else(|| usage_value.get("completion_tokens"))
731                .and_then(|ct| ct.as_u64())
732                .unwrap_or(0) as u32,
733            total_tokens: usage_value
734                .get("total_tokens")
735                .and_then(|tt| tt.as_u64())
736                .unwrap_or(0) as u32,
737            cached_prompt_tokens: None,
738            cache_creation_tokens: None,
739            cache_read_tokens: None,
740        });
741
742        Ok(LLMResponse {
743            content,
744            tool_calls: if tool_calls.is_empty() {
745                None
746            } else {
747                Some(tool_calls)
748            },
749            model,
750            usage,
751            finish_reason,
752            reasoning,
753            reasoning_details: None,
754            tool_references: Vec::new(),
755            request_id: None,
756            organization_id: None,
757            compaction: None,
758        })
759    }
760
761    fn parse_responses_tool_call(item: &Value) -> Option<crate::llm::provider::ToolCall> {
762        let call_id = item.get("id").and_then(|v| v.as_str()).unwrap_or("");
763        let function_obj = item.get("function").and_then(|v| v.as_object());
764        let name = function_obj.and_then(|f| f.get("name").and_then(|n| n.as_str()))?;
765        let arguments = function_obj.and_then(|f| f.get("arguments"));
766
767        let serialized = arguments.map_or("{}".to_owned(), |args| {
768            if args.is_string() {
769                args.as_str().unwrap_or("{}").to_string()
770            } else {
771                args.to_string()
772            }
773        });
774
775        Some(crate::llm::provider::ToolCall::function(
776            call_id.to_string(),
777            name.to_string(),
778            serialized,
779        ))
780    }
781
782    async fn parse_response(
783        &self,
784        response: Response,
785        model: String,
786        use_responses_api: bool,
787    ) -> Result<LLMResponse, LLMError> {
788        let status = response.status();
789
790        if !status.is_success() {
791            let body = response.text().await.unwrap_or_default();
792            return Err(self.format_error(status, &body));
793        }
794
795        let json: Value = response
796            .json()
797            .await
798            .map_err(|err| format_parse_error(PROVIDER_NAME, &err))?;
799
800        if use_responses_api {
801            if json.get("output").is_some() {
802                return Self::parse_responses_api_format(&json, model);
803            }
804        }
805
806        parse_response_openai_format::<fn(&Value, &Value) -> Option<String>>(
807            json,
808            PROVIDER_NAME,
809            model,
810            false,
811            None,
812        )
813    }
814
815    pub fn available_models() -> Vec<String> {
816        models::huggingface::SUPPORTED_MODELS
817            .iter()
818            .map(|s| s.to_string())
819            .collect()
820    }
821
822    fn get_endpoint(&self, use_responses_api: bool) -> String {
823        let base = self.base_url.trim_end_matches('/');
824        if use_responses_api {
825            format!("{}/responses", base)
826        } else {
827            format!("{}/chat/completions", base)
828        }
829    }
830}
831
832#[async_trait]
833impl LLMProvider for HuggingFaceProvider {
834    fn name(&self) -> &str {
835        PROVIDER_KEY
836    }
837
838    fn supports_streaming(&self) -> bool {
839        true
840    }
841
842    fn supports_reasoning(&self, model: &str) -> bool {
843        // Codex-inspired robustness: Setting model_supports_reasoning to false
844        // does NOT disable it for known reasoning models.
845        models::huggingface::REASONING_MODELS.contains(&model)
846            || self
847                .model_behavior
848                .as_ref()
849                .and_then(|b| b.model_supports_reasoning)
850                .unwrap_or(false)
851    }
852
853    fn supports_reasoning_effort(&self, model: &str) -> bool {
854        // Same robustness logic for reasoning effort
855        self.is_glm_model(model)
856            || self.is_deepseek_model(model)
857            || self
858                .model_behavior
859                .as_ref()
860                .and_then(|b| b.model_supports_reasoning_effort)
861                .unwrap_or(false)
862    }
863
864    fn supports_tools(&self, _model: &str) -> bool {
865        true
866    }
867
868    fn supports_parallel_tool_config(&self, _model: &str) -> bool {
869        false
870    }
871
872    fn supports_structured_output(&self, _model: &str) -> bool {
873        true
874    }
875
876    fn supports_context_caching(&self, _model: &str) -> bool {
877        false
878    }
879
880    fn effective_context_size(&self, _model: &str) -> usize {
881        128_000
882    }
883
884    async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
885        let model = ensure_model(&mut request, &self.model);
886
887        self.apply_model_defaults(&mut request);
888        self.validate_request(&request)?;
889
890        let model_id = self.normalize_model_id(&request.model)?;
891        request.model = model_id;
892
893        let use_responses_api = self.should_use_responses_api(&request);
894        let payload = if use_responses_api {
895            self.format_for_responses_api(&request)?
896        } else {
897            self.format_for_chat_completions(&request)?
898        };
899
900        let endpoint = self.get_endpoint(use_responses_api);
901
902        let response = self
903            .http_client
904            .post(&endpoint)
905            .header("Authorization", format!("Bearer {}", self.api_key))
906            .json(&payload)
907            .send()
908            .await
909            .map_err(|err| format_network_error(PROVIDER_NAME, &err))?;
910
911        self.parse_response(response, model, use_responses_api)
912            .await
913    }
914
915    async fn stream(&self, mut request: LLMRequest) -> Result<LLMStream, LLMError> {
916        let model = ensure_model(&mut request, &self.model);
917
918        self.apply_model_defaults(&mut request);
919        self.validate_request(&request)?;
920        request.stream = true;
921
922        let model_id = self.normalize_model_id(&request.model)?;
923        request.model = model_id;
924
925        let use_responses_api = self.should_use_responses_api(&request);
926        let payload = if use_responses_api {
927            self.format_for_responses_api(&request)?
928        } else {
929            self.format_for_chat_completions(&request)?
930        };
931
932        let endpoint = self.get_endpoint(use_responses_api);
933
934        let response = self
935            .http_client
936            .post(&endpoint)
937            .header("Authorization", format!("Bearer {}", self.api_key))
938            .json(&payload)
939            .send()
940            .await
941            .map_err(|err| format_network_error(PROVIDER_NAME, &err))?;
942
943        if !response.status().is_success() {
944            let status = response.status();
945            let body = response.text().await.unwrap_or_default();
946            return Err(self.format_error(status, &body));
947        }
948
949        self.create_stream(response, model, use_responses_api).await
950    }
951
952    fn supported_models(&self) -> Vec<String> {
953        Self::available_models()
954    }
955
956    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
957        if request.messages.is_empty() {
958            return Err(LLMError::InvalidRequest {
959                message: format_llm_error(PROVIDER_NAME, "Messages cannot be empty"),
960                metadata: None,
961            });
962        }
963
964        if request.model.trim().is_empty() {
965            return Err(LLMError::InvalidRequest {
966                message: format_llm_error(PROVIDER_NAME, "Model identifier cannot be empty"),
967                metadata: None,
968            });
969        }
970
971        Ok(())
972    }
973}
974
975impl HuggingFaceProvider {
976    async fn create_stream(
977        &self,
978        response: Response,
979        model: String,
980        use_responses_api: bool,
981    ) -> Result<LLMStream, LLMError> {
982        let mut bytes_stream = response.bytes_stream();
983        let mut buffer = String::with_capacity(4096);
984        let mut aggregator = crate::llm::providers::shared::StreamAggregator::new(model.clone());
985        let telemetry = NoopStreamTelemetry;
986
987        let stream = try_stream! {
988            'outer: while let Some(chunk_result) = bytes_stream.next().await {
989                let chunk = chunk_result.map_err(|err| format_network_error(PROVIDER_NAME, &err))?;
990                let text = String::from_utf8_lossy(&chunk);
991                buffer.push_str(&text);
992
993                if buffer.len() > 128_000 {
994                    Err(LLMError::Provider {
995                        message: format_llm_error(PROVIDER_NAME, "Stream buffer exceeded maximum size (128KB)"),
996                        metadata: None,
997                    })?;
998                }
999
1000                while let Some(newline_pos) = buffer.find('\n') {
1001                    let line = buffer[..newline_pos].trim().to_string();
1002                    buffer.drain(..=newline_pos);
1003
1004                    if line.is_empty() || line.starts_with(':') {
1005                        continue;
1006                    }
1007
1008                    let data = if let Some(stripped) = line.strip_prefix("data: ") {
1009                        stripped
1010                    } else {
1011                        continue;
1012                    };
1013
1014                    if data == "[DONE]" {
1015                        break 'outer;
1016                    }
1017
1018                    let event: Value = match serde_json::from_str(data) {
1019                        Ok(v) => v,
1020                        Err(_) => continue,
1021                    };
1022
1023                    if use_responses_api {
1024                        let event_type = event.get("type").and_then(|t| t.as_str()).unwrap_or("");
1025
1026                        match event_type {
1027                            "response.output_text.delta" | "output_text.delta" => {
1028                                if let Some(delta) = event.get("delta").and_then(|d| d.as_str()) {
1029                                    telemetry.on_content_delta(delta);
1030                                    for ev in aggregator.handle_content(delta) {
1031                                        yield ev;
1032                                    }
1033                                }
1034                                continue;
1035                            }
1036                            "response.reasoning.delta" | "reasoning.delta" => {
1037                                if let Some(delta) = event.get("delta").and_then(|d| d.as_str()) {
1038                                    if let Some(d) = aggregator.handle_reasoning(delta) {
1039                                        telemetry.on_reasoning_delta(&d);
1040                                        yield LLMStreamEvent::Reasoning { delta: d };
1041                                    }
1042                                }
1043                                continue;
1044                            }
1045                            "response.function_call_arguments.delta" | "tool_call.delta" => {
1046                                telemetry.on_tool_call_delta();
1047                                continue;
1048                            }
1049                            "response.completed" => {
1050                                if let Some(response_obj) = event.get("response") {
1051                                    if let Ok(response) = Self::parse_responses_api_format(response_obj, model.clone()) {
1052                                        let final_agg_response = aggregator.finalize();
1053                                        let mut merged_response = response;
1054                                        if merged_response.content.is_none() {
1055                                            merged_response.content = final_agg_response.content;
1056                                        }
1057                                        if merged_response.reasoning.is_none() {
1058                                            merged_response.reasoning = final_agg_response.reasoning;
1059                                        }
1060                                        if merged_response.tool_calls.is_none() {
1061                                            merged_response.tool_calls = final_agg_response.tool_calls;
1062                                        }
1063                                        if merged_response.usage.is_none() {
1064                                            merged_response.usage = final_agg_response.usage;
1065                                        }
1066                                        yield LLMStreamEvent::Completed { response: Box::new(merged_response) };
1067                                        return;
1068                                    }
1069                                }
1070                                break 'outer;
1071                            }
1072                            "response.done" => {
1073                                break 'outer;
1074                            }
1075                            _ => {}
1076                        }
1077                    }
1078
1079                    if let Some(choices_arr) = event.get("choices").and_then(|c| c.as_array()) {
1080                        if let Some(choice) = choices_arr.first() {
1081                            if let Some(delta_obj) = choice.get("delta") {
1082                                if let Some(content) = delta_obj.get("content").and_then(|c| c.as_str()) {
1083                                    telemetry.on_content_delta(content);
1084                                    for ev in aggregator.handle_content(content) {
1085                                        yield ev;
1086                                    }
1087                                }
1088
1089                                if let Some(reason) = delta_obj.get("reasoning_content").and_then(|r| r.as_str()) {
1090                                    if let Some(d) = aggregator.handle_reasoning(reason) {
1091                                        telemetry.on_reasoning_delta(&d);
1092                                        yield LLMStreamEvent::Reasoning { delta: d };
1093                                    }
1094                                }
1095
1096                                if let Some(reasoning_details) = delta_obj
1097                                    .get("reasoning_details")
1098                                    .and_then(|details| details.as_array())
1099                                {
1100                                    aggregator.set_reasoning_details(reasoning_details);
1101                                }
1102
1103                                if let Some(tool_calls_arr) = delta_obj.get("tool_calls").and_then(|tc| tc.as_array()) {
1104                                    aggregator.handle_tool_calls(tool_calls_arr);
1105                                    telemetry.on_tool_call_delta();
1106                                }
1107                            }
1108
1109                            if let Some(finish_reason_str) = choice.get("finish_reason").and_then(|fr| fr.as_str()) {
1110                                aggregator.set_finish_reason(map_finish_reason_common(finish_reason_str));
1111                                if let Some(usage_value) = event.get("usage") {
1112                                    aggregator.set_usage(crate::llm::provider::Usage {
1113                                        prompt_tokens: usage_value.get("prompt_tokens").and_then(|pt| pt.as_u64()).unwrap_or(0) as u32,
1114                                        completion_tokens: usage_value.get("completion_tokens").and_then(|ct| ct.as_u64()).unwrap_or(0) as u32,
1115                                        total_tokens: usage_value.get("total_tokens").and_then(|tt| tt.as_u64()).unwrap_or(0) as u32,
1116                                        cached_prompt_tokens: None,
1117                                        cache_creation_tokens: None,
1118                                        cache_read_tokens: None,
1119                                    });
1120                                }
1121
1122                                break 'outer;
1123                            }
1124                        }
1125                    }
1126                }
1127            }
1128
1129            yield LLMStreamEvent::Completed { response: Box::new(aggregator.finalize()) };
1130        };
1131
1132        Ok(Box::pin(stream))
1133    }
1134}
1135
1136impl_llm_client!(HuggingFaceProvider);
1137
1138#[cfg(test)]
1139mod tests {
1140    use super::HuggingFaceProvider;
1141    use crate::llm::provider::{LLMRequest, Message, ToolDefinition};
1142    use crate::llm::providers::common::{is_minimax_m2_model, normalize_reasoning_detail_object};
1143    use serde_json::json;
1144    use std::sync::Arc;
1145
1146    #[test]
1147    fn minimax_model_detection_handles_variants() {
1148        assert!(is_minimax_m2_model("MiniMaxAI/MiniMax-M2.5:novita"));
1149        assert!(is_minimax_m2_model("minimax-m2.5"));
1150        assert!(!is_minimax_m2_model("deepseek-r1"));
1151    }
1152
1153    #[test]
1154    fn normalize_reasoning_detail_decodes_stringified_object() {
1155        let parsed = normalize_reasoning_detail_object(&json!(
1156            "{\"type\":\"reasoning.text\",\"text\":\"step\"}"
1157        ))
1158        .expect("expected a parsed reasoning detail object");
1159        assert!(parsed.is_object());
1160        assert_eq!(parsed["type"], "reasoning.text");
1161    }
1162
1163    #[test]
1164    fn serialize_messages_normalizes_minimax_reasoning_details() {
1165        let provider = HuggingFaceProvider::with_model(
1166            "test-key".to_string(),
1167            "MiniMaxAI/MiniMax-M2.5:novita".to_string(),
1168        );
1169        let request = LLMRequest {
1170            model: "MiniMaxAI/MiniMax-M2.5:novita".to_string(),
1171            messages: vec![
1172                Message::assistant("answer".to_string()).with_reasoning_details(Some(vec![json!(
1173                    "{\"type\":\"reasoning.text\",\"text\":\"chain\"}"
1174                )])),
1175            ],
1176            ..Default::default()
1177        };
1178
1179        let messages = provider
1180            .serialize_messages_huggingface_chat(&request)
1181            .expect("message serialization should succeed");
1182        assert!(messages[0]["reasoning_details"].is_array());
1183        assert!(messages[0]["reasoning_details"][0].is_object());
1184    }
1185
1186    #[test]
1187    fn serialize_messages_rehydrates_glm_interleaved_history_into_content() {
1188        let provider =
1189            HuggingFaceProvider::with_model("test-key".to_string(), "zai-org/GLM-5:novita".into());
1190        let request = LLMRequest {
1191            model: "zai-org/GLM-5:novita".to_string(),
1192            messages: vec![
1193                Message::assistant("done".to_string()).with_reasoning(Some("trace".to_string())),
1194            ],
1195            ..Default::default()
1196        };
1197
1198        let messages = provider
1199            .serialize_messages_huggingface_chat(&request)
1200            .expect("message serialization should succeed");
1201
1202        assert_eq!(messages[0]["content"], json!("<think>trace</think>done"));
1203    }
1204
1205    #[test]
1206    fn normalize_step35_flash_provider_suffix() {
1207        let provider = HuggingFaceProvider::with_model(
1208            "test-key".to_string(),
1209            "stepfun-ai/Step-3.5-Flash".to_string(),
1210        );
1211
1212        let normalized = provider
1213            .normalize_model_id("stepfun-ai/Step-3.5-Flash")
1214            .expect("normalization should succeed");
1215        assert_eq!(
1216            normalized,
1217            "stepfun-ai/Step-3.5-Flash:featherless-ai".to_string()
1218        );
1219
1220        let normalized_legacy = provider
1221            .normalize_model_id("stepfun-ai/Step-3.5-Flash:fastest")
1222            .expect("legacy suffix normalization should succeed");
1223        assert_eq!(
1224            normalized_legacy,
1225            "stepfun-ai/Step-3.5-Flash:featherless-ai".to_string()
1226        );
1227    }
1228
1229    #[test]
1230    fn format_for_chat_completions_keeps_apply_patch_as_function_tool() {
1231        let provider = HuggingFaceProvider::with_model(
1232            "test-key".to_string(),
1233            "Qwen/Qwen3-Coder-480B-A35B-Instruct".to_string(),
1234        );
1235        let request = LLMRequest {
1236            model: "Qwen/Qwen3-Coder-480B-A35B-Instruct".to_string(),
1237            messages: vec![Message::user("apply a patch".to_string())],
1238            tools: Some(Arc::new(vec![ToolDefinition::apply_patch(
1239                "Apply patches".to_string(),
1240            )])),
1241            ..Default::default()
1242        };
1243
1244        let payload = provider
1245            .format_for_chat_completions(&request)
1246            .expect("payload should serialize");
1247
1248        assert_eq!(payload["tools"][0]["type"], "function");
1249        assert_eq!(payload["tools"][0]["function"]["name"], "apply_patch");
1250    }
1251}