diff options
Diffstat (limited to 'vendor/github.com/xtaci/kcp-go/v5/crypt.go')
-rw-r--r-- | vendor/github.com/xtaci/kcp-go/v5/crypt.go | 618 |
1 files changed, 618 insertions, 0 deletions
diff --git a/vendor/github.com/xtaci/kcp-go/v5/crypt.go b/vendor/github.com/xtaci/kcp-go/v5/crypt.go new file mode 100644 index 0000000..d882852 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/crypt.go @@ -0,0 +1,618 @@ +package kcp + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/sha1" + "unsafe" + + xor "github.com/templexxx/xorsimd" + "github.com/tjfoc/gmsm/sm4" + + "golang.org/x/crypto/blowfish" + "golang.org/x/crypto/cast5" + "golang.org/x/crypto/pbkdf2" + "golang.org/x/crypto/salsa20" + "golang.org/x/crypto/tea" + "golang.org/x/crypto/twofish" + "golang.org/x/crypto/xtea" +) + +var ( + initialVector = []byte{167, 115, 79, 156, 18, 172, 27, 1, 164, 21, 242, 193, 252, 120, 230, 107} + saltxor = `sH3CIVoF#rWLtJo6` +) + +// BlockCrypt defines encryption/decryption methods for a given byte slice. +// Notes on implementing: the data to be encrypted contains a builtin +// nonce at the first 16 bytes +type BlockCrypt interface { + // Encrypt encrypts the whole block in src into dst. + // Dst and src may point at the same memory. + Encrypt(dst, src []byte) + + // Decrypt decrypts the whole block in src into dst. + // Dst and src may point at the same memory. + Decrypt(dst, src []byte) +} + +type salsa20BlockCrypt struct { + key [32]byte +} + +// NewSalsa20BlockCrypt https://en.wikipedia.org/wiki/Salsa20 +func NewSalsa20BlockCrypt(key []byte) (BlockCrypt, error) { + c := new(salsa20BlockCrypt) + copy(c.key[:], key) + return c, nil +} + +func (c *salsa20BlockCrypt) Encrypt(dst, src []byte) { + salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key) + copy(dst[:8], src[:8]) +} +func (c *salsa20BlockCrypt) Decrypt(dst, src []byte) { + salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key) + copy(dst[:8], src[:8]) +} + +type sm4BlockCrypt struct { + encbuf [sm4.BlockSize]byte // 64bit alignment enc/dec buffer + decbuf [2 * sm4.BlockSize]byte + block cipher.Block +} + +// NewSM4BlockCrypt https://github.com/tjfoc/gmsm/tree/master/sm4 +func NewSM4BlockCrypt(key []byte) (BlockCrypt, error) { + c := new(sm4BlockCrypt) + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *sm4BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *sm4BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type twofishBlockCrypt struct { + encbuf [twofish.BlockSize]byte + decbuf [2 * twofish.BlockSize]byte + block cipher.Block +} + +// NewTwofishBlockCrypt https://en.wikipedia.org/wiki/Twofish +func NewTwofishBlockCrypt(key []byte) (BlockCrypt, error) { + c := new(twofishBlockCrypt) + block, err := twofish.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *twofishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *twofishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type tripleDESBlockCrypt struct { + encbuf [des.BlockSize]byte + decbuf [2 * des.BlockSize]byte + block cipher.Block +} + +// NewTripleDESBlockCrypt https://en.wikipedia.org/wiki/Triple_DES +func NewTripleDESBlockCrypt(key []byte) (BlockCrypt, error) { + c := new(tripleDESBlockCrypt) + block, err := des.NewTripleDESCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *tripleDESBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *tripleDESBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type cast5BlockCrypt struct { + encbuf [cast5.BlockSize]byte + decbuf [2 * cast5.BlockSize]byte + block cipher.Block +} + +// NewCast5BlockCrypt https://en.wikipedia.org/wiki/CAST-128 +func NewCast5BlockCrypt(key []byte) (BlockCrypt, error) { + c := new(cast5BlockCrypt) + block, err := cast5.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *cast5BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *cast5BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type blowfishBlockCrypt struct { + encbuf [blowfish.BlockSize]byte + decbuf [2 * blowfish.BlockSize]byte + block cipher.Block +} + +// NewBlowfishBlockCrypt https://en.wikipedia.org/wiki/Blowfish_(cipher) +func NewBlowfishBlockCrypt(key []byte) (BlockCrypt, error) { + c := new(blowfishBlockCrypt) + block, err := blowfish.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *blowfishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *blowfishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type aesBlockCrypt struct { + encbuf [aes.BlockSize]byte + decbuf [2 * aes.BlockSize]byte + block cipher.Block +} + +// NewAESBlockCrypt https://en.wikipedia.org/wiki/Advanced_Encryption_Standard +func NewAESBlockCrypt(key []byte) (BlockCrypt, error) { + c := new(aesBlockCrypt) + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *aesBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *aesBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type teaBlockCrypt struct { + encbuf [tea.BlockSize]byte + decbuf [2 * tea.BlockSize]byte + block cipher.Block +} + +// NewTEABlockCrypt https://en.wikipedia.org/wiki/Tiny_Encryption_Algorithm +func NewTEABlockCrypt(key []byte) (BlockCrypt, error) { + c := new(teaBlockCrypt) + block, err := tea.NewCipherWithRounds(key, 16) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *teaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *teaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type xteaBlockCrypt struct { + encbuf [xtea.BlockSize]byte + decbuf [2 * xtea.BlockSize]byte + block cipher.Block +} + +// NewXTEABlockCrypt https://en.wikipedia.org/wiki/XTEA +func NewXTEABlockCrypt(key []byte) (BlockCrypt, error) { + c := new(xteaBlockCrypt) + block, err := xtea.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *xteaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *xteaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type simpleXORBlockCrypt struct { + xortbl []byte +} + +// NewSimpleXORBlockCrypt simple xor with key expanding +func NewSimpleXORBlockCrypt(key []byte) (BlockCrypt, error) { + c := new(simpleXORBlockCrypt) + c.xortbl = pbkdf2.Key(key, []byte(saltxor), 32, mtuLimit, sha1.New) + return c, nil +} + +func (c *simpleXORBlockCrypt) Encrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) } +func (c *simpleXORBlockCrypt) Decrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) } + +type noneBlockCrypt struct{} + +// NewNoneBlockCrypt does nothing but copying +func NewNoneBlockCrypt(key []byte) (BlockCrypt, error) { + return new(noneBlockCrypt), nil +} + +func (c *noneBlockCrypt) Encrypt(dst, src []byte) { copy(dst, src) } +func (c *noneBlockCrypt) Decrypt(dst, src []byte) { copy(dst, src) } + +// packet encryption with local CFB mode +func encrypt(block cipher.Block, dst, src, buf []byte) { + switch block.BlockSize() { + case 8: + encrypt8(block, dst, src, buf) + case 16: + encrypt16(block, dst, src, buf) + default: + panic("unsupported cipher block size") + } +} + +// optimized encryption for the ciphers which works in 8-bytes +func encrypt8(block cipher.Block, dst, src, buf []byte) { + tbl := buf[:8] + block.Encrypt(tbl, initialVector) + n := len(src) / 8 + base := 0 + repeat := n / 8 + left := n % 8 + ptr_tbl := (*uint64)(unsafe.Pointer(&tbl[0])) + + for i := 0; i < repeat; i++ { + s := src[base:][0:64] + d := dst[base:][0:64] + // 1 + *(*uint64)(unsafe.Pointer(&d[0])) = *(*uint64)(unsafe.Pointer(&s[0])) ^ *ptr_tbl + block.Encrypt(tbl, d[0:8]) + // 2 + *(*uint64)(unsafe.Pointer(&d[8])) = *(*uint64)(unsafe.Pointer(&s[8])) ^ *ptr_tbl + block.Encrypt(tbl, d[8:16]) + // 3 + *(*uint64)(unsafe.Pointer(&d[16])) = *(*uint64)(unsafe.Pointer(&s[16])) ^ *ptr_tbl + block.Encrypt(tbl, d[16:24]) + // 4 + *(*uint64)(unsafe.Pointer(&d[24])) = *(*uint64)(unsafe.Pointer(&s[24])) ^ *ptr_tbl + block.Encrypt(tbl, d[24:32]) + // 5 + *(*uint64)(unsafe.Pointer(&d[32])) = *(*uint64)(unsafe.Pointer(&s[32])) ^ *ptr_tbl + block.Encrypt(tbl, d[32:40]) + // 6 + *(*uint64)(unsafe.Pointer(&d[40])) = *(*uint64)(unsafe.Pointer(&s[40])) ^ *ptr_tbl + block.Encrypt(tbl, d[40:48]) + // 7 + *(*uint64)(unsafe.Pointer(&d[48])) = *(*uint64)(unsafe.Pointer(&s[48])) ^ *ptr_tbl + block.Encrypt(tbl, d[48:56]) + // 8 + *(*uint64)(unsafe.Pointer(&d[56])) = *(*uint64)(unsafe.Pointer(&s[56])) ^ *ptr_tbl + block.Encrypt(tbl, d[56:64]) + base += 64 + } + + switch left { + case 7: + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 6: + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 5: + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 4: + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 3: + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 2: + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 1: + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 0: + xorBytes(dst[base:], src[base:], tbl) + } +} + +// optimized encryption for the ciphers which works in 16-bytes +func encrypt16(block cipher.Block, dst, src, buf []byte) { + tbl := buf[:16] + block.Encrypt(tbl, initialVector) + n := len(src) / 16 + base := 0 + repeat := n / 8 + left := n % 8 + for i := 0; i < repeat; i++ { + s := src[base:][0:128] + d := dst[base:][0:128] + // 1 + xor.Bytes16Align(d[0:16], s[0:16], tbl) + block.Encrypt(tbl, d[0:16]) + // 2 + xor.Bytes16Align(d[16:32], s[16:32], tbl) + block.Encrypt(tbl, d[16:32]) + // 3 + xor.Bytes16Align(d[32:48], s[32:48], tbl) + block.Encrypt(tbl, d[32:48]) + // 4 + xor.Bytes16Align(d[48:64], s[48:64], tbl) + block.Encrypt(tbl, d[48:64]) + // 5 + xor.Bytes16Align(d[64:80], s[64:80], tbl) + block.Encrypt(tbl, d[64:80]) + // 6 + xor.Bytes16Align(d[80:96], s[80:96], tbl) + block.Encrypt(tbl, d[80:96]) + // 7 + xor.Bytes16Align(d[96:112], s[96:112], tbl) + block.Encrypt(tbl, d[96:112]) + // 8 + xor.Bytes16Align(d[112:128], s[112:128], tbl) + block.Encrypt(tbl, d[112:128]) + base += 128 + } + + switch left { + case 7: + xor.Bytes16Align(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 6: + xor.Bytes16Align(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 5: + xor.Bytes16Align(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 4: + xor.Bytes16Align(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 3: + xor.Bytes16Align(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 2: + xor.Bytes16Align(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 1: + xor.Bytes16Align(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 0: + xorBytes(dst[base:], src[base:], tbl) + } +} + +// decryption +func decrypt(block cipher.Block, dst, src, buf []byte) { + switch block.BlockSize() { + case 8: + decrypt8(block, dst, src, buf) + case 16: + decrypt16(block, dst, src, buf) + default: + panic("unsupported cipher block size") + } +} + +// decrypt 8 bytes block, all byte slices are supposed to be 64bit aligned +func decrypt8(block cipher.Block, dst, src, buf []byte) { + tbl := buf[0:8] + next := buf[8:16] + block.Encrypt(tbl, initialVector) + n := len(src) / 8 + base := 0 + repeat := n / 8 + left := n % 8 + ptr_tbl := (*uint64)(unsafe.Pointer(&tbl[0])) + ptr_next := (*uint64)(unsafe.Pointer(&next[0])) + + for i := 0; i < repeat; i++ { + s := src[base:][0:64] + d := dst[base:][0:64] + // 1 + block.Encrypt(next, s[0:8]) + *(*uint64)(unsafe.Pointer(&d[0])) = *(*uint64)(unsafe.Pointer(&s[0])) ^ *ptr_tbl + // 2 + block.Encrypt(tbl, s[8:16]) + *(*uint64)(unsafe.Pointer(&d[8])) = *(*uint64)(unsafe.Pointer(&s[8])) ^ *ptr_next + // 3 + block.Encrypt(next, s[16:24]) + *(*uint64)(unsafe.Pointer(&d[16])) = *(*uint64)(unsafe.Pointer(&s[16])) ^ *ptr_tbl + // 4 + block.Encrypt(tbl, s[24:32]) + *(*uint64)(unsafe.Pointer(&d[24])) = *(*uint64)(unsafe.Pointer(&s[24])) ^ *ptr_next + // 5 + block.Encrypt(next, s[32:40]) + *(*uint64)(unsafe.Pointer(&d[32])) = *(*uint64)(unsafe.Pointer(&s[32])) ^ *ptr_tbl + // 6 + block.Encrypt(tbl, s[40:48]) + *(*uint64)(unsafe.Pointer(&d[40])) = *(*uint64)(unsafe.Pointer(&s[40])) ^ *ptr_next + // 7 + block.Encrypt(next, s[48:56]) + *(*uint64)(unsafe.Pointer(&d[48])) = *(*uint64)(unsafe.Pointer(&s[48])) ^ *ptr_tbl + // 8 + block.Encrypt(tbl, s[56:64]) + *(*uint64)(unsafe.Pointer(&d[56])) = *(*uint64)(unsafe.Pointer(&s[56])) ^ *ptr_next + base += 64 + } + + switch left { + case 7: + block.Encrypt(next, src[base:]) + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) + tbl, next = next, tbl + base += 8 + fallthrough + case 6: + block.Encrypt(next, src[base:]) + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) + tbl, next = next, tbl + base += 8 + fallthrough + case 5: + block.Encrypt(next, src[base:]) + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) + tbl, next = next, tbl + base += 8 + fallthrough + case 4: + block.Encrypt(next, src[base:]) + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) + tbl, next = next, tbl + base += 8 + fallthrough + case 3: + block.Encrypt(next, src[base:]) + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) + tbl, next = next, tbl + base += 8 + fallthrough + case 2: + block.Encrypt(next, src[base:]) + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) + tbl, next = next, tbl + base += 8 + fallthrough + case 1: + block.Encrypt(next, src[base:]) + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) + tbl, next = next, tbl + base += 8 + fallthrough + case 0: + xorBytes(dst[base:], src[base:], tbl) + } +} + +func decrypt16(block cipher.Block, dst, src, buf []byte) { + tbl := buf[0:16] + next := buf[16:32] + block.Encrypt(tbl, initialVector) + n := len(src) / 16 + base := 0 + repeat := n / 8 + left := n % 8 + for i := 0; i < repeat; i++ { + s := src[base:][0:128] + d := dst[base:][0:128] + // 1 + block.Encrypt(next, s[0:16]) + xor.Bytes16Align(d[0:16], s[0:16], tbl) + // 2 + block.Encrypt(tbl, s[16:32]) + xor.Bytes16Align(d[16:32], s[16:32], next) + // 3 + block.Encrypt(next, s[32:48]) + xor.Bytes16Align(d[32:48], s[32:48], tbl) + // 4 + block.Encrypt(tbl, s[48:64]) + xor.Bytes16Align(d[48:64], s[48:64], next) + // 5 + block.Encrypt(next, s[64:80]) + xor.Bytes16Align(d[64:80], s[64:80], tbl) + // 6 + block.Encrypt(tbl, s[80:96]) + xor.Bytes16Align(d[80:96], s[80:96], next) + // 7 + block.Encrypt(next, s[96:112]) + xor.Bytes16Align(d[96:112], s[96:112], tbl) + // 8 + block.Encrypt(tbl, s[112:128]) + xor.Bytes16Align(d[112:128], s[112:128], next) + base += 128 + } + + switch left { + case 7: + block.Encrypt(next, src[base:]) + xor.Bytes16Align(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 6: + block.Encrypt(next, src[base:]) + xor.Bytes16Align(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 5: + block.Encrypt(next, src[base:]) + xor.Bytes16Align(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 4: + block.Encrypt(next, src[base:]) + xor.Bytes16Align(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 3: + block.Encrypt(next, src[base:]) + xor.Bytes16Align(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 2: + block.Encrypt(next, src[base:]) + xor.Bytes16Align(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 1: + block.Encrypt(next, src[base:]) + xor.Bytes16Align(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 0: + xorBytes(dst[base:], src[base:], tbl) + } +} + +// per bytes xors +func xorBytes(dst, a, b []byte) int { + n := len(a) + if len(b) < n { + n = len(b) + } + if n == 0 { + return 0 + } + + for i := 0; i < n; i++ { + dst[i] = a[i] ^ b[i] + } + return n +} |