diff --git a/README.md b/README.md index 038e31e..4e91178 100644 --- a/README.md +++ b/README.md @@ -142,6 +142,14 @@ script.Exec("ping 127.0.0.1").Stdout() Note that `Exec` runs the command concurrently: it doesn't wait for the command to complete before returning any output. That's good, because this `ping` command will run forever (or until we get bored). +If you need to prevent a command from running indefinitely, you can set a timeout: + +```go +script.NewPipe().WithTimeout(5*time.Second).Exec("potentially-hanging-command").Wait() +``` + +If the command doesn't complete within the timeout, it will be terminated and an error will be set. + Instead, when we read from the pipe using `Stdout`, we see each line of output as it's produced: ``` @@ -328,6 +336,7 @@ These are methods on a pipe that change its configuration: | [`WithReader`](https://pkg.go.dev/github.com/bitfield/script#Pipe.WithReader) | pipe source | | [`WithStderr`](https://pkg.go.dev/github.com/bitfield/script#Pipe.WithStderr) | standard error output stream for command | | [`WithStdout`](https://pkg.go.dev/github.com/bitfield/script#Pipe.WithStdout) | standard output stream for pipe | +| [`WithTimeout`](https://pkg.go.dev/github.com/bitfield/script#Pipe.WithTimeout) | timeout for exec commands | ## Filters diff --git a/script.go b/script.go index d7d1bc3..0905c64 100644 --- a/script.go +++ b/script.go @@ -3,6 +3,7 @@ package script import ( "bufio" "container/ring" + "context" "crypto/sha256" "encoding/base64" "encoding/hex" @@ -22,6 +23,7 @@ import ( "strings" "sync" "text/template" + "time" "github.com/itchyny/gojq" "mvdan.cc/sh/v3/shell" @@ -34,10 +36,11 @@ type Pipe struct { stdout io.Writer httpClient *http.Client - mu *sync.Mutex - err error - stderr io.Writer - env []string + mu *sync.Mutex + err error + stderr io.Writer + env []string + timeout time.Duration } // Args creates a pipe containing the program's command-line arguments from @@ -390,6 +393,15 @@ func (p *Pipe) environment() []string { return p.env } +func (p *Pipe) getTimeout() time.Duration { + if p.mu == nil { + return 0 + } + p.mu.Lock() + defer p.mu.Unlock() + return p.timeout +} + // Error returns any error present on the pipe, or nil otherwise. // Error is not a sink and does not wait until the pipe reaches // completion. To wait for completion before returning the error, @@ -413,6 +425,11 @@ func (p *Pipe) Error() error { // The command inherits the current process's environment, optionally modified // by [Pipe.WithEnv]. // +// # Timeout +// +// If a timeout has been set using [Pipe.WithTimeout], the command will be +// terminated if it does not complete within the specified duration. +// // # Error handling // // If the command had a non-zero exit status, the pipe's error status will also @@ -432,7 +449,19 @@ func (p *Pipe) Exec(cmdLine string) *Pipe { if err != nil { return err } - cmd := exec.Command(args[0], args[1:]...) + + timeout := p.getTimeout() + var cmd *exec.Cmd + + if timeout > 0 { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd = exec.CommandContext(ctx, args[0], args[1:]...) + } else { + cmd = exec.Command(args[0], args[1:]...) + } + cmd.Stdin = r cmd.Stdout = w cmd.Stderr = w @@ -455,8 +484,8 @@ func (p *Pipe) Exec(cmdLine string) *Pipe { // ExecForEach renders cmdLine as a Go template for each line of input, running // the resulting command, and produces the combined output of all these -// commands in sequence. See [Pipe.Exec] for details on error handling and -// environment variables. +// commands in sequence. See [Pipe.Exec] for details on error handling, +// environment variables, and timeout behavior. // // This is mostly useful for substituting data into commands using Go template // syntax. For example: @@ -468,6 +497,7 @@ func (p *Pipe) ExecForEach(cmdLine string) *Pipe { return p.WithError(err) } return p.Filter(func(r io.Reader, w io.Writer) error { + timeout := p.getTimeout() scanner := newScanner(r) for scanner.Scan() { cmdLine := new(strings.Builder) @@ -479,7 +509,18 @@ func (p *Pipe) ExecForEach(cmdLine string) *Pipe { if err != nil { return err } - cmd := exec.Command(args[0], args[1:]...) + + var cmd *exec.Cmd + + if timeout > 0 { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd = exec.CommandContext(ctx, args[0], args[1:]...) + } else { + cmd = exec.Command(args[0], args[1:]...) + } + cmd.Stdout = w cmd.Stderr = w pipeStderr := p.stdErr() @@ -497,6 +538,11 @@ func (p *Pipe) ExecForEach(cmdLine string) *Pipe { err = cmd.Wait() if err != nil { fmt.Fprintln(cmd.Stderr, err) + // For timeout errors, we should return them instead of continuing + if strings.Contains(err.Error(), "context deadline exceeded") || + strings.Contains(err.Error(), "signal: killed") { + return err + } continue } } @@ -539,7 +585,7 @@ func (p *Pipe) Filter(filter func(io.Reader, io.Writer) error) *Pipe { } pr, pw := io.Pipe() origReader := p.Reader - p = p.WithReader(pr) + p = p.WithReader(pr).WithTimeout(p.getTimeout()) go func() { defer pw.Close() err := filter(origReader, pw) @@ -720,7 +766,7 @@ func (p *Pipe) Join() *Pipe { // The exact dialect of JQ supported is that provided by // [github.com/itchyny/gojq], whose documentation explains the differences // between it and standard JQ. -// +// // [JSONLines]: https://jsonlines.org/ func (p *Pipe) JQ(query string) *Pipe { parsedQuery, err := gojq.Parse(query) @@ -971,6 +1017,22 @@ func (p *Pipe) WithEnv(env []string) *Pipe { return p } +// WithTimeout sets a timeout for subsequent [Pipe.Exec] and [Pipe.ExecForEach] +// commands. If the command does not complete within the specified duration, it will be +// terminated and the pipe's error status will be set. +// +// A zero or negative duration means no timeout (the default). +// +// Example: +// +// script.Get("https://httpbin.org/delay/5").WithTimeout(500 * time.Millisecond).Wait() +func (p *Pipe) WithTimeout(timeout time.Duration) *Pipe { + p.mu.Lock() + defer p.mu.Unlock() + p.timeout = timeout + return p +} + // WithError sets the error err on the pipe. func (p *Pipe) WithError(err error) *Pipe { p.SetError(err) diff --git a/script_test.go b/script_test.go index 12f4d90..e181fec 100644 --- a/script_test.go +++ b/script_test.go @@ -18,6 +18,7 @@ import ( "strings" "testing" "testing/iotest" + "time" "github.com/bitfield/script" "github.com/google/go-cmp/cmp" @@ -1265,6 +1266,150 @@ func TestExecRunsGoHelpAndGetsUsageMessage(t *testing.T) { } } +func TestExecWithTimeout_AllowsCommandToCompleteWithinTimeout(t *testing.T) { + t.Parallel() + _, err := script.NewPipe().WithTimeout(5 * time.Second).Exec("echo test").String() + if err != nil { + t.Errorf("unexpected error for command within timeout: %v", err) + } +} + +func TestExecWithTimeout_NoTimeoutWhenNotSet(t *testing.T) { + t.Parallel() + _, err := script.Exec("echo test").String() + if err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestExecWithTimeout_ZeroTimeoutMeansNoTimeout(t *testing.T) { + t.Parallel() + _, err := script.NewPipe().WithTimeout(0).Exec("echo test").String() + if err != nil { + t.Errorf("unexpected error with zero timeout: %v", err) + } +} + +func TestExecWithTimeout_TerminatesLongRunningCommand(t *testing.T) { + t.Parallel() + p := script.NewPipe().WithTimeout(100 * time.Millisecond).Exec("sleep 2") + err := p.Wait() + if err == nil { + t.Error("want error for command exceeding timeout") + } + + if !strings.Contains(err.Error(), "context deadline exceeded") && + !strings.Contains(err.Error(), "signal: killed") { + t.Errorf("want context deadline exceeded or killed error, got: %v", err) + } +} + +func TestExecForEachWithTimeout_RespectsTimeout(t *testing.T) { + t.Parallel() + p := script.Echo("test").WithTimeout(100 * time.Millisecond).ExecForEach("sleep 2") + err := p.Wait() + if err == nil { + t.Error("want error for ExecForEach exceeding timeout") + return + } + if !strings.Contains(err.Error(), "context deadline exceeded") && + !strings.Contains(err.Error(), "signal: killed") { + t.Errorf("want context deadline exceeded or killed error, got: %v", err) + } +} + +func TestExecForEachWithTimeout_AllowsQuickCommands(t *testing.T) { + t.Parallel() + output, err := script.Echo("hello\nworld").WithTimeout(5 * time.Second).ExecForEach("echo {{.}}").String() + if err != nil { + t.Errorf("unexpected error for quick commands: %v", err) + } + expected := "hello\nworld\n" + if output != expected { + t.Errorf("want %q, got %q", expected, output) + } +} + +func TestExecForEachWithTimeout_MultipleCommands(t *testing.T) { + t.Parallel() + input := "cmd1\ncmd2\ncmd3" + p := script.Echo(input).WithTimeout(50 * time.Millisecond).ExecForEach("sleep 1") + err := p.Wait() + if err == nil { + t.Error("want error for multiple commands exceeding timeout") + } +} + +func TestExecForEachWithTimeout_NoTimeout(t *testing.T) { + t.Parallel() + output, err := script.Echo("test").ExecForEach("echo {{.}}").String() + if err != nil { + t.Errorf("unexpected error without timeout: %v", err) + } + expected := "test\n" + if output != expected { + t.Errorf("want %q, got %q", expected, output) + } +} + +func TestExecForEachWithTimeout_ZeroTimeout(t *testing.T) { + t.Parallel() + output, err := script.Echo("test").WithTimeout(0).ExecForEach("echo {{.}}").String() + if err != nil { + t.Errorf("unexpected error with zero timeout: %v", err) + } + expected := "test\n" + if output != expected { + t.Errorf("want %q, got %q", expected, output) + } +} + +func TestExecForEachWithTimeout_DifferentCommands(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + command string + timeout time.Duration + shouldTimeout bool + }{ + { + name: "echo command with long timeout", + command: "echo {{.}}", + timeout: 5 * time.Second, + shouldTimeout: false, + }, + { + name: "sleep command with short timeout", + command: "sleep 2", + timeout: 100 * time.Millisecond, + shouldTimeout: true, + }, + { + name: "ping command with short timeout", + command: "ping -c 5 127.0.0.1", + timeout: 50 * time.Millisecond, + shouldTimeout: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + p := script.Echo("test").WithTimeout(tc.timeout).ExecForEach(tc.command) + err := p.Wait() + + if tc.shouldTimeout { + if err == nil { + t.Error("want error for command that should timeout") + } + } else { + if err != nil { + t.Errorf("unexpected error for command that should not timeout: %v", err) + } + } + }) + } +} + func TestFileOutputsContentsOfSpecifiedFile(t *testing.T) { t.Parallel() want := "This is the first line in the file.\nHello, world.\nThis is another line in the file.\n"