1use std::collections::HashMap;
15use std::io::{BufRead, BufReader, Read, Write};
16use std::net::{SocketAddr, TcpListener, TcpStream};
17use std::sync::atomic::{AtomicBool, Ordering};
18use std::sync::{Arc, Mutex};
19use std::thread::{self, JoinHandle};
20
21use anyhow::{Context, Error};
22
23pub struct LockServer {
24 listener: TcpListener,
25 addr: SocketAddr,
26 threads: HashMap<String, ServerClient>,
27 done: Arc<AtomicBool>,
28}
29
30pub struct LockServerStarted {
31 done: Arc<AtomicBool>,
32 addr: SocketAddr,
33 thread: Option<JoinHandle<()>>,
34}
35
36pub struct LockServerClient {
37 _socket: TcpStream,
38}
39
40struct ServerClient {
41 thread: Option<JoinHandle<()>>,
42 lock: Arc<Mutex<(bool, Vec<TcpStream>)>>,
43}
44
45impl LockServer {
46 pub fn new() -> Result<LockServer, Error> {
47 let listener = TcpListener::bind("127.0.0.1:0")
48 .with_context(|| "failed to bind TCP listener to manage locking")?;
49 let addr = listener.local_addr()?;
50 Ok(LockServer {
51 listener,
52 addr,
53 threads: HashMap::new(),
54 done: Arc::new(AtomicBool::new(false)),
55 })
56 }
57
58 pub fn addr(&self) -> &SocketAddr {
59 &self.addr
60 }
61
62 pub fn start(self) -> Result<LockServerStarted, Error> {
63 let addr = self.addr;
64 let done = self.done.clone();
65 let thread = thread::spawn(|| {
66 self.run();
67 });
68 Ok(LockServerStarted {
69 addr,
70 thread: Some(thread),
71 done,
72 })
73 }
74
75 fn run(mut self) {
76 while let Ok((client, _)) = self.listener.accept() {
77 if self.done.load(Ordering::SeqCst) {
78 break;
79 }
80
81 let mut client = BufReader::new(client);
84 let mut name = String::new();
85 if client.read_line(&mut name).is_err() {
86 continue;
87 }
88 let client = client.into_inner();
89
90 if let Some(t) = self.threads.get_mut(&name) {
94 let mut state = t.lock.lock().unwrap();
95 if state.0 {
96 state.1.push(client);
97 continue;
98 }
99 drop(t.thread.take().unwrap().join());
100 }
101
102 let lock = Arc::new(Mutex::new((true, vec![client])));
103 let lock2 = lock.clone();
104 let thread = thread::spawn(move || {
105 loop {
106 let mut client = {
107 let mut state = lock2.lock().unwrap();
108 if state.1.is_empty() {
109 state.0 = false;
110 break;
111 } else {
112 state.1.remove(0)
113 }
114 };
115 if client.write_all(&[1]).is_err() {
118 continue;
119 }
120 let mut dst = Vec::new();
121 drop(client.read_to_end(&mut dst));
122 }
123 });
124
125 self.threads.insert(
126 name,
127 ServerClient {
128 thread: Some(thread),
129 lock,
130 },
131 );
132 }
133 }
134}
135
136impl Drop for LockServer {
137 fn drop(&mut self) {
138 for (_, mut client) in self.threads.drain() {
139 if let Some(thread) = client.thread.take() {
140 drop(thread.join());
141 }
142 }
143 }
144}
145
146impl Drop for LockServerStarted {
147 fn drop(&mut self) {
148 self.done.store(true, Ordering::SeqCst);
149 if TcpStream::connect(&self.addr).is_err() {
151 return;
152 }
153 drop(self.thread.take().unwrap().join());
154 }
155}
156
157impl LockServerClient {
158 pub fn lock(addr: &SocketAddr, name: impl AsRef<[u8]>) -> Result<LockServerClient, Error> {
159 let mut client =
160 TcpStream::connect(&addr).with_context(|| "failed to connect to parent lock server")?;
161 client
162 .write_all(name.as_ref())
163 .and_then(|_| client.write_all(b"\n"))
164 .with_context(|| "failed to write to lock server")?;
165 let mut buf = [0];
166 client
167 .read_exact(&mut buf)
168 .with_context(|| "failed to acquire lock")?;
169 Ok(LockServerClient { _socket: client })
170 }
171}