toast_api/
server.rs

1use axum::{
2    extract::State,
3    http::StatusCode,
4    response::IntoResponse,
5    routing::{get, post},
6    Json, Router,
7};
8use clap::Parser;
9use dashmap::DashMap;
10use serde::{Deserialize, Serialize};
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tower_http::cors::{Any, CorsLayer};
14use tracing::{info, warn};
15use uuid::Uuid;
16
17use crate::api::{Claude, Session};
18
19/// Command line arguments for server configuration
20#[derive(Parser, Debug, Clone)]
21#[clap(author, version, about = "OpenAI-compatible Claude API server")]
22pub struct Args {
23    /// Port to listen on - public field
24    #[clap(long, short, default_value = "3000")]
25    pub port: u16,
26
27    /// Enable debug logging
28    #[clap(long)]
29    pub debug: bool,
30
31    /// Default Claude model to use
32    #[clap(long, default_value = "claude-3-7-sonnet-latest")]
33    pub model: String,
34
35    /// Host address to bind to
36    #[clap(long, default_value = "0.0.0.0")]
37    pub host: String,
38}
39
40/// Server state shared across all requests
41#[derive(Clone)]
42pub struct AppState {
43    pub claude: Arc<Claude>,
44    // Map to track conversations: Vec<Message> -> (conversation_id, Vec<Message>)
45    pub conversations: Arc<DashMap<String, (String, Vec<Message>)>>,
46}
47
48/// Base response structure for OpenAI-compatible API
49#[derive(Serialize, Debug)]
50pub struct OpenAIResponse {
51    id: String,
52    object: &'static str,
53    created: u64,
54    model: String,
55    system_fingerprint: String,
56    choices: Vec<Choice>,
57    usage: Usage,
58}
59
60/// Token usage information
61#[derive(Serialize, Debug)]
62pub struct Usage {
63    prompt_tokens: u32,
64    completion_tokens: u32,
65    total_tokens: u32,
66}
67
68/// Chat message structure (OpenAI compatible)
69#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
70pub struct Message {
71    role: String,
72    content: ContentValue,
73}
74
75/// Content value can be string, array of strings, or structured content
76#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
77#[serde(untagged)]
78pub enum ContentValue {
79    String(String),
80    Array(Vec<String>),
81    Object(ContentObject),
82    ContentBlocks(Vec<ContentBlock>),
83}
84
85/// Content object with type and other fields
86#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
87pub struct ContentObject {
88    #[serde(default)]
89    pub text: Option<String>,
90    #[serde(default)]
91    pub r#type: Option<String>,
92    // Allow for other fields
93    #[serde(flatten)]
94    pub other: std::collections::HashMap<String, serde_json::Value>,
95}
96
97/// Structured content block
98#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
99pub struct ContentBlock {
100    pub r#type: String,
101    #[serde(default)]
102    pub text: Option<String>,
103    // Allow for other fields
104    #[serde(flatten)]
105    pub other: std::collections::HashMap<String, serde_json::Value>,
106}
107
108impl Message {
109    // Helper to get content as a string regardless of format
110    pub fn content_as_string(&self) -> String {
111        match &self.content {
112            ContentValue::String(content) => content.clone(),
113            ContentValue::Array(content) => content.join(" "),
114            ContentValue::Object(obj) => {
115                if let Some(text) = &obj.text {
116                    text.clone()
117                } else {
118                    // Try to convert to JSON string if no text field
119                    serde_json::to_string(obj).unwrap_or_default()
120                }
121            }
122            ContentValue::ContentBlocks(blocks) => blocks
123                .iter()
124                .filter_map(|block| block.text.as_ref())
125                .cloned()
126                .collect::<Vec<_>>()
127                .join(" "),
128        }
129    }
130}
131
132#[derive(Debug, Deserialize)]
133pub struct MessageInner {
134    role: String,
135    content: String,
136}
137
138impl From<MessageInner> for Message {
139    fn from(inner: MessageInner) -> Self {
140        Self {
141            role: inner.role,
142            content: ContentValue::String(inner.content),
143        }
144    }
145}
146
147/// Deserializer for message content that supports both strings and arrays
148
149/// Custom serialization/deserialization for message content
150/// Supports either a single string or an array of strings
151mod content_string_or_array {
152    use serde::{self, Deserialize, Deserializer, Serializer};
153
154    pub fn _serialize<S>(content: &str, serializer: S) -> Result<S::Ok, S::Error>
155    where
156        S: Serializer,
157    {
158        serializer.serialize_str(content)
159    }
160
161    pub fn _deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
162    where
163        D: Deserializer<'de>,
164    {
165        #[derive(Deserialize)]
166        #[serde(untagged)]
167        enum StringOrArray {
168            String(()),
169            Array(()),
170        }
171
172        let value = StringOrArray::deserialize(deserializer)?;
173        match value {
174            StringOrArray::String(_) => Ok(String::new()),
175            StringOrArray::Array(_) => Ok(String::new()),
176        }
177    }
178}
179
180/// Chat completion request (OpenAI compatible)
181#[derive(Deserialize, Serialize, Debug)]
182pub struct ChatCompletionRequest {
183    model: String,
184    messages: Vec<Message>,
185    #[serde(default)]
186    user: Option<String>,  // Optional user identifier for thread association
187}
188
189/// Chat completion response choices (OpenAI compatible)
190#[derive(Serialize, Debug)]
191pub struct Choice {
192    index: i32,
193    message: Message,
194    logprobs: Option<serde_json::Value>,
195    finish_reason: String,
196}
197
198/// Helper function to get model from request or default to Sonnet
199fn get_model(model_name: &str) -> &'static str {
200    match model_name {
201        "gpt-3.5-turbo" => crate::config::HAIKU_MODEL,
202        "gpt-4" | "gpt-4-turbo" => crate::config::SONNET_MODEL,
203        "gpt-4o" | "gpt-4.1" => crate::config::OPUS_MODEL,
204        _ => crate::config::SONNET_MODEL,
205    }
206}
207
208/// Normalize roles for Claude API
209fn _normalize_role(role: &str) -> String {
210    match role {
211        "assistant" => "assistant".to_string(),
212        "user" => "human".to_string(),
213        "system" | "developer" => "system".to_string(),
214        _ => "human".to_string(), // Default to human for unknown roles
215    }
216}
217
218/// Convert messages to a single Claude prompt
219fn _messages_to_prompt(messages: &[Message]) -> String {
220    let mut prompt = String::new();
221
222    for msg in messages {
223        let role = _normalize_role(&msg.role);
224        let content = msg.content_as_string();
225        match role.as_str() {
226            "system" => prompt.push_str(&format!("System: {}\n\n", content)),
227            "human" => prompt.push_str(&format!("Human: {}\n\n", content)),
228            "assistant" => prompt.push_str(&format!("Assistant: {}\n\n", content)),
229            _ => prompt.push_str(&format!("{}: {}\n\n", role, content)),
230        }
231    }
232
233    prompt
234}
235
236/// Generate a consistent conversation key based on message context
237fn generate_conversation_key(messages: &[Message]) -> String {
238    use std::collections::hash_map::DefaultHasher;
239    use std::hash::{Hash, Hasher};
240    
241    // Create a consistent hash based on the conversation context (all messages except the last user message)
242    let context_messages = if messages.len() > 1 {
243        &messages[..messages.len() - 1]
244    } else {
245        &[]
246    };
247    
248    let mut hasher = DefaultHasher::new();
249    for msg in context_messages {
250        msg.role.hash(&mut hasher);
251        msg.content_as_string().hash(&mut hasher);
252    }
253    
254    format!("ctx:{:016x}", hasher.finish())
255}
256
257/// Find a matching conversation by context hash
258/// Returns the key of the matching conversation if found
259fn find_matching_conversation(conversations: &DashMap<String, (String, Vec<Message>)>, messages: &[Message]) -> Option<String> {
260    if messages.is_empty() {
261        return None;
262    }
263
264    // Generate the context key for the current request
265    let context_key = generate_conversation_key(messages);
266    
267    // Check if we have an exact context match
268    if conversations.contains_key(&context_key) {
269        return Some(context_key);
270    }
271    
272    // If no exact match, try to find a conversation that could be extended
273    // Look for conversations where our context is a prefix of their context
274    let our_context = if messages.len() > 1 {
275        &messages[..messages.len() - 1]
276    } else {
277        &[]
278    };
279    
280    for item in conversations.iter() {
281        let (_, stored_history) = &*item.value();
282        
283        // Check if our context is a prefix of the stored history
284        if stored_history.len() >= our_context.len() {
285            let prefix_matches = our_context.iter().zip(stored_history.iter())
286                .all(|(a, b)| a.role == b.role && a.content_as_string() == b.content_as_string());
287            
288            if prefix_matches {
289                return Some(item.key().clone());
290            }
291        }
292    }
293    
294    None
295}
296
297/// Chat completion API handler
298async fn chat_completion_handler(
299    State(state): State<AppState>,
300    Json(request): Json<ChatCompletionRequest>,
301) -> impl IntoResponse {
302    // Debug the raw client request
303    // dbg!(&request);
304
305    // Log when array content is used
306    // Log message content types
307    for (i, msg) in request.messages.iter().enumerate() {
308        match &msg.content {
309            ContentValue::Array(_) => info!("Message at index {} has array content", i),
310            ContentValue::Object(_) => info!("Message at index {} has object content", i),
311            ContentValue::ContentBlocks(_) => info!("Message at index {} has content blocks", i),
312            _ => {} // String content is the default, no need to log
313        }
314    }
315
316    let model_id = get_model(&request.model);
317
318    if request.messages.is_empty() {
319        return (
320            StatusCode::BAD_REQUEST,
321            Json(serde_json::json!({
322                "error": "No messages provided in the request"
323            })),
324        );
325    }
326
327    let messages = request.messages;
328
329    // Generate a request identifier from the user field or a new UUID if not provided
330    // We don't directly use this anymore, but keep it for future use
331    let _request_id = request.user.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
332    
333    // The actual user query is the last message
334    let user_query = messages.last().unwrap().clone();
335
336    let claude = state.claude;
337    let conversations = state.conversations;
338
339    // Try to find existing conversation by matching message history
340    let mut conversation_id: Option<String> = None;
341    let mut message_history: Vec<Message> = Vec::new();
342    
343    info!("Looking for matching conversation by message history");
344    info!("Available conversations: {}", conversations.len());
345
346    // Find matching conversation based on message history
347    let matching_key = find_matching_conversation(&conversations, &messages);
348
349    if let Some(ref key) = matching_key {
350        if let Some(existing) = conversations.get(key) {
351            conversation_id = Some(existing.0.clone());
352            message_history = existing.1.clone();
353            info!("Found conversation with matching history, key: '{}'", key);
354        }
355    } else {
356        info!("No existing conversation with matching history found");
357    }
358
359    // If no matching conversation, create a new one
360    if conversation_id.is_none() {
361        match claude.create_chat().await {
362            Ok(new_id) => {
363                conversation_id = Some(new_id.clone());
364                message_history = messages[..messages.len() - 1].to_vec();
365
366                // Send system prompt if present and include it in history
367                if let Some(system_msg) = messages
368                    .iter()
369                    .find(|m| m.role == "system" || m.role == "developer")
370                {
371                    if let Err(e) = claude
372                        .send_message(&new_id, &system_msg.content_as_string(), &[])
373                        .await
374                    {
375                        warn!("Failed to send system prompt: {}", e);
376                    }
377                }
378
379                // Generate context-based key for new conversation
380                let context_key = generate_conversation_key(&messages);
381                conversations.insert(
382                    context_key,
383                    (new_id.clone(), message_history.clone()),
384                );
385            }
386            Err(e) => {
387                let error_response = serde_json::json!({
388                    "error": format!("Failed to create conversation: {}", e)
389                });
390                dbg!(&error_response);
391                return (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response));
392            }
393        }
394    }
395
396    // Send the user message and get response
397    let completion = match claude
398        .send_message(
399            &conversation_id.clone().unwrap(),
400            &user_query.content_as_string(),
401            &[],
402        )
403        .await
404    {
405        Ok(response) => response,
406        Err(e) => {
407            let error_response = serde_json::json!({
408                "error": format!("Failed to get completion: {}", e)
409            });
410            dbg!(&error_response);
411            return (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response));
412        }
413    };
414
415    // Update the message history with the new assistant response
416    let new_message = Message {
417        role: "assistant".to_string(),
418        content: ContentValue::String(completion.clone()),
419    };
420
421    // Update message history and store it
422    let mut updated_history = message_history.clone();
423    updated_history.push(user_query.clone());
424    updated_history.push(new_message);
425
426    let conv_id = conversation_id.unwrap();
427    
428    // Generate context-based key for the updated conversation
429    let mut extended_messages = messages.clone();
430    extended_messages.push(Message {
431        role: "assistant".to_string(),
432        content: ContentValue::String(completion.clone()),
433    });
434    let updated_key = generate_conversation_key(&extended_messages);
435    
436    // Remove old entry if we had a matching key and it's different from the new key
437    if let Some(ref old_key) = matching_key {
438        if old_key != &updated_key {
439            if let Some(_) = conversations.remove(old_key) {
440                info!("Removed old conversation entry with key: '{}'", old_key);
441            }
442        }
443    }
444
445    // Store with the updated context key
446    conversations.insert(
447        updated_key.clone(),
448        (conv_id.clone(), updated_history.clone()),
449    );
450    
451    // Log the conversation storage
452    info!(
453        "Stored conversation with ID: {} under context key '{}'",
454        conv_id,
455        updated_key
456    );
457
458    // Create OpenAI-compatible response
459    let response = OpenAIResponse {
460        id: Uuid::new_v4().to_string(),
461        object: "chat.completion",
462        created: chrono::Utc::now().timestamp() as u64,
463        model: model_id.to_string(),
464        system_fingerprint: "fp_toast".to_string(),
465        choices: vec![Choice {
466            index: 0,
467            message: Message {
468                role: "assistant".to_string(),
469                content: ContentValue::String(completion.clone()),
470            },
471            logprobs: None,
472            finish_reason: "stop".to_string(),
473        }],
474        usage: Usage {
475            prompt_tokens: user_query.content_as_string().chars().count() as u32 / 4, // Rough estimate
476            completion_tokens: completion.chars().count() as u32 / 4, // Rough estimate
477            total_tokens: (user_query.content_as_string().chars().count()
478                + completion.chars().count()) as u32
479                / 4,
480        },
481    };
482
483    // Debug the response being sent to the client
484    // dbg!(&response);
485
486    (
487        StatusCode::OK,
488        Json(serde_json::to_value(response).unwrap()),
489    )
490}
491
492/// Health check endpoint
493async fn health_check() -> impl IntoResponse {
494    (StatusCode::OK, "OK")
495}
496
497/// Start the Axum server
498pub async fn run() -> anyhow::Result<()> {
499    // For backward compatibility, parse arguments directly when run() is called
500    run_with_args(Args::parse()).await
501}
502
503/// Start the Axum server with provided arguments
504pub async fn run_with_args(args: Args) -> anyhow::Result<()> {
505    // Use provided arguments instead of parsing them
506
507    // Initialize tracing
508    if args.debug {
509        // Full logging in debug mode
510        tracing_subscriber::fmt::init();
511    } else {
512        // Only show warnings and errors in normal mode
513        tracing_subscriber::fmt()
514            .with_max_level(tracing::Level::WARN)
515            .init();
516    }
517
518    // Load config
519    let config_dir = dirs::config_dir()
520        .ok_or_else(|| anyhow::anyhow!("Could not determine config directory"))?
521        .join("toast");
522
523    let cookie_path = config_dir.join("cookie");
524    let org_id_path = config_dir.join("org_id");
525
526    // Check if config directory exists
527    if !config_dir.exists() {
528        return Err(anyhow::anyhow!(
529            "Configuration directory does not exist at {:?}",
530            config_dir
531        ));
532    }
533
534    // Check and load cookie
535    let cookie = if cookie_path.exists() {
536        std::fs::read_to_string(&cookie_path)?.trim().to_string()
537    } else {
538        return Err(anyhow::anyhow!(
539            "Cookie file not found at {:?}",
540            cookie_path
541        ));
542    };
543
544    // Check and load org_id
545    let org_id = if org_id_path.exists() {
546        std::fs::read_to_string(&org_id_path)?.trim().to_string()
547    } else {
548        // Try to extract org_id from cookie
549        if let Some(extracted_org_id) = crate::cli::extract_org_id_from_cookie(&cookie) {
550            // Save the extracted org_id to the file for future use
551            std::fs::write(&org_id_path, &extracted_org_id)?;
552            info!(
553                "Extracted organization ID from cookie and saved to {:?}",
554                org_id_path
555            );
556            extracted_org_id
557        } else {
558            return Err(anyhow::anyhow!(
559                "Organization ID file not found at {:?} and couldn't extract it from cookie",
560                org_id_path
561            ));
562        }
563    };
564
565    let user_agent =
566        "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:137.0) Gecko/20100101 Firefox/137.0"
567            .to_string();
568
569    let session = Session {
570        cookie,
571        user_agent,
572        organization_id: org_id,
573    };
574
575    // Use the configurable model from arguments or default to Sonnet
576    let model_str = crate::config::SONNET_MODEL;
577
578    let claude = Arc::new(Claude::new(session.clone(), model_str)?);
579
580    // Create app state
581    let app_state = AppState {
582        claude,
583        conversations: Arc::new(DashMap::new()),
584    };
585
586    // Set up CORS
587    let cors = CorsLayer::new()
588        .allow_origin(Any)
589        .allow_methods(Any)
590        .allow_headers(Any);
591
592    // Create router
593    let app = Router::new()
594        .route("/health", get(health_check))
595        .route("/v1/chat/completions", post(chat_completion_handler))
596        .route("/", post(chat_completion_handler))
597        .route("/chat/completions", post(chat_completion_handler))
598        .layer(cors)
599        .with_state(app_state);
600
601    // Parse host address
602    let host_parts: Vec<u8> = args
603        .host
604        .split('.')
605        .map(|s| s.parse::<u8>().unwrap_or(0))
606        .collect();
607
608    let host_addr = if host_parts.len() == 4 {
609        [host_parts[0], host_parts[1], host_parts[2], host_parts[3]]
610    } else {
611        [0, 0, 0, 0]
612    };
613
614    // Run server
615    let addr = SocketAddr::from((host_addr, args.port));
616    info!("Server listening on {}", addr);
617
618    // Print usage example
619    info!("Example request:");
620    info!(
621        "curl -X POST http://localhost:{}{} \\",
622        args.port, "/v1/chat/completions  # Also available at / and /chat/completions"
623    );
624    info!("  -H \"Content-Type: application/json\" \\");
625    info!("  -d '{{");
626    info!("    \"model\": \"gpt-4.1\",");
627    info!("    \"messages\": [");
628    info!("      {{");
629    info!("        \"role\": \"developer\",");
630    info!("        \"content\": \"You are a helpful assistant.\"");
631    info!("      }},");
632    info!("      {{");
633    info!("        \"role\": \"user\",");
634    info!("        \"content\": \"Hello! (Simple string content)\"");
635    info!("      }},");
636    info!("      {{");
637    info!("        \"role\": \"user\",");
638    info!("        \"content\": [\"Array of\", \"string content\"]");
639    info!("      }},");
640    info!("      {{");
641    info!("        \"role\": \"user\",");
642    info!("        \"content\": {{\"type\": \"text\", \"text\": \"Object with text field\"}}");
643    info!("      }},");
644    info!("      {{");
645    info!("        \"role\": \"user\",");
646    info!("        \"content\": [");
647    info!("          {{\"type\": \"text\", \"text\": \"Array of content blocks\"}},");
648    info!("          {{\"type\": \"text\", \"text\": \"with multiple entries\"}}");
649    info!("        ]");
650    info!("      }}");
651    info!("    ]");
652    info!("  }}'");
653
654    let listener = tokio::net::TcpListener::bind(addr).await?;
655    axum::serve(listener, app).await?;
656
657    Ok(())
658}