wove/src/io_impl/io_uring/mod.rs
2024-10-21 23:38:48 -04:00

671 lines
15 KiB
Rust

use std::{
future::Future,
net::SocketAddr,
ops::Deref,
os::fd::{FromRawFd, IntoRawFd},
pin::Pin,
rc::Rc,
sync::atomic::{AtomicI32, AtomicU64, Ordering},
task::{Context, Poll, Waker},
};
use async_lock::{futures::BarrierWait, Barrier};
use futures_lite::{pin, Stream};
use parking::Parker;
use crate::{
aliases::IoResult,
io::{AsyncReadLoan, AsyncWriteLoan},
};
use super::IoImpl;
#[derive(Debug)]
struct ResultBarrier {
result: AtomicI32,
barrier: Barrier,
}
impl ResultBarrier {
fn new() -> Self {
Self {
result: AtomicI32::new(0),
barrier: Barrier::new(2),
}
}
async fn wait(&self) {
self.barrier.wait().await;
}
fn result(&self) -> i32 {
self.result.load(Ordering::Relaxed)
}
fn set_result_and_block(&self, v: i32) {
self.result.store(v, Ordering::Relaxed);
self.barrier.wait_blocking();
}
async fn wait_result(&self) -> i32 {
self.wait().await;
self.result()
}
}
#[derive(Debug)]
pub struct UserData<'a> {
persist: bool,
rb: &'a ResultBarrier,
}
impl<'a> UserData<'a> {
fn new_boxed(rb: &'a ResultBarrier, persist: bool) -> Box<Self> {
Box::new(Self { rb, persist })
}
fn new_into_u64(rb: &'a ResultBarrier, persist: bool) -> u64 {
Self::new_boxed(rb, persist).into_u64()
}
fn into_u64(self: Box<Self>) -> u64 {
Box::leak(self) as *mut _ as _
}
unsafe fn from_u64(v: u64) -> Box<UserData<'a>> {
let v = v as *mut UserData;
unsafe { Box::from_raw(v) }
}
}
fn handle_error(i: i32) -> IoResult<i32> {
if i < 0 {
return Err(std::io::Error::from_raw_os_error(-i));
}
Ok(i)
}
pub struct Tick {
did_handle: bool,
submitted: usize,
active_completions: usize,
}
#[derive(Clone)]
pub struct IoUring(Rc<IoUringInner>);
impl IoUring {
pub fn new() -> IoResult<Self> {
Ok(Self(Rc::new(IoUringInner::new()?)))
}
}
impl Deref for IoUring {
type Target = IoUringInner;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct IoUringInner {
uring: io_uring::IoUring,
// TODO: analyze the atomic orderings passed to the atomic operations here
// to make sure they make sense
// TODO: I'm not sure that atomics even make sense here, but they let this
// method work with &self instead of &mut self so :shrug:
active_completions: AtomicU64,
_pd: core::marker::PhantomData<std::cell::RefCell<()>>,
}
impl IoUringInner {
pub fn new() -> IoResult<Self> {
let uring = io_uring::IoUring::new(256)?;
Ok(Self {
uring,
active_completions: AtomicU64::new(0),
_pd: Default::default(),
})
}
/// Cancel all events for the given fd. Does not return anything, and
/// cancellations are made on a best-effort basis
fn cancel_fd(&self, fd: io_uring::types::Fd) {
let rb = ResultBarrier::new();
let entry =
io_uring::opcode::AsyncCancel2::new(io_uring::types::CancelBuilder::fd(fd).all())
.build()
.user_data(UserData::new_into_u64(&rb, false));
self.queue_op(entry);
}
fn queue_op(&self, op: io_uring::squeue::Entry) {
unsafe { self.uring.submission_shared().push(&op).unwrap() }
}
async fn wait_op(&self, op: io_uring::squeue::Entry) -> IoResult<i32> {
let rb = ResultBarrier::new();
let entry = op.user_data(UserData::new_into_u64(&rb, false));
self.queue_op(entry);
handle_error(rb.wait_result().await)
}
pub fn submit(&self, wait_for: usize) -> IoResult<usize> {
let submitted_count = self.uring.submit_and_wait(wait_for)?;
self.active_completions
.fetch_add(submitted_count as u64, Ordering::Relaxed);
Ok(submitted_count)
}
/// Returns the current number of active requests, and a boolean indicating
/// if we actually handled any events.
pub fn poll(&self) -> IoResult<(usize, bool)> {
let mut did_handle = false;
// SAFETY:
// - this method is synchronous, and `cq` is dropped at the end of the scope
// - [`IoUring`] is !Sync, so it should be impossible for 2 threads to be
// running this at the same time
let cq = unsafe { self.uring.completion_shared() };
for entry in cq {
did_handle = true;
let ud = entry.user_data();
if ud == 0 {
self.active_completions.fetch_sub(1, Ordering::Relaxed);
continue;
}
let ud = unsafe { UserData::from_u64(entry.user_data()) };
ud.rb.set_result_and_block(entry.result());
if ud.persist {
Box::leak(ud);
} else {
self.active_completions.fetch_sub(1, Ordering::Relaxed);
}
}
Ok((
self.active_completions.load(Ordering::Relaxed) as usize,
did_handle,
))
}
pub fn tick(&self) -> IoResult<Tick> {
let submitted = self.submit(0)?;
let (active_completions, did_handle) = self.poll()?;
Ok(Tick {
active_completions,
did_handle,
submitted,
})
}
/// Runs a future. Conceptually, this is similar to [`Self::block_on`], but
/// instead of acting as its own executor, this allows us to be embedded in
/// another runtime
pub fn run<F: Future>(&self, fut: F, behavior: ActiveRequestBehavior) -> Run<F> {
Run(self, fut, behavior)
}
pub fn block_on<F: Future>(&self, fut: F) -> F::Output {
futures_lite::pin!(fut);
let parker = Parker::new();
let unparker = parker.unparker();
let waker = Waker::from(unparker);
let cx = &mut Context::from_waker(&waker);
loop {
// Check if the future is ready. If so, return the value.
if let Poll::Ready(v) = fut.as_mut().poll(cx) {
return v;
}
let Tick {
did_handle,
active_completions,
submitted,
} = self.tick().unwrap();
if did_handle {
// If we handled an event, it's likely that our future can make progress,
// so continue the loop
continue;
}
if submitted > 0 {
// We submitted an event. It's possible that our future can make progress
// once we poll, so continue the loop
continue;
}
if active_completions > 0 {
// We didn't submit an event, but we do have completions in-flight.
// We should block until one of them completes, and then re-poll
self.submit(1).unwrap();
continue;
}
// If we've gotten to this point, it's likely that we're waiting on a
// future that depends on another thread to make progress. In that case,
// park the current thread until the waker is called.
parker.park();
}
}
}
pub struct TcpListener(io_uring::types::Fd);
impl Drop for TcpListener {
fn drop(&mut self) {
unsafe { std::net::TcpListener::from_raw_fd(self.0 .0) };
}
}
pub struct TcpStream {
uring: IoUring,
fd: io_uring::types::Fd,
}
impl Drop for TcpStream {
fn drop(&mut self) {
self.uring.cancel_fd(self.fd);
unsafe { std::net::TcpStream::from_raw_fd(self.fd.0) };
}
}
impl AsyncReadLoan for TcpStream {
async fn read<B: crate::io::BufferMut>(&mut self, mut buf: B) -> (B, IoResult<usize>) {
let res = self
.uring
.0
.wait_op(
io_uring::opcode::Read::new(
self.fd,
buf.as_mut_ptr(),
buf.writable_bytes() as u32,
)
.build(),
)
.await
.map(|v| v as usize);
(buf, res)
}
}
impl AsyncWriteLoan for TcpStream {
async fn write<B: crate::io::Buffer>(&mut self, buf: B) -> (B, IoResult<usize>) {
let res = self
.uring
.0
.wait_op(
io_uring::opcode::Write::new(
self.fd,
buf.as_ptr(),
buf.readable_bytes() as u32,
)
.build(),
)
.await
.map(|v| v as usize);
(buf, res)
}
}
impl IoImpl for IoUring {
async fn sleep(&self, duration: std::time::Duration) -> IoResult<()> {
let ts = io_uring::types::Timespec::new()
.sec(duration.as_secs())
.nsec(duration.subsec_nanos());
let entry = io_uring::opcode::Timeout::new(&ts as *const _).build();
let _ = self.0.wait_op(entry).await;
Ok(())
}
type TcpListener = TcpListener;
fn open_tcp_socket(
&self,
addr: std::net::SocketAddr,
) -> impl Future<Output = IoResult<Self::TcpListener>> {
// FIXME(Blocking)
// There is some some magic missing in the commented out code to make the
// socket clean up properly on process exit and whatnot. For now, just use
// the std implementation and cast it to a FileHandle
let listener = std::net::TcpListener::bind(addr);
async { IoResult::Ok(TcpListener(io_uring::types::Fd(listener?.into_raw_fd()))) }
/*
// let (tx, rx) = async_channel::bounded(1);
let domain = match () {
_ if socket_addr.is_ipv4() => libc::AF_INET,
_ if socket_addr.is_ipv6() => libc::AF_INET6,
_ => return Err(std::io::Error::other("Unsupported domain")),
};
let entry = io_uring::opcode::Socket::new(domain, libc::SOCK_STREAM, 0)
.build()
.user_data(UserData::new_boxed(tx, false).into_u64());
unsafe {
p.uring.submission_shared().push(&entry).unwrap();
}
let entry = rx.recv().await.unwrap();
let fd = handle_error(entry.result())?;
let sock = libc::sockaddr_in {
sin_family: domain as _,
sin_port: socket_addr.port().to_be(),
sin_addr: libc::in_addr {
s_addr: match socket_addr.ip() {
IpAddr::V4(v) => v.to_bits().to_be(),
IpAddr::V6(_) => panic!(),
},
},
sin_zero: Default::default(),
};
// FIXME(Blocking)
handle_error(unsafe {
libc::bind(
fd,
&sock as *const _ as *const _,
core::mem::size_of_val(&sock) as u32,
)
})?;
// FIXME(Blocking)
handle_error(unsafe { libc::listen(fd, libc::SOMAXCONN) })?;
Ok(io_uring::types::Fd(fd))
*/
}
async fn listener_local_addr(&self, listener: &Self::TcpListener) -> IoResult<SocketAddr> {
// FIXME(Blocking)
let listener = unsafe { std::net::TcpListener::from_raw_fd(listener.0 .0) };
let addr = listener.local_addr()?;
let _ = listener.into_raw_fd();
Ok(addr)
}
type Incoming = Incoming;
fn accept_many(
&self,
listener: &mut Self::TcpListener,
) -> impl Future<Output = IoResult<Self::Incoming>> {
let rb = Box::pin(ResultBarrier::new());
let cb_id = UserData::new_into_u64(&rb, true);
let entry = io_uring::opcode::AcceptMulti::new(listener.0)
.build()
.user_data(cb_id);
unsafe {
self.0.uring.submission_shared().push(&entry).unwrap();
}
let wait = unsafe {
core::mem::transmute::<BarrierWait<'_>, BarrierWait<'_>>(rb.barrier.wait())
};
async move {
Ok(Incoming {
uring: self.clone(),
rb,
wait,
fd: listener.0,
})
}
}
type TcpStream = TcpStream;
fn tcp_connect(
&self,
socket: SocketAddr,
) -> impl Future<Output = IoResult<Self::TcpStream>> {
// FIXME(Blocking)
let stream = std::net::TcpStream::connect(socket);
async {
Ok(TcpStream {
uring: self.clone(),
fd: io_uring::types::Fd(stream?.into_raw_fd()),
})
}
}
}
pub struct Incoming {
uring: IoUring,
rb: Pin<Box<ResultBarrier>>,
wait: BarrierWait<'static>,
fd: io_uring::types::Fd,
}
impl Unpin for Incoming {
}
impl Stream for Incoming {
type Item = IoResult<TcpStream>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let fut = unsafe { self.as_mut().map_unchecked_mut(|s| &mut s.wait) };
pin!(fut);
fut.poll(cx)
.map(|_| {
let fd = handle_error(self.rb.result())?;
Ok(TcpStream {
uring: self.uring.clone(),
fd: io_uring::types::Fd(fd),
})
})
.map(Some)
}
}
impl Drop for Incoming {
fn drop(&mut self) {
self.uring.0.cancel_fd(self.fd);
}
}
/// Behavior used in [`Run`] when the future is not ready, the last [`IoUring::tick`]
/// didn't submit or handle any events, and there are requests in flight
#[derive(Copy, Clone, Debug)]
pub enum ActiveRequestBehavior {
/// Cause a panic
Panic,
/// Block until a completion is returned
Block,
/// Return pending.
///
/// <div class="warning">
///
/// NOTE: this relies on the fact that the executor is going to poll this
/// _eventually_. If it doesn't, it's likely that you'll get a deadlock.
///
/// </div>
Pending,
}
pub struct Run<'a, F: Future>(&'a IoUringInner, F, ActiveRequestBehavior);
impl<F: Future> Future for Run<'_, F> {
type Output = F::Output;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let io = self.0;
let behavior = self.2;
let fut = unsafe { self.map_unchecked_mut(|s| &mut s.1) };
if let Poll::Ready(out) = fut.poll(cx) {
return Poll::Ready(out);
}
let Tick {
did_handle,
submitted,
active_completions,
} = io.tick().unwrap();
if did_handle {
// We handled an event, it's likely that the future can make progress
cx.waker().wake_by_ref();
return Poll::Pending;
}
if submitted > 0 {
// We submitted an event, it's possible that the future can make progress,
// so we should wake the executor and get re-polled
cx.waker().wake_by_ref();
return Poll::Pending;
}
if active_completions > 0 {
// We have completions in flight, but they're not ready yet.
match behavior {
ActiveRequestBehavior::Panic => {
panic!("The future was not ready, and there are completions in-flight")
},
ActiveRequestBehavior::Block => {
io.submit(1).unwrap();
cx.waker().wake_by_ref();
return Poll::Pending;
},
ActiveRequestBehavior::Pending => {
return Poll::Pending;
},
}
}
// The future likely depends on another thread, so return pending
Poll::Pending
}
}
#[cfg(test)]
mod test {
mod block_on {
use std::time::Duration;
use crate::io_impl::io_uring::IoUring;
#[test]
fn simple() {
let uring = &IoUring::new().unwrap();
let out = uring.block_on(async { 5 });
assert_eq!(out, 5);
}
#[test]
fn sleep() {
let uring = &IoUring::new().unwrap();
let out = uring.block_on(async {
crate::time::sleep(uring, Duration::from_secs(1))
.await
.unwrap();
5
});
assert_eq!(out, 5);
}
}
mod run {
use std::time::Duration;
use crate::io_impl::io_uring::{ActiveRequestBehavior, IoUring};
#[test]
fn sleep() {
let uring = &IoUring::new().unwrap();
let out = futures_lite::future::block_on(async {
uring
.run(
crate::time::sleep(uring, Duration::from_secs(1)),
ActiveRequestBehavior::Block,
)
.await
.unwrap();
5
});
assert_eq!(out, 5)
}
}
mod net {
use std::{future::Future, net::SocketAddr};
use crate::{
aliases::IoResult,
io::{AsyncReadLoan, AsyncWriteLoan},
io_impl::io_uring::IoUring,
};
use futures_lite::StreamExt;
async fn start_echo(
uring: &IoUring,
) -> IoResult<(SocketAddr, impl Future<Output = IoResult<Box<[u8]>>> + '_)> {
let mut listener = crate::net::TcpListener::bind(uring, "127.0.0.1:0").await?;
Ok((listener.local_addr().await?, async move {
let mut incoming = listener.incoming().await?;
let mut stream = incoming.next().await.unwrap()?;
let (mut data, read_amt) = stream.read(vec![0; 4096]).await;
let read_amt = read_amt?;
data.truncate(read_amt);
Ok(data.into_boxed_slice())
}))
}
#[test]
fn basic() {
let uring = &IoUring::new().unwrap();
let input = "Hello, world!".as_bytes();
let output = uring
.block_on(async {
let (addr, read_data) = start_echo(uring).await?;
let mut conn = crate::net::TcpStream::connect(uring, addr).await?;
let (_, res) = conn.write(input).await;
res?;
let data = read_data.await?;
IoResult::Ok(data)
})
.unwrap();
assert_eq!(&output[..], input)
}
}
}