From cbd6162457390b113551890e2ae2460622925db7 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Fri, 15 Apr 2016 08:48:13 -0400 Subject: [PATCH 1/2] Add support for char --- luasrc/ffi.lua | 5 ++++- tests/testData.lua | 2 -- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/luasrc/ffi.lua b/luasrc/ffi.lua index 6b2d199..7178c93 100644 --- a/luasrc/ffi.lua +++ b/luasrc/ffi.lua @@ -111,6 +111,8 @@ addConstants('h5t', { 'VLEN', 'ARRAY', 'NCLASSES', + 'SGN_NONE', + 'SGN_2', }, addH5t) local function addG(x) return addH5t(x) .. "_g" end @@ -310,7 +312,8 @@ function hdf5._getTorchType(typeID) local size = tonumber(hdf5.C.H5Tget_size(typeID)) if className == 'INTEGER' then if size == 1 then - return 'torch.ByteTensor' + local signed = hdf5.C.H5Tget_sign(typeID) == hdf5.h5t.SGN_2 + return signed and 'torch.CharTensor' or 'torch.ByteTensor' end if size == 2 then return 'torch.ShortTensor' diff --git a/tests/testData.lua b/tests/testData.lua index ceb0819..b54ac8a 100644 --- a/tests/testData.lua +++ b/tests/testData.lua @@ -42,14 +42,12 @@ local function intTensorEqual(typename, a, b) return a:add(-b):apply(function(x) return math.abs(tonumber(x)) end):sum() == 0 end ---[[ Not supported yet function myTests:testCharTensor() local k = 0 local testData = torch.CharTensor(4, 6):apply(function() k = k + 1; return k end) local got = writeAndReread(testData) tester:assert(intTensorEqual("torch.CharTensor", got, testData), "Data read does not match data written!") end -]] function myTests:testByteTensor() local k = 0 From a2aa4b05586859f5f4f6380c9a78d42c1b59ccad Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Fri, 15 Apr 2016 08:48:50 -0400 Subject: [PATCH 2/2] Add support for strings of chars --- luasrc/dataset.lua | 12 +++++++++--- luasrc/ffi.lua | 2 ++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/luasrc/dataset.lua b/luasrc/dataset.lua index 36e2d84..8e8486c 100644 --- a/luasrc/dataset.lua +++ b/luasrc/dataset.lua @@ -6,7 +6,7 @@ local unpack = unpack or table.unpack local HDF5DataSet = torch.class("hdf5.HDF5DataSet") --[[ Get the sizes and max sizes of an HDF5 dataspace, returning them in Lua tables ]] -local function getDataspaceSize(nDims, spaceID) +local function getDataspaceSize(nDims, spaceID, datasetID) local size_t = hdf5.ffi.typeof("hsize_t[" .. nDims .. "]") local dims = size_t() local maxDims = size_t() @@ -19,6 +19,12 @@ local function getDataspaceSize(nDims, spaceID) size[k] = tonumber(dims[k-1]) maxSize[k] = tonumber(maxDims[k-1]) end + + local typeID = hdf5.C.H5Dget_type(datasetID) + if hdf5._datatypeName(typeID) == 'STRING' then + size[nDims+1] = tonumber(hdf5.C.H5Tget_size(typeID)) + end + return size, maxSize end @@ -65,7 +71,7 @@ function HDF5DataSet:all() -- Create a new tensor of the correct type and size local nDims = hdf5.C.H5Sget_simple_extent_ndims(self._dataspaceID) - local size = getDataspaceSize(nDims, self._dataspaceID) + local size = getDataspaceSize(nDims, self._dataspaceID, self._datasetID) local factory, nativeType = self:getTensorFactory() local tensor = factory():resize(unpack(size)) @@ -177,6 +183,6 @@ end function HDF5DataSet:dataspaceSize() local nDims = hdf5.C.H5Sget_simple_extent_ndims(self._dataspaceID) - local size = getDataspaceSize(nDims, self._dataspaceID) + local size = getDataspaceSize(nDims, self._dataspaceID, self._datasetID) return size end diff --git a/luasrc/ffi.lua b/luasrc/ffi.lua index 7178c93..cfd8c67 100644 --- a/luasrc/ffi.lua +++ b/luasrc/ffi.lua @@ -333,6 +333,8 @@ function hdf5._getTorchType(typeID) return 'torch.DoubleTensor' end error("Cannot support reading float data with size = " .. size .. " bytes") + elseif className == 'STRING' then + return 'torch.CharTensor' else error("Reading data of class " .. tostring(className) .. "(" .. typeID .. ") is unsupported")