mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 20:31:48 +00:00
feat/msp: use pgxpool instead of pgx.Conn (#62994)
During manual testing for https://github.com/sourcegraph/sourcegraph/pull/62934, I started realizing that I would run into database errors: ``` failed to deallocate cached statement(s): conn busy ``` Turns out, `pgx.Conn` is meant for non-concurrent use. What we really want is a connection pool, with an off-the-shelf offering from `pgxpool`. ## Test plan Integration tests pass, now with more cases using `t.Parallel()`. Also ran a quick sanity check by hand: ``` sg start ``` ``` for i in {1..10}; do curl --header "Content-Type: application/json" --header 'authorization: bearer $SAMS_TOKEN' --data '{"filters":[{"filter":{"is_archived":false}}]}' \ http://localhost:6081/enterpriseportal.subscriptions.v1.SubscriptionsService/ListEnterpriseSubscriptionLicenses & ; done ``` ## Changelog - The MSP runtime `lib/managedservicesplatform/contract.Contract`'s `ConnectToDatabase(...)` has been renamed to `GetConnectionPool(...)`, and now returns a `*pgxpool.Pool` instead of a `*pgx.Conn` - The MSP runtime `lib/managedservicesplatform/cloudsql` helper library's `Connect(...)` has been renamed to `GetConnectionPool(...)`, and now returns a `*pgxpool.Pool` instead of a `*pgx.Conn`
This commit is contained in:
parent
0fcffdd657
commit
21bf7229f2
@ -15,6 +15,7 @@ go_library(
|
||||
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
|
||||
"//lib/errors",
|
||||
"@com_github_jackc_pgx_v5//:pgx",
|
||||
"@com_github_jackc_pgx_v5//pgxpool",
|
||||
],
|
||||
)
|
||||
|
||||
@ -37,7 +38,7 @@ go_test(
|
||||
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
|
||||
"//lib/pointers",
|
||||
"@com_github_jackc_pgx_v4//stdlib",
|
||||
"@com_github_jackc_pgx_v5//:pgx",
|
||||
"@com_github_jackc_pgx_v5//pgxpool",
|
||||
"@com_github_sourcegraph_log//logtest",
|
||||
"@com_github_stretchr_testify//assert",
|
||||
"@com_github_stretchr_testify//require",
|
||||
|
||||
@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/license"
|
||||
"github.com/sourcegraph/sourcegraph/internal/licensing"
|
||||
@ -24,7 +25,7 @@ import (
|
||||
)
|
||||
|
||||
type Reader struct {
|
||||
conn *pgx.Conn
|
||||
db *pgxpool.Pool
|
||||
opts ReaderOptions
|
||||
}
|
||||
|
||||
@ -40,21 +41,28 @@ type ReaderOptions struct {
|
||||
//
|
||||
// 👷 This is intended to be a short-lived mechanism, and should be removed
|
||||
// as part of https://linear.app/sourcegraph/project/12f1d5047bd2/overview.
|
||||
func NewReader(conn *pgx.Conn, opts ReaderOptions) *Reader {
|
||||
return &Reader{conn: conn, opts: opts}
|
||||
func NewReader(db *pgxpool.Pool, opts ReaderOptions) *Reader {
|
||||
return &Reader{db: db, opts: opts}
|
||||
}
|
||||
|
||||
func (r *Reader) Ping(ctx context.Context) error {
|
||||
if err := r.conn.Ping(ctx); err != nil {
|
||||
// Execute ping steps within a single connection.
|
||||
conn, err := r.db.Acquire(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.Acquire")
|
||||
}
|
||||
defer conn.Release()
|
||||
|
||||
if err := conn.Ping(ctx); err != nil {
|
||||
return errors.Wrap(err, "sqlDB.PingContext")
|
||||
}
|
||||
if _, err := r.conn.Exec(ctx, "SELECT current_user;"); err != nil {
|
||||
if _, err := conn.Exec(ctx, "SELECT current_user;"); err != nil {
|
||||
return errors.Wrap(err, "sqlDB.Exec SELECT current_user")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Reader) Close(ctx context.Context) error { return r.conn.Close(ctx) }
|
||||
func (r *Reader) Close() { r.db.Close() }
|
||||
|
||||
type CodyGatewayAccessAttributes struct {
|
||||
SubscriptionID string
|
||||
@ -237,7 +245,7 @@ func (r *Reader) GetCodyGatewayAccessAttributesBySubscription(ctx context.Contex
|
||||
query := newCodyGatewayAccessQuery(queryConditions{
|
||||
whereClause: "subscription.id = $1",
|
||||
}, r.opts)
|
||||
row := r.conn.QueryRow(ctx, query,
|
||||
row := r.db.QueryRow(ctx, query,
|
||||
strings.TrimPrefix(subscriptionID, subscriptionsv1.EnterpriseSubscriptionIDPrefix))
|
||||
return scanCodyGatewayAccessAttributes(row)
|
||||
}
|
||||
@ -259,7 +267,7 @@ func (r *Reader) GetCodyGatewayAccessAttributesByAccessToken(ctx context.Context
|
||||
query := newCodyGatewayAccessQuery(queryConditions{
|
||||
havingClause: "$1 = ANY(array_agg(tokens.license_key_hash))",
|
||||
}, r.opts)
|
||||
row := r.conn.QueryRow(ctx, query, decoded)
|
||||
row := r.db.QueryRow(ctx, query, decoded)
|
||||
return scanCodyGatewayAccessAttributes(row)
|
||||
}
|
||||
|
||||
@ -289,7 +297,7 @@ func scanCodyGatewayAccessAttributes(row pgx.Row) (*CodyGatewayAccessAttributes,
|
||||
|
||||
func (r *Reader) GetAllCodyGatewayAccessAttributes(ctx context.Context) ([]*CodyGatewayAccessAttributes, error) {
|
||||
query := newCodyGatewayAccessQuery(queryConditions{}, r.opts)
|
||||
rows, err := r.conn.Query(ctx, query)
|
||||
rows, err := r.db.Query(ctx, query)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get cody gateway access attributes")
|
||||
}
|
||||
@ -427,7 +435,7 @@ func (r *Reader) ListEnterpriseSubscriptionLicenses(
|
||||
}
|
||||
|
||||
query := newLicensesQuery(conds, r.opts)
|
||||
rows, err := r.conn.Query(ctx, query, args...)
|
||||
rows, err := r.db.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get cody gateway access attributes")
|
||||
}
|
||||
|
||||
@ -7,7 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
pgxstdlibv4 "github.com/jackc/pgx/v4/stdlib"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@ -58,7 +58,7 @@ func newTestDotcomReader(t *testing.T) (database.DB, *dotcomdb.Reader) {
|
||||
|
||||
// Now create a new connection using the conn string 😎
|
||||
t.Logf("pgx.Connect %q", connString)
|
||||
conn, err := pgx.Connect(ctx, connString)
|
||||
conn, err := pgxpool.New(ctx, connString)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make sure it works!
|
||||
@ -235,22 +235,29 @@ func TestGetCodyGatewayAccessAttributes(t *testing.T) {
|
||||
}} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc := tc
|
||||
t.Parallel() // parallel per newTestDotcomReader
|
||||
dotcomdb, dotcomreader := newTestDotcomReader(t)
|
||||
t.Parallel()
|
||||
|
||||
dotcomdb, dotcomreader := newTestDotcomReader(t)
|
||||
// First, set up a subscription and license and some other rubbish
|
||||
// data to ensure we only get the license we want.
|
||||
mock := setupDBAndInsertMockLicense(t, dotcomdb, tc.info, &tc.cgAccess)
|
||||
|
||||
t.Run("by subscription ID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
attr, err := dotcomreader.GetCodyGatewayAccessAttributesBySubscription(ctx, mock.targetSubscriptionID)
|
||||
require.NoError(t, err)
|
||||
validateAccessAttributes(t, dotcomdb, mock, attr, tc.info)
|
||||
})
|
||||
|
||||
t.Run("by access token", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i, token := range mock.accessTokens {
|
||||
t.Run(fmt.Sprintf("token %d", i), func(t *testing.T) {
|
||||
token := token
|
||||
t.Parallel()
|
||||
|
||||
attr, err := dotcomreader.GetCodyGatewayAccessAttributesByAccessToken(ctx, token)
|
||||
require.NoError(t, err)
|
||||
validateAccessAttributes(t, dotcomdb, mock, attr, tc.info)
|
||||
@ -309,7 +316,7 @@ func validateAccessAttributes(t *testing.T, dotcomdb database.DB, mock mockedDat
|
||||
}
|
||||
|
||||
func TestGetAllCodyGatewayAccessAttributes(t *testing.T) {
|
||||
t.Parallel() // parallel per newTestDotcomReader
|
||||
t.Parallel()
|
||||
dotcomdb, dotcomreader := newTestDotcomReader(t)
|
||||
|
||||
info := license.Info{
|
||||
@ -337,7 +344,8 @@ func TestGetAllCodyGatewayAccessAttributes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
|
||||
t.Parallel() // parallel per newTestDotcomReader
|
||||
t.Parallel()
|
||||
|
||||
db, dotcomreader := newTestDotcomReader(t)
|
||||
info := license.Info{
|
||||
ExpiresAt: time.Now().Add(30 * time.Minute),
|
||||
@ -435,6 +443,9 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
|
||||
},
|
||||
}} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc := tc
|
||||
t.Parallel()
|
||||
|
||||
licenses, err := dotcomreader.ListEnterpriseSubscriptionLicenses(ctx, tc.filters, tc.pageSize)
|
||||
require.NoError(t, err)
|
||||
for _, l := range licenses {
|
||||
|
||||
@ -24,7 +24,7 @@ go_library(
|
||||
"//lib/managedservicesplatform/cloudsql",
|
||||
"//lib/managedservicesplatform/runtime",
|
||||
"@com_connectrpc_grpcreflect//:grpcreflect",
|
||||
"@com_github_jackc_pgx_v5//:pgx",
|
||||
"@com_github_jackc_pgx_v5//pgxpool",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//:sourcegraph-accounts-sdk-go",
|
||||
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//scopes",
|
||||
|
||||
@ -3,7 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
@ -16,21 +16,17 @@ func newDotComDBConn(ctx context.Context, config Config) (*dotcomdb.Reader, erro
|
||||
}
|
||||
|
||||
if override := config.DotComDB.PGDSNOverride; override != nil {
|
||||
config, err := pgx.ParseConfig(*override)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "pgx.ParseConfig %q", *override)
|
||||
}
|
||||
conn, err := pgx.ConnectConfig(ctx, config)
|
||||
db, err := pgxpool.New(ctx, *override)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "pgx.ConnectConfig %q", *override)
|
||||
}
|
||||
return dotcomdb.NewReader(conn, readerOpts), nil
|
||||
return dotcomdb.NewReader(db, readerOpts), nil
|
||||
}
|
||||
|
||||
// Use IAM auth to connect to the Cloud SQL database.
|
||||
conn, err := cloudsql.Connect(ctx, config.DotComDB.ConnConfig)
|
||||
db, err := cloudsql.GetConnectionPool(ctx, config.DotComDB.ConnConfig)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "contract.GetPostgreSQLDB")
|
||||
}
|
||||
return dotcomdb.NewReader(conn, readerOpts), nil
|
||||
return dotcomdb.NewReader(db, readerOpts), nil
|
||||
}
|
||||
|
||||
@ -120,10 +120,8 @@ func (Service) Initialize(ctx context.Context, logger log.Logger, contract runti
|
||||
background.CallbackRoutine{
|
||||
StopFunc: func(ctx context.Context) error {
|
||||
start := time.Now()
|
||||
if err := dotcomDB.Close(ctx); err != nil {
|
||||
return errors.Wrap(err, "dotcomDB.Close")
|
||||
}
|
||||
logger.Info("database stopped", log.Duration("elapsed", time.Since(start)))
|
||||
dotcomDB.Close()
|
||||
logger.Info("database connection pool closed", log.Duration("elapsed", time.Since(start)))
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
@ -8,6 +8,7 @@ go_library(
|
||||
deps = [
|
||||
"//lib/errors",
|
||||
"@com_github_jackc_pgx_v5//:pgx",
|
||||
"@com_github_jackc_pgx_v5//pgxpool",
|
||||
"@com_github_jackc_pgx_v5//stdlib",
|
||||
"@com_google_cloud_go_cloudsqlconn//:cloudsqlconn",
|
||||
"@io_opentelemetry_go_otel//:otel",
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
@ -103,24 +104,28 @@ func Open(
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get CloudSQL connection config")
|
||||
}
|
||||
return sql.Open("pgx", stdlib.RegisterConnConfig(config))
|
||||
return sql.Open("pgx", stdlib.RegisterConnConfig(config.ConnConfig))
|
||||
}
|
||||
|
||||
// Connect opens a *pgx.Conn connection to the CloudSQL instance specified by
|
||||
// the ConnConfig.
|
||||
// GetConnectionPool is an alternative to OpenDatabase that returns a
|
||||
// github.com/jackc/pgx/v5/pgxpool to the CloudSQL instance specified by
|
||||
// the ConnConfig, for services that prefer to use 'pgx' directly. A pool returns
|
||||
// without waiting for any connections to be established. Acquire a connection
|
||||
// immediately after creating the pool to check if a connection can successfully
|
||||
// be established.
|
||||
//
|
||||
// 🔔 If you are connecting to a MSP-provisioned Cloud SQL instance,
|
||||
// DO NOT use this - instead, use runtime.Contract.PostgreSQL.OpenDatabase
|
||||
// DO NOT use this - instead, use runtime.Contract.PostgreSQL.GetConnectionPool
|
||||
// instead.
|
||||
func Connect(
|
||||
func GetConnectionPool(
|
||||
ctx context.Context,
|
||||
cfg ConnConfig,
|
||||
) (*pgx.Conn, error) {
|
||||
) (*pgxpool.Pool, error) {
|
||||
config, err := getCloudSQLConnConfig(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get CloudSQL connection config")
|
||||
}
|
||||
return pgx.ConnectConfig(ctx, config)
|
||||
return pgxpool.NewWithConfig(ctx, config)
|
||||
}
|
||||
|
||||
// getCloudSQLConnConfig generates a pgx connection configuration for using
|
||||
@ -128,14 +133,14 @@ func Connect(
|
||||
func getCloudSQLConnConfig(
|
||||
ctx context.Context,
|
||||
cfg ConnConfig,
|
||||
) (*pgx.ConnConfig, error) {
|
||||
) (*pgxpool.Config, error) {
|
||||
if cfg.ConnectionName == nil || cfg.User == nil {
|
||||
return nil, errors.New("missing required PostgreSQL configuration")
|
||||
}
|
||||
|
||||
// https://github.com/GoogleCloudPlatform/cloud-sql-go-connector?tab=readme-ov-file#automatic-iam-database-authentication
|
||||
dsn := fmt.Sprintf("user=%s dbname=%s", *cfg.User, cfg.Database)
|
||||
config, err := pgx.ParseConfig(dsn)
|
||||
config, err := pgxpool.ParseConfig(dsn)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "pgx.ParseConfig")
|
||||
}
|
||||
@ -149,11 +154,11 @@ func getCloudSQLConnConfig(
|
||||
}
|
||||
// Use the Cloud SQL connector to handle connecting to the instance.
|
||||
// This approach does *NOT* require the Cloud SQL proxy.
|
||||
config.DialFunc = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
config.ConnConfig.DialFunc = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return customDialer.Dial(ctx, *cfg.ConnectionName)
|
||||
}
|
||||
// Attach tracing
|
||||
config.Tracer = pgxTracer{}
|
||||
config.ConnConfig.Tracer = pgxTracer{}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@ -21,7 +21,7 @@ go_library(
|
||||
"//lib/pointers",
|
||||
"@com_github_getsentry_sentry_go//:sentry-go",
|
||||
"@com_github_google_uuid//:uuid",
|
||||
"@com_github_jackc_pgx_v5//:pgx",
|
||||
"@com_github_jackc_pgx_v5//pgxpool",
|
||||
"@com_github_jackc_pgx_v5//stdlib",
|
||||
"@com_github_prometheus_client_golang//prometheus/promhttp",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
|
||||
@ -7,7 +7,7 @@ import (
|
||||
"text/template"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
@ -49,26 +49,29 @@ func (c postgreSQLContract) OpenDatabase(ctx context.Context, database string) (
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sql.Open("customdsn", stdlib.RegisterConnConfig(config))
|
||||
return sql.Open("customdsn", stdlib.RegisterConnConfig(config.ConnConfig))
|
||||
}
|
||||
return cloudsql.Open(ctx, c.getCloudSQLConnConfig(database))
|
||||
}
|
||||
|
||||
// ConnectToDatabase is similar to OpenDatabase, but returns a
|
||||
// github.com/jackc/pgx/v5 connection to the configured datbase instead for
|
||||
// services that prefer to use 'pgx' directly.
|
||||
// GetConnectionPool is an alternative to OpenDatabase that returns a
|
||||
// github.com/jackc/pgx/v5/pgxpool for connecting to the configured database
|
||||
// instead, for services that prefer to use 'pgx' directly. A pool returns
|
||||
// without waiting for any connections to be established. Acquire a connection
|
||||
// immediately after creating the pool to check if a connection can successfully
|
||||
// be established.
|
||||
//
|
||||
// In development, the connection can be overridden with the PGDSN environment
|
||||
// variable.
|
||||
func (c postgreSQLContract) ConnectToDatabase(ctx context.Context, database string) (*pgx.Conn, error) {
|
||||
func (c postgreSQLContract) GetConnectionPool(ctx context.Context, database string) (*pgxpool.Pool, error) {
|
||||
if c.customDSNTemplate != nil {
|
||||
config, err := parseCustomDSNTemplateConnConfig(*c.customDSNTemplate, database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return pgx.ConnectConfig(ctx, config)
|
||||
return pgxpool.NewWithConfig(ctx, config)
|
||||
}
|
||||
return cloudsql.Connect(ctx, c.getCloudSQLConnConfig(database))
|
||||
return cloudsql.GetConnectionPool(ctx, c.getCloudSQLConnConfig(database))
|
||||
}
|
||||
|
||||
func (c postgreSQLContract) getCloudSQLConnConfig(database string) cloudsql.ConnConfig {
|
||||
@ -83,7 +86,7 @@ func (c postgreSQLContract) getCloudSQLConnConfig(database string) cloudsql.Conn
|
||||
}
|
||||
}
|
||||
|
||||
func parseCustomDSNTemplateConnConfig(customDSNTemplate, database string) (*pgx.ConnConfig, error) {
|
||||
func parseCustomDSNTemplateConnConfig(customDSNTemplate, database string) (*pgxpool.Config, error) {
|
||||
tmpl, err := template.New("PGDSN").Parse(customDSNTemplate)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "PGDSN is not a valid template")
|
||||
@ -92,7 +95,7 @@ func parseCustomDSNTemplateConnConfig(customDSNTemplate, database string) (*pgx.
|
||||
if err := tmpl.Execute(&dsn, struct{ Database string }{Database: database}); err != nil {
|
||||
return nil, errors.Wrap(err, "PGDSN template is invalid")
|
||||
}
|
||||
config, err := pgx.ParseConfig(dsn.String())
|
||||
config, err := pgxpool.ParseConfig(dsn.String())
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "rendered PGDSN is invalid")
|
||||
}
|
||||
|
||||
@ -405,12 +405,13 @@ commands:
|
||||
DOTCOM_INCLUDE_PRODUCTION_LICENSES: 'true'
|
||||
# Used for authentication
|
||||
SAMS_URL: https://accounts.sgdev.org
|
||||
# client name: 'enterprise-portal-local-dev'
|
||||
ENTERPRISE_PORTAL_SAMS_CLIENT_ID: "sams_cid_018fc125-5a92-70fa-8dee-2c6df3adc100"
|
||||
externalSecrets:
|
||||
ENTERPRISE_PORTAL_SAMS_CLIENT_ID:
|
||||
project: sourcegraph-local-dev
|
||||
name: SG_LOCAL_DEV_SAMS_CLIENT_ID
|
||||
ENTERPRISE_PORTAL_SAMS_CLIENT_SECRET:
|
||||
project: sourcegraph-local-dev
|
||||
name: ENTERPRISE_PORTAL_LOCAL_SAMS_CLIENT_SECRET
|
||||
name: SG_LOCAL_DEV_SAMS_CLIENT_SECRET
|
||||
|
||||
watch:
|
||||
- lib
|
||||
|
||||
Loading…
Reference in New Issue
Block a user