diff options
author | KN4CK3R <admin@oldschoolhack.me> | 2021-06-14 19:20:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-06-14 19:20:43 +0200 |
commit | 440039c0cce18622b12da5677bf6585caed6070a (patch) | |
tree | 8f8532a2d40983b35b3fdb5460b47218b26bbd89 /modules | |
parent | 5d113bdd1905c73fb8071f420ae2d248202971f9 (diff) | |
download | gitea-440039c0cce18622b12da5677bf6585caed6070a.tar.gz gitea-440039c0cce18622b12da5677bf6585caed6070a.zip |
Add push to remote mirror repository (#15157)
* Added push mirror model.
* Integrated push mirror into queue.
* Moved methods into own file.
* Added basic implementation.
* Mirror wiki too.
* Removed duplicated method.
* Get url for different remotes.
* Added migration.
* Unified remote url access.
* Add/Remove push mirror remotes.
* Prevent hangs with missing credentials.
* Moved code between files.
* Changed sanitizer interface.
* Added push mirror backend methods.
* Only update the mirror remote.
* Limit refs on push.
* Added UI part.
* Added missing table.
* Delete mirror if repository gets removed.
* Changed signature. Handle object errors.
* Added upload method.
* Added "upload" unit tests.
* Added transfer adapter unit tests.
* Send correct headers.
* Added pushing of LFS objects.
* Added more logging.
* Simpler body handling.
* Process files in batches to reduce HTTP calls.
* Added created timestamp.
* Fixed invalid column name.
* Changed name to prevent xorm auto setting.
* Remove table header im empty.
* Strip exit code from error message.
* Added docs page about mirroring.
* Fixed date.
* Fixed merge errors.
* Moved test to integrations.
* Added push mirror test.
* Added test.
Diffstat (limited to 'modules')
-rw-r--r-- | modules/context/repo.go | 6 | ||||
-rw-r--r-- | modules/git/remote.go | 31 | ||||
-rw-r--r-- | modules/git/repo.go | 24 | ||||
-rw-r--r-- | modules/lfs/client.go | 10 | ||||
-rw-r--r-- | modules/lfs/client_test.go | 1 | ||||
-rw-r--r-- | modules/lfs/filesystem_client.go | 63 | ||||
-rw-r--r-- | modules/lfs/http_client.go | 131 | ||||
-rw-r--r-- | modules/lfs/http_client_test.go | 320 | ||||
-rw-r--r-- | modules/lfs/shared.go | 4 | ||||
-rw-r--r-- | modules/lfs/transferadapter.go | 102 | ||||
-rw-r--r-- | modules/lfs/transferadapter_test.go | 169 | ||||
-rw-r--r-- | modules/repository/repo.go | 115 | ||||
-rw-r--r-- | modules/task/migrate.go | 2 | ||||
-rw-r--r-- | modules/task/task.go | 2 | ||||
-rw-r--r-- | modules/templates/helper.go | 34 | ||||
-rw-r--r-- | modules/util/sanitize.go | 63 | ||||
-rw-r--r-- | modules/util/sanitize_test.go | 159 |
17 files changed, 1004 insertions, 232 deletions
diff --git a/modules/context/repo.go b/modules/context/repo.go index 3e48b34b3d..72d1cf4c85 100644 --- a/modules/context/repo.go +++ b/modules/context/repo.go @@ -360,13 +360,17 @@ func repoAssignment(ctx *Context, repo *models.Repository) { var err error ctx.Repo.Mirror, err = models.GetMirrorByRepoID(repo.ID) if err != nil { - ctx.ServerError("GetMirror", err) + ctx.ServerError("GetMirrorByRepoID", err) return } ctx.Data["MirrorEnablePrune"] = ctx.Repo.Mirror.EnablePrune ctx.Data["MirrorInterval"] = ctx.Repo.Mirror.Interval ctx.Data["Mirror"] = ctx.Repo.Mirror } + if err = repo.LoadPushMirrors(); err != nil { + ctx.ServerError("LoadPushMirrors", err) + return + } ctx.Repo.Repository = repo ctx.Data["RepoName"] = ctx.Repo.Repository.Name diff --git a/modules/git/remote.go b/modules/git/remote.go new file mode 100644 index 0000000000..7ba2b35a5e --- /dev/null +++ b/modules/git/remote.go @@ -0,0 +1,31 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package git + +import "net/url" + +// GetRemoteAddress returns the url of a specific remote of the repository. +func GetRemoteAddress(repoPath, remoteName string) (*url.URL, error) { + err := LoadGitVersion() + if err != nil { + return nil, err + } + var cmd *Command + if CheckGitVersionAtLeast("2.7") == nil { + cmd = NewCommand("remote", "get-url", remoteName) + } else { + cmd = NewCommand("config", "--get", "remote."+remoteName+".url") + } + + result, err := cmd.RunInDir(repoPath) + if err != nil { + return nil, err + } + + if len(result) > 0 { + result = result[:len(result)-1] + } + return url.Parse(result) +} diff --git a/modules/git/repo.go b/modules/git/repo.go index 515899ab04..e06cd43935 100644 --- a/modules/git/repo.go +++ b/modules/git/repo.go @@ -182,10 +182,12 @@ func Pull(repoPath string, opts PullRemoteOptions) error { // PushOptions options when push to remote type PushOptions struct { - Remote string - Branch string - Force bool - Env []string + Remote string + Branch string + Force bool + Mirror bool + Env []string + Timeout time.Duration } // Push pushs local commits to given remote branch. @@ -194,10 +196,20 @@ func Push(repoPath string, opts PushOptions) error { if opts.Force { cmd.AddArguments("-f") } - cmd.AddArguments("--", opts.Remote, opts.Branch) + if opts.Mirror { + cmd.AddArguments("--mirror") + } + cmd.AddArguments("--", opts.Remote) + if len(opts.Branch) > 0 { + cmd.AddArguments(opts.Branch) + } var outbuf, errbuf strings.Builder - err := cmd.RunInDirTimeoutEnvPipeline(opts.Env, -1, repoPath, &outbuf, &errbuf) + if opts.Timeout == 0 { + opts.Timeout = -1 + } + + err := cmd.RunInDirTimeoutEnvPipeline(opts.Env, opts.Timeout, repoPath, &outbuf, &errbuf) if err != nil { if strings.Contains(errbuf.String(), "non-fast-forward") { return &ErrPushOutOfDate{ diff --git a/modules/lfs/client.go b/modules/lfs/client.go index ae35919d77..0a21440f73 100644 --- a/modules/lfs/client.go +++ b/modules/lfs/client.go @@ -10,9 +10,17 @@ import ( "net/url" ) +// DownloadCallback gets called for every requested LFS object to process its content +type DownloadCallback func(p Pointer, content io.ReadCloser, objectError error) error + +// UploadCallback gets called for every requested LFS object to provide its content +type UploadCallback func(p Pointer, objectError error) (io.ReadCloser, error) + // Client is used to communicate with a LFS source type Client interface { - Download(ctx context.Context, oid string, size int64) (io.ReadCloser, error) + BatchSize() int + Download(ctx context.Context, objects []Pointer, callback DownloadCallback) error + Upload(ctx context.Context, objects []Pointer, callback UploadCallback) error } // NewClient creates a LFS client diff --git a/modules/lfs/client_test.go b/modules/lfs/client_test.go index d4eb005469..1040b39925 100644 --- a/modules/lfs/client_test.go +++ b/modules/lfs/client_test.go @@ -6,7 +6,6 @@ package lfs import ( "net/url" - "testing" "github.com/stretchr/testify/assert" diff --git a/modules/lfs/filesystem_client.go b/modules/lfs/filesystem_client.go index 3a51564a82..dc72981a9e 100644 --- a/modules/lfs/filesystem_client.go +++ b/modules/lfs/filesystem_client.go @@ -19,6 +19,11 @@ type FilesystemClient struct { lfsdir string } +// BatchSize returns the preferred size of batchs to process +func (c *FilesystemClient) BatchSize() int { + return 1 +} + func newFilesystemClient(endpoint *url.URL) *FilesystemClient { path, _ := util.FileURLToPath(endpoint) @@ -33,18 +38,56 @@ func (c *FilesystemClient) objectPath(oid string) string { return filepath.Join(c.lfsdir, oid[0:2], oid[2:4], oid) } -// Download reads the specific LFS object from the target repository -func (c *FilesystemClient) Download(ctx context.Context, oid string, size int64) (io.ReadCloser, error) { - objectPath := c.objectPath(oid) +// Download reads the specific LFS object from the target path +func (c *FilesystemClient) Download(ctx context.Context, objects []Pointer, callback DownloadCallback) error { + for _, object := range objects { + p := Pointer{object.Oid, object.Size} - if _, err := os.Stat(objectPath); os.IsNotExist(err) { - return nil, err - } + objectPath := c.objectPath(p.Oid) + + f, err := os.Open(objectPath) + if err != nil { + return err + } - file, err := os.Open(objectPath) - if err != nil { - return nil, err + if err := callback(p, f, nil); err != nil { + return err + } } + return nil +} + +// Upload writes the specific LFS object to the target path +func (c *FilesystemClient) Upload(ctx context.Context, objects []Pointer, callback UploadCallback) error { + for _, object := range objects { + p := Pointer{object.Oid, object.Size} + + objectPath := c.objectPath(p.Oid) - return file, nil + if err := os.MkdirAll(filepath.Dir(objectPath), os.ModePerm); err != nil { + return err + } + + content, err := callback(p, nil) + if err != nil { + return err + } + + err = func() error { + defer content.Close() + + f, err := os.Create(objectPath) + if err != nil { + return err + } + + _, err = io.Copy(f, content) + + return err + }() + if err != nil { + return err + } + } + return nil } diff --git a/modules/lfs/http_client.go b/modules/lfs/http_client.go index fb45defda1..e799b80831 100644 --- a/modules/lfs/http_client.go +++ b/modules/lfs/http_client.go @@ -7,17 +7,19 @@ package lfs import ( "bytes" "context" - "encoding/json" "errors" "fmt" - "io" "net/http" "net/url" "strings" "code.gitea.io/gitea/modules/log" + + jsoniter "github.com/json-iterator/go" ) +const batchSize = 20 + // HTTPClient is used to communicate with the LFS server // https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md type HTTPClient struct { @@ -26,6 +28,11 @@ type HTTPClient struct { transfers map[string]TransferAdapter } +// BatchSize returns the preferred size of batchs to process +func (c *HTTPClient) BatchSize() int { + return batchSize +} + func newHTTPClient(endpoint *url.URL) *HTTPClient { hc := &http.Client{} @@ -55,21 +62,25 @@ func (c *HTTPClient) transferNames() []string { } func (c *HTTPClient) batch(ctx context.Context, operation string, objects []Pointer) (*BatchResponse, error) { + log.Trace("BATCH operation with objects: %v", objects) + url := fmt.Sprintf("%s/objects/batch", c.endpoint) request := &BatchRequest{operation, c.transferNames(), nil, objects} payload := new(bytes.Buffer) - err := json.NewEncoder(payload).Encode(request) + err := jsoniter.NewEncoder(payload).Encode(request) if err != nil { - return nil, fmt.Errorf("lfs.HTTPClient.batch json.Encode: %w", err) + log.Error("Error encoding json: %v", err) + return nil, err } - log.Trace("lfs.HTTPClient.batch NewRequestWithContext: %s", url) + log.Trace("Calling: %s", url) req, err := http.NewRequestWithContext(ctx, "POST", url, payload) if err != nil { - return nil, fmt.Errorf("lfs.HTTPClient.batch http.NewRequestWithContext: %w", err) + log.Error("Error creating request: %v", err) + return nil, err } req.Header.Set("Content-type", MediaType) req.Header.Set("Accept", MediaType) @@ -81,18 +92,20 @@ func (c *HTTPClient) batch(ctx context.Context, operation string, objects []Poin return nil, ctx.Err() default: } - return nil, fmt.Errorf("lfs.HTTPClient.batch http.Do: %w", err) + log.Error("Error while processing request: %v", err) + return nil, err } defer res.Body.Close() if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("lfs.HTTPClient.batch: Unexpected servers response: %s", res.Status) + return nil, fmt.Errorf("Unexpected server response: %s", res.Status) } var response BatchResponse - err = json.NewDecoder(res.Body).Decode(&response) + err = jsoniter.NewDecoder(res.Body).Decode(&response) if err != nil { - return nil, fmt.Errorf("lfs.HTTPClient.batch json.Decode: %w", err) + log.Error("Error decoding json: %v", err) + return nil, err } if len(response.Transfer) == 0 { @@ -103,27 +116,99 @@ func (c *HTTPClient) batch(ctx context.Context, operation string, objects []Poin } // Download reads the specific LFS object from the LFS server -func (c *HTTPClient) Download(ctx context.Context, oid string, size int64) (io.ReadCloser, error) { - var objects []Pointer - objects = append(objects, Pointer{oid, size}) +func (c *HTTPClient) Download(ctx context.Context, objects []Pointer, callback DownloadCallback) error { + return c.performOperation(ctx, objects, callback, nil) +} + +// Upload sends the specific LFS object to the LFS server +func (c *HTTPClient) Upload(ctx context.Context, objects []Pointer, callback UploadCallback) error { + return c.performOperation(ctx, objects, nil, callback) +} - result, err := c.batch(ctx, "download", objects) +func (c *HTTPClient) performOperation(ctx context.Context, objects []Pointer, dc DownloadCallback, uc UploadCallback) error { + if len(objects) == 0 { + return nil + } + + operation := "download" + if uc != nil { + operation = "upload" + } + + result, err := c.batch(ctx, operation, objects) if err != nil { - return nil, err + return err } transferAdapter, ok := c.transfers[result.Transfer] if !ok { - return nil, fmt.Errorf("lfs.HTTPClient.Download Transferadapter not found: %s", result.Transfer) + return fmt.Errorf("TransferAdapter not found: %s", result.Transfer) } - if len(result.Objects) == 0 { - return nil, errors.New("lfs.HTTPClient.Download: No objects in result") - } + for _, object := range result.Objects { + if object.Error != nil { + objectError := errors.New(object.Error.Message) + log.Trace("Error on object %v: %v", object.Pointer, objectError) + if uc != nil { + if _, err := uc(object.Pointer, objectError); err != nil { + return err + } + } else { + if err := dc(object.Pointer, nil, objectError); err != nil { + return err + } + } + continue + } - content, err := transferAdapter.Download(ctx, result.Objects[0]) - if err != nil { - return nil, err + if uc != nil { + if len(object.Actions) == 0 { + log.Trace("%v already present on server", object.Pointer) + continue + } + + link, ok := object.Actions["upload"] + if !ok { + log.Debug("%+v", object) + return errors.New("Missing action 'upload'") + } + + content, err := uc(object.Pointer, nil) + if err != nil { + return err + } + + err = transferAdapter.Upload(ctx, link, object.Pointer, content) + + content.Close() + + if err != nil { + return err + } + + link, ok = object.Actions["verify"] + if ok { + if err := transferAdapter.Verify(ctx, link, object.Pointer); err != nil { + return err + } + } + } else { + link, ok := object.Actions["download"] + if !ok { + log.Debug("%+v", object) + return errors.New("Missing action 'download'") + } + + content, err := transferAdapter.Download(ctx, link) + if err != nil { + return err + } + + if err := dc(object.Pointer, content, nil); err != nil { + return err + } + } } - return content, nil + + return nil } diff --git a/modules/lfs/http_client_test.go b/modules/lfs/http_client_test.go index 68ec947aa8..0f633ede54 100644 --- a/modules/lfs/http_client_test.go +++ b/modules/lfs/http_client_test.go @@ -7,13 +7,13 @@ package lfs import ( "bytes" "context" - "encoding/json" "io" "io/ioutil" "net/http" "strings" "testing" + jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/assert" ) @@ -30,69 +30,253 @@ func (a *DummyTransferAdapter) Name() string { return "dummy" } -func (a *DummyTransferAdapter) Download(ctx context.Context, r *ObjectResponse) (io.ReadCloser, error) { +func (a *DummyTransferAdapter) Download(ctx context.Context, l *Link) (io.ReadCloser, error) { return ioutil.NopCloser(bytes.NewBufferString("dummy")), nil } -func TestHTTPClientDownload(t *testing.T) { - oid := "fb8f7d8435968c4f82a726a92395be4d16f2f63116caf36c8ad35c60831ab041" - size := int64(6) +func (a *DummyTransferAdapter) Upload(ctx context.Context, l *Link, p Pointer, r io.Reader) error { + return nil +} + +func (a *DummyTransferAdapter) Verify(ctx context.Context, l *Link, p Pointer) error { + return nil +} + +func lfsTestRoundtripHandler(req *http.Request) *http.Response { + var batchResponse *BatchResponse + url := req.URL.String() - roundTripHandler := func(req *http.Request) *http.Response { - url := req.URL.String() - if strings.Contains(url, "status-not-ok") { - return &http.Response{StatusCode: http.StatusBadRequest} + if strings.Contains(url, "status-not-ok") { + return &http.Response{StatusCode: http.StatusBadRequest} + } else if strings.Contains(url, "invalid-json-response") { + return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(bytes.NewBufferString("invalid json"))} + } else if strings.Contains(url, "valid-batch-request-download") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Actions: map[string]*Link{ + "download": {}, + }, + }, + }, + } + } else if strings.Contains(url, "valid-batch-request-upload") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Actions: map[string]*Link{ + "upload": {}, + }, + }, + }, } - if strings.Contains(url, "invalid-json-response") { - return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(bytes.NewBufferString("invalid json"))} + } else if strings.Contains(url, "response-no-objects") { + batchResponse = &BatchResponse{Transfer: "dummy"} + } else if strings.Contains(url, "unknown-transfer-adapter") { + batchResponse = &BatchResponse{Transfer: "unknown_adapter"} + } else if strings.Contains(url, "error-in-response-objects") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Error: &ObjectError{ + Code: 404, + Message: "Object not found", + }, + }, + }, } - if strings.Contains(url, "valid-batch-request-download") { - assert.Equal(t, "POST", req.Method) - assert.Equal(t, MediaType, req.Header.Get("Content-type"), "case %s: error should match", url) - assert.Equal(t, MediaType, req.Header.Get("Accept"), "case %s: error should match", url) + } else if strings.Contains(url, "empty-actions-map") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Actions: map[string]*Link{}, + }, + }, + } + } else if strings.Contains(url, "download-actions-map") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Actions: map[string]*Link{ + "download": {}, + }, + }, + }, + } + } else if strings.Contains(url, "upload-actions-map") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Actions: map[string]*Link{ + "upload": {}, + }, + }, + }, + } + } else if strings.Contains(url, "verify-actions-map") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Actions: map[string]*Link{ + "verify": {}, + }, + }, + }, + } + } else if strings.Contains(url, "unknown-actions-map") { + batchResponse = &BatchResponse{ + Transfer: "dummy", + Objects: []*ObjectResponse{ + { + Actions: map[string]*Link{ + "unknown": {}, + }, + }, + }, + } + } else { + return nil + } - var batchRequest BatchRequest - err := json.NewDecoder(req.Body).Decode(&batchRequest) - assert.NoError(t, err) + payload := new(bytes.Buffer) + jsoniter.NewEncoder(payload).Encode(batchResponse) - assert.Equal(t, "download", batchRequest.Operation) - assert.Len(t, batchRequest.Objects, 1) - assert.Equal(t, oid, batchRequest.Objects[0].Oid) - assert.Equal(t, size, batchRequest.Objects[0].Size) + return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(payload)} +} - batchResponse := &BatchResponse{ - Transfer: "dummy", - Objects: make([]*ObjectResponse, 1), - } +func TestHTTPClientDownload(t *testing.T) { + p := Pointer{Oid: "fb8f7d8435968c4f82a726a92395be4d16f2f63116caf36c8ad35c60831ab041", Size: 6} - payload := new(bytes.Buffer) - json.NewEncoder(payload).Encode(batchResponse) + hc := &http.Client{Transport: RoundTripFunc(func(req *http.Request) *http.Response { + assert.Equal(t, "POST", req.Method) + assert.Equal(t, MediaType, req.Header.Get("Content-type")) + assert.Equal(t, MediaType, req.Header.Get("Accept")) - return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(payload)} - } - if strings.Contains(url, "invalid-response-no-objects") { - batchResponse := &BatchResponse{Transfer: "dummy"} + var batchRequest BatchRequest + err := jsoniter.NewDecoder(req.Body).Decode(&batchRequest) + assert.NoError(t, err) - payload := new(bytes.Buffer) - json.NewEncoder(payload).Encode(batchResponse) + assert.Equal(t, "download", batchRequest.Operation) + assert.Equal(t, 1, len(batchRequest.Objects)) + assert.Equal(t, p.Oid, batchRequest.Objects[0].Oid) + assert.Equal(t, p.Size, batchRequest.Objects[0].Size) - return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(payload)} - } - if strings.Contains(url, "unknown-transfer-adapter") { - batchResponse := &BatchResponse{Transfer: "unknown_adapter"} + return lfsTestRoundtripHandler(req) + })} + dummy := &DummyTransferAdapter{} - payload := new(bytes.Buffer) - json.NewEncoder(payload).Encode(batchResponse) + var cases = []struct { + endpoint string + expectederror string + }{ + // case 0 + { + endpoint: "https://status-not-ok.io", + expectederror: "Unexpected server response: ", + }, + // case 1 + { + endpoint: "https://invalid-json-response.io", + expectederror: "invalid json", + }, + // case 2 + { + endpoint: "https://valid-batch-request-download.io", + expectederror: "", + }, + // case 3 + { + endpoint: "https://response-no-objects.io", + expectederror: "", + }, + // case 4 + { + endpoint: "https://unknown-transfer-adapter.io", + expectederror: "TransferAdapter not found: ", + }, + // case 5 + { + endpoint: "https://error-in-response-objects.io", + expectederror: "Object not found", + }, + // case 6 + { + endpoint: "https://empty-actions-map.io", + expectederror: "Missing action 'download'", + }, + // case 7 + { + endpoint: "https://download-actions-map.io", + expectederror: "", + }, + // case 8 + { + endpoint: "https://upload-actions-map.io", + expectederror: "Missing action 'download'", + }, + // case 9 + { + endpoint: "https://verify-actions-map.io", + expectederror: "Missing action 'download'", + }, + // case 10 + { + endpoint: "https://unknown-actions-map.io", + expectederror: "Missing action 'download'", + }, + } - return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(payload)} + for n, c := range cases { + client := &HTTPClient{ + client: hc, + endpoint: c.endpoint, + transfers: make(map[string]TransferAdapter), } + client.transfers["dummy"] = dummy - t.Errorf("Unknown test case: %s", url) - - return nil + err := client.Download(context.Background(), []Pointer{p}, func(p Pointer, content io.ReadCloser, objectError error) error { + if objectError != nil { + return objectError + } + b, err := io.ReadAll(content) + assert.NoError(t, err) + assert.Equal(t, []byte("dummy"), b) + return nil + }) + if len(c.expectederror) > 0 { + assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror) + } else { + assert.NoError(t, err, "case %d", n) + } } +} + +func TestHTTPClientUpload(t *testing.T) { + p := Pointer{Oid: "fb8f7d8435968c4f82a726a92395be4d16f2f63116caf36c8ad35c60831ab041", Size: 6} + + hc := &http.Client{Transport: RoundTripFunc(func(req *http.Request) *http.Response { + assert.Equal(t, "POST", req.Method) + assert.Equal(t, MediaType, req.Header.Get("Content-type")) + assert.Equal(t, MediaType, req.Header.Get("Accept")) + + var batchRequest BatchRequest + err := jsoniter.NewDecoder(req.Body).Decode(&batchRequest) + assert.NoError(t, err) - hc := &http.Client{Transport: RoundTripFunc(roundTripHandler)} + assert.Equal(t, "upload", batchRequest.Operation) + assert.Equal(t, 1, len(batchRequest.Objects)) + assert.Equal(t, p.Oid, batchRequest.Objects[0].Oid) + assert.Equal(t, p.Size, batchRequest.Objects[0].Size) + + return lfsTestRoundtripHandler(req) + })} dummy := &DummyTransferAdapter{} var cases = []struct { @@ -102,27 +286,57 @@ func TestHTTPClientDownload(t *testing.T) { // case 0 { endpoint: "https://status-not-ok.io", - expectederror: "Unexpected servers response: ", + expectederror: "Unexpected server response: ", }, // case 1 { endpoint: "https://invalid-json-response.io", - expectederror: "json.Decode: ", + expectederror: "invalid json", }, // case 2 { - endpoint: "https://valid-batch-request-download.io", + endpoint: "https://valid-batch-request-upload.io", expectederror: "", }, // case 3 { - endpoint: "https://invalid-response-no-objects.io", - expectederror: "No objects in result", + endpoint: "https://response-no-objects.io", + expectederror: "", }, // case 4 { endpoint: "https://unknown-transfer-adapter.io", - expectederror: "Transferadapter not found: ", + expectederror: "TransferAdapter not found: ", + }, + // case 5 + { + endpoint: "https://error-in-response-objects.io", + expectederror: "Object not found", + }, + // case 6 + { + endpoint: "https://empty-actions-map.io", + expectederror: "", + }, + // case 7 + { + endpoint: "https://download-actions-map.io", + expectederror: "Missing action 'upload'", + }, + // case 8 + { + endpoint: "https://upload-actions-map.io", + expectederror: "", + }, + // case 9 + { + endpoint: "https://verify-actions-map.io", + expectederror: "Missing action 'upload'", + }, + // case 10 + { + endpoint: "https://unknown-actions-map.io", + expectederror: "Missing action 'upload'", }, } @@ -134,7 +348,9 @@ func TestHTTPClientDownload(t *testing.T) { } client.transfers["dummy"] = dummy - _, err := client.Download(context.Background(), oid, size) + err := client.Upload(context.Background(), []Pointer{p}, func(p Pointer, objectError error) (io.ReadCloser, error) { + return ioutil.NopCloser(new(bytes.Buffer)), objectError + }) if len(c.expectederror) > 0 { assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror) } else { diff --git a/modules/lfs/shared.go b/modules/lfs/shared.go index 9abbf85fbd..8343d12e1d 100644 --- a/modules/lfs/shared.go +++ b/modules/lfs/shared.go @@ -49,14 +49,14 @@ type ObjectResponse struct { Error *ObjectError `json:"error,omitempty"` } -// Link provides a structure used to build a hypermedia representation of an HTTP link. +// Link provides a structure with informations about how to access a object. type Link struct { Href string `json:"href"` Header map[string]string `json:"header,omitempty"` ExpiresAt *time.Time `json:"expires_at,omitempty"` } -// ObjectError defines the JSON structure returned to the client in case of an error +// ObjectError defines the JSON structure returned to the client in case of an error. type ObjectError struct { Code int `json:"code"` Message string `json:"message"` diff --git a/modules/lfs/transferadapter.go b/modules/lfs/transferadapter.go index ea3aff0000..8c40ab8c04 100644 --- a/modules/lfs/transferadapter.go +++ b/modules/lfs/transferadapter.go @@ -5,18 +5,24 @@ package lfs import ( + "bytes" "context" "errors" "fmt" "io" "net/http" + + "code.gitea.io/gitea/modules/log" + + jsoniter "github.com/json-iterator/go" ) // TransferAdapter represents an adapter for downloading/uploading LFS objects type TransferAdapter interface { Name() string - Download(ctx context.Context, r *ObjectResponse) (io.ReadCloser, error) - //Upload(ctx context.Context, reader io.Reader) error + Download(ctx context.Context, l *Link) (io.ReadCloser, error) + Upload(ctx context.Context, l *Link, p Pointer, r io.Reader) error + Verify(ctx context.Context, l *Link, p Pointer) error } // BasicTransferAdapter implements the "basic" adapter @@ -30,29 +36,101 @@ func (a *BasicTransferAdapter) Name() string { } // Download reads the download location and downloads the data -func (a *BasicTransferAdapter) Download(ctx context.Context, r *ObjectResponse) (io.ReadCloser, error) { - download, ok := r.Actions["download"] - if !ok { - return nil, errors.New("lfs.BasicTransferAdapter.Download: Action 'download' not found") +func (a *BasicTransferAdapter) Download(ctx context.Context, l *Link) (io.ReadCloser, error) { + resp, err := a.performRequest(ctx, "GET", l, nil, nil) + if err != nil { + return nil, err } + return resp.Body, nil +} - req, err := http.NewRequestWithContext(ctx, "GET", download.Href, nil) +// Upload sends the content to the LFS server +func (a *BasicTransferAdapter) Upload(ctx context.Context, l *Link, p Pointer, r io.Reader) error { + _, err := a.performRequest(ctx, "PUT", l, r, func(req *http.Request) { + if len(req.Header.Get("Content-Type")) == 0 { + req.Header.Set("Content-Type", "application/octet-stream") + } + + if req.Header.Get("Transfer-Encoding") == "chunked" { + req.TransferEncoding = []string{"chunked"} + } + + req.ContentLength = p.Size + }) if err != nil { - return nil, fmt.Errorf("lfs.BasicTransferAdapter.Download http.NewRequestWithContext: %w", err) + return err } - for key, value := range download.Header { + return nil +} + +// Verify calls the verify handler on the LFS server +func (a *BasicTransferAdapter) Verify(ctx context.Context, l *Link, p Pointer) error { + b, err := jsoniter.Marshal(p) + if err != nil { + log.Error("Error encoding json: %v", err) + return err + } + + _, err = a.performRequest(ctx, "POST", l, bytes.NewReader(b), func(req *http.Request) { + req.Header.Set("Content-Type", MediaType) + }) + if err != nil { + return err + } + return nil +} + +func (a *BasicTransferAdapter) performRequest(ctx context.Context, method string, l *Link, body io.Reader, callback func(*http.Request)) (*http.Response, error) { + log.Trace("Calling: %s %s", method, l.Href) + + req, err := http.NewRequestWithContext(ctx, method, l.Href, body) + if err != nil { + log.Error("Error creating request: %v", err) + return nil, err + } + for key, value := range l.Header { req.Header.Set(key, value) } + req.Header.Set("Accept", MediaType) + + if callback != nil { + callback(req) + } res, err := a.client.Do(req) if err != nil { select { case <-ctx.Done(): - return nil, ctx.Err() + return res, ctx.Err() default: } - return nil, fmt.Errorf("lfs.BasicTransferAdapter.Download http.Do: %w", err) + log.Error("Error while processing request: %v", err) + return res, err + } + + if res.StatusCode != http.StatusOK { + return res, handleErrorResponse(res) + } + + return res, nil +} + +func handleErrorResponse(resp *http.Response) error { + defer resp.Body.Close() + + er, err := decodeReponseError(resp.Body) + if err != nil { + return fmt.Errorf("Request failed with status %s", resp.Status) } + log.Trace("ErrorRespone: %v", er) + return errors.New(er.Message) +} - return res.Body, nil +func decodeReponseError(r io.Reader) (ErrorResponse, error) { + var er ErrorResponse + err := jsoniter.NewDecoder(r).Decode(&er) + if err != nil { + log.Error("Error decoding json: %v", err) + } + return er, err } diff --git a/modules/lfs/transferadapter_test.go b/modules/lfs/transferadapter_test.go index 0eabd3faee..7dfdad417e 100644 --- a/modules/lfs/transferadapter_test.go +++ b/modules/lfs/transferadapter_test.go @@ -7,11 +7,13 @@ package lfs import ( "bytes" "context" + "io" "io/ioutil" "net/http" "strings" "testing" + jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/assert" ) @@ -21,58 +23,151 @@ func TestBasicTransferAdapterName(t *testing.T) { assert.Equal(t, "basic", a.Name()) } -func TestBasicTransferAdapterDownload(t *testing.T) { +func TestBasicTransferAdapter(t *testing.T) { + p := Pointer{Oid: "b5a2c96250612366ea272ffac6d9744aaf4b45aacd96aa7cfcb931ee3b558259", Size: 5} + roundTripHandler := func(req *http.Request) *http.Response { + assert.Equal(t, MediaType, req.Header.Get("Accept")) + assert.Equal(t, "test-value", req.Header.Get("test-header")) + url := req.URL.String() - if strings.Contains(url, "valid-download-request") { + if strings.Contains(url, "download-request") { assert.Equal(t, "GET", req.Method) - assert.Equal(t, "test-value", req.Header.Get("test-header")) return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(bytes.NewBufferString("dummy"))} - } + } else if strings.Contains(url, "upload-request") { + assert.Equal(t, "PUT", req.Method) + assert.Equal(t, "application/octet-stream", req.Header.Get("Content-Type")) + + b, err := io.ReadAll(req.Body) + assert.NoError(t, err) + assert.Equal(t, "dummy", string(b)) - t.Errorf("Unknown test case: %s", url) + return &http.Response{StatusCode: http.StatusOK} + } else if strings.Contains(url, "verify-request") { + assert.Equal(t, "POST", req.Method) + assert.Equal(t, MediaType, req.Header.Get("Content-Type")) - return nil + var vp Pointer + err := jsoniter.NewDecoder(req.Body).Decode(&vp) + assert.NoError(t, err) + assert.Equal(t, p.Oid, vp.Oid) + assert.Equal(t, p.Size, vp.Size) + + return &http.Response{StatusCode: http.StatusOK} + } else if strings.Contains(url, "error-response") { + er := &ErrorResponse{ + Message: "Object not found", + } + payload := new(bytes.Buffer) + jsoniter.NewEncoder(payload).Encode(er) + + return &http.Response{StatusCode: http.StatusNotFound, Body: ioutil.NopCloser(payload)} + } else { + t.Errorf("Unknown test case: %s", url) + return nil + } } hc := &http.Client{Transport: RoundTripFunc(roundTripHandler)} a := &BasicTransferAdapter{hc} - var cases = []struct { - response *ObjectResponse - expectederror string - }{ - // case 0 - { - response: &ObjectResponse{}, - expectederror: "Action 'download' not found", - }, - // case 1 - { - response: &ObjectResponse{ - Actions: map[string]*Link{"upload": nil}, + t.Run("Download", func(t *testing.T) { + cases := []struct { + link *Link + expectederror string + }{ + // case 0 + { + link: &Link{ + Href: "https://download-request.io", + Header: map[string]string{"test-header": "test-value"}, + }, + expectederror: "", }, - expectederror: "Action 'download' not found", - }, - // case 2 - { - response: &ObjectResponse{ - Actions: map[string]*Link{"download": { - Href: "https://valid-download-request.io", + // case 1 + { + link: &Link{ + Href: "https://error-response.io", Header: map[string]string{"test-header": "test-value"}, - }}, + }, + expectederror: "Object not found", }, - expectederror: "", - }, - } + } - for n, c := range cases { - _, err := a.Download(context.Background(), c.response) - if len(c.expectederror) > 0 { - assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror) - } else { - assert.NoError(t, err, "case %d", n) + for n, c := range cases { + _, err := a.Download(context.Background(), c.link) + if len(c.expectederror) > 0 { + assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror) + } else { + assert.NoError(t, err, "case %d", n) + } } - } + }) + + t.Run("Upload", func(t *testing.T) { + cases := []struct { + link *Link + expectederror string + }{ + // case 0 + { + link: &Link{ + Href: "https://upload-request.io", + Header: map[string]string{"test-header": "test-value"}, + }, + expectederror: "", + }, + // case 1 + { + link: &Link{ + Href: "https://error-response.io", + Header: map[string]string{"test-header": "test-value"}, + }, + expectederror: "Object not found", + }, + } + + for n, c := range cases { + err := a.Upload(context.Background(), c.link, p, bytes.NewBufferString("dummy")) + if len(c.expectederror) > 0 { + assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror) + } else { + assert.NoError(t, err, "case %d", n) + } + } + }) + + t.Run("Verify", func(t *testing.T) { + cases := []struct { + link *Link + expectederror string + }{ + // case 0 + { + link: &Link{ + Href: "https://verify-request.io", + Header: map[string]string{"test-header": "test-value"}, + }, + expectederror: "", + }, + // case 1 + { + link: &Link{ + Href: "https://error-response.io", + Header: map[string]string{"test-header": "test-value"}, + }, + expectederror: "Object not found", + }, + } + + for n, c := range cases { + err := a.Verify(context.Background(), c.link, p) + if len(c.expectederror) > 0 { + assert.True(t, strings.Contains(err.Error(), c.expectederror), "case %d: '%s' should contain '%s'", n, err.Error(), c.expectederror) + } else { + assert.NoError(t, err, "case %d", n) + } + } + }) } diff --git a/modules/repository/repo.go b/modules/repository/repo.go index 50eb185daa..08531c04ed 100644 --- a/modules/repository/repo.go +++ b/modules/repository/repo.go @@ -7,6 +7,7 @@ package repository import ( "context" "fmt" + "io" "net/url" "path" "strings" @@ -323,64 +324,90 @@ func StoreMissingLfsObjectsInRepository(ctx context.Context, repo *models.Reposi errChan := make(chan error, 1) go lfs.SearchPointerBlobs(ctx, gitRepo, pointerChan, errChan) - err := func() error { - for pointerBlob := range pointerChan { - meta, err := models.NewLFSMetaObject(&models.LFSMetaObject{Pointer: pointerBlob.Pointer, RepositoryID: repo.ID}) - if err != nil { - return fmt.Errorf("StoreMissingLfsObjectsInRepository models.NewLFSMetaObject: %w", err) - } - if meta.Existing { - continue + downloadObjects := func(pointers []lfs.Pointer) error { + err := client.Download(ctx, pointers, func(p lfs.Pointer, content io.ReadCloser, objectError error) error { + if objectError != nil { + return objectError } - log.Trace("StoreMissingLfsObjectsInRepository: LFS OID[%s] not present in repository %s", pointerBlob.Oid, repo.FullName()) + defer content.Close() - err = func() error { - exist, err := contentStore.Exists(pointerBlob.Pointer) - if err != nil { - return fmt.Errorf("StoreMissingLfsObjectsInRepository contentStore.Exists: %w", err) - } - if !exist { - if setting.LFS.MaxFileSize > 0 && pointerBlob.Size > setting.LFS.MaxFileSize { - log.Info("LFS OID[%s] download denied because of LFS_MAX_FILE_SIZE=%d < size %d", pointerBlob.Oid, setting.LFS.MaxFileSize, pointerBlob.Size) - return nil - } - - stream, err := client.Download(ctx, pointerBlob.Oid, pointerBlob.Size) - if err != nil { - return fmt.Errorf("StoreMissingLfsObjectsInRepository: LFS OID[%s] failed to download: %w", pointerBlob.Oid, err) - } - defer stream.Close() - - if err := contentStore.Put(pointerBlob.Pointer, stream); err != nil { - return fmt.Errorf("StoreMissingLfsObjectsInRepository LFS OID[%s] contentStore.Put: %w", pointerBlob.Oid, err) - } - } else { - log.Trace("StoreMissingLfsObjectsInRepository: LFS OID[%s] already present in content store", pointerBlob.Oid) - } - return nil - }() + _, err := models.NewLFSMetaObject(&models.LFSMetaObject{Pointer: p, RepositoryID: repo.ID}) if err != nil { - if _, err2 := repo.RemoveLFSMetaObjectByOid(meta.Oid); err2 != nil { - log.Error("StoreMissingLfsObjectsInRepository RemoveLFSMetaObjectByOid[Oid: %s]: %w", meta.Oid, err2) - } + log.Error("Error creating LFS meta object %v: %v", p, err) + return err + } - select { - case <-ctx.Done(): - return nil - default: + if err := contentStore.Put(p, content); err != nil { + log.Error("Error storing content for LFS meta object %v: %v", p, err) + if _, err2 := repo.RemoveLFSMetaObjectByOid(p.Oid); err2 != nil { + log.Error("Error removing LFS meta object %v: %v", p, err2) } return err } + return nil + }) + if err != nil { + select { + case <-ctx.Done(): + return nil + default: + } } - return nil - }() - if err != nil { return err } + var batch []lfs.Pointer + for pointerBlob := range pointerChan { + meta, err := repo.GetLFSMetaObjectByOid(pointerBlob.Oid) + if err != nil && err != models.ErrLFSObjectNotExist { + log.Error("Error querying LFS meta object %v: %v", pointerBlob.Pointer, err) + return err + } + if meta != nil { + log.Trace("Skipping unknown LFS meta object %v", pointerBlob.Pointer) + continue + } + + log.Trace("LFS object %v not present in repository %s", pointerBlob.Pointer, repo.FullName()) + + exist, err := contentStore.Exists(pointerBlob.Pointer) + if err != nil { + log.Error("Error checking if LFS object %v exists: %v", pointerBlob.Pointer, err) + return err + } + + if exist { + log.Trace("LFS object %v already present; creating meta object", pointerBlob.Pointer) + _, err := models.NewLFSMetaObject(&models.LFSMetaObject{Pointer: pointerBlob.Pointer, RepositoryID: repo.ID}) + if err != nil { + log.Error("Error creating LFS meta object %v: %v", pointerBlob.Pointer, err) + return err + } + } else { + if setting.LFS.MaxFileSize > 0 && pointerBlob.Size > setting.LFS.MaxFileSize { + log.Info("LFS object %v download denied because of LFS_MAX_FILE_SIZE=%d < size %d", pointerBlob.Pointer, setting.LFS.MaxFileSize, pointerBlob.Size) + continue + } + + batch = append(batch, pointerBlob.Pointer) + if len(batch) >= client.BatchSize() { + if err := downloadObjects(batch); err != nil { + return err + } + batch = nil + } + } + } + if len(batch) > 0 { + if err := downloadObjects(batch); err != nil { + return err + } + } + err, has := <-errChan if has { + log.Error("Error enumerating LFS objects for repository: %v", err) return err } diff --git a/modules/task/migrate.go b/modules/task/migrate.go index 57424abac3..fe9b984d44 100644 --- a/modules/task/migrate.go +++ b/modules/task/migrate.go @@ -118,7 +118,7 @@ func runMigrateTask(t *models.Task) (err error) { } // remoteAddr may contain credentials, so we sanitize it - err = util.URLSanitizedError(err, opts.CloneAddr) + err = util.NewStringURLSanitizedError(err, opts.CloneAddr, true) if strings.Contains(err.Error(), "Authentication failed") || strings.Contains(err.Error(), "could not read Username") { return fmt.Errorf("Authentication failed: %v", err.Error()) diff --git a/modules/task/task.go b/modules/task/task.go index 0685aa23d7..1c0a87e1f6 100644 --- a/modules/task/task.go +++ b/modules/task/task.go @@ -74,7 +74,7 @@ func CreateMigrateTask(doer, u *models.User, opts base.MigrateOptions) (*models. if err != nil { return nil, err } - opts.CloneAddr = util.SanitizeURLCredentials(opts.CloneAddr, true) + opts.CloneAddr = util.NewStringURLSanitizer(opts.CloneAddr, true).Replace(opts.CloneAddr) opts.AuthPasswordEncrypted, err = secret.EncryptSecret(setting.SecretKey, opts.AuthPassword) if err != nil { return nil, err diff --git a/modules/templates/helper.go b/modules/templates/helper.go index 9922cfb225..83359a6ef2 100644 --- a/modules/templates/helper.go +++ b/modules/templates/helper.go @@ -27,6 +27,7 @@ import ( "code.gitea.io/gitea/models" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/emoji" + "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/markup" "code.gitea.io/gitea/modules/repository" @@ -35,7 +36,6 @@ import ( "code.gitea.io/gitea/modules/timeutil" "code.gitea.io/gitea/modules/util" "code.gitea.io/gitea/services/gitdiff" - mirror_service "code.gitea.io/gitea/services/mirror" "github.com/editorconfig/editorconfig-core-go/v2" jsoniter "github.com/json-iterator/go" @@ -294,11 +294,8 @@ func NewFuncMap() []template.FuncMap { } return float32(n) * 100 / float32(sum) }, - "CommentMustAsDiff": gitdiff.CommentMustAsDiff, - "MirrorAddress": mirror_service.Address, - "MirrorFullAddress": mirror_service.AddressNoCredentials, - "MirrorUserName": mirror_service.Username, - "MirrorPassword": mirror_service.Password, + "CommentMustAsDiff": gitdiff.CommentMustAsDiff, + "MirrorRemoteAddress": mirrorRemoteAddress, "CommitType": func(commit interface{}) string { switch commit.(type) { case models.SignCommitWithStatuses: @@ -963,3 +960,28 @@ func buildSubjectBodyTemplate(stpl *texttmpl.Template, btpl *template.Template, log.Warn("Failed to parse template [%s/body]: %v", name, err) } } + +type remoteAddress struct { + Address string + Username string + Password string +} + +func mirrorRemoteAddress(m models.RemoteMirrorer) remoteAddress { + a := remoteAddress{} + + u, err := git.GetRemoteAddress(m.GetRepository().RepoPath(), m.GetRemoteName()) + if err != nil { + log.Error("GetRemoteAddress %v", err) + return a + } + + if u.User != nil { + a.Username = u.User.Username() + a.Password, _ = u.User.Password() + } + u.User = nil + a.Address = u.String() + + return a +} diff --git a/modules/util/sanitize.go b/modules/util/sanitize.go index a4f5479dfb..de59ffaa2e 100644 --- a/modules/util/sanitize.go +++ b/modules/util/sanitize.go @@ -1,4 +1,4 @@ -// Copyright 2017 The Gitea Authors. All rights reserved. +// Copyright 2021 The Gitea Authors. All rights reserved. // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. @@ -9,40 +9,53 @@ import ( "strings" ) -// urlSafeError wraps an error whose message may contain a sensitive URL -type urlSafeError struct { - err error - unsanitizedURL string +const userPlaceholder = "sanitized-credential" +const unparsableURL = "(unparsable url)" + +type sanitizedError struct { + err error + replacer *strings.Replacer } -func (err urlSafeError) Error() string { - return SanitizeMessage(err.err.Error(), err.unsanitizedURL) +func (err sanitizedError) Error() string { + return err.replacer.Replace(err.err.Error()) } -// URLSanitizedError returns the sanitized version an error whose message may -// contain a sensitive URL -func URLSanitizedError(err error, unsanitizedURL string) error { - return urlSafeError{err: err, unsanitizedURL: unsanitizedURL} +// NewSanitizedError wraps an error and replaces all old, new string pairs in the message text. +func NewSanitizedError(err error, oldnew ...string) error { + return sanitizedError{err: err, replacer: strings.NewReplacer(oldnew...)} } -// SanitizeMessage sanitizes a message which may contains a sensitive URL -func SanitizeMessage(message, unsanitizedURL string) string { - sanitizedURL := SanitizeURLCredentials(unsanitizedURL, true) - return strings.ReplaceAll(message, unsanitizedURL, sanitizedURL) +// NewURLSanitizedError wraps an error and replaces the url credential or removes them. +func NewURLSanitizedError(err error, u *url.URL, usePlaceholder bool) error { + return sanitizedError{err: err, replacer: NewURLSanitizer(u, usePlaceholder)} } -// SanitizeURLCredentials sanitizes a url, either removing user credentials -// or replacing them with a placeholder. -func SanitizeURLCredentials(unsanitizedURL string, usePlaceholder bool) string { - u, err := url.Parse(unsanitizedURL) - if err != nil { - // don't log the error, since it might contain unsanitized URL. - return "(unparsable url)" - } +// NewStringURLSanitizedError wraps an error and replaces the url credential or removes them. +// If the url can't get parsed it gets replaced with a placeholder string. +func NewStringURLSanitizedError(err error, unsanitizedURL string, usePlaceholder bool) error { + return sanitizedError{err: err, replacer: NewStringURLSanitizer(unsanitizedURL, usePlaceholder)} +} + +// NewURLSanitizer creates a replacer for the url with the credential sanitized or removed. +func NewURLSanitizer(u *url.URL, usePlaceholder bool) *strings.Replacer { + old := u.String() + if u.User != nil && usePlaceholder { - u.User = url.User("<credentials>") + u.User = url.User(userPlaceholder) } else { u.User = nil } - return u.String() + return strings.NewReplacer(old, u.String()) +} + +// NewStringURLSanitizer creates a replacer for the url with the credential sanitized or removed. +// If the url can't get parsed it gets replaced with a placeholder string +func NewStringURLSanitizer(unsanitizedURL string, usePlaceholder bool) *strings.Replacer { + u, err := url.Parse(unsanitizedURL) + if err != nil { + // don't log the error, since it might contain unsanitized URL. + return strings.NewReplacer(unsanitizedURL, unparsableURL) + } + return NewURLSanitizer(u, usePlaceholder) } diff --git a/modules/util/sanitize_test.go b/modules/util/sanitize_test.go index 4f07100675..578f75f518 100644 --- a/modules/util/sanitize_test.go +++ b/modules/util/sanitize_test.go @@ -1,25 +1,164 @@ -// Copyright 2020 The Gitea Authors. All rights reserved. +// Copyright 2021 The Gitea Authors. All rights reserved. // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. package util import ( + "errors" "testing" "github.com/stretchr/testify/assert" ) -func TestSanitizeURLCredentials(t *testing.T) { - var kases = map[string]string{ - "https://github.com/go-gitea/test_repo.git": "https://github.com/go-gitea/test_repo.git", - "https://mytoken@github.com/go-gitea/test_repo.git": "https://github.com/go-gitea/test_repo.git", - "http://github.com/go-gitea/test_repo.git": "http://github.com/go-gitea/test_repo.git", - "/test/repos/repo1": "/test/repos/repo1", - "git@github.com:go-gitea/test_repo.git": "(unparsable url)", +func TestNewSanitizedError(t *testing.T) { + err := errors.New("error while secret on test") + err2 := NewSanitizedError(err) + assert.Equal(t, err.Error(), err2.Error()) + + var cases = []struct { + input error + oldnew []string + expected string + }{ + // case 0 + { + errors.New("error while secret on test"), + []string{"secret", "replaced"}, + "error while replaced on test", + }, + // case 1 + { + errors.New("error while sec-ret on test"), + []string{"secret", "replaced"}, + "error while sec-ret on test", + }, } - for source, value := range kases { - assert.EqualValues(t, value, SanitizeURLCredentials(source, false)) + for n, c := range cases { + err := NewSanitizedError(c.input, c.oldnew...) + + assert.Equal(t, c.expected, err.Error(), "case %d: error should match", n) + } +} + +func TestNewStringURLSanitizer(t *testing.T) { + var cases = []struct { + input string + placeholder bool + expected string + }{ + // case 0 + { + "https://github.com/go-gitea/test_repo.git", + true, + "https://github.com/go-gitea/test_repo.git", + }, + // case 1 + { + "https://github.com/go-gitea/test_repo.git", + false, + "https://github.com/go-gitea/test_repo.git", + }, + // case 2 + { + "https://mytoken@github.com/go-gitea/test_repo.git", + true, + "https://" + userPlaceholder + "@github.com/go-gitea/test_repo.git", + }, + // case 3 + { + "https://mytoken@github.com/go-gitea/test_repo.git", + false, + "https://github.com/go-gitea/test_repo.git", + }, + // case 4 + { + "https://user:password@github.com/go-gitea/test_repo.git", + true, + "https://" + userPlaceholder + "@github.com/go-gitea/test_repo.git", + }, + // case 5 + { + "https://user:password@github.com/go-gitea/test_repo.git", + false, + "https://github.com/go-gitea/test_repo.git", + }, + // case 6 + { + "https://gi\nthub.com/go-gitea/test_repo.git", + false, + unparsableURL, + }, + } + + for n, c := range cases { + // uses NewURLSanitizer internally + result := NewStringURLSanitizer(c.input, c.placeholder).Replace(c.input) + + assert.Equal(t, c.expected, result, "case %d: error should match", n) + } +} + +func TestNewStringURLSanitizedError(t *testing.T) { + var cases = []struct { + input string + placeholder bool + expected string + }{ + // case 0 + { + "https://github.com/go-gitea/test_repo.git", + true, + "https://github.com/go-gitea/test_repo.git", + }, + // case 1 + { + "https://github.com/go-gitea/test_repo.git", + false, + "https://github.com/go-gitea/test_repo.git", + }, + // case 2 + { + "https://mytoken@github.com/go-gitea/test_repo.git", + true, + "https://" + userPlaceholder + "@github.com/go-gitea/test_repo.git", + }, + // case 3 + { + "https://mytoken@github.com/go-gitea/test_repo.git", + false, + "https://github.com/go-gitea/test_repo.git", + }, + // case 4 + { + "https://user:password@github.com/go-gitea/test_repo.git", + true, + "https://" + userPlaceholder + "@github.com/go-gitea/test_repo.git", + }, + // case 5 + { + "https://user:password@github.com/go-gitea/test_repo.git", + false, + "https://github.com/go-gitea/test_repo.git", + }, + // case 6 + { + "https://gi\nthub.com/go-gitea/test_repo.git", + false, + unparsableURL, + }, + } + + encloseText := func(input string) string { + return "test " + input + " test" + } + + for n, c := range cases { + err := errors.New(encloseText(c.input)) + + result := NewStringURLSanitizedError(err, c.input, c.placeholder) + + assert.Equal(t, encloseText(c.expected), result.Error(), "case %d: error should match", n) } } |