use std::{ future::Future, net::SocketAddr, ops::Deref, os::fd::{FromRawFd, IntoRawFd}, pin::Pin, rc::Rc, sync::atomic::{AtomicI32, AtomicU64, Ordering}, task::{Context, Poll, Waker}, }; use async_lock::{futures::BarrierWait, Barrier}; use futures_lite::{pin, Stream}; use parking::Parker; use crate::{ aliases::IoResult, io::{AsyncReadLoan, AsyncWriteLoan}, }; use super::IoImpl; #[derive(Debug)] struct ResultBarrier { result: AtomicI32, barrier: Barrier, } impl ResultBarrier { fn new() -> Self { Self { result: AtomicI32::new(0), barrier: Barrier::new(2), } } async fn wait(&self) { self.barrier.wait().await; } fn result(&self) -> i32 { self.result.load(Ordering::Relaxed) } fn set_result_and_block(&self, v: i32) { self.result.store(v, Ordering::Relaxed); self.barrier.wait_blocking(); } async fn wait_result(&self) -> i32 { self.wait().await; self.result() } } #[derive(Debug)] pub struct UserData<'a> { persist: bool, rb: &'a ResultBarrier, } impl<'a> UserData<'a> { fn new_boxed(rb: &'a ResultBarrier, persist: bool) -> Box { Box::new(Self { rb, persist }) } fn new_into_u64(rb: &'a ResultBarrier, persist: bool) -> u64 { Self::new_boxed(rb, persist).into_u64() } fn into_u64(self: Box) -> u64 { Box::leak(self) as *mut _ as _ } unsafe fn from_u64(v: u64) -> Box> { let v = v as *mut UserData; unsafe { Box::from_raw(v) } } } fn handle_error(i: i32) -> IoResult { if i < 0 { return Err(std::io::Error::from_raw_os_error(-i)); } Ok(i) } pub struct Tick { did_handle: bool, submitted: usize, active_completions: usize, } #[derive(Clone)] pub struct IoUring(Rc); impl IoUring { pub fn new() -> IoResult { Ok(Self(Rc::new(IoUringInner::new()?))) } } impl Deref for IoUring { type Target = IoUringInner; fn deref(&self) -> &Self::Target { &self.0 } } pub struct IoUringInner { uring: io_uring::IoUring, // TODO: analyze the atomic orderings passed to the atomic operations here // to make sure they make sense // TODO: I'm not sure that atomics even make sense here, but they let this // method work with &self instead of &mut self so :shrug: active_completions: AtomicU64, _pd: core::marker::PhantomData>, } impl IoUringInner { pub fn new() -> IoResult { let uring = io_uring::IoUring::new(256)?; Ok(Self { uring, active_completions: AtomicU64::new(0), _pd: Default::default(), }) } /// Cancel all events for the given fd. Does not return anything, and /// cancellations are made on a best-effort basis fn cancel_fd(&self, fd: io_uring::types::Fd) { let rb = ResultBarrier::new(); let entry = io_uring::opcode::AsyncCancel2::new(io_uring::types::CancelBuilder::fd(fd).all()) .build() .user_data(UserData::new_into_u64(&rb, false)); self.queue_op(entry); } fn queue_op(&self, op: io_uring::squeue::Entry) { unsafe { self.uring.submission_shared().push(&op).unwrap() } } async fn wait_op(&self, op: io_uring::squeue::Entry) -> IoResult { let rb = ResultBarrier::new(); let entry = op.user_data(UserData::new_into_u64(&rb, false)); self.queue_op(entry); handle_error(rb.wait_result().await) } pub fn submit(&self, wait_for: usize) -> IoResult { let submitted_count = self.uring.submit_and_wait(wait_for)?; self.active_completions .fetch_add(submitted_count as u64, Ordering::Relaxed); Ok(submitted_count) } /// Returns the current number of active requests, and a boolean indicating /// if we actually handled any events. pub fn poll(&self) -> IoResult<(usize, bool)> { let mut did_handle = false; // SAFETY: // - this method is synchronous, and `cq` is dropped at the end of the scope // - [`IoUring`] is !Sync, so it should be impossible for 2 threads to be // running this at the same time let cq = unsafe { self.uring.completion_shared() }; for entry in cq { did_handle = true; let ud = entry.user_data(); if ud == 0 { self.active_completions.fetch_sub(1, Ordering::Relaxed); continue; } let ud = unsafe { UserData::from_u64(entry.user_data()) }; ud.rb.set_result_and_block(entry.result()); if ud.persist { Box::leak(ud); } else { self.active_completions.fetch_sub(1, Ordering::Relaxed); } } Ok(( self.active_completions.load(Ordering::Relaxed) as usize, did_handle, )) } pub fn tick(&self) -> IoResult { let submitted = self.submit(0)?; let (active_completions, did_handle) = self.poll()?; Ok(Tick { active_completions, did_handle, submitted, }) } /// Runs a future. Conceptually, this is similar to [`Self::block_on`], but /// instead of acting as its own executor, this allows us to be embedded in /// another runtime pub fn run(&self, fut: F, behavior: ActiveRequestBehavior) -> Run { Run(self, fut, behavior) } pub fn block_on(&self, fut: F) -> F::Output { futures_lite::pin!(fut); let parker = Parker::new(); let unparker = parker.unparker(); let waker = Waker::from(unparker); let cx = &mut Context::from_waker(&waker); loop { // Check if the future is ready. If so, return the value. if let Poll::Ready(v) = fut.as_mut().poll(cx) { return v; } let Tick { did_handle, active_completions, submitted, } = self.tick().unwrap(); if did_handle { // If we handled an event, it's likely that our future can make progress, // so continue the loop continue; } if submitted > 0 { // We submitted an event. It's possible that our future can make progress // once we poll, so continue the loop continue; } if active_completions > 0 { // We didn't submit an event, but we do have completions in-flight. // We should block until one of them completes, and then re-poll self.submit(1).unwrap(); continue; } // If we've gotten to this point, it's likely that we're waiting on a // future that depends on another thread to make progress. In that case, // park the current thread until the waker is called. parker.park(); } } } pub struct TcpListener(io_uring::types::Fd); impl Drop for TcpListener { fn drop(&mut self) { unsafe { std::net::TcpListener::from_raw_fd(self.0 .0) }; } } pub struct TcpStream { uring: IoUring, fd: io_uring::types::Fd, } impl Drop for TcpStream { fn drop(&mut self) { self.uring.cancel_fd(self.fd); unsafe { std::net::TcpStream::from_raw_fd(self.fd.0) }; } } impl AsyncReadLoan for TcpStream { async fn read(&mut self, mut buf: B) -> (B, IoResult) { let res = self .uring .0 .wait_op( io_uring::opcode::Read::new( self.fd, buf.as_mut_ptr(), buf.writable_bytes() as u32, ) .build(), ) .await .map(|v| v as usize); (buf, res) } } impl AsyncWriteLoan for TcpStream { async fn write(&mut self, buf: B) -> (B, IoResult) { let res = self .uring .0 .wait_op( io_uring::opcode::Write::new( self.fd, buf.as_ptr(), buf.readable_bytes() as u32, ) .build(), ) .await .map(|v| v as usize); (buf, res) } } impl IoImpl for IoUring { async fn sleep(&self, duration: std::time::Duration) -> IoResult<()> { let ts = io_uring::types::Timespec::new() .sec(duration.as_secs()) .nsec(duration.subsec_nanos()); let entry = io_uring::opcode::Timeout::new(&ts as *const _).build(); let _ = self.0.wait_op(entry).await; Ok(()) } type TcpListener = TcpListener; fn open_tcp_socket( &self, addr: std::net::SocketAddr, ) -> impl Future> { // FIXME(Blocking) // There is some some magic missing in the commented out code to make the // socket clean up properly on process exit and whatnot. For now, just use // the std implementation and cast it to a FileHandle let listener = std::net::TcpListener::bind(addr); async { IoResult::Ok(TcpListener(io_uring::types::Fd(listener?.into_raw_fd()))) } /* // let (tx, rx) = async_channel::bounded(1); let domain = match () { _ if socket_addr.is_ipv4() => libc::AF_INET, _ if socket_addr.is_ipv6() => libc::AF_INET6, _ => return Err(std::io::Error::other("Unsupported domain")), }; let entry = io_uring::opcode::Socket::new(domain, libc::SOCK_STREAM, 0) .build() .user_data(UserData::new_boxed(tx, false).into_u64()); unsafe { p.uring.submission_shared().push(&entry).unwrap(); } let entry = rx.recv().await.unwrap(); let fd = handle_error(entry.result())?; let sock = libc::sockaddr_in { sin_family: domain as _, sin_port: socket_addr.port().to_be(), sin_addr: libc::in_addr { s_addr: match socket_addr.ip() { IpAddr::V4(v) => v.to_bits().to_be(), IpAddr::V6(_) => panic!(), }, }, sin_zero: Default::default(), }; // FIXME(Blocking) handle_error(unsafe { libc::bind( fd, &sock as *const _ as *const _, core::mem::size_of_val(&sock) as u32, ) })?; // FIXME(Blocking) handle_error(unsafe { libc::listen(fd, libc::SOMAXCONN) })?; Ok(io_uring::types::Fd(fd)) */ } async fn listener_local_addr(&self, listener: &Self::TcpListener) -> IoResult { // FIXME(Blocking) let listener = unsafe { std::net::TcpListener::from_raw_fd(listener.0 .0) }; let addr = listener.local_addr()?; let _ = listener.into_raw_fd(); Ok(addr) } type Incoming = Incoming; fn accept_many( &self, listener: &mut Self::TcpListener, ) -> impl Future> { let rb = Box::pin(ResultBarrier::new()); let cb_id = UserData::new_into_u64(&rb, true); let entry = io_uring::opcode::AcceptMulti::new(listener.0) .build() .user_data(cb_id); unsafe { self.0.uring.submission_shared().push(&entry).unwrap(); } let wait = unsafe { core::mem::transmute::, BarrierWait<'_>>(rb.barrier.wait()) }; async move { Ok(Incoming { uring: self.clone(), rb, wait, fd: listener.0, }) } } type TcpStream = TcpStream; fn tcp_connect( &self, socket: SocketAddr, ) -> impl Future> { // FIXME(Blocking) let stream = std::net::TcpStream::connect(socket); async { Ok(TcpStream { uring: self.clone(), fd: io_uring::types::Fd(stream?.into_raw_fd()), }) } } } pub struct Incoming { uring: IoUring, rb: Pin>, wait: BarrierWait<'static>, fd: io_uring::types::Fd, } impl Unpin for Incoming { } impl Stream for Incoming { type Item = IoResult; fn poll_next( mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let fut = unsafe { self.as_mut().map_unchecked_mut(|s| &mut s.wait) }; pin!(fut); fut.poll(cx) .map(|_| { let fd = handle_error(self.rb.result())?; Ok(TcpStream { uring: self.uring.clone(), fd: io_uring::types::Fd(fd), }) }) .map(Some) } } impl Drop for Incoming { fn drop(&mut self) { self.uring.0.cancel_fd(self.fd); } } /// Behavior used in [`Run`] when the future is not ready, the last [`IoUring::tick`] /// didn't submit or handle any events, and there are requests in flight #[derive(Copy, Clone, Debug)] pub enum ActiveRequestBehavior { /// Cause a panic Panic, /// Block until a completion is returned Block, /// Return pending. /// ///
/// /// NOTE: this relies on the fact that the executor is going to poll this /// _eventually_. If it doesn't, it's likely that you'll get a deadlock. /// ///
Pending, } pub struct Run<'a, F: Future>(&'a IoUringInner, F, ActiveRequestBehavior); impl Future for Run<'_, F> { type Output = F::Output; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let io = self.0; let behavior = self.2; let fut = unsafe { self.map_unchecked_mut(|s| &mut s.1) }; if let Poll::Ready(out) = fut.poll(cx) { return Poll::Ready(out); } let Tick { did_handle, submitted, active_completions, } = io.tick().unwrap(); if did_handle { // We handled an event, it's likely that the future can make progress cx.waker().wake_by_ref(); return Poll::Pending; } if submitted > 0 { // We submitted an event, it's possible that the future can make progress, // so we should wake the executor and get re-polled cx.waker().wake_by_ref(); return Poll::Pending; } if active_completions > 0 { // We have completions in flight, but they're not ready yet. match behavior { ActiveRequestBehavior::Panic => { panic!("The future was not ready, and there are completions in-flight") }, ActiveRequestBehavior::Block => { io.submit(1).unwrap(); cx.waker().wake_by_ref(); return Poll::Pending; }, ActiveRequestBehavior::Pending => { return Poll::Pending; }, } } // The future likely depends on another thread, so return pending Poll::Pending } } #[cfg(test)] mod test { mod block_on { use std::time::Duration; use crate::io_impl::io_uring::IoUring; #[test] fn simple() { let uring = &IoUring::new().unwrap(); let out = uring.block_on(async { 5 }); assert_eq!(out, 5); } #[test] fn sleep() { let uring = &IoUring::new().unwrap(); let out = uring.block_on(async { crate::time::sleep(uring, Duration::from_secs(1)) .await .unwrap(); 5 }); assert_eq!(out, 5); } } mod run { use std::time::Duration; use crate::io_impl::io_uring::{ActiveRequestBehavior, IoUring}; #[test] fn sleep() { let uring = &IoUring::new().unwrap(); let out = futures_lite::future::block_on(async { uring .run( crate::time::sleep(uring, Duration::from_secs(1)), ActiveRequestBehavior::Block, ) .await .unwrap(); 5 }); assert_eq!(out, 5) } } mod net { use std::{future::Future, net::SocketAddr}; use crate::{ aliases::IoResult, io::{AsyncReadLoan, AsyncWriteLoan}, io_impl::io_uring::IoUring, }; use futures_lite::StreamExt; async fn start_echo( uring: &IoUring, ) -> IoResult<(SocketAddr, impl Future>> + '_)> { let mut listener = crate::net::TcpListener::bind(uring, "127.0.0.1:0").await?; Ok((listener.local_addr().await?, async move { let mut incoming = listener.incoming().await?; let mut stream = incoming.next().await.unwrap()?; let (mut data, read_amt) = stream.read(vec![0; 4096]).await; let read_amt = read_amt?; data.truncate(read_amt); Ok(data.into_boxed_slice()) })) } #[test] fn basic() { let uring = &IoUring::new().unwrap(); let input = "Hello, world!".as_bytes(); let output = uring .block_on(async { let (addr, read_data) = start_echo(uring).await?; let mut conn = crate::net::TcpStream::connect(uring, addr).await?; let (_, res) = conn.write(input).await; res?; let data = read_data.await?; IoResult::Ok(data) }) .unwrap(); assert_eq!(&output[..], input) } } }