#[cfg(feature = "discover")]
use crate::discover::{Change, Discover};
#[cfg(feature = "discover")]
use futures_core::{ready, Stream};
#[cfg(feature = "discover")]
use pin_project_lite::pin_project;
#[cfg(feature = "discover")]
use std::pin::Pin;
use super::completion::{CompleteOnResponse, TrackCompletion, TrackCompletionFuture};
use super::Load;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower_service::Service;
#[derive(Debug)]
pub struct PendingRequests<S, C = CompleteOnResponse> {
service: S,
ref_count: RefCount,
completion: C,
}
#[derive(Clone, Debug, Default)]
struct RefCount(Arc<()>);
#[cfg(feature = "discover")]
pin_project! {
#[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
#[derive(Debug)]
pub struct PendingRequestsDiscover<D, C = CompleteOnResponse> {
#[pin]
discover: D,
completion: C,
}
}
#[derive(Clone, Copy, Debug, Default, PartialOrd, PartialEq, Ord, Eq)]
pub struct Count(usize);
#[derive(Debug)]
pub struct Handle(RefCount);
impl<S, C> PendingRequests<S, C> {
pub fn new(service: S, completion: C) -> Self {
Self {
service,
completion,
ref_count: RefCount::default(),
}
}
fn handle(&self) -> Handle {
Handle(self.ref_count.clone())
}
}
impl<S, C> Load for PendingRequests<S, C> {
type Metric = Count;
fn load(&self) -> Count {
Count(self.ref_count.ref_count() - 1)
}
}
impl<S, C, Request> Service<Request> for PendingRequests<S, C>
where
S: Service<Request>,
C: TrackCompletion<Handle, S::Response>,
{
type Response = C::Output;
type Error = S::Error;
type Future = TrackCompletionFuture<S::Future, C, Handle>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
TrackCompletionFuture::new(
self.completion.clone(),
self.handle(),
self.service.call(req),
)
}
}
#[cfg(feature = "discover")]
impl<D, C> PendingRequestsDiscover<D, C> {
pub fn new<Request>(discover: D, completion: C) -> Self
where
D: Discover,
D::Service: Service<Request>,
C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,
{
Self {
discover,
completion,
}
}
}
#[cfg(feature = "discover")]
impl<D, C> Stream for PendingRequestsDiscover<D, C>
where
D: Discover,
C: Clone,
{
type Item = Result<Change<D::Key, PendingRequests<D::Service, C>>, D::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use self::Change::*;
let this = self.project();
let change = match ready!(this.discover.poll_discover(cx)).transpose()? {
None => return Poll::Ready(None),
Some(Insert(k, svc)) => Insert(k, PendingRequests::new(svc, this.completion.clone())),
Some(Remove(k)) => Remove(k),
};
Poll::Ready(Some(Ok(change)))
}
}
impl RefCount {
pub(crate) fn ref_count(&self) -> usize {
Arc::strong_count(&self.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::future;
use std::task::{Context, Poll};
struct Svc;
impl Service<()> for Svc {
type Response = ();
type Error = ();
type Future = future::Ready<Result<(), ()>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, (): ()) -> Self::Future {
future::ok(())
}
}
#[test]
fn default() {
let mut svc = PendingRequests::new(Svc, CompleteOnResponse);
assert_eq!(svc.load(), Count(0));
let rsp0 = svc.call(());
assert_eq!(svc.load(), Count(1));
let rsp1 = svc.call(());
assert_eq!(svc.load(), Count(2));
let () = tokio_test::block_on(rsp0).unwrap();
assert_eq!(svc.load(), Count(1));
let () = tokio_test::block_on(rsp1).unwrap();
assert_eq!(svc.load(), Count(0));
}
#[test]
fn with_completion() {
#[derive(Clone)]
struct IntoHandle;
impl TrackCompletion<Handle, ()> for IntoHandle {
type Output = Handle;
fn track_completion(&self, i: Handle, (): ()) -> Handle {
i
}
}
let mut svc = PendingRequests::new(Svc, IntoHandle);
assert_eq!(svc.load(), Count(0));
let rsp = svc.call(());
assert_eq!(svc.load(), Count(1));
let i0 = tokio_test::block_on(rsp).unwrap();
assert_eq!(svc.load(), Count(1));
let rsp = svc.call(());
assert_eq!(svc.load(), Count(2));
let i1 = tokio_test::block_on(rsp).unwrap();
assert_eq!(svc.load(), Count(2));
drop(i1);
assert_eq!(svc.load(), Count(1));
drop(i0);
assert_eq!(svc.load(), Count(0));
}
}