1#![allow(clippy::bind_instead_of_map, clippy::collapsible_if)]
2
3use crate::config::TimeoutsConfig;
4use crate::config::constants::{env_vars, models, urls};
5use crate::config::core::{AnthropicConfig, ModelConfig, PromptCachingConfig};
6use crate::llm::error_display::format_llm_error;
7use crate::llm::provider::{
8 LLMError, LLMErrorMetadata, LLMProvider, LLMRequest, LLMResponse, LLMStream, LLMStreamEvent,
9 MessageRole, ToolDefinition,
10};
11use crate::llm::providers::shared::{
12 NoopStreamTelemetry, StreamTelemetry, function_output_value_from_message_content,
13};
14use async_stream::try_stream;
15use async_trait::async_trait;
16use futures::StreamExt;
17use reqwest::{Client as HttpClient, Response, StatusCode};
18use serde_json::{Value, json};
19
20use super::common::{
21 assistant_interleaved_history_text, ensure_model, impl_llm_client, is_minimax_m2_model,
22 map_finish_reason_common, normalize_reasoning_detail_objects, override_base_url,
23 parse_response_openai_format, resolve_model,
24};
25use super::error_handling::{format_network_error, format_parse_error};
26
27const PROVIDER_NAME: &str = "HuggingFace";
28const PROVIDER_KEY: &str = "huggingface";
29const JSON_INSTRUCTION: &str = "Return JSON that matches the provided schema.";
30
31pub struct HuggingFaceProvider {
32 api_key: String,
33 http_client: HttpClient,
34 base_url: String,
35 model: String,
36 _timeouts: TimeoutsConfig,
37 model_behavior: Option<ModelConfig>,
38}
39
40impl HuggingFaceProvider {
41 pub fn new(api_key: String) -> Self {
42 Self::with_model_internal(
43 api_key,
44 models::huggingface::DEFAULT_MODEL.to_string(),
45 None,
46 None,
47 None,
48 )
49 }
50
51 pub fn with_model(api_key: String, model: String) -> Self {
52 Self::with_model_internal(api_key, model, None, None, None)
53 }
54
55 pub fn with_timeouts(api_key: String, timeouts: TimeoutsConfig) -> Self {
56 Self::with_model_internal(
57 api_key,
58 models::huggingface::DEFAULT_MODEL.to_string(),
59 None,
60 Some(timeouts),
61 None,
62 )
63 }
64
65 fn with_model_internal(
66 api_key: String,
67 model: String,
68 base_url: Option<String>,
69 timeouts: Option<TimeoutsConfig>,
70 model_behavior: Option<ModelConfig>,
71 ) -> Self {
72 use crate::llm::http_client::HttpClientFactory;
73
74 let timeouts = timeouts.unwrap_or_default();
75
76 Self {
77 api_key,
78 http_client: HttpClientFactory::for_llm(&timeouts),
79 base_url: override_base_url(
80 urls::HUGGINGFACE_API_BASE,
81 base_url,
82 Some(env_vars::HUGGINGFACE_BASE_URL),
83 ),
84 model,
85 _timeouts: timeouts,
86 model_behavior,
87 }
88 }
89
90 pub fn from_config(
91 api_key: Option<String>,
92 model: Option<String>,
93 base_url: Option<String>,
94 _prompt_cache: Option<PromptCachingConfig>,
95 timeouts: Option<TimeoutsConfig>,
96 _anthropic: Option<AnthropicConfig>,
97 model_behavior: Option<ModelConfig>,
98 ) -> Self {
99 let api_key_value = api_key.unwrap_or_default();
100 let model_value = resolve_model(model, models::huggingface::DEFAULT_MODEL);
101 Self::with_model_internal(
102 api_key_value,
103 model_value,
104 base_url,
105 timeouts,
106 model_behavior,
107 )
108 }
109
110 fn normalize_model_id(&self, model: &str) -> Result<String, LLMError> {
111 let model = model.trim();
112 let lower = model.to_ascii_lowercase();
113
114 if lower.starts_with(&models::huggingface::STEP_3_5_FLASH_BASE.to_ascii_lowercase()) {
115 if !model.contains(':') {
116 return Ok(format!(
117 "{}:{}",
118 models::huggingface::STEP_3_5_FLASH_BASE,
119 models::huggingface::STEP_3_5_FLASH_PROVIDER
120 ));
121 }
122 if let Some((base, provider)) = model.rsplit_once(':')
123 && provider.eq_ignore_ascii_case("fastest")
124 {
125 return Ok(format!(
126 "{}:{}",
127 base,
128 models::huggingface::STEP_3_5_FLASH_PROVIDER
129 ));
130 }
131 }
132
133 if lower.contains("minimax-m2") && !model.contains(':') {
134 return Err(LLMError::Provider {
135 message: format_llm_error(
136 PROVIDER_NAME,
137 "MiniMax models require explicit provider selection (:novita suffix). \n Use 'MiniMaxAI/MiniMax-M2.5:novita'.",
138 ),
139 metadata: None,
140 });
141 }
142
143 if lower.contains("glm-5") && !model.contains(':') {
144 return Err(LLMError::Provider {
145 message: format_llm_error(
146 PROVIDER_NAME,
147 "GLM models require explicit provider selection on HuggingFace.",
148 ),
149 metadata: None,
150 });
151 }
152
153 Ok(model.to_string())
154 }
155
156 fn serialize_tools_huggingface(&self, tools: &[ToolDefinition]) -> Option<Vec<Value>> {
157 crate::llm::providers::common::serialize_tools_openai_format(tools)
158 }
159
160 fn serialize_messages_huggingface_chat(
161 &self,
162 request: &LLMRequest,
163 ) -> Result<Vec<Value>, LLMError> {
164 use serde_json::{Map, json};
165
166 let mut messages = Vec::with_capacity(request.messages.len());
167
168 for message in &request.messages {
169 message
170 .validate_for_provider(PROVIDER_KEY)
171 .map_err(|e| LLMError::InvalidRequest {
172 message: e,
173 metadata: None,
174 })?;
175
176 let mut message_map = Map::with_capacity(4);
177 message_map.insert(
178 "role".to_owned(),
179 Value::String(message.role.as_generic_str().to_owned()),
180 );
181
182 if let Some(interleaved_content) =
183 assistant_interleaved_history_text(message, &request.model)
184 {
185 message_map.insert("content".to_owned(), Value::String(interleaved_content));
186 } else {
187 match &message.content {
188 crate::llm::provider::MessageContent::Text(text) => {
189 message_map.insert("content".to_owned(), Value::String(text.clone()));
190 }
191 crate::llm::provider::MessageContent::Parts(parts) => {
192 let has_images = parts
193 .iter()
194 .any(crate::llm::provider::ContentPart::is_image);
195 if has_images {
196 let parts_json: Vec<Value> = parts
197 .iter()
198 .map(|part| match part {
199 crate::llm::provider::ContentPart::Text { text } => {
200 json!({ "type": "text", "text": text })
201 }
202 crate::llm::provider::ContentPart::Image {
203 data,
204 mime_type,
205 ..
206 } => {
207 json!({
208 "type": "image_url",
209 "image_url": {
210 "url": format!("data:{};base64,{}", mime_type, data)
211 }
212 })
213 }
214 crate::llm::provider::ContentPart::File {
215 filename,
216 file_id,
217 file_url,
218 ..
219 } => {
220 let fallback = filename
221 .clone()
222 .or_else(|| file_id.clone())
223 .or_else(|| file_url.clone())
224 .unwrap_or_else(|| "attached file".to_string());
225 json!({ "type": "text", "text": format!("[File input not directly supported: {}]", fallback) })
226 }
227 })
228 .collect();
229 message_map.insert("content".to_owned(), Value::Array(parts_json));
230 } else {
231 let text = message.content.as_text().into_owned();
232 message_map.insert("content".to_owned(), Value::String(text));
233 }
234 }
235 }
236 }
237
238 if let Some(tool_calls) = &message.tool_calls {
239 let serialized_calls = tool_calls
240 .iter()
241 .filter_map(|call| {
242 call.function.as_ref().map(|func| {
243 json!({
244 "id": &call.id,
245 "type": "function",
246 "function": {
247 "name": &func.name,
248 "arguments": &func.arguments
249 }
250 })
251 })
252 })
253 .collect::<Vec<_>>();
254 message_map.insert("tool_calls".to_owned(), Value::Array(serialized_calls));
255 }
256
257 if let Some(tool_call_id) = &message.tool_call_id {
258 message_map.insert(
259 "tool_call_id".to_owned(),
260 Value::String(tool_call_id.clone()),
261 );
262 }
263
264 if message.role == MessageRole::Assistant
265 && is_minimax_m2_model(&request.model)
266 && let Some(reasoning_details) = &message.reasoning_details
267 && !reasoning_details.is_empty()
268 {
269 let normalized_details = normalize_reasoning_detail_objects(reasoning_details);
270 if !normalized_details.is_empty() {
271 message_map.insert(
272 "reasoning_details".to_owned(),
273 Value::Array(normalized_details),
274 );
275 }
276 }
277
278 messages.push(Value::Object(message_map));
279 }
280
281 Ok(messages)
282 }
283
284 fn format_for_chat_completions(&self, request: &LLMRequest) -> Result<Value, LLMError> {
285 let mut messages = self.serialize_messages_huggingface_chat(request)?;
286 let is_glm = self.is_glm_model(&request.model);
287
288 if let Some(system) = &request.system_prompt {
289 let has_system = messages
290 .first()
291 .and_then(|m| m.get("role"))
292 .and_then(|r| r.as_str())
293 == Some("system");
294 if !has_system {
295 messages.insert(
296 0,
297 json!({
298 "role": "system",
299 "content": system
300 }),
301 );
302 }
303 }
304
305 let mut payload = json!({
306 "model": request.model,
307 "messages": messages,
308 "stream": request.stream,
309 });
310
311 if request.stream && request.tools.is_some() && is_glm {
312 payload["tool_stream"] = json!(true);
313 }
314
315 if let Some(max_tokens) = request.max_tokens {
316 payload["max_tokens"] = json!(max_tokens);
317 }
318
319 if let Some(tools) = &request.tools {
320 if let Some(serialized) = self.serialize_tools_huggingface(tools) {
321 payload["tools"] = json!(serialized);
322
323 if let Some(choice) = &request.tool_choice {
324 payload["tool_choice"] = choice.to_provider_format("openai");
325 }
326 }
327 }
328
329 if let Some(temperature) = request.temperature {
330 payload["temperature"] = json!(temperature);
331 }
332
333 if let Some(top_p) = request.top_p {
334 payload["top_p"] = json!(top_p);
335 }
336
337 if let Some(top_k) = request.top_k {
338 payload["top_k"] = json!(top_k);
339 }
340
341 if let Some(effort) = request.reasoning_effort {
342 use crate::config::models::Provider;
343 use crate::llm::rig_adapter::RigProviderCapabilities;
344 if let Some(reasoning_params) =
345 RigProviderCapabilities::new(Provider::HuggingFace, &request.model)
346 .reasoning_parameters(effort)
347 {
348 if let Some(params_obj) = reasoning_params.as_object() {
349 for (k, v) in params_obj {
350 payload[k] = v.clone();
351 }
352 }
353 }
354 }
355
356 if request.output_format.is_some() && !is_glm {
357 payload["response_format"] = json!({ "type": "json_object" });
358 }
359
360 Ok(payload)
361 }
362
363 fn is_glm_model(&self, model: &str) -> bool {
364 let lower = model.to_ascii_lowercase();
365 lower.contains("glm")
366 }
367
368 fn is_deepseek_model(&self, model: &str) -> bool {
369 let lower = model.to_ascii_lowercase();
370 lower.contains("deepseek")
371 }
372
373 fn is_minimax_model(&self, model: &str) -> bool {
374 let lower = model.to_ascii_lowercase();
375 lower.contains("minimax")
376 }
377
378 fn apply_model_defaults(&self, request: &mut LLMRequest) {
379 if self.is_minimax_model(&request.model) {
380 if request.temperature.is_none() {
381 request.temperature = Some(1.0);
382 }
383 if request.top_p.is_none() {
384 request.top_p = Some(0.95);
385 }
386 if request.top_k.is_none() {
387 request.top_k = Some(40);
388 }
389 }
390 }
391
392 fn add_json_instruction(&self, payload: &mut Value) -> Result<(), LLMError> {
393 if let Some(instructions) = payload.get_mut("instructions") {
394 if let Some(text) = instructions.as_str() {
395 if !text.contains("Return JSON") {
396 *instructions = json!(format!("{}\n\n{}", text, JSON_INSTRUCTION));
397 }
398 }
399 } else {
400 payload["instructions"] = json!(JSON_INSTRUCTION);
401 }
402
403 Ok(())
404 }
405
406 fn format_for_responses_api(&self, request: &LLMRequest) -> Result<Value, LLMError> {
407 let mut input = Vec::new();
408
409 for msg in &request.messages {
410 let convert_parts = |parts: &[crate::llm::provider::ContentPart]| -> Value {
411 let parts_json: Vec<Value> = parts
412 .iter()
413 .map(|part| match part {
414 crate::llm::provider::ContentPart::Text { text } => {
415 json!({ "type": "input_text", "text": text })
416 }
417 crate::llm::provider::ContentPart::Image {
418 data, mime_type, ..
419 } => {
420 json!({
421 "type": "input_image",
422 "image_url": format!("data:{};base64,{}", mime_type, data)
423 })
424 }
425 crate::llm::provider::ContentPart::File {
426 filename,
427 file_id,
428 file_url,
429 ..
430 } => {
431 let fallback = filename
432 .clone()
433 .or_else(|| file_id.clone())
434 .or_else(|| file_url.clone())
435 .unwrap_or_else(|| "attached file".to_string());
436 json!({
437 "type": "input_text",
438 "text": format!("[File input not directly supported: {}]", fallback)
439 })
440 }
441 })
442 .collect();
443 json!(parts_json)
444 };
445
446 match msg.role {
447 MessageRole::System | MessageRole::User => {
448 if msg.role == MessageRole::System && request.system_prompt.is_some() {
449 if let crate::llm::provider::MessageContent::Text(text) = &msg.content {
450 if request.system_prompt.as_ref().map(|s| s.as_str())
451 == Some(text.as_str())
452 {
453 continue;
454 }
455 }
456 }
457
458 let role = if msg.role == MessageRole::System {
459 "system"
460 } else {
461 "user"
462 };
463
464 let mut message_obj = json!({
465 "type": "message",
466 "role": role,
467 });
468
469 match &msg.content {
470 crate::llm::provider::MessageContent::Text(text) => {
471 message_obj["content"] = json!(text);
472 }
473 crate::llm::provider::MessageContent::Parts(parts) => {
474 message_obj["content"] = convert_parts(parts);
475 }
476 }
477
478 input.push(message_obj);
479 }
480 MessageRole::Assistant => {
481 let has_content = match &msg.content {
482 crate::llm::provider::MessageContent::Text(text) => !text.is_empty(),
483 crate::llm::provider::MessageContent::Parts(parts) => !parts.is_empty(),
484 };
485
486 if has_content {
487 let mut message_obj = json!({
488 "type": "message",
489 "role": "assistant",
490 });
491
492 match &msg.content {
493 crate::llm::provider::MessageContent::Text(text) => {
494 message_obj["content"] = json!(text);
495 }
496 crate::llm::provider::MessageContent::Parts(parts) => {
497 message_obj["content"] = convert_parts(parts);
498 }
499 }
500
501 input.push(message_obj);
502 }
503
504 if let Some(tool_calls) = &msg.tool_calls {
505 for tc in tool_calls {
506 if let Some(func) = &tc.function {
507 input.push(json!({
508 "type": "function_call",
509 "call_id": tc.id,
510 "name": func.name,
511 "arguments": func.arguments
512 }));
513 }
514 }
515 }
516 }
517 MessageRole::Tool => {
518 input.push(json!({
519 "type": "function_call_output",
520 "call_id": msg.tool_call_id.clone().unwrap_or_default(),
521 "output": function_output_value_from_message_content(&msg.content)
522 }));
523 }
524 }
525 }
526
527 let mut payload = json!({
528 "model": request.model,
529 "input": input,
530 "stream": request.stream,
531 });
532
533 if let Some(system_prompt) = &request.system_prompt {
534 payload["instructions"] = json!(system_prompt);
535 }
536
537 if let Some(effort) = request.reasoning_effort {
538 use crate::config::types::ReasoningEffortLevel;
539 if effort != ReasoningEffortLevel::None {
540 payload["reasoning"] = json!({ "effort": effort.as_str() });
541 }
542 }
543
544 if let Some(max_tokens) = request.max_tokens {
545 payload["max_tokens"] = json!(max_tokens);
546 }
547 if let Some(temperature) = request.temperature {
548 payload["temperature"] = json!(temperature);
549 }
550 if let Some(top_p) = request.top_p {
551 payload["top_p"] = json!(top_p);
552 }
553 if let Some(top_k) = request.top_k {
554 payload["top_k"] = json!(top_k);
555 }
556
557 if let Some(tools) = &request.tools {
558 if let Some(serialized) = self.serialize_tools_huggingface(tools) {
559 payload["tools"] = json!(serialized);
560
561 if let Some(choice) = &request.tool_choice {
562 payload["tool_choice"] = choice.to_provider_format("openai");
563 }
564 }
565 }
566
567 if request.output_format.is_some() || request.tools.is_some() {
568 self.add_json_instruction(&mut payload)?;
569 }
570
571 if request.output_format.is_some() && !self.is_glm_model(&request.model) {
572 payload["response_format"] = json!({ "type": "json_object" });
573 }
574
575 Ok(payload)
576 }
577
578 fn should_use_responses_api(&self, _request: &LLMRequest) -> bool {
579 false
580 }
581
582 fn format_error(&self, status: StatusCode, body: &str) -> LLMError {
583 let message = if body.contains("\"code\":\"model_not_supported\"")
584 && body.contains(models::huggingface::STEP_3_5_FLASH_BASE)
585 {
586 format!(
587 "HuggingFace API error ({}): Step 3.5 Flash requires the '{}' provider. \
588Enable that provider in your HuggingFace Inference Providers settings, or switch to another model.",
589 status,
590 models::huggingface::STEP_3_5_FLASH_PROVIDER
591 )
592 } else {
593 format!("HuggingFace API error ({}): {}", status, body)
594 };
595
596 LLMError::Provider {
597 message: format_llm_error(PROVIDER_NAME, &message),
598 metadata: Some(LLMErrorMetadata::new(
599 PROVIDER_NAME,
600 Some(status.as_u16()),
601 None,
602 None,
603 None,
604 None,
605 Some(body.to_string()),
606 )),
607 }
608 }
609
610 fn parse_responses_api_format(json: &Value, model: String) -> Result<LLMResponse, LLMError> {
611 let convenience_text = json.get("output_text").and_then(|t| t.as_str());
612
613 let json_obj = json.get("response").unwrap_or(json);
614
615 let output = json_obj.get("output").and_then(|v| v.as_array());
616
617 let output_arr = match output {
618 Some(arr) => arr,
619 None => {
620 if let Some(text) = convenience_text {
621 return Ok(LLMResponse {
622 content: Some(text.to_string()),
623 tool_calls: None,
624 model,
625 usage: None,
626 finish_reason: crate::llm::provider::FinishReason::Stop,
627 reasoning: None,
628 reasoning_details: None,
629 tool_references: Vec::new(),
630 request_id: None,
631 organization_id: None,
632 compaction: None,
633 });
634 }
635
636 return Err(LLMError::Provider {
637 message: format_llm_error(PROVIDER_NAME, "Not a Responses API format"),
638 metadata: None,
639 });
640 }
641 };
642
643 let mut content_fragments: Vec<String> = Vec::new();
644 let mut reasoning_fragments: Vec<String> = Vec::new();
645 let mut tool_calls: Vec<crate::llm::provider::ToolCall> = Vec::new();
646
647 for item in output_arr {
648 let item_type = item.get("type").and_then(|t| t.as_str()).unwrap_or("");
649
650 match item_type {
651 "message" => {
652 if let Some(content_arr) = item.get("content").and_then(|c| c.as_array()) {
653 for entry in content_arr {
654 let entry_type =
655 entry.get("type").and_then(|t| t.as_str()).unwrap_or("");
656 match entry_type {
657 "text" | "output_text" => {
658 if let Some(text) = entry.get("text").and_then(|t| t.as_str()) {
659 if !text.is_empty() {
660 content_fragments.push(text.to_string());
661 }
662 }
663 }
664 "reasoning" => {
665 if let Some(text) = entry.get("text").and_then(|t| t.as_str()) {
666 if !text.is_empty() {
667 reasoning_fragments.push(text.to_string());
668 }
669 }
670 }
671 "function_call" | "tool_call" => {
672 if let Some(call) = Self::parse_responses_tool_call(entry) {
673 tool_calls.push(call);
674 }
675 }
676 _ => {}
677 }
678 }
679 }
680 }
681 "function_call" | "tool_call" => {
682 if let Some(call) = Self::parse_responses_tool_call(item) {
683 tool_calls.push(call);
684 }
685 }
686 "reasoning" => {
687 if let Some(summary_arr) = item.get("summary").and_then(|s| s.as_array()) {
688 for summary in summary_arr {
689 if let Some(text) = summary.get("text").and_then(|t| t.as_str()) {
690 if !text.is_empty() {
691 reasoning_fragments.push(text.to_string());
692 }
693 }
694 }
695 } else if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
696 reasoning_fragments.push(text.to_string());
697 }
698 }
699 _ => {}
700 }
701 }
702
703 let content = if content_fragments.is_empty() {
704 convenience_text.map(|t| t.to_string())
705 } else {
706 Some(content_fragments.join(""))
707 };
708
709 let reasoning = if reasoning_fragments.is_empty() {
710 None
711 } else {
712 Some(reasoning_fragments.join("\n\n"))
713 };
714
715 let finish_reason = if !tool_calls.is_empty() {
716 crate::llm::provider::FinishReason::ToolCalls
717 } else {
718 crate::llm::provider::FinishReason::Stop
719 };
720
721 let usage_value = json.get("usage").or_else(|| json_obj.get("usage"));
722 let usage = usage_value.map(|usage_value| crate::llm::provider::Usage {
723 prompt_tokens: usage_value
724 .get("input_tokens")
725 .or_else(|| usage_value.get("prompt_tokens"))
726 .and_then(|pt| pt.as_u64())
727 .unwrap_or(0) as u32,
728 completion_tokens: usage_value
729 .get("output_tokens")
730 .or_else(|| usage_value.get("completion_tokens"))
731 .and_then(|ct| ct.as_u64())
732 .unwrap_or(0) as u32,
733 total_tokens: usage_value
734 .get("total_tokens")
735 .and_then(|tt| tt.as_u64())
736 .unwrap_or(0) as u32,
737 cached_prompt_tokens: None,
738 cache_creation_tokens: None,
739 cache_read_tokens: None,
740 });
741
742 Ok(LLMResponse {
743 content,
744 tool_calls: if tool_calls.is_empty() {
745 None
746 } else {
747 Some(tool_calls)
748 },
749 model,
750 usage,
751 finish_reason,
752 reasoning,
753 reasoning_details: None,
754 tool_references: Vec::new(),
755 request_id: None,
756 organization_id: None,
757 compaction: None,
758 })
759 }
760
761 fn parse_responses_tool_call(item: &Value) -> Option<crate::llm::provider::ToolCall> {
762 let call_id = item.get("id").and_then(|v| v.as_str()).unwrap_or("");
763 let function_obj = item.get("function").and_then(|v| v.as_object());
764 let name = function_obj.and_then(|f| f.get("name").and_then(|n| n.as_str()))?;
765 let arguments = function_obj.and_then(|f| f.get("arguments"));
766
767 let serialized = arguments.map_or("{}".to_owned(), |args| {
768 if args.is_string() {
769 args.as_str().unwrap_or("{}").to_string()
770 } else {
771 args.to_string()
772 }
773 });
774
775 Some(crate::llm::provider::ToolCall::function(
776 call_id.to_string(),
777 name.to_string(),
778 serialized,
779 ))
780 }
781
782 async fn parse_response(
783 &self,
784 response: Response,
785 model: String,
786 use_responses_api: bool,
787 ) -> Result<LLMResponse, LLMError> {
788 let status = response.status();
789
790 if !status.is_success() {
791 let body = response.text().await.unwrap_or_default();
792 return Err(self.format_error(status, &body));
793 }
794
795 let json: Value = response
796 .json()
797 .await
798 .map_err(|err| format_parse_error(PROVIDER_NAME, &err))?;
799
800 if use_responses_api {
801 if json.get("output").is_some() {
802 return Self::parse_responses_api_format(&json, model);
803 }
804 }
805
806 parse_response_openai_format::<fn(&Value, &Value) -> Option<String>>(
807 json,
808 PROVIDER_NAME,
809 model,
810 false,
811 None,
812 )
813 }
814
815 pub fn available_models() -> Vec<String> {
816 models::huggingface::SUPPORTED_MODELS
817 .iter()
818 .map(|s| s.to_string())
819 .collect()
820 }
821
822 fn get_endpoint(&self, use_responses_api: bool) -> String {
823 let base = self.base_url.trim_end_matches('/');
824 if use_responses_api {
825 format!("{}/responses", base)
826 } else {
827 format!("{}/chat/completions", base)
828 }
829 }
830}
831
832#[async_trait]
833impl LLMProvider for HuggingFaceProvider {
834 fn name(&self) -> &str {
835 PROVIDER_KEY
836 }
837
838 fn supports_streaming(&self) -> bool {
839 true
840 }
841
842 fn supports_reasoning(&self, model: &str) -> bool {
843 models::huggingface::REASONING_MODELS.contains(&model)
846 || self
847 .model_behavior
848 .as_ref()
849 .and_then(|b| b.model_supports_reasoning)
850 .unwrap_or(false)
851 }
852
853 fn supports_reasoning_effort(&self, model: &str) -> bool {
854 self.is_glm_model(model)
856 || self.is_deepseek_model(model)
857 || self
858 .model_behavior
859 .as_ref()
860 .and_then(|b| b.model_supports_reasoning_effort)
861 .unwrap_or(false)
862 }
863
864 fn supports_tools(&self, _model: &str) -> bool {
865 true
866 }
867
868 fn supports_parallel_tool_config(&self, _model: &str) -> bool {
869 false
870 }
871
872 fn supports_structured_output(&self, _model: &str) -> bool {
873 true
874 }
875
876 fn supports_context_caching(&self, _model: &str) -> bool {
877 false
878 }
879
880 fn effective_context_size(&self, _model: &str) -> usize {
881 128_000
882 }
883
884 async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
885 let model = ensure_model(&mut request, &self.model);
886
887 self.apply_model_defaults(&mut request);
888 self.validate_request(&request)?;
889
890 let model_id = self.normalize_model_id(&request.model)?;
891 request.model = model_id;
892
893 let use_responses_api = self.should_use_responses_api(&request);
894 let payload = if use_responses_api {
895 self.format_for_responses_api(&request)?
896 } else {
897 self.format_for_chat_completions(&request)?
898 };
899
900 let endpoint = self.get_endpoint(use_responses_api);
901
902 let response = self
903 .http_client
904 .post(&endpoint)
905 .header("Authorization", format!("Bearer {}", self.api_key))
906 .json(&payload)
907 .send()
908 .await
909 .map_err(|err| format_network_error(PROVIDER_NAME, &err))?;
910
911 self.parse_response(response, model, use_responses_api)
912 .await
913 }
914
915 async fn stream(&self, mut request: LLMRequest) -> Result<LLMStream, LLMError> {
916 let model = ensure_model(&mut request, &self.model);
917
918 self.apply_model_defaults(&mut request);
919 self.validate_request(&request)?;
920 request.stream = true;
921
922 let model_id = self.normalize_model_id(&request.model)?;
923 request.model = model_id;
924
925 let use_responses_api = self.should_use_responses_api(&request);
926 let payload = if use_responses_api {
927 self.format_for_responses_api(&request)?
928 } else {
929 self.format_for_chat_completions(&request)?
930 };
931
932 let endpoint = self.get_endpoint(use_responses_api);
933
934 let response = self
935 .http_client
936 .post(&endpoint)
937 .header("Authorization", format!("Bearer {}", self.api_key))
938 .json(&payload)
939 .send()
940 .await
941 .map_err(|err| format_network_error(PROVIDER_NAME, &err))?;
942
943 if !response.status().is_success() {
944 let status = response.status();
945 let body = response.text().await.unwrap_or_default();
946 return Err(self.format_error(status, &body));
947 }
948
949 self.create_stream(response, model, use_responses_api).await
950 }
951
952 fn supported_models(&self) -> Vec<String> {
953 Self::available_models()
954 }
955
956 fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
957 if request.messages.is_empty() {
958 return Err(LLMError::InvalidRequest {
959 message: format_llm_error(PROVIDER_NAME, "Messages cannot be empty"),
960 metadata: None,
961 });
962 }
963
964 if request.model.trim().is_empty() {
965 return Err(LLMError::InvalidRequest {
966 message: format_llm_error(PROVIDER_NAME, "Model identifier cannot be empty"),
967 metadata: None,
968 });
969 }
970
971 Ok(())
972 }
973}
974
975impl HuggingFaceProvider {
976 async fn create_stream(
977 &self,
978 response: Response,
979 model: String,
980 use_responses_api: bool,
981 ) -> Result<LLMStream, LLMError> {
982 let mut bytes_stream = response.bytes_stream();
983 let mut buffer = String::with_capacity(4096);
984 let mut aggregator = crate::llm::providers::shared::StreamAggregator::new(model.clone());
985 let telemetry = NoopStreamTelemetry;
986
987 let stream = try_stream! {
988 'outer: while let Some(chunk_result) = bytes_stream.next().await {
989 let chunk = chunk_result.map_err(|err| format_network_error(PROVIDER_NAME, &err))?;
990 let text = String::from_utf8_lossy(&chunk);
991 buffer.push_str(&text);
992
993 if buffer.len() > 128_000 {
994 Err(LLMError::Provider {
995 message: format_llm_error(PROVIDER_NAME, "Stream buffer exceeded maximum size (128KB)"),
996 metadata: None,
997 })?;
998 }
999
1000 while let Some(newline_pos) = buffer.find('\n') {
1001 let line = buffer[..newline_pos].trim().to_string();
1002 buffer.drain(..=newline_pos);
1003
1004 if line.is_empty() || line.starts_with(':') {
1005 continue;
1006 }
1007
1008 let data = if let Some(stripped) = line.strip_prefix("data: ") {
1009 stripped
1010 } else {
1011 continue;
1012 };
1013
1014 if data == "[DONE]" {
1015 break 'outer;
1016 }
1017
1018 let event: Value = match serde_json::from_str(data) {
1019 Ok(v) => v,
1020 Err(_) => continue,
1021 };
1022
1023 if use_responses_api {
1024 let event_type = event.get("type").and_then(|t| t.as_str()).unwrap_or("");
1025
1026 match event_type {
1027 "response.output_text.delta" | "output_text.delta" => {
1028 if let Some(delta) = event.get("delta").and_then(|d| d.as_str()) {
1029 telemetry.on_content_delta(delta);
1030 for ev in aggregator.handle_content(delta) {
1031 yield ev;
1032 }
1033 }
1034 continue;
1035 }
1036 "response.reasoning.delta" | "reasoning.delta" => {
1037 if let Some(delta) = event.get("delta").and_then(|d| d.as_str()) {
1038 if let Some(d) = aggregator.handle_reasoning(delta) {
1039 telemetry.on_reasoning_delta(&d);
1040 yield LLMStreamEvent::Reasoning { delta: d };
1041 }
1042 }
1043 continue;
1044 }
1045 "response.function_call_arguments.delta" | "tool_call.delta" => {
1046 telemetry.on_tool_call_delta();
1047 continue;
1048 }
1049 "response.completed" => {
1050 if let Some(response_obj) = event.get("response") {
1051 if let Ok(response) = Self::parse_responses_api_format(response_obj, model.clone()) {
1052 let final_agg_response = aggregator.finalize();
1053 let mut merged_response = response;
1054 if merged_response.content.is_none() {
1055 merged_response.content = final_agg_response.content;
1056 }
1057 if merged_response.reasoning.is_none() {
1058 merged_response.reasoning = final_agg_response.reasoning;
1059 }
1060 if merged_response.tool_calls.is_none() {
1061 merged_response.tool_calls = final_agg_response.tool_calls;
1062 }
1063 if merged_response.usage.is_none() {
1064 merged_response.usage = final_agg_response.usage;
1065 }
1066 yield LLMStreamEvent::Completed { response: Box::new(merged_response) };
1067 return;
1068 }
1069 }
1070 break 'outer;
1071 }
1072 "response.done" => {
1073 break 'outer;
1074 }
1075 _ => {}
1076 }
1077 }
1078
1079 if let Some(choices_arr) = event.get("choices").and_then(|c| c.as_array()) {
1080 if let Some(choice) = choices_arr.first() {
1081 if let Some(delta_obj) = choice.get("delta") {
1082 if let Some(content) = delta_obj.get("content").and_then(|c| c.as_str()) {
1083 telemetry.on_content_delta(content);
1084 for ev in aggregator.handle_content(content) {
1085 yield ev;
1086 }
1087 }
1088
1089 if let Some(reason) = delta_obj.get("reasoning_content").and_then(|r| r.as_str()) {
1090 if let Some(d) = aggregator.handle_reasoning(reason) {
1091 telemetry.on_reasoning_delta(&d);
1092 yield LLMStreamEvent::Reasoning { delta: d };
1093 }
1094 }
1095
1096 if let Some(reasoning_details) = delta_obj
1097 .get("reasoning_details")
1098 .and_then(|details| details.as_array())
1099 {
1100 aggregator.set_reasoning_details(reasoning_details);
1101 }
1102
1103 if let Some(tool_calls_arr) = delta_obj.get("tool_calls").and_then(|tc| tc.as_array()) {
1104 aggregator.handle_tool_calls(tool_calls_arr);
1105 telemetry.on_tool_call_delta();
1106 }
1107 }
1108
1109 if let Some(finish_reason_str) = choice.get("finish_reason").and_then(|fr| fr.as_str()) {
1110 aggregator.set_finish_reason(map_finish_reason_common(finish_reason_str));
1111 if let Some(usage_value) = event.get("usage") {
1112 aggregator.set_usage(crate::llm::provider::Usage {
1113 prompt_tokens: usage_value.get("prompt_tokens").and_then(|pt| pt.as_u64()).unwrap_or(0) as u32,
1114 completion_tokens: usage_value.get("completion_tokens").and_then(|ct| ct.as_u64()).unwrap_or(0) as u32,
1115 total_tokens: usage_value.get("total_tokens").and_then(|tt| tt.as_u64()).unwrap_or(0) as u32,
1116 cached_prompt_tokens: None,
1117 cache_creation_tokens: None,
1118 cache_read_tokens: None,
1119 });
1120 }
1121
1122 break 'outer;
1123 }
1124 }
1125 }
1126 }
1127 }
1128
1129 yield LLMStreamEvent::Completed { response: Box::new(aggregator.finalize()) };
1130 };
1131
1132 Ok(Box::pin(stream))
1133 }
1134}
1135
1136impl_llm_client!(HuggingFaceProvider);
1137
1138#[cfg(test)]
1139mod tests {
1140 use super::HuggingFaceProvider;
1141 use crate::llm::provider::{LLMRequest, Message, ToolDefinition};
1142 use crate::llm::providers::common::{is_minimax_m2_model, normalize_reasoning_detail_object};
1143 use serde_json::json;
1144 use std::sync::Arc;
1145
1146 #[test]
1147 fn minimax_model_detection_handles_variants() {
1148 assert!(is_minimax_m2_model("MiniMaxAI/MiniMax-M2.5:novita"));
1149 assert!(is_minimax_m2_model("minimax-m2.5"));
1150 assert!(!is_minimax_m2_model("deepseek-r1"));
1151 }
1152
1153 #[test]
1154 fn normalize_reasoning_detail_decodes_stringified_object() {
1155 let parsed = normalize_reasoning_detail_object(&json!(
1156 "{\"type\":\"reasoning.text\",\"text\":\"step\"}"
1157 ))
1158 .expect("expected a parsed reasoning detail object");
1159 assert!(parsed.is_object());
1160 assert_eq!(parsed["type"], "reasoning.text");
1161 }
1162
1163 #[test]
1164 fn serialize_messages_normalizes_minimax_reasoning_details() {
1165 let provider = HuggingFaceProvider::with_model(
1166 "test-key".to_string(),
1167 "MiniMaxAI/MiniMax-M2.5:novita".to_string(),
1168 );
1169 let request = LLMRequest {
1170 model: "MiniMaxAI/MiniMax-M2.5:novita".to_string(),
1171 messages: vec![
1172 Message::assistant("answer".to_string()).with_reasoning_details(Some(vec![json!(
1173 "{\"type\":\"reasoning.text\",\"text\":\"chain\"}"
1174 )])),
1175 ],
1176 ..Default::default()
1177 };
1178
1179 let messages = provider
1180 .serialize_messages_huggingface_chat(&request)
1181 .expect("message serialization should succeed");
1182 assert!(messages[0]["reasoning_details"].is_array());
1183 assert!(messages[0]["reasoning_details"][0].is_object());
1184 }
1185
1186 #[test]
1187 fn serialize_messages_rehydrates_glm_interleaved_history_into_content() {
1188 let provider =
1189 HuggingFaceProvider::with_model("test-key".to_string(), "zai-org/GLM-5:novita".into());
1190 let request = LLMRequest {
1191 model: "zai-org/GLM-5:novita".to_string(),
1192 messages: vec![
1193 Message::assistant("done".to_string()).with_reasoning(Some("trace".to_string())),
1194 ],
1195 ..Default::default()
1196 };
1197
1198 let messages = provider
1199 .serialize_messages_huggingface_chat(&request)
1200 .expect("message serialization should succeed");
1201
1202 assert_eq!(messages[0]["content"], json!("<think>trace</think>done"));
1203 }
1204
1205 #[test]
1206 fn normalize_step35_flash_provider_suffix() {
1207 let provider = HuggingFaceProvider::with_model(
1208 "test-key".to_string(),
1209 "stepfun-ai/Step-3.5-Flash".to_string(),
1210 );
1211
1212 let normalized = provider
1213 .normalize_model_id("stepfun-ai/Step-3.5-Flash")
1214 .expect("normalization should succeed");
1215 assert_eq!(
1216 normalized,
1217 "stepfun-ai/Step-3.5-Flash:featherless-ai".to_string()
1218 );
1219
1220 let normalized_legacy = provider
1221 .normalize_model_id("stepfun-ai/Step-3.5-Flash:fastest")
1222 .expect("legacy suffix normalization should succeed");
1223 assert_eq!(
1224 normalized_legacy,
1225 "stepfun-ai/Step-3.5-Flash:featherless-ai".to_string()
1226 );
1227 }
1228
1229 #[test]
1230 fn format_for_chat_completions_keeps_apply_patch_as_function_tool() {
1231 let provider = HuggingFaceProvider::with_model(
1232 "test-key".to_string(),
1233 "Qwen/Qwen3-Coder-480B-A35B-Instruct".to_string(),
1234 );
1235 let request = LLMRequest {
1236 model: "Qwen/Qwen3-Coder-480B-A35B-Instruct".to_string(),
1237 messages: vec![Message::user("apply a patch".to_string())],
1238 tools: Some(Arc::new(vec![ToolDefinition::apply_patch(
1239 "Apply patches".to_string(),
1240 )])),
1241 ..Default::default()
1242 };
1243
1244 let payload = provider
1245 .format_for_chat_completions(&request)
1246 .expect("payload should serialize");
1247
1248 assert_eq!(payload["tools"][0]["type"], "function");
1249 assert_eq!(payload["tools"][0]["function"]["name"], "apply_patch");
1250 }
1251}