mirror of
https://github.com/google/flatbuffers.git
synced 2026-06-02 12:05:50 +00:00
[Python] (scalar) vector reading speedup via numpy (#4390)
* Add numpy accessor to python flatbuffers scalar vectors * Update python tests to test numpy vector accessor * Update appveyor CI to run Python tests, save generated code as artifact * Update example generated python code * Add numpy info to python usage docs * Update test schema and python tests w/ multi-byte vector * did not mean to push profiling code * adding float64 numpy tests
This commit is contained in:
committed by
Wouter van Oortmerssen
parent
89a68942ac
commit
3282a84e30
@@ -12,9 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
""" A tiny version of `six` to help with backwards compability. """
|
||||
""" A tiny version of `six` to help with backwards compability. Also includes
|
||||
compatibility helpers for numpy. """
|
||||
|
||||
import sys
|
||||
import imp
|
||||
|
||||
PY2 = sys.version_info[0] == 2
|
||||
PY26 = sys.version_info[0:2] == (2, 6)
|
||||
@@ -43,4 +45,37 @@ else:
|
||||
memoryview_type = memoryview
|
||||
struct_bool_decl = "?"
|
||||
|
||||
# Helper functions to facilitate making numpy optional instead of required
|
||||
|
||||
def import_numpy():
|
||||
"""
|
||||
Returns the numpy module if it exists on the system,
|
||||
otherwise returns None.
|
||||
"""
|
||||
try:
|
||||
imp.find_module('numpy')
|
||||
numpy_exists = True
|
||||
except ImportError:
|
||||
numpy_exists = False
|
||||
|
||||
if numpy_exists:
|
||||
# We do this outside of try/except block in case numpy exists
|
||||
# but is not installed correctly. We do not want to catch an
|
||||
# incorrect installation which would manifest as an
|
||||
# ImportError.
|
||||
import numpy as np
|
||||
else:
|
||||
np = None
|
||||
|
||||
return np
|
||||
|
||||
|
||||
class NumpyRequiredForThisFeature(RuntimeError):
|
||||
"""
|
||||
Error raised when user tries to use a feature that
|
||||
requires numpy without having numpy installed.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# NOTE: Future Jython support may require code here (look at `six`).
|
||||
|
||||
@@ -15,13 +15,26 @@
|
||||
from . import number_types as N
|
||||
from . import packer
|
||||
from .compat import memoryview_type
|
||||
from .compat import import_numpy, NumpyRequiredForThisFeature
|
||||
|
||||
np = import_numpy()
|
||||
|
||||
def Get(packer_type, buf, head):
|
||||
""" Get decodes a value at buf[head:] using `packer_type`. """
|
||||
""" Get decodes a value at buf[head] using `packer_type`. """
|
||||
return packer_type.unpack_from(memoryview_type(buf), head)[0]
|
||||
|
||||
|
||||
def GetVectorAsNumpy(numpy_type, buf, count, offset):
|
||||
""" GetVecAsNumpy decodes values starting at buf[head] as
|
||||
`numpy_type`, where `numpy_type` is a numpy dtype. """
|
||||
if np is not None:
|
||||
# TODO: could set .flags.writeable = False to make users jump through
|
||||
# hoops before modifying...
|
||||
return np.frombuffer(buf, dtype=numpy_type, count=count, offset=offset)
|
||||
else:
|
||||
raise NumpyRequiredForThisFeature('Numpy was not found.')
|
||||
|
||||
|
||||
def Write(packer_type, buf, head, n):
|
||||
""" Write encodes `n` at buf[head:] using `packer_type`. """
|
||||
""" Write encodes `n` at buf[head] using `packer_type`. """
|
||||
packer_type.pack_into(buf, head, n)
|
||||
|
||||
@@ -16,7 +16,9 @@ import collections
|
||||
import struct
|
||||
|
||||
from . import packer
|
||||
from .compat import import_numpy, NumpyRequiredForThisFeature
|
||||
|
||||
np = import_numpy()
|
||||
|
||||
# For reference, see:
|
||||
# https://docs.python.org/2/library/ctypes.html#ctypes-fundamental-data-types-2
|
||||
@@ -170,3 +172,10 @@ def uint64_to_float64(n):
|
||||
packed = struct.pack("<1Q", n)
|
||||
(unpacked,) = struct.unpack("<1d", packed)
|
||||
return unpacked
|
||||
|
||||
|
||||
def to_numpy_type(number_type):
|
||||
if np is not None:
|
||||
return np.dtype(number_type.name).newbyteorder('<')
|
||||
else:
|
||||
raise NumpyRequiredForThisFeature('Numpy was not found.')
|
||||
|
||||
@@ -101,6 +101,18 @@ class Table(object):
|
||||
return d
|
||||
return self.Get(validator_flags, self.Pos + off)
|
||||
|
||||
def GetVectorAsNumpy(self, flags, off):
|
||||
"""
|
||||
GetVectorAsNumpy returns the vector that starts at `Vector(off)`
|
||||
as a numpy array with the type specified by `flags`. The array is
|
||||
a `view` into Bytes, so modifying the returned array will
|
||||
modify Bytes in place.
|
||||
"""
|
||||
offset = self.Vector(off)
|
||||
length = self.VectorLen(off) # TODO: length accounts for bytewidth, right?
|
||||
numpy_dtype = N.to_numpy_type(flags)
|
||||
return encode.GetVectorAsNumpy(numpy_dtype, self.Bytes, length, offset)
|
||||
|
||||
def GetVOffsetTSlot(self, slot, d):
|
||||
"""
|
||||
GetVOffsetTSlot retrieves the VOffsetT that the given vtable location
|
||||
|
||||
Reference in New Issue
Block a user