From a8cb7dc4c518d24630d977e0ffc6e651ab01bba3 Mon Sep 17 00:00:00 2001 From: Eric Fritz Date: Mon, 15 Aug 2022 09:15:51 -0500 Subject: [PATCH] encryption: Introduce `Encryptable` (#40282) --- internal/encryption/encryptable.go | 130 +++++++++++++++++++ internal/encryption/encryptable_test.go | 78 +++++++++++ internal/encryption/json_encryptable.go | 81 ++++++++++++ internal/encryption/json_encryptable_test.go | 85 ++++++++++++ internal/encryption/testing/compare.go | 30 +++++ internal/encryption/utils_test.go | 52 ++++++++ 6 files changed, 456 insertions(+) create mode 100644 internal/encryption/encryptable.go create mode 100644 internal/encryption/encryptable_test.go create mode 100644 internal/encryption/json_encryptable.go create mode 100644 internal/encryption/json_encryptable_test.go create mode 100644 internal/encryption/testing/compare.go create mode 100644 internal/encryption/utils_test.go diff --git a/internal/encryption/encryptable.go b/internal/encryption/encryptable.go new file mode 100644 index 00000000000..19d6dc5a24d --- /dev/null +++ b/internal/encryption/encryptable.go @@ -0,0 +1,130 @@ +package encryption + +import ( + "context" + "sync" + + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +// Encryptable wraps a value and an encryption key and handles lazily encrypting and +// decrypting that value. This struct should be used in all places where a value is +// encrypted at-rest to maintain a consistent handling of data with security concerns. +// +// This struct should always be passed by reference. +type Encryptable struct { + mutex sync.Mutex + decrypted *decryptedValue + encrypted *EncryptedValue + key Key +} + +type decryptedValue struct { + value string + err error +} + +// EncryptedValue wraps an encrypted value and serialized metadata about that key that +// encrypted it. +type EncryptedValue struct { + Cipher string + KeyID string +} + +// NewUnencrypted creates a new encryptable from a plaintext value. +func NewUnencrypted(value string) *Encryptable { + return NewUnencryptedWithKey(value, nil) +} + +func NewUnencryptedWithKey(value string, key Key) *Encryptable { + return &Encryptable{ + decrypted: &decryptedValue{value, nil}, + key: key, + } +} + +// NewEncrypted creates a new encryptable from an encrypted value and a relevant encryption key. +func NewEncrypted(cipher, keyID string, key Key) *Encryptable { + return &Encryptable{ + encrypted: &EncryptedValue{cipher, keyID}, + key: key, + } +} + +// Decrypt returns the underlying plaintext value. This method may make an external API call to +// decrypt the underlying encrypted value, but will memoize the result so that subsequent calls +// will be cheap. +func (e *Encryptable) Decrypt(ctx context.Context) (string, error) { + e.mutex.Lock() + defer e.mutex.Unlock() + + return e.decryptLocked(ctx) +} + +func (e *Encryptable) decryptLocked(ctx context.Context) (string, error) { + if e.decrypted != nil { + return e.decrypted.value, e.decrypted.err + } + if e.encrypted == nil { + return "", errors.New("no encrypted value") + } + + value, err := MaybeDecrypt(ctx, e.key, e.encrypted.Cipher, e.encrypted.KeyID) + e.decrypted = &decryptedValue{value, err} + return value, err +} + +// Encrypt returns the underlying encrypted value. This method may make an external API call to +// encrypt the underlying plaintext value, but will memoize the result so that subsequent calls +// will be cheap. +func (e *Encryptable) Encrypt(ctx context.Context, key Key) (string, string, error) { + if err := e.SetKey(ctx, key); err != nil { + return "", "", err + } + + e.mutex.Lock() + defer e.mutex.Unlock() + + if e.encrypted != nil { + return e.encrypted.Cipher, e.encrypted.KeyID, nil + } + if e.decrypted == nil { + return "", "", errors.New("nothing to encrypt") + } + + cipher, keyID, err := MaybeEncrypt(ctx, e.key, e.decrypted.value) + if err != nil { + return "", "", err + } + + e.encrypted = &EncryptedValue{cipher, keyID} + return cipher, keyID, err +} + +// Set updates the underlying plaintext value. +func (e *Encryptable) Set(value string) { + e.mutex.Lock() + defer e.mutex.Unlock() + + e.decrypted = &decryptedValue{value, nil} + e.encrypted = nil +} + +// SetKey updates the encryption key used with the encrypted value. This method may trigger an +// external API call to decrypt the current value. +func (e *Encryptable) SetKey(ctx context.Context, key Key) error { + e.mutex.Lock() + defer e.mutex.Unlock() + + if e.key == key { + return nil + } + + if _, err := e.decryptLocked(ctx); err != nil { + return err + } + + e.key = key + e.encrypted = nil + return nil +} diff --git a/internal/encryption/encryptable_test.go b/internal/encryption/encryptable_test.go new file mode 100644 index 00000000000..218a86674db --- /dev/null +++ b/internal/encryption/encryptable_test.go @@ -0,0 +1,78 @@ +package encryption + +import ( + "context" + "encoding/json" + "testing" +) + +func TestEncryptable(t *testing.T) { + ctx := context.Background() + base64Key := base64Key{} + base64Key2 := base64PlusJunkKey{} + keyID, _ := json.Marshal(base64KeyVersion) + + for _, encryptable := range []*Encryptable{ + NewUnencrypted("foobar"), + NewEncrypted("Zm9vYmFy", string(keyID), base64Key), + } { + // Test Decrypt + decrypted, err := encryptable.Decrypt(ctx) + if err != nil { + t.Fatalf("unexpected error encrypting: %s", err.Error()) + } + if want := "foobar"; decrypted != want { + t.Fatalf("unexpected decrypted value. want=%q have=%q", want, decrypted) + } + + // Test Encrypt + encrypted, keyID, err := encryptable.Encrypt(ctx, base64Key) + if err != nil { + t.Fatalf("unexpected error encrypting: %s", err.Error()) + } + if want := "Zm9vYmFy"; encrypted != want { + t.Fatalf("unexpected encrypted value. want=%q have=%q", want, encrypted) + } + if want := base64KeyVersion.Type; keyType(t, keyID) != want { + t.Fatalf("unexpected key identifier. want=%q have=%q", want, keyType(t, keyID)) + } + + // Test SetKey + if err := encryptable.SetKey(ctx, base64Key2); err != nil { + t.Fatalf("unexpected error setting key: %s", err.Error()) + } + + // Re-test Decrypt + decrypted, err = encryptable.Decrypt(ctx) + if err != nil { + t.Fatalf("unexpected error encrypting: %s", err.Error()) + } + if want := "foobar"; decrypted != want { + t.Fatalf("unexpected decrypted value. want=%q have=%q", want, decrypted) + } + + // Test Set + encryptable.Set("barbaz") + + // Re-test Decrypt + decrypted, err = encryptable.Decrypt(ctx) + if err != nil { + t.Fatalf("unexpected error encrypting: %s", err.Error()) + } + if want := "barbaz"; decrypted != want { + t.Fatalf("unexpected decrypted value. want=%q have=%q", want, decrypted) + } + + // Re-test Encrypt + encrypted, keyID, err = encryptable.Encrypt(ctx, base64Key) + if err != nil { + t.Fatalf("unexpected error encrypting: %s", err.Error()) + } + if want := "YmFyYmF6"; encrypted != want { + t.Fatalf("unexpected encrypted value. want=%q have=%q", want, encrypted) + } + if want := base64KeyVersion.Type; keyType(t, keyID) != want { + t.Fatalf("unexpected key identifier. want=%q have=%q", want, keyType(t, keyID)) + } + } +} diff --git a/internal/encryption/json_encryptable.go b/internal/encryption/json_encryptable.go new file mode 100644 index 00000000000..45ddcf7ace7 --- /dev/null +++ b/internal/encryption/json_encryptable.go @@ -0,0 +1,81 @@ +package encryption + +import ( + "context" + "encoding/json" +) + +// JSONEncryptable wraps a value of type T and an encryption key and handles lazily encoding/encrypting +// and decrypting/decoding that value. This struct should be used in all places where a JSON-serialized +// value is encrypted at-rest to maintain a consistent handling of data with security concerns. +// +// This struct should always be passed by reference. +type JSONEncryptable[T any] struct { + *Encryptable +} + +// NewUnencryptedJSON creates a new JSON encryptable from the given value. +func NewUnencryptedJSON[T any](value T) (*JSONEncryptable[T], error) { + return NewUnencryptedJSONWithKey(value, nil) +} + +func NewUnencryptedJSONWithKey[T any](value T, key Key) (*JSONEncryptable[T], error) { + serialized, err := json.Marshal(value) + if err != nil { + return nil, err + } + + return &JSONEncryptable[T]{Encryptable: NewUnencryptedWithKey(string(serialized), key)}, nil +} + +// NewEncryptedJSON creates a new JSON encryptable an encrypted value and a relevant encryption key. +func NewEncryptedJSON[T any](cipher, keyID string, key Key) *JSONEncryptable[T] { + return &JSONEncryptable[T]{Encryptable: NewEncrypted(cipher, keyID, key)} +} + +// Decrypt decrypts and returns the underlying value as a T. This method may make an external API call +// to decrypt the underlying encrypted value, but will memoize the result so that subsequent calls will +// be cheap. +func (e *JSONEncryptable[T]) Decrypt(ctx context.Context) (value T, _ error) { + serialized, err := e.Encryptable.Decrypt(ctx) + if err != nil { + return value, err + } + + if err := json.Unmarshal([]byte(serialized), &value); err != nil { + return value, err + } + + return value, nil +} + +// DecryptInto decrypts the underlying value and updates the given value. This method may make an external +// API call to decrypt the underlying encrypted value, but will memoize the result so that subsequent calls +// will be cheap. +func (e *JSONEncryptable[T]) DecryptInto(ctx context.Context, value T) error { + serialized, err := e.Encryptable.Decrypt(ctx) + if err != nil { + return err + } + + if err := json.Unmarshal([]byte(serialized), &value); err != nil { + return err + } + + return nil +} + +// Set updates the underlying value. +func (e *JSONEncryptable[T]) Set(value T) error { + serialized, err := json.Marshal(value) + if err != nil { + return err + } + + e.mutex.Lock() + defer e.mutex.Unlock() + + e.decrypted = &decryptedValue{string(serialized), nil} + e.encrypted = nil + return nil +} diff --git a/internal/encryption/json_encryptable_test.go b/internal/encryption/json_encryptable_test.go new file mode 100644 index 00000000000..95154a9e3ba --- /dev/null +++ b/internal/encryption/json_encryptable_test.go @@ -0,0 +1,85 @@ +package encryption + +import ( + "context" + "encoding/json" + "testing" +) + +func TestJSONEncryptable(t *testing.T) { + ctx := context.Background() + base64Key := base64Key{} + keyID, _ := json.Marshal(base64KeyVersion) + + keyType := func(t *testing.T, keyID string) string { + var key KeyVersion + if err := json.Unmarshal([]byte(keyID), &key); err != nil { + t.Fatalf("unexpected key identifier - not json: %s", err.Error()) + } + + return key.Type + } + + type T struct { + Foo int `json:"foo"` + Bar int `json:"bar"` + Baz int `json:"baz"` + } + v1 := T{1, 2, 3} + v2 := T{7, 8, 9} + + unencrypted, err := NewUnencryptedJSON(v1) + if err != nil { + t.Fatalf("unexpected error creating encryptable: %s", err.Error()) + } + + for _, encryptable := range []*JSONEncryptable[T]{ + unencrypted, + NewEncryptedJSON[T]("eyJmb28iOjEsImJhciI6MiwiYmF6IjozfQ==", string(keyID), base64Key), + } { + // Test Decrypt + decrypted, err := encryptable.Decrypt(ctx) + if err != nil { + t.Fatalf("unexpected error encrypting: %s", err.Error()) + } + if want := v1; decrypted != want { + t.Fatalf("unexpected decrypted value. want=%q have=%q", want, decrypted) + } + + // Test Encrypt + encrypted, keyID, err := encryptable.Encrypt(ctx, base64Key) + if err != nil { + t.Fatalf("unexpected error encrypting: %s", err.Error()) + } + if want := "eyJmb28iOjEsImJhciI6MiwiYmF6IjozfQ=="; encrypted != want { + t.Fatalf("unexpected encrypted value. want=%q have=%q", want, encrypted) + } + if want := base64KeyVersion.Type; keyType(t, keyID) != want { + t.Fatalf("unexpected key identifier. want=%q have=%q", want, keyType(t, keyID)) + } + + // Test Set + encryptable.Set(v2) + + // Re-test Decrypt + decrypted, err = encryptable.Decrypt(ctx) + if err != nil { + t.Fatalf("unexpected error encrypting: %s", err.Error()) + } + if want := v2; decrypted != want { + t.Fatalf("unexpected decrypted value. want=%q have=%q", want, decrypted) + } + + // Re-test Encrypt + encrypted, keyID, err = encryptable.Encrypt(ctx, base64Key) + if err != nil { + t.Fatalf("unexpected error encrypting: %s", err.Error()) + } + if want := "eyJmb28iOjcsImJhciI6OCwiYmF6Ijo5fQ=="; encrypted != want { + t.Fatalf("unexpected encrypted value. want=%q have=%q", want, encrypted) + } + if want := base64KeyVersion.Type; keyType(t, keyID) != want { + t.Fatalf("unexpected key identifier. want=%q have=%q", want, keyType(t, keyID)) + } + } +} diff --git a/internal/encryption/testing/compare.go b/internal/encryption/testing/compare.go new file mode 100644 index 00000000000..078a8ce1a05 --- /dev/null +++ b/internal/encryption/testing/compare.go @@ -0,0 +1,30 @@ +package testing + +import ( + "context" + + "github.com/google/go-cmp/cmp" + + "github.com/sourcegraph/sourcegraph/internal/encryption" +) + +var CompareEncryptable = cmp.Comparer(func(a, b *encryption.Encryptable) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + + aValue, err := a.Decrypt(context.Background()) + if err != nil { + return false + } + + bValue, err := b.Decrypt(context.Background()) + if err != nil { + return false + } + + return cmp.Diff(aValue, bValue) == "" +}) diff --git a/internal/encryption/utils_test.go b/internal/encryption/utils_test.go new file mode 100644 index 00000000000..f578203a36b --- /dev/null +++ b/internal/encryption/utils_test.go @@ -0,0 +1,52 @@ +package encryption + +import ( + "context" + "encoding/base64" + "encoding/json" + "testing" +) + +type base64Key struct{} + +var base64KeyVersion = KeyVersion{Type: "base64"} + +func (k base64Key) Encrypt(ctx context.Context, plaintext []byte) ([]byte, error) { + return []byte(base64.StdEncoding.EncodeToString(plaintext)), nil +} + +func (k base64Key) Decrypt(ctx context.Context, ciphertext []byte) (*Secret, error) { + decoded, err := base64.StdEncoding.DecodeString(string(ciphertext)) + s := NewSecret(string(decoded)) + return &s, err +} + +func (k base64Key) Version(ctx context.Context) (KeyVersion, error) { + return base64KeyVersion, nil +} + +type base64PlusJunkKey struct{ base64Key } + +var base64PlusJunkKeyVersion = KeyVersion{Type: "base64-plus-junk"} + +func (k base64PlusJunkKey) Encrypt(ctx context.Context, plaintext []byte) ([]byte, error) { + encrypted, err := k.base64Key.Encrypt(ctx, plaintext) + return append([]byte(`!@#$`), encrypted...), err +} + +func (k base64PlusJunkKey) Decrypt(ctx context.Context, ciphertext []byte) (*Secret, error) { + return k.base64Key.Decrypt(ctx, ciphertext[4:]) +} + +func (k base64PlusJunkKey) Version(ctx context.Context) (KeyVersion, error) { + return base64PlusJunkKeyVersion, nil +} + +func keyType(t *testing.T, keyID string) string { + var key KeyVersion + if err := json.Unmarshal([]byte(keyID), &key); err != nil { + t.Fatalf("unexpected key identifier - not json: %s", err.Error()) + } + + return key.Type +}