Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion src/frontend/prog.cc
Original file line number Diff line number Diff line change
@@ -1 +1,34 @@
int main() { return 0; }
#include "util/async_file.hh"
#include "util/eventloop.hh"
#include "util/temp_file.hh"

#include <array>
#include <iostream>
#include <unistd.h>

using namespace std;

static Task<void> example( EventLoop& loop )
{
TempFile file( "async_sample" );
FileDescriptor& fd = file.fd();

co_await async_write_all( loop, fd, "hello from async\n" );

::lseek( fd.fd_num(), 0, SEEK_SET );

array<char, 64> buffer;
simple_string_span span { buffer.data(), buffer.size() };
const size_t bytes = co_await async_read( loop, fd, span );

cout.write( span.data(), bytes );
cout.flush();
co_return;
}

int main()
{
EventLoop loop;
sync_wait( loop, example( loop ) );
return 0;
}
159 changes: 159 additions & 0 deletions src/util/async_file.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#pragma once

#include "eventloop.hh"
#include "file_descriptor.hh"
#include "task.hh"
#include <optional>
#include <type_traits>
#include <variant>

/** Awaitable that suspends until a file descriptor is readable. */
class WaitReadable
{
std::optional<EventLoop::RuleHandle> rule_ {};
std::coroutine_handle<> handle_ {};
EventLoop* loop_;
FileDescriptor* fd_;

public:
WaitReadable( EventLoop& loop, FileDescriptor& fd )
: loop_( &loop )
, fd_( &fd )
{
}

WaitReadable( const WaitReadable& ) = delete;
WaitReadable& operator=( const WaitReadable& ) = delete;

bool await_ready() const noexcept { return false; }

void await_suspend( std::coroutine_handle<> h )
{
handle_ = h;
rule_ = loop_->add_rule(
"await read",
Direction::In,
*fd_,
[this] {
rule_->cancel();
handle_.resume();
},
[] { return true; } );
loop_ = nullptr;
fd_ = nullptr;
}

void await_resume() const noexcept {}
};

/** Awaitable that suspends until a file descriptor is writeable. */
class WaitWriteable
{
std::optional<EventLoop::RuleHandle> rule_ {};
std::coroutine_handle<> handle_ {};
EventLoop* loop_;
FileDescriptor* fd_;

public:
WaitWriteable( EventLoop& loop, FileDescriptor& fd )
: loop_( &loop )
, fd_( &fd )
{
}

WaitWriteable( const WaitWriteable& ) = delete;
WaitWriteable& operator=( const WaitWriteable& ) = delete;

bool await_ready() const noexcept { return false; }

void await_suspend( std::coroutine_handle<> h )
{
handle_ = h;
rule_ = loop_->add_rule(
"await write",
Direction::Out,
*fd_,
[this] {
rule_->cancel();
handle_.resume();
},
[] { return true; } );
loop_ = nullptr;
fd_ = nullptr;
}

void await_resume() const noexcept {}
};

inline Task<size_t> async_read( EventLoop& loop, FileDescriptor& fd, simple_string_span buffer )
{
while ( true ) {
const size_t bytes = fd.read( buffer );
if ( bytes > 0 || fd.eof() ) {
co_return bytes;
}
co_await WaitReadable { loop, fd };
}
}

inline Task<size_t> async_write( EventLoop& loop, FileDescriptor& fd, std::string_view buffer )
{
while ( true ) {
const size_t bytes = fd.write( buffer );
if ( bytes > 0 || buffer.empty() ) {
co_return bytes;
}
co_await WaitWriteable { loop, fd };
}
}

inline Task<void> async_write_all( EventLoop& loop, FileDescriptor& fd, std::string_view buffer )
{
while ( not buffer.empty() ) {
const size_t written = co_await async_write( loop, fd, buffer );
buffer.remove_prefix( written );
}
co_return;
}

/** Run the event loop until the given task finishes and return its result. */
template<typename T>
T sync_wait( EventLoop& loop, Task<T> t )
{
using result_t = std::conditional_t<std::is_void_v<T>, std::monostate, std::optional<T>>;
result_t result {};
std::exception_ptr ep;
bool done = false;

auto wrapper = [&]() -> Task<void> {
try {
if constexpr ( std::is_void_v<T> ) {
co_await t;
} else {
result = co_await t;
}
} catch ( ... ) {
ep = std::current_exception();
}
done = true;
co_return;
}();

wrapper.start();

while ( not done ) {
if ( loop.wait_next_event( -1 ) == EventLoop::Result::Exit ) {
break;
}
}

if ( ep ) {
std::rethrow_exception( ep );
}

if constexpr ( std::is_void_v<T> ) {
return;
} else {
return *result;
}
}
181 changes: 181 additions & 0 deletions src/util/task.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#pragma once

#include <coroutine>
#include <exception>
#include <optional>
#include <utility>

/** A simple coroutine task similar to Python's awaitables. */

template<typename T>
class Task
{
public:
struct promise_type
{
std::optional<T> value_ {};
std::exception_ptr exception_ {};
std::coroutine_handle<> continuation_ {};

Task get_return_object() { return Task { std::coroutine_handle<promise_type>::from_promise( *this ) }; }

std::suspend_always initial_suspend() const noexcept { return {}; }

struct FinalAwaiter
{
bool await_ready() const noexcept { return false; }
void await_suspend( std::coroutine_handle<promise_type> h ) const noexcept
{
if ( h.promise().continuation_ ) {
h.promise().continuation_.resume();
}
}
void await_resume() const noexcept {}
};

auto final_suspend() const noexcept { return FinalAwaiter {}; }

template<typename U>
void return_value( U&& v )
{
value_ = std::forward<U>( v );
}

void unhandled_exception() { exception_ = std::current_exception(); }
};

using handle_type = std::coroutine_handle<promise_type>;

explicit Task( handle_type h )
: handle_( h )
{
}

Task( Task&& other ) noexcept
: handle_( other.handle_ )
{
other.handle_ = nullptr;
}

Task( const Task& ) = delete;
Task& operator=( const Task& ) = delete;

~Task()
{
if ( handle_ ) {
handle_.destroy();
}
}

void start()
{
if ( handle_ ) {
handle_.resume();
}
}

bool done() const { return not handle_ || handle_.done(); }

bool await_ready() const noexcept { return done(); }

void await_suspend( std::coroutine_handle<> h ) noexcept
{
handle_.promise().continuation_ = h;
handle_.resume();
}

T await_resume()
{
if ( handle_.promise().exception_ ) {
std::rethrow_exception( handle_.promise().exception_ );
}
return std::move( *handle_.promise().value_ );
}

private:
handle_type handle_ { nullptr };
};

// specialization for void
template<>
class Task<void>
{
public:
struct promise_type
{
std::exception_ptr exception_ {};
std::coroutine_handle<> continuation_ {};

Task get_return_object() { return Task { std::coroutine_handle<promise_type>::from_promise( *this ) }; }

std::suspend_always initial_suspend() const noexcept { return {}; }

struct FinalAwaiter
{
bool await_ready() const noexcept { return false; }
void await_suspend( std::coroutine_handle<promise_type> h ) const noexcept
{
if ( h.promise().continuation_ ) {
h.promise().continuation_.resume();
}
}
void await_resume() const noexcept {}
};

auto final_suspend() const noexcept { return FinalAwaiter {}; }

void return_void() {}

void unhandled_exception() { exception_ = std::current_exception(); }
};

using handle_type = std::coroutine_handle<promise_type>;

explicit Task( handle_type h )
: handle_( h )
{
}

Task( Task&& other ) noexcept
: handle_( other.handle_ )
{
other.handle_ = nullptr;
}

Task( const Task& ) = delete;
Task& operator=( const Task& ) = delete;

~Task()
{
if ( handle_ ) {
handle_.destroy();
}
}

void start()
{
if ( handle_ ) {
handle_.resume();
}
}

bool done() const { return not handle_ || handle_.done(); }

bool await_ready() const noexcept { return done(); }

void await_suspend( std::coroutine_handle<> h ) noexcept
{
handle_.promise().continuation_ = h;
handle_.resume();
}

void await_resume()
{
if ( handle_.promise().exception_ ) {
std::rethrow_exception( handle_.promise().exception_ );
}
}

private:
handle_type handle_ { nullptr };
};