More reworking, tcp connections

This commit is contained in:
soup 2024-10-21 13:54:56 -04:00
parent bdcf820ca7
commit 7b26c4b3cc
No known key found for this signature in database
6 changed files with 382 additions and 186 deletions

View file

@ -3,17 +3,9 @@ use std::future::Future;
use crate::{aliases::IoResult, io_impl::IoImpl};
pub trait AsyncRead {
fn read<I: IoImpl>(
&mut self,
io: &I,
buf: Box<[u8]>,
) -> impl Future<Output = IoResult<Box<[u8]>>>;
fn read(&mut self, buf: &mut [u8]) -> impl Future<Output = IoResult<usize>>;
}
pub trait AsyncWrite {
fn write<I: IoImpl>(
&mut self,
io: &I,
buf: Box<[u8]>,
) -> impl Future<Output = IoResult<usize>>;
fn write(&mut self, buf: &[u8]) -> impl Future<Output = IoResult<usize>>;
}

View file

@ -1,9 +1,14 @@
use std::{
async_iter::AsyncIterator,
future::Future,
net::SocketAddr,
os::fd::{FromRawFd, IntoRawFd},
pin::Pin,
sync::atomic::{AtomicU64, Ordering},
task::{Context, Poll, Waker},
};
use futures_lite::Stream;
use parking::Parker;
use crate::aliases::IoResult;
@ -90,6 +95,12 @@ impl IoUring {
for entry in cq {
did_handle = true;
let ud = entry.user_data();
if ud == 0 {
self.active_completions.fetch_sub(1, Ordering::Relaxed);
continue;
}
let ud = unsafe { UserData::from_u64(entry.user_data()) };
ud.tx.send_blocking(entry).unwrap();
if ud.persist {
@ -119,8 +130,8 @@ impl IoUring {
/// Runs a future. Conceptually, this is similar to [`Self::block_on`], but
/// instead of acting as its own executor, this allows us to be embedded in
/// another runtime
pub async fn run<F: Future>(&self, fut: F, behavior: ActiveRequestBehavior) -> Run<F> {
Run(&self, fut, behavior)
pub fn run<F: Future>(&self, fut: F, behavior: ActiveRequestBehavior) -> Run<F> {
Run(self, fut, behavior)
}
pub fn block_on<F: Future>(&self, fut: F) -> F::Output {
@ -171,6 +182,20 @@ impl IoUring {
}
}
pub struct TcpListener(io_uring::types::Fd);
impl Drop for TcpListener {
fn drop(&mut self) {
unsafe { std::net::TcpListener::from_raw_fd(self.0 .0) };
}
}
pub struct TcpStream(io_uring::types::Fd);
impl Drop for TcpStream {
fn drop(&mut self) {
unsafe { std::net::TcpStream::from_raw_fd(self.0 .0) };
}
}
impl IoImpl for IoUring {
async fn sleep(&self, duration: std::time::Duration) -> IoResult<()> {
let (tx, rx) = async_channel::bounded(1);
@ -189,6 +214,202 @@ impl IoImpl for IoUring {
Ok(())
}
type TcpListener = io_uring::types::Fd;
fn open_tcp_socket(
&self,
addr: std::net::SocketAddr,
) -> impl Future<Output = IoResult<Self::TcpListener>> {
// FIXME(Blocking)
// There is some some magic missing in the commented out code to make the
// socket clean up properly on process exit and whatnot. For now, just use
// the std implementation and cast it to a FileHandle
let listener = std::net::TcpListener::bind(addr);
async { IoResult::Ok(io_uring::types::Fd(listener?.into_raw_fd())) }
/*
// let (tx, rx) = async_channel::bounded(1);
let domain = match () {
_ if socket_addr.is_ipv4() => libc::AF_INET,
_ if socket_addr.is_ipv6() => libc::AF_INET6,
_ => return Err(std::io::Error::other("Unsupported domain")),
};
let entry = io_uring::opcode::Socket::new(domain, libc::SOCK_STREAM, 0)
.build()
.user_data(UserData::new_boxed(tx, false).into_u64());
unsafe {
p.uring.submission_shared().push(&entry).unwrap();
}
let entry = rx.recv().await.unwrap();
let fd = handle_error(entry.result())?;
let sock = libc::sockaddr_in {
sin_family: domain as _,
sin_port: socket_addr.port().to_be(),
sin_addr: libc::in_addr {
s_addr: match socket_addr.ip() {
IpAddr::V4(v) => v.to_bits().to_be(),
IpAddr::V6(_) => panic!(),
},
},
sin_zero: Default::default(),
};
// FIXME(Blocking)
handle_error(unsafe {
libc::bind(
fd,
&sock as *const _ as *const _,
core::mem::size_of_val(&sock) as u32,
)
})?;
// FIXME(Blocking)
handle_error(unsafe { libc::listen(fd, libc::SOMAXCONN) })?;
Ok(io_uring::types::Fd(fd))
*/
}
async fn listener_local_addr(&self, listener: &Self::TcpListener) -> IoResult<SocketAddr> {
// FIXME(Blocking)
let listener = unsafe { std::net::TcpListener::from_raw_fd(listener.0) };
let addr = listener.local_addr()?;
let _ = listener.into_raw_fd();
Ok(addr)
}
type Incoming = Incoming;
fn accept_many(
&self,
listener: &mut Self::TcpListener,
) -> impl Future<Output = IoResult<Self::Incoming>> {
let (tx, rx) = async_channel::unbounded();
let cb_id = UserData::new_boxed(tx, true).into_u64();
let entry = io_uring::opcode::AcceptMulti::new(*listener)
.build()
.user_data(cb_id);
unsafe {
self.uring.submission_shared().push(&entry).unwrap();
}
async move {
Ok(Incoming {
io: self as *const _,
rx: Box::pin(rx),
cb_id,
})
}
}
type TcpStream = TcpStream;
async fn tcp_read(&self, stream: &mut Self::TcpStream, buf: &mut [u8]) -> IoResult<usize> {
let (tx, rx) = async_channel::bounded(1);
let entry = io_uring::opcode::Read::new(stream.0, buf.as_mut_ptr(), buf.len() as u32)
.build()
.user_data(UserData::new_boxed(tx, false).into_u64());
unsafe {
self.uring.submission_shared().push(&entry).unwrap();
}
let entry = rx.recv().await.unwrap();
let read_amt = handle_error(entry.result())?;
Ok(read_amt as usize)
}
fn tcp_connect(
&self,
socket: SocketAddr,
) -> impl Future<Output = IoResult<Self::TcpStream>> {
// FIXME(Blocking)
let stream = std::net::TcpStream::connect(socket);
async { Ok(TcpStream(io_uring::types::Fd(stream?.into_raw_fd()))) }
}
async fn tcp_write(&self, stream: &mut Self::TcpStream, buf: &[u8]) -> IoResult<usize> {
let (tx, rx) = async_channel::bounded(1);
let entry = io_uring::opcode::Write::new(stream.0, buf.as_ptr(), buf.len() as u32)
.build()
.user_data(UserData::new_boxed(tx, false).into_u64());
unsafe {
self.uring.submission_shared().push(&entry).unwrap();
}
let entry = rx.recv().await.unwrap();
let write_amt = handle_error(entry.result())?;
Ok(write_amt as usize)
}
}
pub struct Incoming {
io: *const IoUring,
rx: Pin<Box<CqueueEntryReceiver>>,
cb_id: u64,
}
pub fn cancel(io: &IoUring, id: u64) {
let entry = io_uring::opcode::AsyncCancel::new(id).build().user_data(0);
unsafe {
io.uring.submission_shared().push(&entry).unwrap();
}
}
impl AsyncIterator for Incoming {
type Item = IoResult<TcpStream>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let rx = unsafe { self.map_unchecked_mut(|s| &mut s.rx) };
let mut fut = rx.recv();
let pinned = unsafe { Pin::new_unchecked(&mut fut) };
pinned
.poll(cx)
.map(|entry| {
let fd = handle_error(entry.unwrap().result())?;
Ok(TcpStream(io_uring::types::Fd(fd)))
})
.map(Some)
}
}
impl Stream for Incoming {
type Item = IoResult<TcpStream>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
AsyncIterator::poll_next(self, cx)
}
}
impl Drop for Incoming {
fn drop(&mut self) {
let io = unsafe { &*self.io };
cancel(io, self.cb_id);
}
}
/// Behavior used in [`Run`] when the future is not ready, the last [`IoUring::tick`]
@ -252,6 +473,7 @@ impl<F: Future> Future for Run<'_, F> {
},
ActiveRequestBehavior::Block => {
io.submit(1).unwrap();
cx.waker().wake_by_ref();
return Poll::Pending;
},
ActiveRequestBehavior::Pending => {
@ -280,7 +502,7 @@ mod test {
}
#[test]
fn timer() {
fn sleep() {
let uring = &IoUring::new().unwrap();
let out = uring.block_on(async {
crate::time::sleep(uring, Duration::from_secs(1)).await;
@ -290,4 +512,76 @@ mod test {
assert_eq!(out, 5);
}
}
mod run {
use std::time::Duration;
use crate::io_impl::io_uring::{ActiveRequestBehavior, IoUring};
#[test]
fn sleep() {
let uring = &IoUring::new().unwrap();
let out = futures_lite::future::block_on(async {
uring
.run(
crate::time::sleep(uring, Duration::from_secs(1)),
ActiveRequestBehavior::Block,
)
.await;
5
});
assert_eq!(out, 5)
}
}
mod net {
use std::{future::Future, net::SocketAddr};
use crate::{
aliases::IoResult,
io::{AsyncRead, AsyncWrite},
io_impl::io_uring::IoUring,
};
use futures_lite::StreamExt;
async fn start_echo(
uring: &IoUring,
) -> IoResult<(SocketAddr, impl Future<Output = IoResult<Box<[u8]>>> + '_)> {
let listener = crate::net::TcpListener::bind(uring, "127.0.0.1:0").await?;
Ok((listener.local_addr().await?, async {
let mut incoming = listener.incoming().await?;
let mut stream = incoming.next().await.unwrap()?;
let mut data = vec![0; 4096];
let read_amt = stream.read(&mut data).await?;
data.truncate(read_amt);
Ok(data.into_boxed_slice())
}))
}
#[test]
fn basic() {
let uring = &IoUring::new().unwrap();
let input = "Hello, world!".as_bytes();
let output = uring
.block_on(async {
let (addr, read_data) = start_echo(uring).await?;
let mut conn = crate::net::TcpStream::connect(uring, addr).await?;
conn.write(input).await?;
let data = read_data.await?;
IoResult::Ok(data)
})
.unwrap();
assert_eq!(&output[..], input)
}
}
}

View file

@ -1,9 +1,41 @@
use crate::aliases::IoResult;
use std::time::Duration;
use std::{future::Future, net::SocketAddr, time::Duration};
#[cfg(feature = "io_uring")]
pub mod io_uring;
pub trait IoImpl {
async fn sleep(&self, duration: Duration) -> IoResult<()>;
fn sleep(&self, duration: Duration) -> impl Future<Output = IoResult<()>>;
type TcpListener;
fn open_tcp_socket(
&self,
addr: SocketAddr,
) -> impl Future<Output = IoResult<Self::TcpListener>>;
fn listener_local_addr(
&self,
listener: &Self::TcpListener,
) -> impl Future<Output = IoResult<SocketAddr>>;
type Incoming;
fn accept_many(
&self,
listener: &mut Self::TcpListener,
) -> impl Future<Output = IoResult<Self::Incoming>>;
type TcpStream;
fn tcp_read(
&self,
stream: &mut Self::TcpStream,
buf: &mut [u8],
) -> impl Future<Output = IoResult<usize>>;
fn tcp_write(
&self,
stream: &mut Self::TcpStream,
buf: &[u8],
) -> impl Future<Output = IoResult<usize>>;
fn tcp_connect(
&self,
socket: SocketAddr,
) -> impl Future<Output = IoResult<Self::TcpStream>>;
}

View file

@ -6,9 +6,5 @@ mod aliases;
// pub mod fs;
pub mod io;
pub mod io_impl;
// pub mod net;
pub mod net;
pub mod time;
use std::{future::Future, pin::Pin, task::Poll};
use aliases::IoResult;

View file

@ -1,45 +1,46 @@
use std::{async_iter::AsyncIterator, net::ToSocketAddrs, pin::Pin};
use std::{
async_iter::AsyncIterator,
net::{SocketAddr, ToSocketAddrs},
pin::Pin,
};
use futures_lite::Stream;
use crate::{
aliases::IoResult,
io::{AsyncRead, AsyncWrite},
plat_impl, Wove,
io_impl::IoImpl,
};
pub struct TcpStream(pub(crate) plat_impl::FileHandle);
impl AsyncRead for TcpStream {
fn read(
&mut self,
wove: &Wove,
buf: Box<[u8]>,
) -> impl std::future::Future<Output = IoResult<Box<[u8]>>> {
plat_impl::read(&wove.platform, &mut self.0, 0, buf)
pub struct TcpStream<'a, I: IoImpl>(&'a I, I::TcpStream);
impl<'a, I: IoImpl> TcpStream<'a, I> {
pub async fn connect(io: &'a I, addr: SocketAddr) -> IoResult<Self> {
Ok(Self(io, io.tcp_connect(addr).await?))
}
}
impl AsyncWrite for TcpStream {
fn write(
&mut self,
wove: &Wove,
buf: Box<[u8]>,
) -> impl std::future::Future<Output = IoResult<usize>> {
plat_impl::write(&wove.platform, &mut self.0, 0, buf)
impl<'a, I: IoImpl> AsyncRead for TcpStream<'a, I> {
async fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
self.0.tcp_read(&mut self.1, buf).await
}
}
pub struct TcpListener(plat_impl::FileHandle);
impl<'a, I: IoImpl> AsyncWrite for TcpStream<'a, I> {
async fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
self.0.tcp_write(&mut self.1, buf).await
}
}
impl TcpListener {
pub async fn bind(wove: &Wove, addrs: impl ToSocketAddrs) -> IoResult<Self> {
pub struct TcpListener<'a, I: IoImpl>(&'a I, I::TcpListener);
impl<'a, I: IoImpl> TcpListener<'a, I> {
pub async fn bind(io: &'a I, addrs: impl ToSocketAddrs) -> IoResult<Self> {
// TODO(Blocking): to_socket_addrs can block
let mut last_err = None;
for addr in addrs.to_socket_addrs()? {
match plat_impl::open_tcp_socket(&wove.platform, addr).await {
Ok(v) => return Ok(TcpListener(v)),
match io.open_tcp_socket(addr).await {
Ok(v) => return Ok(TcpListener(io, v)),
Err(e) => last_err = Some(e),
}
}
@ -51,38 +52,47 @@ impl TcpListener {
Err(std::io::Error::other("No addrs returned"))
}
pub fn incoming<'a>(&'a mut self, wove: &'a Wove) -> Incoming<'a> {
Incoming::register(wove, self)
pub async fn local_addr(&self) -> IoResult<SocketAddr> {
self.0.listener_local_addr(&self.1).await
}
pub async fn incoming(mut self) -> IoResult<Incoming<'a, I>> {
Ok(Incoming(self.0, self.0.accept_many(&mut self.1).await?))
}
}
pub struct Incoming<'a>(plat_impl::Incoming<'a>);
impl<'a> Incoming<'a> {
fn register(wove: &'a Wove, listener: &'a mut TcpListener) -> Self {
Incoming(plat_impl::accept_many(&wove.platform, &mut listener.0))
}
}
pub struct Incoming<'a, I: IoImpl>(&'a I, I::Incoming);
impl AsyncIterator for Incoming<'_> {
type Item = IoResult<TcpStream>;
impl<'a, I: IoImpl> AsyncIterator for Incoming<'a, I>
where
I::Incoming: AsyncIterator<Item = IoResult<I::TcpStream>>,
{
type Item = IoResult<TcpStream<'a, I>>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let inner = unsafe { self.map_unchecked_mut(|s| &mut s.0) };
let io = self.0;
let inner = unsafe { self.map_unchecked_mut(|s| &mut s.1) };
AsyncIterator::poll_next(inner, cx)
AsyncIterator::poll_next(inner, cx).map(|o| o.map(|r| r.map(|i| TcpStream(io, i))))
}
}
impl Stream for Incoming<'_> {
type Item = IoResult<TcpStream>;
impl<'a, I: IoImpl> Stream for Incoming<'a, I>
where
I::Incoming: Stream<Item = IoResult<I::TcpStream>>,
{
type Item = IoResult<TcpStream<'a, I>>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
AsyncIterator::poll_next(self, cx)
let io = self.0;
let inner = unsafe { self.map_unchecked_mut(|s| &mut s.1) };
Stream::poll_next(inner, cx).map(|o| o.map(|r| r.map(|i| TcpStream(io, i))))
}
}

View file

@ -69,61 +69,6 @@ pub async fn open_tcp_socket(
p: &PlatformLinux,
socket_addr: SocketAddr,
) -> IoResult<FileHandle> {
// FIXME(Blocking)
// There is some some magic missing in the commented out code to make the
// socket clean up properly on process exit and whatnot. For now, just use
// the std implementation and cast it to a FileHandle
let listener = std::net::TcpListener::bind(socket_addr)?;
Ok(io_uring::types::Fd(listener.into_raw_fd()))
/*
// let (tx, rx) = async_channel::bounded(1);
let domain = match () {
_ if socket_addr.is_ipv4() => libc::AF_INET,
_ if socket_addr.is_ipv6() => libc::AF_INET6,
_ => return Err(std::io::Error::other("Unsupported domain")),
};
let entry = io_uring::opcode::Socket::new(domain, libc::SOCK_STREAM, 0)
.build()
.user_data(UserData::new_boxed(tx, false).into_u64());
unsafe {
p.uring.submission_shared().push(&entry).unwrap();
}
let entry = rx.recv().await.unwrap();
let fd = handle_error(entry.result())?;
let sock = libc::sockaddr_in {
sin_family: domain as _,
sin_port: socket_addr.port().to_be(),
sin_addr: libc::in_addr {
s_addr: match socket_addr.ip() {
IpAddr::V4(v) => v.to_bits().to_be(),
IpAddr::V6(_) => panic!(),
},
},
sin_zero: Default::default(),
};
// FIXME(Blocking)
handle_error(unsafe {
libc::bind(
fd,
&sock as *const _ as *const _,
core::mem::size_of_val(&sock) as u32,
)
})?;
// FIXME(Blocking)
handle_error(unsafe { libc::listen(fd, libc::SOMAXCONN) })?;
Ok(io_uring::types::Fd(fd))
*/
}
pub async fn read(
@ -175,77 +120,4 @@ pub async fn write(
Ok(write_amt as usize)
}
pub(crate) struct Incoming<'a> {
plat: &'a PlatformLinux,
rx: Pin<Box<CqueueEntryReceiver>>,
cb_id: u64,
}
pub fn accept_many<'a>(p: &'a PlatformLinux, f: &mut FileHandle) -> Incoming<'a> {
let (tx, rx) = async_channel::unbounded();
let cb_id = UserData::new_boxed(tx, true).into_u64();
let entry = io_uring::opcode::AcceptMulti::new(*f)
.build()
.user_data(cb_id);
unsafe {
p.uring.submission_shared().push(&entry).unwrap();
}
Incoming {
plat: p,
rx: Box::pin(rx),
cb_id,
}
}
pub fn cancel(p: &PlatformLinux, id: u64) {
let entry = io_uring::opcode::AsyncCancel::new(id).build();
unsafe {
p.uring.submission_shared().push(&entry).unwrap();
}
}
impl AsyncIterator for Incoming<'_> {
type Item = IoResult<TcpStream>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let rx = unsafe { self.map_unchecked_mut(|s| &mut s.rx) };
let mut fut = rx.recv();
let pinned = unsafe { Pin::new_unchecked(&mut fut) };
pinned
.poll(cx)
.map(|entry| {
let fd = handle_error(entry.unwrap().result())?;
Ok(crate::net::TcpStream(io_uring::types::Fd(fd)))
})
.map(Some)
}
}
impl Stream for Incoming<'_> {
type Item = IoResult<TcpStream>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
AsyncIterator::poll_next(self, cx)
}
}
impl Drop for Incoming<'_> {
fn drop(&mut self) {
cancel(self.plat, self.cb_id);
}
}
pub type Platform = PlatformLinux;