mirror of
https://github.com/google/flatbuffers.git
synced 2026-06-07 13:53:38 +00:00
Bugfix __eq__ for numpy data types (#8646)
* [Python] Sync PythonTest.sh flags with generate_code.py * [Python] Update generated code to latest flatc version for tests * [Python] Fix test support for numpy newer than 2.0.0 * [Python] Remove unused variable * [Python] Fix __eq__ for numpy arrays * [Python] Run clang-format over the entire file
This commit is contained in:
@@ -282,8 +282,8 @@ class MonsterExtraT(object):
|
||||
self.f1 == other.f1 and \
|
||||
self.f2 == other.f2 and \
|
||||
self.f3 == other.f3 and \
|
||||
self.dvec == other.dvec and \
|
||||
self.fvec == other.fvec
|
||||
np.array_equal(self.dvec, other.dvec) and \
|
||||
np.array_equal(self.fvec, other.fvec)
|
||||
|
||||
# MonsterExtraT
|
||||
def _UnPack(self, monsterExtra):
|
||||
|
||||
@@ -26,7 +26,7 @@ ${test_dir}/../flatc -p -o ${gen_code_path} -I include_test monster_test.fbs --g
|
||||
${test_dir}/../flatc -p -o ${gen_code_path} -I include_test monster_test.fbs --gen-object-api --gen-onefile
|
||||
${test_dir}/../flatc -p -o ${gen_code_path} -I include_test monster_extra.fbs --gen-object-api --python-typing --gen-compare
|
||||
${test_dir}/../flatc -p -o ${gen_code_path} -I include_test arrays_test.fbs --gen-object-api --python-typing
|
||||
${test_dir}/../flatc -p -o ${gen_code_path} -I include_test nested_union_test.fbs --gen-object-api --python-typing
|
||||
${test_dir}/../flatc -p -o ${gen_code_path} -I include_test nested_union_test.fbs --gen-object-api --python-typing --python-decode-obj-api-strings
|
||||
${test_dir}/../flatc -p -o ${gen_code_path} -I include_test service_test.fbs --grpc --grpc-python-typed-handlers --python-typing --no-python-gen-numpy --gen-onefile
|
||||
|
||||
# Syntax: run_tests <interpreter> <benchmark vtable dedupes>
|
||||
|
||||
@@ -97,6 +97,27 @@ def assertRaises(test_case, fn, exception_class):
|
||||
test_case.assertTrue(isinstance(exc, exception_class))
|
||||
|
||||
|
||||
def byte_swap_array(np_version, arr):
|
||||
"""
|
||||
Performs byte swapping on a NumPy array, adapting to different NumPy versions.
|
||||
|
||||
Args:
|
||||
np_version: Version of NumPu (np.__version__)
|
||||
arr: The input NumPy array.
|
||||
|
||||
Returns:
|
||||
A new NumPy array with byte order swapped.
|
||||
"""
|
||||
numpy_version_tuple = tuple(map(int, np_version.split('.')[:3]))
|
||||
min_version_for_new_method_tuple = (2, 0, 0)
|
||||
|
||||
if numpy_version_tuple >= min_version_for_new_method_tuple:
|
||||
# 'S' indicates swap byte order.
|
||||
return arr.byteswap().view(arr.dtype.newbyteorder('S'))
|
||||
else:
|
||||
return arr.byteswap().newbyteorder()
|
||||
|
||||
|
||||
class TestWireFormat(unittest.TestCase):
|
||||
|
||||
def test_wire_format(self):
|
||||
@@ -1070,6 +1091,29 @@ class TestByteLayout(unittest.TestCase):
|
||||
# 1-byte pad:
|
||||
self.assertBuilderEquals(b, [3, 0, 0, 0, 1, 2, 3, 0])
|
||||
|
||||
def test_comparison_of_np_arrays(self):
|
||||
"""
|
||||
MonsterT dvec and fvec are np.array types which can not be compared with == directly
|
||||
This tests ensures that the __eq__ is generated correctly
|
||||
"""
|
||||
try:
|
||||
# if numpy exists, then we should be able to get the
|
||||
# vector as a numpy array
|
||||
import numpy as np
|
||||
vec1 = np.array([1, 2], dtype=np.float32)
|
||||
vec2 = np.array([3, 4], dtype=np.float32)
|
||||
|
||||
monsterA = MyGame.MonsterExtra.MonsterExtraT(d0=1, d1=1, d2=1, d3=1, f0=1, f1=1, f2=1, f3=1, dvec=vec1, fvec=vec2)
|
||||
assert monsterA == monsterA
|
||||
|
||||
monsterB = MyGame.MonsterExtra.MonsterExtraT(d0=2, d1=1, d2=1, d3=1, f0=1, f1=1, f2=1, f3=1, dvec=vec1, fvec=vec2)
|
||||
assert monsterA != monsterB
|
||||
except ImportError:
|
||||
b = flatbuffers.Builder(0)
|
||||
x = 0
|
||||
assertRaises(self, lambda: b.CreateNumpyVector(x),
|
||||
NumpyRequiredForThisFeature)
|
||||
|
||||
def test_create_numpy_vector_int8(self):
|
||||
try:
|
||||
# if numpy exists, then we should be able to get the
|
||||
@@ -1095,7 +1139,7 @@ class TestByteLayout(unittest.TestCase):
|
||||
|
||||
# Reverse endian:
|
||||
b = flatbuffers.Builder(0)
|
||||
x_other_endian = x.byteswap().newbyteorder()
|
||||
x_other_endian = byte_swap_array(np.__version__, x)
|
||||
b.CreateNumpyVector(x_other_endian)
|
||||
self.assertBuilderEquals(
|
||||
b,
|
||||
@@ -1144,7 +1188,7 @@ class TestByteLayout(unittest.TestCase):
|
||||
|
||||
# Reverse endian:
|
||||
b = flatbuffers.Builder(0)
|
||||
x_other_endian = x.byteswap().newbyteorder()
|
||||
x_other_endian = byte_swap_array(np.__version__, x)
|
||||
b.CreateNumpyVector(x_other_endian)
|
||||
self.assertBuilderEquals(
|
||||
b,
|
||||
@@ -1213,7 +1257,7 @@ class TestByteLayout(unittest.TestCase):
|
||||
|
||||
# Reverse endian:
|
||||
b = flatbuffers.Builder(0)
|
||||
x_other_endian = x.byteswap().newbyteorder()
|
||||
x_other_endian = byte_swap_array(np.__version__, x)
|
||||
b.CreateNumpyVector(x_other_endian)
|
||||
self.assertBuilderEquals(
|
||||
b,
|
||||
@@ -1287,7 +1331,7 @@ class TestByteLayout(unittest.TestCase):
|
||||
|
||||
# Reverse endian:
|
||||
b = flatbuffers.Builder(0)
|
||||
x_other_endian = x.byteswap().newbyteorder()
|
||||
x_other_endian = byte_swap_array(np.__version__, x)
|
||||
b.CreateNumpyVector(x_other_endian)
|
||||
self.assertBuilderEquals(
|
||||
b,
|
||||
@@ -1361,7 +1405,7 @@ class TestByteLayout(unittest.TestCase):
|
||||
|
||||
# Reverse endian:
|
||||
b = flatbuffers.Builder(0)
|
||||
x_other_endian = x.byteswap().newbyteorder()
|
||||
x_other_endian = byte_swap_array(np.__version__, x)
|
||||
b.CreateNumpyVector(x_other_endian)
|
||||
self.assertBuilderEquals(
|
||||
b,
|
||||
@@ -1427,7 +1471,7 @@ class TestByteLayout(unittest.TestCase):
|
||||
|
||||
# Reverse endian:
|
||||
b = flatbuffers.Builder(0)
|
||||
x_other_endian = x.byteswap().newbyteorder()
|
||||
x_other_endian = byte_swap_array(np.__version__, x)
|
||||
b.CreateNumpyVector(x_other_endian)
|
||||
self.assertBuilderEquals(
|
||||
b,
|
||||
@@ -2712,7 +2756,7 @@ class TestNestedUnionTables(unittest.TestCase):
|
||||
|
||||
def test_nested_union_tables(self):
|
||||
nestUnion = MyGame.Example.NestedUnion.NestedUnionTest.NestedUnionTestT()
|
||||
nestUnion.name = b"testUnion1"
|
||||
nestUnion.name = "testUnion1"
|
||||
nestUnion.id = 1
|
||||
nestUnion.data = MyGame.Example.NestedUnion.Vec3.Vec3T()
|
||||
nestUnion.dataType = MyGame.Example.NestedUnion.Any.Any.Vec3
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import flatbuffers
|
||||
|
||||
import flatbuffers
|
||||
import typing
|
||||
|
||||
uoffset: typing.TypeAlias = flatbuffers.number_types.UOffsetTFlags.py_type
|
||||
|
||||
Reference in New Issue
Block a user