Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions linux/dump_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ def make_dummy(header, wd):
else:
with open(os.path.join(wd, "dummy.c"), "w") as f:
f.write("/* Autogenerated! */\n")
f.write("#include <{}>\n".format(header))
f.write(f"#include <{header}>\n")


def run_make(struct, output, wd):
args = ["make", "LAYOUT_OUTPUT={}".format(os.path.abspath(output))]
args = ["make", f"LAYOUT_OUTPUT={os.path.abspath(output)}"]
if struct:
args.append("TARGET_STRUCT={}".format(struct))
args.append(f"TARGET_STRUCT={struct}")
subprocess.check_call(args, cwd=wd)


Expand Down
92 changes: 54 additions & 38 deletions python/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@ def __init__(self, total_size):
self.total_size = total_size

def __eq__(self, other):
if not isinstance(other, Type):
return NotImplemented

return self.total_size == other.total_size
return (
self.total_size == other.total_size
if isinstance(other, Type)
else NotImplemented
)


class Void(Type):
def __init__(self):
super(Void, self).__init__(0)

def __eq__(self, other):
if not isinstance(other, Void):
return NotImplemented

return super(Void, self).__eq__(other)
return (
super(Void, self).__eq__(other)
if isinstance(other, Void)
else NotImplemented
)

def __repr__(self):
return "Void()"
Expand All @@ -29,10 +31,11 @@ def __init__(self, struct_name):
self.struct_name = struct_name

def __eq__(self, other):
if not isinstance(other, UnknownStructType):
return NotImplemented

return super(UnknownStructType, self).__eq__(other)
return (
super(UnknownStructType, self).__eq__(other)
if isinstance(other, UnknownStructType)
else NotImplemented
)

def __repr__(self):
return "UnknownStruct({!r})".format(self.struct_name)
Expand All @@ -44,10 +47,11 @@ def __init__(self, total_size, signed):
self.signed = signed

def __eq__(self, other):
if not isinstance(other, Bitfield):
return NotImplemented

return self.signed == other.signed and super(Bitfield, self).__eq__(other)
return (
self.signed == other.signed and super(Bitfield, self).__eq__(other)
if isinstance(other, Bitfield)
else NotImplemented
)

def __repr__(self):
return "Bitfield({!r}, {})".format(self.total_size, self.signed)
Expand All @@ -60,11 +64,15 @@ def __init__(self, total_size, type_, signed):
self.signed = signed

def __eq__(self, other):
if not isinstance(other, Scalar):
return NotImplemented

return (self.type == other.type and self.signed == other.signed
and super(Scalar, self).__eq__(other))
return (
(
self.type == other.type
and self.signed == other.signed
and super(Scalar, self).__eq__(other)
)
if isinstance(other, Scalar)
else NotImplemented
)

def __repr__(self):
return "Scalar({!r}, {!r}, {!r})".format(self.total_size, self.type, self.signed)
Expand All @@ -76,10 +84,11 @@ def __init__(self, total_size, type_):
self.type = type_

def __eq__(self, other):
if not isinstance(other, StructField):
return NotImplemented

return self.type == other.type and super(StructField, self).__eq__(other)
return (
self.type == other.type and super(StructField, self).__eq__(other)
if isinstance(other, StructField)
else NotImplemented
)

def __repr__(self):
return "StructField({!r}, {!r})".format(self.total_size, self.type)
Expand All @@ -91,10 +100,11 @@ def __init__(self, type_=None):
self.type = type_

def __eq__(self, other):
if not isinstance(other, Function):
return NotImplemented

return self.type == other.type and super(Function, self).__eq__(other)
return (
self.type == other.type and super(Function, self).__eq__(other)
if isinstance(other, Function)
else NotImplemented
)

def __repr__(self):
return "Function({!r})".format(self.type)
Expand All @@ -106,10 +116,12 @@ def __init__(self, total_size, pointed_type):
self.pointed_type = pointed_type

def __eq__(self, other):
if not isinstance(other, Pointer):
return NotImplemented

return self.pointed_type == other.pointed_type and super(Pointer, self).__eq__(other)
return (
self.pointed_type == other.pointed_type
and super(Pointer, self).__eq__(other)
if isinstance(other, Pointer)
else NotImplemented
)

def __repr__(self):
return "Pointer({!r}, {!r})".format(self.total_size, self.pointed_type)
Expand All @@ -122,11 +134,15 @@ def __init__(self, total_size, num_elem, elem_type):
self.elem_type = elem_type

def __eq__(self, other):
if not isinstance(other, Array):
return NotImplemented

return (self.num_elem == other.num_elem and self.elem_type == other.elem_type
and super(Array, self).__eq__(other))
return (
(
self.num_elem == other.num_elem
and self.elem_type == other.elem_type
and super(Array, self).__eq__(other)
)
if isinstance(other, Array)
else NotImplemented
)

def __repr__(self):
return "Array({!r}, {!r}, {!r})".format(self.total_size, self.num_elem, self.elem_type)
Expand Down
86 changes: 42 additions & 44 deletions python/struct_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def lookup_struct(s):


def _access_addr(field, base, offset):
if 0 == base:
if base == 0:
raise ValueError("NULL deref! offset {!r} type {!r}".format(offset, field))
return _make_addr(base, offset)

Expand Down Expand Up @@ -131,9 +131,8 @@ def _check_value_overflow(value, bits, signed):
if signed:
if not -(1 << bits) <= value < (1 << (bits - 1)):
raise ValueError("{!r} doesn't fit in signed {}-bits!".format(value, bits))
else:
if not (0 <= value < (1 << bits)):
raise ValueError("{!r} doesn't fit in unsigned {}-bits!".format(value, bits))
elif not (0 <= value < (1 << bits)):
raise ValueError("{!r} doesn't fit in unsigned {}-bits!".format(value, bits))


def _write_accessor(field, base, offset, value):
Expand All @@ -152,18 +151,16 @@ def _write_accessor(field, base, offset, value):
value = to_int(value)
_check_value_overflow(value, field.total_size, False)
ACCESSORS[field.total_size](addr, value)
# give more indicative errors for struct / array
elif isinstance(field, StructField):
raise TypeError("Can't set a struct! Please set its fields instead")
elif isinstance(field, Array):
if isinstance(value, (str, bytes)):
if isinstance(value, str):
value = value.encode("ascii")
if len(value) > field.total_size // 8:
raise ValueError("Buffer overflow!")
ACCESSORS[0](addr, value, len(value))
else:
if not isinstance(value, (str, bytes)):
raise TypeError("Can't set an array! Please set its elements instead")
if isinstance(value, str):
value = value.encode("ascii")
if len(value) > field.total_size // 8:
raise ValueError("Buffer overflow!")
ACCESSORS[0](addr, value, len(value))
else:
raise NotImplementedError("_write_accessor for {!r}".format(field))

Expand All @@ -183,10 +180,11 @@ def __setitem__(self, key, value):
return _write_accessor(self._type, self.____ptr, key * self._type.total_size, value)

def __eq__(self, other):
if not isinstance(other, Ptr):
return NotImplemented

return self._type == other._type and self.____ptr == other.____ptr
return (
self._type == other._type and self.____ptr == other.____ptr
if isinstance(other, Ptr)
else NotImplemented
)

def __repr__(self):
return "Ptr({!r}, 0x{:x})".format(self._type, self.____ptr)
Expand All @@ -195,10 +193,7 @@ def __int__(self):
return self.____ptr

def __add__(self, other):
if not isinstance(other, int):
return NotImplemented

return self.____ptr + other
return self.____ptr + other if isinstance(other, int) else NotImplemented

def __call__(self, *args):
if not isinstance(self._type, Function):
Expand Down Expand Up @@ -230,11 +225,15 @@ def __setitem__(self, key, value):
return _write_accessor(self._elem_type, self.____ptr, key * self._elem_type.total_size, value)

def __eq__(self, other):
if not isinstance(other, ArrayPtr):
return NotImplemented

return (self.____ptr == other.____ptr and self._num_elem == other._num_elem and
self._elem_type == other._elem_type)
return (
(
self.____ptr == other.____ptr
and self._num_elem == other._num_elem
and self._elem_type == other._elem_type
)
if isinstance(other, ArrayPtr)
else NotImplemented
)

def __len__(self):
return self._num_elem
Expand All @@ -247,18 +246,14 @@ def __int__(self):

def read(self, n=None):
n = n if n is not None else self._num_elem
items = []
for i in range(n):
items.append(self[i])

if self._elem_type == ArrayPtr.CHAR_TYPE:
# special case: if type is "char", convert to string
s = "".join(map(chr, items))
if s.find('\x00') != -1:
s = s[:s.find('\x00')]
return s
else:
items = [self[i] for i in range(n)]
if self._elem_type != ArrayPtr.CHAR_TYPE:
return items
# special case: if type is "char", convert to string
s = "".join(map(chr, items))
if '\x00' in s:
s = s[:s.find('\x00')]
return s


def _get_sp_struct(sp):
Expand Down Expand Up @@ -300,11 +295,14 @@ def __dir__(self):
return list(_get_sp_struct(self).fields.keys())

def __eq__(self, other):
if not isinstance(other, StructPtr):
return NotImplemented

return (_get_sp_struct(self) == _get_sp_struct(other)
and _get_sp_ptr(self) == _get_sp_ptr(other))
return (
(
_get_sp_struct(self) == _get_sp_struct(other)
and _get_sp_ptr(self) == _get_sp_ptr(other)
)
if isinstance(other, StructPtr)
else NotImplemented
)

def __int__(self):
return _get_sp_ptr(self)
Expand All @@ -321,7 +319,7 @@ def to_int(p):
if n is not None:
return n

raise ValueError("Can't handle object of type {}".format(type(p)))
raise ValueError(f"Can't handle object of type {type(p)}")


def partial_struct(struct):
Expand Down Expand Up @@ -371,7 +369,7 @@ def _print_indented(s):
print(' ' * indent + s)

def _print_field_simple(field, val):
_print_indented(field + ' = ' + str(val))
_print_indented(f'{field} = {str(val)}')

fields = sp.____struct.fields
ordered_fields = sorted(fields.keys(), key=lambda k: fields[k][0])
Expand All @@ -395,6 +393,6 @@ def _print_field_simple(field, val):
_print_field_simple(field, val)
dump_struct(val, levels=levels - 1, indent=indent + 4)
elif isinstance(fields[field][1], Scalar):
_print_indented(fields[field][1].type + ' ' + field + ' = ' + str(val) + ' ' + hex(val))
_print_indented(f'{fields[field][1].type} {field} = {str(val)} {hex(val)}')
else:
_print_field_simple(field, val)
4 changes: 2 additions & 2 deletions tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_accessor_array_struct():
s = partial_struct(dump_struct_layout(
"struct x { long n; struct { int n; short s; char c; } a[3]; };", "x")["x"])(MEM_BASE)

for i in range(0, 3):
for i in range(3):
assert s.a[i].n == 3 * 10 ** i
assert s.a[i].s == 2 * 10 ** i
assert s.a[i].c == 1 * 10 ** i
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_accessor_set_array():
mem = set_memory_struct(">BBBBB", 0, 0, 0, 0, 0)
s = partial_struct(dump_struct_layout("struct x { char arr[5] };", "x")["x"])(MEM_BASE)

for i in range(0, len(s.arr)):
for i in range(len(s.arr)):
s.arr[i] = i + 1

assert mem == b"\x01\x02\x03\x04\x05"
Expand Down
9 changes: 6 additions & 3 deletions tests/test_struct_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@


def run_gcc(code_path, output_path, struct_name):
args = ["gcc", "-fplugin={}".format(STRUCT_LAYOUT_SO),
"-fplugin-arg-struct_layout-output={}".format(output_path)]
args = [
"gcc",
f"-fplugin={STRUCT_LAYOUT_SO}",
f"-fplugin-arg-struct_layout-output={output_path}",
]
if struct_name:
args.append("-fplugin-arg-struct_layout-struct={}".format(struct_name))
args.append(f"-fplugin-arg-struct_layout-struct={struct_name}")
args += ["-c", "-o", "/dev/null", "-x", "c", code_path]

subprocess.check_call(args)
Expand Down