diff options
author | zeripath <art27@cantab.net> | 2021-07-14 15:43:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-14 10:43:13 -0400 |
commit | 3dcb3e9073d825a4ada184f832892cf4bd5836a3 (patch) | |
tree | aab77b7726f0e20f34b452df166113950ff5fc62 /modules | |
parent | ee43d70a0c237ef9c02b99b9b49d1af348840319 (diff) | |
download | gitea-3dcb3e9073d825a4ada184f832892cf4bd5836a3.tar.gz gitea-3dcb3e9073d825a4ada184f832892cf4bd5836a3.zip |
Second attempt at preventing zombies (#16326)
* Second attempt at preventing zombies
* Ensure that the pipes are closed in ssh.go
* Ensure that a cancellable context is passed up in cmd/* http requests
* Make cmd.fail return properly so defers are obeyed
* Ensure that something is sent to stdout in case of blocks here
Signed-off-by: Andrew Thornton <art27@cantab.net>
* placate lint
Signed-off-by: Andrew Thornton <art27@cantab.net>
* placate lint 2
Signed-off-by: Andrew Thornton <art27@cantab.net>
* placate lint 3
Signed-off-by: Andrew Thornton <art27@cantab.net>
* fixup
Signed-off-by: Andrew Thornton <art27@cantab.net>
* Apply suggestions from code review
Co-authored-by: 6543 <6543@obermui.de>
Co-authored-by: Lauris BH <lauris@nix.lv>
Diffstat (limited to 'modules')
-rw-r--r-- | modules/httplib/httplib.go | 20 | ||||
-rw-r--r-- | modules/private/hook.go | 18 | ||||
-rw-r--r-- | modules/private/internal.go | 17 | ||||
-rw-r--r-- | modules/private/key.go | 9 | ||||
-rw-r--r-- | modules/private/mail.go | 5 | ||||
-rw-r--r-- | modules/private/manager.go | 33 | ||||
-rw-r--r-- | modules/private/restore_repo.go | 5 | ||||
-rw-r--r-- | modules/private/serv.go | 9 | ||||
-rw-r--r-- | modules/ssh/ssh.go | 14 |
9 files changed, 82 insertions, 48 deletions
diff --git a/modules/httplib/httplib.go b/modules/httplib/httplib.go index 294ad0b70b..5c8eac8b42 100644 --- a/modules/httplib/httplib.go +++ b/modules/httplib/httplib.go @@ -7,6 +7,7 @@ package httplib import ( "bytes" + "context" "crypto/tls" "encoding/xml" "io" @@ -122,6 +123,12 @@ func (r *Request) Setting(setting Settings) *Request { return r } +// SetContext sets the request's Context +func (r *Request) SetContext(ctx context.Context) *Request { + r.req = r.req.WithContext(ctx) + return r +} + // SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. func (r *Request) SetBasicAuth(username, password string) *Request { r.req.SetBasicAuth(username, password) @@ -325,7 +332,7 @@ func (r *Request) getResponse() (*http.Response, error) { trans = &http.Transport{ TLSClientConfig: r.setting.TLSClientConfig, Proxy: proxy, - Dial: TimeoutDialer(r.setting.ConnectTimeout), + DialContext: TimeoutDialer(r.setting.ConnectTimeout), } } else if t, ok := trans.(*http.Transport); ok { if t.TLSClientConfig == nil { @@ -334,8 +341,8 @@ func (r *Request) getResponse() (*http.Response, error) { if t.Proxy == nil { t.Proxy = r.setting.Proxy } - if t.Dial == nil { - t.Dial = TimeoutDialer(r.setting.ConnectTimeout) + if t.DialContext == nil { + t.DialContext = TimeoutDialer(r.setting.ConnectTimeout) } } @@ -458,9 +465,10 @@ func (r *Request) Response() (*http.Response, error) { } // TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. -func TimeoutDialer(cTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { - return func(netw, addr string) (net.Conn, error) { - conn, err := net.DialTimeout(netw, addr, cTimeout) +func TimeoutDialer(cTimeout time.Duration) func(ctx context.Context, net, addr string) (c net.Conn, err error) { + return func(ctx context.Context, netw, addr string) (net.Conn, error) { + d := net.Dialer{Timeout: cTimeout} + conn, err := d.DialContext(ctx, netw, addr) if err != nil { return nil, err } diff --git a/modules/private/hook.go b/modules/private/hook.go index 82dcaf3fc9..79fae052dd 100644 --- a/modules/private/hook.go +++ b/modules/private/hook.go @@ -5,6 +5,7 @@ package private import ( + "context" "encoding/json" "fmt" "net/http" @@ -80,12 +81,12 @@ type HookPostReceiveBranchResult struct { } // HookPreReceive check whether the provided commits are allowed -func HookPreReceive(ownerName, repoName string, opts HookOptions) (int, string) { +func HookPreReceive(ctx context.Context, ownerName, repoName string, opts HookOptions) (int, string) { reqURL := setting.LocalURL + fmt.Sprintf("api/internal/hook/pre-receive/%s/%s", url.PathEscape(ownerName), url.PathEscape(repoName), ) - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req = req.Header("Content-Type", "application/json") json := jsoniter.ConfigCompatibleWithStandardLibrary jsonBytes, _ := json.Marshal(opts) @@ -105,13 +106,13 @@ func HookPreReceive(ownerName, repoName string, opts HookOptions) (int, string) } // HookPostReceive updates services and users -func HookPostReceive(ownerName, repoName string, opts HookOptions) (*HookPostReceiveResult, string) { +func HookPostReceive(ctx context.Context, ownerName, repoName string, opts HookOptions) (*HookPostReceiveResult, string) { reqURL := setting.LocalURL + fmt.Sprintf("api/internal/hook/post-receive/%s/%s", url.PathEscape(ownerName), url.PathEscape(repoName), ) - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req = req.Header("Content-Type", "application/json") req.SetTimeout(60*time.Second, time.Duration(60+len(opts.OldCommitIDs))*time.Second) json := jsoniter.ConfigCompatibleWithStandardLibrary @@ -133,13 +134,13 @@ func HookPostReceive(ownerName, repoName string, opts HookOptions) (*HookPostRec } // SetDefaultBranch will set the default branch to the provided branch for the provided repository -func SetDefaultBranch(ownerName, repoName, branch string) error { +func SetDefaultBranch(ctx context.Context, ownerName, repoName, branch string) error { reqURL := setting.LocalURL + fmt.Sprintf("api/internal/hook/set-default-branch/%s/%s/%s", url.PathEscape(ownerName), url.PathEscape(repoName), url.PathEscape(branch), ) - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req = req.Header("Content-Type", "application/json") req.SetTimeout(60*time.Second, 60*time.Second) @@ -155,9 +156,9 @@ func SetDefaultBranch(ownerName, repoName, branch string) error { } // SSHLog sends ssh error log response -func SSHLog(isErr bool, msg string) error { +func SSHLog(ctx context.Context, isErr bool, msg string) error { reqURL := setting.LocalURL + "api/internal/ssh/log" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req = req.Header("Content-Type", "application/json") jsonBytes, _ := json.Marshal(&SSHLogOption{ @@ -171,6 +172,7 @@ func SSHLog(isErr bool, msg string) error { if err != nil { return fmt.Errorf("unable to contact gitea: %v", err) } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("Error returned from gitea: %v", decodeJSONError(resp).Err) diff --git a/modules/private/internal.go b/modules/private/internal.go index 360fae47b6..672ac74970 100644 --- a/modules/private/internal.go +++ b/modules/private/internal.go @@ -5,6 +5,7 @@ package private import ( + "context" "crypto/tls" "fmt" "net" @@ -15,9 +16,11 @@ import ( jsoniter "github.com/json-iterator/go" ) -func newRequest(url, method string) *httplib.Request { - return httplib.NewRequest(url, method).Header("Authorization", - fmt.Sprintf("Bearer %s", setting.InternalToken)) +func newRequest(ctx context.Context, url, method string) *httplib.Request { + return httplib.NewRequest(url, method). + SetContext(ctx). + Header("Authorization", + fmt.Sprintf("Bearer %s", setting.InternalToken)) } // Response internal request response @@ -35,8 +38,8 @@ func decodeJSONError(resp *http.Response) *Response { return &res } -func newInternalRequest(url, method string) *httplib.Request { - req := newRequest(url, method).SetTLSClientConfig(&tls.Config{ +func newInternalRequest(ctx context.Context, url, method string) *httplib.Request { + req := newRequest(ctx, url, method).SetTLSClientConfig(&tls.Config{ InsecureSkipVerify: true, ServerName: setting.Domain, }) @@ -45,6 +48,10 @@ func newInternalRequest(url, method string) *httplib.Request { Dial: func(_, _ string) (net.Conn, error) { return net.Dial("unix", setting.HTTPAddr) }, + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", setting.HTTPAddr) + }, }) } return req diff --git a/modules/private/key.go b/modules/private/key.go index bea7837906..d0b11a96e7 100644 --- a/modules/private/key.go +++ b/modules/private/key.go @@ -5,6 +5,7 @@ package private import ( + "context" "fmt" "io/ioutil" "net/http" @@ -13,10 +14,10 @@ import ( ) // UpdatePublicKeyInRepo update public key and if necessary deploy key updates -func UpdatePublicKeyInRepo(keyID, repoID int64) error { +func UpdatePublicKeyInRepo(ctx context.Context, keyID, repoID int64) error { // Ask for running deliver hook and test pull request tasks. reqURL := setting.LocalURL + fmt.Sprintf("api/internal/ssh/%d/update/%d", keyID, repoID) - resp, err := newInternalRequest(reqURL, "POST").Response() + resp, err := newInternalRequest(ctx, reqURL, "POST").Response() if err != nil { return err } @@ -32,10 +33,10 @@ func UpdatePublicKeyInRepo(keyID, repoID int64) error { // AuthorizedPublicKeyByContent searches content as prefix (leak e-mail part) // and returns public key found. -func AuthorizedPublicKeyByContent(content string) (string, error) { +func AuthorizedPublicKeyByContent(ctx context.Context, content string) (string, error) { // Ask for running deliver hook and test pull request tasks. reqURL := setting.LocalURL + "api/internal/ssh/authorized_keys" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req.Param("content", content) resp, err := req.Response() if err != nil { diff --git a/modules/private/mail.go b/modules/private/mail.go index 9c0912a6e3..4a5a3eedd7 100644 --- a/modules/private/mail.go +++ b/modules/private/mail.go @@ -5,6 +5,7 @@ package private import ( + "context" "fmt" "io/ioutil" "net/http" @@ -27,10 +28,10 @@ type Email struct { // // If to list == nil its supposed to send an email to every // user present in DB -func SendEmail(subject, message string, to []string) (int, string) { +func SendEmail(ctx context.Context, subject, message string, to []string) (int, string) { reqURL := setting.LocalURL + "api/internal/mail/send" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req = req.Header("Content-Type", "application/json") json := jsoniter.ConfigCompatibleWithStandardLibrary jsonBytes, _ := json.Marshal(Email{ diff --git a/modules/private/manager.go b/modules/private/manager.go index 2bc6cec3b9..0bcc3f8112 100644 --- a/modules/private/manager.go +++ b/modules/private/manager.go @@ -5,6 +5,7 @@ package private import ( + "context" "fmt" "net/http" "net/url" @@ -15,10 +16,10 @@ import ( ) // Shutdown calls the internal shutdown function -func Shutdown() (int, string) { +func Shutdown(ctx context.Context) (int, string) { reqURL := setting.LocalURL + "api/internal/manager/shutdown" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") resp, err := req.Response() if err != nil { return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error()) @@ -33,10 +34,10 @@ func Shutdown() (int, string) { } // Restart calls the internal restart function -func Restart() (int, string) { +func Restart(ctx context.Context) (int, string) { reqURL := setting.LocalURL + "api/internal/manager/restart" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") resp, err := req.Response() if err != nil { return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error()) @@ -57,10 +58,10 @@ type FlushOptions struct { } // FlushQueues calls the internal flush-queues function -func FlushQueues(timeout time.Duration, nonBlocking bool) (int, string) { +func FlushQueues(ctx context.Context, timeout time.Duration, nonBlocking bool) (int, string) { reqURL := setting.LocalURL + "api/internal/manager/flush-queues" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") if timeout > 0 { req.SetTimeout(timeout+10*time.Second, timeout+10*time.Second) } @@ -85,10 +86,10 @@ func FlushQueues(timeout time.Duration, nonBlocking bool) (int, string) { } // PauseLogging pauses logging -func PauseLogging() (int, string) { +func PauseLogging(ctx context.Context) (int, string) { reqURL := setting.LocalURL + "api/internal/manager/pause-logging" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") resp, err := req.Response() if err != nil { return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error()) @@ -103,10 +104,10 @@ func PauseLogging() (int, string) { } // ResumeLogging resumes logging -func ResumeLogging() (int, string) { +func ResumeLogging(ctx context.Context) (int, string) { reqURL := setting.LocalURL + "api/internal/manager/resume-logging" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") resp, err := req.Response() if err != nil { return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error()) @@ -121,10 +122,10 @@ func ResumeLogging() (int, string) { } // ReleaseReopenLogging releases and reopens logging files -func ReleaseReopenLogging() (int, string) { +func ReleaseReopenLogging(ctx context.Context) (int, string) { reqURL := setting.LocalURL + "api/internal/manager/release-and-reopen-logging" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") resp, err := req.Response() if err != nil { return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error()) @@ -147,10 +148,10 @@ type LoggerOptions struct { } // AddLogger adds a logger -func AddLogger(group, name, mode string, config map[string]interface{}) (int, string) { +func AddLogger(ctx context.Context, group, name, mode string, config map[string]interface{}) (int, string) { reqURL := setting.LocalURL + "api/internal/manager/add-logger" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req = req.Header("Content-Type", "application/json") json := jsoniter.ConfigCompatibleWithStandardLibrary jsonBytes, _ := json.Marshal(LoggerOptions{ @@ -175,10 +176,10 @@ func AddLogger(group, name, mode string, config map[string]interface{}) (int, st } // RemoveLogger removes a logger -func RemoveLogger(group, name string) (int, string) { +func RemoveLogger(ctx context.Context, group, name string) (int, string) { reqURL := setting.LocalURL + fmt.Sprintf("api/internal/manager/remove-logger/%s/%s", url.PathEscape(group), url.PathEscape(name)) - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") resp, err := req.Response() if err != nil { return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error()) diff --git a/modules/private/restore_repo.go b/modules/private/restore_repo.go index 6fe2e6844b..66b60d8d12 100644 --- a/modules/private/restore_repo.go +++ b/modules/private/restore_repo.go @@ -5,6 +5,7 @@ package private import ( + "context" "fmt" "io/ioutil" "net/http" @@ -23,10 +24,10 @@ type RestoreParams struct { } // RestoreRepo calls the internal RestoreRepo function -func RestoreRepo(repoDir, ownerName, repoName string, units []string) (int, string) { +func RestoreRepo(ctx context.Context, repoDir, ownerName, repoName string, units []string) (int, string) { reqURL := setting.LocalURL + "api/internal/restore_repo" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req.SetTimeout(3*time.Second, 0) // since the request will spend much time, don't timeout req = req.Header("Content-Type", "application/json") json := jsoniter.ConfigCompatibleWithStandardLibrary diff --git a/modules/private/serv.go b/modules/private/serv.go index 659af6dff5..9643dad679 100644 --- a/modules/private/serv.go +++ b/modules/private/serv.go @@ -5,6 +5,7 @@ package private import ( + "context" "fmt" "net/http" "net/url" @@ -21,10 +22,10 @@ type KeyAndOwner struct { } // ServNoCommand returns information about the provided key -func ServNoCommand(keyID int64) (*models.PublicKey, *models.User, error) { +func ServNoCommand(ctx context.Context, keyID int64) (*models.PublicKey, *models.User, error) { reqURL := setting.LocalURL + fmt.Sprintf("api/internal/serv/none/%d", keyID) - resp, err := newInternalRequest(reqURL, "GET").Response() + resp, err := newInternalRequest(ctx, reqURL, "GET").Response() if err != nil { return nil, nil, err } @@ -73,7 +74,7 @@ func IsErrServCommand(err error) bool { } // ServCommand preps for a serv call -func ServCommand(keyID int64, ownerName, repoName string, mode models.AccessMode, verbs ...string) (*ServCommandResults, error) { +func ServCommand(ctx context.Context, keyID int64, ownerName, repoName string, mode models.AccessMode, verbs ...string) (*ServCommandResults, error) { reqURL := setting.LocalURL + fmt.Sprintf("api/internal/serv/command/%d/%s/%s?mode=%d", keyID, url.PathEscape(ownerName), @@ -85,7 +86,7 @@ func ServCommand(keyID int64, ownerName, repoName string, mode models.AccessMode } } - resp, err := newInternalRequest(reqURL, "GET").Response() + resp, err := newInternalRequest(ctx, reqURL, "GET").Response() if err != nil { return nil, err } diff --git a/modules/ssh/ssh.go b/modules/ssh/ssh.go index c0897377c5..efe9525345 100644 --- a/modules/ssh/ssh.go +++ b/modules/ssh/ssh.go @@ -6,6 +6,7 @@ package ssh import ( "bytes" + "context" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -66,7 +67,11 @@ func sessionHandler(session ssh.Session) { args := []string{"serv", "key-" + keyID, "--config=" + setting.CustomConf} log.Trace("SSH: Arguments: %v", args) - cmd := exec.CommandContext(session.Context(), setting.AppPath, args...) + + ctx, cancel := context.WithCancel(session.Context()) + defer cancel() + + cmd := exec.CommandContext(ctx, setting.AppPath, args...) cmd.Env = append( os.Environ(), "SSH_ORIGINAL_COMMAND="+command, @@ -78,16 +83,21 @@ func sessionHandler(session ssh.Session) { log.Error("SSH: StdoutPipe: %v", err) return } + defer stdout.Close() + stderr, err := cmd.StderrPipe() if err != nil { log.Error("SSH: StderrPipe: %v", err) return } + defer stderr.Close() + stdin, err := cmd.StdinPipe() if err != nil { log.Error("SSH: StdinPipe: %v", err) return } + defer stdin.Close() wg := &sync.WaitGroup{} wg.Add(2) @@ -106,6 +116,7 @@ func sessionHandler(session ssh.Session) { go func() { defer wg.Done() + defer stdout.Close() if _, err := io.Copy(session, stdout); err != nil { log.Error("Failed to write stdout to session. %s", err) } @@ -113,6 +124,7 @@ func sessionHandler(session ssh.Session) { go func() { defer wg.Done() + defer stderr.Close() if _, err := io.Copy(session.Stderr(), stderr); err != nil { log.Error("Failed to write stderr to session. %s", err) } |