From 0cfb7eb80b05c058e19e50fb575263908e601469 Mon Sep 17 00:00:00 2001 From: mpawlowski-eyeo <73581124+mpawlowski-eyeo@users.noreply.github.com> Date: Mon, 25 Mar 2024 18:39:51 +0100 Subject: [PATCH] Fix handling non null-terminated string_views in LookupByKey (#8203) * Reproduce the error in a unit test Reproduces #8200 * Overload KeyCompareWithValue to work for string-like objects This fixes #8200. * Extra tests --------- Co-authored-by: Derek Bailey --- include/flatbuffers/reflection_generated.h | 42 +++++++++++++++++++ src/idl_gen_cpp.cpp | 12 ++++++ .../generated_cpp17/monster_test_generated.h | 6 +++ tests/key_field/key_field_sample_generated.h | 6 +++ tests/monster_test.cpp | 27 ++++++++++++ tests/monster_test_generated.h | 6 +++ .../ext_only/monster_test_generated.hpp | 6 +++ .../filesuffix_only/monster_test_suffix.h | 6 +++ .../monster_test_suffix.hpp | 6 +++ 9 files changed, 117 insertions(+) diff --git a/include/flatbuffers/reflection_generated.h b/include/flatbuffers/reflection_generated.h index 79e15a5b7..9035128a0 100644 --- a/include/flatbuffers/reflection_generated.h +++ b/include/flatbuffers/reflection_generated.h @@ -274,6 +274,12 @@ struct KeyValue FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int KeyCompareWithValue(const char *_key) const { return strcmp(key()->c_str(), _key); } + template + int KeyCompareWithValue(const StringType& _key) const { + if (key()->c_str() < _key) return -1; + if (_key < key()->c_str()) return 1; + return 0; + } const ::flatbuffers::String *value() const { return GetPointer(VT_VALUE); } @@ -464,6 +470,12 @@ struct Enum FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int KeyCompareWithValue(const char *_name) const { return strcmp(name()->c_str(), _name); } + template + int KeyCompareWithValue(const StringType& _name) const { + if (name()->c_str() < _name) return -1; + if (_name < name()->c_str()) return 1; + return 0; + } const ::flatbuffers::Vector<::flatbuffers::Offset> *values() const { return GetPointer> *>(VT_VALUES); } @@ -616,6 +628,12 @@ struct Field FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int KeyCompareWithValue(const char *_name) const { return strcmp(name()->c_str(), _name); } + template + int KeyCompareWithValue(const StringType& _name) const { + if (name()->c_str() < _name) return -1; + if (_name < name()->c_str()) return 1; + return 0; + } const reflection::Type *type() const { return GetPointer(VT_TYPE); } @@ -834,6 +852,12 @@ struct Object FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int KeyCompareWithValue(const char *_name) const { return strcmp(name()->c_str(), _name); } + template + int KeyCompareWithValue(const StringType& _name) const { + if (name()->c_str() < _name) return -1; + if (_name < name()->c_str()) return 1; + return 0; + } const ::flatbuffers::Vector<::flatbuffers::Offset> *fields() const { return GetPointer> *>(VT_FIELDS); } @@ -986,6 +1010,12 @@ struct RPCCall FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int KeyCompareWithValue(const char *_name) const { return strcmp(name()->c_str(), _name); } + template + int KeyCompareWithValue(const StringType& _name) const { + if (name()->c_str() < _name) return -1; + if (_name < name()->c_str()) return 1; + return 0; + } const reflection::Object *request() const { return GetPointer(VT_REQUEST); } @@ -1102,6 +1132,12 @@ struct Service FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int KeyCompareWithValue(const char *_name) const { return strcmp(name()->c_str(), _name); } + template + int KeyCompareWithValue(const StringType& _name) const { + if (name()->c_str() < _name) return -1; + if (_name < name()->c_str()) return 1; + return 0; + } const ::flatbuffers::Vector<::flatbuffers::Offset> *calls() const { return GetPointer> *>(VT_CALLS); } @@ -1221,6 +1257,12 @@ struct SchemaFile FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int KeyCompareWithValue(const char *_filename) const { return strcmp(filename()->c_str(), _filename); } + template + int KeyCompareWithValue(const StringType& _filename) const { + if (filename()->c_str() < _filename) return -1; + if (_filename < filename()->c_str()) return 1; + return 0; + } /// Names of included files, relative to project root. const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *included_filenames() const { return GetPointer> *>(VT_INCLUDED_FILENAMES); diff --git a/src/idl_gen_cpp.cpp b/src/idl_gen_cpp.cpp index 621ea191a..301c08d2c 100644 --- a/src/idl_gen_cpp.cpp +++ b/src/idl_gen_cpp.cpp @@ -2417,8 +2417,20 @@ class CppGenerator : public BaseGenerator { // Generate KeyCompareWithValue function if (is_string) { + // Compares key against a null-terminated char array. code_ += " int KeyCompareWithValue(const char *_{{FIELD_NAME}}) const {"; code_ += " return strcmp({{FIELD_NAME}}()->c_str(), _{{FIELD_NAME}});"; + code_ += " }"; + // Compares key against any string-like object (e.g. std::string_view or + // std::string) that implements operator< comparison with const char*. + code_ += " template"; + code_ += + " int KeyCompareWithValue(const StringType& _{{FIELD_NAME}}) const " + "{"; + code_ += + " if ({{FIELD_NAME}}()->c_str() < _{{FIELD_NAME}}) return -1;"; + code_ += " if (_{{FIELD_NAME}} < {{FIELD_NAME}}()->c_str()) return 1;"; + code_ += " return 0;"; } else if (is_array) { const auto &elem_type = field.value.type.VectorType(); std::string input_type = "::flatbuffers::Array<" + diff --git a/tests/cpp17/generated_cpp17/monster_test_generated.h b/tests/cpp17/generated_cpp17/monster_test_generated.h index be3371b9d..acbfb9728 100644 --- a/tests/cpp17/generated_cpp17/monster_test_generated.h +++ b/tests/cpp17/generated_cpp17/monster_test_generated.h @@ -1436,6 +1436,12 @@ struct Monster FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int KeyCompareWithValue(const char *_name) const { return strcmp(name()->c_str(), _name); } + template + int KeyCompareWithValue(const StringType& _name) const { + if (name()->c_str() < _name) return -1; + if (_name < name()->c_str()) return 1; + return 0; + } const ::flatbuffers::Vector *inventory() const { return GetPointer *>(VT_INVENTORY); } diff --git a/tests/key_field/key_field_sample_generated.h b/tests/key_field/key_field_sample_generated.h index ea8e0411b..b98eaeb98 100644 --- a/tests/key_field/key_field_sample_generated.h +++ b/tests/key_field/key_field_sample_generated.h @@ -598,6 +598,12 @@ struct FooTable FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int KeyCompareWithValue(const char *_c) const { return strcmp(c()->c_str(), _c); } + template + int KeyCompareWithValue(const StringType& _c) const { + if (c()->c_str() < _c) return -1; + if (_c < c()->c_str()) return 1; + return 0; + } const ::flatbuffers::Vector *d() const { return GetPointer *>(VT_D); } diff --git a/tests/monster_test.cpp b/tests/monster_test.cpp index 8ec031a3d..b546d2074 100644 --- a/tests/monster_test.cpp +++ b/tests/monster_test.cpp @@ -313,6 +313,33 @@ void AccessFlatBufferTest(const uint8_t *flatbuf, size_t length, bool pooled) { TEST_NOTNULL(vecoftables->LookupByKey("Fred")); TEST_NOTNULL(vecoftables->LookupByKey("Wilma")); + // Verify the same objects are returned for char*-based and string-based + // lookups. + TEST_EQ(vecoftables->LookupByKey("Barney"), + vecoftables->LookupByKey(std::string("Barney"))); + TEST_EQ(vecoftables->LookupByKey("Fred"), + vecoftables->LookupByKey(std::string("Fred"))); + TEST_EQ(vecoftables->LookupByKey("Wilma"), + vecoftables->LookupByKey(std::string("Wilma"))); + +#ifdef FLATBUFFERS_HAS_STRING_VIEW + // Tests for LookupByKey with a key that is a truncated + // version of a longer, invalid key. + const std::string invalid_key = "Barney123"; + std::string_view valid_truncated_key = invalid_key; + valid_truncated_key.remove_suffix(3); // "Barney" + TEST_NOTNULL(vecoftables->LookupByKey(valid_truncated_key)); + TEST_EQ(vecoftables->LookupByKey("Barney"), + vecoftables->LookupByKey(valid_truncated_key)); + + // Tests for LookupByKey with a key that is a truncated + // version of a longer, valid key. + const std::string valid_key = "Barney"; + std::string_view invalid_truncated_key = valid_key; + invalid_truncated_key.remove_suffix(3); // "Bar" + TEST_NULL(vecoftables->LookupByKey(invalid_truncated_key)); +#endif // FLATBUFFERS_HAS_STRING_VIEW + // Test accessing a vector of sorted structs auto vecofstructs = monster->testarrayofsortedstruct(); if (vecofstructs) { // not filled in monster_test.bfbs diff --git a/tests/monster_test_generated.h b/tests/monster_test_generated.h index e93010e3b..923202983 100644 --- a/tests/monster_test_generated.h +++ b/tests/monster_test_generated.h @@ -1432,6 +1432,12 @@ struct Monster FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int KeyCompareWithValue(const char *_name) const { return strcmp(name()->c_str(), _name); } + template + int KeyCompareWithValue(const StringType& _name) const { + if (name()->c_str() < _name) return -1; + if (_name < name()->c_str()) return 1; + return 0; + } const ::flatbuffers::Vector *inventory() const { return GetPointer *>(VT_INVENTORY); } diff --git a/tests/monster_test_suffix/ext_only/monster_test_generated.hpp b/tests/monster_test_suffix/ext_only/monster_test_generated.hpp index bfd7a71ef..64e506518 100644 --- a/tests/monster_test_suffix/ext_only/monster_test_generated.hpp +++ b/tests/monster_test_suffix/ext_only/monster_test_generated.hpp @@ -1423,6 +1423,12 @@ struct Monster FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int KeyCompareWithValue(const char *_name) const { return strcmp(name()->c_str(), _name); } + template + int KeyCompareWithValue(const StringType& _name) const { + if (name()->c_str() < _name) return -1; + if (_name < name()->c_str()) return 1; + return 0; + } const ::flatbuffers::Vector *inventory() const { return GetPointer *>(VT_INVENTORY); } diff --git a/tests/monster_test_suffix/filesuffix_only/monster_test_suffix.h b/tests/monster_test_suffix/filesuffix_only/monster_test_suffix.h index bfd7a71ef..64e506518 100644 --- a/tests/monster_test_suffix/filesuffix_only/monster_test_suffix.h +++ b/tests/monster_test_suffix/filesuffix_only/monster_test_suffix.h @@ -1423,6 +1423,12 @@ struct Monster FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int KeyCompareWithValue(const char *_name) const { return strcmp(name()->c_str(), _name); } + template + int KeyCompareWithValue(const StringType& _name) const { + if (name()->c_str() < _name) return -1; + if (_name < name()->c_str()) return 1; + return 0; + } const ::flatbuffers::Vector *inventory() const { return GetPointer *>(VT_INVENTORY); } diff --git a/tests/monster_test_suffix/monster_test_suffix.hpp b/tests/monster_test_suffix/monster_test_suffix.hpp index bfd7a71ef..64e506518 100644 --- a/tests/monster_test_suffix/monster_test_suffix.hpp +++ b/tests/monster_test_suffix/monster_test_suffix.hpp @@ -1423,6 +1423,12 @@ struct Monster FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int KeyCompareWithValue(const char *_name) const { return strcmp(name()->c_str(), _name); } + template + int KeyCompareWithValue(const StringType& _name) const { + if (name()->c_str() < _name) return -1; + if (_name < name()->c_str()) return 1; + return 0; + } const ::flatbuffers::Vector *inventory() const { return GetPointer *>(VT_INVENTORY); }