diff --git a/comtypes/malloc.py b/comtypes/malloc.py index fcde6264..b23cea32 100644 --- a/comtypes/malloc.py +++ b/comtypes/malloc.py @@ -1,4 +1,5 @@ -from ctypes import HRESULT, POINTER, OleDLL, WinDLL, c_int, c_size_t, c_ulong, c_void_p +from ctypes import HRESULT, POINTER, OleDLL, WinDLL, byref, c_int, c_ulong, c_void_p +from ctypes import c_size_t as SIZE_T from ctypes.wintypes import DWORD, LPVOID from typing import TYPE_CHECKING, Any, Optional @@ -34,7 +35,19 @@ def HeapMinimize(self) -> None: ... _ole32_nohresult = WinDLL("ole32") -SIZE_T = c_size_t _CoTaskMemAlloc = _ole32_nohresult.CoTaskMemAlloc _CoTaskMemAlloc.argtypes = [SIZE_T] _CoTaskMemAlloc.restype = LPVOID + + +def CoGetMalloc(dwMemContext: int = 1) -> IMalloc: + """Retrieves a pointer to the default OLE task memory allocator. + + https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cogetmalloc + """ + malloc = POINTER(IMalloc)() + _CoGetMalloc( + dwMemContext, # This parameter must be 1. + byref(malloc), + ) + return malloc # type: ignore diff --git a/comtypes/test/test_malloc.py b/comtypes/test/test_malloc.py index f3c5fe25..9333babb 100644 --- a/comtypes/test/test_malloc.py +++ b/comtypes/test/test_malloc.py @@ -4,7 +4,7 @@ from pathlib import Path from comtypes import GUID, hresult -from comtypes.malloc import IMalloc, _CoGetMalloc, _CoTaskMemFree +from comtypes.malloc import CoGetMalloc, _CoTaskMemFree # Constants # KNOWNFOLDERID @@ -24,16 +24,9 @@ _SHGetKnownFolderPath.restype = HRESULT -def _get_malloc() -> IMalloc: - malloc = POINTER(IMalloc)() - _CoGetMalloc(1, byref(malloc)) - assert bool(malloc) - return malloc # type: ignore - - class Test(ut.TestCase): def test_Realloc(self): - malloc = _get_malloc() + malloc = CoGetMalloc() size1 = 4 ptr1 = malloc.Alloc(size1) self.assertEqual(malloc.DidAlloc(ptr1), 1) @@ -59,7 +52,7 @@ def test_SHGetKnownFolderPath(self): self.assertEqual(hr, hresult.S_OK) self.assertIsInstance(ptr.value, str) self.assertTrue(Path(ptr.value).exists()) # type: ignore - malloc = _get_malloc() + malloc = CoGetMalloc() self.assertEqual(malloc.DidAlloc(ptr), 1) self.assertGreater(malloc.GetSize(ptr), 0) _CoTaskMemFree(ptr) diff --git a/comtypes/test/test_outparam.py b/comtypes/test/test_outparam.py index 356b6187..1650427d 100644 --- a/comtypes/test/test_outparam.py +++ b/comtypes/test/test_outparam.py @@ -1,15 +1,14 @@ import logging import unittest -from ctypes import POINTER, byref, c_wchar, c_wchar_p, cast, memmove, sizeof, wstring_at +from ctypes import c_wchar, c_wchar_p, cast, memmove, sizeof, wstring_at from unittest.mock import patch -from comtypes.malloc import IMalloc, _CoGetMalloc, _CoTaskMemAlloc, _CoTaskMemFree +from comtypes.malloc import CoGetMalloc, _CoTaskMemAlloc, _CoTaskMemFree logger = logging.getLogger(__name__) -malloc = POINTER(IMalloc)() -_CoGetMalloc(1, byref(malloc)) +malloc = CoGetMalloc() assert bool(malloc) diff --git a/comtypes/test/test_storage.py b/comtypes/test/test_storage.py index 4b49ee46..9f21aa32 100644 --- a/comtypes/test/test_storage.py +++ b/comtypes/test/test_storage.py @@ -1,16 +1,45 @@ +import ctypes +import os import tempfile import unittest from _ctypes import COMError -from ctypes import HRESULT, POINTER, OleDLL, byref, c_ubyte -from ctypes.wintypes import DWORD, PWCHAR +from ctypes import HRESULT, POINTER, OleDLL, Structure, WinDLL, byref, c_ubyte +from ctypes.wintypes import BOOL, DWORD, FILETIME, LONG, PWCHAR, WORD from pathlib import Path from typing import Optional import comtypes import comtypes.client +from comtypes.malloc import CoGetMalloc comtypes.client.GetModule("portabledeviceapi.dll") -from comtypes.gen.PortableDeviceApiLib import IStorage, tagSTATSTG +from comtypes.gen.PortableDeviceApiLib import WSTRING, IStorage, tagSTATSTG + + +class SYSTEMTIME(Structure): + _fields_ = [ + ("wYear", WORD), + ("wMonth", WORD), + ("wDayOfWeek", WORD), + ("wDay", WORD), + ("wHour", WORD), + ("wMinute", WORD), + ("wSecond", WORD), + ("wMilliseconds", WORD), + ] + + +_kernel32 = WinDLL("kernel32") + +# https://learn.microsoft.com/en-us/windows/win32/api/timezoneapi/nf-timezoneapi-systemtimetofiletime +_SystemTimeToFileTime = _kernel32.SystemTimeToFileTime +_SystemTimeToFileTime.argtypes = [POINTER(SYSTEMTIME), POINTER(FILETIME)] +_SystemTimeToFileTime.restype = BOOL + +# https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-comparefiletime +_CompareFileTime = _kernel32.CompareFileTime +_CompareFileTime.argtypes = [POINTER(FILETIME), POINTER(FILETIME)] +_CompareFileTime.restype = LONG STGTY_STORAGE = 1 @@ -36,6 +65,20 @@ _StgCreateDocfile.restype = HRESULT +def _systemtime_to_filetime(st: SYSTEMTIME) -> FILETIME: + ft = FILETIME() + _SystemTimeToFileTime(byref(st), byref(ft)) + return ft + + +def _compare_filetime(ft1: FILETIME, ft2: FILETIME) -> int: + return _CompareFileTime(byref(ft1), byref(ft2)) + + +def _get_pwcsname(stat: tagSTATSTG) -> WSTRING: + return WSTRING.from_address(ctypes.addressof(stat) + tagSTATSTG.pwcsName.offset) + + class Test_IStorage(unittest.TestCase): RW_EXCLUSIVE = STGM_READWRITE | STGM_SHARE_EXCLUSIVE RW_EXCLUSIVE_TX = RW_EXCLUSIVE | STGM_TRANSACTED @@ -48,12 +91,17 @@ def _create_docfile(self, mode: int, name: Optional[str] = None) -> IStorage: _StgCreateDocfile(name, mode, 0, byref(stg)) return stg # type: ignore + FIXED_TEST_FILETIME = _systemtime_to_filetime( + SYSTEMTIME(wYear=2000, wMonth=1, wDay=1) + ) + def test_CreateStream(self): storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) # When created with `StgCreateDocfile(NULL, ...)`, `pwcsName` is a # temporary filename. The file really exists on disk because Windows # creates an actual temporary file for the compound storage. - filepath = Path(storage.Stat(STATFLAG_DEFAULT).pwcsName) + stat = storage.Stat(STATFLAG_DEFAULT) + filepath = Path(stat.pwcsName) self.assertTrue(filepath.exists()) stream = storage.CreateStream("example", self.RW_EXCLUSIVE_CREATE, 0, 0) test_data = b"Some data" @@ -67,6 +115,12 @@ def test_CreateStream(self): self.assertTrue(filepath.exists()) del storage self.assertFalse(filepath.exists()) + name_ptr = _get_pwcsname(stat) + self.assertEqual(name_ptr.value, stat.pwcsName) + malloc = CoGetMalloc() + self.assertEqual(malloc.DidAlloc(name_ptr), 1) + del stat # `pwcsName` is expected to be freed here. + # `DidAlloc` checks are skipped to avoid using a dangling pointer. # TODO: Auto-generated methods based on type info are remote-side and hard # to call from the client. @@ -148,6 +202,32 @@ def test_RenameElement(self): storage.OpenStorage("example", None, self.RW_EXCLUSIVE_TX, None, 0) self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND) + def test_SetElementTimes(self): + storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) + sub_name = "SubStorageElement" + orig_stat = storage.CreateStorage(sub_name, self.CREATE_TESTDOC, 0, 0).Stat( + STATFLAG_DEFAULT + ) + storage.SetElementTimes( + sub_name, + None, # pctime (creation time) + None, # patime (access time) + self.FIXED_TEST_FILETIME, # pmtime (modification time) + ) + storage.Commit(STGC_DEFAULT) + modified_stat = storage.OpenStorage( + sub_name, None, self.RW_EXCLUSIVE_TX, None, 0 + ).Stat(STATFLAG_DEFAULT) + self.assertEqual(_compare_filetime(orig_stat.ctime, modified_stat.ctime), 0) + self.assertEqual(_compare_filetime(orig_stat.atime, modified_stat.atime), 0) + self.assertNotEqual(_compare_filetime(orig_stat.mtime, modified_stat.mtime), 0) + self.assertEqual( + _compare_filetime(self.FIXED_TEST_FILETIME, modified_stat.mtime), 0 + ) + with self.assertRaises(COMError) as cm: + storage.SetElementTimes("NonExistent", None, None, self.FIXED_TEST_FILETIME) + self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND) + def test_SetClass(self): storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC) # Initial value is CLSID_NULL. @@ -176,7 +256,23 @@ def test_Stat(self): stat = storage.Stat(STATFLAG_DEFAULT) self.assertIsInstance(stat, tagSTATSTG) del storage # Release the storage to prevent 'cannot access the file ...' + # Validate each field: + self.assertEqual( + os.path.normcase(os.path.normpath(Path(stat.pwcsName))), + os.path.normcase(os.path.normpath(tmpfile)), + ) self.assertEqual(stat.type, STGTY_STORAGE) + # Timestamps (`mtime`, `ctime`, `atime`) are set by the underlying + # compound file implementation. + # In many cases (especially on modern Windows with NTFS), all three + # timestamps are set to the same value at creation time. However, this + # is not guaranteed by the OLE32 specification. + # Therefore, we only verify that each timestamp is a valid `FILETIME` + # (non-zero is sufficient for a newly created file). + zero_ft = FILETIME() + self.assertNotEqual(_compare_filetime(stat.ctime, zero_ft), 0) + self.assertNotEqual(_compare_filetime(stat.atime, zero_ft), 0) + self.assertNotEqual(_compare_filetime(stat.mtime, zero_ft), 0) # Due to header overhead and file system allocation, the size may be # greater than 0 bytes. self.assertGreaterEqual(stat.cbSize, 0) @@ -185,3 +281,9 @@ def test_Stat(self): self.assertEqual(stat.grfLocksSupported, 0) self.assertEqual(stat.clsid, comtypes.GUID()) # CLSID_NULL for new creation. self.assertEqual(stat.grfStateBits, 0) + name_ptr = _get_pwcsname(stat) + self.assertEqual(name_ptr.value, stat.pwcsName) + malloc = CoGetMalloc() + self.assertEqual(malloc.DidAlloc(name_ptr), 1) + del stat # `pwcsName` is expected to be freed here. + # `DidAlloc` checks are skipped to avoid using a dangling pointer. diff --git a/comtypes/test/test_stream.py b/comtypes/test/test_stream.py index c8d41637..1e270f5a 100644 --- a/comtypes/test/test_stream.py +++ b/comtypes/test/test_stream.py @@ -40,12 +40,13 @@ import comtypes.client from comtypes import hresult +from comtypes.malloc import CoGetMalloc comtypes.client.GetModule("portabledeviceapi.dll") # The stdole module is generated automatically during the portabledeviceapi # module generation. import comtypes.gen.stdole as stdole -from comtypes.gen.PortableDeviceApiLib import IStream +from comtypes.gen.PortableDeviceApiLib import WSTRING, IStream, tagSTATSTG SIZE_T = c_size_t @@ -110,6 +111,10 @@ def _create_stream_on_file( return stream # type: ignore +def _get_pwcsname(stat: tagSTATSTG) -> WSTRING: + return WSTRING.from_address(ctypes.addressof(stat) + tagSTATSTG.pwcsName.offset) + + class Test_RemoteWrite(ut.TestCase): def test_RemoteWrite(self): stream = _create_stream_on_hglobal() @@ -206,19 +211,25 @@ def test_RemoteCopyTo(self): class Test_Stat(ut.TestCase): # https://learn.microsoft.com/en-us/windows/win32/api/objidl/nf-objidl-istream-stat # https://learn.microsoft.com/en-us/windows/win32/api/objidl/ns-objidl-statstg - def test_returns_statstg_from_no_modified_stream(self): + def test_returns_stat_from_no_modified_stream(self): stream = _create_stream_on_hglobal() - statstg = stream.Stat(STATFLAG_DEFAULT) - self.assertIsNone(statstg.pwcsName) - self.assertEqual(statstg.type, STGTY_STREAM) - self.assertEqual(statstg.cbSize, 0) - mt, ct, at = statstg.mtime, statstg.ctime, statstg.atime + stat = stream.Stat(STATFLAG_DEFAULT) + self.assertIsNone(stat.pwcsName) + self.assertEqual(stat.type, STGTY_STREAM) + self.assertEqual(stat.cbSize, 0) + mt, ct, at = stat.mtime, stat.ctime, stat.atime self.assertTrue(mt.dwLowDateTime == ct.dwLowDateTime == at.dwLowDateTime) self.assertTrue(mt.dwHighDateTime == ct.dwHighDateTime == at.dwHighDateTime) - self.assertEqual(statstg.grfMode, 0) - self.assertEqual(statstg.grfLocksSupported, 0) - self.assertEqual(statstg.clsid, comtypes.GUID()) - self.assertEqual(statstg.grfStateBits, 0) + self.assertEqual(stat.grfMode, 0) + self.assertEqual(stat.grfLocksSupported, 0) + self.assertEqual(stat.clsid, comtypes.GUID()) + self.assertEqual(stat.grfStateBits, 0) + name_ptr = _get_pwcsname(stat) + self.assertIsNone(name_ptr.value) + malloc = CoGetMalloc() + self.assertEqual(malloc.DidAlloc(name_ptr), -1) + del stat # `pwcsName` is expected to be freed here. + # `DidAlloc` checks are skipped to avoid using a dangling pointer. class Test_Clone(ut.TestCase): @@ -274,11 +285,18 @@ def test_can_lock_file_based_stream(self): # Cleanup: Close descriptors and release the lock os.close(fd) stm.UnlockRegion(0, 5, LOCK_EXCLUSIVE) - buf, read = stm.RemoteRead(stm.Stat(STATFLAG_DEFAULT).cbSize) + stat = stm.Stat(STATFLAG_DEFAULT) + buf, read = stm.RemoteRead(stat.cbSize) # Verify that COM stream content reflects the successful out-of-lock write self.assertEqual(bytearray(buf)[0:read], b"\x00\x00\x00\x00\x00ABCDE") # Verify that the actual file content on disk matches the expected data self.assertEqual(tmpfile.read_bytes(), b"\x00\x00\x00\x00\x00ABCDE") + name_ptr = _get_pwcsname(stat) + self.assertEqual(name_ptr.value, stat.pwcsName) + malloc = CoGetMalloc() + self.assertEqual(malloc.DidAlloc(name_ptr), 1) + del stat # `pwcsName` is expected to be freed here. + # `DidAlloc` checks are skipped to avoid using a dangling pointer. # TODO: If there is a standard Windows `IStream` implementation that supports