diff --git a/src/frontend/prog.cc b/src/frontend/prog.cc index 76e8197..a87e1c5 100644 --- a/src/frontend/prog.cc +++ b/src/frontend/prog.cc @@ -1 +1,34 @@ -int main() { return 0; } +#include "util/async_file.hh" +#include "util/eventloop.hh" +#include "util/temp_file.hh" + +#include +#include +#include + +using namespace std; + +static Task example( EventLoop& loop ) +{ + TempFile file( "async_sample" ); + FileDescriptor& fd = file.fd(); + + co_await async_write_all( loop, fd, "hello from async\n" ); + + ::lseek( fd.fd_num(), 0, SEEK_SET ); + + array buffer; + simple_string_span span { buffer.data(), buffer.size() }; + const size_t bytes = co_await async_read( loop, fd, span ); + + cout.write( span.data(), bytes ); + cout.flush(); + co_return; +} + +int main() +{ + EventLoop loop; + sync_wait( loop, example( loop ) ); + return 0; +} diff --git a/src/util/async_file.hh b/src/util/async_file.hh new file mode 100644 index 0000000..1e85029 --- /dev/null +++ b/src/util/async_file.hh @@ -0,0 +1,159 @@ +#pragma once + +#include "eventloop.hh" +#include "file_descriptor.hh" +#include "task.hh" +#include +#include +#include + +/** Awaitable that suspends until a file descriptor is readable. */ +class WaitReadable +{ + std::optional rule_ {}; + std::coroutine_handle<> handle_ {}; + EventLoop* loop_; + FileDescriptor* fd_; + +public: + WaitReadable( EventLoop& loop, FileDescriptor& fd ) + : loop_( &loop ) + , fd_( &fd ) + { + } + + WaitReadable( const WaitReadable& ) = delete; + WaitReadable& operator=( const WaitReadable& ) = delete; + + bool await_ready() const noexcept { return false; } + + void await_suspend( std::coroutine_handle<> h ) + { + handle_ = h; + rule_ = loop_->add_rule( + "await read", + Direction::In, + *fd_, + [this] { + rule_->cancel(); + handle_.resume(); + }, + [] { return true; } ); + loop_ = nullptr; + fd_ = nullptr; + } + + void await_resume() const noexcept {} +}; + +/** Awaitable that suspends until a file descriptor is writeable. */ +class WaitWriteable +{ + std::optional rule_ {}; + std::coroutine_handle<> handle_ {}; + EventLoop* loop_; + FileDescriptor* fd_; + +public: + WaitWriteable( EventLoop& loop, FileDescriptor& fd ) + : loop_( &loop ) + , fd_( &fd ) + { + } + + WaitWriteable( const WaitWriteable& ) = delete; + WaitWriteable& operator=( const WaitWriteable& ) = delete; + + bool await_ready() const noexcept { return false; } + + void await_suspend( std::coroutine_handle<> h ) + { + handle_ = h; + rule_ = loop_->add_rule( + "await write", + Direction::Out, + *fd_, + [this] { + rule_->cancel(); + handle_.resume(); + }, + [] { return true; } ); + loop_ = nullptr; + fd_ = nullptr; + } + + void await_resume() const noexcept {} +}; + +inline Task async_read( EventLoop& loop, FileDescriptor& fd, simple_string_span buffer ) +{ + while ( true ) { + const size_t bytes = fd.read( buffer ); + if ( bytes > 0 || fd.eof() ) { + co_return bytes; + } + co_await WaitReadable { loop, fd }; + } +} + +inline Task async_write( EventLoop& loop, FileDescriptor& fd, std::string_view buffer ) +{ + while ( true ) { + const size_t bytes = fd.write( buffer ); + if ( bytes > 0 || buffer.empty() ) { + co_return bytes; + } + co_await WaitWriteable { loop, fd }; + } +} + +inline Task async_write_all( EventLoop& loop, FileDescriptor& fd, std::string_view buffer ) +{ + while ( not buffer.empty() ) { + const size_t written = co_await async_write( loop, fd, buffer ); + buffer.remove_prefix( written ); + } + co_return; +} + +/** Run the event loop until the given task finishes and return its result. */ +template +T sync_wait( EventLoop& loop, Task t ) +{ + using result_t = std::conditional_t, std::monostate, std::optional>; + result_t result {}; + std::exception_ptr ep; + bool done = false; + + auto wrapper = [&]() -> Task { + try { + if constexpr ( std::is_void_v ) { + co_await t; + } else { + result = co_await t; + } + } catch ( ... ) { + ep = std::current_exception(); + } + done = true; + co_return; + }(); + + wrapper.start(); + + while ( not done ) { + if ( loop.wait_next_event( -1 ) == EventLoop::Result::Exit ) { + break; + } + } + + if ( ep ) { + std::rethrow_exception( ep ); + } + + if constexpr ( std::is_void_v ) { + return; + } else { + return *result; + } +} diff --git a/src/util/task.hh b/src/util/task.hh new file mode 100644 index 0000000..29ed5a2 --- /dev/null +++ b/src/util/task.hh @@ -0,0 +1,181 @@ +#pragma once + +#include +#include +#include +#include + +/** A simple coroutine task similar to Python's awaitables. */ + +template +class Task +{ +public: + struct promise_type + { + std::optional value_ {}; + std::exception_ptr exception_ {}; + std::coroutine_handle<> continuation_ {}; + + Task get_return_object() { return Task { std::coroutine_handle::from_promise( *this ) }; } + + std::suspend_always initial_suspend() const noexcept { return {}; } + + struct FinalAwaiter + { + bool await_ready() const noexcept { return false; } + void await_suspend( std::coroutine_handle h ) const noexcept + { + if ( h.promise().continuation_ ) { + h.promise().continuation_.resume(); + } + } + void await_resume() const noexcept {} + }; + + auto final_suspend() const noexcept { return FinalAwaiter {}; } + + template + void return_value( U&& v ) + { + value_ = std::forward( v ); + } + + void unhandled_exception() { exception_ = std::current_exception(); } + }; + + using handle_type = std::coroutine_handle; + + explicit Task( handle_type h ) + : handle_( h ) + { + } + + Task( Task&& other ) noexcept + : handle_( other.handle_ ) + { + other.handle_ = nullptr; + } + + Task( const Task& ) = delete; + Task& operator=( const Task& ) = delete; + + ~Task() + { + if ( handle_ ) { + handle_.destroy(); + } + } + + void start() + { + if ( handle_ ) { + handle_.resume(); + } + } + + bool done() const { return not handle_ || handle_.done(); } + + bool await_ready() const noexcept { return done(); } + + void await_suspend( std::coroutine_handle<> h ) noexcept + { + handle_.promise().continuation_ = h; + handle_.resume(); + } + + T await_resume() + { + if ( handle_.promise().exception_ ) { + std::rethrow_exception( handle_.promise().exception_ ); + } + return std::move( *handle_.promise().value_ ); + } + +private: + handle_type handle_ { nullptr }; +}; + +// specialization for void +template<> +class Task +{ +public: + struct promise_type + { + std::exception_ptr exception_ {}; + std::coroutine_handle<> continuation_ {}; + + Task get_return_object() { return Task { std::coroutine_handle::from_promise( *this ) }; } + + std::suspend_always initial_suspend() const noexcept { return {}; } + + struct FinalAwaiter + { + bool await_ready() const noexcept { return false; } + void await_suspend( std::coroutine_handle h ) const noexcept + { + if ( h.promise().continuation_ ) { + h.promise().continuation_.resume(); + } + } + void await_resume() const noexcept {} + }; + + auto final_suspend() const noexcept { return FinalAwaiter {}; } + + void return_void() {} + + void unhandled_exception() { exception_ = std::current_exception(); } + }; + + using handle_type = std::coroutine_handle; + + explicit Task( handle_type h ) + : handle_( h ) + { + } + + Task( Task&& other ) noexcept + : handle_( other.handle_ ) + { + other.handle_ = nullptr; + } + + Task( const Task& ) = delete; + Task& operator=( const Task& ) = delete; + + ~Task() + { + if ( handle_ ) { + handle_.destroy(); + } + } + + void start() + { + if ( handle_ ) { + handle_.resume(); + } + } + + bool done() const { return not handle_ || handle_.done(); } + + bool await_ready() const noexcept { return done(); } + + void await_suspend( std::coroutine_handle<> h ) noexcept + { + handle_.promise().continuation_ = h; + handle_.resume(); + } + + void await_resume() + { + if ( handle_.promise().exception_ ) { + std::rethrow_exception( handle_.promise().exception_ ); + } + } + +private: + handle_type handle_ { nullptr }; +};