From 8c5b34f1078024fd32f2fd15c6c38ba96a6e3e94 Mon Sep 17 00:00:00 2001 From: soup Date: Mon, 21 Oct 2024 23:38:48 -0400 Subject: [PATCH] More reworking --- Cargo.lock | 11 +- Cargo.toml | 4 +- src/io.rs | 71 ++++++++- src/io_impl/io_uring/mod.rs | 307 ++++++++++++++++++++++-------------- src/io_impl/mod.rs | 21 +-- src/net.rs | 45 +++--- src/time.rs | 6 +- 7 files changed, 297 insertions(+), 168 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0105c79..776441e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3,14 +3,13 @@ version = 4 [[package]] -name = "async-channel" -version = "2.3.1" +name = "async-lock" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a" +checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" dependencies = [ - "concurrent-queue", + "event-listener", "event-listener-strategy", - "futures-core", "pin-project-lite", ] @@ -126,7 +125,7 @@ checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" name = "wove" version = "0.0.0" dependencies = [ - "async-channel", + "async-lock", "futures-lite", "io-uring", "libc", diff --git a/Cargo.toml b/Cargo.toml index 56f745b..0c9c8f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ io-uring = { version = '0.7.1', optional = true } libc = { version = "0.2.161", optional = true } futures-lite = { workspace = true, optional = true } parking = { version = "2.2.1", optional = true } -async-channel = { version = "2.3.1", optional = true } +async-lock = { version = "3.4.0", optional = true } [dev-dependencies] futures-lite = { workspace = true } @@ -26,5 +26,5 @@ io_uring = [ "dep:libc", "dep:futures-lite", "dep:parking", - "dep:async-channel", + "dep:async-lock", ] diff --git a/src/io.rs b/src/io.rs index 4ac64d0..724ab82 100644 --- a/src/io.rs +++ b/src/io.rs @@ -1 +1,70 @@ -pub use futures_lite::{AsyncRead, AsyncWrite}; +use crate::aliases::IoResult; +use std::future::Future; + +pub trait BufferMut { + fn as_mut_ptr(&mut self) -> *mut u8; + fn writable_bytes(&self) -> usize; +} + +impl BufferMut for Vec { + fn as_mut_ptr(&mut self) -> *mut u8 { + self.as_mut_ptr() + } + + fn writable_bytes(&self) -> usize { + self.len() + } +} + +impl BufferMut for Box<[u8]> { + fn as_mut_ptr(&mut self) -> *mut u8 { + self[..].as_mut_ptr() + } + + fn writable_bytes(&self) -> usize { + self.len() + } +} + +pub trait AsyncReadLoan { + fn read(&mut self, buf: B) -> impl Future)>; +} + +pub trait Buffer { + fn as_ptr(&self) -> *const u8; + fn readable_bytes(&self) -> usize; +} + +impl Buffer for Vec { + fn as_ptr(&self) -> *const u8 { + self.as_ptr() + } + + fn readable_bytes(&self) -> usize { + self.len() + } +} + +impl Buffer for Box<[u8]> { + fn as_ptr(&self) -> *const u8 { + self[..].as_ptr() + } + + fn readable_bytes(&self) -> usize { + self.len() + } +} + +impl Buffer for &'static [u8] { + fn as_ptr(&self) -> *const u8 { + self[..].as_ptr() + } + + fn readable_bytes(&self) -> usize { + self.len() + } +} + +pub trait AsyncWriteLoan { + fn write(&mut self, buf: B) -> impl Future)>; +} diff --git a/src/io_impl/io_uring/mod.rs b/src/io_impl/io_uring/mod.rs index bd1ec0f..4911304 100644 --- a/src/io_impl/io_uring/mod.rs +++ b/src/io_impl/io_uring/mod.rs @@ -1,35 +1,78 @@ use std::{ future::Future, net::SocketAddr, + ops::Deref, os::fd::{FromRawFd, IntoRawFd}, pin::Pin, - sync::atomic::{AtomicBool, AtomicU64, Ordering}, + rc::Rc, + sync::atomic::{AtomicI32, AtomicU64, Ordering}, task::{Context, Poll, Waker}, }; -use futures_lite::{AsyncRead, AsyncReadExt, Stream}; +use async_lock::{futures::BarrierWait, Barrier}; +use futures_lite::{pin, Stream}; use parking::Parker; -use crate::aliases::IoResult; +use crate::{ + aliases::IoResult, + io::{AsyncReadLoan, AsyncWriteLoan}, +}; use super::IoImpl; -type CqueueEntryReceiver = async_channel::Receiver; -type CqueueEntrySender = async_channel::Sender; -pub struct UserData { - tx: CqueueEntrySender, - persist: bool, +#[derive(Debug)] +struct ResultBarrier { + result: AtomicI32, + barrier: Barrier, } -impl UserData { - fn new_boxed(tx: CqueueEntrySender, persist: bool) -> Box { - Box::new(Self { tx, persist }) +impl ResultBarrier { + fn new() -> Self { + Self { + result: AtomicI32::new(0), + barrier: Barrier::new(2), + } + } + + async fn wait(&self) { + self.barrier.wait().await; + } + + fn result(&self) -> i32 { + self.result.load(Ordering::Relaxed) + } + + fn set_result_and_block(&self, v: i32) { + self.result.store(v, Ordering::Relaxed); + self.barrier.wait_blocking(); + } + + async fn wait_result(&self) -> i32 { + self.wait().await; + + self.result() + } +} + +#[derive(Debug)] +pub struct UserData<'a> { + persist: bool, + rb: &'a ResultBarrier, +} + +impl<'a> UserData<'a> { + fn new_boxed(rb: &'a ResultBarrier, persist: bool) -> Box { + Box::new(Self { rb, persist }) + } + + fn new_into_u64(rb: &'a ResultBarrier, persist: bool) -> u64 { + Self::new_boxed(rb, persist).into_u64() } fn into_u64(self: Box) -> u64 { Box::leak(self) as *mut _ as _ } - unsafe fn from_u64(v: u64) -> Box { + unsafe fn from_u64(v: u64) -> Box> { let v = v as *mut UserData; unsafe { Box::from_raw(v) } @@ -50,7 +93,22 @@ pub struct Tick { active_completions: usize, } -pub struct IoUring { +#[derive(Clone)] +pub struct IoUring(Rc); +impl IoUring { + pub fn new() -> IoResult { + Ok(Self(Rc::new(IoUringInner::new()?))) + } +} +impl Deref for IoUring { + type Target = IoUringInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +pub struct IoUringInner { uring: io_uring::IoUring, // TODO: analyze the atomic orderings passed to the atomic operations here @@ -62,7 +120,7 @@ pub struct IoUring { _pd: core::marker::PhantomData>, } -impl IoUring { +impl IoUringInner { pub fn new() -> IoResult { let uring = io_uring::IoUring::new(256)?; @@ -74,6 +132,33 @@ impl IoUring { }) } + /// Cancel all events for the given fd. Does not return anything, and + /// cancellations are made on a best-effort basis + fn cancel_fd(&self, fd: io_uring::types::Fd) { + let rb = ResultBarrier::new(); + + let entry = + io_uring::opcode::AsyncCancel2::new(io_uring::types::CancelBuilder::fd(fd).all()) + .build() + .user_data(UserData::new_into_u64(&rb, false)); + + self.queue_op(entry); + } + + fn queue_op(&self, op: io_uring::squeue::Entry) { + unsafe { self.uring.submission_shared().push(&op).unwrap() } + } + + async fn wait_op(&self, op: io_uring::squeue::Entry) -> IoResult { + let rb = ResultBarrier::new(); + + let entry = op.user_data(UserData::new_into_u64(&rb, false)); + + self.queue_op(entry); + + handle_error(rb.wait_result().await) + } + pub fn submit(&self, wait_for: usize) -> IoResult { let submitted_count = self.uring.submit_and_wait(wait_for)?; self.active_completions @@ -101,7 +186,7 @@ impl IoUring { } let ud = unsafe { UserData::from_u64(entry.user_data()) }; - ud.tx.send_blocking(entry).unwrap(); + ud.rb.set_result_and_block(entry.result()); if ud.persist { Box::leak(ud); } else { @@ -188,41 +273,64 @@ impl Drop for TcpListener { } } -pub struct TcpStream<'a>(&'a IoUring, io_uring::types::Fd, AtomicBool); -impl Drop for TcpStream<'_> { +pub struct TcpStream { + uring: IoUring, + fd: io_uring::types::Fd, +} +impl Drop for TcpStream { fn drop(&mut self) { - unsafe { std::net::TcpStream::from_raw_fd(self.1 .0) }; + self.uring.cancel_fd(self.fd); + unsafe { std::net::TcpStream::from_raw_fd(self.fd.0) }; } } -impl AsyncRead for TcpStream<'_> { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - let sq = unsafe { self.0.uring.submission_shared() }; +impl AsyncReadLoan for TcpStream { + async fn read(&mut self, mut buf: B) -> (B, IoResult) { + let res = self + .uring + .0 + .wait_op( + io_uring::opcode::Read::new( + self.fd, + buf.as_mut_ptr(), + buf.writable_bytes() as u32, + ) + .build(), + ) + .await + .map(|v| v as usize); - let entry = io_uring::opcode::Read::new(self.1, buf.as_mut_ptr(), buf.len() as u32) - .build() - .user_data(UserData::new_boxed(tx, false).into_u64()); + (buf, res) + } +} + +impl AsyncWriteLoan for TcpStream { + async fn write(&mut self, buf: B) -> (B, IoResult) { + let res = self + .uring + .0 + .wait_op( + io_uring::opcode::Write::new( + self.fd, + buf.as_ptr(), + buf.readable_bytes() as u32, + ) + .build(), + ) + .await + .map(|v| v as usize); + + (buf, res) } } impl IoImpl for IoUring { async fn sleep(&self, duration: std::time::Duration) -> IoResult<()> { - let (tx, rx) = async_channel::bounded(1); - let ts = io_uring::types::Timespec::new() .sec(duration.as_secs()) .nsec(duration.subsec_nanos()); - let entry = io_uring::opcode::Timeout::new(&ts as *const _) - .build() - .user_data(UserData::new_boxed(tx, false).into_u64()); + let entry = io_uring::opcode::Timeout::new(&ts as *const _).build(); - unsafe { self.uring.submission_shared().push(&entry).unwrap() } - - let entry = rx.recv().await.unwrap(); - handle_error(entry.result())?; + let _ = self.0.wait_op(entry).await; Ok(()) } @@ -303,51 +411,32 @@ impl IoImpl for IoUring { &self, listener: &mut Self::TcpListener, ) -> impl Future> { - let (tx, rx) = async_channel::unbounded(); + let rb = Box::pin(ResultBarrier::new()); - let cb_id = UserData::new_boxed(tx, true).into_u64(); + let cb_id = UserData::new_into_u64(&rb, true); let entry = io_uring::opcode::AcceptMulti::new(listener.0) .build() .user_data(cb_id); unsafe { - self.uring.submission_shared().push(&entry).unwrap(); + self.0.uring.submission_shared().push(&entry).unwrap(); } + let wait = unsafe { + core::mem::transmute::, BarrierWait<'_>>(rb.barrier.wait()) + }; async move { Ok(Incoming { - io: self as *const _, - rx: Box::pin(rx), - cb_id, + uring: self.clone(), + rb, + wait, + fd: listener.0, }) } } - type TcpStream<'a> = TcpStream<'a>; - async fn tcp_read( - &self, - stream: &mut Self::TcpStream<'_>, - buf: &mut [u8], - ) -> IoResult { - return stream.read(buf).await; - - let (tx, rx) = async_channel::bounded(1); - - let entry = io_uring::opcode::Read::new(stream.0, buf.as_mut_ptr(), buf.len() as u32) - .build() - .user_data(UserData::new_boxed(tx, false).into_u64()); - - unsafe { - self.uring.submission_shared().push(&entry).unwrap(); - } - - let entry = rx.recv().await.unwrap(); - let read_amt = handle_error(entry.result())?; - - Ok(read_amt as usize) - } - + type TcpStream = TcpStream; fn tcp_connect( &self, socket: SocketAddr, @@ -355,57 +444,42 @@ impl IoImpl for IoUring { // FIXME(Blocking) let stream = std::net::TcpStream::connect(socket); - async { Ok(TcpStream(io_uring::types::Fd(stream?.into_raw_fd()))) } - } - - async fn tcp_write(&self, stream: &mut Self::TcpStream, buf: &[u8]) -> IoResult { - let (tx, rx) = async_channel::bounded(1); - - let entry = io_uring::opcode::Write::new(stream.0, buf.as_ptr(), buf.len() as u32) - .build() - .user_data(UserData::new_boxed(tx, false).into_u64()); - - unsafe { - self.uring.submission_shared().push(&entry).unwrap(); + async { + Ok(TcpStream { + uring: self.clone(), + fd: io_uring::types::Fd(stream?.into_raw_fd()), + }) } - - let entry = rx.recv().await.unwrap(); - let write_amt = handle_error(entry.result())?; - - Ok(write_amt as usize) } } pub struct Incoming { - io: *const IoUring, - rx: Pin>, - cb_id: u64, + uring: IoUring, + rb: Pin>, + wait: BarrierWait<'static>, + fd: io_uring::types::Fd, +} +impl Unpin for Incoming { } -pub fn cancel(io: &IoUring, id: u64) { - let entry = io_uring::opcode::AsyncCancel::new(id).build().user_data(0); - - unsafe { - io.uring.submission_shared().push(&entry).unwrap(); - } -} impl Stream for Incoming { type Item = IoResult; fn poll_next( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - let rx = unsafe { self.map_unchecked_mut(|s| &mut s.rx) }; - let mut fut = rx.recv(); - let pinned = unsafe { Pin::new_unchecked(&mut fut) }; + let fut = unsafe { self.as_mut().map_unchecked_mut(|s| &mut s.wait) }; + pin!(fut); - pinned - .poll(cx) - .map(|entry| { - let fd = handle_error(entry.unwrap().result())?; + fut.poll(cx) + .map(|_| { + let fd = handle_error(self.rb.result())?; - Ok(TcpStream(io_uring::types::Fd(fd))) + Ok(TcpStream { + uring: self.uring.clone(), + fd: io_uring::types::Fd(fd), + }) }) .map(Some) } @@ -413,8 +487,7 @@ impl Stream for Incoming { impl Drop for Incoming { fn drop(&mut self) { - let io = unsafe { &*self.io }; - cancel(io, self.cb_id); + self.uring.0.cancel_fd(self.fd); } } @@ -439,7 +512,7 @@ pub enum ActiveRequestBehavior { Pending, } -pub struct Run<'a, F: Future>(&'a IoUring, F, ActiveRequestBehavior); +pub struct Run<'a, F: Future>(&'a IoUringInner, F, ActiveRequestBehavior); impl Future for Run<'_, F> { type Output = F::Output; @@ -502,7 +575,7 @@ mod test { #[test] fn simple() { - let uring = IoUring::new().unwrap(); + let uring = &IoUring::new().unwrap(); let out = uring.block_on(async { 5 }); assert_eq!(out, 5); } @@ -511,7 +584,9 @@ mod test { fn sleep() { let uring = &IoUring::new().unwrap(); let out = uring.block_on(async { - crate::time::sleep(uring, Duration::from_secs(1)).await; + crate::time::sleep(uring, Duration::from_secs(1)) + .await + .unwrap(); 5 }); @@ -533,7 +608,8 @@ mod test { crate::time::sleep(uring, Duration::from_secs(1)), ActiveRequestBehavior::Block, ) - .await; + .await + .unwrap(); 5 }); @@ -547,7 +623,7 @@ mod test { use crate::{ aliases::IoResult, - io::{AsyncRead, AsyncWrite}, + io::{AsyncReadLoan, AsyncWriteLoan}, io_impl::io_uring::IoUring, }; use futures_lite::StreamExt; @@ -555,14 +631,14 @@ mod test { async fn start_echo( uring: &IoUring, ) -> IoResult<(SocketAddr, impl Future>> + '_)> { - let listener = crate::net::TcpListener::bind(uring, "127.0.0.1:0").await?; + let mut listener = crate::net::TcpListener::bind(uring, "127.0.0.1:0").await?; - Ok((listener.local_addr().await?, async { + Ok((listener.local_addr().await?, async move { let mut incoming = listener.incoming().await?; let mut stream = incoming.next().await.unwrap()?; - let mut data = vec![0; 4096]; - let read_amt = stream.read(&mut data).await?; + let (mut data, read_amt) = stream.read(vec![0; 4096]).await; + let read_amt = read_amt?; data.truncate(read_amt); Ok(data.into_boxed_slice()) @@ -579,7 +655,8 @@ mod test { let (addr, read_data) = start_echo(uring).await?; let mut conn = crate::net::TcpStream::connect(uring, addr).await?; - conn.write(input).await?; + let (_, res) = conn.write(input).await; + res?; let data = read_data.await?; diff --git a/src/io_impl/mod.rs b/src/io_impl/mod.rs index 606b113..f5f4ff2 100644 --- a/src/io_impl/mod.rs +++ b/src/io_impl/mod.rs @@ -1,6 +1,7 @@ -use futures_lite::AsyncRead; - -use crate::aliases::IoResult; +use crate::{ + aliases::IoResult, + io::{AsyncReadLoan, AsyncWriteLoan}, +}; use std::{future::Future, net::SocketAddr, time::Duration}; #[cfg(feature = "io_uring")] @@ -25,19 +26,9 @@ pub trait IoImpl { listener: &mut Self::TcpListener, ) -> impl Future>; - type TcpStream<'a>: AsyncRead + 'a; - fn tcp_read( - &self, - stream: &mut Self::TcpStream<'_>, - buf: &mut [u8], - ) -> impl Future>; - fn tcp_write( - &self, - stream: &mut Self::TcpStream<'_>, - buf: &[u8], - ) -> impl Future>; + type TcpStream: AsyncReadLoan + AsyncWriteLoan; fn tcp_connect( &self, socket: SocketAddr, - ) -> impl Future>>; + ) -> impl Future>; } diff --git a/src/net.rs b/src/net.rs index 4731e35..bd7abf3 100644 --- a/src/net.rs +++ b/src/net.rs @@ -1,39 +1,33 @@ use std::{ net::{SocketAddr, ToSocketAddrs}, - ops::DerefMut, pin::Pin, - task::Poll, }; use futures_lite::Stream; use crate::{ aliases::IoResult, - io::{AsyncRead, AsyncWrite}, + io::{AsyncReadLoan, AsyncWriteLoan, Buffer, BufferMut}, io_impl::IoImpl, }; -pub struct TcpStream<'a, I: IoImpl>(&'a I, I::TcpStream); -impl<'a, I: IoImpl> TcpStream<'a, I> { - pub async fn connect(io: &'a I, addr: SocketAddr) -> IoResult { - Ok(Self(io, io.tcp_connect(addr).await?)) +pub struct TcpStream(I::TcpStream); +impl TcpStream { + pub async fn connect(io: &I, addr: SocketAddr) -> IoResult { + Ok(Self(io.tcp_connect(addr).await?)) } } -impl AsyncRead for TcpStream<'_, I> -where - I::TcpStream: Unpin + AsyncRead, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut [u8], - ) -> Poll> { - Pin::new(&mut self.1).poll_read(cx, buf) +impl AsyncReadLoan for TcpStream { + async fn read(&mut self, buf: B) -> (B, IoResult) { + self.0.read(buf).await } } -impl<'a, I: IoImpl> AsyncWrite for TcpStream<'a, I> { +impl AsyncWriteLoan for TcpStream { + async fn write(&mut self, buf: B) -> (B, IoResult) { + self.0.write(buf).await + } } pub struct TcpListener<'a, I: IoImpl>(&'a I, I::TcpListener); @@ -61,26 +55,25 @@ impl<'a, I: IoImpl> TcpListener<'a, I> { self.0.listener_local_addr(&self.1).await } - pub async fn incoming(mut self) -> IoResult> { - Ok(Incoming(self.0, self.0.accept_many(&mut self.1).await?)) + pub async fn incoming(&mut self) -> IoResult> { + Ok(Incoming(self.0.accept_many(&mut self.1).await?)) } } -pub struct Incoming<'a, I: IoImpl>(&'a I, I::Incoming); +pub struct Incoming(I::Incoming); -impl<'a, I: IoImpl> Stream for Incoming<'a, I> +impl Stream for Incoming where I::Incoming: Stream>, { - type Item = IoResult>; + type Item = IoResult>; fn poll_next( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - let io = self.0; - let inner = unsafe { self.map_unchecked_mut(|s| &mut s.1) }; + let inner = unsafe { self.map_unchecked_mut(|s| &mut s.0) }; - Stream::poll_next(inner, cx).map(|o| o.map(|r| r.map(|i| TcpStream(io, i)))) + Stream::poll_next(inner, cx).map(|o| o.map(|r| r.map(|i| TcpStream(i)))) } } diff --git a/src/time.rs b/src/time.rs index 16a9541..690f526 100644 --- a/src/time.rs +++ b/src/time.rs @@ -1,7 +1,7 @@ use std::time::Duration; -use crate::io_impl::IoImpl; +use crate::{aliases::IoResult, io_impl::IoImpl}; -pub async fn sleep(io: &I, duration: Duration) { - let _ = io.sleep(duration).await; +pub async fn sleep(io: &I, duration: Duration) -> IoResult<()> { + io.sleep(duration).await }