[C++] Enable using struct and array of struct as key (#7741)

* add unit tests for support struct as key

* make changes to parser and add helper function to generate comparator for struct

* implement

* add more unit tests

* format

* just a test

* test done

* rerun generator

* restore build file

* address comment

* format

* rebase

* rebase

* add more unit tests

* rerun generator

* address some comments

* address comment

* update

* format

* address comment

Co-authored-by: Wen Sun <sunwen@google.com>
Co-authored-by: Derek Bailey <derekbailey@google.com>
This commit is contained in:
Wen Sun
2023-01-24 16:37:13 -08:00
committed by GitHub
parent ee848a02e1
commit 802a3a056a
7 changed files with 853 additions and 37 deletions

View File

@@ -2245,54 +2245,147 @@ class CppGenerator : public BaseGenerator {
}
}
void GenComparatorForStruct(const StructDef &struct_def, size_t space_size,
const std::string lhs_struct_literal,
const std::string rhs_struct_literal) {
code_.SetValue("LHS_PREFIX", lhs_struct_literal);
code_.SetValue("RHS_PREFIX", rhs_struct_literal);
std::string space(space_size, ' ');
for (const auto &curr_field : struct_def.fields.vec) {
const auto curr_field_name = Name(*curr_field);
code_.SetValue("CURR_FIELD_NAME", curr_field_name);
code_.SetValue("LHS", lhs_struct_literal + "_" + curr_field_name);
code_.SetValue("RHS", rhs_struct_literal + "_" + curr_field_name);
const bool is_scalar = IsScalar(curr_field->value.type.base_type);
const bool is_array = IsArray(curr_field->value.type);
const bool is_struct = IsStruct(curr_field->value.type);
// If encouter a key field, call KeyCompareWithValue to compare this field.
if (curr_field->key) {
code_ +=
space + "const auto {{RHS}} = {{RHS_PREFIX}}.{{CURR_FIELD_NAME}}();";
code_ += space + "const auto {{CURR_FIELD_NAME}}_compare_result = {{LHS_PREFIX}}.KeyCompareWithValue({{RHS}});";
code_ += space + "if ({{CURR_FIELD_NAME}}_compare_result != 0)";
code_ += space + " return {{CURR_FIELD_NAME}}_compare_result;";
continue;
}
code_ +=
space + "const auto {{LHS}} = {{LHS_PREFIX}}.{{CURR_FIELD_NAME}}();";
code_ +=
space + "const auto {{RHS}} = {{RHS_PREFIX}}.{{CURR_FIELD_NAME}}();";
if (is_scalar) {
code_ += space + "if ({{LHS}} != {{RHS}})";
code_ += space +
" return static_cast<int>({{LHS}} > {{RHS}}) - "
"static_cast<int>({{LHS}} < {{RHS}});";
} else if (is_array) {
const auto &elem_type = curr_field->value.type.VectorType();
code_ +=
space +
"for (::flatbuffers::uoffset_t i = 0; i < {{LHS}}->size(); i++) {";
code_ += space + " const auto {{LHS}}_elem = {{LHS}}->Get(i);";
code_ += space + " const auto {{RHS}}_elem = {{RHS}}->Get(i);";
if (IsScalar(elem_type.base_type)) {
code_ += space + " if ({{LHS}}_elem != {{RHS}}_elem)";
code_ += space +
" return static_cast<int>({{LHS}}_elem > {{RHS}}_elem) - "
"static_cast<int>({{LHS}}_elem < {{RHS}}_elem);";
code_ += space + "}";
} else if (IsStruct(elem_type)) {
if (curr_field->key) {
code_ += space + "const auto {{CURR_FIELD_NAME}}_compare_result = {{LHS_PREFIX}}.KeyCompareWithValue({{RHS}});";
code_ += space + "if ({{CURR_FIELD_NAME}}_compare_result != 0)";
code_ += space + " return {{CURR_FIELD_NAME}}_compare_result;";
continue;
}
GenComparatorForStruct(
*curr_field->value.type.struct_def, space_size + 2,
code_.GetValue("LHS") + "_elem", code_.GetValue("RHS") + "_elem");
code_ += space + "}";
}
} else if (is_struct) {
GenComparatorForStruct(*curr_field->value.type.struct_def, space_size,
code_.GetValue("LHS"), code_.GetValue("RHS"));
}
}
}
// Generate CompareWithValue method for a key field.
void GenKeyFieldMethods(const FieldDef &field) {
FLATBUFFERS_ASSERT(field.key);
const bool is_string = IsString(field.value.type);
const bool is_array = IsArray(field.value.type);
const bool is_struct = IsStruct(field.value.type);
// Generate KeyCompareLessThan function
code_ +=
" bool KeyCompareLessThan(const {{STRUCT_NAME}} * const o) const {";
if (is_string) {
// use operator< of ::flatbuffers::String
code_ += " return *{{FIELD_NAME}}() < *o->{{FIELD_NAME}}();";
} else if (is_array) {
const auto &elem_type = field.value.type.VectorType();
if (IsScalar(elem_type.base_type)) {
code_ += " return KeyCompareWithValue(o->{{FIELD_NAME}}()) < 0;";
}
} else {
} else if (is_array || is_struct) {
code_ += " return KeyCompareWithValue(o->{{FIELD_NAME}}()) < 0;";
}else {
code_ += " return {{FIELD_NAME}}() < o->{{FIELD_NAME}}();";
}
code_ += " }";
// Generate KeyCompareWithValue function
if (is_string) {
code_ += " int KeyCompareWithValue(const char *_{{FIELD_NAME}}) const {";
code_ += " return strcmp({{FIELD_NAME}}()->c_str(), _{{FIELD_NAME}});";
} else if (is_array) {
const auto &elem_type = field.value.type.VectorType();
std::string input_type = "::flatbuffers::Array<" +
GenTypeGet(elem_type, "", "", " ", false) +
", " + NumToString(elem_type.fixed_length) + ">";
code_.SetValue("INPUT_TYPE", input_type);
code_ +=
" int KeyCompareWithValue(const {{INPUT_TYPE}} *_{{FIELD_NAME}}"
") const {";
code_ +=
" const {{INPUT_TYPE}} *curr_{{FIELD_NAME}} = {{FIELD_NAME}}();";
code_ +=
" for (::flatbuffers::uoffset_t i = 0; i < "
"curr_{{FIELD_NAME}}->size(); i++) {";
if (IsScalar(elem_type.base_type)) {
std::string input_type = "::flatbuffers::Array<" +
GenTypeBasic(elem_type, false) + ", " +
NumToString(elem_type.fixed_length) + ">";
code_.SetValue("INPUT_TYPE", input_type);
code_ +=
" int KeyCompareWithValue(const {{INPUT_TYPE}} *_{{FIELD_NAME}}"
") const {";
code_ +=
" const {{INPUT_TYPE}} *curr_{{FIELD_NAME}} = {{FIELD_NAME}}();";
code_ +=
" for (::flatbuffers::uoffset_t i = 0; i < "
"curr_{{FIELD_NAME}}->size(); i++) {";
code_ += " const auto lhs = curr_{{FIELD_NAME}}->Get(i);";
code_ += " const auto rhs = _{{FIELD_NAME}}->Get(i);";
code_ += " if(lhs != rhs)";
code_ += " if (lhs != rhs)";
code_ +=
" return static_cast<int>(lhs > rhs)"
" - static_cast<int>(lhs < rhs);";
code_ += " }";
code_ += " return 0;";
} else if (IsStruct(elem_type)) {
code_ +=
" const auto &lhs_{{FIELD_NAME}} = "
"*(curr_{{FIELD_NAME}}->Get(i));";
code_ +=
" const auto &rhs_{{FIELD_NAME}} = *(_{{FIELD_NAME}}->Get(i));";
GenComparatorForStruct(*elem_type.struct_def, 6,
"lhs_" + code_.GetValue("FIELD_NAME"),
"rhs_" + code_.GetValue("FIELD_NAME"));
}
code_ += " }";
code_ += " return 0;";
} else if (is_struct) {
const auto *struct_def = field.value.type.struct_def;
code_.SetValue("INPUT_TYPE",
GenTypeGet(field.value.type, "", "", "", false));
code_ +=
" int KeyCompareWithValue(const {{INPUT_TYPE}} &_{{FIELD_NAME}}) "
"const {";
code_ += " const auto &lhs_{{FIELD_NAME}} = {{FIELD_NAME}}();";
code_ += " const auto &rhs_{{FIELD_NAME}} = _{{FIELD_NAME}};";
GenComparatorForStruct(*struct_def, 4,
"lhs_" + code_.GetValue("FIELD_NAME"),
"rhs_" + code_.GetValue("FIELD_NAME"));
code_ += " return 0;";
} else {
FLATBUFFERS_ASSERT(IsScalar(field.value.type.base_type));
auto type = GenTypeBasic(field.value.type, false);