Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.11", "3.12", "3.13"]
python-version: ["3.11", "3.12", "3.13", "3.14"]

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "zenlib"
version = "3.1.8"
version = "3.1.9"
authors = [
{ name="Desultory", email="dev@pyl.onl" },
]
Expand Down
9 changes: 7 additions & 2 deletions src/zenlib/namespace/namespace_process.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from multiprocessing import Event, Pipe, Process, Queue
from os import chroot, chdir, getgid, getuid, setgid, setuid
from getpass import getuser
from multiprocessing import Event, Pipe, Process, Queue
from os import chdir, chroot, getgid, getuid, setgid, setuid

from .namespace import get_id_map, new_id_map, unshare_namespace

Expand All @@ -18,6 +18,7 @@ def __init__(self, target=None, args=None, kwargs=None, **ekwargs):
self.orig_uid = getuid()
self.orig_gid = getgid()
self.uidmapped = Event()
self.unshared = Event()
self.completed = Event()
self.exception_recv, self.exception_send = Pipe()
self.function_queue = Queue()
Expand All @@ -29,11 +30,15 @@ def map_ids(self):

def map_unshare_uids(self):
self.start()

self.unshared.wait()
self.map_ids()

self.uidmapped.set()

def run(self):
unshare_namespace()
self.unshared.set()
self.uidmapped.wait()
setuid(0)
setgid(0)
Expand Down
67 changes: 34 additions & 33 deletions src/zenlib/types/validated_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,42 @@
from typing import ForwardRef, Union, get_args, get_origin, get_type_hints


class ValidatedDataclass:
def __setattr__(self, attribute, value):
value = self._validate_attribute(attribute, value)
super().__setattr__(attribute, value)

def _validate_attribute(self, attribute, value):
"""Ensures the attribute is the correct type"""
if attribute == "logger":
return value
if value is None:
return

expected_type = self.__class__.__annotations__.get(attribute)
if not expected_type:
return value # No type hint, so we can't validate it
if get_origin(expected_type) is Union and isinstance(get_args(expected_type)[0], ForwardRef):
expected_type = get_type_hints(self.__class__)[attribute]

if not isinstance(value, expected_type):
try:
value = expected_type(value)
except ValueError:
raise TypeError(f"[{attribute}] Type mismatch: '{expected_type}' != {type(value)}")
return value


def validatedDataclass(cls):
from zenlib.logging import loggify
from zenlib.util import merge_class

cls = loggify(dataclass(cls))
base_annotations = {}
annotations = {}
for base in cls.__mro__:
base_annotations.update(getattr(base, "__annotations__", {}))

cls.__annotations__.update(base_annotations)

class ValidatedDataclass(cls):
def __setattr__(self, attribute, value):
value = self._validate_attribute(attribute, value)
super().__setattr__(attribute, value)

def _validate_attribute(self, attribute, value):
"""Ensures the attribute is the correct type"""
if attribute == "logger":
return value
if value is None:
return

expected_type = self.__class__.__annotations__.get(attribute)
if not expected_type:
return value # No type hint, so we can't validate it
if get_origin(expected_type) is Union and isinstance(get_args(expected_type)[0], ForwardRef):
expected_type = get_type_hints(self.__class__)[attribute]

if not isinstance(value, expected_type):
try:
value = expected_type(value)
except ValueError:
raise TypeError(f"[{attribute}] Type mismatch: '{expected_type}' != {type(value)}")
return value
annotations.update(getattr(base, "__annotations__", {}))

cls_dict = dict(cls.__dict__)
cls_dict["__annotations__"] = annotations

vdc = loggify(dataclass(type(cls.__name__, (ValidatedDataclass, cls), cls_dict)))

merge_class(cls, ValidatedDataclass, ignored_attributes=["__setattr__"])
return ValidatedDataclass
return vdc
20 changes: 20 additions & 0 deletions tests/test_validated_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,38 @@ class testDataClass:
b: str = None


@validatedDataclass
class anotherDataClass:
x: float = 0.0
y: str = "default"


class TestValidatedDataclass(TestCase):
def test_validated_dataclass(self):
c = testDataClass()
c.a = 1
c.b = "test"
self.assertTrue(hasattr(c, "logger"))

def test_default_values(self):
d = anotherDataClass()
self.assertEqual(d.x, 0.0)
self.assertEqual(d.y, "default")
d.x = 3.14
d.y = "hello"
self.assertEqual(d.x, 3.14)
self.assertEqual(d.y, "hello")

def test_bad_type(self):
c = testDataClass()
d = anotherDataClass()

with self.assertRaises(TypeError):
c.a = "test"

with self.assertRaises(TypeError):
d.x = "not a float"


if __name__ == "__main__":
main()