1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
use futures_sink::Sink;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, mem};
use tokio::sync::mpsc::OwnedPermit;
use tokio::sync::mpsc::Sender;

use super::ReusableBoxFuture;

/// Error returned by the `PollSender` when the channel is closed.
#[derive(Debug)]
pub struct PollSendError<T>(Option<T>);

impl<T> PollSendError<T> {
    /// Consumes the stored value, if any.
    ///
    /// If this error was encountered when calling `start_send`/`send_item`, this will be the item
    /// that the caller attempted to send.  Otherwise, it will be `None`.
    pub fn into_inner(self) -> Option<T> {
        self.0
    }
}

impl<T> fmt::Display for PollSendError<T> {
    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(fmt, "channel closed")
    }
}

impl<T: fmt::Debug> std::error::Error for PollSendError<T> {}

#[derive(Debug)]
enum State<T> {
    Idle(Sender<T>),
    Acquiring,
    ReadyToSend(OwnedPermit<T>),
    Closed,
}

/// A wrapper around [`mpsc::Sender`] that can be polled.
///
/// [`mpsc::Sender`]: tokio::sync::mpsc::Sender
#[derive(Debug)]
pub struct PollSender<T> {
    sender: Option<Sender<T>>,
    state: State<T>,
    acquire: PollSenderFuture<T>,
}

// Creates a future for acquiring a permit from the underlying channel.  This is used to ensure
// there's capacity for a send to complete.
//
// By reusing the same async fn for both `Some` and `None`, we make sure every future passed to
// ReusableBoxFuture has the same underlying type, and hence the same size and alignment.
async fn make_acquire_future<T>(
    data: Option<Sender<T>>,
) -> Result<OwnedPermit<T>, PollSendError<T>> {
    match data {
        Some(sender) => sender
            .reserve_owned()
            .await
            .map_err(|_| PollSendError(None)),
        None => unreachable!("this future should not be pollable in this state"),
    }
}

type InnerFuture<'a, T> = ReusableBoxFuture<'a, Result<OwnedPermit<T>, PollSendError<T>>>;

#[derive(Debug)]
// TODO: This should be replace with a type_alias_impl_trait to eliminate `'static` and all the transmutes
struct PollSenderFuture<T>(InnerFuture<'static, T>);

impl<T> PollSenderFuture<T> {
    /// Create with an empty inner future with no `Send` bound.
    fn empty() -> Self {
        // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not
        // compatible with the transitive bounds required by `Sender<T>`.
        Self(ReusableBoxFuture::new(async { unreachable!() }))
    }
}

impl<T: Send> PollSenderFuture<T> {
    /// Create with an empty inner future.
    fn new() -> Self {
        let v = InnerFuture::new(make_acquire_future(None));
        // This is safe because `make_acquire_future(None)` is actually `'static`
        Self(unsafe { mem::transmute::<InnerFuture<'_, T>, InnerFuture<'static, T>>(v) })
    }

    /// Poll the inner future.
    fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<OwnedPermit<T>, PollSendError<T>>> {
        self.0.poll(cx)
    }

    /// Replace the inner future.
    fn set(&mut self, sender: Option<Sender<T>>) {
        let inner: *mut InnerFuture<'static, T> = &mut self.0;
        let inner: *mut InnerFuture<'_, T> = inner.cast();
        // SAFETY: The `make_acquire_future(sender)` future must not exist after the type `T`
        // becomes invalid, and this casts away the type-level lifetime check for that. However, the
        // inner future is never moved out of this `PollSenderFuture<T>`, so the future will not
        // live longer than the `PollSenderFuture<T>` lives. A `PollSenderFuture<T>` is guaranteed
        // to not exist after the type `T` becomes invalid, because it is annotated with a `T`, so
        // this is ok.
        let inner = unsafe { &mut *inner };
        inner.set(make_acquire_future(sender));
    }
}

impl<T: Send> PollSender<T> {
    /// Creates a new `PollSender`.
    pub fn new(sender: Sender<T>) -> Self {
        Self {
            sender: Some(sender.clone()),
            state: State::Idle(sender),
            acquire: PollSenderFuture::new(),
        }
    }

    fn take_state(&mut self) -> State<T> {
        mem::replace(&mut self.state, State::Closed)
    }

    /// Attempts to prepare the sender to receive a value.
    ///
    /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to
    /// `send_item`.
    ///
    /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value,
    /// by reserving a slot in the channel for the item to be sent. If this method returns
    /// `Poll::Pending`, the current task is registered to be notified (via
    /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again.
    ///
    /// # Errors
    ///
    /// If the channel is closed, an error will be returned.  This is a permanent state.
    pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> {
        loop {
            let (result, next_state) = match self.take_state() {
                State::Idle(sender) => {
                    // Start trying to acquire a permit to reserve a slot for our send, and
                    // immediately loop back around to poll it the first time.
                    self.acquire.set(Some(sender));
                    (None, State::Acquiring)
                }
                State::Acquiring => match self.acquire.poll(cx) {
                    // Channel has capacity.
                    Poll::Ready(Ok(permit)) => {
                        (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit))
                    }
                    // Channel is closed.
                    Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed),
                    // Channel doesn't have capacity yet, so we need to wait.
                    Poll::Pending => (Some(Poll::Pending), State::Acquiring),
                },
                // We're closed, either by choice or because the underlying sender was closed.
                s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s),
                // We're already ready to send an item.
                s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s),
            };

            self.state = next_state;
            if let Some(result) = result {
                return result;
            }
        }
    }

    /// Sends an item to the channel.
    ///
    /// Before calling `send_item`, `poll_reserve` must be called with a successful return
    /// value of `Poll::Ready(Ok(()))`.
    ///
    /// # Errors
    ///
    /// If the channel is closed, an error will be returned.  This is a permanent state.
    ///
    /// # Panics
    ///
    /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method
    /// will panic.
    #[track_caller]
    pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> {
        let (result, next_state) = match self.take_state() {
            State::Idle(_) | State::Acquiring => {
                panic!("`send_item` called without first calling `poll_reserve`")
            }
            // We have a permit to send our item, so go ahead, which gets us our sender back.
            State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))),
            // We're closed, either by choice or because the underlying sender was closed.
            State::Closed => (Err(PollSendError(Some(value))), State::Closed),
        };

        // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`.
        self.state = if self.sender.is_some() {
            next_state
        } else {
            State::Closed
        };
        result
    }

    /// Checks whether this sender is been closed.
    ///
    /// The underlying channel that this sender was wrapping may still be open.
    pub fn is_closed(&self) -> bool {
        matches!(self.state, State::Closed) || self.sender.is_none()
    }

    /// Gets a reference to the `Sender` of the underlying channel.
    ///
    /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender
    /// was wrapping may still be open.
    pub fn get_ref(&self) -> Option<&Sender<T>> {
        self.sender.as_ref()
    }

    /// Closes this sender.
    ///
    /// No more messages will be able to be sent from this sender, but the underlying channel will
    /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel.
    ///
    /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made
    /// to `send_item` in order to consume the reserved slot.  After that, no further sends will be
    /// possible.  If you do not intend to send another item, you can release the reserved slot back
    /// to the underlying sender by calling [`abort_send`].
    ///
    /// [`abort_send`]: crate::sync::PollSender::abort_send
    /// [`Receiver`]: tokio::sync::mpsc::Receiver
    pub fn close(&mut self) {
        // Mark ourselves officially closed by dropping our main sender.
        self.sender = None;

        // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly
        // transition to the closed state.  Otherwise, leave the existing permit in place for the
        // caller if they want to complete the send.
        match self.state {
            State::Idle(_) => self.state = State::Closed,
            State::Acquiring => {
                self.acquire.set(None);
                self.state = State::Closed;
            }
            _ => {}
        }
    }

    /// Aborts the current in-progress send, if any.
    ///
    /// Returns `true` if a send was aborted.  If the sender was closed prior to calling
    /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be
    /// ready to attempt another send.
    pub fn abort_send(&mut self) -> bool {
        // We may have been closed in the meantime, after a call to `poll_reserve` already
        // succeeded.  We'll check if `self.sender` is `None` to see if we should transition to the
        // closed state when we actually abort a send, rather than resetting ourselves back to idle.

        let (result, next_state) = match self.take_state() {
            // We're currently trying to reserve a slot to send into.
            State::Acquiring => {
                // Replacing the future drops the in-flight one.
                self.acquire.set(None);

                // If we haven't closed yet, we have to clone our stored sender since we have no way
                // to get it back from the acquire future we just dropped.
                let state = match self.sender.clone() {
                    Some(sender) => State::Idle(sender),
                    None => State::Closed,
                };
                (true, state)
            }
            // We got the permit.  If we haven't closed yet, get the sender back.
            State::ReadyToSend(permit) => {
                let state = if self.sender.is_some() {
                    State::Idle(permit.release())
                } else {
                    State::Closed
                };
                (true, state)
            }
            s => (false, s),
        };

        self.state = next_state;
        result
    }
}

impl<T> Clone for PollSender<T> {
    /// Clones this `PollSender`.
    ///
    /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`.
    fn clone(&self) -> PollSender<T> {
        let (sender, state) = match self.sender.clone() {
            Some(sender) => (Some(sender.clone()), State::Idle(sender)),
            None => (None, State::Closed),
        };

        Self {
            sender,
            state,
            acquire: PollSenderFuture::empty(),
        }
    }
}

impl<T: Send + 'static> Sink<T> for PollSender<T> {
    type Error = PollSendError<T>;

    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Pin::into_inner(self).poll_reserve(cx)
    }

    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
        Pin::into_inner(self).send_item(item)
    }

    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Pin::into_inner(self).close();
        Poll::Ready(Ok(()))
    }
}