watsonx_rs/orchestrate/
chat.rs

1//! Chat and messaging operations
2
3use crate::error::{Error, Result};
4use super::types::{Message, MessagePayload, ChatWithDocsRequest, ChatWithDocsResponse, ChatWithDocsStatus};
5use super::OrchestrateClient;
6use std::collections::HashMap;
7use serde_json::Value;
8use futures::StreamExt;
9
10#[derive(serde::Deserialize)]
11struct EventData {
12    event: String,
13    data: Value,
14}
15
16impl OrchestrateClient {
17    /// Send a message to an agent and get response (matches wxo-client pattern)
18    /// Uses /runs/stream endpoint and maintains thread_id for conversation continuity
19    pub async fn send_message(&self, agent_id: &str, message: &str, thread_id: Option<String>) -> Result<(String, Option<String>)> {
20        let token = self.access_token.as_ref().ok_or_else(|| {
21            Error::Authentication("Not authenticated. Set access token (Bearer token) first.".to_string())
22        })?;
23
24        let base_url = self.config.get_base_url();
25        let url = format!("{}/runs/stream", base_url);
26
27        let payload = MessagePayload {
28            message: Message {
29                role: "user".to_string(),
30                content: message.to_string(),
31            },
32            additional_properties: HashMap::new(),
33            context: HashMap::new(),
34            agent_id: agent_id.to_string(),
35            thread_id: thread_id.clone(),
36        };
37
38        let response = self
39            .client
40            .post(&url)
41            .header("Authorization", format!("Bearer {}", token))
42            .header("Content-Type", "application/json")
43            .header("X-Instance-ID", &self.config.instance_id)
44            .json(&payload)
45            .send()
46            .await
47            .map_err(|e| Error::Network(e.to_string()))?;
48
49        if !response.status().is_success() {
50            let status = response.status();
51            let error_text = response
52                .text()
53                .await
54                .unwrap_or_else(|_| "Unknown error".to_string());
55            return Err(Error::Api(format!(
56                "Failed to send message: {} - {}",
57                status, error_text
58            )));
59        }
60
61        let text = response.text().await.map_err(|e| Error::Network(e.to_string()))?;
62        let mut answer = String::new();
63        let mut new_thread_id = thread_id;
64
65        for line in text.lines() {
66            if !line.is_empty() {
67                if let Ok(event_data) = serde_json::from_str::<EventData>(&line) {
68                    if event_data.event == "message.created" {
69                        if let Some(data_obj) = event_data.data.as_object() {
70                            if let Some(message_obj) = data_obj.get("message") {
71                                if let Some(content_array) = message_obj.get("content").and_then(|c| c.as_array()) {
72                                    if let Some(first_content) = content_array.first() {
73                                        if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
74                                            answer = text.to_string();
75                                        }
76                                    }
77                                }
78                            }
79                            if let Some(tid) = data_obj.get("thread_id").and_then(|t| t.as_str()) {
80                                new_thread_id = Some(tid.to_string());
81                            }
82                            break;
83                        }
84                    }
85                }
86            }
87        }
88
89        Ok((answer, new_thread_id))
90    }
91
92    /// Stream response from an agent (matches wxo-client pattern)
93    pub async fn stream_message<F>(
94        &self,
95        agent_id: &str,
96        message: &str,
97        thread_id: Option<String>,
98        mut callback: F,
99    ) -> Result<Option<String>>
100    where
101        F: FnMut(String) -> Result<()>,
102    {
103        let token = self.access_token.as_ref().ok_or_else(|| {
104            Error::Authentication("Not authenticated. Set access token (Bearer token) first.".to_string())
105        })?;
106
107        let base_url = self.config.get_base_url();
108        let url = format!("{}/runs/stream", base_url);
109
110        let payload = MessagePayload {
111            message: Message {
112                role: "user".to_string(),
113                content: message.to_string(),
114            },
115            additional_properties: HashMap::new(),
116            context: HashMap::new(),
117            agent_id: agent_id.to_string(),
118            thread_id: thread_id.clone(),
119        };
120
121        let response = self
122            .client
123            .post(&url)
124            .header("Authorization", format!("Bearer {}", token))
125            .header("Content-Type", "application/json")
126            .header("Accept", "text/event-stream")
127            .header("Cache-Control", "no-cache")
128            .header("Connection", "keep-alive")
129            .header("X-Accel-Buffering", "no")
130            .header("X-Instance-ID", &self.config.instance_id)
131            .json(&payload)
132            .send()
133            .await
134            .map_err(|e| Error::Network(e.to_string()))?;
135
136        if !response.status().is_success() {
137            let status = response.status();
138            let error_text = response
139                .text()
140                .await
141                .unwrap_or_else(|_| "Unknown error".to_string());
142            return Err(Error::Api(format!(
143                "Failed to stream message: {} - {}",
144                status, error_text
145            )));
146        }
147
148        let mut stream = response.bytes_stream();
149        let mut buffer = Vec::<u8>::new();
150        let mut new_thread_id = thread_id;
151        let mut chunk_count = 0;
152
153        while let Some(chunk_result) = stream.next().await {
154            let chunk = chunk_result.map_err(|e| Error::Network(e.to_string()))?;
155            chunk_count += 1;
156
157            if chunk_count > 1 {
158                tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
159            }
160
161            buffer.extend_from_slice(&chunk);
162
163            loop {
164                let newline_pos = buffer.iter().position(|&b| b == b'\n');
165
166                if let Some(newline_pos) = newline_pos {
167                    let line_bytes = buffer[..newline_pos].to_vec();
168                    buffer = buffer[newline_pos + 1..].to_vec();
169
170                    if let Ok(line) = String::from_utf8(line_bytes) {
171                        let trimmed = line.trim();
172
173                        if !trimmed.is_empty() {
174                            if let Ok(event_data) = serde_json::from_str::<EventData>(trimmed) {
175                                if event_data.event == "message.delta" {
176                                    if let Some(data_obj) = event_data.data.as_object() {
177                                        if let Some(delta_obj) = data_obj.get("delta").and_then(|d| d.as_object()) {
178                                            if let Some(content_array) = delta_obj.get("content").and_then(|c| c.as_array()) {
179                                                if let Some(first_content) = content_array.first() {
180                                                    if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
181                                                        callback(text.to_string())?;
182                                                    }
183                                                }
184                                            }
185                                        } else if let Some(content_array) = data_obj.get("content").and_then(|c| c.as_array()) {
186                                            if let Some(first_content) = content_array.first() {
187                                                if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
188                                                    callback(text.to_string())?;
189                                                }
190                                            }
191                                        }
192                                        if let Some(tid) = data_obj.get("thread_id").and_then(|t| t.as_str()) {
193                                            new_thread_id = Some(tid.to_string());
194                                        }
195                                    }
196                                } else if event_data.event == "message.created" {
197                                    if let Some(data_obj) = event_data.data.as_object() {
198                                        if let Some(tid) = data_obj.get("thread_id").and_then(|t| t.as_str()) {
199                                            new_thread_id = Some(tid.to_string());
200                                        }
201                                    }
202                                }
203                            }
204                        }
205                    }
206                } else {
207                    break;
208                }
209            }
210        }
211
212        if !buffer.is_empty() {
213            if let Ok(line) = String::from_utf8(buffer) {
214                let trimmed = line.trim();
215                if !trimmed.is_empty() {
216                    if let Ok(event_data) = serde_json::from_str::<EventData>(trimmed) {
217                        if event_data.event == "message.delta" {
218                            if let Some(data_obj) = event_data.data.as_object() {
219                                if let Some(delta_obj) = data_obj.get("delta").and_then(|d| d.as_object()) {
220                                    if let Some(content_array) = delta_obj.get("content").and_then(|c| c.as_array()) {
221                                        if let Some(first_content) = content_array.first() {
222                                            if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
223                                                callback(text.to_string())?;
224                                            }
225                                        }
226                                    }
227                                } else if let Some(content_array) = data_obj.get("content").and_then(|c| c.as_array()) {
228                                    if let Some(first_content) = content_array.first() {
229                                        if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
230                                            callback(text.to_string())?;
231                                        }
232                                    }
233                                }
234                                if let Some(tid) = data_obj.get("thread_id").and_then(|t| t.as_str()) {
235                                    new_thread_id = Some(tid.to_string());
236                                }
237                            }
238                        }
239                    }
240                }
241            }
242        }
243
244        Ok(new_thread_id)
245    }
246
247    /// Get the status of chat with documents knowledge base for a thread
248    pub async fn get_chat_with_docs_status(&self, agent_id: &str, thread_id: &str) -> Result<ChatWithDocsStatus> {
249        let token = self.access_token.as_ref().ok_or_else(|| {
250            Error::Authentication("Not authenticated. Set access token (Bearer token) first.".to_string())
251        })?;
252
253        let base_url = self.config.get_base_url();
254        
255        let endpoints = vec![
256            format!("{}/orchestrate/agents/{}/threads/{}/chat_with_docs_status", base_url, agent_id, thread_id),
257            format!("{}/agents/{}/threads/{}/chat_with_docs_status", base_url, agent_id, thread_id),
258            format!("{}/agents/{}/threads/{}/chat_with_docs/status", base_url, agent_id, thread_id),
259        ];
260
261        for url in endpoints {
262            let response = self
263                .client
264                .get(&url)
265                .header("Authorization", format!("Bearer {}", token))
266                .header("Content-Type", "application/json")
267                .header("X-Instance-ID", &self.config.instance_id)
268                .send()
269                .await
270                .map_err(|e| Error::Network(e.to_string()))?;
271
272            if response.status().is_success() {
273                let status: ChatWithDocsStatus = response
274                    .json()
275                    .await
276                    .map_err(|e| Error::Serialization(e.to_string()))?;
277                return Ok(status);
278            }
279        }
280
281        Err(Error::Api(format!(
282            "Failed to get chat with docs status: All endpoint paths returned 404. Chat with documents may not be available in this instance."
283        )))
284    }
285
286    /// Send a message with document context (chat with documents)
287    pub async fn chat_with_docs(&self, agent_id: &str, thread_id: &str, request: ChatWithDocsRequest) -> Result<ChatWithDocsResponse> {
288        let token = self.access_token.as_ref().ok_or_else(|| {
289            Error::Authentication("Not authenticated. Set access token (Bearer token) first.".to_string())
290        })?;
291
292        let base_url = self.config.get_base_url();
293        
294        let endpoints = vec![
295            format!("{}/orchestrate/agents/{}/threads/{}/chat_with_docs", base_url, agent_id, thread_id),
296            format!("{}/agents/{}/threads/{}/chat_with_docs", base_url, agent_id, thread_id),
297            format!("{}/orchestrate/agents/{}/threads/{}/runs/stream", base_url, agent_id, thread_id),
298            format!("{}/agents/{}/threads/{}/runs/stream", base_url, agent_id, thread_id),
299        ];
300
301        for url in endpoints {
302            let payload = if url.contains("chat_with_docs") {
303                serde_json::json!({
304                    "message": request.message,
305                    "document_content": request.document_content,
306                    "document_path": request.document_path,
307                    "context": request.context,
308                })
309            } else {
310                serde_json::json!({
311                    "message": {
312                        "role": "user",
313                        "content": request.message,
314                    },
315                    "agent_id": agent_id,
316                    "thread_id": thread_id,
317                    "document_content": request.document_content,
318                    "document_path": request.document_path,
319                    "context": request.context,
320                })
321            };
322
323            let response = self
324                .client
325                .post(&url)
326                .header("Authorization", format!("Bearer {}", token))
327                .header("Content-Type", "application/json")
328                .header("X-Instance-ID", &self.config.instance_id)
329                .json(&payload)
330                .send()
331                .await
332                .map_err(|e| Error::Network(e.to_string()))?;
333
334            if response.status().is_success() {
335                let text = response
336                    .text()
337                    .await
338                    .map_err(|e| Error::Network(e.to_string()))?;
339
340                if let Ok(chat_response) = serde_json::from_str::<ChatWithDocsResponse>(&text) {
341                    return Ok(chat_response);
342                }
343
344                if let Ok(value) = serde_json::from_str::<serde_json::Value>(&text) {
345                    let message = value
346                        .get("message")
347                        .and_then(|m| m.as_str())
348                        .or_else(|| value.get("content").and_then(|c| c.as_str()))
349                        .unwrap_or("No response")
350                        .to_string();
351
352                    return Ok(ChatWithDocsResponse {
353                        message,
354                        documents_used: None,
355                        confidence: None,
356                        metadata: None,
357                    });
358                }
359
360                return Ok(ChatWithDocsResponse {
361                    message: text,
362                    documents_used: None,
363                    confidence: None,
364                    metadata: None,
365                });
366            }
367        }
368
369        Err(Error::Api(format!(
370            "Failed to chat with docs: All endpoint paths returned 404. Chat with documents may not be available in this instance."
371        )))
372    }
373
374    /// Stream chat with documents response
375    pub async fn stream_chat_with_docs<F>(
376        &self,
377        agent_id: &str,
378        thread_id: &str,
379        request: ChatWithDocsRequest,
380        mut callback: F,
381    ) -> Result<()>
382    where
383        F: FnMut(String) -> Result<()>,
384    {
385        let token = self.access_token.as_ref().ok_or_else(|| {
386            Error::Authentication("Not authenticated. Set access token (Bearer token) first.".to_string())
387        })?;
388
389        let base_url = self.config.get_base_url();
390        
391        let endpoints = vec![
392            format!("{}/orchestrate/agents/{}/threads/{}/chat_with_docs", base_url, agent_id, thread_id),
393            format!("{}/agents/{}/threads/{}/chat_with_docs", base_url, agent_id, thread_id),
394            format!("{}/orchestrate/agents/{}/threads/{}/runs/stream", base_url, agent_id, thread_id),
395            format!("{}/agents/{}/threads/{}/runs/stream", base_url, agent_id, thread_id),
396        ];
397
398        for url in endpoints {
399            let payload = if url.contains("chat_with_docs") {
400                serde_json::json!({
401                    "message": request.message,
402                    "document_content": request.document_content,
403                    "document_path": request.document_path,
404                    "context": request.context,
405                })
406            } else {
407                serde_json::json!({
408                    "message": {
409                        "role": "user",
410                        "content": request.message,
411                    },
412                    "agent_id": agent_id,
413                    "thread_id": thread_id,
414                    "document_content": request.document_content,
415                    "document_path": request.document_path,
416                    "context": request.context,
417                })
418            };
419
420            let response = self
421                .client
422                .post(&url)
423                .header("Authorization", format!("Bearer {}", token))
424                .header("Content-Type", "application/json")
425                .header("Accept", "text/event-stream")
426                .header("Cache-Control", "no-cache")
427                .header("Connection", "keep-alive")
428                .header("X-Accel-Buffering", "no")
429                .header("X-Instance-ID", &self.config.instance_id)
430                .json(&payload)
431                .send()
432                .await
433                .map_err(|e| Error::Network(e.to_string()))?;
434
435            if !response.status().is_success() {
436                continue;
437            }
438
439            let mut stream = response.bytes_stream();
440            let mut buffer = String::new();
441
442            while let Some(chunk_result) = stream.next().await {
443                let chunk = chunk_result.map_err(|e| Error::Network(e.to_string()))?;
444                let chunk_str = String::from_utf8_lossy(&chunk);
445                buffer.push_str(&chunk_str);
446
447                while let Some(line_end) = buffer.find('\n') {
448                    let line = buffer[..line_end].to_string();
449                    buffer = buffer[line_end + 1..].to_string();
450
451                    if !line.is_empty() && line.starts_with("data:") {
452                        let data_str = &line[5..].trim();
453                        if let Ok(event_data) = serde_json::from_str::<EventData>(data_str) {
454                            if event_data.event == "message.delta" {
455                                if let Some(data_obj) = event_data.data.as_object() {
456                                    if let Some(delta_obj) = data_obj.get("delta").and_then(|d| d.as_object()) {
457                                        if let Some(content_array) = delta_obj.get("content").and_then(|c| c.as_array()) {
458                                            if let Some(first_content) = content_array.first() {
459                                                if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
460                                                    callback(text.to_string())?;
461                                                }
462                                            }
463                                        }
464                                    }
465                                }
466                            }
467                        }
468                    }
469                }
470            }
471
472            return Ok(());
473        }
474
475        Err(Error::Api(format!(
476            "Failed to stream chat with docs: All endpoint paths returned 404. Chat with documents may not be available in this instance."
477        )))
478    }
479}