sync_lsp/connection/
transport.rs

1use std::{io::{
2    BufRead,
3    Write,
4    StdinLock,
5    StdoutLock,
6    Error,
7    stdin,
8    stdout, BufReader
9}, net::{ToSocketAddrs, TcpListener}};
10
11#[cfg(feature = "mio")]
12use std::time::Duration;
13
14#[cfg(feature = "mio")]
15use mio::net::TcpStream;
16#[cfg(not(feature = "mio"))]
17use std::net::TcpStream;
18
19#[cfg(feature = "mio")]
20use mio::{
21    Events,
22    Poll,
23    Token,
24    Interest
25};
26
27#[cfg(unix)]
28#[cfg(feature = "mio")]
29use mio::unix::SourceFd;
30
31use log::{
32    warn,
33    error
34};
35
36/// The transport defines how data is sent and received from the client.
37/// 
38/// The langauge server protocol commonly uses stdio and ipc, but
39/// tcp and custom transports are also supported.
40/// All errors that occur during sending and receiving will cause the
41/// [Server::serve](crate::Server::serve) method to immediately return with an error variant.
42pub struct Transport {
43    raw: RawTransport,
44    error: Option<Error>,
45    #[cfg(feature = "mio")]
46    poll: Option<Poll>,
47    #[cfg(feature = "mio")]
48    events: Events,
49    buffer: Vec<Vec<u8>>
50}
51
52enum RawTransport {
53    Stdio {
54        input: StdinLock<'static>,
55        output: StdoutLock<'static>,
56    },
57    Tpc {
58        input: BufReader<TcpStream>,
59        output: TcpStream
60    },
61    Custom {
62        input: Box<dyn BufRead>,
63        output: Box<dyn Write>
64    }
65}
66
67impl RawTransport {
68    fn input(&mut self) -> &mut dyn BufRead {
69        match self {
70            Self::Stdio { input, .. } => input,
71            Self::Tpc { input, .. } => input,
72            Self::Custom { input, .. } => input
73        }
74    }
75
76    fn output(&mut self) -> &mut dyn Write {
77        match self {
78            Self::Stdio { output, .. } => output,
79            Self::Tpc { output, .. } => output,
80            Self::Custom { output, .. } => output
81        }
82    }
83}
84
85impl Transport {
86
87    /// Creates a new transport from the given input and output streams.
88    /// This transport will not support polling and therefore will not be able to
89    /// support request cancellation.
90    /// 
91    /// # Arguments
92    /// * `input` - The input stream to read from.
93    pub fn custom(input: impl BufRead + 'static, output: impl Write + 'static) -> Transport {
94        Transport {
95            raw: RawTransport::Custom {
96                input: Box::new(input),
97                output: Box::new(output)
98            },
99            error: None,
100            #[cfg(feature = "mio")]
101            events: Events::with_capacity(1),
102            buffer: Vec::new(),
103            #[cfg(feature = "mio")]
104            poll: None
105        }
106    }
107
108    /// Opens a tcp connection to the given address and returns a transport.
109    /// 
110    /// # Argument
111    /// * `addr` - The address to connect to.
112    #[cfg(not(feature = "mio"))]
113    pub fn tcp<T: ToSocketAddrs>(addr: T) -> Result<Transport, Error> {
114        let listener = TcpListener::bind(addr)?;
115        let (output, ..) = listener.accept()?;
116        let input = output.try_clone()?;
117        let input = BufReader::new(input);
118
119        Ok(Transport {
120            raw: RawTransport::Tpc {
121                output,
122                input
123            },
124            error: None,
125            buffer: Vec::new(),
126        })
127    }
128
129    /// Opens a tcp connection to the given address and returns a transport.
130    /// 
131    /// # Argument
132    /// * `addr` - The address to connect to.
133    #[cfg(feature = "mio")]
134    pub fn tcp<T: ToSocketAddrs>(addr: T) -> Result<Transport, Error> {
135        let mut poll = Poll::new().ok();
136        let listener = TcpListener::bind(addr)?;
137        let (output, ..) = listener.accept()?;
138        let input = output.try_clone()?;
139        let mut input = TcpStream::from_std(input);
140
141        if let Some(poll) = poll.as_mut() {
142            poll.registry().register(
143                &mut input,
144                Token(0),
145                Interest::READABLE
146            ).ok();
147        }
148
149        let input = BufReader::new(input);
150        let output = TcpStream::from_std(output);
151
152        Ok(Transport {
153            raw: RawTransport::Tpc {
154                output,
155                input
156            },
157            error: None,
158            events: Events::with_capacity(1),
159            buffer: Vec::new(),
160            poll
161        })
162    }
163
164    /// Locks the standard input and output streams and returns a transport.
165    #[cfg(not(feature = "mio"))]
166    pub fn stdio() -> Transport {
167        Transport {
168            raw: RawTransport::Stdio {
169                output: stdout().lock(),
170                input: stdin().lock()
171            },
172            error: None,
173            buffer: Vec::new(),
174        }
175    }
176
177    /// Locks the standard input and output streams and returns a transport.
178    #[cfg(feature = "mio")]
179    pub fn stdio() -> Transport {
180        let poll = Poll::new().ok();
181        let input = stdin().lock();
182
183        #[cfg(unix)]
184        if let Some(poll) = poll.as_ref() {
185            use std::os::fd::AsRawFd;
186            poll.registry().register(
187                &mut SourceFd(&input.as_raw_fd()),
188                Token(0),
189                Interest::READABLE
190            ).ok();
191        }
192        
193        Transport {
194            raw: RawTransport::Stdio {
195                output: stdout().lock(),
196                input
197            },
198            error: None,
199            events: Events::with_capacity(1),
200            buffer: Vec::new(),
201            poll
202        }
203    }
204
205    pub(crate) fn error(&mut self) -> &mut Option<Error> {
206        &mut self.error
207    }
208
209    pub(crate) fn send(&mut self, message: String) {
210        if self.error().is_some() { return }
211        *self.error() = write!(self.raw.output(), "Content-Length: {}\r\n", message.len())
212            .or(write!(self.raw.output(), "Content-Type: {}\r\n", "application/vscode-jsonrpc; charset=utf-8"))
213            .or(write!(self.raw.output(), "\r\n{message}"))
214            .or(self.raw.output().flush()).err();
215    }
216
217    pub(crate) fn recv(&mut self) -> Option<Vec<u8>> {
218        if let Some(data) = self.buffer.pop() {
219            return Some(data)
220        }
221
222        if self.error().is_some() { return None }
223        match self.try_recv() {
224            Ok(message) => Some(message),
225            Err(error) => {
226                *self.error() = Some(error);
227                None
228            }
229        }
230    }
231
232
233    pub(crate) fn peek(&mut self) -> Option<Vec<u8>> {
234        if self.poll() && self.buffer.len() < 10192 {
235            let data = self.recv();
236            if let Some(data) = data.clone() {
237                self.buffer.push(data)
238            }
239            data
240        } else {
241            None
242        }
243    }
244
245    #[cfg(feature = "mio")]
246    fn poll(&mut self) -> bool {
247        self.events.clear();
248        if let Some(poll) = self.poll.as_mut() {
249            poll.poll(&mut self.events, Some(Duration::from_millis(1))).ok();
250        }
251        !self.events.is_empty()
252    }
253
254    #[cfg(not(feature = "mio"))]
255    fn poll(&mut self) -> bool {
256        false
257    }
258
259    fn try_recv(&mut self) -> Result<Vec<u8>, Error> {
260        loop {
261            let mut content_length: Option<usize> = None;
262    
263            for line in self.raw.input().lines() {
264
265                let line = line?;
266                if line.is_empty() { break }
267
268                match line.split_once(": ") {
269                    Some(("Content-Length", value)) => content_length = Some(
270                        if let Ok(content_length) = value.parse() {
271                            content_length
272                        } else {
273                            error!("Failed to parse Content-Length");
274                            continue
275                        }
276                    ),
277                    Some(("Content-Type", value)) => {
278                        if value != "application/vscode-jsonrpc; charset=utf-8" {
279                            error!("Invalid Content-Type: {value}");
280                            continue
281                        }
282                    },
283                    None => warn!("Invalid header: {line}"),
284                    Some((header, ..)) => warn!("Unknown header: {header}")
285                }
286            }
287
288            let Some(content_length) = content_length else {
289                error!("Received a message without a Content-Length");
290                continue
291            };
292
293            let mut buffer = vec![0; content_length];
294
295            self.raw.input()
296                .read_exact(&mut buffer)?;
297
298            //eprintln!("Received: {message}", message = String::from_utf8_lossy(&buffer));
299
300            return Ok(buffer)
301        }
302    }
303}