Skip to content

Commit

Permalink
Merge pull request #179 from minitech/masking-security
Browse files Browse the repository at this point in the history
Fix cryptographic vulnerabilities in id masking
  • Loading branch information
mattrltrent authored Mar 7, 2024
2 parents 729b8ed + 79d3a73 commit e9f0dfd
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 84 deletions.
3 changes: 2 additions & 1 deletion env-example
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ APPCHECK_TOKEN="kXfeSRgYTnoUztu6MO8FndqiRayoBaJqyDKQmoqvX3V9sZVlep/cm7cP!mgd-B9H
HKDF_SECRET="some-secret-string"

# a 16-byte key
MASK_SECRET="your_16_byte_key"
# generate with `python3 -c 'import secrets; print(secrets.token_urlsafe(16))'`, for example
MASK_SECRET="gG9-td9Nvs3tNZDTXEXKaQ"

# Redis connection string
REDIS_CONN="redis:6379"
Expand Down
82 changes: 41 additions & 41 deletions lib/encryption/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,77 +3,77 @@ package encryption
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"math"
"os"
"strconv"
)

var secretKey []byte
var block cipher.Block

// Strictness ensures that decode(x) = decode(y) only if x = y (I hope). This isn't targeted at a concrete problem, but is a better default.
var encoding *base64.Encoding = base64.RawURLEncoding.Strict()

func init() {
// load from .env
m := os.Getenv("MASK_SECRET")
if m == "" {
panic("MASK_SECRET env not found")
}
secretKey = []byte(m)
}

func Hash(input uint) string {
hash := sha256.Sum256([]byte(fmt.Sprint(input)))
return base64.RawURLEncoding.EncodeToString(hash[:])
}
secretKey, err := encoding.DecodeString(m)
if err != nil {
panic(fmt.Errorf("couldn't decode MASK_SECRET: %w", err))
}

func Mask(id uint) (string, error) {
block, err := aes.NewCipher(secretKey)
block, err = aes.NewCipher(secretKey)
if err != nil {
return "", err
panic(fmt.Errorf("couldn't use MASK_SECRET as AES key: %w", err))
}
}

func encrypt(id uint32) string {
buf := make([]byte, aes.BlockSize)
binary.LittleEndian.PutUint32(buf[:4], id)
block.Encrypt(buf, buf)
return encoding.EncodeToString(buf)
}

ciphertext := make([]byte, aes.BlockSize+len(fmt.Sprintf("%d", id)))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return "", err
func Hash(id uint) string {
if id > math.MaxUint32 {
panic("id out of range")
}

ctr := cipher.NewCTR(block, iv)
plaintext := []byte(fmt.Sprintf("%d", id))
ctr.XORKeyStream(ciphertext[aes.BlockSize:], plaintext)
return encrypt(uint32(id))
}

func Mask(id uint) (string, error) {
if id > math.MaxUint32 {
return "", fmt.Errorf("id out of range")
}

return base64.URLEncoding.EncodeToString(ciphertext), nil
return encrypt(uint32(id)), nil
}

func Unmask(ciphertext string) (uint, error) {
block, err := aes.NewCipher(secretKey)
if err != nil {
return 0, err
if len(ciphertext) != encoding.EncodedLen(aes.BlockSize) {
return 0, fmt.Errorf("invalid ciphertext length")
}

decodedCiphertext, err := base64.URLEncoding.DecodeString(ciphertext)
buf, err := encoding.DecodeString(ciphertext)
if err != nil {
return 0, err
}

if len(decodedCiphertext) < aes.BlockSize {
return 0, fmt.Errorf("invalid ciphertext length")
}
block.Decrypt(buf, buf)

iv := decodedCiphertext[:aes.BlockSize]
if len(decodedCiphertext) <= aes.BlockSize {
return 0, fmt.Errorf("ciphertext too short")
// 256 - 32 = 224 bits for authenticated encryption. This check doesn't need to be timing-safe.
for _, b := range(buf[4:]) {
if b != 0 {
return 0, fmt.Errorf("invalid ciphertext")
}
}

ctr := cipher.NewCTR(block, iv)
plaintext := make([]byte, len(decodedCiphertext)-aes.BlockSize)
ctr.XORKeyStream(plaintext, decodedCiphertext[aes.BlockSize:])

decryptedID, err := strconv.Atoi(string(plaintext))
if err != nil {
return 0, err
}
return uint(decryptedID), nil
return uint(binary.LittleEndian.Uint32(buf[:4])), nil
}
46 changes: 4 additions & 42 deletions lib/encryption/encryption_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,43 +10,19 @@ import (
//! Tests require `MASK_SECRET` env var to be set to pass

func TestUniqueHash(t *testing.T) {
id := uint(78)
hash := Hash(id)
assert.Equal(t, "NJxBIBti24URkmZcUEs1D_mMa0X7YqiiFh94tlNNjek", hash, "Hashes do not match")
}

func TestUniqueMasksMapToSameID(t *testing.T) {

// Test case: Masking the same ID twice should result in different encrypted values
id := uint(5)

maskedID1, err := Mask(id)
assert.NoError(t, err, "Masking error")

maskedID2, err := Mask(id)
assert.NoError(t, err, "Masking error")

assert.NotEqual(t, maskedID1, maskedID2, "Masked IDs should not be equal")

// Test case: Unmasking the encrypted IDs should yield the same original ID
decryptedID1, err := Unmask(maskedID1)
assert.NoError(t, err, "Unmasking error")

decryptedID2, err := Unmask(maskedID2)
assert.NoError(t, err, "Unmasking error")

assert.Equal(t, id, decryptedID1, "Original and decrypted IDs do not match")
assert.Equal(t, id, decryptedID2, "Original and decrypted IDs do not match")
assert.Equal(t, Hash(1), Hash(1), "Hash should be deterministic")
assert.NotEqual(t, Hash(1), Hash(2), "Hash should be unique")
}

func TestEncryptionAndDecryption(t *testing.T) {
tests := []struct {
id uint
}{
{0}, // sub-test case 1
{12345212121224583}, // sub-test case 2
{123452121}, // sub-test case 2
{987654}, // sub-test case 3
{42}, // sub-test case 4
{123}, // sub-test case 5
}

for _, test := range tests {
Expand All @@ -65,17 +41,3 @@ func TestEncryptionAndDecryption(t *testing.T) {
})
}
}

func TestEncryptionAndDecryptionSimple(t *testing.T) {
val, err := Mask(123)
if err != nil {
t.Error("Encryption error:", err)
}

decrypted, err := Unmask(val)
if err != nil {
t.Error("Decryption error:", err)
}

assert.Equal(t, uint(123), decrypted)
}

0 comments on commit e9f0dfd

Please sign in to comment.