feat/internal/memcmd: add internal/memcmd package to allow for memory tracking of exec.Cmd processes (#62803)

This PR adds a new package memcmd, that adds a new abstraction called
"Observer" that allows you to track the memory that a command (and all
of its children) is using. (This package uses a polling approach with
procfs, since [maxRSS on Linux is otherwise
unreliable](https://jkz.wtf/random-linux-oddity-1-ru_maxrss) for our
purposes).

Example usage

```go

import (
	"context"
	"fmt"
	"os/exec"
	"time"

	"github.com/sourcegraph/sourcegraph/internal/memcmd"
)

func Example() {
	const template = `
#!/usr/bin/env bash
set -euo pipefail

word=$(head -c "$((10 * 1024 * 1024))" </dev/zero | tr '\0' '\141') # 10MB worth of 'a's
sleep 1
echo ${#word}
`

	cmd := exec.Command("bash", "-c", template)
	err := cmd.Start()
	if err != nil {
		panic(err)
	}

	observer, err := memcmd.NewLinuxObserver(context.Background(), cmd, 1*time.Millisecond)
	if err != nil {
		panic(err)
	}

	observer.Start()
	defer observer.Stop()

	err = cmd.Wait()
	if err != nil {
		panic(err)
	}

	memoryUsage, err := observer.MaxMemoryUsage()
	if err != nil {
		panic(err)
	}

	fmt.Println((0 < memoryUsage && memoryUsage < 50*1024*1024)) // Output should be between 0 and 50MB

	// Output:
	// true
}

```

## Test plan

Unit tests

Note that some tests only work on darwin, so you'll have to run those
locally.

## Changelog 

This feature adds a package that allows us to track the memory usage of
commands invoked via exec.Cmd.

---------

Co-authored-by: Noah Santschi-Cooney <noah@santschi-cooney.ch>
This commit is contained in:
Geoffrey Gilmore 2024-06-10 14:20:15 -07:00 committed by GitHub
parent ab305657ab
commit aa1121c6ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1197 additions and 1 deletions

2
go.mod
View File

@ -604,7 +604,7 @@ require (
github.com/pquerna/cachecontrol v0.2.0 // indirect
github.com/prometheus/client_model v0.6.0
github.com/prometheus/common/sigv4 v0.1.0 // indirect
github.com/prometheus/procfs v0.12.0 // indirect
github.com/prometheus/procfs v0.12.0
github.com/pseudomuto/protoc-gen-doc v1.5.1
github.com/pseudomuto/protokit v0.2.1 // indirect
github.com/rivo/uniseg v0.4.6 // indirect

View File

@ -0,0 +1,71 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("//dev:go_defs.bzl", "go_test")
go_library(
name = "memcmd",
srcs = [
"observer.go",
"observer_darwin.go",
"observer_linux.go",
],
importpath = "github.com/sourcegraph/sourcegraph/internal/memcmd",
visibility = ["//:__subpackages__"],
deps = [
"//internal/bytesize",
"//lib/errors",
] + select({
"@io_bazel_rules_go//go/platform:android": [
"//internal/env",
"@com_github_prometheus_procfs//:procfs",
],
"@io_bazel_rules_go//go/platform:linux": [
"//internal/env",
"@com_github_prometheus_procfs//:procfs",
],
"//conditions:default": [],
}),
)
go_test(
name = "memcmd_test",
srcs = [
"observer_darwin_test.go",
"observer_example_test.go",
"observer_linux_test.go",
"observer_test.go",
],
data = ["@go_sdk//:bin/go"],
embed = [":memcmd"],
env = {
"GO_RLOCATIONPATH": "$(rlocationpath @go_sdk//:bin/go)",
},
deps = [
"@io_bazel_rules_go//go/runfiles:go_default_library",
] + select({
"@io_bazel_rules_go//go/platform:android": [
"//internal/bytesize",
"//lib/errors",
"@com_github_dustin_go_humanize//:go-humanize",
"@com_github_google_go_cmp//cmp",
"@com_github_sourcegraph_conc//pool",
],
"@io_bazel_rules_go//go/platform:darwin": [
"//internal/bytesize",
"//lib/errors",
"@com_github_dustin_go_humanize//:go-humanize",
],
"@io_bazel_rules_go//go/platform:ios": [
"//internal/bytesize",
"//lib/errors",
"@com_github_dustin_go_humanize//:go-humanize",
],
"@io_bazel_rules_go//go/platform:linux": [
"//internal/bytesize",
"//lib/errors",
"@com_github_dustin_go_humanize//:go-humanize",
"@com_github_google_go_cmp//cmp",
"@com_github_sourcegraph_conc//pool",
],
"//conditions:default": [],
}),
)

View File

@ -0,0 +1,90 @@
package memcmd
import (
"sync"
"github.com/sourcegraph/sourcegraph/internal/bytesize"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
// Observer is an interface for observing and tracking the memory usage of a process.
//
// Implementations of this interface should provide methods to start and stop the observation,
// as well as retrieve the maximum memory usage of the observed process.
//
// Callers must call Stop when they are done with the observer to release any associated resources.
type Observer interface {
// Start starts the observer. It should be called before any other method.
//
// After Start is called, callers must call Stop when they are done with the
// observer to release any resources.
//
// Calling Start() multiple times is safe and has no effect after the first invocation.
Start()
// Stop stops the observer and releases any associated resources. For accurate measurement,
// Stop must be called _after_ Wait has been called on the *exec.Cmd.
// Stop stops the observer and releases any associated resources.
//
// Calling Stop() multiple times is safe and has no effect after the first invocation.
Stop()
// MaxMemoryUsage returns the maximum memory usage in bytes of the process since
// the observer was started.
//
// Calling this method will also stop the observer
//
// It is only valid to call this method after:
// 1) Start() has been called and
// 2) the underlying process has stopped.
//
// See the individual observer implementations for more details on how memory
// usage is calculated.
MaxMemoryUsage() (bytes bytesize.Bytes, err error)
}
type noopObserver struct {
startOnce sync.Once
started chan struct{}
stopOnce sync.Once
stopped chan struct{}
}
func (o *noopObserver) Start() {
o.startOnce.Do(func() {
close(o.started)
})
}
func (o *noopObserver) Stop() {
o.stopOnce.Do(func() {
close(o.stopped)
})
}
func (o *noopObserver) MaxMemoryUsage() (bytesize.Bytes, error) {
select {
case <-o.started:
default:
return 0, errObserverNotStarted
}
o.Stop()
return 0, nil
}
// NewNoOpObserver returns an observer that does nothing. It is useful for
// testing or when you want to disable memory usage tracking.
func NewNoOpObserver() Observer {
return &noopObserver{
started: make(chan struct{}),
stopped: make(chan struct{}),
}
}
var _ Observer = &noopObserver{}
var errProcessNotStopped = errors.New("command has not stopped yet")
var errObserverNotStarted = errors.New("observer has not started yet")

View File

@ -0,0 +1,85 @@
//go:build darwin
package memcmd
import (
"context"
"os/exec"
"sync"
"syscall"
"github.com/sourcegraph/sourcegraph/internal/bytesize"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
type macObserver struct {
startOnce sync.Once
started chan struct{}
stopOnce sync.Once
cmd *exec.Cmd
}
// NewDefaultObserver creates a new Observer for a command running on macOS.
// The command must have already been started before calling this function.
// The command must have also been started with its own process group ID (cmd.SysProcAttr.Setpgid == true).
func NewDefaultObserver(_ context.Context, cmd *exec.Cmd) (Observer, error) {
return NewMacObserver(cmd)
}
// NewMacObserver creates a new Observer for a command running on macOS.
// The command must have already been started before calling this function.
// The command must have also been started with its own process group ID (cmd.SysProcAttr.Setpgid == true).
func NewMacObserver(cmd *exec.Cmd) (Observer, error) {
if cmd.Process == nil {
return nil, errors.New("command has not started")
}
attr := cmd.SysProcAttr
if !(attr != nil && attr.Setpgid) {
return nil, errProcessNotWithinOwnProcessGroup
}
return &macObserver{
started: make(chan struct{}),
cmd: cmd,
}, nil
}
func (o *macObserver) Start() {
o.startOnce.Do(func() {
close(o.started)
})
}
func (o *macObserver) Stop() {
o.stopOnce.Do(func() {})
}
func (o *macObserver) MaxMemoryUsage() (bytesize.Bytes, error) {
select {
case <-o.started:
default:
return 0, errObserverNotStarted
}
o.Stop()
state := o.cmd.ProcessState
if state == nil {
return 0, errProcessNotStopped
}
usage, ok := state.SysUsage().(*syscall.Rusage)
if !ok {
return 0, errors.New("failed to get rusage")
}
// On macOS, MAXRSS is the maximum resident set size used (in bytes, not kilobytes).
// See getrusage(2) for more information.
return bytesize.Bytes(usage.Maxrss), nil
}
var _ Observer = &macObserver{}
var errProcessNotWithinOwnProcessGroup = errors.New("command must be started with its own process group ID (cmd.SysProcAttr.Setpgid = true)")

View File

@ -0,0 +1,124 @@
//go:build darwin
package memcmd
import (
"bytes"
"context"
"os/exec"
"syscall"
"testing"
"github.com/dustin/go-humanize"
"github.com/sourcegraph/sourcegraph/internal/bytesize"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
func TestNewMacObserverIntegration(t *testing.T) {
cmd := allocatingGoProgram(t, 250*1024*1024)
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
var buf bytes.Buffer
cmd.Stderr = &buf
err := cmd.Start()
if err != nil {
t.Fatalf("failed to start test program: %v, stdErr: %s", err, buf.String())
}
observer, err := NewMacObserver(cmd)
if err != nil {
t.Fatalf("failed to create observer: %v", err)
}
observer.Start()
defer observer.Stop()
err = cmd.Wait()
if err != nil {
t.Fatalf("failed to wait for test program: %v stdErr: %s", err, buf.String())
}
memoryUsage, err := observer.MaxMemoryUsage()
if err != nil {
t.Fatalf("failed to get memory usage: %v", err)
}
t.Logf("memory usage: %s", humanize.Bytes(uint64(memoryUsage)))
memoryLow := 200 * bytesize.MB
memoryHigh := 300 * bytesize.MB
if !(memoryLow < memoryUsage && memoryUsage < memoryHigh) {
t.Fatalf("memory usage is not in the expected range (low: %s, high: %s): %s", humanize.Bytes(uint64(memoryLow)), humanize.Bytes(uint64(memoryHigh)), humanize.Bytes(uint64(memoryUsage)))
}
}
func TestMaxMemoryUsageErrorObserverNotStarted(t *testing.T) {
cmd := allocatingGoProgram(t, 50*1024*1024) // 50 MB
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
err := cmd.Start()
if err != nil {
t.Fatalf("failed to start test program: %v", err)
}
defer func() {
_ = cmd.Process.Kill()
}()
observer, err := NewMacObserver(cmd)
if err != nil {
t.Fatalf("failed to create observer: %v", err)
}
_, err = observer.MaxMemoryUsage()
if !errors.Is(err, errObserverNotStarted) {
t.Errorf("expected errObserverNotStarted, got: %v", err)
}
}
func TestMaxMemoryUsageErrorProcessNotCompleted(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cmd := exec.CommandContext(ctx, "sleep", "10s")
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
err := cmd.Start()
if err != nil {
t.Fatalf("failed to start test program: %v", err)
}
defer func() {
_ = cmd.Process.Kill()
}()
observer, err := NewMacObserver(cmd)
if err != nil {
t.Fatalf("failed to create observer: %v", err)
}
observer.Start()
defer observer.Stop()
_, err = observer.MaxMemoryUsage()
if !errors.Is(err, errProcessNotStopped) {
t.Errorf("expected errProcessNotStopped, got: %v", err)
}
}
func TestComplainAboutProcessNotWithinOwnGroup(t *testing.T) {
cmd := exec.Command("sleep", "10s")
err := cmd.Start()
if err != nil {
t.Fatalf("failed to start test program: %v", err)
}
defer func() {
_ = cmd.Process.Kill()
}()
_, err = NewMacObserver(cmd)
if !errors.Is(err, errProcessNotWithinOwnProcessGroup) {
t.Errorf("expected errProcessNotWithinOwnProcessGroup, got: %v", err)
}
}

View File

@ -0,0 +1,67 @@
//go:build linux
package memcmd_test
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"time"
"github.com/sourcegraph/sourcegraph/internal/memcmd"
)
func Example() {
const template = `
#!/usr/bin/env bash
set -euo pipefail
</dev/zero head -c $((1024**2*50)) | tail
sleep 1
`
tempDir, err := os.MkdirTemp("", "foo")
if err != nil {
panic(err)
}
defer func() {
_ = os.RemoveAll(tempDir)
}()
p := filepath.Join(tempDir, "/script.sh")
err = os.WriteFile(p, []byte(template), 0755)
if err != nil {
panic(err)
}
cmd := exec.Command("bash", "-c", p) // 50MB
err = cmd.Start()
if err != nil {
panic(err)
}
observer, err := memcmd.NewLinuxObserver(context.Background(), cmd, 1*time.Millisecond)
if err != nil {
panic(err)
}
observer.Start()
defer observer.Stop()
err = cmd.Wait()
if err != nil {
panic(err)
}
memoryUsage, err := observer.MaxMemoryUsage()
if err != nil {
panic(err)
}
fmt.Println((0 < memoryUsage && memoryUsage < 100*1024*1024)) // Output should be between 0 and 100MB
// Output:
// true
}

View File

@ -0,0 +1,321 @@
//go:build linux
package memcmd
import (
"context"
"io/fs"
"os/exec"
"sync"
"syscall"
"time"
"github.com/prometheus/procfs"
"github.com/sourcegraph/sourcegraph/internal/bytesize"
"github.com/sourcegraph/sourcegraph/internal/env"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
var defaultSamplingInterval = env.MustGetDuration("MEMORY_OBSERVATION_DEFAULT_SAMPLING_INTERVAL", 1*time.Millisecond, "For memory observers spawned by NewDefaultObserver, the interval at which memory usage is sampled. This environment variable only has an effect on Linux.")
// NewDefaultObserver creates a new Observer that observes the memory usage of a process and its children on Linux.
// This function is a convenience function that uses the default sampling interval specified by the MEMORY_OBSERVATION_DEFAULT_SAMPLING_INTERVAL
// environment variable.
//
// See NewLinuxObserver for more information.
func NewDefaultObserver(ctx context.Context, cmd *exec.Cmd) (Observer, error) {
return NewLinuxObserver(ctx, cmd, defaultSamplingInterval)
}
// linuxObserver is an Observer that observes the memory usage of a process and its children on Linux.
type linuxObserver struct {
ctx context.Context
proc processInfoProvider
samplingInterval time.Duration
startOnce sync.Once
started chan struct{}
stopOnce sync.Once
stopFunc func()
cmd *exec.Cmd
mu sync.RWMutex // mutex ensures that we can read and write the memory usage from different goroutines
highestMemoryUsageBytes uint64
errs error
}
// NewLinuxObserver creates a new Observer that observes the memory usage of a process and its children on Linux.
//
// The observer will start sampling the memory usage of the process and its children at regular intervals (specified by samplingInterval).
func NewLinuxObserver(ctx context.Context, cmd *exec.Cmd, samplingInterval time.Duration) (Observer, error) {
if cmd.Process == nil {
// The process has not been started yet
return nil, errors.New("process has not been started yet")
}
if samplingInterval <= 0 {
return nil, errors.New("samplingInterval must be greater than 0")
}
f, err := procfs.NewDefaultFS()
if err != nil {
return nil, errors.Wrap(err, "failed to create procfs")
}
proc := &procfsProcessInfoProvider{fs: f}
if ctx == nil {
ctx = context.Background()
}
ctx, cancel := context.WithCancel(ctx)
return &linuxObserver{
ctx: ctx,
proc: proc,
cmd: cmd,
started: make(chan struct{}),
stopFunc: cancel,
samplingInterval: samplingInterval,
}, nil
}
func (l *linuxObserver) MaxMemoryUsage() (bytesize.Bytes, error) {
select {
case <-l.started:
default:
return 0, errObserverNotStarted
}
l.Stop()
l.mu.RLock()
defer l.mu.RUnlock()
return bytesize.Bytes(l.highestMemoryUsageBytes), l.errs
}
// Start starts the observer.
func (l *linuxObserver) Start() {
l.startOnce.Do(func() {
go l.observe()
close(l.started)
})
}
func (l *linuxObserver) Stop() {
l.stopOnce.Do(func() {
l.stopFunc()
})
}
func (l *linuxObserver) observe() {
// Create a channel to signal when we should collect memory usage
doCollection := make(chan struct{}, 1)
doCollection <- struct{}{} // Trigger initial collection
donePiping := make(chan struct{})
defer close(donePiping)
go func() {
ticker := time.NewTicker(l.samplingInterval)
defer ticker.Stop()
for {
select {
case <-donePiping: // Shutdown the piping goroutine
return
case <-ticker.C: // Trigger memory collection at regular intervals
doCollection <- struct{}{}
}
}
}()
for {
select {
case <-l.ctx.Done():
return
case <-doCollection:
currentMemoryUsageBytes, err := memoryUsageForPidAndChildren(l.ctx, l.proc, l.cmd.Process.Pid)
l.mu.Lock()
l.errs = errors.Append(l.errs, err)
if currentMemoryUsageBytes > l.highestMemoryUsageBytes {
l.highestMemoryUsageBytes = currentMemoryUsageBytes
}
l.mu.Unlock()
}
}
}
func memoryUsageForPidAndChildren(ctx context.Context, proc processInfoProvider, basePid int) (currentMemoryUsageBytes uint64, err error) {
select {
case <-ctx.Done():
return 0, ctx.Err() // Return early if the context is done
default:
}
var allRSSMemoryBytes uint64
var errs error
// This is a depth-first search of the process tree rooted at basePID.
// For each iteration:
// 1) we pop the first element from the stack
// 2) add its memory usage to the total
// 3) add its children to the stack
//
// We continue this process until the stack is empty.
//
// This process is best-effort. We might miss some processes if they
// are created and destroyed between iterations.
//
// Some processes' memory information might also be unavailable to us (e.g. the parent process might have already waited
// on the child process, and the information is no longer available). In this specific case, we will ignore
// the error (will be an os.IsNotExist error since we are using procfs) and continue.
//
// In the end, we return the sum of all the RSS memory usage of the processes in the tree, and any errors that occurred during the iteration.
pidStack := []int{basePid}
for len(pidStack) > 0 {
select {
case <-ctx.Done():
return allRSSMemoryBytes, ctx.Err() // Return early if the context is done
default:
}
currentPid := pidStack[0]
pidStack = pidStack[1:]
rss, err := proc.RSS(currentPid)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) { // Ignore no-longer-existent processes
err = errors.Wrapf(err, "failed to report memory usage for pid %d", currentPid)
errs = errors.Append(errs, err)
}
continue
}
allRSSMemoryBytes += rss
children, err := proc.Children(currentPid)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) { // Ignore no-longer-existent processes
err = errors.Wrapf(err, "failed to list all processes")
errs = errors.Append(errs, err)
}
continue
}
pidStack = append(pidStack, children...)
}
return allRSSMemoryBytes, errs
}
type processInfoProvider interface {
// Children returns the PIDs of the children of the process with the given PID, or an error.
// This is a best-effort operation that might miss some children. See the implementation-specific documentation for
// more information.
//
// If the process does not exist, an error that wraps fs.ErrNotExist is returned.
Children(pid int) (childrenPIDs []int, err error)
// RSS returns the resident set size (RSS) of the process with the given PID, or an error.
//
// If the process does not exist, an error that wraps fs.ErrNotExist is returned.
RSS(pid int) (rssBytes uint64, err error)
}
type procfsProcessInfoProvider struct {
fs procfs.FS
}
func (p *procfsProcessInfoProvider) RSS(pid int) (rssBytes uint64, err error) {
memory, err := func() (uint64, error) {
proc, err := p.fs.Proc(pid)
if err != nil {
return 0, errors.Wrapf(err, "failed to get procfs")
}
status, err := proc.NewStatus()
if err != nil {
return 0, errors.Wrapf(err, "failed to get status")
}
return status.VmRSS, nil
}()
if err != nil {
err = convertESRCH(err) // Ensure that we convert ESRCH errors to fs.ErrNotExist
}
return memory, err
}
// Children returns the PIDs of the children of the process with the given PID, or an error.
//
// This is a best-effort operation that might miss some children since it doesn't represent a snapshot of the process tree.
// (e.g. a child process might be created and destroyed between the time we list the processes and the time we list the children).
func (p *procfsProcessInfoProvider) Children(parentPID int) (pids []int, err error) {
pids, err = func() ([]int, error) {
procs, err := p.fs.AllProcs()
if err != nil {
return nil, errors.Wrapf(err, "failed to list all processes")
}
var children []int
for _, p := range procs {
stat, err := p.Stat()
if err != nil {
if e := convertESRCH(err); !errors.Is(e, fs.ErrNotExist) { // Ignore no-longer-existent processes
err = errors.Wrapf(err, "failed to stat process %d", p.PID)
return nil, err
}
continue
}
if stat.PPID == parentPID {
children = append(children, p.PID)
}
}
return children, nil
}()
if err != nil {
err = convertESRCH(err) // Ensure that we wrap ESRCH errors with fs.ErrNotExist
}
return pids, err
}
// convertESRCH wraps an ESRCH error with fs.ErrNotExist
// to conform to the interface of the processInfoProvider
// (which makes it easier to check for errors).
func convertESRCH(err error) error {
var e syscall.Errno
if errors.As(err, &e) {
// Append fs.ErrNotExist to the error if the error is an ESRCH error (and we haven't already done so)
if e == syscall.ESRCH && !errors.Is(err, fs.ErrNotExist) {
return errors.Append(err, fs.ErrNotExist)
}
}
return err
}
var _ processInfoProvider = &procfsProcessInfoProvider{}

View File

@ -0,0 +1,326 @@
//go:build linux
package memcmd
import (
"bytes"
"context"
"io/fs"
"runtime"
"sort"
"syscall"
"testing"
"time"
"github.com/dustin/go-humanize"
"github.com/google/go-cmp/cmp"
"github.com/sourcegraph/conc/pool"
"github.com/sourcegraph/sourcegraph/internal/bytesize"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
func TestObserverIntegration(t *testing.T) {
cmd := allocatingGoProgram(t, 250*1024*1024) // 250 MB
var buf bytes.Buffer
cmd.Stderr = &buf
err := cmd.Start()
if err != nil {
t.Fatalf("failed to start test program: %v, stdErr: %s", err, buf.String())
}
observer, err := NewLinuxObserver(context.Background(), cmd, 1*time.Millisecond)
if err != nil {
t.Fatalf("failed to create observer: %v", err)
}
observer.Start()
defer observer.Stop()
err = cmd.Wait()
if err != nil {
t.Fatalf("failed to wait for test program: %v stdErr: %s", err, buf.String())
}
memoryUsage, err := observer.MaxMemoryUsage()
if err != nil {
t.Fatalf("failed to get memory usage: %v", err)
}
t.Logf("memory usage: %s", humanize.Bytes(uint64(memoryUsage)))
memoryLow := bytesize.Bytes(200 << 20) // 200 MB
memoryHigh := bytesize.Bytes(350 << 20) // 350 MB
if !(memoryLow < memoryUsage && memoryUsage < memoryHigh) {
t.Fatalf("memory usage is not in the expected range (low: %s, high: %s): %s", humanize.Bytes(uint64(memoryLow)), humanize.Bytes(uint64(memoryHigh)), humanize.Bytes(uint64(memoryUsage)))
}
}
func TestConvertESRCH(t *testing.T) {
tests := []struct {
name string
err error
expected error
}{
{
name: "Nil error",
err: nil,
expected: nil,
},
{
name: "Non-ESRCH syscall error",
err: syscall.ENOENT,
expected: syscall.ENOENT,
},
{
name: "ESRCH error",
err: syscall.ESRCH,
expected: errors.Append(syscall.ESRCH, fs.ErrNotExist),
},
{
name: "Wrapped ESRCH error",
err: errors.Wrap(syscall.ESRCH, "wrapped error"),
expected: errors.Append(errors.Wrap(syscall.ESRCH, "wrapped error"), fs.ErrNotExist),
},
{
name: "Path error including ESRCH",
err: &fs.PathError{
Op: "open",
Path: "/proc/1234",
Err: syscall.ESRCH,
},
expected: errors.Append(&fs.PathError{
Op: "open",
Path: "/proc/1234",
Err: syscall.ESRCH}, fs.ErrNotExist),
},
{
name: "Error already including fs.ErrNotExist",
err: errors.New("file does not exist"),
expected: errors.New("file does not exist"),
},
{
name: "Wrapped error already including fs.ErrNotExist",
err: errors.Wrap(fs.ErrNotExist, "wrapped error"),
expected: errors.Wrap(fs.ErrNotExist, "wrapped error"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
flattenErrStrings := func(e error) []string {
var out []string
for errStack := []error{e}; len(errStack) > 0; errStack = errStack[1:] {
err := errStack[0]
if err == nil {
continue
}
if errs, ok := err.(errors.MultiError); ok {
errStack = append(errStack, errs.Errors()...)
continue
}
out = append(out, err.Error())
}
sort.Strings(out)
return out
}
actualErrs := convertESRCH(tt.err)
if diff := cmp.Diff(flattenErrStrings(tt.expected), flattenErrStrings(actualErrs)); diff != "" {
t.Errorf("convertESRCH() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestMemoryUsageForPidAndChildren(t *testing.T) {
ctx := context.Background()
var spyRSSCalls []int
var spyChildrenCalls []int
// Set up a process tree with PIDs 1: 2, 3
// 3: 4
// where each process has a memory usage of 2^pid (so that we can easily check the sum of memory usage).
//
// Process 2 and 4 will disappear during the call to memoryUsageForPidAndChildren, and we should handle that
// gracefully.
proc := &mockProcLike{
t: t,
mockRSS: func(t *testing.T, pid int) (uint64, error) {
spyRSSCalls = append(spyRSSCalls, pid)
if pid == 2 || pid == 4 {
// Say that the process has disappeared
return 0, fs.ErrNotExist
}
return uint64(1 << pid), nil
},
mockChildren: func(t *testing.T, pid int) ([]int, error) {
spyChildrenCalls = append(spyChildrenCalls, pid)
if pid == 1 {
// Simulate a child process with PIDs 2, 3
return []int{2, 3}, nil
}
if pid == 3 {
// Simulate a child process with PID 4
return []int{4}, nil
}
return nil, fs.ErrNotExist
},
}
usage, err := memoryUsageForPidAndChildren(ctx, proc, 1)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if diff := cmp.Diff([]int{1, 2, 3, 4}, spyRSSCalls); diff != "" {
t.Fatalf("RSS calls mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff([]int{1, 3}, spyChildrenCalls); diff != "" {
t.Fatalf("Children calls mismatch (-want +got):\n%s", diff)
}
expectedUsage := uint64(1<<1 + 1<<3)
if usage != expectedUsage {
t.Errorf("Expected memory usage %d, got %d", expectedUsage, usage)
}
}
func TestMaxMemoryUsageErrorObserverNotStarted(t *testing.T) {
cmd := allocatingGoProgram(t, 50*1024*1024) // 50 MB
err := cmd.Start()
if err != nil {
t.Fatalf("failed to start test program: %v", err)
}
defer func() {
_ = cmd.Process.Kill()
}()
observer, err := NewLinuxObserver(context.Background(), cmd, 1*time.Millisecond)
if err != nil {
t.Fatalf("failed to create observer: %v", err)
}
_, err = observer.MaxMemoryUsage()
if !errors.Is(err, errObserverNotStarted) {
t.Errorf("expected errObserverNotStarted, got: %v", err)
}
}
type mockProcLike struct {
t *testing.T
mockRSS func(*testing.T, int) (uint64, error)
mockChildren func(*testing.T, int) ([]int, error)
}
func (m *mockProcLike) RSS(pid int) (uint64, error) {
if m.mockRSS != nil {
return m.mockRSS(m.t, pid)
}
m.t.Fatal("RSS not implemented")
return 0, nil
}
func (m *mockProcLike) Children(pid int) ([]int, error) {
if m.mockChildren != nil {
return m.mockChildren(m.t, pid)
}
m.t.Fatal("Children not implemented")
return nil, nil
}
func BenchmarkLinuxObservationApproaches(b *testing.B) {
b.Run("Observer", func(b *testing.B) {
for _, interval := range []time.Duration{1 * time.Millisecond, 10 * time.Millisecond, 100 * time.Millisecond} {
b.Run(interval.String(), func(b *testing.B) {
benchFunc(b, interval)
})
}
})
b.Run("NoObserver", func(b *testing.B) {
benchFunc(b, 0)
})
}
func benchFunc(b *testing.B, observerInterval time.Duration) {
for range b.N {
workerPool := pool.New().WithErrors()
for range runtime.NumCPU() {
workerPool.Go(func() error {
cmd := allocatingGoProgram(b, 50*1024*1024)
var buf bytes.Buffer
cmd.Stderr = &buf
err := cmd.Start()
if err != nil {
return errors.Errorf("starting command: %v, stdErr: %s", err, buf.String())
}
observer := NewNoOpObserver()
if observerInterval > 0 {
observer, err = NewLinuxObserver(context.Background(), cmd, observerInterval)
if err != nil {
return errors.Errorf("failed to create observer: %v", err)
}
}
observer.Start()
defer observer.Stop()
err = cmd.Wait()
if err != nil {
return errors.Errorf("waiting for command: %v, stdErr: %s", err, buf.String())
}
_, isNoOpObserver := observer.(*noopObserver)
if isNoOpObserver {
return nil
}
memory, err := observer.MaxMemoryUsage()
if err != nil {
return errors.Errorf("getting memory usage: %v", err)
}
memoryLow := bytesize.Bytes(10 << 20) // 10MB
memoryHigh := bytesize.Bytes(100 << 20) // 100MB
if !(memoryLow < memory && memory < memoryHigh) {
return errors.Errorf("memory usage is not in the expected range (low: %s, high: %s): %s", humanize.Bytes(uint64(memoryLow)), humanize.Bytes(uint64(memoryHigh)), humanize.Bytes(uint64(memory)))
}
return nil
})
if err := workerPool.Wait(); err != nil {
b.Fatalf("error in worker pool: %v", err)
}
}
}
}
var _ processInfoProvider = &mockProcLike{}

View File

@ -0,0 +1,112 @@
package memcmd
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"github.com/bazelbuild/rules_go/go/runfiles"
)
var goBinary = "go"
func init() {
if path := os.Getenv("GO_RLOCATIONPATH"); path != "" {
var err error
goBinary, err = runfiles.Rlocation(path)
if err != nil {
panic(err)
}
}
}
func allocatingGoProgram(t testing.TB, allocationSizeBytes uint64) *exec.Cmd {
t.Helper()
const goTemplate = `
package main
import (
"fmt"
"time"
"os"
)
func main() {
var slice []byte
if len(os.Args) > 0 { // Conditional that's always true to force the slice to be allocated on the heap
slice = make([]byte, %d)
for i := 0; i < len(slice); i++ {
slice[i] = byte(i & 0xff)
}
}
time.Sleep(500 * time.Millisecond)
fmt.Println(len(slice)) // Don't optimize the slice away
}`
goSource := fmt.Sprintf(goTemplate, allocationSizeBytes)
goFile := filepath.Join(t.TempDir(), "main.go")
err := os.WriteFile(goFile, []byte(goSource), 0o644) // permissions: -rw-r--r--
if err != nil {
t.Fatalf("failed to write test program: %v", err)
}
binaryPath := filepath.Join(t.TempDir(), "main")
const bashTemplateGoBuild = `
#!/usr/bin/env bash
set -euxo pipefail
%s build -o %s %s
`
ctx := context.Background()
args := []string{
"--login", // -l: login shell (so that we know that the PATH is set correctly for asdf if needed)
"-c",
fmt.Sprintf(bashTemplateGoBuild, goBinary, binaryPath, goFile),
}
goBuildCmd := exec.CommandContext(ctx, "bash", args...)
goBuildCmd.Env = append(goBuildCmd.Env, fmt.Sprintf("GOCACHE=%s", t.TempDir()))
{
// Ensure that the HOME environment variable is set. This is required for
// asdf to work correctly.
hasHome := false
for _, env := range goBuildCmd.Env {
if strings.HasPrefix(env, "HOME=") {
hasHome = true
break
}
}
if !hasHome {
if home, err := os.UserHomeDir(); err == nil {
goBuildCmd.Env = append(goBuildCmd.Env, fmt.Sprintf("HOME=%s", home))
}
}
}
_, err = goBuildCmd.Output()
if err != nil {
t.Fatalf("failed to compile test program: %v", err)
}
const bashTemplateRunCmd = `
#!/usr/bin/env bash
set -euxo pipefail
%s
echo "done" # force bash to fork
`
return exec.Command("bash", "-c", fmt.Sprintf(bashTemplateRunCmd, binaryPath))
}