diff --git a/Cargo.lock b/Cargo.lock index 776441e..e838a3e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,17 +2,6 @@ # It is not intended for manual editing. version = 4 -[[package]] -name = "async-lock" -version = "3.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" -dependencies = [ - "event-listener", - "event-listener-strategy", - "pin-project-lite", -] - [[package]] name = "bitflags" version = "2.6.0" @@ -25,42 +14,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" -[[package]] -name = "concurrent-queue" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-utils" -version = "0.8.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" - -[[package]] -name = "event-listener" -version = "5.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - -[[package]] -name = "event-listener-strategy" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1" -dependencies = [ - "event-listener", - "pin-project-lite", -] - [[package]] name = "fastrand" version = "2.1.1" @@ -125,7 +78,6 @@ checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" name = "wove" version = "0.0.0" dependencies = [ - "async-lock", "futures-lite", "io-uring", "libc", diff --git a/Cargo.toml b/Cargo.toml index 0c9c8f5..0f49399 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,6 @@ io-uring = { version = '0.7.1', optional = true } libc = { version = "0.2.161", optional = true } futures-lite = { workspace = true, optional = true } parking = { version = "2.2.1", optional = true } -async-lock = { version = "3.4.0", optional = true } [dev-dependencies] futures-lite = { workspace = true } @@ -26,5 +25,4 @@ io_uring = [ "dep:libc", "dep:futures-lite", "dep:parking", - "dep:async-lock", ] diff --git a/flake.nix b/flake.nix index a91d238..a4b2fc0 100644 --- a/flake.nix +++ b/flake.nix @@ -28,10 +28,11 @@ devShells.default = pkgs.mkShell { packages = [ (pkgsFenix.combine [ - (complete.withComponents [ "rust-src" "rust-analyzer" "rustfmt" "clippy" ]) + (complete.withComponents [ "rust-src" "rust-analyzer" "rustfmt" "clippy" "miri" ]) (minimal.withComponents [ "cargo" "rustc" "rust-std" ]) ]) pkgs.cargo-watch + pkgs.gdb pkgs.valgrind ]; }; }); diff --git a/src/futures.rs b/src/futures.rs new file mode 100644 index 0000000..8a7ec95 --- /dev/null +++ b/src/futures.rs @@ -0,0 +1,119 @@ +use std::{ + cell::RefCell, + future::Future, + pin::Pin, + rc::Rc, + task::{Context, Poll, Waker}, +}; + +use futures_lite::Stream; + +type Slot = Rc>>; + +#[derive(Debug, Clone)] +pub struct Put(Slot); +impl Put { + pub fn put(&mut self, value: T) -> PutFut<'_, T> { + PutFut(self, Some(value), None) + } + + pub fn try_put(&mut self, value: T) -> Result<(), T> { + let mut slot = self.0.borrow_mut(); + if slot.is_some() { + return Err(value); + } + + slot.replace(value); + Ok(()) + } +} + +pub struct PutFut<'a, T>(&'a mut Put, Option, Option); +impl Unpin for PutFut<'_, T> { +} + +impl Future for PutFut<'_, T> +where + T: Unpin, +{ + type Output = Result<(), T>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let value = match self.1.take() { + Some(v) => v, + None => panic!(), + }; + + let value = match self.0.try_put(value) { + Err(value) => value, + r => { + self.2.take().unwrap().wake(); + return Poll::Ready(r); + }, + }; + + self.1 = Some(value); + self.2 = Some(cx.waker().clone()); + + Poll::Pending + } +} + +#[derive(Debug)] +pub struct Get(Slot); +impl Get { + pub fn get(&mut self) -> GetFut<'_, T> { + GetFut(self, None) + } + + pub fn try_get(&mut self) -> Option { + let mut slot = self.0.borrow_mut(); + if let Some(value) = slot.take() { + return Some(value); + } + + None + } +} + +#[derive(Debug)] +pub struct GetFut<'a, T>(pub(crate) &'a mut Get, Option); +impl Unpin for GetFut<'_, T> { +} + +impl Future for GetFut<'_, T> +where + T: Unpin, +{ + type Output = T; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(value) = self.0.try_get() { + return Poll::Ready(value); + } + + self.1 = Some(cx.waker().clone()); + + Poll::Pending + } +} + +impl Stream for GetFut<'_, T> +where + T: Unpin, +{ + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Future::poll(self, cx).map(|v| Some(v)) + } +} + +pub fn handoff() -> (Put, Get) { + let slot = Rc::new(RefCell::new(None)); + + let put = Put(slot.clone()); + let get = Get(slot); + + (put, get) +} diff --git a/src/io.rs b/src/io.rs index 724ab82..11254d9 100644 --- a/src/io.rs +++ b/src/io.rs @@ -1,3 +1,5 @@ +use futures_lite::Stream; + use crate::aliases::IoResult; use std::future::Future; @@ -28,6 +30,20 @@ impl BufferMut for Box<[u8]> { pub trait AsyncReadLoan { fn read(&mut self, buf: B) -> impl Future)>; + + fn read_stream(&mut self) -> impl Stream>> { + futures_lite::stream::unfold(self, move |s| async move { + let (mut buf, res) = s.read(vec![0; 4096]).await; + match res { + Err(e) => Some((Err(e), s)), + Ok(0) => None, + Ok(read_amt) => { + buf.truncate(read_amt); + Some((Ok(buf), s)) + }, + } + }) + } } pub trait Buffer { diff --git a/src/io_impl/io_uring/mod.rs b/src/io_impl/io_uring/mod.rs index 4911304..061ec68 100644 --- a/src/io_impl/io_uring/mod.rs +++ b/src/io_impl/io_uring/mod.rs @@ -1,78 +1,39 @@ use std::{ future::Future, + mem::transmute, net::SocketAddr, ops::Deref, os::fd::{FromRawFd, IntoRawFd}, pin::Pin, rc::Rc, - sync::atomic::{AtomicI32, AtomicU64, Ordering}, + sync::atomic::{AtomicU64, Ordering}, task::{Context, Poll, Waker}, }; -use async_lock::{futures::BarrierWait, Barrier}; use futures_lite::{pin, Stream}; use parking::Parker; use crate::{ aliases::IoResult, + futures::{handoff, Get, GetFut, Put}, io::{AsyncReadLoan, AsyncWriteLoan}, }; use super::IoImpl; #[derive(Debug)] -struct ResultBarrier { - result: AtomicI32, - barrier: Barrier, -} -impl ResultBarrier { - fn new() -> Self { - Self { - result: AtomicI32::new(0), - barrier: Barrier::new(2), - } - } - - async fn wait(&self) { - self.barrier.wait().await; - } - - fn result(&self) -> i32 { - self.result.load(Ordering::Relaxed) - } - - fn set_result_and_block(&self, v: i32) { - self.result.store(v, Ordering::Relaxed); - self.barrier.wait_blocking(); - } - - async fn wait_result(&self) -> i32 { - self.wait().await; - - self.result() - } -} - -#[derive(Debug)] -pub struct UserData<'a> { +pub struct UserData { + put: Put, persist: bool, - rb: &'a ResultBarrier, + _opcode: u8, } -impl<'a> UserData<'a> { - fn new_boxed(rb: &'a ResultBarrier, persist: bool) -> Box { - Box::new(Self { rb, persist }) - } - - fn new_into_u64(rb: &'a ResultBarrier, persist: bool) -> u64 { - Self::new_boxed(rb, persist).into_u64() - } - +impl UserData { fn into_u64(self: Box) -> u64 { Box::leak(self) as *mut _ as _ } - unsafe fn from_u64(v: u64) -> Box> { + unsafe fn from_u64(v: u64) -> Box { let v = v as *mut UserData; unsafe { Box::from_raw(v) } @@ -135,28 +96,62 @@ impl IoUringInner { /// Cancel all events for the given fd. Does not return anything, and /// cancellations are made on a best-effort basis fn cancel_fd(&self, fd: io_uring::types::Fd) { - let rb = ResultBarrier::new(); - let entry = io_uring::opcode::AsyncCancel2::new(io_uring::types::CancelBuilder::fd(fd).all()) .build() - .user_data(UserData::new_into_u64(&rb, false)); + .user_data(0); - self.queue_op(entry); + self.op_queue(entry); } - fn queue_op(&self, op: io_uring::squeue::Entry) { + fn _cancel_ud(&self, ud: u64) { + let entry = io_uring::opcode::AsyncCancel2::new( + io_uring::types::CancelBuilder::user_data(ud).all(), + ) + .build() + .user_data(0); + + self.op_queue(entry); + } + + fn op_queue(&self, op: io_uring::squeue::Entry) { unsafe { self.uring.submission_shared().push(&op).unwrap() } } - async fn wait_op(&self, op: io_uring::squeue::Entry) -> IoResult { - let rb = ResultBarrier::new(); + async fn op_wait(&self, op: io_uring::squeue::Entry) -> IoResult { + let (put, mut get) = handoff(); - let entry = op.user_data(UserData::new_into_u64(&rb, false)); + let opcode = unsafe { *(&op as *const _ as *const u8) }; - self.queue_op(entry); + let user_data = UserData { + put, + _opcode: opcode, + persist: false, + }; - handle_error(rb.wait_result().await) + let entry = op.user_data(Box::new(user_data).into_u64()); + + self.op_queue(entry); + + handle_error(get.get().await) + } + + fn op_many(&self, op: io_uring::squeue::Entry) -> Get { + let (put, get) = handoff(); + + let opcode = unsafe { *(&op as *const _ as *const u8) }; + + let user_data = UserData { + put, + _opcode: opcode, + persist: true, + }; + + let entry = op.user_data(Box::new(user_data).into_u64()); + + self.op_queue(entry); + + get } pub fn submit(&self, wait_for: usize) -> IoResult { @@ -185,10 +180,12 @@ impl IoUringInner { continue; } - let ud = unsafe { UserData::from_u64(entry.user_data()) }; - ud.rb.set_result_and_block(entry.result()); - if ud.persist { - Box::leak(ud); + let mut user_data = unsafe { UserData::from_u64(ud) }; + + let _ = user_data.put.try_put(entry.result()); + + if user_data.persist { + Box::leak(user_data); } else { self.active_completions.fetch_sub(1, Ordering::Relaxed); } @@ -288,7 +285,7 @@ impl AsyncReadLoan for TcpStream { let res = self .uring .0 - .wait_op( + .op_wait( io_uring::opcode::Read::new( self.fd, buf.as_mut_ptr(), @@ -308,7 +305,7 @@ impl AsyncWriteLoan for TcpStream { let res = self .uring .0 - .wait_op( + .op_wait( io_uring::opcode::Write::new( self.fd, buf.as_ptr(), @@ -330,7 +327,7 @@ impl IoImpl for IoUring { .nsec(duration.subsec_nanos()); let entry = io_uring::opcode::Timeout::new(&ts as *const _).build(); - let _ = self.0.wait_op(entry).await; + let _ = self.0.op_wait(entry).await; Ok(()) } @@ -411,29 +408,12 @@ impl IoImpl for IoUring { &self, listener: &mut Self::TcpListener, ) -> impl Future> { - let rb = Box::pin(ResultBarrier::new()); + let get = self.op_many(io_uring::opcode::AcceptMulti::new(listener.0).build()); - let cb_id = UserData::new_into_u64(&rb, true); - - let entry = io_uring::opcode::AcceptMulti::new(listener.0) - .build() - .user_data(cb_id); - - unsafe { - self.0.uring.submission_shared().push(&entry).unwrap(); - } - - let wait = unsafe { - core::mem::transmute::, BarrierWait<'_>>(rb.barrier.wait()) - }; - async move { - Ok(Incoming { - uring: self.clone(), - rb, - wait, - fd: listener.0, - }) - } + let mut get = Box::pin(get); + let fut = get.get(); + let fut = unsafe { transmute::, GetFut<'_, _>>(fut) }; + async { Ok(Incoming(self.clone(), get, fut, listener.0)) } } type TcpStream = TcpStream; @@ -453,15 +433,12 @@ impl IoImpl for IoUring { } } -pub struct Incoming { - uring: IoUring, - rb: Pin>, - wait: BarrierWait<'static>, - fd: io_uring::types::Fd, -} -impl Unpin for Incoming { -} - +pub struct Incoming( + IoUring, + #[allow(dead_code)] Pin>>, + GetFut<'static, i32>, + io_uring::types::Fd, +); impl Stream for Incoming { type Item = IoResult; @@ -469,25 +446,25 @@ impl Stream for Incoming { mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - let fut = unsafe { self.as_mut().map_unchecked_mut(|s| &mut s.wait) }; + let fut = &mut self.2; pin!(fut); - fut.poll(cx) - .map(|_| { - let fd = handle_error(self.rb.result())?; + fut.poll_next(cx).map(|o| { + o.map(|v| { + let fd = handle_error(v)?; Ok(TcpStream { - uring: self.uring.clone(), + uring: self.0.clone(), fd: io_uring::types::Fd(fd), }) }) - .map(Some) + }) } } impl Drop for Incoming { fn drop(&mut self) { - self.uring.0.cancel_fd(self.fd); + self.0.cancel_fd(self.3); } } @@ -619,7 +596,7 @@ mod test { } mod net { - use std::{future::Future, net::SocketAddr}; + use std::{future::Future, net::SocketAddr, time::Duration}; use crate::{ aliases::IoResult, @@ -666,5 +643,52 @@ mod test { assert_eq!(&output[..], input) } + + #[test] + fn read_stream() { + let uring = &IoUring::new().unwrap(); + + let res = uring + .block_on(async { + let mut listener = + crate::net::TcpListener::bind(uring, "127.0.0.1:0").await?; + let addr = listener.local_addr().await?; + + let write_data = async { + let mut stream = listener.incoming().await?.next().await.unwrap()?; + let (_, res) = stream.write("Hello".as_bytes()).await; + res?; + + crate::time::sleep(uring, Duration::from_millis(500)).await?; + + let (_, res) = stream.write(", world".as_bytes()).await; + res?; + + IoResult::Ok(()) + }; + + let read_data = async { + let mut stream = crate::net::TcpStream::connect(uring, addr).await?; + + let data: Vec> = stream.read_stream().try_collect().await?; + + IoResult::Ok(data) + }; + + let (_, data) = + futures_lite::future::try_zip(write_data, read_data).await?; + + IoResult::Ok(data) + }) + .unwrap(); + + assert_eq!( + res, + vec![ + "Hello".to_string().into_bytes(), + ", world".to_string().into_bytes() + ] + ) + } } } diff --git a/src/lib.rs b/src/lib.rs index adbb35d..c1066ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ mod aliases; // pub mod fs; +pub mod futures; pub mod io; pub mod io_impl; pub mod net;