Add key lookup support for tables in Go (#7644)

* Add support for key lookup for tables in Go

* Run clang format

* Run go fmt on tests

* Remove TODO in tests

* Update LookupByKey API

* Update LookupByKey API

* Don't use resolvePointer in expectEq

* Use generated getters instead of reading values directly from buffer

* Fix typo

Co-authored-by: Derek Bailey <derekbailey@google.com>
This commit is contained in:
Michael Le
2022-11-22 14:08:19 -08:00
committed by GitHub
parent 1cba8b2b49
commit 60975d6f7e
10 changed files with 365 additions and 7 deletions

View File

@@ -1,5 +1,7 @@
package flatbuffers
import "sort"
// Builder is a state machine for creating FlatBuffer objects.
// Use a Builder to construct object(s) starting from leaf nodes.
//
@@ -315,6 +317,25 @@ func (b *Builder) EndVector(vectorNumElems int) UOffsetT {
return b.Offset()
}
// CreateVectorOfTables serializes slice of table offsets into a vector.
func (b *Builder) CreateVectorOfTables(offsets []UOffsetT) UOffsetT {
b.assertNotNested()
b.StartVector(4, len(offsets), 4)
for i := len(offsets) - 1; i >= 0; i-- {
b.PrependUOffsetT(offsets[i])
}
return b.EndVector(len(offsets))
}
type KeyCompare func(o1, o2 UOffsetT, buf []byte) bool
func (b *Builder) CreateVectorOfSortedTables(offsets []UOffsetT, keyCompare KeyCompare) UOffsetT {
sort.Slice(offsets, func(i, j int) bool {
return keyCompare(offsets[i], offsets[j], b.Bytes)
})
return b.CreateVectorOfTables(offsets)
}
// CreateSharedString Checks if the string is already written
// to the buffer before calling CreateString
func (b *Builder) CreateSharedString(s string) UOffsetT {

View File

@@ -23,3 +23,8 @@ func GetSizePrefixedRootAs(buf []byte, offset UOffsetT, fb FlatBuffer) {
func GetSizePrefix(buf []byte, offset UOffsetT) uint32 {
return GetUint32(buf[offset:])
}
// GetIndirectOffset retrives the relative offset in the provided buffer stored at `offset`.
func GetIndirectOffset(buf []byte, offset UOffsetT) UOffsetT {
return offset + GetUOffsetT(buf[offset:])
}

View File

@@ -21,6 +21,7 @@
#include <sstream>
#include <string>
#include "flatbuffers/base.h"
#include "flatbuffers/code_generators.h"
#include "flatbuffers/flatbuffers.h"
#include "flatbuffers/idl.h"
@@ -104,6 +105,7 @@ class GoGenerator : public BaseGenerator {
++it) {
tracked_imported_namespaces_.clear();
needs_math_import_ = false;
needs_bytes_import_ = false;
needs_imports = false;
std::string enumcode;
GenEnum(**it, &enumcode);
@@ -124,6 +126,7 @@ class GoGenerator : public BaseGenerator {
it != parser_.structs_.vec.end(); ++it) {
tracked_imported_namespaces_.clear();
needs_math_import_ = false;
needs_bytes_import_ = false;
std::string declcode;
GenStruct(**it, &declcode);
if (parser_.opts.one_file) {
@@ -158,6 +161,7 @@ class GoGenerator : public BaseGenerator {
};
std::set<const Definition *, NamespacePtrLess> tracked_imported_namespaces_;
bool needs_math_import_ = false;
bool needs_bytes_import_ = false;
// Most field accessors need to retrieve and test the field offset first,
// this is the prefix code for that.
@@ -489,6 +493,34 @@ class GoGenerator : public BaseGenerator {
code += "}\n\n";
}
void GetMemberOfVectorOfStructByKey(const StructDef &struct_def,
const FieldDef &field,
std::string *code_ptr) {
std::string &code = *code_ptr;
auto vectortype = field.value.type.VectorType();
FLATBUFFERS_ASSERT(vectortype.struct_def->has_key);
auto &vector_struct_fields = vectortype.struct_def->fields.vec;
auto kit =
std::find_if(vector_struct_fields.begin(), vector_struct_fields.end(),
[&](FieldDef *field) { return field->key; });
auto &key_field = **kit;
FLATBUFFERS_ASSERT(key_field.key);
GenReceiver(struct_def, code_ptr);
code += " " + namer_.Field(field) + "ByKey";
code += "(obj *" + TypeName(field);
code += ", key " + NativeType(key_field.value.type) + ") bool" +
OffsetPrefix(field);
code += "\t\tx := rcv._tab.Vector(o)\n";
code += "\t\treturn ";
code += "obj.LookupByKey(key, x, rcv._tab.Bytes)\n";
code += "\t}\n";
code += "\treturn false\n";
code += "}\n\n";
}
// Get the value of a vector's non-struct member.
void GetMemberOfVectorOfNonStruct(const StructDef &struct_def,
const FieldDef &field,
@@ -690,6 +722,12 @@ class GoGenerator : public BaseGenerator {
auto vectortype = field.value.type.VectorType();
if (vectortype.base_type == BASE_TYPE_STRUCT) {
GetMemberOfVectorOfStruct(struct_def, field, code_ptr);
// TODO(michaeltle): Support querying fixed struct by key.
// Currently, we only support keyed tables.
if (!vectortype.struct_def->fixed &&
vectortype.struct_def->has_key) {
GetMemberOfVectorOfStructByKey(struct_def, field, code_ptr);
}
} else {
GetMemberOfVectorOfNonStruct(struct_def, field, code_ptr);
}
@@ -824,6 +862,12 @@ class GoGenerator : public BaseGenerator {
GenStructAccessor(struct_def, field, code_ptr);
GenStructMutator(struct_def, field, code_ptr);
// TODO(michaeltle): Support querying fixed struct by key. Currently,
// we only support keyed tables.
if (!struct_def.fixed && field.key) {
GenKeyCompare(struct_def, field, code_ptr);
GenLookupByKey(struct_def, field, code_ptr);
}
}
// Generate builders
@@ -836,6 +880,79 @@ class GoGenerator : public BaseGenerator {
}
}
void GenKeyCompare(const StructDef &struct_def, const FieldDef &field,
std::string *code_ptr) {
FLATBUFFERS_ASSERT(struct_def.has_key);
FLATBUFFERS_ASSERT(field.key);
std::string &code = *code_ptr;
code += "func " + namer_.Type(struct_def) + "KeyCompare(";
code += "o1, o2 flatbuffers.UOffsetT, buf []byte) bool {\n";
code += "\tobj1 := &" + namer_.Type(struct_def) + "{}\n";
code += "\tobj2 := &" + namer_.Type(struct_def) + "{}\n";
code += "\tobj1.Init(buf, flatbuffers.UOffsetT(len(buf)) - o1)\n";
code += "\tobj2.Init(buf, flatbuffers.UOffsetT(len(buf)) - o2)\n";
if (IsString(field.value.type)) {
code += "\treturn string(obj1." + namer_.Function(field.name) + "()) < ";
code += "string(obj2." + namer_.Function(field.name) + "())\n";
} else {
code += "\treturn obj1." + namer_.Function(field.name) + "() < ";
code += "obj2." + namer_.Function(field.name) + "()\n";
}
code += "}\n\n";
}
void GenLookupByKey(const StructDef &struct_def, const FieldDef &field,
std::string *code_ptr) {
FLATBUFFERS_ASSERT(struct_def.has_key);
FLATBUFFERS_ASSERT(field.key);
std::string &code = *code_ptr;
GenReceiver(struct_def, code_ptr);
code += " LookupByKey(";
code += "key " + NativeType(field.value.type) + ", ";
code += "vectorLocation flatbuffers.UOffsetT, ";
code += "buf []byte) bool {\n";
code += "\tspan := flatbuffers.GetUOffsetT(buf[vectorLocation - 4:])\n";
code += "\tstart := flatbuffers.UOffsetT(0)\n";
code += "\tfor span != 0 {\n";
code += "\t\tmiddle := span / 2\n";
code += "\t\ttableOffset := flatbuffers.GetIndirectOffset(buf, ";
code += "vectorLocation+ 4 * (start + middle))\n";
code += "\t\tobj := &" + namer_.Type(struct_def) + "{}\n";
code += "\t\tobj.Init(buf, tableOffset)\n";
if (IsString(field.value.type)) {
code += "\t\tbKey := []byte(key)\n";
needs_bytes_import_ = true;
code +=
"\t\tcomp := bytes.Compare(obj." + namer_.Function(field.name) + "()";
code += ", bKey)\n";
} else {
code += "\t\tval := obj." + namer_.Function(field.name) + "()\n";
code += "\t\tcomp := 0\n";
code += "\t\tif val > key {\n";
code += "\t\t\tcomp = 1\n";
code += "\t\t} else if val < key {\n";
code += "\t\t\tcomp = -1\n";
code += "\t\t}\n";
}
code += "\t\tif comp > 0 {\n";
code += "\t\t\tspan = middle\n";
code += "\t\t} else if comp < 0 {\n";
code += "\t\t\tmiddle += 1\n";
code += "\t\t\tstart += middle\n";
code += "\t\t\tspan -= middle\n";
code += "\t\t} else {\n";
code += "\t\t\trcv.Init(buf, tableOffset)\n";
code += "\t\t\treturn true\n";
code += "\t\t}\n";
code += "\t}\n";
code += "\treturn false\n";
code += "}\n\n";
}
void GenNativeStruct(const StructDef &struct_def, std::string *code_ptr) {
std::string &code = *code_ptr;
@@ -1354,9 +1471,10 @@ class GoGenerator : public BaseGenerator {
code += "package " + name_space_name + "\n\n";
if (needs_imports) {
code += "import (\n";
if (is_enum) { code += "\t\"strconv\"\n\n"; }
if (needs_bytes_import_) code += "\t\"bytes\"\n";
// math is needed to support non-finite scalar default values.
if (needs_math_import_) { code += "\t\"math\"\n\n"; }
if (needs_math_import_) { code += "\t\"math\"\n"; }
if (is_enum) { code += "\t\"strconv\"\n"; }
if (!parser_.opts.go_import.empty()) {
code += "\tflatbuffers \"" + parser_.opts.go_import + "\"\n";
} else {

View File

@@ -4,7 +4,6 @@ package Example
import (
"strconv"
flatbuffers "github.com/google/flatbuffers/go"
MyGame__Example2 "MyGame/Example2"

View File

@@ -4,7 +4,6 @@ package Example
import (
"strconv"
flatbuffers "github.com/google/flatbuffers/go"
)

View File

@@ -4,7 +4,6 @@ package Example
import (
"strconv"
flatbuffers "github.com/google/flatbuffers/go"
MyGame__Example2 "MyGame/Example2"

View File

@@ -3,8 +3,8 @@
package Example
import (
"bytes"
"math"
flatbuffers "github.com/google/flatbuffers/go"
MyGame "MyGame"
@@ -568,6 +568,38 @@ func (rcv *Monster) Name() []byte {
return nil
}
func MonsterKeyCompare(o1, o2 flatbuffers.UOffsetT, buf []byte) bool {
obj1 := &Monster{}
obj2 := &Monster{}
obj1.Init(buf, flatbuffers.UOffsetT(len(buf)) - o1)
obj2.Init(buf, flatbuffers.UOffsetT(len(buf)) - o2)
return string(obj1.Name()) < string(obj2.Name())
}
func (rcv *Monster) LookupByKey(key string, vectorLocation flatbuffers.UOffsetT, buf []byte) bool {
span := flatbuffers.GetUOffsetT(buf[vectorLocation - 4:])
start := flatbuffers.UOffsetT(0)
for span != 0 {
middle := span / 2
tableOffset := flatbuffers.GetIndirectOffset(buf, vectorLocation+ 4 * (start + middle))
obj := &Monster{}
obj.Init(buf, tableOffset)
bKey := []byte(key)
comp := bytes.Compare(obj.Name(), bKey)
if comp > 0 {
span = middle
} else if comp < 0 {
middle += 1
start += middle
span -= middle
} else {
rcv.Init(buf, tableOffset)
return true
}
}
return false
}
func (rcv *Monster) Inventory(j int) byte {
o := flatbuffers.UOffsetT(rcv._tab.Offset(14))
if o != 0 {
@@ -685,6 +717,15 @@ func (rcv *Monster) Testarrayoftables(obj *Monster, j int) bool {
return false
}
func (rcv *Monster) TestarrayoftablesByKey(obj *Monster, key string) bool{
o := flatbuffers.UOffsetT(rcv._tab.Offset(26))
if o != 0 {
x := rcv._tab.Vector(o)
return obj.LookupByKey(key, x, rcv._tab.Bytes)
}
return false
}
func (rcv *Monster) TestarrayoftablesLength() int {
o := flatbuffers.UOffsetT(rcv._tab.Offset(26))
if o != 0 {
@@ -1091,6 +1132,15 @@ func (rcv *Monster) VectorOfReferrables(obj *Referrable, j int) bool {
return false
}
func (rcv *Monster) VectorOfReferrablesByKey(obj *Referrable, key uint64) bool{
o := flatbuffers.UOffsetT(rcv._tab.Offset(74))
if o != 0 {
x := rcv._tab.Vector(o)
return obj.LookupByKey(key, x, rcv._tab.Bytes)
}
return false
}
func (rcv *Monster) VectorOfReferrablesLength() int {
o := flatbuffers.UOffsetT(rcv._tab.Offset(74))
if o != 0 {
@@ -1149,6 +1199,15 @@ func (rcv *Monster) VectorOfStrongReferrables(obj *Referrable, j int) bool {
return false
}
func (rcv *Monster) VectorOfStrongReferrablesByKey(obj *Referrable, key uint64) bool{
o := flatbuffers.UOffsetT(rcv._tab.Offset(80))
if o != 0 {
x := rcv._tab.Vector(o)
return obj.LookupByKey(key, x, rcv._tab.Bytes)
}
return false
}
func (rcv *Monster) VectorOfStrongReferrablesLength() int {
o := flatbuffers.UOffsetT(rcv._tab.Offset(80))
if o != 0 {
@@ -1367,6 +1426,15 @@ func (rcv *Monster) ScalarKeySortedTables(obj *Stat, j int) bool {
return false
}
func (rcv *Monster) ScalarKeySortedTablesByKey(obj *Stat, key uint16) bool{
o := flatbuffers.UOffsetT(rcv._tab.Offset(104))
if o != 0 {
x := rcv._tab.Vector(o)
return obj.LookupByKey(key, x, rcv._tab.Bytes)
}
return false
}
func (rcv *Monster) ScalarKeySortedTablesLength() int {
o := flatbuffers.UOffsetT(rcv._tab.Offset(104))
if o != 0 {

View File

@@ -67,6 +67,43 @@ func (rcv *Referrable) MutateId(n uint64) bool {
return rcv._tab.MutateUint64Slot(4, n)
}
func ReferrableKeyCompare(o1, o2 flatbuffers.UOffsetT, buf []byte) bool {
obj1 := &Referrable{}
obj2 := &Referrable{}
obj1.Init(buf, flatbuffers.UOffsetT(len(buf)) - o1)
obj2.Init(buf, flatbuffers.UOffsetT(len(buf)) - o2)
return obj1.Id() < obj2.Id()
}
func (rcv *Referrable) LookupByKey(key uint64, vectorLocation flatbuffers.UOffsetT, buf []byte) bool {
span := flatbuffers.GetUOffsetT(buf[vectorLocation - 4:])
start := flatbuffers.UOffsetT(0)
for span != 0 {
middle := span / 2
tableOffset := flatbuffers.GetIndirectOffset(buf, vectorLocation+ 4 * (start + middle))
obj := &Referrable{}
obj.Init(buf, tableOffset)
val := obj.Id()
comp := 0
if val > key {
comp = 1
} else if val < key {
comp = -1
}
if comp > 0 {
span = middle
} else if comp < 0 {
middle += 1
start += middle
span -= middle
} else {
rcv.Init(buf, tableOffset)
return true
}
}
return false
}
func ReferrableStart(builder *flatbuffers.Builder) {
builder.StartObject(1)
}

View File

@@ -94,6 +94,43 @@ func (rcv *Stat) MutateCount(n uint16) bool {
return rcv._tab.MutateUint16Slot(8, n)
}
func StatKeyCompare(o1, o2 flatbuffers.UOffsetT, buf []byte) bool {
obj1 := &Stat{}
obj2 := &Stat{}
obj1.Init(buf, flatbuffers.UOffsetT(len(buf)) - o1)
obj2.Init(buf, flatbuffers.UOffsetT(len(buf)) - o2)
return obj1.Count() < obj2.Count()
}
func (rcv *Stat) LookupByKey(key uint16, vectorLocation flatbuffers.UOffsetT, buf []byte) bool {
span := flatbuffers.GetUOffsetT(buf[vectorLocation - 4:])
start := flatbuffers.UOffsetT(0)
for span != 0 {
middle := span / 2
tableOffset := flatbuffers.GetIndirectOffset(buf, vectorLocation+ 4 * (start + middle))
obj := &Stat{}
obj.Init(buf, tableOffset)
val := obj.Count()
comp := 0
if val > key {
comp = 1
} else if val < key {
comp = -1
}
if comp > 0 {
span = middle
} else if comp < 0 {
middle += 1
start += middle
span -= middle
} else {
rcv.Init(buf, tableOffset)
return true
}
}
return false
}
func StatStart(builder *flatbuffers.Builder) {
builder.StartObject(3)
}

View File

@@ -186,9 +186,12 @@ func TestAll(t *testing.T) {
// Check size-prefixed flatbuffers
CheckSizePrefixedBuffer(t.Fatalf)
// Check that optional scalars work
// Check that optional scalars works
CheckOptionalScalars(t.Fatalf)
// Check that getting vector element by key works
CheckByKey(t.Fatalf)
// If the filename of the FlatBuffers file generated by the Java test
// is given, check that Go code can read it, and that Go code
// generates an identical buffer when used to create the example data:
@@ -2215,6 +2218,78 @@ func CheckOptionalScalars(fail func(string, ...interface{})) {
expectEq("defaultEnum", obj.DefaultEnum, optional_scalars.OptionalByteTwo)
}
func CheckByKey(fail func(string, ...interface{})) {
expectEq := func(what string, a, b interface{}) {
if a != b {
fail(FailString("Lookup by key: "+what, b, a))
}
}
b := flatbuffers.NewBuilder(0)
name := b.CreateString("Boss")
slime := &example.MonsterT{Name: "Slime"}
pig := &example.MonsterT{Name: "Pig"}
slimeBoss := &example.MonsterT{Name: "SlimeBoss"}
mushroom := &example.MonsterT{Name: "Mushroom"}
ironPig := &example.MonsterT{Name: "Iron Pig"}
monsterOffsets := make([]flatbuffers.UOffsetT, 5)
monsterOffsets[0] = slime.Pack(b)
monsterOffsets[1] = pig.Pack(b)
monsterOffsets[2] = slimeBoss.Pack(b)
monsterOffsets[3] = mushroom.Pack(b)
monsterOffsets[4] = ironPig.Pack(b)
testarrayoftables := b.CreateVectorOfSortedTables(monsterOffsets, example.MonsterKeyCompare)
str := &example.StatT{Id: "Strength", Count: 42}
luk := &example.StatT{Id: "Luck", Count: 51}
hp := &example.StatT{Id: "Health", Count: 12}
// Test default count value of 0
mp := &example.StatT{Id: "Mana"}
statOffsets := make([]flatbuffers.UOffsetT, 4)
statOffsets[0] = str.Pack(b)
statOffsets[1] = luk.Pack(b)
statOffsets[2] = hp.Pack(b)
statOffsets[3] = mp.Pack(b)
scalarKeySortedTablesOffset := b.CreateVectorOfSortedTables(statOffsets, example.StatKeyCompare)
example.MonsterStart(b)
example.MonsterAddName(b, name)
example.MonsterAddTestarrayoftables(b, testarrayoftables)
example.MonsterAddScalarKeySortedTables(b, scalarKeySortedTablesOffset)
moff := example.MonsterEnd(b)
b.Finish(moff)
monster := example.GetRootAsMonster(b.Bytes, b.Head())
slimeMon := &example.Monster{}
monster.TestarrayoftablesByKey(slimeMon, slime.Name)
mushroomMon := &example.Monster{}
monster.TestarrayoftablesByKey(mushroomMon, mushroom.Name)
slimeBossMon := &example.Monster{}
monster.TestarrayoftablesByKey(slimeBossMon, slimeBoss.Name)
strStat := &example.Stat{}
monster.ScalarKeySortedTablesByKey(strStat, str.Count)
lukStat := &example.Stat{}
monster.ScalarKeySortedTablesByKey(lukStat, luk.Count)
mpStat := &example.Stat{}
monster.ScalarKeySortedTablesByKey(mpStat, mp.Count)
expectEq("Boss name", string(monster.Name()), "Boss")
expectEq("Slime name", string(slimeMon.Name()), slime.Name)
expectEq("Mushroom name", string(mushroomMon.Name()), mushroom.Name)
expectEq("SlimeBoss name", string(slimeBossMon.Name()), slimeBoss.Name)
expectEq("Strength Id", string(strStat.Id()), str.Id)
expectEq("Strength Count", strStat.Count(), str.Count)
expectEq("Luck Id", string(lukStat.Id()), luk.Id)
expectEq("Luck Count", lukStat.Count(), luk.Count)
expectEq("Mana Id", string(mpStat.Id()), mp.Id)
// Use default count value as key
expectEq("Mana Count", mpStat.Count(), uint16(0))
}
// BenchmarkVtableDeduplication measures the speed of vtable deduplication
// by creating prePop vtables, then populating b.N objects with a
// different single vtable.