Skip to main content

rz_cli/
telegram.rs

1//! Telegram-to-NATS bridge — makes a Telegram chat appear as a regular rz agent.
2//!
3//! Long-polls Telegram for messages, publishes them to NATS. Subscribes to a
4//! NATS subject, forwards agent replies to Telegram. Registers in KV for
5//! discovery via `rz list --all`.
6
7use std::collections::HashMap;
8use std::time::Duration;
9
10use eyre::{Result, bail};
11use futures::StreamExt;
12
13use rz_agent_protocol::{Envelope, MessageKind};
14
15const LONG_POLL_TIMEOUT: u64 = 30;
16const MAX_BACKOFF: Duration = Duration::from_secs(60);
17const INITIAL_BACKOFF: Duration = Duration::from_secs(1);
18const MAX_MESSAGE_LEN: usize = 4096;
19const MAX_REF_MAP_SIZE: usize = 100;
20const KV_REFRESH_SECS: u64 = 20;
21
22/// Telegram-to-NATS bridge.
23pub struct TelegramBridge {
24    token: String,
25    chat_id: i64,
26    name: String,
27    default_agent: String,
28}
29
30impl TelegramBridge {
31    pub fn new(
32        token: String,
33        chat_id: i64,
34        name: String,
35        default_agent: String,
36    ) -> Self {
37        Self { token, chat_id, name, default_agent }
38    }
39
40    fn api_url(&self) -> String {
41        format!("https://api.telegram.org/bot{}", self.token)
42    }
43
44    fn own_subject(&self) -> String {
45        format!("agent.{}", self.name)
46    }
47
48    fn from_identity(&self) -> String {
49        self.own_subject()
50    }
51
52    /// Run the bridge. Long-polls Telegram + subscribes to NATS.
53    /// Blocks until Ctrl-C or an unrecoverable error.
54    pub async fn run(
55        &self,
56        nats_client: &async_nats::Client,
57        kv: Option<&async_nats::jetstream::kv::Store>,
58    ) -> Result<()> {
59        let http = reqwest::Client::new();
60        let api = self.api_url();
61
62        // Validate the bot token.
63        validate_token(&http, &api).await?;
64
65        // Subscribe to our NATS subject for inbound agent messages.
66        let subject = self.own_subject();
67        let mut nats_sub = nats_client
68            .subscribe(async_nats::Subject::from(subject.clone()))
69            .await
70            .map_err(|e| eyre::eyre!("NATS subscribe to {subject}: {e}"))?;
71        eprintln!("subscribed to NATS subject: {subject}");
72
73        // Register in KV immediately.
74        self.register_kv(kv).await;
75
76        let mut offset: Option<i64> = None;
77        let mut backoff = INITIAL_BACKOFF;
78        let mut ref_map = RefMap::new();
79        let mut kv_interval = tokio::time::interval(Duration::from_secs(KV_REFRESH_SECS));
80
81        loop {
82            tokio::select! {
83                // NATS → Telegram: agent messages forwarded to Telegram chat.
84                msg = nats_sub.next() => {
85                    let Some(msg) = msg else {
86                        bail!("NATS subscription closed");
87                    };
88                    let payload = String::from_utf8_lossy(&msg.payload);
89                    if let Ok(env) = Envelope::decode(&payload) {
90                        if let Some(text) = extract_display_text(&env.kind) {
91                            let sender = env.from.rsplit('.').next().unwrap_or(&env.from);
92                            let display = format!("[{sender}] {text}");
93                            if let Some(tg_msg_id) = send_message(&http, &api, self.chat_id, &display).await {
94                                ref_map.insert(tg_msg_id, env.id.clone());
95                            }
96                        }
97                    }
98                }
99
100                // Telegram → NATS: poll for new messages.
101                updates = get_updates(&http, &api, &mut offset, &mut backoff) => {
102                    for update in updates {
103                        if let Some(parsed) = parse_update(&update) {
104                            // Only process messages from the configured chat.
105                            if parsed.chat_id != self.chat_id {
106                                continue;
107                            }
108
109                            // Route: @agent prefix or default target.
110                            let (target_agent, text) = parse_target(&parsed.text, &self.default_agent);
111                            let target_subject = format!("agent.{target_agent}");
112
113                            // Build envelope.
114                            let mut envelope = Envelope::chat(
115                                &self.from_identity(),
116                                text,
117                            );
118
119                            // Correlate replies via ref map.
120                            if let Some(reply_to_msg_id) = parsed.reply_to_message_id {
121                                if let Some(rz_id) = ref_map.get(reply_to_msg_id) {
122                                    envelope = envelope.with_ref(rz_id);
123                                }
124                            }
125
126                            // Publish to NATS.
127                            let wire = match envelope.encode() {
128                                Ok(w) => w,
129                                Err(e) => {
130                                    eprintln!("encode error: {e}");
131                                    continue;
132                                }
133                            };
134                            // Publish via JetStream so durable consumers receive the message.
135                            // Ensure stream exists, then publish.
136                            let js = async_nats::jetstream::new(nats_client.clone());
137                            let target_agent_name = target_agent.to_string();
138                            let stream_name = format!("RZ_{}", target_agent_name.replace('.', "_").replace('-', "_"));
139                            let _ = js.get_or_create_stream(async_nats::jetstream::stream::Config {
140                                name: stream_name,
141                                subjects: vec![target_subject.clone()],
142                                max_messages: 10_000,
143                                ..Default::default()
144                            }).await;
145                            match js.publish(
146                                async_nats::Subject::from(target_subject.clone()),
147                                wire.into_bytes().into(),
148                            ).await {
149                                Ok(ack_future) => {
150                                    if let Err(e) = ack_future.await {
151                                        eprintln!("NATS jetstream ack failed for {target_subject}: {e}");
152                                    }
153                                }
154                                Err(e) => {
155                                    eprintln!("NATS publish to {target_subject}: {e}");
156                                    continue;
157                                }
158                            }
159                            eprintln!("  → {target_subject}: {}", truncate(text, 80));
160                        }
161                    }
162                }
163
164                // Periodic KV refresh.
165                _ = kv_interval.tick() => {
166                    self.register_kv(kv).await;
167                }
168
169                // Ctrl-C.
170                _ = tokio::signal::ctrl_c() => {
171                    eprintln!("stopping telegram bridge");
172                    break;
173                }
174            }
175        }
176
177        Ok(())
178    }
179
180    async fn register_kv(&self, kv: Option<&async_nats::jetstream::kv::Store>) {
181        let Some(kv) = kv else { return };
182        let now_ms = std::time::SystemTime::now()
183            .duration_since(std::time::UNIX_EPOCH)
184            .unwrap_or_default()
185            .as_millis() as u64;
186        let value = serde_json::json!({
187            "name": self.name,
188            "id": format!("telegram-{}", self.chat_id),
189            "transport": "nats",
190            "endpoint": self.name,
191            "capabilities": [],
192            "permanent": true,
193            "registered_at": now_ms,
194            "last_seen": now_ms,
195        });
196        let _ = kv.put(&self.name, value.to_string().into()).await;
197    }
198}
199
200// ---------------------------------------------------------------------------
201// Telegram API helpers
202// ---------------------------------------------------------------------------
203
204async fn validate_token(client: &reqwest::Client, api: &str) -> Result<()> {
205    let resp: serde_json::Value = client
206        .post(format!("{api}/getMe"))
207        .send()
208        .await
209        .map_err(|e| eyre::eyre!("getMe request failed: {e}"))?
210        .json()
211        .await
212        .map_err(|e| eyre::eyre!("getMe parse failed: {e}"))?;
213
214    if resp["ok"].as_bool() != Some(true) {
215        let desc = resp["description"].as_str().unwrap_or("unknown error");
216        bail!("Telegram getMe failed: {desc}");
217    }
218
219    let bot_name = resp["result"]["username"].as_str().unwrap_or("unknown");
220    eprintln!("Telegram bot @{bot_name} connected");
221    Ok(())
222}
223
224/// Long-poll Telegram for updates. Returns parsed update objects.
225/// Manages backoff internally — on error, sleeps and returns empty.
226async fn get_updates(
227    client: &reqwest::Client,
228    api: &str,
229    offset: &mut Option<i64>,
230    backoff: &mut Duration,
231) -> Vec<serde_json::Value> {
232    let mut params = serde_json::json!({
233        "timeout": LONG_POLL_TIMEOUT,
234        "allowed_updates": ["message"],
235    });
236    if let Some(off) = offset {
237        params["offset"] = serde_json::json!(*off);
238    }
239
240    let result = client
241        .post(format!("{api}/getUpdates"))
242        .json(&params)
243        .timeout(Duration::from_secs(LONG_POLL_TIMEOUT + 10))
244        .send()
245        .await;
246
247    let resp = match result {
248        Ok(r) => r,
249        Err(e) => {
250            eprintln!("Telegram poll error: {e}, retrying in {backoff:?}");
251            tokio::time::sleep(*backoff).await;
252            *backoff = (*backoff * 2).min(MAX_BACKOFF);
253            return Vec::new();
254        }
255    };
256
257    let status = resp.status();
258
259    if status.as_u16() == 429 {
260        let body: serde_json::Value = resp.json().await.unwrap_or_default();
261        let retry = body["parameters"]["retry_after"].as_u64().unwrap_or(5);
262        eprintln!("Telegram rate limited, retry after {retry}s");
263        tokio::time::sleep(Duration::from_secs(retry)).await;
264        return Vec::new();
265    }
266
267    if !status.is_success() {
268        eprintln!("Telegram getUpdates failed ({status}), retrying in {backoff:?}");
269        tokio::time::sleep(*backoff).await;
270        *backoff = (*backoff * 2).min(MAX_BACKOFF);
271        return Vec::new();
272    }
273
274    let body: serde_json::Value = match resp.json().await {
275        Ok(v) => v,
276        Err(e) => {
277            eprintln!("Telegram parse error: {e}");
278            tokio::time::sleep(*backoff).await;
279            *backoff = (*backoff * 2).min(MAX_BACKOFF);
280            return Vec::new();
281        }
282    };
283
284    *backoff = INITIAL_BACKOFF;
285
286    let Some(updates) = body["result"].as_array() else {
287        return Vec::new();
288    };
289
290    // Advance offset past all received updates.
291    for update in updates {
292        if let Some(update_id) = update["update_id"].as_i64() {
293            *offset = Some(update_id + 1);
294        }
295    }
296
297    updates.clone()
298}
299
300/// Send a text message to Telegram. Returns the message_id on success.
301/// Splits messages that exceed Telegram's 4096-char limit.
302async fn send_message(
303    client: &reqwest::Client,
304    api: &str,
305    chat_id: i64,
306    text: &str,
307) -> Option<i64> {
308    let chunks = split_message(text);
309    let mut last_msg_id = None;
310
311    for chunk in &chunks {
312        let body = serde_json::json!({
313            "chat_id": chat_id,
314            "text": chunk,
315        });
316
317        let resp = match client
318            .post(format!("{api}/sendMessage"))
319            .json(&body)
320            .send()
321            .await
322        {
323            Ok(r) => r,
324            Err(e) => {
325                eprintln!("Telegram sendMessage error: {e}");
326                return last_msg_id;
327            }
328        };
329
330        if resp.status().is_success() {
331            if let Ok(val) = resp.json::<serde_json::Value>().await {
332                if let Some(msg_id) = val["result"]["message_id"].as_i64() {
333                    last_msg_id = Some(msg_id);
334                }
335            }
336        } else {
337            let err = resp.text().await.unwrap_or_default();
338            eprintln!("Telegram sendMessage failed: {err}");
339        }
340    }
341
342    last_msg_id
343}
344
345/// Split text into chunks that fit Telegram's 4096-char limit.
346fn split_message(text: &str) -> Vec<String> {
347    if text.len() <= MAX_MESSAGE_LEN {
348        return vec![text.to_string()];
349    }
350
351    let mut chunks = Vec::new();
352    let mut remaining = text;
353
354    while !remaining.is_empty() {
355        if remaining.len() <= MAX_MESSAGE_LEN {
356            chunks.push(remaining.to_string());
357            break;
358        }
359
360        let byte_limit = floor_char_boundary(remaining, MAX_MESSAGE_LEN);
361        let search = &remaining[..byte_limit];
362        let break_pos = search
363            .rfind('\n')
364            .or_else(|| search.rfind(' '))
365            .map(|p| p + 1)
366            .unwrap_or(byte_limit);
367
368        chunks.push(remaining[..break_pos].to_string());
369        remaining = remaining[break_pos..].trim_start();
370    }
371
372    chunks
373}
374
375fn floor_char_boundary(s: &str, max_bytes: usize) -> usize {
376    if max_bytes >= s.len() {
377        return s.len();
378    }
379    let mut i = max_bytes;
380    while i > 0 && !s.is_char_boundary(i) {
381        i -= 1;
382    }
383    i
384}
385
386// ---------------------------------------------------------------------------
387// Message parsing
388// ---------------------------------------------------------------------------
389
390struct ParsedMessage {
391    chat_id: i64,
392    text: String,
393    reply_to_message_id: Option<i64>,
394}
395
396fn parse_update(update: &serde_json::Value) -> Option<ParsedMessage> {
397    let message = update.get("message")?;
398    let text = message["text"].as_str()?;
399    let chat_id = message["chat"]["id"].as_i64()?;
400    let reply_to_message_id = message
401        .get("reply_to_message")
402        .and_then(|r| r["message_id"].as_i64());
403
404    Some(ParsedMessage {
405        chat_id,
406        text: text.to_string(),
407        reply_to_message_id,
408    })
409}
410
411/// Parse `@agent_name rest of message` routing prefix.
412/// Returns (target_agent, message_text).
413fn parse_target<'a>(text: &'a str, default_agent: &'a str) -> (&'a str, &'a str) {
414    if let Some(rest) = text.strip_prefix('@') {
415        if let Some(space_pos) = rest.find(' ') {
416            let agent = &rest[..space_pos];
417            let msg = rest[space_pos..].trim_start();
418            if !agent.is_empty() && !msg.is_empty() {
419                return (agent, msg);
420            }
421        }
422    }
423    (default_agent, text)
424}
425
426/// Extract displayable text from an envelope kind.
427/// Returns None for protocol-internal messages (Ping, Pong, Hello).
428fn extract_display_text(kind: &MessageKind) -> Option<String> {
429    match kind {
430        MessageKind::Chat { text } => Some(text.clone()),
431        MessageKind::Error { message } => Some(format!("Error: {message}")),
432        MessageKind::Timer { label } => Some(format!("Timer: {label}")),
433        MessageKind::Status { state, detail } => {
434            Some(format!("[{state}] {detail}"))
435        }
436        MessageKind::ToolCall { name, .. } => {
437            Some(format!("(calling tool: {name})"))
438        }
439        MessageKind::ToolResult { result, is_error, .. } => {
440            let prefix = if *is_error { "Tool error" } else { "Tool result" };
441            Some(format!("{prefix}: {}", truncate(result, 200)))
442        }
443        MessageKind::Delegate { task, .. } => {
444            Some(format!("(delegating: {})", truncate(task, 200)))
445        }
446        // Internal protocol — don't forward.
447        MessageKind::Ping
448        | MessageKind::Pong
449        | MessageKind::Hello { .. } => None,
450    }
451}
452
453fn truncate(s: &str, max: usize) -> &str {
454    if s.len() <= max {
455        s
456    } else {
457        let end = floor_char_boundary(s, max);
458        &s[..end]
459    }
460}
461
462// ---------------------------------------------------------------------------
463// Bounded ref map: telegram_message_id → rz_envelope_id
464// ---------------------------------------------------------------------------
465
466struct RefMap {
467    map: HashMap<i64, String>,
468    order: Vec<i64>,
469}
470
471impl RefMap {
472    fn new() -> Self {
473        Self {
474            map: HashMap::new(),
475            order: Vec::new(),
476        }
477    }
478
479    fn insert(&mut self, tg_msg_id: i64, rz_id: String) {
480        if self.order.len() >= MAX_REF_MAP_SIZE {
481            if let Some(oldest) = self.order.first().copied() {
482                self.map.remove(&oldest);
483                self.order.remove(0);
484            }
485        }
486        self.map.insert(tg_msg_id, rz_id);
487        self.order.push(tg_msg_id);
488    }
489
490    fn get(&self, tg_msg_id: i64) -> Option<String> {
491        self.map.get(&tg_msg_id).cloned()
492    }
493}
494
495// ---------------------------------------------------------------------------
496// Tests
497// ---------------------------------------------------------------------------
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    #[test]
504    fn parse_target_with_prefix() {
505        let (agent, msg) = parse_target("@coder fix the tests", "default");
506        assert_eq!(agent, "coder");
507        assert_eq!(msg, "fix the tests");
508    }
509
510    #[test]
511    fn parse_target_without_prefix() {
512        let (agent, msg) = parse_target("just a message", "default");
513        assert_eq!(agent, "default");
514        assert_eq!(msg, "just a message");
515    }
516
517    #[test]
518    fn parse_target_bare_at() {
519        let (agent, msg) = parse_target("@", "default");
520        assert_eq!(agent, "default");
521        assert_eq!(msg, "@");
522    }
523
524    #[test]
525    fn parse_target_at_no_space() {
526        let (agent, msg) = parse_target("@agent", "default");
527        assert_eq!(agent, "default");
528        assert_eq!(msg, "@agent");
529    }
530
531    #[test]
532    fn ref_map_bounded() {
533        let mut rm = RefMap::new();
534        for i in 0..150 {
535            rm.insert(i, format!("id-{i}"));
536        }
537        assert!(rm.map.len() <= MAX_REF_MAP_SIZE);
538        // Oldest entries evicted.
539        assert!(rm.get(0).is_none());
540        assert!(rm.get(149).is_some());
541    }
542
543    #[test]
544    fn ref_map_get_returns_stored() {
545        let mut rm = RefMap::new();
546        rm.insert(42, "abc123".into());
547        assert_eq!(rm.get(42), Some("abc123".into()));
548        assert_eq!(rm.get(99), None);
549    }
550
551    #[test]
552    fn split_message_short() {
553        let chunks = split_message("hello");
554        assert_eq!(chunks.len(), 1);
555        assert_eq!(chunks[0], "hello");
556    }
557
558    #[test]
559    fn split_message_long() {
560        let text = "a ".repeat(3000); // 6000 chars
561        let chunks = split_message(&text);
562        assert!(chunks.len() > 1);
563        for chunk in &chunks {
564            assert!(chunk.len() <= MAX_MESSAGE_LEN);
565        }
566    }
567
568    #[test]
569    fn extract_display_text_chat() {
570        let kind = MessageKind::Chat { text: "hello".into() };
571        assert_eq!(extract_display_text(&kind), Some("hello".into()));
572    }
573
574    #[test]
575    fn extract_display_text_ping_is_none() {
576        assert_eq!(extract_display_text(&MessageKind::Ping), None);
577    }
578
579    #[test]
580    fn extract_display_text_error() {
581        let kind = MessageKind::Error { message: "boom".into() };
582        assert_eq!(extract_display_text(&kind), Some("Error: boom".into()));
583    }
584}