mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 15:12:02 +00:00
encryption: Introduce Encryptable (#40282)
This commit is contained in:
parent
cacd49ea0a
commit
a8cb7dc4c5
130
internal/encryption/encryptable.go
Normal file
130
internal/encryption/encryptable.go
Normal file
@ -0,0 +1,130 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
// Encryptable wraps a value and an encryption key and handles lazily encrypting and
|
||||
// decrypting that value. This struct should be used in all places where a value is
|
||||
// encrypted at-rest to maintain a consistent handling of data with security concerns.
|
||||
//
|
||||
// This struct should always be passed by reference.
|
||||
type Encryptable struct {
|
||||
mutex sync.Mutex
|
||||
decrypted *decryptedValue
|
||||
encrypted *EncryptedValue
|
||||
key Key
|
||||
}
|
||||
|
||||
type decryptedValue struct {
|
||||
value string
|
||||
err error
|
||||
}
|
||||
|
||||
// EncryptedValue wraps an encrypted value and serialized metadata about that key that
|
||||
// encrypted it.
|
||||
type EncryptedValue struct {
|
||||
Cipher string
|
||||
KeyID string
|
||||
}
|
||||
|
||||
// NewUnencrypted creates a new encryptable from a plaintext value.
|
||||
func NewUnencrypted(value string) *Encryptable {
|
||||
return NewUnencryptedWithKey(value, nil)
|
||||
}
|
||||
|
||||
func NewUnencryptedWithKey(value string, key Key) *Encryptable {
|
||||
return &Encryptable{
|
||||
decrypted: &decryptedValue{value, nil},
|
||||
key: key,
|
||||
}
|
||||
}
|
||||
|
||||
// NewEncrypted creates a new encryptable from an encrypted value and a relevant encryption key.
|
||||
func NewEncrypted(cipher, keyID string, key Key) *Encryptable {
|
||||
return &Encryptable{
|
||||
encrypted: &EncryptedValue{cipher, keyID},
|
||||
key: key,
|
||||
}
|
||||
}
|
||||
|
||||
// Decrypt returns the underlying plaintext value. This method may make an external API call to
|
||||
// decrypt the underlying encrypted value, but will memoize the result so that subsequent calls
|
||||
// will be cheap.
|
||||
func (e *Encryptable) Decrypt(ctx context.Context) (string, error) {
|
||||
e.mutex.Lock()
|
||||
defer e.mutex.Unlock()
|
||||
|
||||
return e.decryptLocked(ctx)
|
||||
}
|
||||
|
||||
func (e *Encryptable) decryptLocked(ctx context.Context) (string, error) {
|
||||
if e.decrypted != nil {
|
||||
return e.decrypted.value, e.decrypted.err
|
||||
}
|
||||
if e.encrypted == nil {
|
||||
return "", errors.New("no encrypted value")
|
||||
}
|
||||
|
||||
value, err := MaybeDecrypt(ctx, e.key, e.encrypted.Cipher, e.encrypted.KeyID)
|
||||
e.decrypted = &decryptedValue{value, err}
|
||||
return value, err
|
||||
}
|
||||
|
||||
// Encrypt returns the underlying encrypted value. This method may make an external API call to
|
||||
// encrypt the underlying plaintext value, but will memoize the result so that subsequent calls
|
||||
// will be cheap.
|
||||
func (e *Encryptable) Encrypt(ctx context.Context, key Key) (string, string, error) {
|
||||
if err := e.SetKey(ctx, key); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
e.mutex.Lock()
|
||||
defer e.mutex.Unlock()
|
||||
|
||||
if e.encrypted != nil {
|
||||
return e.encrypted.Cipher, e.encrypted.KeyID, nil
|
||||
}
|
||||
if e.decrypted == nil {
|
||||
return "", "", errors.New("nothing to encrypt")
|
||||
}
|
||||
|
||||
cipher, keyID, err := MaybeEncrypt(ctx, e.key, e.decrypted.value)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
e.encrypted = &EncryptedValue{cipher, keyID}
|
||||
return cipher, keyID, err
|
||||
}
|
||||
|
||||
// Set updates the underlying plaintext value.
|
||||
func (e *Encryptable) Set(value string) {
|
||||
e.mutex.Lock()
|
||||
defer e.mutex.Unlock()
|
||||
|
||||
e.decrypted = &decryptedValue{value, nil}
|
||||
e.encrypted = nil
|
||||
}
|
||||
|
||||
// SetKey updates the encryption key used with the encrypted value. This method may trigger an
|
||||
// external API call to decrypt the current value.
|
||||
func (e *Encryptable) SetKey(ctx context.Context, key Key) error {
|
||||
e.mutex.Lock()
|
||||
defer e.mutex.Unlock()
|
||||
|
||||
if e.key == key {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := e.decryptLocked(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.key = key
|
||||
e.encrypted = nil
|
||||
return nil
|
||||
}
|
||||
78
internal/encryption/encryptable_test.go
Normal file
78
internal/encryption/encryptable_test.go
Normal file
@ -0,0 +1,78 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEncryptable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
base64Key := base64Key{}
|
||||
base64Key2 := base64PlusJunkKey{}
|
||||
keyID, _ := json.Marshal(base64KeyVersion)
|
||||
|
||||
for _, encryptable := range []*Encryptable{
|
||||
NewUnencrypted("foobar"),
|
||||
NewEncrypted("Zm9vYmFy", string(keyID), base64Key),
|
||||
} {
|
||||
// Test Decrypt
|
||||
decrypted, err := encryptable.Decrypt(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error encrypting: %s", err.Error())
|
||||
}
|
||||
if want := "foobar"; decrypted != want {
|
||||
t.Fatalf("unexpected decrypted value. want=%q have=%q", want, decrypted)
|
||||
}
|
||||
|
||||
// Test Encrypt
|
||||
encrypted, keyID, err := encryptable.Encrypt(ctx, base64Key)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error encrypting: %s", err.Error())
|
||||
}
|
||||
if want := "Zm9vYmFy"; encrypted != want {
|
||||
t.Fatalf("unexpected encrypted value. want=%q have=%q", want, encrypted)
|
||||
}
|
||||
if want := base64KeyVersion.Type; keyType(t, keyID) != want {
|
||||
t.Fatalf("unexpected key identifier. want=%q have=%q", want, keyType(t, keyID))
|
||||
}
|
||||
|
||||
// Test SetKey
|
||||
if err := encryptable.SetKey(ctx, base64Key2); err != nil {
|
||||
t.Fatalf("unexpected error setting key: %s", err.Error())
|
||||
}
|
||||
|
||||
// Re-test Decrypt
|
||||
decrypted, err = encryptable.Decrypt(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error encrypting: %s", err.Error())
|
||||
}
|
||||
if want := "foobar"; decrypted != want {
|
||||
t.Fatalf("unexpected decrypted value. want=%q have=%q", want, decrypted)
|
||||
}
|
||||
|
||||
// Test Set
|
||||
encryptable.Set("barbaz")
|
||||
|
||||
// Re-test Decrypt
|
||||
decrypted, err = encryptable.Decrypt(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error encrypting: %s", err.Error())
|
||||
}
|
||||
if want := "barbaz"; decrypted != want {
|
||||
t.Fatalf("unexpected decrypted value. want=%q have=%q", want, decrypted)
|
||||
}
|
||||
|
||||
// Re-test Encrypt
|
||||
encrypted, keyID, err = encryptable.Encrypt(ctx, base64Key)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error encrypting: %s", err.Error())
|
||||
}
|
||||
if want := "YmFyYmF6"; encrypted != want {
|
||||
t.Fatalf("unexpected encrypted value. want=%q have=%q", want, encrypted)
|
||||
}
|
||||
if want := base64KeyVersion.Type; keyType(t, keyID) != want {
|
||||
t.Fatalf("unexpected key identifier. want=%q have=%q", want, keyType(t, keyID))
|
||||
}
|
||||
}
|
||||
}
|
||||
81
internal/encryption/json_encryptable.go
Normal file
81
internal/encryption/json_encryptable.go
Normal file
@ -0,0 +1,81 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// JSONEncryptable wraps a value of type T and an encryption key and handles lazily encoding/encrypting
|
||||
// and decrypting/decoding that value. This struct should be used in all places where a JSON-serialized
|
||||
// value is encrypted at-rest to maintain a consistent handling of data with security concerns.
|
||||
//
|
||||
// This struct should always be passed by reference.
|
||||
type JSONEncryptable[T any] struct {
|
||||
*Encryptable
|
||||
}
|
||||
|
||||
// NewUnencryptedJSON creates a new JSON encryptable from the given value.
|
||||
func NewUnencryptedJSON[T any](value T) (*JSONEncryptable[T], error) {
|
||||
return NewUnencryptedJSONWithKey(value, nil)
|
||||
}
|
||||
|
||||
func NewUnencryptedJSONWithKey[T any](value T, key Key) (*JSONEncryptable[T], error) {
|
||||
serialized, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &JSONEncryptable[T]{Encryptable: NewUnencryptedWithKey(string(serialized), key)}, nil
|
||||
}
|
||||
|
||||
// NewEncryptedJSON creates a new JSON encryptable an encrypted value and a relevant encryption key.
|
||||
func NewEncryptedJSON[T any](cipher, keyID string, key Key) *JSONEncryptable[T] {
|
||||
return &JSONEncryptable[T]{Encryptable: NewEncrypted(cipher, keyID, key)}
|
||||
}
|
||||
|
||||
// Decrypt decrypts and returns the underlying value as a T. This method may make an external API call
|
||||
// to decrypt the underlying encrypted value, but will memoize the result so that subsequent calls will
|
||||
// be cheap.
|
||||
func (e *JSONEncryptable[T]) Decrypt(ctx context.Context) (value T, _ error) {
|
||||
serialized, err := e.Encryptable.Decrypt(ctx)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(serialized), &value); err != nil {
|
||||
return value, err
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// DecryptInto decrypts the underlying value and updates the given value. This method may make an external
|
||||
// API call to decrypt the underlying encrypted value, but will memoize the result so that subsequent calls
|
||||
// will be cheap.
|
||||
func (e *JSONEncryptable[T]) DecryptInto(ctx context.Context, value T) error {
|
||||
serialized, err := e.Encryptable.Decrypt(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(serialized), &value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set updates the underlying value.
|
||||
func (e *JSONEncryptable[T]) Set(value T) error {
|
||||
serialized, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.mutex.Lock()
|
||||
defer e.mutex.Unlock()
|
||||
|
||||
e.decrypted = &decryptedValue{string(serialized), nil}
|
||||
e.encrypted = nil
|
||||
return nil
|
||||
}
|
||||
85
internal/encryption/json_encryptable_test.go
Normal file
85
internal/encryption/json_encryptable_test.go
Normal file
@ -0,0 +1,85 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestJSONEncryptable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
base64Key := base64Key{}
|
||||
keyID, _ := json.Marshal(base64KeyVersion)
|
||||
|
||||
keyType := func(t *testing.T, keyID string) string {
|
||||
var key KeyVersion
|
||||
if err := json.Unmarshal([]byte(keyID), &key); err != nil {
|
||||
t.Fatalf("unexpected key identifier - not json: %s", err.Error())
|
||||
}
|
||||
|
||||
return key.Type
|
||||
}
|
||||
|
||||
type T struct {
|
||||
Foo int `json:"foo"`
|
||||
Bar int `json:"bar"`
|
||||
Baz int `json:"baz"`
|
||||
}
|
||||
v1 := T{1, 2, 3}
|
||||
v2 := T{7, 8, 9}
|
||||
|
||||
unencrypted, err := NewUnencryptedJSON(v1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error creating encryptable: %s", err.Error())
|
||||
}
|
||||
|
||||
for _, encryptable := range []*JSONEncryptable[T]{
|
||||
unencrypted,
|
||||
NewEncryptedJSON[T]("eyJmb28iOjEsImJhciI6MiwiYmF6IjozfQ==", string(keyID), base64Key),
|
||||
} {
|
||||
// Test Decrypt
|
||||
decrypted, err := encryptable.Decrypt(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error encrypting: %s", err.Error())
|
||||
}
|
||||
if want := v1; decrypted != want {
|
||||
t.Fatalf("unexpected decrypted value. want=%q have=%q", want, decrypted)
|
||||
}
|
||||
|
||||
// Test Encrypt
|
||||
encrypted, keyID, err := encryptable.Encrypt(ctx, base64Key)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error encrypting: %s", err.Error())
|
||||
}
|
||||
if want := "eyJmb28iOjEsImJhciI6MiwiYmF6IjozfQ=="; encrypted != want {
|
||||
t.Fatalf("unexpected encrypted value. want=%q have=%q", want, encrypted)
|
||||
}
|
||||
if want := base64KeyVersion.Type; keyType(t, keyID) != want {
|
||||
t.Fatalf("unexpected key identifier. want=%q have=%q", want, keyType(t, keyID))
|
||||
}
|
||||
|
||||
// Test Set
|
||||
encryptable.Set(v2)
|
||||
|
||||
// Re-test Decrypt
|
||||
decrypted, err = encryptable.Decrypt(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error encrypting: %s", err.Error())
|
||||
}
|
||||
if want := v2; decrypted != want {
|
||||
t.Fatalf("unexpected decrypted value. want=%q have=%q", want, decrypted)
|
||||
}
|
||||
|
||||
// Re-test Encrypt
|
||||
encrypted, keyID, err = encryptable.Encrypt(ctx, base64Key)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error encrypting: %s", err.Error())
|
||||
}
|
||||
if want := "eyJmb28iOjcsImJhciI6OCwiYmF6Ijo5fQ=="; encrypted != want {
|
||||
t.Fatalf("unexpected encrypted value. want=%q have=%q", want, encrypted)
|
||||
}
|
||||
if want := base64KeyVersion.Type; keyType(t, keyID) != want {
|
||||
t.Fatalf("unexpected key identifier. want=%q have=%q", want, keyType(t, keyID))
|
||||
}
|
||||
}
|
||||
}
|
||||
30
internal/encryption/testing/compare.go
Normal file
30
internal/encryption/testing/compare.go
Normal file
@ -0,0 +1,30 @@
|
||||
package testing
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/encryption"
|
||||
)
|
||||
|
||||
var CompareEncryptable = cmp.Comparer(func(a, b *encryption.Encryptable) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
aValue, err := a.Decrypt(context.Background())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
bValue, err := b.Decrypt(context.Background())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return cmp.Diff(aValue, bValue) == ""
|
||||
})
|
||||
52
internal/encryption/utils_test.go
Normal file
52
internal/encryption/utils_test.go
Normal file
@ -0,0 +1,52 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type base64Key struct{}
|
||||
|
||||
var base64KeyVersion = KeyVersion{Type: "base64"}
|
||||
|
||||
func (k base64Key) Encrypt(ctx context.Context, plaintext []byte) ([]byte, error) {
|
||||
return []byte(base64.StdEncoding.EncodeToString(plaintext)), nil
|
||||
}
|
||||
|
||||
func (k base64Key) Decrypt(ctx context.Context, ciphertext []byte) (*Secret, error) {
|
||||
decoded, err := base64.StdEncoding.DecodeString(string(ciphertext))
|
||||
s := NewSecret(string(decoded))
|
||||
return &s, err
|
||||
}
|
||||
|
||||
func (k base64Key) Version(ctx context.Context) (KeyVersion, error) {
|
||||
return base64KeyVersion, nil
|
||||
}
|
||||
|
||||
type base64PlusJunkKey struct{ base64Key }
|
||||
|
||||
var base64PlusJunkKeyVersion = KeyVersion{Type: "base64-plus-junk"}
|
||||
|
||||
func (k base64PlusJunkKey) Encrypt(ctx context.Context, plaintext []byte) ([]byte, error) {
|
||||
encrypted, err := k.base64Key.Encrypt(ctx, plaintext)
|
||||
return append([]byte(`!@#$`), encrypted...), err
|
||||
}
|
||||
|
||||
func (k base64PlusJunkKey) Decrypt(ctx context.Context, ciphertext []byte) (*Secret, error) {
|
||||
return k.base64Key.Decrypt(ctx, ciphertext[4:])
|
||||
}
|
||||
|
||||
func (k base64PlusJunkKey) Version(ctx context.Context) (KeyVersion, error) {
|
||||
return base64PlusJunkKeyVersion, nil
|
||||
}
|
||||
|
||||
func keyType(t *testing.T, keyID string) string {
|
||||
var key KeyVersion
|
||||
if err := json.Unmarshal([]byte(keyID), &key); err != nil {
|
||||
t.Fatalf("unexpected key identifier - not json: %s", err.Error())
|
||||
}
|
||||
|
||||
return key.Type
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user