diff --git a/dare.go b/dare.go index 5aaf4c0..ed11259 100644 --- a/dare.go +++ b/dare.go @@ -210,6 +210,7 @@ func (ae *authEncV20) seal(dst, src []byte, finalize bool) { type authDecV20 struct { authDec refHeader headerV20 + endSeqNum *uint32 finalized bool } @@ -227,6 +228,7 @@ func newAuthDecV20(cfg *Config) (authDecV20, error) { SeqNum: cfg.SequenceNumber, Ciphers: ciphers, }, + endSeqNum: cfg.EndSequenceNumber, }, nil } @@ -259,6 +261,8 @@ func (ad *authDecV20) Open(dst, src []byte) error { if header.IsFinal() { ad.finalized = true refNonce[0] |= 0x80 // set final flag + } else if ad.endSeqNum != nil && ad.SeqNum == *ad.endSeqNum { + ad.finalized = true } if subtle.ConstantTimeCompare(header.Nonce(), refNonce) != 1 { return errNonceMismatch diff --git a/sio.go b/sio.go index 32985d3..d01caa9 100644 --- a/sio.go +++ b/sio.go @@ -118,6 +118,11 @@ type Config struct { // stream. SequenceNumber uint32 + // The last expected sequence number. It should only + // be set manually when decrypting a range within a + // stream. + EndSequenceNumber *uint32 + // The RNG used to generate random values. If not set // the default value (crypto/rand.Reader) is used. Rand io.Reader diff --git a/sio_test.go b/sio_test.go index dc25326..a242025 100644 --- a/sio_test.go +++ b/sio_test.go @@ -168,6 +168,117 @@ func TestDecryptBuffer(t *testing.T) { } }) } + + t.Run("EndSequenceNumber", func(t *testing.T) { + datasize := maxPayloadSize * 3 + data := make([]byte, datasize) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatalf("Failed to generate random data: %v", err) + } + + output := bytes.NewBuffer(nil) + + if _, err := Encrypt(output, bytes.NewReader(data), config); err != nil { + t.Errorf("Encryption failed: %v", err) + } + + t.Run("First package only", func(t *testing.T) { + var end uint32 = 0 + configEnd := Config{ + Key: config.Key, + EndSequenceNumber: &end, + } + + decrypted, err := DecryptBuffer(make([]byte, 0, maxPayloadSize), output.Bytes()[:maxPackageSize], configEnd) + if len(decrypted) != maxPayloadSize || err != nil { + t.Errorf("Decryption failed: number of bytes: %d vs. %d - %v", len(decrypted), maxPayloadSize, err) + return + } + if !bytes.Equal(data[:maxPayloadSize], decrypted) { + t.Errorf("Failed to encrypt and decrypt data") + } + }) + + t.Run("Second package only", func(t *testing.T) { + var end uint32 = 1 + configEnd := Config{ + Key: config.Key, + SequenceNumber: 1, + EndSequenceNumber: &end, + } + + decrypted, err := DecryptBuffer(make([]byte, 0, maxPayloadSize), output.Bytes()[maxPackageSize:maxPackageSize*2], configEnd) + if len(decrypted) != maxPayloadSize || err != nil { + t.Errorf("Decryption failed: number of bytes: %d vs. %d - %v", len(decrypted), maxPayloadSize*2, err) + return + } + if !bytes.Equal(data[maxPayloadSize:maxPayloadSize*2], decrypted) { + t.Errorf("Failed to encrypt and decrypt data") + } + }) + + t.Run("Last package only", func(t *testing.T) { + configEnd := Config{ + Key: config.Key, + SequenceNumber: 2, + } + + decrypted, err := DecryptBuffer(make([]byte, 0, maxPayloadSize), output.Bytes()[maxPackageSize*2:], configEnd) + if len(decrypted) != maxPayloadSize || err != nil { + t.Errorf("Decryption failed: number of bytes: %d vs. %d - %v", len(decrypted), maxPayloadSize, err) + return + } + if !bytes.Equal(data[maxPayloadSize*2:], decrypted) { + t.Errorf("Failed to encrypt and decrypt data") + } + }) + + t.Run("First & second package", func(t *testing.T) { + var end uint32 = 1 + configEnd := Config{ + Key: config.Key, + EndSequenceNumber: &end, + } + + decrypted, err := DecryptBuffer(make([]byte, 0, maxPayloadSize*2), output.Bytes()[:maxPackageSize*2], configEnd) + if len(decrypted) != maxPayloadSize*2 || err != nil { + t.Errorf("Decryption failed: number of bytes: %d vs. %d - %v", len(decrypted), maxPayloadSize*2, err) + return + } + if !bytes.Equal(data[:maxPayloadSize*2], decrypted) { + t.Errorf("Failed to encrypt and decrypt data") + } + }) + + t.Run("Second & last package", func(t *testing.T) { + var end uint32 = 2 + configEnd := Config{ + Key: config.Key, + SequenceNumber: 1, + EndSequenceNumber: &end, + } + + decrypted, err := DecryptBuffer(make([]byte, 0, maxPayloadSize*2), output.Bytes()[maxPackageSize:], configEnd) + if len(decrypted) != maxPayloadSize*2 || err != nil { + t.Errorf("Decryption failed: number of bytes: %d vs. %d - %v", len(decrypted), maxPayloadSize*2, err) + return + } + if !bytes.Equal(data[maxPayloadSize:], decrypted) { + t.Errorf("Failed to encrypt and decrypt data") + } + }) + + t.Run("End at second but no EndSequenceNumber", func(t *testing.T) { + configEnd := Config{ + Key: config.Key, + } + + _, err := DecryptBuffer(make([]byte, 0, maxPayloadSize*2), output.Bytes()[:maxPackageSize*2], configEnd) + if err != errUnexpectedEOF { + t.Errorf("No error but error expected") + } + }) + }) } func TestReader(t *testing.T) {