Bugfix __eq__ for numpy data types (#8646)

* [Python] Sync PythonTest.sh flags with generate_code.py

* [Python] Update generated code to latest flatc version for tests

* [Python] Fix test support for numpy newer than 2.0.0

* [Python] Remove unused variable

* [Python] Fix __eq__ for numpy arrays

* [Python] Run clang-format over the entire file
This commit is contained in:
Felix
2025-07-26 19:31:38 +02:00
committed by GitHub
parent 860d645349
commit f32a7dcbd2
5 changed files with 108 additions and 88 deletions

View File

@@ -52,9 +52,9 @@ class PythonStubGenerator {
public:
PythonStubGenerator(const Parser &parser, const std::string &path,
const Version &version)
: parser_{parser},
namer_{WithFlagOptions(kStubConfig, parser.opts, path),
Keywords(version)},
: parser_{ parser },
namer_{ WithFlagOptions(kStubConfig, parser.opts, path),
Keywords(version) },
version_(version) {}
bool Generate() {
@@ -140,8 +140,7 @@ class PythonStubGenerator {
return module;
}
template <typename T>
std::string ModuleFor(const T *def) const {
template<typename T> std::string ModuleFor(const T *def) const {
if (parser_.opts.one_file) return ModuleForFile(def->file);
return namer_.NamespacedType(*def);
}
@@ -165,7 +164,7 @@ class PythonStubGenerator {
return "None";
}
template <typename F>
template<typename F>
std::string UnionType(const EnumDef &enum_def, Imports *imports,
F type) const {
imports->Import("typing");
@@ -181,14 +180,9 @@ class PythonStubGenerator {
result += import.name;
break;
}
case BASE_TYPE_STRING:
result += "str";
break;
case BASE_TYPE_NONE:
result += "None";
break;
default:
break;
case BASE_TYPE_STRING: result += "str"; break;
case BASE_TYPE_NONE: result += "None"; break;
default: break;
}
}
return "typing.Union[" + result + "]";
@@ -229,18 +223,14 @@ class PythonStubGenerator {
namer_.Type(*type.struct_def));
return import.name;
}
case BASE_TYPE_STRING:
return "str";
case BASE_TYPE_STRING: return "str";
case BASE_TYPE_ARRAY:
case BASE_TYPE_VECTOR: {
imports->Import("typing");
return "typing.List[" + TypeOf(type.VectorType(), imports) + "]";
}
case BASE_TYPE_UNION:
return UnionType(*type.enum_def, imports);
default:
FLATBUFFERS_ASSERT(0);
return "";
case BASE_TYPE_UNION: return UnionType(*type.enum_def, imports);
default: FLATBUFFERS_ASSERT(0); return "";
}
}
@@ -262,8 +252,7 @@ class PythonStubGenerator {
namer_.ObjectType(*field_type.struct_def));
return field_name + ": " + import.name + " | None";
}
case BASE_TYPE_STRING:
return field_name + ": str | None";
case BASE_TYPE_STRING: return field_name + ": str | None";
case BASE_TYPE_ARRAY:
case BASE_TYPE_VECTOR: {
imports->Import("typing");
@@ -282,8 +271,7 @@ class PythonStubGenerator {
case BASE_TYPE_UNION:
return field_name + ": " +
UnionObjectType(*field->value.type.enum_def, imports);
default:
return field_name;
default: return field_name;
}
}
@@ -312,9 +300,7 @@ class PythonStubGenerator {
field_type = "'" + import_.name + "' | None";
break;
}
case BASE_TYPE_STRING:
field_type = "str | None";
break;
case BASE_TYPE_STRING: field_type = "str | None"; break;
case BASE_TYPE_ARRAY:
case BASE_TYPE_VECTOR: {
imports->Import("typing");
@@ -334,9 +320,7 @@ class PythonStubGenerator {
case BASE_TYPE_UNION:
field_type = UnionObjectType(*type.enum_def, imports);
break;
default:
field_type = "typing.Any";
break;
default: field_type = "typing.Any"; break;
}
}
stub << " " << field_name << ": " << field_type << " = ...,\n";
@@ -485,8 +469,7 @@ class PythonStubGenerator {
stub << " def " << name << "(self) -> table.Table | None: ...\n";
break;
}
default:
break;
default: break;
}
}
}
@@ -530,9 +513,7 @@ class PythonStubGenerator {
stub << '\n';
stub << "def Create" + namer_.Type(*struct_def)
<< "(builder: flatbuffers.Builder";
for (const std::string &arg : args) {
stub << ", " << arg;
}
for (const std::string &arg : args) { stub << ", " << arg; }
stub << ") -> uoffset: ...\n";
}
@@ -610,11 +591,10 @@ class PythonStubGenerator {
imports->Import("typing", "cast");
if (version_.major == 3){
if (version_.major == 3) {
imports->Import("enum", "IntEnum");
stub << "(IntEnum)";
}
else {
} else {
stub << "(object)";
}
@@ -637,16 +617,15 @@ class PythonStubGenerator {
ss << "from __future__ import annotations\n";
ss << '\n';
ss << "import flatbuffers\n";
if (parser_.opts.python_gen_numpy) {
ss << "import numpy as np\n";
}
if (parser_.opts.python_gen_numpy) { ss << "import numpy as np\n"; }
ss << '\n';
std::set<std::string> modules;
std::map<std::string, std::set<std::string>> names_by_module;
for (const Import &import : imports.imports) {
if (import.IsLocal()) continue; // skip all local imports
if (import.module == "flatbuffers" && import.name == "") continue; // skip double include hardcoded flatbuffers
if (import.module == "flatbuffers" && import.name == "")
continue; // skip double include hardcoded flatbuffers
if (import.name == "") {
modules.insert(import.module);
} else {
@@ -686,7 +665,8 @@ class PythonStubGenerator {
const Parser &parser_;
const IdlNamer namer_;
const Version version_;
};} // namespace
};
} // namespace
class PythonGenerator : public BaseGenerator {
public:
@@ -695,8 +675,8 @@ class PythonGenerator : public BaseGenerator {
: BaseGenerator(parser, path, file_name, "" /* not used */,
"" /* not used */, "py"),
float_const_gen_("float('nan')", "float('inf')", "float('-inf')"),
namer_(WithFlagOptions(kConfig, parser.opts, path),
Keywords(version)) {}
namer_(WithFlagOptions(kConfig, parser.opts, path), Keywords(version)) {
}
// Most field accessors need to retrieve and test the field offset first,
// this is the prefix code for that.
@@ -886,9 +866,8 @@ class PythonGenerator : public BaseGenerator {
GenReceiver(struct_def, code_ptr);
code += namer_.Method(field);
const ImportMapEntry import_entry = {
GenPackageReference(field.value.type), TypeName(field)
};
const ImportMapEntry import_entry = { GenPackageReference(field.value.type),
TypeName(field) };
if (parser_.opts.python_typing) {
const std::string return_type = ReturnType(struct_def, field);
@@ -948,9 +927,8 @@ class PythonGenerator : public BaseGenerator {
GenReceiver(struct_def, code_ptr);
code += namer_.Method(field) + "(self)";
const ImportMapEntry import_entry = {
GenPackageReference(field.value.type), TypeName(field)
};
const ImportMapEntry import_entry = { GenPackageReference(field.value.type),
TypeName(field) };
if (parser_.opts.python_typing) {
const std::string return_type = ReturnType(struct_def, field);
@@ -1036,11 +1014,8 @@ class PythonGenerator : public BaseGenerator {
code += Indent + Indent + "return None\n\n";
}
template <typename T>
std::string ModuleFor(const T *def) const {
if (!parser_.opts.one_file) {
return namer_.NamespacedType(*def);
}
template<typename T> std::string ModuleFor(const T *def) const {
if (!parser_.opts.one_file) { return namer_.NamespacedType(*def); }
std::string filename =
StripExtension(def->file) + parser_.opts.filename_suffix;
@@ -1070,9 +1045,8 @@ class PythonGenerator : public BaseGenerator {
GenReceiver(struct_def, code_ptr);
code += namer_.Method(field);
const ImportMapEntry import_entry = {
GenPackageReference(field.value.type), TypeName(field)
};
const ImportMapEntry import_entry = { GenPackageReference(field.value.type),
TypeName(field) };
if (parser_.opts.python_typing) {
const std::string return_type = ReturnType(struct_def, field);
@@ -1195,8 +1169,7 @@ class PythonGenerator : public BaseGenerator {
std::string qualified_name = NestedFlatbufferType(unqualified_name);
if (qualified_name.empty()) { qualified_name = nested->constant; }
const ImportMapEntry import_entry = { qualified_name,
unqualified_name };
const ImportMapEntry import_entry = { qualified_name, unqualified_name };
auto &code = *code_ptr;
GenReceiver(struct_def, code_ptr);
@@ -1808,8 +1781,8 @@ class PythonGenerator : public BaseGenerator {
}
field_type = "Optional[List[" + field_type + "]";
} else {
field_type =
"Optional[List[" + GetBasePythonTypeForScalarAndString(base_type) + "]]";
field_type = "Optional[List[" +
GetBasePythonTypeForScalarAndString(base_type) + "]]";
}
}
@@ -1858,11 +1831,12 @@ class PythonGenerator : public BaseGenerator {
const auto field_field = namer_.Field(field);
// Build signature with keyword arguments, type hints, and default values.
signature_params += GenIndents(2) + field_field + " = " + default_value + ",";
signature_params +=
GenIndents(2) + field_field + " = " + default_value + ",";
// Build the body of the __init__ method.
init_body += GenIndents(2) + "self." + field_field + " = " + field_field +
" # type: " + field_type;
" # type: " + field_type;
}
// Writes __init__ method.
@@ -1954,10 +1928,16 @@ class PythonGenerator : public BaseGenerator {
auto &field = **it;
if (field.deprecated) continue;
// Wrties the comparison statement for this field.
const auto field_field = namer_.Field(field);
code += " and \\" + GenIndents(3) + "self." + field_field +
" == " + "other." + field_field;
// Writes the comparison statement for this field.
const auto field_name = namer_.Field(field);
if (parser_.opts.python_gen_numpy &&
field.value.type.base_type == BASE_TYPE_VECTOR) {
code += " and \\" + GenIndents(3) + "np.array_equal(self." +
field_name + ", " + "other." + field_name + ")";
} else {
code += " and \\" + GenIndents(3) + "self." + field_name +
" == " + "other." + field_name;
}
}
code += "\n";
}
@@ -2154,7 +2134,6 @@ class PythonGenerator : public BaseGenerator {
auto &field = **it;
if (field.deprecated) continue;
auto field_type = TypeName(field);
switch (field.value.type.base_type) {
case BASE_TYPE_STRUCT: {
GenUnPackForStruct(struct_def, field, &code);
@@ -2338,9 +2317,9 @@ class PythonGenerator : public BaseGenerator {
if (parser_.opts.python_gen_numpy) {
code_prefix += GenIndents(3) + "if np is not None and type(self." +
field_field + ") is np.ndarray:";
field_field + ") is np.ndarray:";
code_prefix += GenIndents(4) + field_field +
" = builder.CreateNumpyVector(self." + field_field + ")";
" = builder.CreateNumpyVector(self." + field_field + ")";
code_prefix += GenIndents(3) + "else:";
GenPackForScalarVectorFieldHelper(struct_def, field, code_prefix_ptr, 4);
code_prefix += "(self." + field_field + "[i])";
@@ -2788,9 +2767,7 @@ class PythonGenerator : public BaseGenerator {
}
}
}
if (parser_.opts.python_gen_numpy) {
code += "np = import_numpy()\n\n";
}
if (parser_.opts.python_gen_numpy) { code += "np = import_numpy()\n\n"; }
}
}
@@ -2828,7 +2805,7 @@ class PythonGenerator : public BaseGenerator {
static bool GeneratePython(const Parser &parser, const std::string &path,
const std::string &file_name) {
python::Version version{parser.opts.python_version};
python::Version version{ parser.opts.python_version };
if (!version.IsValid()) return false;
python::PythonGenerator generator(parser, path, file_name, version);