Optional Scalars support for Rust (#6034)

* First draft of rust optionals

* Code cleanup around ftBool and ftVectorOfBool

* Tests for Rust optional scalars

* test bools too

Co-authored-by: Casper Neo <cneo@google.com>
This commit is contained in:
Casper
2020-07-23 16:30:27 -07:00
committed by GitHub
parent c8fa0afdfc
commit 043b52bd4a
12 changed files with 626 additions and 59 deletions

View File

@@ -661,22 +661,15 @@ class RustGenerator : public BaseGenerator {
return "VT_" + MakeUpper(Name(field));
}
std::string GetDefaultConstant(const FieldDef &field) {
return field.value.type.base_type == BASE_TYPE_FLOAT
? field.value.constant + ""
: field.value.constant;
}
std::string GetDefaultScalarValue(const FieldDef &field) {
switch (GetFullType(field.value.type)) {
case ftInteger: {
return GetDefaultConstant(field);
}
case ftInteger:
case ftFloat: {
return GetDefaultConstant(field);
return field.nullable ? "None" : field.value.constant;
}
case ftBool: {
return field.value.constant == "0" ? "false" : "true";
return field.nullable ? "None" :
field.value.constant == "0" ? "false" : "true";
}
case ftUnionKey:
case ftEnumKey: {
@@ -714,7 +707,7 @@ class RustGenerator : public BaseGenerator {
case ftFloat:
case ftBool: {
const auto typname = GetTypeBasic(type);
return typname;
return field.nullable ? "Option<" + typname + ">" : typname;
}
case ftStruct: {
const auto typname = WrapInNameSpace(*type.struct_def);
@@ -738,14 +731,12 @@ class RustGenerator : public BaseGenerator {
}
case ftVectorOfInteger:
case ftVectorOfBool:
case ftVectorOfFloat: {
const auto typname = GetTypeBasic(type.VectorType());
return "Option<flatbuffers::WIPOffset<flatbuffers::Vector<" + lifetime +
", " + typname + ">>>";
}
case ftVectorOfBool: {
return "Option<flatbuffers::WIPOffset<flatbuffers::Vector<" + lifetime +
", bool>>>";
// TODO(cneo): Fix whitespace in generated code.
}
case ftVectorOfEnumKey: {
const auto typname = WrapInNameSpace(*type.enum_def);
@@ -815,15 +806,12 @@ class RustGenerator : public BaseGenerator {
">>>>";
}
case ftVectorOfInteger:
case ftVectorOfBool:
case ftVectorOfFloat: {
const auto typname = GetTypeBasic(type.VectorType());
return "flatbuffers::WIPOffset<flatbuffers::Vector<" + lifetime + ", " +
typname + ">>";
}
case ftVectorOfBool: {
return "flatbuffers::WIPOffset<flatbuffers::Vector<" + lifetime +
", bool>>";
}
case ftVectorOfString: {
return "flatbuffers::WIPOffset<flatbuffers::Vector<" + lifetime +
", flatbuffers::ForwardsUOffset<&" + lifetime + " str>>>";
@@ -851,12 +839,9 @@ class RustGenerator : public BaseGenerator {
return "flatbuffers::WIPOffset<" + typname + "<" + lifetime + ">>";
}
case ftInteger:
case ftBool:
case ftFloat: {
const auto typname = GetTypeBasic(type);
return typname;
}
case ftBool: {
return "bool";
return GetTypeBasic(type);
}
case ftString: {
return "flatbuffers::WIPOffset<&" + lifetime + " str>";
@@ -878,14 +863,13 @@ class RustGenerator : public BaseGenerator {
switch (GetFullType(field.value.type)) {
case ftInteger:
case ftBool:
case ftFloat: {
const auto typname = GetTypeBasic(field.value.type);
return "self.fbb_.push_slot::<" + typname + ">";
return (field.nullable ?
"self.fbb_.push_slot_always::<" :
"self.fbb_.push_slot::<") + typname + ">";
}
case ftBool: {
return "self.fbb_.push_slot::<bool>";
}
case ftEnumKey:
case ftUnionKey: {
const auto underlying_typname = GetTypeBasic(type);
@@ -924,12 +908,10 @@ class RustGenerator : public BaseGenerator {
switch (GetFullType(field.value.type)) {
case ftInteger:
case ftFloat: {
const auto typname = GetTypeBasic(type);
return typname;
}
case ftFloat:
case ftBool: {
return "bool";
const auto typname = GetTypeBasic(type);
return field.nullable ? "Option<" + typname + ">" : typname;
}
case ftStruct: {
const auto typname = WrapInNameSpace(*type.struct_def);
@@ -956,6 +938,7 @@ class RustGenerator : public BaseGenerator {
field.required);
}
case ftVectorOfInteger:
case ftVectorOfBool:
case ftVectorOfFloat: {
const auto typname = GetTypeBasic(type.VectorType());
if (IsOneByte(type.VectorType().base_type)) {
@@ -966,10 +949,6 @@ class RustGenerator : public BaseGenerator {
"flatbuffers::Vector<" + lifetime + ", " + typname + ">",
field.required);
}
case ftVectorOfBool: {
return WrapInOptionIfNotRequired("&" + lifetime + " [bool]",
field.required);
}
case ftVectorOfEnumKey: {
const auto typname = WrapInNameSpace(*type.enum_def);
return WrapInOptionIfNotRequired(
@@ -1016,9 +995,13 @@ class RustGenerator : public BaseGenerator {
case ftFloat:
case ftBool: {
const auto typname = GetTypeBasic(type);
const auto default_value = GetDefaultScalarValue(field);
return "self._tab.get::<" + typname + ">(" + offset_name + ", Some(" +
default_value + ")).unwrap()";
if (field.nullable) {
return "self._tab.get::<" + typname + ">(" + offset_name + ", None)";
} else {
const auto default_value = GetDefaultScalarValue(field);
return "self._tab.get::<" + typname + ">(" + offset_name + ", Some(" +
default_value + ")).unwrap()";
}
}
case ftStruct: {
const auto typname = WrapInNameSpace(*type.struct_def);
@@ -1056,6 +1039,7 @@ class RustGenerator : public BaseGenerator {
}
case ftVectorOfInteger:
case ftVectorOfBool:
case ftVectorOfFloat: {
const auto typname = GetTypeBasic(type.VectorType());
std::string s =
@@ -1068,14 +1052,6 @@ class RustGenerator : public BaseGenerator {
}
return AddUnwrapIfRequired(s, field.required);
}
case ftVectorOfBool: {
return AddUnwrapIfRequired(
"self._tab.get::<flatbuffers::ForwardsUOffset<"
"flatbuffers::Vector<" +
lifetime + ", bool>>>(" + offset_name +
", None).map(|v| v.safe_slice())",
field.required);
}
case ftVectorOfEnumKey: {
const auto typname = WrapInNameSpace(*type.enum_def);
return AddUnwrapIfRequired(
@@ -1116,8 +1092,9 @@ class RustGenerator : public BaseGenerator {
return "INVALID_CODE_GENERATION"; // for return analysis
}
bool TableFieldReturnsOption(const Type &type) {
switch (GetFullType(type)) {
bool TableFieldReturnsOption(const FieldDef &field) {
if (field.nullable) return true;
switch (GetFullType(field.value.type)) {
case ftInteger:
case ftFloat:
case ftBool:
@@ -1205,7 +1182,7 @@ class RustGenerator : public BaseGenerator {
if (!field.deprecated && (!struct_def.sortbysize ||
size == SizeOf(field.value.type.base_type))) {
code_.SetValue("FIELD_NAME", Name(field));
if (TableFieldReturnsOption(field.value.type)) {
if (TableFieldReturnsOption(field)) {
code_ +=
" if let Some(x) = args.{{FIELD_NAME}} "
"{ builder.add_{{FIELD_NAME}}(x); }";
@@ -1421,7 +1398,6 @@ class RustGenerator : public BaseGenerator {
const auto &field = **it;
if (!field.deprecated) {
const bool is_scalar = IsScalar(field.value.type.base_type);
std::string offset = GetFieldOffsetName(field);
// Generate functions to add data, which take one of two forms.
@@ -1443,7 +1419,7 @@ class RustGenerator : public BaseGenerator {
code_ +=
" pub fn add_{{FIELD_NAME}}(&mut self, {{FIELD_NAME}}: "
"{{FIELD_TYPE}}) {";
if (is_scalar) {
if (is_scalar && !field.nullable) {
code_.SetValue("FIELD_DEFAULT_VALUE",
TableBuilderAddFuncDefaultValue(field));
code_ +=