sylt_std/
network.rs

1use crate as sylt_std;
2
3use std::cell::RefCell;
4use std::collections::HashMap;
5use std::collections::hash_map::Entry;
6use std::io::Write;
7use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream};
8use std::ops::DerefMut;
9use std::rc::Rc;
10use std::str::FromStr;
11use std::sync::{Arc, Mutex};
12use std::thread;
13use sylt_common::flat_value::{FlatValue, FlatValuePack};
14use sylt_common::{error::RuntimeError, RuntimeContext, Type, Value};
15
16const DEFAULT_PORT: u16 = 8588;
17
18type RPC = Vec<FlatValuePack>;
19
20std::thread_local! {
21    static RPC_QUEUE: Arc<Mutex<Vec<(RPC, Option<SocketAddr>)>>> = Arc::new(Mutex::new(Vec::new()));
22    static SERVER_HANDLE: RefCell<Option<TcpStream>> = RefCell::new(None);
23    static CLIENT_HANDLES: Arc<Mutex<Option<HashMap<SocketAddr, (TcpStream, bool)>>>> = Arc::new(Mutex::new(None));
24    static CURRENT_REQUEST_SOCKET_ADDR: RefCell<Option<SocketAddr>> = RefCell::new(None);
25}
26
27/// Listen for new connections and accept them.
28fn rpc_listen(
29    listener: TcpListener,
30    queue: Arc<Mutex<Vec<(RPC, Option<SocketAddr>)>>>,
31    handles: Arc<Mutex<Option<HashMap<SocketAddr, (TcpStream, bool)>>>>,
32) {
33    loop {
34        if let Ok((stream, addr)) = listener.accept() {
35            match stream.try_clone() {
36                Ok(stream) => {
37                    if let Some(handles) = handles.lock().unwrap().as_mut() {
38                        handles.insert(addr, (stream, true));
39                    } else {
40                        eprintln!("Server has been shutdown, ignoring connection from {:?}", addr);
41                        if let Err(e) = stream.shutdown(Shutdown::Both) {
42                            eprintln!("Error disconnecting client {:?}: {:?}", addr, e)
43                        }
44                    }
45                }
46                Err(e) => {
47                    eprintln!("Error accepting TCP connection: {:?}", e);
48                    eprintln!("Ignoring");
49                    continue;
50                }
51            }
52            let queue = Arc::clone(&queue);
53            thread::spawn(move || rpc_handle_stream(stream, Some(addr), queue));
54        }
55    }
56}
57
58/// Receive RPC values from a stream and queue them locally.
59fn rpc_handle_stream(
60    stream: TcpStream,
61    socket_addr: Option<SocketAddr>,
62    queue: Arc<Mutex<Vec<(RPC, Option<SocketAddr>)>>>,
63) {
64    loop {
65        let rpc = match bincode::deserialize_from(&stream) {
66            Ok(rpc) => rpc,
67            Err(e) => {
68                eprintln!("Error reading from client: {:?}", e);
69                return;
70            }
71        };
72        queue.lock().unwrap().push((rpc, socket_addr));
73    }
74}
75
76#[sylt_macro::sylt_doc(n_rpc_start_server, "Starts an RPC server on the specified port, returning success status.", [One(Int(port))] Type::Bool)]
77#[sylt_macro::sylt_link(n_rpc_start_server, "sylt_std::network")]
78pub fn n_rpc_start_server(ctx: RuntimeContext<'_>) -> Result<Value, RuntimeError> {
79    if ctx.typecheck {
80        return Ok(Value::Bool(true));
81    }
82
83    // Get the port from the arguments.
84    let values = ctx.machine.stack_from_base(ctx.stack_base);
85    let port = match values.as_ref() {
86        [Value::Int(port)] => *port as u16,
87        _ => DEFAULT_PORT,
88    };
89    // Bind the server.
90    let listener = match TcpListener::bind(("0.0.0.0", port)) {
91        Ok(listener) => listener,
92        Err(e) => {
93            eprintln!("Error binding server to TCP: {:?}", e);
94            return Ok(Value::Bool(false));
95        }
96    };
97
98    // Initialize the thread local with our list of client handles.
99    CLIENT_HANDLES.with(|global_handles| {
100        global_handles.lock().unwrap().insert(HashMap::new());
101    });
102
103    // Start listening for new clients.
104    let rpc_queue = RPC_QUEUE.with(|queue| Arc::clone(queue));
105    let handles = CLIENT_HANDLES.with(|handles| Arc::clone(handles));
106    thread::spawn(|| rpc_listen(listener, rpc_queue, handles));
107
108    Ok(Value::Bool(true))
109}
110
111#[sylt_macro::sylt_link(n_rpc_stop_server, "sylt_std::network")]
112pub fn n_rpc_stop_server(ctx: RuntimeContext<'_>) -> Result<Value, RuntimeError> {
113    if ctx.typecheck {
114        return Ok(Value::Bool(true));
115    }
116
117    if n_rpc_is_server(ctx)? == Value::Bool(false) {
118        return Ok(Value::Bool(false));
119    }
120
121    CLIENT_HANDLES.with(|handles| {
122        if let Some(handles) = handles.lock().unwrap().as_mut().take() {
123            for (addr, (stream, _)) in handles {
124                if let Err(e) = stream.shutdown(Shutdown::Both) {
125                    eprintln!("Error disconnecting client {:?}: {:?}", addr, e);
126                }
127            }
128        }
129    });
130
131    Ok(Value::Bool(true))
132}
133
134//NOTE(gu): We don't force a disconnect.
135#[sylt_macro::sylt_doc(n_rpc_connect, "Connects to an RPC server on the specified IP and port.", [One(String(ip)), One(Int(port))] Type::Bool)]
136#[sylt_macro::sylt_link(n_rpc_connect, "sylt_std::network")]
137pub fn n_rpc_connect(ctx: RuntimeContext<'_>) -> Result<Value, RuntimeError> {
138    if ctx.typecheck {
139        return Ok(Value::Bool(true));
140    }
141
142    // Get the ip and port from the arguments.
143    let values = ctx.machine.stack_from_base(ctx.stack_base);
144    let socket_addr = match values.as_ref() {
145        [Value::String(ip), Value::Int(port)] => (ip.as_str(), *port as u16),
146        [Value::String(ip)] => (ip.as_str(), DEFAULT_PORT),
147        _ => {
148            return Err(RuntimeError::ExternTypeMismatch(
149                "n_rpc_connect".to_string(),
150                values.iter().map(Type::from).collect(),
151            ));
152        }
153    };
154    // Connect to the server.
155    let stream = match TcpStream::connect(socket_addr) {
156        Ok(stream) => stream,
157        Err(e) => {
158            eprintln!("Error connecting to server: {:?}", e);
159            return Ok(Value::Bool(false));
160        }
161    };
162    // Store the stream so we can send to it later.
163    match stream.try_clone() {
164        Ok(stream) => {
165            SERVER_HANDLE.with(|server_handle| {
166                server_handle
167                    .borrow_mut()
168                    .insert(stream);
169            });
170        },
171        Err(e) => {
172            eprintln!("Error connecting to server: {:?}", e);
173            return Ok(Value::Bool(false));
174        },
175    }
176
177    // Handle incoming RPCs by putting them on the queue.
178    let rpc_queue = RPC_QUEUE.with(|queue| Arc::clone(queue));
179    thread::spawn(|| rpc_handle_stream(stream, None, rpc_queue));
180
181    Ok(Value::Bool(true))
182}
183
184#[sylt_macro::sylt_doc(n_rpc_is_server, "Returns whether we've started a server or not.", [] Type::Bool)]
185#[sylt_macro::sylt_link(n_rpc_is_server, "sylt_std::network")]
186pub fn n_rpc_is_server(_: RuntimeContext<'_>) -> Result<Value, RuntimeError> {
187    Ok(Value::Bool(
188        CLIENT_HANDLES.with(|handles| handles.lock().unwrap().is_some()),
189    ))
190}
191
192#[sylt_macro::sylt_doc(n_rpc_connected_clients, "Returns how many clients are currently connected.", [] Type::Int)]
193#[sylt_macro::sylt_link(n_rpc_connected_clients, "sylt_std::network")]
194pub fn n_rpc_connected_clients(_: RuntimeContext<'_>) -> Result<Value, RuntimeError> {
195    Ok(Value::Int(CLIENT_HANDLES.with(|handles| {
196        handles
197            .lock()
198            .unwrap()
199            .as_ref()
200            .map(|handles| handles.len() as i64)
201            .unwrap_or(0)
202    })))
203}
204
205#[sylt_macro::sylt_doc(n_rpc_is_client, "Returns whether we've connected to a client or not.", [] Type::Bool)]
206#[sylt_macro::sylt_link(n_rpc_is_client, "sylt_std::network")]
207pub fn n_rpc_is_client(_: RuntimeContext<'_>) -> Result<Value, RuntimeError> {
208    Ok(Value::Bool(
209        SERVER_HANDLE.with(|handle| handle.borrow().is_some()),
210    ))
211}
212
213/// Parse args given to an external function as rpc arguments, i.e. one callable followed by 0..n arguments.
214fn get_rpc_args(ctx: RuntimeContext<'_>, arg_offset: usize, func_name: &str) -> Result<Vec<FlatValuePack>, RuntimeError> {
215    let values = ctx.machine.stack_from_base(ctx.stack_base);
216    let flat_values: Vec<FlatValuePack> = values[arg_offset..].iter().map(|v| FlatValue::pack(v)).collect();
217
218    if flat_values.len() != 0 {
219        Ok(flat_values)
220    } else {
221        Err(RuntimeError::ExternTypeMismatch(
222            func_name.to_string(),
223            values.iter().map(Type::from).collect(),
224        ))
225    }
226}
227
228#[sylt_macro::sylt_doc(n_rpc_clients, "Performs an RPC on all connected clients.", [One(Value(callable)), One(List(args))] Type::Void)]
229#[sylt_macro::sylt_link(n_rpc_clients, "sylt_std::network")]
230pub fn n_rpc_clients(ctx: RuntimeContext<'_>) -> Result<Value, RuntimeError> {
231    if ctx.typecheck {
232        return Ok(Value::Nil);
233    }
234
235    // Serialize the RPC.
236    let serialized = match bincode::serialize(&get_rpc_args(ctx, 0, "n_rpc_clients")?) {
237        Ok(serialized) => serialized,
238        Err(e) => {
239            eprintln!("Error serializing values: {:?}", e);
240            return Ok(Value::Bool(false));
241        }
242    };
243
244    // Send the serialized data to all clients.
245    CLIENT_HANDLES.with(|client_handles| {
246        if let Some(streams) = client_handles.lock().unwrap().as_mut() {
247            for (_, (stream, keep)) in streams.iter_mut() {
248                if let Err(e) = stream.write(&serialized) {
249                    eprintln!("Error sending data to a client: {:?}", e);
250                    *keep = false;
251                }
252            }
253            streams.retain(|_, (_, keep)| *keep);
254        } else {
255            eprintln!("A server hasn't been started");
256        }
257    });
258
259    Ok(Value::Nil)
260}
261
262
263#[sylt_macro::sylt_doc(n_rpc_client_ip, "Performs an RPC on a specific connected clients.", [One(Value(callable)), One(List(args))] Type::Bool)]
264#[sylt_macro::sylt_link(n_rpc_client_ip, "sylt_std::network")]
265pub fn n_rpc_client_ip(ctx: RuntimeContext<'_>) -> Result<Value, RuntimeError> {
266    if ctx.typecheck {
267        return Ok(Value::Bool(true));
268    }
269
270    let ip = match ctx.machine.stack_from_base(ctx.stack_base).get(0) {
271        Some(Value::String(s)) => SocketAddr::from_str(s.as_ref()).unwrap(),
272        _ => {
273            return Ok(Value::Bool(false)); //TODO(gu): Type error here, probably.
274        }
275    };
276
277    // Serialize the RPC.
278    let serialized = match bincode::serialize(&get_rpc_args(ctx, 1, "n_rpc_client_ip")?) {
279        Ok(serialized) => serialized,
280        Err(e) => {
281            eprintln!("Error serializing values: {:?}", e);
282            return Ok(Value::Bool(false));
283        }
284    };
285
286    CLIENT_HANDLES.with(|client_handles| {
287        if let Some(streams) = client_handles.lock().unwrap().as_mut() {
288            if let Entry::Occupied(mut o) = streams.entry(ip) {
289                let (stream, _) = o.get_mut();
290                if let Err(e) = stream.write(&serialized) {
291                    eprintln!("Error sending data to a specific client {:?}: {:?}", ip, e);
292                    o.remove();
293                }
294                Ok(Value::Bool(true))
295            } else {
296                Ok(Value::Bool(false))
297            }
298        } else {
299            eprintln!("A server hasn't been started");
300            Ok(Value::Bool(false))
301        }
302    })
303}
304
305//TODO(gu): This doc is wrong since this takes variadic arguments.
306#[sylt_macro::sylt_doc(n_rpc_server, "Performs an RPC on the connected server, returning success status.", [One(Value(callable)), One(Value(args))] Type::Bool)]
307#[sylt_macro::sylt_link(n_rpc_server, "sylt_std::network")]
308pub fn n_rpc_server(ctx: RuntimeContext<'_>) -> Result<Value, RuntimeError> {
309    if ctx.typecheck {
310        return Ok(Value::Bool(true));
311    }
312
313    // Serialize the RPC.
314    let serialized = match bincode::serialize(&get_rpc_args(ctx, 0, "n_rpc_server")?) {
315        Ok(serialized) => serialized,
316        Err(e) => {
317            eprintln!("Error serializing values: {:?}", e);
318            return Ok(Value::Bool(false));
319        }
320    };
321
322    // Send the serialized data to the server.
323    SERVER_HANDLE.with(|server_handle| {
324        if let Some(mut stream) = server_handle.borrow().as_ref() {
325            match stream.write(&serialized) {
326                Ok(_) => Ok(Value::Bool(true)),
327                Err(e) => {
328                    eprintln!("Error sending data to server: {:?}", e);
329                    Ok(Value::Bool(false))
330                },
331            }
332        } else {
333            Ok(Value::Bool(false))
334        }
335    })
336}
337
338#[sylt_macro::sylt_doc(n_rpc_disconnect, "Disconnect from the currently connected server.", [] Type::Void)]
339#[sylt_macro::sylt_link(n_rpc_disconnect, "sylt_std::network")]
340pub fn n_rpc_disconnect(ctx: RuntimeContext<'_>) -> Result<Value, RuntimeError> {
341    if ctx.typecheck {
342        return Ok(Value::Nil);
343    }
344
345    SERVER_HANDLE.with(|server_handle| {
346        if let Some(handle) = server_handle.borrow_mut().take() {
347            if let Err(e) = handle.shutdown(Shutdown::Both) {
348                eprintln!("Error disconnecting from server: {:?}", e);
349            }
350        }
351    });
352
353    Ok(Value::Nil)
354}
355
356#[sylt_macro::sylt_doc(n_rpc_current_request_ip, "Get the socket address that sent the currently processed RPC. Empty string if not a server or not processing an RPC.", [] Type::String)]
357#[sylt_macro::sylt_link(n_rpc_current_request_ip, "sylt_std::network")]
358pub fn n_rpc_current_request_ip(_: RuntimeContext<'_>) -> Result<Value, RuntimeError> {
359    CURRENT_REQUEST_SOCKET_ADDR.with(|current|
360        Ok(Value::String(Rc::new(
361            current
362                .borrow()
363                .map(|socket| socket.to_string())
364                .unwrap_or("".to_string())
365        )))
366    )
367}
368
369sylt_macro::extern_function!(
370    "sylt_std::network"
371    split_ip
372    ""
373    [One(String(ip_port))] -> Type::Tuple(vec![Type::String, Type::Int]) => {
374        let addr = SocketAddr::from_str(ip_port.as_str()).unwrap();
375        Ok(Value::Tuple(Rc::new(vec![
376            Value::String(Rc::new(addr.ip().to_string())),
377            Value::Int(addr.port() as i64),
378        ])))
379    },
380);
381
382
383#[sylt_macro::sylt_doc(n_rpc_resolve, "Resolves the queued RPCs that has been received since the last resolve.", [] Type::Void)]
384#[sylt_macro::sylt_link(n_rpc_resolve, "sylt_std::network")]
385pub fn n_rpc_resolve(ctx: RuntimeContext<'_>) -> Result<Value, RuntimeError> {
386    if ctx.typecheck {
387        return Ok(Value::Nil);
388    }
389
390    // Take the current queue.
391    let queue = RPC_QUEUE.with(|queue| {
392        std::mem::replace(
393            queue.lock().unwrap().deref_mut(),
394            Vec::new(),
395        )
396    });
397
398    // Convert the queue into Values that can be evaluated.
399    let queue = queue
400        .into_iter()
401        .map(|(rpc, addr)| (rpc.iter().map(FlatValue::unpack).collect::<Vec<_>>(), addr));
402
403    // Evaluate each RPC one a time.
404    for (values, addr) in queue {
405        if values.is_empty() {
406            eprintln!("Tried to resolve empty RPC");
407            continue;
408        }
409        CURRENT_REQUEST_SOCKET_ADDR.with(|current| *current.borrow_mut() = addr);
410        // Create a vec of references to the argument list. This is kinda weird
411        // but it's needed since the runtime usually doesn't handle owned
412        // values.
413        let borrowed_values: Vec<_> = values.iter().collect();
414        if let Err(e) = ctx.machine.eval_call(values[0].clone(), &borrowed_values[1..]) {
415            eprintln!("{}", e);
416            panic!("Error evaluating received RPC");
417        }
418        CURRENT_REQUEST_SOCKET_ADDR.with(|current| current.borrow_mut().take());
419    }
420    Ok(Value::Nil)
421}
422
423sylt_macro::sylt_link_gen!("sylt_std::network");