]> source.dussan.org Git - gitea.git/commitdiff
Second attempt at preventing zombies (#16326)
authorzeripath <art27@cantab.net>
Wed, 14 Jul 2021 14:43:13 +0000 (15:43 +0100)
committerGitHub <noreply@github.com>
Wed, 14 Jul 2021 14:43:13 +0000 (10:43 -0400)
* Second attempt at preventing zombies

* Ensure that the pipes are closed in ssh.go
* Ensure that a cancellable context is passed up in cmd/* http requests
* Make cmd.fail return properly so defers are obeyed
* Ensure that something is sent to stdout in case of blocks here

Signed-off-by: Andrew Thornton <art27@cantab.net>
* placate lint

Signed-off-by: Andrew Thornton <art27@cantab.net>
* placate lint 2

Signed-off-by: Andrew Thornton <art27@cantab.net>
* placate lint 3

Signed-off-by: Andrew Thornton <art27@cantab.net>
* fixup

Signed-off-by: Andrew Thornton <art27@cantab.net>
* Apply suggestions from code review

Co-authored-by: 6543 <6543@obermui.de>
Co-authored-by: Lauris BH <lauris@nix.lv>
21 files changed:
cmd/cmd.go
cmd/hook.go
cmd/keys.go
cmd/mailer.go
cmd/manager.go
cmd/restore_repo.go
cmd/serv.go
integrations/mssql.ini.tmpl
integrations/mysql.ini.tmpl
integrations/mysql8.ini.tmpl
integrations/pgsql.ini.tmpl
integrations/sqlite.ini.tmpl
modules/httplib/httplib.go
modules/private/hook.go
modules/private/internal.go
modules/private/key.go
modules/private/mail.go
modules/private/manager.go
modules/private/restore_repo.go
modules/private/serv.go
modules/ssh/ssh.go

index bb768cc159d64b2dc52536ced0b45a3010bc9e32..8d9d1ee077ed57e1ce081d95d8fc613bd39eb431 100644 (file)
@@ -7,9 +7,13 @@
 package cmd
 
 import (
+       "context"
        "errors"
        "fmt"
+       "os"
+       "os/signal"
        "strings"
+       "syscall"
 
        "code.gitea.io/gitea/models"
        "code.gitea.io/gitea/modules/setting"
@@ -66,3 +70,25 @@ func initDBDisableConsole(disableConsole bool) error {
        }
        return nil
 }
+
+func installSignals() (context.Context, context.CancelFunc) {
+       ctx, cancel := context.WithCancel(context.Background())
+       go func() {
+               // install notify
+               signalChannel := make(chan os.Signal, 1)
+
+               signal.Notify(
+                       signalChannel,
+                       syscall.SIGINT,
+                       syscall.SIGTERM,
+               )
+               select {
+               case <-signalChannel:
+               case <-ctx.Done():
+               }
+               cancel()
+               signal.Reset()
+       }()
+
+       return ctx, cancel
+}
index 067a0bfb8ab9eabea5c5a7196d1fccb0b4c43889..87f1f37562e2e16a2fc8bb9f42ecc335a308a484 100644 (file)
@@ -152,17 +152,18 @@ func runHookPreReceive(c *cli.Context) error {
        if os.Getenv(models.EnvIsInternal) == "true" {
                return nil
        }
+       ctx, cancel := installSignals()
+       defer cancel()
 
        setup("hooks/pre-receive.log", c.Bool("debug"))
 
        if len(os.Getenv("SSH_ORIGINAL_COMMAND")) == 0 {
                if setting.OnlyAllowPushIfGiteaEnvironmentSet {
-                       fail(`Rejecting changes as Gitea environment not set.
+                       return fail(`Rejecting changes as Gitea environment not set.
 If you are pushing over SSH you must push with a key managed by
 Gitea or set your environment appropriately.`, "")
-               } else {
-                       return nil
                }
+               return nil
        }
 
        // the environment is set by serv command
@@ -235,14 +236,14 @@ Gitea or set your environment appropriately.`, "")
                                hookOptions.OldCommitIDs = oldCommitIDs
                                hookOptions.NewCommitIDs = newCommitIDs
                                hookOptions.RefFullNames = refFullNames
-                               statusCode, msg := private.HookPreReceive(username, reponame, hookOptions)
+                               statusCode, msg := private.HookPreReceive(ctx, username, reponame, hookOptions)
                                switch statusCode {
                                case http.StatusOK:
                                        // no-op
                                case http.StatusInternalServerError:
-                                       fail("Internal Server Error", msg)
+                                       return fail("Internal Server Error", msg)
                                default:
-                                       fail(msg, "")
+                                       return fail(msg, "")
                                }
                                count = 0
                                lastline = 0
@@ -263,12 +264,12 @@ Gitea or set your environment appropriately.`, "")
 
                fmt.Fprintf(out, " Checking %d references\n", count)
 
-               statusCode, msg := private.HookPreReceive(username, reponame, hookOptions)
+               statusCode, msg := private.HookPreReceive(ctx, username, reponame, hookOptions)
                switch statusCode {
                case http.StatusInternalServerError:
-                       fail("Internal Server Error", msg)
+                       return fail("Internal Server Error", msg)
                case http.StatusForbidden:
-                       fail(msg, "")
+                       return fail(msg, "")
                }
        } else if lastline > 0 {
                fmt.Fprintf(out, "\n")
@@ -285,8 +286,11 @@ func runHookUpdate(c *cli.Context) error {
 }
 
 func runHookPostReceive(c *cli.Context) error {
+       ctx, cancel := installSignals()
+       defer cancel()
+
        // First of all run update-server-info no matter what
-       if _, err := git.NewCommand("update-server-info").Run(); err != nil {
+       if _, err := git.NewCommand("update-server-info").SetParentContext(ctx).Run(); err != nil {
                return fmt.Errorf("Failed to call 'git update-server-info': %v", err)
        }
 
@@ -299,12 +303,11 @@ func runHookPostReceive(c *cli.Context) error {
 
        if len(os.Getenv("SSH_ORIGINAL_COMMAND")) == 0 {
                if setting.OnlyAllowPushIfGiteaEnvironmentSet {
-                       fail(`Rejecting changes as Gitea environment not set.
+                       return fail(`Rejecting changes as Gitea environment not set.
 If you are pushing over SSH you must push with a key managed by
 Gitea or set your environment appropriately.`, "")
-               } else {
-                       return nil
                }
+               return nil
        }
 
        var out io.Writer
@@ -371,11 +374,11 @@ Gitea or set your environment appropriately.`, "")
                        hookOptions.OldCommitIDs = oldCommitIDs
                        hookOptions.NewCommitIDs = newCommitIDs
                        hookOptions.RefFullNames = refFullNames
-                       resp, err := private.HookPostReceive(repoUser, repoName, hookOptions)
+                       resp, err := private.HookPostReceive(ctx, repoUser, repoName, hookOptions)
                        if resp == nil {
                                _ = dWriter.Close()
                                hookPrintResults(results)
-                               fail("Internal Server Error", err)
+                               return fail("Internal Server Error", err)
                        }
                        wasEmpty = wasEmpty || resp.RepoWasEmpty
                        results = append(results, resp.Results...)
@@ -386,9 +389,9 @@ Gitea or set your environment appropriately.`, "")
        if count == 0 {
                if wasEmpty && masterPushed {
                        // We need to tell the repo to reset the default branch to master
-                       err := private.SetDefaultBranch(repoUser, repoName, "master")
+                       err := private.SetDefaultBranch(ctx, repoUser, repoName, "master")
                        if err != nil {
-                               fail("Internal Server Error", "SetDefaultBranch failed with Error: %v", err)
+                               return fail("Internal Server Error", "SetDefaultBranch failed with Error: %v", err)
                        }
                }
                fmt.Fprintf(out, "Processed %d references in total\n", total)
@@ -404,11 +407,11 @@ Gitea or set your environment appropriately.`, "")
 
        fmt.Fprintf(out, " Processing %d references\n", count)
 
-       resp, err := private.HookPostReceive(repoUser, repoName, hookOptions)
+       resp, err := private.HookPostReceive(ctx, repoUser, repoName, hookOptions)
        if resp == nil {
                _ = dWriter.Close()
                hookPrintResults(results)
-               fail("Internal Server Error", err)
+               return fail("Internal Server Error", err)
        }
        wasEmpty = wasEmpty || resp.RepoWasEmpty
        results = append(results, resp.Results...)
@@ -417,9 +420,9 @@ Gitea or set your environment appropriately.`, "")
 
        if wasEmpty && masterPushed {
                // We need to tell the repo to reset the default branch to master
-               err := private.SetDefaultBranch(repoUser, repoName, "master")
+               err := private.SetDefaultBranch(ctx, repoUser, repoName, "master")
                if err != nil {
-                       fail("Internal Server Error", "SetDefaultBranch failed with Error: %v", err)
+                       return fail("Internal Server Error", "SetDefaultBranch failed with Error: %v", err)
                }
        }
        _ = dWriter.Close()
index 7456815cd77bcd30e3504aea4d6b31230751bdf3..684aca64e22abfc018d4c174312328f95949376f 100644 (file)
@@ -62,9 +62,12 @@ func runKeys(c *cli.Context) error {
                return errors.New("No key type and content provided")
        }
 
+       ctx, cancel := installSignals()
+       defer cancel()
+
        setup("keys.log", false)
 
-       authorizedString, err := private.AuthorizedPublicKeyByContent(content)
+       authorizedString, err := private.AuthorizedPublicKeyByContent(ctx, content)
        if err != nil {
                return err
        }
index ee11b56cc77bf2cf70747461f0f61194cc484b04..1a4b0902e268fc7e1b0bef45f8216238e23eecef 100644 (file)
@@ -14,6 +14,9 @@ import (
 )
 
 func runSendMail(c *cli.Context) error {
+       ctx, cancel := installSignals()
+       defer cancel()
+
        setting.NewContext()
 
        if err := argsSet(c, "title"); err != nil {
@@ -39,7 +42,7 @@ func runSendMail(c *cli.Context) error {
                }
        }
 
-       status, message := private.SendEmail(subject, body, nil)
+       status, message := private.SendEmail(ctx, subject, body, nil)
        if status != http.StatusOK {
                fmt.Printf("error: %s\n", message)
                return nil
index 20c7858682aca20af4b7476633bc62f2328784ae..99d283b4418e7509fde75ca59cf687ad5154bb67 100644 (file)
@@ -236,10 +236,13 @@ func runRemoveLogger(c *cli.Context) error {
                group = log.DEFAULT
        }
        name := c.Args().First()
-       statusCode, msg := private.RemoveLogger(group, name)
+       ctx, cancel := installSignals()
+       defer cancel()
+
+       statusCode, msg := private.RemoveLogger(ctx, group, name)
        switch statusCode {
        case http.StatusInternalServerError:
-               fail("InternalServerError", msg)
+               return fail("InternalServerError", msg)
        }
 
        fmt.Fprintln(os.Stdout, msg)
@@ -371,10 +374,13 @@ func commonAddLogger(c *cli.Context, mode string, vals map[string]interface{}) e
        if c.IsSet("name") {
                name = c.String("name")
        }
-       statusCode, msg := private.AddLogger(group, name, mode, vals)
+       ctx, cancel := installSignals()
+       defer cancel()
+
+       statusCode, msg := private.AddLogger(ctx, group, name, mode, vals)
        switch statusCode {
        case http.StatusInternalServerError:
-               fail("InternalServerError", msg)
+               return fail("InternalServerError", msg)
        }
 
        fmt.Fprintln(os.Stdout, msg)
@@ -382,11 +388,14 @@ func commonAddLogger(c *cli.Context, mode string, vals map[string]interface{}) e
 }
 
 func runShutdown(c *cli.Context) error {
+       ctx, cancel := installSignals()
+       defer cancel()
+
        setup("manager", c.Bool("debug"))
-       statusCode, msg := private.Shutdown()
+       statusCode, msg := private.Shutdown(ctx)
        switch statusCode {
        case http.StatusInternalServerError:
-               fail("InternalServerError", msg)
+               return fail("InternalServerError", msg)
        }
 
        fmt.Fprintln(os.Stdout, msg)
@@ -394,11 +403,14 @@ func runShutdown(c *cli.Context) error {
 }
 
 func runRestart(c *cli.Context) error {
+       ctx, cancel := installSignals()
+       defer cancel()
+
        setup("manager", c.Bool("debug"))
-       statusCode, msg := private.Restart()
+       statusCode, msg := private.Restart(ctx)
        switch statusCode {
        case http.StatusInternalServerError:
-               fail("InternalServerError", msg)
+               return fail("InternalServerError", msg)
        }
 
        fmt.Fprintln(os.Stdout, msg)
@@ -406,11 +418,14 @@ func runRestart(c *cli.Context) error {
 }
 
 func runFlushQueues(c *cli.Context) error {
+       ctx, cancel := installSignals()
+       defer cancel()
+
        setup("manager", c.Bool("debug"))
-       statusCode, msg := private.FlushQueues(c.Duration("timeout"), c.Bool("non-blocking"))
+       statusCode, msg := private.FlushQueues(ctx, c.Duration("timeout"), c.Bool("non-blocking"))
        switch statusCode {
        case http.StatusInternalServerError:
-               fail("InternalServerError", msg)
+               return fail("InternalServerError", msg)
        }
 
        fmt.Fprintln(os.Stdout, msg)
@@ -418,11 +433,14 @@ func runFlushQueues(c *cli.Context) error {
 }
 
 func runPauseLogging(c *cli.Context) error {
+       ctx, cancel := installSignals()
+       defer cancel()
+
        setup("manager", c.Bool("debug"))
-       statusCode, msg := private.PauseLogging()
+       statusCode, msg := private.PauseLogging(ctx)
        switch statusCode {
        case http.StatusInternalServerError:
-               fail("InternalServerError", msg)
+               return fail("InternalServerError", msg)
        }
 
        fmt.Fprintln(os.Stdout, msg)
@@ -430,11 +448,14 @@ func runPauseLogging(c *cli.Context) error {
 }
 
 func runResumeLogging(c *cli.Context) error {
+       ctx, cancel := installSignals()
+       defer cancel()
+
        setup("manager", c.Bool("debug"))
-       statusCode, msg := private.ResumeLogging()
+       statusCode, msg := private.ResumeLogging(ctx)
        switch statusCode {
        case http.StatusInternalServerError:
-               fail("InternalServerError", msg)
+               return fail("InternalServerError", msg)
        }
 
        fmt.Fprintln(os.Stdout, msg)
@@ -442,11 +463,14 @@ func runResumeLogging(c *cli.Context) error {
 }
 
 func runReleaseReopenLogging(c *cli.Context) error {
+       ctx, cancel := installSignals()
+       defer cancel()
+
        setup("manager", c.Bool("debug"))
-       statusCode, msg := private.ReleaseReopenLogging()
+       statusCode, msg := private.ReleaseReopenLogging(ctx)
        switch statusCode {
        case http.StatusInternalServerError:
-               fail("InternalServerError", msg)
+               return fail("InternalServerError", msg)
        }
 
        fmt.Fprintln(os.Stdout, msg)
index b832471928f4af122361b255ff46b6f85a5e68f8..1208796c9bde417b7cb1c9afb3da5b43b6cbbb1f 100644 (file)
@@ -40,20 +40,24 @@ var CmdRestoreRepository = cli.Command{
                cli.StringFlag{
                        Name:  "units",
                        Value: "",
-                       Usage: `Which items will be restored, one or more units should be separated as comma. 
+                       Usage: `Which items will be restored, one or more units should be separated as comma.
 wiki, issues, labels, releases, release_assets, milestones, pull_requests, comments are allowed. Empty means all units.`,
                },
        },
 }
 
-func runRestoreRepository(ctx *cli.Context) error {
+func runRestoreRepository(c *cli.Context) error {
+       ctx, cancel := installSignals()
+       defer cancel()
+
        setting.NewContext()
 
        statusCode, errStr := private.RestoreRepo(
-               ctx.String("repo_dir"),
-               ctx.String("owner_name"),
-               ctx.String("repo_name"),
-               ctx.StringSlice("units"),
+               ctx,
+               c.String("repo_dir"),
+               c.String("owner_name"),
+               c.String("repo_name"),
+               c.StringSlice("units"),
        )
        if statusCode == http.StatusOK {
                return nil
index 40f8b89c9a98b44f99ca20b0483a5e839e077304..97ae901d270eecef4f492865b6c8ce88eead14b7 100644 (file)
@@ -6,17 +6,14 @@
 package cmd
 
 import (
-       "context"
        "fmt"
        "net/http"
        "net/url"
        "os"
        "os/exec"
-       "os/signal"
        "regexp"
        "strconv"
        "strings"
-       "syscall"
        "time"
 
        "code.gitea.io/gitea/models"
@@ -75,7 +72,10 @@ var (
        alphaDashDotPattern = regexp.MustCompile(`[^\w-\.]`)
 )
 
-func fail(userMessage, logMessage string, args ...interface{}) {
+func fail(userMessage, logMessage string, args ...interface{}) error {
+       // There appears to be a chance to cause a zombie process and failure to read the Exit status
+       // if nothing is outputted on stdout.
+       fmt.Fprintln(os.Stdout, "")
        fmt.Fprintln(os.Stderr, "Gitea:", userMessage)
 
        if len(logMessage) > 0 {
@@ -83,15 +83,19 @@ func fail(userMessage, logMessage string, args ...interface{}) {
                        fmt.Fprintf(os.Stderr, logMessage+"\n", args...)
                }
        }
+       ctx, cancel := installSignals()
+       defer cancel()
 
        if len(logMessage) > 0 {
-               _ = private.SSHLog(true, fmt.Sprintf(logMessage+": ", args...))
+               _ = private.SSHLog(ctx, true, fmt.Sprintf(logMessage+": ", args...))
        }
-
-       os.Exit(1)
+       return cli.NewExitError(fmt.Sprintf("Gitea: %s", userMessage), 1)
 }
 
 func runServ(c *cli.Context) error {
+       ctx, cancel := installSignals()
+       defer cancel()
+
        // FIXME: This needs to internationalised
        setup("serv.log", c.Bool("debug"))
 
@@ -109,18 +113,18 @@ func runServ(c *cli.Context) error {
 
        keys := strings.Split(c.Args()[0], "-")
        if len(keys) != 2 || keys[0] != "key" {
-               fail("Key ID format error", "Invalid key argument: %s", c.Args()[0])
+               return fail("Key ID format error", "Invalid key argument: %s", c.Args()[0])
        }
        keyID, err := strconv.ParseInt(keys[1], 10, 64)
        if err != nil {
-               fail("Key ID format error", "Invalid key argument: %s", c.Args()[1])
+               return fail("Key ID format error", "Invalid key argument: %s", c.Args()[1])
        }
 
        cmd := os.Getenv("SSH_ORIGINAL_COMMAND")
        if len(cmd) == 0 {
-               key, user, err := private.ServNoCommand(keyID)
+               key, user, err := private.ServNoCommand(ctx, keyID)
                if err != nil {
-                       fail("Internal error", "Failed to check provided key: %v", err)
+                       return fail("Internal error", "Failed to check provided key: %v", err)
                }
                switch key.Type {
                case models.KeyTypeDeploy:
@@ -138,11 +142,11 @@ func runServ(c *cli.Context) error {
 
        words, err := shellquote.Split(cmd)
        if err != nil {
-               fail("Error parsing arguments", "Failed to parse arguments: %v", err)
+               return fail("Error parsing arguments", "Failed to parse arguments: %v", err)
        }
 
        if len(words) < 2 {
-               fail("Too few arguments", "Too few arguments in cmd: %s", cmd)
+               return fail("Too few arguments", "Too few arguments in cmd: %s", cmd)
        }
 
        verb := words[0]
@@ -154,7 +158,7 @@ func runServ(c *cli.Context) error {
        var lfsVerb string
        if verb == lfsAuthenticateVerb {
                if !setting.LFS.StartServer {
-                       fail("Unknown git command", "LFS authentication request over SSH denied, LFS support is disabled")
+                       return fail("Unknown git command", "LFS authentication request over SSH denied, LFS support is disabled")
                }
 
                if len(words) > 2 {
@@ -167,37 +171,37 @@ func runServ(c *cli.Context) error {
 
        rr := strings.SplitN(repoPath, "/", 2)
        if len(rr) != 2 {
-               fail("Invalid repository path", "Invalid repository path: %v", repoPath)
+               return fail("Invalid repository path", "Invalid repository path: %v", repoPath)
        }
 
        username := strings.ToLower(rr[0])
        reponame := strings.ToLower(strings.TrimSuffix(rr[1], ".git"))
 
        if alphaDashDotPattern.MatchString(reponame) {
-               fail("Invalid repo name", "Invalid repo name: %s", reponame)
+               return fail("Invalid repo name", "Invalid repo name: %s", reponame)
        }
 
        if setting.EnablePprof || c.Bool("enable-pprof") {
                if err := os.MkdirAll(setting.PprofDataPath, os.ModePerm); err != nil {
-                       fail("Error while trying to create PPROF_DATA_PATH", "Error while trying to create PPROF_DATA_PATH: %v", err)
+                       return fail("Error while trying to create PPROF_DATA_PATH", "Error while trying to create PPROF_DATA_PATH: %v", err)
                }
 
                stopCPUProfiler, err := pprof.DumpCPUProfileForUsername(setting.PprofDataPath, username)
                if err != nil {
-                       fail("Internal Server Error", "Unable to start CPU profile: %v", err)
+                       return fail("Internal Server Error", "Unable to start CPU profile: %v", err)
                }
                defer func() {
                        stopCPUProfiler()
                        err := pprof.DumpMemProfileForUsername(setting.PprofDataPath, username)
                        if err != nil {
-                               fail("Internal Server Error", "Unable to dump Mem Profile: %v", err)
+                               _ = fail("Internal Server Error", "Unable to dump Mem Profile: %v", err)
                        }
                }()
        }
 
        requestedMode, has := allowedCommands[verb]
        if !has {
-               fail("Unknown git command", "Unknown git command %s", verb)
+               return fail("Unknown git command", "Unknown git command %s", verb)
        }
 
        if verb == lfsAuthenticateVerb {
@@ -206,21 +210,20 @@ func runServ(c *cli.Context) error {
                } else if lfsVerb == "download" {
                        requestedMode = models.AccessModeRead
                } else {
-                       fail("Unknown LFS verb", "Unknown lfs verb %s", lfsVerb)
+                       return fail("Unknown LFS verb", "Unknown lfs verb %s", lfsVerb)
                }
        }
 
-       results, err := private.ServCommand(keyID, username, reponame, requestedMode, verb, lfsVerb)
+       results, err := private.ServCommand(ctx, keyID, username, reponame, requestedMode, verb, lfsVerb)
        if err != nil {
                if private.IsErrServCommand(err) {
                        errServCommand := err.(private.ErrServCommand)
                        if errServCommand.StatusCode != http.StatusInternalServerError {
-                               fail("Unauthorized", "%s", errServCommand.Error())
-                       } else {
-                               fail("Internal Server Error", "%s", errServCommand.Error())
+                               return fail("Unauthorized", "%s", errServCommand.Error())
                        }
+                       return fail("Internal Server Error", "%s", errServCommand.Error())
                }
-               fail("Internal Server Error", "%s", err.Error())
+               return fail("Internal Server Error", "%s", err.Error())
        }
        os.Setenv(models.EnvRepoIsWiki, strconv.FormatBool(results.IsWiki))
        os.Setenv(models.EnvRepoName, results.RepoName)
@@ -253,7 +256,7 @@ func runServ(c *cli.Context) error {
                // Sign and get the complete encoded token as a string using the secret
                tokenString, err := token.SignedString(setting.LFS.JWTSecretBytes)
                if err != nil {
-                       fail("Internal error", "Failed to sign JWT token: %v", err)
+                       return fail("Internal error", "Failed to sign JWT token: %v", err)
                }
 
                tokenAuthentication := &models.LFSTokenResponse{
@@ -266,7 +269,7 @@ func runServ(c *cli.Context) error {
                enc := json.NewEncoder(os.Stdout)
                err = enc.Encode(tokenAuthentication)
                if err != nil {
-                       fail("Internal error", "Failed to encode LFS json response: %v", err)
+                       return fail("Internal error", "Failed to encode LFS json response: %v", err)
                }
                return nil
        }
@@ -276,25 +279,6 @@ func runServ(c *cli.Context) error {
                verb = strings.Replace(verb, "-", " ", 1)
        }
 
-       ctx, cancel := context.WithCancel(context.Background())
-       defer cancel()
-       go func() {
-               // install notify
-               signalChannel := make(chan os.Signal, 1)
-
-               signal.Notify(
-                       signalChannel,
-                       syscall.SIGINT,
-                       syscall.SIGTERM,
-               )
-               select {
-               case <-signalChannel:
-               case <-ctx.Done():
-               }
-               cancel()
-               signal.Reset()
-       }()
-
        var gitcmd *exec.Cmd
        verbs := strings.Split(verb, " ")
        if len(verbs) == 2 {
@@ -308,13 +292,13 @@ func runServ(c *cli.Context) error {
        gitcmd.Stdin = os.Stdin
        gitcmd.Stderr = os.Stderr
        if err = gitcmd.Run(); err != nil {
-               fail("Internal error", "Failed to execute git command: %v", err)
+               return fail("Internal error", "Failed to execute git command: %v", err)
        }
 
        // Update user key activity.
        if results.KeyID > 0 {
-               if err = private.UpdatePublicKeyInRepo(results.KeyID, results.RepoID); err != nil {
-                       fail("Internal error", "UpdatePublicKeyInRepo: %v", err)
+               if err = private.UpdatePublicKeyInRepo(ctx, results.KeyID, results.RepoID); err != nil {
+                       return fail("Internal error", "UpdatePublicKeyInRepo: %v", err)
                }
        }
 
index 1867070ff5ef2f136a0b7d53e07bc1bad1e7fb4e..0a7710fc5fe11f6515da7f0a4cea8a8f7cf675ce 100644 (file)
@@ -83,6 +83,7 @@ MODE                 = test,file
 ROOT_PATH            = mssql-log
 ROUTER               = ,
 XORM                 = file
+ENABLE_SSH_LOG       = true
 
 [log.test]
 LEVEL                = Info
index 176992cb26d0a106f70c3fd15bc6011205b84c37..a78b0425a198a546999a6ea99655e7ad1ff2509e 100644 (file)
@@ -101,6 +101,7 @@ MODE                 = test,file
 ROOT_PATH            = mysql-log
 ROUTER               = ,
 XORM                 = file
+ENABLE_SSH_LOG       = true
 
 [log.test]
 LEVEL                = Info
index 7c5bcb58dc5f586f39d20f2108b66ecd18654ab2..1151b6abc26a52d4d1316aff719bbf8195359969 100644 (file)
@@ -80,6 +80,7 @@ MODE                 = test,file
 ROOT_PATH            = mysql8-log
 ROUTER               = ,
 XORM                 = file
+ENABLE_SSH_LOG       = true
 
 [log.test]
 LEVEL                = Info
index 3a4a5e6c4fe52dc54a4853ab41e847e31cc2086b..f11d4faba5d93a1ec0dc962e4b1ceaf2ec189beb 100644 (file)
@@ -84,6 +84,7 @@ MODE                 = test,file
 ROOT_PATH            = pgsql-log
 ROUTER               = ,
 XORM                 = file
+ENABLE_SSH_LOG       = true
 
 [log.test]
 LEVEL                = Info
index 4a796e93178714784aaf158faa2ecfc84b64b5c9..71ac39a44baf5dd7c242bd5cea221aca37ea4c93 100644 (file)
@@ -79,6 +79,7 @@ MODE                 = test,file
 ROOT_PATH            = sqlite-log
 ROUTER               = ,
 XORM                 = file
+ENABLE_SSH_LOG       = true
 
 [log.test]
 LEVEL                = Info
index 294ad0b70b6721cccad81d4f264662567e0f2104..5c8eac8b4283ae2cad76619b7f2d676a076497d8 100644 (file)
@@ -7,6 +7,7 @@ package httplib
 
 import (
        "bytes"
+       "context"
        "crypto/tls"
        "encoding/xml"
        "io"
@@ -122,6 +123,12 @@ func (r *Request) Setting(setting Settings) *Request {
        return r
 }
 
+// SetContext sets the request's Context
+func (r *Request) SetContext(ctx context.Context) *Request {
+       r.req = r.req.WithContext(ctx)
+       return r
+}
+
 // SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password.
 func (r *Request) SetBasicAuth(username, password string) *Request {
        r.req.SetBasicAuth(username, password)
@@ -325,7 +332,7 @@ func (r *Request) getResponse() (*http.Response, error) {
                trans = &http.Transport{
                        TLSClientConfig: r.setting.TLSClientConfig,
                        Proxy:           proxy,
-                       Dial:            TimeoutDialer(r.setting.ConnectTimeout),
+                       DialContext:     TimeoutDialer(r.setting.ConnectTimeout),
                }
        } else if t, ok := trans.(*http.Transport); ok {
                if t.TLSClientConfig == nil {
@@ -334,8 +341,8 @@ func (r *Request) getResponse() (*http.Response, error) {
                if t.Proxy == nil {
                        t.Proxy = r.setting.Proxy
                }
-               if t.Dial == nil {
-                       t.Dial = TimeoutDialer(r.setting.ConnectTimeout)
+               if t.DialContext == nil {
+                       t.DialContext = TimeoutDialer(r.setting.ConnectTimeout)
                }
        }
 
@@ -458,9 +465,10 @@ func (r *Request) Response() (*http.Response, error) {
 }
 
 // TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field.
-func TimeoutDialer(cTimeout time.Duration) func(net, addr string) (c net.Conn, err error) {
-       return func(netw, addr string) (net.Conn, error) {
-               conn, err := net.DialTimeout(netw, addr, cTimeout)
+func TimeoutDialer(cTimeout time.Duration) func(ctx context.Context, net, addr string) (c net.Conn, err error) {
+       return func(ctx context.Context, netw, addr string) (net.Conn, error) {
+               d := net.Dialer{Timeout: cTimeout}
+               conn, err := d.DialContext(ctx, netw, addr)
                if err != nil {
                        return nil, err
                }
index 82dcaf3fc9793c618658990e571dd2de609a8770..79fae052dd76530e0b02ed712e718464bdb79d30 100644 (file)
@@ -5,6 +5,7 @@
 package private
 
 import (
+       "context"
        "encoding/json"
        "fmt"
        "net/http"
@@ -80,12 +81,12 @@ type HookPostReceiveBranchResult struct {
 }
 
 // HookPreReceive check whether the provided commits are allowed
-func HookPreReceive(ownerName, repoName string, opts HookOptions) (int, string) {
+func HookPreReceive(ctx context.Context, ownerName, repoName string, opts HookOptions) (int, string) {
        reqURL := setting.LocalURL + fmt.Sprintf("api/internal/hook/pre-receive/%s/%s",
                url.PathEscape(ownerName),
                url.PathEscape(repoName),
        )
-       req := newInternalRequest(reqURL, "POST")
+       req := newInternalRequest(ctx, reqURL, "POST")
        req = req.Header("Content-Type", "application/json")
        json := jsoniter.ConfigCompatibleWithStandardLibrary
        jsonBytes, _ := json.Marshal(opts)
@@ -105,13 +106,13 @@ func HookPreReceive(ownerName, repoName string, opts HookOptions) (int, string)
 }
 
 // HookPostReceive updates services and users
-func HookPostReceive(ownerName, repoName string, opts HookOptions) (*HookPostReceiveResult, string) {
+func HookPostReceive(ctx context.Context, ownerName, repoName string, opts HookOptions) (*HookPostReceiveResult, string) {
        reqURL := setting.LocalURL + fmt.Sprintf("api/internal/hook/post-receive/%s/%s",
                url.PathEscape(ownerName),
                url.PathEscape(repoName),
        )
 
-       req := newInternalRequest(reqURL, "POST")
+       req := newInternalRequest(ctx, reqURL, "POST")
        req = req.Header("Content-Type", "application/json")
        req.SetTimeout(60*time.Second, time.Duration(60+len(opts.OldCommitIDs))*time.Second)
        json := jsoniter.ConfigCompatibleWithStandardLibrary
@@ -133,13 +134,13 @@ func HookPostReceive(ownerName, repoName string, opts HookOptions) (*HookPostRec
 }
 
 // SetDefaultBranch will set the default branch to the provided branch for the provided repository
-func SetDefaultBranch(ownerName, repoName, branch string) error {
+func SetDefaultBranch(ctx context.Context, ownerName, repoName, branch string) error {
        reqURL := setting.LocalURL + fmt.Sprintf("api/internal/hook/set-default-branch/%s/%s/%s",
                url.PathEscape(ownerName),
                url.PathEscape(repoName),
                url.PathEscape(branch),
        )
-       req := newInternalRequest(reqURL, "POST")
+       req := newInternalRequest(ctx, reqURL, "POST")
        req = req.Header("Content-Type", "application/json")
 
        req.SetTimeout(60*time.Second, 60*time.Second)
@@ -155,9 +156,9 @@ func SetDefaultBranch(ownerName, repoName, branch string) error {
 }
 
 // SSHLog sends ssh error log response
-func SSHLog(isErr bool, msg string) error {
+func SSHLog(ctx context.Context, isErr bool, msg string) error {
        reqURL := setting.LocalURL + "api/internal/ssh/log"
-       req := newInternalRequest(reqURL, "POST")
+       req := newInternalRequest(ctx, reqURL, "POST")
        req = req.Header("Content-Type", "application/json")
 
        jsonBytes, _ := json.Marshal(&SSHLogOption{
@@ -171,6 +172,7 @@ func SSHLog(isErr bool, msg string) error {
        if err != nil {
                return fmt.Errorf("unable to contact gitea: %v", err)
        }
+
        defer resp.Body.Close()
        if resp.StatusCode != http.StatusOK {
                return fmt.Errorf("Error returned from gitea: %v", decodeJSONError(resp).Err)
index 360fae47b6e8c640270f96793563a928eb1ffc20..672ac74970edf23133d84fab33c92dc24d989a29 100644 (file)
@@ -5,6 +5,7 @@
 package private
 
 import (
+       "context"
        "crypto/tls"
        "fmt"
        "net"
@@ -15,9 +16,11 @@ import (
        jsoniter "github.com/json-iterator/go"
 )
 
-func newRequest(url, method string) *httplib.Request {
-       return httplib.NewRequest(url, method).Header("Authorization",
-               fmt.Sprintf("Bearer %s", setting.InternalToken))
+func newRequest(ctx context.Context, url, method string) *httplib.Request {
+       return httplib.NewRequest(url, method).
+               SetContext(ctx).
+               Header("Authorization",
+                       fmt.Sprintf("Bearer %s", setting.InternalToken))
 }
 
 // Response internal request response
@@ -35,8 +38,8 @@ func decodeJSONError(resp *http.Response) *Response {
        return &res
 }
 
-func newInternalRequest(url, method string) *httplib.Request {
-       req := newRequest(url, method).SetTLSClientConfig(&tls.Config{
+func newInternalRequest(ctx context.Context, url, method string) *httplib.Request {
+       req := newRequest(ctx, url, method).SetTLSClientConfig(&tls.Config{
                InsecureSkipVerify: true,
                ServerName:         setting.Domain,
        })
@@ -45,6 +48,10 @@ func newInternalRequest(url, method string) *httplib.Request {
                        Dial: func(_, _ string) (net.Conn, error) {
                                return net.Dial("unix", setting.HTTPAddr)
                        },
+                       DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
+                               var d net.Dialer
+                               return d.DialContext(ctx, "unix", setting.HTTPAddr)
+                       },
                })
        }
        return req
index bea7837906a0c949ee5daa8af96eacaa0898a364..d0b11a96e7ac2fccbba693dfbfa129ac470029ef 100644 (file)
@@ -5,6 +5,7 @@
 package private
 
 import (
+       "context"
        "fmt"
        "io/ioutil"
        "net/http"
@@ -13,10 +14,10 @@ import (
 )
 
 // UpdatePublicKeyInRepo update public key and if necessary deploy key updates
-func UpdatePublicKeyInRepo(keyID, repoID int64) error {
+func UpdatePublicKeyInRepo(ctx context.Context, keyID, repoID int64) error {
        // Ask for running deliver hook and test pull request tasks.
        reqURL := setting.LocalURL + fmt.Sprintf("api/internal/ssh/%d/update/%d", keyID, repoID)
-       resp, err := newInternalRequest(reqURL, "POST").Response()
+       resp, err := newInternalRequest(ctx, reqURL, "POST").Response()
        if err != nil {
                return err
        }
@@ -32,10 +33,10 @@ func UpdatePublicKeyInRepo(keyID, repoID int64) error {
 
 // AuthorizedPublicKeyByContent searches content as prefix (leak e-mail part)
 // and returns public key found.
-func AuthorizedPublicKeyByContent(content string) (string, error) {
+func AuthorizedPublicKeyByContent(ctx context.Context, content string) (string, error) {
        // Ask for running deliver hook and test pull request tasks.
        reqURL := setting.LocalURL + "api/internal/ssh/authorized_keys"
-       req := newInternalRequest(reqURL, "POST")
+       req := newInternalRequest(ctx, reqURL, "POST")
        req.Param("content", content)
        resp, err := req.Response()
        if err != nil {
index 9c0912a6e349036705c05c876d460c24d30d5f52..4a5a3eedd794da898c37928e0af1bf52817f4ac9 100644 (file)
@@ -5,6 +5,7 @@
 package private
 
 import (
+       "context"
        "fmt"
        "io/ioutil"
        "net/http"
@@ -27,10 +28,10 @@ type Email struct {
 //
 // If to list == nil its supposed to send an email to every
 // user present in DB
-func SendEmail(subject, message string, to []string) (int, string) {
+func SendEmail(ctx context.Context, subject, message string, to []string) (int, string) {
        reqURL := setting.LocalURL + "api/internal/mail/send"
 
-       req := newInternalRequest(reqURL, "POST")
+       req := newInternalRequest(ctx, reqURL, "POST")
        req = req.Header("Content-Type", "application/json")
        json := jsoniter.ConfigCompatibleWithStandardLibrary
        jsonBytes, _ := json.Marshal(Email{
index 2bc6cec3b968af0ee688379dcc8ef3610343b356..0bcc3f8112484d6f7f9d7c4067b9ca9e45e2ec41 100644 (file)
@@ -5,6 +5,7 @@
 package private
 
 import (
+       "context"
        "fmt"
        "net/http"
        "net/url"
@@ -15,10 +16,10 @@ import (
 )
 
 // Shutdown calls the internal shutdown function
-func Shutdown() (int, string) {
+func Shutdown(ctx context.Context) (int, string) {
        reqURL := setting.LocalURL + "api/internal/manager/shutdown"
 
-       req := newInternalRequest(reqURL, "POST")
+       req := newInternalRequest(ctx, reqURL, "POST")
        resp, err := req.Response()
        if err != nil {
                return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error())
@@ -33,10 +34,10 @@ func Shutdown() (int, string) {
 }
 
 // Restart calls the internal restart function
-func Restart() (int, string) {
+func Restart(ctx context.Context) (int, string) {
        reqURL := setting.LocalURL + "api/internal/manager/restart"
 
-       req := newInternalRequest(reqURL, "POST")
+       req := newInternalRequest(ctx, reqURL, "POST")
        resp, err := req.Response()
        if err != nil {
                return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error())
@@ -57,10 +58,10 @@ type FlushOptions struct {
 }
 
 // FlushQueues calls the internal flush-queues function
-func FlushQueues(timeout time.Duration, nonBlocking bool) (int, string) {
+func FlushQueues(ctx context.Context, timeout time.Duration, nonBlocking bool) (int, string) {
        reqURL := setting.LocalURL + "api/internal/manager/flush-queues"
 
-       req := newInternalRequest(reqURL, "POST")
+       req := newInternalRequest(ctx, reqURL, "POST")
        if timeout > 0 {
                req.SetTimeout(timeout+10*time.Second, timeout+10*time.Second)
        }
@@ -85,10 +86,10 @@ func FlushQueues(timeout time.Duration, nonBlocking bool) (int, string) {
 }
 
 // PauseLogging pauses logging
-func PauseLogging() (int, string) {
+func PauseLogging(ctx context.Context) (int, string) {
        reqURL := setting.LocalURL + "api/internal/manager/pause-logging"
 
-       req := newInternalRequest(reqURL, "POST")
+       req := newInternalRequest(ctx, reqURL, "POST")
        resp, err := req.Response()
        if err != nil {
                return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error())
@@ -103,10 +104,10 @@ func PauseLogging() (int, string) {
 }
 
 // ResumeLogging resumes logging
-func ResumeLogging() (int, string) {
+func ResumeLogging(ctx context.Context) (int, string) {
        reqURL := setting.LocalURL + "api/internal/manager/resume-logging"
 
-       req := newInternalRequest(reqURL, "POST")
+       req := newInternalRequest(ctx, reqURL, "POST")
        resp, err := req.Response()
        if err != nil {
                return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error())
@@ -121,10 +122,10 @@ func ResumeLogging() (int, string) {
 }
 
 // ReleaseReopenLogging releases and reopens logging files
-func ReleaseReopenLogging() (int, string) {
+func ReleaseReopenLogging(ctx context.Context) (int, string) {
        reqURL := setting.LocalURL + "api/internal/manager/release-and-reopen-logging"
 
-       req := newInternalRequest(reqURL, "POST")
+       req := newInternalRequest(ctx, reqURL, "POST")
        resp, err := req.Response()
        if err != nil {
                return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error())
@@ -147,10 +148,10 @@ type LoggerOptions struct {
 }
 
 // AddLogger adds a logger
-func AddLogger(group, name, mode string, config map[string]interface{}) (int, string) {
+func AddLogger(ctx context.Context, group, name, mode string, config map[string]interface{}) (int, string) {
        reqURL := setting.LocalURL + "api/internal/manager/add-logger"
 
-       req := newInternalRequest(reqURL, "POST")
+       req := newInternalRequest(ctx, reqURL, "POST")
        req = req.Header("Content-Type", "application/json")
        json := jsoniter.ConfigCompatibleWithStandardLibrary
        jsonBytes, _ := json.Marshal(LoggerOptions{
@@ -175,10 +176,10 @@ func AddLogger(group, name, mode string, config map[string]interface{}) (int, st
 }
 
 // RemoveLogger removes a logger
-func RemoveLogger(group, name string) (int, string) {
+func RemoveLogger(ctx context.Context, group, name string) (int, string) {
        reqURL := setting.LocalURL + fmt.Sprintf("api/internal/manager/remove-logger/%s/%s", url.PathEscape(group), url.PathEscape(name))
 
-       req := newInternalRequest(reqURL, "POST")
+       req := newInternalRequest(ctx, reqURL, "POST")
        resp, err := req.Response()
        if err != nil {
                return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error())
index 6fe2e6844b986de4cd000e5db8df2ba13cbb5494..66b60d8d124b21374f112a5f8e2637b8e7b560c3 100644 (file)
@@ -5,6 +5,7 @@
 package private
 
 import (
+       "context"
        "fmt"
        "io/ioutil"
        "net/http"
@@ -23,10 +24,10 @@ type RestoreParams struct {
 }
 
 // RestoreRepo calls the internal RestoreRepo function
-func RestoreRepo(repoDir, ownerName, repoName string, units []string) (int, string) {
+func RestoreRepo(ctx context.Context, repoDir, ownerName, repoName string, units []string) (int, string) {
        reqURL := setting.LocalURL + "api/internal/restore_repo"
 
-       req := newInternalRequest(reqURL, "POST")
+       req := newInternalRequest(ctx, reqURL, "POST")
        req.SetTimeout(3*time.Second, 0) // since the request will spend much time, don't timeout
        req = req.Header("Content-Type", "application/json")
        json := jsoniter.ConfigCompatibleWithStandardLibrary
index 659af6dff5f97f9bf3d5a5d5acc60e0984b27032..9643dad679a1bd0a8a54be480cde69c90165e444 100644 (file)
@@ -5,6 +5,7 @@
 package private
 
 import (
+       "context"
        "fmt"
        "net/http"
        "net/url"
@@ -21,10 +22,10 @@ type KeyAndOwner struct {
 }
 
 // ServNoCommand returns information about the provided key
-func ServNoCommand(keyID int64) (*models.PublicKey, *models.User, error) {
+func ServNoCommand(ctx context.Context, keyID int64) (*models.PublicKey, *models.User, error) {
        reqURL := setting.LocalURL + fmt.Sprintf("api/internal/serv/none/%d",
                keyID)
-       resp, err := newInternalRequest(reqURL, "GET").Response()
+       resp, err := newInternalRequest(ctx, reqURL, "GET").Response()
        if err != nil {
                return nil, nil, err
        }
@@ -73,7 +74,7 @@ func IsErrServCommand(err error) bool {
 }
 
 // ServCommand preps for a serv call
-func ServCommand(keyID int64, ownerName, repoName string, mode models.AccessMode, verbs ...string) (*ServCommandResults, error) {
+func ServCommand(ctx context.Context, keyID int64, ownerName, repoName string, mode models.AccessMode, verbs ...string) (*ServCommandResults, error) {
        reqURL := setting.LocalURL + fmt.Sprintf("api/internal/serv/command/%d/%s/%s?mode=%d",
                keyID,
                url.PathEscape(ownerName),
@@ -85,7 +86,7 @@ func ServCommand(keyID int64, ownerName, repoName string, mode models.AccessMode
                }
        }
 
-       resp, err := newInternalRequest(reqURL, "GET").Response()
+       resp, err := newInternalRequest(ctx, reqURL, "GET").Response()
        if err != nil {
                return nil, err
        }
index c0897377c56fdd533b0533cab004681b779c7a0d..efe952534551bd4f210d9e05e654026ae6148f52 100644 (file)
@@ -6,6 +6,7 @@ package ssh
 
 import (
        "bytes"
+       "context"
        "crypto/rand"
        "crypto/rsa"
        "crypto/x509"
@@ -66,7 +67,11 @@ func sessionHandler(session ssh.Session) {
 
        args := []string{"serv", "key-" + keyID, "--config=" + setting.CustomConf}
        log.Trace("SSH: Arguments: %v", args)
-       cmd := exec.CommandContext(session.Context(), setting.AppPath, args...)
+
+       ctx, cancel := context.WithCancel(session.Context())
+       defer cancel()
+
+       cmd := exec.CommandContext(ctx, setting.AppPath, args...)
        cmd.Env = append(
                os.Environ(),
                "SSH_ORIGINAL_COMMAND="+command,
@@ -78,16 +83,21 @@ func sessionHandler(session ssh.Session) {
                log.Error("SSH: StdoutPipe: %v", err)
                return
        }
+       defer stdout.Close()
+
        stderr, err := cmd.StderrPipe()
        if err != nil {
                log.Error("SSH: StderrPipe: %v", err)
                return
        }
+       defer stderr.Close()
+
        stdin, err := cmd.StdinPipe()
        if err != nil {
                log.Error("SSH: StdinPipe: %v", err)
                return
        }
+       defer stdin.Close()
 
        wg := &sync.WaitGroup{}
        wg.Add(2)
@@ -106,6 +116,7 @@ func sessionHandler(session ssh.Session) {
 
        go func() {
                defer wg.Done()
+               defer stdout.Close()
                if _, err := io.Copy(session, stdout); err != nil {
                        log.Error("Failed to write stdout to session. %s", err)
                }
@@ -113,6 +124,7 @@ func sessionHandler(session ssh.Session) {
 
        go func() {
                defer wg.Done()
+               defer stderr.Close()
                if _, err := io.Copy(session.Stderr(), stderr); err != nil {
                        log.Error("Failed to write stderr to session. %s", err)
                }