diff --git a/google/cloud/storage/asyncio/async_write_object_stream.py b/google/cloud/storage/asyncio/async_write_object_stream.py index 721183962..319f394dd 100644 --- a/google/cloud/storage/asyncio/async_write_object_stream.py +++ b/google/cloud/storage/asyncio/async_write_object_stream.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import List, Optional, Tuple +import grpc from google.cloud import _storage_v2 from google.cloud.storage.asyncio import _utils from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient @@ -181,9 +182,17 @@ async def close(self) -> None: async def requests_done(self): """Signals that all requests have been sent.""" - await self.socket_like_rpc.send(None) - _utils.update_write_handle_if_exists(self, await self.socket_like_rpc.recv()) + + # The server may send a final "EOF" response immediately, or it may + # first send an intermediate response followed by the EOF response depending on whether the object was finalized or not. + first_resp = await self.socket_like_rpc.recv() + _utils.update_write_handle_if_exists(self, first_resp) + + if first_resp != grpc.aio.EOF: + self.persisted_size = first_resp.persisted_size + second_resp = await self.socket_like_rpc.recv() + assert second_resp == grpc.aio.EOF async def send( self, bidi_write_object_request: _storage_v2.BidiWriteObjectRequest diff --git a/tests/unit/asyncio/test_async_write_object_stream.py b/tests/unit/asyncio/test_async_write_object_stream.py index 77e2ef091..4e952336b 100644 --- a/tests/unit/asyncio/test_async_write_object_stream.py +++ b/tests/unit/asyncio/test_async_write_object_stream.py @@ -15,6 +15,8 @@ import unittest.mock as mock from unittest.mock import AsyncMock, MagicMock import pytest +import grpc + from google.cloud.storage.asyncio.async_write_object_stream import ( _AsyncWriteObjectStream, @@ -194,11 +196,57 @@ async def test_close_success(self, mock_client): stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) stream._is_stream_open = True stream.socket_like_rpc = AsyncMock() + + stream.socket_like_rpc.send = AsyncMock() + first_resp = _storage_v2.BidiWriteObjectResponse(persisted_size=100) + stream.socket_like_rpc.recv = AsyncMock(side_effect=[first_resp, grpc.aio.EOF]) stream.socket_like_rpc.close = AsyncMock() await stream.close() stream.socket_like_rpc.close.assert_awaited_once() assert not stream.is_stream_open + assert stream.persisted_size == 100 + + @pytest.mark.asyncio + async def test_close_with_persisted_size_then_eof(self, mock_client): + """Test close when first recv has persisted_size, second is EOF.""" + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) + stream._is_stream_open = True + stream.socket_like_rpc = AsyncMock() + + # First response has persisted_size (NOT EOF, intermediate) + persisted_resp = _storage_v2.BidiWriteObjectResponse(persisted_size=500) + # Second response is EOF (None) + eof_resp = grpc.aio.EOF + + stream.socket_like_rpc.send = AsyncMock() + stream.socket_like_rpc.recv = AsyncMock(side_effect=[persisted_resp, eof_resp]) + stream.socket_like_rpc.close = AsyncMock() + + await stream.close() + + # Verify two recv calls: first has persisted_size (NOT EOF), so read second (EOF) + assert stream.socket_like_rpc.recv.await_count == 2 + assert stream.persisted_size == 500 + assert not stream.is_stream_open + + @pytest.mark.asyncio + async def test_close_with_grpc_aio_eof_response(self, mock_client): + """Test close when first recv is grpc.aio.EOF sentinel.""" + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) + stream._is_stream_open = True + stream.socket_like_rpc = AsyncMock() + + # First recv returns grpc.aio.EOF (explicit sentinel from finalize) + stream.socket_like_rpc.send = AsyncMock() + stream.socket_like_rpc.recv = AsyncMock(return_value=grpc.aio.EOF) + stream.socket_like_rpc.close = AsyncMock() + + await stream.close() + + # Verify only one recv call (grpc.aio.EOF=EOF, so don't read second) + assert stream.socket_like_rpc.recv.await_count == 1 + assert not stream.is_stream_open @pytest.mark.asyncio async def test_methods_require_open_raises(self, mock_client):