Skip to content
17 changes: 15 additions & 2 deletions comtypes/malloc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ctypes import HRESULT, POINTER, OleDLL, WinDLL, c_int, c_size_t, c_ulong, c_void_p
from ctypes import HRESULT, POINTER, OleDLL, WinDLL, byref, c_int, c_ulong, c_void_p
from ctypes import c_size_t as SIZE_T
from ctypes.wintypes import DWORD, LPVOID
from typing import TYPE_CHECKING, Any, Optional

Expand Down Expand Up @@ -34,7 +35,19 @@ def HeapMinimize(self) -> None: ...

_ole32_nohresult = WinDLL("ole32")

SIZE_T = c_size_t
_CoTaskMemAlloc = _ole32_nohresult.CoTaskMemAlloc
_CoTaskMemAlloc.argtypes = [SIZE_T]
_CoTaskMemAlloc.restype = LPVOID


def CoGetMalloc(dwMemContext: int = 1) -> IMalloc:
"""Retrieves a pointer to the default OLE task memory allocator.

https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cogetmalloc
"""
malloc = POINTER(IMalloc)()
_CoGetMalloc(
dwMemContext, # This parameter must be 1.
byref(malloc),
)
return malloc # type: ignore
13 changes: 3 additions & 10 deletions comtypes/test/test_malloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path

from comtypes import GUID, hresult
from comtypes.malloc import IMalloc, _CoGetMalloc, _CoTaskMemFree
from comtypes.malloc import CoGetMalloc, _CoTaskMemFree

# Constants
# KNOWNFOLDERID
Expand All @@ -24,16 +24,9 @@
_SHGetKnownFolderPath.restype = HRESULT


def _get_malloc() -> IMalloc:
malloc = POINTER(IMalloc)()
_CoGetMalloc(1, byref(malloc))
assert bool(malloc)
return malloc # type: ignore


class Test(ut.TestCase):
def test_Realloc(self):
malloc = _get_malloc()
malloc = CoGetMalloc()
size1 = 4
ptr1 = malloc.Alloc(size1)
self.assertEqual(malloc.DidAlloc(ptr1), 1)
Expand All @@ -59,7 +52,7 @@ def test_SHGetKnownFolderPath(self):
self.assertEqual(hr, hresult.S_OK)
self.assertIsInstance(ptr.value, str)
self.assertTrue(Path(ptr.value).exists()) # type: ignore
malloc = _get_malloc()
malloc = CoGetMalloc()
self.assertEqual(malloc.DidAlloc(ptr), 1)
self.assertGreater(malloc.GetSize(ptr), 0)
_CoTaskMemFree(ptr)
Expand Down
7 changes: 3 additions & 4 deletions comtypes/test/test_outparam.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import logging
import unittest
from ctypes import POINTER, byref, c_wchar, c_wchar_p, cast, memmove, sizeof, wstring_at
from ctypes import c_wchar, c_wchar_p, cast, memmove, sizeof, wstring_at
from unittest.mock import patch

from comtypes.malloc import IMalloc, _CoGetMalloc, _CoTaskMemAlloc, _CoTaskMemFree
from comtypes.malloc import CoGetMalloc, _CoTaskMemAlloc, _CoTaskMemFree

logger = logging.getLogger(__name__)


malloc = POINTER(IMalloc)()
_CoGetMalloc(1, byref(malloc))
malloc = CoGetMalloc()
assert bool(malloc)


Expand Down
110 changes: 106 additions & 4 deletions comtypes/test/test_storage.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,45 @@
import ctypes
import os
import tempfile
import unittest
from _ctypes import COMError
from ctypes import HRESULT, POINTER, OleDLL, byref, c_ubyte
from ctypes.wintypes import DWORD, PWCHAR
from ctypes import HRESULT, POINTER, OleDLL, Structure, WinDLL, byref, c_ubyte
from ctypes.wintypes import BOOL, DWORD, FILETIME, LONG, PWCHAR, WORD
from pathlib import Path
from typing import Optional

import comtypes
import comtypes.client
from comtypes.malloc import CoGetMalloc

comtypes.client.GetModule("portabledeviceapi.dll")
from comtypes.gen.PortableDeviceApiLib import IStorage, tagSTATSTG
from comtypes.gen.PortableDeviceApiLib import WSTRING, IStorage, tagSTATSTG


class SYSTEMTIME(Structure):
_fields_ = [
("wYear", WORD),
("wMonth", WORD),
("wDayOfWeek", WORD),
("wDay", WORD),
("wHour", WORD),
("wMinute", WORD),
("wSecond", WORD),
("wMilliseconds", WORD),
]


_kernel32 = WinDLL("kernel32")

# https://learn.microsoft.com/en-us/windows/win32/api/timezoneapi/nf-timezoneapi-systemtimetofiletime
_SystemTimeToFileTime = _kernel32.SystemTimeToFileTime
_SystemTimeToFileTime.argtypes = [POINTER(SYSTEMTIME), POINTER(FILETIME)]
_SystemTimeToFileTime.restype = BOOL

# https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-comparefiletime
_CompareFileTime = _kernel32.CompareFileTime
_CompareFileTime.argtypes = [POINTER(FILETIME), POINTER(FILETIME)]
_CompareFileTime.restype = LONG

STGTY_STORAGE = 1

Expand All @@ -36,6 +65,20 @@
_StgCreateDocfile.restype = HRESULT


def _systemtime_to_filetime(st: SYSTEMTIME) -> FILETIME:
ft = FILETIME()
_SystemTimeToFileTime(byref(st), byref(ft))
return ft


def _compare_filetime(ft1: FILETIME, ft2: FILETIME) -> int:
return _CompareFileTime(byref(ft1), byref(ft2))


def _get_pwcsname(stat: tagSTATSTG) -> WSTRING:
return WSTRING.from_address(ctypes.addressof(stat) + tagSTATSTG.pwcsName.offset)


class Test_IStorage(unittest.TestCase):
RW_EXCLUSIVE = STGM_READWRITE | STGM_SHARE_EXCLUSIVE
RW_EXCLUSIVE_TX = RW_EXCLUSIVE | STGM_TRANSACTED
Expand All @@ -48,12 +91,17 @@ def _create_docfile(self, mode: int, name: Optional[str] = None) -> IStorage:
_StgCreateDocfile(name, mode, 0, byref(stg))
return stg # type: ignore

FIXED_TEST_FILETIME = _systemtime_to_filetime(
SYSTEMTIME(wYear=2000, wMonth=1, wDay=1)
)

def test_CreateStream(self):
storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)
# When created with `StgCreateDocfile(NULL, ...)`, `pwcsName` is a
# temporary filename. The file really exists on disk because Windows
# creates an actual temporary file for the compound storage.
filepath = Path(storage.Stat(STATFLAG_DEFAULT).pwcsName)
stat = storage.Stat(STATFLAG_DEFAULT)
filepath = Path(stat.pwcsName)
self.assertTrue(filepath.exists())
stream = storage.CreateStream("example", self.RW_EXCLUSIVE_CREATE, 0, 0)
test_data = b"Some data"
Expand All @@ -67,6 +115,12 @@ def test_CreateStream(self):
self.assertTrue(filepath.exists())
del storage
self.assertFalse(filepath.exists())
name_ptr = _get_pwcsname(stat)
self.assertEqual(name_ptr.value, stat.pwcsName)
malloc = CoGetMalloc()
self.assertEqual(malloc.DidAlloc(name_ptr), 1)
del stat # `pwcsName` is expected to be freed here.
# `DidAlloc` checks are skipped to avoid using a dangling pointer.

# TODO: Auto-generated methods based on type info are remote-side and hard
# to call from the client.
Expand Down Expand Up @@ -148,6 +202,32 @@ def test_RenameElement(self):
storage.OpenStorage("example", None, self.RW_EXCLUSIVE_TX, None, 0)
self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND)

def test_SetElementTimes(self):
storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)
sub_name = "SubStorageElement"
orig_stat = storage.CreateStorage(sub_name, self.CREATE_TESTDOC, 0, 0).Stat(
STATFLAG_DEFAULT
)
storage.SetElementTimes(
sub_name,
None, # pctime (creation time)
None, # patime (access time)
self.FIXED_TEST_FILETIME, # pmtime (modification time)
)
storage.Commit(STGC_DEFAULT)
modified_stat = storage.OpenStorage(
sub_name, None, self.RW_EXCLUSIVE_TX, None, 0
).Stat(STATFLAG_DEFAULT)
self.assertEqual(_compare_filetime(orig_stat.ctime, modified_stat.ctime), 0)
self.assertEqual(_compare_filetime(orig_stat.atime, modified_stat.atime), 0)
self.assertNotEqual(_compare_filetime(orig_stat.mtime, modified_stat.mtime), 0)
self.assertEqual(
_compare_filetime(self.FIXED_TEST_FILETIME, modified_stat.mtime), 0
)
with self.assertRaises(COMError) as cm:
storage.SetElementTimes("NonExistent", None, None, self.FIXED_TEST_FILETIME)
self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND)

def test_SetClass(self):
storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)
# Initial value is CLSID_NULL.
Expand Down Expand Up @@ -176,7 +256,23 @@ def test_Stat(self):
stat = storage.Stat(STATFLAG_DEFAULT)
self.assertIsInstance(stat, tagSTATSTG)
del storage # Release the storage to prevent 'cannot access the file ...'
# Validate each field:
self.assertEqual(
os.path.normcase(os.path.normpath(Path(stat.pwcsName))),
os.path.normcase(os.path.normpath(tmpfile)),
)
self.assertEqual(stat.type, STGTY_STORAGE)
# Timestamps (`mtime`, `ctime`, `atime`) are set by the underlying
# compound file implementation.
# In many cases (especially on modern Windows with NTFS), all three
# timestamps are set to the same value at creation time. However, this
# is not guaranteed by the OLE32 specification.
# Therefore, we only verify that each timestamp is a valid `FILETIME`
# (non-zero is sufficient for a newly created file).
zero_ft = FILETIME()
self.assertNotEqual(_compare_filetime(stat.ctime, zero_ft), 0)
self.assertNotEqual(_compare_filetime(stat.atime, zero_ft), 0)
self.assertNotEqual(_compare_filetime(stat.mtime, zero_ft), 0)
# Due to header overhead and file system allocation, the size may be
# greater than 0 bytes.
self.assertGreaterEqual(stat.cbSize, 0)
Expand All @@ -185,3 +281,9 @@ def test_Stat(self):
self.assertEqual(stat.grfLocksSupported, 0)
self.assertEqual(stat.clsid, comtypes.GUID()) # CLSID_NULL for new creation.
self.assertEqual(stat.grfStateBits, 0)
name_ptr = _get_pwcsname(stat)
self.assertEqual(name_ptr.value, stat.pwcsName)
malloc = CoGetMalloc()
self.assertEqual(malloc.DidAlloc(name_ptr), 1)
del stat # `pwcsName` is expected to be freed here.
# `DidAlloc` checks are skipped to avoid using a dangling pointer.
42 changes: 30 additions & 12 deletions comtypes/test/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@

import comtypes.client
from comtypes import hresult
from comtypes.malloc import CoGetMalloc

comtypes.client.GetModule("portabledeviceapi.dll")
# The stdole module is generated automatically during the portabledeviceapi
# module generation.
import comtypes.gen.stdole as stdole
from comtypes.gen.PortableDeviceApiLib import IStream
from comtypes.gen.PortableDeviceApiLib import WSTRING, IStream, tagSTATSTG

SIZE_T = c_size_t

Expand Down Expand Up @@ -110,6 +111,10 @@ def _create_stream_on_file(
return stream # type: ignore


def _get_pwcsname(stat: tagSTATSTG) -> WSTRING:
return WSTRING.from_address(ctypes.addressof(stat) + tagSTATSTG.pwcsName.offset)


class Test_RemoteWrite(ut.TestCase):
def test_RemoteWrite(self):
stream = _create_stream_on_hglobal()
Expand Down Expand Up @@ -206,19 +211,25 @@ def test_RemoteCopyTo(self):
class Test_Stat(ut.TestCase):
# https://learn.microsoft.com/en-us/windows/win32/api/objidl/nf-objidl-istream-stat
# https://learn.microsoft.com/en-us/windows/win32/api/objidl/ns-objidl-statstg
def test_returns_statstg_from_no_modified_stream(self):
def test_returns_stat_from_no_modified_stream(self):
stream = _create_stream_on_hglobal()
statstg = stream.Stat(STATFLAG_DEFAULT)
self.assertIsNone(statstg.pwcsName)
self.assertEqual(statstg.type, STGTY_STREAM)
self.assertEqual(statstg.cbSize, 0)
mt, ct, at = statstg.mtime, statstg.ctime, statstg.atime
stat = stream.Stat(STATFLAG_DEFAULT)
self.assertIsNone(stat.pwcsName)
self.assertEqual(stat.type, STGTY_STREAM)
self.assertEqual(stat.cbSize, 0)
mt, ct, at = stat.mtime, stat.ctime, stat.atime
self.assertTrue(mt.dwLowDateTime == ct.dwLowDateTime == at.dwLowDateTime)
self.assertTrue(mt.dwHighDateTime == ct.dwHighDateTime == at.dwHighDateTime)
self.assertEqual(statstg.grfMode, 0)
self.assertEqual(statstg.grfLocksSupported, 0)
self.assertEqual(statstg.clsid, comtypes.GUID())
self.assertEqual(statstg.grfStateBits, 0)
self.assertEqual(stat.grfMode, 0)
self.assertEqual(stat.grfLocksSupported, 0)
self.assertEqual(stat.clsid, comtypes.GUID())
self.assertEqual(stat.grfStateBits, 0)
name_ptr = _get_pwcsname(stat)
self.assertIsNone(name_ptr.value)
malloc = CoGetMalloc()
self.assertEqual(malloc.DidAlloc(name_ptr), -1)
del stat # `pwcsName` is expected to be freed here.
# `DidAlloc` checks are skipped to avoid using a dangling pointer.


class Test_Clone(ut.TestCase):
Expand Down Expand Up @@ -274,11 +285,18 @@ def test_can_lock_file_based_stream(self):
# Cleanup: Close descriptors and release the lock
os.close(fd)
stm.UnlockRegion(0, 5, LOCK_EXCLUSIVE)
buf, read = stm.RemoteRead(stm.Stat(STATFLAG_DEFAULT).cbSize)
stat = stm.Stat(STATFLAG_DEFAULT)
buf, read = stm.RemoteRead(stat.cbSize)
# Verify that COM stream content reflects the successful out-of-lock write
self.assertEqual(bytearray(buf)[0:read], b"\x00\x00\x00\x00\x00ABCDE")
# Verify that the actual file content on disk matches the expected data
self.assertEqual(tmpfile.read_bytes(), b"\x00\x00\x00\x00\x00ABCDE")
name_ptr = _get_pwcsname(stat)
self.assertEqual(name_ptr.value, stat.pwcsName)
malloc = CoGetMalloc()
self.assertEqual(malloc.DidAlloc(name_ptr), 1)
del stat # `pwcsName` is expected to be freed here.
# `DidAlloc` checks are skipped to avoid using a dangling pointer.


# TODO: If there is a standard Windows `IStream` implementation that supports
Expand Down