diff --git a/src/idl_gen_python.cpp b/src/idl_gen_python.cpp index 78b969459..c9c289ba4 100644 --- a/src/idl_gen_python.cpp +++ b/src/idl_gen_python.cpp @@ -52,9 +52,9 @@ class PythonStubGenerator { public: PythonStubGenerator(const Parser &parser, const std::string &path, const Version &version) - : parser_{parser}, - namer_{WithFlagOptions(kStubConfig, parser.opts, path), - Keywords(version)}, + : parser_{ parser }, + namer_{ WithFlagOptions(kStubConfig, parser.opts, path), + Keywords(version) }, version_(version) {} bool Generate() { @@ -140,8 +140,7 @@ class PythonStubGenerator { return module; } - template - std::string ModuleFor(const T *def) const { + template std::string ModuleFor(const T *def) const { if (parser_.opts.one_file) return ModuleForFile(def->file); return namer_.NamespacedType(*def); } @@ -165,7 +164,7 @@ class PythonStubGenerator { return "None"; } - template + template std::string UnionType(const EnumDef &enum_def, Imports *imports, F type) const { imports->Import("typing"); @@ -181,14 +180,9 @@ class PythonStubGenerator { result += import.name; break; } - case BASE_TYPE_STRING: - result += "str"; - break; - case BASE_TYPE_NONE: - result += "None"; - break; - default: - break; + case BASE_TYPE_STRING: result += "str"; break; + case BASE_TYPE_NONE: result += "None"; break; + default: break; } } return "typing.Union[" + result + "]"; @@ -229,18 +223,14 @@ class PythonStubGenerator { namer_.Type(*type.struct_def)); return import.name; } - case BASE_TYPE_STRING: - return "str"; + case BASE_TYPE_STRING: return "str"; case BASE_TYPE_ARRAY: case BASE_TYPE_VECTOR: { imports->Import("typing"); return "typing.List[" + TypeOf(type.VectorType(), imports) + "]"; } - case BASE_TYPE_UNION: - return UnionType(*type.enum_def, imports); - default: - FLATBUFFERS_ASSERT(0); - return ""; + case BASE_TYPE_UNION: return UnionType(*type.enum_def, imports); + default: FLATBUFFERS_ASSERT(0); return ""; } } @@ -262,8 +252,7 @@ class PythonStubGenerator { namer_.ObjectType(*field_type.struct_def)); return field_name + ": " + import.name + " | None"; } - case BASE_TYPE_STRING: - return field_name + ": str | None"; + case BASE_TYPE_STRING: return field_name + ": str | None"; case BASE_TYPE_ARRAY: case BASE_TYPE_VECTOR: { imports->Import("typing"); @@ -282,8 +271,7 @@ class PythonStubGenerator { case BASE_TYPE_UNION: return field_name + ": " + UnionObjectType(*field->value.type.enum_def, imports); - default: - return field_name; + default: return field_name; } } @@ -312,9 +300,7 @@ class PythonStubGenerator { field_type = "'" + import_.name + "' | None"; break; } - case BASE_TYPE_STRING: - field_type = "str | None"; - break; + case BASE_TYPE_STRING: field_type = "str | None"; break; case BASE_TYPE_ARRAY: case BASE_TYPE_VECTOR: { imports->Import("typing"); @@ -334,9 +320,7 @@ class PythonStubGenerator { case BASE_TYPE_UNION: field_type = UnionObjectType(*type.enum_def, imports); break; - default: - field_type = "typing.Any"; - break; + default: field_type = "typing.Any"; break; } } stub << " " << field_name << ": " << field_type << " = ...,\n"; @@ -485,8 +469,7 @@ class PythonStubGenerator { stub << " def " << name << "(self) -> table.Table | None: ...\n"; break; } - default: - break; + default: break; } } } @@ -530,9 +513,7 @@ class PythonStubGenerator { stub << '\n'; stub << "def Create" + namer_.Type(*struct_def) << "(builder: flatbuffers.Builder"; - for (const std::string &arg : args) { - stub << ", " << arg; - } + for (const std::string &arg : args) { stub << ", " << arg; } stub << ") -> uoffset: ...\n"; } @@ -610,11 +591,10 @@ class PythonStubGenerator { imports->Import("typing", "cast"); - if (version_.major == 3){ + if (version_.major == 3) { imports->Import("enum", "IntEnum"); stub << "(IntEnum)"; - } - else { + } else { stub << "(object)"; } @@ -637,16 +617,15 @@ class PythonStubGenerator { ss << "from __future__ import annotations\n"; ss << '\n'; ss << "import flatbuffers\n"; - if (parser_.opts.python_gen_numpy) { - ss << "import numpy as np\n"; - } + if (parser_.opts.python_gen_numpy) { ss << "import numpy as np\n"; } ss << '\n'; std::set modules; std::map> names_by_module; for (const Import &import : imports.imports) { if (import.IsLocal()) continue; // skip all local imports - if (import.module == "flatbuffers" && import.name == "") continue; // skip double include hardcoded flatbuffers + if (import.module == "flatbuffers" && import.name == "") + continue; // skip double include hardcoded flatbuffers if (import.name == "") { modules.insert(import.module); } else { @@ -686,7 +665,8 @@ class PythonStubGenerator { const Parser &parser_; const IdlNamer namer_; const Version version_; -};} // namespace +}; +} // namespace class PythonGenerator : public BaseGenerator { public: @@ -695,8 +675,8 @@ class PythonGenerator : public BaseGenerator { : BaseGenerator(parser, path, file_name, "" /* not used */, "" /* not used */, "py"), float_const_gen_("float('nan')", "float('inf')", "float('-inf')"), - namer_(WithFlagOptions(kConfig, parser.opts, path), - Keywords(version)) {} + namer_(WithFlagOptions(kConfig, parser.opts, path), Keywords(version)) { + } // Most field accessors need to retrieve and test the field offset first, // this is the prefix code for that. @@ -886,9 +866,8 @@ class PythonGenerator : public BaseGenerator { GenReceiver(struct_def, code_ptr); code += namer_.Method(field); - const ImportMapEntry import_entry = { - GenPackageReference(field.value.type), TypeName(field) - }; + const ImportMapEntry import_entry = { GenPackageReference(field.value.type), + TypeName(field) }; if (parser_.opts.python_typing) { const std::string return_type = ReturnType(struct_def, field); @@ -948,9 +927,8 @@ class PythonGenerator : public BaseGenerator { GenReceiver(struct_def, code_ptr); code += namer_.Method(field) + "(self)"; - const ImportMapEntry import_entry = { - GenPackageReference(field.value.type), TypeName(field) - }; + const ImportMapEntry import_entry = { GenPackageReference(field.value.type), + TypeName(field) }; if (parser_.opts.python_typing) { const std::string return_type = ReturnType(struct_def, field); @@ -1036,11 +1014,8 @@ class PythonGenerator : public BaseGenerator { code += Indent + Indent + "return None\n\n"; } - template - std::string ModuleFor(const T *def) const { - if (!parser_.opts.one_file) { - return namer_.NamespacedType(*def); - } + template std::string ModuleFor(const T *def) const { + if (!parser_.opts.one_file) { return namer_.NamespacedType(*def); } std::string filename = StripExtension(def->file) + parser_.opts.filename_suffix; @@ -1070,9 +1045,8 @@ class PythonGenerator : public BaseGenerator { GenReceiver(struct_def, code_ptr); code += namer_.Method(field); - const ImportMapEntry import_entry = { - GenPackageReference(field.value.type), TypeName(field) - }; + const ImportMapEntry import_entry = { GenPackageReference(field.value.type), + TypeName(field) }; if (parser_.opts.python_typing) { const std::string return_type = ReturnType(struct_def, field); @@ -1195,8 +1169,7 @@ class PythonGenerator : public BaseGenerator { std::string qualified_name = NestedFlatbufferType(unqualified_name); if (qualified_name.empty()) { qualified_name = nested->constant; } - const ImportMapEntry import_entry = { qualified_name, - unqualified_name }; + const ImportMapEntry import_entry = { qualified_name, unqualified_name }; auto &code = *code_ptr; GenReceiver(struct_def, code_ptr); @@ -1808,8 +1781,8 @@ class PythonGenerator : public BaseGenerator { } field_type = "Optional[List[" + field_type + "]"; } else { - field_type = - "Optional[List[" + GetBasePythonTypeForScalarAndString(base_type) + "]]"; + field_type = "Optional[List[" + + GetBasePythonTypeForScalarAndString(base_type) + "]]"; } } @@ -1858,11 +1831,12 @@ class PythonGenerator : public BaseGenerator { const auto field_field = namer_.Field(field); // Build signature with keyword arguments, type hints, and default values. - signature_params += GenIndents(2) + field_field + " = " + default_value + ","; + signature_params += + GenIndents(2) + field_field + " = " + default_value + ","; // Build the body of the __init__ method. init_body += GenIndents(2) + "self." + field_field + " = " + field_field + - " # type: " + field_type; + " # type: " + field_type; } // Writes __init__ method. @@ -1954,10 +1928,16 @@ class PythonGenerator : public BaseGenerator { auto &field = **it; if (field.deprecated) continue; - // Wrties the comparison statement for this field. - const auto field_field = namer_.Field(field); - code += " and \\" + GenIndents(3) + "self." + field_field + - " == " + "other." + field_field; + // Writes the comparison statement for this field. + const auto field_name = namer_.Field(field); + if (parser_.opts.python_gen_numpy && + field.value.type.base_type == BASE_TYPE_VECTOR) { + code += " and \\" + GenIndents(3) + "np.array_equal(self." + + field_name + ", " + "other." + field_name + ")"; + } else { + code += " and \\" + GenIndents(3) + "self." + field_name + + " == " + "other." + field_name; + } } code += "\n"; } @@ -2154,7 +2134,6 @@ class PythonGenerator : public BaseGenerator { auto &field = **it; if (field.deprecated) continue; - auto field_type = TypeName(field); switch (field.value.type.base_type) { case BASE_TYPE_STRUCT: { GenUnPackForStruct(struct_def, field, &code); @@ -2338,9 +2317,9 @@ class PythonGenerator : public BaseGenerator { if (parser_.opts.python_gen_numpy) { code_prefix += GenIndents(3) + "if np is not None and type(self." + - field_field + ") is np.ndarray:"; + field_field + ") is np.ndarray:"; code_prefix += GenIndents(4) + field_field + - " = builder.CreateNumpyVector(self." + field_field + ")"; + " = builder.CreateNumpyVector(self." + field_field + ")"; code_prefix += GenIndents(3) + "else:"; GenPackForScalarVectorFieldHelper(struct_def, field, code_prefix_ptr, 4); code_prefix += "(self." + field_field + "[i])"; @@ -2788,9 +2767,7 @@ class PythonGenerator : public BaseGenerator { } } } - if (parser_.opts.python_gen_numpy) { - code += "np = import_numpy()\n\n"; - } + if (parser_.opts.python_gen_numpy) { code += "np = import_numpy()\n\n"; } } } @@ -2828,7 +2805,7 @@ class PythonGenerator : public BaseGenerator { static bool GeneratePython(const Parser &parser, const std::string &path, const std::string &file_name) { - python::Version version{parser.opts.python_version}; + python::Version version{ parser.opts.python_version }; if (!version.IsValid()) return false; python::PythonGenerator generator(parser, path, file_name, version); diff --git a/tests/MyGame/MonsterExtra.py b/tests/MyGame/MonsterExtra.py index 5c5c89252..e07362ba7 100644 --- a/tests/MyGame/MonsterExtra.py +++ b/tests/MyGame/MonsterExtra.py @@ -282,8 +282,8 @@ class MonsterExtraT(object): self.f1 == other.f1 and \ self.f2 == other.f2 and \ self.f3 == other.f3 and \ - self.dvec == other.dvec and \ - self.fvec == other.fvec + np.array_equal(self.dvec, other.dvec) and \ + np.array_equal(self.fvec, other.fvec) # MonsterExtraT def _UnPack(self, monsterExtra): diff --git a/tests/PythonTest.sh b/tests/PythonTest.sh index 647f3daf1..90bee5724 100755 --- a/tests/PythonTest.sh +++ b/tests/PythonTest.sh @@ -26,7 +26,7 @@ ${test_dir}/../flatc -p -o ${gen_code_path} -I include_test monster_test.fbs --g ${test_dir}/../flatc -p -o ${gen_code_path} -I include_test monster_test.fbs --gen-object-api --gen-onefile ${test_dir}/../flatc -p -o ${gen_code_path} -I include_test monster_extra.fbs --gen-object-api --python-typing --gen-compare ${test_dir}/../flatc -p -o ${gen_code_path} -I include_test arrays_test.fbs --gen-object-api --python-typing -${test_dir}/../flatc -p -o ${gen_code_path} -I include_test nested_union_test.fbs --gen-object-api --python-typing +${test_dir}/../flatc -p -o ${gen_code_path} -I include_test nested_union_test.fbs --gen-object-api --python-typing --python-decode-obj-api-strings ${test_dir}/../flatc -p -o ${gen_code_path} -I include_test service_test.fbs --grpc --grpc-python-typed-handlers --python-typing --no-python-gen-numpy --gen-onefile # Syntax: run_tests diff --git a/tests/py_test.py b/tests/py_test.py index 7d6c0a379..4749fe440 100644 --- a/tests/py_test.py +++ b/tests/py_test.py @@ -97,6 +97,27 @@ def assertRaises(test_case, fn, exception_class): test_case.assertTrue(isinstance(exc, exception_class)) +def byte_swap_array(np_version, arr): + """ + Performs byte swapping on a NumPy array, adapting to different NumPy versions. + + Args: + np_version: Version of NumPu (np.__version__) + arr: The input NumPy array. + + Returns: + A new NumPy array with byte order swapped. + """ + numpy_version_tuple = tuple(map(int, np_version.split('.')[:3])) + min_version_for_new_method_tuple = (2, 0, 0) + + if numpy_version_tuple >= min_version_for_new_method_tuple: + # 'S' indicates swap byte order. + return arr.byteswap().view(arr.dtype.newbyteorder('S')) + else: + return arr.byteswap().newbyteorder() + + class TestWireFormat(unittest.TestCase): def test_wire_format(self): @@ -1070,6 +1091,29 @@ class TestByteLayout(unittest.TestCase): # 1-byte pad: self.assertBuilderEquals(b, [3, 0, 0, 0, 1, 2, 3, 0]) + def test_comparison_of_np_arrays(self): + """ + MonsterT dvec and fvec are np.array types which can not be compared with == directly + This tests ensures that the __eq__ is generated correctly + """ + try: + # if numpy exists, then we should be able to get the + # vector as a numpy array + import numpy as np + vec1 = np.array([1, 2], dtype=np.float32) + vec2 = np.array([3, 4], dtype=np.float32) + + monsterA = MyGame.MonsterExtra.MonsterExtraT(d0=1, d1=1, d2=1, d3=1, f0=1, f1=1, f2=1, f3=1, dvec=vec1, fvec=vec2) + assert monsterA == monsterA + + monsterB = MyGame.MonsterExtra.MonsterExtraT(d0=2, d1=1, d2=1, d3=1, f0=1, f1=1, f2=1, f3=1, dvec=vec1, fvec=vec2) + assert monsterA != monsterB + except ImportError: + b = flatbuffers.Builder(0) + x = 0 + assertRaises(self, lambda: b.CreateNumpyVector(x), + NumpyRequiredForThisFeature) + def test_create_numpy_vector_int8(self): try: # if numpy exists, then we should be able to get the @@ -1095,7 +1139,7 @@ class TestByteLayout(unittest.TestCase): # Reverse endian: b = flatbuffers.Builder(0) - x_other_endian = x.byteswap().newbyteorder() + x_other_endian = byte_swap_array(np.__version__, x) b.CreateNumpyVector(x_other_endian) self.assertBuilderEquals( b, @@ -1144,7 +1188,7 @@ class TestByteLayout(unittest.TestCase): # Reverse endian: b = flatbuffers.Builder(0) - x_other_endian = x.byteswap().newbyteorder() + x_other_endian = byte_swap_array(np.__version__, x) b.CreateNumpyVector(x_other_endian) self.assertBuilderEquals( b, @@ -1213,7 +1257,7 @@ class TestByteLayout(unittest.TestCase): # Reverse endian: b = flatbuffers.Builder(0) - x_other_endian = x.byteswap().newbyteorder() + x_other_endian = byte_swap_array(np.__version__, x) b.CreateNumpyVector(x_other_endian) self.assertBuilderEquals( b, @@ -1287,7 +1331,7 @@ class TestByteLayout(unittest.TestCase): # Reverse endian: b = flatbuffers.Builder(0) - x_other_endian = x.byteswap().newbyteorder() + x_other_endian = byte_swap_array(np.__version__, x) b.CreateNumpyVector(x_other_endian) self.assertBuilderEquals( b, @@ -1361,7 +1405,7 @@ class TestByteLayout(unittest.TestCase): # Reverse endian: b = flatbuffers.Builder(0) - x_other_endian = x.byteswap().newbyteorder() + x_other_endian = byte_swap_array(np.__version__, x) b.CreateNumpyVector(x_other_endian) self.assertBuilderEquals( b, @@ -1427,7 +1471,7 @@ class TestByteLayout(unittest.TestCase): # Reverse endian: b = flatbuffers.Builder(0) - x_other_endian = x.byteswap().newbyteorder() + x_other_endian = byte_swap_array(np.__version__, x) b.CreateNumpyVector(x_other_endian) self.assertBuilderEquals( b, @@ -2712,7 +2756,7 @@ class TestNestedUnionTables(unittest.TestCase): def test_nested_union_tables(self): nestUnion = MyGame.Example.NestedUnion.NestedUnionTest.NestedUnionTestT() - nestUnion.name = b"testUnion1" + nestUnion.name = "testUnion1" nestUnion.id = 1 nestUnion.data = MyGame.Example.NestedUnion.Vec3.Vec3T() nestUnion.dataType = MyGame.Example.NestedUnion.Any.Any.Vec3 diff --git a/tests/service_test_generated.pyi b/tests/service_test_generated.pyi index 2189a94c6..de1aaeebb 100644 --- a/tests/service_test_generated.pyi +++ b/tests/service_test_generated.pyi @@ -2,7 +2,6 @@ from __future__ import annotations import flatbuffers -import flatbuffers import typing uoffset: typing.TypeAlias = flatbuffers.number_types.UOffsetTFlags.py_type