diff --git a/python/singa/autograd.py b/python/singa/autograd.py index 56b5498a35..7469d6ff40 100755 --- a/python/singa/autograd.py +++ b/python/singa/autograd.py @@ -23,6 +23,7 @@ import numpy as np import math +from singa import tensor from .tensor import Tensor from . import layer from singa.proto import model_pb2 @@ -798,7 +799,7 @@ def __call__(self, x): self.handle.device_id = x.device.id() y = batchnorm_2d(self.handle, x, self.scale, self.bias, - self.running_mean, self.running_var) + self.running_mean, self.running_var) return y @@ -962,3 +963,163 @@ def __init__(self, kernel_size, stride=None, padding=0): stride = kernel_size super(MaxPool2d, self).__init__( (1, kernel_size), (0, stride), (0, padding), False) + + +class _RNN(Operation): + + def __init__(self, handle): + self.handle = handle + + #def forward(self, X, h0, c0, W): + def forward(self, X, h0, W, c0=None): + # X of shape (seq_len, batch, input_size) + # h0_c0: (h0, c0) if lstm, else (h0,) + # h0, c0 of shape (num_layers * num_directions, batch, hidden_size) + if c0 is None: + assert self.handle.rnn_mode_ != 'lstm' + c0= CTensor([]) # CTensor([]) and Tensor cx are the same? + + if self.handle.device_id == -1: + raise NotImplementedError + else: + if training: + Y, hout, cout = singa.GpuRNNForwardTraining( + self.handle, X, h0, c0, W) + self.cache=(X, Y, h0, c0, W) + else: + Y, hout, cout = singa.GpuRNNForwardInference( + self.handle, X, h0, c0, W) + + # Y of shape (seq_len, batch, hidden_size * num_directions) + # hout_cout: (hout, cout) if lstm, else (hout,) + # hout, cout of shape (num_layers * num_directions, batch, + # hidden_size) + + #oututs= _1dTo3d(Y) + shape=(self.handle.seq_length_, self.handle.batch_size_, self.handle.hidden_size_) + outputs = singa.Reshape(Y, shape) + + if self.handle.rnn_mode_ != 'lstm': + return outputs, hout + else: + return outputs, hout, cout + + def backward(self, dY, dh=CTensor([]), dc=CTensor([])): + assert training is True and hasattr( + self, 'cache'), 'Please set training as True before do BP. ' + + #dY_1d= _3dTo1d(dY) + + if dY.device().id() != self.handle.device_id: + dY.ToDevice(self.cache[0].device()) + + if self.handle.device_id == -1: + raise NotImplementedError + else: + dX_1d, dhout, dcout, dW = singa.GpuRNNBackward( + self.handle, dY, dh, dc, self.cache) + + #dX = _1dTo3d(dX_1d) + shape=(self.handle.seq_length_, self.handle.batch_size_, self.handle.input_size_) + dX = singa.Reshape(dX_1d, shape) + + if self.handle.rnn_mode_ != 'lstm': + return dX, dhout, dW + else: + return dX, dhout, dW, dcout + + +#def rnn(handle, x, h0, c0, W): + # return _RNN(handle)(x, h0, c0, W) + +def rnn(handle, x, h0, W, c0): + if c0 is None: + return _RNN(handle)(x, h0, W) + else: + return _RNN(handle)(x, h0, W, c0) + + +class RNN(Layer): + + def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False, rnn_mode='tanh'): + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = dropout + self.bidirectional = bidirectional + self.rnn_mode = rnn_mode + + if bias is not True or batch_first is not False: + raise NotImplementedError + + mult = 1 + if self.rnn_mode == 'tanh' or self.rnn_mode == 'relu': + mult *= 1 + elif self.rnn_mode == 'lstm': + mult *= 4 + elif self.rnn_mode == 'gru': + mult *= 3 + else: + raise ValueError + + if self.bidirectional: + mult *= 2 + + W_Size = 0 + for k in range(num_layers): + if k == 0: + w_size = self.hidden_size * \ + (self.input_size + self.hidden_size + 2) + else: + w_size = self.hidden_size * \ + (self.hidden_size + self.hidden_size + 2) + W_Size += mult * w_size + + self.W_Size = W_Size + self.W = Tensor(shape=(W_Size,), requires_grad=True, stores_grad=True) # TODO: assign value of Wi separately + self.W.uniform(0.0, 1.0) + + def __call__(self, inputs, h0, c0=None): + # inputs of shape (seq_len, batch, input_size) + # h0_c0: (h0, c0) if lstm, else (h0,) + # h0, c0 of shape (num_layers * num_directions, batch, hidden_size) + + self.device_check(inputs, h0, self.W) + + if self.rnn_mode == 'lstm': + assert c0 is not None, 'Please input c0.' + self.device_check(h0, c0) + else: + assert c0 is None, 'only lstm needs input c0' + + if not hasattr(self, 'handle'): + self.handle = singa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers, + self.rnn_mode, self.dropout, self.bidirectional, self.W_Size) + elif inputs.shape[0] != self.handle.seq_length_ or inputs.shape[1] != self.handle.batch_size_: + self.handle = singa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers, + self.rnn_mode, self.dropout, self.bidirectional, self.W_Size) + + self.handle.device_id = inputs.device.id() + + #X= _3dTo1d(inputs) + X=inputs + outputs = rnn(self.handle, X, h0, self.W, c0) + #outputs = rnn(self.handle, X, h0, self.W) + #outputs=tensor.to_numpy(outputs[0]) + #print(outputs.shape) + #print(outputs) + return outputs + + +class LSTM(RNN): + + def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False): + super(LSTM, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, rnn_mode='lstm') + + +class GRU(RNN): + + def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False): + super(GRU, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, rnn_mode='gru') diff --git a/src/api/model_operation.i b/src/api/model_operation.i index 435ff1c502..5cec92e4ad 100755 --- a/src/api/model_operation.i +++ b/src/api/model_operation.i @@ -7,7 +7,7 @@ #include "../src/model/operation/convolution.h" #include "../src/model/operation/batchnorm.h" #include "../src/model/operation/pooling.h" - +#include "../src/model/operation/rnn.h" %} namespace singa { @@ -51,6 +51,17 @@ class PoolingHandle { int pooled_width; }; +class RNNHandle { +public: + RNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size); + + size_t batch_size_; + size_t seq_length_; + size_t input_size_; + size_t hidden_size_; + std::string rnn_mode_; +}; #if USE_CUDNN class CudnnConvHandle: public ConvHandle { @@ -106,6 +117,27 @@ Tensor GpuPoolingForward(const CudnnPoolingHandle &cph, const Tensor &x); Tensor GpuPoolingBackward(const CudnnPoolingHandle &cph, const Tensor &dy, const Tensor& x, const Tensor& y); + +class CudnnRNNHandle: public RNNHandle { +public: + CudnnRNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size); + + size_t batch_size_; + size_t seq_length_; + size_t input_size_; + size_t hidden_size_; + std::string rnn_mode_; + +}; + +std::vector GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) ; + +std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W); + +std::vector GpuRNNBackward(const CudnnRNNHandle &crh, const Tensor &dY, const Tensor &dhy, const Tensor &dcy, const std::vector &cache); + + #endif // USE_CUDNN } //namespace singa diff --git a/src/model/operation/rnn.cc b/src/model/operation/rnn.cc new file mode 100755 index 0000000000..e8c614e072 --- /dev/null +++ b/src/model/operation/rnn.cc @@ -0,0 +1,461 @@ +#include "./rnn.h" +#include +namespace singa { + +RNNHandle::RNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size) { + + CHECK_EQ(input.shape(2), Input_size); + batch_size_ = input.shape(1); + seq_length_= input.shape(0); + + input_size_ = Input_size; + CHECK_GT(input_size_, 0u); + hidden_size_ = Hidden_size; + CHECK_GT(hidden_size_, 0u); + num_stacks_ = Num_stacks; + CHECK_GT(num_stacks_, 0u); + dropout_ = Dropout; // drop probability + CHECK_GE(dropout_, 0); + + if (bidirectional) + num_directions_ = 2; + else + num_directions_ = 1; + + rnn_mode_ = Rnn_mode; + if (rnn_mode_ == "lstm") { + has_cell_ = true; + } else if (rnn_mode_ != "relu" && rnn_mode_ != "tanh" && rnn_mode_ != "gru") { + LOG(FATAL) << "RNN memory unit (mode) of " << rnn_mode_ + << " is not supported Please use 'relu', 'tanh', 'lstm' and 'gru'"; + } + // the first constant (4) is the size of float + // the second constant (2, 8, 6) is the number of sets of params + weight_size= Weight_size; + +}; + +#ifdef USE_CUDNN + +CudnnRNNHandle::CudnnRNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size): + RNNHandle(input, Input_size, Hidden_size, Num_stacks, Rnn_mode, Dropout, bidirectional, Weight_size) { + + DataType dtype = input.data_type(); + dtype_ = GetCudnnDataType(dtype); + + UpdateIODescriptors(input); + ResetHiddenAndCellDescriptors(); + SetRNNDescriptor(input.device()); + UpdateSpaces(seq_length_, input.device()); +}; + +CudnnRNNHandle::~CudnnRNNHandle() { + if (weight_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyFilterDescriptor(weight_desc_)); + if (dropout_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc_)); + if (rnn_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyRNNDescriptor(rnn_desc_)); + if (hx_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(hx_desc_)); + if (hy_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(hy_desc_)); + if (cx_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(cx_desc_)); + if (cy_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(cy_desc_)); + if (dhx_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(dhx_desc_)); + if (dhy_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(dhy_desc_)); + if (dcx_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(dcx_desc_)); + if (dcy_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(dcy_desc_)); + DestroyIODescriptors(); +}; + +void CudnnRNNHandle::DestroyIODescriptors() { + if (x_descs_ != nullptr) { + for (size_t i = 0; i < seq_length_; i++) { + CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_descs_[i])); + CUDNN_CHECK(cudnnDestroyTensorDescriptor(dx_descs_[i])); + } + delete [] x_descs_; + delete [] dx_descs_; + } + if (y_descs_ != nullptr) { + for (size_t i = 0; i < seq_length_; i++) { + CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_descs_[i])); + CUDNN_CHECK(cudnnDestroyTensorDescriptor(dy_descs_[i])); + } + delete [] y_descs_; + delete [] dy_descs_; + } +}; + + +void CudnnRNNHandle::UpdateIODescriptors(const Tensor &input) { + x_descs_ = new cudnnTensorDescriptor_t[seq_length_]; + dx_descs_ = new cudnnTensorDescriptor_t[seq_length_]; + y_descs_ = new cudnnTensorDescriptor_t[seq_length_]; + dy_descs_ = new cudnnTensorDescriptor_t[seq_length_]; + for (size_t i = 0; i < seq_length_; i++) { + CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_descs_[i])); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&dx_descs_[i])); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_descs_[i])); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&dy_descs_[i])); + } + + for (size_t i = 0; i < seq_length_; i++) { + CHECK_EQ(input.shape(2), input_size_); + int d[3] = {1, 1, 1}, s[3] = {1, 1, 1}; + d[0] = static_cast(batch_size_); + CHECK_GT(d[0], 0); + d[1] = static_cast(input_size_); + s[0] = d[1] * d[2]; + s[1] = d[2]; + CUDNN_CHECK(cudnnSetTensorNdDescriptor(x_descs_[i], dtype_, 3, d, s)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(dx_descs_[i], dtype_, 3, d, s)); + + d[0] = static_cast(batch_size_); + d[1] = static_cast(hidden_size_ * num_directions_); + s[0] = d[1] * d[2]; + s[1] = d[2]; + CUDNN_CHECK(cudnnSetTensorNdDescriptor(y_descs_[i], dtype_, 3, d, s)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(dy_descs_[i], dtype_, 3, d, s)); + } +}; + +void CudnnRNNHandle::ResetHiddenAndCellDescriptors() { + if (cx_desc_ == nullptr) + CUDNN_CHECK(cudnnCreateTensorDescriptor(&cx_desc_)); + if (dcx_desc_ == nullptr) + CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcx_desc_)); + if (cy_desc_ == nullptr) + CUDNN_CHECK(cudnnCreateTensorDescriptor(&cy_desc_)); + if (dcy_desc_ == nullptr) + CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcy_desc_)); + if (hx_desc_ == nullptr) + CUDNN_CHECK(cudnnCreateTensorDescriptor(&hx_desc_)); + if (dhx_desc_ == nullptr) + CUDNN_CHECK(cudnnCreateTensorDescriptor(&dhx_desc_)); + if (hy_desc_ == nullptr) + CUDNN_CHECK(cudnnCreateTensorDescriptor(&hy_desc_)); + if (dhy_desc_ == nullptr) + CUDNN_CHECK(cudnnCreateTensorDescriptor(&dhy_desc_)); + + int dim[3] = {1, 1, 1}; + dim[0] = static_cast(num_stacks_ * num_directions_); + dim[1] = static_cast(batch_size_); + dim[2] = static_cast(hidden_size_); + int stride[3] = {1, 1, 1}; + stride[0] = dim[1] * dim[2]; + stride[1] = dim[2]; + CUDNN_CHECK(cudnnSetTensorNdDescriptor(hx_desc_, dtype_, 3, dim, stride)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(dhx_desc_, dtype_, 3, dim, stride)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(hy_desc_, dtype_, 3, dim, stride)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(dhy_desc_, dtype_, 3, dim, stride)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(cx_desc_, dtype_, 3, dim, stride)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(dcx_desc_, dtype_, 3, dim, stride)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(cy_desc_, dtype_, 3, dim, stride)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(dcy_desc_, dtype_, 3, dim, stride)); +}; + +void CudnnRNNHandle::SetRNNDescriptor(shared_ptr dev) { + auto ctx = dev->context(0); + CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc_)); + size_t state_size; + CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &state_size)); + dropout_state_ = Tensor(Shape{state_size}, dev, kChar); + CUDNN_CHECK(cudnnSetDropoutDescriptor( + dropout_desc_, ctx->cudnn_handle, 1 - dropout_, // keep probability + dropout_state_.block()->mutable_data(), state_size, seed_)); + + CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc_)); + cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; + //if (input_mode_ == "skip") + //input_mode = CUDNN_SKIP_INPUT; + + cudnnDirectionMode_t direction = CUDNN_UNIDIRECTIONAL; + if (num_directions_ == 2) + direction = CUDNN_BIDIRECTIONAL; + + cudnnRNNMode_t rnn_mode = CUDNN_LSTM; + if (rnn_mode_ == "relu") + rnn_mode = CUDNN_RNN_RELU; + else if (rnn_mode_ == "tanh") + rnn_mode = CUDNN_RNN_TANH; + else if (rnn_mode_ == "gru") + rnn_mode = CUDNN_GRU; +#if CUDNN_MAJOR <= 5 + CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hidden_size_, num_stacks_, + dropout_desc_, input_mode, direction, + rnn_mode, dtype_)); +#else + CUDNN_CHECK(cudnnSetRNNDescriptor(ctx->cudnn_handle, rnn_desc_, hidden_size_, num_stacks_, + dropout_desc_, input_mode, direction, + rnn_mode, CUDNN_RNN_ALGO_STANDARD, dtype_)); +#endif + size_t weight_size_; + CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnn_desc_, x_descs_[0], + &weight_size_, dtype_)); + // check the size manually calculated + //std::cout<(weight_size_), 1, 1}; + CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc_)); + CUDNN_CHECK(cudnnSetFilterNdDescriptor(weight_desc_, dtype_, + CUDNN_TENSOR_NCHW, 3, filter_dim)); +}; + +void CudnnRNNHandle::UpdateSpaces(size_t seq_length, shared_ptr dev) { + size_t count; + auto ctx = dev->context(0); + CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx->cudnn_handle, rnn_desc_, + seq_length, x_descs_, &count)); + if (workspace_.Size() != count) { + workspace_ = Tensor(Shape{count}, dev, kChar); + // workspace_.SetValue(0); + } + + CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnn_desc_, + seq_length, x_descs_, &count)); + if (reserve_space_.Size() != count) { + reserve_space_ = Tensor(Shape{count}, dev, kChar); + // reserve_space_.SetValue(0); + } +}; + +Tensor MergeInputs(size_t num, const vector &in) { + if (num == 1) + return in.at(0); + size_t size = 0; + for (size_t i = 0; i < num; i++) size += in.at(i).Size(); + Tensor out(Shape{size}, in.at(0).device(), in.at(0).data_type()); + for (size_t i = 0, offset = 0; i < num; i++) { + CopyDataToFrom(&out, in.at(i), in.at(i).Size(), offset); + offset += in.at(i).Size(); + } + return out; +}; + +vector SplitOutput(size_t num, size_t dim, + const vector &in, + const Tensor output) { + vector outputs; + if (num == 1) { + outputs.push_back(Reshape(output, Shape{in.at(0).shape(0), dim})); + } else { + for (size_t i = 0, offset = 0; offset < output.Size(); i++) { + Shape s{in.at(i).shape(0), dim}; + Tensor out(s, output.device(), output.data_type()); + CopyDataToFrom(&out, output, out.Size(), 0, offset); + outputs.push_back(out); + offset += out.Size(); + } + CHECK_EQ(num, outputs.size()); + } + return outputs; +}; + +std::vector GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) { + DataType dtype = input.data_type(); + auto dev = input.device(); + + + Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_}; + Tensor output(outshape, dev, dtype); + // LOG(INFO) << "output size " << output.Size(); + + Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_}; + Tensor hy(state_shape, dev, dtype); + + Tensor cy; + if (crh.has_cell_) { + cy.ResetLike(hy); + } + + int did = input.device()->id(); + CHECK_EQ(did, output.device()->id()); + if (hx.Size()) { + CHECK_EQ(did, hx.device()->id()); + CHECK_EQ(hx.device()->lang(), kCuda); + } + if (cx.Size()) { + CHECK_EQ(did, cx.device()->id()); + CHECK_EQ(cx.device()->lang(), kCuda); + } + CHECK_EQ(did, W.device()->id()); + CHECK_EQ(did, crh.workspace_.device()->id()); + CHECK_EQ(input.device()->lang(), kCuda); + CHECK_EQ(output.device()->lang(), kCuda); + CHECK_EQ(W.device()->lang(), kCuda); + CHECK_EQ(crh.workspace_.device()->lang(), kCuda); + CHECK_EQ(crh.reserve_space_.device()->lang(), kCuda); + CHECK_EQ(did, crh.reserve_space_.device()->id()); + + Block *inb = input.block(), *outb = output.block(), + *wb = W.block(), *hxb = hx.block(), *cxb = cx.block(), + *hyb = hy.block(), *cyb = cy.block(), + *wspace = crh.workspace_.block(), + *rspace = crh.reserve_space_.block(); + + dev->Exec( + [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, &crh](Context * ctx) { + // clang-format off + cudnnRNNForwardTraining( + ctx->cudnn_handle, + crh.rnn_desc_, + crh.seq_length_, + crh.x_descs_, inb->data(), + crh.hx_desc_, hxb == nullptr ? nullptr : hxb->data(), + crh.cx_desc_, cxb == nullptr ? nullptr : cxb->data(), + crh.weight_desc_, wb->data(), + crh.y_descs_, outb->mutable_data(), + crh.hy_desc_, hyb->mutable_data(), + crh.cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(), + wspace->mutable_data(), + crh.workspace_.Size(), rspace->mutable_data(), + crh.reserve_space_.Size()); + // clang-format on + }, + {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace}); + + return {output, hy, cy}; +}; + +std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) { + DataType dtype = input.data_type(); + auto dev = input.device(); + + Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_}; + Tensor output(outshape, dev, dtype); + // LOG(INFO) << "output size " << output.Size(); + + Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_}; + Tensor hy(state_shape, dev, dtype); + + Tensor cy; + if (crh.has_cell_) { + cy.ResetLike(hy); + } + + int did = input.device()->id(); + CHECK_EQ(did, output.device()->id()); + if (hx.Size()) { + CHECK_EQ(did, hx.device()->id()); + CHECK_EQ(hx.device()->lang(), kCuda); + } + if (cx.Size()) { + CHECK_EQ(did, cx.device()->id()); + CHECK_EQ(cx.device()->lang(), kCuda); + } + CHECK_EQ(did, W.device()->id()); + CHECK_EQ(did, crh.workspace_.device()->id()); + CHECK_EQ(input.device()->lang(), kCuda); + CHECK_EQ(output.device()->lang(), kCuda); + CHECK_EQ(W.device()->lang(), kCuda); + CHECK_EQ(crh.workspace_.device()->lang(), kCuda); + + Block *inb = input.block(), *outb = output.block(), + *wb = W.block(), *hxb = hx.block(), *cxb = cx.block(), + *hyb = hy.block(), *cyb = cy.block(), + *wspace = crh.workspace_.block(); + + dev->Exec([inb, outb, wb, hxb, cxb, hyb, cyb, wspace, &crh](Context * ctx) { + // clang-format off + cudnnRNNForwardInference( + ctx->cudnn_handle, + crh.rnn_desc_, + crh.seq_length_, + crh.x_descs_, inb->data(), + crh.hx_desc_, hxb == nullptr ? nullptr : hxb->data(), + crh.cx_desc_, cxb == nullptr ? nullptr : cxb->data(), + crh.weight_desc_, wb->data(), + crh.y_descs_, outb->mutable_data(), + crh.hy_desc_, hyb->mutable_data(), + crh.cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(), + wspace->mutable_data(), crh.workspace_.Size()); + // clang-format on + }, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace}); + + return {output, hy, cy}; +}; + +std::vector GpuRNNBackward(const CudnnRNNHandle &crh, const Tensor &dY, const Tensor &dhy, const Tensor &dcy, const std::vector &cache) { + const Tensor x = cache[0]; + const Tensor y = cache[1]; + const Tensor hx = cache[2]; + const Tensor cx = cache[3]; + const Tensor W = cache[4]; + + auto dev = y.device(); + auto dtype = y.data_type(); + + CHECK_EQ(dY.Size(), y.Size()); + + Shape xshape{y.Size() * crh.input_size_ / crh.hidden_size_ / crh.num_directions_}; + Tensor dx(xshape, dev, dtype); + + Tensor dw(W.shape(), dev, dtype); + + Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_}; + Tensor dhx(state_shape, dev, dtype); + + Tensor dcx; + if (crh.has_cell_) + dcx.ResetLike(dhx); + + dw.SetValue(0.0f); + Block *yb = y.block(), *dyb = dY.block(), *dhyb = dhy.block(), + *dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(), + *wb = W.block(), *dwb = dw.block(), *hxb = hx.block(), + *dxb = dx.block(), *dhxb = dhx.block(), *dcxb = dcx.block(), + *wspace = crh.workspace_.block(), *rspace = crh.reserve_space_.block(); + + y.device()->Exec( + [yb, dyb, dhyb, dcyb, xb, cxb, wb, dwb, hxb, dxb, dhxb, dcxb, wspace, + rspace, &crh](Context * ctx) { + // clang-format off + cudnnRNNBackwardData( + ctx->cudnn_handle, + crh.rnn_desc_, + crh.seq_length_, + crh.y_descs_, yb->data(), + crh.dy_descs_, dyb->data(), + crh.dhy_desc_, dhyb == nullptr ? nullptr : dhyb->data(), + crh.dcy_desc_, dcyb == nullptr ? nullptr : dcyb->data(), + crh.weight_desc_, wb->data(), + crh.hx_desc_, hxb == nullptr ? nullptr : hxb->data(), + crh.cx_desc_, cxb == nullptr ? nullptr : cxb->data(), + crh.dx_descs_, dxb->mutable_data(), + crh.dhx_desc_, dhxb->mutable_data(), + crh.dcx_desc_, dcxb == nullptr ? nullptr : dcxb->mutable_data(), + wspace->mutable_data(), crh.workspace_.Size(), + rspace->mutable_data(), crh.reserve_space_.Size()); + cudnnRNNBackwardWeights( + ctx->cudnn_handle, + crh.rnn_desc_, + crh.seq_length_, + crh.x_descs_, xb->data(), + crh.hx_desc_, hxb == nullptr ? nullptr : hxb->data(), + crh.y_descs_, yb->data(), + wspace->data(), crh.workspace_.Size(), + crh.dweight_desc_, dwb->mutable_data(), + rspace->data(), crh.reserve_space_.Size()); + // clang-format on + }, + {yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace}, + {dxb, dwb, dhxb, dcxb, wspace, rspace}); + + return {dx, dhx, dcx, dw}; +}; + +#endif // USE_CUDNN + +} // namespace singa + + diff --git a/src/model/operation/rnn.h b/src/model/operation/rnn.h new file mode 100755 index 0000000000..88c24e4292 --- /dev/null +++ b/src/model/operation/rnn.h @@ -0,0 +1,88 @@ +#ifndef SINGA_MODEL_OPERATION_CUDNN_RNN_H_ +#define SINGA_MODEL_OPERATION_CUDNN_RNN_H_ + +#include +#include +#include "singa/core/tensor.h" + + +#ifdef USE_CUDNN +#include +#include "../layer/cudnn_utils.h" +#endif // USE_CUDNN + + +namespace singa { + +class RNNHandle { +public: + RNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size); + + size_t input_size_; + size_t hidden_size_; + size_t num_stacks_; + float dropout_; + size_t seed_ = 0x1234567; + size_t num_directions_; + std::string rnn_mode_; + bool has_cell_; + size_t weight_size; + + size_t batch_size_; + size_t seq_length_; + +}; + +#ifdef USE_CUDNN + +class CudnnRNNHandle: public RNNHandle { +public: + CudnnRNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size); + ~CudnnRNNHandle(); + + void DestroyIODescriptors(); + void UpdateIODescriptors(const Tensor &input); + void ResetHiddenAndCellDescriptors() ; + void SetRNNDescriptor(shared_ptr dev); + void UpdateSpaces(size_t seq_length, shared_ptr dev); + + + cudnnTensorDescriptor_t* x_descs_ = nullptr; + cudnnTensorDescriptor_t* dx_descs_ = nullptr; + cudnnTensorDescriptor_t* y_descs_ = nullptr; + cudnnTensorDescriptor_t* dy_descs_ = nullptr; + cudnnTensorDescriptor_t hx_desc_ = nullptr; + cudnnTensorDescriptor_t dhx_desc_ = nullptr; + cudnnTensorDescriptor_t cx_desc_ = nullptr; + cudnnTensorDescriptor_t dcx_desc_ = nullptr; + cudnnTensorDescriptor_t hy_desc_ = nullptr; + cudnnTensorDescriptor_t dhy_desc_ = nullptr; + cudnnTensorDescriptor_t cy_desc_ = nullptr; + cudnnTensorDescriptor_t dcy_desc_ = nullptr; + cudnnFilterDescriptor_t weight_desc_ = nullptr; + cudnnFilterDescriptor_t dweight_desc_ = nullptr; + cudnnRNNDescriptor_t rnn_desc_ = nullptr; + cudnnDropoutDescriptor_t dropout_desc_ = nullptr; + cudnnDataType_t dtype_ = CUDNN_DATA_FLOAT; + Tensor workspace_; + Tensor reserve_space_; + Tensor dropout_state_; +}; +Tensor MergeInputs(size_t num, const vector &in); + +vector SplitOutput(size_t num, size_t dim, + const vector &in, + const Tensor output); + +std::vector GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) ; + +std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W); + +std::vector GpuRNNBackward(const CudnnRNNHandle &crh, const Tensor &dY, const Tensor &dhy, const Tensor &dcy, const vector &cache); + +#endif // USE_CUDNN + +} // namespace singa +#endif // SINGA_MODEL_OPERATION_CUDNN_RNN_H_