diff --git a/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.cpp b/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.cpp index 5de75f60..8dbf568e 100644 --- a/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.cpp +++ b/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.cpp @@ -32,6 +32,9 @@ void LexiconFreeSeq2SeqDecoder::decodeStep( hyp_[0].clear(); hyp_[0].emplace_back(0.0, lm_->start(0), nullptr, -1, nullptr); + // Size of each group + int grpSize = opt_.beamSize / opt_.numBeamGroups; + // Decode frame by frame int t = 0; for (; t < maxOutputLength_; t++) { @@ -61,87 +64,104 @@ void LexiconFreeSeq2SeqDecoder::decodeStep( std::vector idx(emittingModelScores.back().size()); - // Generate new hypothesis - for (int hypo = 0, validHypo = 0; hypo < hyp_[t].size(); hypo++) { - const LexiconFreeSeq2SeqDecoderState& prevHyp = hyp_[t][hypo]; - // Change nothing for completed hypothesis - if (prevHyp.token == eos_) { - candidatesAdd( - candidates_, - candidatesBestScore_, - opt_.beamThreshold, - prevHyp.score, - prevHyp.lmState, - &prevHyp, - eos_, - nullptr, - prevHyp.emittingModelScore, - prevHyp.lmScore, - hypo); - continue; - } - - const EmittingModelStatePtr& outState = outStates[validHypo]; - if (!outState) { - validHypo++; - continue; - } - - std::iota(idx.begin(), idx.end(), 0); - if (emittingModelScores[validHypo].size() > opt_.beamSizeToken) { - std::partial_sort( - idx.begin(), - idx.begin() + opt_.beamSizeToken, - idx.end(), - [&emittingModelScores, &validHypo]( - const size_t& l, const size_t& r) { - return emittingModelScores[validHypo][l] > - emittingModelScores[validHypo][r]; - }); - } - - for (int r = 0; r < - std::min(emittingModelScores[validHypo].size(), - (size_t)opt_.beamSizeToken); - r++) { - int n = idx[r]; - double emittingModelScore = emittingModelScores[validHypo][n]; - - if (n == eos_) { /* (1) Try eos */ - auto lmStateScorePair = lm_->finish(prevHyp.lmState); - auto lmScore = lmStateScorePair.second; - + // Iterate through groups, if only one group, just vanilla BS + int hypo = 0; + int validHypo = 0; + + uniqueCandidateTokens_.clear(); + + for (int grp = 0; grp < opt_.numBeamGroups; grp++) { + // Generate new hypothesis + for (hypo, validHypo ; hypo < std::min(hyp_[t].size(), (size_t)grpSize); hypo++) { + const LexiconFreeSeq2SeqDecoderState& prevHyp = hyp_[t][hypo]; + // Change nothing for completed hypothesis + if (prevHyp.token == eos_) { candidatesAdd( candidates_, candidatesBestScore_, opt_.beamThreshold, - prevHyp.score + emittingModelScore + opt_.eosScore + - opt_.lmWeight * lmScore, - lmStateScorePair.first, + prevHyp.score, + prevHyp.lmState, &prevHyp, - n, + eos_, nullptr, - prevHyp.emittingModelScore + emittingModelScore, - prevHyp.lmScore + lmScore, - hypo); - } else { /* (2) Try normal token */ - auto lmStateScorePair = lm_->score(prevHyp.lmState, n); - auto lmScore = lmStateScorePair.second; - candidatesAdd( - candidates_, - candidatesBestScore_, - opt_.beamThreshold, - prevHyp.score + emittingModelScore + opt_.lmWeight * lmScore, - lmStateScorePair.first, - &prevHyp, - n, - outState, - prevHyp.emittingModelScore + emittingModelScore, - prevHyp.lmScore + lmScore, + prevHyp.emittingModelScore, + prevHyp.lmScore, hypo); + continue; } + + const EmittingModelStatePtr& outState = outStates[validHypo]; + if (!outState) { + validHypo++; + continue; + } + + std::iota(idx.begin(), idx.end(), 0); + if (emittingModelScores[validHypo].size() > opt_.beamSizeToken) { + std::partial_sort( + idx.begin(), + idx.begin() + opt_.beamSizeToken, + idx.end(), + [&emittingModelScores, &validHypo]( + const size_t& l, const size_t& r) { + return emittingModelScores[validHypo][l] > + emittingModelScores[validHypo][r]; + }); + } + + for (int r = 0; r < + std::min(emittingModelScores[validHypo].size(), + (size_t)opt_.beamSizeToken); + r++) { + int n = idx[r]; + + double diversityFactor = 0.0; + if (grp > 0) { + // Need to get a set of all the tokens chosen from other groups + // Can only apply the diversity factor after first run through + diversityFactor = diversityFunction_(uniqueCandidateTokens_, n); + } + // Augment log probabilities with diveristy penalty + double emittingModelScore = emittingModelScores[validHypo][n] + diversityFactor; + + if (n == eos_) { /* (1) Try eos */ + auto lmStateScorePair = lm_->finish(prevHyp.lmState); + auto lmScore = lmStateScorePair.second; + + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + prevHyp.score + emittingModelScore + opt_.eosScore + + opt_.lmWeight * lmScore, + lmStateScorePair.first, + &prevHyp, + n, + nullptr, + prevHyp.emittingModelScore + emittingModelScore, + prevHyp.lmScore + lmScore, + hypo); + } else { /* (2) Try normal token */ + auto lmStateScorePair = lm_->score(prevHyp.lmState, n); + auto lmScore = lmStateScorePair.second; + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + prevHyp.score + emittingModelScore + opt_.lmWeight * lmScore, + lmStateScorePair.first, + &prevHyp, + n, + outState, + prevHyp.emittingModelScore + emittingModelScore, + prevHyp.lmScore + lmScore, + hypo); + uniqueCandidateTokens_.insert(n); + } + } + validHypo++; } - validHypo++; } candidatesStore( candidates_, diff --git a/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.h b/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.h index 0d9b9ee3..55ca1927 100644 --- a/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.h +++ b/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.h @@ -26,6 +26,7 @@ struct LexiconFreeSeq2SeqDecoderOptions { double lmWeight; // Weight of lm double eosScore; // Score for inserting an EOS bool logAdd; // If or not use logadd when merging hypothesis + int numBeamGroups = 1; // For diverse beam search, number of beam groups to utilize. Defaults to 1 (non-diverse beam search). }; /** @@ -78,7 +79,7 @@ struct LexiconFreeSeq2SeqDecoderState { int getWord() const { return -1; } -}; +}; /** * Decoder implements a beam seach decoder that finds the token transcription @@ -100,12 +101,14 @@ class LexiconFreeSeq2SeqDecoder : public Decoder { const LMPtr& lm, const int eos, EmittingModelUpdateFunc emittingModelUpdateFunc, - const int maxOutputLength) + const int maxOutputLength, + DiversityFunction diversityFunction) : opt_(std::move(opt)), lm_(lm), eos_(eos), emittingModelUpdateFunc_(emittingModelUpdateFunc), - maxOutputLength_(maxOutputLength) {} + maxOutputLength_(maxOutputLength), + diversityFunction_(diversityFunction) {} void decodeStep(const float* emissions, int T, int N) override; @@ -130,6 +133,10 @@ class LexiconFreeSeq2SeqDecoder : public Decoder { std::vector rawPrevStates_; int maxOutputLength_; + DiversityFunction diversityFunction_; + + std::unordered_set uniqueCandidateTokens_; + std::vector candidates_; std::vector candidatePtrs_; double candidatesBestScore_; diff --git a/flashlight/lib/text/decoder/Utils.h b/flashlight/lib/text/decoder/Utils.h index 7df6c8c0..5a8d3d1e 100644 --- a/flashlight/lib/text/decoder/Utils.h +++ b/flashlight/lib/text/decoder/Utils.h @@ -116,6 +116,8 @@ using EmittingModelUpdateFunc = std::function< int& // The current time step being decoded -- 0 --> T )>; +using DiversityFunction = std::function(const std::vector&, const int); + /* ===================== Candidate-related operations ===================== */ template