mirror of
https://github.com/google/flatbuffers.git
synced 2026-06-04 04:33:23 +00:00
[C++] Enable using struct and array of struct as key (#7741)
* add unit tests for support struct as key * make changes to parser and add helper function to generate comparator for struct * implement * add more unit tests * format * just a test * test done * rerun generator * restore build file * address comment * format * rebase * rebase * add more unit tests * rerun generator * address some comments * address comment * update * format * address comment Co-authored-by: Wen Sun <sunwen@google.com> Co-authored-by: Derek Bailey <derekbailey@google.com>
This commit is contained in:
@@ -2245,54 +2245,147 @@ class CppGenerator : public BaseGenerator {
|
||||
}
|
||||
}
|
||||
|
||||
void GenComparatorForStruct(const StructDef &struct_def, size_t space_size,
|
||||
const std::string lhs_struct_literal,
|
||||
const std::string rhs_struct_literal) {
|
||||
code_.SetValue("LHS_PREFIX", lhs_struct_literal);
|
||||
code_.SetValue("RHS_PREFIX", rhs_struct_literal);
|
||||
std::string space(space_size, ' ');
|
||||
for (const auto &curr_field : struct_def.fields.vec) {
|
||||
const auto curr_field_name = Name(*curr_field);
|
||||
code_.SetValue("CURR_FIELD_NAME", curr_field_name);
|
||||
code_.SetValue("LHS", lhs_struct_literal + "_" + curr_field_name);
|
||||
code_.SetValue("RHS", rhs_struct_literal + "_" + curr_field_name);
|
||||
const bool is_scalar = IsScalar(curr_field->value.type.base_type);
|
||||
const bool is_array = IsArray(curr_field->value.type);
|
||||
const bool is_struct = IsStruct(curr_field->value.type);
|
||||
|
||||
// If encouter a key field, call KeyCompareWithValue to compare this field.
|
||||
if (curr_field->key) {
|
||||
code_ +=
|
||||
space + "const auto {{RHS}} = {{RHS_PREFIX}}.{{CURR_FIELD_NAME}}();";
|
||||
code_ += space + "const auto {{CURR_FIELD_NAME}}_compare_result = {{LHS_PREFIX}}.KeyCompareWithValue({{RHS}});";
|
||||
|
||||
code_ += space + "if ({{CURR_FIELD_NAME}}_compare_result != 0)";
|
||||
code_ += space + " return {{CURR_FIELD_NAME}}_compare_result;";
|
||||
continue;
|
||||
}
|
||||
|
||||
code_ +=
|
||||
space + "const auto {{LHS}} = {{LHS_PREFIX}}.{{CURR_FIELD_NAME}}();";
|
||||
code_ +=
|
||||
space + "const auto {{RHS}} = {{RHS_PREFIX}}.{{CURR_FIELD_NAME}}();";
|
||||
if (is_scalar) {
|
||||
code_ += space + "if ({{LHS}} != {{RHS}})";
|
||||
code_ += space +
|
||||
" return static_cast<int>({{LHS}} > {{RHS}}) - "
|
||||
"static_cast<int>({{LHS}} < {{RHS}});";
|
||||
} else if (is_array) {
|
||||
const auto &elem_type = curr_field->value.type.VectorType();
|
||||
code_ +=
|
||||
space +
|
||||
"for (::flatbuffers::uoffset_t i = 0; i < {{LHS}}->size(); i++) {";
|
||||
code_ += space + " const auto {{LHS}}_elem = {{LHS}}->Get(i);";
|
||||
code_ += space + " const auto {{RHS}}_elem = {{RHS}}->Get(i);";
|
||||
if (IsScalar(elem_type.base_type)) {
|
||||
code_ += space + " if ({{LHS}}_elem != {{RHS}}_elem)";
|
||||
code_ += space +
|
||||
" return static_cast<int>({{LHS}}_elem > {{RHS}}_elem) - "
|
||||
"static_cast<int>({{LHS}}_elem < {{RHS}}_elem);";
|
||||
code_ += space + "}";
|
||||
|
||||
} else if (IsStruct(elem_type)) {
|
||||
if (curr_field->key) {
|
||||
code_ += space + "const auto {{CURR_FIELD_NAME}}_compare_result = {{LHS_PREFIX}}.KeyCompareWithValue({{RHS}});";
|
||||
code_ += space + "if ({{CURR_FIELD_NAME}}_compare_result != 0)";
|
||||
code_ += space + " return {{CURR_FIELD_NAME}}_compare_result;";
|
||||
continue;
|
||||
}
|
||||
GenComparatorForStruct(
|
||||
*curr_field->value.type.struct_def, space_size + 2,
|
||||
code_.GetValue("LHS") + "_elem", code_.GetValue("RHS") + "_elem");
|
||||
|
||||
code_ += space + "}";
|
||||
}
|
||||
|
||||
} else if (is_struct) {
|
||||
GenComparatorForStruct(*curr_field->value.type.struct_def, space_size,
|
||||
code_.GetValue("LHS"), code_.GetValue("RHS"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate CompareWithValue method for a key field.
|
||||
void GenKeyFieldMethods(const FieldDef &field) {
|
||||
FLATBUFFERS_ASSERT(field.key);
|
||||
const bool is_string = IsString(field.value.type);
|
||||
const bool is_array = IsArray(field.value.type);
|
||||
|
||||
const bool is_struct = IsStruct(field.value.type);
|
||||
// Generate KeyCompareLessThan function
|
||||
code_ +=
|
||||
" bool KeyCompareLessThan(const {{STRUCT_NAME}} * const o) const {";
|
||||
if (is_string) {
|
||||
// use operator< of ::flatbuffers::String
|
||||
code_ += " return *{{FIELD_NAME}}() < *o->{{FIELD_NAME}}();";
|
||||
} else if (is_array) {
|
||||
const auto &elem_type = field.value.type.VectorType();
|
||||
if (IsScalar(elem_type.base_type)) {
|
||||
code_ += " return KeyCompareWithValue(o->{{FIELD_NAME}}()) < 0;";
|
||||
}
|
||||
} else {
|
||||
} else if (is_array || is_struct) {
|
||||
code_ += " return KeyCompareWithValue(o->{{FIELD_NAME}}()) < 0;";
|
||||
}else {
|
||||
code_ += " return {{FIELD_NAME}}() < o->{{FIELD_NAME}}();";
|
||||
}
|
||||
code_ += " }";
|
||||
|
||||
// Generate KeyCompareWithValue function
|
||||
if (is_string) {
|
||||
code_ += " int KeyCompareWithValue(const char *_{{FIELD_NAME}}) const {";
|
||||
code_ += " return strcmp({{FIELD_NAME}}()->c_str(), _{{FIELD_NAME}});";
|
||||
} else if (is_array) {
|
||||
const auto &elem_type = field.value.type.VectorType();
|
||||
std::string input_type = "::flatbuffers::Array<" +
|
||||
GenTypeGet(elem_type, "", "", " ", false) +
|
||||
", " + NumToString(elem_type.fixed_length) + ">";
|
||||
code_.SetValue("INPUT_TYPE", input_type);
|
||||
code_ +=
|
||||
" int KeyCompareWithValue(const {{INPUT_TYPE}} *_{{FIELD_NAME}}"
|
||||
") const {";
|
||||
code_ +=
|
||||
" const {{INPUT_TYPE}} *curr_{{FIELD_NAME}} = {{FIELD_NAME}}();";
|
||||
code_ +=
|
||||
" for (::flatbuffers::uoffset_t i = 0; i < "
|
||||
"curr_{{FIELD_NAME}}->size(); i++) {";
|
||||
|
||||
if (IsScalar(elem_type.base_type)) {
|
||||
std::string input_type = "::flatbuffers::Array<" +
|
||||
GenTypeBasic(elem_type, false) + ", " +
|
||||
NumToString(elem_type.fixed_length) + ">";
|
||||
code_.SetValue("INPUT_TYPE", input_type);
|
||||
code_ +=
|
||||
" int KeyCompareWithValue(const {{INPUT_TYPE}} *_{{FIELD_NAME}}"
|
||||
") const {";
|
||||
code_ +=
|
||||
" const {{INPUT_TYPE}} *curr_{{FIELD_NAME}} = {{FIELD_NAME}}();";
|
||||
code_ +=
|
||||
" for (::flatbuffers::uoffset_t i = 0; i < "
|
||||
"curr_{{FIELD_NAME}}->size(); i++) {";
|
||||
code_ += " const auto lhs = curr_{{FIELD_NAME}}->Get(i);";
|
||||
code_ += " const auto rhs = _{{FIELD_NAME}}->Get(i);";
|
||||
code_ += " if(lhs != rhs)";
|
||||
code_ += " if (lhs != rhs)";
|
||||
code_ +=
|
||||
" return static_cast<int>(lhs > rhs)"
|
||||
" - static_cast<int>(lhs < rhs);";
|
||||
code_ += " }";
|
||||
code_ += " return 0;";
|
||||
} else if (IsStruct(elem_type)) {
|
||||
code_ +=
|
||||
" const auto &lhs_{{FIELD_NAME}} = "
|
||||
"*(curr_{{FIELD_NAME}}->Get(i));";
|
||||
code_ +=
|
||||
" const auto &rhs_{{FIELD_NAME}} = *(_{{FIELD_NAME}}->Get(i));";
|
||||
GenComparatorForStruct(*elem_type.struct_def, 6,
|
||||
"lhs_" + code_.GetValue("FIELD_NAME"),
|
||||
"rhs_" + code_.GetValue("FIELD_NAME"));
|
||||
}
|
||||
code_ += " }";
|
||||
code_ += " return 0;";
|
||||
} else if (is_struct) {
|
||||
const auto *struct_def = field.value.type.struct_def;
|
||||
code_.SetValue("INPUT_TYPE",
|
||||
GenTypeGet(field.value.type, "", "", "", false));
|
||||
code_ +=
|
||||
" int KeyCompareWithValue(const {{INPUT_TYPE}} &_{{FIELD_NAME}}) "
|
||||
"const {";
|
||||
code_ += " const auto &lhs_{{FIELD_NAME}} = {{FIELD_NAME}}();";
|
||||
code_ += " const auto &rhs_{{FIELD_NAME}} = _{{FIELD_NAME}};";
|
||||
GenComparatorForStruct(*struct_def, 4,
|
||||
"lhs_" + code_.GetValue("FIELD_NAME"),
|
||||
"rhs_" + code_.GetValue("FIELD_NAME"));
|
||||
code_ += " return 0;";
|
||||
|
||||
} else {
|
||||
FLATBUFFERS_ASSERT(IsScalar(field.value.type.base_type));
|
||||
auto type = GenTypeBasic(field.value.type, false);
|
||||
|
||||
Reference in New Issue
Block a user