This commit is contained in:
soup 2024-10-21 19:35:35 -04:00
parent 7b26c4b3cc
commit 287d3d4933
No known key found for this signature in database
5 changed files with 50 additions and 68 deletions

View file

@ -1,11 +1 @@
use std::future::Future;
use crate::{aliases::IoResult, io_impl::IoImpl};
pub trait AsyncRead {
fn read(&mut self, buf: &mut [u8]) -> impl Future<Output = IoResult<usize>>;
}
pub trait AsyncWrite {
fn write(&mut self, buf: &[u8]) -> impl Future<Output = IoResult<usize>>;
}
pub use futures_lite::{AsyncRead, AsyncWrite};

View file

@ -1,14 +1,13 @@
use std::{
async_iter::AsyncIterator,
future::Future,
net::SocketAddr,
os::fd::{FromRawFd, IntoRawFd},
pin::Pin,
sync::atomic::{AtomicU64, Ordering},
sync::atomic::{AtomicBool, AtomicU64, Ordering},
task::{Context, Poll, Waker},
};
use futures_lite::Stream;
use futures_lite::{AsyncRead, AsyncReadExt, Stream};
use parking::Parker;
use crate::aliases::IoResult;
@ -189,10 +188,23 @@ impl Drop for TcpListener {
}
}
pub struct TcpStream(io_uring::types::Fd);
impl Drop for TcpStream {
pub struct TcpStream<'a>(&'a IoUring, io_uring::types::Fd, AtomicBool);
impl Drop for TcpStream<'_> {
fn drop(&mut self) {
unsafe { std::net::TcpStream::from_raw_fd(self.0 .0) };
unsafe { std::net::TcpStream::from_raw_fd(self.1 .0) };
}
}
impl AsyncRead for TcpStream<'_> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let sq = unsafe { self.0.uring.submission_shared() };
let entry = io_uring::opcode::Read::new(self.1, buf.as_mut_ptr(), buf.len() as u32)
.build()
.user_data(UserData::new_boxed(tx, false).into_u64());
}
}
@ -215,7 +227,7 @@ impl IoImpl for IoUring {
Ok(())
}
type TcpListener = io_uring::types::Fd;
type TcpListener = TcpListener;
fn open_tcp_socket(
&self,
addr: std::net::SocketAddr,
@ -226,7 +238,7 @@ impl IoImpl for IoUring {
// the std implementation and cast it to a FileHandle
let listener = std::net::TcpListener::bind(addr);
async { IoResult::Ok(io_uring::types::Fd(listener?.into_raw_fd())) }
async { IoResult::Ok(TcpListener(io_uring::types::Fd(listener?.into_raw_fd()))) }
/*
// let (tx, rx) = async_channel::bounded(1);
@ -279,7 +291,7 @@ impl IoImpl for IoUring {
async fn listener_local_addr(&self, listener: &Self::TcpListener) -> IoResult<SocketAddr> {
// FIXME(Blocking)
let listener = unsafe { std::net::TcpListener::from_raw_fd(listener.0) };
let listener = unsafe { std::net::TcpListener::from_raw_fd(listener.0 .0) };
let addr = listener.local_addr()?;
let _ = listener.into_raw_fd();
@ -295,7 +307,7 @@ impl IoImpl for IoUring {
let cb_id = UserData::new_boxed(tx, true).into_u64();
let entry = io_uring::opcode::AcceptMulti::new(*listener)
let entry = io_uring::opcode::AcceptMulti::new(listener.0)
.build()
.user_data(cb_id);
@ -312,8 +324,14 @@ impl IoImpl for IoUring {
}
}
type TcpStream = TcpStream;
async fn tcp_read(&self, stream: &mut Self::TcpStream, buf: &mut [u8]) -> IoResult<usize> {
type TcpStream<'a> = TcpStream<'a>;
async fn tcp_read(
&self,
stream: &mut Self::TcpStream<'_>,
buf: &mut [u8],
) -> IoResult<usize> {
return stream.read(buf).await;
let (tx, rx) = async_channel::bounded(1);
let entry = io_uring::opcode::Read::new(stream.0, buf.as_mut_ptr(), buf.len() as u32)
@ -371,8 +389,7 @@ pub fn cancel(io: &IoUring, id: u64) {
io.uring.submission_shared().push(&entry).unwrap();
}
}
impl AsyncIterator for Incoming {
impl Stream for Incoming {
type Item = IoResult<TcpStream>;
fn poll_next(
@ -394,17 +411,6 @@ impl AsyncIterator for Incoming {
}
}
impl Stream for Incoming {
type Item = IoResult<TcpStream>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
AsyncIterator::poll_next(self, cx)
}
}
impl Drop for Incoming {
fn drop(&mut self) {
let io = unsafe { &*self.io };

View file

@ -1,3 +1,5 @@
use futures_lite::AsyncRead;
use crate::aliases::IoResult;
use std::{future::Future, net::SocketAddr, time::Duration};
@ -23,19 +25,19 @@ pub trait IoImpl {
listener: &mut Self::TcpListener,
) -> impl Future<Output = IoResult<Self::Incoming>>;
type TcpStream;
type TcpStream<'a>: AsyncRead + 'a;
fn tcp_read(
&self,
stream: &mut Self::TcpStream,
stream: &mut Self::TcpStream<'_>,
buf: &mut [u8],
) -> impl Future<Output = IoResult<usize>>;
fn tcp_write(
&self,
stream: &mut Self::TcpStream,
stream: &mut Self::TcpStream<'_>,
buf: &[u8],
) -> impl Future<Output = IoResult<usize>>;
fn tcp_connect(
&self,
socket: SocketAddr,
) -> impl Future<Output = IoResult<Self::TcpStream>>;
) -> impl Future<Output = IoResult<Self::TcpStream<'_>>>;
}

View file

@ -1,7 +1,3 @@
#![feature(async_closure)]
#![feature(async_iterator)]
#![feature(never_type)]
mod aliases;
// pub mod fs;
pub mod io;

View file

@ -1,7 +1,8 @@
use std::{
async_iter::AsyncIterator,
net::{SocketAddr, ToSocketAddrs},
ops::DerefMut,
pin::Pin,
task::Poll,
};
use futures_lite::Stream;
@ -19,16 +20,20 @@ impl<'a, I: IoImpl> TcpStream<'a, I> {
}
}
impl<'a, I: IoImpl> AsyncRead for TcpStream<'a, I> {
async fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
self.0.tcp_read(&mut self.1, buf).await
impl<I: IoImpl> AsyncRead for TcpStream<'_, I>
where
I::TcpStream: Unpin + AsyncRead,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.1).poll_read(cx, buf)
}
}
impl<'a, I: IoImpl> AsyncWrite for TcpStream<'a, I> {
async fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
self.0.tcp_write(&mut self.1, buf).await
}
}
pub struct TcpListener<'a, I: IoImpl>(&'a I, I::TcpListener);
@ -63,23 +68,6 @@ impl<'a, I: IoImpl> TcpListener<'a, I> {
pub struct Incoming<'a, I: IoImpl>(&'a I, I::Incoming);
impl<'a, I: IoImpl> AsyncIterator for Incoming<'a, I>
where
I::Incoming: AsyncIterator<Item = IoResult<I::TcpStream>>,
{
type Item = IoResult<TcpStream<'a, I>>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let io = self.0;
let inner = unsafe { self.map_unchecked_mut(|s| &mut s.1) };
AsyncIterator::poll_next(inner, cx).map(|o| o.map(|r| r.map(|i| TcpStream(io, i))))
}
}
impl<'a, I: IoImpl> Stream for Incoming<'a, I>
where
I::Incoming: Stream<Item = IoResult<I::TcpStream>>,