smos_application/types/
chat_request.rs1use serde::{Deserialize, Serialize};
10use serde_json::{Map, Value};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ChatRequest {
15 pub model: String,
16
17 pub messages: Vec<Value>,
21
22 #[serde(flatten)]
26 pub extra: Map<String, Value>,
27}
28
29impl ChatRequest {
30 pub fn new(model: impl Into<String>, messages: Vec<Value>) -> Self {
32 Self {
33 model: model.into(),
34 messages,
35 extra: Map::new(),
36 }
37 }
38
39 pub fn with_extra(mut self, key: impl Into<String>, value: Value) -> Self {
41 self.extra.insert(key.into(), value);
42 self
43 }
44
45 pub fn extra(&self, key: &str) -> Option<&Value> {
47 self.extra.get(key)
48 }
49
50 pub fn is_streaming(&self) -> bool {
52 self.extra
53 .get("stream")
54 .and_then(Value::as_bool)
55 .unwrap_or(false)
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62 use serde_json::json;
63
64 #[test]
65 fn serialises_known_fields_at_top_level() {
66 let req = ChatRequest::new("gpt-4o", vec![json!({"role": "user", "content": "hi"})]);
67 let v: Value = serde_json::to_value(&req).unwrap();
68 assert_eq!(v["model"], "gpt-4o");
69 assert_eq!(v["messages"][0]["role"], "user");
70 }
71
72 #[test]
73 fn extra_fields_flatten_alongside_known_fields() {
74 let req = ChatRequest::new("m", vec![]).with_extra("temperature", json!(0.7));
75 let v: Value = serde_json::to_value(&req).unwrap();
76 assert_eq!(v["temperature"], 0.7);
77 assert!(v.get("extra").is_none());
78 }
79
80 #[test]
81 fn deserialises_unknown_fields_into_extra() {
82 let raw = serde_json::json!({
83 "model": "m",
84 "messages": [],
85 "temperature": 0.3,
86 "tools": [{"type": "function"}],
87 });
88 let req: ChatRequest = serde_json::from_value(raw).unwrap();
89 assert_eq!(req.model, "m");
90 assert_eq!(req.extra("temperature"), Some(&json!(0.3)));
91 assert!(req.extra("tools").is_some());
92 }
93
94 #[test]
95 fn roundtrip_preserves_all_fields() {
96 let req = ChatRequest::new("m", vec![json!({"role": "system"})])
97 .with_extra("stream", json!(true))
98 .with_extra("max_tokens", json!(128));
99 let json = serde_json::to_string(&req).unwrap();
100 let back: ChatRequest = serde_json::from_str(&json).unwrap();
101 assert_eq!(back.model, "m");
102 assert_eq!(back.extra("stream"), Some(&json!(true)));
103 assert_eq!(back.extra("max_tokens"), Some(&json!(128)));
104 }
105
106 #[test]
107 fn is_streaming_reads_extra_bool() {
108 let streaming = ChatRequest::new("m", vec![]).with_extra("stream", json!(true));
109 let non_streaming = ChatRequest::new("m", vec![]).with_extra("stream", json!(false));
110 let unset = ChatRequest::new("m", vec![]);
111 assert!(streaming.is_streaming());
112 assert!(!non_streaming.is_streaming());
113 assert!(!unset.is_streaming());
114 }
115}