This commit is contained in:
soup 2024-10-21 19:35:35 -04:00
parent 7b26c4b3cc
commit 287d3d4933
No known key found for this signature in database
5 changed files with 50 additions and 68 deletions

View file

@ -1,11 +1 @@
use std::future::Future; pub use futures_lite::{AsyncRead, AsyncWrite};
use crate::{aliases::IoResult, io_impl::IoImpl};
pub trait AsyncRead {
fn read(&mut self, buf: &mut [u8]) -> impl Future<Output = IoResult<usize>>;
}
pub trait AsyncWrite {
fn write(&mut self, buf: &[u8]) -> impl Future<Output = IoResult<usize>>;
}

View file

@ -1,14 +1,13 @@
use std::{ use std::{
async_iter::AsyncIterator,
future::Future, future::Future,
net::SocketAddr, net::SocketAddr,
os::fd::{FromRawFd, IntoRawFd}, os::fd::{FromRawFd, IntoRawFd},
pin::Pin, pin::Pin,
sync::atomic::{AtomicU64, Ordering}, sync::atomic::{AtomicBool, AtomicU64, Ordering},
task::{Context, Poll, Waker}, task::{Context, Poll, Waker},
}; };
use futures_lite::Stream; use futures_lite::{AsyncRead, AsyncReadExt, Stream};
use parking::Parker; use parking::Parker;
use crate::aliases::IoResult; use crate::aliases::IoResult;
@ -189,10 +188,23 @@ impl Drop for TcpListener {
} }
} }
pub struct TcpStream(io_uring::types::Fd); pub struct TcpStream<'a>(&'a IoUring, io_uring::types::Fd, AtomicBool);
impl Drop for TcpStream { impl Drop for TcpStream<'_> {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { std::net::TcpStream::from_raw_fd(self.0 .0) }; unsafe { std::net::TcpStream::from_raw_fd(self.1 .0) };
}
}
impl AsyncRead for TcpStream<'_> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let sq = unsafe { self.0.uring.submission_shared() };
let entry = io_uring::opcode::Read::new(self.1, buf.as_mut_ptr(), buf.len() as u32)
.build()
.user_data(UserData::new_boxed(tx, false).into_u64());
} }
} }
@ -215,7 +227,7 @@ impl IoImpl for IoUring {
Ok(()) Ok(())
} }
type TcpListener = io_uring::types::Fd; type TcpListener = TcpListener;
fn open_tcp_socket( fn open_tcp_socket(
&self, &self,
addr: std::net::SocketAddr, addr: std::net::SocketAddr,
@ -226,7 +238,7 @@ impl IoImpl for IoUring {
// the std implementation and cast it to a FileHandle // the std implementation and cast it to a FileHandle
let listener = std::net::TcpListener::bind(addr); let listener = std::net::TcpListener::bind(addr);
async { IoResult::Ok(io_uring::types::Fd(listener?.into_raw_fd())) } async { IoResult::Ok(TcpListener(io_uring::types::Fd(listener?.into_raw_fd()))) }
/* /*
// let (tx, rx) = async_channel::bounded(1); // let (tx, rx) = async_channel::bounded(1);
@ -279,7 +291,7 @@ impl IoImpl for IoUring {
async fn listener_local_addr(&self, listener: &Self::TcpListener) -> IoResult<SocketAddr> { async fn listener_local_addr(&self, listener: &Self::TcpListener) -> IoResult<SocketAddr> {
// FIXME(Blocking) // FIXME(Blocking)
let listener = unsafe { std::net::TcpListener::from_raw_fd(listener.0) }; let listener = unsafe { std::net::TcpListener::from_raw_fd(listener.0 .0) };
let addr = listener.local_addr()?; let addr = listener.local_addr()?;
let _ = listener.into_raw_fd(); let _ = listener.into_raw_fd();
@ -295,7 +307,7 @@ impl IoImpl for IoUring {
let cb_id = UserData::new_boxed(tx, true).into_u64(); let cb_id = UserData::new_boxed(tx, true).into_u64();
let entry = io_uring::opcode::AcceptMulti::new(*listener) let entry = io_uring::opcode::AcceptMulti::new(listener.0)
.build() .build()
.user_data(cb_id); .user_data(cb_id);
@ -312,8 +324,14 @@ impl IoImpl for IoUring {
} }
} }
type TcpStream = TcpStream; type TcpStream<'a> = TcpStream<'a>;
async fn tcp_read(&self, stream: &mut Self::TcpStream, buf: &mut [u8]) -> IoResult<usize> { async fn tcp_read(
&self,
stream: &mut Self::TcpStream<'_>,
buf: &mut [u8],
) -> IoResult<usize> {
return stream.read(buf).await;
let (tx, rx) = async_channel::bounded(1); let (tx, rx) = async_channel::bounded(1);
let entry = io_uring::opcode::Read::new(stream.0, buf.as_mut_ptr(), buf.len() as u32) let entry = io_uring::opcode::Read::new(stream.0, buf.as_mut_ptr(), buf.len() as u32)
@ -371,8 +389,7 @@ pub fn cancel(io: &IoUring, id: u64) {
io.uring.submission_shared().push(&entry).unwrap(); io.uring.submission_shared().push(&entry).unwrap();
} }
} }
impl Stream for Incoming {
impl AsyncIterator for Incoming {
type Item = IoResult<TcpStream>; type Item = IoResult<TcpStream>;
fn poll_next( fn poll_next(
@ -394,17 +411,6 @@ impl AsyncIterator for Incoming {
} }
} }
impl Stream for Incoming {
type Item = IoResult<TcpStream>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
AsyncIterator::poll_next(self, cx)
}
}
impl Drop for Incoming { impl Drop for Incoming {
fn drop(&mut self) { fn drop(&mut self) {
let io = unsafe { &*self.io }; let io = unsafe { &*self.io };

View file

@ -1,3 +1,5 @@
use futures_lite::AsyncRead;
use crate::aliases::IoResult; use crate::aliases::IoResult;
use std::{future::Future, net::SocketAddr, time::Duration}; use std::{future::Future, net::SocketAddr, time::Duration};
@ -23,19 +25,19 @@ pub trait IoImpl {
listener: &mut Self::TcpListener, listener: &mut Self::TcpListener,
) -> impl Future<Output = IoResult<Self::Incoming>>; ) -> impl Future<Output = IoResult<Self::Incoming>>;
type TcpStream; type TcpStream<'a>: AsyncRead + 'a;
fn tcp_read( fn tcp_read(
&self, &self,
stream: &mut Self::TcpStream, stream: &mut Self::TcpStream<'_>,
buf: &mut [u8], buf: &mut [u8],
) -> impl Future<Output = IoResult<usize>>; ) -> impl Future<Output = IoResult<usize>>;
fn tcp_write( fn tcp_write(
&self, &self,
stream: &mut Self::TcpStream, stream: &mut Self::TcpStream<'_>,
buf: &[u8], buf: &[u8],
) -> impl Future<Output = IoResult<usize>>; ) -> impl Future<Output = IoResult<usize>>;
fn tcp_connect( fn tcp_connect(
&self, &self,
socket: SocketAddr, socket: SocketAddr,
) -> impl Future<Output = IoResult<Self::TcpStream>>; ) -> impl Future<Output = IoResult<Self::TcpStream<'_>>>;
} }

View file

@ -1,7 +1,3 @@
#![feature(async_closure)]
#![feature(async_iterator)]
#![feature(never_type)]
mod aliases; mod aliases;
// pub mod fs; // pub mod fs;
pub mod io; pub mod io;

View file

@ -1,7 +1,8 @@
use std::{ use std::{
async_iter::AsyncIterator,
net::{SocketAddr, ToSocketAddrs}, net::{SocketAddr, ToSocketAddrs},
ops::DerefMut,
pin::Pin, pin::Pin,
task::Poll,
}; };
use futures_lite::Stream; use futures_lite::Stream;
@ -19,16 +20,20 @@ impl<'a, I: IoImpl> TcpStream<'a, I> {
} }
} }
impl<'a, I: IoImpl> AsyncRead for TcpStream<'a, I> { impl<I: IoImpl> AsyncRead for TcpStream<'_, I>
async fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> { where
self.0.tcp_read(&mut self.1, buf).await I::TcpStream: Unpin + AsyncRead,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.1).poll_read(cx, buf)
} }
} }
impl<'a, I: IoImpl> AsyncWrite for TcpStream<'a, I> { impl<'a, I: IoImpl> AsyncWrite for TcpStream<'a, I> {
async fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
self.0.tcp_write(&mut self.1, buf).await
}
} }
pub struct TcpListener<'a, I: IoImpl>(&'a I, I::TcpListener); pub struct TcpListener<'a, I: IoImpl>(&'a I, I::TcpListener);
@ -63,23 +68,6 @@ impl<'a, I: IoImpl> TcpListener<'a, I> {
pub struct Incoming<'a, I: IoImpl>(&'a I, I::Incoming); pub struct Incoming<'a, I: IoImpl>(&'a I, I::Incoming);
impl<'a, I: IoImpl> AsyncIterator for Incoming<'a, I>
where
I::Incoming: AsyncIterator<Item = IoResult<I::TcpStream>>,
{
type Item = IoResult<TcpStream<'a, I>>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let io = self.0;
let inner = unsafe { self.map_unchecked_mut(|s| &mut s.1) };
AsyncIterator::poll_next(inner, cx).map(|o| o.map(|r| r.map(|i| TcpStream(io, i))))
}
}
impl<'a, I: IoImpl> Stream for Incoming<'a, I> impl<'a, I: IoImpl> Stream for Incoming<'a, I>
where where
I::Incoming: Stream<Item = IoResult<I::TcpStream>>, I::Incoming: Stream<Item = IoResult<I::TcpStream>>,