Skip to main content

rust_lstar/knowledge_base/
network.rs

1use super::active::ActiveKnowledgeBase;
2use super::base::{KnowledgeBase, KnowledgeBaseTrait};
3use super::stats::KnowledgeBaseStats;
4use crate::letter::Letter;
5use crate::query::OutputQuery;
6/// Network active knowledge base implementation
7/// Communicates with a remote target via network sockets
8use crate::word::Word;
9use std::io::{Read, Write};
10use std::net::TcpStream;
11use std::time::Duration;
12
13/// An active knowledge base that communicates with a remote target via network
14pub struct NetworkActiveKnowledgeBase {
15    base: KnowledgeBase,
16    target_host: String,
17    target_port: u16,
18    timeout: Duration,
19    target_running: bool,
20}
21
22impl NetworkActiveKnowledgeBase {
23    /// Creates a new network knowledge base
24    ///
25    /// # Arguments
26    /// * `target_host` - Hostname or IP address of the target
27    /// * `target_port` - Port number of the target service
28    /// * `timeout` - Socket timeout duration
29    pub fn new(target_host: String, target_port: u16, timeout: Duration) -> Self {
30        NetworkActiveKnowledgeBase {
31            base: KnowledgeBase::new(),
32            target_host,
33            target_port,
34            timeout,
35            target_running: false,
36        }
37    }
38
39    /// Gets the target hostname
40    pub fn target_host(&self) -> &str {
41        &self.target_host
42    }
43
44    /// Gets the target port
45    pub fn target_port(&self) -> u16 {
46        self.target_port
47    }
48
49    /// Gets the socket timeout
50    pub fn timeout(&self) -> Duration {
51        self.timeout
52    }
53
54    /// Sets the socket timeout
55    pub fn set_timeout(&mut self, timeout: Duration) {
56        self.timeout = timeout;
57    }
58
59    /// Gets the connection address
60    fn connection_addr(&self) -> String {
61        format!("{}:{}", self.target_host, self.target_port)
62    }
63
64    /// Sends data to the target and receives response
65    fn send_and_receive(&self, stream: &mut TcpStream, data: &[u8]) -> Result<Vec<u8>, String> {
66        // Send data
67        stream
68            .write_all(data)
69            .map_err(|e| format!("Failed to send data: {}", e))?;
70
71        // Receive response
72        let mut buffer = vec![0; 1024];
73        let n = stream
74            .read(&mut buffer)
75            .map_err(|e| format!("Failed to receive data: {}", e))?;
76
77        buffer.truncate(n);
78        Ok(buffer)
79    }
80
81    /// Submits a single letter to the network target
82    fn submit_letter(&self, stream: &mut TcpStream, letter: &Letter) -> Result<Letter, String> {
83        let data = letter.symbols().into_bytes();
84        let response = self.send_and_receive(stream, &data)?;
85
86        match String::from_utf8(response) {
87            Ok(response_str) => {
88                let trimmed = response_str.trim();
89                Ok(Letter::new(trimmed))
90            }
91            Err(_) => Ok(Letter::new("")),
92        }
93    }
94}
95
96impl ActiveKnowledgeBase for NetworkActiveKnowledgeBase {
97    fn start_target(&mut self) -> Result<(), String> {
98        // The underlying protocol is queried in submit_word() with a fresh
99        // connection for each word (same behavior as pylstar).
100        self.target_running = true;
101        Ok(())
102    }
103
104    fn stop_target(&mut self) -> Result<(), String> {
105        self.target_running = false;
106        Ok(())
107    }
108
109    fn submit_word(&mut self, word: &Word) -> Result<Word, String> {
110        let addr = self.connection_addr();
111        let mut stream = TcpStream::connect(&addr)
112            .map_err(|e| format!("Failed to connect to {}: {}", addr, e))?;
113
114        stream
115            .set_read_timeout(Some(self.timeout))
116            .map_err(|e| format!("Failed to set read timeout: {}", e))?;
117        stream
118            .set_write_timeout(Some(self.timeout))
119            .map_err(|e| format!("Failed to set write timeout: {}", e))?;
120
121        let mut output_letters = Vec::new();
122
123        for letter in word.letters() {
124            output_letters.push(self.submit_letter(&mut stream, letter)?);
125        }
126
127        Ok(Word::from_letters(output_letters))
128    }
129
130    fn is_target_running(&self) -> bool {
131        self.target_running
132    }
133}
134
135impl KnowledgeBaseTrait for NetworkActiveKnowledgeBase {
136    fn resolve_query(&mut self, query: &mut OutputQuery) -> Result<(), String> {
137        match self.base.resolve_query(query) {
138            Ok(_) => Ok(()),
139            Err(_) => {
140                self.start_target()?;
141                let submit_result = self.submit_word(&query.input_word);
142                let stop_result = self.stop_target();
143
144                let output = submit_result?;
145                stop_result?;
146
147                self.base.add_word(&query.input_word, &output)?;
148                query.set_result(output);
149                Ok(())
150            }
151        }
152    }
153
154    fn add_word(&mut self, input_word: &Word, output_word: &Word) -> Result<(), String> {
155        self.base.add_word(input_word, output_word)
156    }
157}
158
159impl NetworkActiveKnowledgeBase {
160    pub fn stats(&self) -> &KnowledgeBaseStats {
161        self.base.stats()
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_creation() {
171        let kb =
172            NetworkActiveKnowledgeBase::new("localhost".to_string(), 3000, Duration::from_secs(5));
173
174        assert_eq!(kb.target_host(), "localhost");
175        assert_eq!(kb.target_port(), 3000);
176        assert!(!kb.is_target_running());
177    }
178
179    #[test]
180    fn test_connection_addr() {
181        let kb = NetworkActiveKnowledgeBase::new(
182            "example.com".to_string(),
183            8080,
184            Duration::from_secs(5),
185        );
186
187        assert_eq!(kb.connection_addr(), "example.com:8080");
188    }
189
190    #[test]
191    fn test_set_timeout() {
192        let mut kb =
193            NetworkActiveKnowledgeBase::new("localhost".to_string(), 3000, Duration::from_secs(5));
194
195        assert_eq!(kb.timeout(), Duration::from_secs(5));
196        kb.set_timeout(Duration::from_secs(10));
197        assert_eq!(kb.timeout(), Duration::from_secs(10));
198    }
199}