diff --git a/comtypes/test/test_storage.py b/comtypes/test/test_storage.py index 0a2f2719..4b49ee46 100644 --- a/comtypes/test/test_storage.py +++ b/comtypes/test/test_storage.py @@ -1,17 +1,18 @@ -import contextlib +import tempfile import unittest from _ctypes import COMError from ctypes import HRESULT, POINTER, OleDLL, byref, c_ubyte from ctypes.wintypes import DWORD, PWCHAR from pathlib import Path +from typing import Optional import comtypes import comtypes.client -with contextlib.redirect_stdout(None): # supress warnings - mod = comtypes.client.GetModule("msvidctl.dll") +comtypes.client.GetModule("portabledeviceapi.dll") +from comtypes.gen.PortableDeviceApiLib import IStorage, tagSTATSTG -from comtypes.gen.MSVidCtlLib import IStorage +STGTY_STORAGE = 1 STATFLAG_DEFAULT = 0 STGC_DEFAULT = 0 @@ -26,7 +27,7 @@ STREAM_SEEK_SET = 0 STG_E_PATHNOTFOUND = -2147287038 - +STG_E_INVALIDFLAG = -2147286785 _ole32 = OleDLL("ole32") @@ -36,28 +37,25 @@ class Test_IStorage(unittest.TestCase): - CREATE_DOC_FLAG = ( - STGM_DIRECT - | STGM_READWRITE - | STGM_CREATE - | STGM_SHARE_EXCLUSIVE - | STGM_DELETEONRELEASE - ) - CREATE_STM_FLAG = STGM_CREATE | STGM_READWRITE | STGM_SHARE_EXCLUSIVE - OPEN_STM_FLAG = STGM_READ | STGM_SHARE_EXCLUSIVE - CREATE_STG_FLAG = STGM_TRANSACTED | STGM_READWRITE | STGM_SHARE_EXCLUSIVE - OPEN_STG_FLAG = STGM_TRANSACTED | STGM_READWRITE | STGM_SHARE_EXCLUSIVE - - def _create_docfile(self) -> IStorage: + RW_EXCLUSIVE = STGM_READWRITE | STGM_SHARE_EXCLUSIVE + RW_EXCLUSIVE_TX = RW_EXCLUSIVE | STGM_TRANSACTED + RW_EXCLUSIVE_CREATE = RW_EXCLUSIVE | STGM_CREATE + CREATE_TESTDOC = STGM_DIRECT | STGM_CREATE | RW_EXCLUSIVE + CREATE_TEMP_TESTDOC = CREATE_TESTDOC | STGM_DELETEONRELEASE + + def _create_docfile(self, mode: int, name: Optional[str] = None) -> IStorage: stg = POINTER(IStorage)() - _StgCreateDocfile(None, self.CREATE_DOC_FLAG, 0, byref(stg)) + _StgCreateDocfile(name, mode, 0, byref(stg)) return stg # type: ignore def test_CreateStream(self): - storage = self._create_docfile() + storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) + # When created with `StgCreateDocfile(NULL, ...)`, `pwcsName` is a + # temporary filename. The file really exists on disk because Windows + # creates an actual temporary file for the compound storage. filepath = Path(storage.Stat(STATFLAG_DEFAULT).pwcsName) self.assertTrue(filepath.exists()) - stream = storage.CreateStream("example", self.CREATE_STM_FLAG, 0, 0) + stream = storage.CreateStream("example", self.RW_EXCLUSIVE_CREATE, 0, 0) test_data = b"Some data" pv = (c_ubyte * len(test_data)).from_buffer(bytearray(test_data)) stream.RemoteWrite(pv, len(test_data)) @@ -70,64 +68,120 @@ def test_CreateStream(self): del storage self.assertFalse(filepath.exists()) + # TODO: Auto-generated methods based on type info are remote-side and hard + # to call from the client. + # If a proper invocation method or workaround is found, testing + # becomes possible. + # See: https://github.com/enthought/comtypes/issues/607 + # def test_RemoteOpenStream(self): + # pass + def test_CreateStorage(self): - parent = self._create_docfile() - child = parent.CreateStorage("child", self.CREATE_STG_FLAG, 0, 0) + parent = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) + child = parent.CreateStorage("child", self.RW_EXCLUSIVE_TX, 0, 0) self.assertEqual("child", child.Stat(STATFLAG_DEFAULT).pwcsName) def test_OpenStorage(self): - parent = self._create_docfile() - created_child = parent.CreateStorage("child", self.CREATE_STG_FLAG, 0, 0) - del created_child - opened_child = parent.OpenStorage("child", None, self.OPEN_STG_FLAG, None, 0) - self.assertEqual("child", opened_child.Stat(STATFLAG_DEFAULT).pwcsName) + parent = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) + with self.assertRaises(COMError) as cm: + parent.OpenStorage("child", None, self.RW_EXCLUSIVE_TX, None, 0) + self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND) + parent.CreateStorage("child", self.RW_EXCLUSIVE_TX, 0, 0) + child = parent.OpenStorage("child", None, self.RW_EXCLUSIVE_TX, None, 0) + self.assertEqual("child", child.Stat(STATFLAG_DEFAULT).pwcsName) def test_RemoteCopyTo(self): - src_stg = self._create_docfile() - src_stg.CreateStorage("child", self.CREATE_STG_FLAG, 0, 0) - dst_stg = self._create_docfile() + src_stg = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) + src_stg.CreateStorage("child", self.RW_EXCLUSIVE_TX, 0, 0) + dst_stg = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) src_stg.RemoteCopyTo(0, None, None, dst_stg) src_stg.Commit(STGC_DEFAULT) del src_stg - opened_stg = dst_stg.OpenStorage("child", None, self.OPEN_STG_FLAG, None, 0) + opened_stg = dst_stg.OpenStorage("child", None, self.RW_EXCLUSIVE_TX, None, 0) self.assertEqual("child", opened_stg.Stat(STATFLAG_DEFAULT).pwcsName) def test_MoveElementTo(self): - src_stg = self._create_docfile() - src_stg.CreateStorage("foo", self.CREATE_STG_FLAG, 0, 0) - dst_stg = self._create_docfile() + src_stg = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) + src_stg.CreateStorage("foo", self.RW_EXCLUSIVE_TX, 0, 0) + dst_stg = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) src_stg.MoveElementTo("foo", dst_stg, "bar", STGMOVE_MOVE) - opened_stg = dst_stg.OpenStorage("bar", None, self.OPEN_STG_FLAG, None, 0) + opened_stg = dst_stg.OpenStorage("bar", None, self.RW_EXCLUSIVE_TX, None, 0) self.assertEqual("bar", opened_stg.Stat(STATFLAG_DEFAULT).pwcsName) - with self.assertRaises(COMError) as ctx: - src_stg.OpenStorage("foo", None, self.OPEN_STG_FLAG, None, 0) - self.assertEqual(ctx.exception.hresult, STG_E_PATHNOTFOUND) + with self.assertRaises(COMError) as cm: + src_stg.OpenStorage("foo", None, self.RW_EXCLUSIVE_TX, None, 0) + self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND) def test_Revert(self): - storage = self._create_docfile() - foo = storage.CreateStorage("foo", self.CREATE_STG_FLAG, 0, 0) - foo.CreateStorage("bar", self.CREATE_STG_FLAG, 0, 0) - bar = foo.OpenStorage("bar", None, self.OPEN_STG_FLAG, None, 0) + storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) + foo = storage.CreateStorage("foo", self.RW_EXCLUSIVE_TX, 0, 0) + foo.CreateStorage("bar", self.RW_EXCLUSIVE_TX, 0, 0) + bar = foo.OpenStorage("bar", None, self.RW_EXCLUSIVE_TX, None, 0) self.assertEqual("bar", bar.Stat(STATFLAG_DEFAULT).pwcsName) foo.Revert() - with self.assertRaises(COMError) as ctx: - foo.OpenStorage("bar", None, self.OPEN_STG_FLAG, None, 0) - self.assertEqual(ctx.exception.hresult, STG_E_PATHNOTFOUND) + with self.assertRaises(COMError) as cm: + foo.OpenStorage("bar", None, self.RW_EXCLUSIVE_TX, None, 0) + self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND) + + # TODO: Auto-generated methods based on type info are remote-side and hard + # to call from the client. + # If a proper invocation method or workaround is found, testing + # becomes possible. + # See: https://github.com/enthought/comtypes/issues/607 + # def test_RemoteEnumElements(self): + # pass def test_DestroyElement(self): - storage = self._create_docfile() - storage.CreateStorage("example", self.CREATE_STG_FLAG, 0, 0) + storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) + storage.CreateStorage("example", self.RW_EXCLUSIVE_TX, 0, 0) storage.DestroyElement("example") - with self.assertRaises(COMError) as ctx: - storage.OpenStorage("example", None, self.OPEN_STG_FLAG, None, 0) - self.assertEqual(ctx.exception.hresult, STG_E_PATHNOTFOUND) + with self.assertRaises(COMError) as cm: + storage.OpenStorage("example", None, self.RW_EXCLUSIVE_TX, None, 0) + self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND) def test_RenameElement(self): - storage = self._create_docfile() - storage.CreateStorage("example", self.CREATE_STG_FLAG, 0, 0) + storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) + storage.CreateStorage("example", self.RW_EXCLUSIVE_TX, 0, 0) storage.RenameElement("example", "sample") - sample = storage.OpenStorage("sample", None, self.OPEN_STG_FLAG, None, 0) + sample = storage.OpenStorage("sample", None, self.RW_EXCLUSIVE_TX, None, 0) self.assertEqual("sample", sample.Stat(STATFLAG_DEFAULT).pwcsName) - with self.assertRaises(COMError) as ctx: - storage.OpenStorage("example", None, self.OPEN_STG_FLAG, None, 0) - self.assertEqual(ctx.exception.hresult, STG_E_PATHNOTFOUND) + with self.assertRaises(COMError) as cm: + storage.OpenStorage("example", None, self.RW_EXCLUSIVE_TX, None, 0) + self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND) + + def test_SetClass(self): + storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) + # Initial value is CLSID_NULL. + self.assertEqual(storage.Stat(STATFLAG_DEFAULT).clsid, comtypes.GUID()) + new_clsid = comtypes.GUID.create_new() + storage.SetClass(new_clsid) + self.assertEqual(storage.Stat(STATFLAG_DEFAULT).clsid, new_clsid) + # Re-set CLSID to CLSID_NULL and verify it is correctly set. + storage.SetClass(comtypes.GUID()) + self.assertEqual(storage.Stat(STATFLAG_DEFAULT).clsid, comtypes.GUID()) + + def test_Stat(self): + with tempfile.TemporaryDirectory() as t: + tmpdir = Path(t) + tmpfile = tmpdir / "test_docfile.cfs" + self.assertFalse(tmpfile.exists()) + # When created with `StgCreateDocfile(filepath_string, ...)`, the + # compound file is created at that location. + storage = self._create_docfile( + name=str(tmpfile), mode=self.CREATE_TEMP_TESTDOC + ) + self.assertTrue(tmpfile.exists()) + with self.assertRaises(COMError) as cm: + storage.Stat(0xFFFFFFFF) # Invalid flag + self.assertEqual(cm.exception.hresult, STG_E_INVALIDFLAG) + stat = storage.Stat(STATFLAG_DEFAULT) + self.assertIsInstance(stat, tagSTATSTG) + del storage # Release the storage to prevent 'cannot access the file ...' + self.assertEqual(stat.type, STGTY_STORAGE) + # Due to header overhead and file system allocation, the size may be + # greater than 0 bytes. + self.assertGreaterEqual(stat.cbSize, 0) + # `grfMode` should reflect the access mode flags from creation. + self.assertEqual(stat.grfMode, self.RW_EXCLUSIVE | STGM_DIRECT) + self.assertEqual(stat.grfLocksSupported, 0) + self.assertEqual(stat.clsid, comtypes.GUID()) # CLSID_NULL for new creation. + self.assertEqual(stat.grfStateBits, 0) diff --git a/comtypes/test/test_stream.py b/comtypes/test/test_stream.py index d6b3b6bb..c8d41637 100644 --- a/comtypes/test/test_stream.py +++ b/comtypes/test/test_stream.py @@ -1,7 +1,10 @@ import contextlib import ctypes +import os import struct +import tempfile import unittest as ut +from _ctypes import COMError from collections.abc import Iterator from ctypes import ( HRESULT, @@ -26,11 +29,13 @@ HWND, INT, LONG, + LPCWSTR, LPVOID, UINT, ULARGE_INTEGER, WORD, ) +from pathlib import Path from typing import Optional import comtypes.client @@ -46,11 +51,24 @@ STATFLAG_DEFAULT = 0 STGC_DEFAULT = 0 + +EACCES = 13 # Permission denied + STGTY_STREAM = 2 STREAM_SEEK_SET = 0 STREAM_SEEK_CUR = 1 STREAM_SEEK_END = 2 +STGM_CREATE = 0x00001000 +STGM_READWRITE = 0x00000002 +STGM_SHARE_DENY_NONE = 0x00000040 + +STG_E_INVALIDFUNCTION = -2147287039 # 0x80030001 + +LOCK_EXCLUSIVE = 2 + +FILE_ATTRIBUTE_NORMAL = 0x80 + _ole32 = OleDLL("ole32") _CreateStreamOnHGlobal = _ole32.CreateStreamOnHGlobal @@ -63,8 +81,19 @@ _IStream_Size.argtypes = [POINTER(IStream), POINTER(ULARGE_INTEGER)] _IStream_Size.restype = HRESULT +_SHCreateStreamOnFileEx = _shlwapi.SHCreateStreamOnFileEx +_SHCreateStreamOnFileEx.argtypes = [ + LPCWSTR, # pszFile + DWORD, # grfMode + DWORD, # dwAttributes + BOOL, # fCreate + POINTER(IStream), # pstmTemplate + POINTER(POINTER(IStream)), # ppstm +] +_SHCreateStreamOnFileEx.restype = HRESULT + -def _create_stream( +def _create_stream_on_hglobal( handle: Optional[int] = None, delete_on_release: bool = True ) -> IStream: # Create an IStream @@ -73,9 +102,17 @@ def _create_stream( return stream # type: ignore +def _create_stream_on_file( + filepath: Path, mode: int, attr: int, create: bool +) -> IStream: + stream = POINTER(IStream)() # type: ignore + _SHCreateStreamOnFileEx(str(filepath), mode, attr, create, None, byref(stream)) + return stream # type: ignore + + class Test_RemoteWrite(ut.TestCase): def test_RemoteWrite(self): - stream = _create_stream() + stream = _create_stream_on_hglobal() test_data = b"Some data" pv = (c_ubyte * len(test_data)).from_buffer(bytearray(test_data)) @@ -87,7 +124,7 @@ def test_RemoteWrite(self): class Test_RemoteRead(ut.TestCase): def test_RemoteRead(self): - stream = _create_stream() + stream = _create_stream_on_hglobal() test_data = b"Some data" pv = (c_ubyte * len(test_data)).from_buffer(bytearray(test_data)) stream.RemoteWrite(pv, len(test_data)) @@ -108,7 +145,7 @@ def test_RemoteRead(self): class Test_RemoteSeek(ut.TestCase): def _create_sample_stream(self) -> IStream: - stream = _create_stream() + stream = _create_stream_on_hglobal() test_data = b"spam egg bacon ham" pv = (c_ubyte * len(test_data)).from_buffer(bytearray(test_data)) stream.RemoteWrite(pv, len(test_data)) @@ -141,7 +178,7 @@ def test_takes_STREAM_SEEK_END_as_origin(self): class Test_SetSize(ut.TestCase): def test_SetSize(self): - stream = _create_stream() + stream = _create_stream_on_hglobal() stream.SetSize(42) pui = pointer(c_ulonglong()) _IStream_Size(stream, pui) @@ -150,8 +187,8 @@ def test_SetSize(self): class Test_RemoteCopyTo(ut.TestCase): def test_RemoteCopyTo(self): - src = _create_stream() - dst = _create_stream() + src = _create_stream_on_hglobal() + dst = _create_stream_on_hglobal() test_data = b"parrot" pv = (c_ubyte * len(test_data)).from_buffer(bytearray(test_data)) src_written = src.RemoteWrite(pv, len(test_data)) @@ -170,7 +207,7 @@ class Test_Stat(ut.TestCase): # https://learn.microsoft.com/en-us/windows/win32/api/objidl/nf-objidl-istream-stat # https://learn.microsoft.com/en-us/windows/win32/api/objidl/ns-objidl-statstg def test_returns_statstg_from_no_modified_stream(self): - stream = _create_stream() + stream = _create_stream_on_hglobal() statstg = stream.Stat(STATFLAG_DEFAULT) self.assertIsNone(statstg.pwcsName) self.assertEqual(statstg.type, STGTY_STREAM) @@ -186,7 +223,7 @@ def test_returns_statstg_from_no_modified_stream(self): class Test_Clone(ut.TestCase): def test_Clone(self): - orig = _create_stream() + orig = _create_stream_on_hglobal() test_data = b"spam egg bacon ham" pv = (c_ubyte * len(test_data)).from_buffer(bytearray(test_data)) orig.RemoteWrite(pv, len(test_data)) @@ -197,6 +234,69 @@ def test_Clone(self): self.assertEqual(bytearray(buf)[0:read], test_data) +class Test_LockRegion_UnlockRegion(ut.TestCase): + def test_cannot_lock_memory_based_stream(self): + stm = _create_stream_on_hglobal() + # For memory-backed streams, `LockRegion` and `UnlockRegion` are + # typically not supported and will return `STG_E_INVALIDFUNCTION`. + with self.assertRaises(COMError) as cm: + stm.LockRegion(0, 5, LOCK_EXCLUSIVE) + self.assertEqual(cm.exception.hresult, STG_E_INVALIDFUNCTION) + with self.assertRaises(COMError) as cm: + stm.UnlockRegion(0, 5, LOCK_EXCLUSIVE) + self.assertEqual(cm.exception.hresult, STG_E_INVALIDFUNCTION) + + def test_can_lock_file_based_stream(self): + with tempfile.TemporaryDirectory() as t: + tmpdir = Path(t) + tmpfile = tmpdir / "lock_test.txt" + # Create a file-backed stream to enable `LockRegion` support. + # This implementation maps directly to OS-level file locking, + # which is not available for memory-based streams. + stm = _create_stream_on_file( + tmpfile, + STGM_READWRITE | STGM_SHARE_DENY_NONE | STGM_CREATE, + FILE_ATTRIBUTE_NORMAL, + True, + ) + stm.SetSize(10) # Allocate file space + stm.LockRegion(0, 5, LOCK_EXCLUSIVE) # Lock the first 5 bytes (0-4) + # Open a separate file descriptor to simulate concurrent access + fd = os.open(tmpfile, os.O_RDWR) + # Writing to the LOCKED region must fail with EACCES + os.lseek(fd, 0, os.SEEK_SET) + with self.assertRaises(OSError) as cm: + os.write(fd, b"ABCDE") + self.assertEqual(cm.exception.errno, EACCES) + # Writing to the UNLOCKED region (offset 5+) must succeed + os.lseek(fd, 5, os.SEEK_SET) + os.write(fd, b"ABCDE") + # Cleanup: Close descriptors and release the lock + os.close(fd) + stm.UnlockRegion(0, 5, LOCK_EXCLUSIVE) + buf, read = stm.RemoteRead(stm.Stat(STATFLAG_DEFAULT).cbSize) + # Verify that COM stream content reflects the successful out-of-lock write + self.assertEqual(bytearray(buf)[0:read], b"\x00\x00\x00\x00\x00ABCDE") + # Verify that the actual file content on disk matches the expected data + self.assertEqual(tmpfile.read_bytes(), b"\x00\x00\x00\x00\x00ABCDE") + + +# TODO: If there is a standard Windows `IStream` implementation that supports +# `Revert`, it should be used for testing. +# https://learn.microsoft.com/en-us/windows/win32/api/objidl/nf-objidl-istream-revert +# +# - For memory-based streams (created by `CreateStreamOnHGlobal`), +# `IStream::Revert` has no effect because the object "is not transacted" +# per the specification. All writes are committed immediately to the +# underlying HGLOBAL. +# https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-createstreamonhglobal +# +# - `IStream::Revert` is not implemented for the standard Compound File +# (Structured Storage) implementation. According to official documentation, +# `Revert` has no effect on these streams. +# https://learn.microsoft.com/en-us/windows/win32/api/objidl/nn-objidl-istream#methods + + _user32 = WinDLL("user32") _GetDC = _user32.GetDC @@ -566,7 +666,7 @@ def test_load_from_handle_stream(self): with global_alloc(GMEM_FIXED | GMEM_ZEROINIT, len(data)) as handle: with global_lock(handle) as lp_mem: ctypes.memmove(lp_mem, data, len(data)) - pstm = _create_stream(handle, delete_on_release=False) + pstm = _create_stream_on_hglobal(handle, delete_on_release=False) # Load picture from the stream pic: stdole.IPicture = POINTER(stdole.IPicture)() # type: ignore hr = _OleLoadPicture( @@ -585,7 +685,7 @@ def test_load_from_handle_stream(self): def test_load_from_buffer_stream(self): width, height = 1, 1 data = create_24bit_pixel_data(0, 255, 0, width, height) # Green pixel - srcstm = _create_stream(delete_on_release=True) + srcstm = _create_stream_on_hglobal(delete_on_release=True) pv = (c_ubyte * len(data)).from_buffer(bytearray(data)) srcstm.RemoteWrite(pv, len(data)) srcstm.Commit(STGC_DEFAULT) @@ -618,7 +718,7 @@ def test_load_from_buffer_stream(self): # BGR, 1x1 pixel, green (0, 255, 0), in Windows GDI. self.assertEqual(gdi_data, b"\x00\xff\x00") # Save picture to the stream - dststm = _create_stream(delete_on_release=True) + dststm = _create_stream_on_hglobal(delete_on_release=True) pic.SaveAsFile(dststm, False) dststm.RemoteSeek(0, STREAM_SEEK_SET) buf, read = dststm.RemoteRead(dststm.Stat(STATFLAG_DEFAULT).cbSize) @@ -648,7 +748,7 @@ def test_save_created_bitmap_picture(self): ) self.assertEqual(hr, hresult.S_OK) self.assertEqual(pic.Type, PICTYPE_BITMAP) - dststm = _create_stream(delete_on_release=True) + dststm = _create_stream_on_hglobal(delete_on_release=True) pic.SaveAsFile(dststm, True) dststm.RemoteSeek(0, STREAM_SEEK_SET) buf, read = dststm.RemoteRead(dststm.Stat(STATFLAG_DEFAULT).cbSize)