diff --git a/src/io.rs b/src/io.rs index a72c2d3..c86c9ea 100644 --- a/src/io.rs +++ b/src/io.rs @@ -3,17 +3,9 @@ use std::future::Future; use crate::{aliases::IoResult, io_impl::IoImpl}; pub trait AsyncRead { - fn read( - &mut self, - io: &I, - buf: Box<[u8]>, - ) -> impl Future>>; + fn read(&mut self, buf: &mut [u8]) -> impl Future>; } pub trait AsyncWrite { - fn write( - &mut self, - io: &I, - buf: Box<[u8]>, - ) -> impl Future>; + fn write(&mut self, buf: &[u8]) -> impl Future>; } diff --git a/src/io_impl/io_uring/mod.rs b/src/io_impl/io_uring/mod.rs index d799b17..80d2690 100644 --- a/src/io_impl/io_uring/mod.rs +++ b/src/io_impl/io_uring/mod.rs @@ -1,9 +1,14 @@ use std::{ + async_iter::AsyncIterator, future::Future, + net::SocketAddr, + os::fd::{FromRawFd, IntoRawFd}, + pin::Pin, sync::atomic::{AtomicU64, Ordering}, task::{Context, Poll, Waker}, }; +use futures_lite::Stream; use parking::Parker; use crate::aliases::IoResult; @@ -90,6 +95,12 @@ impl IoUring { for entry in cq { did_handle = true; + let ud = entry.user_data(); + if ud == 0 { + self.active_completions.fetch_sub(1, Ordering::Relaxed); + continue; + } + let ud = unsafe { UserData::from_u64(entry.user_data()) }; ud.tx.send_blocking(entry).unwrap(); if ud.persist { @@ -119,8 +130,8 @@ impl IoUring { /// Runs a future. Conceptually, this is similar to [`Self::block_on`], but /// instead of acting as its own executor, this allows us to be embedded in /// another runtime - pub async fn run(&self, fut: F, behavior: ActiveRequestBehavior) -> Run { - Run(&self, fut, behavior) + pub fn run(&self, fut: F, behavior: ActiveRequestBehavior) -> Run { + Run(self, fut, behavior) } pub fn block_on(&self, fut: F) -> F::Output { @@ -171,6 +182,20 @@ impl IoUring { } } +pub struct TcpListener(io_uring::types::Fd); +impl Drop for TcpListener { + fn drop(&mut self) { + unsafe { std::net::TcpListener::from_raw_fd(self.0 .0) }; + } +} + +pub struct TcpStream(io_uring::types::Fd); +impl Drop for TcpStream { + fn drop(&mut self) { + unsafe { std::net::TcpStream::from_raw_fd(self.0 .0) }; + } +} + impl IoImpl for IoUring { async fn sleep(&self, duration: std::time::Duration) -> IoResult<()> { let (tx, rx) = async_channel::bounded(1); @@ -189,6 +214,202 @@ impl IoImpl for IoUring { Ok(()) } + + type TcpListener = io_uring::types::Fd; + fn open_tcp_socket( + &self, + addr: std::net::SocketAddr, + ) -> impl Future> { + // FIXME(Blocking) + // There is some some magic missing in the commented out code to make the + // socket clean up properly on process exit and whatnot. For now, just use + // the std implementation and cast it to a FileHandle + + let listener = std::net::TcpListener::bind(addr); + async { IoResult::Ok(io_uring::types::Fd(listener?.into_raw_fd())) } + + /* + // let (tx, rx) = async_channel::bounded(1); + + let domain = match () { + _ if socket_addr.is_ipv4() => libc::AF_INET, + _ if socket_addr.is_ipv6() => libc::AF_INET6, + _ => return Err(std::io::Error::other("Unsupported domain")), + }; + + let entry = io_uring::opcode::Socket::new(domain, libc::SOCK_STREAM, 0) + .build() + .user_data(UserData::new_boxed(tx, false).into_u64()); + + unsafe { + p.uring.submission_shared().push(&entry).unwrap(); + } + + let entry = rx.recv().await.unwrap(); + + let fd = handle_error(entry.result())?; + + let sock = libc::sockaddr_in { + sin_family: domain as _, + sin_port: socket_addr.port().to_be(), + sin_addr: libc::in_addr { + s_addr: match socket_addr.ip() { + IpAddr::V4(v) => v.to_bits().to_be(), + IpAddr::V6(_) => panic!(), + }, + }, + sin_zero: Default::default(), + }; + + // FIXME(Blocking) + handle_error(unsafe { + libc::bind( + fd, + &sock as *const _ as *const _, + core::mem::size_of_val(&sock) as u32, + ) + })?; + + // FIXME(Blocking) + handle_error(unsafe { libc::listen(fd, libc::SOMAXCONN) })?; + + Ok(io_uring::types::Fd(fd)) + */ + } + + async fn listener_local_addr(&self, listener: &Self::TcpListener) -> IoResult { + // FIXME(Blocking) + let listener = unsafe { std::net::TcpListener::from_raw_fd(listener.0) }; + let addr = listener.local_addr()?; + let _ = listener.into_raw_fd(); + + Ok(addr) + } + + type Incoming = Incoming; + fn accept_many( + &self, + listener: &mut Self::TcpListener, + ) -> impl Future> { + let (tx, rx) = async_channel::unbounded(); + + let cb_id = UserData::new_boxed(tx, true).into_u64(); + + let entry = io_uring::opcode::AcceptMulti::new(*listener) + .build() + .user_data(cb_id); + + unsafe { + self.uring.submission_shared().push(&entry).unwrap(); + } + + async move { + Ok(Incoming { + io: self as *const _, + rx: Box::pin(rx), + cb_id, + }) + } + } + + type TcpStream = TcpStream; + async fn tcp_read(&self, stream: &mut Self::TcpStream, buf: &mut [u8]) -> IoResult { + let (tx, rx) = async_channel::bounded(1); + + let entry = io_uring::opcode::Read::new(stream.0, buf.as_mut_ptr(), buf.len() as u32) + .build() + .user_data(UserData::new_boxed(tx, false).into_u64()); + + unsafe { + self.uring.submission_shared().push(&entry).unwrap(); + } + + let entry = rx.recv().await.unwrap(); + let read_amt = handle_error(entry.result())?; + + Ok(read_amt as usize) + } + + fn tcp_connect( + &self, + socket: SocketAddr, + ) -> impl Future> { + // FIXME(Blocking) + let stream = std::net::TcpStream::connect(socket); + + async { Ok(TcpStream(io_uring::types::Fd(stream?.into_raw_fd()))) } + } + + async fn tcp_write(&self, stream: &mut Self::TcpStream, buf: &[u8]) -> IoResult { + let (tx, rx) = async_channel::bounded(1); + + let entry = io_uring::opcode::Write::new(stream.0, buf.as_ptr(), buf.len() as u32) + .build() + .user_data(UserData::new_boxed(tx, false).into_u64()); + + unsafe { + self.uring.submission_shared().push(&entry).unwrap(); + } + + let entry = rx.recv().await.unwrap(); + let write_amt = handle_error(entry.result())?; + + Ok(write_amt as usize) + } +} + +pub struct Incoming { + io: *const IoUring, + rx: Pin>, + cb_id: u64, +} + +pub fn cancel(io: &IoUring, id: u64) { + let entry = io_uring::opcode::AsyncCancel::new(id).build().user_data(0); + + unsafe { + io.uring.submission_shared().push(&entry).unwrap(); + } +} + +impl AsyncIterator for Incoming { + type Item = IoResult; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let rx = unsafe { self.map_unchecked_mut(|s| &mut s.rx) }; + let mut fut = rx.recv(); + let pinned = unsafe { Pin::new_unchecked(&mut fut) }; + + pinned + .poll(cx) + .map(|entry| { + let fd = handle_error(entry.unwrap().result())?; + + Ok(TcpStream(io_uring::types::Fd(fd))) + }) + .map(Some) + } +} + +impl Stream for Incoming { + type Item = IoResult; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + AsyncIterator::poll_next(self, cx) + } +} + +impl Drop for Incoming { + fn drop(&mut self) { + let io = unsafe { &*self.io }; + cancel(io, self.cb_id); + } } /// Behavior used in [`Run`] when the future is not ready, the last [`IoUring::tick`] @@ -252,6 +473,7 @@ impl Future for Run<'_, F> { }, ActiveRequestBehavior::Block => { io.submit(1).unwrap(); + cx.waker().wake_by_ref(); return Poll::Pending; }, ActiveRequestBehavior::Pending => { @@ -280,7 +502,7 @@ mod test { } #[test] - fn timer() { + fn sleep() { let uring = &IoUring::new().unwrap(); let out = uring.block_on(async { crate::time::sleep(uring, Duration::from_secs(1)).await; @@ -290,4 +512,76 @@ mod test { assert_eq!(out, 5); } } + + mod run { + use std::time::Duration; + + use crate::io_impl::io_uring::{ActiveRequestBehavior, IoUring}; + + #[test] + fn sleep() { + let uring = &IoUring::new().unwrap(); + let out = futures_lite::future::block_on(async { + uring + .run( + crate::time::sleep(uring, Duration::from_secs(1)), + ActiveRequestBehavior::Block, + ) + .await; + + 5 + }); + + assert_eq!(out, 5) + } + } + + mod net { + use std::{future::Future, net::SocketAddr}; + + use crate::{ + aliases::IoResult, + io::{AsyncRead, AsyncWrite}, + io_impl::io_uring::IoUring, + }; + use futures_lite::StreamExt; + + async fn start_echo( + uring: &IoUring, + ) -> IoResult<(SocketAddr, impl Future>> + '_)> { + let listener = crate::net::TcpListener::bind(uring, "127.0.0.1:0").await?; + + Ok((listener.local_addr().await?, async { + let mut incoming = listener.incoming().await?; + let mut stream = incoming.next().await.unwrap()?; + + let mut data = vec![0; 4096]; + let read_amt = stream.read(&mut data).await?; + data.truncate(read_amt); + + Ok(data.into_boxed_slice()) + })) + } + + #[test] + fn basic() { + let uring = &IoUring::new().unwrap(); + + let input = "Hello, world!".as_bytes(); + let output = uring + .block_on(async { + let (addr, read_data) = start_echo(uring).await?; + let mut conn = crate::net::TcpStream::connect(uring, addr).await?; + + conn.write(input).await?; + + let data = read_data.await?; + + IoResult::Ok(data) + }) + .unwrap(); + + assert_eq!(&output[..], input) + } + } } diff --git a/src/io_impl/mod.rs b/src/io_impl/mod.rs index 7d62027..aee2e26 100644 --- a/src/io_impl/mod.rs +++ b/src/io_impl/mod.rs @@ -1,9 +1,41 @@ use crate::aliases::IoResult; -use std::time::Duration; +use std::{future::Future, net::SocketAddr, time::Duration}; #[cfg(feature = "io_uring")] pub mod io_uring; pub trait IoImpl { - async fn sleep(&self, duration: Duration) -> IoResult<()>; + fn sleep(&self, duration: Duration) -> impl Future>; + + type TcpListener; + fn open_tcp_socket( + &self, + addr: SocketAddr, + ) -> impl Future>; + fn listener_local_addr( + &self, + listener: &Self::TcpListener, + ) -> impl Future>; + + type Incoming; + fn accept_many( + &self, + listener: &mut Self::TcpListener, + ) -> impl Future>; + + type TcpStream; + fn tcp_read( + &self, + stream: &mut Self::TcpStream, + buf: &mut [u8], + ) -> impl Future>; + fn tcp_write( + &self, + stream: &mut Self::TcpStream, + buf: &[u8], + ) -> impl Future>; + fn tcp_connect( + &self, + socket: SocketAddr, + ) -> impl Future>; } diff --git a/src/lib.rs b/src/lib.rs index 0876e8f..eda836e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,9 +6,5 @@ mod aliases; // pub mod fs; pub mod io; pub mod io_impl; -// pub mod net; +pub mod net; pub mod time; - -use std::{future::Future, pin::Pin, task::Poll}; - -use aliases::IoResult; diff --git a/src/net.rs b/src/net.rs index 4defd88..3ec6d3e 100644 --- a/src/net.rs +++ b/src/net.rs @@ -1,45 +1,46 @@ -use std::{async_iter::AsyncIterator, net::ToSocketAddrs, pin::Pin}; +use std::{ + async_iter::AsyncIterator, + net::{SocketAddr, ToSocketAddrs}, + pin::Pin, +}; use futures_lite::Stream; use crate::{ aliases::IoResult, io::{AsyncRead, AsyncWrite}, - plat_impl, Wove, + io_impl::IoImpl, }; -pub struct TcpStream(pub(crate) plat_impl::FileHandle); - -impl AsyncRead for TcpStream { - fn read( - &mut self, - wove: &Wove, - buf: Box<[u8]>, - ) -> impl std::future::Future>> { - plat_impl::read(&wove.platform, &mut self.0, 0, buf) +pub struct TcpStream<'a, I: IoImpl>(&'a I, I::TcpStream); +impl<'a, I: IoImpl> TcpStream<'a, I> { + pub async fn connect(io: &'a I, addr: SocketAddr) -> IoResult { + Ok(Self(io, io.tcp_connect(addr).await?)) } } -impl AsyncWrite for TcpStream { - fn write( - &mut self, - wove: &Wove, - buf: Box<[u8]>, - ) -> impl std::future::Future> { - plat_impl::write(&wove.platform, &mut self.0, 0, buf) +impl<'a, I: IoImpl> AsyncRead for TcpStream<'a, I> { + async fn read(&mut self, buf: &mut [u8]) -> IoResult { + self.0.tcp_read(&mut self.1, buf).await } } -pub struct TcpListener(plat_impl::FileHandle); +impl<'a, I: IoImpl> AsyncWrite for TcpStream<'a, I> { + async fn write(&mut self, buf: &[u8]) -> IoResult { + self.0.tcp_write(&mut self.1, buf).await + } +} -impl TcpListener { - pub async fn bind(wove: &Wove, addrs: impl ToSocketAddrs) -> IoResult { +pub struct TcpListener<'a, I: IoImpl>(&'a I, I::TcpListener); + +impl<'a, I: IoImpl> TcpListener<'a, I> { + pub async fn bind(io: &'a I, addrs: impl ToSocketAddrs) -> IoResult { // TODO(Blocking): to_socket_addrs can block let mut last_err = None; for addr in addrs.to_socket_addrs()? { - match plat_impl::open_tcp_socket(&wove.platform, addr).await { - Ok(v) => return Ok(TcpListener(v)), + match io.open_tcp_socket(addr).await { + Ok(v) => return Ok(TcpListener(io, v)), Err(e) => last_err = Some(e), } } @@ -51,38 +52,47 @@ impl TcpListener { Err(std::io::Error::other("No addrs returned")) } - pub fn incoming<'a>(&'a mut self, wove: &'a Wove) -> Incoming<'a> { - Incoming::register(wove, self) + pub async fn local_addr(&self) -> IoResult { + self.0.listener_local_addr(&self.1).await + } + + pub async fn incoming(mut self) -> IoResult> { + Ok(Incoming(self.0, self.0.accept_many(&mut self.1).await?)) } } -pub struct Incoming<'a>(plat_impl::Incoming<'a>); -impl<'a> Incoming<'a> { - fn register(wove: &'a Wove, listener: &'a mut TcpListener) -> Self { - Incoming(plat_impl::accept_many(&wove.platform, &mut listener.0)) - } -} +pub struct Incoming<'a, I: IoImpl>(&'a I, I::Incoming); -impl AsyncIterator for Incoming<'_> { - type Item = IoResult; +impl<'a, I: IoImpl> AsyncIterator for Incoming<'a, I> +where + I::Incoming: AsyncIterator>, +{ + type Item = IoResult>; fn poll_next( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - let inner = unsafe { self.map_unchecked_mut(|s| &mut s.0) }; + let io = self.0; + let inner = unsafe { self.map_unchecked_mut(|s| &mut s.1) }; - AsyncIterator::poll_next(inner, cx) + AsyncIterator::poll_next(inner, cx).map(|o| o.map(|r| r.map(|i| TcpStream(io, i)))) } } -impl Stream for Incoming<'_> { - type Item = IoResult; +impl<'a, I: IoImpl> Stream for Incoming<'a, I> +where + I::Incoming: Stream>, +{ + type Item = IoResult>; fn poll_next( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - AsyncIterator::poll_next(self, cx) + let io = self.0; + let inner = unsafe { self.map_unchecked_mut(|s| &mut s.1) }; + + Stream::poll_next(inner, cx).map(|o| o.map(|r| r.map(|i| TcpStream(io, i)))) } } diff --git a/src/plat/linux.rs b/src/plat/linux.rs index 72dc6b1..be93f3e 100644 --- a/src/plat/linux.rs +++ b/src/plat/linux.rs @@ -69,61 +69,6 @@ pub async fn open_tcp_socket( p: &PlatformLinux, socket_addr: SocketAddr, ) -> IoResult { - // FIXME(Blocking) - // There is some some magic missing in the commented out code to make the - // socket clean up properly on process exit and whatnot. For now, just use - // the std implementation and cast it to a FileHandle - - let listener = std::net::TcpListener::bind(socket_addr)?; - Ok(io_uring::types::Fd(listener.into_raw_fd())) - - /* - // let (tx, rx) = async_channel::bounded(1); - - let domain = match () { - _ if socket_addr.is_ipv4() => libc::AF_INET, - _ if socket_addr.is_ipv6() => libc::AF_INET6, - _ => return Err(std::io::Error::other("Unsupported domain")), - }; - - let entry = io_uring::opcode::Socket::new(domain, libc::SOCK_STREAM, 0) - .build() - .user_data(UserData::new_boxed(tx, false).into_u64()); - - unsafe { - p.uring.submission_shared().push(&entry).unwrap(); - } - - let entry = rx.recv().await.unwrap(); - - let fd = handle_error(entry.result())?; - - let sock = libc::sockaddr_in { - sin_family: domain as _, - sin_port: socket_addr.port().to_be(), - sin_addr: libc::in_addr { - s_addr: match socket_addr.ip() { - IpAddr::V4(v) => v.to_bits().to_be(), - IpAddr::V6(_) => panic!(), - }, - }, - sin_zero: Default::default(), - }; - - // FIXME(Blocking) - handle_error(unsafe { - libc::bind( - fd, - &sock as *const _ as *const _, - core::mem::size_of_val(&sock) as u32, - ) - })?; - - // FIXME(Blocking) - handle_error(unsafe { libc::listen(fd, libc::SOMAXCONN) })?; - - Ok(io_uring::types::Fd(fd)) - */ } pub async fn read( @@ -175,77 +120,4 @@ pub async fn write( Ok(write_amt as usize) } -pub(crate) struct Incoming<'a> { - plat: &'a PlatformLinux, - rx: Pin>, - cb_id: u64, -} - -pub fn accept_many<'a>(p: &'a PlatformLinux, f: &mut FileHandle) -> Incoming<'a> { - let (tx, rx) = async_channel::unbounded(); - - let cb_id = UserData::new_boxed(tx, true).into_u64(); - - let entry = io_uring::opcode::AcceptMulti::new(*f) - .build() - .user_data(cb_id); - - unsafe { - p.uring.submission_shared().push(&entry).unwrap(); - } - - Incoming { - plat: p, - rx: Box::pin(rx), - cb_id, - } -} - -pub fn cancel(p: &PlatformLinux, id: u64) { - let entry = io_uring::opcode::AsyncCancel::new(id).build(); - - unsafe { - p.uring.submission_shared().push(&entry).unwrap(); - } -} - -impl AsyncIterator for Incoming<'_> { - type Item = IoResult; - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let rx = unsafe { self.map_unchecked_mut(|s| &mut s.rx) }; - let mut fut = rx.recv(); - let pinned = unsafe { Pin::new_unchecked(&mut fut) }; - - pinned - .poll(cx) - .map(|entry| { - let fd = handle_error(entry.unwrap().result())?; - - Ok(crate::net::TcpStream(io_uring::types::Fd(fd))) - }) - .map(Some) - } -} - -impl Stream for Incoming<'_> { - type Item = IoResult; - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - AsyncIterator::poll_next(self, cx) - } -} - -impl Drop for Incoming<'_> { - fn drop(&mut self) { - cancel(self.plat, self.cb_id); - } -} - pub type Platform = PlatformLinux;