sourcegraph/internal/memcmd/observer_linux.go
Geoffrey Gilmore aa1121c6ba
feat/internal/memcmd: add internal/memcmd package to allow for memory tracking of exec.Cmd processes (#62803)
This PR adds a new package memcmd, that adds a new abstraction called
"Observer" that allows you to track the memory that a command (and all
of its children) is using. (This package uses a polling approach with
procfs, since [maxRSS on Linux is otherwise
unreliable](https://jkz.wtf/random-linux-oddity-1-ru_maxrss) for our
purposes).

Example usage

```go

import (
	"context"
	"fmt"
	"os/exec"
	"time"

	"github.com/sourcegraph/sourcegraph/internal/memcmd"
)

func Example() {
	const template = `
#!/usr/bin/env bash
set -euo pipefail

word=$(head -c "$((10 * 1024 * 1024))" </dev/zero | tr '\0' '\141') # 10MB worth of 'a's
sleep 1
echo ${#word}
`

	cmd := exec.Command("bash", "-c", template)
	err := cmd.Start()
	if err != nil {
		panic(err)
	}

	observer, err := memcmd.NewLinuxObserver(context.Background(), cmd, 1*time.Millisecond)
	if err != nil {
		panic(err)
	}

	observer.Start()
	defer observer.Stop()

	err = cmd.Wait()
	if err != nil {
		panic(err)
	}

	memoryUsage, err := observer.MaxMemoryUsage()
	if err != nil {
		panic(err)
	}

	fmt.Println((0 < memoryUsage && memoryUsage < 50*1024*1024)) // Output should be between 0 and 50MB

	// Output:
	// true
}

```

## Test plan

Unit tests

Note that some tests only work on darwin, so you'll have to run those
locally.

## Changelog 

This feature adds a package that allows us to track the memory usage of
commands invoked via exec.Cmd.

---------

Co-authored-by: Noah Santschi-Cooney <noah@santschi-cooney.ch>
2024-06-10 14:20:15 -07:00

322 lines
8.7 KiB
Go

//go:build linux
package memcmd
import (
"context"
"io/fs"
"os/exec"
"sync"
"syscall"
"time"
"github.com/prometheus/procfs"
"github.com/sourcegraph/sourcegraph/internal/bytesize"
"github.com/sourcegraph/sourcegraph/internal/env"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
var defaultSamplingInterval = env.MustGetDuration("MEMORY_OBSERVATION_DEFAULT_SAMPLING_INTERVAL", 1*time.Millisecond, "For memory observers spawned by NewDefaultObserver, the interval at which memory usage is sampled. This environment variable only has an effect on Linux.")
// NewDefaultObserver creates a new Observer that observes the memory usage of a process and its children on Linux.
// This function is a convenience function that uses the default sampling interval specified by the MEMORY_OBSERVATION_DEFAULT_SAMPLING_INTERVAL
// environment variable.
//
// See NewLinuxObserver for more information.
func NewDefaultObserver(ctx context.Context, cmd *exec.Cmd) (Observer, error) {
return NewLinuxObserver(ctx, cmd, defaultSamplingInterval)
}
// linuxObserver is an Observer that observes the memory usage of a process and its children on Linux.
type linuxObserver struct {
ctx context.Context
proc processInfoProvider
samplingInterval time.Duration
startOnce sync.Once
started chan struct{}
stopOnce sync.Once
stopFunc func()
cmd *exec.Cmd
mu sync.RWMutex // mutex ensures that we can read and write the memory usage from different goroutines
highestMemoryUsageBytes uint64
errs error
}
// NewLinuxObserver creates a new Observer that observes the memory usage of a process and its children on Linux.
//
// The observer will start sampling the memory usage of the process and its children at regular intervals (specified by samplingInterval).
func NewLinuxObserver(ctx context.Context, cmd *exec.Cmd, samplingInterval time.Duration) (Observer, error) {
if cmd.Process == nil {
// The process has not been started yet
return nil, errors.New("process has not been started yet")
}
if samplingInterval <= 0 {
return nil, errors.New("samplingInterval must be greater than 0")
}
f, err := procfs.NewDefaultFS()
if err != nil {
return nil, errors.Wrap(err, "failed to create procfs")
}
proc := &procfsProcessInfoProvider{fs: f}
if ctx == nil {
ctx = context.Background()
}
ctx, cancel := context.WithCancel(ctx)
return &linuxObserver{
ctx: ctx,
proc: proc,
cmd: cmd,
started: make(chan struct{}),
stopFunc: cancel,
samplingInterval: samplingInterval,
}, nil
}
func (l *linuxObserver) MaxMemoryUsage() (bytesize.Bytes, error) {
select {
case <-l.started:
default:
return 0, errObserverNotStarted
}
l.Stop()
l.mu.RLock()
defer l.mu.RUnlock()
return bytesize.Bytes(l.highestMemoryUsageBytes), l.errs
}
// Start starts the observer.
func (l *linuxObserver) Start() {
l.startOnce.Do(func() {
go l.observe()
close(l.started)
})
}
func (l *linuxObserver) Stop() {
l.stopOnce.Do(func() {
l.stopFunc()
})
}
func (l *linuxObserver) observe() {
// Create a channel to signal when we should collect memory usage
doCollection := make(chan struct{}, 1)
doCollection <- struct{}{} // Trigger initial collection
donePiping := make(chan struct{})
defer close(donePiping)
go func() {
ticker := time.NewTicker(l.samplingInterval)
defer ticker.Stop()
for {
select {
case <-donePiping: // Shutdown the piping goroutine
return
case <-ticker.C: // Trigger memory collection at regular intervals
doCollection <- struct{}{}
}
}
}()
for {
select {
case <-l.ctx.Done():
return
case <-doCollection:
currentMemoryUsageBytes, err := memoryUsageForPidAndChildren(l.ctx, l.proc, l.cmd.Process.Pid)
l.mu.Lock()
l.errs = errors.Append(l.errs, err)
if currentMemoryUsageBytes > l.highestMemoryUsageBytes {
l.highestMemoryUsageBytes = currentMemoryUsageBytes
}
l.mu.Unlock()
}
}
}
func memoryUsageForPidAndChildren(ctx context.Context, proc processInfoProvider, basePid int) (currentMemoryUsageBytes uint64, err error) {
select {
case <-ctx.Done():
return 0, ctx.Err() // Return early if the context is done
default:
}
var allRSSMemoryBytes uint64
var errs error
// This is a depth-first search of the process tree rooted at basePID.
// For each iteration:
// 1) we pop the first element from the stack
// 2) add its memory usage to the total
// 3) add its children to the stack
//
// We continue this process until the stack is empty.
//
// This process is best-effort. We might miss some processes if they
// are created and destroyed between iterations.
//
// Some processes' memory information might also be unavailable to us (e.g. the parent process might have already waited
// on the child process, and the information is no longer available). In this specific case, we will ignore
// the error (will be an os.IsNotExist error since we are using procfs) and continue.
//
// In the end, we return the sum of all the RSS memory usage of the processes in the tree, and any errors that occurred during the iteration.
pidStack := []int{basePid}
for len(pidStack) > 0 {
select {
case <-ctx.Done():
return allRSSMemoryBytes, ctx.Err() // Return early if the context is done
default:
}
currentPid := pidStack[0]
pidStack = pidStack[1:]
rss, err := proc.RSS(currentPid)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) { // Ignore no-longer-existent processes
err = errors.Wrapf(err, "failed to report memory usage for pid %d", currentPid)
errs = errors.Append(errs, err)
}
continue
}
allRSSMemoryBytes += rss
children, err := proc.Children(currentPid)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) { // Ignore no-longer-existent processes
err = errors.Wrapf(err, "failed to list all processes")
errs = errors.Append(errs, err)
}
continue
}
pidStack = append(pidStack, children...)
}
return allRSSMemoryBytes, errs
}
type processInfoProvider interface {
// Children returns the PIDs of the children of the process with the given PID, or an error.
// This is a best-effort operation that might miss some children. See the implementation-specific documentation for
// more information.
//
// If the process does not exist, an error that wraps fs.ErrNotExist is returned.
Children(pid int) (childrenPIDs []int, err error)
// RSS returns the resident set size (RSS) of the process with the given PID, or an error.
//
// If the process does not exist, an error that wraps fs.ErrNotExist is returned.
RSS(pid int) (rssBytes uint64, err error)
}
type procfsProcessInfoProvider struct {
fs procfs.FS
}
func (p *procfsProcessInfoProvider) RSS(pid int) (rssBytes uint64, err error) {
memory, err := func() (uint64, error) {
proc, err := p.fs.Proc(pid)
if err != nil {
return 0, errors.Wrapf(err, "failed to get procfs")
}
status, err := proc.NewStatus()
if err != nil {
return 0, errors.Wrapf(err, "failed to get status")
}
return status.VmRSS, nil
}()
if err != nil {
err = convertESRCH(err) // Ensure that we convert ESRCH errors to fs.ErrNotExist
}
return memory, err
}
// Children returns the PIDs of the children of the process with the given PID, or an error.
//
// This is a best-effort operation that might miss some children since it doesn't represent a snapshot of the process tree.
// (e.g. a child process might be created and destroyed between the time we list the processes and the time we list the children).
func (p *procfsProcessInfoProvider) Children(parentPID int) (pids []int, err error) {
pids, err = func() ([]int, error) {
procs, err := p.fs.AllProcs()
if err != nil {
return nil, errors.Wrapf(err, "failed to list all processes")
}
var children []int
for _, p := range procs {
stat, err := p.Stat()
if err != nil {
if e := convertESRCH(err); !errors.Is(e, fs.ErrNotExist) { // Ignore no-longer-existent processes
err = errors.Wrapf(err, "failed to stat process %d", p.PID)
return nil, err
}
continue
}
if stat.PPID == parentPID {
children = append(children, p.PID)
}
}
return children, nil
}()
if err != nil {
err = convertESRCH(err) // Ensure that we wrap ESRCH errors with fs.ErrNotExist
}
return pids, err
}
// convertESRCH wraps an ESRCH error with fs.ErrNotExist
// to conform to the interface of the processInfoProvider
// (which makes it easier to check for errors).
func convertESRCH(err error) error {
var e syscall.Errno
if errors.As(err, &e) {
// Append fs.ErrNotExist to the error if the error is an ESRCH error (and we haven't already done so)
if e == syscall.ESRCH && !errors.Is(err, fs.ErrNotExist) {
return errors.Append(err, fs.ErrNotExist)
}
}
return err
}
var _ processInfoProvider = &procfsProcessInfoProvider{}