diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..d4369ff --- /dev/null +++ b/errors.go @@ -0,0 +1,73 @@ +package extism + +import "bytes" + +// errPrefix is a sentinel byte sequence used to identify errors originating from host functions. +// It helps distinguish these errors when serialized to bytes. +var errPrefix = []byte{0xFF, 0xFE, 0xFD} + +// hostFuncError wraps another error and identifies it as a host function error. +// When a host function is called and that host function wants to return an error, +// internally extism will wrap that error in this type before serializing the error +// using the bytes method, and writing the error into WASM memory so that the guest +// can read the error. +// +// The bytes method appends a set of sentinel bytes which the host can later read +// when calls `error_get` to see if the error that was previously set was set by +// the host or the guest. If we see the matching sentinel bytes in the prefix of +// the error bytes, then we know that the error was a host function error, and the +// host can ignore it. +// +// The purpose of this is to allow us to piggyback off the existing `error_get` and +// `error_set` extism kernel functions. These previously were only used by guests to +// communicate errors to the host. In order to prevent host plugin function calls from +// seeing their own host function errors, the plugin can check and see if the error +// was created via a host function using this type. +// +// This is an effort to preserve backwards compatibility with existing PDKs which +// may not know to call `error_get` to see if there are any host->guest errors. We +// need the host SDKs to handle the scenario where the host calls `error_set` but +// the guest never calls `error_get` resulting in the host seeing their own error. +type hostFuncError struct { + inner error // The underlying error being wrapped. +} + +// Error implements the error interface for hostFuncError. +// It returns the message of the wrapped error or an empty string if there is no inner error. +func (e *hostFuncError) Error() string { + if e.inner == nil { + return "" + } + return e.inner.Error() +} + +// bytes serializes the hostFuncError into a byte slice. +// If there is no inner error, it returns nil. Otherwise, it prefixes the error message +// with a sentinel byte sequence to facilitate identification during deserialization. +func (e *hostFuncError) bytes() []byte { + if e.inner == nil { + return nil + } + return append(errPrefix, []byte(e.inner.Error())...) +} + +// isHostFuncError checks if the given byte slice represents a serialized host function error. +// It verifies the presence of the sentinel prefix to make this determination. +func isHostFuncError(error []byte) bool { + if error == nil { + return false + } + if len(error) < len(errPrefix) { + return false // The slice is too short to contain the prefix. + } + return bytes.Equal(error[:len(errPrefix)], errPrefix) +} + +// newHostFuncError creates a new hostFuncError instance wrapping the provided error. +// If the input error is nil, it returns nil to avoid creating redundant wrappers. +func newHostFuncError(err error) *hostFuncError { + if err == nil { + return nil + } + return &hostFuncError{inner: err} +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..d57111b --- /dev/null +++ b/errors_test.go @@ -0,0 +1,120 @@ +package extism + +import ( + "bytes" + "errors" + "testing" +) + +func TestNewHostFuncError(t *testing.T) { + tests := []struct { + name string + inputErr error + wantNil bool + }{ + { + name: "nil error input", + inputErr: nil, + wantNil: true, + }, + { + name: "non-nil error input", + inputErr: errors.New("test error"), + wantNil: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := newHostFuncError(tt.inputErr) + if (err == nil) != tt.wantNil { + t.Errorf("got nil: %v, want nil: %v", err == nil, tt.wantNil) + } + }) + } +} + +func TestBytes(t *testing.T) { + tests := []struct { + name string + inputErr error + wantPrefix []byte + wantMsg string + wantNil bool + }{ + { + name: "nil inner error", + inputErr: nil, + wantPrefix: nil, + wantMsg: "", + wantNil: true, + }, + { + name: "non-nil inner error", + inputErr: errors.New("test error"), + wantPrefix: errPrefix, + wantMsg: "test error", + wantNil: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &hostFuncError{inner: tt.inputErr} + b := e.bytes() + + if tt.wantNil { + if b != nil { + t.Errorf("expected nil, got %x", b) + } + return + } + + if len(b) < len(tt.wantPrefix) { + t.Fatalf("returned bytes too short, got %x, want prefix %x", b, tt.wantPrefix) + } + + if !bytes.HasPrefix(b, tt.wantPrefix) { + t.Errorf("expected prefix %x, got %x", tt.wantPrefix, b[:len(tt.wantPrefix)]) + } + + gotMsg := string(b[len(tt.wantPrefix):]) + if gotMsg != tt.wantMsg { + t.Errorf("expected message %q, got %q", tt.wantMsg, gotMsg) + } + }) + } +} + +func TestIsHostFuncError(t *testing.T) { + tests := []struct { + name string + inputErr []byte + want bool + }{ + { + name: "nil error input", + inputErr: nil, + want: false, + }, + { + name: "not a hostFuncError", + inputErr: []byte("normal error"), + want: false, + }, + { + name: "valid hostFuncError", + inputErr: newHostFuncError(errors.New("host function error")).bytes(), + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isHostFuncError(tt.inputErr) + if got != tt.want { + t.Errorf("isHostFuncError(%v) = %v, want %v", tt.inputErr, got, tt.want) + } + }) + } +} diff --git a/extism.go b/extism.go index 1e8c120..0ec4fdd 100644 --- a/extism.go +++ b/extism.go @@ -110,12 +110,9 @@ type Plugin struct { close []func(ctx context.Context) error extism api.Module - //Runtime *Runtime - //Main Module - module api.Module - Timeout time.Duration - Config map[string]string - // NOTE: maybe we can have some nice methods for getting/setting vars + module api.Module + Timeout time.Duration + Config map[string]string Var map[string][]byte AllowedHosts []string AllowedPaths map[string]string @@ -435,6 +432,15 @@ func (p *Plugin) GetErrorWithContext(ctx context.Context) string { } mem, _ := p.Memory().Read(uint32(errOffs[0]), uint32(errLen[0])) + + // A host function error is an error set by a host function during a guest->host function + // call. These errors are intended to be handled only by the guest. If the error makes it + // back here, the guest PDK most likely doesn't know to handle it, in which case we should + // ignore it here. + if isHostFuncError(mem) { + return "" + } + return string(mem) } diff --git a/extism_test.go b/extism_test.go index 38b3584..6135b9f 100644 --- a/extism_test.go +++ b/extism_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" observe "github.com/dylibso/observe-sdk/go" "github.com/dylibso/observe-sdk/go/adapter/stdout" @@ -1038,6 +1039,54 @@ func TestEnableExperimentalFeature(t *testing.T) { } } +// This test creates host functions that set errors. Previously, host functions +// would have to panic to communicate host function errors, but this unfortunately +// stopped execution of the guest function. In other words, there was no way to +// gracefully communicate errors from host->guest when the guest called a host +// function. This has since been fixed and this test proves that even when guests +// don't reset the error state, the host can still determine that the current error +// state was a host->guest error and not a guest->host error and ignores it. +func TestHostFunctionError(t *testing.T) { + manifest := manifest("host_multiple.wasm") + + hostGreenMessage := NewHostFunctionWithStack( + "hostGreenMessage", + func(ctx context.Context, plugin *CurrentPlugin, stack []uint64) { + plugin.SetError(ctx, errors.New("this is an error")) + }, + []ValueType{ValueTypePTR}, + []ValueType{ValueTypePTR}, + ) + hostPurpleMessage := NewHostFunctionWithStack( + "hostPurpleMessage", + func(ctx context.Context, plugin *CurrentPlugin, stack []uint64) { + plugin.SetError(ctx, errors.New("this is an error")) + }, + []ValueType{ValueTypePTR}, + []ValueType{ValueTypePTR}, + ) + + ctx := context.Background() + p, err := NewCompiledPlugin(ctx, manifest, PluginConfig{ + EnableWasi: true, + }, []HostFunction{ + hostGreenMessage, + hostPurpleMessage, + }) + require.NoError(t, err) + + pluginInst, err := p.Instance(ctx, PluginInstanceConfig{ + ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(), + }) + require.NoError(t, err) + + _, _, err = pluginInst.Call( + "say_green", + []byte("John Doe"), + ) + require.NoError(t, err, "the host function should have returned an error to the guest but it should not have propagated back to the host") +} + func BenchmarkInitialize(b *testing.B) { ctx := context.Background() cache := wazero.NewCompilationCache() diff --git a/host.go b/host.go index a32b456..9c85136 100644 --- a/host.go +++ b/host.go @@ -109,6 +109,24 @@ func (p *Plugin) currentPlugin() *CurrentPlugin { return &CurrentPlugin{p} } +// SetError allows the host function to set an error that will be +// gracefully returned by extism guest modules. +func (p *CurrentPlugin) SetError(ctx context.Context, err error) { + if err == nil { + return + } + + offset, err := p.WriteBytes(newHostFuncError(err).bytes()) + if err != nil { + panic(fmt.Sprintf("failed to write error message to memory: %v", err)) + } + + _, err = p.plugin.extism.ExportedFunction("error_set").Call(ctx, offset) + if err != nil { + panic(fmt.Sprintf("failed to set error: %v", err)) + } +} + func (p *CurrentPlugin) Log(level LogLevel, message string) { p.plugin.Log(level, message) } @@ -306,7 +324,7 @@ func instantiateEnvModule(ctx context.Context, rt wazero.Runtime) (api.Module, e WithGoModuleFunction(api.GoModuleFunc(store_u64), []ValueType{ValueTypeI64, ValueTypeI64}, []ValueType{}). Export("store_u64") - hostFunc := func(name string, f interface{}) { + hostFunc := func(name string, f any) { builder.NewFunctionBuilder().WithFunc(f).Export(name) } diff --git a/plugin.go b/plugin.go index c1139ab..99f8673 100644 --- a/plugin.go +++ b/plugin.go @@ -2,6 +2,7 @@ package extism import ( "context" + "encoding/base64" "errors" "fmt" observe "github.com/dylibso/observe-sdk/go" @@ -220,7 +221,12 @@ func (p *CompiledPlugin) Instance(ctx context.Context, config PluginInstanceConf if moduleConfig == nil { moduleConfig = wazero.NewModuleConfig() } - moduleConfig = moduleConfig.WithName(strconv.Itoa(int(p.instanceCount.Add(1)))) + moduleConfig = moduleConfig. + WithName(strconv.Itoa(int(p.instanceCount.Add(1)))). + // We can tell the guest module what the error prefix will be for errors that are set + // by host functions. Guests should trim this prefix off of their error messages when + // reading them. + WithEnv("EXTISM_HOST_FUNC_ERROR_PREFIX", base64.StdEncoding.EncodeToString(errPrefix)) // NOTE: this is only necessary for guest modules because // host modules have the same access privileges as the host itself