diff --git a/go/table.go b/go/table.go index b273146fa..f991c80e5 100644 --- a/go/table.go +++ b/go/table.go @@ -31,10 +31,25 @@ func (t *Table) String(off UOffsetT) string { } // ByteVector gets a byte slice from data stored inside the flatbuffer. +// If the offset is invalid or out of bounds, returns nil to prevent crashes. func (t *Table) ByteVector(off UOffsetT) []byte { + n := UOffsetT(len(t.Bytes)) + // Need at least SizeUOffsetT bytes to read the relative vector offset. + u := UOffsetT(SizeUOffsetT) + if n < u || off > n-u { + return nil + } off += GetUOffsetT(t.Bytes[off:]) + // Need at least SizeUOffsetT bytes to read the vector length. + if n < u || off > n-u { + return nil + } start := off + UOffsetT(SizeUOffsetT) length := GetUOffsetT(t.Bytes[off:]) + // Avoid overflow by checking the length against the remaining buffer space. + if length > n-start { + return nil + } return t.Bytes[start : start+length] } diff --git a/tests/go_test.go b/tests/go_test.go index 05c79b458..f00d31e21 100644 --- a/tests/go_test.go +++ b/tests/go_test.go @@ -22,8 +22,8 @@ import ( pizza "Pizza" "encoding/json" optional_scalars "optional_scalars" // refers to generated code - required_strings "required_strings" // refers to generated code order "order" + required_strings "required_strings" // refers to generated code "bytes" "flag" @@ -132,6 +132,10 @@ func TestAll(t *testing.T) { CheckByteStringIsNestedError(t.Fatalf) CheckStructIsNotInlineError(t.Fatalf) CheckFinishedBytesError(t.Fatalf) + + // Verify bounds checking + CheckByteVectorBoundsChecking(t.Fatalf) + CheckSharedStrings(t.Fatalf) CheckEmptiedBuilder(t.Fatalf) @@ -2471,6 +2475,52 @@ func CheckByKey(fail func(string, ...interface{})) { expectEq("Mana Count", mpStat.Count(), uint16(0)) } +// CheckByteVectorBoundsChecking ensures ByteVector handles malformed input safely. +func CheckByteVectorBoundsChecking(fail func(string, ...interface{})) { + // Test case 1: Offset beyond buffer size + table := &flatbuffers.Table{ + Bytes: []byte{0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00}, // Small buffer + Pos: 0, + } + result := table.ByteVector(100) // Offset way beyond buffer + if result != nil { + fail("ByteVector should return nil for offset beyond buffer") + } + + // Test case 2: Malicious length field + // Construct: [relative offset: 4] [vector length: 0xFFFFFFFF] [data...] + maliciousBytes := make([]byte, 20) + // At position 0, set relative offset to point to position 4 + maliciousBytes[0] = 4 + maliciousBytes[1] = 0 + maliciousBytes[2] = 0 + maliciousBytes[3] = 0 + // At position 4, set malicious vector length + maliciousBytes[4] = 0xFF + maliciousBytes[5] = 0xFF + maliciousBytes[6] = 0xFF + maliciousBytes[7] = 0xFF + + table = &flatbuffers.Table{Bytes: maliciousBytes, Pos: 0} + result = table.ByteVector(0) + if result != nil { + fail("ByteVector should return nil for malicious length field") + } + + // Test case 3: Valid case should still work + // Construct: [relative offset: 4] [vector length: 3] [data: 'a', 'b', 'c'] + validBytes := []byte{ + 4, 0, 0, 0, // relative offset to vector data (at position 4) + 3, 0, 0, 0, // vector length (3 bytes) + 'a', 'b', 'c', // actual vector data + } + table = &flatbuffers.Table{Bytes: validBytes, Pos: 0} + result = table.ByteVector(0) + if result == nil || !bytes.Equal(result, []byte("abc")) { + fail("ByteVector should work correctly for valid data") + } +} + // BenchmarkVtableDeduplication measures the speed of vtable deduplication // by creating prePop vtables, then populating b.N objects with a // different single vtable.