Quantcast
Channel: Code happens
Viewing all articles
Browse latest Browse all 53

hyper combinators in Rust

$
0
0

Recently I read Michael Snoyman’s post on combining Axum, Hyper, Tonic and Tower. While his solution worked, it irked me – it seemed like there should be a much tighter solution possible.

I can deep dive into the code in a later post perhaps, but I think there are four points of difference. One, since the post was written Axum has started boxing its routes : so the enum dispatch approach taken, which delivers low overheads actually has no benefits today.

Two, while writing out the entire type by hand has some benefits, async code is much more pithy.

Thirdly, the code in the post is entirely generic, except the routing function itself.

And fourth, the outer Service<AddrStream> is an unnecessary layer to abstract over: given the similar constraints – the inner Service must take Request<..>, it is possible to just not use a couple of helpers and instead work directly with Service<Request...>.

So, onto a pithier version.

First, the app server code itself.

use std::{convert::Infallible, net::SocketAddr};

use axum::routing::get;
use hyper::{server::conn::AddrStream, service::make_service_fn};
use hyper::{Body, Request};
use tonic::async_trait;

use demo::echo_server::{Echo, EchoServer};
use demo::{EchoReply, EchoRequest};

struct MyEcho;

#[async_trait]
impl Echo for MyEcho {
    async fn echo(
        &self,
        request: tonic::Request<EchoRequest>,
    ) -> Result<tonic::Response<EchoReply>, tonic::Status> {
        Ok(tonic::Response::new(EchoReply {
            message: format!("Echoing back: {}", request.get_ref().message),
        }))
    }
}

#[tokio::main]
async fn main() {
    let addr = SocketAddr::from(([0, 0, 0, 0], 3000));

    let axum_service = axum::Router::new().route("/", get(|| async { "Hello world!" }));

    let grpc_service = tonic::transport::Server::builder()
        .add_service(EchoServer::new(MyEcho))
        .into_service();

    let both_service =
        demo_router::Router::new(axum_service, grpc_service, |req: &Request<Body>| {
            Ok::<bool, Infallible>(
                req.headers().get("content-type").map(|x| x.as_bytes())
                    == Some(b"application/grpc"),
            )
        });

    let make_service = make_service_fn(move |_conn: &AddrStream| {
        let both_service = both_service.clone();
        async { Ok::<_, Infallible>(both_service) }
    });

    let server = hyper::Server::bind(&addr).serve(make_service);

    if let Err(e) = server.await {
        eprintln!("server error: {}", e);
    }
}

Note the Router: it takes the two services and Fn to determine which to use on any given request. Then we just drop that composed service into make_service_fn and we’re done.

Next up we have the Router implementation. This is generic across any two Service<Request<...>> types as long as they are both Into<Bytes> for their Data, and Into<Box<dyn Error>> for errors.

use std::{future::Future, pin::Pin, task::Poll};

use http_body::combinators::UnsyncBoxBody;
use hyper::{body::HttpBody, Body, Request, Response};
use tower::Service;

#[derive(Clone)]
pub struct Router<First, Second, F> {
    first: First,
    second: Second,
    discriminator: F,
}

impl<First, Second, F> Router<First, Second, F> {
    pub fn new(first: First, second: Second, discriminator: F) -> Self {
        Self {
            first,
            second,
            discriminator,
        }
    }
}

impl<First, Second, FirstBody, FirstBodyError, SecondBody, SecondBodyError, F, FErr>
    Service<Request<Body>> for BinaryRouter<First, Second, F>
where
    First: Service<Request<Body>, Response = Response<FirstBody>>,
    First::Error: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
    First::Future: Send + 'static,
    First::Response: 'static,
    Second: Service<Request<Body>, Response = Response<SecondBody>>,
    Second::Error: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
    Second::Future: Send + 'static,
    Second::Response: 'static,
    F: Fn(&Request<Body>) -> Result<bool, FErr>,
    FErr: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
    FirstBody: HttpBody<Error = FirstBodyError> + Send + 'static,
    FirstBody::Data: Into<bytes::Bytes>,
    FirstBodyError: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
    SecondBody: HttpBody<Error = SecondBodyError> + Send + 'static,
    SecondBody::Data: Into<bytes::Bytes>,
    SecondBodyError: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
{
    type Response = Response<
        UnsyncBoxBody<
            <hyper::Body as HttpBody>::Data,
            Box<dyn std::error::Error + Send + Sync + 'static>,
        >,
    >;
    type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
    type Future =
        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;

    fn poll_ready(
        &mut self,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        match self.first.poll_ready(cx) {
            Poll::Ready(Ok(())) => match self.second.poll_ready(cx) {
                Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
                Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
                Poll::Pending => Poll::Pending,
            },
            Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
            Poll::Pending => Poll::Pending,
        }
    }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        let discriminant = { (self.discriminator)(&req) };
        let (first, second) = if matches!(discriminant, Ok(false)) {
            (Some(self.first.call(req)), None)
        } else if matches!(discriminant, Ok(true)) {
            (None, Some(self.second.call(req)))
        } else {
            (None, None)
        };
        let f = async {
            Ok(match discriminant.map_err(Into::into)? {
                true => second
                    .unwrap()
                    .await
                    .map_err(Into::into)?
                    .map(|b| b.map_data(Into::into).map_err(Into::into).boxed_unsync()),
                false => first
                    .unwrap()
                    .await
                    .map_err(Into::into)?
                    .map(|b| b.map_data(Into::into).map_err(Into::into).boxed_unsync()),
            })
        };
        Box::pin(f)
    }
}

Interesting things here – I use boxed_unsync to abstract over the body concrete type, and I implement the future using async code rather than as a separate struct. It becomes much smaller even after a few bits of extra type constraining.

One thing that flummoxed me for a little was the need to capture the future for the underlying response outside of the async block. Failing to do so provokes a 'static requirement which was tricky to debug. Fortunately there is a bug on making this easier to diagnose in rustc already. The underlying problem is that if you create the async block, and then dereference self, the type for impl of .first has to live an arbitrary time. Whereas by capturing the future immediately, only the impl of the future has to live an arbitrary time, and that doesn’t then require changing the signature of the function.

This is almost worth turning into a crate – I couldn’t see an existing one when I looked, though it does end up rather small – < 100 lines. What do you all think?


Viewing all articles
Browse latest Browse all 53

Trending Articles