From 9f85351b1c2e30c48dfc69d52da11f68fbdad592 Mon Sep 17 00:00:00 2001 From: junkmd Date: Sun, 11 Jan 2026 11:13:14 +0900 Subject: [PATCH 1/8] test: Validate `pwcsName` in `IStorage.Stat` for created files Ensures that the `pwcsName` field of the `tagSTATSTG` structure, returned by `IStorage.Stat`, correctly reflects the file path used during the creation of a compound file. This enhances the `test_Stat` in `test_storage.py` by verifying the consistency of the storage's reported name with its physical location. --- comtypes/test/test_storage.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comtypes/test/test_storage.py b/comtypes/test/test_storage.py index 4b49ee46..c20865db 100644 --- a/comtypes/test/test_storage.py +++ b/comtypes/test/test_storage.py @@ -1,3 +1,4 @@ +import os import tempfile import unittest from _ctypes import COMError @@ -175,6 +176,10 @@ def test_Stat(self): self.assertEqual(cm.exception.hresult, STG_E_INVALIDFLAG) stat = storage.Stat(STATFLAG_DEFAULT) self.assertIsInstance(stat, tagSTATSTG) + self.assertEqual( + os.path.normcase(os.path.normpath(Path(stat.pwcsName))), + os.path.normcase(os.path.normpath(tmpfile)), + ) del storage # Release the storage to prevent 'cannot access the file ...' self.assertEqual(stat.type, STGTY_STORAGE) # Due to header overhead and file system allocation, the size may be From 3f2339249e2060d20b3e97e5d6a8c479e67bf836 Mon Sep 17 00:00:00 2001 From: junkmd Date: Sun, 11 Jan 2026 11:13:14 +0900 Subject: [PATCH 2/8] test: Verify `tagSTATSTG.pwcsName` memory management in `IStorage` and `IStream` tests. --- comtypes/test/test_storage.py | 30 ++++++++++++++++++++++++++++-- comtypes/test/test_stream.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/comtypes/test/test_storage.py b/comtypes/test/test_storage.py index c20865db..45b96095 100644 --- a/comtypes/test/test_storage.py +++ b/comtypes/test/test_storage.py @@ -1,3 +1,4 @@ +import ctypes import os import tempfile import unittest @@ -9,9 +10,10 @@ import comtypes import comtypes.client +from comtypes.malloc import IMalloc, _CoGetMalloc comtypes.client.GetModule("portabledeviceapi.dll") -from comtypes.gen.PortableDeviceApiLib import IStorage, tagSTATSTG +from comtypes.gen.PortableDeviceApiLib import WSTRING, IStorage, tagSTATSTG STGTY_STORAGE = 1 @@ -37,6 +39,17 @@ _StgCreateDocfile.restype = HRESULT +def _get_malloc() -> IMalloc: + malloc = POINTER(IMalloc)() + _CoGetMalloc(1, byref(malloc)) + assert bool(malloc) + return malloc # type: ignore + + +def _get_pwcsname(stat: tagSTATSTG) -> WSTRING: + return WSTRING.from_address(ctypes.addressof(stat) + tagSTATSTG.pwcsName.offset) + + class Test_IStorage(unittest.TestCase): RW_EXCLUSIVE = STGM_READWRITE | STGM_SHARE_EXCLUSIVE RW_EXCLUSIVE_TX = RW_EXCLUSIVE | STGM_TRANSACTED @@ -54,7 +67,8 @@ def test_CreateStream(self): # When created with `StgCreateDocfile(NULL, ...)`, `pwcsName` is a # temporary filename. The file really exists on disk because Windows # creates an actual temporary file for the compound storage. - filepath = Path(storage.Stat(STATFLAG_DEFAULT).pwcsName) + stat = storage.Stat(STATFLAG_DEFAULT) + filepath = Path(stat.pwcsName) self.assertTrue(filepath.exists()) stream = storage.CreateStream("example", self.RW_EXCLUSIVE_CREATE, 0, 0) test_data = b"Some data" @@ -68,6 +82,12 @@ def test_CreateStream(self): self.assertTrue(filepath.exists()) del storage self.assertFalse(filepath.exists()) + name_ptr = _get_pwcsname(stat) + self.assertEqual(name_ptr.value, stat.pwcsName) + malloc = _get_malloc() + self.assertEqual(malloc.DidAlloc(name_ptr), 1) + del stat + self.assertEqual(malloc.DidAlloc(name_ptr), 0) # TODO: Auto-generated methods based on type info are remote-side and hard # to call from the client. @@ -190,3 +210,9 @@ def test_Stat(self): self.assertEqual(stat.grfLocksSupported, 0) self.assertEqual(stat.clsid, comtypes.GUID()) # CLSID_NULL for new creation. self.assertEqual(stat.grfStateBits, 0) + name_ptr = _get_pwcsname(stat) + self.assertEqual(name_ptr.value, stat.pwcsName) + malloc = _get_malloc() + self.assertEqual(malloc.DidAlloc(name_ptr), 1) + del stat + self.assertEqual(malloc.DidAlloc(name_ptr), 0) diff --git a/comtypes/test/test_stream.py b/comtypes/test/test_stream.py index c8d41637..105372bf 100644 --- a/comtypes/test/test_stream.py +++ b/comtypes/test/test_stream.py @@ -40,12 +40,13 @@ import comtypes.client from comtypes import hresult +from comtypes.malloc import IMalloc, _CoGetMalloc comtypes.client.GetModule("portabledeviceapi.dll") # The stdole module is generated automatically during the portabledeviceapi # module generation. import comtypes.gen.stdole as stdole -from comtypes.gen.PortableDeviceApiLib import IStream +from comtypes.gen.PortableDeviceApiLib import WSTRING, IStream, tagSTATSTG SIZE_T = c_size_t @@ -110,6 +111,17 @@ def _create_stream_on_file( return stream # type: ignore +def _get_malloc() -> IMalloc: + malloc = POINTER(IMalloc)() + _CoGetMalloc(1, byref(malloc)) + assert bool(malloc) + return malloc # type: ignore + + +def _get_pwcsname(stat: tagSTATSTG) -> WSTRING: + return WSTRING.from_address(ctypes.addressof(stat) + tagSTATSTG.pwcsName.offset) + + class Test_RemoteWrite(ut.TestCase): def test_RemoteWrite(self): stream = _create_stream_on_hglobal() @@ -219,6 +231,12 @@ def test_returns_statstg_from_no_modified_stream(self): self.assertEqual(statstg.grfLocksSupported, 0) self.assertEqual(statstg.clsid, comtypes.GUID()) self.assertEqual(statstg.grfStateBits, 0) + name_ptr = _get_pwcsname(statstg) + self.assertIsNone(name_ptr.value) + malloc = _get_malloc() + self.assertEqual(malloc.DidAlloc(name_ptr), -1) + del statstg + self.assertEqual(malloc.DidAlloc(name_ptr), -1) class Test_Clone(ut.TestCase): @@ -274,11 +292,18 @@ def test_can_lock_file_based_stream(self): # Cleanup: Close descriptors and release the lock os.close(fd) stm.UnlockRegion(0, 5, LOCK_EXCLUSIVE) - buf, read = stm.RemoteRead(stm.Stat(STATFLAG_DEFAULT).cbSize) + statstg = stm.Stat(STATFLAG_DEFAULT) + buf, read = stm.RemoteRead(statstg.cbSize) # Verify that COM stream content reflects the successful out-of-lock write self.assertEqual(bytearray(buf)[0:read], b"\x00\x00\x00\x00\x00ABCDE") # Verify that the actual file content on disk matches the expected data self.assertEqual(tmpfile.read_bytes(), b"\x00\x00\x00\x00\x00ABCDE") + name_ptr = _get_pwcsname(statstg) + self.assertEqual(name_ptr.value, statstg.pwcsName) + malloc = _get_malloc() + self.assertEqual(malloc.DidAlloc(name_ptr), 1) + del statstg + self.assertEqual(malloc.DidAlloc(name_ptr), 0) # TODO: If there is a standard Windows `IStream` implementation that supports From ddb6a5274afdc9c7b5c70ad4aedf29f857eb893f Mon Sep 17 00:00:00 2001 From: junkmd Date: Sun, 11 Jan 2026 11:13:14 +0900 Subject: [PATCH 3/8] test: Validate `FILETIME` in `IStorage.Stat` test. Added validation for `ctime`, `atime`, and `mtime` `FILETIME` timestamps within the `tagSTATSTG` structure in `test_storage.py`. --- comtypes/test/test_storage.py | 64 +++++++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 6 deletions(-) diff --git a/comtypes/test/test_storage.py b/comtypes/test/test_storage.py index 45b96095..48939fc9 100644 --- a/comtypes/test/test_storage.py +++ b/comtypes/test/test_storage.py @@ -3,8 +3,8 @@ import tempfile import unittest from _ctypes import COMError -from ctypes import HRESULT, POINTER, OleDLL, byref, c_ubyte -from ctypes.wintypes import DWORD, PWCHAR +from ctypes import HRESULT, POINTER, OleDLL, Structure, WinDLL, byref, c_ubyte +from ctypes.wintypes import BOOL, DWORD, FILETIME, LONG, PWCHAR, WORD from pathlib import Path from typing import Optional @@ -15,6 +15,32 @@ comtypes.client.GetModule("portabledeviceapi.dll") from comtypes.gen.PortableDeviceApiLib import WSTRING, IStorage, tagSTATSTG + +class SYSTEMTIME(Structure): + _fields_ = [ + ("wYear", WORD), + ("wMonth", WORD), + ("wDayOfWeek", WORD), + ("wDay", WORD), + ("wHour", WORD), + ("wMinute", WORD), + ("wSecond", WORD), + ("wMilliseconds", WORD), + ] + + +_kernel32 = WinDLL("kernel32") + +# https://learn.microsoft.com/en-us/windows/win32/api/timezoneapi/nf-timezoneapi-systemtimetofiletime +_SystemTimeToFileTime = _kernel32.SystemTimeToFileTime +_SystemTimeToFileTime.argtypes = [POINTER(SYSTEMTIME), POINTER(FILETIME)] +_SystemTimeToFileTime.restype = BOOL + +# https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-comparefiletime +_CompareFileTime = _kernel32.CompareFileTime +_CompareFileTime.argtypes = [POINTER(FILETIME), POINTER(FILETIME)] +_CompareFileTime.restype = LONG + STGTY_STORAGE = 1 STATFLAG_DEFAULT = 0 @@ -39,6 +65,16 @@ _StgCreateDocfile.restype = HRESULT +def _systemtime_to_filetime(st: SYSTEMTIME) -> FILETIME: + ft = FILETIME() + _SystemTimeToFileTime(byref(st), byref(ft)) + return ft + + +def _compare_filetime(ft1: FILETIME, ft2: FILETIME) -> int: + return _CompareFileTime(byref(ft1), byref(ft2)) + + def _get_malloc() -> IMalloc: malloc = POINTER(IMalloc)() _CoGetMalloc(1, byref(malloc)) @@ -62,6 +98,10 @@ def _create_docfile(self, mode: int, name: Optional[str] = None) -> IStorage: _StgCreateDocfile(name, mode, 0, byref(stg)) return stg # type: ignore + FIXED_TEST_FILETIME = _systemtime_to_filetime( + SYSTEMTIME(wYear=2000, wMonth=1, wDay=1) + ) + def test_CreateStream(self): storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) # When created with `StgCreateDocfile(NULL, ...)`, `pwcsName` is a @@ -196,12 +236,24 @@ def test_Stat(self): self.assertEqual(cm.exception.hresult, STG_E_INVALIDFLAG) stat = storage.Stat(STATFLAG_DEFAULT) self.assertIsInstance(stat, tagSTATSTG) - self.assertEqual( - os.path.normcase(os.path.normpath(Path(stat.pwcsName))), - os.path.normcase(os.path.normpath(tmpfile)), - ) del storage # Release the storage to prevent 'cannot access the file ...' + # Validate each field: + self.assertEqual( + os.path.normcase(os.path.normpath(Path(stat.pwcsName))), + os.path.normcase(os.path.normpath(tmpfile)), + ) self.assertEqual(stat.type, STGTY_STORAGE) + # Timestamps (`mtime`, `ctime`, `atime`) are set by the underlying + # compound file implementation. + # In many cases (especially on modern Windows with NTFS), all three + # timestamps are set to the same value at creation time. However, this + # is not guaranteed by the OLE32 specification. + # Therefore, we only verify that each timestamp is a valid `FILETIME` + # (non-zero is sufficient for a newly created file). + zero_ft = FILETIME() + self.assertNotEqual(_compare_filetime(stat.ctime, zero_ft), 0) + self.assertNotEqual(_compare_filetime(stat.atime, zero_ft), 0) + self.assertNotEqual(_compare_filetime(stat.mtime, zero_ft), 0) # Due to header overhead and file system allocation, the size may be # greater than 0 bytes. self.assertGreaterEqual(stat.cbSize, 0) From 50af548112b4239ac1594463a6a24f2f9ab5c47b Mon Sep 17 00:00:00 2001 From: junkmd Date: Sun, 11 Jan 2026 11:13:14 +0900 Subject: [PATCH 4/8] test: Add `IStorage.SetElementTimes` functionality test. Introduced `test_SetElementTimes` in `test_storage.py` to verify the functionality of the `IStorage.SetElementTimes` method. --- comtypes/test/test_storage.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/comtypes/test/test_storage.py b/comtypes/test/test_storage.py index 48939fc9..7cdc5a6e 100644 --- a/comtypes/test/test_storage.py +++ b/comtypes/test/test_storage.py @@ -209,6 +209,32 @@ def test_RenameElement(self): storage.OpenStorage("example", None, self.RW_EXCLUSIVE_TX, None, 0) self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND) + def test_SetElementTimes(self): + storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) + sub_name = "SubStorageElement" + orig_stat = storage.CreateStorage(sub_name, self.CREATE_TESTDOC, 0, 0).Stat( + STATFLAG_DEFAULT + ) + storage.SetElementTimes( + sub_name, + None, # pctime (creation time) + None, # patime (access time) + self.FIXED_TEST_FILETIME, # pmtime (modification time) + ) + storage.Commit(STGC_DEFAULT) + modified_stat = storage.OpenStorage( + sub_name, None, self.RW_EXCLUSIVE_TX, None, 0 + ).Stat(STATFLAG_DEFAULT) + self.assertEqual(_compare_filetime(orig_stat.ctime, modified_stat.ctime), 0) + self.assertEqual(_compare_filetime(orig_stat.atime, modified_stat.atime), 0) + self.assertNotEqual(_compare_filetime(orig_stat.mtime, modified_stat.mtime), 0) + self.assertEqual( + _compare_filetime(self.FIXED_TEST_FILETIME, modified_stat.mtime), 0 + ) + with self.assertRaises(COMError) as cm: + storage.SetElementTimes("NonExistent", None, None, self.FIXED_TEST_FILETIME) + self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND) + def test_SetClass(self): storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) # Initial value is CLSID_NULL. From abc670a5c7e3e2a804c06d59e89c9b940e3ac656 Mon Sep 17 00:00:00 2001 From: junkmd Date: Sun, 11 Jan 2026 11:13:14 +0900 Subject: [PATCH 5/8] refactor: Rename `statstg` to `stat` in `test_stream.py`. --- comtypes/test/test_stream.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/comtypes/test/test_stream.py b/comtypes/test/test_stream.py index 105372bf..19691535 100644 --- a/comtypes/test/test_stream.py +++ b/comtypes/test/test_stream.py @@ -218,24 +218,24 @@ def test_RemoteCopyTo(self): class Test_Stat(ut.TestCase): # https://learn.microsoft.com/en-us/windows/win32/api/objidl/nf-objidl-istream-stat # https://learn.microsoft.com/en-us/windows/win32/api/objidl/ns-objidl-statstg - def test_returns_statstg_from_no_modified_stream(self): + def test_returns_stat_from_no_modified_stream(self): stream = _create_stream_on_hglobal() - statstg = stream.Stat(STATFLAG_DEFAULT) - self.assertIsNone(statstg.pwcsName) - self.assertEqual(statstg.type, STGTY_STREAM) - self.assertEqual(statstg.cbSize, 0) - mt, ct, at = statstg.mtime, statstg.ctime, statstg.atime + stat = stream.Stat(STATFLAG_DEFAULT) + self.assertIsNone(stat.pwcsName) + self.assertEqual(stat.type, STGTY_STREAM) + self.assertEqual(stat.cbSize, 0) + mt, ct, at = stat.mtime, stat.ctime, stat.atime self.assertTrue(mt.dwLowDateTime == ct.dwLowDateTime == at.dwLowDateTime) self.assertTrue(mt.dwHighDateTime == ct.dwHighDateTime == at.dwHighDateTime) - self.assertEqual(statstg.grfMode, 0) - self.assertEqual(statstg.grfLocksSupported, 0) - self.assertEqual(statstg.clsid, comtypes.GUID()) - self.assertEqual(statstg.grfStateBits, 0) - name_ptr = _get_pwcsname(statstg) + self.assertEqual(stat.grfMode, 0) + self.assertEqual(stat.grfLocksSupported, 0) + self.assertEqual(stat.clsid, comtypes.GUID()) + self.assertEqual(stat.grfStateBits, 0) + name_ptr = _get_pwcsname(stat) self.assertIsNone(name_ptr.value) malloc = _get_malloc() self.assertEqual(malloc.DidAlloc(name_ptr), -1) - del statstg + del stat self.assertEqual(malloc.DidAlloc(name_ptr), -1) @@ -292,17 +292,17 @@ def test_can_lock_file_based_stream(self): # Cleanup: Close descriptors and release the lock os.close(fd) stm.UnlockRegion(0, 5, LOCK_EXCLUSIVE) - statstg = stm.Stat(STATFLAG_DEFAULT) - buf, read = stm.RemoteRead(statstg.cbSize) + stat = stm.Stat(STATFLAG_DEFAULT) + buf, read = stm.RemoteRead(stat.cbSize) # Verify that COM stream content reflects the successful out-of-lock write self.assertEqual(bytearray(buf)[0:read], b"\x00\x00\x00\x00\x00ABCDE") # Verify that the actual file content on disk matches the expected data self.assertEqual(tmpfile.read_bytes(), b"\x00\x00\x00\x00\x00ABCDE") - name_ptr = _get_pwcsname(statstg) - self.assertEqual(name_ptr.value, statstg.pwcsName) + name_ptr = _get_pwcsname(stat) + self.assertEqual(name_ptr.value, stat.pwcsName) malloc = _get_malloc() self.assertEqual(malloc.DidAlloc(name_ptr), 1) - del statstg + del stat self.assertEqual(malloc.DidAlloc(name_ptr), 0) From 0d5a80b597f3d006e11e1754702618eea812964e Mon Sep 17 00:00:00 2001 From: junkmd Date: Sun, 11 Jan 2026 11:13:14 +0900 Subject: [PATCH 6/8] refactor: Improve `SIZE_T` import in `malloc.py`. Directly import `c_size_t` as `SIZE_T` from `ctypes` to remove an unnecessary alias assignment, streamlining type definition. --- comtypes/malloc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comtypes/malloc.py b/comtypes/malloc.py index fcde6264..84516021 100644 --- a/comtypes/malloc.py +++ b/comtypes/malloc.py @@ -1,4 +1,5 @@ -from ctypes import HRESULT, POINTER, OleDLL, WinDLL, c_int, c_size_t, c_ulong, c_void_p +from ctypes import HRESULT, POINTER, OleDLL, WinDLL, c_int, c_ulong, c_void_p +from ctypes import c_size_t as SIZE_T from ctypes.wintypes import DWORD, LPVOID from typing import TYPE_CHECKING, Any, Optional @@ -34,7 +35,6 @@ def HeapMinimize(self) -> None: ... _ole32_nohresult = WinDLL("ole32") -SIZE_T = c_size_t _CoTaskMemAlloc = _ole32_nohresult.CoTaskMemAlloc _CoTaskMemAlloc.argtypes = [SIZE_T] _CoTaskMemAlloc.restype = LPVOID From c21d63fb0f25f246ff10e13a24264a9094f0c216 Mon Sep 17 00:00:00 2001 From: junkmd Date: Sun, 11 Jan 2026 11:13:14 +0900 Subject: [PATCH 7/8] refactor: Centralize `IMalloc` retrieval with `CoGetMalloc` utility function. Introduced a `CoGetMalloc` utility function in `comtypes/malloc.py` to encapsulate the low-level `_CoGetMalloc` API call. This new function streamlines the acquisition of the OLE task memory allocator by standardizing the `dwMemContext` parameter to `1`, which is the only supported value for `CoGetMalloc`. --- comtypes/malloc.py | 15 ++++++++++++++- comtypes/test/test_malloc.py | 13 +++---------- comtypes/test/test_outparam.py | 7 +++---- comtypes/test/test_storage.py | 13 +++---------- comtypes/test/test_stream.py | 13 +++---------- 5 files changed, 26 insertions(+), 35 deletions(-) diff --git a/comtypes/malloc.py b/comtypes/malloc.py index 84516021..b23cea32 100644 --- a/comtypes/malloc.py +++ b/comtypes/malloc.py @@ -1,4 +1,4 @@ -from ctypes import HRESULT, POINTER, OleDLL, WinDLL, c_int, c_ulong, c_void_p +from ctypes import HRESULT, POINTER, OleDLL, WinDLL, byref, c_int, c_ulong, c_void_p from ctypes import c_size_t as SIZE_T from ctypes.wintypes import DWORD, LPVOID from typing import TYPE_CHECKING, Any, Optional @@ -38,3 +38,16 @@ def HeapMinimize(self) -> None: ... _CoTaskMemAlloc = _ole32_nohresult.CoTaskMemAlloc _CoTaskMemAlloc.argtypes = [SIZE_T] _CoTaskMemAlloc.restype = LPVOID + + +def CoGetMalloc(dwMemContext: int = 1) -> IMalloc: + """Retrieves a pointer to the default OLE task memory allocator. + + https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cogetmalloc + """ + malloc = POINTER(IMalloc)() + _CoGetMalloc( + dwMemContext, # This parameter must be 1. + byref(malloc), + ) + return malloc # type: ignore diff --git a/comtypes/test/test_malloc.py b/comtypes/test/test_malloc.py index f3c5fe25..9333babb 100644 --- a/comtypes/test/test_malloc.py +++ b/comtypes/test/test_malloc.py @@ -4,7 +4,7 @@ from pathlib import Path from comtypes import GUID, hresult -from comtypes.malloc import IMalloc, _CoGetMalloc, _CoTaskMemFree +from comtypes.malloc import CoGetMalloc, _CoTaskMemFree # Constants # KNOWNFOLDERID @@ -24,16 +24,9 @@ _SHGetKnownFolderPath.restype = HRESULT -def _get_malloc() -> IMalloc: - malloc = POINTER(IMalloc)() - _CoGetMalloc(1, byref(malloc)) - assert bool(malloc) - return malloc # type: ignore - - class Test(ut.TestCase): def test_Realloc(self): - malloc = _get_malloc() + malloc = CoGetMalloc() size1 = 4 ptr1 = malloc.Alloc(size1) self.assertEqual(malloc.DidAlloc(ptr1), 1) @@ -59,7 +52,7 @@ def test_SHGetKnownFolderPath(self): self.assertEqual(hr, hresult.S_OK) self.assertIsInstance(ptr.value, str) self.assertTrue(Path(ptr.value).exists()) # type: ignore - malloc = _get_malloc() + malloc = CoGetMalloc() self.assertEqual(malloc.DidAlloc(ptr), 1) self.assertGreater(malloc.GetSize(ptr), 0) _CoTaskMemFree(ptr) diff --git a/comtypes/test/test_outparam.py b/comtypes/test/test_outparam.py index 356b6187..1650427d 100644 --- a/comtypes/test/test_outparam.py +++ b/comtypes/test/test_outparam.py @@ -1,15 +1,14 @@ import logging import unittest -from ctypes import POINTER, byref, c_wchar, c_wchar_p, cast, memmove, sizeof, wstring_at +from ctypes import c_wchar, c_wchar_p, cast, memmove, sizeof, wstring_at from unittest.mock import patch -from comtypes.malloc import IMalloc, _CoGetMalloc, _CoTaskMemAlloc, _CoTaskMemFree +from comtypes.malloc import CoGetMalloc, _CoTaskMemAlloc, _CoTaskMemFree logger = logging.getLogger(__name__) -malloc = POINTER(IMalloc)() -_CoGetMalloc(1, byref(malloc)) +malloc = CoGetMalloc() assert bool(malloc) diff --git a/comtypes/test/test_storage.py b/comtypes/test/test_storage.py index 7cdc5a6e..a1a68fff 100644 --- a/comtypes/test/test_storage.py +++ b/comtypes/test/test_storage.py @@ -10,7 +10,7 @@ import comtypes import comtypes.client -from comtypes.malloc import IMalloc, _CoGetMalloc +from comtypes.malloc import CoGetMalloc comtypes.client.GetModule("portabledeviceapi.dll") from comtypes.gen.PortableDeviceApiLib import WSTRING, IStorage, tagSTATSTG @@ -75,13 +75,6 @@ def _compare_filetime(ft1: FILETIME, ft2: FILETIME) -> int: return _CompareFileTime(byref(ft1), byref(ft2)) -def _get_malloc() -> IMalloc: - malloc = POINTER(IMalloc)() - _CoGetMalloc(1, byref(malloc)) - assert bool(malloc) - return malloc # type: ignore - - def _get_pwcsname(stat: tagSTATSTG) -> WSTRING: return WSTRING.from_address(ctypes.addressof(stat) + tagSTATSTG.pwcsName.offset) @@ -124,7 +117,7 @@ def test_CreateStream(self): self.assertFalse(filepath.exists()) name_ptr = _get_pwcsname(stat) self.assertEqual(name_ptr.value, stat.pwcsName) - malloc = _get_malloc() + malloc = CoGetMalloc() self.assertEqual(malloc.DidAlloc(name_ptr), 1) del stat self.assertEqual(malloc.DidAlloc(name_ptr), 0) @@ -290,7 +283,7 @@ def test_Stat(self): self.assertEqual(stat.grfStateBits, 0) name_ptr = _get_pwcsname(stat) self.assertEqual(name_ptr.value, stat.pwcsName) - malloc = _get_malloc() + malloc = CoGetMalloc() self.assertEqual(malloc.DidAlloc(name_ptr), 1) del stat self.assertEqual(malloc.DidAlloc(name_ptr), 0) diff --git a/comtypes/test/test_stream.py b/comtypes/test/test_stream.py index 19691535..ffa71b99 100644 --- a/comtypes/test/test_stream.py +++ b/comtypes/test/test_stream.py @@ -40,7 +40,7 @@ import comtypes.client from comtypes import hresult -from comtypes.malloc import IMalloc, _CoGetMalloc +from comtypes.malloc import CoGetMalloc comtypes.client.GetModule("portabledeviceapi.dll") # The stdole module is generated automatically during the portabledeviceapi @@ -111,13 +111,6 @@ def _create_stream_on_file( return stream # type: ignore -def _get_malloc() -> IMalloc: - malloc = POINTER(IMalloc)() - _CoGetMalloc(1, byref(malloc)) - assert bool(malloc) - return malloc # type: ignore - - def _get_pwcsname(stat: tagSTATSTG) -> WSTRING: return WSTRING.from_address(ctypes.addressof(stat) + tagSTATSTG.pwcsName.offset) @@ -233,7 +226,7 @@ def test_returns_stat_from_no_modified_stream(self): self.assertEqual(stat.grfStateBits, 0) name_ptr = _get_pwcsname(stat) self.assertIsNone(name_ptr.value) - malloc = _get_malloc() + malloc = CoGetMalloc() self.assertEqual(malloc.DidAlloc(name_ptr), -1) del stat self.assertEqual(malloc.DidAlloc(name_ptr), -1) @@ -300,7 +293,7 @@ def test_can_lock_file_based_stream(self): self.assertEqual(tmpfile.read_bytes(), b"\x00\x00\x00\x00\x00ABCDE") name_ptr = _get_pwcsname(stat) self.assertEqual(name_ptr.value, stat.pwcsName) - malloc = _get_malloc() + malloc = CoGetMalloc() self.assertEqual(malloc.DidAlloc(name_ptr), 1) del stat self.assertEqual(malloc.DidAlloc(name_ptr), 0) From 65dd0e577b97db87a3badc63d29576dec76f1733 Mon Sep 17 00:00:00 2001 From: junkmd Date: Sun, 11 Jan 2026 11:13:14 +0900 Subject: [PATCH 8/8] fix: Remove `DidAlloc` checks on freed memory in `STATSTG` tests. Removed assertions for `IMalloc.DidAlloc` after `del stat` in `test_storage.py` and `test_stream.py`. --- comtypes/test/test_storage.py | 8 ++++---- comtypes/test/test_stream.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/comtypes/test/test_storage.py b/comtypes/test/test_storage.py index a1a68fff..9f21aa32 100644 --- a/comtypes/test/test_storage.py +++ b/comtypes/test/test_storage.py @@ -119,8 +119,8 @@ def test_CreateStream(self): self.assertEqual(name_ptr.value, stat.pwcsName) malloc = CoGetMalloc() self.assertEqual(malloc.DidAlloc(name_ptr), 1) - del stat - self.assertEqual(malloc.DidAlloc(name_ptr), 0) + del stat # `pwcsName` is expected to be freed here. + # `DidAlloc` checks are skipped to avoid using a dangling pointer. # TODO: Auto-generated methods based on type info are remote-side and hard # to call from the client. @@ -285,5 +285,5 @@ def test_Stat(self): self.assertEqual(name_ptr.value, stat.pwcsName) malloc = CoGetMalloc() self.assertEqual(malloc.DidAlloc(name_ptr), 1) - del stat - self.assertEqual(malloc.DidAlloc(name_ptr), 0) + del stat # `pwcsName` is expected to be freed here. + # `DidAlloc` checks are skipped to avoid using a dangling pointer. diff --git a/comtypes/test/test_stream.py b/comtypes/test/test_stream.py index ffa71b99..1e270f5a 100644 --- a/comtypes/test/test_stream.py +++ b/comtypes/test/test_stream.py @@ -228,8 +228,8 @@ def test_returns_stat_from_no_modified_stream(self): self.assertIsNone(name_ptr.value) malloc = CoGetMalloc() self.assertEqual(malloc.DidAlloc(name_ptr), -1) - del stat - self.assertEqual(malloc.DidAlloc(name_ptr), -1) + del stat # `pwcsName` is expected to be freed here. + # `DidAlloc` checks are skipped to avoid using a dangling pointer. class Test_Clone(ut.TestCase): @@ -295,8 +295,8 @@ def test_can_lock_file_based_stream(self): self.assertEqual(name_ptr.value, stat.pwcsName) malloc = CoGetMalloc() self.assertEqual(malloc.DidAlloc(name_ptr), 1) - del stat - self.assertEqual(malloc.DidAlloc(name_ptr), 0) + del stat # `pwcsName` is expected to be freed here. + # `DidAlloc` checks are skipped to avoid using a dangling pointer. # TODO: If there is a standard Windows `IStream` implementation that supports