change return from newDecryptionReader to io.Reader

This commit is contained in:
alexmullins 2015-10-29 17:07:25 -05:00
parent abdb448165
commit 678a885320
1 changed files with 6 additions and 15 deletions

View File

@ -17,7 +17,6 @@ import (
"hash" "hash"
"hash/crc32" "hash/crc32"
"io" "io"
"io/ioutil"
"os" "os"
"golang.org/x/crypto/pbkdf2" "golang.org/x/crypto/pbkdf2"
@ -190,11 +189,11 @@ func (f *File) Open() (rc io.ReadCloser, err error) {
return return
} }
rc = dcomp(r) rc = dcomp(r)
// TODO: if AE-2, skip CRC and possible dataDescriptor
var desr io.Reader var desr io.Reader
if f.hasDataDescriptor() { if f.hasDataDescriptor() {
desr = io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset+size, dataDescriptorLen) desr = io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset+size, dataDescriptorLen)
} }
// TODO: if AE-2, skip CRC
rc = &checksumReader{ rc = &checksumReader{
rc: rc, rc: rc,
hash: crc32.NewIEEE(), hash: crc32.NewIEEE(),
@ -204,18 +203,17 @@ func (f *File) Open() (rc io.ReadCloser, err error) {
return return
} }
func newDecryptionReader(r io.Reader, f *File) (io.ReadCloser, error) { func newDecryptionReader(r io.Reader, f *File) (io.Reader, error) {
keyLen := aesKeyLen(f.aesStrength) keyLen := aesKeyLen(f.aesStrength)
saltLen := keyLen / 2 // salt is half of key len saltLen := keyLen / 2 // salt is half of key len
if saltLen == 0 { if saltLen == 0 {
return nil, ErrDecryption return nil, ErrDecryption
} }
// Is there a better method than reading in the entire contents?
content := make([]byte, f.CompressedSize64) content := make([]byte, f.CompressedSize64)
if _, err := io.ReadFull(r, content); err != nil { if _, err := io.ReadFull(r, content); err != nil {
return nil, ErrDecryption return nil, ErrDecryption
} }
// grab the salt, pwvv, data, and authcode // grab the salt, pwvv, data, and authcode
salt := content[:saltLen] salt := content[:saltLen]
pwvv := content[saltLen : saltLen+2] pwvv := content[saltLen : saltLen+2]
@ -223,35 +221,31 @@ func newDecryptionReader(r io.Reader, f *File) (io.ReadCloser, error) {
size := f.UncompressedSize64 size := f.UncompressedSize64
data := content[:size] data := content[:size]
authcode := content[size:] authcode := content[size:]
// generate keys // generate keys
decKey, authKey, pwv := generateKeys(f.password, salt, keyLen) decKey, authKey, pwv := generateKeys(f.password, salt, keyLen)
// check password verifier (pwv) // check password verifier (pwv)
if !bytes.Equal(pwv, pwvv) { if !bytes.Equal(pwv, pwvv) {
return nil, ErrDecryption return nil, ErrDecryption
} }
// check authentication // check authentication
if !checkAuthentication(data, authcode, authKey) { if !checkAuthentication(data, authcode, authKey) {
return nil, ErrDecryption return nil, ErrDecryption
} }
// set the IV // set the IV
// see: https://forum.golangbridge.org/t/iv-counter-help-for-aes-ctr/1369
var iv [aes.BlockSize]byte var iv [aes.BlockSize]byte
iv[0] = 1 iv[0] = 1
return decryptStream(data, decKey, iv[:]), nil return decryptStream(data, decKey, iv[:]), nil
} }
func decryptStream(ciphertext, key, iv []byte) io.ReadCloser { func decryptStream(ciphertext, key, iv []byte) io.Reader {
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
if err != nil { if err != nil {
return nil return nil
} }
stream := cipher.NewCTR(block, iv) stream := cipher.NewCTR(block, iv)
reader := cipher.StreamReader{S: stream, R: bytes.NewReader(ciphertext)} reader := cipher.StreamReader{S: stream, R: bytes.NewReader(ciphertext)}
return ioutil.NopCloser(reader) return reader
} }
func checkAuthentication(message, authcode, key []byte) bool { func checkAuthentication(message, authcode, key []byte) bool {
@ -399,13 +393,10 @@ func readDirectoryHeader(f *File, r io.Reader) error {
case winzipAesExtraId: case winzipAesExtraId:
// grab the AE version // grab the AE version
f.ae = eb.uint16() f.ae = eb.uint16()
// skip vendor ID // skip vendor ID
_ = eb.uint16() _ = eb.uint16()
// AES strength // AES strength
f.aesStrength = eb.uint8() f.aesStrength = eb.uint8()
// set the actual compression method. // set the actual compression method.
f.Method = eb.uint16() f.Method = eb.uint16()
} }