diff options
Diffstat (limited to 'modules/lfs')
-rw-r--r-- | modules/lfs/client.go | 10 | ||||
-rw-r--r-- | modules/lfs/client_test.go | 1 | ||||
-rw-r--r-- | modules/lfs/filesystem_client.go | 63 | ||||
-rw-r--r-- | modules/lfs/http_client.go | 131 | ||||
-rw-r--r-- | modules/lfs/http_client_test.go | 320 | ||||
-rw-r--r-- | modules/lfs/shared.go | 4 | ||||
-rw-r--r-- | modules/lfs/transferadapter.go | 102 | ||||
-rw-r--r-- | modules/lfs/transferadapter_test.go | 169 |
8 files changed, 662 insertions, 138 deletions
diff --git a/modules/lfs/client.go b/modules/lfs/client.go index ae35919d77..0a21440f73 100644 --- a/modules/lfs/client.go +++ b/modules/lfs/client.go @@ -10,9 +10,17 @@ import ( "net/url" ) +// DownloadCallback gets called for every requested LFS object to process its content +type DownloadCallback func(p Pointer, content io.ReadCloser, objectError error) error + +// UploadCallback gets called for every requested LFS object to provide its content +type UploadCallback func(p Pointer, objectError error) (io.ReadCloser, error) + // Client is used to communicate with a LFS source type Client interface { - Download(ctx context.Context, oid string, size int64) (io.ReadCloser, error) + BatchSize() int + Download(ctx context.Context, objects []Pointer, callback DownloadCallback) error + Upload(ctx context.Context, objects []Pointer, callback UploadCallback) error } // NewClient creates a LFS client diff --git a/modules/lfs/client_test.go b/modules/lfs/client_test.go index d4eb005469..1040b39925 100644 --- a/modules/lfs/client_test.go +++ b/modules/lfs/client_test.go @@ -6,7 +6,6 @@ package lfs import ( "net/url" - "testing" "github.com/stretchr/testify/assert" diff --git a/modules/lfs/filesystem_client.go b/modules/lfs/filesystem_client.go index 3a51564a82..dc72981a9e 100644 --- a/modules/lfs/filesystem_client.go +++ b/modules/lfs/filesystem_client.go @@ -19,6 +19,11 @@ type FilesystemClient struct { lfsdir string } +// BatchSize returns the preferred size of batchs to process +func (c *FilesystemClient) BatchSize() int { + return 1 +} + func newFilesystemClient(endpoint *url.URL) *FilesystemClient { path, _ := util.FileURLToPath(endpoint) @@ -33,18 +38,56 @@ func (c *FilesystemClient) objectPath(oid string) string { return filepath.Join(c.lfsdir, oid[0:2], oid[2:4], oid) } -// Download reads the specific LFS object from the target repository -func (c *FilesystemClient) Download(ctx context.Context, oid string, size int64) (io.ReadCloser, error) { - objectPath := c.objectPath(oid) +// Download reads the specific LFS object from the target path +func (c *FilesystemClient) Download(ctx context.Context, objects []Pointer, callback DownloadCallback) error { + for _, object := range objects { + p := Pointer{object.Oid, object.Size} - if _, err := os.Stat(objectPath); os.IsNotExist(err) { - return nil, err - } + objectPath := c.objectPath(p.Oid) + + f, err := os.Open(objectPath) + if err != nil { + return err + } - file, err := os.Open(objectPath) - if err != nil { - return nil, err + if err := callback(p, f, nil); err != nil { + return err + } } + return nil +} + +// Upload writes the specific LFS object to the target path +func (c *FilesystemClient) Upload(ctx context.Context, objects []Pointer, callback UploadCallback) error { + for _, object := range objects { + p := Pointer{object.Oid, object.Size} + + objectPath := c.objectPath(p.Oid) - return file, nil + if err := os.MkdirAll(filepath.Dir(objectPath), os.ModePerm); err != nil { + return err + } + + content, err := callback(p, nil) + if err != nil { + return err + } + + err = func() error { + defer content.Close() + + f, err := os.Create(objectPath) + if err != nil { + return err + } + + _, err = io.Copy(f, content) + + return err + }() + if err != nil { + return err + } + } + return nil } diff --git a/modules/lfs/http_client.go b/modules/lfs/http_client.go index fb45defda1..e799b80831 100644 --- a/modules/lfs/http_client.go +++ b/modules/lfs/http_client.go @@ -7,17 +7,19 @@ package lfs import ( "bytes" "context" - "encoding/json" "errors" "fmt" - "io" "net/http" "net/url" "strings" "code.gitea.io/gitea/modules/log" + + jsoniter "github.com/json-iterator/go" ) +const batchSize = 20 + // HTTPClient is used to communicate with the LFS server // https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md type HTTPClient struct { @@ -26,6 +28,11 @@ type HTTPClient struct { transfers map[string]TransferAdapter } +// BatchSize returns the preferred size of batchs to process +func (c *HTTPClient) BatchSize() int { + return batchSize +} + func newHTTPClient(endpoint *url.URL) *HTTPClient { hc := &http.Client{} @@ -55,21 +62,25 @@ func (c *HTTPClient) transferNames() []string { } func (c *HTTPClient) batch(ctx context.Context, operation string, objects []Pointer) (*BatchResponse, error) { + log.Trace("BATCH operation with objects: %v", objects) + url := fmt.Sprintf("%s/objects/batch", c.endpoint) request := &BatchRequest{operation, c.transferNames(), nil, objects} payload := new(bytes.Buffer) - err := json.NewEncoder(payload).Encode(request) + err := jsoniter.NewEncoder(payload).Encode(request) if err != nil { - return nil, fmt.Errorf("lfs.HTTPClient.batch json.Encode: %w", err) + log.Error("Error encoding json: %v", err) + return nil, err } - log.Trace("lfs.HTTPClient.batch NewRequestWithContext: %s", url) + log.Trace("Calling: %s", url) req, err := http.NewRequestWithContext(ctx, "POST", url, payload) if err != nil { - return nil, fmt.Errorf("lfs.HTTPClient.batch http.NewRequestWithContext: %w", err) + log.Error("Error creating request: %v", err) + return nil, err } req.Header.Set("Content-type", MediaType) req.Header.Set("Accept", MediaType) @@ -81,18 +92,20 @@ func (c *HTTPClient) batch(ctx context.Context, operation string, objects []Poin return nil, ctx.Err() default: } - return nil, fmt.Errorf("lfs.HTTPClient.batch http.Do: %w", err) + log.Error("Error while processing request: %v", err) + return nil, err } defer res.Body.Close() if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("lfs.HTTPClient.batch: Unexpected servers response: %s", res.Status) + return nil, fmt.Errorf("Unexpected server response: %s", res.Status) } var response BatchResponse - err = json.NewDecoder(res.Body).Decode(&response) + err = jsoniter.NewDecoder(res.Body).Decode(&response) if err != nil { - return nil, fmt.Errorf("lfs.HTTPClient.batch json.Decode: %w", err) + log.Error("Error decoding json: %v", err) + return nil, err } if len(response.Transfer) == 0 { @@ -103,27 +116,99 @@ func (c *HTTPClient) batch(ctx context.Context, operation string, objects []Poin } // Download reads the specific LFS object from the LFS server -func (c *HTTPClient) Download(ctx context.Context, oid string, size int64) (io.ReadCloser, error) { - var objects []Pointer - objects = append(objects, Pointer{oid, size}) +func (c *HTTPClient) Download(ctx context.Context, objects []Pointer, callback DownloadCallback) error { + return c.performOperation(ctx, objects, callback, nil) +} + +// Upload sends the specific LFS object to the LFS server +func (c *HTTPClient) Upload(ctx context.Context, objects []Pointer, callback UploadCallback) error { + return c.performOperation(ctx, objects, nil, callback) +} - result, err := c.batch(ctx, "download", objects) +func (c *HTTPClient) performOperation(ctx context.Context, objects []Pointer, dc DownloadCallback, uc UploadCallback) error { + if len(objects) == 0 { + return nil + } + + operation := "download" + if uc != nil { + operation = "upload" + } + + result, err := c.batch(ctx, operation, objects) if err != nil { - return nil, err + return err } transferAdapter, ok := c.transfers[result.Transfer] if !ok { - return nil, fmt.Errorf("lfs.HTTPClient.Download Transferadapter not found: %s", result.Transfer) + return fmt.Errorf("TransferAdapter not found: %s", result.Transfer) } - if len(result.Objects) == 0 { - return nil, errors.New("lfs.HTTPClient.Download: No objects in result") - } + for _, object := range result.Objects { + if object.Error != nil { + objectError := errors.New(object.Error.Message) + log.Trace("Error on object %v: %v", object.Pointer, objectError) + if uc != nil { + if _, err := uc(object.Pointer, objectError); err != nil { + return err + } + } else { + if err := dc(object.Pointer, nil, objectError); err != nil { + return err + } + } + continue + } - content, err := transferAdapter.Download(ctx, result.Objects[0]) - if err != nil { - return nil, err + if uc != nil { + if len(object.Actions) == 0 { + log.Trace("%v already present on server", object.Pointer) + continue + } + + link, ok := object.Actions["upload"] + if !ok { + log.Debug("%+v", object) + return errors.New("Missing action 'upload'") + } + + content, err := uc(object.Pointer, nil) + if err != nil { + return err + } + + err = transferAdapter.Upload(ctx, link, object.Pointer, content) + + content.Close() + + if err != nil { + return err + } + + link, ok = object.Actions["verify"] + if ok { + if err := transferAdapter.Verify(ctx, link, object.Pointer); err != nil { + return err + } + } + } else { + link, ok := object.Actions["download"] + if !ok { + log.Debug("%+v", object) + return errors.New("Missing action 'download'") + } + + content, err := transferAdapter.Download(ctx, link) + if err != nil { + return err + } + + if err := dc(object.Pointer, content, nil); err != nil { + return err + } + } } - return content, nil + + return nil } diff --git a/modules/lfs/http_client_test.go b/modules/lfs/http_client_test.go index 68ec947aa8..0f633ede54 100644 --- a/modules/lfs/http_client_test.go +++ b/modules/lfs/http_client_test.go @@ -7,13 +7,13 @@ package lfs import ( "bytes" "context" - "encoding/json" "io" "io/ioutil" "net/http" "strings" "testing" + jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/assert" ) @@ -30,69 +30,253 @@ func (a *DummyTransferAdapter) Name() string { return "dummy" } -func (a *DummyTransferAdapter) Download(ctx context.Context, r *ObjectResponse) (io.ReadCloser, error) { +func (a *DummyTransferAdapter) Download(ctx context.Context, l *Link) (io.ReadCloser, error) { return ioutil.NopCloser(bytes.NewBufferString("dummy")), nil } -func TestHTTPClientDownload(t *testing.T) { - oid := "fb8f7d8435968c4f82a726a92395be4d16f2f63116caf36c8ad35c60831ab041" - size := int64(6) +func (a *DummyTransferAdapter) Upload(ctx context.Context, l *Link, p Pointer, r io.Reader) error { + return nil +} + +func (a *DummyTransferAdapter) Verify(ctx context.Context, l *Link, p Pointer) error { + return nil +} + +func lfsTestRoundtripHandler(req *http.Request) *http.Response { + var batchResponse *BatchResponse + url := req.URL.String() - roundTripHandler := func(req *http.Request) *http.Response { - url := req.URL.String() - if strings.Contains(url, "status-not-ok") { - return &http.Response{StatusCode: http.StatusBadRequest} + if strings.Contains(url, "status-not-ok") { + return &http.Response{StatusCode: http.StatusBadRequest} + } else if strings.Contains(url, "invalid-json-response") { + return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(bytes.NewBufferString("invalid json"))} + } else if strings.Contains(url, "valid-batch-request-download") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Actions: map[string]*Link{ + "download": {}, + }, + }, + }, + } + } else if strings.Contains(url, "valid-batch-request-upload") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Actions: map[string]*Link{ + "upload": {}, + }, + }, + }, } - if strings.Contains(url, "invalid-json-response") { - return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(bytes.NewBufferString("invalid json"))} + } else if strings.Contains(url, "response-no-objects") { + batchResponse = &BatchResponse{Transfer: "dummy"} + } else if strings.Contains(url, "unknown-transfer-adapter") { + batchResponse = &BatchResponse{Transfer: "unknown_adapter"} + } else if strings.Contains(url, "error-in-response-objects") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Error: &ObjectError{ + Code: 404, + Message: "Object not found", + }, + }, + }, } - if strings.Contains(url, "valid-batch-request-download") { - assert.Equal(t, "POST", req.Method) - assert.Equal(t, MediaType, req.Header.Get("Content-type"), "case %s: error should match", url) - assert.Equal(t, MediaType, req.Header.Get("Accept"), "case %s: error should match", url) + } else if strings.Contains(url, "empty-actions-map") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Actions: map[string]*Link{}, + }, + }, + } + } else if strings.Contains(url, "download-actions-map") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Actions: map[string]*Link{ + "download": {}, + }, + }, + }, + } + } else if strings.Contains(url, "upload-actions-map") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Actions: map[string]*Link{ + "upload": {}, + }, + }, + }, + } + } else if strings.Contains(url, "verify-actions-map") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Actions: map[string]*Link{ + "verify": {}, + }, + }, + }, + } + } else if strings.Contains(url, "unknown-actions-map") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Actions: map[string]*Link{ + "unknown": {}, + }, + }, + }, + } + } else { + return nil + } - var batchRequest BatchRequest - err := json.NewDecoder(req.Body).Decode(&batchRequest) - assert.NoError(t, err) + payload := new(bytes.Buffer) + jsoniter.NewEncoder(payload).Encode(batchResponse) - assert.Equal(t, "download", batchRequest.Operation) - assert.Len(t, batchRequest.Objects, 1) - assert.Equal(t, oid, batchRequest.Objects[0].Oid) - assert.Equal(t, size, batchRequest.Objects[0].Size) + return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(payload)} +} - batchResponse := &BatchResponse{ - Transfer: "dummy", - Objects: make([]*ObjectResponse, 1), - } +func TestHTTPClientDownload(t *testing.T) { + p := Pointer{Oid: "fb8f7d8435968c4f82a726a92395be4d16f2f63116caf36c8ad35c60831ab041", Size: 6} - payload := new(bytes.Buffer) - json.NewEncoder(payload).Encode(batchResponse) + hc := &http.Client{Transport: RoundTripFunc(func(req *http.Request) *http.Response { + assert.Equal(t, "POST", req.Method) + assert.Equal(t, MediaType, req.Header.Get("Content-type")) + assert.Equal(t, MediaType, req.Header.Get("Accept")) - return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(payload)} - } - if strings.Contains(url, "invalid-response-no-objects") { - batchResponse := &BatchResponse{Transfer: "dummy"} + var batchRequest BatchRequest + err := jsoniter.NewDecoder(req.Body).Decode(&batchRequest) + assert.NoError(t, err) - payload := new(bytes.Buffer) - json.NewEncoder(payload).Encode(batchResponse) + assert.Equal(t, "download", batchRequest.Operation) + assert.Equal(t, 1, len(batchRequest.Objects)) + assert.Equal(t, p.Oid, batchRequest.Objects[0].Oid) + assert.Equal(t, p.Size, batchRequest.Objects[0].Size) - return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(payload)} - } - if strings.Contains(url, "unknown-transfer-adapter") { - batchResponse := &BatchResponse{Transfer: "unknown_adapter"} + return lfsTestRoundtripHandler(req) + })} + dummy := &DummyTransferAdapter{} - payload := new(bytes.Buffer) - json.NewEncoder(payload).Encode(batchResponse) + var cases = []struct { + endpoint string + expectederror string + }{ + // case 0 + { + endpoint: "https://status-not-ok.io", + expectederror: "Unexpected server response: ", + }, + // case 1 + { + endpoint: "https://invalid-json-response.io", + expectederror: "invalid json", + }, + // case 2 + { + endpoint: "https://valid-batch-request-download.io", + expectederror: "", + }, + // case 3 + { + endpoint: "https://response-no-objects.io", + expectederror: "", + }, + // case 4 + { + endpoint: "https://unknown-transfer-adapter.io", + expectederror: "TransferAdapter not found: ", + }, + // case 5 + { + endpoint: "https://error-in-response-objects.io", + expectederror: "Object not found", + }, + // case 6 + { + endpoint: "https://empty-actions-map.io", + expectederror: "Missing action 'download'", + }, + // case 7 + { + endpoint: "https://download-actions-map.io", + expectederror: "", + }, + // case 8 + { + endpoint: "https://upload-actions-map.io", + expectederror: "Missing action 'download'", + }, + // case 9 + { + endpoint: "https://verify-actions-map.io", + expectederror: "Missing action 'download'", + }, + // case 10 + { + endpoint: "https://unknown-actions-map.io", + expectederror: "Missing action 'download'", + }, + } - return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(payload)} + for n, c := range cases { + client := &HTTPClient{ + client: hc, + endpoint: c.endpoint, + transfers: make(map[string]TransferAdapter), } + client.transfers["dummy"] = dummy - t.Errorf("Unknown test case: %s", url) - - return nil + err := client.Download(context.Background(), []Pointer{p}, func(p Pointer, content io.ReadCloser, objectError error) error { + if objectError != nil { + return objectError + } + b, err := io.ReadAll(content) + assert.NoError(t, err) + assert.Equal(t, []byte("dummy"), b) + return nil + }) + if len(c.expectederror) > 0 { + assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror) + } else { + assert.NoError(t, err, "case %d", n) + } } +} + +func TestHTTPClientUpload(t *testing.T) { + p := Pointer{Oid: "fb8f7d8435968c4f82a726a92395be4d16f2f63116caf36c8ad35c60831ab041", Size: 6} + + hc := &http.Client{Transport: RoundTripFunc(func(req *http.Request) *http.Response { + assert.Equal(t, "POST", req.Method) + assert.Equal(t, MediaType, req.Header.Get("Content-type")) + assert.Equal(t, MediaType, req.Header.Get("Accept")) + + var batchRequest BatchRequest + err := jsoniter.NewDecoder(req.Body).Decode(&batchRequest) + assert.NoError(t, err) - hc := &http.Client{Transport: RoundTripFunc(roundTripHandler)} + assert.Equal(t, "upload", batchRequest.Operation) + assert.Equal(t, 1, len(batchRequest.Objects)) + assert.Equal(t, p.Oid, batchRequest.Objects[0].Oid) + assert.Equal(t, p.Size, batchRequest.Objects[0].Size) + + return lfsTestRoundtripHandler(req) + })} dummy := &DummyTransferAdapter{} var cases = []struct { @@ -102,27 +286,57 @@ func TestHTTPClientDownload(t *testing.T) { // case 0 { endpoint: "https://status-not-ok.io", - expectederror: "Unexpected servers response: ", + expectederror: "Unexpected server response: ", }, // case 1 { endpoint: "https://invalid-json-response.io", - expectederror: "json.Decode: ", + expectederror: "invalid json", }, // case 2 { - endpoint: "https://valid-batch-request-download.io", + endpoint: "https://valid-batch-request-upload.io", expectederror: "", }, // case 3 { - endpoint: "https://invalid-response-no-objects.io", - expectederror: "No objects in result", + endpoint: "https://response-no-objects.io", + expectederror: "", }, // case 4 { endpoint: "https://unknown-transfer-adapter.io", - expectederror: "Transferadapter not found: ", + expectederror: "TransferAdapter not found: ", + }, + // case 5 + { + endpoint: "https://error-in-response-objects.io", + expectederror: "Object not found", + }, + // case 6 + { + endpoint: "https://empty-actions-map.io", + expectederror: "", + }, + // case 7 + { + endpoint: "https://download-actions-map.io", + expectederror: "Missing action 'upload'", + }, + // case 8 + { + endpoint: "https://upload-actions-map.io", + expectederror: "", + }, + // case 9 + { + endpoint: "https://verify-actions-map.io", + expectederror: "Missing action 'upload'", + }, + // case 10 + { + endpoint: "https://unknown-actions-map.io", + expectederror: "Missing action 'upload'", }, } @@ -134,7 +348,9 @@ func TestHTTPClientDownload(t *testing.T) { } client.transfers["dummy"] = dummy - _, err := client.Download(context.Background(), oid, size) + err := client.Upload(context.Background(), []Pointer{p}, func(p Pointer, objectError error) (io.ReadCloser, error) { + return ioutil.NopCloser(new(bytes.Buffer)), objectError + }) if len(c.expectederror) > 0 { assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror) } else { diff --git a/modules/lfs/shared.go b/modules/lfs/shared.go index 9abbf85fbd..8343d12e1d 100644 --- a/modules/lfs/shared.go +++ b/modules/lfs/shared.go @@ -49,14 +49,14 @@ type ObjectResponse struct { Error *ObjectError `json:"error,omitempty"` } -// Link provides a structure used to build a hypermedia representation of an HTTP link. +// Link provides a structure with informations about how to access a object. type Link struct { Href string `json:"href"` Header map[string]string `json:"header,omitempty"` ExpiresAt *time.Time `json:"expires_at,omitempty"` } -// ObjectError defines the JSON structure returned to the client in case of an error +// ObjectError defines the JSON structure returned to the client in case of an error. type ObjectError struct { Code int `json:"code"` Message string `json:"message"` diff --git a/modules/lfs/transferadapter.go b/modules/lfs/transferadapter.go index ea3aff0000..8c40ab8c04 100644 --- a/modules/lfs/transferadapter.go +++ b/modules/lfs/transferadapter.go @@ -5,18 +5,24 @@ package lfs import ( + "bytes" "context" "errors" "fmt" "io" "net/http" + + "code.gitea.io/gitea/modules/log" + + jsoniter "github.com/json-iterator/go" ) // TransferAdapter represents an adapter for downloading/uploading LFS objects type TransferAdapter interface { Name() string - Download(ctx context.Context, r *ObjectResponse) (io.ReadCloser, error) - //Upload(ctx context.Context, reader io.Reader) error + Download(ctx context.Context, l *Link) (io.ReadCloser, error) + Upload(ctx context.Context, l *Link, p Pointer, r io.Reader) error + Verify(ctx context.Context, l *Link, p Pointer) error } // BasicTransferAdapter implements the "basic" adapter @@ -30,29 +36,101 @@ func (a *BasicTransferAdapter) Name() string { } // Download reads the download location and downloads the data -func (a *BasicTransferAdapter) Download(ctx context.Context, r *ObjectResponse) (io.ReadCloser, error) { - download, ok := r.Actions["download"] - if !ok { - return nil, errors.New("lfs.BasicTransferAdapter.Download: Action 'download' not found") +func (a *BasicTransferAdapter) Download(ctx context.Context, l *Link) (io.ReadCloser, error) { + resp, err := a.performRequest(ctx, "GET", l, nil, nil) + if err != nil { + return nil, err } + return resp.Body, nil +} - req, err := http.NewRequestWithContext(ctx, "GET", download.Href, nil) +// Upload sends the content to the LFS server +func (a *BasicTransferAdapter) Upload(ctx context.Context, l *Link, p Pointer, r io.Reader) error { + _, err := a.performRequest(ctx, "PUT", l, r, func(req *http.Request) { + if len(req.Header.Get("Content-Type")) == 0 { + req.Header.Set("Content-Type", "application/octet-stream") + } + + if req.Header.Get("Transfer-Encoding") == "chunked" { + req.TransferEncoding = []string{"chunked"} + } + + req.ContentLength = p.Size + }) if err != nil { - return nil, fmt.Errorf("lfs.BasicTransferAdapter.Download http.NewRequestWithContext: %w", err) + return err } - for key, value := range download.Header { + return nil +} + +// Verify calls the verify handler on the LFS server +func (a *BasicTransferAdapter) Verify(ctx context.Context, l *Link, p Pointer) error { + b, err := jsoniter.Marshal(p) + if err != nil { + log.Error("Error encoding json: %v", err) + return err + } + + _, err = a.performRequest(ctx, "POST", l, bytes.NewReader(b), func(req *http.Request) { + req.Header.Set("Content-Type", MediaType) + }) + if err != nil { + return err + } + return nil +} + +func (a *BasicTransferAdapter) performRequest(ctx context.Context, method string, l *Link, body io.Reader, callback func(*http.Request)) (*http.Response, error) { + log.Trace("Calling: %s %s", method, l.Href) + + req, err := http.NewRequestWithContext(ctx, method, l.Href, body) + if err != nil { + log.Error("Error creating request: %v", err) + return nil, err + } + for key, value := range l.Header { req.Header.Set(key, value) } + req.Header.Set("Accept", MediaType) + + if callback != nil { + callback(req) + } res, err := a.client.Do(req) if err != nil { select { case <-ctx.Done(): - return nil, ctx.Err() + return res, ctx.Err() default: } - return nil, fmt.Errorf("lfs.BasicTransferAdapter.Download http.Do: %w", err) + log.Error("Error while processing request: %v", err) + return res, err + } + + if res.StatusCode != http.StatusOK { + return res, handleErrorResponse(res) + } + + return res, nil +} + +func handleErrorResponse(resp *http.Response) error { + defer resp.Body.Close() + + er, err := decodeReponseError(resp.Body) + if err != nil { + return fmt.Errorf("Request failed with status %s", resp.Status) } + log.Trace("ErrorRespone: %v", er) + return errors.New(er.Message) +} - return res.Body, nil +func decodeReponseError(r io.Reader) (ErrorResponse, error) { + var er ErrorResponse + err := jsoniter.NewDecoder(r).Decode(&er) + if err != nil { + log.Error("Error decoding json: %v", err) + } + return er, err } diff --git a/modules/lfs/transferadapter_test.go b/modules/lfs/transferadapter_test.go index 0eabd3faee..7dfdad417e 100644 --- a/modules/lfs/transferadapter_test.go +++ b/modules/lfs/transferadapter_test.go @@ -7,11 +7,13 @@ package lfs import ( "bytes" "context" + "io" "io/ioutil" "net/http" "strings" "testing" + jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/assert" ) @@ -21,58 +23,151 @@ func TestBasicTransferAdapterName(t *testing.T) { assert.Equal(t, "basic", a.Name()) } -func TestBasicTransferAdapterDownload(t *testing.T) { +func TestBasicTransferAdapter(t *testing.T) { + p := Pointer{Oid: "b5a2c96250612366ea272ffac6d9744aaf4b45aacd96aa7cfcb931ee3b558259", Size: 5} + roundTripHandler := func(req *http.Request) *http.Response { + assert.Equal(t, MediaType, req.Header.Get("Accept")) + assert.Equal(t, "test-value", req.Header.Get("test-header")) + url := req.URL.String() - if strings.Contains(url, "valid-download-request") { + if strings.Contains(url, "download-request") { assert.Equal(t, "GET", req.Method) - assert.Equal(t, "test-value", req.Header.Get("test-header")) return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(bytes.NewBufferString("dummy"))} - } + } else if strings.Contains(url, "upload-request") { + assert.Equal(t, "PUT", req.Method) + assert.Equal(t, "application/octet-stream", req.Header.Get("Content-Type")) + + b, err := io.ReadAll(req.Body) + assert.NoError(t, err) + assert.Equal(t, "dummy", string(b)) - t.Errorf("Unknown test case: %s", url) + return &http.Response{StatusCode: http.StatusOK} + } else if strings.Contains(url, "verify-request") { + assert.Equal(t, "POST", req.Method) + assert.Equal(t, MediaType, req.Header.Get("Content-Type")) - return nil + var vp Pointer + err := jsoniter.NewDecoder(req.Body).Decode(&vp) + assert.NoError(t, err) + assert.Equal(t, p.Oid, vp.Oid) + assert.Equal(t, p.Size, vp.Size) + + return &http.Response{StatusCode: http.StatusOK} + } else if strings.Contains(url, "error-response") { + er := &ErrorResponse{ + Message: "Object not found", + } + payload := new(bytes.Buffer) + jsoniter.NewEncoder(payload).Encode(er) + + return &http.Response{StatusCode: http.StatusNotFound, Body: ioutil.NopCloser(payload)} + } else { + t.Errorf("Unknown test case: %s", url) + return nil + } } hc := &http.Client{Transport: RoundTripFunc(roundTripHandler)} a := &BasicTransferAdapter{hc} - var cases = []struct { - response *ObjectResponse - expectederror string - }{ - // case 0 - { - response: &ObjectResponse{}, - expectederror: "Action 'download' not found", - }, - // case 1 - { - response: &ObjectResponse{ - Actions: map[string]*Link{"upload": nil}, + t.Run("Download", func(t *testing.T) { + cases := []struct { + link *Link + expectederror string + }{ + // case 0 + { + link: &Link{ + Href: "https://download-request.io", + Header: map[string]string{"test-header": "test-value"}, + }, + expectederror: "", }, - expectederror: "Action 'download' not found", - }, - // case 2 - { - response: &ObjectResponse{ - Actions: map[string]*Link{"download": { - Href: "https://valid-download-request.io", + // case 1 + { + link: &Link{ + Href: "https://error-response.io", Header: map[string]string{"test-header": "test-value"}, - }}, + }, + expectederror: "Object not found", }, - expectederror: "", - }, - } + } - for n, c := range cases { - _, err := a.Download(context.Background(), c.response) - if len(c.expectederror) > 0 { - assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror) - } else { - assert.NoError(t, err, "case %d", n) + for n, c := range cases { + _, err := a.Download(context.Background(), c.link) + if len(c.expectederror) > 0 { + assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror) + } else { + assert.NoError(t, err, "case %d", n) + } } - } + }) + + t.Run("Upload", func(t *testing.T) { + cases := []struct { + link *Link + expectederror string + }{ + // case 0 + { + link: &Link{ + Href: "https://upload-request.io", + Header: map[string]string{"test-header": "test-value"}, + }, + expectederror: "", + }, + // case 1 + { + link: &Link{ + Href: "https://error-response.io", + Header: map[string]string{"test-header": "test-value"}, + }, + expectederror: "Object not found", + }, + } + + for n, c := range cases { + err := a.Upload(context.Background(), c.link, p, bytes.NewBufferString("dummy")) + if len(c.expectederror) > 0 { + assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror) + } else { + assert.NoError(t, err, "case %d", n) + } + } + }) + + t.Run("Verify", func(t *testing.T) { + cases := []struct { + link *Link + expectederror string + }{ + // case 0 + { + link: &Link{ + Href: "https://verify-request.io", + Header: map[string]string{"test-header": "test-value"}, + }, + expectederror: "", + }, + // case 1 + { + link: &Link{ + Href: "https://error-response.io", + Header: map[string]string{"test-header": "test-value"}, + }, + expectederror: "Object not found", + }, + } + + for n, c := range cases { + err := a.Verify(context.Background(), c.link, p) + if len(c.expectederror) > 0 { + assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror) + } else { + assert.NoError(t, err, "case %d", n) + } + } + }) } |