diff --git a/comtypes/malloc.py b/comtypes/malloc.py index d4543d41..fcde6264 100644 --- a/comtypes/malloc.py +++ b/comtypes/malloc.py @@ -1,27 +1,9 @@ -import logging -from ctypes import ( - HRESULT, - POINTER, - OleDLL, - WinDLL, - byref, - c_int, - c_size_t, - c_ulong, - c_void_p, - c_wchar, - c_wchar_p, - cast, - memmove, - sizeof, - wstring_at, -) +from ctypes import HRESULT, POINTER, OleDLL, WinDLL, c_int, c_size_t, c_ulong, c_void_p from ctypes.wintypes import DWORD, LPVOID +from typing import TYPE_CHECKING, Any, Optional from comtypes import COMMETHOD, GUID, IUnknown -from comtypes.GUID import _CoTaskMemFree - -logger = logging.getLogger(__name__) +from comtypes.GUID import _CoTaskMemFree as _CoTaskMemFree class IMalloc(IUnknown): @@ -34,6 +16,14 @@ class IMalloc(IUnknown): COMMETHOD([], c_int, "DidAlloc", ([], c_void_p, "pv")), COMMETHOD([], None, "HeapMinimize"), # 25 ] + if TYPE_CHECKING: + + def Alloc(self, cb: int) -> Optional[int]: ... + def Realloc(self, pv: Any, cb: int) -> Optional[int]: ... + def Free(self, py: Any) -> None: ... + def GetSize(self, pv: Any) -> int: ... + def DidAlloc(self, pv: Any) -> int: ... + def HeapMinimize(self) -> None: ... _ole32 = OleDLL("ole32") @@ -48,30 +38,3 @@ class IMalloc(IUnknown): _CoTaskMemAlloc = _ole32_nohresult.CoTaskMemAlloc _CoTaskMemAlloc.argtypes = [SIZE_T] _CoTaskMemAlloc.restype = LPVOID - -malloc = POINTER(IMalloc)() -_CoGetMalloc(1, byref(malloc)) -assert bool(malloc) - - -def from_outparam(self): - if not self: - return None - result = wstring_at(self) - # `DidAlloc` method returns; - # * 1 (allocated) - # * 0 (not allocated) - # * -1 (cannot determine or NULL) - # https://learn.microsoft.com/en-us/windows/win32/api/objidl/nf-objidl-imalloc-didalloc - assert malloc.DidAlloc(self), "memory was NOT allocated by CoTaskMemAlloc" - _CoTaskMemFree(self) - return result - - -def comstring(text, typ=c_wchar_p): - size = (len(text) + 1) * sizeof(c_wchar) - mem = _CoTaskMemAlloc(size) - logger.debug("malloc'd 0x%x, %d bytes" % (mem, size)) - ptr = cast(mem, typ) - memmove(mem, text, size) - return ptr diff --git a/comtypes/test/test_malloc.py b/comtypes/test/test_malloc.py index 597e58b7..f3c5fe25 100644 --- a/comtypes/test/test_malloc.py +++ b/comtypes/test/test_malloc.py @@ -1,6 +1,68 @@ import unittest as ut +from ctypes import HRESULT, POINTER, OleDLL, byref +from ctypes.wintypes import DWORD, HANDLE, LPWSTR +from pathlib import Path -from comtypes.malloc import IMalloc # noqa +from comtypes import GUID, hresult +from comtypes.malloc import IMalloc, _CoGetMalloc, _CoTaskMemFree +# Constants +# KNOWNFOLDERID +# https://learn.microsoft.com/en-us/windows/win32/shell/knownfolderid +FOLDERID_System = GUID("{1AC14E77-02E7-4E5D-B744-2EB1AE5198B7}") +# https://learn.microsoft.com/en-us/windows/win32/api/shlobj_core/ne-shlobj_core-known_folder_flag +KF_FLAG_DEFAULT = 0x00000000 -class Test(ut.TestCase): ... +_shell32 = OleDLL("shell32") +_SHGetKnownFolderPath = _shell32.SHGetKnownFolderPath +_SHGetKnownFolderPath.argtypes = [ + POINTER(GUID), # rfid + DWORD, # dwFlags + HANDLE, # hToken + POINTER(LPWSTR), # ppszPath +] +_SHGetKnownFolderPath.restype = HRESULT + + +def _get_malloc() -> IMalloc: + malloc = POINTER(IMalloc)() + _CoGetMalloc(1, byref(malloc)) + assert bool(malloc) + return malloc # type: ignore + + +class Test(ut.TestCase): + def test_Realloc(self): + malloc = _get_malloc() + size1 = 4 + ptr1 = malloc.Alloc(size1) + self.assertEqual(malloc.DidAlloc(ptr1), 1) + self.assertEqual(malloc.GetSize(ptr1), size1) + size2 = size1 - 1 + ptr2 = malloc.Realloc(ptr1, size2) + self.assertEqual(malloc.DidAlloc(ptr2), 1) + self.assertEqual(malloc.GetSize(ptr2), size2) + size3 = size1 + 1 + ptr3 = malloc.Realloc(ptr2, size3) + self.assertEqual(malloc.DidAlloc(ptr3), 1) + self.assertEqual(malloc.GetSize(ptr3), size3) + malloc.Free(ptr3) + self.assertEqual(malloc.DidAlloc(ptr3), 0) + malloc.HeapMinimize() + del ptr3 + + def test_SHGetKnownFolderPath(self): + ptr = LPWSTR() + hr = _SHGetKnownFolderPath( + byref(FOLDERID_System), KF_FLAG_DEFAULT, None, byref(ptr) + ) + self.assertEqual(hr, hresult.S_OK) + self.assertIsInstance(ptr.value, str) + self.assertTrue(Path(ptr.value).exists()) # type: ignore + malloc = _get_malloc() + self.assertEqual(malloc.DidAlloc(ptr), 1) + self.assertGreater(malloc.GetSize(ptr), 0) + _CoTaskMemFree(ptr) + self.assertEqual(malloc.DidAlloc(ptr), 0) + malloc.HeapMinimize() + del ptr diff --git a/comtypes/test/test_from_outparam.py b/comtypes/test/test_outparam.py similarity index 65% rename from comtypes/test/test_from_outparam.py rename to comtypes/test/test_outparam.py index bd4c05a8..356b6187 100644 --- a/comtypes/test/test_from_outparam.py +++ b/comtypes/test/test_outparam.py @@ -1,56 +1,13 @@ import logging import unittest -from ctypes import ( - HRESULT, - POINTER, - OleDLL, - WinDLL, - byref, - c_int, - c_size_t, - c_ulong, - c_void_p, - c_wchar, - c_wchar_p, - cast, - memmove, - sizeof, - wstring_at, -) -from ctypes.wintypes import DWORD, LPVOID +from ctypes import POINTER, byref, c_wchar, c_wchar_p, cast, memmove, sizeof, wstring_at from unittest.mock import patch -from comtypes import COMMETHOD, GUID, IUnknown -from comtypes.GUID import _CoTaskMemFree +from comtypes.malloc import IMalloc, _CoGetMalloc, _CoTaskMemAlloc, _CoTaskMemFree logger = logging.getLogger(__name__) -class IMalloc(IUnknown): - _iid_ = GUID("{00000002-0000-0000-C000-000000000046}") - _methods_ = [ - COMMETHOD([], c_void_p, "Alloc", ([], c_ulong, "cb")), - COMMETHOD([], c_void_p, "Realloc", ([], c_void_p, "pv"), ([], c_ulong, "cb")), - COMMETHOD([], None, "Free", ([], c_void_p, "py")), - COMMETHOD([], c_ulong, "GetSize", ([], c_void_p, "pv")), - COMMETHOD([], c_int, "DidAlloc", ([], c_void_p, "pv")), - COMMETHOD([], None, "HeapMinimize"), # 25 - ] - - -_ole32 = OleDLL("ole32") - -_CoGetMalloc = _ole32.CoGetMalloc -_CoGetMalloc.argtypes = [DWORD, POINTER(POINTER(IMalloc))] -_CoGetMalloc.restype = HRESULT - -_ole32_nohresult = WinDLL("ole32") - -SIZE_T = c_size_t -_CoTaskMemAlloc = _ole32_nohresult.CoTaskMemAlloc -_CoTaskMemAlloc.argtypes = [SIZE_T] -_CoTaskMemAlloc.restype = LPVOID - malloc = POINTER(IMalloc)() _CoGetMalloc(1, byref(malloc)) assert bool(malloc) diff --git a/comtypes/test/test_urlhistory.py b/comtypes/test/test_urlhistory.py index 62125737..083aa32b 100644 --- a/comtypes/test/test_urlhistory.py +++ b/comtypes/test/test_urlhistory.py @@ -4,7 +4,7 @@ from ctypes import * from comtypes.client import CreateObject, GetModule -from comtypes.GUID import _CoTaskMemFree +from comtypes.malloc import _CoTaskMemFree from comtypes.patcher import Patch # ./urlhist.tlb was downloaded somewhere from the internet (?) diff --git a/pyproject.toml b/pyproject.toml index b44e4dd3..f9cea138 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ ignore = ["E402"] "comtypes/test/test_client.py" = ["F401"] "comtypes/test/test_dict.py" = ["F841"] "comtypes/test/test_eventinterface.py" = ["F841"] -"comtypes/test/test_from_outparam.py" = ["F841"] +"comtypes/test/test_outparam.py" = ["F841"] "comtypes/test/test_sapi.py" = ["E401"] "comtypes/test/test_server.py" = ["F401", "F841"] "comtypes/test/test_subinterface.py" = ["E401", "F401", "F403", "F405"]