sourcegraph/dev/sg/internal/run/command.go
Bolaji Olajide 776701ba9c
fix(sg): acknowledge command execution state to avoid recursion when executing short running commands (#64181)
Some commands like the
[`batcheshelper-builder`](https://sourcegraph.sourcegraph.com/github.com/sourcegraph/sourcegraph/-/blob/sg.config.yaml?L821)
aren't long running commands.
This command is used to build and load an image into docker. The `cmd`
section returns an `exit 0`. This behavior combined with
`continueWatchOnExit` results in an infinite loop where the process is
continually restarted because `sg` doesn't know that the process has
finished executing and isn't a long-running process.


https://github.com/user-attachments/assets/e7a027a1-6f93-403f-9240-6a791255fba9

An example of the behavior is shown below as running `sg start batches`
results in the `batcheshelper-builder` command continually restarted.

The fix is quite simple, we return an empty receiver channel when the
process is done executing so that `sg` knows it's done and doesn't
restart the command unless there's a change.

## Test plan

* Manual testing with `go run ./dev/sg start batches` doesn't result in
an infinite loop anymore.
* Add unit tests

## Changelog
2024-07-31 16:09:44 -05:00

538 lines
14 KiB
Go

package run
import (
"bytes"
"context"
"fmt"
"io"
"net"
"os"
"os/exec"
"path/filepath"
"strings"
"syscall"
"github.com/grafana/regexp"
"github.com/sourcegraph/conc/pool"
"go.bobheadxi.dev/streamline/pipe"
"github.com/sourcegraph/sourcegraph/dev/sg/internal/secrets"
"github.com/sourcegraph/sourcegraph/dev/sg/internal/std"
"github.com/sourcegraph/sourcegraph/dev/sg/interrupt"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/lib/output"
"github.com/sourcegraph/sourcegraph/lib/process"
)
type Command struct {
Config SGConfigCommandOptions
Cmd string `yaml:"cmd"`
DefaultArgs string `yaml:"defaultArgs"`
Install string `yaml:"install"`
InstallFunc string `yaml:"install_func"`
CheckBinary string `yaml:"checkBinary"`
Watch []string `yaml:"watch"`
// ATTENTION: If you add a new field here, be sure to also handle that
// field in `Merge` (below).
}
// UnmarshalYAML implements the Unmarshaler interface for Command.
// This allows us to parse the flat YAML configuration into nested struct.
func (cmd *Command) UnmarshalYAML(unmarshal func(any) error) error {
// In order to not recurse infinitely (calling UnmarshalYAML over and over) we create a
// temporary type alias.
// First parse the Command specific options
type rawCommand Command
if err := unmarshal((*rawCommand)(cmd)); err != nil {
return err
}
// Then parse the common options from the same list into a nested struct
return unmarshal(&cmd.Config)
}
func (cmd Command) GetConfig() SGConfigCommandOptions {
return cmd.Config
}
func (cmd Command) UpdateConfig(f func(*SGConfigCommandOptions)) SGConfigCommand {
f(&cmd.Config)
return cmd
}
func (cmd Command) GetName() string {
return cmd.Config.Name
}
func (cmd Command) GetBinaryLocation() (string, error) {
if cmd.CheckBinary != "" {
return filepath.Join(cmd.Config.RepositoryRoot, cmd.CheckBinary), nil
}
return "", noBinaryError{name: cmd.Config.Name}
}
func (cmd Command) GetBazelTarget() string {
return ""
}
func (cmd Command) GetExecCmd(ctx context.Context) (*exec.Cmd, error) {
return exec.CommandContext(ctx, "bash", "-c", cmd.Cmd), nil
}
func (cmd Command) RunInstall(ctx context.Context, parentEnv map[string]string) error {
if cmd.requiresInstall() {
if cmd.hasBashInstaller() {
return cmd.bashInstall(ctx, parentEnv)
} else {
return cmd.functionInstall(ctx, parentEnv)
}
}
return nil
}
// Standard commands ignore installer
func (cmd Command) SetInstallerOutput(chan<- output.FancyLine) {}
func (cmd Command) Count() int {
return 1
}
func (cmd Command) requiresInstall() bool {
return cmd.Install != "" || cmd.InstallFunc != ""
}
func (cmd Command) hasBashInstaller() bool {
return cmd.Install != "" || cmd.InstallFunc == ""
}
func (cmd Command) bashInstall(ctx context.Context, parentEnv map[string]string) error {
output, err := BashInRoot(ctx, cmd.Install, makeEnv(parentEnv, cmd.Config.Env))
if err != nil {
return installErr{cmdName: cmd.Config.Name, output: output, originalErr: err}
}
return nil
}
func (cmd Command) functionInstall(ctx context.Context, parentEnv map[string]string) error {
fn, ok := installFuncs[cmd.InstallFunc]
if !ok {
return installErr{cmdName: cmd.Config.Name, originalErr: errors.Newf("no install func with name %q found", cmd.InstallFunc)}
}
if err := fn(ctx, makeEnvMap(parentEnv, cmd.Config.Env)); err != nil {
return installErr{cmdName: cmd.Config.Name, originalErr: err}
}
return nil
}
func (cmd Command) getWatchPaths() []string {
fullPaths := make([]string, len(cmd.Watch))
for i, path := range cmd.Watch {
fullPaths[i] = filepath.Join(cmd.Config.RepositoryRoot, path)
}
return fullPaths
}
func (cmd Command) StartWatch(ctx context.Context) (<-chan struct{}, error) {
return WatchPaths(ctx, cmd.getWatchPaths())
}
func (c Command) Merge(other Command) Command {
merged := c
merged.Config = c.Config.Merge(other.Config)
merged.Cmd = mergeStrings(c.Cmd, other.Cmd)
merged.Install = mergeStrings(c.Install, other.Install)
merged.InstallFunc = mergeStrings(c.InstallFunc, other.InstallFunc)
merged.Watch = mergeSlices(c.Watch, other.Watch)
return merged
}
func mergeStrings(a, b string) string {
if b != "" {
return b
}
return a
}
func mergeSlices[T any](a, b []T) []T {
if len(b) > 0 {
return b
}
return a
}
// Merge maps properly merges the two, as opposed to every other merge method which
// simply overwrites the first with the second.
// This is to preserve the behavior of the original code.
func mergeMaps[K comparable, V any](a, b map[K]V) map[K]V {
if a == nil {
return b
}
for k, v := range b {
a[k] = v
}
return a
}
func equal(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
var sgConn net.Conn
func OpenUnixSocket() error {
var err error
sgConn, err = net.Dial("unix", "/tmp/sg.sock")
return err
}
func getSecrets(ctx context.Context, name string, extSecrets map[string]secrets.ExternalSecret) (map[string]string, error) {
secretsEnv := map[string]string{}
if len(extSecrets) == 0 {
return secretsEnv, nil
}
secretsStore, err := secrets.FromContext(ctx)
if err != nil {
return nil, errors.Errorf("failed to get secrets store: %v", err)
}
var errs error
for envName, secret := range extSecrets {
secretsEnv[envName], err = secretsStore.GetExternal(ctx, secret)
if err != nil {
errs = errors.Append(errs,
errors.Wrapf(err, "failed to access secret %q for command %q", envName, name))
}
}
return secretsEnv, errs
}
type startedCmd struct {
*exec.Cmd
opts commandOptions
cancel func()
outEg *pool.ErrorPool
result chan error
finished bool
}
type commandOptions struct {
name string
exec *exec.Cmd
dir string
env []string
stdout outputOptions
stderr outputOptions
}
type outputOptions struct {
// When true, output will be ignored and not written to any writers
ignore bool
// When non-nil, all output will be flushed to this file and not to the terminal
logfile io.Writer
// when enabled, output will not be streamed to the writers until
// after the process is begun, only captured for later retrieval
buffer bool
// Buffer that captures the output for error logging
captured io.ReadWriter
// Additional writers to write output to
additionalWriters []io.Writer
// Channel that is used to signal that output should start streaming
// when buffer is true
start chan struct{}
}
func startSgCmd(ctx context.Context, cmd SGConfigCommand, parentEnv map[string]string) (*startedCmd, error) {
exec, err := cmd.GetExecCmd(ctx)
if err != nil {
return nil, err
}
conf := cmd.GetConfig()
secretsEnv, err := getSecrets(ctx, conf.Name, conf.ExternalSecrets)
if err != nil {
std.Out.WriteLine(output.Styledf(output.StyleWarning, "[%s] %s %s",
conf.Name, output.EmojiFailure, err.Error()))
}
opts := commandOptions{
name: conf.Name,
exec: exec,
env: makeEnv(parentEnv, secretsEnv, conf.Env),
dir: conf.RepositoryRoot,
stdout: outputOptions{ignore: conf.IgnoreStdout},
stderr: outputOptions{ignore: conf.IgnoreStderr},
}
if conf.Logfile != "" {
if logfile, err := initLogFile(conf.Logfile); err != nil {
return nil, err
} else {
opts.stdout.logfile = logfile
opts.stderr.logfile = logfile
}
}
if conf.Preamble != "" {
std.Out.WriteLine(output.Styledf(output.StyleOrange, "[%s] %s %s", conf.Name, output.EmojiInfo, conf.Preamble))
}
return startCmd(ctx, opts)
}
func initLogFile(logfile string) (io.Writer, error) {
if strings.HasPrefix(logfile, "~/") || strings.HasPrefix(logfile, "$HOME") {
home, err := os.UserHomeDir()
if err != nil {
return nil, errors.Wrap(err, "failed to get user home directory")
}
logfile = filepath.Join(home, strings.Replace(strings.Replace(logfile, "~/", "", 1), "$HOME", "", 1))
}
parent := filepath.Dir(logfile)
if err := os.MkdirAll(parent, os.ModePerm); err != nil {
return nil, err
}
// we don't have to worry about the file existing already and growing large, since this will truncate the file if it exists
return os.Create(logfile)
}
func startCmd(ctx context.Context, opts commandOptions) (*startedCmd, error) {
sc := &startedCmd{
opts: opts,
}
ctx, cancel := context.WithCancel(ctx)
sc.cancel = func() {
// The default cancel function will use a SIGKILL (9) which does
// not allow processes to cleanup. If they have spawned child processes
// those child processes will be orphaned and continue running.
// SIGINT will instead gracefully shut down the process and child processes.
if sc.Cmd.Process != nil {
// We created a process group above which we kill here.
pgid, err := syscall.Getpgid(sc.Cmd.Process.Pid)
if err != nil {
// Ignore Errno 3 (No such process) as this means the process has already exited
if !errors.Is(err, syscall.Errno(0x3)) {
panic(errors.Wrapf(err, "failed to get process group ID for %s (PID %d)", sc.opts.name, sc.Cmd.Process.Pid))
}
} else {
// note the minus sign; this signals that we want to kill the whole process group
if err := syscall.Kill(-pgid, syscall.SIGINT); err != nil {
panic(errors.Wrapf(err, "failed kill process group ID %d for cmd %s ", pgid, sc.opts.name))
}
<-sc.Exit()
}
}
cancel()
}
// Register an interrupt handler
interrupt.RegisterConcurrent(sc.cancel)
sc.Cmd = opts.exec
sc.Cmd.Dir = opts.dir
sc.Cmd.Env = opts.env
// This sets up a process group which we kill later.
// This allows us to ensure that any child processes are killed as well when this exits
// This will only work on POSIX systems
sc.Cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
if err := sc.connectOutput(ctx); err != nil {
sc.cancel()
return nil, err
}
if err := sc.Start(); err != nil {
sc.cancel()
return nil, err
}
return sc, nil
}
func (sc *startedCmd) connectOutput(ctx context.Context) error {
stdoutWriter := sc.getOutputWriter(ctx, &sc.opts.stdout, "stdout")
stderrWriter := sc.getOutputWriter(ctx, &sc.opts.stderr, "stderr")
eg, err := process.PipeOutputUnbuffered(ctx, sc.Cmd, stdoutWriter, stderrWriter)
if err != nil {
return err
}
sc.outEg = eg
return nil
}
func (sc *startedCmd) getOutputWriter(ctx context.Context, opts *outputOptions, outputName string) io.Writer {
writers := opts.additionalWriters
if writers == nil {
writers = []io.Writer{}
}
if opts.captured == nil {
opts.captured = &prefixSuffixSaver{N: 32 << 10}
}
writers = append(writers, opts.captured)
if opts.ignore {
std.Out.WriteLine(output.Styledf(output.StyleSuggestion, "Ignoring %s of %s", outputName, sc.opts.name))
} else if opts.logfile != nil {
return opts.logfile
} else {
// Create a channel to signal when output should start. If buffering is disabled, close
// the channel so output starts immediately.
opts.start = make(chan struct{})
if !opts.buffer {
close(opts.start)
}
writers = append(writers, newOutputPipe(ctx, sc.opts.name, std.Out.Output, opts.start))
}
if sgConn != nil {
w, stream := pipe.NewStream()
go func() {
err := stream.Stream(func(line string) {
_, _ = sgConn.Write([]byte(fmt.Sprintf("%s: %s\n", sc.opts.name, line)))
})
_ = w.CloseWithError(err)
}()
go func() {
<-ctx.Done()
_ = w.CloseWithError(ctx.Err())
}()
writers = append(writers, w)
}
return io.MultiWriter(writers...)
}
func (sc *startedCmd) Exit() <-chan error {
// We track the state of a single process to avoid an infinite loop
// for short-running commands. When the command is done executing,
// we simply return an empty receiver channel instead.
if sc.finished {
fakeChan := make(<-chan error)
return fakeChan
}
if sc.result == nil {
sc.result = make(chan error)
go func() {
sc.result <- sc.Wait()
close(sc.result)
}()
}
return sc.result
}
func (sc *startedCmd) Wait() error {
err := sc.wait()
// We are certain that the command is done executing at this point.
sc.finished = true
var e *exec.ExitError
if errors.As(err, &e) {
err = runErr{
cmdName: sc.opts.name,
exitCode: e.ExitCode(),
stderr: sc.CapturedStderr(),
stdout: sc.CapturedStdout(),
}
}
return err
}
var mockStartedCmdWaitFunc func() error
func (sc *startedCmd) wait() error {
if mockStartedCmdWaitFunc != nil {
return mockStartedCmdWaitFunc()
}
if err := sc.outEg.Wait(); err != nil {
return err
}
return sc.Cmd.Wait()
}
func (sc *startedCmd) CapturedStdout() string {
return captured(sc.opts.stdout)
}
func (sc *startedCmd) CapturedStderr() string {
return captured(sc.opts.stderr)
}
func captured(opts outputOptions) string {
if opts.captured == nil {
return ""
}
if output, err := io.ReadAll(opts.captured); err == nil {
return string(output)
}
return ""
}
// Begins writing output to StdOut and StdErr if it was previously buffered
func (sc *startedCmd) StartOutput() {
sc.startOutput(sc.opts.stdout)
sc.startOutput(sc.opts.stderr)
}
func (sc *startedCmd) startOutput(opts outputOptions) {
if opts.buffer && opts.start != nil {
close(opts.start)
}
}
// patternMatcher is writer which looks for a regular expression in the
// written bytes and calls a callback if a match is found
// by default it only looks for the matched pattern once
type patternMatcher struct {
regex *regexp.Regexp
callback func()
buffer bytes.Buffer
multi bool
disabled bool
}
func (writer *patternMatcher) Write(p []byte) (int, error) {
if writer.disabled {
return len(p), nil
}
n, err := writer.buffer.Write(p)
if err != nil {
return n, err
}
if writer.regex.MatchReader(&writer.buffer) {
writer.callback()
if !writer.multi {
writer.disabled = true
}
}
return n, err
}