diff --git a/tests/server/api/v1/test_provider.py b/tests/server/api/v1/test_provider.py index fe9872e72..656cc36e0 100644 --- a/tests/server/api/v1/test_provider.py +++ b/tests/server/api/v1/test_provider.py @@ -6,7 +6,12 @@ from waterbutler.core.path import WaterButlerPath from waterbutler.server.api.v1.provider import ProviderHandler, list_or_value -from tests.utils import MockCoroutine, MockStream, MockWriter, MockProvider +from tests.utils import ( + MockCoroutine, + MockProvider, + MockRequestStream, + MockStream +) from tests.server.api.v1.fixtures import (http_request, handler, patch_auth_handler, handler_auth, patch_make_provider_core) @@ -130,13 +135,12 @@ async def test_data_received(self, handler): @pytest.mark.asyncio async def test_data_received_stream(self, handler): handler.path = WaterButlerPath('/folder/') - handler.stream = MockStream() - handler.writer = MockWriter() + handler.stream = MockRequestStream(handler.request) await handler.data_received(b'1234567890') assert handler.bytes_uploaded == 10 - handler.writer.write.assert_called_once_with(b'1234567890') + assert handler.stream._buffer == b'1234567890' @pytest.mark.asyncio async def test_on_finish_download_file(self, handler): diff --git a/tests/utils.py b/tests/utils.py index 42b467c75..ea72e33ac 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,6 +14,7 @@ from waterbutler.core import metadata, provider from waterbutler.core.path import WaterButlerPath from waterbutler.core.streams.file import FileStreamReader +from waterbutler.core.streams.http import RequestStreamReader class MockCoroutine(mock.Mock): @@ -75,6 +76,14 @@ def __init__(self): super().__init__(tempfile.TemporaryFile()) +class MockRequestStream(RequestStreamReader): + content_type = 'application/octet-stream' + size = 1334 + + def __init__(self, request): + super().__init__(request) + + class MockRequestBody(concurrent.Future): def __await__(self): diff --git a/waterbutler/core/streams/http.py b/waterbutler/core/streams/http.py index fb2955344..5b0e14925 100644 --- a/waterbutler/core/streams/http.py +++ b/waterbutler/core/streams/http.py @@ -1,8 +1,19 @@ import uuid import asyncio +from asyncio import Future +from asyncio.streams import _DEFAULT_LIMIT + + +from tornado import gen, ioloop from waterbutler.core.streams.base import BaseStream, MultiStream, StringStream +import logging +logger = logging.getLogger(__name__) + + +print(_DEFAULT_LIMIT) + class FormDataStream(MultiStream): """A child of MultiSteam used to create stream friendly multipart form data requests. @@ -173,26 +184,81 @@ async def _read(self, size): return chunk +class WritePendingError(): + pass + + class RequestStreamReader(BaseStream): - def __init__(self, request, inner): + def __init__(self, request, max_buffer_size=_DEFAULT_LIMIT): super().__init__() - self.inner = inner self.request = request + self.max_buffer_size = max_buffer_size + self.pending_feed = None @property def size(self): return int(self.request.headers.get('Content-Length')) - def at_eof(self): - return self.inner.at_eof() + def feed_data(self, chunk, timeout=None): + assert not self._eof, 'feed_data after feed_eof' + # Trying to write to the stream from several coroutines doesn't seem + # like a great idea, so limit it to one event loop, one coroutine. + if self.pending_feed is not None: + # Make sure the pending future is complete. + future, chunk = self.pending_feed + if not future.done(): + raise WritePendingError('Another coroutine is alreading waiting to write to this stream.') + self.pending_feed = None - async def _read(self, size): - if self.inner.at_eof(): - return b'' - if size < 0: - return (await self.inner.read(size)) - try: - return (await self.inner.readexactly(size)) - except asyncio.IncompleteReadError as e: - return e.partial + if not chunk: + # Nothing to add to the stream. + return + + future = Future() + + if len(self._buffer) > self.max_buffer_size: + # The buffer is full, and no more can be written to it until some + # of it has been consumed. We will always be able to write + # something to the buffer, because we don't check it for overflow. + # (Default limit still remains) + assert self.pending_feed is None + self.pending_feed = (future, chunk) + + future.add_done_callback(lambda _: self.clear_pending_feed()) + + if timeout: + # Let a caller specify a maximum amount of time to wait. + def on_timeout(): + if not future.done(): + future.set_exception(gen.TimeoutError()) + io_loop = ioloop.IOLoop.current() + timeout_handle = io_loop.add_timeout(timeout, on_timeout) + future.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle)) + + else: + # Sets the result of the Future. + self.feed_nowait(future, chunk) + + # Give the future back for it to get awaited somewhere. + return future + + def clear_pending_feed(self): + self.pending_feed = None + + def feed_nowait(self, future, chunk): + # We can put the chunk on the buffer. + self._buffer.extend(chunk) + future.set_result(None) + self.clear_pending_feed() + + # Let a waiting read know there's data. + self._wakeup_waiter() + + async def _read(self, n=-1): + data = await asyncio.StreamReader.read(self, n) + if self.pending_feed is not None and len(self._buffer) <= self.max_buffer_size: + future, chunk = self.pending_feed + if not future.done(): + self.feed_nowait(future, chunk) + return data diff --git a/waterbutler/server/api/v0/crud.py b/waterbutler/server/api/v0/crud.py index c1a52720b..9ac60f11a 100644 --- a/waterbutler/server/api/v0/crud.py +++ b/waterbutler/server/api/v0/crud.py @@ -1,5 +1,4 @@ import os -import socket import asyncio from http import HTTPStatus @@ -35,12 +34,7 @@ async def prepare(self): async def prepare_stream(self): if self.request.method in self.STREAM_METHODS: - self.rsock, self.wsock = socket.socketpair() - - self.reader, _ = await asyncio.open_unix_connection(sock=self.rsock) - _, self.writer = await asyncio.open_unix_connection(sock=self.wsock) - - self.stream = RequestStreamReader(self.request, self.reader) + self.stream = RequestStreamReader(self.request) self.uploader = asyncio.ensure_future(self.provider.upload(self.stream, **self.arguments)) @@ -51,8 +45,7 @@ async def data_received(self, chunk): """Note: Only called during uploads.""" self.bytes_uploaded += len(chunk) if self.stream: - self.writer.write(chunk) - await self.writer.drain() + await self.stream.feed_data(chunk) async def get(self): """Download a file.""" @@ -110,7 +103,7 @@ async def post(self): async def put(self): """Upload a file.""" - self.writer.write_eof() + self.stream.feed_eof() metadata, created = await self.uploader diff --git a/waterbutler/server/api/v1/provider/__init__.py b/waterbutler/server/api/v1/provider/__init__.py index dfc8787bf..d31022e42 100644 --- a/waterbutler/server/api/v1/provider/__init__.py +++ b/waterbutler/server/api/v1/provider/__init__.py @@ -1,5 +1,4 @@ import uuid -import socket import asyncio import logging from http import HTTPStatus @@ -123,21 +122,15 @@ async def data_received(self, chunk): """Note: Only called during uploads.""" self.bytes_uploaded += len(chunk) if self.stream: - self.writer.write(chunk) - await self.writer.drain() + await self.stream.feed_data(chunk) else: self.body += chunk async def prepare_stream(self): """Sets up an asyncio pipe from client to server - Only called on PUT when path is to a file + Only called on PUT when path is to a file. """ - self.rsock, self.wsock = socket.socketpair() - - self.reader, _ = await asyncio.open_unix_connection(sock=self.rsock) - _, self.writer = await asyncio.open_unix_connection(sock=self.wsock) - - self.stream = RequestStreamReader(self.request, self.reader) + self.stream = RequestStreamReader(self.request) self.uploader = asyncio.ensure_future(self.provider.upload(self.stream, self.target_path)) def on_finish(self): diff --git a/waterbutler/server/api/v1/provider/create.py b/waterbutler/server/api/v1/provider/create.py index 21c428404..13377f5e0 100644 --- a/waterbutler/server/api/v1/provider/create.py +++ b/waterbutler/server/api/v1/provider/create.py @@ -94,11 +94,9 @@ async def create_folder(self): self.write({'data': self.metadata.json_api_serialized(self.resource)}) async def upload_file(self): - self.writer.write_eof() + self.stream.feed_eof() self.metadata, created = await self.uploader - self.writer.close() - self.wsock.close() if created: self.set_status(201)