Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ script.Exec("ping 127.0.0.1").Stdout()

Note that `Exec` runs the command concurrently: it doesn't wait for the command to complete before returning any output. That's good, because this `ping` command will run forever (or until we get bored).

If you need to prevent a command from running indefinitely, you can set a timeout:

```go
script.NewPipe().WithTimeout(5*time.Second).Exec("potentially-hanging-command").Wait()
```

If the command doesn't complete within the timeout, it will be terminated and an error will be set.

Instead, when we read from the pipe using `Stdout`, we see each line of output as it's produced:

```
Expand Down Expand Up @@ -328,6 +336,7 @@ These are methods on a pipe that change its configuration:
| [`WithReader`](https://pkg.go.dev/github.com/bitfield/script#Pipe.WithReader) | pipe source |
| [`WithStderr`](https://pkg.go.dev/github.com/bitfield/script#Pipe.WithStderr) | standard error output stream for command |
| [`WithStdout`](https://pkg.go.dev/github.com/bitfield/script#Pipe.WithStdout) | standard output stream for pipe |
| [`WithTimeout`](https://pkg.go.dev/github.com/bitfield/script#Pipe.WithTimeout) | timeout for exec commands |

## Filters

Expand Down
82 changes: 72 additions & 10 deletions script.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package script
import (
"bufio"
"container/ring"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
Expand All @@ -22,6 +23,7 @@ import (
"strings"
"sync"
"text/template"
"time"

"github.com/itchyny/gojq"
"mvdan.cc/sh/v3/shell"
Expand All @@ -34,10 +36,11 @@ type Pipe struct {
stdout io.Writer
httpClient *http.Client

mu *sync.Mutex
err error
stderr io.Writer
env []string
mu *sync.Mutex
err error
stderr io.Writer
env []string
timeout time.Duration
}

// Args creates a pipe containing the program's command-line arguments from
Expand Down Expand Up @@ -390,6 +393,15 @@ func (p *Pipe) environment() []string {
return p.env
}

func (p *Pipe) getTimeout() time.Duration {
if p.mu == nil {
return 0
}
p.mu.Lock()
defer p.mu.Unlock()
return p.timeout
}

// Error returns any error present on the pipe, or nil otherwise.
// Error is not a sink and does not wait until the pipe reaches
// completion. To wait for completion before returning the error,
Expand All @@ -413,6 +425,11 @@ func (p *Pipe) Error() error {
// The command inherits the current process's environment, optionally modified
// by [Pipe.WithEnv].
//
// # Timeout
//
// If a timeout has been set using [Pipe.WithTimeout], the command will be
// terminated if it does not complete within the specified duration.
//
// # Error handling
//
// If the command had a non-zero exit status, the pipe's error status will also
Expand All @@ -432,7 +449,19 @@ func (p *Pipe) Exec(cmdLine string) *Pipe {
if err != nil {
return err
}
cmd := exec.Command(args[0], args[1:]...)

timeout := p.getTimeout()
var cmd *exec.Cmd

if timeout > 0 {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

cmd = exec.CommandContext(ctx, args[0], args[1:]...)
} else {
cmd = exec.Command(args[0], args[1:]...)
}

cmd.Stdin = r
cmd.Stdout = w
cmd.Stderr = w
Expand All @@ -455,8 +484,8 @@ func (p *Pipe) Exec(cmdLine string) *Pipe {

// ExecForEach renders cmdLine as a Go template for each line of input, running
// the resulting command, and produces the combined output of all these
// commands in sequence. See [Pipe.Exec] for details on error handling and
// environment variables.
// commands in sequence. See [Pipe.Exec] for details on error handling,
// environment variables, and timeout behavior.
//
// This is mostly useful for substituting data into commands using Go template
// syntax. For example:
Expand All @@ -468,6 +497,7 @@ func (p *Pipe) ExecForEach(cmdLine string) *Pipe {
return p.WithError(err)
}
return p.Filter(func(r io.Reader, w io.Writer) error {
timeout := p.getTimeout()
scanner := newScanner(r)
for scanner.Scan() {
cmdLine := new(strings.Builder)
Expand All @@ -479,7 +509,18 @@ func (p *Pipe) ExecForEach(cmdLine string) *Pipe {
if err != nil {
return err
}
cmd := exec.Command(args[0], args[1:]...)

var cmd *exec.Cmd

if timeout > 0 {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

cmd = exec.CommandContext(ctx, args[0], args[1:]...)
} else {
cmd = exec.Command(args[0], args[1:]...)
}

cmd.Stdout = w
cmd.Stderr = w
pipeStderr := p.stdErr()
Expand All @@ -497,6 +538,11 @@ func (p *Pipe) ExecForEach(cmdLine string) *Pipe {
err = cmd.Wait()
if err != nil {
fmt.Fprintln(cmd.Stderr, err)
// For timeout errors, we should return them instead of continuing
if strings.Contains(err.Error(), "context deadline exceeded") ||
strings.Contains(err.Error(), "signal: killed") {
return err
}
continue
}
}
Expand Down Expand Up @@ -539,7 +585,7 @@ func (p *Pipe) Filter(filter func(io.Reader, io.Writer) error) *Pipe {
}
pr, pw := io.Pipe()
origReader := p.Reader
p = p.WithReader(pr)
p = p.WithReader(pr).WithTimeout(p.getTimeout())
go func() {
defer pw.Close()
err := filter(origReader, pw)
Expand Down Expand Up @@ -720,7 +766,7 @@ func (p *Pipe) Join() *Pipe {
// The exact dialect of JQ supported is that provided by
// [github.com/itchyny/gojq], whose documentation explains the differences
// between it and standard JQ.
//
//
// [JSONLines]: https://jsonlines.org/
func (p *Pipe) JQ(query string) *Pipe {
parsedQuery, err := gojq.Parse(query)
Expand Down Expand Up @@ -971,6 +1017,22 @@ func (p *Pipe) WithEnv(env []string) *Pipe {
return p
}

// WithTimeout sets a timeout for subsequent [Pipe.Exec] and [Pipe.ExecForEach]
// commands. If the command does not complete within the specified duration, it will be
// terminated and the pipe's error status will be set.
//
// A zero or negative duration means no timeout (the default).
//
// Example:
//
// script.Get("https://httpbin.org/delay/5").WithTimeout(500 * time.Millisecond).Wait()
func (p *Pipe) WithTimeout(timeout time.Duration) *Pipe {
p.mu.Lock()
defer p.mu.Unlock()
p.timeout = timeout
return p
}

// WithError sets the error err on the pipe.
func (p *Pipe) WithError(err error) *Pipe {
p.SetError(err)
Expand Down
145 changes: 145 additions & 0 deletions script_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"strings"
"testing"
"testing/iotest"
"time"

"github.com/bitfield/script"
"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -1265,6 +1266,150 @@ func TestExecRunsGoHelpAndGetsUsageMessage(t *testing.T) {
}
}

func TestExecWithTimeout_AllowsCommandToCompleteWithinTimeout(t *testing.T) {
t.Parallel()
_, err := script.NewPipe().WithTimeout(5 * time.Second).Exec("echo test").String()
if err != nil {
t.Errorf("unexpected error for command within timeout: %v", err)
}
}

func TestExecWithTimeout_NoTimeoutWhenNotSet(t *testing.T) {
t.Parallel()
_, err := script.Exec("echo test").String()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}

func TestExecWithTimeout_ZeroTimeoutMeansNoTimeout(t *testing.T) {
t.Parallel()
_, err := script.NewPipe().WithTimeout(0).Exec("echo test").String()
if err != nil {
t.Errorf("unexpected error with zero timeout: %v", err)
}
}

func TestExecWithTimeout_TerminatesLongRunningCommand(t *testing.T) {
t.Parallel()
p := script.NewPipe().WithTimeout(100 * time.Millisecond).Exec("sleep 2")
err := p.Wait()
if err == nil {
t.Error("want error for command exceeding timeout")
}

if !strings.Contains(err.Error(), "context deadline exceeded") &&
!strings.Contains(err.Error(), "signal: killed") {
t.Errorf("want context deadline exceeded or killed error, got: %v", err)
}
}

func TestExecForEachWithTimeout_RespectsTimeout(t *testing.T) {
t.Parallel()
p := script.Echo("test").WithTimeout(100 * time.Millisecond).ExecForEach("sleep 2")
err := p.Wait()
if err == nil {
t.Error("want error for ExecForEach exceeding timeout")
return
}
if !strings.Contains(err.Error(), "context deadline exceeded") &&
!strings.Contains(err.Error(), "signal: killed") {
t.Errorf("want context deadline exceeded or killed error, got: %v", err)
}
}

func TestExecForEachWithTimeout_AllowsQuickCommands(t *testing.T) {
t.Parallel()
output, err := script.Echo("hello\nworld").WithTimeout(5 * time.Second).ExecForEach("echo {{.}}").String()
if err != nil {
t.Errorf("unexpected error for quick commands: %v", err)
}
expected := "hello\nworld\n"
if output != expected {
t.Errorf("want %q, got %q", expected, output)
}
}

func TestExecForEachWithTimeout_MultipleCommands(t *testing.T) {
t.Parallel()
input := "cmd1\ncmd2\ncmd3"
p := script.Echo(input).WithTimeout(50 * time.Millisecond).ExecForEach("sleep 1")
err := p.Wait()
if err == nil {
t.Error("want error for multiple commands exceeding timeout")
}
}

func TestExecForEachWithTimeout_NoTimeout(t *testing.T) {
t.Parallel()
output, err := script.Echo("test").ExecForEach("echo {{.}}").String()
if err != nil {
t.Errorf("unexpected error without timeout: %v", err)
}
expected := "test\n"
if output != expected {
t.Errorf("want %q, got %q", expected, output)
}
}

func TestExecForEachWithTimeout_ZeroTimeout(t *testing.T) {
t.Parallel()
output, err := script.Echo("test").WithTimeout(0).ExecForEach("echo {{.}}").String()
if err != nil {
t.Errorf("unexpected error with zero timeout: %v", err)
}
expected := "test\n"
if output != expected {
t.Errorf("want %q, got %q", expected, output)
}
}

func TestExecForEachWithTimeout_DifferentCommands(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
command string
timeout time.Duration
shouldTimeout bool
}{
{
name: "echo command with long timeout",
command: "echo {{.}}",
timeout: 5 * time.Second,
shouldTimeout: false,
},
{
name: "sleep command with short timeout",
command: "sleep 2",
timeout: 100 * time.Millisecond,
shouldTimeout: true,
},
{
name: "ping command with short timeout",
command: "ping -c 5 127.0.0.1",
timeout: 50 * time.Millisecond,
shouldTimeout: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
p := script.Echo("test").WithTimeout(tc.timeout).ExecForEach(tc.command)
err := p.Wait()

if tc.shouldTimeout {
if err == nil {
t.Error("want error for command that should timeout")
}
} else {
if err != nil {
t.Errorf("unexpected error for command that should not timeout: %v", err)
}
}
})
}
}

func TestFileOutputsContentsOfSpecifiedFile(t *testing.T) {
t.Parallel()
want := "This is the first line in the file.\nHello, world.\nThis is another line in the file.\n"
Expand Down