From aa1121c6ba0122598280281b8d02c0e045dbd031 Mon Sep 17 00:00:00 2001 From: Geoffrey Gilmore Date: Mon, 10 Jun 2024 14:20:15 -0700 Subject: [PATCH] feat/internal/memcmd: add internal/memcmd package to allow for memory tracking of exec.Cmd processes (#62803) This PR adds a new package memcmd, that adds a new abstraction called "Observer" that allows you to track the memory that a command (and all of its children) is using. (This package uses a polling approach with procfs, since [maxRSS on Linux is otherwise unreliable](https://jkz.wtf/random-linux-oddity-1-ru_maxrss) for our purposes). Example usage ```go import ( "context" "fmt" "os/exec" "time" "github.com/sourcegraph/sourcegraph/internal/memcmd" ) func Example() { const template = ` #!/usr/bin/env bash set -euo pipefail word=$(head -c "$((10 * 1024 * 1024))" --- go.mod | 2 +- internal/memcmd/BUILD.bazel | 71 +++++ internal/memcmd/observer.go | 90 +++++++ internal/memcmd/observer_darwin.go | 85 ++++++ internal/memcmd/observer_darwin_test.go | 124 +++++++++ internal/memcmd/observer_example_test.go | 67 +++++ internal/memcmd/observer_linux.go | 321 ++++++++++++++++++++++ internal/memcmd/observer_linux_test.go | 326 +++++++++++++++++++++++ internal/memcmd/observer_test.go | 112 ++++++++ 9 files changed, 1197 insertions(+), 1 deletion(-) create mode 100644 internal/memcmd/BUILD.bazel create mode 100644 internal/memcmd/observer.go create mode 100644 internal/memcmd/observer_darwin.go create mode 100644 internal/memcmd/observer_darwin_test.go create mode 100644 internal/memcmd/observer_example_test.go create mode 100644 internal/memcmd/observer_linux.go create mode 100644 internal/memcmd/observer_linux_test.go create mode 100644 internal/memcmd/observer_test.go diff --git a/go.mod b/go.mod index 5c67c728348..8851587e998 100644 --- a/go.mod +++ b/go.mod @@ -604,7 +604,7 @@ require ( github.com/pquerna/cachecontrol v0.2.0 // indirect github.com/prometheus/client_model v0.6.0 github.com/prometheus/common/sigv4 v0.1.0 // indirect - github.com/prometheus/procfs v0.12.0 // indirect + github.com/prometheus/procfs v0.12.0 github.com/pseudomuto/protoc-gen-doc v1.5.1 github.com/pseudomuto/protokit v0.2.1 // indirect github.com/rivo/uniseg v0.4.6 // indirect diff --git a/internal/memcmd/BUILD.bazel b/internal/memcmd/BUILD.bazel new file mode 100644 index 00000000000..3f995a0500d --- /dev/null +++ b/internal/memcmd/BUILD.bazel @@ -0,0 +1,71 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//dev:go_defs.bzl", "go_test") + +go_library( + name = "memcmd", + srcs = [ + "observer.go", + "observer_darwin.go", + "observer_linux.go", + ], + importpath = "github.com/sourcegraph/sourcegraph/internal/memcmd", + visibility = ["//:__subpackages__"], + deps = [ + "//internal/bytesize", + "//lib/errors", + ] + select({ + "@io_bazel_rules_go//go/platform:android": [ + "//internal/env", + "@com_github_prometheus_procfs//:procfs", + ], + "@io_bazel_rules_go//go/platform:linux": [ + "//internal/env", + "@com_github_prometheus_procfs//:procfs", + ], + "//conditions:default": [], + }), +) + +go_test( + name = "memcmd_test", + srcs = [ + "observer_darwin_test.go", + "observer_example_test.go", + "observer_linux_test.go", + "observer_test.go", + ], + data = ["@go_sdk//:bin/go"], + embed = [":memcmd"], + env = { + "GO_RLOCATIONPATH": "$(rlocationpath @go_sdk//:bin/go)", + }, + deps = [ + "@io_bazel_rules_go//go/runfiles:go_default_library", + ] + select({ + "@io_bazel_rules_go//go/platform:android": [ + "//internal/bytesize", + "//lib/errors", + "@com_github_dustin_go_humanize//:go-humanize", + "@com_github_google_go_cmp//cmp", + "@com_github_sourcegraph_conc//pool", + ], + "@io_bazel_rules_go//go/platform:darwin": [ + "//internal/bytesize", + "//lib/errors", + "@com_github_dustin_go_humanize//:go-humanize", + ], + "@io_bazel_rules_go//go/platform:ios": [ + "//internal/bytesize", + "//lib/errors", + "@com_github_dustin_go_humanize//:go-humanize", + ], + "@io_bazel_rules_go//go/platform:linux": [ + "//internal/bytesize", + "//lib/errors", + "@com_github_dustin_go_humanize//:go-humanize", + "@com_github_google_go_cmp//cmp", + "@com_github_sourcegraph_conc//pool", + ], + "//conditions:default": [], + }), +) diff --git a/internal/memcmd/observer.go b/internal/memcmd/observer.go new file mode 100644 index 00000000000..ca1d6270afe --- /dev/null +++ b/internal/memcmd/observer.go @@ -0,0 +1,90 @@ +package memcmd + +import ( + "sync" + + "github.com/sourcegraph/sourcegraph/internal/bytesize" + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +// Observer is an interface for observing and tracking the memory usage of a process. +// +// Implementations of this interface should provide methods to start and stop the observation, +// as well as retrieve the maximum memory usage of the observed process. +// +// Callers must call Stop when they are done with the observer to release any associated resources. +type Observer interface { + // Start starts the observer. It should be called before any other method. + // + // After Start is called, callers must call Stop when they are done with the + // observer to release any resources. + // + // Calling Start() multiple times is safe and has no effect after the first invocation. + Start() + + // Stop stops the observer and releases any associated resources. For accurate measurement, + // Stop must be called _after_ Wait has been called on the *exec.Cmd. + // Stop stops the observer and releases any associated resources. + // + // Calling Stop() multiple times is safe and has no effect after the first invocation. + Stop() + + // MaxMemoryUsage returns the maximum memory usage in bytes of the process since + // the observer was started. + // + // Calling this method will also stop the observer + // + // It is only valid to call this method after: + // 1) Start() has been called and + // 2) the underlying process has stopped. + // + // See the individual observer implementations for more details on how memory + // usage is calculated. + MaxMemoryUsage() (bytes bytesize.Bytes, err error) +} + +type noopObserver struct { + startOnce sync.Once + started chan struct{} + + stopOnce sync.Once + stopped chan struct{} +} + +func (o *noopObserver) Start() { + o.startOnce.Do(func() { + close(o.started) + }) +} + +func (o *noopObserver) Stop() { + o.stopOnce.Do(func() { + close(o.stopped) + }) +} + +func (o *noopObserver) MaxMemoryUsage() (bytesize.Bytes, error) { + select { + case <-o.started: + default: + return 0, errObserverNotStarted + } + + o.Stop() + + return 0, nil +} + +// NewNoOpObserver returns an observer that does nothing. It is useful for +// testing or when you want to disable memory usage tracking. +func NewNoOpObserver() Observer { + return &noopObserver{ + started: make(chan struct{}), + stopped: make(chan struct{}), + } +} + +var _ Observer = &noopObserver{} + +var errProcessNotStopped = errors.New("command has not stopped yet") +var errObserverNotStarted = errors.New("observer has not started yet") diff --git a/internal/memcmd/observer_darwin.go b/internal/memcmd/observer_darwin.go new file mode 100644 index 00000000000..84603f55556 --- /dev/null +++ b/internal/memcmd/observer_darwin.go @@ -0,0 +1,85 @@ +//go:build darwin + +package memcmd + +import ( + "context" + "os/exec" + "sync" + "syscall" + + "github.com/sourcegraph/sourcegraph/internal/bytesize" + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +type macObserver struct { + startOnce sync.Once + started chan struct{} + + stopOnce sync.Once + cmd *exec.Cmd +} + +// NewDefaultObserver creates a new Observer for a command running on macOS. +// The command must have already been started before calling this function. +// The command must have also been started with its own process group ID (cmd.SysProcAttr.Setpgid == true). +func NewDefaultObserver(_ context.Context, cmd *exec.Cmd) (Observer, error) { + return NewMacObserver(cmd) +} + +// NewMacObserver creates a new Observer for a command running on macOS. +// The command must have already been started before calling this function. +// The command must have also been started with its own process group ID (cmd.SysProcAttr.Setpgid == true). +func NewMacObserver(cmd *exec.Cmd) (Observer, error) { + if cmd.Process == nil { + return nil, errors.New("command has not started") + } + + attr := cmd.SysProcAttr + if !(attr != nil && attr.Setpgid) { + return nil, errProcessNotWithinOwnProcessGroup + } + + return &macObserver{ + started: make(chan struct{}), + cmd: cmd, + }, nil +} + +func (o *macObserver) Start() { + o.startOnce.Do(func() { + close(o.started) + }) +} + +func (o *macObserver) Stop() { + o.stopOnce.Do(func() {}) +} + +func (o *macObserver) MaxMemoryUsage() (bytesize.Bytes, error) { + select { + case <-o.started: + default: + return 0, errObserverNotStarted + } + + o.Stop() + + state := o.cmd.ProcessState + if state == nil { + return 0, errProcessNotStopped + } + + usage, ok := state.SysUsage().(*syscall.Rusage) + if !ok { + return 0, errors.New("failed to get rusage") + } + + // On macOS, MAXRSS is the maximum resident set size used (in bytes, not kilobytes). + // See getrusage(2) for more information. + return bytesize.Bytes(usage.Maxrss), nil +} + +var _ Observer = &macObserver{} + +var errProcessNotWithinOwnProcessGroup = errors.New("command must be started with its own process group ID (cmd.SysProcAttr.Setpgid = true)") diff --git a/internal/memcmd/observer_darwin_test.go b/internal/memcmd/observer_darwin_test.go new file mode 100644 index 00000000000..3c6bbf2668b --- /dev/null +++ b/internal/memcmd/observer_darwin_test.go @@ -0,0 +1,124 @@ +//go:build darwin + +package memcmd + +import ( + "bytes" + "context" + "os/exec" + "syscall" + "testing" + + "github.com/dustin/go-humanize" + + "github.com/sourcegraph/sourcegraph/internal/bytesize" + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +func TestNewMacObserverIntegration(t *testing.T) { + cmd := allocatingGoProgram(t, 250*1024*1024) + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + var buf bytes.Buffer + cmd.Stderr = &buf + err := cmd.Start() + if err != nil { + t.Fatalf("failed to start test program: %v, stdErr: %s", err, buf.String()) + } + + observer, err := NewMacObserver(cmd) + if err != nil { + t.Fatalf("failed to create observer: %v", err) + } + + observer.Start() + defer observer.Stop() + + err = cmd.Wait() + if err != nil { + t.Fatalf("failed to wait for test program: %v stdErr: %s", err, buf.String()) + } + + memoryUsage, err := observer.MaxMemoryUsage() + if err != nil { + t.Fatalf("failed to get memory usage: %v", err) + } + + t.Logf("memory usage: %s", humanize.Bytes(uint64(memoryUsage))) + memoryLow := 200 * bytesize.MB + memoryHigh := 300 * bytesize.MB + + if !(memoryLow < memoryUsage && memoryUsage < memoryHigh) { + t.Fatalf("memory usage is not in the expected range (low: %s, high: %s): %s", humanize.Bytes(uint64(memoryLow)), humanize.Bytes(uint64(memoryHigh)), humanize.Bytes(uint64(memoryUsage))) + } +} + +func TestMaxMemoryUsageErrorObserverNotStarted(t *testing.T) { + cmd := allocatingGoProgram(t, 50*1024*1024) // 50 MB + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + err := cmd.Start() + if err != nil { + t.Fatalf("failed to start test program: %v", err) + } + + defer func() { + _ = cmd.Process.Kill() + }() + + observer, err := NewMacObserver(cmd) + if err != nil { + t.Fatalf("failed to create observer: %v", err) + } + + _, err = observer.MaxMemoryUsage() + if !errors.Is(err, errObserverNotStarted) { + t.Errorf("expected errObserverNotStarted, got: %v", err) + } +} + +func TestMaxMemoryUsageErrorProcessNotCompleted(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cmd := exec.CommandContext(ctx, "sleep", "10s") + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + err := cmd.Start() + if err != nil { + t.Fatalf("failed to start test program: %v", err) + } + + defer func() { + _ = cmd.Process.Kill() + }() + + observer, err := NewMacObserver(cmd) + if err != nil { + t.Fatalf("failed to create observer: %v", err) + } + + observer.Start() + defer observer.Stop() + + _, err = observer.MaxMemoryUsage() + if !errors.Is(err, errProcessNotStopped) { + t.Errorf("expected errProcessNotStopped, got: %v", err) + } +} + +func TestComplainAboutProcessNotWithinOwnGroup(t *testing.T) { + cmd := exec.Command("sleep", "10s") + err := cmd.Start() + if err != nil { + t.Fatalf("failed to start test program: %v", err) + } + + defer func() { + _ = cmd.Process.Kill() + }() + + _, err = NewMacObserver(cmd) + if !errors.Is(err, errProcessNotWithinOwnProcessGroup) { + t.Errorf("expected errProcessNotWithinOwnProcessGroup, got: %v", err) + } +} diff --git a/internal/memcmd/observer_example_test.go b/internal/memcmd/observer_example_test.go new file mode 100644 index 00000000000..f1347bfdd37 --- /dev/null +++ b/internal/memcmd/observer_example_test.go @@ -0,0 +1,67 @@ +//go:build linux + +package memcmd_test + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "time" + + "github.com/sourcegraph/sourcegraph/internal/memcmd" +) + +func Example() { + const template = ` +#!/usr/bin/env bash +set -euo pipefail + + l.highestMemoryUsageBytes { + l.highestMemoryUsageBytes = currentMemoryUsageBytes + } + + l.mu.Unlock() + } + } +} + +func memoryUsageForPidAndChildren(ctx context.Context, proc processInfoProvider, basePid int) (currentMemoryUsageBytes uint64, err error) { + select { + case <-ctx.Done(): + return 0, ctx.Err() // Return early if the context is done + default: + } + + var allRSSMemoryBytes uint64 + var errs error + + // This is a depth-first search of the process tree rooted at basePID. + // For each iteration: + // 1) we pop the first element from the stack + // 2) add its memory usage to the total + // 3) add its children to the stack + // + // We continue this process until the stack is empty. + // + // This process is best-effort. We might miss some processes if they + // are created and destroyed between iterations. + // + // Some processes' memory information might also be unavailable to us (e.g. the parent process might have already waited + // on the child process, and the information is no longer available). In this specific case, we will ignore + // the error (will be an os.IsNotExist error since we are using procfs) and continue. + // + // In the end, we return the sum of all the RSS memory usage of the processes in the tree, and any errors that occurred during the iteration. + + pidStack := []int{basePid} + for len(pidStack) > 0 { + select { + case <-ctx.Done(): + return allRSSMemoryBytes, ctx.Err() // Return early if the context is done + default: + } + + currentPid := pidStack[0] + pidStack = pidStack[1:] + + rss, err := proc.RSS(currentPid) + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { // Ignore no-longer-existent processes + err = errors.Wrapf(err, "failed to report memory usage for pid %d", currentPid) + errs = errors.Append(errs, err) + } + + continue + } + + allRSSMemoryBytes += rss + + children, err := proc.Children(currentPid) + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { // Ignore no-longer-existent processes + err = errors.Wrapf(err, "failed to list all processes") + errs = errors.Append(errs, err) + } + + continue + } + + pidStack = append(pidStack, children...) + } + + return allRSSMemoryBytes, errs +} + +type processInfoProvider interface { + // Children returns the PIDs of the children of the process with the given PID, or an error. + // This is a best-effort operation that might miss some children. See the implementation-specific documentation for + // more information. + // + // If the process does not exist, an error that wraps fs.ErrNotExist is returned. + Children(pid int) (childrenPIDs []int, err error) + + // RSS returns the resident set size (RSS) of the process with the given PID, or an error. + // + // If the process does not exist, an error that wraps fs.ErrNotExist is returned. + RSS(pid int) (rssBytes uint64, err error) +} + +type procfsProcessInfoProvider struct { + fs procfs.FS +} + +func (p *procfsProcessInfoProvider) RSS(pid int) (rssBytes uint64, err error) { + memory, err := func() (uint64, error) { + proc, err := p.fs.Proc(pid) + if err != nil { + return 0, errors.Wrapf(err, "failed to get procfs") + } + + status, err := proc.NewStatus() + if err != nil { + return 0, errors.Wrapf(err, "failed to get status") + } + return status.VmRSS, nil + }() + + if err != nil { + err = convertESRCH(err) // Ensure that we convert ESRCH errors to fs.ErrNotExist + } + + return memory, err +} + +// Children returns the PIDs of the children of the process with the given PID, or an error. +// +// This is a best-effort operation that might miss some children since it doesn't represent a snapshot of the process tree. +// (e.g. a child process might be created and destroyed between the time we list the processes and the time we list the children). +func (p *procfsProcessInfoProvider) Children(parentPID int) (pids []int, err error) { + pids, err = func() ([]int, error) { + procs, err := p.fs.AllProcs() + if err != nil { + return nil, errors.Wrapf(err, "failed to list all processes") + } + + var children []int + for _, p := range procs { + stat, err := p.Stat() + if err != nil { + if e := convertESRCH(err); !errors.Is(e, fs.ErrNotExist) { // Ignore no-longer-existent processes + err = errors.Wrapf(err, "failed to stat process %d", p.PID) + return nil, err + } + + continue + } + + if stat.PPID == parentPID { + children = append(children, p.PID) + } + } + + return children, nil + }() + + if err != nil { + err = convertESRCH(err) // Ensure that we wrap ESRCH errors with fs.ErrNotExist + } + + return pids, err +} + +// convertESRCH wraps an ESRCH error with fs.ErrNotExist +// to conform to the interface of the processInfoProvider +// (which makes it easier to check for errors). +func convertESRCH(err error) error { + var e syscall.Errno + if errors.As(err, &e) { + // Append fs.ErrNotExist to the error if the error is an ESRCH error (and we haven't already done so) + if e == syscall.ESRCH && !errors.Is(err, fs.ErrNotExist) { + return errors.Append(err, fs.ErrNotExist) + } + } + + return err +} + +var _ processInfoProvider = &procfsProcessInfoProvider{} diff --git a/internal/memcmd/observer_linux_test.go b/internal/memcmd/observer_linux_test.go new file mode 100644 index 00000000000..880c49f7254 --- /dev/null +++ b/internal/memcmd/observer_linux_test.go @@ -0,0 +1,326 @@ +//go:build linux + +package memcmd + +import ( + "bytes" + "context" + "io/fs" + "runtime" + "sort" + "syscall" + "testing" + "time" + + "github.com/dustin/go-humanize" + "github.com/google/go-cmp/cmp" + "github.com/sourcegraph/conc/pool" + + "github.com/sourcegraph/sourcegraph/internal/bytesize" + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +func TestObserverIntegration(t *testing.T) { + cmd := allocatingGoProgram(t, 250*1024*1024) // 250 MB + + var buf bytes.Buffer + cmd.Stderr = &buf + err := cmd.Start() + if err != nil { + t.Fatalf("failed to start test program: %v, stdErr: %s", err, buf.String()) + } + + observer, err := NewLinuxObserver(context.Background(), cmd, 1*time.Millisecond) + if err != nil { + t.Fatalf("failed to create observer: %v", err) + } + + observer.Start() + defer observer.Stop() + + err = cmd.Wait() + if err != nil { + t.Fatalf("failed to wait for test program: %v stdErr: %s", err, buf.String()) + } + + memoryUsage, err := observer.MaxMemoryUsage() + if err != nil { + t.Fatalf("failed to get memory usage: %v", err) + } + + t.Logf("memory usage: %s", humanize.Bytes(uint64(memoryUsage))) + + memoryLow := bytesize.Bytes(200 << 20) // 200 MB + memoryHigh := bytesize.Bytes(350 << 20) // 350 MB + + if !(memoryLow < memoryUsage && memoryUsage < memoryHigh) { + t.Fatalf("memory usage is not in the expected range (low: %s, high: %s): %s", humanize.Bytes(uint64(memoryLow)), humanize.Bytes(uint64(memoryHigh)), humanize.Bytes(uint64(memoryUsage))) + } +} + +func TestConvertESRCH(t *testing.T) { + tests := []struct { + name string + err error + expected error + }{ + { + name: "Nil error", + err: nil, + expected: nil, + }, + { + name: "Non-ESRCH syscall error", + err: syscall.ENOENT, + expected: syscall.ENOENT, + }, + { + name: "ESRCH error", + err: syscall.ESRCH, + expected: errors.Append(syscall.ESRCH, fs.ErrNotExist), + }, + { + name: "Wrapped ESRCH error", + err: errors.Wrap(syscall.ESRCH, "wrapped error"), + expected: errors.Append(errors.Wrap(syscall.ESRCH, "wrapped error"), fs.ErrNotExist), + }, + { + name: "Path error including ESRCH", + err: &fs.PathError{ + Op: "open", + Path: "/proc/1234", + Err: syscall.ESRCH, + }, + expected: errors.Append(&fs.PathError{ + Op: "open", + Path: "/proc/1234", + Err: syscall.ESRCH}, fs.ErrNotExist), + }, + { + name: "Error already including fs.ErrNotExist", + err: errors.New("file does not exist"), + expected: errors.New("file does not exist"), + }, + { + name: "Wrapped error already including fs.ErrNotExist", + err: errors.Wrap(fs.ErrNotExist, "wrapped error"), + expected: errors.Wrap(fs.ErrNotExist, "wrapped error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + flattenErrStrings := func(e error) []string { + var out []string + + for errStack := []error{e}; len(errStack) > 0; errStack = errStack[1:] { + err := errStack[0] + + if err == nil { + continue + } + + if errs, ok := err.(errors.MultiError); ok { + errStack = append(errStack, errs.Errors()...) + continue + } + + out = append(out, err.Error()) + } + + sort.Strings(out) + return out + } + + actualErrs := convertESRCH(tt.err) + + if diff := cmp.Diff(flattenErrStrings(tt.expected), flattenErrStrings(actualErrs)); diff != "" { + t.Errorf("convertESRCH() mismatch (-want +got):\n%s", diff) + } + + }) + } +} + +func TestMemoryUsageForPidAndChildren(t *testing.T) { + ctx := context.Background() + + var spyRSSCalls []int + var spyChildrenCalls []int + + // Set up a process tree with PIDs 1: 2, 3 + // 3: 4 + // where each process has a memory usage of 2^pid (so that we can easily check the sum of memory usage). + // + // Process 2 and 4 will disappear during the call to memoryUsageForPidAndChildren, and we should handle that + // gracefully. + + proc := &mockProcLike{ + t: t, + mockRSS: func(t *testing.T, pid int) (uint64, error) { + spyRSSCalls = append(spyRSSCalls, pid) + + if pid == 2 || pid == 4 { + // Say that the process has disappeared + return 0, fs.ErrNotExist + } + + return uint64(1 << pid), nil + }, + mockChildren: func(t *testing.T, pid int) ([]int, error) { + spyChildrenCalls = append(spyChildrenCalls, pid) + + if pid == 1 { + // Simulate a child process with PIDs 2, 3 + return []int{2, 3}, nil + } + + if pid == 3 { + // Simulate a child process with PID 4 + return []int{4}, nil + } + + return nil, fs.ErrNotExist + }, + } + + usage, err := memoryUsageForPidAndChildren(ctx, proc, 1) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if diff := cmp.Diff([]int{1, 2, 3, 4}, spyRSSCalls); diff != "" { + t.Fatalf("RSS calls mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff([]int{1, 3}, spyChildrenCalls); diff != "" { + t.Fatalf("Children calls mismatch (-want +got):\n%s", diff) + } + + expectedUsage := uint64(1<<1 + 1<<3) + if usage != expectedUsage { + t.Errorf("Expected memory usage %d, got %d", expectedUsage, usage) + } +} + +func TestMaxMemoryUsageErrorObserverNotStarted(t *testing.T) { + cmd := allocatingGoProgram(t, 50*1024*1024) // 50 MB + err := cmd.Start() + if err != nil { + t.Fatalf("failed to start test program: %v", err) + } + defer func() { + _ = cmd.Process.Kill() + }() + + observer, err := NewLinuxObserver(context.Background(), cmd, 1*time.Millisecond) + if err != nil { + t.Fatalf("failed to create observer: %v", err) + } + + _, err = observer.MaxMemoryUsage() + if !errors.Is(err, errObserverNotStarted) { + t.Errorf("expected errObserverNotStarted, got: %v", err) + } +} + +type mockProcLike struct { + t *testing.T + + mockRSS func(*testing.T, int) (uint64, error) + mockChildren func(*testing.T, int) ([]int, error) +} + +func (m *mockProcLike) RSS(pid int) (uint64, error) { + if m.mockRSS != nil { + return m.mockRSS(m.t, pid) + } + + m.t.Fatal("RSS not implemented") + return 0, nil +} + +func (m *mockProcLike) Children(pid int) ([]int, error) { + if m.mockChildren != nil { + return m.mockChildren(m.t, pid) + } + + m.t.Fatal("Children not implemented") + return nil, nil +} + +func BenchmarkLinuxObservationApproaches(b *testing.B) { + b.Run("Observer", func(b *testing.B) { + for _, interval := range []time.Duration{1 * time.Millisecond, 10 * time.Millisecond, 100 * time.Millisecond} { + b.Run(interval.String(), func(b *testing.B) { + benchFunc(b, interval) + }) + } + }) + + b.Run("NoObserver", func(b *testing.B) { + benchFunc(b, 0) + }) +} + +func benchFunc(b *testing.B, observerInterval time.Duration) { + for range b.N { + workerPool := pool.New().WithErrors() + + for range runtime.NumCPU() { + workerPool.Go(func() error { + cmd := allocatingGoProgram(b, 50*1024*1024) + + var buf bytes.Buffer + cmd.Stderr = &buf + err := cmd.Start() + if err != nil { + return errors.Errorf("starting command: %v, stdErr: %s", err, buf.String()) + } + + observer := NewNoOpObserver() + + if observerInterval > 0 { + observer, err = NewLinuxObserver(context.Background(), cmd, observerInterval) + if err != nil { + return errors.Errorf("failed to create observer: %v", err) + } + } + + observer.Start() + defer observer.Stop() + + err = cmd.Wait() + if err != nil { + return errors.Errorf("waiting for command: %v, stdErr: %s", err, buf.String()) + } + + _, isNoOpObserver := observer.(*noopObserver) + if isNoOpObserver { + return nil + } + + memory, err := observer.MaxMemoryUsage() + if err != nil { + return errors.Errorf("getting memory usage: %v", err) + } + + memoryLow := bytesize.Bytes(10 << 20) // 10MB + memoryHigh := bytesize.Bytes(100 << 20) // 100MB + + if !(memoryLow < memory && memory < memoryHigh) { + return errors.Errorf("memory usage is not in the expected range (low: %s, high: %s): %s", humanize.Bytes(uint64(memoryLow)), humanize.Bytes(uint64(memoryHigh)), humanize.Bytes(uint64(memory))) + } + + return nil + }) + + if err := workerPool.Wait(); err != nil { + b.Fatalf("error in worker pool: %v", err) + } + } + } +} + +var _ processInfoProvider = &mockProcLike{} diff --git a/internal/memcmd/observer_test.go b/internal/memcmd/observer_test.go new file mode 100644 index 00000000000..705f11a5b48 --- /dev/null +++ b/internal/memcmd/observer_test.go @@ -0,0 +1,112 @@ +package memcmd + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/bazelbuild/rules_go/go/runfiles" +) + +var goBinary = "go" + +func init() { + if path := os.Getenv("GO_RLOCATIONPATH"); path != "" { + var err error + goBinary, err = runfiles.Rlocation(path) + if err != nil { + panic(err) + } + } +} + +func allocatingGoProgram(t testing.TB, allocationSizeBytes uint64) *exec.Cmd { + t.Helper() + + const goTemplate = ` +package main + +import ( + "fmt" + "time" + "os" +) + +func main() { + var slice []byte + + if len(os.Args) > 0 { // Conditional that's always true to force the slice to be allocated on the heap + slice = make([]byte, %d) + for i := 0; i < len(slice); i++ { + slice[i] = byte(i & 0xff) + } + } + + time.Sleep(500 * time.Millisecond) + fmt.Println(len(slice)) // Don't optimize the slice away +}` + + goSource := fmt.Sprintf(goTemplate, allocationSizeBytes) + + goFile := filepath.Join(t.TempDir(), "main.go") + err := os.WriteFile(goFile, []byte(goSource), 0o644) // permissions: -rw-r--r-- + if err != nil { + t.Fatalf("failed to write test program: %v", err) + } + + binaryPath := filepath.Join(t.TempDir(), "main") + + const bashTemplateGoBuild = ` +#!/usr/bin/env bash +set -euxo pipefail + +%s build -o %s %s +` + + ctx := context.Background() + + args := []string{ + "--login", // -l: login shell (so that we know that the PATH is set correctly for asdf if needed) + "-c", + fmt.Sprintf(bashTemplateGoBuild, goBinary, binaryPath, goFile), + } + + goBuildCmd := exec.CommandContext(ctx, "bash", args...) + goBuildCmd.Env = append(goBuildCmd.Env, fmt.Sprintf("GOCACHE=%s", t.TempDir())) + + { + // Ensure that the HOME environment variable is set. This is required for + // asdf to work correctly. + hasHome := false + for _, env := range goBuildCmd.Env { + if strings.HasPrefix(env, "HOME=") { + hasHome = true + break + } + } + + if !hasHome { + if home, err := os.UserHomeDir(); err == nil { + goBuildCmd.Env = append(goBuildCmd.Env, fmt.Sprintf("HOME=%s", home)) + } + } + } + + _, err = goBuildCmd.Output() + if err != nil { + t.Fatalf("failed to compile test program: %v", err) + } + + const bashTemplateRunCmd = ` +#!/usr/bin/env bash +set -euxo pipefail + +%s +echo "done" # force bash to fork +` + return exec.Command("bash", "-c", fmt.Sprintf(bashTemplateRunCmd, binaryPath)) +}