From 287d3d49337e81f7a6dfdc9839168b775472b446 Mon Sep 17 00:00:00 2001 From: soup Date: Mon, 21 Oct 2024 19:35:35 -0400 Subject: [PATCH] . --- src/io.rs | 12 +------- src/io_impl/io_uring/mod.rs | 56 ++++++++++++++++++++----------------- src/io_impl/mod.rs | 10 ++++--- src/lib.rs | 4 --- src/net.rs | 36 ++++++++---------------- 5 files changed, 50 insertions(+), 68 deletions(-) diff --git a/src/io.rs b/src/io.rs index c86c9ea..4ac64d0 100644 --- a/src/io.rs +++ b/src/io.rs @@ -1,11 +1 @@ -use std::future::Future; - -use crate::{aliases::IoResult, io_impl::IoImpl}; - -pub trait AsyncRead { - fn read(&mut self, buf: &mut [u8]) -> impl Future>; -} - -pub trait AsyncWrite { - fn write(&mut self, buf: &[u8]) -> impl Future>; -} +pub use futures_lite::{AsyncRead, AsyncWrite}; diff --git a/src/io_impl/io_uring/mod.rs b/src/io_impl/io_uring/mod.rs index 80d2690..bd1ec0f 100644 --- a/src/io_impl/io_uring/mod.rs +++ b/src/io_impl/io_uring/mod.rs @@ -1,14 +1,13 @@ use std::{ - async_iter::AsyncIterator, future::Future, net::SocketAddr, os::fd::{FromRawFd, IntoRawFd}, pin::Pin, - sync::atomic::{AtomicU64, Ordering}, + sync::atomic::{AtomicBool, AtomicU64, Ordering}, task::{Context, Poll, Waker}, }; -use futures_lite::Stream; +use futures_lite::{AsyncRead, AsyncReadExt, Stream}; use parking::Parker; use crate::aliases::IoResult; @@ -189,10 +188,23 @@ impl Drop for TcpListener { } } -pub struct TcpStream(io_uring::types::Fd); -impl Drop for TcpStream { +pub struct TcpStream<'a>(&'a IoUring, io_uring::types::Fd, AtomicBool); +impl Drop for TcpStream<'_> { fn drop(&mut self) { - unsafe { std::net::TcpStream::from_raw_fd(self.0 .0) }; + unsafe { std::net::TcpStream::from_raw_fd(self.1 .0) }; + } +} +impl AsyncRead for TcpStream<'_> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let sq = unsafe { self.0.uring.submission_shared() }; + + let entry = io_uring::opcode::Read::new(self.1, buf.as_mut_ptr(), buf.len() as u32) + .build() + .user_data(UserData::new_boxed(tx, false).into_u64()); } } @@ -215,7 +227,7 @@ impl IoImpl for IoUring { Ok(()) } - type TcpListener = io_uring::types::Fd; + type TcpListener = TcpListener; fn open_tcp_socket( &self, addr: std::net::SocketAddr, @@ -226,7 +238,7 @@ impl IoImpl for IoUring { // the std implementation and cast it to a FileHandle let listener = std::net::TcpListener::bind(addr); - async { IoResult::Ok(io_uring::types::Fd(listener?.into_raw_fd())) } + async { IoResult::Ok(TcpListener(io_uring::types::Fd(listener?.into_raw_fd()))) } /* // let (tx, rx) = async_channel::bounded(1); @@ -279,7 +291,7 @@ impl IoImpl for IoUring { async fn listener_local_addr(&self, listener: &Self::TcpListener) -> IoResult { // FIXME(Blocking) - let listener = unsafe { std::net::TcpListener::from_raw_fd(listener.0) }; + let listener = unsafe { std::net::TcpListener::from_raw_fd(listener.0 .0) }; let addr = listener.local_addr()?; let _ = listener.into_raw_fd(); @@ -295,7 +307,7 @@ impl IoImpl for IoUring { let cb_id = UserData::new_boxed(tx, true).into_u64(); - let entry = io_uring::opcode::AcceptMulti::new(*listener) + let entry = io_uring::opcode::AcceptMulti::new(listener.0) .build() .user_data(cb_id); @@ -312,8 +324,14 @@ impl IoImpl for IoUring { } } - type TcpStream = TcpStream; - async fn tcp_read(&self, stream: &mut Self::TcpStream, buf: &mut [u8]) -> IoResult { + type TcpStream<'a> = TcpStream<'a>; + async fn tcp_read( + &self, + stream: &mut Self::TcpStream<'_>, + buf: &mut [u8], + ) -> IoResult { + return stream.read(buf).await; + let (tx, rx) = async_channel::bounded(1); let entry = io_uring::opcode::Read::new(stream.0, buf.as_mut_ptr(), buf.len() as u32) @@ -371,8 +389,7 @@ pub fn cancel(io: &IoUring, id: u64) { io.uring.submission_shared().push(&entry).unwrap(); } } - -impl AsyncIterator for Incoming { +impl Stream for Incoming { type Item = IoResult; fn poll_next( @@ -394,17 +411,6 @@ impl AsyncIterator for Incoming { } } -impl Stream for Incoming { - type Item = IoResult; - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - AsyncIterator::poll_next(self, cx) - } -} - impl Drop for Incoming { fn drop(&mut self) { let io = unsafe { &*self.io }; diff --git a/src/io_impl/mod.rs b/src/io_impl/mod.rs index aee2e26..606b113 100644 --- a/src/io_impl/mod.rs +++ b/src/io_impl/mod.rs @@ -1,3 +1,5 @@ +use futures_lite::AsyncRead; + use crate::aliases::IoResult; use std::{future::Future, net::SocketAddr, time::Duration}; @@ -23,19 +25,19 @@ pub trait IoImpl { listener: &mut Self::TcpListener, ) -> impl Future>; - type TcpStream; + type TcpStream<'a>: AsyncRead + 'a; fn tcp_read( &self, - stream: &mut Self::TcpStream, + stream: &mut Self::TcpStream<'_>, buf: &mut [u8], ) -> impl Future>; fn tcp_write( &self, - stream: &mut Self::TcpStream, + stream: &mut Self::TcpStream<'_>, buf: &[u8], ) -> impl Future>; fn tcp_connect( &self, socket: SocketAddr, - ) -> impl Future>; + ) -> impl Future>>; } diff --git a/src/lib.rs b/src/lib.rs index eda836e..adbb35d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,3 @@ -#![feature(async_closure)] -#![feature(async_iterator)] -#![feature(never_type)] - mod aliases; // pub mod fs; pub mod io; diff --git a/src/net.rs b/src/net.rs index 3ec6d3e..4731e35 100644 --- a/src/net.rs +++ b/src/net.rs @@ -1,7 +1,8 @@ use std::{ - async_iter::AsyncIterator, net::{SocketAddr, ToSocketAddrs}, + ops::DerefMut, pin::Pin, + task::Poll, }; use futures_lite::Stream; @@ -19,16 +20,20 @@ impl<'a, I: IoImpl> TcpStream<'a, I> { } } -impl<'a, I: IoImpl> AsyncRead for TcpStream<'a, I> { - async fn read(&mut self, buf: &mut [u8]) -> IoResult { - self.0.tcp_read(&mut self.1, buf).await +impl AsyncRead for TcpStream<'_, I> +where + I::TcpStream: Unpin + AsyncRead, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut self.1).poll_read(cx, buf) } } impl<'a, I: IoImpl> AsyncWrite for TcpStream<'a, I> { - async fn write(&mut self, buf: &[u8]) -> IoResult { - self.0.tcp_write(&mut self.1, buf).await - } } pub struct TcpListener<'a, I: IoImpl>(&'a I, I::TcpListener); @@ -63,23 +68,6 @@ impl<'a, I: IoImpl> TcpListener<'a, I> { pub struct Incoming<'a, I: IoImpl>(&'a I, I::Incoming); -impl<'a, I: IoImpl> AsyncIterator for Incoming<'a, I> -where - I::Incoming: AsyncIterator>, -{ - type Item = IoResult>; - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let io = self.0; - let inner = unsafe { self.map_unchecked_mut(|s| &mut s.1) }; - - AsyncIterator::poll_next(inner, cx).map(|o| o.map(|r| r.map(|i| TcpStream(io, i)))) - } -} - impl<'a, I: IoImpl> Stream for Incoming<'a, I> where I::Incoming: Stream>,