diff --git a/examples/tcp_echo.rs b/examples/tcp_echo.rs index 108ad9d..a4eccf9 100644 --- a/examples/tcp_echo.rs +++ b/examples/tcp_echo.rs @@ -1,20 +1,42 @@ -use futures_lite::StreamExt; +use futures_lite::{future::zip, StreamExt}; use wove::{ - io::{AsyncReadLoan, AsyncWriteLoan, BufferResultExt}, + futures::Nexus, + io::{AsyncReadLoan, AsyncWriteLoan, Transpose}, io_impl::io_uring::IoUring, }; pub async fn go(uring: &IoUring) -> std::io::Result<()> { let mut listener = wove::net::TcpListener::bind(uring, "127.0.0.1:0").await?; let addr = listener.local_addr().await?; - println!("Listening on {addr}"); + println!("{addr}"); let mut incoming = listener.incoming().await?; - while let Some(conn) = incoming.next().await { - let mut conn = conn?; - let (buf, _read_amt) = conn.read(vec![0; 4096]).await.buf_ok()?; - conn.write(buf).await.buf_ok()?; - } + let handlers = Nexus::new(); + + let accept = async { + while let Some(conn) = incoming.next().await { + let mut conn = conn?; + + handlers.push(Box::pin(async move { + loop { + let (buf, read_amt) = conn.read(vec![0; 4096]).await.transpose()?; + if read_amt == 0 { + break; + } + conn.write(buf).await.transpose()?; + } + + std::io::Result::Ok(()) + })) + } + + std::io::Result::Ok(()) + }; + + let handle_connections = handlers.stream().fuse().last(); + let (accept_result, handle_result) = zip(accept, handle_connections).await; + accept_result?; + handle_result.transpose()?; Ok(()) } diff --git a/src/futures.rs b/src/futures.rs index 8a7ec95..d3aa322 100644 --- a/src/futures.rs +++ b/src/futures.rs @@ -1,12 +1,12 @@ use std::{ - cell::RefCell, + cell::{RefCell, UnsafeCell}, future::Future, pin::Pin, rc::Rc, task::{Context, Poll, Waker}, }; -use futures_lite::Stream; +use futures_lite::{FutureExt, Stream}; type Slot = Rc>>; @@ -117,3 +117,52 @@ pub fn handoff() -> (Put, Get) { (put, get) } + +pub struct Nexus(UnsafeCell>); +impl Nexus { + pub fn new() -> Self { + Self(UnsafeCell::new(Vec::new())) + } + + pub fn push(&self, fut: Fut) { + unsafe { &mut *self.0.get() }.push(fut); + } +} + +impl Default for Nexus { + fn default() -> Self { + Self::new() + } +} + +impl Nexus { + pub fn stream(&self) -> NexusStream<'_, Fut> { + NexusStream(self) + } +} + +pub struct NexusStream<'a, Fut>(&'a Nexus); +impl Stream for NexusStream<'_, Fut> { + type Item = Fut::Output; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let futures = unsafe { &mut *self.0 .0.get() }; + + let out = futures.iter_mut().enumerate().find_map(|(idx, fut)| { + if let Poll::Ready(v) = fut.poll(cx) { + return Some((idx, v)); + } + + None + }); + + match out { + None => Poll::Pending, + Some((idx, v)) => { + futures.swap_remove(idx); + + Poll::Ready(Some(v)) + }, + } + } +} diff --git a/src/io.rs b/src/io.rs index af510b9..604a77e 100644 --- a/src/io.rs +++ b/src/io.rs @@ -3,11 +3,11 @@ use futures_lite::Stream; use crate::aliases::IoResult; use std::{future::Future, mem::MaybeUninit}; -pub trait BufferResultExt { - fn buf_ok(self) -> IoResult<(B, T)>; +pub trait Transpose { + fn transpose(self) -> IoResult<(B, T)>; } -impl BufferResultExt for (B, IoResult) { - fn buf_ok(self) -> IoResult<(B, T)> { +impl Transpose for (B, IoResult) { + fn transpose(self) -> IoResult<(B, T)> { let (buf, res) = self; res.map(|v| (buf, v)) diff --git a/src/io_impl/io_uring/mod.rs b/src/io_impl/io_uring/mod.rs index dfa9993..0b7340b 100644 --- a/src/io_impl/io_uring/mod.rs +++ b/src/io_impl/io_uring/mod.rs @@ -600,7 +600,7 @@ mod test { use crate::{ aliases::IoResult, - io::{AsyncReadLoan, AsyncWriteLoan, BufferResultExt}, + io::{AsyncReadLoan, AsyncWriteLoan, Transpose}, io_impl::io_uring::IoUring, }; use futures_lite::StreamExt; @@ -632,7 +632,7 @@ mod test { let (addr, read_data) = start_echo(uring).await?; let mut conn = crate::net::TcpStream::connect(uring, addr).await?; - conn.write(input).await.buf_ok()?; + conn.write(input).await.transpose()?; let data = read_data.await?; @@ -655,11 +655,11 @@ mod test { let write_data = async { let mut stream = listener.incoming().await?.next().await.unwrap()?; - stream.write("Hello".as_bytes()).await.buf_ok()?; + stream.write("Hello".as_bytes()).await.transpose()?; crate::time::sleep(uring, Duration::from_millis(500)).await?; - stream.write(", world".as_bytes()).await.buf_ok()?; + stream.write(", world".as_bytes()).await.transpose()?; IoResult::Ok(()) };