feat/msp: use pgxpool instead of pgx.Conn (#62994)

During manual testing for https://github.com/sourcegraph/sourcegraph/pull/62934, I started realizing that I would run into database errors:

```
failed to deallocate cached statement(s): conn busy
```

Turns out, `pgx.Conn` is meant for non-concurrent use. What we really want is a connection pool, with an off-the-shelf offering from `pgxpool`.

## Test plan

Integration tests pass, now with more cases using `t.Parallel()`. Also ran a quick sanity check by hand:

```
sg start
```

```
for i in {1..10}; do curl --header "Content-Type: application/json" --header 'authorization: bearer $SAMS_TOKEN' --data '{"filters":[{"filter":{"is_archived":false}}]}' \
    http://localhost:6081/enterpriseportal.subscriptions.v1.SubscriptionsService/ListEnterpriseSubscriptionLicenses & ; done
```

## Changelog

- The MSP runtime `lib/managedservicesplatform/contract.Contract`'s `ConnectToDatabase(...)` has been renamed to `GetConnectionPool(...)`, and now returns a `*pgxpool.Pool` instead of a `*pgx.Conn`
- The MSP runtime `lib/managedservicesplatform/cloudsql` helper library's `Connect(...)` has been renamed to `GetConnectionPool(...)`, and now returns a `*pgxpool.Pool` instead of a `*pgx.Conn`
This commit is contained in:
Robert Lin 2024-05-31 09:30:56 -07:00 committed by GitHub
parent 0fcffdd657
commit 21bf7229f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 80 additions and 56 deletions

View File

@ -15,6 +15,7 @@ go_library(
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
"//lib/errors",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_jackc_pgx_v5//pgxpool",
],
)
@ -37,7 +38,7 @@ go_test(
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
"//lib/pointers",
"@com_github_jackc_pgx_v4//stdlib",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_jackc_pgx_v5//pgxpool",
"@com_github_sourcegraph_log//logtest",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",

View File

@ -14,6 +14,7 @@ import (
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/sourcegraph/sourcegraph/internal/license"
"github.com/sourcegraph/sourcegraph/internal/licensing"
@ -24,7 +25,7 @@ import (
)
type Reader struct {
conn *pgx.Conn
db *pgxpool.Pool
opts ReaderOptions
}
@ -40,21 +41,28 @@ type ReaderOptions struct {
//
// 👷 This is intended to be a short-lived mechanism, and should be removed
// as part of https://linear.app/sourcegraph/project/12f1d5047bd2/overview.
func NewReader(conn *pgx.Conn, opts ReaderOptions) *Reader {
return &Reader{conn: conn, opts: opts}
func NewReader(db *pgxpool.Pool, opts ReaderOptions) *Reader {
return &Reader{db: db, opts: opts}
}
func (r *Reader) Ping(ctx context.Context) error {
if err := r.conn.Ping(ctx); err != nil {
// Execute ping steps within a single connection.
conn, err := r.db.Acquire(ctx)
if err != nil {
return errors.Wrap(err, "db.Acquire")
}
defer conn.Release()
if err := conn.Ping(ctx); err != nil {
return errors.Wrap(err, "sqlDB.PingContext")
}
if _, err := r.conn.Exec(ctx, "SELECT current_user;"); err != nil {
if _, err := conn.Exec(ctx, "SELECT current_user;"); err != nil {
return errors.Wrap(err, "sqlDB.Exec SELECT current_user")
}
return nil
}
func (r *Reader) Close(ctx context.Context) error { return r.conn.Close(ctx) }
func (r *Reader) Close() { r.db.Close() }
type CodyGatewayAccessAttributes struct {
SubscriptionID string
@ -237,7 +245,7 @@ func (r *Reader) GetCodyGatewayAccessAttributesBySubscription(ctx context.Contex
query := newCodyGatewayAccessQuery(queryConditions{
whereClause: "subscription.id = $1",
}, r.opts)
row := r.conn.QueryRow(ctx, query,
row := r.db.QueryRow(ctx, query,
strings.TrimPrefix(subscriptionID, subscriptionsv1.EnterpriseSubscriptionIDPrefix))
return scanCodyGatewayAccessAttributes(row)
}
@ -259,7 +267,7 @@ func (r *Reader) GetCodyGatewayAccessAttributesByAccessToken(ctx context.Context
query := newCodyGatewayAccessQuery(queryConditions{
havingClause: "$1 = ANY(array_agg(tokens.license_key_hash))",
}, r.opts)
row := r.conn.QueryRow(ctx, query, decoded)
row := r.db.QueryRow(ctx, query, decoded)
return scanCodyGatewayAccessAttributes(row)
}
@ -289,7 +297,7 @@ func scanCodyGatewayAccessAttributes(row pgx.Row) (*CodyGatewayAccessAttributes,
func (r *Reader) GetAllCodyGatewayAccessAttributes(ctx context.Context) ([]*CodyGatewayAccessAttributes, error) {
query := newCodyGatewayAccessQuery(queryConditions{}, r.opts)
rows, err := r.conn.Query(ctx, query)
rows, err := r.db.Query(ctx, query)
if err != nil {
return nil, errors.Wrap(err, "failed to get cody gateway access attributes")
}
@ -427,7 +435,7 @@ func (r *Reader) ListEnterpriseSubscriptionLicenses(
}
query := newLicensesQuery(conds, r.opts)
rows, err := r.conn.Query(ctx, query, args...)
rows, err := r.db.Query(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "failed to get cody gateway access attributes")
}

View File

@ -7,7 +7,7 @@ import (
"time"
pgxstdlibv4 "github.com/jackc/pgx/v4/stdlib"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -58,7 +58,7 @@ func newTestDotcomReader(t *testing.T) (database.DB, *dotcomdb.Reader) {
// Now create a new connection using the conn string 😎
t.Logf("pgx.Connect %q", connString)
conn, err := pgx.Connect(ctx, connString)
conn, err := pgxpool.New(ctx, connString)
require.NoError(t, err)
// Make sure it works!
@ -235,22 +235,29 @@ func TestGetCodyGatewayAccessAttributes(t *testing.T) {
}} {
t.Run(tc.name, func(t *testing.T) {
tc := tc
t.Parallel() // parallel per newTestDotcomReader
dotcomdb, dotcomreader := newTestDotcomReader(t)
t.Parallel()
dotcomdb, dotcomreader := newTestDotcomReader(t)
// First, set up a subscription and license and some other rubbish
// data to ensure we only get the license we want.
mock := setupDBAndInsertMockLicense(t, dotcomdb, tc.info, &tc.cgAccess)
t.Run("by subscription ID", func(t *testing.T) {
t.Parallel()
attr, err := dotcomreader.GetCodyGatewayAccessAttributesBySubscription(ctx, mock.targetSubscriptionID)
require.NoError(t, err)
validateAccessAttributes(t, dotcomdb, mock, attr, tc.info)
})
t.Run("by access token", func(t *testing.T) {
t.Parallel()
for i, token := range mock.accessTokens {
t.Run(fmt.Sprintf("token %d", i), func(t *testing.T) {
token := token
t.Parallel()
attr, err := dotcomreader.GetCodyGatewayAccessAttributesByAccessToken(ctx, token)
require.NoError(t, err)
validateAccessAttributes(t, dotcomdb, mock, attr, tc.info)
@ -309,7 +316,7 @@ func validateAccessAttributes(t *testing.T, dotcomdb database.DB, mock mockedDat
}
func TestGetAllCodyGatewayAccessAttributes(t *testing.T) {
t.Parallel() // parallel per newTestDotcomReader
t.Parallel()
dotcomdb, dotcomreader := newTestDotcomReader(t)
info := license.Info{
@ -337,7 +344,8 @@ func TestGetAllCodyGatewayAccessAttributes(t *testing.T) {
}
func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
t.Parallel() // parallel per newTestDotcomReader
t.Parallel()
db, dotcomreader := newTestDotcomReader(t)
info := license.Info{
ExpiresAt: time.Now().Add(30 * time.Minute),
@ -435,6 +443,9 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
},
}} {
t.Run(tc.name, func(t *testing.T) {
tc := tc
t.Parallel()
licenses, err := dotcomreader.ListEnterpriseSubscriptionLicenses(ctx, tc.filters, tc.pageSize)
require.NoError(t, err)
for _, l := range licenses {

View File

@ -24,7 +24,7 @@ go_library(
"//lib/managedservicesplatform/cloudsql",
"//lib/managedservicesplatform/runtime",
"@com_connectrpc_grpcreflect//:grpcreflect",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_jackc_pgx_v5//pgxpool",
"@com_github_sourcegraph_log//:log",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//:sourcegraph-accounts-sdk-go",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//scopes",

View File

@ -3,7 +3,7 @@ package service
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb"
"github.com/sourcegraph/sourcegraph/lib/errors"
@ -16,21 +16,17 @@ func newDotComDBConn(ctx context.Context, config Config) (*dotcomdb.Reader, erro
}
if override := config.DotComDB.PGDSNOverride; override != nil {
config, err := pgx.ParseConfig(*override)
if err != nil {
return nil, errors.Wrapf(err, "pgx.ParseConfig %q", *override)
}
conn, err := pgx.ConnectConfig(ctx, config)
db, err := pgxpool.New(ctx, *override)
if err != nil {
return nil, errors.Wrapf(err, "pgx.ConnectConfig %q", *override)
}
return dotcomdb.NewReader(conn, readerOpts), nil
return dotcomdb.NewReader(db, readerOpts), nil
}
// Use IAM auth to connect to the Cloud SQL database.
conn, err := cloudsql.Connect(ctx, config.DotComDB.ConnConfig)
db, err := cloudsql.GetConnectionPool(ctx, config.DotComDB.ConnConfig)
if err != nil {
return nil, errors.Wrap(err, "contract.GetPostgreSQLDB")
}
return dotcomdb.NewReader(conn, readerOpts), nil
return dotcomdb.NewReader(db, readerOpts), nil
}

View File

@ -120,10 +120,8 @@ func (Service) Initialize(ctx context.Context, logger log.Logger, contract runti
background.CallbackRoutine{
StopFunc: func(ctx context.Context) error {
start := time.Now()
if err := dotcomDB.Close(ctx); err != nil {
return errors.Wrap(err, "dotcomDB.Close")
}
logger.Info("database stopped", log.Duration("elapsed", time.Since(start)))
dotcomDB.Close()
logger.Info("database connection pool closed", log.Duration("elapsed", time.Since(start)))
return nil
},
},

View File

@ -8,6 +8,7 @@ go_library(
deps = [
"//lib/errors",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_jackc_pgx_v5//pgxpool",
"@com_github_jackc_pgx_v5//stdlib",
"@com_google_cloud_go_cloudsqlconn//:cloudsqlconn",
"@io_opentelemetry_go_otel//:otel",

View File

@ -8,6 +8,7 @@ import (
"cloud.google.com/go/cloudsqlconn"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
@ -103,24 +104,28 @@ func Open(
if err != nil {
return nil, errors.Wrap(err, "get CloudSQL connection config")
}
return sql.Open("pgx", stdlib.RegisterConnConfig(config))
return sql.Open("pgx", stdlib.RegisterConnConfig(config.ConnConfig))
}
// Connect opens a *pgx.Conn connection to the CloudSQL instance specified by
// the ConnConfig.
// GetConnectionPool is an alternative to OpenDatabase that returns a
// github.com/jackc/pgx/v5/pgxpool to the CloudSQL instance specified by
// the ConnConfig, for services that prefer to use 'pgx' directly. A pool returns
// without waiting for any connections to be established. Acquire a connection
// immediately after creating the pool to check if a connection can successfully
// be established.
//
// 🔔 If you are connecting to a MSP-provisioned Cloud SQL instance,
// DO NOT use this - instead, use runtime.Contract.PostgreSQL.OpenDatabase
// DO NOT use this - instead, use runtime.Contract.PostgreSQL.GetConnectionPool
// instead.
func Connect(
func GetConnectionPool(
ctx context.Context,
cfg ConnConfig,
) (*pgx.Conn, error) {
) (*pgxpool.Pool, error) {
config, err := getCloudSQLConnConfig(ctx, cfg)
if err != nil {
return nil, errors.Wrap(err, "get CloudSQL connection config")
}
return pgx.ConnectConfig(ctx, config)
return pgxpool.NewWithConfig(ctx, config)
}
// getCloudSQLConnConfig generates a pgx connection configuration for using
@ -128,14 +133,14 @@ func Connect(
func getCloudSQLConnConfig(
ctx context.Context,
cfg ConnConfig,
) (*pgx.ConnConfig, error) {
) (*pgxpool.Config, error) {
if cfg.ConnectionName == nil || cfg.User == nil {
return nil, errors.New("missing required PostgreSQL configuration")
}
// https://github.com/GoogleCloudPlatform/cloud-sql-go-connector?tab=readme-ov-file#automatic-iam-database-authentication
dsn := fmt.Sprintf("user=%s dbname=%s", *cfg.User, cfg.Database)
config, err := pgx.ParseConfig(dsn)
config, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, errors.Wrap(err, "pgx.ParseConfig")
}
@ -149,11 +154,11 @@ func getCloudSQLConnConfig(
}
// Use the Cloud SQL connector to handle connecting to the instance.
// This approach does *NOT* require the Cloud SQL proxy.
config.DialFunc = func(ctx context.Context, _, _ string) (net.Conn, error) {
config.ConnConfig.DialFunc = func(ctx context.Context, _, _ string) (net.Conn, error) {
return customDialer.Dial(ctx, *cfg.ConnectionName)
}
// Attach tracing
config.Tracer = pgxTracer{}
config.ConnConfig.Tracer = pgxTracer{}
return config, nil
}

View File

@ -21,7 +21,7 @@ go_library(
"//lib/pointers",
"@com_github_getsentry_sentry_go//:sentry-go",
"@com_github_google_uuid//:uuid",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_jackc_pgx_v5//pgxpool",
"@com_github_jackc_pgx_v5//stdlib",
"@com_github_prometheus_client_golang//prometheus/promhttp",
"@com_github_sourcegraph_log//:log",

View File

@ -7,7 +7,7 @@ import (
"text/template"
"cloud.google.com/go/cloudsqlconn"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/sourcegraph/sourcegraph/lib/errors"
@ -49,26 +49,29 @@ func (c postgreSQLContract) OpenDatabase(ctx context.Context, database string) (
if err != nil {
return nil, err
}
return sql.Open("customdsn", stdlib.RegisterConnConfig(config))
return sql.Open("customdsn", stdlib.RegisterConnConfig(config.ConnConfig))
}
return cloudsql.Open(ctx, c.getCloudSQLConnConfig(database))
}
// ConnectToDatabase is similar to OpenDatabase, but returns a
// github.com/jackc/pgx/v5 connection to the configured datbase instead for
// services that prefer to use 'pgx' directly.
// GetConnectionPool is an alternative to OpenDatabase that returns a
// github.com/jackc/pgx/v5/pgxpool for connecting to the configured database
// instead, for services that prefer to use 'pgx' directly. A pool returns
// without waiting for any connections to be established. Acquire a connection
// immediately after creating the pool to check if a connection can successfully
// be established.
//
// In development, the connection can be overridden with the PGDSN environment
// variable.
func (c postgreSQLContract) ConnectToDatabase(ctx context.Context, database string) (*pgx.Conn, error) {
func (c postgreSQLContract) GetConnectionPool(ctx context.Context, database string) (*pgxpool.Pool, error) {
if c.customDSNTemplate != nil {
config, err := parseCustomDSNTemplateConnConfig(*c.customDSNTemplate, database)
if err != nil {
return nil, err
}
return pgx.ConnectConfig(ctx, config)
return pgxpool.NewWithConfig(ctx, config)
}
return cloudsql.Connect(ctx, c.getCloudSQLConnConfig(database))
return cloudsql.GetConnectionPool(ctx, c.getCloudSQLConnConfig(database))
}
func (c postgreSQLContract) getCloudSQLConnConfig(database string) cloudsql.ConnConfig {
@ -83,7 +86,7 @@ func (c postgreSQLContract) getCloudSQLConnConfig(database string) cloudsql.Conn
}
}
func parseCustomDSNTemplateConnConfig(customDSNTemplate, database string) (*pgx.ConnConfig, error) {
func parseCustomDSNTemplateConnConfig(customDSNTemplate, database string) (*pgxpool.Config, error) {
tmpl, err := template.New("PGDSN").Parse(customDSNTemplate)
if err != nil {
return nil, errors.Wrap(err, "PGDSN is not a valid template")
@ -92,7 +95,7 @@ func parseCustomDSNTemplateConnConfig(customDSNTemplate, database string) (*pgx.
if err := tmpl.Execute(&dsn, struct{ Database string }{Database: database}); err != nil {
return nil, errors.Wrap(err, "PGDSN template is invalid")
}
config, err := pgx.ParseConfig(dsn.String())
config, err := pgxpool.ParseConfig(dsn.String())
if err != nil {
return nil, errors.Wrap(err, "rendered PGDSN is invalid")
}

View File

@ -405,12 +405,13 @@ commands:
DOTCOM_INCLUDE_PRODUCTION_LICENSES: 'true'
# Used for authentication
SAMS_URL: https://accounts.sgdev.org
# client name: 'enterprise-portal-local-dev'
ENTERPRISE_PORTAL_SAMS_CLIENT_ID: "sams_cid_018fc125-5a92-70fa-8dee-2c6df3adc100"
externalSecrets:
ENTERPRISE_PORTAL_SAMS_CLIENT_ID:
project: sourcegraph-local-dev
name: SG_LOCAL_DEV_SAMS_CLIENT_ID
ENTERPRISE_PORTAL_SAMS_CLIENT_SECRET:
project: sourcegraph-local-dev
name: ENTERPRISE_PORTAL_LOCAL_SAMS_CLIENT_SECRET
name: SG_LOCAL_DEV_SAMS_CLIENT_SECRET
watch:
- lib