1#![deny(missing_docs)]
2#[cfg(not(feature = "loom"))]
5use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
6
7use std::{
8 pin::Pin,
9 sync::{Arc, Weak},
10 task::{Context, Poll},
11};
12
13#[cfg(feature = "loom")]
14use loom::sync::atomic::{AtomicBool, AtomicU64, Ordering};
15use pin_project::{pin_project, pinned_drop};
16use tokio::io::AsyncWrite;
17use tracing::{Span, debug, instrument};
18
19use crate::{Backpressure, Channel, id_factory::Id};
20
21#[pin_project(project = WriterProj)]
23pub enum Writer {
24 Strong(#[pin] StrongWriter),
26 Weak(#[pin] WeakWriter),
28}
29
30pub struct StrongWriter {
32 id: Id,
34 chan: Weak<Channel>,
36 pub(crate) pos: Arc<AtomicU64>,
40 pos_id: Option<usize>,
42 pub(crate) rem: usize,
44 fuse: AtomicBool,
46 span: Span,
48}
49
50#[pin_project(PinnedDrop)]
52pub struct WeakWriter {
53 id: Id,
55 chan: Weak<Channel>,
57 fuse: AtomicBool,
59 #[pin]
62 current: Option<StrongWriter>,
63 span: Span,
65}
66
67impl StrongWriter {
68 #[instrument(name = "StrongWriter", parent = &chan.span, skip_all, fields(id = id.get()))]
69 pub(crate) fn new(
70 id: Id,
71 chan: Arc<Channel>,
72 pos: Arc<AtomicU64>,
73 pos_id: Option<usize>,
74 ) -> Self {
75 Self {
76 id,
77 chan: Arc::downgrade(&chan),
78 pos,
79 pos_id,
80 rem: 0,
81 fuse: AtomicBool::new(false),
82 span: Span::current(),
83 }
84 }
85
86 fn release_tail(&mut self) {
87 if let Some(id) = self.pos_id.take()
88 && let Some(chan) = self.chan.upgrade()
89 {
90 chan.remove_tail(id);
91 }
92 }
93
94 fn is_idle(&self) -> bool {
95 self.rem == 0
96 }
97
98 pub fn writable_size(&self) -> u64 {
100 if let Some(chan) = self.chan.upgrade() {
101 chan.writable_size(self.pos.load(Ordering::Acquire))
102 } else {
103 0
104 }
105 }
106
107 #[instrument(parent = &self.span, skip(self))]
111 pub fn terminate(&mut self) {
112 debug!("terminate");
113
114 self.fuse.store(true, Ordering::Release);
115 self.release_tail();
116 }
117
118 #[instrument(parent = &self.span, skip(self))]
120 pub fn downgrade(mut self) -> WeakWriter {
121 debug!("downgrade this writer");
122
123 self.release_tail();
124
125 WeakWriter::new_with_state(
126 self.id.clone(),
127 self.chan.clone(),
128 self.fuse.load(Ordering::Acquire),
129 None,
130 )
131 }
132}
133
134impl Drop for StrongWriter {
135 fn drop(&mut self) {
136 self.release_tail();
137 }
138}
139
140impl AsyncWrite for StrongWriter {
141 #[instrument(parent = &self.span, skip_all)]
142 fn poll_write(
143 mut self: Pin<&mut Self>,
144 cx: &mut Context<'_>,
145 buf: &[u8],
146 ) -> Poll<std::io::Result<usize>> {
147 let Some(chan) = self.chan.upgrade() else {
148 return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)));
149 };
150
151 if self.fuse.load(Ordering::Acquire) || chan.terminated.load(Ordering::Acquire) {
153 return Poll::Ready(Err(std::io::Error::from(
154 std::io::ErrorKind::ConnectionAborted,
155 )));
156 }
157
158 if chan.draining.load(Ordering::Acquire) && self.rem == 0 {
160 return Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::Interrupted)));
161 }
162
163 let len = buf.len();
164
165 let pos = if self.rem == 0 {
167 if matches!(chan.backpressure, Backpressure::Drop) {
169 let avail = chan.writable_size(self.pos.load(Ordering::Acquire)) as usize;
170 if avail < len {
171 return Poll::Ready(Ok(0));
172 }
173 let reserve = len as u64;
174 let pos = chan.reserve_slice(reserve);
175 self.pos.store(pos, Ordering::Release);
176 self.rem = len;
177 chan.register_frame(pos, reserve, self.id.get());
178 pos
179 } else {
180 let pos = chan.reserve_slice(len as u64);
181 self.pos.store(pos, Ordering::Release);
182 self.rem = len;
183 chan.register_frame(pos, len as u64, self.id.get());
184 pos
185 }
186 } else {
187 self.pos.load(Ordering::Acquire)
188 };
189
190 let written = chan.write(pos, &buf[..self.rem]);
191
192 if written == 0 {
193 debug!("writer poll_write made no progress");
194 if matches!(chan.backpressure, Backpressure::Drop) {
195 return Poll::Ready(Ok(0));
197 }
198 chan.enqueue(pos, cx.waker().to_owned());
199 Poll::Pending
200 } else {
201 self.pos.store(pos + written as u64, Ordering::Release);
202 self.rem -= written;
203 debug!(
204 pos = self.pos.load(Ordering::Acquire),
205 rem = self.rem,
206 written,
207 "writer poll_write committed bytes"
208 );
209 chan.schedule_readers();
210 Poll::Ready(Ok(written))
211 }
212 }
213
214 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
215 Poll::Ready(Ok(()))
216 }
217
218 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
219 Poll::Ready(Ok(()))
220 }
221}
222
223impl WeakWriter {
224 #[instrument(name = "WeakWriter", parent = &chan.span, skip_all, fields(id = id.get()))]
225 pub(crate) fn new(id: Id, chan: Arc<Channel>) -> Self {
226 Self {
227 id,
228 chan: Arc::downgrade(&chan),
229 fuse: AtomicBool::new(false),
230 current: None,
231 span: Span::current(),
232 }
233 }
234
235 #[instrument(name = "WeakWriter", parent = &chan.upgrade().expect("channel missing").span, skip_all, fields(id = id.get()))]
236 fn new_with_state(
237 id: Id,
238 chan: Weak<Channel>,
239 fuse_state: bool,
240 current: Option<StrongWriter>,
241 ) -> Self {
242 let mut writer = Self {
243 id,
244 chan,
245 fuse: AtomicBool::new(fuse_state),
246 current,
247 span: Span::current(),
248 };
249 if fuse_state {
250 writer.terminate();
251 }
252 writer
253 }
254
255 fn ensure_strong(self: Pin<&mut Self>) -> std::io::Result<Pin<&mut StrongWriter>> {
256 let Some(chan) = self.chan.upgrade() else {
257 return Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe));
258 };
259
260 let mut this = self.project();
261 if this.fuse.load(Ordering::Acquire) {
262 return Err(std::io::Error::from(std::io::ErrorKind::ConnectionAborted));
263 }
264 if chan.draining.load(Ordering::Acquire) {
265 return Err(std::io::Error::from(std::io::ErrorKind::Interrupted));
266 }
267 if this.current.is_none() {
268 let strong = chan.new_strong_writer_with_id(this.id.clone());
269 this.current.set(Some(strong));
270 }
271 Ok(this.current.as_pin_mut().expect("strong writer present"))
272 }
273
274 fn release_if_idle(self: Pin<&mut Self>) {
275 let mut this = self.project();
276 if let Some(mut strong) = this.current.as_mut().as_pin_mut()
277 && strong.is_idle()
278 {
279 strong.as_mut().get_mut().release_tail();
280 if let Some(chan) = this.chan.upgrade() {
281 chan.schedule_readers();
282 }
283 this.current.set(None);
284 }
285 }
286
287 #[instrument(parent = &self.span, skip(self))]
288 fn terminate(&mut self) {
289 debug!("terminate");
290
291 self.fuse.store(true, Ordering::Release);
292 if let Some(mut strong) = self.current.take() {
293 strong.terminate();
294 }
295 }
296}
297
298impl AsyncWrite for WeakWriter {
299 fn poll_write(
300 self: Pin<&mut Self>,
301 cx: &mut Context<'_>,
302 buf: &[u8],
303 ) -> Poll<std::io::Result<usize>> {
304 let mut this = self;
305 match this.as_mut().ensure_strong() {
306 Ok(mut strong) => {
307 let result = strong.as_mut().poll_write(cx, buf);
308 if matches!(result, Poll::Ready(Ok(_))) {
309 this.release_if_idle();
310 }
311 result
312 }
313 Err(err) => Poll::Ready(Err(err)),
314 }
315 }
316
317 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
318 let mut this = self;
319 match this.as_mut().ensure_strong() {
320 Ok(mut strong) => {
321 let result = strong.as_mut().poll_flush(cx);
322 this.release_if_idle();
323 result
324 }
325 Err(err) => Poll::Ready(Err(err)),
326 }
327 }
328
329 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
330 let mut this = self;
331 match this.as_mut().ensure_strong() {
332 Ok(mut strong) => {
333 let result = strong.as_mut().poll_shutdown(cx);
334 this.release_if_idle();
335 result
336 }
337 Err(err) => Poll::Ready(Err(err)),
338 }
339 }
340}
341
342#[pinned_drop]
343impl PinnedDrop for WeakWriter {
344 fn drop(self: Pin<&mut Self>) {
345 let mut this = self.project();
346 if let Some(mut strong) = this.current.take() {
347 strong.release_tail();
348 if let Some(chan) = strong.chan.upgrade() {
349 chan.schedule_readers();
350 }
351 }
352 }
353}
354
355impl Writer {
356 pub fn terminate(&mut self) {
358 match self {
359 Writer::Strong(writer) => writer.terminate(),
360 Writer::Weak(writer) => writer.terminate(),
361 }
362 }
363
364 pub fn downgrade(self) -> Writer {
366 match self {
367 Writer::Strong(writer) => Writer::Weak(writer.downgrade()),
368 Writer::Weak(writer) => Writer::Weak(writer),
369 }
370 }
371
372 #[allow(clippy::result_large_err)]
374 pub fn into_strong(self) -> std::result::Result<StrongWriter, Self> {
375 match self {
376 Writer::Strong(strong) => Ok(strong),
377 Writer::Weak(_) => Err(self),
378 }
379 }
380
381 #[allow(clippy::result_large_err)]
383 pub fn into_weak(self) -> std::result::Result<WeakWriter, Self> {
384 match self {
385 Writer::Weak(weak) => Ok(weak),
386 Writer::Strong(strong) => Ok(strong.downgrade()),
387 }
388 }
389}
390
391impl AsyncWrite for Writer {
392 fn poll_write(
393 self: Pin<&mut Self>,
394 cx: &mut Context<'_>,
395 buf: &[u8],
396 ) -> Poll<std::io::Result<usize>> {
397 match self.project() {
398 WriterProj::Strong(strong) => strong.poll_write(cx, buf),
399 WriterProj::Weak(weak) => weak.poll_write(cx, buf),
400 }
401 }
402
403 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
404 match self.project() {
405 WriterProj::Strong(strong) => strong.poll_flush(cx),
406 WriterProj::Weak(weak) => weak.poll_flush(cx),
407 }
408 }
409
410 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
411 match self.project() {
412 WriterProj::Strong(strong) => strong.poll_shutdown(cx),
413 WriterProj::Weak(weak) => weak.poll_shutdown(cx),
414 }
415 }
416}