summaryrefslogtreecommitdiffstats
path: root/modules/lfs
diff options
context:
space:
mode:
Diffstat (limited to 'modules/lfs')
-rw-r--r--modules/lfs/client.go10
-rw-r--r--modules/lfs/client_test.go1
-rw-r--r--modules/lfs/filesystem_client.go63
-rw-r--r--modules/lfs/http_client.go131
-rw-r--r--modules/lfs/http_client_test.go320
-rw-r--r--modules/lfs/shared.go4
-rw-r--r--modules/lfs/transferadapter.go102
-rw-r--r--modules/lfs/transferadapter_test.go169
8 files changed, 662 insertions, 138 deletions
diff --git a/modules/lfs/client.go b/modules/lfs/client.go
index ae35919d77..0a21440f73 100644
--- a/modules/lfs/client.go
+++ b/modules/lfs/client.go
@@ -10,9 +10,17 @@ import (
"net/url"
)
+// DownloadCallback gets called for every requested LFS object to process its content
+type DownloadCallback func(p Pointer, content io.ReadCloser, objectError error) error
+
+// UploadCallback gets called for every requested LFS object to provide its content
+type UploadCallback func(p Pointer, objectError error) (io.ReadCloser, error)
+
// Client is used to communicate with a LFS source
type Client interface {
- Download(ctx context.Context, oid string, size int64) (io.ReadCloser, error)
+ BatchSize() int
+ Download(ctx context.Context, objects []Pointer, callback DownloadCallback) error
+ Upload(ctx context.Context, objects []Pointer, callback UploadCallback) error
}
// NewClient creates a LFS client
diff --git a/modules/lfs/client_test.go b/modules/lfs/client_test.go
index d4eb005469..1040b39925 100644
--- a/modules/lfs/client_test.go
+++ b/modules/lfs/client_test.go
@@ -6,7 +6,6 @@ package lfs
import (
"net/url"
-
"testing"
"github.com/stretchr/testify/assert"
diff --git a/modules/lfs/filesystem_client.go b/modules/lfs/filesystem_client.go
index 3a51564a82..dc72981a9e 100644
--- a/modules/lfs/filesystem_client.go
+++ b/modules/lfs/filesystem_client.go
@@ -19,6 +19,11 @@ type FilesystemClient struct {
lfsdir string
}
+// BatchSize returns the preferred size of batchs to process
+func (c *FilesystemClient) BatchSize() int {
+ return 1
+}
+
func newFilesystemClient(endpoint *url.URL) *FilesystemClient {
path, _ := util.FileURLToPath(endpoint)
@@ -33,18 +38,56 @@ func (c *FilesystemClient) objectPath(oid string) string {
return filepath.Join(c.lfsdir, oid[0:2], oid[2:4], oid)
}
-// Download reads the specific LFS object from the target repository
-func (c *FilesystemClient) Download(ctx context.Context, oid string, size int64) (io.ReadCloser, error) {
- objectPath := c.objectPath(oid)
+// Download reads the specific LFS object from the target path
+func (c *FilesystemClient) Download(ctx context.Context, objects []Pointer, callback DownloadCallback) error {
+ for _, object := range objects {
+ p := Pointer{object.Oid, object.Size}
- if _, err := os.Stat(objectPath); os.IsNotExist(err) {
- return nil, err
- }
+ objectPath := c.objectPath(p.Oid)
+
+ f, err := os.Open(objectPath)
+ if err != nil {
+ return err
+ }
- file, err := os.Open(objectPath)
- if err != nil {
- return nil, err
+ if err := callback(p, f, nil); err != nil {
+ return err
+ }
}
+ return nil
+}
+
+// Upload writes the specific LFS object to the target path
+func (c *FilesystemClient) Upload(ctx context.Context, objects []Pointer, callback UploadCallback) error {
+ for _, object := range objects {
+ p := Pointer{object.Oid, object.Size}
+
+ objectPath := c.objectPath(p.Oid)
- return file, nil
+ if err := os.MkdirAll(filepath.Dir(objectPath), os.ModePerm); err != nil {
+ return err
+ }
+
+ content, err := callback(p, nil)
+ if err != nil {
+ return err
+ }
+
+ err = func() error {
+ defer content.Close()
+
+ f, err := os.Create(objectPath)
+ if err != nil {
+ return err
+ }
+
+ _, err = io.Copy(f, content)
+
+ return err
+ }()
+ if err != nil {
+ return err
+ }
+ }
+ return nil
}
diff --git a/modules/lfs/http_client.go b/modules/lfs/http_client.go
index fb45defda1..e799b80831 100644
--- a/modules/lfs/http_client.go
+++ b/modules/lfs/http_client.go
@@ -7,17 +7,19 @@ package lfs
import (
"bytes"
"context"
- "encoding/json"
"errors"
"fmt"
- "io"
"net/http"
"net/url"
"strings"
"code.gitea.io/gitea/modules/log"
+
+ jsoniter "github.com/json-iterator/go"
)
+const batchSize = 20
+
// HTTPClient is used to communicate with the LFS server
// https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md
type HTTPClient struct {
@@ -26,6 +28,11 @@ type HTTPClient struct {
transfers map[string]TransferAdapter
}
+// BatchSize returns the preferred size of batchs to process
+func (c *HTTPClient) BatchSize() int {
+ return batchSize
+}
+
func newHTTPClient(endpoint *url.URL) *HTTPClient {
hc := &http.Client{}
@@ -55,21 +62,25 @@ func (c *HTTPClient) transferNames() []string {
}
func (c *HTTPClient) batch(ctx context.Context, operation string, objects []Pointer) (*BatchResponse, error) {
+ log.Trace("BATCH operation with objects: %v", objects)
+
url := fmt.Sprintf("%s/objects/batch", c.endpoint)
request := &BatchRequest{operation, c.transferNames(), nil, objects}
payload := new(bytes.Buffer)
- err := json.NewEncoder(payload).Encode(request)
+ err := jsoniter.NewEncoder(payload).Encode(request)
if err != nil {
- return nil, fmt.Errorf("lfs.HTTPClient.batch json.Encode: %w", err)
+ log.Error("Error encoding json: %v", err)
+ return nil, err
}
- log.Trace("lfs.HTTPClient.batch NewRequestWithContext: %s", url)
+ log.Trace("Calling: %s", url)
req, err := http.NewRequestWithContext(ctx, "POST", url, payload)
if err != nil {
- return nil, fmt.Errorf("lfs.HTTPClient.batch http.NewRequestWithContext: %w", err)
+ log.Error("Error creating request: %v", err)
+ return nil, err
}
req.Header.Set("Content-type", MediaType)
req.Header.Set("Accept", MediaType)
@@ -81,18 +92,20 @@ func (c *HTTPClient) batch(ctx context.Context, operation string, objects []Poin
return nil, ctx.Err()
default:
}
- return nil, fmt.Errorf("lfs.HTTPClient.batch http.Do: %w", err)
+ log.Error("Error while processing request: %v", err)
+ return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("lfs.HTTPClient.batch: Unexpected servers response: %s", res.Status)
+ return nil, fmt.Errorf("Unexpected server response: %s", res.Status)
}
var response BatchResponse
- err = json.NewDecoder(res.Body).Decode(&response)
+ err = jsoniter.NewDecoder(res.Body).Decode(&response)
if err != nil {
- return nil, fmt.Errorf("lfs.HTTPClient.batch json.Decode: %w", err)
+ log.Error("Error decoding json: %v", err)
+ return nil, err
}
if len(response.Transfer) == 0 {
@@ -103,27 +116,99 @@ func (c *HTTPClient) batch(ctx context.Context, operation string, objects []Poin
}
// Download reads the specific LFS object from the LFS server
-func (c *HTTPClient) Download(ctx context.Context, oid string, size int64) (io.ReadCloser, error) {
- var objects []Pointer
- objects = append(objects, Pointer{oid, size})
+func (c *HTTPClient) Download(ctx context.Context, objects []Pointer, callback DownloadCallback) error {
+ return c.performOperation(ctx, objects, callback, nil)
+}
+
+// Upload sends the specific LFS object to the LFS server
+func (c *HTTPClient) Upload(ctx context.Context, objects []Pointer, callback UploadCallback) error {
+ return c.performOperation(ctx, objects, nil, callback)
+}
- result, err := c.batch(ctx, "download", objects)
+func (c *HTTPClient) performOperation(ctx context.Context, objects []Pointer, dc DownloadCallback, uc UploadCallback) error {
+ if len(objects) == 0 {
+ return nil
+ }
+
+ operation := "download"
+ if uc != nil {
+ operation = "upload"
+ }
+
+ result, err := c.batch(ctx, operation, objects)
if err != nil {
- return nil, err
+ return err
}
transferAdapter, ok := c.transfers[result.Transfer]
if !ok {
- return nil, fmt.Errorf("lfs.HTTPClient.Download Transferadapter not found: %s", result.Transfer)
+ return fmt.Errorf("TransferAdapter not found: %s", result.Transfer)
}
- if len(result.Objects) == 0 {
- return nil, errors.New("lfs.HTTPClient.Download: No objects in result")
- }
+ for _, object := range result.Objects {
+ if object.Error != nil {
+ objectError := errors.New(object.Error.Message)
+ log.Trace("Error on object %v: %v", object.Pointer, objectError)
+ if uc != nil {
+ if _, err := uc(object.Pointer, objectError); err != nil {
+ return err
+ }
+ } else {
+ if err := dc(object.Pointer, nil, objectError); err != nil {
+ return err
+ }
+ }
+ continue
+ }
- content, err := transferAdapter.Download(ctx, result.Objects[0])
- if err != nil {
- return nil, err
+ if uc != nil {
+ if len(object.Actions) == 0 {
+ log.Trace("%v already present on server", object.Pointer)
+ continue
+ }
+
+ link, ok := object.Actions["upload"]
+ if !ok {
+ log.Debug("%+v", object)
+ return errors.New("Missing action 'upload'")
+ }
+
+ content, err := uc(object.Pointer, nil)
+ if err != nil {
+ return err
+ }
+
+ err = transferAdapter.Upload(ctx, link, object.Pointer, content)
+
+ content.Close()
+
+ if err != nil {
+ return err
+ }
+
+ link, ok = object.Actions["verify"]
+ if ok {
+ if err := transferAdapter.Verify(ctx, link, object.Pointer); err != nil {
+ return err
+ }
+ }
+ } else {
+ link, ok := object.Actions["download"]
+ if !ok {
+ log.Debug("%+v", object)
+ return errors.New("Missing action 'download'")
+ }
+
+ content, err := transferAdapter.Download(ctx, link)
+ if err != nil {
+ return err
+ }
+
+ if err := dc(object.Pointer, content, nil); err != nil {
+ return err
+ }
+ }
}
- return content, nil
+
+ return nil
}
diff --git a/modules/lfs/http_client_test.go b/modules/lfs/http_client_test.go
index 68ec947aa8..0f633ede54 100644
--- a/modules/lfs/http_client_test.go
+++ b/modules/lfs/http_client_test.go
@@ -7,13 +7,13 @@ package lfs
import (
"bytes"
"context"
- "encoding/json"
"io"
"io/ioutil"
"net/http"
"strings"
"testing"
+ jsoniter "github.com/json-iterator/go"
"github.com/stretchr/testify/assert"
)
@@ -30,69 +30,253 @@ func (a *DummyTransferAdapter) Name() string {
return "dummy"
}
-func (a *DummyTransferAdapter) Download(ctx context.Context, r *ObjectResponse) (io.ReadCloser, error) {
+func (a *DummyTransferAdapter) Download(ctx context.Context, l *Link) (io.ReadCloser, error) {
return ioutil.NopCloser(bytes.NewBufferString("dummy")), nil
}
-func TestHTTPClientDownload(t *testing.T) {
- oid := "fb8f7d8435968c4f82a726a92395be4d16f2f63116caf36c8ad35c60831ab041"
- size := int64(6)
+func (a *DummyTransferAdapter) Upload(ctx context.Context, l *Link, p Pointer, r io.Reader) error {
+ return nil
+}
+
+func (a *DummyTransferAdapter) Verify(ctx context.Context, l *Link, p Pointer) error {
+ return nil
+}
+
+func lfsTestRoundtripHandler(req *http.Request) *http.Response {
+ var batchResponse *BatchResponse
+ url := req.URL.String()
- roundTripHandler := func(req *http.Request) *http.Response {
- url := req.URL.String()
- if strings.Contains(url, "status-not-ok") {
- return &http.Response{StatusCode: http.StatusBadRequest}
+ if strings.Contains(url, "status-not-ok") {
+ return &http.Response{StatusCode: http.StatusBadRequest}
+ } else if strings.Contains(url, "invalid-json-response") {
+ return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(bytes.NewBufferString("invalid json"))}
+ } else if strings.Contains(url, "valid-batch-request-download") {
+ batchResponse = &BatchResponse{
+ Transfer: "dummy",
+ Objects: []*ObjectResponse{
+ {
+ Actions: map[string]*Link{
+ "download": {},
+ },
+ },
+ },
+ }
+ } else if strings.Contains(url, "valid-batch-request-upload") {
+ batchResponse = &BatchResponse{
+ Transfer: "dummy",
+ Objects: []*ObjectResponse{
+ {
+ Actions: map[string]*Link{
+ "upload": {},
+ },
+ },
+ },
}
- if strings.Contains(url, "invalid-json-response") {
- return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(bytes.NewBufferString("invalid json"))}
+ } else if strings.Contains(url, "response-no-objects") {
+ batchResponse = &BatchResponse{Transfer: "dummy"}
+ } else if strings.Contains(url, "unknown-transfer-adapter") {
+ batchResponse = &BatchResponse{Transfer: "unknown_adapter"}
+ } else if strings.Contains(url, "error-in-response-objects") {
+ batchResponse = &BatchResponse{
+ Transfer: "dummy",
+ Objects: []*ObjectResponse{
+ {
+ Error: &ObjectError{
+ Code: 404,
+ Message: "Object not found",
+ },
+ },
+ },
}
- if strings.Contains(url, "valid-batch-request-download") {
- assert.Equal(t, "POST", req.Method)
- assert.Equal(t, MediaType, req.Header.Get("Content-type"), "case %s: error should match", url)
- assert.Equal(t, MediaType, req.Header.Get("Accept"), "case %s: error should match", url)
+ } else if strings.Contains(url, "empty-actions-map") {
+ batchResponse = &BatchResponse{
+ Transfer: "dummy",
+ Objects: []*ObjectResponse{
+ {
+ Actions: map[string]*Link{},
+ },
+ },
+ }
+ } else if strings.Contains(url, "download-actions-map") {
+ batchResponse = &BatchResponse{
+ Transfer: "dummy",
+ Objects: []*ObjectResponse{
+ {
+ Actions: map[string]*Link{
+ "download": {},
+ },
+ },
+ },
+ }
+ } else if strings.Contains(url, "upload-actions-map") {
+ batchResponse = &BatchResponse{
+ Transfer: "dummy",
+ Objects: []*ObjectResponse{
+ {
+ Actions: map[string]*Link{
+ "upload": {},
+ },
+ },
+ },
+ }
+ } else if strings.Contains(url, "verify-actions-map") {
+ batchResponse = &BatchResponse{
+ Transfer: "dummy",
+ Objects: []*ObjectResponse{
+ {
+ Actions: map[string]*Link{
+ "verify": {},
+ },
+ },
+ },
+ }
+ } else if strings.Contains(url, "unknown-actions-map") {
+ batchResponse = &BatchResponse{
+ Transfer: "dummy",
+ Objects: []*ObjectResponse{
+ {
+ Actions: map[string]*Link{
+ "unknown": {},
+ },
+ },
+ },
+ }
+ } else {
+ return nil
+ }
- var batchRequest BatchRequest
- err := json.NewDecoder(req.Body).Decode(&batchRequest)
- assert.NoError(t, err)
+ payload := new(bytes.Buffer)
+ jsoniter.NewEncoder(payload).Encode(batchResponse)
- assert.Equal(t, "download", batchRequest.Operation)
- assert.Len(t, batchRequest.Objects, 1)
- assert.Equal(t, oid, batchRequest.Objects[0].Oid)
- assert.Equal(t, size, batchRequest.Objects[0].Size)
+ return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(payload)}
+}
- batchResponse := &BatchResponse{
- Transfer: "dummy",
- Objects: make([]*ObjectResponse, 1),
- }
+func TestHTTPClientDownload(t *testing.T) {
+ p := Pointer{Oid: "fb8f7d8435968c4f82a726a92395be4d16f2f63116caf36c8ad35c60831ab041", Size: 6}
- payload := new(bytes.Buffer)
- json.NewEncoder(payload).Encode(batchResponse)
+ hc := &http.Client{Transport: RoundTripFunc(func(req *http.Request) *http.Response {
+ assert.Equal(t, "POST", req.Method)
+ assert.Equal(t, MediaType, req.Header.Get("Content-type"))
+ assert.Equal(t, MediaType, req.Header.Get("Accept"))
- return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(payload)}
- }
- if strings.Contains(url, "invalid-response-no-objects") {
- batchResponse := &BatchResponse{Transfer: "dummy"}
+ var batchRequest BatchRequest
+ err := jsoniter.NewDecoder(req.Body).Decode(&batchRequest)
+ assert.NoError(t, err)
- payload := new(bytes.Buffer)
- json.NewEncoder(payload).Encode(batchResponse)
+ assert.Equal(t, "download", batchRequest.Operation)
+ assert.Equal(t, 1, len(batchRequest.Objects))
+ assert.Equal(t, p.Oid, batchRequest.Objects[0].Oid)
+ assert.Equal(t, p.Size, batchRequest.Objects[0].Size)
- return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(payload)}
- }
- if strings.Contains(url, "unknown-transfer-adapter") {
- batchResponse := &BatchResponse{Transfer: "unknown_adapter"}
+ return lfsTestRoundtripHandler(req)
+ })}
+ dummy := &DummyTransferAdapter{}
- payload := new(bytes.Buffer)
- json.NewEncoder(payload).Encode(batchResponse)
+ var cases = []struct {
+ endpoint string
+ expectederror string
+ }{
+ // case 0
+ {
+ endpoint: "https://status-not-ok.io",
+ expectederror: "Unexpected server response: ",
+ },
+ // case 1
+ {
+ endpoint: "https://invalid-json-response.io",
+ expectederror: "invalid json",
+ },
+ // case 2
+ {
+ endpoint: "https://valid-batch-request-download.io",
+ expectederror: "",
+ },
+ // case 3
+ {
+ endpoint: "https://response-no-objects.io",
+ expectederror: "",
+ },
+ // case 4
+ {
+ endpoint: "https://unknown-transfer-adapter.io",
+ expectederror: "TransferAdapter not found: ",
+ },
+ // case 5
+ {
+ endpoint: "https://error-in-response-objects.io",
+ expectederror: "Object not found",
+ },
+ // case 6
+ {
+ endpoint: "https://empty-actions-map.io",
+ expectederror: "Missing action 'download'",
+ },
+ // case 7
+ {
+ endpoint: "https://download-actions-map.io",
+ expectederror: "",
+ },
+ // case 8
+ {
+ endpoint: "https://upload-actions-map.io",
+ expectederror: "Missing action 'download'",
+ },
+ // case 9
+ {
+ endpoint: "https://verify-actions-map.io",
+ expectederror: "Missing action 'download'",
+ },
+ // case 10
+ {
+ endpoint: "https://unknown-actions-map.io",
+ expectederror: "Missing action 'download'",
+ },
+ }
- return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(payload)}
+ for n, c := range cases {
+ client := &HTTPClient{
+ client: hc,
+ endpoint: c.endpoint,
+ transfers: make(map[string]TransferAdapter),
}
+ client.transfers["dummy"] = dummy
- t.Errorf("Unknown test case: %s", url)
-
- return nil
+ err := client.Download(context.Background(), []Pointer{p}, func(p Pointer, content io.ReadCloser, objectError error) error {
+ if objectError != nil {
+ return objectError
+ }
+ b, err := io.ReadAll(content)
+ assert.NoError(t, err)
+ assert.Equal(t, []byte("dummy"), b)
+ return nil
+ })
+ if len(c.expectederror) > 0 {
+ assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror)
+ } else {
+ assert.NoError(t, err, "case %d", n)
+ }
}
+}
+
+func TestHTTPClientUpload(t *testing.T) {
+ p := Pointer{Oid: "fb8f7d8435968c4f82a726a92395be4d16f2f63116caf36c8ad35c60831ab041", Size: 6}
+
+ hc := &http.Client{Transport: RoundTripFunc(func(req *http.Request) *http.Response {
+ assert.Equal(t, "POST", req.Method)
+ assert.Equal(t, MediaType, req.Header.Get("Content-type"))
+ assert.Equal(t, MediaType, req.Header.Get("Accept"))
+
+ var batchRequest BatchRequest
+ err := jsoniter.NewDecoder(req.Body).Decode(&batchRequest)
+ assert.NoError(t, err)
- hc := &http.Client{Transport: RoundTripFunc(roundTripHandler)}
+ assert.Equal(t, "upload", batchRequest.Operation)
+ assert.Equal(t, 1, len(batchRequest.Objects))
+ assert.Equal(t, p.Oid, batchRequest.Objects[0].Oid)
+ assert.Equal(t, p.Size, batchRequest.Objects[0].Size)
+
+ return lfsTestRoundtripHandler(req)
+ })}
dummy := &DummyTransferAdapter{}
var cases = []struct {
@@ -102,27 +286,57 @@ func TestHTTPClientDownload(t *testing.T) {
// case 0
{
endpoint: "https://status-not-ok.io",
- expectederror: "Unexpected servers response: ",
+ expectederror: "Unexpected server response: ",
},
// case 1
{
endpoint: "https://invalid-json-response.io",
- expectederror: "json.Decode: ",
+ expectederror: "invalid json",
},
// case 2
{
- endpoint: "https://valid-batch-request-download.io",
+ endpoint: "https://valid-batch-request-upload.io",
expectederror: "",
},
// case 3
{
- endpoint: "https://invalid-response-no-objects.io",
- expectederror: "No objects in result",
+ endpoint: "https://response-no-objects.io",
+ expectederror: "",
},
// case 4
{
endpoint: "https://unknown-transfer-adapter.io",
- expectederror: "Transferadapter not found: ",
+ expectederror: "TransferAdapter not found: ",
+ },
+ // case 5
+ {
+ endpoint: "https://error-in-response-objects.io",
+ expectederror: "Object not found",
+ },
+ // case 6
+ {
+ endpoint: "https://empty-actions-map.io",
+ expectederror: "",
+ },
+ // case 7
+ {
+ endpoint: "https://download-actions-map.io",
+ expectederror: "Missing action 'upload'",
+ },
+ // case 8
+ {
+ endpoint: "https://upload-actions-map.io",
+ expectederror: "",
+ },
+ // case 9
+ {
+ endpoint: "https://verify-actions-map.io",
+ expectederror: "Missing action 'upload'",
+ },
+ // case 10
+ {
+ endpoint: "https://unknown-actions-map.io",
+ expectederror: "Missing action 'upload'",
},
}
@@ -134,7 +348,9 @@ func TestHTTPClientDownload(t *testing.T) {
}
client.transfers["dummy"] = dummy
- _, err := client.Download(context.Background(), oid, size)
+ err := client.Upload(context.Background(), []Pointer{p}, func(p Pointer, objectError error) (io.ReadCloser, error) {
+ return ioutil.NopCloser(new(bytes.Buffer)), objectError
+ })
if len(c.expectederror) > 0 {
assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror)
} else {
diff --git a/modules/lfs/shared.go b/modules/lfs/shared.go
index 9abbf85fbd..8343d12e1d 100644
--- a/modules/lfs/shared.go
+++ b/modules/lfs/shared.go
@@ -49,14 +49,14 @@ type ObjectResponse struct {
Error *ObjectError `json:"error,omitempty"`
}
-// Link provides a structure used to build a hypermedia representation of an HTTP link.
+// Link provides a structure with informations about how to access a object.
type Link struct {
Href string `json:"href"`
Header map[string]string `json:"header,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
}
-// ObjectError defines the JSON structure returned to the client in case of an error
+// ObjectError defines the JSON structure returned to the client in case of an error.
type ObjectError struct {
Code int `json:"code"`
Message string `json:"message"`
diff --git a/modules/lfs/transferadapter.go b/modules/lfs/transferadapter.go
index ea3aff0000..8c40ab8c04 100644
--- a/modules/lfs/transferadapter.go
+++ b/modules/lfs/transferadapter.go
@@ -5,18 +5,24 @@
package lfs
import (
+ "bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
+
+ "code.gitea.io/gitea/modules/log"
+
+ jsoniter "github.com/json-iterator/go"
)
// TransferAdapter represents an adapter for downloading/uploading LFS objects
type TransferAdapter interface {
Name() string
- Download(ctx context.Context, r *ObjectResponse) (io.ReadCloser, error)
- //Upload(ctx context.Context, reader io.Reader) error
+ Download(ctx context.Context, l *Link) (io.ReadCloser, error)
+ Upload(ctx context.Context, l *Link, p Pointer, r io.Reader) error
+ Verify(ctx context.Context, l *Link, p Pointer) error
}
// BasicTransferAdapter implements the "basic" adapter
@@ -30,29 +36,101 @@ func (a *BasicTransferAdapter) Name() string {
}
// Download reads the download location and downloads the data
-func (a *BasicTransferAdapter) Download(ctx context.Context, r *ObjectResponse) (io.ReadCloser, error) {
- download, ok := r.Actions["download"]
- if !ok {
- return nil, errors.New("lfs.BasicTransferAdapter.Download: Action 'download' not found")
+func (a *BasicTransferAdapter) Download(ctx context.Context, l *Link) (io.ReadCloser, error) {
+ resp, err := a.performRequest(ctx, "GET", l, nil, nil)
+ if err != nil {
+ return nil, err
}
+ return resp.Body, nil
+}
- req, err := http.NewRequestWithContext(ctx, "GET", download.Href, nil)
+// Upload sends the content to the LFS server
+func (a *BasicTransferAdapter) Upload(ctx context.Context, l *Link, p Pointer, r io.Reader) error {
+ _, err := a.performRequest(ctx, "PUT", l, r, func(req *http.Request) {
+ if len(req.Header.Get("Content-Type")) == 0 {
+ req.Header.Set("Content-Type", "application/octet-stream")
+ }
+
+ if req.Header.Get("Transfer-Encoding") == "chunked" {
+ req.TransferEncoding = []string{"chunked"}
+ }
+
+ req.ContentLength = p.Size
+ })
if err != nil {
- return nil, fmt.Errorf("lfs.BasicTransferAdapter.Download http.NewRequestWithContext: %w", err)
+ return err
}
- for key, value := range download.Header {
+ return nil
+}
+
+// Verify calls the verify handler on the LFS server
+func (a *BasicTransferAdapter) Verify(ctx context.Context, l *Link, p Pointer) error {
+ b, err := jsoniter.Marshal(p)
+ if err != nil {
+ log.Error("Error encoding json: %v", err)
+ return err
+ }
+
+ _, err = a.performRequest(ctx, "POST", l, bytes.NewReader(b), func(req *http.Request) {
+ req.Header.Set("Content-Type", MediaType)
+ })
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (a *BasicTransferAdapter) performRequest(ctx context.Context, method string, l *Link, body io.Reader, callback func(*http.Request)) (*http.Response, error) {
+ log.Trace("Calling: %s %s", method, l.Href)
+
+ req, err := http.NewRequestWithContext(ctx, method, l.Href, body)
+ if err != nil {
+ log.Error("Error creating request: %v", err)
+ return nil, err
+ }
+ for key, value := range l.Header {
req.Header.Set(key, value)
}
+ req.Header.Set("Accept", MediaType)
+
+ if callback != nil {
+ callback(req)
+ }
res, err := a.client.Do(req)
if err != nil {
select {
case <-ctx.Done():
- return nil, ctx.Err()
+ return res, ctx.Err()
default:
}
- return nil, fmt.Errorf("lfs.BasicTransferAdapter.Download http.Do: %w", err)
+ log.Error("Error while processing request: %v", err)
+ return res, err
+ }
+
+ if res.StatusCode != http.StatusOK {
+ return res, handleErrorResponse(res)
+ }
+
+ return res, nil
+}
+
+func handleErrorResponse(resp *http.Response) error {
+ defer resp.Body.Close()
+
+ er, err := decodeReponseError(resp.Body)
+ if err != nil {
+ return fmt.Errorf("Request failed with status %s", resp.Status)
}
+ log.Trace("ErrorRespone: %v", er)
+ return errors.New(er.Message)
+}
- return res.Body, nil
+func decodeReponseError(r io.Reader) (ErrorResponse, error) {
+ var er ErrorResponse
+ err := jsoniter.NewDecoder(r).Decode(&er)
+ if err != nil {
+ log.Error("Error decoding json: %v", err)
+ }
+ return er, err
}
diff --git a/modules/lfs/transferadapter_test.go b/modules/lfs/transferadapter_test.go
index 0eabd3faee..7dfdad417e 100644
--- a/modules/lfs/transferadapter_test.go
+++ b/modules/lfs/transferadapter_test.go
@@ -7,11 +7,13 @@ package lfs
import (
"bytes"
"context"
+ "io"
"io/ioutil"
"net/http"
"strings"
"testing"
+ jsoniter "github.com/json-iterator/go"
"github.com/stretchr/testify/assert"
)
@@ -21,58 +23,151 @@ func TestBasicTransferAdapterName(t *testing.T) {
assert.Equal(t, "basic", a.Name())
}
-func TestBasicTransferAdapterDownload(t *testing.T) {
+func TestBasicTransferAdapter(t *testing.T) {
+ p := Pointer{Oid: "b5a2c96250612366ea272ffac6d9744aaf4b45aacd96aa7cfcb931ee3b558259", Size: 5}
+
roundTripHandler := func(req *http.Request) *http.Response {
+ assert.Equal(t, MediaType, req.Header.Get("Accept"))
+ assert.Equal(t, "test-value", req.Header.Get("test-header"))
+
url := req.URL.String()
- if strings.Contains(url, "valid-download-request") {
+ if strings.Contains(url, "download-request") {
assert.Equal(t, "GET", req.Method)
- assert.Equal(t, "test-value", req.Header.Get("test-header"))
return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(bytes.NewBufferString("dummy"))}
- }
+ } else if strings.Contains(url, "upload-request") {
+ assert.Equal(t, "PUT", req.Method)
+ assert.Equal(t, "application/octet-stream", req.Header.Get("Content-Type"))
+
+ b, err := io.ReadAll(req.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, "dummy", string(b))
- t.Errorf("Unknown test case: %s", url)
+ return &http.Response{StatusCode: http.StatusOK}
+ } else if strings.Contains(url, "verify-request") {
+ assert.Equal(t, "POST", req.Method)
+ assert.Equal(t, MediaType, req.Header.Get("Content-Type"))
- return nil
+ var vp Pointer
+ err := jsoniter.NewDecoder(req.Body).Decode(&vp)
+ assert.NoError(t, err)
+ assert.Equal(t, p.Oid, vp.Oid)
+ assert.Equal(t, p.Size, vp.Size)
+
+ return &http.Response{StatusCode: http.StatusOK}
+ } else if strings.Contains(url, "error-response") {
+ er := &ErrorResponse{
+ Message: "Object not found",
+ }
+ payload := new(bytes.Buffer)
+ jsoniter.NewEncoder(payload).Encode(er)
+
+ return &http.Response{StatusCode: http.StatusNotFound, Body: ioutil.NopCloser(payload)}
+ } else {
+ t.Errorf("Unknown test case: %s", url)
+ return nil
+ }
}
hc := &http.Client{Transport: RoundTripFunc(roundTripHandler)}
a := &BasicTransferAdapter{hc}
- var cases = []struct {
- response *ObjectResponse
- expectederror string
- }{
- // case 0
- {
- response: &ObjectResponse{},
- expectederror: "Action 'download' not found",
- },
- // case 1
- {
- response: &ObjectResponse{
- Actions: map[string]*Link{"upload": nil},
+ t.Run("Download", func(t *testing.T) {
+ cases := []struct {
+ link *Link
+ expectederror string
+ }{
+ // case 0
+ {
+ link: &Link{
+ Href: "https://download-request.io",
+ Header: map[string]string{"test-header": "test-value"},
+ },
+ expectederror: "",
},
- expectederror: "Action 'download' not found",
- },
- // case 2
- {
- response: &ObjectResponse{
- Actions: map[string]*Link{"download": {
- Href: "https://valid-download-request.io",
+ // case 1
+ {
+ link: &Link{
+ Href: "https://error-response.io",
Header: map[string]string{"test-header": "test-value"},
- }},
+ },
+ expectederror: "Object not found",
},
- expectederror: "",
- },
- }
+ }
- for n, c := range cases {
- _, err := a.Download(context.Background(), c.response)
- if len(c.expectederror) > 0 {
- assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror)
- } else {
- assert.NoError(t, err, "case %d", n)
+ for n, c := range cases {
+ _, err := a.Download(context.Background(), c.link)
+ if len(c.expectederror) > 0 {
+ assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror)
+ } else {
+ assert.NoError(t, err, "case %d", n)
+ }
}
- }
+ })
+
+ t.Run("Upload", func(t *testing.T) {
+ cases := []struct {
+ link *Link
+ expectederror string
+ }{
+ // case 0
+ {
+ link: &Link{
+ Href: "https://upload-request.io",
+ Header: map[string]string{"test-header": "test-value"},
+ },
+ expectederror: "",
+ },
+ // case 1
+ {
+ link: &Link{
+ Href: "https://error-response.io",
+ Header: map[string]string{"test-header": "test-value"},
+ },
+ expectederror: "Object not found",
+ },
+ }
+
+ for n, c := range cases {
+ err := a.Upload(context.Background(), c.link, p, bytes.NewBufferString("dummy"))
+ if len(c.expectederror) > 0 {
+ assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror)
+ } else {
+ assert.NoError(t, err, "case %d", n)
+ }
+ }
+ })
+
+ t.Run("Verify", func(t *testing.T) {
+ cases := []struct {
+ link *Link
+ expectederror string
+ }{
+ // case 0
+ {
+ link: &Link{
+ Href: "https://verify-request.io",
+ Header: map[string]string{"test-header": "test-value"},
+ },
+ expectederror: "",
+ },
+ // case 1
+ {
+ link: &Link{
+ Href: "https://error-response.io",
+ Header: map[string]string{"test-header": "test-value"},
+ },
+ expectederror: "Object not found",
+ },
+ }
+
+ for n, c := range cases {
+ err := a.Verify(context.Background(), c.link, p)
+ if len(c.expectederror) > 0 {
+ assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror)
+ } else {
+ assert.NoError(t, err, "case %d", n)
+ }
+ }
+ })
}