diff --git a/go/builder.go b/go/builder.go index d99b590bb..5d90e8ef9 100644 --- a/go/builder.go +++ b/go/builder.go @@ -1,5 +1,7 @@ package flatbuffers +import "sort" + // Builder is a state machine for creating FlatBuffer objects. // Use a Builder to construct object(s) starting from leaf nodes. // @@ -315,6 +317,25 @@ func (b *Builder) EndVector(vectorNumElems int) UOffsetT { return b.Offset() } +// CreateVectorOfTables serializes slice of table offsets into a vector. +func (b *Builder) CreateVectorOfTables(offsets []UOffsetT) UOffsetT { + b.assertNotNested() + b.StartVector(4, len(offsets), 4) + for i := len(offsets) - 1; i >= 0; i-- { + b.PrependUOffsetT(offsets[i]) + } + return b.EndVector(len(offsets)) +} + +type KeyCompare func(o1, o2 UOffsetT, buf []byte) bool + +func (b *Builder) CreateVectorOfSortedTables(offsets []UOffsetT, keyCompare KeyCompare) UOffsetT { + sort.Slice(offsets, func(i, j int) bool { + return keyCompare(offsets[i], offsets[j], b.Bytes) + }) + return b.CreateVectorOfTables(offsets) +} + // CreateSharedString Checks if the string is already written // to the buffer before calling CreateString func (b *Builder) CreateSharedString(s string) UOffsetT { diff --git a/go/lib.go b/go/lib.go index 9a333ff04..9333d8bd3 100644 --- a/go/lib.go +++ b/go/lib.go @@ -23,3 +23,8 @@ func GetSizePrefixedRootAs(buf []byte, offset UOffsetT, fb FlatBuffer) { func GetSizePrefix(buf []byte, offset UOffsetT) uint32 { return GetUint32(buf[offset:]) } + +// GetIndirectOffset retrives the relative offset in the provided buffer stored at `offset`. +func GetIndirectOffset(buf []byte, offset UOffsetT) UOffsetT { + return offset + GetUOffsetT(buf[offset:]) +} diff --git a/src/idl_gen_go.cpp b/src/idl_gen_go.cpp index a5e0c364f..54e886458 100644 --- a/src/idl_gen_go.cpp +++ b/src/idl_gen_go.cpp @@ -21,6 +21,7 @@ #include #include +#include "flatbuffers/base.h" #include "flatbuffers/code_generators.h" #include "flatbuffers/flatbuffers.h" #include "flatbuffers/idl.h" @@ -104,6 +105,7 @@ class GoGenerator : public BaseGenerator { ++it) { tracked_imported_namespaces_.clear(); needs_math_import_ = false; + needs_bytes_import_ = false; needs_imports = false; std::string enumcode; GenEnum(**it, &enumcode); @@ -124,6 +126,7 @@ class GoGenerator : public BaseGenerator { it != parser_.structs_.vec.end(); ++it) { tracked_imported_namespaces_.clear(); needs_math_import_ = false; + needs_bytes_import_ = false; std::string declcode; GenStruct(**it, &declcode); if (parser_.opts.one_file) { @@ -158,6 +161,7 @@ class GoGenerator : public BaseGenerator { }; std::set tracked_imported_namespaces_; bool needs_math_import_ = false; + bool needs_bytes_import_ = false; // Most field accessors need to retrieve and test the field offset first, // this is the prefix code for that. @@ -489,6 +493,34 @@ class GoGenerator : public BaseGenerator { code += "}\n\n"; } + void GetMemberOfVectorOfStructByKey(const StructDef &struct_def, + const FieldDef &field, + std::string *code_ptr) { + std::string &code = *code_ptr; + auto vectortype = field.value.type.VectorType(); + FLATBUFFERS_ASSERT(vectortype.struct_def->has_key); + + auto &vector_struct_fields = vectortype.struct_def->fields.vec; + auto kit = + std::find_if(vector_struct_fields.begin(), vector_struct_fields.end(), + [&](FieldDef *field) { return field->key; }); + + auto &key_field = **kit; + FLATBUFFERS_ASSERT(key_field.key); + + GenReceiver(struct_def, code_ptr); + code += " " + namer_.Field(field) + "ByKey"; + code += "(obj *" + TypeName(field); + code += ", key " + NativeType(key_field.value.type) + ") bool" + + OffsetPrefix(field); + code += "\t\tx := rcv._tab.Vector(o)\n"; + code += "\t\treturn "; + code += "obj.LookupByKey(key, x, rcv._tab.Bytes)\n"; + code += "\t}\n"; + code += "\treturn false\n"; + code += "}\n\n"; + } + // Get the value of a vector's non-struct member. void GetMemberOfVectorOfNonStruct(const StructDef &struct_def, const FieldDef &field, @@ -690,6 +722,12 @@ class GoGenerator : public BaseGenerator { auto vectortype = field.value.type.VectorType(); if (vectortype.base_type == BASE_TYPE_STRUCT) { GetMemberOfVectorOfStruct(struct_def, field, code_ptr); + // TODO(michaeltle): Support querying fixed struct by key. + // Currently, we only support keyed tables. + if (!vectortype.struct_def->fixed && + vectortype.struct_def->has_key) { + GetMemberOfVectorOfStructByKey(struct_def, field, code_ptr); + } } else { GetMemberOfVectorOfNonStruct(struct_def, field, code_ptr); } @@ -824,6 +862,12 @@ class GoGenerator : public BaseGenerator { GenStructAccessor(struct_def, field, code_ptr); GenStructMutator(struct_def, field, code_ptr); + // TODO(michaeltle): Support querying fixed struct by key. Currently, + // we only support keyed tables. + if (!struct_def.fixed && field.key) { + GenKeyCompare(struct_def, field, code_ptr); + GenLookupByKey(struct_def, field, code_ptr); + } } // Generate builders @@ -836,6 +880,79 @@ class GoGenerator : public BaseGenerator { } } + void GenKeyCompare(const StructDef &struct_def, const FieldDef &field, + std::string *code_ptr) { + FLATBUFFERS_ASSERT(struct_def.has_key); + FLATBUFFERS_ASSERT(field.key); + std::string &code = *code_ptr; + + code += "func " + namer_.Type(struct_def) + "KeyCompare("; + code += "o1, o2 flatbuffers.UOffsetT, buf []byte) bool {\n"; + code += "\tobj1 := &" + namer_.Type(struct_def) + "{}\n"; + code += "\tobj2 := &" + namer_.Type(struct_def) + "{}\n"; + code += "\tobj1.Init(buf, flatbuffers.UOffsetT(len(buf)) - o1)\n"; + code += "\tobj2.Init(buf, flatbuffers.UOffsetT(len(buf)) - o2)\n"; + if (IsString(field.value.type)) { + code += "\treturn string(obj1." + namer_.Function(field.name) + "()) < "; + code += "string(obj2." + namer_.Function(field.name) + "())\n"; + } else { + code += "\treturn obj1." + namer_.Function(field.name) + "() < "; + code += "obj2." + namer_.Function(field.name) + "()\n"; + } + code += "}\n\n"; + } + + void GenLookupByKey(const StructDef &struct_def, const FieldDef &field, + std::string *code_ptr) { + FLATBUFFERS_ASSERT(struct_def.has_key); + FLATBUFFERS_ASSERT(field.key); + std::string &code = *code_ptr; + + GenReceiver(struct_def, code_ptr); + code += " LookupByKey("; + code += "key " + NativeType(field.value.type) + ", "; + code += "vectorLocation flatbuffers.UOffsetT, "; + code += "buf []byte) bool {\n"; + code += "\tspan := flatbuffers.GetUOffsetT(buf[vectorLocation - 4:])\n"; + code += "\tstart := flatbuffers.UOffsetT(0)\n"; + code += "\tfor span != 0 {\n"; + code += "\t\tmiddle := span / 2\n"; + code += "\t\ttableOffset := flatbuffers.GetIndirectOffset(buf, "; + code += "vectorLocation+ 4 * (start + middle))\n"; + + code += "\t\tobj := &" + namer_.Type(struct_def) + "{}\n"; + code += "\t\tobj.Init(buf, tableOffset)\n"; + + if (IsString(field.value.type)) { + code += "\t\tbKey := []byte(key)\n"; + needs_bytes_import_ = true; + code += + "\t\tcomp := bytes.Compare(obj." + namer_.Function(field.name) + "()"; + code += ", bKey)\n"; + } else { + code += "\t\tval := obj." + namer_.Function(field.name) + "()\n"; + code += "\t\tcomp := 0\n"; + code += "\t\tif val > key {\n"; + code += "\t\t\tcomp = 1\n"; + code += "\t\t} else if val < key {\n"; + code += "\t\t\tcomp = -1\n"; + code += "\t\t}\n"; + } + code += "\t\tif comp > 0 {\n"; + code += "\t\t\tspan = middle\n"; + code += "\t\t} else if comp < 0 {\n"; + code += "\t\t\tmiddle += 1\n"; + code += "\t\t\tstart += middle\n"; + code += "\t\t\tspan -= middle\n"; + code += "\t\t} else {\n"; + code += "\t\t\trcv.Init(buf, tableOffset)\n"; + code += "\t\t\treturn true\n"; + code += "\t\t}\n"; + code += "\t}\n"; + code += "\treturn false\n"; + code += "}\n\n"; + } + void GenNativeStruct(const StructDef &struct_def, std::string *code_ptr) { std::string &code = *code_ptr; @@ -1354,9 +1471,10 @@ class GoGenerator : public BaseGenerator { code += "package " + name_space_name + "\n\n"; if (needs_imports) { code += "import (\n"; - if (is_enum) { code += "\t\"strconv\"\n\n"; } + if (needs_bytes_import_) code += "\t\"bytes\"\n"; // math is needed to support non-finite scalar default values. - if (needs_math_import_) { code += "\t\"math\"\n\n"; } + if (needs_math_import_) { code += "\t\"math\"\n"; } + if (is_enum) { code += "\t\"strconv\"\n"; } if (!parser_.opts.go_import.empty()) { code += "\tflatbuffers \"" + parser_.opts.go_import + "\"\n"; } else { diff --git a/tests/MyGame/Example/Any.go b/tests/MyGame/Example/Any.go index 62664185b..3b7f6295d 100644 --- a/tests/MyGame/Example/Any.go +++ b/tests/MyGame/Example/Any.go @@ -4,7 +4,6 @@ package Example import ( "strconv" - flatbuffers "github.com/google/flatbuffers/go" MyGame__Example2 "MyGame/Example2" diff --git a/tests/MyGame/Example/AnyAmbiguousAliases.go b/tests/MyGame/Example/AnyAmbiguousAliases.go index cdb65c9b2..83e5f7d82 100644 --- a/tests/MyGame/Example/AnyAmbiguousAliases.go +++ b/tests/MyGame/Example/AnyAmbiguousAliases.go @@ -4,7 +4,6 @@ package Example import ( "strconv" - flatbuffers "github.com/google/flatbuffers/go" ) diff --git a/tests/MyGame/Example/AnyUniqueAliases.go b/tests/MyGame/Example/AnyUniqueAliases.go index 32cbe08b9..b36e61d9b 100644 --- a/tests/MyGame/Example/AnyUniqueAliases.go +++ b/tests/MyGame/Example/AnyUniqueAliases.go @@ -4,7 +4,6 @@ package Example import ( "strconv" - flatbuffers "github.com/google/flatbuffers/go" MyGame__Example2 "MyGame/Example2" diff --git a/tests/MyGame/Example/Monster.go b/tests/MyGame/Example/Monster.go index b64ced7da..6f8fae39b 100644 --- a/tests/MyGame/Example/Monster.go +++ b/tests/MyGame/Example/Monster.go @@ -3,8 +3,8 @@ package Example import ( + "bytes" "math" - flatbuffers "github.com/google/flatbuffers/go" MyGame "MyGame" @@ -568,6 +568,38 @@ func (rcv *Monster) Name() []byte { return nil } +func MonsterKeyCompare(o1, o2 flatbuffers.UOffsetT, buf []byte) bool { + obj1 := &Monster{} + obj2 := &Monster{} + obj1.Init(buf, flatbuffers.UOffsetT(len(buf)) - o1) + obj2.Init(buf, flatbuffers.UOffsetT(len(buf)) - o2) + return string(obj1.Name()) < string(obj2.Name()) +} + +func (rcv *Monster) LookupByKey(key string, vectorLocation flatbuffers.UOffsetT, buf []byte) bool { + span := flatbuffers.GetUOffsetT(buf[vectorLocation - 4:]) + start := flatbuffers.UOffsetT(0) + for span != 0 { + middle := span / 2 + tableOffset := flatbuffers.GetIndirectOffset(buf, vectorLocation+ 4 * (start + middle)) + obj := &Monster{} + obj.Init(buf, tableOffset) + bKey := []byte(key) + comp := bytes.Compare(obj.Name(), bKey) + if comp > 0 { + span = middle + } else if comp < 0 { + middle += 1 + start += middle + span -= middle + } else { + rcv.Init(buf, tableOffset) + return true + } + } + return false +} + func (rcv *Monster) Inventory(j int) byte { o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) if o != 0 { @@ -685,6 +717,15 @@ func (rcv *Monster) Testarrayoftables(obj *Monster, j int) bool { return false } +func (rcv *Monster) TestarrayoftablesByKey(obj *Monster, key string) bool{ + o := flatbuffers.UOffsetT(rcv._tab.Offset(26)) + if o != 0 { + x := rcv._tab.Vector(o) + return obj.LookupByKey(key, x, rcv._tab.Bytes) + } + return false +} + func (rcv *Monster) TestarrayoftablesLength() int { o := flatbuffers.UOffsetT(rcv._tab.Offset(26)) if o != 0 { @@ -1091,6 +1132,15 @@ func (rcv *Monster) VectorOfReferrables(obj *Referrable, j int) bool { return false } +func (rcv *Monster) VectorOfReferrablesByKey(obj *Referrable, key uint64) bool{ + o := flatbuffers.UOffsetT(rcv._tab.Offset(74)) + if o != 0 { + x := rcv._tab.Vector(o) + return obj.LookupByKey(key, x, rcv._tab.Bytes) + } + return false +} + func (rcv *Monster) VectorOfReferrablesLength() int { o := flatbuffers.UOffsetT(rcv._tab.Offset(74)) if o != 0 { @@ -1149,6 +1199,15 @@ func (rcv *Monster) VectorOfStrongReferrables(obj *Referrable, j int) bool { return false } +func (rcv *Monster) VectorOfStrongReferrablesByKey(obj *Referrable, key uint64) bool{ + o := flatbuffers.UOffsetT(rcv._tab.Offset(80)) + if o != 0 { + x := rcv._tab.Vector(o) + return obj.LookupByKey(key, x, rcv._tab.Bytes) + } + return false +} + func (rcv *Monster) VectorOfStrongReferrablesLength() int { o := flatbuffers.UOffsetT(rcv._tab.Offset(80)) if o != 0 { @@ -1367,6 +1426,15 @@ func (rcv *Monster) ScalarKeySortedTables(obj *Stat, j int) bool { return false } +func (rcv *Monster) ScalarKeySortedTablesByKey(obj *Stat, key uint16) bool{ + o := flatbuffers.UOffsetT(rcv._tab.Offset(104)) + if o != 0 { + x := rcv._tab.Vector(o) + return obj.LookupByKey(key, x, rcv._tab.Bytes) + } + return false +} + func (rcv *Monster) ScalarKeySortedTablesLength() int { o := flatbuffers.UOffsetT(rcv._tab.Offset(104)) if o != 0 { diff --git a/tests/MyGame/Example/Referrable.go b/tests/MyGame/Example/Referrable.go index aa2707984..0b14beb2e 100644 --- a/tests/MyGame/Example/Referrable.go +++ b/tests/MyGame/Example/Referrable.go @@ -67,6 +67,43 @@ func (rcv *Referrable) MutateId(n uint64) bool { return rcv._tab.MutateUint64Slot(4, n) } +func ReferrableKeyCompare(o1, o2 flatbuffers.UOffsetT, buf []byte) bool { + obj1 := &Referrable{} + obj2 := &Referrable{} + obj1.Init(buf, flatbuffers.UOffsetT(len(buf)) - o1) + obj2.Init(buf, flatbuffers.UOffsetT(len(buf)) - o2) + return obj1.Id() < obj2.Id() +} + +func (rcv *Referrable) LookupByKey(key uint64, vectorLocation flatbuffers.UOffsetT, buf []byte) bool { + span := flatbuffers.GetUOffsetT(buf[vectorLocation - 4:]) + start := flatbuffers.UOffsetT(0) + for span != 0 { + middle := span / 2 + tableOffset := flatbuffers.GetIndirectOffset(buf, vectorLocation+ 4 * (start + middle)) + obj := &Referrable{} + obj.Init(buf, tableOffset) + val := obj.Id() + comp := 0 + if val > key { + comp = 1 + } else if val < key { + comp = -1 + } + if comp > 0 { + span = middle + } else if comp < 0 { + middle += 1 + start += middle + span -= middle + } else { + rcv.Init(buf, tableOffset) + return true + } + } + return false +} + func ReferrableStart(builder *flatbuffers.Builder) { builder.StartObject(1) } diff --git a/tests/MyGame/Example/Stat.go b/tests/MyGame/Example/Stat.go index 714964098..d7976cd7b 100644 --- a/tests/MyGame/Example/Stat.go +++ b/tests/MyGame/Example/Stat.go @@ -94,6 +94,43 @@ func (rcv *Stat) MutateCount(n uint16) bool { return rcv._tab.MutateUint16Slot(8, n) } +func StatKeyCompare(o1, o2 flatbuffers.UOffsetT, buf []byte) bool { + obj1 := &Stat{} + obj2 := &Stat{} + obj1.Init(buf, flatbuffers.UOffsetT(len(buf)) - o1) + obj2.Init(buf, flatbuffers.UOffsetT(len(buf)) - o2) + return obj1.Count() < obj2.Count() +} + +func (rcv *Stat) LookupByKey(key uint16, vectorLocation flatbuffers.UOffsetT, buf []byte) bool { + span := flatbuffers.GetUOffsetT(buf[vectorLocation - 4:]) + start := flatbuffers.UOffsetT(0) + for span != 0 { + middle := span / 2 + tableOffset := flatbuffers.GetIndirectOffset(buf, vectorLocation+ 4 * (start + middle)) + obj := &Stat{} + obj.Init(buf, tableOffset) + val := obj.Count() + comp := 0 + if val > key { + comp = 1 + } else if val < key { + comp = -1 + } + if comp > 0 { + span = middle + } else if comp < 0 { + middle += 1 + start += middle + span -= middle + } else { + rcv.Init(buf, tableOffset) + return true + } + } + return false +} + func StatStart(builder *flatbuffers.Builder) { builder.StartObject(3) } diff --git a/tests/go_test.go b/tests/go_test.go index d454b5647..7cbac1e5e 100644 --- a/tests/go_test.go +++ b/tests/go_test.go @@ -186,9 +186,12 @@ func TestAll(t *testing.T) { // Check size-prefixed flatbuffers CheckSizePrefixedBuffer(t.Fatalf) - // Check that optional scalars work + // Check that optional scalars works CheckOptionalScalars(t.Fatalf) + // Check that getting vector element by key works + CheckByKey(t.Fatalf) + // If the filename of the FlatBuffers file generated by the Java test // is given, check that Go code can read it, and that Go code // generates an identical buffer when used to create the example data: @@ -2215,6 +2218,78 @@ func CheckOptionalScalars(fail func(string, ...interface{})) { expectEq("defaultEnum", obj.DefaultEnum, optional_scalars.OptionalByteTwo) } +func CheckByKey(fail func(string, ...interface{})) { + expectEq := func(what string, a, b interface{}) { + if a != b { + fail(FailString("Lookup by key: "+what, b, a)) + } + } + + b := flatbuffers.NewBuilder(0) + name := b.CreateString("Boss") + + slime := &example.MonsterT{Name: "Slime"} + pig := &example.MonsterT{Name: "Pig"} + slimeBoss := &example.MonsterT{Name: "SlimeBoss"} + mushroom := &example.MonsterT{Name: "Mushroom"} + ironPig := &example.MonsterT{Name: "Iron Pig"} + + monsterOffsets := make([]flatbuffers.UOffsetT, 5) + monsterOffsets[0] = slime.Pack(b) + monsterOffsets[1] = pig.Pack(b) + monsterOffsets[2] = slimeBoss.Pack(b) + monsterOffsets[3] = mushroom.Pack(b) + monsterOffsets[4] = ironPig.Pack(b) + testarrayoftables := b.CreateVectorOfSortedTables(monsterOffsets, example.MonsterKeyCompare) + + str := &example.StatT{Id: "Strength", Count: 42} + luk := &example.StatT{Id: "Luck", Count: 51} + hp := &example.StatT{Id: "Health", Count: 12} + // Test default count value of 0 + mp := &example.StatT{Id: "Mana"} + + statOffsets := make([]flatbuffers.UOffsetT, 4) + statOffsets[0] = str.Pack(b) + statOffsets[1] = luk.Pack(b) + statOffsets[2] = hp.Pack(b) + statOffsets[3] = mp.Pack(b) + scalarKeySortedTablesOffset := b.CreateVectorOfSortedTables(statOffsets, example.StatKeyCompare) + + example.MonsterStart(b) + example.MonsterAddName(b, name) + example.MonsterAddTestarrayoftables(b, testarrayoftables) + example.MonsterAddScalarKeySortedTables(b, scalarKeySortedTablesOffset) + moff := example.MonsterEnd(b) + b.Finish(moff) + + monster := example.GetRootAsMonster(b.Bytes, b.Head()) + slimeMon := &example.Monster{} + monster.TestarrayoftablesByKey(slimeMon, slime.Name) + mushroomMon := &example.Monster{} + monster.TestarrayoftablesByKey(mushroomMon, mushroom.Name) + slimeBossMon := &example.Monster{} + monster.TestarrayoftablesByKey(slimeBossMon, slimeBoss.Name) + + strStat := &example.Stat{} + monster.ScalarKeySortedTablesByKey(strStat, str.Count) + lukStat := &example.Stat{} + monster.ScalarKeySortedTablesByKey(lukStat, luk.Count) + mpStat := &example.Stat{} + monster.ScalarKeySortedTablesByKey(mpStat, mp.Count) + + expectEq("Boss name", string(monster.Name()), "Boss") + expectEq("Slime name", string(slimeMon.Name()), slime.Name) + expectEq("Mushroom name", string(mushroomMon.Name()), mushroom.Name) + expectEq("SlimeBoss name", string(slimeBossMon.Name()), slimeBoss.Name) + expectEq("Strength Id", string(strStat.Id()), str.Id) + expectEq("Strength Count", strStat.Count(), str.Count) + expectEq("Luck Id", string(lukStat.Id()), luk.Id) + expectEq("Luck Count", lukStat.Count(), luk.Count) + expectEq("Mana Id", string(mpStat.Id()), mp.Id) + // Use default count value as key + expectEq("Mana Count", mpStat.Count(), uint16(0)) +} + // BenchmarkVtableDeduplication measures the speed of vtable deduplication // by creating prePop vtables, then populating b.N objects with a // different single vtable.