encryption: Introduce Encryptable (#40282)

This commit is contained in:
Eric Fritz 2022-08-15 09:15:51 -05:00 committed by GitHub
parent cacd49ea0a
commit a8cb7dc4c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 456 additions and 0 deletions

View File

@ -0,0 +1,130 @@
package encryption
import (
"context"
"sync"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
// Encryptable wraps a value and an encryption key and handles lazily encrypting and
// decrypting that value. This struct should be used in all places where a value is
// encrypted at-rest to maintain a consistent handling of data with security concerns.
//
// This struct should always be passed by reference.
type Encryptable struct {
mutex sync.Mutex
decrypted *decryptedValue
encrypted *EncryptedValue
key Key
}
type decryptedValue struct {
value string
err error
}
// EncryptedValue wraps an encrypted value and serialized metadata about that key that
// encrypted it.
type EncryptedValue struct {
Cipher string
KeyID string
}
// NewUnencrypted creates a new encryptable from a plaintext value.
func NewUnencrypted(value string) *Encryptable {
return NewUnencryptedWithKey(value, nil)
}
func NewUnencryptedWithKey(value string, key Key) *Encryptable {
return &Encryptable{
decrypted: &decryptedValue{value, nil},
key: key,
}
}
// NewEncrypted creates a new encryptable from an encrypted value and a relevant encryption key.
func NewEncrypted(cipher, keyID string, key Key) *Encryptable {
return &Encryptable{
encrypted: &EncryptedValue{cipher, keyID},
key: key,
}
}
// Decrypt returns the underlying plaintext value. This method may make an external API call to
// decrypt the underlying encrypted value, but will memoize the result so that subsequent calls
// will be cheap.
func (e *Encryptable) Decrypt(ctx context.Context) (string, error) {
e.mutex.Lock()
defer e.mutex.Unlock()
return e.decryptLocked(ctx)
}
func (e *Encryptable) decryptLocked(ctx context.Context) (string, error) {
if e.decrypted != nil {
return e.decrypted.value, e.decrypted.err
}
if e.encrypted == nil {
return "", errors.New("no encrypted value")
}
value, err := MaybeDecrypt(ctx, e.key, e.encrypted.Cipher, e.encrypted.KeyID)
e.decrypted = &decryptedValue{value, err}
return value, err
}
// Encrypt returns the underlying encrypted value. This method may make an external API call to
// encrypt the underlying plaintext value, but will memoize the result so that subsequent calls
// will be cheap.
func (e *Encryptable) Encrypt(ctx context.Context, key Key) (string, string, error) {
if err := e.SetKey(ctx, key); err != nil {
return "", "", err
}
e.mutex.Lock()
defer e.mutex.Unlock()
if e.encrypted != nil {
return e.encrypted.Cipher, e.encrypted.KeyID, nil
}
if e.decrypted == nil {
return "", "", errors.New("nothing to encrypt")
}
cipher, keyID, err := MaybeEncrypt(ctx, e.key, e.decrypted.value)
if err != nil {
return "", "", err
}
e.encrypted = &EncryptedValue{cipher, keyID}
return cipher, keyID, err
}
// Set updates the underlying plaintext value.
func (e *Encryptable) Set(value string) {
e.mutex.Lock()
defer e.mutex.Unlock()
e.decrypted = &decryptedValue{value, nil}
e.encrypted = nil
}
// SetKey updates the encryption key used with the encrypted value. This method may trigger an
// external API call to decrypt the current value.
func (e *Encryptable) SetKey(ctx context.Context, key Key) error {
e.mutex.Lock()
defer e.mutex.Unlock()
if e.key == key {
return nil
}
if _, err := e.decryptLocked(ctx); err != nil {
return err
}
e.key = key
e.encrypted = nil
return nil
}

View File

@ -0,0 +1,78 @@
package encryption
import (
"context"
"encoding/json"
"testing"
)
func TestEncryptable(t *testing.T) {
ctx := context.Background()
base64Key := base64Key{}
base64Key2 := base64PlusJunkKey{}
keyID, _ := json.Marshal(base64KeyVersion)
for _, encryptable := range []*Encryptable{
NewUnencrypted("foobar"),
NewEncrypted("Zm9vYmFy", string(keyID), base64Key),
} {
// Test Decrypt
decrypted, err := encryptable.Decrypt(ctx)
if err != nil {
t.Fatalf("unexpected error encrypting: %s", err.Error())
}
if want := "foobar"; decrypted != want {
t.Fatalf("unexpected decrypted value. want=%q have=%q", want, decrypted)
}
// Test Encrypt
encrypted, keyID, err := encryptable.Encrypt(ctx, base64Key)
if err != nil {
t.Fatalf("unexpected error encrypting: %s", err.Error())
}
if want := "Zm9vYmFy"; encrypted != want {
t.Fatalf("unexpected encrypted value. want=%q have=%q", want, encrypted)
}
if want := base64KeyVersion.Type; keyType(t, keyID) != want {
t.Fatalf("unexpected key identifier. want=%q have=%q", want, keyType(t, keyID))
}
// Test SetKey
if err := encryptable.SetKey(ctx, base64Key2); err != nil {
t.Fatalf("unexpected error setting key: %s", err.Error())
}
// Re-test Decrypt
decrypted, err = encryptable.Decrypt(ctx)
if err != nil {
t.Fatalf("unexpected error encrypting: %s", err.Error())
}
if want := "foobar"; decrypted != want {
t.Fatalf("unexpected decrypted value. want=%q have=%q", want, decrypted)
}
// Test Set
encryptable.Set("barbaz")
// Re-test Decrypt
decrypted, err = encryptable.Decrypt(ctx)
if err != nil {
t.Fatalf("unexpected error encrypting: %s", err.Error())
}
if want := "barbaz"; decrypted != want {
t.Fatalf("unexpected decrypted value. want=%q have=%q", want, decrypted)
}
// Re-test Encrypt
encrypted, keyID, err = encryptable.Encrypt(ctx, base64Key)
if err != nil {
t.Fatalf("unexpected error encrypting: %s", err.Error())
}
if want := "YmFyYmF6"; encrypted != want {
t.Fatalf("unexpected encrypted value. want=%q have=%q", want, encrypted)
}
if want := base64KeyVersion.Type; keyType(t, keyID) != want {
t.Fatalf("unexpected key identifier. want=%q have=%q", want, keyType(t, keyID))
}
}
}

View File

@ -0,0 +1,81 @@
package encryption
import (
"context"
"encoding/json"
)
// JSONEncryptable wraps a value of type T and an encryption key and handles lazily encoding/encrypting
// and decrypting/decoding that value. This struct should be used in all places where a JSON-serialized
// value is encrypted at-rest to maintain a consistent handling of data with security concerns.
//
// This struct should always be passed by reference.
type JSONEncryptable[T any] struct {
*Encryptable
}
// NewUnencryptedJSON creates a new JSON encryptable from the given value.
func NewUnencryptedJSON[T any](value T) (*JSONEncryptable[T], error) {
return NewUnencryptedJSONWithKey(value, nil)
}
func NewUnencryptedJSONWithKey[T any](value T, key Key) (*JSONEncryptable[T], error) {
serialized, err := json.Marshal(value)
if err != nil {
return nil, err
}
return &JSONEncryptable[T]{Encryptable: NewUnencryptedWithKey(string(serialized), key)}, nil
}
// NewEncryptedJSON creates a new JSON encryptable an encrypted value and a relevant encryption key.
func NewEncryptedJSON[T any](cipher, keyID string, key Key) *JSONEncryptable[T] {
return &JSONEncryptable[T]{Encryptable: NewEncrypted(cipher, keyID, key)}
}
// Decrypt decrypts and returns the underlying value as a T. This method may make an external API call
// to decrypt the underlying encrypted value, but will memoize the result so that subsequent calls will
// be cheap.
func (e *JSONEncryptable[T]) Decrypt(ctx context.Context) (value T, _ error) {
serialized, err := e.Encryptable.Decrypt(ctx)
if err != nil {
return value, err
}
if err := json.Unmarshal([]byte(serialized), &value); err != nil {
return value, err
}
return value, nil
}
// DecryptInto decrypts the underlying value and updates the given value. This method may make an external
// API call to decrypt the underlying encrypted value, but will memoize the result so that subsequent calls
// will be cheap.
func (e *JSONEncryptable[T]) DecryptInto(ctx context.Context, value T) error {
serialized, err := e.Encryptable.Decrypt(ctx)
if err != nil {
return err
}
if err := json.Unmarshal([]byte(serialized), &value); err != nil {
return err
}
return nil
}
// Set updates the underlying value.
func (e *JSONEncryptable[T]) Set(value T) error {
serialized, err := json.Marshal(value)
if err != nil {
return err
}
e.mutex.Lock()
defer e.mutex.Unlock()
e.decrypted = &decryptedValue{string(serialized), nil}
e.encrypted = nil
return nil
}

View File

@ -0,0 +1,85 @@
package encryption
import (
"context"
"encoding/json"
"testing"
)
func TestJSONEncryptable(t *testing.T) {
ctx := context.Background()
base64Key := base64Key{}
keyID, _ := json.Marshal(base64KeyVersion)
keyType := func(t *testing.T, keyID string) string {
var key KeyVersion
if err := json.Unmarshal([]byte(keyID), &key); err != nil {
t.Fatalf("unexpected key identifier - not json: %s", err.Error())
}
return key.Type
}
type T struct {
Foo int `json:"foo"`
Bar int `json:"bar"`
Baz int `json:"baz"`
}
v1 := T{1, 2, 3}
v2 := T{7, 8, 9}
unencrypted, err := NewUnencryptedJSON(v1)
if err != nil {
t.Fatalf("unexpected error creating encryptable: %s", err.Error())
}
for _, encryptable := range []*JSONEncryptable[T]{
unencrypted,
NewEncryptedJSON[T]("eyJmb28iOjEsImJhciI6MiwiYmF6IjozfQ==", string(keyID), base64Key),
} {
// Test Decrypt
decrypted, err := encryptable.Decrypt(ctx)
if err != nil {
t.Fatalf("unexpected error encrypting: %s", err.Error())
}
if want := v1; decrypted != want {
t.Fatalf("unexpected decrypted value. want=%q have=%q", want, decrypted)
}
// Test Encrypt
encrypted, keyID, err := encryptable.Encrypt(ctx, base64Key)
if err != nil {
t.Fatalf("unexpected error encrypting: %s", err.Error())
}
if want := "eyJmb28iOjEsImJhciI6MiwiYmF6IjozfQ=="; encrypted != want {
t.Fatalf("unexpected encrypted value. want=%q have=%q", want, encrypted)
}
if want := base64KeyVersion.Type; keyType(t, keyID) != want {
t.Fatalf("unexpected key identifier. want=%q have=%q", want, keyType(t, keyID))
}
// Test Set
encryptable.Set(v2)
// Re-test Decrypt
decrypted, err = encryptable.Decrypt(ctx)
if err != nil {
t.Fatalf("unexpected error encrypting: %s", err.Error())
}
if want := v2; decrypted != want {
t.Fatalf("unexpected decrypted value. want=%q have=%q", want, decrypted)
}
// Re-test Encrypt
encrypted, keyID, err = encryptable.Encrypt(ctx, base64Key)
if err != nil {
t.Fatalf("unexpected error encrypting: %s", err.Error())
}
if want := "eyJmb28iOjcsImJhciI6OCwiYmF6Ijo5fQ=="; encrypted != want {
t.Fatalf("unexpected encrypted value. want=%q have=%q", want, encrypted)
}
if want := base64KeyVersion.Type; keyType(t, keyID) != want {
t.Fatalf("unexpected key identifier. want=%q have=%q", want, keyType(t, keyID))
}
}
}

View File

@ -0,0 +1,30 @@
package testing
import (
"context"
"github.com/google/go-cmp/cmp"
"github.com/sourcegraph/sourcegraph/internal/encryption"
)
var CompareEncryptable = cmp.Comparer(func(a, b *encryption.Encryptable) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
aValue, err := a.Decrypt(context.Background())
if err != nil {
return false
}
bValue, err := b.Decrypt(context.Background())
if err != nil {
return false
}
return cmp.Diff(aValue, bValue) == ""
})

View File

@ -0,0 +1,52 @@
package encryption
import (
"context"
"encoding/base64"
"encoding/json"
"testing"
)
type base64Key struct{}
var base64KeyVersion = KeyVersion{Type: "base64"}
func (k base64Key) Encrypt(ctx context.Context, plaintext []byte) ([]byte, error) {
return []byte(base64.StdEncoding.EncodeToString(plaintext)), nil
}
func (k base64Key) Decrypt(ctx context.Context, ciphertext []byte) (*Secret, error) {
decoded, err := base64.StdEncoding.DecodeString(string(ciphertext))
s := NewSecret(string(decoded))
return &s, err
}
func (k base64Key) Version(ctx context.Context) (KeyVersion, error) {
return base64KeyVersion, nil
}
type base64PlusJunkKey struct{ base64Key }
var base64PlusJunkKeyVersion = KeyVersion{Type: "base64-plus-junk"}
func (k base64PlusJunkKey) Encrypt(ctx context.Context, plaintext []byte) ([]byte, error) {
encrypted, err := k.base64Key.Encrypt(ctx, plaintext)
return append([]byte(`!@#$`), encrypted...), err
}
func (k base64PlusJunkKey) Decrypt(ctx context.Context, ciphertext []byte) (*Secret, error) {
return k.base64Key.Decrypt(ctx, ciphertext[4:])
}
func (k base64PlusJunkKey) Version(ctx context.Context) (KeyVersion, error) {
return base64PlusJunkKeyVersion, nil
}
func keyType(t *testing.T, keyID string) string {
var key KeyVersion
if err := json.Unmarshal([]byte(keyID), &key); err != nil {
t.Fatalf("unexpected key identifier - not json: %s", err.Error())
}
return key.Type
}