diff --git a/dockermodelrunner/desktop/api.go b/dockermodelrunner/desktop/api.go new file mode 100644 index 00000000..801edaf9 --- /dev/null +++ b/dockermodelrunner/desktop/api.go @@ -0,0 +1,160 @@ +package desktop + +import ( + "encoding/json" + "fmt" + "time" +) + +// ProgressMessage represents a message sent during model pull operations +type ProgressMessage struct { + Type string `json:"type"` // "progress", "success", or "error" + Message string `json:"message"` // Human-readable message +} + +// OpenAIChatMessage represents a message sent during OpenAI chat operations +type OpenAIChatMessage struct { + // Role is the role of the message sender. + Role string `json:"role"` + + // Content is the content of the message. + Content string `json:"content"` +} + +// OpenAIChatRequest represents a request to the OpenAI chat API. +type OpenAIChatRequest struct { + // Model is the model to use for the chat. + Model string `json:"model"` + + // Messages is the list of messages to send to the chat. + Messages []OpenAIChatMessage `json:"messages"` + + // Stream is whether to stream the response. + Stream bool `json:"stream"` +} + +// OpenAIChatResponse represents a response from the OpenAI chat API. +type OpenAIChatResponse struct { + // ID is the ID of the chat. + ID string `json:"id"` + + // Object is the object type. + Object string `json:"object"` + + // Created is the creation time of the chat. + Created int64 `json:"created"` + + // Model is the model used for the chat. + Model string `json:"model"` + + // Choices is the list of choices from the chat. + Choices []struct { + // Delta is the delta of the choice. + Delta struct { + // Content is the content of the choice. + Content string `json:"content"` + + // Role is the role of the choice. + Role string `json:"role,omitempty"` + } `json:"delta"` + + // Index is the index of the choice. + Index int `json:"index"` + + // FinishReason is the reason the chat finished. + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} + +// OpenAIModel represents a model in the OpenAI API. +type OpenAIModel struct { + // ID is the ID of the model. + ID string `json:"id"` + + // Object is the object type. + Object string `json:"object"` + + // Created is the creation time of the model. + Created int64 `json:"created"` + + // OwnedBy is the owner of the model. + OwnedBy string `json:"owned_by"` +} + +// OpenAIModelList represents a list of models in the OpenAI API. +type OpenAIModelList struct { + // Object is the object type. + Object string `json:"object"` + + // Data is the list of models. + Data []*OpenAIModel `json:"data"` +} + +// Format represents the format of a model. +// TODO: To be replaced by the Model struct from pianta's common/pkg/inference/models/api.go. +// (https://github.com/docker/pinata/pull/33331) +type Format string + +// Config represents the configuration of a model. +type Config struct { + // Format is the format of the model. + Format Format `json:"format,omitempty"` + + // Quantization is the quantization of the model. + Quantization string `json:"quantization,omitempty"` + + // Parameters is the parameters of the model. + Parameters string `json:"parameters,omitempty"` + + // Architecture is the architecture of the model. + Architecture string `json:"architecture,omitempty"` + + // Size is the size of the model. + Size string `json:"size,omitempty"` +} + +// Model represents a model in the Docker Model Runner. +type Model struct { + // ID is the globally unique model identifier. + ID string `json:"id"` + + // Tags are the list of tags associated with the model. + Tags []string `json:"tags"` + + // Created is the Unix epoch timestamp corresponding to the model creation. + Created time.Time `json:"created"` + + // Config describes the model. + Config Config `json:"config"` +} + +// modelAlias is an alias for Model to avoid recursion in JSON marshaling/unmarshaling. +// This is necessary because we want Model to contain a time.Time field which is not directly +// compatible with JSON serialization/deserialization. +type modelAlias Model + +// modelResponseJSON is a struct used for JSON marshaling/unmarshaling of Model. +// It includes a Unix timestamp for the Created field to ensure compatibility with JSON. +type modelResponseJSON struct { + modelAlias + CreatedAt int64 `json:"created"` +} + +// UnmarshalJSON implements json.Unmarshaler. +func (mr *Model) UnmarshalJSON(b []byte) error { + var resp modelResponseJSON + if err := json.Unmarshal(b, &resp); err != nil { + return fmt.Errorf("unmarshal model response: %w", err) + } + *mr = Model(resp.modelAlias) + mr.Created = time.Unix(resp.CreatedAt, 0) + return nil +} + +// MarshalJSON implements json.Marshaler. +func (mr Model) MarshalJSON() ([]byte, error) { + return json.Marshal(modelResponseJSON{ + modelAlias: modelAlias(mr), + CreatedAt: mr.Created.Unix(), + }) +} diff --git a/dockermodelrunner/desktop/api_test.go b/dockermodelrunner/desktop/api_test.go new file mode 100644 index 00000000..43406ec1 --- /dev/null +++ b/dockermodelrunner/desktop/api_test.go @@ -0,0 +1,82 @@ +package desktop + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestUnmarshalJSON(t *testing.T) { + jsonData := `{ + "id": "model123", + "tags": ["tag1", "tag2"], + "config": { + "format": "format1", + "quantization": "quantization1", + "parameters": "parameters1", + "architecture": "architecture1", + "size": "size1" + }, + "created": 1682179200 + }` + + var response Model + err := json.Unmarshal([]byte(jsonData), &response) + require.NoError(t, err) + require.Equal(t, Model{ + ID: "model123", + Tags: []string{"tag1", "tag2"}, + Config: Config{ + Format: "format1", + Quantization: "quantization1", + Parameters: "parameters1", + Architecture: "architecture1", + Size: "size1", + }, + Created: time.Unix(1682179200, 0), + }, response) +} + +func TestUnmarshalJSONError(t *testing.T) { + // Invalid JSON with malformed created timestamp + invalidJSON := `{ + "id": "model123", + "tags": ["tag1", "tag2"], + "config": { + "format": "format1", + "quantization": "quantization1", + "parameters": "parameters1", + "architecture": "architecture1", + "size": "size1" + }, + "created": "not-a-number" + }` + + var response Model + err := json.Unmarshal([]byte(invalidJSON), &response) + require.Error(t, err) + require.Contains(t, err.Error(), "unmarshal model response") +} + +func TestMarshalJSON(t *testing.T) { + response := Model{ + ID: "model123", + Tags: []string{"tag1", "tag2"}, + Config: Config{ + Format: "format1", + Quantization: "quantization1", + Parameters: "parameters1", + Architecture: "architecture1", + Size: "size1", + }, + Created: time.Unix(1682179200, 0), + } + + expectedJSON := `{"id":"model123","tags":["tag1","tag2"],"config":{"format":"format1","quantization":"quantization1","parameters":"parameters1","architecture":"architecture1","size":"size1"},"created":1682179200}` + + jsonData, err := json.Marshal(response) + require.NoError(t, err, "Failed to marshal JSON") + require.JSONEq(t, expectedJSON, string(jsonData), "Unexpected JSON output") +} diff --git a/dockermodelrunner/desktop/desktop.go b/dockermodelrunner/desktop/desktop.go new file mode 100644 index 00000000..e11067a8 --- /dev/null +++ b/dockermodelrunner/desktop/desktop.go @@ -0,0 +1,489 @@ +package desktop + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "html" + "io" + "net/http" + "strconv" + "strings" + + "github.com/docker/docker-sdk-go/dockermodelrunner/inference" + "github.com/docker/docker-sdk-go/dockermodelrunner/models" +) + +var ( + // ErrNotFound is returned when a model is not found. + ErrNotFound = errors.New("model not found") + + // ErrServiceUnavailable is returned when the service is unavailable. + ErrServiceUnavailable = errors.New("service unavailable") +) + +// Client is a client for the Docker Model Runner API. +type Client struct { + dockerClient DockerHTTPClient +} + +// DockerHTTPClient is an interface that can be used to mock the Docker client. +// +//go:generate mockgen -source=desktop.go -destination=../mocks/mock_desktop.go -package=mocks DockerHTTPClient +type DockerHTTPClient interface { + Do(req *http.Request) (*http.Response, error) +} + +// New creates a new Client. +func New(dockerClient DockerHTTPClient) *Client { + return &Client{dockerClient} +} + +// Status represents the status of the Docker Model Runner. +type Status struct { + Running bool `json:"running"` + Status []byte `json:"status"` + Error error `json:"error"` +} + +// Status returns the status of the Docker Model Runner. +func (c *Client) Status() Status { + // TODO: Query "/". + resp, err := c.doRequest(http.MethodGet, inference.ModelsPrefix, nil) + if err != nil { + err = c.handleQueryError(err, inference.ModelsPrefix) + if errors.Is(err, ErrServiceUnavailable) { + return Status{ + Running: false, + } + } + return Status{ + Running: false, + Error: err, + } + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return Status{ + Running: false, + Error: fmt.Errorf("unexpected status code: %d", resp.StatusCode), + } + } + + var status []byte + statusResp, err := c.doRequest(http.MethodGet, inference.InferencePrefix+"/status", nil) + if err != nil { + status = []byte(fmt.Sprintf("error querying status: %v", err)) + } else { + defer statusResp.Body.Close() + statusBody, err := io.ReadAll(statusResp.Body) + if err != nil { + status = []byte(fmt.Sprintf("error reading status body: %v", err)) + } else { + status = statusBody + } + } + return Status{ + Running: true, + Status: status, + } +} + +// Pull pulls a model from the Docker Model Runner. +func (c *Client) Pull(model string, progress func(string)) (string, bool, error) { + jsonData, err := json.Marshal(models.ModelCreateRequest{From: model}) + if err != nil { + return "", false, fmt.Errorf("marshal request: %w", err) + } + + createPath := inference.ModelsPrefix + "/create" + resp, err := c.doRequest( + http.MethodPost, + createPath, + bytes.NewReader(jsonData), + ) + if err != nil { + return "", false, c.handleQueryError(err, createPath) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", false, fmt.Errorf("pull %s status=%d body=%s", model, resp.StatusCode, body) + } + + return scanProgress(resp, "pull", model, progress) +} + +// scanProgress scans the progress of a model for a given action. +func scanProgress(resp *http.Response, action string, model string, progress func(string)) (string, bool, error) { + progressShown := false + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + progressLine := scanner.Text() + if progressLine == "" { + continue + } + + // Parse the progress message + var progressMsg ProgressMessage + if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil { + return "", progressShown, fmt.Errorf("unmarshal progress message: %w", err) + } + + // Handle different message types + switch progressMsg.Type { + case "progress": + progress(progressMsg.Message) + progressShown = true + case "error": + return "", progressShown, fmt.Errorf("%s %s: %s", action, model, progressMsg.Message) + case "success": + return progressMsg.Message, progressShown, nil + default: + return "", progressShown, fmt.Errorf("unknown message type: %s", progressMsg.Type) + } + } + + return "", progressShown, fmt.Errorf("%s model %s: unexpected end of stream", action, model) +} + +// Push pushes a model to the Docker Model Runner. +func (c *Client) Push(model string, progress func(string)) (string, bool, error) { + pushPath := inference.ModelsPrefix + "/" + model + "/push" + resp, err := c.doRequest( + http.MethodPost, + pushPath, + nil, // Assuming no body is needed for the push request + ) + if err != nil { + return "", false, c.handleQueryError(err, pushPath) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", false, fmt.Errorf("push %s status=%d body=%s", model, resp.StatusCode, body) + } + + return scanProgress(resp, "push", model, progress) +} + +// List lists all models in the Docker Model Runner. +func (c *Client) List() ([]Model, error) { + modelsRoute := inference.ModelsPrefix + body, err := c.listRaw(modelsRoute, "") + if err != nil { + return nil, err + } + + var modelsJSON []Model + if err := json.Unmarshal(body, &modelsJSON); err != nil { + return nil, fmt.Errorf("unmarshal response body: %w", err) + } + + return modelsJSON, nil +} + +// ListOpenAI lists all models in the Docker Model Runner using the OpenAI API. +func (c *Client) ListOpenAI() (*OpenAIModelList, error) { + modelsRoute := inference.InferencePrefix + "/v1/models" + rawResponse, err := c.listRaw(modelsRoute, "") + if err != nil { + return nil, err + } + + var modelsJSON OpenAIModelList + if err := json.Unmarshal(rawResponse, &modelsJSON); err != nil { + return nil, fmt.Errorf("unmarshal response body: %w", err) + } + + return &modelsJSON, nil +} + +// Inspect inspects a model in the Docker Model Runner. +func (c *Client) Inspect(model string) (*Model, error) { + if model != "" { + if !strings.Contains(strings.Trim(model, "/"), "/") { + // Do an extra API call to check if the model parameter isn't a model ID. + modelId, err := c.fullModelID(model) + if err != nil { + return nil, fmt.Errorf("invalid model name: %s", model) + } + model = modelId + } + } + + rawResponse, err := c.listRaw(fmt.Sprintf("%s/%s", inference.ModelsPrefix, model), model) + if err != nil { + return nil, err + } + + var modelInspect Model + if err := json.Unmarshal(rawResponse, &modelInspect); err != nil { + return nil, fmt.Errorf("unmarshal response body: %w", err) + } + + return &modelInspect, nil +} + +// InspectOpenAI inspects a model in the Docker Model Runner using the OpenAI API. +func (c *Client) InspectOpenAI(model string) (*OpenAIModel, error) { + modelsRoute := inference.InferencePrefix + "/v1/models" + if !strings.Contains(strings.Trim(model, "/"), "/") { + // Do an extra API call to check if the model parameter isn't a model ID. + var err error + if model, err = c.fullModelID(model); err != nil { + return nil, fmt.Errorf("invalid model name: %s", model) + } + } + + rawResponse, err := c.listRaw(fmt.Sprintf("%s/%s", modelsRoute, model), model) + if err != nil { + return nil, err + } + + var modelInspect OpenAIModel + if err := json.Unmarshal(rawResponse, &modelInspect); err != nil { + return nil, fmt.Errorf("unmarshal response body: %w", err) + } + + return &modelInspect, nil +} + +// listRaw lists all models in the Docker Model Runner. +func (c *Client) listRaw(route string, model string) ([]byte, error) { + resp, err := c.doRequest(http.MethodGet, route, nil) + if err != nil { + return nil, c.handleQueryError(err, route) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if model != "" && resp.StatusCode == http.StatusNotFound { + return nil, fmt.Errorf("%w: %s", ErrNotFound, model) + } + return nil, fmt.Errorf("list models: %s", resp.Status) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response body: %w", err) + } + return body, nil +} + +// fullModelID returns the full model ID for a given model ID. +func (c *Client) fullModelID(id string) (string, error) { + bodyResponse, err := c.listRaw(inference.ModelsPrefix, "") + if err != nil { + return "", err + } + + var modelsJSON []Model + if err := json.Unmarshal(bodyResponse, &modelsJSON); err != nil { + return "", fmt.Errorf("unmarshal response body: %w", err) + } + + for _, m := range modelsJSON { + if m.ID[7:19] == id || strings.TrimPrefix(m.ID, "sha256:") == id || m.ID == id { + return m.ID, nil + } + } + + return "", fmt.Errorf("model with ID %s not found", id) +} + +// Chat chats with a model in the Docker Model Runner. +func (c *Client) Chat(model, prompt string) error { + if !strings.Contains(strings.Trim(model, "/"), "/") { + // Do an extra API call to check if the model parameter isn't a model ID. + if expanded, err := c.fullModelID(model); err == nil { + model = expanded + } + } + + reqBody := OpenAIChatRequest{ + Model: model, + Messages: []OpenAIChatMessage{ + { + Role: "user", + Content: prompt, + }, + }, + Stream: true, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("marshal request: %w", err) + } + + chatCompletionsPath := inference.InferencePrefix + "/v1/chat/completions" + resp, err := c.doRequest( + http.MethodPost, + chatCompletionsPath, + bytes.NewReader(jsonData), + ) + if err != nil { + return c.handleQueryError(err, chatCompletionsPath) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("chat with %s status=%d body=%s", model, resp.StatusCode, body) + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if line == "" { + continue + } + + if !strings.HasPrefix(line, "data: ") { + continue + } + + data := strings.TrimPrefix(line, "data: ") + + if data == "[DONE]" { + break + } + + var streamResp OpenAIChatResponse + if err := json.Unmarshal([]byte(data), &streamResp); err != nil { + return fmt.Errorf("unmarshal stream response: %w", err) + } + + if len(streamResp.Choices) > 0 && streamResp.Choices[0].Delta.Content != "" { + chunk := streamResp.Choices[0].Delta.Content + fmt.Print(chunk) + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("read response stream: %w", err) + } + + return nil +} + +// Remove removes a model from the Docker Model Runner. +func (c *Client) Remove(models []string, force bool) (string, error) { + modelRemoved := "" + for _, model := range models { + // Check if not a model ID passed as parameter. + if !strings.Contains(model, "/") { + if expanded, err := c.fullModelID(model); err == nil { + model = expanded + } + } + + // Construct the URL with query parameters + removePath := fmt.Sprintf("%s/%s?force=%s", + inference.ModelsPrefix, + model, + strconv.FormatBool(force), + ) + + resp, err := c.doRequest(http.MethodDelete, removePath, nil) + if err != nil { + return modelRemoved, c.handleQueryError(err, removePath) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusNotFound { + return modelRemoved, fmt.Errorf("no such model: %s", model) + } + var bodyStr string + body, err := io.ReadAll(resp.Body) + if err != nil { + bodyStr = fmt.Sprintf("(failed to read response body: %v)", err) + } else { + bodyStr = string(body) + } + return modelRemoved, fmt.Errorf("removing %s failed with status %s: %s", model, resp.Status, bodyStr) + } + modelRemoved += fmt.Sprintf("Model %s removed successfully\n", model) + } + return modelRemoved, nil +} + +// URL returns the URL for the Docker Model Runner. +func URL(path string) string { + return "http://localhost" + inference.ExperimentalEndpointsPrefix + path +} + +// doRequest is a helper function that performs HTTP requests and handles 503 responses +func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequest(method, URL(path), body) + if err != nil { + return nil, fmt.Errorf("new %s request: %w", method, err) + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + + resp, err := c.dockerClient.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode == http.StatusServiceUnavailable { + resp.Body.Close() + return nil, ErrServiceUnavailable + } + + return resp, nil +} + +// handleQueryError is a helper function that handles query errors. +func (c *Client) handleQueryError(err error, path string) error { + if errors.Is(err, ErrServiceUnavailable) { + return ErrServiceUnavailable + } + return fmt.Errorf("query %s: %w", path, err) +} + +// Tag tags a model in the Docker Model Runner. +func (c *Client) Tag(source, targetRepo, targetTag string) (string, error) { + // Check if the source is a model ID, and expand it if necessary + if !strings.Contains(strings.Trim(source, "/"), "/") { + // Do an extra API call to check if the model parameter might be a model ID + if expanded, err := c.fullModelID(source); err == nil { + source = expanded + } + } + + // Construct the URL with query parameters + tagPath := fmt.Sprintf("%s/%s/tag?repo=%s&tag=%s", + inference.ModelsPrefix, + source, + targetRepo, + targetTag, + ) + + resp, err := c.doRequest(http.MethodPost, tagPath, nil) + if err != nil { + return "", c.handleQueryError(err, tagPath) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("tag %s:%s status=%d body=%s", targetRepo, targetTag, resp.StatusCode, body) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read response body: %w", err) + } + + return string(body), nil +} diff --git a/dockermodelrunner/go.mod b/dockermodelrunner/go.mod new file mode 100644 index 00000000..9d9ee2c8 --- /dev/null +++ b/dockermodelrunner/go.mod @@ -0,0 +1,14 @@ +module github.com/docker/docker-sdk-go/dockermodelrunner + +go 1.23.6 + +require github.com/stretchr/testify v1.10.0 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.13.1 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/dockermodelrunner/go.sum b/dockermodelrunner/go.sum new file mode 100644 index 00000000..e48aae2f --- /dev/null +++ b/dockermodelrunner/go.sum @@ -0,0 +1,23 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/dockermodelrunner/inference/api.go b/dockermodelrunner/inference/api.go new file mode 100644 index 00000000..9d6428ca --- /dev/null +++ b/dockermodelrunner/inference/api.go @@ -0,0 +1,12 @@ +package inference + +// ExperimentalEndpointsPrefix is used to prefix all routes on the Docker +// socket while they are still in their experimental stage. This prefix doesn't +// apply to endpoints on model-runner.docker.internal. +const ExperimentalEndpointsPrefix = "/exp/vDD4.40" + +// InferencePrefix is the prefix for inference related related routes. +var InferencePrefix = "/engines" + +// ModelsPrefix is the prefix for all model manager related routes. +var ModelsPrefix = "/models" diff --git a/dockermodelrunner/modeldistribution/types/types.go b/dockermodelrunner/modeldistribution/types/types.go new file mode 100644 index 00000000..fd915bc2 --- /dev/null +++ b/dockermodelrunner/modeldistribution/types/types.go @@ -0,0 +1,102 @@ +package types + +import ( + "encoding/json" + "fmt" + "time" +) + +// Store interface for model storage operations +type Store interface { + // Push a model to the store with given tags + Push(modelPath string, tags []string) error + + // Pull a model by tag + Pull(tag string, destPath string) error + + // List all models in the store + List() ([]Model, error) + + // GetByTag Get model info by tag + GetByTag(tag string) (*Model, error) + + // Delete a model by tag + Delete(tag string) error + + // AddTags Add tags to an existing model + AddTags(tag string, newTags []string) error + + // RemoveTags Remove tags from a model + RemoveTags(tags []string) error + + // Version Get store version + Version() string + + // Upgrade store to latest version + Upgrade() error +} + +// Model represents a model with its metadata and tags +type Model struct { + // ID is the globally unique model identifier. + ID string `json:"id"` + // Tags are the list of tags associated with the model. + Tags []string `json:"tags"` + // Files are the GGUF files associated with the model. + Files []string `json:"files"` + // Created is the Unix epoch timestamp corresponding to the model creation. + Created time.Time `json:"created"` +} + +// modelAlias is an alias for Model to avoid recursion in JSON marshaling/unmarshaling. +// This is necessary because we want Model to contain a time.Time field which is not directly +// compatible with JSON serialization/deserialization. +type modelAlias Model + +// modelResponseJSON is a struct used for JSON marshaling/unmarshaling of Model. +// It includes a Unix timestamp for the Created field to ensure compatibility with JSON. +type modelResponseJSON struct { + modelAlias + CreatedAt int64 `json:"created"` +} + +// UnmarshalJSON implements json.Unmarshaler. +func (mr *Model) UnmarshalJSON(b []byte) error { + var resp modelResponseJSON + if err := json.Unmarshal(b, &resp); err != nil { + return fmt.Errorf("unmarshal model response: %w", err) + } + *mr = Model(resp.modelAlias) + mr.Created = time.Unix(resp.CreatedAt, 0) + return nil +} + +// MarshalJSON implements json.Marshaler. +func (mr Model) MarshalJSON() ([]byte, error) { + return json.Marshal(modelResponseJSON{ + modelAlias: modelAlias(mr), + CreatedAt: mr.Created.Unix(), + }) +} + +// ModelIndex represents the index of all models in the store +type ModelIndex struct { + Models []Model `json:"models"` +} + +// StoreLayout represents the layout information of the store +type StoreLayout struct { + Version string `json:"version"` +} + +// ManifestReference represents a reference to a manifest in the store +type ManifestReference struct { + Digest string `json:"digest"` + MediaType string `json:"mediaType"` + Size int64 `json:"size"` +} + +// StoreOptions represents options for creating a store +type StoreOptions struct { + RootPath string +} diff --git a/dockermodelrunner/modeldistribution/types/types_test.go b/dockermodelrunner/modeldistribution/types/types_test.go new file mode 100644 index 00000000..1be7fefc --- /dev/null +++ b/dockermodelrunner/modeldistribution/types/types_test.go @@ -0,0 +1,64 @@ +package types + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestUnmarshalJSON(t *testing.T) { + jsonData := `{ + "id": "model123", + "tags": ["tag1", "tag2"], + "files": ["file1", "file2"], + "created": 1682179200 + }` + + var response Model + err := json.Unmarshal([]byte(jsonData), &response) + require.NoError(t, err) + require.Equal(t, Model{ + ID: "model123", + Tags: []string{"tag1", "tag2"}, + Files: []string{ + "file1", + "file2", + }, + Created: time.Unix(1682179200, 0), + }, response) +} + +func TestUnmarshalJSONError(t *testing.T) { + // Invalid JSON with malformed created timestamp + invalidJSON := `{ + "id": "model123", + "tags": ["tag1", "tag2"], + "files": ["file1", "file2"], + "created": "not-a-number" + }` + + var response Model + err := json.Unmarshal([]byte(invalidJSON), &response) + require.Error(t, err) + require.Contains(t, err.Error(), "unmarshal model response") +} + +func TestMarshalJSON(t *testing.T) { + response := Model{ + ID: "model123", + Tags: []string{"tag1", "tag2"}, + Files: []string{ + "file1", + "file2", + }, + Created: time.Unix(1682179200, 0), + } + + expectedJSON := `{"id":"model123","tags":["tag1","tag2"],"files":["file1","file2"],"created":1682179200}` + + jsonData, err := json.Marshal(response) + require.NoError(t, err, "Failed to marshal JSON") + require.JSONEq(t, expectedJSON, string(jsonData), "Unexpected JSON output") +} diff --git a/dockermodelrunner/models/api.go b/dockermodelrunner/models/api.go new file mode 100644 index 00000000..715861b2 --- /dev/null +++ b/dockermodelrunner/models/api.go @@ -0,0 +1,99 @@ +package models + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/docker/docker-sdk-go/dockermodelrunner/modeldistribution/types" +) + +// ModelCreateRequest represents a model create request. It is designed to +// follow Docker Engine API conventions, most closely following the request +// associated with POST /images/create. At the moment is only designed to +// facilitate pulls, though in the future it may facilitate model building and +// refinement (such as fine tuning, quantization, or distillation). +type ModelCreateRequest struct { + // From is the name of the model to pull. + From string `json:"from"` +} + +// ToOpenAI converts a types.Model to its OpenAI API representation. +func ToOpenAI(m *types.Model) *OpenAIModel { + return &OpenAIModel{ + ID: m.Tags[0], + Object: "model", + Created: m.Created, + OwnedBy: "docker", + } +} + +// ModelList represents a list of models. +type ModelList []*types.Model + +// ToOpenAI converts the model list to its OpenAI API representation. This function never +// returns a nil slice (though it may return an empty slice). +func (l ModelList) ToOpenAI() *OpenAIModelList { + // Convert the constituent models. + models := make([]*OpenAIModel, len(l)) + for m, model := range l { + models[m] = ToOpenAI(model) + } + + // Create the OpenAI model list. + return &OpenAIModelList{ + Object: "list", + Data: models, + } +} + +// OpenAIModel represents a locally stored model using OpenAI conventions. +type OpenAIModel struct { + // ID is the model tag. + ID string `json:"id"` + // Object is the object type. For OpenAIModel, it is always "model". + Object string `json:"object"` + // Created is the Unix epoch timestamp corresponding to the model creation. + Created time.Time `json:"created"` + // OwnedBy is the model owner. At the moment, it is always "docker". + OwnedBy string `json:"owned_by"` +} + +// OpenAIModelList represents a list of models using OpenAI conventions. +type OpenAIModelList struct { + // Object is the object type. For OpenAIModelList, it is always "list". + Object string `json:"object"` + // Data is the list of models. + Data []*OpenAIModel `json:"data"` +} + +// openAIModelAlias is an alias for OpenAIModel to avoid recursion in JSON marshaling/unmarshaling. +// This is necessary because we want OpenAIModel to contain a time.Time field which is not directly +// compatible with JSON serialization/deserialization. +type openAIModelAlias OpenAIModel + +// openAIModelResponseJSON is a struct used for JSON marshaling/unmarshaling of OpenAIModel. +// It includes a Unix timestamp for the Created field to ensure compatibility with JSON. +type openAIModelResponseJSON struct { + openAIModelAlias + CreatedAt int64 `json:"created"` +} + +// UnmarshalJSON implements json.Unmarshaler. +func (mr *OpenAIModel) UnmarshalJSON(b []byte) error { + var resp openAIModelResponseJSON + if err := json.Unmarshal(b, &resp); err != nil { + return fmt.Errorf("unmarshal model response: %w", err) + } + *mr = OpenAIModel(resp.openAIModelAlias) + mr.Created = time.Unix(resp.CreatedAt, 0) + return nil +} + +// MarshalJSON implements json.Marshaler. +func (mr OpenAIModel) MarshalJSON() ([]byte, error) { + return json.Marshal(openAIModelResponseJSON{ + openAIModelAlias: openAIModelAlias(mr), + CreatedAt: mr.Created.Unix(), + }) +} diff --git a/dockermodelrunner/models/api_test.go b/dockermodelrunner/models/api_test.go new file mode 100644 index 00000000..43d88c88 --- /dev/null +++ b/dockermodelrunner/models/api_test.go @@ -0,0 +1,58 @@ +package models + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestUnmarshalJSON(t *testing.T) { + jsonData := `{ + "id": "model123", + "object": "model", + "created": 1682179200, + "owned_by": "docker" + }` + + var response OpenAIModel + err := json.Unmarshal([]byte(jsonData), &response) + require.NoError(t, err) + require.Equal(t, OpenAIModel{ + ID: "model123", + Object: "model", + Created: time.Unix(1682179200, 0), + OwnedBy: "docker", + }, response) +} + +func TestUnmarshalJSONError(t *testing.T) { + // Invalid JSON with malformed created timestamp + invalidJSON := `{ + "id": "model123", + "object": "model", + "created": "not-a-number", + "owned_by": "docker" + }` + + var response OpenAIModel + err := json.Unmarshal([]byte(invalidJSON), &response) + require.Error(t, err) + require.Contains(t, err.Error(), "unmarshal model response") +} + +func TestMarshalJSON(t *testing.T) { + response := OpenAIModel{ + ID: "model123", + Object: "model", + Created: time.Unix(1682179200, 0), + OwnedBy: "docker", + } + + expectedJSON := `{"id":"model123","object":"model","created":1682179200,"owned_by":"docker"}` + + jsonData, err := json.Marshal(response) + require.NoError(t, err, "Failed to marshal JSON") + require.JSONEq(t, expectedJSON, string(jsonData), "Unexpected JSON output") +} diff --git a/go.work b/go.work index f51212b7..8e9c2e9b 100644 --- a/go.work +++ b/go.work @@ -3,4 +3,5 @@ go 1.23.6 use ( ./dockerconfig ./dockercontext + ./dockermodelrunner )