671 lines
15 KiB
Rust
671 lines
15 KiB
Rust
use std::{
|
|
future::Future,
|
|
net::SocketAddr,
|
|
ops::Deref,
|
|
os::fd::{FromRawFd, IntoRawFd},
|
|
pin::Pin,
|
|
rc::Rc,
|
|
sync::atomic::{AtomicI32, AtomicU64, Ordering},
|
|
task::{Context, Poll, Waker},
|
|
};
|
|
|
|
use async_lock::{futures::BarrierWait, Barrier};
|
|
use futures_lite::{pin, Stream};
|
|
use parking::Parker;
|
|
|
|
use crate::{
|
|
aliases::IoResult,
|
|
io::{AsyncReadLoan, AsyncWriteLoan},
|
|
};
|
|
|
|
use super::IoImpl;
|
|
|
|
#[derive(Debug)]
|
|
struct ResultBarrier {
|
|
result: AtomicI32,
|
|
barrier: Barrier,
|
|
}
|
|
impl ResultBarrier {
|
|
fn new() -> Self {
|
|
Self {
|
|
result: AtomicI32::new(0),
|
|
barrier: Barrier::new(2),
|
|
}
|
|
}
|
|
|
|
async fn wait(&self) {
|
|
self.barrier.wait().await;
|
|
}
|
|
|
|
fn result(&self) -> i32 {
|
|
self.result.load(Ordering::Relaxed)
|
|
}
|
|
|
|
fn set_result_and_block(&self, v: i32) {
|
|
self.result.store(v, Ordering::Relaxed);
|
|
self.barrier.wait_blocking();
|
|
}
|
|
|
|
async fn wait_result(&self) -> i32 {
|
|
self.wait().await;
|
|
|
|
self.result()
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct UserData<'a> {
|
|
persist: bool,
|
|
rb: &'a ResultBarrier,
|
|
}
|
|
|
|
impl<'a> UserData<'a> {
|
|
fn new_boxed(rb: &'a ResultBarrier, persist: bool) -> Box<Self> {
|
|
Box::new(Self { rb, persist })
|
|
}
|
|
|
|
fn new_into_u64(rb: &'a ResultBarrier, persist: bool) -> u64 {
|
|
Self::new_boxed(rb, persist).into_u64()
|
|
}
|
|
|
|
fn into_u64(self: Box<Self>) -> u64 {
|
|
Box::leak(self) as *mut _ as _
|
|
}
|
|
|
|
unsafe fn from_u64(v: u64) -> Box<UserData<'a>> {
|
|
let v = v as *mut UserData;
|
|
|
|
unsafe { Box::from_raw(v) }
|
|
}
|
|
}
|
|
|
|
fn handle_error(i: i32) -> IoResult<i32> {
|
|
if i < 0 {
|
|
return Err(std::io::Error::from_raw_os_error(-i));
|
|
}
|
|
|
|
Ok(i)
|
|
}
|
|
|
|
pub struct Tick {
|
|
did_handle: bool,
|
|
submitted: usize,
|
|
active_completions: usize,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct IoUring(Rc<IoUringInner>);
|
|
impl IoUring {
|
|
pub fn new() -> IoResult<Self> {
|
|
Ok(Self(Rc::new(IoUringInner::new()?)))
|
|
}
|
|
}
|
|
impl Deref for IoUring {
|
|
type Target = IoUringInner;
|
|
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
pub struct IoUringInner {
|
|
uring: io_uring::IoUring,
|
|
|
|
// TODO: analyze the atomic orderings passed to the atomic operations here
|
|
// to make sure they make sense
|
|
|
|
// TODO: I'm not sure that atomics even make sense here, but they let this
|
|
// method work with &self instead of &mut self so :shrug:
|
|
active_completions: AtomicU64,
|
|
|
|
_pd: core::marker::PhantomData<std::cell::RefCell<()>>,
|
|
}
|
|
impl IoUringInner {
|
|
pub fn new() -> IoResult<Self> {
|
|
let uring = io_uring::IoUring::new(256)?;
|
|
|
|
Ok(Self {
|
|
uring,
|
|
active_completions: AtomicU64::new(0),
|
|
|
|
_pd: Default::default(),
|
|
})
|
|
}
|
|
|
|
/// Cancel all events for the given fd. Does not return anything, and
|
|
/// cancellations are made on a best-effort basis
|
|
fn cancel_fd(&self, fd: io_uring::types::Fd) {
|
|
let rb = ResultBarrier::new();
|
|
|
|
let entry =
|
|
io_uring::opcode::AsyncCancel2::new(io_uring::types::CancelBuilder::fd(fd).all())
|
|
.build()
|
|
.user_data(UserData::new_into_u64(&rb, false));
|
|
|
|
self.queue_op(entry);
|
|
}
|
|
|
|
fn queue_op(&self, op: io_uring::squeue::Entry) {
|
|
unsafe { self.uring.submission_shared().push(&op).unwrap() }
|
|
}
|
|
|
|
async fn wait_op(&self, op: io_uring::squeue::Entry) -> IoResult<i32> {
|
|
let rb = ResultBarrier::new();
|
|
|
|
let entry = op.user_data(UserData::new_into_u64(&rb, false));
|
|
|
|
self.queue_op(entry);
|
|
|
|
handle_error(rb.wait_result().await)
|
|
}
|
|
|
|
pub fn submit(&self, wait_for: usize) -> IoResult<usize> {
|
|
let submitted_count = self.uring.submit_and_wait(wait_for)?;
|
|
self.active_completions
|
|
.fetch_add(submitted_count as u64, Ordering::Relaxed);
|
|
|
|
Ok(submitted_count)
|
|
}
|
|
|
|
/// Returns the current number of active requests, and a boolean indicating
|
|
/// if we actually handled any events.
|
|
pub fn poll(&self) -> IoResult<(usize, bool)> {
|
|
let mut did_handle = false;
|
|
// SAFETY:
|
|
// - this method is synchronous, and `cq` is dropped at the end of the scope
|
|
// - [`IoUring`] is !Sync, so it should be impossible for 2 threads to be
|
|
// running this at the same time
|
|
let cq = unsafe { self.uring.completion_shared() };
|
|
for entry in cq {
|
|
did_handle = true;
|
|
|
|
let ud = entry.user_data();
|
|
if ud == 0 {
|
|
self.active_completions.fetch_sub(1, Ordering::Relaxed);
|
|
continue;
|
|
}
|
|
|
|
let ud = unsafe { UserData::from_u64(entry.user_data()) };
|
|
ud.rb.set_result_and_block(entry.result());
|
|
if ud.persist {
|
|
Box::leak(ud);
|
|
} else {
|
|
self.active_completions.fetch_sub(1, Ordering::Relaxed);
|
|
}
|
|
}
|
|
|
|
Ok((
|
|
self.active_completions.load(Ordering::Relaxed) as usize,
|
|
did_handle,
|
|
))
|
|
}
|
|
|
|
pub fn tick(&self) -> IoResult<Tick> {
|
|
let submitted = self.submit(0)?;
|
|
let (active_completions, did_handle) = self.poll()?;
|
|
|
|
Ok(Tick {
|
|
active_completions,
|
|
did_handle,
|
|
submitted,
|
|
})
|
|
}
|
|
|
|
/// Runs a future. Conceptually, this is similar to [`Self::block_on`], but
|
|
/// instead of acting as its own executor, this allows us to be embedded in
|
|
/// another runtime
|
|
pub fn run<F: Future>(&self, fut: F, behavior: ActiveRequestBehavior) -> Run<F> {
|
|
Run(self, fut, behavior)
|
|
}
|
|
|
|
pub fn block_on<F: Future>(&self, fut: F) -> F::Output {
|
|
futures_lite::pin!(fut);
|
|
|
|
let parker = Parker::new();
|
|
let unparker = parker.unparker();
|
|
let waker = Waker::from(unparker);
|
|
|
|
let cx = &mut Context::from_waker(&waker);
|
|
|
|
loop {
|
|
// Check if the future is ready. If so, return the value.
|
|
if let Poll::Ready(v) = fut.as_mut().poll(cx) {
|
|
return v;
|
|
}
|
|
|
|
let Tick {
|
|
did_handle,
|
|
active_completions,
|
|
submitted,
|
|
} = self.tick().unwrap();
|
|
|
|
if did_handle {
|
|
// If we handled an event, it's likely that our future can make progress,
|
|
// so continue the loop
|
|
continue;
|
|
}
|
|
|
|
if submitted > 0 {
|
|
// We submitted an event. It's possible that our future can make progress
|
|
// once we poll, so continue the loop
|
|
continue;
|
|
}
|
|
|
|
if active_completions > 0 {
|
|
// We didn't submit an event, but we do have completions in-flight.
|
|
// We should block until one of them completes, and then re-poll
|
|
self.submit(1).unwrap();
|
|
continue;
|
|
}
|
|
|
|
// If we've gotten to this point, it's likely that we're waiting on a
|
|
// future that depends on another thread to make progress. In that case,
|
|
// park the current thread until the waker is called.
|
|
parker.park();
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct TcpListener(io_uring::types::Fd);
|
|
impl Drop for TcpListener {
|
|
fn drop(&mut self) {
|
|
unsafe { std::net::TcpListener::from_raw_fd(self.0 .0) };
|
|
}
|
|
}
|
|
|
|
pub struct TcpStream {
|
|
uring: IoUring,
|
|
fd: io_uring::types::Fd,
|
|
}
|
|
impl Drop for TcpStream {
|
|
fn drop(&mut self) {
|
|
self.uring.cancel_fd(self.fd);
|
|
unsafe { std::net::TcpStream::from_raw_fd(self.fd.0) };
|
|
}
|
|
}
|
|
impl AsyncReadLoan for TcpStream {
|
|
async fn read<B: crate::io::BufferMut>(&mut self, mut buf: B) -> (B, IoResult<usize>) {
|
|
let res = self
|
|
.uring
|
|
.0
|
|
.wait_op(
|
|
io_uring::opcode::Read::new(
|
|
self.fd,
|
|
buf.as_mut_ptr(),
|
|
buf.writable_bytes() as u32,
|
|
)
|
|
.build(),
|
|
)
|
|
.await
|
|
.map(|v| v as usize);
|
|
|
|
(buf, res)
|
|
}
|
|
}
|
|
|
|
impl AsyncWriteLoan for TcpStream {
|
|
async fn write<B: crate::io::Buffer>(&mut self, buf: B) -> (B, IoResult<usize>) {
|
|
let res = self
|
|
.uring
|
|
.0
|
|
.wait_op(
|
|
io_uring::opcode::Write::new(
|
|
self.fd,
|
|
buf.as_ptr(),
|
|
buf.readable_bytes() as u32,
|
|
)
|
|
.build(),
|
|
)
|
|
.await
|
|
.map(|v| v as usize);
|
|
|
|
(buf, res)
|
|
}
|
|
}
|
|
|
|
impl IoImpl for IoUring {
|
|
async fn sleep(&self, duration: std::time::Duration) -> IoResult<()> {
|
|
let ts = io_uring::types::Timespec::new()
|
|
.sec(duration.as_secs())
|
|
.nsec(duration.subsec_nanos());
|
|
let entry = io_uring::opcode::Timeout::new(&ts as *const _).build();
|
|
|
|
let _ = self.0.wait_op(entry).await;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
type TcpListener = TcpListener;
|
|
fn open_tcp_socket(
|
|
&self,
|
|
addr: std::net::SocketAddr,
|
|
) -> impl Future<Output = IoResult<Self::TcpListener>> {
|
|
// FIXME(Blocking)
|
|
// There is some some magic missing in the commented out code to make the
|
|
// socket clean up properly on process exit and whatnot. For now, just use
|
|
// the std implementation and cast it to a FileHandle
|
|
|
|
let listener = std::net::TcpListener::bind(addr);
|
|
async { IoResult::Ok(TcpListener(io_uring::types::Fd(listener?.into_raw_fd()))) }
|
|
|
|
/*
|
|
// let (tx, rx) = async_channel::bounded(1);
|
|
|
|
let domain = match () {
|
|
_ if socket_addr.is_ipv4() => libc::AF_INET,
|
|
_ if socket_addr.is_ipv6() => libc::AF_INET6,
|
|
_ => return Err(std::io::Error::other("Unsupported domain")),
|
|
};
|
|
|
|
let entry = io_uring::opcode::Socket::new(domain, libc::SOCK_STREAM, 0)
|
|
.build()
|
|
.user_data(UserData::new_boxed(tx, false).into_u64());
|
|
|
|
unsafe {
|
|
p.uring.submission_shared().push(&entry).unwrap();
|
|
}
|
|
|
|
let entry = rx.recv().await.unwrap();
|
|
|
|
let fd = handle_error(entry.result())?;
|
|
|
|
let sock = libc::sockaddr_in {
|
|
sin_family: domain as _,
|
|
sin_port: socket_addr.port().to_be(),
|
|
sin_addr: libc::in_addr {
|
|
s_addr: match socket_addr.ip() {
|
|
IpAddr::V4(v) => v.to_bits().to_be(),
|
|
IpAddr::V6(_) => panic!(),
|
|
},
|
|
},
|
|
sin_zero: Default::default(),
|
|
};
|
|
|
|
// FIXME(Blocking)
|
|
handle_error(unsafe {
|
|
libc::bind(
|
|
fd,
|
|
&sock as *const _ as *const _,
|
|
core::mem::size_of_val(&sock) as u32,
|
|
)
|
|
})?;
|
|
|
|
// FIXME(Blocking)
|
|
handle_error(unsafe { libc::listen(fd, libc::SOMAXCONN) })?;
|
|
|
|
Ok(io_uring::types::Fd(fd))
|
|
*/
|
|
}
|
|
|
|
async fn listener_local_addr(&self, listener: &Self::TcpListener) -> IoResult<SocketAddr> {
|
|
// FIXME(Blocking)
|
|
let listener = unsafe { std::net::TcpListener::from_raw_fd(listener.0 .0) };
|
|
let addr = listener.local_addr()?;
|
|
let _ = listener.into_raw_fd();
|
|
|
|
Ok(addr)
|
|
}
|
|
|
|
type Incoming = Incoming;
|
|
fn accept_many(
|
|
&self,
|
|
listener: &mut Self::TcpListener,
|
|
) -> impl Future<Output = IoResult<Self::Incoming>> {
|
|
let rb = Box::pin(ResultBarrier::new());
|
|
|
|
let cb_id = UserData::new_into_u64(&rb, true);
|
|
|
|
let entry = io_uring::opcode::AcceptMulti::new(listener.0)
|
|
.build()
|
|
.user_data(cb_id);
|
|
|
|
unsafe {
|
|
self.0.uring.submission_shared().push(&entry).unwrap();
|
|
}
|
|
|
|
let wait = unsafe {
|
|
core::mem::transmute::<BarrierWait<'_>, BarrierWait<'_>>(rb.barrier.wait())
|
|
};
|
|
async move {
|
|
Ok(Incoming {
|
|
uring: self.clone(),
|
|
rb,
|
|
wait,
|
|
fd: listener.0,
|
|
})
|
|
}
|
|
}
|
|
|
|
type TcpStream = TcpStream;
|
|
fn tcp_connect(
|
|
&self,
|
|
socket: SocketAddr,
|
|
) -> impl Future<Output = IoResult<Self::TcpStream>> {
|
|
// FIXME(Blocking)
|
|
let stream = std::net::TcpStream::connect(socket);
|
|
|
|
async {
|
|
Ok(TcpStream {
|
|
uring: self.clone(),
|
|
fd: io_uring::types::Fd(stream?.into_raw_fd()),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct Incoming {
|
|
uring: IoUring,
|
|
rb: Pin<Box<ResultBarrier>>,
|
|
wait: BarrierWait<'static>,
|
|
fd: io_uring::types::Fd,
|
|
}
|
|
impl Unpin for Incoming {
|
|
}
|
|
|
|
impl Stream for Incoming {
|
|
type Item = IoResult<TcpStream>;
|
|
|
|
fn poll_next(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<Option<Self::Item>> {
|
|
let fut = unsafe { self.as_mut().map_unchecked_mut(|s| &mut s.wait) };
|
|
pin!(fut);
|
|
|
|
fut.poll(cx)
|
|
.map(|_| {
|
|
let fd = handle_error(self.rb.result())?;
|
|
|
|
Ok(TcpStream {
|
|
uring: self.uring.clone(),
|
|
fd: io_uring::types::Fd(fd),
|
|
})
|
|
})
|
|
.map(Some)
|
|
}
|
|
}
|
|
|
|
impl Drop for Incoming {
|
|
fn drop(&mut self) {
|
|
self.uring.0.cancel_fd(self.fd);
|
|
}
|
|
}
|
|
|
|
/// Behavior used in [`Run`] when the future is not ready, the last [`IoUring::tick`]
|
|
/// didn't submit or handle any events, and there are requests in flight
|
|
#[derive(Copy, Clone, Debug)]
|
|
pub enum ActiveRequestBehavior {
|
|
/// Cause a panic
|
|
Panic,
|
|
|
|
/// Block until a completion is returned
|
|
Block,
|
|
|
|
/// Return pending.
|
|
///
|
|
/// <div class="warning">
|
|
///
|
|
/// NOTE: this relies on the fact that the executor is going to poll this
|
|
/// _eventually_. If it doesn't, it's likely that you'll get a deadlock.
|
|
///
|
|
/// </div>
|
|
Pending,
|
|
}
|
|
|
|
pub struct Run<'a, F: Future>(&'a IoUringInner, F, ActiveRequestBehavior);
|
|
impl<F: Future> Future for Run<'_, F> {
|
|
type Output = F::Output;
|
|
|
|
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
let io = self.0;
|
|
let behavior = self.2;
|
|
let fut = unsafe { self.map_unchecked_mut(|s| &mut s.1) };
|
|
|
|
if let Poll::Ready(out) = fut.poll(cx) {
|
|
return Poll::Ready(out);
|
|
}
|
|
|
|
let Tick {
|
|
did_handle,
|
|
submitted,
|
|
active_completions,
|
|
} = io.tick().unwrap();
|
|
|
|
if did_handle {
|
|
// We handled an event, it's likely that the future can make progress
|
|
cx.waker().wake_by_ref();
|
|
return Poll::Pending;
|
|
}
|
|
|
|
if submitted > 0 {
|
|
// We submitted an event, it's possible that the future can make progress,
|
|
// so we should wake the executor and get re-polled
|
|
cx.waker().wake_by_ref();
|
|
return Poll::Pending;
|
|
}
|
|
|
|
if active_completions > 0 {
|
|
// We have completions in flight, but they're not ready yet.
|
|
match behavior {
|
|
ActiveRequestBehavior::Panic => {
|
|
panic!("The future was not ready, and there are completions in-flight")
|
|
},
|
|
ActiveRequestBehavior::Block => {
|
|
io.submit(1).unwrap();
|
|
cx.waker().wake_by_ref();
|
|
return Poll::Pending;
|
|
},
|
|
ActiveRequestBehavior::Pending => {
|
|
return Poll::Pending;
|
|
},
|
|
}
|
|
}
|
|
|
|
// The future likely depends on another thread, so return pending
|
|
Poll::Pending
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
mod block_on {
|
|
use std::time::Duration;
|
|
|
|
use crate::io_impl::io_uring::IoUring;
|
|
|
|
#[test]
|
|
fn simple() {
|
|
let uring = &IoUring::new().unwrap();
|
|
let out = uring.block_on(async { 5 });
|
|
assert_eq!(out, 5);
|
|
}
|
|
|
|
#[test]
|
|
fn sleep() {
|
|
let uring = &IoUring::new().unwrap();
|
|
let out = uring.block_on(async {
|
|
crate::time::sleep(uring, Duration::from_secs(1))
|
|
.await
|
|
.unwrap();
|
|
5
|
|
});
|
|
|
|
assert_eq!(out, 5);
|
|
}
|
|
}
|
|
|
|
mod run {
|
|
use std::time::Duration;
|
|
|
|
use crate::io_impl::io_uring::{ActiveRequestBehavior, IoUring};
|
|
|
|
#[test]
|
|
fn sleep() {
|
|
let uring = &IoUring::new().unwrap();
|
|
let out = futures_lite::future::block_on(async {
|
|
uring
|
|
.run(
|
|
crate::time::sleep(uring, Duration::from_secs(1)),
|
|
ActiveRequestBehavior::Block,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
5
|
|
});
|
|
|
|
assert_eq!(out, 5)
|
|
}
|
|
}
|
|
|
|
mod net {
|
|
use std::{future::Future, net::SocketAddr};
|
|
|
|
use crate::{
|
|
aliases::IoResult,
|
|
io::{AsyncReadLoan, AsyncWriteLoan},
|
|
io_impl::io_uring::IoUring,
|
|
};
|
|
use futures_lite::StreamExt;
|
|
|
|
async fn start_echo(
|
|
uring: &IoUring,
|
|
) -> IoResult<(SocketAddr, impl Future<Output = IoResult<Box<[u8]>>> + '_)> {
|
|
let mut listener = crate::net::TcpListener::bind(uring, "127.0.0.1:0").await?;
|
|
|
|
Ok((listener.local_addr().await?, async move {
|
|
let mut incoming = listener.incoming().await?;
|
|
let mut stream = incoming.next().await.unwrap()?;
|
|
|
|
let (mut data, read_amt) = stream.read(vec![0; 4096]).await;
|
|
let read_amt = read_amt?;
|
|
data.truncate(read_amt);
|
|
|
|
Ok(data.into_boxed_slice())
|
|
}))
|
|
}
|
|
|
|
#[test]
|
|
fn basic() {
|
|
let uring = &IoUring::new().unwrap();
|
|
|
|
let input = "Hello, world!".as_bytes();
|
|
let output = uring
|
|
.block_on(async {
|
|
let (addr, read_data) = start_echo(uring).await?;
|
|
let mut conn = crate::net::TcpStream::connect(uring, addr).await?;
|
|
|
|
let (_, res) = conn.write(input).await;
|
|
res?;
|
|
|
|
let data = read_data.await?;
|
|
|
|
IoResult::Ok(data)
|
|
})
|
|
.unwrap();
|
|
|
|
assert_eq!(&output[..], input)
|
|
}
|
|
}
|
|
}
|