toast_api/
server.rs

1use axum::{
2    extract::State,
3    http::StatusCode,
4    response::IntoResponse,
5    routing::{get, post},
6    Json, Router,
7};
8use clap::Parser;
9use dashmap::DashMap;
10use serde::{Deserialize, Serialize};
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tower_http::cors::{Any, CorsLayer};
14use tracing::{info, warn};
15use uuid::Uuid;
16
17use crate::api::{Claude, Session};
18
19/// Command line arguments for server configuration
20#[derive(Parser, Debug, Clone)]
21#[clap(author, version, about = "OpenAI-compatible Claude API server")]
22pub struct Args {
23    /// Port to listen on - public field
24    #[clap(long, short, default_value = "3000")]
25    pub port: u16,
26
27    /// Enable debug logging
28    #[clap(long)]
29    pub debug: bool,
30
31    /// Default Claude model to use
32    #[clap(long, default_value = "claude-3-7-sonnet-latest")]
33    pub model: String,
34
35    /// Host address to bind to
36    #[clap(long, default_value = "0.0.0.0")]
37    pub host: String,
38}
39
40/// Server state shared across all requests
41#[derive(Clone)]
42pub struct AppState {
43    pub claude: Arc<Claude>,
44    // Map to track conversations: Vec<Message> -> (conversation_id, Vec<Message>)
45    pub conversations: Arc<DashMap<String, (String, Vec<Message>)>>,
46}
47
48/// Base response structure for OpenAI-compatible API
49#[derive(Serialize, Debug)]
50pub struct OpenAIResponse {
51    id: String,
52    object: &'static str,
53    created: u64,
54    model: String,
55    system_fingerprint: String,
56    choices: Vec<Choice>,
57    usage: Usage,
58}
59
60/// Token usage information
61#[derive(Serialize, Debug)]
62pub struct Usage {
63    prompt_tokens: u32,
64    completion_tokens: u32,
65    total_tokens: u32,
66}
67
68/// Chat message structure (OpenAI compatible)
69#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
70pub struct Message {
71    role: String,
72    content: ContentValue,
73}
74
75/// Content value can be string, array of strings, or structured content
76#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
77#[serde(untagged)]
78pub enum ContentValue {
79    String(String),
80    Array(Vec<String>),
81    Object(ContentObject),
82    ContentBlocks(Vec<ContentBlock>),
83}
84
85/// Content object with type and other fields
86#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
87pub struct ContentObject {
88    #[serde(default)]
89    pub text: Option<String>,
90    #[serde(default)]
91    pub r#type: Option<String>,
92    // Allow for other fields
93    #[serde(flatten)]
94    pub other: std::collections::HashMap<String, serde_json::Value>,
95}
96
97/// Structured content block
98#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
99pub struct ContentBlock {
100    pub r#type: String,
101    #[serde(default)]
102    pub text: Option<String>,
103    // Allow for other fields
104    #[serde(flatten)]
105    pub other: std::collections::HashMap<String, serde_json::Value>,
106}
107
108impl Message {
109    // Helper to get content as a string regardless of format
110    pub fn content_as_string(&self) -> String {
111        match &self.content {
112            ContentValue::String(content) => content.clone(),
113            ContentValue::Array(content) => content.join(" "),
114            ContentValue::Object(obj) => {
115                if let Some(text) = &obj.text {
116                    text.clone()
117                } else {
118                    // Try to convert to JSON string if no text field
119                    serde_json::to_string(obj).unwrap_or_default()
120                }
121            }
122            ContentValue::ContentBlocks(blocks) => blocks
123                .iter()
124                .filter_map(|block| block.text.as_ref())
125                .cloned()
126                .collect::<Vec<_>>()
127                .join(" "),
128        }
129    }
130}
131
132#[derive(Debug, Deserialize)]
133pub struct MessageInner {
134    role: String,
135    content: String,
136}
137
138impl From<MessageInner> for Message {
139    fn from(inner: MessageInner) -> Self {
140        Self {
141            role: inner.role,
142            content: ContentValue::String(inner.content),
143        }
144    }
145}
146
147/// Deserializer for message content that supports both strings and arrays
148/// Custom serialization/deserialization for message content
149/// Supports either a single string or an array of strings
150mod content_string_or_array {
151    use serde::{self, Deserialize, Deserializer, Serializer};
152
153    pub fn _serialize<S>(content: &str, serializer: S) -> Result<S::Ok, S::Error>
154    where
155        S: Serializer,
156    {
157        serializer.serialize_str(content)
158    }
159
160    pub fn _deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
161    where
162        D: Deserializer<'de>,
163    {
164        #[derive(Deserialize)]
165        #[serde(untagged)]
166        enum StringOrArray {
167            String(()),
168            Array(()),
169        }
170
171        let value = StringOrArray::deserialize(deserializer)?;
172        match value {
173            StringOrArray::String(_) => Ok(String::new()),
174            StringOrArray::Array(_) => Ok(String::new()),
175        }
176    }
177}
178
179/// Chat completion request (OpenAI compatible)
180#[derive(Deserialize, Serialize, Debug)]
181pub struct ChatCompletionRequest {
182    model: String,
183    messages: Vec<Message>,
184    #[serde(default)]
185    user: Option<String>,  // Optional user identifier for thread association
186}
187
188/// Chat completion response choices (OpenAI compatible)
189#[derive(Serialize, Debug)]
190pub struct Choice {
191    index: i32,
192    message: Message,
193    logprobs: Option<serde_json::Value>,
194    finish_reason: String,
195}
196
197/// Helper function to get model from request or default to Sonnet
198fn get_model(model_name: &str) -> &'static str {
199    match model_name {
200        "gpt-3.5-turbo" => crate::config::HAIKU_MODEL,
201        "gpt-4" | "gpt-4-turbo" => crate::config::SONNET_MODEL,
202        "gpt-4o" | "gpt-4.1" => crate::config::OPUS_MODEL,
203        _ => crate::config::SONNET_MODEL,
204    }
205}
206
207/// Normalize roles for Claude API
208fn _normalize_role(role: &str) -> String {
209    match role {
210        "assistant" => "assistant".to_string(),
211        "user" => "human".to_string(),
212        "system" | "developer" => "system".to_string(),
213        _ => "human".to_string(), // Default to human for unknown roles
214    }
215}
216
217/// Convert messages to a single Claude prompt
218fn _messages_to_prompt(messages: &[Message]) -> String {
219    let mut prompt = String::new();
220
221    for msg in messages {
222        let role = _normalize_role(&msg.role);
223        let content = msg.content_as_string();
224        match role.as_str() {
225            "system" => prompt.push_str(&format!("System: {content}\n\n")),
226            "human" => prompt.push_str(&format!("Human: {content}\n\n")),
227            "assistant" => prompt.push_str(&format!("Assistant: {content}\n\n")),
228            _ => prompt.push_str(&format!("{role}: {content}\n\n")),
229        }
230    }
231
232    prompt
233}
234
235/// Generate a consistent conversation key based on message context
236fn generate_conversation_key(messages: &[Message]) -> String {
237    use std::collections::hash_map::DefaultHasher;
238    use std::hash::{Hash, Hasher};
239    
240    // Create a consistent hash based on the conversation context (all messages except the last user message)
241    let context_messages = if messages.len() > 1 {
242        &messages[..messages.len() - 1]
243    } else {
244        &[]
245    };
246    
247    let mut hasher = DefaultHasher::new();
248    for msg in context_messages {
249        msg.role.hash(&mut hasher);
250        msg.content_as_string().hash(&mut hasher);
251    }
252    
253    format!("ctx:{:016x}", hasher.finish())
254}
255
256/// Find a matching conversation by context hash
257/// Returns the key of the matching conversation if found
258fn find_matching_conversation(conversations: &DashMap<String, (String, Vec<Message>)>, messages: &[Message]) -> Option<String> {
259    if messages.is_empty() {
260        return None;
261    }
262
263    // Generate the context key for the current request
264    let context_key = generate_conversation_key(messages);
265    
266    // Check if we have an exact context match
267    if conversations.contains_key(&context_key) {
268        return Some(context_key);
269    }
270    
271    // If no exact match, try to find a conversation that could be extended
272    // Look for conversations where our context is a prefix of their context
273    let our_context = if messages.len() > 1 {
274        &messages[..messages.len() - 1]
275    } else {
276        &[]
277    };
278    
279    for item in conversations.iter() {
280        let (_, stored_history) = item.value();
281        
282        // Check if our context is a prefix of the stored history
283        if stored_history.len() >= our_context.len() {
284            let prefix_matches = our_context.iter().zip(stored_history.iter())
285                .all(|(a, b)| a.role == b.role && a.content_as_string() == b.content_as_string());
286            
287            if prefix_matches {
288                return Some(item.key().clone());
289            }
290        }
291    }
292    
293    None
294}
295
296/// Chat completion API handler
297async fn chat_completion_handler(
298    State(state): State<AppState>,
299    Json(request): Json<ChatCompletionRequest>,
300) -> impl IntoResponse {
301    // Debug the raw client request
302    // dbg!(&request);
303
304    // Log when array content is used
305    // Log message content types
306    for (i, msg) in request.messages.iter().enumerate() {
307        match &msg.content {
308            ContentValue::Array(_) => info!("Message at index {} has array content", i),
309            ContentValue::Object(_) => info!("Message at index {} has object content", i),
310            ContentValue::ContentBlocks(_) => info!("Message at index {} has content blocks", i),
311            _ => {} // String content is the default, no need to log
312        }
313    }
314
315    let model_id = get_model(&request.model);
316
317    if request.messages.is_empty() {
318        return (
319            StatusCode::BAD_REQUEST,
320            Json(serde_json::json!({
321                "error": "No messages provided in the request"
322            })),
323        );
324    }
325
326    let messages = request.messages;
327
328    // Generate a request identifier from the user field or a new UUID if not provided
329    // We don't directly use this anymore, but keep it for future use
330    let _request_id = request.user.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
331    
332    // The actual user query is the last message
333    let user_query = messages.last().unwrap().clone();
334
335    let claude = state.claude;
336    let conversations = state.conversations;
337
338    // Try to find existing conversation by matching message history
339    let mut conversation_id: Option<String> = None;
340    let mut message_history: Vec<Message> = Vec::new();
341    
342    info!("Looking for matching conversation by message history");
343    info!("Available conversations: {}", conversations.len());
344
345    // Find matching conversation based on message history
346    let matching_key = find_matching_conversation(&conversations, &messages);
347
348    if let Some(ref key) = matching_key {
349        if let Some(existing) = conversations.get(key) {
350            conversation_id = Some(existing.0.clone());
351            message_history = existing.1.clone();
352            info!("Found conversation with matching history, key: '{}'", key);
353        }
354    } else {
355        info!("No existing conversation with matching history found");
356    }
357
358    // If no matching conversation, create a new one
359    if conversation_id.is_none() {
360        match claude.create_chat().await {
361            Ok(new_id) => {
362                conversation_id = Some(new_id.clone());
363                message_history = messages[..messages.len() - 1].to_vec();
364
365                // Send system prompt if present and include it in history
366                if let Some(system_msg) = messages
367                    .iter()
368                    .find(|m| m.role == "system" || m.role == "developer")
369                {
370                    if let Err(e) = claude
371                        .send_message(&new_id, &system_msg.content_as_string(), &[])
372                        .await
373                    {
374                        warn!("Failed to send system prompt: {}", e);
375                    }
376                }
377
378                // Generate context-based key for new conversation
379                let context_key = generate_conversation_key(&messages);
380                conversations.insert(
381                    context_key,
382                    (new_id.clone(), message_history.clone()),
383                );
384            }
385            Err(e) => {
386                let error_response = serde_json::json!({
387                    "error": format!("Failed to create conversation: {}", e)
388                });
389                dbg!(&error_response);
390                return (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response));
391            }
392        }
393    }
394
395    // Send the user message and get response
396    let completion = match claude
397        .send_message(
398            &conversation_id.clone().unwrap(),
399            &user_query.content_as_string(),
400            &[],
401        )
402        .await
403    {
404        Ok(response) => response,
405        Err(e) => {
406            let error_response = serde_json::json!({
407                "error": format!("Failed to get completion: {}", e)
408            });
409            dbg!(&error_response);
410            return (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response));
411        }
412    };
413
414    // Update the message history with the new assistant response
415    let new_message = Message {
416        role: "assistant".to_string(),
417        content: ContentValue::String(completion.clone()),
418    };
419
420    // Update message history and store it
421    let mut updated_history = message_history.clone();
422    updated_history.push(user_query.clone());
423    updated_history.push(new_message);
424
425    let conv_id = conversation_id.unwrap();
426    
427    // Generate context-based key for the updated conversation
428    let mut extended_messages = messages.clone();
429    extended_messages.push(Message {
430        role: "assistant".to_string(),
431        content: ContentValue::String(completion.clone()),
432    });
433    let updated_key = generate_conversation_key(&extended_messages);
434    
435    // Remove old entry if we had a matching key and it's different from the new key
436    if let Some(ref old_key) = matching_key {
437        if old_key != &updated_key
438            && conversations.remove(old_key).is_some() {
439                info!("Removed old conversation entry with key: '{}'", old_key);
440            }
441    }
442
443    // Store with the updated context key
444    conversations.insert(
445        updated_key.clone(),
446        (conv_id.clone(), updated_history.clone()),
447    );
448    
449    // Log the conversation storage
450    info!(
451        "Stored conversation with ID: {} under context key '{}'",
452        conv_id,
453        updated_key
454    );
455
456    // Create OpenAI-compatible response
457    let response = OpenAIResponse {
458        id: Uuid::new_v4().to_string(),
459        object: "chat.completion",
460        created: chrono::Utc::now().timestamp() as u64,
461        model: model_id.to_string(),
462        system_fingerprint: "fp_toast".to_string(),
463        choices: vec![Choice {
464            index: 0,
465            message: Message {
466                role: "assistant".to_string(),
467                content: ContentValue::String(completion.clone()),
468            },
469            logprobs: None,
470            finish_reason: "stop".to_string(),
471        }],
472        usage: Usage {
473            prompt_tokens: user_query.content_as_string().chars().count() as u32 / 4, // Rough estimate
474            completion_tokens: completion.chars().count() as u32 / 4, // Rough estimate
475            total_tokens: (user_query.content_as_string().chars().count()
476                + completion.chars().count()) as u32
477                / 4,
478        },
479    };
480
481    // Debug the response being sent to the client
482    // dbg!(&response);
483
484    (
485        StatusCode::OK,
486        Json(serde_json::to_value(response).unwrap()),
487    )
488}
489
490/// Health check endpoint
491async fn health_check() -> impl IntoResponse {
492    (StatusCode::OK, "OK")
493}
494
495/// Start the Axum server
496pub async fn run() -> anyhow::Result<()> {
497    // For backward compatibility, parse arguments directly when run() is called
498    run_with_args(Args::parse()).await
499}
500
501/// Start the Axum server with provided arguments
502pub async fn run_with_args(args: Args) -> anyhow::Result<()> {
503    // Use provided arguments instead of parsing them
504
505    // Initialize tracing
506    if args.debug {
507        // Full logging in debug mode
508        tracing_subscriber::fmt::init();
509    } else {
510        // Only show warnings and errors in normal mode
511        tracing_subscriber::fmt()
512            .with_max_level(tracing::Level::WARN)
513            .init();
514    }
515
516    // Load config
517    let config_dir = dirs::config_dir()
518        .ok_or_else(|| anyhow::anyhow!("Could not determine config directory"))?
519        .join("toast");
520
521    let cookie_path = config_dir.join("cookie");
522    let org_id_path = config_dir.join("org_id");
523
524    // Check if config directory exists
525    if !config_dir.exists() {
526        return Err(anyhow::anyhow!(
527            "Configuration directory does not exist at {:?}",
528            config_dir
529        ));
530    }
531
532    // Check and load cookie
533    let cookie = if cookie_path.exists() {
534        std::fs::read_to_string(&cookie_path)?.trim().to_string()
535    } else {
536        return Err(anyhow::anyhow!(
537            "Cookie file not found at {:?}",
538            cookie_path
539        ));
540    };
541
542    // Check and load org_id
543    let org_id = if org_id_path.exists() {
544        std::fs::read_to_string(&org_id_path)?.trim().to_string()
545    } else {
546        // Try to extract org_id from cookie
547        if let Some(extracted_org_id) = crate::utils::extract_org_id_from_cookie(&cookie) {
548            // Save the extracted org_id to the file for future use
549            std::fs::write(&org_id_path, &extracted_org_id)?;
550            info!(
551                "Extracted organization ID from cookie and saved to {:?}",
552                org_id_path
553            );
554            extracted_org_id
555        } else {
556            return Err(anyhow::anyhow!(
557                "Organization ID file not found at {:?} and couldn't extract it from cookie",
558                org_id_path
559            ));
560        }
561    };
562
563    let user_agent =
564        "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:137.0) Gecko/20100101 Firefox/137.0"
565            .to_string();
566
567    let session = Session {
568        cookie,
569        user_agent,
570        organization_id: org_id,
571    };
572
573    // Use the configurable model from arguments or default to Sonnet
574    let model_str = crate::config::SONNET_MODEL;
575
576    let claude = Arc::new(Claude::new(session.clone(), model_str)?);
577
578    // Create app state
579    let app_state = AppState {
580        claude,
581        conversations: Arc::new(DashMap::new()),
582    };
583
584    // Set up CORS
585    let cors = CorsLayer::new()
586        .allow_origin(Any)
587        .allow_methods(Any)
588        .allow_headers(Any);
589
590    // Create router
591    let app = Router::new()
592        .route("/health", get(health_check))
593        .route("/v1/chat/completions", post(chat_completion_handler))
594        .route("/", post(chat_completion_handler))
595        .route("/chat/completions", post(chat_completion_handler))
596        .layer(cors)
597        .with_state(app_state);
598
599    // Parse host address
600    let host_parts: Vec<u8> = args
601        .host
602        .split('.')
603        .map(|s| s.parse::<u8>().unwrap_or(0))
604        .collect();
605
606    let host_addr = if host_parts.len() == 4 {
607        [host_parts[0], host_parts[1], host_parts[2], host_parts[3]]
608    } else {
609        [0, 0, 0, 0]
610    };
611
612    // Run server
613    let addr = SocketAddr::from((host_addr, args.port));
614    info!("Server listening on {}", addr);
615
616    // Print usage example
617    info!("Example request:");
618    info!(
619        "curl -X POST http://localhost:{}{} \\",
620        args.port, "/v1/chat/completions  # Also available at / and /chat/completions"
621    );
622    info!("  -H \"Content-Type: application/json\" \\");
623    info!("  -d '{{");
624    info!("    \"model\": \"gpt-4.1\",");
625    info!("    \"messages\": [");
626    info!("      {{");
627    info!("        \"role\": \"developer\",");
628    info!("        \"content\": \"You are a helpful assistant.\"");
629    info!("      }},");
630    info!("      {{");
631    info!("        \"role\": \"user\",");
632    info!("        \"content\": \"Hello! (Simple string content)\"");
633    info!("      }},");
634    info!("      {{");
635    info!("        \"role\": \"user\",");
636    info!("        \"content\": [\"Array of\", \"string content\"]");
637    info!("      }},");
638    info!("      {{");
639    info!("        \"role\": \"user\",");
640    info!("        \"content\": {{\"type\": \"text\", \"text\": \"Object with text field\"}}");
641    info!("      }},");
642    info!("      {{");
643    info!("        \"role\": \"user\",");
644    info!("        \"content\": [");
645    info!("          {{\"type\": \"text\", \"text\": \"Array of content blocks\"}},");
646    info!("          {{\"type\": \"text\", \"text\": \"with multiple entries\"}}");
647    info!("        ]");
648    info!("      }}");
649    info!("    ]");
650    info!("  }}'");
651
652    let listener = tokio::net::TcpListener::bind(addr).await?;
653    axum::serve(listener, app).await?;
654
655    Ok(())
656}