use std::{ future::Future, mem::transmute, net::SocketAddr, ops::Deref, os::fd::{FromRawFd, IntoRawFd}, pin::Pin, rc::Rc, sync::atomic::{AtomicU64, Ordering}, task::{Context, Poll, Waker}, }; use futures_lite::{pin, Stream}; use parking::Parker; use crate::{ aliases::IoResult, futures::{handoff, Get, GetFut, Put}, io::{AsyncReadLoan, AsyncWriteLoan}, }; use super::IoImpl; #[derive(Debug)] pub struct UserData { put: Put, persist: bool, _opcode: u8, } impl UserData { fn into_u64(self: Box) -> u64 { Box::leak(self) as *mut _ as _ } unsafe fn from_u64(v: u64) -> Box { let v = v as *mut UserData; unsafe { Box::from_raw(v) } } } fn handle_error(i: i32) -> IoResult { if i < 0 { return Err(std::io::Error::from_raw_os_error(-i)); } Ok(i) } pub struct Tick { did_handle: bool, submitted: usize, active_completions: usize, } #[derive(Clone)] pub struct IoUring(Rc); impl IoUring { pub fn new() -> IoResult { Ok(Self(Rc::new(IoUringInner::new()?))) } } impl Deref for IoUring { type Target = IoUringInner; fn deref(&self) -> &Self::Target { &self.0 } } pub struct IoUringInner { uring: io_uring::IoUring, // TODO: analyze the atomic orderings passed to the atomic operations here // to make sure they make sense // TODO: I'm not sure that atomics even make sense here, but they let this // method work with &self instead of &mut self so :shrug: active_completions: AtomicU64, _pd: core::marker::PhantomData>, } impl IoUringInner { pub fn new() -> IoResult { let uring = io_uring::IoUring::new(256)?; Ok(Self { uring, active_completions: AtomicU64::new(0), _pd: Default::default(), }) } /// Cancel all events for the given fd. Does not return anything, and /// cancellations are made on a best-effort basis fn cancel_fd(&self, fd: io_uring::types::Fd) { let entry = io_uring::opcode::AsyncCancel2::new(io_uring::types::CancelBuilder::fd(fd).all()) .build() .user_data(0); self.op_queue(entry); } fn _cancel_ud(&self, ud: u64) { let entry = io_uring::opcode::AsyncCancel2::new( io_uring::types::CancelBuilder::user_data(ud).all(), ) .build() .user_data(0); self.op_queue(entry); } fn op_queue(&self, op: io_uring::squeue::Entry) { unsafe { self.uring.submission_shared().push(&op).unwrap() } } async fn op_wait(&self, op: io_uring::squeue::Entry) -> IoResult { let (put, mut get) = handoff(); let opcode = unsafe { *(&op as *const _ as *const u8) }; let user_data = UserData { put, _opcode: opcode, persist: false, }; let entry = op.user_data(Box::new(user_data).into_u64()); self.op_queue(entry); handle_error(get.get().await) } fn op_many(&self, op: io_uring::squeue::Entry) -> Get { let (put, get) = handoff(); let opcode = unsafe { *(&op as *const _ as *const u8) }; let user_data = UserData { put, _opcode: opcode, persist: true, }; let entry = op.user_data(Box::new(user_data).into_u64()); self.op_queue(entry); get } pub fn submit(&self, wait_for: usize) -> IoResult { let submitted_count = self.uring.submit_and_wait(wait_for)?; self.active_completions .fetch_add(submitted_count as u64, Ordering::Relaxed); Ok(submitted_count) } /// Returns the current number of active requests, and a boolean indicating /// if we actually handled any events. pub fn poll(&self) -> IoResult<(usize, bool)> { let mut did_handle = false; // SAFETY: // - this method is synchronous, and `cq` is dropped at the end of the scope // - [`IoUring`] is !Sync, so it should be impossible for 2 threads to be // running this at the same time let cq = unsafe { self.uring.completion_shared() }; for entry in cq { did_handle = true; let ud = entry.user_data(); if ud == 0 { self.active_completions.fetch_sub(1, Ordering::Relaxed); continue; } let mut user_data = unsafe { UserData::from_u64(ud) }; let _ = user_data.put.try_put(entry.result()); if user_data.persist { Box::leak(user_data); } else { self.active_completions.fetch_sub(1, Ordering::Relaxed); } } Ok(( self.active_completions.load(Ordering::Relaxed) as usize, did_handle, )) } pub fn tick(&self) -> IoResult { let submitted = self.submit(0)?; let (active_completions, did_handle) = self.poll()?; Ok(Tick { active_completions, did_handle, submitted, }) } /// Runs a future. Conceptually, this is similar to [`Self::block_on`], but /// instead of acting as its own executor, this allows us to be embedded in /// another runtime pub fn run(&self, fut: F, behavior: ActiveRequestBehavior) -> Run { Run(self, fut, behavior) } pub fn block_on(&self, fut: F) -> F::Output { futures_lite::pin!(fut); let parker = Parker::new(); let unparker = parker.unparker(); let waker = Waker::from(unparker); let cx = &mut Context::from_waker(&waker); loop { // Check if the future is ready. If so, return the value. if let Poll::Ready(v) = fut.as_mut().poll(cx) { return v; } let Tick { did_handle, active_completions, submitted, } = self.tick().unwrap(); if did_handle { // If we handled an event, it's likely that our future can make progress, // so continue the loop continue; } if submitted > 0 { // We submitted an event. It's possible that our future can make progress // once we poll, so continue the loop continue; } if active_completions > 0 { // We didn't submit an event, but we do have completions in-flight. // We should block until one of them completes, and then re-poll self.submit(1).unwrap(); continue; } // If we've gotten to this point, it's likely that we're waiting on a // future that depends on another thread to make progress. In that case, // park the current thread until the waker is called. parker.park(); } } } pub struct TcpListener(io_uring::types::Fd); impl Drop for TcpListener { fn drop(&mut self) { unsafe { std::net::TcpListener::from_raw_fd(self.0 .0) }; } } pub struct TcpStream { uring: IoUring, fd: io_uring::types::Fd, } impl Drop for TcpStream { fn drop(&mut self) { self.uring.cancel_fd(self.fd); unsafe { std::net::TcpStream::from_raw_fd(self.fd.0) }; } } impl AsyncReadLoan for TcpStream { async fn read(&mut self, mut buf: B) -> (B, IoResult) { let res = self .uring .0 .op_wait( io_uring::opcode::Read::new( self.fd, buf.as_mut_ptr(), buf.writable_bytes() as u32, ) .build(), ) .await .map(|v| v as usize); (buf, res) } } impl AsyncWriteLoan for TcpStream { async fn write(&mut self, buf: B) -> (B, IoResult) { let res = self .uring .0 .op_wait( io_uring::opcode::Write::new( self.fd, buf.as_ptr(), buf.readable_bytes() as u32, ) .build(), ) .await .map(|v| v as usize); (buf, res) } } impl IoImpl for IoUring { async fn sleep(&self, duration: std::time::Duration) -> IoResult<()> { let ts = io_uring::types::Timespec::new() .sec(duration.as_secs()) .nsec(duration.subsec_nanos()); let entry = io_uring::opcode::Timeout::new(&ts as *const _).build(); let _ = self.0.op_wait(entry).await; Ok(()) } type TcpListener = TcpListener; fn open_tcp_socket( &self, addr: std::net::SocketAddr, ) -> impl Future> { // FIXME(Blocking) // There is some some magic missing in the commented out code to make the // socket clean up properly on process exit and whatnot. For now, just use // the std implementation and cast it to a FileHandle let listener = std::net::TcpListener::bind(addr); async { IoResult::Ok(TcpListener(io_uring::types::Fd(listener?.into_raw_fd()))) } /* // let (tx, rx) = async_channel::bounded(1); let domain = match () { _ if socket_addr.is_ipv4() => libc::AF_INET, _ if socket_addr.is_ipv6() => libc::AF_INET6, _ => return Err(std::io::Error::other("Unsupported domain")), }; let entry = io_uring::opcode::Socket::new(domain, libc::SOCK_STREAM, 0) .build() .user_data(UserData::new_boxed(tx, false).into_u64()); unsafe { p.uring.submission_shared().push(&entry).unwrap(); } let entry = rx.recv().await.unwrap(); let fd = handle_error(entry.result())?; let sock = libc::sockaddr_in { sin_family: domain as _, sin_port: socket_addr.port().to_be(), sin_addr: libc::in_addr { s_addr: match socket_addr.ip() { IpAddr::V4(v) => v.to_bits().to_be(), IpAddr::V6(_) => panic!(), }, }, sin_zero: Default::default(), }; // FIXME(Blocking) handle_error(unsafe { libc::bind( fd, &sock as *const _ as *const _, core::mem::size_of_val(&sock) as u32, ) })?; // FIXME(Blocking) handle_error(unsafe { libc::listen(fd, libc::SOMAXCONN) })?; Ok(io_uring::types::Fd(fd)) */ } async fn listener_local_addr(&self, listener: &Self::TcpListener) -> IoResult { // FIXME(Blocking) let listener = unsafe { std::net::TcpListener::from_raw_fd(listener.0 .0) }; let addr = listener.local_addr()?; let _ = listener.into_raw_fd(); Ok(addr) } type Incoming = Incoming; fn accept_many( &self, listener: &mut Self::TcpListener, ) -> impl Future> { let get = self.op_many(io_uring::opcode::AcceptMulti::new(listener.0).build()); let mut get = Box::pin(get); let fut = get.get(); let fut = unsafe { transmute::, GetFut<'_, _>>(fut) }; async { Ok(Incoming(self.clone(), get, fut, listener.0)) } } type TcpStream = TcpStream; fn tcp_connect( &self, socket: SocketAddr, ) -> impl Future> { // FIXME(Blocking) let stream = std::net::TcpStream::connect(socket); async { Ok(TcpStream { uring: self.clone(), fd: io_uring::types::Fd(stream?.into_raw_fd()), }) } } } pub struct Incoming( IoUring, #[allow(dead_code)] Pin>>, GetFut<'static, i32>, io_uring::types::Fd, ); impl Stream for Incoming { type Item = IoResult; fn poll_next( mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let fut = &mut self.2; pin!(fut); fut.poll_next(cx).map(|o| { o.map(|v| { let fd = handle_error(v)?; Ok(TcpStream { uring: self.0.clone(), fd: io_uring::types::Fd(fd), }) }) }) } } impl Drop for Incoming { fn drop(&mut self) { self.0.cancel_fd(self.3); } } /// Behavior used in [`Run`] when the future is not ready, the last [`IoUring::tick`] /// didn't submit or handle any events, and there are requests in flight #[derive(Copy, Clone, Debug)] pub enum ActiveRequestBehavior { /// Cause a panic Panic, /// Block until a completion is returned Block, /// Return pending. /// ///
/// /// NOTE: this relies on the fact that the executor is going to poll this /// _eventually_. If it doesn't, it's likely that you'll get a deadlock. /// ///
Pending, } pub struct Run<'a, F: Future>(&'a IoUringInner, F, ActiveRequestBehavior); impl Future for Run<'_, F> { type Output = F::Output; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let io = self.0; let behavior = self.2; let fut = unsafe { self.map_unchecked_mut(|s| &mut s.1) }; if let Poll::Ready(out) = fut.poll(cx) { return Poll::Ready(out); } let Tick { did_handle, submitted, active_completions, } = io.tick().unwrap(); if did_handle { // We handled an event, it's likely that the future can make progress cx.waker().wake_by_ref(); return Poll::Pending; } if submitted > 0 { // We submitted an event, it's possible that the future can make progress, // so we should wake the executor and get re-polled cx.waker().wake_by_ref(); return Poll::Pending; } if active_completions > 0 { // We have completions in flight, but they're not ready yet. match behavior { ActiveRequestBehavior::Panic => { panic!("The future was not ready, and there are completions in-flight") }, ActiveRequestBehavior::Block => { io.submit(1).unwrap(); cx.waker().wake_by_ref(); return Poll::Pending; }, ActiveRequestBehavior::Pending => { return Poll::Pending; }, } } // The future likely depends on another thread, so return pending Poll::Pending } } #[cfg(test)] mod test { mod block_on { use std::time::Duration; use crate::io_impl::io_uring::IoUring; #[test] fn simple() { let uring = &IoUring::new().unwrap(); let out = uring.block_on(async { 5 }); assert_eq!(out, 5); } #[test] fn sleep() { let uring = &IoUring::new().unwrap(); let out = uring.block_on(async { crate::time::sleep(uring, Duration::from_secs(1)) .await .unwrap(); 5 }); assert_eq!(out, 5); } } mod run { use std::time::Duration; use crate::io_impl::io_uring::{ActiveRequestBehavior, IoUring}; #[test] fn sleep() { let uring = &IoUring::new().unwrap(); let out = futures_lite::future::block_on(async { uring .run( crate::time::sleep(uring, Duration::from_secs(1)), ActiveRequestBehavior::Block, ) .await .unwrap(); 5 }); assert_eq!(out, 5) } } mod net { use std::{future::Future, net::SocketAddr, time::Duration}; use crate::{ aliases::IoResult, io::{AsyncReadLoan, AsyncWriteLoan}, io_impl::io_uring::IoUring, }; use futures_lite::StreamExt; async fn start_echo( uring: &IoUring, ) -> IoResult<(SocketAddr, impl Future>> + '_)> { let mut listener = crate::net::TcpListener::bind(uring, "127.0.0.1:0").await?; Ok((listener.local_addr().await?, async move { let mut incoming = listener.incoming().await?; let mut stream = incoming.next().await.unwrap()?; let (mut data, read_amt) = stream.read(vec![0; 4096]).await; let read_amt = read_amt?; data.truncate(read_amt); Ok(data.into_boxed_slice()) })) } #[test] fn basic() { let uring = &IoUring::new().unwrap(); let input = "Hello, world!".as_bytes(); let output = uring .block_on(async { let (addr, read_data) = start_echo(uring).await?; let mut conn = crate::net::TcpStream::connect(uring, addr).await?; let (_, res) = conn.write(input).await; res?; let data = read_data.await?; IoResult::Ok(data) }) .unwrap(); assert_eq!(&output[..], input) } #[test] fn read_stream() { let uring = &IoUring::new().unwrap(); let res = uring .block_on(async { let mut listener = crate::net::TcpListener::bind(uring, "127.0.0.1:0").await?; let addr = listener.local_addr().await?; let write_data = async { let mut stream = listener.incoming().await?.next().await.unwrap()?; let (_, res) = stream.write("Hello".as_bytes()).await; res?; crate::time::sleep(uring, Duration::from_millis(500)).await?; let (_, res) = stream.write(", world".as_bytes()).await; res?; IoResult::Ok(()) }; let read_data = async { let mut stream = crate::net::TcpStream::connect(uring, addr).await?; let data: Vec> = stream.read_stream().try_collect().await?; IoResult::Ok(data) }; let (_, data) = futures_lite::future::try_zip(write_data, read_data).await?; IoResult::Ok(data) }) .unwrap(); assert_eq!( res, vec![ "Hello".to_string().into_bytes(), ", world".to_string().into_bytes() ] ) } } }