[[projects]]
branch = "master"
- digest = "1:296fd9dfbae66f6feeb09c7163ec39c262de425289154430a55d0a248c520486"
+ digest = "1:ebd587087cf937b6d3db7dde843a557d157fd68820a9d3d0157a8d8f4011ad29"
name = "code.gitea.io/git"
packages = ["."]
pruneopts = "NUT"
- revision = "d945eda535aa7d6b3c1f486279df2a3f7d05f78b"
+ revision = "578ad8f1259b0d660d19b05a011596f8fd3fea37"
[[projects]]
branch = "master"
- digest = "1:b194da40b41ae99546dfeec5a85f1fec2a6c51350d438e511ef90f4293c6dcd7"
+ digest = "1:4d2822cfcdf270183cee220e79e7bba55d5214a9c2bfa9b1fd6c6daaf5016eda"
name = "code.gitea.io/sdk"
packages = ["gitea"]
pruneopts = "NUT"
- revision = "4f96d9ac89886e78c50de8c835ebe87461578a5e"
+ revision = "59ddbdc4be1423ab3d5f30b859193ac0308df147"
[[projects]]
digest = "1:3fcef06a1a6561955c94af6c7757a6fa37605eb653f0d06ab960e5bb80092195"
pruneopts = "NUT"
revision = "57eb5e1fc594ad4b0b1dbea7b286d299e0cb43c2"
+[[projects]]
+ digest = "1:b498b36dbb2b306d1c5205ee5236c9e60352be8f9eea9bf08186723a9f75b4f3"
+ name = "github.com/emirpasic/gods"
+ packages = [
+ "containers",
+ "lists",
+ "lists/arraylist",
+ "trees",
+ "trees/binaryheap",
+ "utils",
+ ]
+ pruneopts = "NUT"
+ revision = "1615341f118ae12f353cc8a983f35b584342c9b3"
+ version = "v1.12.0"
+
[[projects]]
digest = "1:8603f74d35c93b37c615a02ba297be2cf2efc9ff6f1ff2b458a903990b568e48"
name = "github.com/ethantkoenig/rupture"
pruneopts = "NUT"
revision = "8fb95d837f7d6db1913fecfd7bcc5333e6499596"
+[[projects]]
+ branch = "master"
+ digest = "1:62fe3a7ea2050ecbd753a71889026f83d73329337ada66325cbafd5dea5f713d"
+ name = "github.com/jbenet/go-context"
+ packages = ["io"]
+ pruneopts = "NUT"
+ revision = "d14ea06fba99483203c19d92cfcd13ebe73135f4"
+
[[projects]]
digest = "1:6342cf70eaae592f7b8e2552037f2a9d4d16fa321c6e36f09c3bc450add2de19"
name = "github.com/kballard/go-shellquote"
pruneopts = "NUT"
revision = "cd60e84ee657ff3dc51de0b4f55dd299a3e136f2"
+[[projects]]
+ digest = "1:29e44e9481a689be0093a0033299b95741d394a97b28e0273c21afe697873a22"
+ name = "github.com/kevinburke/ssh_config"
+ packages = ["."]
+ pruneopts = "NUT"
+ revision = "81db2a75821ed34e682567d48be488a1c3121088"
+ version = "0.5"
+
[[projects]]
digest = "1:b32126992771fddadf6a778fe7ab29150665ed78f31ce4eb550a9db3bc0e650c"
name = "github.com/keybase/go-crypto"
pruneopts = "NUT"
revision = "f77f16ffc87a6a58814e64ae72d55f9c41374e6d"
+[[projects]]
+ digest = "1:a4df73029d2c42fabcb6b41e327d2f87e685284ec03edf76921c267d9cfc9c23"
+ name = "github.com/mitchellh/go-homedir"
+ packages = ["."]
+ pruneopts = "NUT"
+ revision = "ae18d6b8b3205b561c79e8e5f69bff09736185f4"
+ version = "v1.0.0"
+
[[projects]]
digest = "1:c7dc71a7e144df03332152d730f9c5ae22cf1cfd55454cb001ba8ffcb78aa7f0"
name = "github.com/mrjones/oauth"
pruneopts = "NUT"
revision = "891127d8d1b52734debe1b3c3d7e747502b6c366"
+[[projects]]
+ digest = "1:cf254277d898b713195cc6b4a3fac8bf738b9f1121625df27843b52b267eec6c"
+ name = "github.com/pelletier/go-buffruneio"
+ packages = ["."]
+ pruneopts = "NUT"
+ revision = "c37440a7cf42ac63b919c752ca73a85067e05992"
+ version = "v0.2.0"
+
[[projects]]
digest = "1:44c66ad69563dbe3f8e76d7d6cad21a03626e53f1875b5ab163ded419e01ca7a"
name = "github.com/philhofer/fwd"
pruneopts = "NUT"
revision = "1dba4b3954bc059efc3991ec364f9f9a35f597d2"
+[[projects]]
+ digest = "1:89fd77d603a74a6540d60067debad9397865bf040955d907362c95d364baeba6"
+ name = "github.com/src-d/gcfg"
+ packages = [
+ ".",
+ "scanner",
+ "token",
+ "types",
+ ]
+ pruneopts = "NUT"
+ revision = "1ac3a1ac202429a54835fe8408a92880156b489d"
+ version = "v1.4.0"
+
[[projects]]
branch = "master"
digest = "1:69177343ca227319b4580441a67d9d889e9ac7fcbfb89fbaa36d3283e6ab0139"
pruneopts = "NUT"
revision = "8ce1146b8621c95164efd9c8b1124cfa9b8afb4e"
+[[projects]]
+ digest = "1:3148cb3478c26a92b4c1a18abb9428234b281e278af6267840721a24b6cbc6a3"
+ name = "github.com/xanzy/ssh-agent"
+ packages = ["."]
+ pruneopts = "NUT"
+ revision = "640f0ab560aeb89d523bb6ac322b1244d5c3796c"
+ version = "v0.2.0"
+
[[projects]]
digest = "1:27d050258a4b19ca3b7a1bf26f4a04c5c66bbf0670b346ee509ebb0ad82257a6"
name = "github.com/yohcop/openid-go"
revision = "2c050d2dae5345c417db301f11fda6fbf5ad0f0a"
[[projects]]
- digest = "1:e4ea859df4986eb46feebbb84a2d163a4a314e87668177ca13b3b0adecaf50e8"
+ digest = "1:c3d6b9e2cf3936ba9927da2e8858651aad69890b9dd3349f1316b4003b25d7a3"
name = "golang.org/x/crypto"
packages = [
"acme",
"acme/autocert",
+ "cast5",
"curve25519",
"ed25519",
"ed25519/internal/edwards25519",
"internal/chacha20",
"md4",
+ "openpgp",
+ "openpgp/armor",
+ "openpgp/elgamal",
+ "openpgp/errors",
+ "openpgp/packet",
+ "openpgp/s2k",
"pbkdf2",
"poly1305",
"ssh",
+ "ssh/agent",
+ "ssh/knownhosts",
]
pruneopts = "NUT"
revision = "12dd70caea0268ac0d6c2707d0611ef601e7c64e"
revision = "e6179049628164864e6e84e973cfb56335748dea"
version = "v2.3.2"
+[[projects]]
+ digest = "1:1cf1388ec8c73b7ecc711d9f279ab631ea0a6964d1ccc32809a6be90c33fa2a0"
+ name = "gopkg.in/src-d/go-billy.v4"
+ packages = [
+ ".",
+ "helper/chroot",
+ "helper/polyfill",
+ "osfs",
+ "util",
+ ]
+ pruneopts = "NUT"
+ revision = "982626487c60a5252e7d0b695ca23fb0fa2fd670"
+ version = "v4.3.0"
+
+[[projects]]
+ digest = "1:8a0efb153cc5b7e0e129d716834217be483e2b326e72f3dcca8b03cd3207e9e4"
+ name = "gopkg.in/src-d/go-git.v4"
+ packages = [
+ ".",
+ "config",
+ "internal/revision",
+ "plumbing",
+ "plumbing/cache",
+ "plumbing/filemode",
+ "plumbing/format/config",
+ "plumbing/format/diff",
+ "plumbing/format/gitignore",
+ "plumbing/format/idxfile",
+ "plumbing/format/index",
+ "plumbing/format/objfile",
+ "plumbing/format/packfile",
+ "plumbing/format/pktline",
+ "plumbing/object",
+ "plumbing/protocol/packp",
+ "plumbing/protocol/packp/capability",
+ "plumbing/protocol/packp/sideband",
+ "plumbing/revlist",
+ "plumbing/storer",
+ "plumbing/transport",
+ "plumbing/transport/client",
+ "plumbing/transport/file",
+ "plumbing/transport/git",
+ "plumbing/transport/http",
+ "plumbing/transport/internal/common",
+ "plumbing/transport/server",
+ "plumbing/transport/ssh",
+ "storage",
+ "storage/filesystem",
+ "storage/filesystem/dotgit",
+ "storage/memory",
+ "utils/binary",
+ "utils/diff",
+ "utils/ioutil",
+ "utils/merkletrie",
+ "utils/merkletrie/filesystem",
+ "utils/merkletrie/index",
+ "utils/merkletrie/internal/frame",
+ "utils/merkletrie/noder",
+ ]
+ pruneopts = "NUT"
+ revision = "f62cd8e3495579a8323455fa0c4e6c44bb0d5e09"
+ version = "v4.8.0"
+
[[projects]]
digest = "1:9c541fc507676a69ea8aaed1af53278a5241d26ce0f192c993fec2ac5b78f795"
name = "gopkg.in/testfixtures.v2"
revision = "fa3fb89109b0b31957a5430cef3e93e535de362b"
version = "v2.5.0"
+[[projects]]
+ digest = "1:b233ad4ec87ac916e7bf5e678e98a2cb9e8b52f6de6ad3e11834fc7a71b8e3bf"
+ name = "gopkg.in/warnings.v0"
+ packages = ["."]
+ pruneopts = "NUT"
+ revision = "ec4a0fea49c7b46c2aeb0b51aac55779c607e52b"
+ version = "v0.1.2"
+
[[projects]]
digest = "1:ad6f94355d292690137613735965bd3688844880fdab90eccf66321910344942"
name = "gopkg.in/yaml.v2"
--- /dev/null
+// Copyright 2018 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package integrations
+
+import (
+ "net/http"
+ "testing"
+
+ "code.gitea.io/gitea/models"
+)
+
+func TestAPIReposGitRefs(t *testing.T) {
+ prepareTestEnv(t)
+ user := models.AssertExistsAndLoadBean(t, &models.User{ID: 2}).(*models.User)
+ // Login as User2.
+ session := loginUser(t, user.Name)
+ token := getTokenForLoggedInUser(t, session)
+
+ for _, ref := range [...]string{
+ "refs/heads/master", // Branch
+ "refs/tags/v1.1", // Tag
+ } {
+ req := NewRequestf(t, "GET", "/api/v1/repos/%s/repo1/git/%s?token="+token, user.Name, ref)
+ session.MakeRequest(t, req, http.StatusOK)
+ }
+ // Test getting all refs
+ req := NewRequestf(t, "GET", "/api/v1/repos/%s/repo1/git/refs?token="+token, user.Name)
+ session.MakeRequest(t, req, http.StatusOK)
+ // Test getting non-existent refs
+ req = NewRequestf(t, "GET", "/api/v1/repos/%s/repo1/git/refs/heads/unknown?token="+token, user.Name)
+ session.MakeRequest(t, req, http.StatusNotFound)
+}
m.Get("/status", repo.GetCombinedCommitStatusByRef)
m.Get("/statuses", repo.GetCommitStatusesByRef)
})
+ m.Group("/git", func() {
+ m.Get("/refs", repo.GetGitAllRefs)
+ m.Get("/refs/*", repo.GetGitRefs)
+ })
}, repoAssignment())
})
--- /dev/null
+// Copyright 2018 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package repo
+
+import (
+ "code.gitea.io/gitea/modules/context"
+
+ "code.gitea.io/git"
+ api "code.gitea.io/sdk/gitea"
+)
+
+// GetGitAllRefs get ref or an list all the refs of a repository
+func GetGitAllRefs(ctx *context.APIContext) {
+ // swagger:operation GET /repos/{owner}/{repo}/git/refs repository repoListAllGitRefs
+ // ---
+ // summary: Get specified ref or filtered repository's refs
+ // produces:
+ // - application/json
+ // parameters:
+ // - name: owner
+ // in: path
+ // description: owner of the repo
+ // type: string
+ // required: true
+ // - name: repo
+ // in: path
+ // description: name of the repo
+ // type: string
+ // required: true
+ // responses:
+ // "200":
+ // "$ref": "#/responses/Reference"
+ // "$ref": "#/responses/ReferenceList"
+ // "404":
+ // "$ref": "#/responses/notFound"
+
+ getGitRefsInternal(ctx, "")
+}
+
+// GetGitRefs get ref or an filteresd list of refs of a repository
+func GetGitRefs(ctx *context.APIContext) {
+ // swagger:operation GET /repos/{owner}/{repo}/git/refs/{ref} repository repoListGitRefs
+ // ---
+ // summary: Get specified ref or filtered repository's refs
+ // produces:
+ // - application/json
+ // parameters:
+ // - name: owner
+ // in: path
+ // description: owner of the repo
+ // type: string
+ // required: true
+ // - name: repo
+ // in: path
+ // description: name of the repo
+ // type: string
+ // required: true
+ // - name: ref
+ // in: path
+ // description: part or full name of the ref
+ // type: string
+ // required: true
+ // responses:
+ // "200":
+ // "$ref": "#/responses/Reference"
+ // "$ref": "#/responses/ReferenceList"
+ // "404":
+ // "$ref": "#/responses/notFound"
+
+ getGitRefsInternal(ctx, ctx.Params("*"))
+}
+
+func getGitRefsInternal(ctx *context.APIContext, filter string) {
+ gitRepo, err := git.OpenRepository(ctx.Repo.Repository.RepoPath())
+ if err != nil {
+ ctx.Error(500, "OpenRepository", err)
+ return
+ }
+ if len(filter) > 0 {
+ filter = "refs/" + filter
+ }
+
+ refs, err := gitRepo.GetRefsFiltered(filter)
+ if err != nil {
+ ctx.Error(500, "GetRefsFiltered", err)
+ return
+ }
+
+ if len(refs) == 0 {
+ ctx.Status(404)
+ return
+ }
+
+ apiRefs := make([]*api.Reference, len(refs))
+ for i := range refs {
+ apiRefs[i] = &api.Reference{
+ Ref: refs[i].Name,
+ URL: ctx.Repo.Repository.APIURL() + "/git/" + refs[i].Name,
+ Object: &api.GitObject{
+ SHA: refs[i].Object.String(),
+ Type: refs[i].Type,
+ // TODO: Add commit/tag info URL
+ //URL: ctx.Repo.Repository.APIURL() + "/git/" + refs[i].Type + "s/" + refs[i].Object.String(),
+ },
+ }
+ }
+ // If single reference is found and it matches filter exactly return it as object
+ if len(apiRefs) == 1 && apiRefs[0].Ref == filter {
+ ctx.JSON(200, &apiRefs[0])
+ return
+ }
+ ctx.JSON(200, &apiRefs)
+}
Body []api.Branch `json:"body"`
}
+// Reference
+// swagger:response Reference
+type swaggerResponseReference struct {
+ // in:body
+ Body api.Reference `json:"body"`
+}
+
+// ReferenceList
+// swagger:response ReferenceList
+type swaggerResponseReferenceList struct {
+ // in:body
+ Body []api.Reference `json:"body"`
+}
+
// Hook
// swagger:response Hook
type swaggerResponseHook struct {
}
}
},
+ "/repos/{owner}/{repo}/git/refs": {
+ "get": {
+ "produces": [
+ "application/json"
+ ],
+ "tags": [
+ "repository"
+ ],
+ "summary": "Get specified ref or filtered repository's refs",
+ "operationId": "repoListAllGitRefs",
+ "parameters": [
+ {
+ "type": "string",
+ "description": "owner of the repo",
+ "name": "owner",
+ "in": "path",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "name of the repo",
+ "name": "repo",
+ "in": "path",
+ "required": true
+ }
+ ],
+ "responses": {
+ "200": {
+ "$ref": "#/responses/ReferenceList"
+ },
+ "404": {
+ "$ref": "#/responses/notFound"
+ }
+ }
+ }
+ },
+ "/repos/{owner}/{repo}/git/refs/{ref}": {
+ "get": {
+ "produces": [
+ "application/json"
+ ],
+ "tags": [
+ "repository"
+ ],
+ "summary": "Get specified ref or filtered repository's refs",
+ "operationId": "repoListGitRefs",
+ "parameters": [
+ {
+ "type": "string",
+ "description": "owner of the repo",
+ "name": "owner",
+ "in": "path",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "name of the repo",
+ "name": "repo",
+ "in": "path",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "part or full name of the ref",
+ "name": "ref",
+ "in": "path",
+ "required": true
+ }
+ ],
+ "responses": {
+ "200": {
+ "$ref": "#/responses/ReferenceList"
+ },
+ "404": {
+ "$ref": "#/responses/notFound"
+ }
+ }
+ }
+ },
"/repos/{owner}/{repo}/hooks": {
"get": {
"produces": [
},
"x-go-package": "code.gitea.io/gitea/vendor/code.gitea.io/sdk/gitea"
},
+ "GitObject": {
+ "type": "object",
+ "title": "GitObject represents a Git object.",
+ "properties": {
+ "sha": {
+ "type": "string",
+ "x-go-name": "SHA"
+ },
+ "type": {
+ "type": "string",
+ "x-go-name": "Type"
+ },
+ "url": {
+ "type": "string",
+ "x-go-name": "URL"
+ }
+ },
+ "x-go-package": "code.gitea.io/gitea/vendor/code.gitea.io/sdk/gitea"
+ },
"Issue": {
"description": "Issue represents an issue in a repository",
"type": "object",
},
"x-go-package": "code.gitea.io/gitea/vendor/code.gitea.io/sdk/gitea"
},
+ "Reference": {
+ "type": "object",
+ "title": "Reference represents a Git reference.",
+ "properties": {
+ "object": {
+ "$ref": "#/definitions/GitObject"
+ },
+ "ref": {
+ "type": "string",
+ "x-go-name": "Ref"
+ },
+ "url": {
+ "type": "string",
+ "x-go-name": "URL"
+ }
+ },
+ "x-go-package": "code.gitea.io/gitea/vendor/code.gitea.io/sdk/gitea"
+ },
"Release": {
"description": "Release represents a repository release",
"type": "object",
}
}
},
+ "Reference": {
+ "description": "Reference",
+ "schema": {
+ "$ref": "#/definitions/Reference"
+ }
+ },
+ "ReferenceList": {
+ "description": "ReferenceList",
+ "schema": {
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/Reference"
+ }
+ }
+ },
"Release": {
"description": "Release",
"schema": {
--- /dev/null
+// Copyright 2018 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package git
+
+// Reference represents a Git ref.
+type Reference struct {
+ Name string
+ repo *Repository
+ Object SHA1 // The id of this commit object
+ Type string
+}
+
+// Commit return the commit of the reference
+func (ref *Reference) Commit() (*Commit, error) {
+ return ref.repo.getCommit(ref.Object)
+}
// Copyright 2015 The Gogs Authors. All rights reserved.
+// Copyright 2018 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
import (
"fmt"
"strings"
+
+ "gopkg.in/src-d/go-git.v4"
+ "gopkg.in/src-d/go-git.v4/plumbing"
)
// BranchPrefix base dir of the branch information file store on git
// GetBranches returns all branches of the repository.
func (repo *Repository) GetBranches() ([]string, error) {
- stdout, err := NewCommand("for-each-ref", "--format=%(refname)", BranchPrefix).RunInDir(repo.Path)
+ r, err := git.PlainOpen(repo.Path)
if err != nil {
return nil, err
}
- refs := strings.Split(stdout, "\n")
- branches := make([]string, len(refs)-1)
- for i, ref := range refs[:len(refs)-1] {
- branches[i] = strings.TrimPrefix(ref, BranchPrefix)
+ branchIter, err := r.Branches()
+ if err != nil {
+ return nil, err
}
+ branches := make([]string, 0)
+ if err = branchIter.ForEach(func(branch *plumbing.Reference) error {
+ branches = append(branches, branch.Name().Short())
+ return nil
+ }); err != nil {
+ return nil, err
+ }
+
return branches, nil
}
--- /dev/null
+// Copyright 2018 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package git
+
+import (
+ "strings"
+
+ "gopkg.in/src-d/go-git.v4"
+ "gopkg.in/src-d/go-git.v4/plumbing"
+)
+
+// GetRefs returns all references of the repository.
+func (repo *Repository) GetRefs() ([]*Reference, error) {
+ return repo.GetRefsFiltered("")
+}
+
+// GetRefsFiltered returns all references of the repository that matches patterm exactly or starting with.
+func (repo *Repository) GetRefsFiltered(pattern string) ([]*Reference, error) {
+ r, err := git.PlainOpen(repo.Path)
+ if err != nil {
+ return nil, err
+ }
+
+ refsIter, err := r.References()
+ if err != nil {
+ return nil, err
+ }
+ refs := make([]*Reference, 0)
+ if err = refsIter.ForEach(func(ref *plumbing.Reference) error {
+ if ref.Name() != plumbing.HEAD && !ref.Name().IsRemote() &&
+ (pattern == "" || strings.HasPrefix(ref.Name().String(), pattern)) {
+ r := &Reference{
+ Name: ref.Name().String(),
+ Object: SHA1(ref.Hash()),
+ Type: string(ObjectCommit),
+ repo: repo,
+ }
+ if ref.Name().IsTag() {
+ r.Type = string(ObjectTag)
+ }
+ refs = append(refs, r)
+ }
+ return nil
+ }); err != nil {
+ return nil, err
+ }
+
+ return refs, nil
+}
--- /dev/null
+// Copyright 2018 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package gitea
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "strings"
+)
+
+// Reference represents a Git reference.
+type Reference struct {
+ Ref string `json:"ref"`
+ URL string `json:"url"`
+ Object *GitObject `json:"object"`
+}
+
+// GitObject represents a Git object.
+type GitObject struct {
+ Type string `json:"type"`
+ SHA string `json:"sha"`
+ URL string `json:"url"`
+}
+
+// GetRepoRef get one ref's information of one repository
+func (c *Client) GetRepoRef(user, repo, ref string) (*Reference, error) {
+ ref = strings.TrimPrefix(ref, "refs/")
+ r := new(Reference)
+ err := c.getParsedResponse("GET", fmt.Sprintf("/repos/%s/%s/git/refs/%s", user, repo, ref), nil, nil, &r)
+ if _, ok := err.(*json.UnmarshalTypeError); ok {
+ // Multiple refs
+ return nil, errors.New("no exact match found for this ref")
+ } else if err != nil {
+ return nil, err
+ }
+
+ return r, nil
+}
+
+// GetRepoRefs get list of ref's information of one repository
+func (c *Client) GetRepoRefs(user, repo, ref string) ([]*Reference, error) {
+ ref = strings.TrimPrefix(ref, "refs/")
+ resp, err := c.getResponse("GET", fmt.Sprintf("/repos/%s/%s/git/refs/%s", user, repo, ref), nil, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ // Attempt to unmarshal single returned ref.
+ r := new(Reference)
+ refErr := json.Unmarshal(resp, r)
+ if refErr == nil {
+ return []*Reference{r}, nil
+ }
+
+ // Attempt to unmarshal multiple refs.
+ var rs []*Reference
+ refsErr := json.Unmarshal(resp, &rs)
+ if refsErr == nil {
+ if len(rs) == 0 {
+ return nil, errors.New("unexpected response: an array of refs with length 0")
+ }
+ return rs, nil
+ }
+
+ return nil, fmt.Errorf("unmarshalling failed for both single and multiple refs: %s and %s", refErr, refsErr)
+}
--- /dev/null
+Copyright (c) 2015, Emir Pasic
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+-------------------------------------------------------------------------------
+
+AVL Tree:
+
+Copyright (c) 2017 Benjamin Scher Purcell <benjapurcell@gmail.com>
+
+Permission to use, copy, modify, and distribute this software for any
+purpose with or without fee is hereby granted, provided that the above
+copyright notice and this permission notice appear in all copies.
+
+THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package containers provides core interfaces and functions for data structures.
+//
+// Container is the base interface for all data structures to implement.
+//
+// Iterators provide stateful iterators.
+//
+// Enumerable provides Ruby inspired (each, select, map, find, any?, etc.) container functions.
+//
+// Serialization provides serializers (marshalers) and deserializers (unmarshalers).
+package containers
+
+import "github.com/emirpasic/gods/utils"
+
+// Container is base interface that all data structures implement.
+type Container interface {
+ Empty() bool
+ Size() int
+ Clear()
+ Values() []interface{}
+}
+
+// GetSortedValues returns sorted container's elements with respect to the passed comparator.
+// Does not effect the ordering of elements within the container.
+func GetSortedValues(container Container, comparator utils.Comparator) []interface{} {
+ values := container.Values()
+ if len(values) < 2 {
+ return values
+ }
+ utils.Sort(values, comparator)
+ return values
+}
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package containers
+
+// EnumerableWithIndex provides functions for ordered containers whose values can be fetched by an index.
+type EnumerableWithIndex interface {
+ // Each calls the given function once for each element, passing that element's index and value.
+ Each(func(index int, value interface{}))
+
+ // Map invokes the given function once for each element and returns a
+ // container containing the values returned by the given function.
+ // TODO would appreciate help on how to enforce this in containers (don't want to type assert when chaining)
+ // Map(func(index int, value interface{}) interface{}) Container
+
+ // Select returns a new container containing all elements for which the given function returns a true value.
+ // TODO need help on how to enforce this in containers (don't want to type assert when chaining)
+ // Select(func(index int, value interface{}) bool) Container
+
+ // Any passes each element of the container to the given function and
+ // returns true if the function ever returns true for any element.
+ Any(func(index int, value interface{}) bool) bool
+
+ // All passes each element of the container to the given function and
+ // returns true if the function returns true for all elements.
+ All(func(index int, value interface{}) bool) bool
+
+ // Find passes each element of the container to the given function and returns
+ // the first (index,value) for which the function is true or -1,nil otherwise
+ // if no element matches the criteria.
+ Find(func(index int, value interface{}) bool) (int, interface{})
+}
+
+// EnumerableWithKey provides functions for ordered containers whose values whose elements are key/value pairs.
+type EnumerableWithKey interface {
+ // Each calls the given function once for each element, passing that element's key and value.
+ Each(func(key interface{}, value interface{}))
+
+ // Map invokes the given function once for each element and returns a container
+ // containing the values returned by the given function as key/value pairs.
+ // TODO need help on how to enforce this in containers (don't want to type assert when chaining)
+ // Map(func(key interface{}, value interface{}) (interface{}, interface{})) Container
+
+ // Select returns a new container containing all elements for which the given function returns a true value.
+ // TODO need help on how to enforce this in containers (don't want to type assert when chaining)
+ // Select(func(key interface{}, value interface{}) bool) Container
+
+ // Any passes each element of the container to the given function and
+ // returns true if the function ever returns true for any element.
+ Any(func(key interface{}, value interface{}) bool) bool
+
+ // All passes each element of the container to the given function and
+ // returns true if the function returns true for all elements.
+ All(func(key interface{}, value interface{}) bool) bool
+
+ // Find passes each element of the container to the given function and returns
+ // the first (key,value) for which the function is true or nil,nil otherwise if no element
+ // matches the criteria.
+ Find(func(key interface{}, value interface{}) bool) (interface{}, interface{})
+}
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package containers
+
+// IteratorWithIndex is stateful iterator for ordered containers whose values can be fetched by an index.
+type IteratorWithIndex interface {
+ // Next moves the iterator to the next element and returns true if there was a next element in the container.
+ // If Next() returns true, then next element's index and value can be retrieved by Index() and Value().
+ // If Next() was called for the first time, then it will point the iterator to the first element if it exists.
+ // Modifies the state of the iterator.
+ Next() bool
+
+ // Value returns the current element's value.
+ // Does not modify the state of the iterator.
+ Value() interface{}
+
+ // Index returns the current element's index.
+ // Does not modify the state of the iterator.
+ Index() int
+
+ // Begin resets the iterator to its initial state (one-before-first)
+ // Call Next() to fetch the first element if any.
+ Begin()
+
+ // First moves the iterator to the first element and returns true if there was a first element in the container.
+ // If First() returns true, then first element's index and value can be retrieved by Index() and Value().
+ // Modifies the state of the iterator.
+ First() bool
+}
+
+// IteratorWithKey is a stateful iterator for ordered containers whose elements are key value pairs.
+type IteratorWithKey interface {
+ // Next moves the iterator to the next element and returns true if there was a next element in the container.
+ // If Next() returns true, then next element's key and value can be retrieved by Key() and Value().
+ // If Next() was called for the first time, then it will point the iterator to the first element if it exists.
+ // Modifies the state of the iterator.
+ Next() bool
+
+ // Value returns the current element's value.
+ // Does not modify the state of the iterator.
+ Value() interface{}
+
+ // Key returns the current element's key.
+ // Does not modify the state of the iterator.
+ Key() interface{}
+
+ // Begin resets the iterator to its initial state (one-before-first)
+ // Call Next() to fetch the first element if any.
+ Begin()
+
+ // First moves the iterator to the first element and returns true if there was a first element in the container.
+ // If First() returns true, then first element's key and value can be retrieved by Key() and Value().
+ // Modifies the state of the iterator.
+ First() bool
+}
+
+// ReverseIteratorWithIndex is stateful iterator for ordered containers whose values can be fetched by an index.
+//
+// Essentially it is the same as IteratorWithIndex, but provides additional:
+//
+// Prev() function to enable traversal in reverse
+//
+// Last() function to move the iterator to the last element.
+//
+// End() function to move the iterator past the last element (one-past-the-end).
+type ReverseIteratorWithIndex interface {
+ // Prev moves the iterator to the previous element and returns true if there was a previous element in the container.
+ // If Prev() returns true, then previous element's index and value can be retrieved by Index() and Value().
+ // Modifies the state of the iterator.
+ Prev() bool
+
+ // End moves the iterator past the last element (one-past-the-end).
+ // Call Prev() to fetch the last element if any.
+ End()
+
+ // Last moves the iterator to the last element and returns true if there was a last element in the container.
+ // If Last() returns true, then last element's index and value can be retrieved by Index() and Value().
+ // Modifies the state of the iterator.
+ Last() bool
+
+ IteratorWithIndex
+}
+
+// ReverseIteratorWithKey is a stateful iterator for ordered containers whose elements are key value pairs.
+//
+// Essentially it is the same as IteratorWithKey, but provides additional:
+//
+// Prev() function to enable traversal in reverse
+//
+// Last() function to move the iterator to the last element.
+type ReverseIteratorWithKey interface {
+ // Prev moves the iterator to the previous element and returns true if there was a previous element in the container.
+ // If Prev() returns true, then previous element's key and value can be retrieved by Key() and Value().
+ // Modifies the state of the iterator.
+ Prev() bool
+
+ // End moves the iterator past the last element (one-past-the-end).
+ // Call Prev() to fetch the last element if any.
+ End()
+
+ // Last moves the iterator to the last element and returns true if there was a last element in the container.
+ // If Last() returns true, then last element's key and value can be retrieved by Key() and Value().
+ // Modifies the state of the iterator.
+ Last() bool
+
+ IteratorWithKey
+}
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package containers
+
+// JSONSerializer provides JSON serialization
+type JSONSerializer interface {
+ // ToJSON outputs the JSON representation of containers's elements.
+ ToJSON() ([]byte, error)
+}
+
+// JSONDeserializer provides JSON deserialization
+type JSONDeserializer interface {
+ // FromJSON populates containers's elements from the input JSON representation.
+ FromJSON([]byte) error
+}
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package arraylist implements the array list.
+//
+// Structure is not thread safe.
+//
+// Reference: https://en.wikipedia.org/wiki/List_%28abstract_data_type%29
+package arraylist
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/emirpasic/gods/lists"
+ "github.com/emirpasic/gods/utils"
+)
+
+func assertListImplementation() {
+ var _ lists.List = (*List)(nil)
+}
+
+// List holds the elements in a slice
+type List struct {
+ elements []interface{}
+ size int
+}
+
+const (
+ growthFactor = float32(2.0) // growth by 100%
+ shrinkFactor = float32(0.25) // shrink when size is 25% of capacity (0 means never shrink)
+)
+
+// New instantiates a new list and adds the passed values, if any, to the list
+func New(values ...interface{}) *List {
+ list := &List{}
+ if len(values) > 0 {
+ list.Add(values...)
+ }
+ return list
+}
+
+// Add appends a value at the end of the list
+func (list *List) Add(values ...interface{}) {
+ list.growBy(len(values))
+ for _, value := range values {
+ list.elements[list.size] = value
+ list.size++
+ }
+}
+
+// Get returns the element at index.
+// Second return parameter is true if index is within bounds of the array and array is not empty, otherwise false.
+func (list *List) Get(index int) (interface{}, bool) {
+
+ if !list.withinRange(index) {
+ return nil, false
+ }
+
+ return list.elements[index], true
+}
+
+// Remove removes the element at the given index from the list.
+func (list *List) Remove(index int) {
+
+ if !list.withinRange(index) {
+ return
+ }
+
+ list.elements[index] = nil // cleanup reference
+ copy(list.elements[index:], list.elements[index+1:list.size]) // shift to the left by one (slow operation, need ways to optimize this)
+ list.size--
+
+ list.shrink()
+}
+
+// Contains checks if elements (one or more) are present in the set.
+// All elements have to be present in the set for the method to return true.
+// Performance time complexity of n^2.
+// Returns true if no arguments are passed at all, i.e. set is always super-set of empty set.
+func (list *List) Contains(values ...interface{}) bool {
+
+ for _, searchValue := range values {
+ found := false
+ for _, element := range list.elements {
+ if element == searchValue {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return false
+ }
+ }
+ return true
+}
+
+// Values returns all elements in the list.
+func (list *List) Values() []interface{} {
+ newElements := make([]interface{}, list.size, list.size)
+ copy(newElements, list.elements[:list.size])
+ return newElements
+}
+
+//IndexOf returns index of provided element
+func (list *List) IndexOf(value interface{}) int {
+ if list.size == 0 {
+ return -1
+ }
+ for index, element := range list.elements {
+ if element == value {
+ return index
+ }
+ }
+ return -1
+}
+
+// Empty returns true if list does not contain any elements.
+func (list *List) Empty() bool {
+ return list.size == 0
+}
+
+// Size returns number of elements within the list.
+func (list *List) Size() int {
+ return list.size
+}
+
+// Clear removes all elements from the list.
+func (list *List) Clear() {
+ list.size = 0
+ list.elements = []interface{}{}
+}
+
+// Sort sorts values (in-place) using.
+func (list *List) Sort(comparator utils.Comparator) {
+ if len(list.elements) < 2 {
+ return
+ }
+ utils.Sort(list.elements[:list.size], comparator)
+}
+
+// Swap swaps the two values at the specified positions.
+func (list *List) Swap(i, j int) {
+ if list.withinRange(i) && list.withinRange(j) {
+ list.elements[i], list.elements[j] = list.elements[j], list.elements[i]
+ }
+}
+
+// Insert inserts values at specified index position shifting the value at that position (if any) and any subsequent elements to the right.
+// Does not do anything if position is negative or bigger than list's size
+// Note: position equal to list's size is valid, i.e. append.
+func (list *List) Insert(index int, values ...interface{}) {
+
+ if !list.withinRange(index) {
+ // Append
+ if index == list.size {
+ list.Add(values...)
+ }
+ return
+ }
+
+ l := len(values)
+ list.growBy(l)
+ list.size += l
+ copy(list.elements[index+l:], list.elements[index:list.size-l])
+ copy(list.elements[index:], values)
+}
+
+// Set the value at specified index
+// Does not do anything if position is negative or bigger than list's size
+// Note: position equal to list's size is valid, i.e. append.
+func (list *List) Set(index int, value interface{}) {
+
+ if !list.withinRange(index) {
+ // Append
+ if index == list.size {
+ list.Add(value)
+ }
+ return
+ }
+
+ list.elements[index] = value
+}
+
+// String returns a string representation of container
+func (list *List) String() string {
+ str := "ArrayList\n"
+ values := []string{}
+ for _, value := range list.elements[:list.size] {
+ values = append(values, fmt.Sprintf("%v", value))
+ }
+ str += strings.Join(values, ", ")
+ return str
+}
+
+// Check that the index is within bounds of the list
+func (list *List) withinRange(index int) bool {
+ return index >= 0 && index < list.size
+}
+
+func (list *List) resize(cap int) {
+ newElements := make([]interface{}, cap, cap)
+ copy(newElements, list.elements)
+ list.elements = newElements
+}
+
+// Expand the array if necessary, i.e. capacity will be reached if we add n elements
+func (list *List) growBy(n int) {
+ // When capacity is reached, grow by a factor of growthFactor and add number of elements
+ currentCapacity := cap(list.elements)
+ if list.size+n >= currentCapacity {
+ newCapacity := int(growthFactor * float32(currentCapacity+n))
+ list.resize(newCapacity)
+ }
+}
+
+// Shrink the array if necessary, i.e. when size is shrinkFactor percent of current capacity
+func (list *List) shrink() {
+ if shrinkFactor == 0.0 {
+ return
+ }
+ // Shrink when size is at shrinkFactor * capacity
+ currentCapacity := cap(list.elements)
+ if list.size <= int(float32(currentCapacity)*shrinkFactor) {
+ list.resize(list.size)
+ }
+}
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package arraylist
+
+import "github.com/emirpasic/gods/containers"
+
+func assertEnumerableImplementation() {
+ var _ containers.EnumerableWithIndex = (*List)(nil)
+}
+
+// Each calls the given function once for each element, passing that element's index and value.
+func (list *List) Each(f func(index int, value interface{})) {
+ iterator := list.Iterator()
+ for iterator.Next() {
+ f(iterator.Index(), iterator.Value())
+ }
+}
+
+// Map invokes the given function once for each element and returns a
+// container containing the values returned by the given function.
+func (list *List) Map(f func(index int, value interface{}) interface{}) *List {
+ newList := &List{}
+ iterator := list.Iterator()
+ for iterator.Next() {
+ newList.Add(f(iterator.Index(), iterator.Value()))
+ }
+ return newList
+}
+
+// Select returns a new container containing all elements for which the given function returns a true value.
+func (list *List) Select(f func(index int, value interface{}) bool) *List {
+ newList := &List{}
+ iterator := list.Iterator()
+ for iterator.Next() {
+ if f(iterator.Index(), iterator.Value()) {
+ newList.Add(iterator.Value())
+ }
+ }
+ return newList
+}
+
+// Any passes each element of the collection to the given function and
+// returns true if the function ever returns true for any element.
+func (list *List) Any(f func(index int, value interface{}) bool) bool {
+ iterator := list.Iterator()
+ for iterator.Next() {
+ if f(iterator.Index(), iterator.Value()) {
+ return true
+ }
+ }
+ return false
+}
+
+// All passes each element of the collection to the given function and
+// returns true if the function returns true for all elements.
+func (list *List) All(f func(index int, value interface{}) bool) bool {
+ iterator := list.Iterator()
+ for iterator.Next() {
+ if !f(iterator.Index(), iterator.Value()) {
+ return false
+ }
+ }
+ return true
+}
+
+// Find passes each element of the container to the given function and returns
+// the first (index,value) for which the function is true or -1,nil otherwise
+// if no element matches the criteria.
+func (list *List) Find(f func(index int, value interface{}) bool) (int, interface{}) {
+ iterator := list.Iterator()
+ for iterator.Next() {
+ if f(iterator.Index(), iterator.Value()) {
+ return iterator.Index(), iterator.Value()
+ }
+ }
+ return -1, nil
+}
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package arraylist
+
+import "github.com/emirpasic/gods/containers"
+
+func assertIteratorImplementation() {
+ var _ containers.ReverseIteratorWithIndex = (*Iterator)(nil)
+}
+
+// Iterator holding the iterator's state
+type Iterator struct {
+ list *List
+ index int
+}
+
+// Iterator returns a stateful iterator whose values can be fetched by an index.
+func (list *List) Iterator() Iterator {
+ return Iterator{list: list, index: -1}
+}
+
+// Next moves the iterator to the next element and returns true if there was a next element in the container.
+// If Next() returns true, then next element's index and value can be retrieved by Index() and Value().
+// If Next() was called for the first time, then it will point the iterator to the first element if it exists.
+// Modifies the state of the iterator.
+func (iterator *Iterator) Next() bool {
+ if iterator.index < iterator.list.size {
+ iterator.index++
+ }
+ return iterator.list.withinRange(iterator.index)
+}
+
+// Prev moves the iterator to the previous element and returns true if there was a previous element in the container.
+// If Prev() returns true, then previous element's index and value can be retrieved by Index() and Value().
+// Modifies the state of the iterator.
+func (iterator *Iterator) Prev() bool {
+ if iterator.index >= 0 {
+ iterator.index--
+ }
+ return iterator.list.withinRange(iterator.index)
+}
+
+// Value returns the current element's value.
+// Does not modify the state of the iterator.
+func (iterator *Iterator) Value() interface{} {
+ return iterator.list.elements[iterator.index]
+}
+
+// Index returns the current element's index.
+// Does not modify the state of the iterator.
+func (iterator *Iterator) Index() int {
+ return iterator.index
+}
+
+// Begin resets the iterator to its initial state (one-before-first)
+// Call Next() to fetch the first element if any.
+func (iterator *Iterator) Begin() {
+ iterator.index = -1
+}
+
+// End moves the iterator past the last element (one-past-the-end).
+// Call Prev() to fetch the last element if any.
+func (iterator *Iterator) End() {
+ iterator.index = iterator.list.size
+}
+
+// First moves the iterator to the first element and returns true if there was a first element in the container.
+// If First() returns true, then first element's index and value can be retrieved by Index() and Value().
+// Modifies the state of the iterator.
+func (iterator *Iterator) First() bool {
+ iterator.Begin()
+ return iterator.Next()
+}
+
+// Last moves the iterator to the last element and returns true if there was a last element in the container.
+// If Last() returns true, then last element's index and value can be retrieved by Index() and Value().
+// Modifies the state of the iterator.
+func (iterator *Iterator) Last() bool {
+ iterator.End()
+ return iterator.Prev()
+}
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package arraylist
+
+import (
+ "encoding/json"
+ "github.com/emirpasic/gods/containers"
+)
+
+func assertSerializationImplementation() {
+ var _ containers.JSONSerializer = (*List)(nil)
+ var _ containers.JSONDeserializer = (*List)(nil)
+}
+
+// ToJSON outputs the JSON representation of list's elements.
+func (list *List) ToJSON() ([]byte, error) {
+ return json.Marshal(list.elements[:list.size])
+}
+
+// FromJSON populates list's elements from the input JSON representation.
+func (list *List) FromJSON(data []byte) error {
+ err := json.Unmarshal(data, &list.elements)
+ if err == nil {
+ list.size = len(list.elements)
+ }
+ return err
+}
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package lists provides an abstract List interface.
+//
+// In computer science, a list or sequence is an abstract data type that represents an ordered sequence of values, where the same value may occur more than once. An instance of a list is a computer representation of the mathematical concept of a finite sequence; the (potentially) infinite analog of a list is a stream. Lists are a basic example of containers, as they contain other values. If the same value occurs multiple times, each occurrence is considered a distinct item.
+//
+// Reference: https://en.wikipedia.org/wiki/List_%28abstract_data_type%29
+package lists
+
+import (
+ "github.com/emirpasic/gods/containers"
+ "github.com/emirpasic/gods/utils"
+)
+
+// List interface that all lists implement
+type List interface {
+ Get(index int) (interface{}, bool)
+ Remove(index int)
+ Add(values ...interface{})
+ Contains(values ...interface{}) bool
+ Sort(comparator utils.Comparator)
+ Swap(index1, index2 int)
+ Insert(index int, values ...interface{})
+ Set(index int, value interface{})
+
+ containers.Container
+ // Empty() bool
+ // Size() int
+ // Clear()
+ // Values() []interface{}
+}
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package binaryheap implements a binary heap backed by array list.
+//
+// Comparator defines this heap as either min or max heap.
+//
+// Structure is not thread safe.
+//
+// References: http://en.wikipedia.org/wiki/Binary_heap
+package binaryheap
+
+import (
+ "fmt"
+ "github.com/emirpasic/gods/lists/arraylist"
+ "github.com/emirpasic/gods/trees"
+ "github.com/emirpasic/gods/utils"
+ "strings"
+)
+
+func assertTreeImplementation() {
+ var _ trees.Tree = (*Heap)(nil)
+}
+
+// Heap holds elements in an array-list
+type Heap struct {
+ list *arraylist.List
+ Comparator utils.Comparator
+}
+
+// NewWith instantiates a new empty heap tree with the custom comparator.
+func NewWith(comparator utils.Comparator) *Heap {
+ return &Heap{list: arraylist.New(), Comparator: comparator}
+}
+
+// NewWithIntComparator instantiates a new empty heap with the IntComparator, i.e. elements are of type int.
+func NewWithIntComparator() *Heap {
+ return &Heap{list: arraylist.New(), Comparator: utils.IntComparator}
+}
+
+// NewWithStringComparator instantiates a new empty heap with the StringComparator, i.e. elements are of type string.
+func NewWithStringComparator() *Heap {
+ return &Heap{list: arraylist.New(), Comparator: utils.StringComparator}
+}
+
+// Push adds a value onto the heap and bubbles it up accordingly.
+func (heap *Heap) Push(values ...interface{}) {
+ if len(values) == 1 {
+ heap.list.Add(values[0])
+ heap.bubbleUp()
+ } else {
+ // Reference: https://en.wikipedia.org/wiki/Binary_heap#Building_a_heap
+ for _, value := range values {
+ heap.list.Add(value)
+ }
+ size := heap.list.Size()/2 + 1
+ for i := size; i >= 0; i-- {
+ heap.bubbleDownIndex(i)
+ }
+ }
+}
+
+// Pop removes top element on heap and returns it, or nil if heap is empty.
+// Second return parameter is true, unless the heap was empty and there was nothing to pop.
+func (heap *Heap) Pop() (value interface{}, ok bool) {
+ value, ok = heap.list.Get(0)
+ if !ok {
+ return
+ }
+ lastIndex := heap.list.Size() - 1
+ heap.list.Swap(0, lastIndex)
+ heap.list.Remove(lastIndex)
+ heap.bubbleDown()
+ return
+}
+
+// Peek returns top element on the heap without removing it, or nil if heap is empty.
+// Second return parameter is true, unless the heap was empty and there was nothing to peek.
+func (heap *Heap) Peek() (value interface{}, ok bool) {
+ return heap.list.Get(0)
+}
+
+// Empty returns true if heap does not contain any elements.
+func (heap *Heap) Empty() bool {
+ return heap.list.Empty()
+}
+
+// Size returns number of elements within the heap.
+func (heap *Heap) Size() int {
+ return heap.list.Size()
+}
+
+// Clear removes all elements from the heap.
+func (heap *Heap) Clear() {
+ heap.list.Clear()
+}
+
+// Values returns all elements in the heap.
+func (heap *Heap) Values() []interface{} {
+ return heap.list.Values()
+}
+
+// String returns a string representation of container
+func (heap *Heap) String() string {
+ str := "BinaryHeap\n"
+ values := []string{}
+ for _, value := range heap.list.Values() {
+ values = append(values, fmt.Sprintf("%v", value))
+ }
+ str += strings.Join(values, ", ")
+ return str
+}
+
+// Performs the "bubble down" operation. This is to place the element that is at the root
+// of the heap in its correct place so that the heap maintains the min/max-heap order property.
+func (heap *Heap) bubbleDown() {
+ heap.bubbleDownIndex(0)
+}
+
+// Performs the "bubble down" operation. This is to place the element that is at the index
+// of the heap in its correct place so that the heap maintains the min/max-heap order property.
+func (heap *Heap) bubbleDownIndex(index int) {
+ size := heap.list.Size()
+ for leftIndex := index<<1 + 1; leftIndex < size; leftIndex = index<<1 + 1 {
+ rightIndex := index<<1 + 2
+ smallerIndex := leftIndex
+ leftValue, _ := heap.list.Get(leftIndex)
+ rightValue, _ := heap.list.Get(rightIndex)
+ if rightIndex < size && heap.Comparator(leftValue, rightValue) > 0 {
+ smallerIndex = rightIndex
+ }
+ indexValue, _ := heap.list.Get(index)
+ smallerValue, _ := heap.list.Get(smallerIndex)
+ if heap.Comparator(indexValue, smallerValue) > 0 {
+ heap.list.Swap(index, smallerIndex)
+ } else {
+ break
+ }
+ index = smallerIndex
+ }
+}
+
+// Performs the "bubble up" operation. This is to place a newly inserted
+// element (i.e. last element in the list) in its correct place so that
+// the heap maintains the min/max-heap order property.
+func (heap *Heap) bubbleUp() {
+ index := heap.list.Size() - 1
+ for parentIndex := (index - 1) >> 1; index > 0; parentIndex = (index - 1) >> 1 {
+ indexValue, _ := heap.list.Get(index)
+ parentValue, _ := heap.list.Get(parentIndex)
+ if heap.Comparator(parentValue, indexValue) <= 0 {
+ break
+ }
+ heap.list.Swap(index, parentIndex)
+ index = parentIndex
+ }
+}
+
+// Check that the index is within bounds of the list
+func (heap *Heap) withinRange(index int) bool {
+ return index >= 0 && index < heap.list.Size()
+}
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package binaryheap
+
+import "github.com/emirpasic/gods/containers"
+
+func assertIteratorImplementation() {
+ var _ containers.ReverseIteratorWithIndex = (*Iterator)(nil)
+}
+
+// Iterator returns a stateful iterator whose values can be fetched by an index.
+type Iterator struct {
+ heap *Heap
+ index int
+}
+
+// Iterator returns a stateful iterator whose values can be fetched by an index.
+func (heap *Heap) Iterator() Iterator {
+ return Iterator{heap: heap, index: -1}
+}
+
+// Next moves the iterator to the next element and returns true if there was a next element in the container.
+// If Next() returns true, then next element's index and value can be retrieved by Index() and Value().
+// If Next() was called for the first time, then it will point the iterator to the first element if it exists.
+// Modifies the state of the iterator.
+func (iterator *Iterator) Next() bool {
+ if iterator.index < iterator.heap.Size() {
+ iterator.index++
+ }
+ return iterator.heap.withinRange(iterator.index)
+}
+
+// Prev moves the iterator to the previous element and returns true if there was a previous element in the container.
+// If Prev() returns true, then previous element's index and value can be retrieved by Index() and Value().
+// Modifies the state of the iterator.
+func (iterator *Iterator) Prev() bool {
+ if iterator.index >= 0 {
+ iterator.index--
+ }
+ return iterator.heap.withinRange(iterator.index)
+}
+
+// Value returns the current element's value.
+// Does not modify the state of the iterator.
+func (iterator *Iterator) Value() interface{} {
+ value, _ := iterator.heap.list.Get(iterator.index)
+ return value
+}
+
+// Index returns the current element's index.
+// Does not modify the state of the iterator.
+func (iterator *Iterator) Index() int {
+ return iterator.index
+}
+
+// Begin resets the iterator to its initial state (one-before-first)
+// Call Next() to fetch the first element if any.
+func (iterator *Iterator) Begin() {
+ iterator.index = -1
+}
+
+// End moves the iterator past the last element (one-past-the-end).
+// Call Prev() to fetch the last element if any.
+func (iterator *Iterator) End() {
+ iterator.index = iterator.heap.Size()
+}
+
+// First moves the iterator to the first element and returns true if there was a first element in the container.
+// If First() returns true, then first element's index and value can be retrieved by Index() and Value().
+// Modifies the state of the iterator.
+func (iterator *Iterator) First() bool {
+ iterator.Begin()
+ return iterator.Next()
+}
+
+// Last moves the iterator to the last element and returns true if there was a last element in the container.
+// If Last() returns true, then last element's index and value can be retrieved by Index() and Value().
+// Modifies the state of the iterator.
+func (iterator *Iterator) Last() bool {
+ iterator.End()
+ return iterator.Prev()
+}
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package binaryheap
+
+import "github.com/emirpasic/gods/containers"
+
+func assertSerializationImplementation() {
+ var _ containers.JSONSerializer = (*Heap)(nil)
+ var _ containers.JSONDeserializer = (*Heap)(nil)
+}
+
+// ToJSON outputs the JSON representation of the heap.
+func (heap *Heap) ToJSON() ([]byte, error) {
+ return heap.list.ToJSON()
+}
+
+// FromJSON populates the heap from the input JSON representation.
+func (heap *Heap) FromJSON(data []byte) error {
+ return heap.list.FromJSON(data)
+}
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package trees provides an abstract Tree interface.
+//
+// In computer science, a tree is a widely used abstract data type (ADT) or data structure implementing this ADT that simulates a hierarchical tree structure, with a root value and subtrees of children with a parent node, represented as a set of linked nodes.
+//
+// Reference: https://en.wikipedia.org/wiki/Tree_%28data_structure%29
+package trees
+
+import "github.com/emirpasic/gods/containers"
+
+// Tree interface that all trees implement
+type Tree interface {
+ containers.Container
+ // Empty() bool
+ // Size() int
+ // Clear()
+ // Values() []interface{}
+}
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package utils
+
+import "time"
+
+// Comparator will make type assertion (see IntComparator for example),
+// which will panic if a or b are not of the asserted type.
+//
+// Should return a number:
+// negative , if a < b
+// zero , if a == b
+// positive , if a > b
+type Comparator func(a, b interface{}) int
+
+// StringComparator provides a fast comparison on strings
+func StringComparator(a, b interface{}) int {
+ s1 := a.(string)
+ s2 := b.(string)
+ min := len(s2)
+ if len(s1) < len(s2) {
+ min = len(s1)
+ }
+ diff := 0
+ for i := 0; i < min && diff == 0; i++ {
+ diff = int(s1[i]) - int(s2[i])
+ }
+ if diff == 0 {
+ diff = len(s1) - len(s2)
+ }
+ if diff < 0 {
+ return -1
+ }
+ if diff > 0 {
+ return 1
+ }
+ return 0
+}
+
+// IntComparator provides a basic comparison on int
+func IntComparator(a, b interface{}) int {
+ aAsserted := a.(int)
+ bAsserted := b.(int)
+ switch {
+ case aAsserted > bAsserted:
+ return 1
+ case aAsserted < bAsserted:
+ return -1
+ default:
+ return 0
+ }
+}
+
+// Int8Comparator provides a basic comparison on int8
+func Int8Comparator(a, b interface{}) int {
+ aAsserted := a.(int8)
+ bAsserted := b.(int8)
+ switch {
+ case aAsserted > bAsserted:
+ return 1
+ case aAsserted < bAsserted:
+ return -1
+ default:
+ return 0
+ }
+}
+
+// Int16Comparator provides a basic comparison on int16
+func Int16Comparator(a, b interface{}) int {
+ aAsserted := a.(int16)
+ bAsserted := b.(int16)
+ switch {
+ case aAsserted > bAsserted:
+ return 1
+ case aAsserted < bAsserted:
+ return -1
+ default:
+ return 0
+ }
+}
+
+// Int32Comparator provides a basic comparison on int32
+func Int32Comparator(a, b interface{}) int {
+ aAsserted := a.(int32)
+ bAsserted := b.(int32)
+ switch {
+ case aAsserted > bAsserted:
+ return 1
+ case aAsserted < bAsserted:
+ return -1
+ default:
+ return 0
+ }
+}
+
+// Int64Comparator provides a basic comparison on int64
+func Int64Comparator(a, b interface{}) int {
+ aAsserted := a.(int64)
+ bAsserted := b.(int64)
+ switch {
+ case aAsserted > bAsserted:
+ return 1
+ case aAsserted < bAsserted:
+ return -1
+ default:
+ return 0
+ }
+}
+
+// UIntComparator provides a basic comparison on uint
+func UIntComparator(a, b interface{}) int {
+ aAsserted := a.(uint)
+ bAsserted := b.(uint)
+ switch {
+ case aAsserted > bAsserted:
+ return 1
+ case aAsserted < bAsserted:
+ return -1
+ default:
+ return 0
+ }
+}
+
+// UInt8Comparator provides a basic comparison on uint8
+func UInt8Comparator(a, b interface{}) int {
+ aAsserted := a.(uint8)
+ bAsserted := b.(uint8)
+ switch {
+ case aAsserted > bAsserted:
+ return 1
+ case aAsserted < bAsserted:
+ return -1
+ default:
+ return 0
+ }
+}
+
+// UInt16Comparator provides a basic comparison on uint16
+func UInt16Comparator(a, b interface{}) int {
+ aAsserted := a.(uint16)
+ bAsserted := b.(uint16)
+ switch {
+ case aAsserted > bAsserted:
+ return 1
+ case aAsserted < bAsserted:
+ return -1
+ default:
+ return 0
+ }
+}
+
+// UInt32Comparator provides a basic comparison on uint32
+func UInt32Comparator(a, b interface{}) int {
+ aAsserted := a.(uint32)
+ bAsserted := b.(uint32)
+ switch {
+ case aAsserted > bAsserted:
+ return 1
+ case aAsserted < bAsserted:
+ return -1
+ default:
+ return 0
+ }
+}
+
+// UInt64Comparator provides a basic comparison on uint64
+func UInt64Comparator(a, b interface{}) int {
+ aAsserted := a.(uint64)
+ bAsserted := b.(uint64)
+ switch {
+ case aAsserted > bAsserted:
+ return 1
+ case aAsserted < bAsserted:
+ return -1
+ default:
+ return 0
+ }
+}
+
+// Float32Comparator provides a basic comparison on float32
+func Float32Comparator(a, b interface{}) int {
+ aAsserted := a.(float32)
+ bAsserted := b.(float32)
+ switch {
+ case aAsserted > bAsserted:
+ return 1
+ case aAsserted < bAsserted:
+ return -1
+ default:
+ return 0
+ }
+}
+
+// Float64Comparator provides a basic comparison on float64
+func Float64Comparator(a, b interface{}) int {
+ aAsserted := a.(float64)
+ bAsserted := b.(float64)
+ switch {
+ case aAsserted > bAsserted:
+ return 1
+ case aAsserted < bAsserted:
+ return -1
+ default:
+ return 0
+ }
+}
+
+// ByteComparator provides a basic comparison on byte
+func ByteComparator(a, b interface{}) int {
+ aAsserted := a.(byte)
+ bAsserted := b.(byte)
+ switch {
+ case aAsserted > bAsserted:
+ return 1
+ case aAsserted < bAsserted:
+ return -1
+ default:
+ return 0
+ }
+}
+
+// RuneComparator provides a basic comparison on rune
+func RuneComparator(a, b interface{}) int {
+ aAsserted := a.(rune)
+ bAsserted := b.(rune)
+ switch {
+ case aAsserted > bAsserted:
+ return 1
+ case aAsserted < bAsserted:
+ return -1
+ default:
+ return 0
+ }
+}
+
+// TimeComparator provides a basic comparison on time.Time
+func TimeComparator(a, b interface{}) int {
+ aAsserted := a.(time.Time)
+ bAsserted := b.(time.Time)
+
+ switch {
+ case aAsserted.After(bAsserted):
+ return 1
+ case aAsserted.Before(bAsserted):
+ return -1
+ default:
+ return 0
+ }
+}
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package utils
+
+import "sort"
+
+// Sort sorts values (in-place) with respect to the given comparator.
+//
+// Uses Go's sort (hybrid of quicksort for large and then insertion sort for smaller slices).
+func Sort(values []interface{}, comparator Comparator) {
+ sort.Sort(sortable{values, comparator})
+}
+
+type sortable struct {
+ values []interface{}
+ comparator Comparator
+}
+
+func (s sortable) Len() int {
+ return len(s.values)
+}
+func (s sortable) Swap(i, j int) {
+ s.values[i], s.values[j] = s.values[j], s.values[i]
+}
+func (s sortable) Less(i, j int) bool {
+ return s.comparator(s.values[i], s.values[j]) < 0
+}
--- /dev/null
+// Copyright (c) 2015, Emir Pasic. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package utils provides common utility functions.
+//
+// Provided functionalities:
+// - sorting
+// - comparators
+package utils
+
+import (
+ "fmt"
+ "strconv"
+)
+
+// ToString converts a value to string.
+func ToString(value interface{}) string {
+ switch value.(type) {
+ case string:
+ return value.(string)
+ case int8:
+ return strconv.FormatInt(int64(value.(int8)), 10)
+ case int16:
+ return strconv.FormatInt(int64(value.(int16)), 10)
+ case int32:
+ return strconv.FormatInt(int64(value.(int32)), 10)
+ case int64:
+ return strconv.FormatInt(int64(value.(int64)), 10)
+ case uint8:
+ return strconv.FormatUint(uint64(value.(uint8)), 10)
+ case uint16:
+ return strconv.FormatUint(uint64(value.(uint16)), 10)
+ case uint32:
+ return strconv.FormatUint(uint64(value.(uint32)), 10)
+ case uint64:
+ return strconv.FormatUint(uint64(value.(uint64)), 10)
+ case float32:
+ return strconv.FormatFloat(float64(value.(float32)), 'g', -1, 64)
+ case float64:
+ return strconv.FormatFloat(float64(value.(float64)), 'g', -1, 64)
+ case bool:
+ return strconv.FormatBool(value.(bool))
+ default:
+ return fmt.Sprintf("%+v", value)
+ }
+}
--- /dev/null
+The MIT License (MIT)
+
+Copyright (c) 2014 Juan Batiz-Benet
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
--- /dev/null
+// Package ctxio provides io.Reader and io.Writer wrappers that
+// respect context.Contexts. Use these at the interface between
+// your context code and your io.
+//
+// WARNING: read the code. see how writes and reads will continue
+// until you cancel the io. Maybe this package should provide
+// versions of io.ReadCloser and io.WriteCloser that automatically
+// call .Close when the context expires. But for now -- since in my
+// use cases I have long-lived connections with ephemeral io wrappers
+// -- this has yet to be a need.
+package ctxio
+
+import (
+ "io"
+
+ context "golang.org/x/net/context"
+)
+
+type ioret struct {
+ n int
+ err error
+}
+
+type Writer interface {
+ io.Writer
+}
+
+type ctxWriter struct {
+ w io.Writer
+ ctx context.Context
+}
+
+// NewWriter wraps a writer to make it respect given Context.
+// If there is a blocking write, the returned Writer will return
+// whenever the context is cancelled (the return values are n=0
+// and err=ctx.Err().)
+//
+// Note well: this wrapper DOES NOT ACTUALLY cancel the underlying
+// write-- there is no way to do that with the standard go io
+// interface. So the read and write _will_ happen or hang. So, use
+// this sparingly, make sure to cancel the read or write as necesary
+// (e.g. closing a connection whose context is up, etc.)
+//
+// Furthermore, in order to protect your memory from being read
+// _after_ you've cancelled the context, this io.Writer will
+// first make a **copy** of the buffer.
+func NewWriter(ctx context.Context, w io.Writer) *ctxWriter {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ return &ctxWriter{ctx: ctx, w: w}
+}
+
+func (w *ctxWriter) Write(buf []byte) (int, error) {
+ buf2 := make([]byte, len(buf))
+ copy(buf2, buf)
+
+ c := make(chan ioret, 1)
+
+ go func() {
+ n, err := w.w.Write(buf2)
+ c <- ioret{n, err}
+ close(c)
+ }()
+
+ select {
+ case r := <-c:
+ return r.n, r.err
+ case <-w.ctx.Done():
+ return 0, w.ctx.Err()
+ }
+}
+
+type Reader interface {
+ io.Reader
+}
+
+type ctxReader struct {
+ r io.Reader
+ ctx context.Context
+}
+
+// NewReader wraps a reader to make it respect given Context.
+// If there is a blocking read, the returned Reader will return
+// whenever the context is cancelled (the return values are n=0
+// and err=ctx.Err().)
+//
+// Note well: this wrapper DOES NOT ACTUALLY cancel the underlying
+// write-- there is no way to do that with the standard go io
+// interface. So the read and write _will_ happen or hang. So, use
+// this sparingly, make sure to cancel the read or write as necesary
+// (e.g. closing a connection whose context is up, etc.)
+//
+// Furthermore, in order to protect your memory from being read
+// _before_ you've cancelled the context, this io.Reader will
+// allocate a buffer of the same size, and **copy** into the client's
+// if the read succeeds in time.
+func NewReader(ctx context.Context, r io.Reader) *ctxReader {
+ return &ctxReader{ctx: ctx, r: r}
+}
+
+func (r *ctxReader) Read(buf []byte) (int, error) {
+ buf2 := make([]byte, len(buf))
+
+ c := make(chan ioret, 1)
+
+ go func() {
+ n, err := r.r.Read(buf2)
+ c <- ioret{n, err}
+ close(c)
+ }()
+
+ select {
+ case ret := <-c:
+ copy(buf, buf2)
+ return ret.n, ret.err
+ case <-r.ctx.Done():
+ return 0, r.ctx.Err()
+ }
+}
--- /dev/null
+Eugene Terentev <eugene@terentev.net>
+Kevin Burke <kev@inburke.com>
+Sergey Lukjanov <me@slukjanov.name>
+Wayne Ashley Berry <wayneashleyberry@gmail.com>
--- /dev/null
+Copyright (c) 2017 Kevin Burke.
+
+Permission is hereby granted, free of charge, to any person
+obtaining a copy of this software and associated documentation
+files (the "Software"), to deal in the Software without
+restriction, including without limitation the rights to use,
+copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the
+Software is furnished to do so, subject to the following
+conditions:
+
+The above copyright notice and this permission notice shall be
+included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
+OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+OTHER DEALINGS IN THE SOFTWARE.
+
+===================
+
+The lexer and parser borrow heavily from github.com/pelletier/go-toml. The
+license for that project is copied below.
+
+The MIT License (MIT)
+
+Copyright (c) 2013 - 2017 Thomas Pelletier, Eric Anderton
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
--- /dev/null
+// Package ssh_config provides tools for manipulating SSH config files.
+//
+// Importantly, this parser attempts to preserve comments in a given file, so
+// you can manipulate a `ssh_config` file from a program, if your heart desires.
+//
+// The Get() and GetStrict() functions will attempt to read values from
+// $HOME/.ssh/config, falling back to /etc/ssh/ssh_config. The first argument is
+// the host name to match on ("example.com"), and the second argument is the key
+// you want to retrieve ("Port"). The keywords are case insensitive.
+//
+// port := ssh_config.Get("myhost", "Port")
+//
+// You can also manipulate an SSH config file and then print it or write it back
+// to disk.
+//
+// f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config"))
+// cfg, _ := ssh_config.Decode(f)
+// for _, host := range cfg.Hosts {
+// fmt.Println("patterns:", host.Patterns)
+// for _, node := range host.Nodes {
+// fmt.Println(node.String())
+// }
+// }
+//
+// // Write the cfg back to disk:
+// fmt.Println(cfg.String())
+//
+// BUG: the Match directive is currently unsupported; parsing a config with
+// a Match directive will trigger an error.
+package ssh_config
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ osuser "os/user"
+ "path/filepath"
+ "regexp"
+ "runtime"
+ "strings"
+ "sync"
+)
+
+const version = "0.5"
+
+type configFinder func() string
+
+// UserSettings checks ~/.ssh and /etc/ssh for configuration files. The config
+// files are parsed and cached the first time Get() or GetStrict() is called.
+type UserSettings struct {
+ IgnoreErrors bool
+ systemConfig *Config
+ systemConfigFinder configFinder
+ userConfig *Config
+ userConfigFinder configFinder
+ loadConfigs sync.Once
+ onceErr error
+}
+
+func homedir() string {
+ user, err := osuser.Current()
+ if err == nil {
+ return user.HomeDir
+ } else {
+ return os.Getenv("HOME")
+ }
+}
+
+func userConfigFinder() string {
+ return filepath.Join(homedir(), ".ssh", "config")
+}
+
+// DefaultUserSettings is the default UserSettings and is used by Get and
+// GetStrict. It checks both $HOME/.ssh/config and /etc/ssh/ssh_config for keys,
+// and it will return parse errors (if any) instead of swallowing them.
+var DefaultUserSettings = &UserSettings{
+ IgnoreErrors: false,
+ systemConfigFinder: systemConfigFinder,
+ userConfigFinder: userConfigFinder,
+}
+
+func systemConfigFinder() string {
+ return filepath.Join("/", "etc", "ssh", "ssh_config")
+}
+
+func findVal(c *Config, alias, key string) (string, error) {
+ if c == nil {
+ return "", nil
+ }
+ val, err := c.Get(alias, key)
+ if err != nil || val == "" {
+ return "", err
+ }
+ if err := validate(key, val); err != nil {
+ return "", err
+ }
+ return val, nil
+}
+
+// Get finds the first value for key within a declaration that matches the
+// alias. Get returns the empty string if no value was found, or if IgnoreErrors
+// is false and we could not parse the configuration file. Use GetStrict to
+// disambiguate the latter cases.
+//
+// The match for key is case insensitive.
+//
+// Get is a wrapper around DefaultUserSettings.Get.
+func Get(alias, key string) string {
+ return DefaultUserSettings.Get(alias, key)
+}
+
+// GetStrict finds the first value for key within a declaration that matches the
+// alias. If key has a default value and no matching configuration is found, the
+// default will be returned. For more information on default values and the way
+// patterns are matched, see the manpage for ssh_config.
+//
+// error will be non-nil if and only if a user's configuration file or the
+// system configuration file could not be parsed, and u.IgnoreErrors is false.
+//
+// GetStrict is a wrapper around DefaultUserSettings.GetStrict.
+func GetStrict(alias, key string) (string, error) {
+ return DefaultUserSettings.GetStrict(alias, key)
+}
+
+// Get finds the first value for key within a declaration that matches the
+// alias. Get returns the empty string if no value was found, or if IgnoreErrors
+// is false and we could not parse the configuration file. Use GetStrict to
+// disambiguate the latter cases.
+//
+// The match for key is case insensitive.
+func (u *UserSettings) Get(alias, key string) string {
+ val, err := u.GetStrict(alias, key)
+ if err != nil {
+ return ""
+ }
+ return val
+}
+
+// GetStrict finds the first value for key within a declaration that matches the
+// alias. If key has a default value and no matching configuration is found, the
+// default will be returned. For more information on default values and the way
+// patterns are matched, see the manpage for ssh_config.
+//
+// error will be non-nil if and only if a user's configuration file or the
+// system configuration file could not be parsed, and u.IgnoreErrors is false.
+func (u *UserSettings) GetStrict(alias, key string) (string, error) {
+ u.loadConfigs.Do(func() {
+ // can't parse user file, that's ok.
+ var filename string
+ if u.userConfigFinder == nil {
+ filename = userConfigFinder()
+ } else {
+ filename = u.userConfigFinder()
+ }
+ var err error
+ u.userConfig, err = parseFile(filename)
+ if err != nil && os.IsNotExist(err) == false {
+ u.onceErr = err
+ return
+ }
+ if u.systemConfigFinder == nil {
+ filename = systemConfigFinder()
+ } else {
+ filename = u.systemConfigFinder()
+ }
+ u.systemConfig, err = parseFile(filename)
+ if err != nil && os.IsNotExist(err) == false {
+ u.onceErr = err
+ return
+ }
+ })
+ if u.onceErr != nil && u.IgnoreErrors == false {
+ return "", u.onceErr
+ }
+ val, err := findVal(u.userConfig, alias, key)
+ if err != nil || val != "" {
+ return val, err
+ }
+ val2, err2 := findVal(u.systemConfig, alias, key)
+ if err2 != nil || val2 != "" {
+ return val2, err2
+ }
+ return Default(key), nil
+}
+
+func parseFile(filename string) (*Config, error) {
+ return parseWithDepth(filename, 0)
+}
+
+func parseWithDepth(filename string, depth uint8) (*Config, error) {
+ f, err := os.Open(filename)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+ return decode(f, isSystem(filename), depth)
+}
+
+func isSystem(filename string) bool {
+ // TODO i'm not sure this is the best way to detect a system repo
+ return strings.HasPrefix(filepath.Clean(filename), "/etc/ssh")
+}
+
+// Decode reads r into a Config, or returns an error if r could not be parsed as
+// an SSH config file.
+func Decode(r io.Reader) (*Config, error) {
+ return decode(r, false, 0)
+}
+
+func decode(r io.Reader, system bool, depth uint8) (c *Config, err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if _, ok := r.(runtime.Error); ok {
+ panic(r)
+ }
+ if e, ok := r.(error); ok && e == ErrDepthExceeded {
+ err = e
+ return
+ }
+ err = errors.New(r.(string))
+ }
+ }()
+
+ c = parseSSH(lexSSH(r), system, depth)
+ return c, err
+}
+
+// Config represents an SSH config file.
+type Config struct {
+ // A list of hosts to match against. The file begins with an implicit
+ // "Host *" declaration matching all hosts.
+ Hosts []*Host
+ depth uint8
+ position Position
+}
+
+// Get finds the first value in the configuration that matches the alias and
+// contains key. Get returns the empty string if no value was found, or if the
+// Config contains an invalid conditional Include value.
+//
+// The match for key is case insensitive.
+func (c *Config) Get(alias, key string) (string, error) {
+ lowerKey := strings.ToLower(key)
+ for _, host := range c.Hosts {
+ if !host.Matches(alias) {
+ continue
+ }
+ for _, node := range host.Nodes {
+ switch t := node.(type) {
+ case *Empty:
+ continue
+ case *KV:
+ // "keys are case insensitive" per the spec
+ lkey := strings.ToLower(t.Key)
+ if lkey == "match" {
+ panic("can't handle Match directives")
+ }
+ if lkey == lowerKey {
+ return t.Value, nil
+ }
+ case *Include:
+ val := t.Get(alias, key)
+ if val != "" {
+ return val, nil
+ }
+ default:
+ return "", fmt.Errorf("unknown Node type %v", t)
+ }
+ }
+ }
+ return "", nil
+}
+
+// String returns a string representation of the Config file.
+func (c Config) String() string {
+ return marshal(c).String()
+}
+
+func (c Config) MarshalText() ([]byte, error) {
+ return marshal(c).Bytes(), nil
+}
+
+func marshal(c Config) *bytes.Buffer {
+ var buf bytes.Buffer
+ for i := range c.Hosts {
+ buf.WriteString(c.Hosts[i].String())
+ }
+ return &buf
+}
+
+// Pattern is a pattern in a Host declaration. Patterns are read-only values;
+// create a new one with NewPattern().
+type Pattern struct {
+ str string // Its appearance in the file, not the value that gets compiled.
+ regex *regexp.Regexp
+ not bool // True if this is a negated match
+}
+
+// String prints the string representation of the pattern.
+func (p Pattern) String() string {
+ return p.str
+}
+
+// Copied from regexp.go with * and ? removed.
+var specialBytes = []byte(`\.+()|[]{}^$`)
+
+func special(b byte) bool {
+ return bytes.IndexByte(specialBytes, b) >= 0
+}
+
+// NewPattern creates a new Pattern for matching hosts. NewPattern("*") creates
+// a Pattern that matches all hosts.
+//
+// From the manpage, a pattern consists of zero or more non-whitespace
+// characters, `*' (a wildcard that matches zero or more characters), or `?' (a
+// wildcard that matches exactly one character). For example, to specify a set
+// of declarations for any host in the ".co.uk" set of domains, the following
+// pattern could be used:
+//
+// Host *.co.uk
+//
+// The following pattern would match any host in the 192.168.0.[0-9] network range:
+//
+// Host 192.168.0.?
+func NewPattern(s string) (*Pattern, error) {
+ if s == "" {
+ return nil, errors.New("ssh_config: empty pattern")
+ }
+ negated := false
+ if s[0] == '!' {
+ negated = true
+ s = s[1:]
+ }
+ var buf bytes.Buffer
+ buf.WriteByte('^')
+ for i := 0; i < len(s); i++ {
+ // A byte loop is correct because all metacharacters are ASCII.
+ switch b := s[i]; b {
+ case '*':
+ buf.WriteString(".*")
+ case '?':
+ buf.WriteString(".?")
+ default:
+ // borrowing from QuoteMeta here.
+ if special(b) {
+ buf.WriteByte('\\')
+ }
+ buf.WriteByte(b)
+ }
+ }
+ buf.WriteByte('$')
+ r, err := regexp.Compile(buf.String())
+ if err != nil {
+ return nil, err
+ }
+ return &Pattern{str: s, regex: r, not: negated}, nil
+}
+
+// Host describes a Host directive and the keywords that follow it.
+type Host struct {
+ // A list of host patterns that should match this host.
+ Patterns []*Pattern
+ // A Node is either a key/value pair or a comment line.
+ Nodes []Node
+ // EOLComment is the comment (if any) terminating the Host line.
+ EOLComment string
+ hasEquals bool
+ leadingSpace uint16 // TODO: handle spaces vs tabs here.
+ // The file starts with an implicit "Host *" declaration.
+ implicit bool
+}
+
+// Matches returns true if the Host matches for the given alias. For
+// a description of the rules that provide a match, see the manpage for
+// ssh_config.
+func (h *Host) Matches(alias string) bool {
+ found := false
+ for i := range h.Patterns {
+ if h.Patterns[i].regex.MatchString(alias) {
+ if h.Patterns[i].not == true {
+ // Negated match. "A pattern entry may be negated by prefixing
+ // it with an exclamation mark (`!'). If a negated entry is
+ // matched, then the Host entry is ignored, regardless of
+ // whether any other patterns on the line match. Negated matches
+ // are therefore useful to provide exceptions for wildcard
+ // matches."
+ return false
+ }
+ found = true
+ }
+ }
+ return found
+}
+
+// String prints h as it would appear in a config file. Minor tweaks may be
+// present in the whitespace in the printed file.
+func (h *Host) String() string {
+ var buf bytes.Buffer
+ if h.implicit == false {
+ buf.WriteString(strings.Repeat(" ", int(h.leadingSpace)))
+ buf.WriteString("Host")
+ if h.hasEquals {
+ buf.WriteString(" = ")
+ } else {
+ buf.WriteString(" ")
+ }
+ for i, pat := range h.Patterns {
+ buf.WriteString(pat.String())
+ if i < len(h.Patterns)-1 {
+ buf.WriteString(" ")
+ }
+ }
+ if h.EOLComment != "" {
+ buf.WriteString(" #")
+ buf.WriteString(h.EOLComment)
+ }
+ buf.WriteByte('\n')
+ }
+ for i := range h.Nodes {
+ buf.WriteString(h.Nodes[i].String())
+ buf.WriteByte('\n')
+ }
+ return buf.String()
+}
+
+// Node represents a line in a Config.
+type Node interface {
+ Pos() Position
+ String() string
+}
+
+// KV is a line in the config file that contains a key, a value, and possibly
+// a comment.
+type KV struct {
+ Key string
+ Value string
+ Comment string
+ hasEquals bool
+ leadingSpace uint16 // Space before the key. TODO handle spaces vs tabs.
+ position Position
+}
+
+// Pos returns k's Position.
+func (k *KV) Pos() Position {
+ return k.position
+}
+
+// String prints k as it was parsed in the config file. There may be slight
+// changes to the whitespace between values.
+func (k *KV) String() string {
+ if k == nil {
+ return ""
+ }
+ equals := " "
+ if k.hasEquals {
+ equals = " = "
+ }
+ line := fmt.Sprintf("%s%s%s%s", strings.Repeat(" ", int(k.leadingSpace)), k.Key, equals, k.Value)
+ if k.Comment != "" {
+ line += " #" + k.Comment
+ }
+ return line
+}
+
+// Empty is a line in the config file that contains only whitespace or comments.
+type Empty struct {
+ Comment string
+ leadingSpace uint16 // TODO handle spaces vs tabs.
+ position Position
+}
+
+// Pos returns e's Position.
+func (e *Empty) Pos() Position {
+ return e.position
+}
+
+// String prints e as it was parsed in the config file.
+func (e *Empty) String() string {
+ if e == nil {
+ return ""
+ }
+ if e.Comment == "" {
+ return ""
+ }
+ return fmt.Sprintf("%s#%s", strings.Repeat(" ", int(e.leadingSpace)), e.Comment)
+}
+
+// Include holds the result of an Include directive, including the config files
+// that have been parsed as part of that directive. At most 5 levels of Include
+// statements will be parsed.
+type Include struct {
+ // Comment is the contents of any comment at the end of the Include
+ // statement.
+ Comment string
+ parsed bool
+ // an include directive can include several different files, and wildcards
+ directives []string
+
+ mu sync.Mutex
+ // 1:1 mapping between matches and keys in files array; matches preserves
+ // ordering
+ matches []string
+ // actual filenames are listed here
+ files map[string]*Config
+ leadingSpace uint16
+ position Position
+ depth uint8
+ hasEquals bool
+}
+
+const maxRecurseDepth = 5
+
+// ErrDepthExceeded is returned if too many Include directives are parsed.
+// Usually this indicates a recursive loop (an Include directive pointing to the
+// file it contains).
+var ErrDepthExceeded = errors.New("ssh_config: max recurse depth exceeded")
+
+func removeDups(arr []string) []string {
+ // Use map to record duplicates as we find them.
+ encountered := make(map[string]bool, len(arr))
+ result := make([]string, 0)
+
+ for v := range arr {
+ if encountered[arr[v]] == false {
+ encountered[arr[v]] = true
+ result = append(result, arr[v])
+ }
+ }
+ return result
+}
+
+// NewInclude creates a new Include with a list of file globs to include.
+// Configuration files are parsed greedily (e.g. as soon as this function runs).
+// Any error encountered while parsing nested configuration files will be
+// returned.
+func NewInclude(directives []string, hasEquals bool, pos Position, comment string, system bool, depth uint8) (*Include, error) {
+ if depth > maxRecurseDepth {
+ return nil, ErrDepthExceeded
+ }
+ inc := &Include{
+ Comment: comment,
+ directives: directives,
+ files: make(map[string]*Config),
+ position: pos,
+ leadingSpace: uint16(pos.Col) - 1,
+ depth: depth,
+ hasEquals: hasEquals,
+ }
+ // no need for inc.mu.Lock() since nothing else can access this inc
+ matches := make([]string, 0)
+ for i := range directives {
+ var path string
+ if filepath.IsAbs(directives[i]) {
+ path = directives[i]
+ } else if system {
+ path = filepath.Join("/etc/ssh", directives[i])
+ } else {
+ path = filepath.Join(homedir(), ".ssh", directives[i])
+ }
+ theseMatches, err := filepath.Glob(path)
+ if err != nil {
+ return nil, err
+ }
+ matches = append(matches, theseMatches...)
+ }
+ matches = removeDups(matches)
+ inc.matches = matches
+ for i := range matches {
+ config, err := parseWithDepth(matches[i], depth)
+ if err != nil {
+ return nil, err
+ }
+ inc.files[matches[i]] = config
+ }
+ return inc, nil
+}
+
+// Pos returns the position of the Include directive in the larger file.
+func (i *Include) Pos() Position {
+ return i.position
+}
+
+// Get finds the first value in the Include statement matching the alias and the
+// given key.
+func (inc *Include) Get(alias, key string) string {
+ inc.mu.Lock()
+ defer inc.mu.Unlock()
+ // TODO: we search files in any order which is not correct
+ for i := range inc.matches {
+ cfg := inc.files[inc.matches[i]]
+ if cfg == nil {
+ panic("nil cfg")
+ }
+ val, err := cfg.Get(alias, key)
+ if err == nil && val != "" {
+ return val
+ }
+ }
+ return ""
+}
+
+// String prints out a string representation of this Include directive. Note
+// included Config files are not printed as part of this representation.
+func (inc *Include) String() string {
+ equals := " "
+ if inc.hasEquals {
+ equals = " = "
+ }
+ line := fmt.Sprintf("%sInclude%s%s", strings.Repeat(" ", int(inc.leadingSpace)), equals, strings.Join(inc.directives, " "))
+ if inc.Comment != "" {
+ line += " #" + inc.Comment
+ }
+ return line
+}
+
+var matchAll *Pattern
+
+func init() {
+ var err error
+ matchAll, err = NewPattern("*")
+ if err != nil {
+ panic(err)
+ }
+}
+
+func newConfig() *Config {
+ return &Config{
+ Hosts: []*Host{
+ &Host{
+ implicit: true,
+ Patterns: []*Pattern{matchAll},
+ Nodes: make([]Node, 0),
+ },
+ },
+ depth: 0,
+ }
+}
--- /dev/null
+package ssh_config
+
+import (
+ "io"
+
+ buffruneio "github.com/pelletier/go-buffruneio"
+)
+
+// Define state functions
+type sshLexStateFn func() sshLexStateFn
+
+type sshLexer struct {
+ input *buffruneio.Reader // Textual source
+ buffer []rune // Runes composing the current token
+ tokens chan token
+ line uint32
+ col uint16
+ endbufferLine uint32
+ endbufferCol uint16
+}
+
+func (s *sshLexer) lexComment(previousState sshLexStateFn) sshLexStateFn {
+ return func() sshLexStateFn {
+ growingString := ""
+ for next := s.peek(); next != '\n' && next != eof; next = s.peek() {
+ if next == '\r' && s.follow("\r\n") {
+ break
+ }
+ growingString += string(next)
+ s.next()
+ }
+ s.emitWithValue(tokenComment, growingString)
+ s.skip()
+ return previousState
+ }
+}
+
+// lex the space after an equals sign in a function
+func (s *sshLexer) lexRspace() sshLexStateFn {
+ for {
+ next := s.peek()
+ if !isSpace(next) {
+ break
+ }
+ s.skip()
+ }
+ return s.lexRvalue
+}
+
+func (s *sshLexer) lexEquals() sshLexStateFn {
+ for {
+ next := s.peek()
+ if next == '=' {
+ s.emit(tokenEquals)
+ s.skip()
+ return s.lexRspace
+ }
+ // TODO error handling here; newline eof etc.
+ if !isSpace(next) {
+ break
+ }
+ s.skip()
+ }
+ return s.lexRvalue
+}
+
+func (s *sshLexer) lexKey() sshLexStateFn {
+ growingString := ""
+
+ for r := s.peek(); isKeyChar(r); r = s.peek() {
+ // simplified a lot here
+ if isSpace(r) || r == '=' {
+ s.emitWithValue(tokenKey, growingString)
+ s.skip()
+ return s.lexEquals
+ }
+ growingString += string(r)
+ s.next()
+ }
+ s.emitWithValue(tokenKey, growingString)
+ return s.lexEquals
+}
+
+func (s *sshLexer) lexRvalue() sshLexStateFn {
+ growingString := ""
+ for {
+ next := s.peek()
+ switch next {
+ case '\r':
+ if s.follow("\r\n") {
+ s.emitWithValue(tokenString, growingString)
+ s.skip()
+ return s.lexVoid
+ }
+ case '\n':
+ s.emitWithValue(tokenString, growingString)
+ s.skip()
+ return s.lexVoid
+ case '#':
+ s.emitWithValue(tokenString, growingString)
+ s.skip()
+ return s.lexComment(s.lexVoid)
+ case eof:
+ s.next()
+ }
+ if next == eof {
+ break
+ }
+ growingString += string(next)
+ s.next()
+ }
+ s.emit(tokenEOF)
+ return nil
+}
+
+func (s *sshLexer) read() rune {
+ r, _, err := s.input.ReadRune()
+ if err != nil {
+ panic(err)
+ }
+ if r == '\n' {
+ s.endbufferLine++
+ s.endbufferCol = 1
+ } else {
+ s.endbufferCol++
+ }
+ return r
+}
+
+func (s *sshLexer) next() rune {
+ r := s.read()
+
+ if r != eof {
+ s.buffer = append(s.buffer, r)
+ }
+ return r
+}
+
+func (s *sshLexer) lexVoid() sshLexStateFn {
+ for {
+ next := s.peek()
+ switch next {
+ case '#':
+ s.skip()
+ return s.lexComment(s.lexVoid)
+ case '\r':
+ fallthrough
+ case '\n':
+ s.emit(tokenEmptyLine)
+ s.skip()
+ continue
+ }
+
+ if isSpace(next) {
+ s.skip()
+ }
+
+ if isKeyStartChar(next) {
+ return s.lexKey
+ }
+
+ // removed IsKeyStartChar and lexKey. probably will need to readd
+
+ if next == eof {
+ s.next()
+ break
+ }
+ }
+
+ s.emit(tokenEOF)
+ return nil
+}
+
+func (s *sshLexer) ignore() {
+ s.buffer = make([]rune, 0)
+ s.line = s.endbufferLine
+ s.col = s.endbufferCol
+}
+
+func (s *sshLexer) skip() {
+ s.next()
+ s.ignore()
+}
+
+func (s *sshLexer) emit(t tokenType) {
+ s.emitWithValue(t, string(s.buffer))
+}
+
+func (s *sshLexer) emitWithValue(t tokenType, value string) {
+ tok := token{
+ Position: Position{s.line, s.col},
+ typ: t,
+ val: value,
+ }
+ s.tokens <- tok
+ s.ignore()
+}
+
+func (s *sshLexer) peek() rune {
+ r, _, err := s.input.ReadRune()
+ if err != nil {
+ panic(err)
+ }
+ s.input.UnreadRune()
+ return r
+}
+
+func (s *sshLexer) follow(next string) bool {
+ for _, expectedRune := range next {
+ r, _, err := s.input.ReadRune()
+ defer s.input.UnreadRune()
+ if err != nil {
+ panic(err)
+ }
+ if expectedRune != r {
+ return false
+ }
+ }
+ return true
+}
+
+func (s *sshLexer) run() {
+ for state := s.lexVoid; state != nil; {
+ state = state()
+ }
+ close(s.tokens)
+}
+
+func lexSSH(input io.Reader) chan token {
+ bufferedInput := buffruneio.NewReader(input)
+ l := &sshLexer{
+ input: bufferedInput,
+ tokens: make(chan token),
+ line: 1,
+ col: 1,
+ endbufferLine: 1,
+ endbufferCol: 1,
+ }
+ go l.run()
+ return l.tokens
+}
--- /dev/null
+package ssh_config
+
+import (
+ "fmt"
+ "strings"
+)
+
+type sshParser struct {
+ flow chan token
+ config *Config
+ tokensBuffer []token
+ currentTable []string
+ seenTableKeys []string
+ // /etc/ssh parser or local parser - used to find the default for relative
+ // filepaths in the Include directive
+ system bool
+ depth uint8
+}
+
+type sshParserStateFn func() sshParserStateFn
+
+// Formats and panics an error message based on a token
+func (p *sshParser) raiseErrorf(tok *token, msg string, args ...interface{}) {
+ // TODO this format is ugly
+ panic(tok.Position.String() + ": " + fmt.Sprintf(msg, args...))
+}
+
+func (p *sshParser) raiseError(tok *token, err error) {
+ if err == ErrDepthExceeded {
+ panic(err)
+ }
+ // TODO this format is ugly
+ panic(tok.Position.String() + ": " + err.Error())
+}
+
+func (p *sshParser) run() {
+ for state := p.parseStart; state != nil; {
+ state = state()
+ }
+}
+
+func (p *sshParser) peek() *token {
+ if len(p.tokensBuffer) != 0 {
+ return &(p.tokensBuffer[0])
+ }
+
+ tok, ok := <-p.flow
+ if !ok {
+ return nil
+ }
+ p.tokensBuffer = append(p.tokensBuffer, tok)
+ return &tok
+}
+
+func (p *sshParser) getToken() *token {
+ if len(p.tokensBuffer) != 0 {
+ tok := p.tokensBuffer[0]
+ p.tokensBuffer = p.tokensBuffer[1:]
+ return &tok
+ }
+ tok, ok := <-p.flow
+ if !ok {
+ return nil
+ }
+ return &tok
+}
+
+func (p *sshParser) parseStart() sshParserStateFn {
+ tok := p.peek()
+
+ // end of stream, parsing is finished
+ if tok == nil {
+ return nil
+ }
+
+ switch tok.typ {
+ case tokenComment, tokenEmptyLine:
+ return p.parseComment
+ case tokenKey:
+ return p.parseKV
+ case tokenEOF:
+ return nil
+ default:
+ p.raiseErrorf(tok, fmt.Sprintf("unexpected token %q\n", tok))
+ }
+ return nil
+}
+
+func (p *sshParser) parseKV() sshParserStateFn {
+ key := p.getToken()
+ hasEquals := false
+ val := p.getToken()
+ if val.typ == tokenEquals {
+ hasEquals = true
+ val = p.getToken()
+ }
+ comment := ""
+ tok := p.peek()
+ if tok == nil {
+ tok = &token{typ: tokenEOF}
+ }
+ if tok.typ == tokenComment && tok.Position.Line == val.Position.Line {
+ tok = p.getToken()
+ comment = tok.val
+ }
+ if strings.ToLower(key.val) == "match" {
+ // https://github.com/kevinburke/ssh_config/issues/6
+ p.raiseErrorf(val, "ssh_config: Match directive parsing is unsupported")
+ return nil
+ }
+ if strings.ToLower(key.val) == "host" {
+ strPatterns := strings.Split(val.val, " ")
+ patterns := make([]*Pattern, 0)
+ for i := range strPatterns {
+ if strPatterns[i] == "" {
+ continue
+ }
+ pat, err := NewPattern(strPatterns[i])
+ if err != nil {
+ p.raiseErrorf(val, "Invalid host pattern: %v", err)
+ return nil
+ }
+ patterns = append(patterns, pat)
+ }
+ p.config.Hosts = append(p.config.Hosts, &Host{
+ Patterns: patterns,
+ Nodes: make([]Node, 0),
+ EOLComment: comment,
+ hasEquals: hasEquals,
+ })
+ return p.parseStart
+ }
+ lastHost := p.config.Hosts[len(p.config.Hosts)-1]
+ if strings.ToLower(key.val) == "include" {
+ inc, err := NewInclude(strings.Split(val.val, " "), hasEquals, key.Position, comment, p.system, p.depth+1)
+ if err == ErrDepthExceeded {
+ p.raiseError(val, err)
+ return nil
+ }
+ if err != nil {
+ p.raiseErrorf(val, "Error parsing Include directive: %v", err)
+ return nil
+ }
+ lastHost.Nodes = append(lastHost.Nodes, inc)
+ return p.parseStart
+ }
+ kv := &KV{
+ Key: key.val,
+ Value: val.val,
+ Comment: comment,
+ hasEquals: hasEquals,
+ leadingSpace: uint16(key.Position.Col) - 1,
+ position: key.Position,
+ }
+ lastHost.Nodes = append(lastHost.Nodes, kv)
+ return p.parseStart
+}
+
+func (p *sshParser) parseComment() sshParserStateFn {
+ comment := p.getToken()
+ lastHost := p.config.Hosts[len(p.config.Hosts)-1]
+ lastHost.Nodes = append(lastHost.Nodes, &Empty{
+ Comment: comment.val,
+ // account for the "#" as well
+ leadingSpace: comment.Position.Col - 2,
+ position: comment.Position,
+ })
+ return p.parseStart
+}
+
+func parseSSH(flow chan token, system bool, depth uint8) *Config {
+ result := newConfig()
+ result.position = Position{1, 1}
+ parser := &sshParser{
+ flow: flow,
+ config: result,
+ tokensBuffer: make([]token, 0),
+ currentTable: make([]string, 0),
+ seenTableKeys: make([]string, 0),
+ system: system,
+ depth: depth,
+ }
+ parser.run()
+ return result
+}
--- /dev/null
+package ssh_config
+
+import "fmt"
+
+// Position of a document element within a SSH document.
+//
+// Line and Col are both 1-indexed positions for the element's line number and
+// column number, respectively. Values of zero or less will cause Invalid(),
+// to return true.
+type Position struct {
+ Line uint32 // line within the document
+ Col uint16 // column within the line
+}
+
+// String representation of the position.
+// Displays 1-indexed line and column numbers.
+func (p Position) String() string {
+ return fmt.Sprintf("(%d, %d)", p.Line, p.Col)
+}
+
+// Invalid returns whether or not the position is valid (i.e. with negative or
+// null values)
+func (p Position) Invalid() bool {
+ return p.Line <= 0 || p.Col <= 0
+}
--- /dev/null
+package ssh_config
+
+import "fmt"
+
+type token struct {
+ Position
+ typ tokenType
+ val string
+}
+
+func (t token) String() string {
+ switch t.typ {
+ case tokenEOF:
+ return "EOF"
+ }
+ return fmt.Sprintf("%q", t.val)
+}
+
+type tokenType int
+
+const (
+ eof = -(iota + 1)
+)
+
+const (
+ tokenError tokenType = iota
+ tokenEOF
+ tokenEmptyLine
+ tokenComment
+ tokenKey
+ tokenEquals
+ tokenString
+)
+
+func isSpace(r rune) bool {
+ return r == ' ' || r == '\t'
+}
+
+func isKeyStartChar(r rune) bool {
+ return !(isSpace(r) || r == '\r' || r == '\n' || r == eof)
+}
+
+// I'm not sure that this is correct
+func isKeyChar(r rune) bool {
+ // Keys start with the first character that isn't whitespace or [ and end
+ // with the last non-whitespace character before the equals sign. Keys
+ // cannot contain a # character."
+ return !(r == '\r' || r == '\n' || r == eof || r == '=')
+}
--- /dev/null
+package ssh_config
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+)
+
+// Default returns the default value for the given keyword, for example "22" if
+// the keyword is "Port". Default returns the empty string if the keyword has no
+// default, or if the keyword is unknown. Keyword matching is case-insensitive.
+//
+// Default values are provided by OpenSSH_7.4p1 on a Mac.
+func Default(keyword string) string {
+ return defaults[strings.ToLower(keyword)]
+}
+
+// Arguments where the value must be "yes" or "no" and *only* yes or no.
+var yesnos = map[string]bool{
+ strings.ToLower("BatchMode"): true,
+ strings.ToLower("CanonicalizeFallbackLocal"): true,
+ strings.ToLower("ChallengeResponseAuthentication"): true,
+ strings.ToLower("CheckHostIP"): true,
+ strings.ToLower("ClearAllForwardings"): true,
+ strings.ToLower("Compression"): true,
+ strings.ToLower("EnableSSHKeysign"): true,
+ strings.ToLower("ExitOnForwardFailure"): true,
+ strings.ToLower("ForwardAgent"): true,
+ strings.ToLower("ForwardX11"): true,
+ strings.ToLower("ForwardX11Trusted"): true,
+ strings.ToLower("GatewayPorts"): true,
+ strings.ToLower("GSSAPIAuthentication"): true,
+ strings.ToLower("GSSAPIDelegateCredentials"): true,
+ strings.ToLower("HostbasedAuthentication"): true,
+ strings.ToLower("IdentitiesOnly"): true,
+ strings.ToLower("KbdInteractiveAuthentication"): true,
+ strings.ToLower("NoHostAuthenticationForLocalhost"): true,
+ strings.ToLower("PasswordAuthentication"): true,
+ strings.ToLower("PermitLocalCommand"): true,
+ strings.ToLower("PubkeyAuthentication"): true,
+ strings.ToLower("RhostsRSAAuthentication"): true,
+ strings.ToLower("RSAAuthentication"): true,
+ strings.ToLower("StreamLocalBindUnlink"): true,
+ strings.ToLower("TCPKeepAlive"): true,
+ strings.ToLower("UseKeychain"): true,
+ strings.ToLower("UsePrivilegedPort"): true,
+ strings.ToLower("VisualHostKey"): true,
+}
+
+var uints = map[string]bool{
+ strings.ToLower("CanonicalizeMaxDots"): true,
+ strings.ToLower("CompressionLevel"): true, // 1 to 9
+ strings.ToLower("ConnectionAttempts"): true,
+ strings.ToLower("ConnectTimeout"): true,
+ strings.ToLower("NumberOfPasswordPrompts"): true,
+ strings.ToLower("Port"): true,
+ strings.ToLower("ServerAliveCountMax"): true,
+ strings.ToLower("ServerAliveInterval"): true,
+}
+
+func mustBeYesOrNo(lkey string) bool {
+ return yesnos[lkey]
+}
+
+func mustBeUint(lkey string) bool {
+ return uints[lkey]
+}
+
+func validate(key, val string) error {
+ lkey := strings.ToLower(key)
+ if mustBeYesOrNo(lkey) && (val != "yes" && val != "no") {
+ return fmt.Errorf("ssh_config: value for key %q must be 'yes' or 'no', got %q", key, val)
+ }
+ if mustBeUint(lkey) {
+ _, err := strconv.ParseUint(val, 10, 64)
+ if err != nil {
+ return fmt.Errorf("ssh_config: %v", err)
+ }
+ }
+ return nil
+}
+
+var defaults = map[string]string{
+ strings.ToLower("AddKeysToAgent"): "no",
+ strings.ToLower("AddressFamily"): "any",
+ strings.ToLower("BatchMode"): "no",
+ strings.ToLower("CanonicalizeFallbackLocal"): "yes",
+ strings.ToLower("CanonicalizeHostname"): "no",
+ strings.ToLower("CanonicalizeMaxDots"): "1",
+ strings.ToLower("ChallengeResponseAuthentication"): "yes",
+ strings.ToLower("CheckHostIP"): "yes",
+ // TODO is this still the correct cipher
+ strings.ToLower("Cipher"): "3des",
+ strings.ToLower("Ciphers"): "chacha20-poly1305@openssh.com,aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,aes128-cbc,aes192-cbc,aes256-cbc",
+ strings.ToLower("ClearAllForwardings"): "no",
+ strings.ToLower("Compression"): "no",
+ strings.ToLower("CompressionLevel"): "6",
+ strings.ToLower("ConnectionAttempts"): "1",
+ strings.ToLower("ControlMaster"): "no",
+ strings.ToLower("EnableSSHKeysign"): "no",
+ strings.ToLower("EscapeChar"): "~",
+ strings.ToLower("ExitOnForwardFailure"): "no",
+ strings.ToLower("FingerprintHash"): "sha256",
+ strings.ToLower("ForwardAgent"): "no",
+ strings.ToLower("ForwardX11"): "no",
+ strings.ToLower("ForwardX11Timeout"): "20m",
+ strings.ToLower("ForwardX11Trusted"): "no",
+ strings.ToLower("GatewayPorts"): "no",
+ strings.ToLower("GlobalKnownHostsFile"): "/etc/ssh/ssh_known_hosts /etc/ssh/ssh_known_hosts2",
+ strings.ToLower("GSSAPIAuthentication"): "no",
+ strings.ToLower("GSSAPIDelegateCredentials"): "no",
+ strings.ToLower("HashKnownHosts"): "no",
+ strings.ToLower("HostbasedAuthentication"): "no",
+
+ strings.ToLower("HostbasedKeyTypes"): "ecdsa-sha2-nistp256-cert-v01@openssh.com,ecdsa-sha2-nistp384-cert-v01@openssh.com,ecdsa-sha2-nistp521-cert-v01@openssh.com,ssh-ed25519-cert-v01@openssh.com,ssh-rsa-cert-v01@openssh.com,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,ssh-ed25519,ssh-rsa",
+ strings.ToLower("HostKeyAlgorithms"): "ecdsa-sha2-nistp256-cert-v01@openssh.com,ecdsa-sha2-nistp384-cert-v01@openssh.com,ecdsa-sha2-nistp521-cert-v01@openssh.com,ssh-ed25519-cert-v01@openssh.com,ssh-rsa-cert-v01@openssh.com,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,ssh-ed25519,ssh-rsa",
+ // HostName has a dynamic default (the value passed at the command line).
+
+ strings.ToLower("IdentitiesOnly"): "no",
+ strings.ToLower("IdentityFile"): "~/.ssh/identity",
+
+ // IPQoS has a dynamic default based on interactive or non-interactive
+ // sessions.
+
+ strings.ToLower("KbdInteractiveAuthentication"): "yes",
+
+ strings.ToLower("KexAlgorithms"): "curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group-exchange-sha1,diffie-hellman-group14-sha1",
+ strings.ToLower("LogLevel"): "INFO",
+ strings.ToLower("MACs"): "umac-64-etm@openssh.com,umac-128-etm@openssh.com,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com,hmac-sha1-etm@openssh.com,umac-64@openssh.com,umac-128@openssh.com,hmac-sha2-256,hmac-sha2-512,hmac-sha1",
+
+ strings.ToLower("NoHostAuthenticationForLocalhost"): "no",
+ strings.ToLower("NumberOfPasswordPrompts"): "3",
+ strings.ToLower("PasswordAuthentication"): "yes",
+ strings.ToLower("PermitLocalCommand"): "no",
+ strings.ToLower("Port"): "22",
+
+ strings.ToLower("PreferredAuthentications"): "gssapi-with-mic,hostbased,publickey,keyboard-interactive,password",
+ strings.ToLower("Protocol"): "2",
+ strings.ToLower("ProxyUseFdpass"): "no",
+ strings.ToLower("PubkeyAcceptedKeyTypes"): "ecdsa-sha2-nistp256-cert-v01@openssh.com,ecdsa-sha2-nistp384-cert-v01@openssh.com,ecdsa-sha2-nistp521-cert-v01@openssh.com,ssh-ed25519-cert-v01@openssh.com,ssh-rsa-cert-v01@openssh.com,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,ssh-ed25519,ssh-rsa",
+ strings.ToLower("PubkeyAuthentication"): "yes",
+ strings.ToLower("RekeyLimit"): "default none",
+ strings.ToLower("RhostsRSAAuthentication"): "no",
+ strings.ToLower("RSAAuthentication"): "yes",
+
+ strings.ToLower("ServerAliveCountMax"): "3",
+ strings.ToLower("ServerAliveInterval"): "0",
+ strings.ToLower("StreamLocalBindMask"): "0177",
+ strings.ToLower("StreamLocalBindUnlink"): "no",
+ strings.ToLower("StrictHostKeyChecking"): "ask",
+ strings.ToLower("TCPKeepAlive"): "yes",
+ strings.ToLower("Tunnel"): "no",
+ strings.ToLower("TunnelDevice"): "any:any",
+ strings.ToLower("UpdateHostKeys"): "no",
+ strings.ToLower("UseKeychain"): "no",
+ strings.ToLower("UsePrivilegedPort"): "no",
+
+ strings.ToLower("UserKnownHostsFile"): "~/.ssh/known_hosts ~/.ssh/known_hosts2",
+ strings.ToLower("VerifyHostKeyDNS"): "no",
+ strings.ToLower("VisualHostKey"): "no",
+ strings.ToLower("XAuthLocation"): "/usr/X11R6/bin/xauth",
+}
--- /dev/null
+The MIT License (MIT)
+
+Copyright (c) 2013 Mitchell Hashimoto
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
--- /dev/null
+package homedir
+
+import (
+ "bytes"
+ "errors"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+)
+
+// DisableCache will disable caching of the home directory. Caching is enabled
+// by default.
+var DisableCache bool
+
+var homedirCache string
+var cacheLock sync.RWMutex
+
+// Dir returns the home directory for the executing user.
+//
+// This uses an OS-specific method for discovering the home directory.
+// An error is returned if a home directory cannot be detected.
+func Dir() (string, error) {
+ if !DisableCache {
+ cacheLock.RLock()
+ cached := homedirCache
+ cacheLock.RUnlock()
+ if cached != "" {
+ return cached, nil
+ }
+ }
+
+ cacheLock.Lock()
+ defer cacheLock.Unlock()
+
+ var result string
+ var err error
+ if runtime.GOOS == "windows" {
+ result, err = dirWindows()
+ } else {
+ // Unix-like system, so just assume Unix
+ result, err = dirUnix()
+ }
+
+ if err != nil {
+ return "", err
+ }
+ homedirCache = result
+ return result, nil
+}
+
+// Expand expands the path to include the home directory if the path
+// is prefixed with `~`. If it isn't prefixed with `~`, the path is
+// returned as-is.
+func Expand(path string) (string, error) {
+ if len(path) == 0 {
+ return path, nil
+ }
+
+ if path[0] != '~' {
+ return path, nil
+ }
+
+ if len(path) > 1 && path[1] != '/' && path[1] != '\\' {
+ return "", errors.New("cannot expand user-specific home dir")
+ }
+
+ dir, err := Dir()
+ if err != nil {
+ return "", err
+ }
+
+ return filepath.Join(dir, path[1:]), nil
+}
+
+func dirUnix() (string, error) {
+ homeEnv := "HOME"
+ if runtime.GOOS == "plan9" {
+ // On plan9, env vars are lowercase.
+ homeEnv = "home"
+ }
+
+ // First prefer the HOME environmental variable
+ if home := os.Getenv(homeEnv); home != "" {
+ return home, nil
+ }
+
+ var stdout bytes.Buffer
+
+ // If that fails, try OS specific commands
+ if runtime.GOOS == "darwin" {
+ cmd := exec.Command("sh", "-c", `dscl -q . -read /Users/"$(whoami)" NFSHomeDirectory | sed 's/^[^ ]*: //'`)
+ cmd.Stdout = &stdout
+ if err := cmd.Run(); err == nil {
+ result := strings.TrimSpace(stdout.String())
+ if result != "" {
+ return result, nil
+ }
+ }
+ } else {
+ cmd := exec.Command("getent", "passwd", strconv.Itoa(os.Getuid()))
+ cmd.Stdout = &stdout
+ if err := cmd.Run(); err != nil {
+ // If the error is ErrNotFound, we ignore it. Otherwise, return it.
+ if err != exec.ErrNotFound {
+ return "", err
+ }
+ } else {
+ if passwd := strings.TrimSpace(stdout.String()); passwd != "" {
+ // username:password:uid:gid:gecos:home:shell
+ passwdParts := strings.SplitN(passwd, ":", 7)
+ if len(passwdParts) > 5 {
+ return passwdParts[5], nil
+ }
+ }
+ }
+ }
+
+ // If all else fails, try the shell
+ stdout.Reset()
+ cmd := exec.Command("sh", "-c", "cd && pwd")
+ cmd.Stdout = &stdout
+ if err := cmd.Run(); err != nil {
+ return "", err
+ }
+
+ result := strings.TrimSpace(stdout.String())
+ if result == "" {
+ return "", errors.New("blank output when reading home directory")
+ }
+
+ return result, nil
+}
+
+func dirWindows() (string, error) {
+ // First prefer the HOME environmental variable
+ if home := os.Getenv("HOME"); home != "" {
+ return home, nil
+ }
+
+ // Prefer standard environment variable USERPROFILE
+ if home := os.Getenv("USERPROFILE"); home != "" {
+ return home, nil
+ }
+
+ drive := os.Getenv("HOMEDRIVE")
+ path := os.Getenv("HOMEPATH")
+ home := drive + path
+ if drive == "" || path == "" {
+ return "", errors.New("HOMEDRIVE, HOMEPATH, or USERPROFILE are blank")
+ }
+
+ return home, nil
+}
--- /dev/null
+// Package buffruneio is a wrapper around bufio to provide buffered runes access with unlimited unreads.
+package buffruneio
+
+import (
+ "bufio"
+ "container/list"
+ "errors"
+ "io"
+)
+
+// Rune to indicate end of file.
+const (
+ EOF = -(iota + 1)
+)
+
+// ErrNoRuneToUnread is returned by UnreadRune() when the read index is already at the beginning of the buffer.
+var ErrNoRuneToUnread = errors.New("no rune to unwind")
+
+// Reader implements runes buffering for an io.Reader object.
+type Reader struct {
+ buffer *list.List
+ current *list.Element
+ input *bufio.Reader
+}
+
+// NewReader returns a new Reader.
+func NewReader(rd io.Reader) *Reader {
+ return &Reader{
+ buffer: list.New(),
+ input: bufio.NewReader(rd),
+ }
+}
+
+type runeWithSize struct {
+ r rune
+ size int
+}
+
+func (rd *Reader) feedBuffer() error {
+ r, size, err := rd.input.ReadRune()
+
+ if err != nil {
+ if err != io.EOF {
+ return err
+ }
+ r = EOF
+ }
+
+ newRuneWithSize := runeWithSize{r, size}
+
+ rd.buffer.PushBack(newRuneWithSize)
+ if rd.current == nil {
+ rd.current = rd.buffer.Back()
+ }
+ return nil
+}
+
+// ReadRune reads the next rune from buffer, or from the underlying reader if needed.
+func (rd *Reader) ReadRune() (rune, int, error) {
+ if rd.current == rd.buffer.Back() || rd.current == nil {
+ err := rd.feedBuffer()
+ if err != nil {
+ return EOF, 0, err
+ }
+ }
+
+ runeWithSize := rd.current.Value.(runeWithSize)
+ rd.current = rd.current.Next()
+ return runeWithSize.r, runeWithSize.size, nil
+}
+
+// UnreadRune pushes back the previously read rune in the buffer, extending it if needed.
+func (rd *Reader) UnreadRune() error {
+ if rd.current == rd.buffer.Front() {
+ return ErrNoRuneToUnread
+ }
+ if rd.current == nil {
+ rd.current = rd.buffer.Back()
+ } else {
+ rd.current = rd.current.Prev()
+ }
+ return nil
+}
+
+// Forget removes runes stored before the current stream position index.
+func (rd *Reader) Forget() {
+ if rd.current == nil {
+ rd.current = rd.buffer.Back()
+ }
+ for ; rd.current != rd.buffer.Front(); rd.buffer.Remove(rd.current.Prev()) {
+ }
+}
+
+// PeekRune returns at most the next n runes, reading from the uderlying source if
+// needed. Does not move the current index. It includes EOF if reached.
+func (rd *Reader) PeekRunes(n int) []rune {
+ res := make([]rune, 0, n)
+ cursor := rd.current
+ for i := 0; i < n; i++ {
+ if cursor == nil {
+ err := rd.feedBuffer()
+ if err != nil {
+ return res
+ }
+ cursor = rd.buffer.Back()
+ }
+ if cursor != nil {
+ r := cursor.Value.(runeWithSize).r
+ res = append(res, r)
+ if r == EOF {
+ return res
+ }
+ cursor = cursor.Next()
+ }
+ }
+ return res
+}
--- /dev/null
+Copyright (c) 2012 Péter Surányi. Portions Copyright (c) 2009 The Go
+Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--- /dev/null
+// Package gcfg reads "INI-style" text-based configuration files with
+// "name=value" pairs grouped into sections (gcfg files).
+//
+// This package is still a work in progress; see the sections below for planned
+// changes.
+//
+// Syntax
+//
+// The syntax is based on that used by git config:
+// http://git-scm.com/docs/git-config#_syntax .
+// There are some (planned) differences compared to the git config format:
+// - improve data portability:
+// - must be encoded in UTF-8 (for now) and must not contain the 0 byte
+// - include and "path" type is not supported
+// (path type may be implementable as a user-defined type)
+// - internationalization
+// - section and variable names can contain unicode letters, unicode digits
+// (as defined in http://golang.org/ref/spec#Characters ) and hyphens
+// (U+002D), starting with a unicode letter
+// - disallow potentially ambiguous or misleading definitions:
+// - `[sec.sub]` format is not allowed (deprecated in gitconfig)
+// - `[sec ""]` is not allowed
+// - use `[sec]` for section name "sec" and empty subsection name
+// - (planned) within a single file, definitions must be contiguous for each:
+// - section: '[secA]' -> '[secB]' -> '[secA]' is an error
+// - subsection: '[sec "A"]' -> '[sec "B"]' -> '[sec "A"]' is an error
+// - multivalued variable: 'multi=a' -> 'other=x' -> 'multi=b' is an error
+//
+// Data structure
+//
+// The functions in this package read values into a user-defined struct.
+// Each section corresponds to a struct field in the config struct, and each
+// variable in a section corresponds to a data field in the section struct.
+// The mapping of each section or variable name to fields is done either based
+// on the "gcfg" struct tag or by matching the name of the section or variable,
+// ignoring case. In the latter case, hyphens '-' in section and variable names
+// correspond to underscores '_' in field names.
+// Fields must be exported; to use a section or variable name starting with a
+// letter that is neither upper- or lower-case, prefix the field name with 'X'.
+// (See https://code.google.com/p/go/issues/detail?id=5763#c4 .)
+//
+// For sections with subsections, the corresponding field in config must be a
+// map, rather than a struct, with string keys and pointer-to-struct values.
+// Values for subsection variables are stored in the map with the subsection
+// name used as the map key.
+// (Note that unlike section and variable names, subsection names are case
+// sensitive.)
+// When using a map, and there is a section with the same section name but
+// without a subsection name, its values are stored with the empty string used
+// as the key.
+// It is possible to provide default values for subsections in the section
+// "default-<sectionname>" (or by setting values in the corresponding struct
+// field "Default_<sectionname>").
+//
+// The functions in this package panic if config is not a pointer to a struct,
+// or when a field is not of a suitable type (either a struct or a map with
+// string keys and pointer-to-struct values).
+//
+// Parsing of values
+//
+// The section structs in the config struct may contain single-valued or
+// multi-valued variables. Variables of unnamed slice type (that is, a type
+// starting with `[]`) are treated as multi-value; all others (including named
+// slice types) are treated as single-valued variables.
+//
+// Single-valued variables are handled based on the type as follows.
+// Unnamed pointer types (that is, types starting with `*`) are dereferenced,
+// and if necessary, a new instance is allocated.
+//
+// For types implementing the encoding.TextUnmarshaler interface, the
+// UnmarshalText method is used to set the value. Implementing this method is
+// the recommended way for parsing user-defined types.
+//
+// For fields of string kind, the value string is assigned to the field, after
+// unquoting and unescaping as needed.
+// For fields of bool kind, the field is set to true if the value is "true",
+// "yes", "on" or "1", and set to false if the value is "false", "no", "off" or
+// "0", ignoring case. In addition, single-valued bool fields can be specified
+// with a "blank" value (variable name without equals sign and value); in such
+// case the value is set to true.
+//
+// Predefined integer types [u]int(|8|16|32|64) and big.Int are parsed as
+// decimal or hexadecimal (if having '0x' prefix). (This is to prevent
+// unintuitively handling zero-padded numbers as octal.) Other types having
+// [u]int* as the underlying type, such as os.FileMode and uintptr allow
+// decimal, hexadecimal, or octal values.
+// Parsing mode for integer types can be overridden using the struct tag option
+// ",int=mode" where mode is a combination of the 'd', 'h', and 'o' characters
+// (each standing for decimal, hexadecimal, and octal, respectively.)
+//
+// All other types are parsed using fmt.Sscanf with the "%v" verb.
+//
+// For multi-valued variables, each individual value is parsed as above and
+// appended to the slice. If the first value is specified as a "blank" value
+// (variable name without equals sign and value), a new slice is allocated;
+// that is any values previously set in the slice will be ignored.
+//
+// The types subpackage for provides helpers for parsing "enum-like" and integer
+// types.
+//
+// Error handling
+//
+// There are 3 types of errors:
+//
+// - programmer errors / panics:
+// - invalid configuration structure
+// - data errors:
+// - fatal errors:
+// - invalid configuration syntax
+// - warnings:
+// - data that doesn't belong to any part of the config structure
+//
+// Programmer errors trigger panics. These are should be fixed by the programmer
+// before releasing code that uses gcfg.
+//
+// Data errors cause gcfg to return a non-nil error value. This includes the
+// case when there are extra unknown key-value definitions in the configuration
+// data (extra data).
+// However, in some occasions it is desirable to be able to proceed in
+// situations when the only data error is that of extra data.
+// These errors are handled at a different (warning) priority and can be
+// filtered out programmatically. To ignore extra data warnings, wrap the
+// gcfg.Read*Into invocation into a call to gcfg.FatalOnly.
+//
+// TODO
+//
+// The following is a list of changes under consideration:
+// - documentation
+// - self-contained syntax documentation
+// - more practical examples
+// - move TODOs to issue tracker (eventually)
+// - syntax
+// - reconsider valid escape sequences
+// (gitconfig doesn't support \r in value, \t in subsection name, etc.)
+// - reading / parsing gcfg files
+// - define internal representation structure
+// - support multiple inputs (readers, strings, files)
+// - support declaring encoding (?)
+// - support varying fields sets for subsections (?)
+// - writing gcfg files
+// - error handling
+// - make error context accessible programmatically?
+// - limit input size?
+//
+package gcfg // import "github.com/src-d/gcfg"
--- /dev/null
+package gcfg
+
+import (
+ "gopkg.in/warnings.v0"
+)
+
+// FatalOnly filters the results of a Read*Into invocation and returns only
+// fatal errors. That is, errors (warnings) indicating data for unknown
+// sections / variables is ignored. Example invocation:
+//
+// err := gcfg.FatalOnly(gcfg.ReadFileInto(&cfg, configFile))
+// if err != nil {
+// ...
+//
+func FatalOnly(err error) error {
+ return warnings.FatalOnly(err)
+}
+
+func isFatal(err error) bool {
+ _, ok := err.(extraData)
+ return !ok
+}
+
+type extraData struct {
+ section string
+ subsection *string
+ variable *string
+}
+
+func (e extraData) Error() string {
+ s := "can't store data at section \"" + e.section + "\""
+ if e.subsection != nil {
+ s += ", subsection \"" + *e.subsection + "\""
+ }
+ if e.variable != nil {
+ s += ", variable \"" + *e.variable + "\""
+ }
+ return s
+}
+
+var _ error = extraData{}
--- /dev/null
+// +build !go1.2
+
+package gcfg
+
+type textUnmarshaler interface {
+ UnmarshalText(text []byte) error
+}
--- /dev/null
+// +build go1.2
+
+package gcfg
+
+import (
+ "encoding"
+)
+
+type textUnmarshaler encoding.TextUnmarshaler
--- /dev/null
+package gcfg
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "strings"
+
+ "github.com/src-d/gcfg/scanner"
+ "github.com/src-d/gcfg/token"
+ "gopkg.in/warnings.v0"
+)
+
+var unescape = map[rune]rune{'\\': '\\', '"': '"', 'n': '\n', 't': '\t', 'b': '\b'}
+
+// no error: invalid literals should be caught by scanner
+func unquote(s string) string {
+ u, q, esc := make([]rune, 0, len(s)), false, false
+ for _, c := range s {
+ if esc {
+ uc, ok := unescape[c]
+ switch {
+ case ok:
+ u = append(u, uc)
+ fallthrough
+ case !q && c == '\n':
+ esc = false
+ continue
+ }
+ panic("invalid escape sequence")
+ }
+ switch c {
+ case '"':
+ q = !q
+ case '\\':
+ esc = true
+ default:
+ u = append(u, c)
+ }
+ }
+ if q {
+ panic("missing end quote")
+ }
+ if esc {
+ panic("invalid escape sequence")
+ }
+ return string(u)
+}
+
+func read(c *warnings.Collector, callback func(string, string, string, string, bool) error,
+ fset *token.FileSet, file *token.File, src []byte) error {
+ //
+ var s scanner.Scanner
+ var errs scanner.ErrorList
+ s.Init(file, src, func(p token.Position, m string) { errs.Add(p, m) }, 0)
+ sect, sectsub := "", ""
+ pos, tok, lit := s.Scan()
+ errfn := func(msg string) error {
+ return fmt.Errorf("%s: %s", fset.Position(pos), msg)
+ }
+ for {
+ if errs.Len() > 0 {
+ if err := c.Collect(errs.Err()); err != nil {
+ return err
+ }
+ }
+ switch tok {
+ case token.EOF:
+ return nil
+ case token.EOL, token.COMMENT:
+ pos, tok, lit = s.Scan()
+ case token.LBRACK:
+ pos, tok, lit = s.Scan()
+ if errs.Len() > 0 {
+ if err := c.Collect(errs.Err()); err != nil {
+ return err
+ }
+ }
+ if tok != token.IDENT {
+ if err := c.Collect(errfn("expected section name")); err != nil {
+ return err
+ }
+ }
+ sect, sectsub = lit, ""
+ pos, tok, lit = s.Scan()
+ if errs.Len() > 0 {
+ if err := c.Collect(errs.Err()); err != nil {
+ return err
+ }
+ }
+ if tok == token.STRING {
+ sectsub = unquote(lit)
+ if sectsub == "" {
+ if err := c.Collect(errfn("empty subsection name")); err != nil {
+ return err
+ }
+ }
+ pos, tok, lit = s.Scan()
+ if errs.Len() > 0 {
+ if err := c.Collect(errs.Err()); err != nil {
+ return err
+ }
+ }
+ }
+ if tok != token.RBRACK {
+ if sectsub == "" {
+ if err := c.Collect(errfn("expected subsection name or right bracket")); err != nil {
+ return err
+ }
+ }
+ if err := c.Collect(errfn("expected right bracket")); err != nil {
+ return err
+ }
+ }
+ pos, tok, lit = s.Scan()
+ if tok != token.EOL && tok != token.EOF && tok != token.COMMENT {
+ if err := c.Collect(errfn("expected EOL, EOF, or comment")); err != nil {
+ return err
+ }
+ }
+ // If a section/subsection header was found, ensure a
+ // container object is created, even if there are no
+ // variables further down.
+ err := c.Collect(callback(sect, sectsub, "", "", true))
+ if err != nil {
+ return err
+ }
+ case token.IDENT:
+ if sect == "" {
+ if err := c.Collect(errfn("expected section header")); err != nil {
+ return err
+ }
+ }
+ n := lit
+ pos, tok, lit = s.Scan()
+ if errs.Len() > 0 {
+ return errs.Err()
+ }
+ blank, v := tok == token.EOF || tok == token.EOL || tok == token.COMMENT, ""
+ if !blank {
+ if tok != token.ASSIGN {
+ if err := c.Collect(errfn("expected '='")); err != nil {
+ return err
+ }
+ }
+ pos, tok, lit = s.Scan()
+ if errs.Len() > 0 {
+ if err := c.Collect(errs.Err()); err != nil {
+ return err
+ }
+ }
+ if tok != token.STRING {
+ if err := c.Collect(errfn("expected value")); err != nil {
+ return err
+ }
+ }
+ v = unquote(lit)
+ pos, tok, lit = s.Scan()
+ if errs.Len() > 0 {
+ if err := c.Collect(errs.Err()); err != nil {
+ return err
+ }
+ }
+ if tok != token.EOL && tok != token.EOF && tok != token.COMMENT {
+ if err := c.Collect(errfn("expected EOL, EOF, or comment")); err != nil {
+ return err
+ }
+ }
+ }
+ err := c.Collect(callback(sect, sectsub, n, v, blank))
+ if err != nil {
+ return err
+ }
+ default:
+ if sect == "" {
+ if err := c.Collect(errfn("expected section header")); err != nil {
+ return err
+ }
+ }
+ if err := c.Collect(errfn("expected section header or variable declaration")); err != nil {
+ return err
+ }
+ }
+ }
+ panic("never reached")
+}
+
+func readInto(config interface{}, fset *token.FileSet, file *token.File,
+ src []byte) error {
+ //
+ c := warnings.NewCollector(isFatal)
+ firstPassCallback := func(s string, ss string, k string, v string, bv bool) error {
+ return set(c, config, s, ss, k, v, bv, false)
+ }
+ err := read(c, firstPassCallback, fset, file, src)
+ if err != nil {
+ return err
+ }
+ secondPassCallback := func(s string, ss string, k string, v string, bv bool) error {
+ return set(c, config, s, ss, k, v, bv, true)
+ }
+ err = read(c, secondPassCallback, fset, file, src)
+ if err != nil {
+ return err
+ }
+ return c.Done()
+}
+
+// ReadWithCallback reads gcfg formatted data from reader and calls
+// callback with each section and option found.
+//
+// Callback is called with section, subsection, option key, option value
+// and blank value flag as arguments.
+//
+// When a section is found, callback is called with nil subsection, option key
+// and option value.
+//
+// When a subsection is found, callback is called with nil option key and
+// option value.
+//
+// If blank value flag is true, it means that the value was not set for an option
+// (as opposed to set to empty string).
+//
+// If callback returns an error, ReadWithCallback terminates with an error too.
+func ReadWithCallback(reader io.Reader, callback func(string, string, string, string, bool) error) error {
+ src, err := ioutil.ReadAll(reader)
+ if err != nil {
+ return err
+ }
+
+ fset := token.NewFileSet()
+ file := fset.AddFile("", fset.Base(), len(src))
+ c := warnings.NewCollector(isFatal)
+
+ return read(c, callback, fset, file, src)
+}
+
+// ReadInto reads gcfg formatted data from reader and sets the values into the
+// corresponding fields in config.
+func ReadInto(config interface{}, reader io.Reader) error {
+ src, err := ioutil.ReadAll(reader)
+ if err != nil {
+ return err
+ }
+ fset := token.NewFileSet()
+ file := fset.AddFile("", fset.Base(), len(src))
+ return readInto(config, fset, file, src)
+}
+
+// ReadStringInto reads gcfg formatted data from str and sets the values into
+// the corresponding fields in config.
+func ReadStringInto(config interface{}, str string) error {
+ r := strings.NewReader(str)
+ return ReadInto(config, r)
+}
+
+// ReadFileInto reads gcfg formatted data from the file filename and sets the
+// values into the corresponding fields in config.
+func ReadFileInto(config interface{}, filename string) error {
+ f, err := os.Open(filename)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ src, err := ioutil.ReadAll(f)
+ if err != nil {
+ return err
+ }
+ fset := token.NewFileSet()
+ file := fset.AddFile(filename, fset.Base(), len(src))
+ return readInto(config, fset, file, src)
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package scanner
+
+import (
+ "fmt"
+ "io"
+ "sort"
+)
+
+import (
+ "github.com/src-d/gcfg/token"
+)
+
+// In an ErrorList, an error is represented by an *Error.
+// The position Pos, if valid, points to the beginning of
+// the offending token, and the error condition is described
+// by Msg.
+//
+type Error struct {
+ Pos token.Position
+ Msg string
+}
+
+// Error implements the error interface.
+func (e Error) Error() string {
+ if e.Pos.Filename != "" || e.Pos.IsValid() {
+ // don't print "<unknown position>"
+ // TODO(gri) reconsider the semantics of Position.IsValid
+ return e.Pos.String() + ": " + e.Msg
+ }
+ return e.Msg
+}
+
+// ErrorList is a list of *Errors.
+// The zero value for an ErrorList is an empty ErrorList ready to use.
+//
+type ErrorList []*Error
+
+// Add adds an Error with given position and error message to an ErrorList.
+func (p *ErrorList) Add(pos token.Position, msg string) {
+ *p = append(*p, &Error{pos, msg})
+}
+
+// Reset resets an ErrorList to no errors.
+func (p *ErrorList) Reset() { *p = (*p)[0:0] }
+
+// ErrorList implements the sort Interface.
+func (p ErrorList) Len() int { return len(p) }
+func (p ErrorList) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
+
+func (p ErrorList) Less(i, j int) bool {
+ e := &p[i].Pos
+ f := &p[j].Pos
+ if e.Filename < f.Filename {
+ return true
+ }
+ if e.Filename == f.Filename {
+ return e.Offset < f.Offset
+ }
+ return false
+}
+
+// Sort sorts an ErrorList. *Error entries are sorted by position,
+// other errors are sorted by error message, and before any *Error
+// entry.
+//
+func (p ErrorList) Sort() {
+ sort.Sort(p)
+}
+
+// RemoveMultiples sorts an ErrorList and removes all but the first error per line.
+func (p *ErrorList) RemoveMultiples() {
+ sort.Sort(p)
+ var last token.Position // initial last.Line is != any legal error line
+ i := 0
+ for _, e := range *p {
+ if e.Pos.Filename != last.Filename || e.Pos.Line != last.Line {
+ last = e.Pos
+ (*p)[i] = e
+ i++
+ }
+ }
+ (*p) = (*p)[0:i]
+}
+
+// An ErrorList implements the error interface.
+func (p ErrorList) Error() string {
+ switch len(p) {
+ case 0:
+ return "no errors"
+ case 1:
+ return p[0].Error()
+ }
+ return fmt.Sprintf("%s (and %d more errors)", p[0], len(p)-1)
+}
+
+// Err returns an error equivalent to this error list.
+// If the list is empty, Err returns nil.
+func (p ErrorList) Err() error {
+ if len(p) == 0 {
+ return nil
+ }
+ return p
+}
+
+// PrintError is a utility function that prints a list of errors to w,
+// one error per line, if the err parameter is an ErrorList. Otherwise
+// it prints the err string.
+//
+func PrintError(w io.Writer, err error) {
+ if list, ok := err.(ErrorList); ok {
+ for _, e := range list {
+ fmt.Fprintf(w, "%s\n", e)
+ }
+ } else if err != nil {
+ fmt.Fprintf(w, "%s\n", err)
+ }
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package scanner implements a scanner for gcfg configuration text.
+// It takes a []byte as source which can then be tokenized
+// through repeated calls to the Scan method.
+//
+// Note that the API for the scanner package may change to accommodate new
+// features or implementation changes in gcfg.
+//
+package scanner
+
+import (
+ "fmt"
+ "path/filepath"
+ "unicode"
+ "unicode/utf8"
+)
+
+import (
+ "github.com/src-d/gcfg/token"
+)
+
+// An ErrorHandler may be provided to Scanner.Init. If a syntax error is
+// encountered and a handler was installed, the handler is called with a
+// position and an error message. The position points to the beginning of
+// the offending token.
+//
+type ErrorHandler func(pos token.Position, msg string)
+
+// A Scanner holds the scanner's internal state while processing
+// a given text. It can be allocated as part of another data
+// structure but must be initialized via Init before use.
+//
+type Scanner struct {
+ // immutable state
+ file *token.File // source file handle
+ dir string // directory portion of file.Name()
+ src []byte // source
+ err ErrorHandler // error reporting; or nil
+ mode Mode // scanning mode
+
+ // scanning state
+ ch rune // current character
+ offset int // character offset
+ rdOffset int // reading offset (position after current character)
+ lineOffset int // current line offset
+ nextVal bool // next token is expected to be a value
+
+ // public state - ok to modify
+ ErrorCount int // number of errors encountered
+}
+
+// Read the next Unicode char into s.ch.
+// s.ch < 0 means end-of-file.
+//
+func (s *Scanner) next() {
+ if s.rdOffset < len(s.src) {
+ s.offset = s.rdOffset
+ if s.ch == '\n' {
+ s.lineOffset = s.offset
+ s.file.AddLine(s.offset)
+ }
+ r, w := rune(s.src[s.rdOffset]), 1
+ switch {
+ case r == 0:
+ s.error(s.offset, "illegal character NUL")
+ case r >= 0x80:
+ // not ASCII
+ r, w = utf8.DecodeRune(s.src[s.rdOffset:])
+ if r == utf8.RuneError && w == 1 {
+ s.error(s.offset, "illegal UTF-8 encoding")
+ }
+ }
+ s.rdOffset += w
+ s.ch = r
+ } else {
+ s.offset = len(s.src)
+ if s.ch == '\n' {
+ s.lineOffset = s.offset
+ s.file.AddLine(s.offset)
+ }
+ s.ch = -1 // eof
+ }
+}
+
+// A mode value is a set of flags (or 0).
+// They control scanner behavior.
+//
+type Mode uint
+
+const (
+ ScanComments Mode = 1 << iota // return comments as COMMENT tokens
+)
+
+// Init prepares the scanner s to tokenize the text src by setting the
+// scanner at the beginning of src. The scanner uses the file set file
+// for position information and it adds line information for each line.
+// It is ok to re-use the same file when re-scanning the same file as
+// line information which is already present is ignored. Init causes a
+// panic if the file size does not match the src size.
+//
+// Calls to Scan will invoke the error handler err if they encounter a
+// syntax error and err is not nil. Also, for each error encountered,
+// the Scanner field ErrorCount is incremented by one. The mode parameter
+// determines how comments are handled.
+//
+// Note that Init may call err if there is an error in the first character
+// of the file.
+//
+func (s *Scanner) Init(file *token.File, src []byte, err ErrorHandler, mode Mode) {
+ // Explicitly initialize all fields since a scanner may be reused.
+ if file.Size() != len(src) {
+ panic(fmt.Sprintf("file size (%d) does not match src len (%d)", file.Size(), len(src)))
+ }
+ s.file = file
+ s.dir, _ = filepath.Split(file.Name())
+ s.src = src
+ s.err = err
+ s.mode = mode
+
+ s.ch = ' '
+ s.offset = 0
+ s.rdOffset = 0
+ s.lineOffset = 0
+ s.ErrorCount = 0
+ s.nextVal = false
+
+ s.next()
+}
+
+func (s *Scanner) error(offs int, msg string) {
+ if s.err != nil {
+ s.err(s.file.Position(s.file.Pos(offs)), msg)
+ }
+ s.ErrorCount++
+}
+
+func (s *Scanner) scanComment() string {
+ // initial [;#] already consumed
+ offs := s.offset - 1 // position of initial [;#]
+
+ for s.ch != '\n' && s.ch >= 0 {
+ s.next()
+ }
+ return string(s.src[offs:s.offset])
+}
+
+func isLetter(ch rune) bool {
+ return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch >= 0x80 && unicode.IsLetter(ch)
+}
+
+func isDigit(ch rune) bool {
+ return '0' <= ch && ch <= '9' || ch >= 0x80 && unicode.IsDigit(ch)
+}
+
+func (s *Scanner) scanIdentifier() string {
+ offs := s.offset
+ for isLetter(s.ch) || isDigit(s.ch) || s.ch == '-' {
+ s.next()
+ }
+ return string(s.src[offs:s.offset])
+}
+
+func (s *Scanner) scanEscape(val bool) {
+ offs := s.offset
+ ch := s.ch
+ s.next() // always make progress
+ switch ch {
+ case '\\', '"':
+ // ok
+ case 'n', 't', 'b':
+ if val {
+ break // ok
+ }
+ fallthrough
+ default:
+ s.error(offs, "unknown escape sequence")
+ }
+}
+
+func (s *Scanner) scanString() string {
+ // '"' opening already consumed
+ offs := s.offset - 1
+
+ for s.ch != '"' {
+ ch := s.ch
+ s.next()
+ if ch == '\n' || ch < 0 {
+ s.error(offs, "string not terminated")
+ break
+ }
+ if ch == '\\' {
+ s.scanEscape(false)
+ }
+ }
+
+ s.next()
+
+ return string(s.src[offs:s.offset])
+}
+
+func stripCR(b []byte) []byte {
+ c := make([]byte, len(b))
+ i := 0
+ for _, ch := range b {
+ if ch != '\r' {
+ c[i] = ch
+ i++
+ }
+ }
+ return c[:i]
+}
+
+func (s *Scanner) scanValString() string {
+ offs := s.offset
+
+ hasCR := false
+ end := offs
+ inQuote := false
+loop:
+ for inQuote || s.ch >= 0 && s.ch != '\n' && s.ch != ';' && s.ch != '#' {
+ ch := s.ch
+ s.next()
+ switch {
+ case inQuote && ch == '\\':
+ s.scanEscape(true)
+ case !inQuote && ch == '\\':
+ if s.ch == '\r' {
+ hasCR = true
+ s.next()
+ }
+ if s.ch != '\n' {
+ s.scanEscape(true)
+ } else {
+ s.next()
+ }
+ case ch == '"':
+ inQuote = !inQuote
+ case ch == '\r':
+ hasCR = true
+ case ch < 0 || inQuote && ch == '\n':
+ s.error(offs, "string not terminated")
+ break loop
+ }
+ if inQuote || !isWhiteSpace(ch) {
+ end = s.offset
+ }
+ }
+
+ lit := s.src[offs:end]
+ if hasCR {
+ lit = stripCR(lit)
+ }
+
+ return string(lit)
+}
+
+func isWhiteSpace(ch rune) bool {
+ return ch == ' ' || ch == '\t' || ch == '\r'
+}
+
+func (s *Scanner) skipWhitespace() {
+ for isWhiteSpace(s.ch) {
+ s.next()
+ }
+}
+
+// Scan scans the next token and returns the token position, the token,
+// and its literal string if applicable. The source end is indicated by
+// token.EOF.
+//
+// If the returned token is a literal (token.IDENT, token.STRING) or
+// token.COMMENT, the literal string has the corresponding value.
+//
+// If the returned token is token.ILLEGAL, the literal string is the
+// offending character.
+//
+// In all other cases, Scan returns an empty literal string.
+//
+// For more tolerant parsing, Scan will return a valid token if
+// possible even if a syntax error was encountered. Thus, even
+// if the resulting token sequence contains no illegal tokens,
+// a client may not assume that no error occurred. Instead it
+// must check the scanner's ErrorCount or the number of calls
+// of the error handler, if there was one installed.
+//
+// Scan adds line information to the file added to the file
+// set with Init. Token positions are relative to that file
+// and thus relative to the file set.
+//
+func (s *Scanner) Scan() (pos token.Pos, tok token.Token, lit string) {
+scanAgain:
+ s.skipWhitespace()
+
+ // current token start
+ pos = s.file.Pos(s.offset)
+
+ // determine token value
+ switch ch := s.ch; {
+ case s.nextVal:
+ lit = s.scanValString()
+ tok = token.STRING
+ s.nextVal = false
+ case isLetter(ch):
+ lit = s.scanIdentifier()
+ tok = token.IDENT
+ default:
+ s.next() // always make progress
+ switch ch {
+ case -1:
+ tok = token.EOF
+ case '\n':
+ tok = token.EOL
+ case '"':
+ tok = token.STRING
+ lit = s.scanString()
+ case '[':
+ tok = token.LBRACK
+ case ']':
+ tok = token.RBRACK
+ case ';', '#':
+ // comment
+ lit = s.scanComment()
+ if s.mode&ScanComments == 0 {
+ // skip comment
+ goto scanAgain
+ }
+ tok = token.COMMENT
+ case '=':
+ tok = token.ASSIGN
+ s.nextVal = true
+ default:
+ s.error(s.file.Offset(pos), fmt.Sprintf("illegal character %#U", ch))
+ tok = token.ILLEGAL
+ lit = string(ch)
+ }
+ }
+
+ return
+}
--- /dev/null
+package gcfg
+
+import (
+ "bytes"
+ "encoding/gob"
+ "fmt"
+ "math/big"
+ "reflect"
+ "strings"
+ "unicode"
+ "unicode/utf8"
+
+ "github.com/src-d/gcfg/types"
+ "gopkg.in/warnings.v0"
+)
+
+type tag struct {
+ ident string
+ intMode string
+}
+
+func newTag(ts string) tag {
+ t := tag{}
+ s := strings.Split(ts, ",")
+ t.ident = s[0]
+ for _, tse := range s[1:] {
+ if strings.HasPrefix(tse, "int=") {
+ t.intMode = tse[len("int="):]
+ }
+ }
+ return t
+}
+
+func fieldFold(v reflect.Value, name string) (reflect.Value, tag) {
+ var n string
+ r0, _ := utf8.DecodeRuneInString(name)
+ if unicode.IsLetter(r0) && !unicode.IsLower(r0) && !unicode.IsUpper(r0) {
+ n = "X"
+ }
+ n += strings.Replace(name, "-", "_", -1)
+ f, ok := v.Type().FieldByNameFunc(func(fieldName string) bool {
+ if !v.FieldByName(fieldName).CanSet() {
+ return false
+ }
+ f, _ := v.Type().FieldByName(fieldName)
+ t := newTag(f.Tag.Get("gcfg"))
+ if t.ident != "" {
+ return strings.EqualFold(t.ident, name)
+ }
+ return strings.EqualFold(n, fieldName)
+ })
+ if !ok {
+ return reflect.Value{}, tag{}
+ }
+ return v.FieldByName(f.Name), newTag(f.Tag.Get("gcfg"))
+}
+
+type setter func(destp interface{}, blank bool, val string, t tag) error
+
+var errUnsupportedType = fmt.Errorf("unsupported type")
+var errBlankUnsupported = fmt.Errorf("blank value not supported for type")
+
+var setters = []setter{
+ typeSetter, textUnmarshalerSetter, kindSetter, scanSetter,
+}
+
+func textUnmarshalerSetter(d interface{}, blank bool, val string, t tag) error {
+ dtu, ok := d.(textUnmarshaler)
+ if !ok {
+ return errUnsupportedType
+ }
+ if blank {
+ return errBlankUnsupported
+ }
+ return dtu.UnmarshalText([]byte(val))
+}
+
+func boolSetter(d interface{}, blank bool, val string, t tag) error {
+ if blank {
+ reflect.ValueOf(d).Elem().Set(reflect.ValueOf(true))
+ return nil
+ }
+ b, err := types.ParseBool(val)
+ if err == nil {
+ reflect.ValueOf(d).Elem().Set(reflect.ValueOf(b))
+ }
+ return err
+}
+
+func intMode(mode string) types.IntMode {
+ var m types.IntMode
+ if strings.ContainsAny(mode, "dD") {
+ m |= types.Dec
+ }
+ if strings.ContainsAny(mode, "hH") {
+ m |= types.Hex
+ }
+ if strings.ContainsAny(mode, "oO") {
+ m |= types.Oct
+ }
+ return m
+}
+
+var typeModes = map[reflect.Type]types.IntMode{
+ reflect.TypeOf(int(0)): types.Dec | types.Hex,
+ reflect.TypeOf(int8(0)): types.Dec | types.Hex,
+ reflect.TypeOf(int16(0)): types.Dec | types.Hex,
+ reflect.TypeOf(int32(0)): types.Dec | types.Hex,
+ reflect.TypeOf(int64(0)): types.Dec | types.Hex,
+ reflect.TypeOf(uint(0)): types.Dec | types.Hex,
+ reflect.TypeOf(uint8(0)): types.Dec | types.Hex,
+ reflect.TypeOf(uint16(0)): types.Dec | types.Hex,
+ reflect.TypeOf(uint32(0)): types.Dec | types.Hex,
+ reflect.TypeOf(uint64(0)): types.Dec | types.Hex,
+ // use default mode (allow dec/hex/oct) for uintptr type
+ reflect.TypeOf(big.Int{}): types.Dec | types.Hex,
+}
+
+func intModeDefault(t reflect.Type) types.IntMode {
+ m, ok := typeModes[t]
+ if !ok {
+ m = types.Dec | types.Hex | types.Oct
+ }
+ return m
+}
+
+func intSetter(d interface{}, blank bool, val string, t tag) error {
+ if blank {
+ return errBlankUnsupported
+ }
+ mode := intMode(t.intMode)
+ if mode == 0 {
+ mode = intModeDefault(reflect.TypeOf(d).Elem())
+ }
+ return types.ParseInt(d, val, mode)
+}
+
+func stringSetter(d interface{}, blank bool, val string, t tag) error {
+ if blank {
+ return errBlankUnsupported
+ }
+ dsp, ok := d.(*string)
+ if !ok {
+ return errUnsupportedType
+ }
+ *dsp = val
+ return nil
+}
+
+var kindSetters = map[reflect.Kind]setter{
+ reflect.String: stringSetter,
+ reflect.Bool: boolSetter,
+ reflect.Int: intSetter,
+ reflect.Int8: intSetter,
+ reflect.Int16: intSetter,
+ reflect.Int32: intSetter,
+ reflect.Int64: intSetter,
+ reflect.Uint: intSetter,
+ reflect.Uint8: intSetter,
+ reflect.Uint16: intSetter,
+ reflect.Uint32: intSetter,
+ reflect.Uint64: intSetter,
+ reflect.Uintptr: intSetter,
+}
+
+var typeSetters = map[reflect.Type]setter{
+ reflect.TypeOf(big.Int{}): intSetter,
+}
+
+func typeSetter(d interface{}, blank bool, val string, tt tag) error {
+ t := reflect.ValueOf(d).Type().Elem()
+ setter, ok := typeSetters[t]
+ if !ok {
+ return errUnsupportedType
+ }
+ return setter(d, blank, val, tt)
+}
+
+func kindSetter(d interface{}, blank bool, val string, tt tag) error {
+ k := reflect.ValueOf(d).Type().Elem().Kind()
+ setter, ok := kindSetters[k]
+ if !ok {
+ return errUnsupportedType
+ }
+ return setter(d, blank, val, tt)
+}
+
+func scanSetter(d interface{}, blank bool, val string, tt tag) error {
+ if blank {
+ return errBlankUnsupported
+ }
+ return types.ScanFully(d, val, 'v')
+}
+
+func newValue(c *warnings.Collector, sect string, vCfg reflect.Value,
+ vType reflect.Type) (reflect.Value, error) {
+ //
+ pv := reflect.New(vType)
+ dfltName := "default-" + sect
+ dfltField, _ := fieldFold(vCfg, dfltName)
+ var err error
+ if dfltField.IsValid() {
+ b := bytes.NewBuffer(nil)
+ ge := gob.NewEncoder(b)
+ if err = c.Collect(ge.EncodeValue(dfltField)); err != nil {
+ return pv, err
+ }
+ gd := gob.NewDecoder(bytes.NewReader(b.Bytes()))
+ if err = c.Collect(gd.DecodeValue(pv.Elem())); err != nil {
+ return pv, err
+ }
+ }
+ return pv, nil
+}
+
+func set(c *warnings.Collector, cfg interface{}, sect, sub, name string,
+ value string, blankValue bool, subsectPass bool) error {
+ //
+ vPCfg := reflect.ValueOf(cfg)
+ if vPCfg.Kind() != reflect.Ptr || vPCfg.Elem().Kind() != reflect.Struct {
+ panic(fmt.Errorf("config must be a pointer to a struct"))
+ }
+ vCfg := vPCfg.Elem()
+ vSect, _ := fieldFold(vCfg, sect)
+ if !vSect.IsValid() {
+ err := extraData{section: sect}
+ return c.Collect(err)
+ }
+ isSubsect := vSect.Kind() == reflect.Map
+ if subsectPass != isSubsect {
+ return nil
+ }
+ if isSubsect {
+ vst := vSect.Type()
+ if vst.Key().Kind() != reflect.String ||
+ vst.Elem().Kind() != reflect.Ptr ||
+ vst.Elem().Elem().Kind() != reflect.Struct {
+ panic(fmt.Errorf("map field for section must have string keys and "+
+ " pointer-to-struct values: section %q", sect))
+ }
+ if vSect.IsNil() {
+ vSect.Set(reflect.MakeMap(vst))
+ }
+ k := reflect.ValueOf(sub)
+ pv := vSect.MapIndex(k)
+ if !pv.IsValid() {
+ vType := vSect.Type().Elem().Elem()
+ var err error
+ if pv, err = newValue(c, sect, vCfg, vType); err != nil {
+ return err
+ }
+ vSect.SetMapIndex(k, pv)
+ }
+ vSect = pv.Elem()
+ } else if vSect.Kind() != reflect.Struct {
+ panic(fmt.Errorf("field for section must be a map or a struct: "+
+ "section %q", sect))
+ } else if sub != "" {
+ err := extraData{section: sect, subsection: &sub}
+ return c.Collect(err)
+ }
+ // Empty name is a special value, meaning that only the
+ // section/subsection object is to be created, with no values set.
+ if name == "" {
+ return nil
+ }
+ vVar, t := fieldFold(vSect, name)
+ if !vVar.IsValid() {
+ var err error
+ if isSubsect {
+ err = extraData{section: sect, subsection: &sub, variable: &name}
+ } else {
+ err = extraData{section: sect, variable: &name}
+ }
+ return c.Collect(err)
+ }
+ // vVal is either single-valued var, or newly allocated value within multi-valued var
+ var vVal reflect.Value
+ // multi-value if unnamed slice type
+ isMulti := vVar.Type().Name() == "" && vVar.Kind() == reflect.Slice ||
+ vVar.Type().Name() == "" && vVar.Kind() == reflect.Ptr && vVar.Type().Elem().Name() == "" && vVar.Type().Elem().Kind() == reflect.Slice
+ if isMulti && vVar.Kind() == reflect.Ptr {
+ if vVar.IsNil() {
+ vVar.Set(reflect.New(vVar.Type().Elem()))
+ }
+ vVar = vVar.Elem()
+ }
+ if isMulti && blankValue {
+ vVar.Set(reflect.Zero(vVar.Type()))
+ return nil
+ }
+ if isMulti {
+ vVal = reflect.New(vVar.Type().Elem()).Elem()
+ } else {
+ vVal = vVar
+ }
+ isDeref := vVal.Type().Name() == "" && vVal.Type().Kind() == reflect.Ptr
+ isNew := isDeref && vVal.IsNil()
+ // vAddr is address of value to set (dereferenced & allocated as needed)
+ var vAddr reflect.Value
+ switch {
+ case isNew:
+ vAddr = reflect.New(vVal.Type().Elem())
+ case isDeref && !isNew:
+ vAddr = vVal
+ default:
+ vAddr = vVal.Addr()
+ }
+ vAddrI := vAddr.Interface()
+ err, ok := error(nil), false
+ for _, s := range setters {
+ err = s(vAddrI, blankValue, value, t)
+ if err == nil {
+ ok = true
+ break
+ }
+ if err != errUnsupportedType {
+ return err
+ }
+ }
+ if !ok {
+ // in case all setters returned errUnsupportedType
+ return err
+ }
+ if isNew { // set reference if it was dereferenced and newly allocated
+ vVal.Set(vAddr)
+ }
+ if isMulti { // append if multi-valued
+ vVar.Set(reflect.Append(vVar, vVal))
+ }
+ return nil
+}
--- /dev/null
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// TODO(gri) consider making this a separate package outside the go directory.
+
+package token
+
+import (
+ "fmt"
+ "sort"
+ "sync"
+)
+
+// -----------------------------------------------------------------------------
+// Positions
+
+// Position describes an arbitrary source position
+// including the file, line, and column location.
+// A Position is valid if the line number is > 0.
+//
+type Position struct {
+ Filename string // filename, if any
+ Offset int // offset, starting at 0
+ Line int // line number, starting at 1
+ Column int // column number, starting at 1 (character count)
+}
+
+// IsValid returns true if the position is valid.
+func (pos *Position) IsValid() bool { return pos.Line > 0 }
+
+// String returns a string in one of several forms:
+//
+// file:line:column valid position with file name
+// line:column valid position without file name
+// file invalid position with file name
+// - invalid position without file name
+//
+func (pos Position) String() string {
+ s := pos.Filename
+ if pos.IsValid() {
+ if s != "" {
+ s += ":"
+ }
+ s += fmt.Sprintf("%d:%d", pos.Line, pos.Column)
+ }
+ if s == "" {
+ s = "-"
+ }
+ return s
+}
+
+// Pos is a compact encoding of a source position within a file set.
+// It can be converted into a Position for a more convenient, but much
+// larger, representation.
+//
+// The Pos value for a given file is a number in the range [base, base+size],
+// where base and size are specified when adding the file to the file set via
+// AddFile.
+//
+// To create the Pos value for a specific source offset, first add
+// the respective file to the current file set (via FileSet.AddFile)
+// and then call File.Pos(offset) for that file. Given a Pos value p
+// for a specific file set fset, the corresponding Position value is
+// obtained by calling fset.Position(p).
+//
+// Pos values can be compared directly with the usual comparison operators:
+// If two Pos values p and q are in the same file, comparing p and q is
+// equivalent to comparing the respective source file offsets. If p and q
+// are in different files, p < q is true if the file implied by p was added
+// to the respective file set before the file implied by q.
+//
+type Pos int
+
+// The zero value for Pos is NoPos; there is no file and line information
+// associated with it, and NoPos().IsValid() is false. NoPos is always
+// smaller than any other Pos value. The corresponding Position value
+// for NoPos is the zero value for Position.
+//
+const NoPos Pos = 0
+
+// IsValid returns true if the position is valid.
+func (p Pos) IsValid() bool {
+ return p != NoPos
+}
+
+// -----------------------------------------------------------------------------
+// File
+
+// A File is a handle for a file belonging to a FileSet.
+// A File has a name, size, and line offset table.
+//
+type File struct {
+ set *FileSet
+ name string // file name as provided to AddFile
+ base int // Pos value range for this file is [base...base+size]
+ size int // file size as provided to AddFile
+
+ // lines and infos are protected by set.mutex
+ lines []int
+ infos []lineInfo
+}
+
+// Name returns the file name of file f as registered with AddFile.
+func (f *File) Name() string {
+ return f.name
+}
+
+// Base returns the base offset of file f as registered with AddFile.
+func (f *File) Base() int {
+ return f.base
+}
+
+// Size returns the size of file f as registered with AddFile.
+func (f *File) Size() int {
+ return f.size
+}
+
+// LineCount returns the number of lines in file f.
+func (f *File) LineCount() int {
+ f.set.mutex.RLock()
+ n := len(f.lines)
+ f.set.mutex.RUnlock()
+ return n
+}
+
+// AddLine adds the line offset for a new line.
+// The line offset must be larger than the offset for the previous line
+// and smaller than the file size; otherwise the line offset is ignored.
+//
+func (f *File) AddLine(offset int) {
+ f.set.mutex.Lock()
+ if i := len(f.lines); (i == 0 || f.lines[i-1] < offset) && offset < f.size {
+ f.lines = append(f.lines, offset)
+ }
+ f.set.mutex.Unlock()
+}
+
+// SetLines sets the line offsets for a file and returns true if successful.
+// The line offsets are the offsets of the first character of each line;
+// for instance for the content "ab\nc\n" the line offsets are {0, 3}.
+// An empty file has an empty line offset table.
+// Each line offset must be larger than the offset for the previous line
+// and smaller than the file size; otherwise SetLines fails and returns
+// false.
+//
+func (f *File) SetLines(lines []int) bool {
+ // verify validity of lines table
+ size := f.size
+ for i, offset := range lines {
+ if i > 0 && offset <= lines[i-1] || size <= offset {
+ return false
+ }
+ }
+
+ // set lines table
+ f.set.mutex.Lock()
+ f.lines = lines
+ f.set.mutex.Unlock()
+ return true
+}
+
+// SetLinesForContent sets the line offsets for the given file content.
+func (f *File) SetLinesForContent(content []byte) {
+ var lines []int
+ line := 0
+ for offset, b := range content {
+ if line >= 0 {
+ lines = append(lines, line)
+ }
+ line = -1
+ if b == '\n' {
+ line = offset + 1
+ }
+ }
+
+ // set lines table
+ f.set.mutex.Lock()
+ f.lines = lines
+ f.set.mutex.Unlock()
+}
+
+// A lineInfo object describes alternative file and line number
+// information (such as provided via a //line comment in a .go
+// file) for a given file offset.
+type lineInfo struct {
+ // fields are exported to make them accessible to gob
+ Offset int
+ Filename string
+ Line int
+}
+
+// AddLineInfo adds alternative file and line number information for
+// a given file offset. The offset must be larger than the offset for
+// the previously added alternative line info and smaller than the
+// file size; otherwise the information is ignored.
+//
+// AddLineInfo is typically used to register alternative position
+// information for //line filename:line comments in source files.
+//
+func (f *File) AddLineInfo(offset int, filename string, line int) {
+ f.set.mutex.Lock()
+ if i := len(f.infos); i == 0 || f.infos[i-1].Offset < offset && offset < f.size {
+ f.infos = append(f.infos, lineInfo{offset, filename, line})
+ }
+ f.set.mutex.Unlock()
+}
+
+// Pos returns the Pos value for the given file offset;
+// the offset must be <= f.Size().
+// f.Pos(f.Offset(p)) == p.
+//
+func (f *File) Pos(offset int) Pos {
+ if offset > f.size {
+ panic("illegal file offset")
+ }
+ return Pos(f.base + offset)
+}
+
+// Offset returns the offset for the given file position p;
+// p must be a valid Pos value in that file.
+// f.Offset(f.Pos(offset)) == offset.
+//
+func (f *File) Offset(p Pos) int {
+ if int(p) < f.base || int(p) > f.base+f.size {
+ panic("illegal Pos value")
+ }
+ return int(p) - f.base
+}
+
+// Line returns the line number for the given file position p;
+// p must be a Pos value in that file or NoPos.
+//
+func (f *File) Line(p Pos) int {
+ // TODO(gri) this can be implemented much more efficiently
+ return f.Position(p).Line
+}
+
+func searchLineInfos(a []lineInfo, x int) int {
+ return sort.Search(len(a), func(i int) bool { return a[i].Offset > x }) - 1
+}
+
+// info returns the file name, line, and column number for a file offset.
+func (f *File) info(offset int) (filename string, line, column int) {
+ filename = f.name
+ if i := searchInts(f.lines, offset); i >= 0 {
+ line, column = i+1, offset-f.lines[i]+1
+ }
+ if len(f.infos) > 0 {
+ // almost no files have extra line infos
+ if i := searchLineInfos(f.infos, offset); i >= 0 {
+ alt := &f.infos[i]
+ filename = alt.Filename
+ if i := searchInts(f.lines, alt.Offset); i >= 0 {
+ line += alt.Line - i - 1
+ }
+ }
+ }
+ return
+}
+
+func (f *File) position(p Pos) (pos Position) {
+ offset := int(p) - f.base
+ pos.Offset = offset
+ pos.Filename, pos.Line, pos.Column = f.info(offset)
+ return
+}
+
+// Position returns the Position value for the given file position p;
+// p must be a Pos value in that file or NoPos.
+//
+func (f *File) Position(p Pos) (pos Position) {
+ if p != NoPos {
+ if int(p) < f.base || int(p) > f.base+f.size {
+ panic("illegal Pos value")
+ }
+ pos = f.position(p)
+ }
+ return
+}
+
+// -----------------------------------------------------------------------------
+// FileSet
+
+// A FileSet represents a set of source files.
+// Methods of file sets are synchronized; multiple goroutines
+// may invoke them concurrently.
+//
+type FileSet struct {
+ mutex sync.RWMutex // protects the file set
+ base int // base offset for the next file
+ files []*File // list of files in the order added to the set
+ last *File // cache of last file looked up
+}
+
+// NewFileSet creates a new file set.
+func NewFileSet() *FileSet {
+ s := new(FileSet)
+ s.base = 1 // 0 == NoPos
+ return s
+}
+
+// Base returns the minimum base offset that must be provided to
+// AddFile when adding the next file.
+//
+func (s *FileSet) Base() int {
+ s.mutex.RLock()
+ b := s.base
+ s.mutex.RUnlock()
+ return b
+
+}
+
+// AddFile adds a new file with a given filename, base offset, and file size
+// to the file set s and returns the file. Multiple files may have the same
+// name. The base offset must not be smaller than the FileSet's Base(), and
+// size must not be negative.
+//
+// Adding the file will set the file set's Base() value to base + size + 1
+// as the minimum base value for the next file. The following relationship
+// exists between a Pos value p for a given file offset offs:
+//
+// int(p) = base + offs
+//
+// with offs in the range [0, size] and thus p in the range [base, base+size].
+// For convenience, File.Pos may be used to create file-specific position
+// values from a file offset.
+//
+func (s *FileSet) AddFile(filename string, base, size int) *File {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+ if base < s.base || size < 0 {
+ panic("illegal base or size")
+ }
+ // base >= s.base && size >= 0
+ f := &File{s, filename, base, size, []int{0}, nil}
+ base += size + 1 // +1 because EOF also has a position
+ if base < 0 {
+ panic("token.Pos offset overflow (> 2G of source code in file set)")
+ }
+ // add the file to the file set
+ s.base = base
+ s.files = append(s.files, f)
+ s.last = f
+ return f
+}
+
+// Iterate calls f for the files in the file set in the order they were added
+// until f returns false.
+//
+func (s *FileSet) Iterate(f func(*File) bool) {
+ for i := 0; ; i++ {
+ var file *File
+ s.mutex.RLock()
+ if i < len(s.files) {
+ file = s.files[i]
+ }
+ s.mutex.RUnlock()
+ if file == nil || !f(file) {
+ break
+ }
+ }
+}
+
+func searchFiles(a []*File, x int) int {
+ return sort.Search(len(a), func(i int) bool { return a[i].base > x }) - 1
+}
+
+func (s *FileSet) file(p Pos) *File {
+ // common case: p is in last file
+ if f := s.last; f != nil && f.base <= int(p) && int(p) <= f.base+f.size {
+ return f
+ }
+ // p is not in last file - search all files
+ if i := searchFiles(s.files, int(p)); i >= 0 {
+ f := s.files[i]
+ // f.base <= int(p) by definition of searchFiles
+ if int(p) <= f.base+f.size {
+ s.last = f
+ return f
+ }
+ }
+ return nil
+}
+
+// File returns the file that contains the position p.
+// If no such file is found (for instance for p == NoPos),
+// the result is nil.
+//
+func (s *FileSet) File(p Pos) (f *File) {
+ if p != NoPos {
+ s.mutex.RLock()
+ f = s.file(p)
+ s.mutex.RUnlock()
+ }
+ return
+}
+
+// Position converts a Pos in the fileset into a general Position.
+func (s *FileSet) Position(p Pos) (pos Position) {
+ if p != NoPos {
+ s.mutex.RLock()
+ if f := s.file(p); f != nil {
+ pos = f.position(p)
+ }
+ s.mutex.RUnlock()
+ }
+ return
+}
+
+// -----------------------------------------------------------------------------
+// Helper functions
+
+func searchInts(a []int, x int) int {
+ // This function body is a manually inlined version of:
+ //
+ // return sort.Search(len(a), func(i int) bool { return a[i] > x }) - 1
+ //
+ // With better compiler optimizations, this may not be needed in the
+ // future, but at the moment this change improves the go/printer
+ // benchmark performance by ~30%. This has a direct impact on the
+ // speed of gofmt and thus seems worthwhile (2011-04-29).
+ // TODO(gri): Remove this when compilers have caught up.
+ i, j := 0, len(a)
+ for i < j {
+ h := i + (j-i)/2 // avoid overflow when computing h
+ // i ≤ h < j
+ if a[h] <= x {
+ i = h + 1
+ } else {
+ j = h
+ }
+ }
+ return i - 1
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package token
+
+type serializedFile struct {
+ // fields correspond 1:1 to fields with same (lower-case) name in File
+ Name string
+ Base int
+ Size int
+ Lines []int
+ Infos []lineInfo
+}
+
+type serializedFileSet struct {
+ Base int
+ Files []serializedFile
+}
+
+// Read calls decode to deserialize a file set into s; s must not be nil.
+func (s *FileSet) Read(decode func(interface{}) error) error {
+ var ss serializedFileSet
+ if err := decode(&ss); err != nil {
+ return err
+ }
+
+ s.mutex.Lock()
+ s.base = ss.Base
+ files := make([]*File, len(ss.Files))
+ for i := 0; i < len(ss.Files); i++ {
+ f := &ss.Files[i]
+ files[i] = &File{s, f.Name, f.Base, f.Size, f.Lines, f.Infos}
+ }
+ s.files = files
+ s.last = nil
+ s.mutex.Unlock()
+
+ return nil
+}
+
+// Write calls encode to serialize the file set s.
+func (s *FileSet) Write(encode func(interface{}) error) error {
+ var ss serializedFileSet
+
+ s.mutex.Lock()
+ ss.Base = s.base
+ files := make([]serializedFile, len(s.files))
+ for i, f := range s.files {
+ files[i] = serializedFile{f.name, f.base, f.size, f.lines, f.infos}
+ }
+ ss.Files = files
+ s.mutex.Unlock()
+
+ return encode(ss)
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package token defines constants representing the lexical tokens of the gcfg
+// configuration syntax and basic operations on tokens (printing, predicates).
+//
+// Note that the API for the token package may change to accommodate new
+// features or implementation changes in gcfg.
+//
+package token
+
+import "strconv"
+
+// Token is the set of lexical tokens of the gcfg configuration syntax.
+type Token int
+
+// The list of tokens.
+const (
+ // Special tokens
+ ILLEGAL Token = iota
+ EOF
+ COMMENT
+
+ literal_beg
+ // Identifiers and basic type literals
+ // (these tokens stand for classes of literals)
+ IDENT // section-name, variable-name
+ STRING // "subsection-name", variable value
+ literal_end
+
+ operator_beg
+ // Operators and delimiters
+ ASSIGN // =
+ LBRACK // [
+ RBRACK // ]
+ EOL // \n
+ operator_end
+)
+
+var tokens = [...]string{
+ ILLEGAL: "ILLEGAL",
+
+ EOF: "EOF",
+ COMMENT: "COMMENT",
+
+ IDENT: "IDENT",
+ STRING: "STRING",
+
+ ASSIGN: "=",
+ LBRACK: "[",
+ RBRACK: "]",
+ EOL: "\n",
+}
+
+// String returns the string corresponding to the token tok.
+// For operators and delimiters, the string is the actual token character
+// sequence (e.g., for the token ASSIGN, the string is "="). For all other
+// tokens the string corresponds to the token constant name (e.g. for the
+// token IDENT, the string is "IDENT").
+//
+func (tok Token) String() string {
+ s := ""
+ if 0 <= tok && tok < Token(len(tokens)) {
+ s = tokens[tok]
+ }
+ if s == "" {
+ s = "token(" + strconv.Itoa(int(tok)) + ")"
+ }
+ return s
+}
+
+// Predicates
+
+// IsLiteral returns true for tokens corresponding to identifiers
+// and basic type literals; it returns false otherwise.
+//
+func (tok Token) IsLiteral() bool { return literal_beg < tok && tok < literal_end }
+
+// IsOperator returns true for tokens corresponding to operators and
+// delimiters; it returns false otherwise.
+//
+func (tok Token) IsOperator() bool { return operator_beg < tok && tok < operator_end }
--- /dev/null
+package types
+
+// BoolValues defines the name and value mappings for ParseBool.
+var BoolValues = map[string]interface{}{
+ "true": true, "yes": true, "on": true, "1": true,
+ "false": false, "no": false, "off": false, "0": false,
+}
+
+var boolParser = func() *EnumParser {
+ ep := &EnumParser{}
+ ep.AddVals(BoolValues)
+ return ep
+}()
+
+// ParseBool parses bool values according to the definitions in BoolValues.
+// Parsing is case-insensitive.
+func ParseBool(s string) (bool, error) {
+ v, err := boolParser.Parse(s)
+ if err != nil {
+ return false, err
+ }
+ return v.(bool), nil
+}
--- /dev/null
+// Package types defines helpers for type conversions.
+//
+// The API for this package is not finalized yet.
+package types
--- /dev/null
+package types
+
+import (
+ "fmt"
+ "reflect"
+ "strings"
+)
+
+// EnumParser parses "enum" values; i.e. a predefined set of strings to
+// predefined values.
+type EnumParser struct {
+ Type string // type name; if not set, use type of first value added
+ CaseMatch bool // if true, matching of strings is case-sensitive
+ // PrefixMatch bool
+ vals map[string]interface{}
+}
+
+// AddVals adds strings and values to an EnumParser.
+func (ep *EnumParser) AddVals(vals map[string]interface{}) {
+ if ep.vals == nil {
+ ep.vals = make(map[string]interface{})
+ }
+ for k, v := range vals {
+ if ep.Type == "" {
+ ep.Type = reflect.TypeOf(v).Name()
+ }
+ if !ep.CaseMatch {
+ k = strings.ToLower(k)
+ }
+ ep.vals[k] = v
+ }
+}
+
+// Parse parses the string and returns the value or an error.
+func (ep EnumParser) Parse(s string) (interface{}, error) {
+ if !ep.CaseMatch {
+ s = strings.ToLower(s)
+ }
+ v, ok := ep.vals[s]
+ if !ok {
+ return false, fmt.Errorf("failed to parse %s %#q", ep.Type, s)
+ }
+ return v, nil
+}
--- /dev/null
+package types
+
+import (
+ "fmt"
+ "strings"
+)
+
+// An IntMode is a mode for parsing integer values, representing a set of
+// accepted bases.
+type IntMode uint8
+
+// IntMode values for ParseInt; can be combined using binary or.
+const (
+ Dec IntMode = 1 << iota
+ Hex
+ Oct
+)
+
+// String returns a string representation of IntMode; e.g. `IntMode(Dec|Hex)`.
+func (m IntMode) String() string {
+ var modes []string
+ if m&Dec != 0 {
+ modes = append(modes, "Dec")
+ }
+ if m&Hex != 0 {
+ modes = append(modes, "Hex")
+ }
+ if m&Oct != 0 {
+ modes = append(modes, "Oct")
+ }
+ return "IntMode(" + strings.Join(modes, "|") + ")"
+}
+
+var errIntAmbig = fmt.Errorf("ambiguous integer value; must include '0' prefix")
+
+func prefix0(val string) bool {
+ return strings.HasPrefix(val, "0") || strings.HasPrefix(val, "-0")
+}
+
+func prefix0x(val string) bool {
+ return strings.HasPrefix(val, "0x") || strings.HasPrefix(val, "-0x")
+}
+
+// ParseInt parses val using mode into intptr, which must be a pointer to an
+// integer kind type. Non-decimal value require prefix `0` or `0x` in the cases
+// when mode permits ambiguity of base; otherwise the prefix can be omitted.
+func ParseInt(intptr interface{}, val string, mode IntMode) error {
+ val = strings.TrimSpace(val)
+ verb := byte(0)
+ switch mode {
+ case Dec:
+ verb = 'd'
+ case Dec + Hex:
+ if prefix0x(val) {
+ verb = 'v'
+ } else {
+ verb = 'd'
+ }
+ case Dec + Oct:
+ if prefix0(val) && !prefix0x(val) {
+ verb = 'v'
+ } else {
+ verb = 'd'
+ }
+ case Dec + Hex + Oct:
+ verb = 'v'
+ case Hex:
+ if prefix0x(val) {
+ verb = 'v'
+ } else {
+ verb = 'x'
+ }
+ case Oct:
+ verb = 'o'
+ case Hex + Oct:
+ if prefix0(val) {
+ verb = 'v'
+ } else {
+ return errIntAmbig
+ }
+ }
+ if verb == 0 {
+ panic("unsupported mode")
+ }
+ return ScanFully(intptr, val, verb)
+}
--- /dev/null
+package types
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+)
+
+// ScanFully uses fmt.Sscanf with verb to fully scan val into ptr.
+func ScanFully(ptr interface{}, val string, verb byte) error {
+ t := reflect.ValueOf(ptr).Elem().Type()
+ // attempt to read extra bytes to make sure the value is consumed
+ var b []byte
+ n, err := fmt.Sscanf(val, "%"+string(verb)+"%s", ptr, &b)
+ switch {
+ case n < 1 || n == 1 && err != io.EOF:
+ return fmt.Errorf("failed to parse %q as %v: %v", val, t, err)
+ case n > 1:
+ return fmt.Errorf("failed to parse %q as %v: extra characters %q", val, t, string(b))
+ }
+ // n == 1 && err == io.EOF
+ return nil
+}
--- /dev/null
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright {yyyy} {name of copyright owner}
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
--- /dev/null
+//
+// Copyright (c) 2014 David Mzareulyan
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy of this software
+// and associated documentation files (the "Software"), to deal in the Software without restriction,
+// including without limitation the rights to use, copy, modify, merge, publish, distribute,
+// sublicense, and/or sell copies of the Software, and to permit persons to whom the Software
+// is furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in all copies or substantial
+// portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
+// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+//
+
+// +build windows
+
+package sshagent
+
+// see https://github.com/Yasushi/putty/blob/master/windows/winpgntc.c#L155
+// see https://github.com/paramiko/paramiko/blob/master/paramiko/win_pageant.py
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "sync"
+ "syscall"
+ "unsafe"
+)
+
+// Maximum size of message can be sent to pageant
+const MaxMessageLen = 8192
+
+var (
+ ErrPageantNotFound = errors.New("pageant process not found")
+ ErrSendMessage = errors.New("error sending message")
+
+ ErrMessageTooLong = errors.New("message too long")
+ ErrInvalidMessageFormat = errors.New("invalid message format")
+ ErrResponseTooLong = errors.New("response too long")
+)
+
+const (
+ agentCopydataID = 0x804e50ba
+ wmCopydata = 74
+)
+
+type copyData struct {
+ dwData uintptr
+ cbData uint32
+ lpData unsafe.Pointer
+}
+
+var (
+ lock sync.Mutex
+
+ winFindWindow = winAPI("user32.dll", "FindWindowW")
+ winGetCurrentThreadID = winAPI("kernel32.dll", "GetCurrentThreadId")
+ winSendMessage = winAPI("user32.dll", "SendMessageW")
+)
+
+func winAPI(dllName, funcName string) func(...uintptr) (uintptr, uintptr, error) {
+ proc := syscall.MustLoadDLL(dllName).MustFindProc(funcName)
+ return func(a ...uintptr) (uintptr, uintptr, error) { return proc.Call(a...) }
+}
+
+// Available returns true if Pageant is running
+func Available() bool { return pageantWindow() != 0 }
+
+// Query sends message msg to Pageant and returns response or error.
+// 'msg' is raw agent request with length prefix
+// Response is raw agent response with length prefix
+func query(msg []byte) ([]byte, error) {
+ if len(msg) > MaxMessageLen {
+ return nil, ErrMessageTooLong
+ }
+
+ msgLen := binary.BigEndian.Uint32(msg[:4])
+ if len(msg) != int(msgLen)+4 {
+ return nil, ErrInvalidMessageFormat
+ }
+
+ lock.Lock()
+ defer lock.Unlock()
+
+ paWin := pageantWindow()
+
+ if paWin == 0 {
+ return nil, ErrPageantNotFound
+ }
+
+ thID, _, _ := winGetCurrentThreadID()
+ mapName := fmt.Sprintf("PageantRequest%08x", thID)
+ pMapName, _ := syscall.UTF16PtrFromString(mapName)
+
+ mmap, err := syscall.CreateFileMapping(syscall.InvalidHandle, nil, syscall.PAGE_READWRITE, 0, MaxMessageLen+4, pMapName)
+ if err != nil {
+ return nil, err
+ }
+ defer syscall.CloseHandle(mmap)
+
+ ptr, err := syscall.MapViewOfFile(mmap, syscall.FILE_MAP_WRITE, 0, 0, 0)
+ if err != nil {
+ return nil, err
+ }
+ defer syscall.UnmapViewOfFile(ptr)
+
+ mmSlice := (*(*[MaxMessageLen]byte)(unsafe.Pointer(ptr)))[:]
+
+ copy(mmSlice, msg)
+
+ mapNameBytesZ := append([]byte(mapName), 0)
+
+ cds := copyData{
+ dwData: agentCopydataID,
+ cbData: uint32(len(mapNameBytesZ)),
+ lpData: unsafe.Pointer(&(mapNameBytesZ[0])),
+ }
+
+ resp, _, _ := winSendMessage(paWin, wmCopydata, 0, uintptr(unsafe.Pointer(&cds)))
+
+ if resp == 0 {
+ return nil, ErrSendMessage
+ }
+
+ respLen := binary.BigEndian.Uint32(mmSlice[:4])
+ if respLen > MaxMessageLen-4 {
+ return nil, ErrResponseTooLong
+ }
+
+ respData := make([]byte, respLen+4)
+ copy(respData, mmSlice)
+
+ return respData, nil
+}
+
+func pageantWindow() uintptr {
+ nameP, _ := syscall.UTF16PtrFromString("Pageant")
+ h, _, _ := winFindWindow(uintptr(unsafe.Pointer(nameP)), uintptr(unsafe.Pointer(nameP)))
+ return h
+}
--- /dev/null
+//
+// Copyright 2015, Sander van Harmelen
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// +build !windows
+
+package sshagent
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "os"
+
+ "golang.org/x/crypto/ssh/agent"
+)
+
+// New returns a new agent.Agent that uses a unix socket
+func New() (agent.Agent, net.Conn, error) {
+ if !Available() {
+ return nil, nil, errors.New("SSH agent requested but SSH_AUTH_SOCK not-specified")
+ }
+
+ sshAuthSock := os.Getenv("SSH_AUTH_SOCK")
+
+ conn, err := net.Dial("unix", sshAuthSock)
+ if err != nil {
+ return nil, nil, fmt.Errorf("Error connecting to SSH_AUTH_SOCK: %v", err)
+ }
+
+ return agent.NewClient(conn), conn, nil
+}
+
+// Available returns true is a auth socket is defined
+func Available() bool {
+ return os.Getenv("SSH_AUTH_SOCK") != ""
+}
--- /dev/null
+//
+// Copyright (c) 2014 David Mzareulyan
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy of this software
+// and associated documentation files (the "Software"), to deal in the Software without restriction,
+// including without limitation the rights to use, copy, modify, merge, publish, distribute,
+// sublicense, and/or sell copies of the Software, and to permit persons to whom the Software
+// is furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in all copies or substantial
+// portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
+// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+//
+
+// +build windows
+
+package sshagent
+
+import (
+ "errors"
+ "io"
+ "net"
+ "sync"
+
+ "golang.org/x/crypto/ssh/agent"
+)
+
+// New returns a new agent.Agent and the (custom) connection it uses
+// to communicate with a running pagent.exe instance (see README.md)
+func New() (agent.Agent, net.Conn, error) {
+ if !Available() {
+ return nil, nil, errors.New("SSH agent requested but Pageant not running")
+ }
+
+ return agent.NewClient(&conn{}), nil, nil
+}
+
+type conn struct {
+ sync.Mutex
+ buf []byte
+}
+
+func (c *conn) Close() {
+ c.Lock()
+ defer c.Unlock()
+ c.buf = nil
+}
+
+func (c *conn) Write(p []byte) (int, error) {
+ c.Lock()
+ defer c.Unlock()
+
+ resp, err := query(p)
+ if err != nil {
+ return 0, err
+ }
+
+ c.buf = append(c.buf, resp...)
+
+ return len(p), nil
+}
+
+func (c *conn) Read(p []byte) (int, error) {
+ c.Lock()
+ defer c.Unlock()
+
+ if len(c.buf) == 0 {
+ return 0, io.EOF
+ }
+
+ n := copy(p, c.buf)
+ c.buf = c.buf[n:]
+
+ return n, nil
+}
--- /dev/null
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package cast5 implements CAST5, as defined in RFC 2144. CAST5 is a common
+// OpenPGP cipher.
+package cast5 // import "golang.org/x/crypto/cast5"
+
+import "errors"
+
+const BlockSize = 8
+const KeySize = 16
+
+type Cipher struct {
+ masking [16]uint32
+ rotate [16]uint8
+}
+
+func NewCipher(key []byte) (c *Cipher, err error) {
+ if len(key) != KeySize {
+ return nil, errors.New("CAST5: keys must be 16 bytes")
+ }
+
+ c = new(Cipher)
+ c.keySchedule(key)
+ return
+}
+
+func (c *Cipher) BlockSize() int {
+ return BlockSize
+}
+
+func (c *Cipher) Encrypt(dst, src []byte) {
+ l := uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3])
+ r := uint32(src[4])<<24 | uint32(src[5])<<16 | uint32(src[6])<<8 | uint32(src[7])
+
+ l, r = r, l^f1(r, c.masking[0], c.rotate[0])
+ l, r = r, l^f2(r, c.masking[1], c.rotate[1])
+ l, r = r, l^f3(r, c.masking[2], c.rotate[2])
+ l, r = r, l^f1(r, c.masking[3], c.rotate[3])
+
+ l, r = r, l^f2(r, c.masking[4], c.rotate[4])
+ l, r = r, l^f3(r, c.masking[5], c.rotate[5])
+ l, r = r, l^f1(r, c.masking[6], c.rotate[6])
+ l, r = r, l^f2(r, c.masking[7], c.rotate[7])
+
+ l, r = r, l^f3(r, c.masking[8], c.rotate[8])
+ l, r = r, l^f1(r, c.masking[9], c.rotate[9])
+ l, r = r, l^f2(r, c.masking[10], c.rotate[10])
+ l, r = r, l^f3(r, c.masking[11], c.rotate[11])
+
+ l, r = r, l^f1(r, c.masking[12], c.rotate[12])
+ l, r = r, l^f2(r, c.masking[13], c.rotate[13])
+ l, r = r, l^f3(r, c.masking[14], c.rotate[14])
+ l, r = r, l^f1(r, c.masking[15], c.rotate[15])
+
+ dst[0] = uint8(r >> 24)
+ dst[1] = uint8(r >> 16)
+ dst[2] = uint8(r >> 8)
+ dst[3] = uint8(r)
+ dst[4] = uint8(l >> 24)
+ dst[5] = uint8(l >> 16)
+ dst[6] = uint8(l >> 8)
+ dst[7] = uint8(l)
+}
+
+func (c *Cipher) Decrypt(dst, src []byte) {
+ l := uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3])
+ r := uint32(src[4])<<24 | uint32(src[5])<<16 | uint32(src[6])<<8 | uint32(src[7])
+
+ l, r = r, l^f1(r, c.masking[15], c.rotate[15])
+ l, r = r, l^f3(r, c.masking[14], c.rotate[14])
+ l, r = r, l^f2(r, c.masking[13], c.rotate[13])
+ l, r = r, l^f1(r, c.masking[12], c.rotate[12])
+
+ l, r = r, l^f3(r, c.masking[11], c.rotate[11])
+ l, r = r, l^f2(r, c.masking[10], c.rotate[10])
+ l, r = r, l^f1(r, c.masking[9], c.rotate[9])
+ l, r = r, l^f3(r, c.masking[8], c.rotate[8])
+
+ l, r = r, l^f2(r, c.masking[7], c.rotate[7])
+ l, r = r, l^f1(r, c.masking[6], c.rotate[6])
+ l, r = r, l^f3(r, c.masking[5], c.rotate[5])
+ l, r = r, l^f2(r, c.masking[4], c.rotate[4])
+
+ l, r = r, l^f1(r, c.masking[3], c.rotate[3])
+ l, r = r, l^f3(r, c.masking[2], c.rotate[2])
+ l, r = r, l^f2(r, c.masking[1], c.rotate[1])
+ l, r = r, l^f1(r, c.masking[0], c.rotate[0])
+
+ dst[0] = uint8(r >> 24)
+ dst[1] = uint8(r >> 16)
+ dst[2] = uint8(r >> 8)
+ dst[3] = uint8(r)
+ dst[4] = uint8(l >> 24)
+ dst[5] = uint8(l >> 16)
+ dst[6] = uint8(l >> 8)
+ dst[7] = uint8(l)
+}
+
+type keyScheduleA [4][7]uint8
+type keyScheduleB [4][5]uint8
+
+// keyScheduleRound contains the magic values for a round of the key schedule.
+// The keyScheduleA deals with the lines like:
+// z0z1z2z3 = x0x1x2x3 ^ S5[xD] ^ S6[xF] ^ S7[xC] ^ S8[xE] ^ S7[x8]
+// Conceptually, both x and z are in the same array, x first. The first
+// element describes which word of this array gets written to and the
+// second, which word gets read. So, for the line above, it's "4, 0", because
+// it's writing to the first word of z, which, being after x, is word 4, and
+// reading from the first word of x: word 0.
+//
+// Next are the indexes into the S-boxes. Now the array is treated as bytes. So
+// "xD" is 0xd. The first byte of z is written as "16 + 0", just to be clear
+// that it's z that we're indexing.
+//
+// keyScheduleB deals with lines like:
+// K1 = S5[z8] ^ S6[z9] ^ S7[z7] ^ S8[z6] ^ S5[z2]
+// "K1" is ignored because key words are always written in order. So the five
+// elements are the S-box indexes. They use the same form as in keyScheduleA,
+// above.
+
+type keyScheduleRound struct{}
+type keySchedule []keyScheduleRound
+
+var schedule = []struct {
+ a keyScheduleA
+ b keyScheduleB
+}{
+ {
+ keyScheduleA{
+ {4, 0, 0xd, 0xf, 0xc, 0xe, 0x8},
+ {5, 2, 16 + 0, 16 + 2, 16 + 1, 16 + 3, 0xa},
+ {6, 3, 16 + 7, 16 + 6, 16 + 5, 16 + 4, 9},
+ {7, 1, 16 + 0xa, 16 + 9, 16 + 0xb, 16 + 8, 0xb},
+ },
+ keyScheduleB{
+ {16 + 8, 16 + 9, 16 + 7, 16 + 6, 16 + 2},
+ {16 + 0xa, 16 + 0xb, 16 + 5, 16 + 4, 16 + 6},
+ {16 + 0xc, 16 + 0xd, 16 + 3, 16 + 2, 16 + 9},
+ {16 + 0xe, 16 + 0xf, 16 + 1, 16 + 0, 16 + 0xc},
+ },
+ },
+ {
+ keyScheduleA{
+ {0, 6, 16 + 5, 16 + 7, 16 + 4, 16 + 6, 16 + 0},
+ {1, 4, 0, 2, 1, 3, 16 + 2},
+ {2, 5, 7, 6, 5, 4, 16 + 1},
+ {3, 7, 0xa, 9, 0xb, 8, 16 + 3},
+ },
+ keyScheduleB{
+ {3, 2, 0xc, 0xd, 8},
+ {1, 0, 0xe, 0xf, 0xd},
+ {7, 6, 8, 9, 3},
+ {5, 4, 0xa, 0xb, 7},
+ },
+ },
+ {
+ keyScheduleA{
+ {4, 0, 0xd, 0xf, 0xc, 0xe, 8},
+ {5, 2, 16 + 0, 16 + 2, 16 + 1, 16 + 3, 0xa},
+ {6, 3, 16 + 7, 16 + 6, 16 + 5, 16 + 4, 9},
+ {7, 1, 16 + 0xa, 16 + 9, 16 + 0xb, 16 + 8, 0xb},
+ },
+ keyScheduleB{
+ {16 + 3, 16 + 2, 16 + 0xc, 16 + 0xd, 16 + 9},
+ {16 + 1, 16 + 0, 16 + 0xe, 16 + 0xf, 16 + 0xc},
+ {16 + 7, 16 + 6, 16 + 8, 16 + 9, 16 + 2},
+ {16 + 5, 16 + 4, 16 + 0xa, 16 + 0xb, 16 + 6},
+ },
+ },
+ {
+ keyScheduleA{
+ {0, 6, 16 + 5, 16 + 7, 16 + 4, 16 + 6, 16 + 0},
+ {1, 4, 0, 2, 1, 3, 16 + 2},
+ {2, 5, 7, 6, 5, 4, 16 + 1},
+ {3, 7, 0xa, 9, 0xb, 8, 16 + 3},
+ },
+ keyScheduleB{
+ {8, 9, 7, 6, 3},
+ {0xa, 0xb, 5, 4, 7},
+ {0xc, 0xd, 3, 2, 8},
+ {0xe, 0xf, 1, 0, 0xd},
+ },
+ },
+}
+
+func (c *Cipher) keySchedule(in []byte) {
+ var t [8]uint32
+ var k [32]uint32
+
+ for i := 0; i < 4; i++ {
+ j := i * 4
+ t[i] = uint32(in[j])<<24 | uint32(in[j+1])<<16 | uint32(in[j+2])<<8 | uint32(in[j+3])
+ }
+
+ x := []byte{6, 7, 4, 5}
+ ki := 0
+
+ for half := 0; half < 2; half++ {
+ for _, round := range schedule {
+ for j := 0; j < 4; j++ {
+ var a [7]uint8
+ copy(a[:], round.a[j][:])
+ w := t[a[1]]
+ w ^= sBox[4][(t[a[2]>>2]>>(24-8*(a[2]&3)))&0xff]
+ w ^= sBox[5][(t[a[3]>>2]>>(24-8*(a[3]&3)))&0xff]
+ w ^= sBox[6][(t[a[4]>>2]>>(24-8*(a[4]&3)))&0xff]
+ w ^= sBox[7][(t[a[5]>>2]>>(24-8*(a[5]&3)))&0xff]
+ w ^= sBox[x[j]][(t[a[6]>>2]>>(24-8*(a[6]&3)))&0xff]
+ t[a[0]] = w
+ }
+
+ for j := 0; j < 4; j++ {
+ var b [5]uint8
+ copy(b[:], round.b[j][:])
+ w := sBox[4][(t[b[0]>>2]>>(24-8*(b[0]&3)))&0xff]
+ w ^= sBox[5][(t[b[1]>>2]>>(24-8*(b[1]&3)))&0xff]
+ w ^= sBox[6][(t[b[2]>>2]>>(24-8*(b[2]&3)))&0xff]
+ w ^= sBox[7][(t[b[3]>>2]>>(24-8*(b[3]&3)))&0xff]
+ w ^= sBox[4+j][(t[b[4]>>2]>>(24-8*(b[4]&3)))&0xff]
+ k[ki] = w
+ ki++
+ }
+ }
+ }
+
+ for i := 0; i < 16; i++ {
+ c.masking[i] = k[i]
+ c.rotate[i] = uint8(k[16+i] & 0x1f)
+ }
+}
+
+// These are the three 'f' functions. See RFC 2144, section 2.2.
+func f1(d, m uint32, r uint8) uint32 {
+ t := m + d
+ I := (t << r) | (t >> (32 - r))
+ return ((sBox[0][I>>24] ^ sBox[1][(I>>16)&0xff]) - sBox[2][(I>>8)&0xff]) + sBox[3][I&0xff]
+}
+
+func f2(d, m uint32, r uint8) uint32 {
+ t := m ^ d
+ I := (t << r) | (t >> (32 - r))
+ return ((sBox[0][I>>24] - sBox[1][(I>>16)&0xff]) + sBox[2][(I>>8)&0xff]) ^ sBox[3][I&0xff]
+}
+
+func f3(d, m uint32, r uint8) uint32 {
+ t := m - d
+ I := (t << r) | (t >> (32 - r))
+ return ((sBox[0][I>>24] + sBox[1][(I>>16)&0xff]) ^ sBox[2][(I>>8)&0xff]) - sBox[3][I&0xff]
+}
+
+var sBox = [8][256]uint32{
+ {
+ 0x30fb40d4, 0x9fa0ff0b, 0x6beccd2f, 0x3f258c7a, 0x1e213f2f, 0x9c004dd3, 0x6003e540, 0xcf9fc949,
+ 0xbfd4af27, 0x88bbbdb5, 0xe2034090, 0x98d09675, 0x6e63a0e0, 0x15c361d2, 0xc2e7661d, 0x22d4ff8e,
+ 0x28683b6f, 0xc07fd059, 0xff2379c8, 0x775f50e2, 0x43c340d3, 0xdf2f8656, 0x887ca41a, 0xa2d2bd2d,
+ 0xa1c9e0d6, 0x346c4819, 0x61b76d87, 0x22540f2f, 0x2abe32e1, 0xaa54166b, 0x22568e3a, 0xa2d341d0,
+ 0x66db40c8, 0xa784392f, 0x004dff2f, 0x2db9d2de, 0x97943fac, 0x4a97c1d8, 0x527644b7, 0xb5f437a7,
+ 0xb82cbaef, 0xd751d159, 0x6ff7f0ed, 0x5a097a1f, 0x827b68d0, 0x90ecf52e, 0x22b0c054, 0xbc8e5935,
+ 0x4b6d2f7f, 0x50bb64a2, 0xd2664910, 0xbee5812d, 0xb7332290, 0xe93b159f, 0xb48ee411, 0x4bff345d,
+ 0xfd45c240, 0xad31973f, 0xc4f6d02e, 0x55fc8165, 0xd5b1caad, 0xa1ac2dae, 0xa2d4b76d, 0xc19b0c50,
+ 0x882240f2, 0x0c6e4f38, 0xa4e4bfd7, 0x4f5ba272, 0x564c1d2f, 0xc59c5319, 0xb949e354, 0xb04669fe,
+ 0xb1b6ab8a, 0xc71358dd, 0x6385c545, 0x110f935d, 0x57538ad5, 0x6a390493, 0xe63d37e0, 0x2a54f6b3,
+ 0x3a787d5f, 0x6276a0b5, 0x19a6fcdf, 0x7a42206a, 0x29f9d4d5, 0xf61b1891, 0xbb72275e, 0xaa508167,
+ 0x38901091, 0xc6b505eb, 0x84c7cb8c, 0x2ad75a0f, 0x874a1427, 0xa2d1936b, 0x2ad286af, 0xaa56d291,
+ 0xd7894360, 0x425c750d, 0x93b39e26, 0x187184c9, 0x6c00b32d, 0x73e2bb14, 0xa0bebc3c, 0x54623779,
+ 0x64459eab, 0x3f328b82, 0x7718cf82, 0x59a2cea6, 0x04ee002e, 0x89fe78e6, 0x3fab0950, 0x325ff6c2,
+ 0x81383f05, 0x6963c5c8, 0x76cb5ad6, 0xd49974c9, 0xca180dcf, 0x380782d5, 0xc7fa5cf6, 0x8ac31511,
+ 0x35e79e13, 0x47da91d0, 0xf40f9086, 0xa7e2419e, 0x31366241, 0x051ef495, 0xaa573b04, 0x4a805d8d,
+ 0x548300d0, 0x00322a3c, 0xbf64cddf, 0xba57a68e, 0x75c6372b, 0x50afd341, 0xa7c13275, 0x915a0bf5,
+ 0x6b54bfab, 0x2b0b1426, 0xab4cc9d7, 0x449ccd82, 0xf7fbf265, 0xab85c5f3, 0x1b55db94, 0xaad4e324,
+ 0xcfa4bd3f, 0x2deaa3e2, 0x9e204d02, 0xc8bd25ac, 0xeadf55b3, 0xd5bd9e98, 0xe31231b2, 0x2ad5ad6c,
+ 0x954329de, 0xadbe4528, 0xd8710f69, 0xaa51c90f, 0xaa786bf6, 0x22513f1e, 0xaa51a79b, 0x2ad344cc,
+ 0x7b5a41f0, 0xd37cfbad, 0x1b069505, 0x41ece491, 0xb4c332e6, 0x032268d4, 0xc9600acc, 0xce387e6d,
+ 0xbf6bb16c, 0x6a70fb78, 0x0d03d9c9, 0xd4df39de, 0xe01063da, 0x4736f464, 0x5ad328d8, 0xb347cc96,
+ 0x75bb0fc3, 0x98511bfb, 0x4ffbcc35, 0xb58bcf6a, 0xe11f0abc, 0xbfc5fe4a, 0xa70aec10, 0xac39570a,
+ 0x3f04442f, 0x6188b153, 0xe0397a2e, 0x5727cb79, 0x9ceb418f, 0x1cacd68d, 0x2ad37c96, 0x0175cb9d,
+ 0xc69dff09, 0xc75b65f0, 0xd9db40d8, 0xec0e7779, 0x4744ead4, 0xb11c3274, 0xdd24cb9e, 0x7e1c54bd,
+ 0xf01144f9, 0xd2240eb1, 0x9675b3fd, 0xa3ac3755, 0xd47c27af, 0x51c85f4d, 0x56907596, 0xa5bb15e6,
+ 0x580304f0, 0xca042cf1, 0x011a37ea, 0x8dbfaadb, 0x35ba3e4a, 0x3526ffa0, 0xc37b4d09, 0xbc306ed9,
+ 0x98a52666, 0x5648f725, 0xff5e569d, 0x0ced63d0, 0x7c63b2cf, 0x700b45e1, 0xd5ea50f1, 0x85a92872,
+ 0xaf1fbda7, 0xd4234870, 0xa7870bf3, 0x2d3b4d79, 0x42e04198, 0x0cd0ede7, 0x26470db8, 0xf881814c,
+ 0x474d6ad7, 0x7c0c5e5c, 0xd1231959, 0x381b7298, 0xf5d2f4db, 0xab838653, 0x6e2f1e23, 0x83719c9e,
+ 0xbd91e046, 0x9a56456e, 0xdc39200c, 0x20c8c571, 0x962bda1c, 0xe1e696ff, 0xb141ab08, 0x7cca89b9,
+ 0x1a69e783, 0x02cc4843, 0xa2f7c579, 0x429ef47d, 0x427b169c, 0x5ac9f049, 0xdd8f0f00, 0x5c8165bf,
+ },
+ {
+ 0x1f201094, 0xef0ba75b, 0x69e3cf7e, 0x393f4380, 0xfe61cf7a, 0xeec5207a, 0x55889c94, 0x72fc0651,
+ 0xada7ef79, 0x4e1d7235, 0xd55a63ce, 0xde0436ba, 0x99c430ef, 0x5f0c0794, 0x18dcdb7d, 0xa1d6eff3,
+ 0xa0b52f7b, 0x59e83605, 0xee15b094, 0xe9ffd909, 0xdc440086, 0xef944459, 0xba83ccb3, 0xe0c3cdfb,
+ 0xd1da4181, 0x3b092ab1, 0xf997f1c1, 0xa5e6cf7b, 0x01420ddb, 0xe4e7ef5b, 0x25a1ff41, 0xe180f806,
+ 0x1fc41080, 0x179bee7a, 0xd37ac6a9, 0xfe5830a4, 0x98de8b7f, 0x77e83f4e, 0x79929269, 0x24fa9f7b,
+ 0xe113c85b, 0xacc40083, 0xd7503525, 0xf7ea615f, 0x62143154, 0x0d554b63, 0x5d681121, 0xc866c359,
+ 0x3d63cf73, 0xcee234c0, 0xd4d87e87, 0x5c672b21, 0x071f6181, 0x39f7627f, 0x361e3084, 0xe4eb573b,
+ 0x602f64a4, 0xd63acd9c, 0x1bbc4635, 0x9e81032d, 0x2701f50c, 0x99847ab4, 0xa0e3df79, 0xba6cf38c,
+ 0x10843094, 0x2537a95e, 0xf46f6ffe, 0xa1ff3b1f, 0x208cfb6a, 0x8f458c74, 0xd9e0a227, 0x4ec73a34,
+ 0xfc884f69, 0x3e4de8df, 0xef0e0088, 0x3559648d, 0x8a45388c, 0x1d804366, 0x721d9bfd, 0xa58684bb,
+ 0xe8256333, 0x844e8212, 0x128d8098, 0xfed33fb4, 0xce280ae1, 0x27e19ba5, 0xd5a6c252, 0xe49754bd,
+ 0xc5d655dd, 0xeb667064, 0x77840b4d, 0xa1b6a801, 0x84db26a9, 0xe0b56714, 0x21f043b7, 0xe5d05860,
+ 0x54f03084, 0x066ff472, 0xa31aa153, 0xdadc4755, 0xb5625dbf, 0x68561be6, 0x83ca6b94, 0x2d6ed23b,
+ 0xeccf01db, 0xa6d3d0ba, 0xb6803d5c, 0xaf77a709, 0x33b4a34c, 0x397bc8d6, 0x5ee22b95, 0x5f0e5304,
+ 0x81ed6f61, 0x20e74364, 0xb45e1378, 0xde18639b, 0x881ca122, 0xb96726d1, 0x8049a7e8, 0x22b7da7b,
+ 0x5e552d25, 0x5272d237, 0x79d2951c, 0xc60d894c, 0x488cb402, 0x1ba4fe5b, 0xa4b09f6b, 0x1ca815cf,
+ 0xa20c3005, 0x8871df63, 0xb9de2fcb, 0x0cc6c9e9, 0x0beeff53, 0xe3214517, 0xb4542835, 0x9f63293c,
+ 0xee41e729, 0x6e1d2d7c, 0x50045286, 0x1e6685f3, 0xf33401c6, 0x30a22c95, 0x31a70850, 0x60930f13,
+ 0x73f98417, 0xa1269859, 0xec645c44, 0x52c877a9, 0xcdff33a6, 0xa02b1741, 0x7cbad9a2, 0x2180036f,
+ 0x50d99c08, 0xcb3f4861, 0xc26bd765, 0x64a3f6ab, 0x80342676, 0x25a75e7b, 0xe4e6d1fc, 0x20c710e6,
+ 0xcdf0b680, 0x17844d3b, 0x31eef84d, 0x7e0824e4, 0x2ccb49eb, 0x846a3bae, 0x8ff77888, 0xee5d60f6,
+ 0x7af75673, 0x2fdd5cdb, 0xa11631c1, 0x30f66f43, 0xb3faec54, 0x157fd7fa, 0xef8579cc, 0xd152de58,
+ 0xdb2ffd5e, 0x8f32ce19, 0x306af97a, 0x02f03ef8, 0x99319ad5, 0xc242fa0f, 0xa7e3ebb0, 0xc68e4906,
+ 0xb8da230c, 0x80823028, 0xdcdef3c8, 0xd35fb171, 0x088a1bc8, 0xbec0c560, 0x61a3c9e8, 0xbca8f54d,
+ 0xc72feffa, 0x22822e99, 0x82c570b4, 0xd8d94e89, 0x8b1c34bc, 0x301e16e6, 0x273be979, 0xb0ffeaa6,
+ 0x61d9b8c6, 0x00b24869, 0xb7ffce3f, 0x08dc283b, 0x43daf65a, 0xf7e19798, 0x7619b72f, 0x8f1c9ba4,
+ 0xdc8637a0, 0x16a7d3b1, 0x9fc393b7, 0xa7136eeb, 0xc6bcc63e, 0x1a513742, 0xef6828bc, 0x520365d6,
+ 0x2d6a77ab, 0x3527ed4b, 0x821fd216, 0x095c6e2e, 0xdb92f2fb, 0x5eea29cb, 0x145892f5, 0x91584f7f,
+ 0x5483697b, 0x2667a8cc, 0x85196048, 0x8c4bacea, 0x833860d4, 0x0d23e0f9, 0x6c387e8a, 0x0ae6d249,
+ 0xb284600c, 0xd835731d, 0xdcb1c647, 0xac4c56ea, 0x3ebd81b3, 0x230eabb0, 0x6438bc87, 0xf0b5b1fa,
+ 0x8f5ea2b3, 0xfc184642, 0x0a036b7a, 0x4fb089bd, 0x649da589, 0xa345415e, 0x5c038323, 0x3e5d3bb9,
+ 0x43d79572, 0x7e6dd07c, 0x06dfdf1e, 0x6c6cc4ef, 0x7160a539, 0x73bfbe70, 0x83877605, 0x4523ecf1,
+ },
+ {
+ 0x8defc240, 0x25fa5d9f, 0xeb903dbf, 0xe810c907, 0x47607fff, 0x369fe44b, 0x8c1fc644, 0xaececa90,
+ 0xbeb1f9bf, 0xeefbcaea, 0xe8cf1950, 0x51df07ae, 0x920e8806, 0xf0ad0548, 0xe13c8d83, 0x927010d5,
+ 0x11107d9f, 0x07647db9, 0xb2e3e4d4, 0x3d4f285e, 0xb9afa820, 0xfade82e0, 0xa067268b, 0x8272792e,
+ 0x553fb2c0, 0x489ae22b, 0xd4ef9794, 0x125e3fbc, 0x21fffcee, 0x825b1bfd, 0x9255c5ed, 0x1257a240,
+ 0x4e1a8302, 0xbae07fff, 0x528246e7, 0x8e57140e, 0x3373f7bf, 0x8c9f8188, 0xa6fc4ee8, 0xc982b5a5,
+ 0xa8c01db7, 0x579fc264, 0x67094f31, 0xf2bd3f5f, 0x40fff7c1, 0x1fb78dfc, 0x8e6bd2c1, 0x437be59b,
+ 0x99b03dbf, 0xb5dbc64b, 0x638dc0e6, 0x55819d99, 0xa197c81c, 0x4a012d6e, 0xc5884a28, 0xccc36f71,
+ 0xb843c213, 0x6c0743f1, 0x8309893c, 0x0feddd5f, 0x2f7fe850, 0xd7c07f7e, 0x02507fbf, 0x5afb9a04,
+ 0xa747d2d0, 0x1651192e, 0xaf70bf3e, 0x58c31380, 0x5f98302e, 0x727cc3c4, 0x0a0fb402, 0x0f7fef82,
+ 0x8c96fdad, 0x5d2c2aae, 0x8ee99a49, 0x50da88b8, 0x8427f4a0, 0x1eac5790, 0x796fb449, 0x8252dc15,
+ 0xefbd7d9b, 0xa672597d, 0xada840d8, 0x45f54504, 0xfa5d7403, 0xe83ec305, 0x4f91751a, 0x925669c2,
+ 0x23efe941, 0xa903f12e, 0x60270df2, 0x0276e4b6, 0x94fd6574, 0x927985b2, 0x8276dbcb, 0x02778176,
+ 0xf8af918d, 0x4e48f79e, 0x8f616ddf, 0xe29d840e, 0x842f7d83, 0x340ce5c8, 0x96bbb682, 0x93b4b148,
+ 0xef303cab, 0x984faf28, 0x779faf9b, 0x92dc560d, 0x224d1e20, 0x8437aa88, 0x7d29dc96, 0x2756d3dc,
+ 0x8b907cee, 0xb51fd240, 0xe7c07ce3, 0xe566b4a1, 0xc3e9615e, 0x3cf8209d, 0x6094d1e3, 0xcd9ca341,
+ 0x5c76460e, 0x00ea983b, 0xd4d67881, 0xfd47572c, 0xf76cedd9, 0xbda8229c, 0x127dadaa, 0x438a074e,
+ 0x1f97c090, 0x081bdb8a, 0x93a07ebe, 0xb938ca15, 0x97b03cff, 0x3dc2c0f8, 0x8d1ab2ec, 0x64380e51,
+ 0x68cc7bfb, 0xd90f2788, 0x12490181, 0x5de5ffd4, 0xdd7ef86a, 0x76a2e214, 0xb9a40368, 0x925d958f,
+ 0x4b39fffa, 0xba39aee9, 0xa4ffd30b, 0xfaf7933b, 0x6d498623, 0x193cbcfa, 0x27627545, 0x825cf47a,
+ 0x61bd8ba0, 0xd11e42d1, 0xcead04f4, 0x127ea392, 0x10428db7, 0x8272a972, 0x9270c4a8, 0x127de50b,
+ 0x285ba1c8, 0x3c62f44f, 0x35c0eaa5, 0xe805d231, 0x428929fb, 0xb4fcdf82, 0x4fb66a53, 0x0e7dc15b,
+ 0x1f081fab, 0x108618ae, 0xfcfd086d, 0xf9ff2889, 0x694bcc11, 0x236a5cae, 0x12deca4d, 0x2c3f8cc5,
+ 0xd2d02dfe, 0xf8ef5896, 0xe4cf52da, 0x95155b67, 0x494a488c, 0xb9b6a80c, 0x5c8f82bc, 0x89d36b45,
+ 0x3a609437, 0xec00c9a9, 0x44715253, 0x0a874b49, 0xd773bc40, 0x7c34671c, 0x02717ef6, 0x4feb5536,
+ 0xa2d02fff, 0xd2bf60c4, 0xd43f03c0, 0x50b4ef6d, 0x07478cd1, 0x006e1888, 0xa2e53f55, 0xb9e6d4bc,
+ 0xa2048016, 0x97573833, 0xd7207d67, 0xde0f8f3d, 0x72f87b33, 0xabcc4f33, 0x7688c55d, 0x7b00a6b0,
+ 0x947b0001, 0x570075d2, 0xf9bb88f8, 0x8942019e, 0x4264a5ff, 0x856302e0, 0x72dbd92b, 0xee971b69,
+ 0x6ea22fde, 0x5f08ae2b, 0xaf7a616d, 0xe5c98767, 0xcf1febd2, 0x61efc8c2, 0xf1ac2571, 0xcc8239c2,
+ 0x67214cb8, 0xb1e583d1, 0xb7dc3e62, 0x7f10bdce, 0xf90a5c38, 0x0ff0443d, 0x606e6dc6, 0x60543a49,
+ 0x5727c148, 0x2be98a1d, 0x8ab41738, 0x20e1be24, 0xaf96da0f, 0x68458425, 0x99833be5, 0x600d457d,
+ 0x282f9350, 0x8334b362, 0xd91d1120, 0x2b6d8da0, 0x642b1e31, 0x9c305a00, 0x52bce688, 0x1b03588a,
+ 0xf7baefd5, 0x4142ed9c, 0xa4315c11, 0x83323ec5, 0xdfef4636, 0xa133c501, 0xe9d3531c, 0xee353783,
+ },
+ {
+ 0x9db30420, 0x1fb6e9de, 0xa7be7bef, 0xd273a298, 0x4a4f7bdb, 0x64ad8c57, 0x85510443, 0xfa020ed1,
+ 0x7e287aff, 0xe60fb663, 0x095f35a1, 0x79ebf120, 0xfd059d43, 0x6497b7b1, 0xf3641f63, 0x241e4adf,
+ 0x28147f5f, 0x4fa2b8cd, 0xc9430040, 0x0cc32220, 0xfdd30b30, 0xc0a5374f, 0x1d2d00d9, 0x24147b15,
+ 0xee4d111a, 0x0fca5167, 0x71ff904c, 0x2d195ffe, 0x1a05645f, 0x0c13fefe, 0x081b08ca, 0x05170121,
+ 0x80530100, 0xe83e5efe, 0xac9af4f8, 0x7fe72701, 0xd2b8ee5f, 0x06df4261, 0xbb9e9b8a, 0x7293ea25,
+ 0xce84ffdf, 0xf5718801, 0x3dd64b04, 0xa26f263b, 0x7ed48400, 0x547eebe6, 0x446d4ca0, 0x6cf3d6f5,
+ 0x2649abdf, 0xaea0c7f5, 0x36338cc1, 0x503f7e93, 0xd3772061, 0x11b638e1, 0x72500e03, 0xf80eb2bb,
+ 0xabe0502e, 0xec8d77de, 0x57971e81, 0xe14f6746, 0xc9335400, 0x6920318f, 0x081dbb99, 0xffc304a5,
+ 0x4d351805, 0x7f3d5ce3, 0xa6c866c6, 0x5d5bcca9, 0xdaec6fea, 0x9f926f91, 0x9f46222f, 0x3991467d,
+ 0xa5bf6d8e, 0x1143c44f, 0x43958302, 0xd0214eeb, 0x022083b8, 0x3fb6180c, 0x18f8931e, 0x281658e6,
+ 0x26486e3e, 0x8bd78a70, 0x7477e4c1, 0xb506e07c, 0xf32d0a25, 0x79098b02, 0xe4eabb81, 0x28123b23,
+ 0x69dead38, 0x1574ca16, 0xdf871b62, 0x211c40b7, 0xa51a9ef9, 0x0014377b, 0x041e8ac8, 0x09114003,
+ 0xbd59e4d2, 0xe3d156d5, 0x4fe876d5, 0x2f91a340, 0x557be8de, 0x00eae4a7, 0x0ce5c2ec, 0x4db4bba6,
+ 0xe756bdff, 0xdd3369ac, 0xec17b035, 0x06572327, 0x99afc8b0, 0x56c8c391, 0x6b65811c, 0x5e146119,
+ 0x6e85cb75, 0xbe07c002, 0xc2325577, 0x893ff4ec, 0x5bbfc92d, 0xd0ec3b25, 0xb7801ab7, 0x8d6d3b24,
+ 0x20c763ef, 0xc366a5fc, 0x9c382880, 0x0ace3205, 0xaac9548a, 0xeca1d7c7, 0x041afa32, 0x1d16625a,
+ 0x6701902c, 0x9b757a54, 0x31d477f7, 0x9126b031, 0x36cc6fdb, 0xc70b8b46, 0xd9e66a48, 0x56e55a79,
+ 0x026a4ceb, 0x52437eff, 0x2f8f76b4, 0x0df980a5, 0x8674cde3, 0xedda04eb, 0x17a9be04, 0x2c18f4df,
+ 0xb7747f9d, 0xab2af7b4, 0xefc34d20, 0x2e096b7c, 0x1741a254, 0xe5b6a035, 0x213d42f6, 0x2c1c7c26,
+ 0x61c2f50f, 0x6552daf9, 0xd2c231f8, 0x25130f69, 0xd8167fa2, 0x0418f2c8, 0x001a96a6, 0x0d1526ab,
+ 0x63315c21, 0x5e0a72ec, 0x49bafefd, 0x187908d9, 0x8d0dbd86, 0x311170a7, 0x3e9b640c, 0xcc3e10d7,
+ 0xd5cad3b6, 0x0caec388, 0xf73001e1, 0x6c728aff, 0x71eae2a1, 0x1f9af36e, 0xcfcbd12f, 0xc1de8417,
+ 0xac07be6b, 0xcb44a1d8, 0x8b9b0f56, 0x013988c3, 0xb1c52fca, 0xb4be31cd, 0xd8782806, 0x12a3a4e2,
+ 0x6f7de532, 0x58fd7eb6, 0xd01ee900, 0x24adffc2, 0xf4990fc5, 0x9711aac5, 0x001d7b95, 0x82e5e7d2,
+ 0x109873f6, 0x00613096, 0xc32d9521, 0xada121ff, 0x29908415, 0x7fbb977f, 0xaf9eb3db, 0x29c9ed2a,
+ 0x5ce2a465, 0xa730f32c, 0xd0aa3fe8, 0x8a5cc091, 0xd49e2ce7, 0x0ce454a9, 0xd60acd86, 0x015f1919,
+ 0x77079103, 0xdea03af6, 0x78a8565e, 0xdee356df, 0x21f05cbe, 0x8b75e387, 0xb3c50651, 0xb8a5c3ef,
+ 0xd8eeb6d2, 0xe523be77, 0xc2154529, 0x2f69efdf, 0xafe67afb, 0xf470c4b2, 0xf3e0eb5b, 0xd6cc9876,
+ 0x39e4460c, 0x1fda8538, 0x1987832f, 0xca007367, 0xa99144f8, 0x296b299e, 0x492fc295, 0x9266beab,
+ 0xb5676e69, 0x9bd3ddda, 0xdf7e052f, 0xdb25701c, 0x1b5e51ee, 0xf65324e6, 0x6afce36c, 0x0316cc04,
+ 0x8644213e, 0xb7dc59d0, 0x7965291f, 0xccd6fd43, 0x41823979, 0x932bcdf6, 0xb657c34d, 0x4edfd282,
+ 0x7ae5290c, 0x3cb9536b, 0x851e20fe, 0x9833557e, 0x13ecf0b0, 0xd3ffb372, 0x3f85c5c1, 0x0aef7ed2,
+ },
+ {
+ 0x7ec90c04, 0x2c6e74b9, 0x9b0e66df, 0xa6337911, 0xb86a7fff, 0x1dd358f5, 0x44dd9d44, 0x1731167f,
+ 0x08fbf1fa, 0xe7f511cc, 0xd2051b00, 0x735aba00, 0x2ab722d8, 0x386381cb, 0xacf6243a, 0x69befd7a,
+ 0xe6a2e77f, 0xf0c720cd, 0xc4494816, 0xccf5c180, 0x38851640, 0x15b0a848, 0xe68b18cb, 0x4caadeff,
+ 0x5f480a01, 0x0412b2aa, 0x259814fc, 0x41d0efe2, 0x4e40b48d, 0x248eb6fb, 0x8dba1cfe, 0x41a99b02,
+ 0x1a550a04, 0xba8f65cb, 0x7251f4e7, 0x95a51725, 0xc106ecd7, 0x97a5980a, 0xc539b9aa, 0x4d79fe6a,
+ 0xf2f3f763, 0x68af8040, 0xed0c9e56, 0x11b4958b, 0xe1eb5a88, 0x8709e6b0, 0xd7e07156, 0x4e29fea7,
+ 0x6366e52d, 0x02d1c000, 0xc4ac8e05, 0x9377f571, 0x0c05372a, 0x578535f2, 0x2261be02, 0xd642a0c9,
+ 0xdf13a280, 0x74b55bd2, 0x682199c0, 0xd421e5ec, 0x53fb3ce8, 0xc8adedb3, 0x28a87fc9, 0x3d959981,
+ 0x5c1ff900, 0xfe38d399, 0x0c4eff0b, 0x062407ea, 0xaa2f4fb1, 0x4fb96976, 0x90c79505, 0xb0a8a774,
+ 0xef55a1ff, 0xe59ca2c2, 0xa6b62d27, 0xe66a4263, 0xdf65001f, 0x0ec50966, 0xdfdd55bc, 0x29de0655,
+ 0x911e739a, 0x17af8975, 0x32c7911c, 0x89f89468, 0x0d01e980, 0x524755f4, 0x03b63cc9, 0x0cc844b2,
+ 0xbcf3f0aa, 0x87ac36e9, 0xe53a7426, 0x01b3d82b, 0x1a9e7449, 0x64ee2d7e, 0xcddbb1da, 0x01c94910,
+ 0xb868bf80, 0x0d26f3fd, 0x9342ede7, 0x04a5c284, 0x636737b6, 0x50f5b616, 0xf24766e3, 0x8eca36c1,
+ 0x136e05db, 0xfef18391, 0xfb887a37, 0xd6e7f7d4, 0xc7fb7dc9, 0x3063fcdf, 0xb6f589de, 0xec2941da,
+ 0x26e46695, 0xb7566419, 0xf654efc5, 0xd08d58b7, 0x48925401, 0xc1bacb7f, 0xe5ff550f, 0xb6083049,
+ 0x5bb5d0e8, 0x87d72e5a, 0xab6a6ee1, 0x223a66ce, 0xc62bf3cd, 0x9e0885f9, 0x68cb3e47, 0x086c010f,
+ 0xa21de820, 0xd18b69de, 0xf3f65777, 0xfa02c3f6, 0x407edac3, 0xcbb3d550, 0x1793084d, 0xb0d70eba,
+ 0x0ab378d5, 0xd951fb0c, 0xded7da56, 0x4124bbe4, 0x94ca0b56, 0x0f5755d1, 0xe0e1e56e, 0x6184b5be,
+ 0x580a249f, 0x94f74bc0, 0xe327888e, 0x9f7b5561, 0xc3dc0280, 0x05687715, 0x646c6bd7, 0x44904db3,
+ 0x66b4f0a3, 0xc0f1648a, 0x697ed5af, 0x49e92ff6, 0x309e374f, 0x2cb6356a, 0x85808573, 0x4991f840,
+ 0x76f0ae02, 0x083be84d, 0x28421c9a, 0x44489406, 0x736e4cb8, 0xc1092910, 0x8bc95fc6, 0x7d869cf4,
+ 0x134f616f, 0x2e77118d, 0xb31b2be1, 0xaa90b472, 0x3ca5d717, 0x7d161bba, 0x9cad9010, 0xaf462ba2,
+ 0x9fe459d2, 0x45d34559, 0xd9f2da13, 0xdbc65487, 0xf3e4f94e, 0x176d486f, 0x097c13ea, 0x631da5c7,
+ 0x445f7382, 0x175683f4, 0xcdc66a97, 0x70be0288, 0xb3cdcf72, 0x6e5dd2f3, 0x20936079, 0x459b80a5,
+ 0xbe60e2db, 0xa9c23101, 0xeba5315c, 0x224e42f2, 0x1c5c1572, 0xf6721b2c, 0x1ad2fff3, 0x8c25404e,
+ 0x324ed72f, 0x4067b7fd, 0x0523138e, 0x5ca3bc78, 0xdc0fd66e, 0x75922283, 0x784d6b17, 0x58ebb16e,
+ 0x44094f85, 0x3f481d87, 0xfcfeae7b, 0x77b5ff76, 0x8c2302bf, 0xaaf47556, 0x5f46b02a, 0x2b092801,
+ 0x3d38f5f7, 0x0ca81f36, 0x52af4a8a, 0x66d5e7c0, 0xdf3b0874, 0x95055110, 0x1b5ad7a8, 0xf61ed5ad,
+ 0x6cf6e479, 0x20758184, 0xd0cefa65, 0x88f7be58, 0x4a046826, 0x0ff6f8f3, 0xa09c7f70, 0x5346aba0,
+ 0x5ce96c28, 0xe176eda3, 0x6bac307f, 0x376829d2, 0x85360fa9, 0x17e3fe2a, 0x24b79767, 0xf5a96b20,
+ 0xd6cd2595, 0x68ff1ebf, 0x7555442c, 0xf19f06be, 0xf9e0659a, 0xeeb9491d, 0x34010718, 0xbb30cab8,
+ 0xe822fe15, 0x88570983, 0x750e6249, 0xda627e55, 0x5e76ffa8, 0xb1534546, 0x6d47de08, 0xefe9e7d4,
+ },
+ {
+ 0xf6fa8f9d, 0x2cac6ce1, 0x4ca34867, 0xe2337f7c, 0x95db08e7, 0x016843b4, 0xeced5cbc, 0x325553ac,
+ 0xbf9f0960, 0xdfa1e2ed, 0x83f0579d, 0x63ed86b9, 0x1ab6a6b8, 0xde5ebe39, 0xf38ff732, 0x8989b138,
+ 0x33f14961, 0xc01937bd, 0xf506c6da, 0xe4625e7e, 0xa308ea99, 0x4e23e33c, 0x79cbd7cc, 0x48a14367,
+ 0xa3149619, 0xfec94bd5, 0xa114174a, 0xeaa01866, 0xa084db2d, 0x09a8486f, 0xa888614a, 0x2900af98,
+ 0x01665991, 0xe1992863, 0xc8f30c60, 0x2e78ef3c, 0xd0d51932, 0xcf0fec14, 0xf7ca07d2, 0xd0a82072,
+ 0xfd41197e, 0x9305a6b0, 0xe86be3da, 0x74bed3cd, 0x372da53c, 0x4c7f4448, 0xdab5d440, 0x6dba0ec3,
+ 0x083919a7, 0x9fbaeed9, 0x49dbcfb0, 0x4e670c53, 0x5c3d9c01, 0x64bdb941, 0x2c0e636a, 0xba7dd9cd,
+ 0xea6f7388, 0xe70bc762, 0x35f29adb, 0x5c4cdd8d, 0xf0d48d8c, 0xb88153e2, 0x08a19866, 0x1ae2eac8,
+ 0x284caf89, 0xaa928223, 0x9334be53, 0x3b3a21bf, 0x16434be3, 0x9aea3906, 0xefe8c36e, 0xf890cdd9,
+ 0x80226dae, 0xc340a4a3, 0xdf7e9c09, 0xa694a807, 0x5b7c5ecc, 0x221db3a6, 0x9a69a02f, 0x68818a54,
+ 0xceb2296f, 0x53c0843a, 0xfe893655, 0x25bfe68a, 0xb4628abc, 0xcf222ebf, 0x25ac6f48, 0xa9a99387,
+ 0x53bddb65, 0xe76ffbe7, 0xe967fd78, 0x0ba93563, 0x8e342bc1, 0xe8a11be9, 0x4980740d, 0xc8087dfc,
+ 0x8de4bf99, 0xa11101a0, 0x7fd37975, 0xda5a26c0, 0xe81f994f, 0x9528cd89, 0xfd339fed, 0xb87834bf,
+ 0x5f04456d, 0x22258698, 0xc9c4c83b, 0x2dc156be, 0x4f628daa, 0x57f55ec5, 0xe2220abe, 0xd2916ebf,
+ 0x4ec75b95, 0x24f2c3c0, 0x42d15d99, 0xcd0d7fa0, 0x7b6e27ff, 0xa8dc8af0, 0x7345c106, 0xf41e232f,
+ 0x35162386, 0xe6ea8926, 0x3333b094, 0x157ec6f2, 0x372b74af, 0x692573e4, 0xe9a9d848, 0xf3160289,
+ 0x3a62ef1d, 0xa787e238, 0xf3a5f676, 0x74364853, 0x20951063, 0x4576698d, 0xb6fad407, 0x592af950,
+ 0x36f73523, 0x4cfb6e87, 0x7da4cec0, 0x6c152daa, 0xcb0396a8, 0xc50dfe5d, 0xfcd707ab, 0x0921c42f,
+ 0x89dff0bb, 0x5fe2be78, 0x448f4f33, 0x754613c9, 0x2b05d08d, 0x48b9d585, 0xdc049441, 0xc8098f9b,
+ 0x7dede786, 0xc39a3373, 0x42410005, 0x6a091751, 0x0ef3c8a6, 0x890072d6, 0x28207682, 0xa9a9f7be,
+ 0xbf32679d, 0xd45b5b75, 0xb353fd00, 0xcbb0e358, 0x830f220a, 0x1f8fb214, 0xd372cf08, 0xcc3c4a13,
+ 0x8cf63166, 0x061c87be, 0x88c98f88, 0x6062e397, 0x47cf8e7a, 0xb6c85283, 0x3cc2acfb, 0x3fc06976,
+ 0x4e8f0252, 0x64d8314d, 0xda3870e3, 0x1e665459, 0xc10908f0, 0x513021a5, 0x6c5b68b7, 0x822f8aa0,
+ 0x3007cd3e, 0x74719eef, 0xdc872681, 0x073340d4, 0x7e432fd9, 0x0c5ec241, 0x8809286c, 0xf592d891,
+ 0x08a930f6, 0x957ef305, 0xb7fbffbd, 0xc266e96f, 0x6fe4ac98, 0xb173ecc0, 0xbc60b42a, 0x953498da,
+ 0xfba1ae12, 0x2d4bd736, 0x0f25faab, 0xa4f3fceb, 0xe2969123, 0x257f0c3d, 0x9348af49, 0x361400bc,
+ 0xe8816f4a, 0x3814f200, 0xa3f94043, 0x9c7a54c2, 0xbc704f57, 0xda41e7f9, 0xc25ad33a, 0x54f4a084,
+ 0xb17f5505, 0x59357cbe, 0xedbd15c8, 0x7f97c5ab, 0xba5ac7b5, 0xb6f6deaf, 0x3a479c3a, 0x5302da25,
+ 0x653d7e6a, 0x54268d49, 0x51a477ea, 0x5017d55b, 0xd7d25d88, 0x44136c76, 0x0404a8c8, 0xb8e5a121,
+ 0xb81a928a, 0x60ed5869, 0x97c55b96, 0xeaec991b, 0x29935913, 0x01fdb7f1, 0x088e8dfa, 0x9ab6f6f5,
+ 0x3b4cbf9f, 0x4a5de3ab, 0xe6051d35, 0xa0e1d855, 0xd36b4cf1, 0xf544edeb, 0xb0e93524, 0xbebb8fbd,
+ 0xa2d762cf, 0x49c92f54, 0x38b5f331, 0x7128a454, 0x48392905, 0xa65b1db8, 0x851c97bd, 0xd675cf2f,
+ },
+ {
+ 0x85e04019, 0x332bf567, 0x662dbfff, 0xcfc65693, 0x2a8d7f6f, 0xab9bc912, 0xde6008a1, 0x2028da1f,
+ 0x0227bce7, 0x4d642916, 0x18fac300, 0x50f18b82, 0x2cb2cb11, 0xb232e75c, 0x4b3695f2, 0xb28707de,
+ 0xa05fbcf6, 0xcd4181e9, 0xe150210c, 0xe24ef1bd, 0xb168c381, 0xfde4e789, 0x5c79b0d8, 0x1e8bfd43,
+ 0x4d495001, 0x38be4341, 0x913cee1d, 0x92a79c3f, 0x089766be, 0xbaeeadf4, 0x1286becf, 0xb6eacb19,
+ 0x2660c200, 0x7565bde4, 0x64241f7a, 0x8248dca9, 0xc3b3ad66, 0x28136086, 0x0bd8dfa8, 0x356d1cf2,
+ 0x107789be, 0xb3b2e9ce, 0x0502aa8f, 0x0bc0351e, 0x166bf52a, 0xeb12ff82, 0xe3486911, 0xd34d7516,
+ 0x4e7b3aff, 0x5f43671b, 0x9cf6e037, 0x4981ac83, 0x334266ce, 0x8c9341b7, 0xd0d854c0, 0xcb3a6c88,
+ 0x47bc2829, 0x4725ba37, 0xa66ad22b, 0x7ad61f1e, 0x0c5cbafa, 0x4437f107, 0xb6e79962, 0x42d2d816,
+ 0x0a961288, 0xe1a5c06e, 0x13749e67, 0x72fc081a, 0xb1d139f7, 0xf9583745, 0xcf19df58, 0xbec3f756,
+ 0xc06eba30, 0x07211b24, 0x45c28829, 0xc95e317f, 0xbc8ec511, 0x38bc46e9, 0xc6e6fa14, 0xbae8584a,
+ 0xad4ebc46, 0x468f508b, 0x7829435f, 0xf124183b, 0x821dba9f, 0xaff60ff4, 0xea2c4e6d, 0x16e39264,
+ 0x92544a8b, 0x009b4fc3, 0xaba68ced, 0x9ac96f78, 0x06a5b79a, 0xb2856e6e, 0x1aec3ca9, 0xbe838688,
+ 0x0e0804e9, 0x55f1be56, 0xe7e5363b, 0xb3a1f25d, 0xf7debb85, 0x61fe033c, 0x16746233, 0x3c034c28,
+ 0xda6d0c74, 0x79aac56c, 0x3ce4e1ad, 0x51f0c802, 0x98f8f35a, 0x1626a49f, 0xeed82b29, 0x1d382fe3,
+ 0x0c4fb99a, 0xbb325778, 0x3ec6d97b, 0x6e77a6a9, 0xcb658b5c, 0xd45230c7, 0x2bd1408b, 0x60c03eb7,
+ 0xb9068d78, 0xa33754f4, 0xf430c87d, 0xc8a71302, 0xb96d8c32, 0xebd4e7be, 0xbe8b9d2d, 0x7979fb06,
+ 0xe7225308, 0x8b75cf77, 0x11ef8da4, 0xe083c858, 0x8d6b786f, 0x5a6317a6, 0xfa5cf7a0, 0x5dda0033,
+ 0xf28ebfb0, 0xf5b9c310, 0xa0eac280, 0x08b9767a, 0xa3d9d2b0, 0x79d34217, 0x021a718d, 0x9ac6336a,
+ 0x2711fd60, 0x438050e3, 0x069908a8, 0x3d7fedc4, 0x826d2bef, 0x4eeb8476, 0x488dcf25, 0x36c9d566,
+ 0x28e74e41, 0xc2610aca, 0x3d49a9cf, 0xbae3b9df, 0xb65f8de6, 0x92aeaf64, 0x3ac7d5e6, 0x9ea80509,
+ 0xf22b017d, 0xa4173f70, 0xdd1e16c3, 0x15e0d7f9, 0x50b1b887, 0x2b9f4fd5, 0x625aba82, 0x6a017962,
+ 0x2ec01b9c, 0x15488aa9, 0xd716e740, 0x40055a2c, 0x93d29a22, 0xe32dbf9a, 0x058745b9, 0x3453dc1e,
+ 0xd699296e, 0x496cff6f, 0x1c9f4986, 0xdfe2ed07, 0xb87242d1, 0x19de7eae, 0x053e561a, 0x15ad6f8c,
+ 0x66626c1c, 0x7154c24c, 0xea082b2a, 0x93eb2939, 0x17dcb0f0, 0x58d4f2ae, 0x9ea294fb, 0x52cf564c,
+ 0x9883fe66, 0x2ec40581, 0x763953c3, 0x01d6692e, 0xd3a0c108, 0xa1e7160e, 0xe4f2dfa6, 0x693ed285,
+ 0x74904698, 0x4c2b0edd, 0x4f757656, 0x5d393378, 0xa132234f, 0x3d321c5d, 0xc3f5e194, 0x4b269301,
+ 0xc79f022f, 0x3c997e7e, 0x5e4f9504, 0x3ffafbbd, 0x76f7ad0e, 0x296693f4, 0x3d1fce6f, 0xc61e45be,
+ 0xd3b5ab34, 0xf72bf9b7, 0x1b0434c0, 0x4e72b567, 0x5592a33d, 0xb5229301, 0xcfd2a87f, 0x60aeb767,
+ 0x1814386b, 0x30bcc33d, 0x38a0c07d, 0xfd1606f2, 0xc363519b, 0x589dd390, 0x5479f8e6, 0x1cb8d647,
+ 0x97fd61a9, 0xea7759f4, 0x2d57539d, 0x569a58cf, 0xe84e63ad, 0x462e1b78, 0x6580f87e, 0xf3817914,
+ 0x91da55f4, 0x40a230f3, 0xd1988f35, 0xb6e318d2, 0x3ffa50bc, 0x3d40f021, 0xc3c0bdae, 0x4958c24c,
+ 0x518f36b2, 0x84b1d370, 0x0fedce83, 0x878ddada, 0xf2a279c7, 0x94e01be8, 0x90716f4b, 0x954b8aa3,
+ },
+ {
+ 0xe216300d, 0xbbddfffc, 0xa7ebdabd, 0x35648095, 0x7789f8b7, 0xe6c1121b, 0x0e241600, 0x052ce8b5,
+ 0x11a9cfb0, 0xe5952f11, 0xece7990a, 0x9386d174, 0x2a42931c, 0x76e38111, 0xb12def3a, 0x37ddddfc,
+ 0xde9adeb1, 0x0a0cc32c, 0xbe197029, 0x84a00940, 0xbb243a0f, 0xb4d137cf, 0xb44e79f0, 0x049eedfd,
+ 0x0b15a15d, 0x480d3168, 0x8bbbde5a, 0x669ded42, 0xc7ece831, 0x3f8f95e7, 0x72df191b, 0x7580330d,
+ 0x94074251, 0x5c7dcdfa, 0xabbe6d63, 0xaa402164, 0xb301d40a, 0x02e7d1ca, 0x53571dae, 0x7a3182a2,
+ 0x12a8ddec, 0xfdaa335d, 0x176f43e8, 0x71fb46d4, 0x38129022, 0xce949ad4, 0xb84769ad, 0x965bd862,
+ 0x82f3d055, 0x66fb9767, 0x15b80b4e, 0x1d5b47a0, 0x4cfde06f, 0xc28ec4b8, 0x57e8726e, 0x647a78fc,
+ 0x99865d44, 0x608bd593, 0x6c200e03, 0x39dc5ff6, 0x5d0b00a3, 0xae63aff2, 0x7e8bd632, 0x70108c0c,
+ 0xbbd35049, 0x2998df04, 0x980cf42a, 0x9b6df491, 0x9e7edd53, 0x06918548, 0x58cb7e07, 0x3b74ef2e,
+ 0x522fffb1, 0xd24708cc, 0x1c7e27cd, 0xa4eb215b, 0x3cf1d2e2, 0x19b47a38, 0x424f7618, 0x35856039,
+ 0x9d17dee7, 0x27eb35e6, 0xc9aff67b, 0x36baf5b8, 0x09c467cd, 0xc18910b1, 0xe11dbf7b, 0x06cd1af8,
+ 0x7170c608, 0x2d5e3354, 0xd4de495a, 0x64c6d006, 0xbcc0c62c, 0x3dd00db3, 0x708f8f34, 0x77d51b42,
+ 0x264f620f, 0x24b8d2bf, 0x15c1b79e, 0x46a52564, 0xf8d7e54e, 0x3e378160, 0x7895cda5, 0x859c15a5,
+ 0xe6459788, 0xc37bc75f, 0xdb07ba0c, 0x0676a3ab, 0x7f229b1e, 0x31842e7b, 0x24259fd7, 0xf8bef472,
+ 0x835ffcb8, 0x6df4c1f2, 0x96f5b195, 0xfd0af0fc, 0xb0fe134c, 0xe2506d3d, 0x4f9b12ea, 0xf215f225,
+ 0xa223736f, 0x9fb4c428, 0x25d04979, 0x34c713f8, 0xc4618187, 0xea7a6e98, 0x7cd16efc, 0x1436876c,
+ 0xf1544107, 0xbedeee14, 0x56e9af27, 0xa04aa441, 0x3cf7c899, 0x92ecbae6, 0xdd67016d, 0x151682eb,
+ 0xa842eedf, 0xfdba60b4, 0xf1907b75, 0x20e3030f, 0x24d8c29e, 0xe139673b, 0xefa63fb8, 0x71873054,
+ 0xb6f2cf3b, 0x9f326442, 0xcb15a4cc, 0xb01a4504, 0xf1e47d8d, 0x844a1be5, 0xbae7dfdc, 0x42cbda70,
+ 0xcd7dae0a, 0x57e85b7a, 0xd53f5af6, 0x20cf4d8c, 0xcea4d428, 0x79d130a4, 0x3486ebfb, 0x33d3cddc,
+ 0x77853b53, 0x37effcb5, 0xc5068778, 0xe580b3e6, 0x4e68b8f4, 0xc5c8b37e, 0x0d809ea2, 0x398feb7c,
+ 0x132a4f94, 0x43b7950e, 0x2fee7d1c, 0x223613bd, 0xdd06caa2, 0x37df932b, 0xc4248289, 0xacf3ebc3,
+ 0x5715f6b7, 0xef3478dd, 0xf267616f, 0xc148cbe4, 0x9052815e, 0x5e410fab, 0xb48a2465, 0x2eda7fa4,
+ 0xe87b40e4, 0xe98ea084, 0x5889e9e1, 0xefd390fc, 0xdd07d35b, 0xdb485694, 0x38d7e5b2, 0x57720101,
+ 0x730edebc, 0x5b643113, 0x94917e4f, 0x503c2fba, 0x646f1282, 0x7523d24a, 0xe0779695, 0xf9c17a8f,
+ 0x7a5b2121, 0xd187b896, 0x29263a4d, 0xba510cdf, 0x81f47c9f, 0xad1163ed, 0xea7b5965, 0x1a00726e,
+ 0x11403092, 0x00da6d77, 0x4a0cdd61, 0xad1f4603, 0x605bdfb0, 0x9eedc364, 0x22ebe6a8, 0xcee7d28a,
+ 0xa0e736a0, 0x5564a6b9, 0x10853209, 0xc7eb8f37, 0x2de705ca, 0x8951570f, 0xdf09822b, 0xbd691a6c,
+ 0xaa12e4f2, 0x87451c0f, 0xe0f6a27a, 0x3ada4819, 0x4cf1764f, 0x0d771c2b, 0x67cdb156, 0x350d8384,
+ 0x5938fa0f, 0x42399ef3, 0x36997b07, 0x0e84093d, 0x4aa93e61, 0x8360d87b, 0x1fa98b0c, 0x1149382c,
+ 0xe97625a5, 0x0614d1b7, 0x0e25244b, 0x0c768347, 0x589e8d82, 0x0d2059d1, 0xa466bb1e, 0xf8da0a82,
+ 0x04f19130, 0xba6e4ec0, 0x99265164, 0x1ee7230d, 0x50b2ad80, 0xeaee6801, 0x8db2a283, 0xea8bf59e,
+ },
+}
--- /dev/null
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package armor implements OpenPGP ASCII Armor, see RFC 4880. OpenPGP Armor is
+// very similar to PEM except that it has an additional CRC checksum.
+package armor // import "golang.org/x/crypto/openpgp/armor"
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/base64"
+ "golang.org/x/crypto/openpgp/errors"
+ "io"
+)
+
+// A Block represents an OpenPGP armored structure.
+//
+// The encoded form is:
+// -----BEGIN Type-----
+// Headers
+//
+// base64-encoded Bytes
+// '=' base64 encoded checksum
+// -----END Type-----
+// where Headers is a possibly empty sequence of Key: Value lines.
+//
+// Since the armored data can be very large, this package presents a streaming
+// interface.
+type Block struct {
+ Type string // The type, taken from the preamble (i.e. "PGP SIGNATURE").
+ Header map[string]string // Optional headers.
+ Body io.Reader // A Reader from which the contents can be read
+ lReader lineReader
+ oReader openpgpReader
+}
+
+var ArmorCorrupt error = errors.StructuralError("armor invalid")
+
+const crc24Init = 0xb704ce
+const crc24Poly = 0x1864cfb
+const crc24Mask = 0xffffff
+
+// crc24 calculates the OpenPGP checksum as specified in RFC 4880, section 6.1
+func crc24(crc uint32, d []byte) uint32 {
+ for _, b := range d {
+ crc ^= uint32(b) << 16
+ for i := 0; i < 8; i++ {
+ crc <<= 1
+ if crc&0x1000000 != 0 {
+ crc ^= crc24Poly
+ }
+ }
+ }
+ return crc
+}
+
+var armorStart = []byte("-----BEGIN ")
+var armorEnd = []byte("-----END ")
+var armorEndOfLine = []byte("-----")
+
+// lineReader wraps a line based reader. It watches for the end of an armor
+// block and records the expected CRC value.
+type lineReader struct {
+ in *bufio.Reader
+ buf []byte
+ eof bool
+ crc uint32
+}
+
+func (l *lineReader) Read(p []byte) (n int, err error) {
+ if l.eof {
+ return 0, io.EOF
+ }
+
+ if len(l.buf) > 0 {
+ n = copy(p, l.buf)
+ l.buf = l.buf[n:]
+ return
+ }
+
+ line, isPrefix, err := l.in.ReadLine()
+ if err != nil {
+ return
+ }
+ if isPrefix {
+ return 0, ArmorCorrupt
+ }
+
+ if len(line) == 5 && line[0] == '=' {
+ // This is the checksum line
+ var expectedBytes [3]byte
+ var m int
+ m, err = base64.StdEncoding.Decode(expectedBytes[0:], line[1:])
+ if m != 3 || err != nil {
+ return
+ }
+ l.crc = uint32(expectedBytes[0])<<16 |
+ uint32(expectedBytes[1])<<8 |
+ uint32(expectedBytes[2])
+
+ line, _, err = l.in.ReadLine()
+ if err != nil && err != io.EOF {
+ return
+ }
+ if !bytes.HasPrefix(line, armorEnd) {
+ return 0, ArmorCorrupt
+ }
+
+ l.eof = true
+ return 0, io.EOF
+ }
+
+ if len(line) > 96 {
+ return 0, ArmorCorrupt
+ }
+
+ n = copy(p, line)
+ bytesToSave := len(line) - n
+ if bytesToSave > 0 {
+ if cap(l.buf) < bytesToSave {
+ l.buf = make([]byte, 0, bytesToSave)
+ }
+ l.buf = l.buf[0:bytesToSave]
+ copy(l.buf, line[n:])
+ }
+
+ return
+}
+
+// openpgpReader passes Read calls to the underlying base64 decoder, but keeps
+// a running CRC of the resulting data and checks the CRC against the value
+// found by the lineReader at EOF.
+type openpgpReader struct {
+ lReader *lineReader
+ b64Reader io.Reader
+ currentCRC uint32
+}
+
+func (r *openpgpReader) Read(p []byte) (n int, err error) {
+ n, err = r.b64Reader.Read(p)
+ r.currentCRC = crc24(r.currentCRC, p[:n])
+
+ if err == io.EOF {
+ if r.lReader.crc != uint32(r.currentCRC&crc24Mask) {
+ return 0, ArmorCorrupt
+ }
+ }
+
+ return
+}
+
+// Decode reads a PGP armored block from the given Reader. It will ignore
+// leading garbage. If it doesn't find a block, it will return nil, io.EOF. The
+// given Reader is not usable after calling this function: an arbitrary amount
+// of data may have been read past the end of the block.
+func Decode(in io.Reader) (p *Block, err error) {
+ r := bufio.NewReaderSize(in, 100)
+ var line []byte
+ ignoreNext := false
+
+TryNextBlock:
+ p = nil
+
+ // Skip leading garbage
+ for {
+ ignoreThis := ignoreNext
+ line, ignoreNext, err = r.ReadLine()
+ if err != nil {
+ return
+ }
+ if ignoreNext || ignoreThis {
+ continue
+ }
+ line = bytes.TrimSpace(line)
+ if len(line) > len(armorStart)+len(armorEndOfLine) && bytes.HasPrefix(line, armorStart) {
+ break
+ }
+ }
+
+ p = new(Block)
+ p.Type = string(line[len(armorStart) : len(line)-len(armorEndOfLine)])
+ p.Header = make(map[string]string)
+ nextIsContinuation := false
+ var lastKey string
+
+ // Read headers
+ for {
+ isContinuation := nextIsContinuation
+ line, nextIsContinuation, err = r.ReadLine()
+ if err != nil {
+ p = nil
+ return
+ }
+ if isContinuation {
+ p.Header[lastKey] += string(line)
+ continue
+ }
+ line = bytes.TrimSpace(line)
+ if len(line) == 0 {
+ break
+ }
+
+ i := bytes.Index(line, []byte(": "))
+ if i == -1 {
+ goto TryNextBlock
+ }
+ lastKey = string(line[:i])
+ p.Header[lastKey] = string(line[i+2:])
+ }
+
+ p.lReader.in = r
+ p.oReader.currentCRC = crc24Init
+ p.oReader.lReader = &p.lReader
+ p.oReader.b64Reader = base64.NewDecoder(base64.StdEncoding, &p.lReader)
+ p.Body = &p.oReader
+
+ return
+}
--- /dev/null
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package armor
+
+import (
+ "encoding/base64"
+ "io"
+)
+
+var armorHeaderSep = []byte(": ")
+var blockEnd = []byte("\n=")
+var newline = []byte("\n")
+var armorEndOfLineOut = []byte("-----\n")
+
+// writeSlices writes its arguments to the given Writer.
+func writeSlices(out io.Writer, slices ...[]byte) (err error) {
+ for _, s := range slices {
+ _, err = out.Write(s)
+ if err != nil {
+ return err
+ }
+ }
+ return
+}
+
+// lineBreaker breaks data across several lines, all of the same byte length
+// (except possibly the last). Lines are broken with a single '\n'.
+type lineBreaker struct {
+ lineLength int
+ line []byte
+ used int
+ out io.Writer
+ haveWritten bool
+}
+
+func newLineBreaker(out io.Writer, lineLength int) *lineBreaker {
+ return &lineBreaker{
+ lineLength: lineLength,
+ line: make([]byte, lineLength),
+ used: 0,
+ out: out,
+ }
+}
+
+func (l *lineBreaker) Write(b []byte) (n int, err error) {
+ n = len(b)
+
+ if n == 0 {
+ return
+ }
+
+ if l.used == 0 && l.haveWritten {
+ _, err = l.out.Write([]byte{'\n'})
+ if err != nil {
+ return
+ }
+ }
+
+ if l.used+len(b) < l.lineLength {
+ l.used += copy(l.line[l.used:], b)
+ return
+ }
+
+ l.haveWritten = true
+ _, err = l.out.Write(l.line[0:l.used])
+ if err != nil {
+ return
+ }
+ excess := l.lineLength - l.used
+ l.used = 0
+
+ _, err = l.out.Write(b[0:excess])
+ if err != nil {
+ return
+ }
+
+ _, err = l.Write(b[excess:])
+ return
+}
+
+func (l *lineBreaker) Close() (err error) {
+ if l.used > 0 {
+ _, err = l.out.Write(l.line[0:l.used])
+ if err != nil {
+ return
+ }
+ }
+
+ return
+}
+
+// encoding keeps track of a running CRC24 over the data which has been written
+// to it and outputs a OpenPGP checksum when closed, followed by an armor
+// trailer.
+//
+// It's built into a stack of io.Writers:
+// encoding -> base64 encoder -> lineBreaker -> out
+type encoding struct {
+ out io.Writer
+ breaker *lineBreaker
+ b64 io.WriteCloser
+ crc uint32
+ blockType []byte
+}
+
+func (e *encoding) Write(data []byte) (n int, err error) {
+ e.crc = crc24(e.crc, data)
+ return e.b64.Write(data)
+}
+
+func (e *encoding) Close() (err error) {
+ err = e.b64.Close()
+ if err != nil {
+ return
+ }
+ e.breaker.Close()
+
+ var checksumBytes [3]byte
+ checksumBytes[0] = byte(e.crc >> 16)
+ checksumBytes[1] = byte(e.crc >> 8)
+ checksumBytes[2] = byte(e.crc)
+
+ var b64ChecksumBytes [4]byte
+ base64.StdEncoding.Encode(b64ChecksumBytes[:], checksumBytes[:])
+
+ return writeSlices(e.out, blockEnd, b64ChecksumBytes[:], newline, armorEnd, e.blockType, armorEndOfLine)
+}
+
+// Encode returns a WriteCloser which will encode the data written to it in
+// OpenPGP armor.
+func Encode(out io.Writer, blockType string, headers map[string]string) (w io.WriteCloser, err error) {
+ bType := []byte(blockType)
+ err = writeSlices(out, armorStart, bType, armorEndOfLineOut)
+ if err != nil {
+ return
+ }
+
+ for k, v := range headers {
+ err = writeSlices(out, []byte(k), armorHeaderSep, []byte(v), newline)
+ if err != nil {
+ return
+ }
+ }
+
+ _, err = out.Write(newline)
+ if err != nil {
+ return
+ }
+
+ e := &encoding{
+ out: out,
+ breaker: newLineBreaker(out, 64),
+ crc: crc24Init,
+ blockType: bType,
+ }
+ e.b64 = base64.NewEncoder(base64.StdEncoding, e.breaker)
+ return e, nil
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package openpgp
+
+import "hash"
+
+// NewCanonicalTextHash reformats text written to it into the canonical
+// form and then applies the hash h. See RFC 4880, section 5.2.1.
+func NewCanonicalTextHash(h hash.Hash) hash.Hash {
+ return &canonicalTextHash{h, 0}
+}
+
+type canonicalTextHash struct {
+ h hash.Hash
+ s int
+}
+
+var newline = []byte{'\r', '\n'}
+
+func (cth *canonicalTextHash) Write(buf []byte) (int, error) {
+ start := 0
+
+ for i, c := range buf {
+ switch cth.s {
+ case 0:
+ if c == '\r' {
+ cth.s = 1
+ } else if c == '\n' {
+ cth.h.Write(buf[start:i])
+ cth.h.Write(newline)
+ start = i + 1
+ }
+ case 1:
+ cth.s = 0
+ }
+ }
+
+ cth.h.Write(buf[start:])
+ return len(buf), nil
+}
+
+func (cth *canonicalTextHash) Sum(in []byte) []byte {
+ return cth.h.Sum(in)
+}
+
+func (cth *canonicalTextHash) Reset() {
+ cth.h.Reset()
+ cth.s = 0
+}
+
+func (cth *canonicalTextHash) Size() int {
+ return cth.h.Size()
+}
+
+func (cth *canonicalTextHash) BlockSize() int {
+ return cth.h.BlockSize()
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package elgamal implements ElGamal encryption, suitable for OpenPGP,
+// as specified in "A Public-Key Cryptosystem and a Signature Scheme Based on
+// Discrete Logarithms," IEEE Transactions on Information Theory, v. IT-31,
+// n. 4, 1985, pp. 469-472.
+//
+// This form of ElGamal embeds PKCS#1 v1.5 padding, which may make it
+// unsuitable for other protocols. RSA should be used in preference in any
+// case.
+package elgamal // import "golang.org/x/crypto/openpgp/elgamal"
+
+import (
+ "crypto/rand"
+ "crypto/subtle"
+ "errors"
+ "io"
+ "math/big"
+)
+
+// PublicKey represents an ElGamal public key.
+type PublicKey struct {
+ G, P, Y *big.Int
+}
+
+// PrivateKey represents an ElGamal private key.
+type PrivateKey struct {
+ PublicKey
+ X *big.Int
+}
+
+// Encrypt encrypts the given message to the given public key. The result is a
+// pair of integers. Errors can result from reading random, or because msg is
+// too large to be encrypted to the public key.
+func Encrypt(random io.Reader, pub *PublicKey, msg []byte) (c1, c2 *big.Int, err error) {
+ pLen := (pub.P.BitLen() + 7) / 8
+ if len(msg) > pLen-11 {
+ err = errors.New("elgamal: message too long")
+ return
+ }
+
+ // EM = 0x02 || PS || 0x00 || M
+ em := make([]byte, pLen-1)
+ em[0] = 2
+ ps, mm := em[1:len(em)-len(msg)-1], em[len(em)-len(msg):]
+ err = nonZeroRandomBytes(ps, random)
+ if err != nil {
+ return
+ }
+ em[len(em)-len(msg)-1] = 0
+ copy(mm, msg)
+
+ m := new(big.Int).SetBytes(em)
+
+ k, err := rand.Int(random, pub.P)
+ if err != nil {
+ return
+ }
+
+ c1 = new(big.Int).Exp(pub.G, k, pub.P)
+ s := new(big.Int).Exp(pub.Y, k, pub.P)
+ c2 = s.Mul(s, m)
+ c2.Mod(c2, pub.P)
+
+ return
+}
+
+// Decrypt takes two integers, resulting from an ElGamal encryption, and
+// returns the plaintext of the message. An error can result only if the
+// ciphertext is invalid. Users should keep in mind that this is a padding
+// oracle and thus, if exposed to an adaptive chosen ciphertext attack, can
+// be used to break the cryptosystem. See ``Chosen Ciphertext Attacks
+// Against Protocols Based on the RSA Encryption Standard PKCS #1'', Daniel
+// Bleichenbacher, Advances in Cryptology (Crypto '98),
+func Decrypt(priv *PrivateKey, c1, c2 *big.Int) (msg []byte, err error) {
+ s := new(big.Int).Exp(c1, priv.X, priv.P)
+ s.ModInverse(s, priv.P)
+ s.Mul(s, c2)
+ s.Mod(s, priv.P)
+ em := s.Bytes()
+
+ firstByteIsTwo := subtle.ConstantTimeByteEq(em[0], 2)
+
+ // The remainder of the plaintext must be a string of non-zero random
+ // octets, followed by a 0, followed by the message.
+ // lookingForIndex: 1 iff we are still looking for the zero.
+ // index: the offset of the first zero byte.
+ var lookingForIndex, index int
+ lookingForIndex = 1
+
+ for i := 1; i < len(em); i++ {
+ equals0 := subtle.ConstantTimeByteEq(em[i], 0)
+ index = subtle.ConstantTimeSelect(lookingForIndex&equals0, i, index)
+ lookingForIndex = subtle.ConstantTimeSelect(equals0, 0, lookingForIndex)
+ }
+
+ if firstByteIsTwo != 1 || lookingForIndex != 0 || index < 9 {
+ return nil, errors.New("elgamal: decryption error")
+ }
+ return em[index+1:], nil
+}
+
+// nonZeroRandomBytes fills the given slice with non-zero random octets.
+func nonZeroRandomBytes(s []byte, rand io.Reader) (err error) {
+ _, err = io.ReadFull(rand, s)
+ if err != nil {
+ return
+ }
+
+ for i := 0; i < len(s); i++ {
+ for s[i] == 0 {
+ _, err = io.ReadFull(rand, s[i:i+1])
+ if err != nil {
+ return
+ }
+ }
+ }
+
+ return
+}
--- /dev/null
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package errors contains common error types for the OpenPGP packages.
+package errors // import "golang.org/x/crypto/openpgp/errors"
+
+import (
+ "strconv"
+)
+
+// A StructuralError is returned when OpenPGP data is found to be syntactically
+// invalid.
+type StructuralError string
+
+func (s StructuralError) Error() string {
+ return "openpgp: invalid data: " + string(s)
+}
+
+// UnsupportedError indicates that, although the OpenPGP data is valid, it
+// makes use of currently unimplemented features.
+type UnsupportedError string
+
+func (s UnsupportedError) Error() string {
+ return "openpgp: unsupported feature: " + string(s)
+}
+
+// InvalidArgumentError indicates that the caller is in error and passed an
+// incorrect value.
+type InvalidArgumentError string
+
+func (i InvalidArgumentError) Error() string {
+ return "openpgp: invalid argument: " + string(i)
+}
+
+// SignatureError indicates that a syntactically valid signature failed to
+// validate.
+type SignatureError string
+
+func (b SignatureError) Error() string {
+ return "openpgp: invalid signature: " + string(b)
+}
+
+type keyIncorrectError int
+
+func (ki keyIncorrectError) Error() string {
+ return "openpgp: incorrect key"
+}
+
+var ErrKeyIncorrect error = keyIncorrectError(0)
+
+type unknownIssuerError int
+
+func (unknownIssuerError) Error() string {
+ return "openpgp: signature made by unknown entity"
+}
+
+var ErrUnknownIssuer error = unknownIssuerError(0)
+
+type keyRevokedError int
+
+func (keyRevokedError) Error() string {
+ return "openpgp: signature made by revoked key"
+}
+
+var ErrKeyRevoked error = keyRevokedError(0)
+
+type UnknownPacketTypeError uint8
+
+func (upte UnknownPacketTypeError) Error() string {
+ return "openpgp: unknown packet type: " + strconv.Itoa(int(upte))
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package openpgp
+
+import (
+ "crypto/rsa"
+ "io"
+ "time"
+
+ "golang.org/x/crypto/openpgp/armor"
+ "golang.org/x/crypto/openpgp/errors"
+ "golang.org/x/crypto/openpgp/packet"
+)
+
+// PublicKeyType is the armor type for a PGP public key.
+var PublicKeyType = "PGP PUBLIC KEY BLOCK"
+
+// PrivateKeyType is the armor type for a PGP private key.
+var PrivateKeyType = "PGP PRIVATE KEY BLOCK"
+
+// An Entity represents the components of an OpenPGP key: a primary public key
+// (which must be a signing key), one or more identities claimed by that key,
+// and zero or more subkeys, which may be encryption keys.
+type Entity struct {
+ PrimaryKey *packet.PublicKey
+ PrivateKey *packet.PrivateKey
+ Identities map[string]*Identity // indexed by Identity.Name
+ Revocations []*packet.Signature
+ Subkeys []Subkey
+}
+
+// An Identity represents an identity claimed by an Entity and zero or more
+// assertions by other entities about that claim.
+type Identity struct {
+ Name string // by convention, has the form "Full Name (comment) <email@example.com>"
+ UserId *packet.UserId
+ SelfSignature *packet.Signature
+ Signatures []*packet.Signature
+}
+
+// A Subkey is an additional public key in an Entity. Subkeys can be used for
+// encryption.
+type Subkey struct {
+ PublicKey *packet.PublicKey
+ PrivateKey *packet.PrivateKey
+ Sig *packet.Signature
+}
+
+// A Key identifies a specific public key in an Entity. This is either the
+// Entity's primary key or a subkey.
+type Key struct {
+ Entity *Entity
+ PublicKey *packet.PublicKey
+ PrivateKey *packet.PrivateKey
+ SelfSignature *packet.Signature
+}
+
+// A KeyRing provides access to public and private keys.
+type KeyRing interface {
+ // KeysById returns the set of keys that have the given key id.
+ KeysById(id uint64) []Key
+ // KeysByIdAndUsage returns the set of keys with the given id
+ // that also meet the key usage given by requiredUsage.
+ // The requiredUsage is expressed as the bitwise-OR of
+ // packet.KeyFlag* values.
+ KeysByIdUsage(id uint64, requiredUsage byte) []Key
+ // DecryptionKeys returns all private keys that are valid for
+ // decryption.
+ DecryptionKeys() []Key
+}
+
+// primaryIdentity returns the Identity marked as primary or the first identity
+// if none are so marked.
+func (e *Entity) primaryIdentity() *Identity {
+ var firstIdentity *Identity
+ for _, ident := range e.Identities {
+ if firstIdentity == nil {
+ firstIdentity = ident
+ }
+ if ident.SelfSignature.IsPrimaryId != nil && *ident.SelfSignature.IsPrimaryId {
+ return ident
+ }
+ }
+ return firstIdentity
+}
+
+// encryptionKey returns the best candidate Key for encrypting a message to the
+// given Entity.
+func (e *Entity) encryptionKey(now time.Time) (Key, bool) {
+ candidateSubkey := -1
+
+ // Iterate the keys to find the newest key
+ var maxTime time.Time
+ for i, subkey := range e.Subkeys {
+ if subkey.Sig.FlagsValid &&
+ subkey.Sig.FlagEncryptCommunications &&
+ subkey.PublicKey.PubKeyAlgo.CanEncrypt() &&
+ !subkey.Sig.KeyExpired(now) &&
+ (maxTime.IsZero() || subkey.Sig.CreationTime.After(maxTime)) {
+ candidateSubkey = i
+ maxTime = subkey.Sig.CreationTime
+ }
+ }
+
+ if candidateSubkey != -1 {
+ subkey := e.Subkeys[candidateSubkey]
+ return Key{e, subkey.PublicKey, subkey.PrivateKey, subkey.Sig}, true
+ }
+
+ // If we don't have any candidate subkeys for encryption and
+ // the primary key doesn't have any usage metadata then we
+ // assume that the primary key is ok. Or, if the primary key is
+ // marked as ok to encrypt to, then we can obviously use it.
+ i := e.primaryIdentity()
+ if !i.SelfSignature.FlagsValid || i.SelfSignature.FlagEncryptCommunications &&
+ e.PrimaryKey.PubKeyAlgo.CanEncrypt() &&
+ !i.SelfSignature.KeyExpired(now) {
+ return Key{e, e.PrimaryKey, e.PrivateKey, i.SelfSignature}, true
+ }
+
+ // This Entity appears to be signing only.
+ return Key{}, false
+}
+
+// signingKey return the best candidate Key for signing a message with this
+// Entity.
+func (e *Entity) signingKey(now time.Time) (Key, bool) {
+ candidateSubkey := -1
+
+ for i, subkey := range e.Subkeys {
+ if subkey.Sig.FlagsValid &&
+ subkey.Sig.FlagSign &&
+ subkey.PublicKey.PubKeyAlgo.CanSign() &&
+ !subkey.Sig.KeyExpired(now) {
+ candidateSubkey = i
+ break
+ }
+ }
+
+ if candidateSubkey != -1 {
+ subkey := e.Subkeys[candidateSubkey]
+ return Key{e, subkey.PublicKey, subkey.PrivateKey, subkey.Sig}, true
+ }
+
+ // If we have no candidate subkey then we assume that it's ok to sign
+ // with the primary key.
+ i := e.primaryIdentity()
+ if !i.SelfSignature.FlagsValid || i.SelfSignature.FlagSign &&
+ !i.SelfSignature.KeyExpired(now) {
+ return Key{e, e.PrimaryKey, e.PrivateKey, i.SelfSignature}, true
+ }
+
+ return Key{}, false
+}
+
+// An EntityList contains one or more Entities.
+type EntityList []*Entity
+
+// KeysById returns the set of keys that have the given key id.
+func (el EntityList) KeysById(id uint64) (keys []Key) {
+ for _, e := range el {
+ if e.PrimaryKey.KeyId == id {
+ var selfSig *packet.Signature
+ for _, ident := range e.Identities {
+ if selfSig == nil {
+ selfSig = ident.SelfSignature
+ } else if ident.SelfSignature.IsPrimaryId != nil && *ident.SelfSignature.IsPrimaryId {
+ selfSig = ident.SelfSignature
+ break
+ }
+ }
+ keys = append(keys, Key{e, e.PrimaryKey, e.PrivateKey, selfSig})
+ }
+
+ for _, subKey := range e.Subkeys {
+ if subKey.PublicKey.KeyId == id {
+ keys = append(keys, Key{e, subKey.PublicKey, subKey.PrivateKey, subKey.Sig})
+ }
+ }
+ }
+ return
+}
+
+// KeysByIdAndUsage returns the set of keys with the given id that also meet
+// the key usage given by requiredUsage. The requiredUsage is expressed as
+// the bitwise-OR of packet.KeyFlag* values.
+func (el EntityList) KeysByIdUsage(id uint64, requiredUsage byte) (keys []Key) {
+ for _, key := range el.KeysById(id) {
+ if len(key.Entity.Revocations) > 0 {
+ continue
+ }
+
+ if key.SelfSignature.RevocationReason != nil {
+ continue
+ }
+
+ if key.SelfSignature.FlagsValid && requiredUsage != 0 {
+ var usage byte
+ if key.SelfSignature.FlagCertify {
+ usage |= packet.KeyFlagCertify
+ }
+ if key.SelfSignature.FlagSign {
+ usage |= packet.KeyFlagSign
+ }
+ if key.SelfSignature.FlagEncryptCommunications {
+ usage |= packet.KeyFlagEncryptCommunications
+ }
+ if key.SelfSignature.FlagEncryptStorage {
+ usage |= packet.KeyFlagEncryptStorage
+ }
+ if usage&requiredUsage != requiredUsage {
+ continue
+ }
+ }
+
+ keys = append(keys, key)
+ }
+ return
+}
+
+// DecryptionKeys returns all private keys that are valid for decryption.
+func (el EntityList) DecryptionKeys() (keys []Key) {
+ for _, e := range el {
+ for _, subKey := range e.Subkeys {
+ if subKey.PrivateKey != nil && (!subKey.Sig.FlagsValid || subKey.Sig.FlagEncryptStorage || subKey.Sig.FlagEncryptCommunications) {
+ keys = append(keys, Key{e, subKey.PublicKey, subKey.PrivateKey, subKey.Sig})
+ }
+ }
+ }
+ return
+}
+
+// ReadArmoredKeyRing reads one or more public/private keys from an armor keyring file.
+func ReadArmoredKeyRing(r io.Reader) (EntityList, error) {
+ block, err := armor.Decode(r)
+ if err == io.EOF {
+ return nil, errors.InvalidArgumentError("no armored data found")
+ }
+ if err != nil {
+ return nil, err
+ }
+ if block.Type != PublicKeyType && block.Type != PrivateKeyType {
+ return nil, errors.InvalidArgumentError("expected public or private key block, got: " + block.Type)
+ }
+
+ return ReadKeyRing(block.Body)
+}
+
+// ReadKeyRing reads one or more public/private keys. Unsupported keys are
+// ignored as long as at least a single valid key is found.
+func ReadKeyRing(r io.Reader) (el EntityList, err error) {
+ packets := packet.NewReader(r)
+ var lastUnsupportedError error
+
+ for {
+ var e *Entity
+ e, err = ReadEntity(packets)
+ if err != nil {
+ // TODO: warn about skipped unsupported/unreadable keys
+ if _, ok := err.(errors.UnsupportedError); ok {
+ lastUnsupportedError = err
+ err = readToNextPublicKey(packets)
+ } else if _, ok := err.(errors.StructuralError); ok {
+ // Skip unreadable, badly-formatted keys
+ lastUnsupportedError = err
+ err = readToNextPublicKey(packets)
+ }
+ if err == io.EOF {
+ err = nil
+ break
+ }
+ if err != nil {
+ el = nil
+ break
+ }
+ } else {
+ el = append(el, e)
+ }
+ }
+
+ if len(el) == 0 && err == nil {
+ err = lastUnsupportedError
+ }
+ return
+}
+
+// readToNextPublicKey reads packets until the start of the entity and leaves
+// the first packet of the new entity in the Reader.
+func readToNextPublicKey(packets *packet.Reader) (err error) {
+ var p packet.Packet
+ for {
+ p, err = packets.Next()
+ if err == io.EOF {
+ return
+ } else if err != nil {
+ if _, ok := err.(errors.UnsupportedError); ok {
+ err = nil
+ continue
+ }
+ return
+ }
+
+ if pk, ok := p.(*packet.PublicKey); ok && !pk.IsSubkey {
+ packets.Unread(p)
+ return
+ }
+ }
+}
+
+// ReadEntity reads an entity (public key, identities, subkeys etc) from the
+// given Reader.
+func ReadEntity(packets *packet.Reader) (*Entity, error) {
+ e := new(Entity)
+ e.Identities = make(map[string]*Identity)
+
+ p, err := packets.Next()
+ if err != nil {
+ return nil, err
+ }
+
+ var ok bool
+ if e.PrimaryKey, ok = p.(*packet.PublicKey); !ok {
+ if e.PrivateKey, ok = p.(*packet.PrivateKey); !ok {
+ packets.Unread(p)
+ return nil, errors.StructuralError("first packet was not a public/private key")
+ }
+ e.PrimaryKey = &e.PrivateKey.PublicKey
+ }
+
+ if !e.PrimaryKey.PubKeyAlgo.CanSign() {
+ return nil, errors.StructuralError("primary key cannot be used for signatures")
+ }
+
+ var current *Identity
+ var revocations []*packet.Signature
+EachPacket:
+ for {
+ p, err := packets.Next()
+ if err == io.EOF {
+ break
+ } else if err != nil {
+ return nil, err
+ }
+
+ switch pkt := p.(type) {
+ case *packet.UserId:
+ current = new(Identity)
+ current.Name = pkt.Id
+ current.UserId = pkt
+ e.Identities[pkt.Id] = current
+
+ for {
+ p, err = packets.Next()
+ if err == io.EOF {
+ return nil, io.ErrUnexpectedEOF
+ } else if err != nil {
+ return nil, err
+ }
+
+ sig, ok := p.(*packet.Signature)
+ if !ok {
+ return nil, errors.StructuralError("user ID packet not followed by self-signature")
+ }
+
+ if (sig.SigType == packet.SigTypePositiveCert || sig.SigType == packet.SigTypeGenericCert) && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId {
+ if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, e.PrimaryKey, sig); err != nil {
+ return nil, errors.StructuralError("user ID self-signature invalid: " + err.Error())
+ }
+ current.SelfSignature = sig
+ break
+ }
+ current.Signatures = append(current.Signatures, sig)
+ }
+ case *packet.Signature:
+ if pkt.SigType == packet.SigTypeKeyRevocation {
+ revocations = append(revocations, pkt)
+ } else if pkt.SigType == packet.SigTypeDirectSignature {
+ // TODO: RFC4880 5.2.1 permits signatures
+ // directly on keys (eg. to bind additional
+ // revocation keys).
+ } else if current == nil {
+ return nil, errors.StructuralError("signature packet found before user id packet")
+ } else {
+ current.Signatures = append(current.Signatures, pkt)
+ }
+ case *packet.PrivateKey:
+ if pkt.IsSubkey == false {
+ packets.Unread(p)
+ break EachPacket
+ }
+ err = addSubkey(e, packets, &pkt.PublicKey, pkt)
+ if err != nil {
+ return nil, err
+ }
+ case *packet.PublicKey:
+ if pkt.IsSubkey == false {
+ packets.Unread(p)
+ break EachPacket
+ }
+ err = addSubkey(e, packets, pkt, nil)
+ if err != nil {
+ return nil, err
+ }
+ default:
+ // we ignore unknown packets
+ }
+ }
+
+ if len(e.Identities) == 0 {
+ return nil, errors.StructuralError("entity without any identities")
+ }
+
+ for _, revocation := range revocations {
+ err = e.PrimaryKey.VerifyRevocationSignature(revocation)
+ if err == nil {
+ e.Revocations = append(e.Revocations, revocation)
+ } else {
+ // TODO: RFC 4880 5.2.3.15 defines revocation keys.
+ return nil, errors.StructuralError("revocation signature signed by alternate key")
+ }
+ }
+
+ return e, nil
+}
+
+func addSubkey(e *Entity, packets *packet.Reader, pub *packet.PublicKey, priv *packet.PrivateKey) error {
+ var subKey Subkey
+ subKey.PublicKey = pub
+ subKey.PrivateKey = priv
+ p, err := packets.Next()
+ if err == io.EOF {
+ return io.ErrUnexpectedEOF
+ }
+ if err != nil {
+ return errors.StructuralError("subkey signature invalid: " + err.Error())
+ }
+ var ok bool
+ subKey.Sig, ok = p.(*packet.Signature)
+ if !ok {
+ return errors.StructuralError("subkey packet not followed by signature")
+ }
+ if subKey.Sig.SigType != packet.SigTypeSubkeyBinding && subKey.Sig.SigType != packet.SigTypeSubkeyRevocation {
+ return errors.StructuralError("subkey signature with wrong type")
+ }
+ err = e.PrimaryKey.VerifyKeySignature(subKey.PublicKey, subKey.Sig)
+ if err != nil {
+ return errors.StructuralError("subkey signature invalid: " + err.Error())
+ }
+ e.Subkeys = append(e.Subkeys, subKey)
+ return nil
+}
+
+const defaultRSAKeyBits = 2048
+
+// NewEntity returns an Entity that contains a fresh RSA/RSA keypair with a
+// single identity composed of the given full name, comment and email, any of
+// which may be empty but must not contain any of "()<>\x00".
+// If config is nil, sensible defaults will be used.
+func NewEntity(name, comment, email string, config *packet.Config) (*Entity, error) {
+ currentTime := config.Now()
+
+ bits := defaultRSAKeyBits
+ if config != nil && config.RSABits != 0 {
+ bits = config.RSABits
+ }
+
+ uid := packet.NewUserId(name, comment, email)
+ if uid == nil {
+ return nil, errors.InvalidArgumentError("user id field contained invalid characters")
+ }
+ signingPriv, err := rsa.GenerateKey(config.Random(), bits)
+ if err != nil {
+ return nil, err
+ }
+ encryptingPriv, err := rsa.GenerateKey(config.Random(), bits)
+ if err != nil {
+ return nil, err
+ }
+
+ e := &Entity{
+ PrimaryKey: packet.NewRSAPublicKey(currentTime, &signingPriv.PublicKey),
+ PrivateKey: packet.NewRSAPrivateKey(currentTime, signingPriv),
+ Identities: make(map[string]*Identity),
+ }
+ isPrimaryId := true
+ e.Identities[uid.Id] = &Identity{
+ Name: uid.Id,
+ UserId: uid,
+ SelfSignature: &packet.Signature{
+ CreationTime: currentTime,
+ SigType: packet.SigTypePositiveCert,
+ PubKeyAlgo: packet.PubKeyAlgoRSA,
+ Hash: config.Hash(),
+ IsPrimaryId: &isPrimaryId,
+ FlagsValid: true,
+ FlagSign: true,
+ FlagCertify: true,
+ IssuerKeyId: &e.PrimaryKey.KeyId,
+ },
+ }
+
+ // If the user passes in a DefaultHash via packet.Config,
+ // set the PreferredHash for the SelfSignature.
+ if config != nil && config.DefaultHash != 0 {
+ e.Identities[uid.Id].SelfSignature.PreferredHash = []uint8{hashToHashId(config.DefaultHash)}
+ }
+
+ // Likewise for DefaultCipher.
+ if config != nil && config.DefaultCipher != 0 {
+ e.Identities[uid.Id].SelfSignature.PreferredSymmetric = []uint8{uint8(config.DefaultCipher)}
+ }
+
+ e.Subkeys = make([]Subkey, 1)
+ e.Subkeys[0] = Subkey{
+ PublicKey: packet.NewRSAPublicKey(currentTime, &encryptingPriv.PublicKey),
+ PrivateKey: packet.NewRSAPrivateKey(currentTime, encryptingPriv),
+ Sig: &packet.Signature{
+ CreationTime: currentTime,
+ SigType: packet.SigTypeSubkeyBinding,
+ PubKeyAlgo: packet.PubKeyAlgoRSA,
+ Hash: config.Hash(),
+ FlagsValid: true,
+ FlagEncryptStorage: true,
+ FlagEncryptCommunications: true,
+ IssuerKeyId: &e.PrimaryKey.KeyId,
+ },
+ }
+ e.Subkeys[0].PublicKey.IsSubkey = true
+ e.Subkeys[0].PrivateKey.IsSubkey = true
+
+ return e, nil
+}
+
+// SerializePrivate serializes an Entity, including private key material, to
+// the given Writer. For now, it must only be used on an Entity returned from
+// NewEntity.
+// If config is nil, sensible defaults will be used.
+func (e *Entity) SerializePrivate(w io.Writer, config *packet.Config) (err error) {
+ err = e.PrivateKey.Serialize(w)
+ if err != nil {
+ return
+ }
+ for _, ident := range e.Identities {
+ err = ident.UserId.Serialize(w)
+ if err != nil {
+ return
+ }
+ err = ident.SelfSignature.SignUserId(ident.UserId.Id, e.PrimaryKey, e.PrivateKey, config)
+ if err != nil {
+ return
+ }
+ err = ident.SelfSignature.Serialize(w)
+ if err != nil {
+ return
+ }
+ }
+ for _, subkey := range e.Subkeys {
+ err = subkey.PrivateKey.Serialize(w)
+ if err != nil {
+ return
+ }
+ err = subkey.Sig.SignKey(subkey.PublicKey, e.PrivateKey, config)
+ if err != nil {
+ return
+ }
+ err = subkey.Sig.Serialize(w)
+ if err != nil {
+ return
+ }
+ }
+ return nil
+}
+
+// Serialize writes the public part of the given Entity to w. (No private
+// key material will be output).
+func (e *Entity) Serialize(w io.Writer) error {
+ err := e.PrimaryKey.Serialize(w)
+ if err != nil {
+ return err
+ }
+ for _, ident := range e.Identities {
+ err = ident.UserId.Serialize(w)
+ if err != nil {
+ return err
+ }
+ err = ident.SelfSignature.Serialize(w)
+ if err != nil {
+ return err
+ }
+ for _, sig := range ident.Signatures {
+ err = sig.Serialize(w)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ for _, subkey := range e.Subkeys {
+ err = subkey.PublicKey.Serialize(w)
+ if err != nil {
+ return err
+ }
+ err = subkey.Sig.Serialize(w)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// SignIdentity adds a signature to e, from signer, attesting that identity is
+// associated with e. The provided identity must already be an element of
+// e.Identities and the private key of signer must have been decrypted if
+// necessary.
+// If config is nil, sensible defaults will be used.
+func (e *Entity) SignIdentity(identity string, signer *Entity, config *packet.Config) error {
+ if signer.PrivateKey == nil {
+ return errors.InvalidArgumentError("signing Entity must have a private key")
+ }
+ if signer.PrivateKey.Encrypted {
+ return errors.InvalidArgumentError("signing Entity's private key must be decrypted")
+ }
+ ident, ok := e.Identities[identity]
+ if !ok {
+ return errors.InvalidArgumentError("given identity string not found in Entity")
+ }
+
+ sig := &packet.Signature{
+ SigType: packet.SigTypeGenericCert,
+ PubKeyAlgo: signer.PrivateKey.PubKeyAlgo,
+ Hash: config.Hash(),
+ CreationTime: config.Now(),
+ IssuerKeyId: &signer.PrivateKey.KeyId,
+ }
+ if err := sig.SignUserId(identity, e.PrimaryKey, signer.PrivateKey, config); err != nil {
+ return err
+ }
+ ident.Signatures = append(ident.Signatures, sig)
+ return nil
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "compress/bzip2"
+ "compress/flate"
+ "compress/zlib"
+ "golang.org/x/crypto/openpgp/errors"
+ "io"
+ "strconv"
+)
+
+// Compressed represents a compressed OpenPGP packet. The decompressed contents
+// will contain more OpenPGP packets. See RFC 4880, section 5.6.
+type Compressed struct {
+ Body io.Reader
+}
+
+const (
+ NoCompression = flate.NoCompression
+ BestSpeed = flate.BestSpeed
+ BestCompression = flate.BestCompression
+ DefaultCompression = flate.DefaultCompression
+)
+
+// CompressionConfig contains compressor configuration settings.
+type CompressionConfig struct {
+ // Level is the compression level to use. It must be set to
+ // between -1 and 9, with -1 causing the compressor to use the
+ // default compression level, 0 causing the compressor to use
+ // no compression and 1 to 9 representing increasing (better,
+ // slower) compression levels. If Level is less than -1 or
+ // more then 9, a non-nil error will be returned during
+ // encryption. See the constants above for convenient common
+ // settings for Level.
+ Level int
+}
+
+func (c *Compressed) parse(r io.Reader) error {
+ var buf [1]byte
+ _, err := readFull(r, buf[:])
+ if err != nil {
+ return err
+ }
+
+ switch buf[0] {
+ case 1:
+ c.Body = flate.NewReader(r)
+ case 2:
+ c.Body, err = zlib.NewReader(r)
+ case 3:
+ c.Body = bzip2.NewReader(r)
+ default:
+ err = errors.UnsupportedError("unknown compression algorithm: " + strconv.Itoa(int(buf[0])))
+ }
+
+ return err
+}
+
+// compressedWriterCloser represents the serialized compression stream
+// header and the compressor. Its Close() method ensures that both the
+// compressor and serialized stream header are closed. Its Write()
+// method writes to the compressor.
+type compressedWriteCloser struct {
+ sh io.Closer // Stream Header
+ c io.WriteCloser // Compressor
+}
+
+func (cwc compressedWriteCloser) Write(p []byte) (int, error) {
+ return cwc.c.Write(p)
+}
+
+func (cwc compressedWriteCloser) Close() (err error) {
+ err = cwc.c.Close()
+ if err != nil {
+ return err
+ }
+
+ return cwc.sh.Close()
+}
+
+// SerializeCompressed serializes a compressed data packet to w and
+// returns a WriteCloser to which the literal data packets themselves
+// can be written and which MUST be closed on completion. If cc is
+// nil, sensible defaults will be used to configure the compression
+// algorithm.
+func SerializeCompressed(w io.WriteCloser, algo CompressionAlgo, cc *CompressionConfig) (literaldata io.WriteCloser, err error) {
+ compressed, err := serializeStreamHeader(w, packetTypeCompressed)
+ if err != nil {
+ return
+ }
+
+ _, err = compressed.Write([]byte{uint8(algo)})
+ if err != nil {
+ return
+ }
+
+ level := DefaultCompression
+ if cc != nil {
+ level = cc.Level
+ }
+
+ var compressor io.WriteCloser
+ switch algo {
+ case CompressionZIP:
+ compressor, err = flate.NewWriter(compressed, level)
+ case CompressionZLIB:
+ compressor, err = zlib.NewWriterLevel(compressed, level)
+ default:
+ s := strconv.Itoa(int(algo))
+ err = errors.UnsupportedError("Unsupported compression algorithm: " + s)
+ }
+ if err != nil {
+ return
+ }
+
+ literaldata = compressedWriteCloser{compressed, compressor}
+
+ return
+}
--- /dev/null
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "crypto"
+ "crypto/rand"
+ "io"
+ "time"
+)
+
+// Config collects a number of parameters along with sensible defaults.
+// A nil *Config is valid and results in all default values.
+type Config struct {
+ // Rand provides the source of entropy.
+ // If nil, the crypto/rand Reader is used.
+ Rand io.Reader
+ // DefaultHash is the default hash function to be used.
+ // If zero, SHA-256 is used.
+ DefaultHash crypto.Hash
+ // DefaultCipher is the cipher to be used.
+ // If zero, AES-128 is used.
+ DefaultCipher CipherFunction
+ // Time returns the current time as the number of seconds since the
+ // epoch. If Time is nil, time.Now is used.
+ Time func() time.Time
+ // DefaultCompressionAlgo is the compression algorithm to be
+ // applied to the plaintext before encryption. If zero, no
+ // compression is done.
+ DefaultCompressionAlgo CompressionAlgo
+ // CompressionConfig configures the compression settings.
+ CompressionConfig *CompressionConfig
+ // S2KCount is only used for symmetric encryption. It
+ // determines the strength of the passphrase stretching when
+ // the said passphrase is hashed to produce a key. S2KCount
+ // should be between 1024 and 65011712, inclusive. If Config
+ // is nil or S2KCount is 0, the value 65536 used. Not all
+ // values in the above range can be represented. S2KCount will
+ // be rounded up to the next representable value if it cannot
+ // be encoded exactly. When set, it is strongly encrouraged to
+ // use a value that is at least 65536. See RFC 4880 Section
+ // 3.7.1.3.
+ S2KCount int
+ // RSABits is the number of bits in new RSA keys made with NewEntity.
+ // If zero, then 2048 bit keys are created.
+ RSABits int
+}
+
+func (c *Config) Random() io.Reader {
+ if c == nil || c.Rand == nil {
+ return rand.Reader
+ }
+ return c.Rand
+}
+
+func (c *Config) Hash() crypto.Hash {
+ if c == nil || uint(c.DefaultHash) == 0 {
+ return crypto.SHA256
+ }
+ return c.DefaultHash
+}
+
+func (c *Config) Cipher() CipherFunction {
+ if c == nil || uint8(c.DefaultCipher) == 0 {
+ return CipherAES128
+ }
+ return c.DefaultCipher
+}
+
+func (c *Config) Now() time.Time {
+ if c == nil || c.Time == nil {
+ return time.Now()
+ }
+ return c.Time()
+}
+
+func (c *Config) Compression() CompressionAlgo {
+ if c == nil {
+ return CompressionNone
+ }
+ return c.DefaultCompressionAlgo
+}
+
+func (c *Config) PasswordHashIterations() int {
+ if c == nil || c.S2KCount == 0 {
+ return 0
+ }
+ return c.S2KCount
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "crypto/rsa"
+ "encoding/binary"
+ "io"
+ "math/big"
+ "strconv"
+
+ "golang.org/x/crypto/openpgp/elgamal"
+ "golang.org/x/crypto/openpgp/errors"
+)
+
+const encryptedKeyVersion = 3
+
+// EncryptedKey represents a public-key encrypted session key. See RFC 4880,
+// section 5.1.
+type EncryptedKey struct {
+ KeyId uint64
+ Algo PublicKeyAlgorithm
+ CipherFunc CipherFunction // only valid after a successful Decrypt
+ Key []byte // only valid after a successful Decrypt
+
+ encryptedMPI1, encryptedMPI2 parsedMPI
+}
+
+func (e *EncryptedKey) parse(r io.Reader) (err error) {
+ var buf [10]byte
+ _, err = readFull(r, buf[:])
+ if err != nil {
+ return
+ }
+ if buf[0] != encryptedKeyVersion {
+ return errors.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0])))
+ }
+ e.KeyId = binary.BigEndian.Uint64(buf[1:9])
+ e.Algo = PublicKeyAlgorithm(buf[9])
+ switch e.Algo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
+ e.encryptedMPI1.bytes, e.encryptedMPI1.bitLength, err = readMPI(r)
+ if err != nil {
+ return
+ }
+ case PubKeyAlgoElGamal:
+ e.encryptedMPI1.bytes, e.encryptedMPI1.bitLength, err = readMPI(r)
+ if err != nil {
+ return
+ }
+ e.encryptedMPI2.bytes, e.encryptedMPI2.bitLength, err = readMPI(r)
+ if err != nil {
+ return
+ }
+ }
+ _, err = consumeAll(r)
+ return
+}
+
+func checksumKeyMaterial(key []byte) uint16 {
+ var checksum uint16
+ for _, v := range key {
+ checksum += uint16(v)
+ }
+ return checksum
+}
+
+// Decrypt decrypts an encrypted session key with the given private key. The
+// private key must have been decrypted first.
+// If config is nil, sensible defaults will be used.
+func (e *EncryptedKey) Decrypt(priv *PrivateKey, config *Config) error {
+ var err error
+ var b []byte
+
+ // TODO(agl): use session key decryption routines here to avoid
+ // padding oracle attacks.
+ switch priv.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
+ k := priv.PrivateKey.(*rsa.PrivateKey)
+ b, err = rsa.DecryptPKCS1v15(config.Random(), k, padToKeySize(&k.PublicKey, e.encryptedMPI1.bytes))
+ case PubKeyAlgoElGamal:
+ c1 := new(big.Int).SetBytes(e.encryptedMPI1.bytes)
+ c2 := new(big.Int).SetBytes(e.encryptedMPI2.bytes)
+ b, err = elgamal.Decrypt(priv.PrivateKey.(*elgamal.PrivateKey), c1, c2)
+ default:
+ err = errors.InvalidArgumentError("cannot decrypted encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo)))
+ }
+
+ if err != nil {
+ return err
+ }
+
+ e.CipherFunc = CipherFunction(b[0])
+ e.Key = b[1 : len(b)-2]
+ expectedChecksum := uint16(b[len(b)-2])<<8 | uint16(b[len(b)-1])
+ checksum := checksumKeyMaterial(e.Key)
+ if checksum != expectedChecksum {
+ return errors.StructuralError("EncryptedKey checksum incorrect")
+ }
+
+ return nil
+}
+
+// Serialize writes the encrypted key packet, e, to w.
+func (e *EncryptedKey) Serialize(w io.Writer) error {
+ var mpiLen int
+ switch e.Algo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
+ mpiLen = 2 + len(e.encryptedMPI1.bytes)
+ case PubKeyAlgoElGamal:
+ mpiLen = 2 + len(e.encryptedMPI1.bytes) + 2 + len(e.encryptedMPI2.bytes)
+ default:
+ return errors.InvalidArgumentError("don't know how to serialize encrypted key type " + strconv.Itoa(int(e.Algo)))
+ }
+
+ serializeHeader(w, packetTypeEncryptedKey, 1 /* version */ +8 /* key id */ +1 /* algo */ +mpiLen)
+
+ w.Write([]byte{encryptedKeyVersion})
+ binary.Write(w, binary.BigEndian, e.KeyId)
+ w.Write([]byte{byte(e.Algo)})
+
+ switch e.Algo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
+ writeMPIs(w, e.encryptedMPI1)
+ case PubKeyAlgoElGamal:
+ writeMPIs(w, e.encryptedMPI1, e.encryptedMPI2)
+ default:
+ panic("internal error")
+ }
+
+ return nil
+}
+
+// SerializeEncryptedKey serializes an encrypted key packet to w that contains
+// key, encrypted to pub.
+// If config is nil, sensible defaults will be used.
+func SerializeEncryptedKey(w io.Writer, pub *PublicKey, cipherFunc CipherFunction, key []byte, config *Config) error {
+ var buf [10]byte
+ buf[0] = encryptedKeyVersion
+ binary.BigEndian.PutUint64(buf[1:9], pub.KeyId)
+ buf[9] = byte(pub.PubKeyAlgo)
+
+ keyBlock := make([]byte, 1 /* cipher type */ +len(key)+2 /* checksum */)
+ keyBlock[0] = byte(cipherFunc)
+ copy(keyBlock[1:], key)
+ checksum := checksumKeyMaterial(key)
+ keyBlock[1+len(key)] = byte(checksum >> 8)
+ keyBlock[1+len(key)+1] = byte(checksum)
+
+ switch pub.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
+ return serializeEncryptedKeyRSA(w, config.Random(), buf, pub.PublicKey.(*rsa.PublicKey), keyBlock)
+ case PubKeyAlgoElGamal:
+ return serializeEncryptedKeyElGamal(w, config.Random(), buf, pub.PublicKey.(*elgamal.PublicKey), keyBlock)
+ case PubKeyAlgoDSA, PubKeyAlgoRSASignOnly:
+ return errors.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
+ }
+
+ return errors.UnsupportedError("encrypting a key to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
+}
+
+func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub *rsa.PublicKey, keyBlock []byte) error {
+ cipherText, err := rsa.EncryptPKCS1v15(rand, pub, keyBlock)
+ if err != nil {
+ return errors.InvalidArgumentError("RSA encryption failed: " + err.Error())
+ }
+
+ packetLen := 10 /* header length */ + 2 /* mpi size */ + len(cipherText)
+
+ err = serializeHeader(w, packetTypeEncryptedKey, packetLen)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(header[:])
+ if err != nil {
+ return err
+ }
+ return writeMPI(w, 8*uint16(len(cipherText)), cipherText)
+}
+
+func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header [10]byte, pub *elgamal.PublicKey, keyBlock []byte) error {
+ c1, c2, err := elgamal.Encrypt(rand, pub, keyBlock)
+ if err != nil {
+ return errors.InvalidArgumentError("ElGamal encryption failed: " + err.Error())
+ }
+
+ packetLen := 10 /* header length */
+ packetLen += 2 /* mpi size */ + (c1.BitLen()+7)/8
+ packetLen += 2 /* mpi size */ + (c2.BitLen()+7)/8
+
+ err = serializeHeader(w, packetTypeEncryptedKey, packetLen)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(header[:])
+ if err != nil {
+ return err
+ }
+ err = writeBig(w, c1)
+ if err != nil {
+ return err
+ }
+ return writeBig(w, c2)
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "encoding/binary"
+ "io"
+)
+
+// LiteralData represents an encrypted file. See RFC 4880, section 5.9.
+type LiteralData struct {
+ IsBinary bool
+ FileName string
+ Time uint32 // Unix epoch time. Either creation time or modification time. 0 means undefined.
+ Body io.Reader
+}
+
+// ForEyesOnly returns whether the contents of the LiteralData have been marked
+// as especially sensitive.
+func (l *LiteralData) ForEyesOnly() bool {
+ return l.FileName == "_CONSOLE"
+}
+
+func (l *LiteralData) parse(r io.Reader) (err error) {
+ var buf [256]byte
+
+ _, err = readFull(r, buf[:2])
+ if err != nil {
+ return
+ }
+
+ l.IsBinary = buf[0] == 'b'
+ fileNameLen := int(buf[1])
+
+ _, err = readFull(r, buf[:fileNameLen])
+ if err != nil {
+ return
+ }
+
+ l.FileName = string(buf[:fileNameLen])
+
+ _, err = readFull(r, buf[:4])
+ if err != nil {
+ return
+ }
+
+ l.Time = binary.BigEndian.Uint32(buf[:4])
+ l.Body = r
+ return
+}
+
+// SerializeLiteral serializes a literal data packet to w and returns a
+// WriteCloser to which the data itself can be written and which MUST be closed
+// on completion. The fileName is truncated to 255 bytes.
+func SerializeLiteral(w io.WriteCloser, isBinary bool, fileName string, time uint32) (plaintext io.WriteCloser, err error) {
+ var buf [4]byte
+ buf[0] = 't'
+ if isBinary {
+ buf[0] = 'b'
+ }
+ if len(fileName) > 255 {
+ fileName = fileName[:255]
+ }
+ buf[1] = byte(len(fileName))
+
+ inner, err := serializeStreamHeader(w, packetTypeLiteralData)
+ if err != nil {
+ return
+ }
+
+ _, err = inner.Write(buf[:2])
+ if err != nil {
+ return
+ }
+ _, err = inner.Write([]byte(fileName))
+ if err != nil {
+ return
+ }
+ binary.BigEndian.PutUint32(buf[:], time)
+ _, err = inner.Write(buf[:])
+ if err != nil {
+ return
+ }
+
+ plaintext = inner
+ return
+}
--- /dev/null
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// OpenPGP CFB Mode. http://tools.ietf.org/html/rfc4880#section-13.9
+
+package packet
+
+import (
+ "crypto/cipher"
+)
+
+type ocfbEncrypter struct {
+ b cipher.Block
+ fre []byte
+ outUsed int
+}
+
+// An OCFBResyncOption determines if the "resynchronization step" of OCFB is
+// performed.
+type OCFBResyncOption bool
+
+const (
+ OCFBResync OCFBResyncOption = true
+ OCFBNoResync OCFBResyncOption = false
+)
+
+// NewOCFBEncrypter returns a cipher.Stream which encrypts data with OpenPGP's
+// cipher feedback mode using the given cipher.Block, and an initial amount of
+// ciphertext. randData must be random bytes and be the same length as the
+// cipher.Block's block size. Resync determines if the "resynchronization step"
+// from RFC 4880, 13.9 step 7 is performed. Different parts of OpenPGP vary on
+// this point.
+func NewOCFBEncrypter(block cipher.Block, randData []byte, resync OCFBResyncOption) (cipher.Stream, []byte) {
+ blockSize := block.BlockSize()
+ if len(randData) != blockSize {
+ return nil, nil
+ }
+
+ x := &ocfbEncrypter{
+ b: block,
+ fre: make([]byte, blockSize),
+ outUsed: 0,
+ }
+ prefix := make([]byte, blockSize+2)
+
+ block.Encrypt(x.fre, x.fre)
+ for i := 0; i < blockSize; i++ {
+ prefix[i] = randData[i] ^ x.fre[i]
+ }
+
+ block.Encrypt(x.fre, prefix[:blockSize])
+ prefix[blockSize] = x.fre[0] ^ randData[blockSize-2]
+ prefix[blockSize+1] = x.fre[1] ^ randData[blockSize-1]
+
+ if resync {
+ block.Encrypt(x.fre, prefix[2:])
+ } else {
+ x.fre[0] = prefix[blockSize]
+ x.fre[1] = prefix[blockSize+1]
+ x.outUsed = 2
+ }
+ return x, prefix
+}
+
+func (x *ocfbEncrypter) XORKeyStream(dst, src []byte) {
+ for i := 0; i < len(src); i++ {
+ if x.outUsed == len(x.fre) {
+ x.b.Encrypt(x.fre, x.fre)
+ x.outUsed = 0
+ }
+
+ x.fre[x.outUsed] ^= src[i]
+ dst[i] = x.fre[x.outUsed]
+ x.outUsed++
+ }
+}
+
+type ocfbDecrypter struct {
+ b cipher.Block
+ fre []byte
+ outUsed int
+}
+
+// NewOCFBDecrypter returns a cipher.Stream which decrypts data with OpenPGP's
+// cipher feedback mode using the given cipher.Block. Prefix must be the first
+// blockSize + 2 bytes of the ciphertext, where blockSize is the cipher.Block's
+// block size. If an incorrect key is detected then nil is returned. On
+// successful exit, blockSize+2 bytes of decrypted data are written into
+// prefix. Resync determines if the "resynchronization step" from RFC 4880,
+// 13.9 step 7 is performed. Different parts of OpenPGP vary on this point.
+func NewOCFBDecrypter(block cipher.Block, prefix []byte, resync OCFBResyncOption) cipher.Stream {
+ blockSize := block.BlockSize()
+ if len(prefix) != blockSize+2 {
+ return nil
+ }
+
+ x := &ocfbDecrypter{
+ b: block,
+ fre: make([]byte, blockSize),
+ outUsed: 0,
+ }
+ prefixCopy := make([]byte, len(prefix))
+ copy(prefixCopy, prefix)
+
+ block.Encrypt(x.fre, x.fre)
+ for i := 0; i < blockSize; i++ {
+ prefixCopy[i] ^= x.fre[i]
+ }
+
+ block.Encrypt(x.fre, prefix[:blockSize])
+ prefixCopy[blockSize] ^= x.fre[0]
+ prefixCopy[blockSize+1] ^= x.fre[1]
+
+ if prefixCopy[blockSize-2] != prefixCopy[blockSize] ||
+ prefixCopy[blockSize-1] != prefixCopy[blockSize+1] {
+ return nil
+ }
+
+ if resync {
+ block.Encrypt(x.fre, prefix[2:])
+ } else {
+ x.fre[0] = prefix[blockSize]
+ x.fre[1] = prefix[blockSize+1]
+ x.outUsed = 2
+ }
+ copy(prefix, prefixCopy)
+ return x
+}
+
+func (x *ocfbDecrypter) XORKeyStream(dst, src []byte) {
+ for i := 0; i < len(src); i++ {
+ if x.outUsed == len(x.fre) {
+ x.b.Encrypt(x.fre, x.fre)
+ x.outUsed = 0
+ }
+
+ c := src[i]
+ dst[i] = x.fre[x.outUsed] ^ src[i]
+ x.fre[x.outUsed] = c
+ x.outUsed++
+ }
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "crypto"
+ "encoding/binary"
+ "golang.org/x/crypto/openpgp/errors"
+ "golang.org/x/crypto/openpgp/s2k"
+ "io"
+ "strconv"
+)
+
+// OnePassSignature represents a one-pass signature packet. See RFC 4880,
+// section 5.4.
+type OnePassSignature struct {
+ SigType SignatureType
+ Hash crypto.Hash
+ PubKeyAlgo PublicKeyAlgorithm
+ KeyId uint64
+ IsLast bool
+}
+
+const onePassSignatureVersion = 3
+
+func (ops *OnePassSignature) parse(r io.Reader) (err error) {
+ var buf [13]byte
+
+ _, err = readFull(r, buf[:])
+ if err != nil {
+ return
+ }
+ if buf[0] != onePassSignatureVersion {
+ err = errors.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0])))
+ }
+
+ var ok bool
+ ops.Hash, ok = s2k.HashIdToHash(buf[2])
+ if !ok {
+ return errors.UnsupportedError("hash function: " + strconv.Itoa(int(buf[2])))
+ }
+
+ ops.SigType = SignatureType(buf[1])
+ ops.PubKeyAlgo = PublicKeyAlgorithm(buf[3])
+ ops.KeyId = binary.BigEndian.Uint64(buf[4:12])
+ ops.IsLast = buf[12] != 0
+ return
+}
+
+// Serialize marshals the given OnePassSignature to w.
+func (ops *OnePassSignature) Serialize(w io.Writer) error {
+ var buf [13]byte
+ buf[0] = onePassSignatureVersion
+ buf[1] = uint8(ops.SigType)
+ var ok bool
+ buf[2], ok = s2k.HashToHashId(ops.Hash)
+ if !ok {
+ return errors.UnsupportedError("hash type: " + strconv.Itoa(int(ops.Hash)))
+ }
+ buf[3] = uint8(ops.PubKeyAlgo)
+ binary.BigEndian.PutUint64(buf[4:12], ops.KeyId)
+ if ops.IsLast {
+ buf[12] = 1
+ }
+
+ if err := serializeHeader(w, packetTypeOnePassSignature, len(buf)); err != nil {
+ return err
+ }
+ _, err := w.Write(buf[:])
+ return err
+}
--- /dev/null
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "bytes"
+ "io"
+ "io/ioutil"
+
+ "golang.org/x/crypto/openpgp/errors"
+)
+
+// OpaquePacket represents an OpenPGP packet as raw, unparsed data. This is
+// useful for splitting and storing the original packet contents separately,
+// handling unsupported packet types or accessing parts of the packet not yet
+// implemented by this package.
+type OpaquePacket struct {
+ // Packet type
+ Tag uint8
+ // Reason why the packet was parsed opaquely
+ Reason error
+ // Binary contents of the packet data
+ Contents []byte
+}
+
+func (op *OpaquePacket) parse(r io.Reader) (err error) {
+ op.Contents, err = ioutil.ReadAll(r)
+ return
+}
+
+// Serialize marshals the packet to a writer in its original form, including
+// the packet header.
+func (op *OpaquePacket) Serialize(w io.Writer) (err error) {
+ err = serializeHeader(w, packetType(op.Tag), len(op.Contents))
+ if err == nil {
+ _, err = w.Write(op.Contents)
+ }
+ return
+}
+
+// Parse attempts to parse the opaque contents into a structure supported by
+// this package. If the packet is not known then the result will be another
+// OpaquePacket.
+func (op *OpaquePacket) Parse() (p Packet, err error) {
+ hdr := bytes.NewBuffer(nil)
+ err = serializeHeader(hdr, packetType(op.Tag), len(op.Contents))
+ if err != nil {
+ op.Reason = err
+ return op, err
+ }
+ p, err = Read(io.MultiReader(hdr, bytes.NewBuffer(op.Contents)))
+ if err != nil {
+ op.Reason = err
+ p = op
+ }
+ return
+}
+
+// OpaqueReader reads OpaquePackets from an io.Reader.
+type OpaqueReader struct {
+ r io.Reader
+}
+
+func NewOpaqueReader(r io.Reader) *OpaqueReader {
+ return &OpaqueReader{r: r}
+}
+
+// Read the next OpaquePacket.
+func (or *OpaqueReader) Next() (op *OpaquePacket, err error) {
+ tag, _, contents, err := readHeader(or.r)
+ if err != nil {
+ return
+ }
+ op = &OpaquePacket{Tag: uint8(tag), Reason: err}
+ err = op.parse(contents)
+ if err != nil {
+ consumeAll(contents)
+ }
+ return
+}
+
+// OpaqueSubpacket represents an unparsed OpenPGP subpacket,
+// as found in signature and user attribute packets.
+type OpaqueSubpacket struct {
+ SubType uint8
+ Contents []byte
+}
+
+// OpaqueSubpackets extracts opaque, unparsed OpenPGP subpackets from
+// their byte representation.
+func OpaqueSubpackets(contents []byte) (result []*OpaqueSubpacket, err error) {
+ var (
+ subHeaderLen int
+ subPacket *OpaqueSubpacket
+ )
+ for len(contents) > 0 {
+ subHeaderLen, subPacket, err = nextSubpacket(contents)
+ if err != nil {
+ break
+ }
+ result = append(result, subPacket)
+ contents = contents[subHeaderLen+len(subPacket.Contents):]
+ }
+ return
+}
+
+func nextSubpacket(contents []byte) (subHeaderLen int, subPacket *OpaqueSubpacket, err error) {
+ // RFC 4880, section 5.2.3.1
+ var subLen uint32
+ if len(contents) < 1 {
+ goto Truncated
+ }
+ subPacket = &OpaqueSubpacket{}
+ switch {
+ case contents[0] < 192:
+ subHeaderLen = 2 // 1 length byte, 1 subtype byte
+ if len(contents) < subHeaderLen {
+ goto Truncated
+ }
+ subLen = uint32(contents[0])
+ contents = contents[1:]
+ case contents[0] < 255:
+ subHeaderLen = 3 // 2 length bytes, 1 subtype
+ if len(contents) < subHeaderLen {
+ goto Truncated
+ }
+ subLen = uint32(contents[0]-192)<<8 + uint32(contents[1]) + 192
+ contents = contents[2:]
+ default:
+ subHeaderLen = 6 // 5 length bytes, 1 subtype
+ if len(contents) < subHeaderLen {
+ goto Truncated
+ }
+ subLen = uint32(contents[1])<<24 |
+ uint32(contents[2])<<16 |
+ uint32(contents[3])<<8 |
+ uint32(contents[4])
+ contents = contents[5:]
+ }
+ if subLen > uint32(len(contents)) || subLen == 0 {
+ goto Truncated
+ }
+ subPacket.SubType = contents[0]
+ subPacket.Contents = contents[1:subLen]
+ return
+Truncated:
+ err = errors.StructuralError("subpacket truncated")
+ return
+}
+
+func (osp *OpaqueSubpacket) Serialize(w io.Writer) (err error) {
+ buf := make([]byte, 6)
+ n := serializeSubpacketLength(buf, len(osp.Contents)+1)
+ buf[n] = osp.SubType
+ if _, err = w.Write(buf[:n+1]); err != nil {
+ return
+ }
+ _, err = w.Write(osp.Contents)
+ return
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package packet implements parsing and serialization of OpenPGP packets, as
+// specified in RFC 4880.
+package packet // import "golang.org/x/crypto/openpgp/packet"
+
+import (
+ "bufio"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/des"
+ "crypto/rsa"
+ "io"
+ "math/big"
+
+ "golang.org/x/crypto/cast5"
+ "golang.org/x/crypto/openpgp/errors"
+)
+
+// readFull is the same as io.ReadFull except that reading zero bytes returns
+// ErrUnexpectedEOF rather than EOF.
+func readFull(r io.Reader, buf []byte) (n int, err error) {
+ n, err = io.ReadFull(r, buf)
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return
+}
+
+// readLength reads an OpenPGP length from r. See RFC 4880, section 4.2.2.
+func readLength(r io.Reader) (length int64, isPartial bool, err error) {
+ var buf [4]byte
+ _, err = readFull(r, buf[:1])
+ if err != nil {
+ return
+ }
+ switch {
+ case buf[0] < 192:
+ length = int64(buf[0])
+ case buf[0] < 224:
+ length = int64(buf[0]-192) << 8
+ _, err = readFull(r, buf[0:1])
+ if err != nil {
+ return
+ }
+ length += int64(buf[0]) + 192
+ case buf[0] < 255:
+ length = int64(1) << (buf[0] & 0x1f)
+ isPartial = true
+ default:
+ _, err = readFull(r, buf[0:4])
+ if err != nil {
+ return
+ }
+ length = int64(buf[0])<<24 |
+ int64(buf[1])<<16 |
+ int64(buf[2])<<8 |
+ int64(buf[3])
+ }
+ return
+}
+
+// partialLengthReader wraps an io.Reader and handles OpenPGP partial lengths.
+// The continuation lengths are parsed and removed from the stream and EOF is
+// returned at the end of the packet. See RFC 4880, section 4.2.2.4.
+type partialLengthReader struct {
+ r io.Reader
+ remaining int64
+ isPartial bool
+}
+
+func (r *partialLengthReader) Read(p []byte) (n int, err error) {
+ for r.remaining == 0 {
+ if !r.isPartial {
+ return 0, io.EOF
+ }
+ r.remaining, r.isPartial, err = readLength(r.r)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ toRead := int64(len(p))
+ if toRead > r.remaining {
+ toRead = r.remaining
+ }
+
+ n, err = r.r.Read(p[:int(toRead)])
+ r.remaining -= int64(n)
+ if n < int(toRead) && err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return
+}
+
+// partialLengthWriter writes a stream of data using OpenPGP partial lengths.
+// See RFC 4880, section 4.2.2.4.
+type partialLengthWriter struct {
+ w io.WriteCloser
+ lengthByte [1]byte
+}
+
+func (w *partialLengthWriter) Write(p []byte) (n int, err error) {
+ for len(p) > 0 {
+ for power := uint(14); power < 32; power-- {
+ l := 1 << power
+ if len(p) >= l {
+ w.lengthByte[0] = 224 + uint8(power)
+ _, err = w.w.Write(w.lengthByte[:])
+ if err != nil {
+ return
+ }
+ var m int
+ m, err = w.w.Write(p[:l])
+ n += m
+ if err != nil {
+ return
+ }
+ p = p[l:]
+ break
+ }
+ }
+ }
+ return
+}
+
+func (w *partialLengthWriter) Close() error {
+ w.lengthByte[0] = 0
+ _, err := w.w.Write(w.lengthByte[:])
+ if err != nil {
+ return err
+ }
+ return w.w.Close()
+}
+
+// A spanReader is an io.LimitReader, but it returns ErrUnexpectedEOF if the
+// underlying Reader returns EOF before the limit has been reached.
+type spanReader struct {
+ r io.Reader
+ n int64
+}
+
+func (l *spanReader) Read(p []byte) (n int, err error) {
+ if l.n <= 0 {
+ return 0, io.EOF
+ }
+ if int64(len(p)) > l.n {
+ p = p[0:l.n]
+ }
+ n, err = l.r.Read(p)
+ l.n -= int64(n)
+ if l.n > 0 && err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return
+}
+
+// readHeader parses a packet header and returns an io.Reader which will return
+// the contents of the packet. See RFC 4880, section 4.2.
+func readHeader(r io.Reader) (tag packetType, length int64, contents io.Reader, err error) {
+ var buf [4]byte
+ _, err = io.ReadFull(r, buf[:1])
+ if err != nil {
+ return
+ }
+ if buf[0]&0x80 == 0 {
+ err = errors.StructuralError("tag byte does not have MSB set")
+ return
+ }
+ if buf[0]&0x40 == 0 {
+ // Old format packet
+ tag = packetType((buf[0] & 0x3f) >> 2)
+ lengthType := buf[0] & 3
+ if lengthType == 3 {
+ length = -1
+ contents = r
+ return
+ }
+ lengthBytes := 1 << lengthType
+ _, err = readFull(r, buf[0:lengthBytes])
+ if err != nil {
+ return
+ }
+ for i := 0; i < lengthBytes; i++ {
+ length <<= 8
+ length |= int64(buf[i])
+ }
+ contents = &spanReader{r, length}
+ return
+ }
+
+ // New format packet
+ tag = packetType(buf[0] & 0x3f)
+ length, isPartial, err := readLength(r)
+ if err != nil {
+ return
+ }
+ if isPartial {
+ contents = &partialLengthReader{
+ remaining: length,
+ isPartial: true,
+ r: r,
+ }
+ length = -1
+ } else {
+ contents = &spanReader{r, length}
+ }
+ return
+}
+
+// serializeHeader writes an OpenPGP packet header to w. See RFC 4880, section
+// 4.2.
+func serializeHeader(w io.Writer, ptype packetType, length int) (err error) {
+ var buf [6]byte
+ var n int
+
+ buf[0] = 0x80 | 0x40 | byte(ptype)
+ if length < 192 {
+ buf[1] = byte(length)
+ n = 2
+ } else if length < 8384 {
+ length -= 192
+ buf[1] = 192 + byte(length>>8)
+ buf[2] = byte(length)
+ n = 3
+ } else {
+ buf[1] = 255
+ buf[2] = byte(length >> 24)
+ buf[3] = byte(length >> 16)
+ buf[4] = byte(length >> 8)
+ buf[5] = byte(length)
+ n = 6
+ }
+
+ _, err = w.Write(buf[:n])
+ return
+}
+
+// serializeStreamHeader writes an OpenPGP packet header to w where the
+// length of the packet is unknown. It returns a io.WriteCloser which can be
+// used to write the contents of the packet. See RFC 4880, section 4.2.
+func serializeStreamHeader(w io.WriteCloser, ptype packetType) (out io.WriteCloser, err error) {
+ var buf [1]byte
+ buf[0] = 0x80 | 0x40 | byte(ptype)
+ _, err = w.Write(buf[:])
+ if err != nil {
+ return
+ }
+ out = &partialLengthWriter{w: w}
+ return
+}
+
+// Packet represents an OpenPGP packet. Users are expected to try casting
+// instances of this interface to specific packet types.
+type Packet interface {
+ parse(io.Reader) error
+}
+
+// consumeAll reads from the given Reader until error, returning the number of
+// bytes read.
+func consumeAll(r io.Reader) (n int64, err error) {
+ var m int
+ var buf [1024]byte
+
+ for {
+ m, err = r.Read(buf[:])
+ n += int64(m)
+ if err == io.EOF {
+ err = nil
+ return
+ }
+ if err != nil {
+ return
+ }
+ }
+}
+
+// packetType represents the numeric ids of the different OpenPGP packet types. See
+// http://www.iana.org/assignments/pgp-parameters/pgp-parameters.xhtml#pgp-parameters-2
+type packetType uint8
+
+const (
+ packetTypeEncryptedKey packetType = 1
+ packetTypeSignature packetType = 2
+ packetTypeSymmetricKeyEncrypted packetType = 3
+ packetTypeOnePassSignature packetType = 4
+ packetTypePrivateKey packetType = 5
+ packetTypePublicKey packetType = 6
+ packetTypePrivateSubkey packetType = 7
+ packetTypeCompressed packetType = 8
+ packetTypeSymmetricallyEncrypted packetType = 9
+ packetTypeLiteralData packetType = 11
+ packetTypeUserId packetType = 13
+ packetTypePublicSubkey packetType = 14
+ packetTypeUserAttribute packetType = 17
+ packetTypeSymmetricallyEncryptedMDC packetType = 18
+)
+
+// peekVersion detects the version of a public key packet about to
+// be read. A bufio.Reader at the original position of the io.Reader
+// is returned.
+func peekVersion(r io.Reader) (bufr *bufio.Reader, ver byte, err error) {
+ bufr = bufio.NewReader(r)
+ var verBuf []byte
+ if verBuf, err = bufr.Peek(1); err != nil {
+ return
+ }
+ ver = verBuf[0]
+ return
+}
+
+// Read reads a single OpenPGP packet from the given io.Reader. If there is an
+// error parsing a packet, the whole packet is consumed from the input.
+func Read(r io.Reader) (p Packet, err error) {
+ tag, _, contents, err := readHeader(r)
+ if err != nil {
+ return
+ }
+
+ switch tag {
+ case packetTypeEncryptedKey:
+ p = new(EncryptedKey)
+ case packetTypeSignature:
+ var version byte
+ // Detect signature version
+ if contents, version, err = peekVersion(contents); err != nil {
+ return
+ }
+ if version < 4 {
+ p = new(SignatureV3)
+ } else {
+ p = new(Signature)
+ }
+ case packetTypeSymmetricKeyEncrypted:
+ p = new(SymmetricKeyEncrypted)
+ case packetTypeOnePassSignature:
+ p = new(OnePassSignature)
+ case packetTypePrivateKey, packetTypePrivateSubkey:
+ pk := new(PrivateKey)
+ if tag == packetTypePrivateSubkey {
+ pk.IsSubkey = true
+ }
+ p = pk
+ case packetTypePublicKey, packetTypePublicSubkey:
+ var version byte
+ if contents, version, err = peekVersion(contents); err != nil {
+ return
+ }
+ isSubkey := tag == packetTypePublicSubkey
+ if version < 4 {
+ p = &PublicKeyV3{IsSubkey: isSubkey}
+ } else {
+ p = &PublicKey{IsSubkey: isSubkey}
+ }
+ case packetTypeCompressed:
+ p = new(Compressed)
+ case packetTypeSymmetricallyEncrypted:
+ p = new(SymmetricallyEncrypted)
+ case packetTypeLiteralData:
+ p = new(LiteralData)
+ case packetTypeUserId:
+ p = new(UserId)
+ case packetTypeUserAttribute:
+ p = new(UserAttribute)
+ case packetTypeSymmetricallyEncryptedMDC:
+ se := new(SymmetricallyEncrypted)
+ se.MDC = true
+ p = se
+ default:
+ err = errors.UnknownPacketTypeError(tag)
+ }
+ if p != nil {
+ err = p.parse(contents)
+ }
+ if err != nil {
+ consumeAll(contents)
+ }
+ return
+}
+
+// SignatureType represents the different semantic meanings of an OpenPGP
+// signature. See RFC 4880, section 5.2.1.
+type SignatureType uint8
+
+const (
+ SigTypeBinary SignatureType = 0
+ SigTypeText = 1
+ SigTypeGenericCert = 0x10
+ SigTypePersonaCert = 0x11
+ SigTypeCasualCert = 0x12
+ SigTypePositiveCert = 0x13
+ SigTypeSubkeyBinding = 0x18
+ SigTypePrimaryKeyBinding = 0x19
+ SigTypeDirectSignature = 0x1F
+ SigTypeKeyRevocation = 0x20
+ SigTypeSubkeyRevocation = 0x28
+)
+
+// PublicKeyAlgorithm represents the different public key system specified for
+// OpenPGP. See
+// http://www.iana.org/assignments/pgp-parameters/pgp-parameters.xhtml#pgp-parameters-12
+type PublicKeyAlgorithm uint8
+
+const (
+ PubKeyAlgoRSA PublicKeyAlgorithm = 1
+ PubKeyAlgoRSAEncryptOnly PublicKeyAlgorithm = 2
+ PubKeyAlgoRSASignOnly PublicKeyAlgorithm = 3
+ PubKeyAlgoElGamal PublicKeyAlgorithm = 16
+ PubKeyAlgoDSA PublicKeyAlgorithm = 17
+ // RFC 6637, Section 5.
+ PubKeyAlgoECDH PublicKeyAlgorithm = 18
+ PubKeyAlgoECDSA PublicKeyAlgorithm = 19
+)
+
+// CanEncrypt returns true if it's possible to encrypt a message to a public
+// key of the given type.
+func (pka PublicKeyAlgorithm) CanEncrypt() bool {
+ switch pka {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal:
+ return true
+ }
+ return false
+}
+
+// CanSign returns true if it's possible for a public key of the given type to
+// sign a message.
+func (pka PublicKeyAlgorithm) CanSign() bool {
+ switch pka {
+ case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA, PubKeyAlgoECDSA:
+ return true
+ }
+ return false
+}
+
+// CipherFunction represents the different block ciphers specified for OpenPGP. See
+// http://www.iana.org/assignments/pgp-parameters/pgp-parameters.xhtml#pgp-parameters-13
+type CipherFunction uint8
+
+const (
+ Cipher3DES CipherFunction = 2
+ CipherCAST5 CipherFunction = 3
+ CipherAES128 CipherFunction = 7
+ CipherAES192 CipherFunction = 8
+ CipherAES256 CipherFunction = 9
+)
+
+// KeySize returns the key size, in bytes, of cipher.
+func (cipher CipherFunction) KeySize() int {
+ switch cipher {
+ case Cipher3DES:
+ return 24
+ case CipherCAST5:
+ return cast5.KeySize
+ case CipherAES128:
+ return 16
+ case CipherAES192:
+ return 24
+ case CipherAES256:
+ return 32
+ }
+ return 0
+}
+
+// blockSize returns the block size, in bytes, of cipher.
+func (cipher CipherFunction) blockSize() int {
+ switch cipher {
+ case Cipher3DES:
+ return des.BlockSize
+ case CipherCAST5:
+ return 8
+ case CipherAES128, CipherAES192, CipherAES256:
+ return 16
+ }
+ return 0
+}
+
+// new returns a fresh instance of the given cipher.
+func (cipher CipherFunction) new(key []byte) (block cipher.Block) {
+ switch cipher {
+ case Cipher3DES:
+ block, _ = des.NewTripleDESCipher(key)
+ case CipherCAST5:
+ block, _ = cast5.NewCipher(key)
+ case CipherAES128, CipherAES192, CipherAES256:
+ block, _ = aes.NewCipher(key)
+ }
+ return
+}
+
+// readMPI reads a big integer from r. The bit length returned is the bit
+// length that was specified in r. This is preserved so that the integer can be
+// reserialized exactly.
+func readMPI(r io.Reader) (mpi []byte, bitLength uint16, err error) {
+ var buf [2]byte
+ _, err = readFull(r, buf[0:])
+ if err != nil {
+ return
+ }
+ bitLength = uint16(buf[0])<<8 | uint16(buf[1])
+ numBytes := (int(bitLength) + 7) / 8
+ mpi = make([]byte, numBytes)
+ _, err = readFull(r, mpi)
+ // According to RFC 4880 3.2. we should check that the MPI has no leading
+ // zeroes (at least when not an encrypted MPI?), but this implementation
+ // does generate leading zeroes, so we keep accepting them.
+ return
+}
+
+// writeMPI serializes a big integer to w.
+func writeMPI(w io.Writer, bitLength uint16, mpiBytes []byte) (err error) {
+ // Note that we can produce leading zeroes, in violation of RFC 4880 3.2.
+ // Implementations seem to be tolerant of them, and stripping them would
+ // make it complex to guarantee matching re-serialization.
+ _, err = w.Write([]byte{byte(bitLength >> 8), byte(bitLength)})
+ if err == nil {
+ _, err = w.Write(mpiBytes)
+ }
+ return
+}
+
+// writeBig serializes a *big.Int to w.
+func writeBig(w io.Writer, i *big.Int) error {
+ return writeMPI(w, uint16(i.BitLen()), i.Bytes())
+}
+
+// padToKeySize left-pads a MPI with zeroes to match the length of the
+// specified RSA public.
+func padToKeySize(pub *rsa.PublicKey, b []byte) []byte {
+ k := (pub.N.BitLen() + 7) / 8
+ if len(b) >= k {
+ return b
+ }
+ bb := make([]byte, k)
+ copy(bb[len(bb)-len(b):], b)
+ return bb
+}
+
+// CompressionAlgo Represents the different compression algorithms
+// supported by OpenPGP (except for BZIP2, which is not currently
+// supported). See Section 9.3 of RFC 4880.
+type CompressionAlgo uint8
+
+const (
+ CompressionNone CompressionAlgo = 0
+ CompressionZIP CompressionAlgo = 1
+ CompressionZLIB CompressionAlgo = 2
+)
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "bytes"
+ "crypto"
+ "crypto/cipher"
+ "crypto/dsa"
+ "crypto/ecdsa"
+ "crypto/rsa"
+ "crypto/sha1"
+ "io"
+ "io/ioutil"
+ "math/big"
+ "strconv"
+ "time"
+
+ "golang.org/x/crypto/openpgp/elgamal"
+ "golang.org/x/crypto/openpgp/errors"
+ "golang.org/x/crypto/openpgp/s2k"
+)
+
+// PrivateKey represents a possibly encrypted private key. See RFC 4880,
+// section 5.5.3.
+type PrivateKey struct {
+ PublicKey
+ Encrypted bool // if true then the private key is unavailable until Decrypt has been called.
+ encryptedData []byte
+ cipher CipherFunction
+ s2k func(out, in []byte)
+ PrivateKey interface{} // An *{rsa|dsa|ecdsa}.PrivateKey or a crypto.Signer.
+ sha1Checksum bool
+ iv []byte
+}
+
+func NewRSAPrivateKey(currentTime time.Time, priv *rsa.PrivateKey) *PrivateKey {
+ pk := new(PrivateKey)
+ pk.PublicKey = *NewRSAPublicKey(currentTime, &priv.PublicKey)
+ pk.PrivateKey = priv
+ return pk
+}
+
+func NewDSAPrivateKey(currentTime time.Time, priv *dsa.PrivateKey) *PrivateKey {
+ pk := new(PrivateKey)
+ pk.PublicKey = *NewDSAPublicKey(currentTime, &priv.PublicKey)
+ pk.PrivateKey = priv
+ return pk
+}
+
+func NewElGamalPrivateKey(currentTime time.Time, priv *elgamal.PrivateKey) *PrivateKey {
+ pk := new(PrivateKey)
+ pk.PublicKey = *NewElGamalPublicKey(currentTime, &priv.PublicKey)
+ pk.PrivateKey = priv
+ return pk
+}
+
+func NewECDSAPrivateKey(currentTime time.Time, priv *ecdsa.PrivateKey) *PrivateKey {
+ pk := new(PrivateKey)
+ pk.PublicKey = *NewECDSAPublicKey(currentTime, &priv.PublicKey)
+ pk.PrivateKey = priv
+ return pk
+}
+
+// NewSignerPrivateKey creates a sign-only PrivateKey from a crypto.Signer that
+// implements RSA or ECDSA.
+func NewSignerPrivateKey(currentTime time.Time, signer crypto.Signer) *PrivateKey {
+ pk := new(PrivateKey)
+ switch pubkey := signer.Public().(type) {
+ case rsa.PublicKey:
+ pk.PublicKey = *NewRSAPublicKey(currentTime, &pubkey)
+ pk.PubKeyAlgo = PubKeyAlgoRSASignOnly
+ case ecdsa.PublicKey:
+ pk.PublicKey = *NewECDSAPublicKey(currentTime, &pubkey)
+ default:
+ panic("openpgp: unknown crypto.Signer type in NewSignerPrivateKey")
+ }
+ pk.PrivateKey = signer
+ return pk
+}
+
+func (pk *PrivateKey) parse(r io.Reader) (err error) {
+ err = (&pk.PublicKey).parse(r)
+ if err != nil {
+ return
+ }
+ var buf [1]byte
+ _, err = readFull(r, buf[:])
+ if err != nil {
+ return
+ }
+
+ s2kType := buf[0]
+
+ switch s2kType {
+ case 0:
+ pk.s2k = nil
+ pk.Encrypted = false
+ case 254, 255:
+ _, err = readFull(r, buf[:])
+ if err != nil {
+ return
+ }
+ pk.cipher = CipherFunction(buf[0])
+ pk.Encrypted = true
+ pk.s2k, err = s2k.Parse(r)
+ if err != nil {
+ return
+ }
+ if s2kType == 254 {
+ pk.sha1Checksum = true
+ }
+ default:
+ return errors.UnsupportedError("deprecated s2k function in private key")
+ }
+
+ if pk.Encrypted {
+ blockSize := pk.cipher.blockSize()
+ if blockSize == 0 {
+ return errors.UnsupportedError("unsupported cipher in private key: " + strconv.Itoa(int(pk.cipher)))
+ }
+ pk.iv = make([]byte, blockSize)
+ _, err = readFull(r, pk.iv)
+ if err != nil {
+ return
+ }
+ }
+
+ pk.encryptedData, err = ioutil.ReadAll(r)
+ if err != nil {
+ return
+ }
+
+ if !pk.Encrypted {
+ return pk.parsePrivateKey(pk.encryptedData)
+ }
+
+ return
+}
+
+func mod64kHash(d []byte) uint16 {
+ var h uint16
+ for _, b := range d {
+ h += uint16(b)
+ }
+ return h
+}
+
+func (pk *PrivateKey) Serialize(w io.Writer) (err error) {
+ // TODO(agl): support encrypted private keys
+ buf := bytes.NewBuffer(nil)
+ err = pk.PublicKey.serializeWithoutHeaders(buf)
+ if err != nil {
+ return
+ }
+ buf.WriteByte(0 /* no encryption */)
+
+ privateKeyBuf := bytes.NewBuffer(nil)
+
+ switch priv := pk.PrivateKey.(type) {
+ case *rsa.PrivateKey:
+ err = serializeRSAPrivateKey(privateKeyBuf, priv)
+ case *dsa.PrivateKey:
+ err = serializeDSAPrivateKey(privateKeyBuf, priv)
+ case *elgamal.PrivateKey:
+ err = serializeElGamalPrivateKey(privateKeyBuf, priv)
+ case *ecdsa.PrivateKey:
+ err = serializeECDSAPrivateKey(privateKeyBuf, priv)
+ default:
+ err = errors.InvalidArgumentError("unknown private key type")
+ }
+ if err != nil {
+ return
+ }
+
+ ptype := packetTypePrivateKey
+ contents := buf.Bytes()
+ privateKeyBytes := privateKeyBuf.Bytes()
+ if pk.IsSubkey {
+ ptype = packetTypePrivateSubkey
+ }
+ err = serializeHeader(w, ptype, len(contents)+len(privateKeyBytes)+2)
+ if err != nil {
+ return
+ }
+ _, err = w.Write(contents)
+ if err != nil {
+ return
+ }
+ _, err = w.Write(privateKeyBytes)
+ if err != nil {
+ return
+ }
+
+ checksum := mod64kHash(privateKeyBytes)
+ var checksumBytes [2]byte
+ checksumBytes[0] = byte(checksum >> 8)
+ checksumBytes[1] = byte(checksum)
+ _, err = w.Write(checksumBytes[:])
+
+ return
+}
+
+func serializeRSAPrivateKey(w io.Writer, priv *rsa.PrivateKey) error {
+ err := writeBig(w, priv.D)
+ if err != nil {
+ return err
+ }
+ err = writeBig(w, priv.Primes[1])
+ if err != nil {
+ return err
+ }
+ err = writeBig(w, priv.Primes[0])
+ if err != nil {
+ return err
+ }
+ return writeBig(w, priv.Precomputed.Qinv)
+}
+
+func serializeDSAPrivateKey(w io.Writer, priv *dsa.PrivateKey) error {
+ return writeBig(w, priv.X)
+}
+
+func serializeElGamalPrivateKey(w io.Writer, priv *elgamal.PrivateKey) error {
+ return writeBig(w, priv.X)
+}
+
+func serializeECDSAPrivateKey(w io.Writer, priv *ecdsa.PrivateKey) error {
+ return writeBig(w, priv.D)
+}
+
+// Decrypt decrypts an encrypted private key using a passphrase.
+func (pk *PrivateKey) Decrypt(passphrase []byte) error {
+ if !pk.Encrypted {
+ return nil
+ }
+
+ key := make([]byte, pk.cipher.KeySize())
+ pk.s2k(key, passphrase)
+ block := pk.cipher.new(key)
+ cfb := cipher.NewCFBDecrypter(block, pk.iv)
+
+ data := make([]byte, len(pk.encryptedData))
+ cfb.XORKeyStream(data, pk.encryptedData)
+
+ if pk.sha1Checksum {
+ if len(data) < sha1.Size {
+ return errors.StructuralError("truncated private key data")
+ }
+ h := sha1.New()
+ h.Write(data[:len(data)-sha1.Size])
+ sum := h.Sum(nil)
+ if !bytes.Equal(sum, data[len(data)-sha1.Size:]) {
+ return errors.StructuralError("private key checksum failure")
+ }
+ data = data[:len(data)-sha1.Size]
+ } else {
+ if len(data) < 2 {
+ return errors.StructuralError("truncated private key data")
+ }
+ var sum uint16
+ for i := 0; i < len(data)-2; i++ {
+ sum += uint16(data[i])
+ }
+ if data[len(data)-2] != uint8(sum>>8) ||
+ data[len(data)-1] != uint8(sum) {
+ return errors.StructuralError("private key checksum failure")
+ }
+ data = data[:len(data)-2]
+ }
+
+ return pk.parsePrivateKey(data)
+}
+
+func (pk *PrivateKey) parsePrivateKey(data []byte) (err error) {
+ switch pk.PublicKey.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoRSAEncryptOnly:
+ return pk.parseRSAPrivateKey(data)
+ case PubKeyAlgoDSA:
+ return pk.parseDSAPrivateKey(data)
+ case PubKeyAlgoElGamal:
+ return pk.parseElGamalPrivateKey(data)
+ case PubKeyAlgoECDSA:
+ return pk.parseECDSAPrivateKey(data)
+ }
+ panic("impossible")
+}
+
+func (pk *PrivateKey) parseRSAPrivateKey(data []byte) (err error) {
+ rsaPub := pk.PublicKey.PublicKey.(*rsa.PublicKey)
+ rsaPriv := new(rsa.PrivateKey)
+ rsaPriv.PublicKey = *rsaPub
+
+ buf := bytes.NewBuffer(data)
+ d, _, err := readMPI(buf)
+ if err != nil {
+ return
+ }
+ p, _, err := readMPI(buf)
+ if err != nil {
+ return
+ }
+ q, _, err := readMPI(buf)
+ if err != nil {
+ return
+ }
+
+ rsaPriv.D = new(big.Int).SetBytes(d)
+ rsaPriv.Primes = make([]*big.Int, 2)
+ rsaPriv.Primes[0] = new(big.Int).SetBytes(p)
+ rsaPriv.Primes[1] = new(big.Int).SetBytes(q)
+ if err := rsaPriv.Validate(); err != nil {
+ return err
+ }
+ rsaPriv.Precompute()
+ pk.PrivateKey = rsaPriv
+ pk.Encrypted = false
+ pk.encryptedData = nil
+
+ return nil
+}
+
+func (pk *PrivateKey) parseDSAPrivateKey(data []byte) (err error) {
+ dsaPub := pk.PublicKey.PublicKey.(*dsa.PublicKey)
+ dsaPriv := new(dsa.PrivateKey)
+ dsaPriv.PublicKey = *dsaPub
+
+ buf := bytes.NewBuffer(data)
+ x, _, err := readMPI(buf)
+ if err != nil {
+ return
+ }
+
+ dsaPriv.X = new(big.Int).SetBytes(x)
+ pk.PrivateKey = dsaPriv
+ pk.Encrypted = false
+ pk.encryptedData = nil
+
+ return nil
+}
+
+func (pk *PrivateKey) parseElGamalPrivateKey(data []byte) (err error) {
+ pub := pk.PublicKey.PublicKey.(*elgamal.PublicKey)
+ priv := new(elgamal.PrivateKey)
+ priv.PublicKey = *pub
+
+ buf := bytes.NewBuffer(data)
+ x, _, err := readMPI(buf)
+ if err != nil {
+ return
+ }
+
+ priv.X = new(big.Int).SetBytes(x)
+ pk.PrivateKey = priv
+ pk.Encrypted = false
+ pk.encryptedData = nil
+
+ return nil
+}
+
+func (pk *PrivateKey) parseECDSAPrivateKey(data []byte) (err error) {
+ ecdsaPub := pk.PublicKey.PublicKey.(*ecdsa.PublicKey)
+
+ buf := bytes.NewBuffer(data)
+ d, _, err := readMPI(buf)
+ if err != nil {
+ return
+ }
+
+ pk.PrivateKey = &ecdsa.PrivateKey{
+ PublicKey: *ecdsaPub,
+ D: new(big.Int).SetBytes(d),
+ }
+ pk.Encrypted = false
+ pk.encryptedData = nil
+
+ return nil
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "bytes"
+ "crypto"
+ "crypto/dsa"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rsa"
+ "crypto/sha1"
+ _ "crypto/sha256"
+ _ "crypto/sha512"
+ "encoding/binary"
+ "fmt"
+ "hash"
+ "io"
+ "math/big"
+ "strconv"
+ "time"
+
+ "golang.org/x/crypto/openpgp/elgamal"
+ "golang.org/x/crypto/openpgp/errors"
+)
+
+var (
+ // NIST curve P-256
+ oidCurveP256 []byte = []byte{0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07}
+ // NIST curve P-384
+ oidCurveP384 []byte = []byte{0x2B, 0x81, 0x04, 0x00, 0x22}
+ // NIST curve P-521
+ oidCurveP521 []byte = []byte{0x2B, 0x81, 0x04, 0x00, 0x23}
+)
+
+const maxOIDLength = 8
+
+// ecdsaKey stores the algorithm-specific fields for ECDSA keys.
+// as defined in RFC 6637, Section 9.
+type ecdsaKey struct {
+ // oid contains the OID byte sequence identifying the elliptic curve used
+ oid []byte
+ // p contains the elliptic curve point that represents the public key
+ p parsedMPI
+}
+
+// parseOID reads the OID for the curve as defined in RFC 6637, Section 9.
+func parseOID(r io.Reader) (oid []byte, err error) {
+ buf := make([]byte, maxOIDLength)
+ if _, err = readFull(r, buf[:1]); err != nil {
+ return
+ }
+ oidLen := buf[0]
+ if int(oidLen) > len(buf) {
+ err = errors.UnsupportedError("invalid oid length: " + strconv.Itoa(int(oidLen)))
+ return
+ }
+ oid = buf[:oidLen]
+ _, err = readFull(r, oid)
+ return
+}
+
+func (f *ecdsaKey) parse(r io.Reader) (err error) {
+ if f.oid, err = parseOID(r); err != nil {
+ return err
+ }
+ f.p.bytes, f.p.bitLength, err = readMPI(r)
+ return
+}
+
+func (f *ecdsaKey) serialize(w io.Writer) (err error) {
+ buf := make([]byte, maxOIDLength+1)
+ buf[0] = byte(len(f.oid))
+ copy(buf[1:], f.oid)
+ if _, err = w.Write(buf[:len(f.oid)+1]); err != nil {
+ return
+ }
+ return writeMPIs(w, f.p)
+}
+
+func (f *ecdsaKey) newECDSA() (*ecdsa.PublicKey, error) {
+ var c elliptic.Curve
+ if bytes.Equal(f.oid, oidCurveP256) {
+ c = elliptic.P256()
+ } else if bytes.Equal(f.oid, oidCurveP384) {
+ c = elliptic.P384()
+ } else if bytes.Equal(f.oid, oidCurveP521) {
+ c = elliptic.P521()
+ } else {
+ return nil, errors.UnsupportedError(fmt.Sprintf("unsupported oid: %x", f.oid))
+ }
+ x, y := elliptic.Unmarshal(c, f.p.bytes)
+ if x == nil {
+ return nil, errors.UnsupportedError("failed to parse EC point")
+ }
+ return &ecdsa.PublicKey{Curve: c, X: x, Y: y}, nil
+}
+
+func (f *ecdsaKey) byteLen() int {
+ return 1 + len(f.oid) + 2 + len(f.p.bytes)
+}
+
+type kdfHashFunction byte
+type kdfAlgorithm byte
+
+// ecdhKdf stores key derivation function parameters
+// used for ECDH encryption. See RFC 6637, Section 9.
+type ecdhKdf struct {
+ KdfHash kdfHashFunction
+ KdfAlgo kdfAlgorithm
+}
+
+func (f *ecdhKdf) parse(r io.Reader) (err error) {
+ buf := make([]byte, 1)
+ if _, err = readFull(r, buf); err != nil {
+ return
+ }
+ kdfLen := int(buf[0])
+ if kdfLen < 3 {
+ return errors.UnsupportedError("Unsupported ECDH KDF length: " + strconv.Itoa(kdfLen))
+ }
+ buf = make([]byte, kdfLen)
+ if _, err = readFull(r, buf); err != nil {
+ return
+ }
+ reserved := int(buf[0])
+ f.KdfHash = kdfHashFunction(buf[1])
+ f.KdfAlgo = kdfAlgorithm(buf[2])
+ if reserved != 0x01 {
+ return errors.UnsupportedError("Unsupported KDF reserved field: " + strconv.Itoa(reserved))
+ }
+ return
+}
+
+func (f *ecdhKdf) serialize(w io.Writer) (err error) {
+ buf := make([]byte, 4)
+ // See RFC 6637, Section 9, Algorithm-Specific Fields for ECDH keys.
+ buf[0] = byte(0x03) // Length of the following fields
+ buf[1] = byte(0x01) // Reserved for future extensions, must be 1 for now
+ buf[2] = byte(f.KdfHash)
+ buf[3] = byte(f.KdfAlgo)
+ _, err = w.Write(buf[:])
+ return
+}
+
+func (f *ecdhKdf) byteLen() int {
+ return 4
+}
+
+// PublicKey represents an OpenPGP public key. See RFC 4880, section 5.5.2.
+type PublicKey struct {
+ CreationTime time.Time
+ PubKeyAlgo PublicKeyAlgorithm
+ PublicKey interface{} // *rsa.PublicKey, *dsa.PublicKey or *ecdsa.PublicKey
+ Fingerprint [20]byte
+ KeyId uint64
+ IsSubkey bool
+
+ n, e, p, q, g, y parsedMPI
+
+ // RFC 6637 fields
+ ec *ecdsaKey
+ ecdh *ecdhKdf
+}
+
+// signingKey provides a convenient abstraction over signature verification
+// for v3 and v4 public keys.
+type signingKey interface {
+ SerializeSignaturePrefix(io.Writer)
+ serializeWithoutHeaders(io.Writer) error
+}
+
+func fromBig(n *big.Int) parsedMPI {
+ return parsedMPI{
+ bytes: n.Bytes(),
+ bitLength: uint16(n.BitLen()),
+ }
+}
+
+// NewRSAPublicKey returns a PublicKey that wraps the given rsa.PublicKey.
+func NewRSAPublicKey(creationTime time.Time, pub *rsa.PublicKey) *PublicKey {
+ pk := &PublicKey{
+ CreationTime: creationTime,
+ PubKeyAlgo: PubKeyAlgoRSA,
+ PublicKey: pub,
+ n: fromBig(pub.N),
+ e: fromBig(big.NewInt(int64(pub.E))),
+ }
+
+ pk.setFingerPrintAndKeyId()
+ return pk
+}
+
+// NewDSAPublicKey returns a PublicKey that wraps the given dsa.PublicKey.
+func NewDSAPublicKey(creationTime time.Time, pub *dsa.PublicKey) *PublicKey {
+ pk := &PublicKey{
+ CreationTime: creationTime,
+ PubKeyAlgo: PubKeyAlgoDSA,
+ PublicKey: pub,
+ p: fromBig(pub.P),
+ q: fromBig(pub.Q),
+ g: fromBig(pub.G),
+ y: fromBig(pub.Y),
+ }
+
+ pk.setFingerPrintAndKeyId()
+ return pk
+}
+
+// NewElGamalPublicKey returns a PublicKey that wraps the given elgamal.PublicKey.
+func NewElGamalPublicKey(creationTime time.Time, pub *elgamal.PublicKey) *PublicKey {
+ pk := &PublicKey{
+ CreationTime: creationTime,
+ PubKeyAlgo: PubKeyAlgoElGamal,
+ PublicKey: pub,
+ p: fromBig(pub.P),
+ g: fromBig(pub.G),
+ y: fromBig(pub.Y),
+ }
+
+ pk.setFingerPrintAndKeyId()
+ return pk
+}
+
+func NewECDSAPublicKey(creationTime time.Time, pub *ecdsa.PublicKey) *PublicKey {
+ pk := &PublicKey{
+ CreationTime: creationTime,
+ PubKeyAlgo: PubKeyAlgoECDSA,
+ PublicKey: pub,
+ ec: new(ecdsaKey),
+ }
+
+ switch pub.Curve {
+ case elliptic.P256():
+ pk.ec.oid = oidCurveP256
+ case elliptic.P384():
+ pk.ec.oid = oidCurveP384
+ case elliptic.P521():
+ pk.ec.oid = oidCurveP521
+ default:
+ panic("unknown elliptic curve")
+ }
+
+ pk.ec.p.bytes = elliptic.Marshal(pub.Curve, pub.X, pub.Y)
+
+ // The bit length is 3 (for the 0x04 specifying an uncompressed key)
+ // plus two field elements (for x and y), which are rounded up to the
+ // nearest byte. See https://tools.ietf.org/html/rfc6637#section-6
+ fieldBytes := (pub.Curve.Params().BitSize + 7) & ^7
+ pk.ec.p.bitLength = uint16(3 + fieldBytes + fieldBytes)
+
+ pk.setFingerPrintAndKeyId()
+ return pk
+}
+
+func (pk *PublicKey) parse(r io.Reader) (err error) {
+ // RFC 4880, section 5.5.2
+ var buf [6]byte
+ _, err = readFull(r, buf[:])
+ if err != nil {
+ return
+ }
+ if buf[0] != 4 {
+ return errors.UnsupportedError("public key version")
+ }
+ pk.CreationTime = time.Unix(int64(uint32(buf[1])<<24|uint32(buf[2])<<16|uint32(buf[3])<<8|uint32(buf[4])), 0)
+ pk.PubKeyAlgo = PublicKeyAlgorithm(buf[5])
+ switch pk.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
+ err = pk.parseRSA(r)
+ case PubKeyAlgoDSA:
+ err = pk.parseDSA(r)
+ case PubKeyAlgoElGamal:
+ err = pk.parseElGamal(r)
+ case PubKeyAlgoECDSA:
+ pk.ec = new(ecdsaKey)
+ if err = pk.ec.parse(r); err != nil {
+ return err
+ }
+ pk.PublicKey, err = pk.ec.newECDSA()
+ case PubKeyAlgoECDH:
+ pk.ec = new(ecdsaKey)
+ if err = pk.ec.parse(r); err != nil {
+ return
+ }
+ pk.ecdh = new(ecdhKdf)
+ if err = pk.ecdh.parse(r); err != nil {
+ return
+ }
+ // The ECDH key is stored in an ecdsa.PublicKey for convenience.
+ pk.PublicKey, err = pk.ec.newECDSA()
+ default:
+ err = errors.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo)))
+ }
+ if err != nil {
+ return
+ }
+
+ pk.setFingerPrintAndKeyId()
+ return
+}
+
+func (pk *PublicKey) setFingerPrintAndKeyId() {
+ // RFC 4880, section 12.2
+ fingerPrint := sha1.New()
+ pk.SerializeSignaturePrefix(fingerPrint)
+ pk.serializeWithoutHeaders(fingerPrint)
+ copy(pk.Fingerprint[:], fingerPrint.Sum(nil))
+ pk.KeyId = binary.BigEndian.Uint64(pk.Fingerprint[12:20])
+}
+
+// parseRSA parses RSA public key material from the given Reader. See RFC 4880,
+// section 5.5.2.
+func (pk *PublicKey) parseRSA(r io.Reader) (err error) {
+ pk.n.bytes, pk.n.bitLength, err = readMPI(r)
+ if err != nil {
+ return
+ }
+ pk.e.bytes, pk.e.bitLength, err = readMPI(r)
+ if err != nil {
+ return
+ }
+
+ if len(pk.e.bytes) > 3 {
+ err = errors.UnsupportedError("large public exponent")
+ return
+ }
+ rsa := &rsa.PublicKey{
+ N: new(big.Int).SetBytes(pk.n.bytes),
+ E: 0,
+ }
+ for i := 0; i < len(pk.e.bytes); i++ {
+ rsa.E <<= 8
+ rsa.E |= int(pk.e.bytes[i])
+ }
+ pk.PublicKey = rsa
+ return
+}
+
+// parseDSA parses DSA public key material from the given Reader. See RFC 4880,
+// section 5.5.2.
+func (pk *PublicKey) parseDSA(r io.Reader) (err error) {
+ pk.p.bytes, pk.p.bitLength, err = readMPI(r)
+ if err != nil {
+ return
+ }
+ pk.q.bytes, pk.q.bitLength, err = readMPI(r)
+ if err != nil {
+ return
+ }
+ pk.g.bytes, pk.g.bitLength, err = readMPI(r)
+ if err != nil {
+ return
+ }
+ pk.y.bytes, pk.y.bitLength, err = readMPI(r)
+ if err != nil {
+ return
+ }
+
+ dsa := new(dsa.PublicKey)
+ dsa.P = new(big.Int).SetBytes(pk.p.bytes)
+ dsa.Q = new(big.Int).SetBytes(pk.q.bytes)
+ dsa.G = new(big.Int).SetBytes(pk.g.bytes)
+ dsa.Y = new(big.Int).SetBytes(pk.y.bytes)
+ pk.PublicKey = dsa
+ return
+}
+
+// parseElGamal parses ElGamal public key material from the given Reader. See
+// RFC 4880, section 5.5.2.
+func (pk *PublicKey) parseElGamal(r io.Reader) (err error) {
+ pk.p.bytes, pk.p.bitLength, err = readMPI(r)
+ if err != nil {
+ return
+ }
+ pk.g.bytes, pk.g.bitLength, err = readMPI(r)
+ if err != nil {
+ return
+ }
+ pk.y.bytes, pk.y.bitLength, err = readMPI(r)
+ if err != nil {
+ return
+ }
+
+ elgamal := new(elgamal.PublicKey)
+ elgamal.P = new(big.Int).SetBytes(pk.p.bytes)
+ elgamal.G = new(big.Int).SetBytes(pk.g.bytes)
+ elgamal.Y = new(big.Int).SetBytes(pk.y.bytes)
+ pk.PublicKey = elgamal
+ return
+}
+
+// SerializeSignaturePrefix writes the prefix for this public key to the given Writer.
+// The prefix is used when calculating a signature over this public key. See
+// RFC 4880, section 5.2.4.
+func (pk *PublicKey) SerializeSignaturePrefix(h io.Writer) {
+ var pLength uint16
+ switch pk.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
+ pLength += 2 + uint16(len(pk.n.bytes))
+ pLength += 2 + uint16(len(pk.e.bytes))
+ case PubKeyAlgoDSA:
+ pLength += 2 + uint16(len(pk.p.bytes))
+ pLength += 2 + uint16(len(pk.q.bytes))
+ pLength += 2 + uint16(len(pk.g.bytes))
+ pLength += 2 + uint16(len(pk.y.bytes))
+ case PubKeyAlgoElGamal:
+ pLength += 2 + uint16(len(pk.p.bytes))
+ pLength += 2 + uint16(len(pk.g.bytes))
+ pLength += 2 + uint16(len(pk.y.bytes))
+ case PubKeyAlgoECDSA:
+ pLength += uint16(pk.ec.byteLen())
+ case PubKeyAlgoECDH:
+ pLength += uint16(pk.ec.byteLen())
+ pLength += uint16(pk.ecdh.byteLen())
+ default:
+ panic("unknown public key algorithm")
+ }
+ pLength += 6
+ h.Write([]byte{0x99, byte(pLength >> 8), byte(pLength)})
+ return
+}
+
+func (pk *PublicKey) Serialize(w io.Writer) (err error) {
+ length := 6 // 6 byte header
+
+ switch pk.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
+ length += 2 + len(pk.n.bytes)
+ length += 2 + len(pk.e.bytes)
+ case PubKeyAlgoDSA:
+ length += 2 + len(pk.p.bytes)
+ length += 2 + len(pk.q.bytes)
+ length += 2 + len(pk.g.bytes)
+ length += 2 + len(pk.y.bytes)
+ case PubKeyAlgoElGamal:
+ length += 2 + len(pk.p.bytes)
+ length += 2 + len(pk.g.bytes)
+ length += 2 + len(pk.y.bytes)
+ case PubKeyAlgoECDSA:
+ length += pk.ec.byteLen()
+ case PubKeyAlgoECDH:
+ length += pk.ec.byteLen()
+ length += pk.ecdh.byteLen()
+ default:
+ panic("unknown public key algorithm")
+ }
+
+ packetType := packetTypePublicKey
+ if pk.IsSubkey {
+ packetType = packetTypePublicSubkey
+ }
+ err = serializeHeader(w, packetType, length)
+ if err != nil {
+ return
+ }
+ return pk.serializeWithoutHeaders(w)
+}
+
+// serializeWithoutHeaders marshals the PublicKey to w in the form of an
+// OpenPGP public key packet, not including the packet header.
+func (pk *PublicKey) serializeWithoutHeaders(w io.Writer) (err error) {
+ var buf [6]byte
+ buf[0] = 4
+ t := uint32(pk.CreationTime.Unix())
+ buf[1] = byte(t >> 24)
+ buf[2] = byte(t >> 16)
+ buf[3] = byte(t >> 8)
+ buf[4] = byte(t)
+ buf[5] = byte(pk.PubKeyAlgo)
+
+ _, err = w.Write(buf[:])
+ if err != nil {
+ return
+ }
+
+ switch pk.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
+ return writeMPIs(w, pk.n, pk.e)
+ case PubKeyAlgoDSA:
+ return writeMPIs(w, pk.p, pk.q, pk.g, pk.y)
+ case PubKeyAlgoElGamal:
+ return writeMPIs(w, pk.p, pk.g, pk.y)
+ case PubKeyAlgoECDSA:
+ return pk.ec.serialize(w)
+ case PubKeyAlgoECDH:
+ if err = pk.ec.serialize(w); err != nil {
+ return
+ }
+ return pk.ecdh.serialize(w)
+ }
+ return errors.InvalidArgumentError("bad public-key algorithm")
+}
+
+// CanSign returns true iff this public key can generate signatures
+func (pk *PublicKey) CanSign() bool {
+ return pk.PubKeyAlgo != PubKeyAlgoRSAEncryptOnly && pk.PubKeyAlgo != PubKeyAlgoElGamal
+}
+
+// VerifySignature returns nil iff sig is a valid signature, made by this
+// public key, of the data hashed into signed. signed is mutated by this call.
+func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err error) {
+ if !pk.CanSign() {
+ return errors.InvalidArgumentError("public key cannot generate signatures")
+ }
+
+ signed.Write(sig.HashSuffix)
+ hashBytes := signed.Sum(nil)
+
+ if hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1] {
+ return errors.SignatureError("hash tag doesn't match")
+ }
+
+ if pk.PubKeyAlgo != sig.PubKeyAlgo {
+ return errors.InvalidArgumentError("public key and signature use different algorithms")
+ }
+
+ switch pk.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
+ rsaPublicKey, _ := pk.PublicKey.(*rsa.PublicKey)
+ err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, padToKeySize(rsaPublicKey, sig.RSASignature.bytes))
+ if err != nil {
+ return errors.SignatureError("RSA verification failure")
+ }
+ return nil
+ case PubKeyAlgoDSA:
+ dsaPublicKey, _ := pk.PublicKey.(*dsa.PublicKey)
+ // Need to truncate hashBytes to match FIPS 186-3 section 4.6.
+ subgroupSize := (dsaPublicKey.Q.BitLen() + 7) / 8
+ if len(hashBytes) > subgroupSize {
+ hashBytes = hashBytes[:subgroupSize]
+ }
+ if !dsa.Verify(dsaPublicKey, hashBytes, new(big.Int).SetBytes(sig.DSASigR.bytes), new(big.Int).SetBytes(sig.DSASigS.bytes)) {
+ return errors.SignatureError("DSA verification failure")
+ }
+ return nil
+ case PubKeyAlgoECDSA:
+ ecdsaPublicKey := pk.PublicKey.(*ecdsa.PublicKey)
+ if !ecdsa.Verify(ecdsaPublicKey, hashBytes, new(big.Int).SetBytes(sig.ECDSASigR.bytes), new(big.Int).SetBytes(sig.ECDSASigS.bytes)) {
+ return errors.SignatureError("ECDSA verification failure")
+ }
+ return nil
+ default:
+ return errors.SignatureError("Unsupported public key algorithm used in signature")
+ }
+}
+
+// VerifySignatureV3 returns nil iff sig is a valid signature, made by this
+// public key, of the data hashed into signed. signed is mutated by this call.
+func (pk *PublicKey) VerifySignatureV3(signed hash.Hash, sig *SignatureV3) (err error) {
+ if !pk.CanSign() {
+ return errors.InvalidArgumentError("public key cannot generate signatures")
+ }
+
+ suffix := make([]byte, 5)
+ suffix[0] = byte(sig.SigType)
+ binary.BigEndian.PutUint32(suffix[1:], uint32(sig.CreationTime.Unix()))
+ signed.Write(suffix)
+ hashBytes := signed.Sum(nil)
+
+ if hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1] {
+ return errors.SignatureError("hash tag doesn't match")
+ }
+
+ if pk.PubKeyAlgo != sig.PubKeyAlgo {
+ return errors.InvalidArgumentError("public key and signature use different algorithms")
+ }
+
+ switch pk.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
+ rsaPublicKey := pk.PublicKey.(*rsa.PublicKey)
+ if err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, padToKeySize(rsaPublicKey, sig.RSASignature.bytes)); err != nil {
+ return errors.SignatureError("RSA verification failure")
+ }
+ return
+ case PubKeyAlgoDSA:
+ dsaPublicKey := pk.PublicKey.(*dsa.PublicKey)
+ // Need to truncate hashBytes to match FIPS 186-3 section 4.6.
+ subgroupSize := (dsaPublicKey.Q.BitLen() + 7) / 8
+ if len(hashBytes) > subgroupSize {
+ hashBytes = hashBytes[:subgroupSize]
+ }
+ if !dsa.Verify(dsaPublicKey, hashBytes, new(big.Int).SetBytes(sig.DSASigR.bytes), new(big.Int).SetBytes(sig.DSASigS.bytes)) {
+ return errors.SignatureError("DSA verification failure")
+ }
+ return nil
+ default:
+ panic("shouldn't happen")
+ }
+}
+
+// keySignatureHash returns a Hash of the message that needs to be signed for
+// pk to assert a subkey relationship to signed.
+func keySignatureHash(pk, signed signingKey, hashFunc crypto.Hash) (h hash.Hash, err error) {
+ if !hashFunc.Available() {
+ return nil, errors.UnsupportedError("hash function")
+ }
+ h = hashFunc.New()
+
+ // RFC 4880, section 5.2.4
+ pk.SerializeSignaturePrefix(h)
+ pk.serializeWithoutHeaders(h)
+ signed.SerializeSignaturePrefix(h)
+ signed.serializeWithoutHeaders(h)
+ return
+}
+
+// VerifyKeySignature returns nil iff sig is a valid signature, made by this
+// public key, of signed.
+func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) error {
+ h, err := keySignatureHash(pk, signed, sig.Hash)
+ if err != nil {
+ return err
+ }
+ if err = pk.VerifySignature(h, sig); err != nil {
+ return err
+ }
+
+ if sig.FlagSign {
+ // Signing subkeys must be cross-signed. See
+ // https://www.gnupg.org/faq/subkey-cross-certify.html.
+ if sig.EmbeddedSignature == nil {
+ return errors.StructuralError("signing subkey is missing cross-signature")
+ }
+ // Verify the cross-signature. This is calculated over the same
+ // data as the main signature, so we cannot just recursively
+ // call signed.VerifyKeySignature(...)
+ if h, err = keySignatureHash(pk, signed, sig.EmbeddedSignature.Hash); err != nil {
+ return errors.StructuralError("error while hashing for cross-signature: " + err.Error())
+ }
+ if err := signed.VerifySignature(h, sig.EmbeddedSignature); err != nil {
+ return errors.StructuralError("error while verifying cross-signature: " + err.Error())
+ }
+ }
+
+ return nil
+}
+
+func keyRevocationHash(pk signingKey, hashFunc crypto.Hash) (h hash.Hash, err error) {
+ if !hashFunc.Available() {
+ return nil, errors.UnsupportedError("hash function")
+ }
+ h = hashFunc.New()
+
+ // RFC 4880, section 5.2.4
+ pk.SerializeSignaturePrefix(h)
+ pk.serializeWithoutHeaders(h)
+
+ return
+}
+
+// VerifyRevocationSignature returns nil iff sig is a valid signature, made by this
+// public key.
+func (pk *PublicKey) VerifyRevocationSignature(sig *Signature) (err error) {
+ h, err := keyRevocationHash(pk, sig.Hash)
+ if err != nil {
+ return err
+ }
+ return pk.VerifySignature(h, sig)
+}
+
+// userIdSignatureHash returns a Hash of the message that needs to be signed
+// to assert that pk is a valid key for id.
+func userIdSignatureHash(id string, pk *PublicKey, hashFunc crypto.Hash) (h hash.Hash, err error) {
+ if !hashFunc.Available() {
+ return nil, errors.UnsupportedError("hash function")
+ }
+ h = hashFunc.New()
+
+ // RFC 4880, section 5.2.4
+ pk.SerializeSignaturePrefix(h)
+ pk.serializeWithoutHeaders(h)
+
+ var buf [5]byte
+ buf[0] = 0xb4
+ buf[1] = byte(len(id) >> 24)
+ buf[2] = byte(len(id) >> 16)
+ buf[3] = byte(len(id) >> 8)
+ buf[4] = byte(len(id))
+ h.Write(buf[:])
+ h.Write([]byte(id))
+
+ return
+}
+
+// VerifyUserIdSignature returns nil iff sig is a valid signature, made by this
+// public key, that id is the identity of pub.
+func (pk *PublicKey) VerifyUserIdSignature(id string, pub *PublicKey, sig *Signature) (err error) {
+ h, err := userIdSignatureHash(id, pub, sig.Hash)
+ if err != nil {
+ return err
+ }
+ return pk.VerifySignature(h, sig)
+}
+
+// VerifyUserIdSignatureV3 returns nil iff sig is a valid signature, made by this
+// public key, that id is the identity of pub.
+func (pk *PublicKey) VerifyUserIdSignatureV3(id string, pub *PublicKey, sig *SignatureV3) (err error) {
+ h, err := userIdSignatureV3Hash(id, pub, sig.Hash)
+ if err != nil {
+ return err
+ }
+ return pk.VerifySignatureV3(h, sig)
+}
+
+// KeyIdString returns the public key's fingerprint in capital hex
+// (e.g. "6C7EE1B8621CC013").
+func (pk *PublicKey) KeyIdString() string {
+ return fmt.Sprintf("%X", pk.Fingerprint[12:20])
+}
+
+// KeyIdShortString returns the short form of public key's fingerprint
+// in capital hex, as shown by gpg --list-keys (e.g. "621CC013").
+func (pk *PublicKey) KeyIdShortString() string {
+ return fmt.Sprintf("%X", pk.Fingerprint[16:20])
+}
+
+// A parsedMPI is used to store the contents of a big integer, along with the
+// bit length that was specified in the original input. This allows the MPI to
+// be reserialized exactly.
+type parsedMPI struct {
+ bytes []byte
+ bitLength uint16
+}
+
+// writeMPIs is a utility function for serializing several big integers to the
+// given Writer.
+func writeMPIs(w io.Writer, mpis ...parsedMPI) (err error) {
+ for _, mpi := range mpis {
+ err = writeMPI(w, mpi.bitLength, mpi.bytes)
+ if err != nil {
+ return
+ }
+ }
+ return
+}
+
+// BitLength returns the bit length for the given public key.
+func (pk *PublicKey) BitLength() (bitLength uint16, err error) {
+ switch pk.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
+ bitLength = pk.n.bitLength
+ case PubKeyAlgoDSA:
+ bitLength = pk.p.bitLength
+ case PubKeyAlgoElGamal:
+ bitLength = pk.p.bitLength
+ default:
+ err = errors.InvalidArgumentError("bad public-key algorithm")
+ }
+ return
+}
--- /dev/null
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "crypto"
+ "crypto/md5"
+ "crypto/rsa"
+ "encoding/binary"
+ "fmt"
+ "hash"
+ "io"
+ "math/big"
+ "strconv"
+ "time"
+
+ "golang.org/x/crypto/openpgp/errors"
+)
+
+// PublicKeyV3 represents older, version 3 public keys. These keys are less secure and
+// should not be used for signing or encrypting. They are supported here only for
+// parsing version 3 key material and validating signatures.
+// See RFC 4880, section 5.5.2.
+type PublicKeyV3 struct {
+ CreationTime time.Time
+ DaysToExpire uint16
+ PubKeyAlgo PublicKeyAlgorithm
+ PublicKey *rsa.PublicKey
+ Fingerprint [16]byte
+ KeyId uint64
+ IsSubkey bool
+
+ n, e parsedMPI
+}
+
+// newRSAPublicKeyV3 returns a PublicKey that wraps the given rsa.PublicKey.
+// Included here for testing purposes only. RFC 4880, section 5.5.2:
+// "an implementation MUST NOT generate a V3 key, but MAY accept it."
+func newRSAPublicKeyV3(creationTime time.Time, pub *rsa.PublicKey) *PublicKeyV3 {
+ pk := &PublicKeyV3{
+ CreationTime: creationTime,
+ PublicKey: pub,
+ n: fromBig(pub.N),
+ e: fromBig(big.NewInt(int64(pub.E))),
+ }
+
+ pk.setFingerPrintAndKeyId()
+ return pk
+}
+
+func (pk *PublicKeyV3) parse(r io.Reader) (err error) {
+ // RFC 4880, section 5.5.2
+ var buf [8]byte
+ if _, err = readFull(r, buf[:]); err != nil {
+ return
+ }
+ if buf[0] < 2 || buf[0] > 3 {
+ return errors.UnsupportedError("public key version")
+ }
+ pk.CreationTime = time.Unix(int64(uint32(buf[1])<<24|uint32(buf[2])<<16|uint32(buf[3])<<8|uint32(buf[4])), 0)
+ pk.DaysToExpire = binary.BigEndian.Uint16(buf[5:7])
+ pk.PubKeyAlgo = PublicKeyAlgorithm(buf[7])
+ switch pk.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
+ err = pk.parseRSA(r)
+ default:
+ err = errors.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo)))
+ }
+ if err != nil {
+ return
+ }
+
+ pk.setFingerPrintAndKeyId()
+ return
+}
+
+func (pk *PublicKeyV3) setFingerPrintAndKeyId() {
+ // RFC 4880, section 12.2
+ fingerPrint := md5.New()
+ fingerPrint.Write(pk.n.bytes)
+ fingerPrint.Write(pk.e.bytes)
+ fingerPrint.Sum(pk.Fingerprint[:0])
+ pk.KeyId = binary.BigEndian.Uint64(pk.n.bytes[len(pk.n.bytes)-8:])
+}
+
+// parseRSA parses RSA public key material from the given Reader. See RFC 4880,
+// section 5.5.2.
+func (pk *PublicKeyV3) parseRSA(r io.Reader) (err error) {
+ if pk.n.bytes, pk.n.bitLength, err = readMPI(r); err != nil {
+ return
+ }
+ if pk.e.bytes, pk.e.bitLength, err = readMPI(r); err != nil {
+ return
+ }
+
+ // RFC 4880 Section 12.2 requires the low 8 bytes of the
+ // modulus to form the key id.
+ if len(pk.n.bytes) < 8 {
+ return errors.StructuralError("v3 public key modulus is too short")
+ }
+ if len(pk.e.bytes) > 3 {
+ err = errors.UnsupportedError("large public exponent")
+ return
+ }
+ rsa := &rsa.PublicKey{N: new(big.Int).SetBytes(pk.n.bytes)}
+ for i := 0; i < len(pk.e.bytes); i++ {
+ rsa.E <<= 8
+ rsa.E |= int(pk.e.bytes[i])
+ }
+ pk.PublicKey = rsa
+ return
+}
+
+// SerializeSignaturePrefix writes the prefix for this public key to the given Writer.
+// The prefix is used when calculating a signature over this public key. See
+// RFC 4880, section 5.2.4.
+func (pk *PublicKeyV3) SerializeSignaturePrefix(w io.Writer) {
+ var pLength uint16
+ switch pk.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
+ pLength += 2 + uint16(len(pk.n.bytes))
+ pLength += 2 + uint16(len(pk.e.bytes))
+ default:
+ panic("unknown public key algorithm")
+ }
+ pLength += 6
+ w.Write([]byte{0x99, byte(pLength >> 8), byte(pLength)})
+ return
+}
+
+func (pk *PublicKeyV3) Serialize(w io.Writer) (err error) {
+ length := 8 // 8 byte header
+
+ switch pk.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
+ length += 2 + len(pk.n.bytes)
+ length += 2 + len(pk.e.bytes)
+ default:
+ panic("unknown public key algorithm")
+ }
+
+ packetType := packetTypePublicKey
+ if pk.IsSubkey {
+ packetType = packetTypePublicSubkey
+ }
+ if err = serializeHeader(w, packetType, length); err != nil {
+ return
+ }
+ return pk.serializeWithoutHeaders(w)
+}
+
+// serializeWithoutHeaders marshals the PublicKey to w in the form of an
+// OpenPGP public key packet, not including the packet header.
+func (pk *PublicKeyV3) serializeWithoutHeaders(w io.Writer) (err error) {
+ var buf [8]byte
+ // Version 3
+ buf[0] = 3
+ // Creation time
+ t := uint32(pk.CreationTime.Unix())
+ buf[1] = byte(t >> 24)
+ buf[2] = byte(t >> 16)
+ buf[3] = byte(t >> 8)
+ buf[4] = byte(t)
+ // Days to expire
+ buf[5] = byte(pk.DaysToExpire >> 8)
+ buf[6] = byte(pk.DaysToExpire)
+ // Public key algorithm
+ buf[7] = byte(pk.PubKeyAlgo)
+
+ if _, err = w.Write(buf[:]); err != nil {
+ return
+ }
+
+ switch pk.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
+ return writeMPIs(w, pk.n, pk.e)
+ }
+ return errors.InvalidArgumentError("bad public-key algorithm")
+}
+
+// CanSign returns true iff this public key can generate signatures
+func (pk *PublicKeyV3) CanSign() bool {
+ return pk.PubKeyAlgo != PubKeyAlgoRSAEncryptOnly
+}
+
+// VerifySignatureV3 returns nil iff sig is a valid signature, made by this
+// public key, of the data hashed into signed. signed is mutated by this call.
+func (pk *PublicKeyV3) VerifySignatureV3(signed hash.Hash, sig *SignatureV3) (err error) {
+ if !pk.CanSign() {
+ return errors.InvalidArgumentError("public key cannot generate signatures")
+ }
+
+ suffix := make([]byte, 5)
+ suffix[0] = byte(sig.SigType)
+ binary.BigEndian.PutUint32(suffix[1:], uint32(sig.CreationTime.Unix()))
+ signed.Write(suffix)
+ hashBytes := signed.Sum(nil)
+
+ if hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1] {
+ return errors.SignatureError("hash tag doesn't match")
+ }
+
+ if pk.PubKeyAlgo != sig.PubKeyAlgo {
+ return errors.InvalidArgumentError("public key and signature use different algorithms")
+ }
+
+ switch pk.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
+ if err = rsa.VerifyPKCS1v15(pk.PublicKey, sig.Hash, hashBytes, sig.RSASignature.bytes); err != nil {
+ return errors.SignatureError("RSA verification failure")
+ }
+ return
+ default:
+ // V3 public keys only support RSA.
+ panic("shouldn't happen")
+ }
+}
+
+// VerifyUserIdSignatureV3 returns nil iff sig is a valid signature, made by this
+// public key, that id is the identity of pub.
+func (pk *PublicKeyV3) VerifyUserIdSignatureV3(id string, pub *PublicKeyV3, sig *SignatureV3) (err error) {
+ h, err := userIdSignatureV3Hash(id, pk, sig.Hash)
+ if err != nil {
+ return err
+ }
+ return pk.VerifySignatureV3(h, sig)
+}
+
+// VerifyKeySignatureV3 returns nil iff sig is a valid signature, made by this
+// public key, of signed.
+func (pk *PublicKeyV3) VerifyKeySignatureV3(signed *PublicKeyV3, sig *SignatureV3) (err error) {
+ h, err := keySignatureHash(pk, signed, sig.Hash)
+ if err != nil {
+ return err
+ }
+ return pk.VerifySignatureV3(h, sig)
+}
+
+// userIdSignatureV3Hash returns a Hash of the message that needs to be signed
+// to assert that pk is a valid key for id.
+func userIdSignatureV3Hash(id string, pk signingKey, hfn crypto.Hash) (h hash.Hash, err error) {
+ if !hfn.Available() {
+ return nil, errors.UnsupportedError("hash function")
+ }
+ h = hfn.New()
+
+ // RFC 4880, section 5.2.4
+ pk.SerializeSignaturePrefix(h)
+ pk.serializeWithoutHeaders(h)
+
+ h.Write([]byte(id))
+
+ return
+}
+
+// KeyIdString returns the public key's fingerprint in capital hex
+// (e.g. "6C7EE1B8621CC013").
+func (pk *PublicKeyV3) KeyIdString() string {
+ return fmt.Sprintf("%X", pk.KeyId)
+}
+
+// KeyIdShortString returns the short form of public key's fingerprint
+// in capital hex, as shown by gpg --list-keys (e.g. "621CC013").
+func (pk *PublicKeyV3) KeyIdShortString() string {
+ return fmt.Sprintf("%X", pk.KeyId&0xFFFFFFFF)
+}
+
+// BitLength returns the bit length for the given public key.
+func (pk *PublicKeyV3) BitLength() (bitLength uint16, err error) {
+ switch pk.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
+ bitLength = pk.n.bitLength
+ default:
+ err = errors.InvalidArgumentError("bad public-key algorithm")
+ }
+ return
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "golang.org/x/crypto/openpgp/errors"
+ "io"
+)
+
+// Reader reads packets from an io.Reader and allows packets to be 'unread' so
+// that they result from the next call to Next.
+type Reader struct {
+ q []Packet
+ readers []io.Reader
+}
+
+// New io.Readers are pushed when a compressed or encrypted packet is processed
+// and recursively treated as a new source of packets. However, a carefully
+// crafted packet can trigger an infinite recursive sequence of packets. See
+// http://mumble.net/~campbell/misc/pgp-quine
+// https://web.nvd.nist.gov/view/vuln/detail?vulnId=CVE-2013-4402
+// This constant limits the number of recursive packets that may be pushed.
+const maxReaders = 32
+
+// Next returns the most recently unread Packet, or reads another packet from
+// the top-most io.Reader. Unknown packet types are skipped.
+func (r *Reader) Next() (p Packet, err error) {
+ if len(r.q) > 0 {
+ p = r.q[len(r.q)-1]
+ r.q = r.q[:len(r.q)-1]
+ return
+ }
+
+ for len(r.readers) > 0 {
+ p, err = Read(r.readers[len(r.readers)-1])
+ if err == nil {
+ return
+ }
+ if err == io.EOF {
+ r.readers = r.readers[:len(r.readers)-1]
+ continue
+ }
+ if _, ok := err.(errors.UnknownPacketTypeError); !ok {
+ return nil, err
+ }
+ }
+
+ return nil, io.EOF
+}
+
+// Push causes the Reader to start reading from a new io.Reader. When an EOF
+// error is seen from the new io.Reader, it is popped and the Reader continues
+// to read from the next most recent io.Reader. Push returns a StructuralError
+// if pushing the reader would exceed the maximum recursion level, otherwise it
+// returns nil.
+func (r *Reader) Push(reader io.Reader) (err error) {
+ if len(r.readers) >= maxReaders {
+ return errors.StructuralError("too many layers of packets")
+ }
+ r.readers = append(r.readers, reader)
+ return nil
+}
+
+// Unread causes the given Packet to be returned from the next call to Next.
+func (r *Reader) Unread(p Packet) {
+ r.q = append(r.q, p)
+}
+
+func NewReader(r io.Reader) *Reader {
+ return &Reader{
+ q: nil,
+ readers: []io.Reader{r},
+ }
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "bytes"
+ "crypto"
+ "crypto/dsa"
+ "crypto/ecdsa"
+ "encoding/asn1"
+ "encoding/binary"
+ "hash"
+ "io"
+ "math/big"
+ "strconv"
+ "time"
+
+ "golang.org/x/crypto/openpgp/errors"
+ "golang.org/x/crypto/openpgp/s2k"
+)
+
+const (
+ // See RFC 4880, section 5.2.3.21 for details.
+ KeyFlagCertify = 1 << iota
+ KeyFlagSign
+ KeyFlagEncryptCommunications
+ KeyFlagEncryptStorage
+)
+
+// Signature represents a signature. See RFC 4880, section 5.2.
+type Signature struct {
+ SigType SignatureType
+ PubKeyAlgo PublicKeyAlgorithm
+ Hash crypto.Hash
+
+ // HashSuffix is extra data that is hashed in after the signed data.
+ HashSuffix []byte
+ // HashTag contains the first two bytes of the hash for fast rejection
+ // of bad signed data.
+ HashTag [2]byte
+ CreationTime time.Time
+
+ RSASignature parsedMPI
+ DSASigR, DSASigS parsedMPI
+ ECDSASigR, ECDSASigS parsedMPI
+
+ // rawSubpackets contains the unparsed subpackets, in order.
+ rawSubpackets []outputSubpacket
+
+ // The following are optional so are nil when not included in the
+ // signature.
+
+ SigLifetimeSecs, KeyLifetimeSecs *uint32
+ PreferredSymmetric, PreferredHash, PreferredCompression []uint8
+ IssuerKeyId *uint64
+ IsPrimaryId *bool
+
+ // FlagsValid is set if any flags were given. See RFC 4880, section
+ // 5.2.3.21 for details.
+ FlagsValid bool
+ FlagCertify, FlagSign, FlagEncryptCommunications, FlagEncryptStorage bool
+
+ // RevocationReason is set if this signature has been revoked.
+ // See RFC 4880, section 5.2.3.23 for details.
+ RevocationReason *uint8
+ RevocationReasonText string
+
+ // MDC is set if this signature has a feature packet that indicates
+ // support for MDC subpackets.
+ MDC bool
+
+ // EmbeddedSignature, if non-nil, is a signature of the parent key, by
+ // this key. This prevents an attacker from claiming another's signing
+ // subkey as their own.
+ EmbeddedSignature *Signature
+
+ outSubpackets []outputSubpacket
+}
+
+func (sig *Signature) parse(r io.Reader) (err error) {
+ // RFC 4880, section 5.2.3
+ var buf [5]byte
+ _, err = readFull(r, buf[:1])
+ if err != nil {
+ return
+ }
+ if buf[0] != 4 {
+ err = errors.UnsupportedError("signature packet version " + strconv.Itoa(int(buf[0])))
+ return
+ }
+
+ _, err = readFull(r, buf[:5])
+ if err != nil {
+ return
+ }
+ sig.SigType = SignatureType(buf[0])
+ sig.PubKeyAlgo = PublicKeyAlgorithm(buf[1])
+ switch sig.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA, PubKeyAlgoECDSA:
+ default:
+ err = errors.UnsupportedError("public key algorithm " + strconv.Itoa(int(sig.PubKeyAlgo)))
+ return
+ }
+
+ var ok bool
+ sig.Hash, ok = s2k.HashIdToHash(buf[2])
+ if !ok {
+ return errors.UnsupportedError("hash function " + strconv.Itoa(int(buf[2])))
+ }
+
+ hashedSubpacketsLength := int(buf[3])<<8 | int(buf[4])
+ l := 6 + hashedSubpacketsLength
+ sig.HashSuffix = make([]byte, l+6)
+ sig.HashSuffix[0] = 4
+ copy(sig.HashSuffix[1:], buf[:5])
+ hashedSubpackets := sig.HashSuffix[6:l]
+ _, err = readFull(r, hashedSubpackets)
+ if err != nil {
+ return
+ }
+ // See RFC 4880, section 5.2.4
+ trailer := sig.HashSuffix[l:]
+ trailer[0] = 4
+ trailer[1] = 0xff
+ trailer[2] = uint8(l >> 24)
+ trailer[3] = uint8(l >> 16)
+ trailer[4] = uint8(l >> 8)
+ trailer[5] = uint8(l)
+
+ err = parseSignatureSubpackets(sig, hashedSubpackets, true)
+ if err != nil {
+ return
+ }
+
+ _, err = readFull(r, buf[:2])
+ if err != nil {
+ return
+ }
+ unhashedSubpacketsLength := int(buf[0])<<8 | int(buf[1])
+ unhashedSubpackets := make([]byte, unhashedSubpacketsLength)
+ _, err = readFull(r, unhashedSubpackets)
+ if err != nil {
+ return
+ }
+ err = parseSignatureSubpackets(sig, unhashedSubpackets, false)
+ if err != nil {
+ return
+ }
+
+ _, err = readFull(r, sig.HashTag[:2])
+ if err != nil {
+ return
+ }
+
+ switch sig.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
+ sig.RSASignature.bytes, sig.RSASignature.bitLength, err = readMPI(r)
+ case PubKeyAlgoDSA:
+ sig.DSASigR.bytes, sig.DSASigR.bitLength, err = readMPI(r)
+ if err == nil {
+ sig.DSASigS.bytes, sig.DSASigS.bitLength, err = readMPI(r)
+ }
+ case PubKeyAlgoECDSA:
+ sig.ECDSASigR.bytes, sig.ECDSASigR.bitLength, err = readMPI(r)
+ if err == nil {
+ sig.ECDSASigS.bytes, sig.ECDSASigS.bitLength, err = readMPI(r)
+ }
+ default:
+ panic("unreachable")
+ }
+ return
+}
+
+// parseSignatureSubpackets parses subpackets of the main signature packet. See
+// RFC 4880, section 5.2.3.1.
+func parseSignatureSubpackets(sig *Signature, subpackets []byte, isHashed bool) (err error) {
+ for len(subpackets) > 0 {
+ subpackets, err = parseSignatureSubpacket(sig, subpackets, isHashed)
+ if err != nil {
+ return
+ }
+ }
+
+ if sig.CreationTime.IsZero() {
+ err = errors.StructuralError("no creation time in signature")
+ }
+
+ return
+}
+
+type signatureSubpacketType uint8
+
+const (
+ creationTimeSubpacket signatureSubpacketType = 2
+ signatureExpirationSubpacket signatureSubpacketType = 3
+ keyExpirationSubpacket signatureSubpacketType = 9
+ prefSymmetricAlgosSubpacket signatureSubpacketType = 11
+ issuerSubpacket signatureSubpacketType = 16
+ prefHashAlgosSubpacket signatureSubpacketType = 21
+ prefCompressionSubpacket signatureSubpacketType = 22
+ primaryUserIdSubpacket signatureSubpacketType = 25
+ keyFlagsSubpacket signatureSubpacketType = 27
+ reasonForRevocationSubpacket signatureSubpacketType = 29
+ featuresSubpacket signatureSubpacketType = 30
+ embeddedSignatureSubpacket signatureSubpacketType = 32
+)
+
+// parseSignatureSubpacket parses a single subpacket. len(subpacket) is >= 1.
+func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (rest []byte, err error) {
+ // RFC 4880, section 5.2.3.1
+ var (
+ length uint32
+ packetType signatureSubpacketType
+ isCritical bool
+ )
+ switch {
+ case subpacket[0] < 192:
+ length = uint32(subpacket[0])
+ subpacket = subpacket[1:]
+ case subpacket[0] < 255:
+ if len(subpacket) < 2 {
+ goto Truncated
+ }
+ length = uint32(subpacket[0]-192)<<8 + uint32(subpacket[1]) + 192
+ subpacket = subpacket[2:]
+ default:
+ if len(subpacket) < 5 {
+ goto Truncated
+ }
+ length = uint32(subpacket[1])<<24 |
+ uint32(subpacket[2])<<16 |
+ uint32(subpacket[3])<<8 |
+ uint32(subpacket[4])
+ subpacket = subpacket[5:]
+ }
+ if length > uint32(len(subpacket)) {
+ goto Truncated
+ }
+ rest = subpacket[length:]
+ subpacket = subpacket[:length]
+ if len(subpacket) == 0 {
+ err = errors.StructuralError("zero length signature subpacket")
+ return
+ }
+ packetType = signatureSubpacketType(subpacket[0] & 0x7f)
+ isCritical = subpacket[0]&0x80 == 0x80
+ subpacket = subpacket[1:]
+ sig.rawSubpackets = append(sig.rawSubpackets, outputSubpacket{isHashed, packetType, isCritical, subpacket})
+ switch packetType {
+ case creationTimeSubpacket:
+ if !isHashed {
+ err = errors.StructuralError("signature creation time in non-hashed area")
+ return
+ }
+ if len(subpacket) != 4 {
+ err = errors.StructuralError("signature creation time not four bytes")
+ return
+ }
+ t := binary.BigEndian.Uint32(subpacket)
+ sig.CreationTime = time.Unix(int64(t), 0)
+ case signatureExpirationSubpacket:
+ // Signature expiration time, section 5.2.3.10
+ if !isHashed {
+ return
+ }
+ if len(subpacket) != 4 {
+ err = errors.StructuralError("expiration subpacket with bad length")
+ return
+ }
+ sig.SigLifetimeSecs = new(uint32)
+ *sig.SigLifetimeSecs = binary.BigEndian.Uint32(subpacket)
+ case keyExpirationSubpacket:
+ // Key expiration time, section 5.2.3.6
+ if !isHashed {
+ return
+ }
+ if len(subpacket) != 4 {
+ err = errors.StructuralError("key expiration subpacket with bad length")
+ return
+ }
+ sig.KeyLifetimeSecs = new(uint32)
+ *sig.KeyLifetimeSecs = binary.BigEndian.Uint32(subpacket)
+ case prefSymmetricAlgosSubpacket:
+ // Preferred symmetric algorithms, section 5.2.3.7
+ if !isHashed {
+ return
+ }
+ sig.PreferredSymmetric = make([]byte, len(subpacket))
+ copy(sig.PreferredSymmetric, subpacket)
+ case issuerSubpacket:
+ // Issuer, section 5.2.3.5
+ if len(subpacket) != 8 {
+ err = errors.StructuralError("issuer subpacket with bad length")
+ return
+ }
+ sig.IssuerKeyId = new(uint64)
+ *sig.IssuerKeyId = binary.BigEndian.Uint64(subpacket)
+ case prefHashAlgosSubpacket:
+ // Preferred hash algorithms, section 5.2.3.8
+ if !isHashed {
+ return
+ }
+ sig.PreferredHash = make([]byte, len(subpacket))
+ copy(sig.PreferredHash, subpacket)
+ case prefCompressionSubpacket:
+ // Preferred compression algorithms, section 5.2.3.9
+ if !isHashed {
+ return
+ }
+ sig.PreferredCompression = make([]byte, len(subpacket))
+ copy(sig.PreferredCompression, subpacket)
+ case primaryUserIdSubpacket:
+ // Primary User ID, section 5.2.3.19
+ if !isHashed {
+ return
+ }
+ if len(subpacket) != 1 {
+ err = errors.StructuralError("primary user id subpacket with bad length")
+ return
+ }
+ sig.IsPrimaryId = new(bool)
+ if subpacket[0] > 0 {
+ *sig.IsPrimaryId = true
+ }
+ case keyFlagsSubpacket:
+ // Key flags, section 5.2.3.21
+ if !isHashed {
+ return
+ }
+ if len(subpacket) == 0 {
+ err = errors.StructuralError("empty key flags subpacket")
+ return
+ }
+ sig.FlagsValid = true
+ if subpacket[0]&KeyFlagCertify != 0 {
+ sig.FlagCertify = true
+ }
+ if subpacket[0]&KeyFlagSign != 0 {
+ sig.FlagSign = true
+ }
+ if subpacket[0]&KeyFlagEncryptCommunications != 0 {
+ sig.FlagEncryptCommunications = true
+ }
+ if subpacket[0]&KeyFlagEncryptStorage != 0 {
+ sig.FlagEncryptStorage = true
+ }
+ case reasonForRevocationSubpacket:
+ // Reason For Revocation, section 5.2.3.23
+ if !isHashed {
+ return
+ }
+ if len(subpacket) == 0 {
+ err = errors.StructuralError("empty revocation reason subpacket")
+ return
+ }
+ sig.RevocationReason = new(uint8)
+ *sig.RevocationReason = subpacket[0]
+ sig.RevocationReasonText = string(subpacket[1:])
+ case featuresSubpacket:
+ // Features subpacket, section 5.2.3.24 specifies a very general
+ // mechanism for OpenPGP implementations to signal support for new
+ // features. In practice, the subpacket is used exclusively to
+ // indicate support for MDC-protected encryption.
+ sig.MDC = len(subpacket) >= 1 && subpacket[0]&1 == 1
+ case embeddedSignatureSubpacket:
+ // Only usage is in signatures that cross-certify
+ // signing subkeys. section 5.2.3.26 describes the
+ // format, with its usage described in section 11.1
+ if sig.EmbeddedSignature != nil {
+ err = errors.StructuralError("Cannot have multiple embedded signatures")
+ return
+ }
+ sig.EmbeddedSignature = new(Signature)
+ // Embedded signatures are required to be v4 signatures see
+ // section 12.1. However, we only parse v4 signatures in this
+ // file anyway.
+ if err := sig.EmbeddedSignature.parse(bytes.NewBuffer(subpacket)); err != nil {
+ return nil, err
+ }
+ if sigType := sig.EmbeddedSignature.SigType; sigType != SigTypePrimaryKeyBinding {
+ return nil, errors.StructuralError("cross-signature has unexpected type " + strconv.Itoa(int(sigType)))
+ }
+ default:
+ if isCritical {
+ err = errors.UnsupportedError("unknown critical signature subpacket type " + strconv.Itoa(int(packetType)))
+ return
+ }
+ }
+ return
+
+Truncated:
+ err = errors.StructuralError("signature subpacket truncated")
+ return
+}
+
+// subpacketLengthLength returns the length, in bytes, of an encoded length value.
+func subpacketLengthLength(length int) int {
+ if length < 192 {
+ return 1
+ }
+ if length < 16320 {
+ return 2
+ }
+ return 5
+}
+
+// serializeSubpacketLength marshals the given length into to.
+func serializeSubpacketLength(to []byte, length int) int {
+ // RFC 4880, Section 4.2.2.
+ if length < 192 {
+ to[0] = byte(length)
+ return 1
+ }
+ if length < 16320 {
+ length -= 192
+ to[0] = byte((length >> 8) + 192)
+ to[1] = byte(length)
+ return 2
+ }
+ to[0] = 255
+ to[1] = byte(length >> 24)
+ to[2] = byte(length >> 16)
+ to[3] = byte(length >> 8)
+ to[4] = byte(length)
+ return 5
+}
+
+// subpacketsLength returns the serialized length, in bytes, of the given
+// subpackets.
+func subpacketsLength(subpackets []outputSubpacket, hashed bool) (length int) {
+ for _, subpacket := range subpackets {
+ if subpacket.hashed == hashed {
+ length += subpacketLengthLength(len(subpacket.contents) + 1)
+ length += 1 // type byte
+ length += len(subpacket.contents)
+ }
+ }
+ return
+}
+
+// serializeSubpackets marshals the given subpackets into to.
+func serializeSubpackets(to []byte, subpackets []outputSubpacket, hashed bool) {
+ for _, subpacket := range subpackets {
+ if subpacket.hashed == hashed {
+ n := serializeSubpacketLength(to, len(subpacket.contents)+1)
+ to[n] = byte(subpacket.subpacketType)
+ to = to[1+n:]
+ n = copy(to, subpacket.contents)
+ to = to[n:]
+ }
+ }
+ return
+}
+
+// KeyExpired returns whether sig is a self-signature of a key that has
+// expired.
+func (sig *Signature) KeyExpired(currentTime time.Time) bool {
+ if sig.KeyLifetimeSecs == nil {
+ return false
+ }
+ expiry := sig.CreationTime.Add(time.Duration(*sig.KeyLifetimeSecs) * time.Second)
+ return currentTime.After(expiry)
+}
+
+// buildHashSuffix constructs the HashSuffix member of sig in preparation for signing.
+func (sig *Signature) buildHashSuffix() (err error) {
+ hashedSubpacketsLen := subpacketsLength(sig.outSubpackets, true)
+
+ var ok bool
+ l := 6 + hashedSubpacketsLen
+ sig.HashSuffix = make([]byte, l+6)
+ sig.HashSuffix[0] = 4
+ sig.HashSuffix[1] = uint8(sig.SigType)
+ sig.HashSuffix[2] = uint8(sig.PubKeyAlgo)
+ sig.HashSuffix[3], ok = s2k.HashToHashId(sig.Hash)
+ if !ok {
+ sig.HashSuffix = nil
+ return errors.InvalidArgumentError("hash cannot be represented in OpenPGP: " + strconv.Itoa(int(sig.Hash)))
+ }
+ sig.HashSuffix[4] = byte(hashedSubpacketsLen >> 8)
+ sig.HashSuffix[5] = byte(hashedSubpacketsLen)
+ serializeSubpackets(sig.HashSuffix[6:l], sig.outSubpackets, true)
+ trailer := sig.HashSuffix[l:]
+ trailer[0] = 4
+ trailer[1] = 0xff
+ trailer[2] = byte(l >> 24)
+ trailer[3] = byte(l >> 16)
+ trailer[4] = byte(l >> 8)
+ trailer[5] = byte(l)
+ return
+}
+
+func (sig *Signature) signPrepareHash(h hash.Hash) (digest []byte, err error) {
+ err = sig.buildHashSuffix()
+ if err != nil {
+ return
+ }
+
+ h.Write(sig.HashSuffix)
+ digest = h.Sum(nil)
+ copy(sig.HashTag[:], digest)
+ return
+}
+
+// Sign signs a message with a private key. The hash, h, must contain
+// the hash of the message to be signed and will be mutated by this function.
+// On success, the signature is stored in sig. Call Serialize to write it out.
+// If config is nil, sensible defaults will be used.
+func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey, config *Config) (err error) {
+ sig.outSubpackets = sig.buildSubpackets()
+ digest, err := sig.signPrepareHash(h)
+ if err != nil {
+ return
+ }
+
+ switch priv.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
+ // supports both *rsa.PrivateKey and crypto.Signer
+ sig.RSASignature.bytes, err = priv.PrivateKey.(crypto.Signer).Sign(config.Random(), digest, sig.Hash)
+ sig.RSASignature.bitLength = uint16(8 * len(sig.RSASignature.bytes))
+ case PubKeyAlgoDSA:
+ dsaPriv := priv.PrivateKey.(*dsa.PrivateKey)
+
+ // Need to truncate hashBytes to match FIPS 186-3 section 4.6.
+ subgroupSize := (dsaPriv.Q.BitLen() + 7) / 8
+ if len(digest) > subgroupSize {
+ digest = digest[:subgroupSize]
+ }
+ r, s, err := dsa.Sign(config.Random(), dsaPriv, digest)
+ if err == nil {
+ sig.DSASigR.bytes = r.Bytes()
+ sig.DSASigR.bitLength = uint16(8 * len(sig.DSASigR.bytes))
+ sig.DSASigS.bytes = s.Bytes()
+ sig.DSASigS.bitLength = uint16(8 * len(sig.DSASigS.bytes))
+ }
+ case PubKeyAlgoECDSA:
+ var r, s *big.Int
+ if pk, ok := priv.PrivateKey.(*ecdsa.PrivateKey); ok {
+ // direct support, avoid asn1 wrapping/unwrapping
+ r, s, err = ecdsa.Sign(config.Random(), pk, digest)
+ } else {
+ var b []byte
+ b, err = priv.PrivateKey.(crypto.Signer).Sign(config.Random(), digest, nil)
+ if err == nil {
+ r, s, err = unwrapECDSASig(b)
+ }
+ }
+ if err == nil {
+ sig.ECDSASigR = fromBig(r)
+ sig.ECDSASigS = fromBig(s)
+ }
+ default:
+ err = errors.UnsupportedError("public key algorithm: " + strconv.Itoa(int(sig.PubKeyAlgo)))
+ }
+
+ return
+}
+
+// unwrapECDSASig parses the two integer components of an ASN.1-encoded ECDSA
+// signature.
+func unwrapECDSASig(b []byte) (r, s *big.Int, err error) {
+ var ecsdaSig struct {
+ R, S *big.Int
+ }
+ _, err = asn1.Unmarshal(b, &ecsdaSig)
+ if err != nil {
+ return
+ }
+ return ecsdaSig.R, ecsdaSig.S, nil
+}
+
+// SignUserId computes a signature from priv, asserting that pub is a valid
+// key for the identity id. On success, the signature is stored in sig. Call
+// Serialize to write it out.
+// If config is nil, sensible defaults will be used.
+func (sig *Signature) SignUserId(id string, pub *PublicKey, priv *PrivateKey, config *Config) error {
+ h, err := userIdSignatureHash(id, pub, sig.Hash)
+ if err != nil {
+ return err
+ }
+ return sig.Sign(h, priv, config)
+}
+
+// SignKey computes a signature from priv, asserting that pub is a subkey. On
+// success, the signature is stored in sig. Call Serialize to write it out.
+// If config is nil, sensible defaults will be used.
+func (sig *Signature) SignKey(pub *PublicKey, priv *PrivateKey, config *Config) error {
+ h, err := keySignatureHash(&priv.PublicKey, pub, sig.Hash)
+ if err != nil {
+ return err
+ }
+ return sig.Sign(h, priv, config)
+}
+
+// Serialize marshals sig to w. Sign, SignUserId or SignKey must have been
+// called first.
+func (sig *Signature) Serialize(w io.Writer) (err error) {
+ if len(sig.outSubpackets) == 0 {
+ sig.outSubpackets = sig.rawSubpackets
+ }
+ if sig.RSASignature.bytes == nil && sig.DSASigR.bytes == nil && sig.ECDSASigR.bytes == nil {
+ return errors.InvalidArgumentError("Signature: need to call Sign, SignUserId or SignKey before Serialize")
+ }
+
+ sigLength := 0
+ switch sig.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
+ sigLength = 2 + len(sig.RSASignature.bytes)
+ case PubKeyAlgoDSA:
+ sigLength = 2 + len(sig.DSASigR.bytes)
+ sigLength += 2 + len(sig.DSASigS.bytes)
+ case PubKeyAlgoECDSA:
+ sigLength = 2 + len(sig.ECDSASigR.bytes)
+ sigLength += 2 + len(sig.ECDSASigS.bytes)
+ default:
+ panic("impossible")
+ }
+
+ unhashedSubpacketsLen := subpacketsLength(sig.outSubpackets, false)
+ length := len(sig.HashSuffix) - 6 /* trailer not included */ +
+ 2 /* length of unhashed subpackets */ + unhashedSubpacketsLen +
+ 2 /* hash tag */ + sigLength
+ err = serializeHeader(w, packetTypeSignature, length)
+ if err != nil {
+ return
+ }
+
+ _, err = w.Write(sig.HashSuffix[:len(sig.HashSuffix)-6])
+ if err != nil {
+ return
+ }
+
+ unhashedSubpackets := make([]byte, 2+unhashedSubpacketsLen)
+ unhashedSubpackets[0] = byte(unhashedSubpacketsLen >> 8)
+ unhashedSubpackets[1] = byte(unhashedSubpacketsLen)
+ serializeSubpackets(unhashedSubpackets[2:], sig.outSubpackets, false)
+
+ _, err = w.Write(unhashedSubpackets)
+ if err != nil {
+ return
+ }
+ _, err = w.Write(sig.HashTag[:])
+ if err != nil {
+ return
+ }
+
+ switch sig.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
+ err = writeMPIs(w, sig.RSASignature)
+ case PubKeyAlgoDSA:
+ err = writeMPIs(w, sig.DSASigR, sig.DSASigS)
+ case PubKeyAlgoECDSA:
+ err = writeMPIs(w, sig.ECDSASigR, sig.ECDSASigS)
+ default:
+ panic("impossible")
+ }
+ return
+}
+
+// outputSubpacket represents a subpacket to be marshaled.
+type outputSubpacket struct {
+ hashed bool // true if this subpacket is in the hashed area.
+ subpacketType signatureSubpacketType
+ isCritical bool
+ contents []byte
+}
+
+func (sig *Signature) buildSubpackets() (subpackets []outputSubpacket) {
+ creationTime := make([]byte, 4)
+ binary.BigEndian.PutUint32(creationTime, uint32(sig.CreationTime.Unix()))
+ subpackets = append(subpackets, outputSubpacket{true, creationTimeSubpacket, false, creationTime})
+
+ if sig.IssuerKeyId != nil {
+ keyId := make([]byte, 8)
+ binary.BigEndian.PutUint64(keyId, *sig.IssuerKeyId)
+ subpackets = append(subpackets, outputSubpacket{true, issuerSubpacket, false, keyId})
+ }
+
+ if sig.SigLifetimeSecs != nil && *sig.SigLifetimeSecs != 0 {
+ sigLifetime := make([]byte, 4)
+ binary.BigEndian.PutUint32(sigLifetime, *sig.SigLifetimeSecs)
+ subpackets = append(subpackets, outputSubpacket{true, signatureExpirationSubpacket, true, sigLifetime})
+ }
+
+ // Key flags may only appear in self-signatures or certification signatures.
+
+ if sig.FlagsValid {
+ var flags byte
+ if sig.FlagCertify {
+ flags |= KeyFlagCertify
+ }
+ if sig.FlagSign {
+ flags |= KeyFlagSign
+ }
+ if sig.FlagEncryptCommunications {
+ flags |= KeyFlagEncryptCommunications
+ }
+ if sig.FlagEncryptStorage {
+ flags |= KeyFlagEncryptStorage
+ }
+ subpackets = append(subpackets, outputSubpacket{true, keyFlagsSubpacket, false, []byte{flags}})
+ }
+
+ // The following subpackets may only appear in self-signatures
+
+ if sig.KeyLifetimeSecs != nil && *sig.KeyLifetimeSecs != 0 {
+ keyLifetime := make([]byte, 4)
+ binary.BigEndian.PutUint32(keyLifetime, *sig.KeyLifetimeSecs)
+ subpackets = append(subpackets, outputSubpacket{true, keyExpirationSubpacket, true, keyLifetime})
+ }
+
+ if sig.IsPrimaryId != nil && *sig.IsPrimaryId {
+ subpackets = append(subpackets, outputSubpacket{true, primaryUserIdSubpacket, false, []byte{1}})
+ }
+
+ if len(sig.PreferredSymmetric) > 0 {
+ subpackets = append(subpackets, outputSubpacket{true, prefSymmetricAlgosSubpacket, false, sig.PreferredSymmetric})
+ }
+
+ if len(sig.PreferredHash) > 0 {
+ subpackets = append(subpackets, outputSubpacket{true, prefHashAlgosSubpacket, false, sig.PreferredHash})
+ }
+
+ if len(sig.PreferredCompression) > 0 {
+ subpackets = append(subpackets, outputSubpacket{true, prefCompressionSubpacket, false, sig.PreferredCompression})
+ }
+
+ return
+}
--- /dev/null
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "crypto"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "strconv"
+ "time"
+
+ "golang.org/x/crypto/openpgp/errors"
+ "golang.org/x/crypto/openpgp/s2k"
+)
+
+// SignatureV3 represents older version 3 signatures. These signatures are less secure
+// than version 4 and should not be used to create new signatures. They are included
+// here for backwards compatibility to read and validate with older key material.
+// See RFC 4880, section 5.2.2.
+type SignatureV3 struct {
+ SigType SignatureType
+ CreationTime time.Time
+ IssuerKeyId uint64
+ PubKeyAlgo PublicKeyAlgorithm
+ Hash crypto.Hash
+ HashTag [2]byte
+
+ RSASignature parsedMPI
+ DSASigR, DSASigS parsedMPI
+}
+
+func (sig *SignatureV3) parse(r io.Reader) (err error) {
+ // RFC 4880, section 5.2.2
+ var buf [8]byte
+ if _, err = readFull(r, buf[:1]); err != nil {
+ return
+ }
+ if buf[0] < 2 || buf[0] > 3 {
+ err = errors.UnsupportedError("signature packet version " + strconv.Itoa(int(buf[0])))
+ return
+ }
+ if _, err = readFull(r, buf[:1]); err != nil {
+ return
+ }
+ if buf[0] != 5 {
+ err = errors.UnsupportedError(
+ "invalid hashed material length " + strconv.Itoa(int(buf[0])))
+ return
+ }
+
+ // Read hashed material: signature type + creation time
+ if _, err = readFull(r, buf[:5]); err != nil {
+ return
+ }
+ sig.SigType = SignatureType(buf[0])
+ t := binary.BigEndian.Uint32(buf[1:5])
+ sig.CreationTime = time.Unix(int64(t), 0)
+
+ // Eight-octet Key ID of signer.
+ if _, err = readFull(r, buf[:8]); err != nil {
+ return
+ }
+ sig.IssuerKeyId = binary.BigEndian.Uint64(buf[:])
+
+ // Public-key and hash algorithm
+ if _, err = readFull(r, buf[:2]); err != nil {
+ return
+ }
+ sig.PubKeyAlgo = PublicKeyAlgorithm(buf[0])
+ switch sig.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA:
+ default:
+ err = errors.UnsupportedError("public key algorithm " + strconv.Itoa(int(sig.PubKeyAlgo)))
+ return
+ }
+ var ok bool
+ if sig.Hash, ok = s2k.HashIdToHash(buf[1]); !ok {
+ return errors.UnsupportedError("hash function " + strconv.Itoa(int(buf[2])))
+ }
+
+ // Two-octet field holding left 16 bits of signed hash value.
+ if _, err = readFull(r, sig.HashTag[:2]); err != nil {
+ return
+ }
+
+ switch sig.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
+ sig.RSASignature.bytes, sig.RSASignature.bitLength, err = readMPI(r)
+ case PubKeyAlgoDSA:
+ if sig.DSASigR.bytes, sig.DSASigR.bitLength, err = readMPI(r); err != nil {
+ return
+ }
+ sig.DSASigS.bytes, sig.DSASigS.bitLength, err = readMPI(r)
+ default:
+ panic("unreachable")
+ }
+ return
+}
+
+// Serialize marshals sig to w. Sign, SignUserId or SignKey must have been
+// called first.
+func (sig *SignatureV3) Serialize(w io.Writer) (err error) {
+ buf := make([]byte, 8)
+
+ // Write the sig type and creation time
+ buf[0] = byte(sig.SigType)
+ binary.BigEndian.PutUint32(buf[1:5], uint32(sig.CreationTime.Unix()))
+ if _, err = w.Write(buf[:5]); err != nil {
+ return
+ }
+
+ // Write the issuer long key ID
+ binary.BigEndian.PutUint64(buf[:8], sig.IssuerKeyId)
+ if _, err = w.Write(buf[:8]); err != nil {
+ return
+ }
+
+ // Write public key algorithm, hash ID, and hash value
+ buf[0] = byte(sig.PubKeyAlgo)
+ hashId, ok := s2k.HashToHashId(sig.Hash)
+ if !ok {
+ return errors.UnsupportedError(fmt.Sprintf("hash function %v", sig.Hash))
+ }
+ buf[1] = hashId
+ copy(buf[2:4], sig.HashTag[:])
+ if _, err = w.Write(buf[:4]); err != nil {
+ return
+ }
+
+ if sig.RSASignature.bytes == nil && sig.DSASigR.bytes == nil {
+ return errors.InvalidArgumentError("Signature: need to call Sign, SignUserId or SignKey before Serialize")
+ }
+
+ switch sig.PubKeyAlgo {
+ case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
+ err = writeMPIs(w, sig.RSASignature)
+ case PubKeyAlgoDSA:
+ err = writeMPIs(w, sig.DSASigR, sig.DSASigS)
+ default:
+ panic("impossible")
+ }
+ return
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "bytes"
+ "crypto/cipher"
+ "io"
+ "strconv"
+
+ "golang.org/x/crypto/openpgp/errors"
+ "golang.org/x/crypto/openpgp/s2k"
+)
+
+// This is the largest session key that we'll support. Since no 512-bit cipher
+// has even been seriously used, this is comfortably large.
+const maxSessionKeySizeInBytes = 64
+
+// SymmetricKeyEncrypted represents a passphrase protected session key. See RFC
+// 4880, section 5.3.
+type SymmetricKeyEncrypted struct {
+ CipherFunc CipherFunction
+ s2k func(out, in []byte)
+ encryptedKey []byte
+}
+
+const symmetricKeyEncryptedVersion = 4
+
+func (ske *SymmetricKeyEncrypted) parse(r io.Reader) error {
+ // RFC 4880, section 5.3.
+ var buf [2]byte
+ if _, err := readFull(r, buf[:]); err != nil {
+ return err
+ }
+ if buf[0] != symmetricKeyEncryptedVersion {
+ return errors.UnsupportedError("SymmetricKeyEncrypted version")
+ }
+ ske.CipherFunc = CipherFunction(buf[1])
+
+ if ske.CipherFunc.KeySize() == 0 {
+ return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[1])))
+ }
+
+ var err error
+ ske.s2k, err = s2k.Parse(r)
+ if err != nil {
+ return err
+ }
+
+ encryptedKey := make([]byte, maxSessionKeySizeInBytes)
+ // The session key may follow. We just have to try and read to find
+ // out. If it exists then we limit it to maxSessionKeySizeInBytes.
+ n, err := readFull(r, encryptedKey)
+ if err != nil && err != io.ErrUnexpectedEOF {
+ return err
+ }
+
+ if n != 0 {
+ if n == maxSessionKeySizeInBytes {
+ return errors.UnsupportedError("oversized encrypted session key")
+ }
+ ske.encryptedKey = encryptedKey[:n]
+ }
+
+ return nil
+}
+
+// Decrypt attempts to decrypt an encrypted session key and returns the key and
+// the cipher to use when decrypting a subsequent Symmetrically Encrypted Data
+// packet.
+func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) ([]byte, CipherFunction, error) {
+ key := make([]byte, ske.CipherFunc.KeySize())
+ ske.s2k(key, passphrase)
+
+ if len(ske.encryptedKey) == 0 {
+ return key, ske.CipherFunc, nil
+ }
+
+ // the IV is all zeros
+ iv := make([]byte, ske.CipherFunc.blockSize())
+ c := cipher.NewCFBDecrypter(ske.CipherFunc.new(key), iv)
+ plaintextKey := make([]byte, len(ske.encryptedKey))
+ c.XORKeyStream(plaintextKey, ske.encryptedKey)
+ cipherFunc := CipherFunction(plaintextKey[0])
+ if cipherFunc.blockSize() == 0 {
+ return nil, ske.CipherFunc, errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(cipherFunc)))
+ }
+ plaintextKey = plaintextKey[1:]
+ if l, cipherKeySize := len(plaintextKey), cipherFunc.KeySize(); l != cipherFunc.KeySize() {
+ return nil, cipherFunc, errors.StructuralError("length of decrypted key (" + strconv.Itoa(l) + ") " +
+ "not equal to cipher keysize (" + strconv.Itoa(cipherKeySize) + ")")
+ }
+ return plaintextKey, cipherFunc, nil
+}
+
+// SerializeSymmetricKeyEncrypted serializes a symmetric key packet to w. The
+// packet contains a random session key, encrypted by a key derived from the
+// given passphrase. The session key is returned and must be passed to
+// SerializeSymmetricallyEncrypted.
+// If config is nil, sensible defaults will be used.
+func SerializeSymmetricKeyEncrypted(w io.Writer, passphrase []byte, config *Config) (key []byte, err error) {
+ cipherFunc := config.Cipher()
+ keySize := cipherFunc.KeySize()
+ if keySize == 0 {
+ return nil, errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(cipherFunc)))
+ }
+
+ s2kBuf := new(bytes.Buffer)
+ keyEncryptingKey := make([]byte, keySize)
+ // s2k.Serialize salts and stretches the passphrase, and writes the
+ // resulting key to keyEncryptingKey and the s2k descriptor to s2kBuf.
+ err = s2k.Serialize(s2kBuf, keyEncryptingKey, config.Random(), passphrase, &s2k.Config{Hash: config.Hash(), S2KCount: config.PasswordHashIterations()})
+ if err != nil {
+ return
+ }
+ s2kBytes := s2kBuf.Bytes()
+
+ packetLength := 2 /* header */ + len(s2kBytes) + 1 /* cipher type */ + keySize
+ err = serializeHeader(w, packetTypeSymmetricKeyEncrypted, packetLength)
+ if err != nil {
+ return
+ }
+
+ var buf [2]byte
+ buf[0] = symmetricKeyEncryptedVersion
+ buf[1] = byte(cipherFunc)
+ _, err = w.Write(buf[:])
+ if err != nil {
+ return
+ }
+ _, err = w.Write(s2kBytes)
+ if err != nil {
+ return
+ }
+
+ sessionKey := make([]byte, keySize)
+ _, err = io.ReadFull(config.Random(), sessionKey)
+ if err != nil {
+ return
+ }
+ iv := make([]byte, cipherFunc.blockSize())
+ c := cipher.NewCFBEncrypter(cipherFunc.new(keyEncryptingKey), iv)
+ encryptedCipherAndKey := make([]byte, keySize+1)
+ c.XORKeyStream(encryptedCipherAndKey, buf[1:])
+ c.XORKeyStream(encryptedCipherAndKey[1:], sessionKey)
+ _, err = w.Write(encryptedCipherAndKey)
+ if err != nil {
+ return
+ }
+
+ key = sessionKey
+ return
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "crypto/cipher"
+ "crypto/sha1"
+ "crypto/subtle"
+ "golang.org/x/crypto/openpgp/errors"
+ "hash"
+ "io"
+ "strconv"
+)
+
+// SymmetricallyEncrypted represents a symmetrically encrypted byte string. The
+// encrypted contents will consist of more OpenPGP packets. See RFC 4880,
+// sections 5.7 and 5.13.
+type SymmetricallyEncrypted struct {
+ MDC bool // true iff this is a type 18 packet and thus has an embedded MAC.
+ contents io.Reader
+ prefix []byte
+}
+
+const symmetricallyEncryptedVersion = 1
+
+func (se *SymmetricallyEncrypted) parse(r io.Reader) error {
+ if se.MDC {
+ // See RFC 4880, section 5.13.
+ var buf [1]byte
+ _, err := readFull(r, buf[:])
+ if err != nil {
+ return err
+ }
+ if buf[0] != symmetricallyEncryptedVersion {
+ return errors.UnsupportedError("unknown SymmetricallyEncrypted version")
+ }
+ }
+ se.contents = r
+ return nil
+}
+
+// Decrypt returns a ReadCloser, from which the decrypted contents of the
+// packet can be read. An incorrect key can, with high probability, be detected
+// immediately and this will result in a KeyIncorrect error being returned.
+func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.ReadCloser, error) {
+ keySize := c.KeySize()
+ if keySize == 0 {
+ return nil, errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(c)))
+ }
+ if len(key) != keySize {
+ return nil, errors.InvalidArgumentError("SymmetricallyEncrypted: incorrect key length")
+ }
+
+ if se.prefix == nil {
+ se.prefix = make([]byte, c.blockSize()+2)
+ _, err := readFull(se.contents, se.prefix)
+ if err != nil {
+ return nil, err
+ }
+ } else if len(se.prefix) != c.blockSize()+2 {
+ return nil, errors.InvalidArgumentError("can't try ciphers with different block lengths")
+ }
+
+ ocfbResync := OCFBResync
+ if se.MDC {
+ // MDC packets use a different form of OCFB mode.
+ ocfbResync = OCFBNoResync
+ }
+
+ s := NewOCFBDecrypter(c.new(key), se.prefix, ocfbResync)
+ if s == nil {
+ return nil, errors.ErrKeyIncorrect
+ }
+
+ plaintext := cipher.StreamReader{S: s, R: se.contents}
+
+ if se.MDC {
+ // MDC packets have an embedded hash that we need to check.
+ h := sha1.New()
+ h.Write(se.prefix)
+ return &seMDCReader{in: plaintext, h: h}, nil
+ }
+
+ // Otherwise, we just need to wrap plaintext so that it's a valid ReadCloser.
+ return seReader{plaintext}, nil
+}
+
+// seReader wraps an io.Reader with a no-op Close method.
+type seReader struct {
+ in io.Reader
+}
+
+func (ser seReader) Read(buf []byte) (int, error) {
+ return ser.in.Read(buf)
+}
+
+func (ser seReader) Close() error {
+ return nil
+}
+
+const mdcTrailerSize = 1 /* tag byte */ + 1 /* length byte */ + sha1.Size
+
+// An seMDCReader wraps an io.Reader, maintains a running hash and keeps hold
+// of the most recent 22 bytes (mdcTrailerSize). Upon EOF, those bytes form an
+// MDC packet containing a hash of the previous contents which is checked
+// against the running hash. See RFC 4880, section 5.13.
+type seMDCReader struct {
+ in io.Reader
+ h hash.Hash
+ trailer [mdcTrailerSize]byte
+ scratch [mdcTrailerSize]byte
+ trailerUsed int
+ error bool
+ eof bool
+}
+
+func (ser *seMDCReader) Read(buf []byte) (n int, err error) {
+ if ser.error {
+ err = io.ErrUnexpectedEOF
+ return
+ }
+ if ser.eof {
+ err = io.EOF
+ return
+ }
+
+ // If we haven't yet filled the trailer buffer then we must do that
+ // first.
+ for ser.trailerUsed < mdcTrailerSize {
+ n, err = ser.in.Read(ser.trailer[ser.trailerUsed:])
+ ser.trailerUsed += n
+ if err == io.EOF {
+ if ser.trailerUsed != mdcTrailerSize {
+ n = 0
+ err = io.ErrUnexpectedEOF
+ ser.error = true
+ return
+ }
+ ser.eof = true
+ n = 0
+ return
+ }
+
+ if err != nil {
+ n = 0
+ return
+ }
+ }
+
+ // If it's a short read then we read into a temporary buffer and shift
+ // the data into the caller's buffer.
+ if len(buf) <= mdcTrailerSize {
+ n, err = readFull(ser.in, ser.scratch[:len(buf)])
+ copy(buf, ser.trailer[:n])
+ ser.h.Write(buf[:n])
+ copy(ser.trailer[:], ser.trailer[n:])
+ copy(ser.trailer[mdcTrailerSize-n:], ser.scratch[:])
+ if n < len(buf) {
+ ser.eof = true
+ err = io.EOF
+ }
+ return
+ }
+
+ n, err = ser.in.Read(buf[mdcTrailerSize:])
+ copy(buf, ser.trailer[:])
+ ser.h.Write(buf[:n])
+ copy(ser.trailer[:], buf[n:])
+
+ if err == io.EOF {
+ ser.eof = true
+ }
+ return
+}
+
+// This is a new-format packet tag byte for a type 19 (MDC) packet.
+const mdcPacketTagByte = byte(0x80) | 0x40 | 19
+
+func (ser *seMDCReader) Close() error {
+ if ser.error {
+ return errors.SignatureError("error during reading")
+ }
+
+ for !ser.eof {
+ // We haven't seen EOF so we need to read to the end
+ var buf [1024]byte
+ _, err := ser.Read(buf[:])
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return errors.SignatureError("error during reading")
+ }
+ }
+
+ if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size {
+ return errors.SignatureError("MDC packet not found")
+ }
+ ser.h.Write(ser.trailer[:2])
+
+ final := ser.h.Sum(nil)
+ if subtle.ConstantTimeCompare(final, ser.trailer[2:]) != 1 {
+ return errors.SignatureError("hash mismatch")
+ }
+ return nil
+}
+
+// An seMDCWriter writes through to an io.WriteCloser while maintains a running
+// hash of the data written. On close, it emits an MDC packet containing the
+// running hash.
+type seMDCWriter struct {
+ w io.WriteCloser
+ h hash.Hash
+}
+
+func (w *seMDCWriter) Write(buf []byte) (n int, err error) {
+ w.h.Write(buf)
+ return w.w.Write(buf)
+}
+
+func (w *seMDCWriter) Close() (err error) {
+ var buf [mdcTrailerSize]byte
+
+ buf[0] = mdcPacketTagByte
+ buf[1] = sha1.Size
+ w.h.Write(buf[:2])
+ digest := w.h.Sum(nil)
+ copy(buf[2:], digest)
+
+ _, err = w.w.Write(buf[:])
+ if err != nil {
+ return
+ }
+ return w.w.Close()
+}
+
+// noOpCloser is like an ioutil.NopCloser, but for an io.Writer.
+type noOpCloser struct {
+ w io.Writer
+}
+
+func (c noOpCloser) Write(data []byte) (n int, err error) {
+ return c.w.Write(data)
+}
+
+func (c noOpCloser) Close() error {
+ return nil
+}
+
+// SerializeSymmetricallyEncrypted serializes a symmetrically encrypted packet
+// to w and returns a WriteCloser to which the to-be-encrypted packets can be
+// written.
+// If config is nil, sensible defaults will be used.
+func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, key []byte, config *Config) (contents io.WriteCloser, err error) {
+ if c.KeySize() != len(key) {
+ return nil, errors.InvalidArgumentError("SymmetricallyEncrypted.Serialize: bad key length")
+ }
+ writeCloser := noOpCloser{w}
+ ciphertext, err := serializeStreamHeader(writeCloser, packetTypeSymmetricallyEncryptedMDC)
+ if err != nil {
+ return
+ }
+
+ _, err = ciphertext.Write([]byte{symmetricallyEncryptedVersion})
+ if err != nil {
+ return
+ }
+
+ block := c.new(key)
+ blockSize := block.BlockSize()
+ iv := make([]byte, blockSize)
+ _, err = config.Random().Read(iv)
+ if err != nil {
+ return
+ }
+ s, prefix := NewOCFBEncrypter(block, iv, OCFBNoResync)
+ _, err = ciphertext.Write(prefix)
+ if err != nil {
+ return
+ }
+ plaintext := cipher.StreamWriter{S: s, W: ciphertext}
+
+ h := sha1.New()
+ h.Write(iv)
+ h.Write(iv[blockSize-2:])
+ contents = &seMDCWriter{w: plaintext, h: h}
+ return
+}
--- /dev/null
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "bytes"
+ "image"
+ "image/jpeg"
+ "io"
+ "io/ioutil"
+)
+
+const UserAttrImageSubpacket = 1
+
+// UserAttribute is capable of storing other types of data about a user
+// beyond name, email and a text comment. In practice, user attributes are typically used
+// to store a signed thumbnail photo JPEG image of the user.
+// See RFC 4880, section 5.12.
+type UserAttribute struct {
+ Contents []*OpaqueSubpacket
+}
+
+// NewUserAttributePhoto creates a user attribute packet
+// containing the given images.
+func NewUserAttributePhoto(photos ...image.Image) (uat *UserAttribute, err error) {
+ uat = new(UserAttribute)
+ for _, photo := range photos {
+ var buf bytes.Buffer
+ // RFC 4880, Section 5.12.1.
+ data := []byte{
+ 0x10, 0x00, // Little-endian image header length (16 bytes)
+ 0x01, // Image header version 1
+ 0x01, // JPEG
+ 0, 0, 0, 0, // 12 reserved octets, must be all zero.
+ 0, 0, 0, 0,
+ 0, 0, 0, 0}
+ if _, err = buf.Write(data); err != nil {
+ return
+ }
+ if err = jpeg.Encode(&buf, photo, nil); err != nil {
+ return
+ }
+ uat.Contents = append(uat.Contents, &OpaqueSubpacket{
+ SubType: UserAttrImageSubpacket,
+ Contents: buf.Bytes()})
+ }
+ return
+}
+
+// NewUserAttribute creates a new user attribute packet containing the given subpackets.
+func NewUserAttribute(contents ...*OpaqueSubpacket) *UserAttribute {
+ return &UserAttribute{Contents: contents}
+}
+
+func (uat *UserAttribute) parse(r io.Reader) (err error) {
+ // RFC 4880, section 5.13
+ b, err := ioutil.ReadAll(r)
+ if err != nil {
+ return
+ }
+ uat.Contents, err = OpaqueSubpackets(b)
+ return
+}
+
+// Serialize marshals the user attribute to w in the form of an OpenPGP packet, including
+// header.
+func (uat *UserAttribute) Serialize(w io.Writer) (err error) {
+ var buf bytes.Buffer
+ for _, sp := range uat.Contents {
+ sp.Serialize(&buf)
+ }
+ if err = serializeHeader(w, packetTypeUserAttribute, buf.Len()); err != nil {
+ return err
+ }
+ _, err = w.Write(buf.Bytes())
+ return
+}
+
+// ImageData returns zero or more byte slices, each containing
+// JPEG File Interchange Format (JFIF), for each photo in the
+// the user attribute packet.
+func (uat *UserAttribute) ImageData() (imageData [][]byte) {
+ for _, sp := range uat.Contents {
+ if sp.SubType == UserAttrImageSubpacket && len(sp.Contents) > 16 {
+ imageData = append(imageData, sp.Contents[16:])
+ }
+ }
+ return
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+ "io"
+ "io/ioutil"
+ "strings"
+)
+
+// UserId contains text that is intended to represent the name and email
+// address of the key holder. See RFC 4880, section 5.11. By convention, this
+// takes the form "Full Name (Comment) <email@example.com>"
+type UserId struct {
+ Id string // By convention, this takes the form "Full Name (Comment) <email@example.com>" which is split out in the fields below.
+
+ Name, Comment, Email string
+}
+
+func hasInvalidCharacters(s string) bool {
+ for _, c := range s {
+ switch c {
+ case '(', ')', '<', '>', 0:
+ return true
+ }
+ }
+ return false
+}
+
+// NewUserId returns a UserId or nil if any of the arguments contain invalid
+// characters. The invalid characters are '\x00', '(', ')', '<' and '>'
+func NewUserId(name, comment, email string) *UserId {
+ // RFC 4880 doesn't deal with the structure of userid strings; the
+ // name, comment and email form is just a convention. However, there's
+ // no convention about escaping the metacharacters and GPG just refuses
+ // to create user ids where, say, the name contains a '('. We mirror
+ // this behaviour.
+
+ if hasInvalidCharacters(name) || hasInvalidCharacters(comment) || hasInvalidCharacters(email) {
+ return nil
+ }
+
+ uid := new(UserId)
+ uid.Name, uid.Comment, uid.Email = name, comment, email
+ uid.Id = name
+ if len(comment) > 0 {
+ if len(uid.Id) > 0 {
+ uid.Id += " "
+ }
+ uid.Id += "("
+ uid.Id += comment
+ uid.Id += ")"
+ }
+ if len(email) > 0 {
+ if len(uid.Id) > 0 {
+ uid.Id += " "
+ }
+ uid.Id += "<"
+ uid.Id += email
+ uid.Id += ">"
+ }
+ return uid
+}
+
+func (uid *UserId) parse(r io.Reader) (err error) {
+ // RFC 4880, section 5.11
+ b, err := ioutil.ReadAll(r)
+ if err != nil {
+ return
+ }
+ uid.Id = string(b)
+ uid.Name, uid.Comment, uid.Email = parseUserId(uid.Id)
+ return
+}
+
+// Serialize marshals uid to w in the form of an OpenPGP packet, including
+// header.
+func (uid *UserId) Serialize(w io.Writer) error {
+ err := serializeHeader(w, packetTypeUserId, len(uid.Id))
+ if err != nil {
+ return err
+ }
+ _, err = w.Write([]byte(uid.Id))
+ return err
+}
+
+// parseUserId extracts the name, comment and email from a user id string that
+// is formatted as "Full Name (Comment) <email@example.com>".
+func parseUserId(id string) (name, comment, email string) {
+ var n, c, e struct {
+ start, end int
+ }
+ var state int
+
+ for offset, rune := range id {
+ switch state {
+ case 0:
+ // Entering name
+ n.start = offset
+ state = 1
+ fallthrough
+ case 1:
+ // In name
+ if rune == '(' {
+ state = 2
+ n.end = offset
+ } else if rune == '<' {
+ state = 5
+ n.end = offset
+ }
+ case 2:
+ // Entering comment
+ c.start = offset
+ state = 3
+ fallthrough
+ case 3:
+ // In comment
+ if rune == ')' {
+ state = 4
+ c.end = offset
+ }
+ case 4:
+ // Between comment and email
+ if rune == '<' {
+ state = 5
+ }
+ case 5:
+ // Entering email
+ e.start = offset
+ state = 6
+ fallthrough
+ case 6:
+ // In email
+ if rune == '>' {
+ state = 7
+ e.end = offset
+ }
+ default:
+ // After email
+ }
+ }
+ switch state {
+ case 1:
+ // ended in the name
+ n.end = len(id)
+ case 3:
+ // ended in comment
+ c.end = len(id)
+ case 6:
+ // ended in email
+ e.end = len(id)
+ }
+
+ name = strings.TrimSpace(id[n.start:n.end])
+ comment = strings.TrimSpace(id[c.start:c.end])
+ email = strings.TrimSpace(id[e.start:e.end])
+ return
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package openpgp implements high level operations on OpenPGP messages.
+package openpgp // import "golang.org/x/crypto/openpgp"
+
+import (
+ "crypto"
+ _ "crypto/sha256"
+ "hash"
+ "io"
+ "strconv"
+
+ "golang.org/x/crypto/openpgp/armor"
+ "golang.org/x/crypto/openpgp/errors"
+ "golang.org/x/crypto/openpgp/packet"
+)
+
+// SignatureType is the armor type for a PGP signature.
+var SignatureType = "PGP SIGNATURE"
+
+// readArmored reads an armored block with the given type.
+func readArmored(r io.Reader, expectedType string) (body io.Reader, err error) {
+ block, err := armor.Decode(r)
+ if err != nil {
+ return
+ }
+
+ if block.Type != expectedType {
+ return nil, errors.InvalidArgumentError("expected '" + expectedType + "', got: " + block.Type)
+ }
+
+ return block.Body, nil
+}
+
+// MessageDetails contains the result of parsing an OpenPGP encrypted and/or
+// signed message.
+type MessageDetails struct {
+ IsEncrypted bool // true if the message was encrypted.
+ EncryptedToKeyIds []uint64 // the list of recipient key ids.
+ IsSymmetricallyEncrypted bool // true if a passphrase could have decrypted the message.
+ DecryptedWith Key // the private key used to decrypt the message, if any.
+ IsSigned bool // true if the message is signed.
+ SignedByKeyId uint64 // the key id of the signer, if any.
+ SignedBy *Key // the key of the signer, if available.
+ LiteralData *packet.LiteralData // the metadata of the contents
+ UnverifiedBody io.Reader // the contents of the message.
+
+ // If IsSigned is true and SignedBy is non-zero then the signature will
+ // be verified as UnverifiedBody is read. The signature cannot be
+ // checked until the whole of UnverifiedBody is read so UnverifiedBody
+ // must be consumed until EOF before the data can be trusted. Even if a
+ // message isn't signed (or the signer is unknown) the data may contain
+ // an authentication code that is only checked once UnverifiedBody has
+ // been consumed. Once EOF has been seen, the following fields are
+ // valid. (An authentication code failure is reported as a
+ // SignatureError error when reading from UnverifiedBody.)
+ SignatureError error // nil if the signature is good.
+ Signature *packet.Signature // the signature packet itself, if v4 (default)
+ SignatureV3 *packet.SignatureV3 // the signature packet if it is a v2 or v3 signature
+
+ decrypted io.ReadCloser
+}
+
+// A PromptFunction is used as a callback by functions that may need to decrypt
+// a private key, or prompt for a passphrase. It is called with a list of
+// acceptable, encrypted private keys and a boolean that indicates whether a
+// passphrase is usable. It should either decrypt a private key or return a
+// passphrase to try. If the decrypted private key or given passphrase isn't
+// correct, the function will be called again, forever. Any error returned will
+// be passed up.
+type PromptFunction func(keys []Key, symmetric bool) ([]byte, error)
+
+// A keyEnvelopePair is used to store a private key with the envelope that
+// contains a symmetric key, encrypted with that key.
+type keyEnvelopePair struct {
+ key Key
+ encryptedKey *packet.EncryptedKey
+}
+
+// ReadMessage parses an OpenPGP message that may be signed and/or encrypted.
+// The given KeyRing should contain both public keys (for signature
+// verification) and, possibly encrypted, private keys for decrypting.
+// If config is nil, sensible defaults will be used.
+func ReadMessage(r io.Reader, keyring KeyRing, prompt PromptFunction, config *packet.Config) (md *MessageDetails, err error) {
+ var p packet.Packet
+
+ var symKeys []*packet.SymmetricKeyEncrypted
+ var pubKeys []keyEnvelopePair
+ var se *packet.SymmetricallyEncrypted
+
+ packets := packet.NewReader(r)
+ md = new(MessageDetails)
+ md.IsEncrypted = true
+
+ // The message, if encrypted, starts with a number of packets
+ // containing an encrypted decryption key. The decryption key is either
+ // encrypted to a public key, or with a passphrase. This loop
+ // collects these packets.
+ParsePackets:
+ for {
+ p, err = packets.Next()
+ if err != nil {
+ return nil, err
+ }
+ switch p := p.(type) {
+ case *packet.SymmetricKeyEncrypted:
+ // This packet contains the decryption key encrypted with a passphrase.
+ md.IsSymmetricallyEncrypted = true
+ symKeys = append(symKeys, p)
+ case *packet.EncryptedKey:
+ // This packet contains the decryption key encrypted to a public key.
+ md.EncryptedToKeyIds = append(md.EncryptedToKeyIds, p.KeyId)
+ switch p.Algo {
+ case packet.PubKeyAlgoRSA, packet.PubKeyAlgoRSAEncryptOnly, packet.PubKeyAlgoElGamal:
+ break
+ default:
+ continue
+ }
+ var keys []Key
+ if p.KeyId == 0 {
+ keys = keyring.DecryptionKeys()
+ } else {
+ keys = keyring.KeysById(p.KeyId)
+ }
+ for _, k := range keys {
+ pubKeys = append(pubKeys, keyEnvelopePair{k, p})
+ }
+ case *packet.SymmetricallyEncrypted:
+ se = p
+ break ParsePackets
+ case *packet.Compressed, *packet.LiteralData, *packet.OnePassSignature:
+ // This message isn't encrypted.
+ if len(symKeys) != 0 || len(pubKeys) != 0 {
+ return nil, errors.StructuralError("key material not followed by encrypted message")
+ }
+ packets.Unread(p)
+ return readSignedMessage(packets, nil, keyring)
+ }
+ }
+
+ var candidates []Key
+ var decrypted io.ReadCloser
+
+ // Now that we have the list of encrypted keys we need to decrypt at
+ // least one of them or, if we cannot, we need to call the prompt
+ // function so that it can decrypt a key or give us a passphrase.
+FindKey:
+ for {
+ // See if any of the keys already have a private key available
+ candidates = candidates[:0]
+ candidateFingerprints := make(map[string]bool)
+
+ for _, pk := range pubKeys {
+ if pk.key.PrivateKey == nil {
+ continue
+ }
+ if !pk.key.PrivateKey.Encrypted {
+ if len(pk.encryptedKey.Key) == 0 {
+ pk.encryptedKey.Decrypt(pk.key.PrivateKey, config)
+ }
+ if len(pk.encryptedKey.Key) == 0 {
+ continue
+ }
+ decrypted, err = se.Decrypt(pk.encryptedKey.CipherFunc, pk.encryptedKey.Key)
+ if err != nil && err != errors.ErrKeyIncorrect {
+ return nil, err
+ }
+ if decrypted != nil {
+ md.DecryptedWith = pk.key
+ break FindKey
+ }
+ } else {
+ fpr := string(pk.key.PublicKey.Fingerprint[:])
+ if v := candidateFingerprints[fpr]; v {
+ continue
+ }
+ candidates = append(candidates, pk.key)
+ candidateFingerprints[fpr] = true
+ }
+ }
+
+ if len(candidates) == 0 && len(symKeys) == 0 {
+ return nil, errors.ErrKeyIncorrect
+ }
+
+ if prompt == nil {
+ return nil, errors.ErrKeyIncorrect
+ }
+
+ passphrase, err := prompt(candidates, len(symKeys) != 0)
+ if err != nil {
+ return nil, err
+ }
+
+ // Try the symmetric passphrase first
+ if len(symKeys) != 0 && passphrase != nil {
+ for _, s := range symKeys {
+ key, cipherFunc, err := s.Decrypt(passphrase)
+ if err == nil {
+ decrypted, err = se.Decrypt(cipherFunc, key)
+ if err != nil && err != errors.ErrKeyIncorrect {
+ return nil, err
+ }
+ if decrypted != nil {
+ break FindKey
+ }
+ }
+
+ }
+ }
+ }
+
+ md.decrypted = decrypted
+ if err := packets.Push(decrypted); err != nil {
+ return nil, err
+ }
+ return readSignedMessage(packets, md, keyring)
+}
+
+// readSignedMessage reads a possibly signed message if mdin is non-zero then
+// that structure is updated and returned. Otherwise a fresh MessageDetails is
+// used.
+func readSignedMessage(packets *packet.Reader, mdin *MessageDetails, keyring KeyRing) (md *MessageDetails, err error) {
+ if mdin == nil {
+ mdin = new(MessageDetails)
+ }
+ md = mdin
+
+ var p packet.Packet
+ var h hash.Hash
+ var wrappedHash hash.Hash
+FindLiteralData:
+ for {
+ p, err = packets.Next()
+ if err != nil {
+ return nil, err
+ }
+ switch p := p.(type) {
+ case *packet.Compressed:
+ if err := packets.Push(p.Body); err != nil {
+ return nil, err
+ }
+ case *packet.OnePassSignature:
+ if !p.IsLast {
+ return nil, errors.UnsupportedError("nested signatures")
+ }
+
+ h, wrappedHash, err = hashForSignature(p.Hash, p.SigType)
+ if err != nil {
+ md = nil
+ return
+ }
+
+ md.IsSigned = true
+ md.SignedByKeyId = p.KeyId
+ keys := keyring.KeysByIdUsage(p.KeyId, packet.KeyFlagSign)
+ if len(keys) > 0 {
+ md.SignedBy = &keys[0]
+ }
+ case *packet.LiteralData:
+ md.LiteralData = p
+ break FindLiteralData
+ }
+ }
+
+ if md.SignedBy != nil {
+ md.UnverifiedBody = &signatureCheckReader{packets, h, wrappedHash, md}
+ } else if md.decrypted != nil {
+ md.UnverifiedBody = checkReader{md}
+ } else {
+ md.UnverifiedBody = md.LiteralData.Body
+ }
+
+ return md, nil
+}
+
+// hashForSignature returns a pair of hashes that can be used to verify a
+// signature. The signature may specify that the contents of the signed message
+// should be preprocessed (i.e. to normalize line endings). Thus this function
+// returns two hashes. The second should be used to hash the message itself and
+// performs any needed preprocessing.
+func hashForSignature(hashId crypto.Hash, sigType packet.SignatureType) (hash.Hash, hash.Hash, error) {
+ if !hashId.Available() {
+ return nil, nil, errors.UnsupportedError("hash not available: " + strconv.Itoa(int(hashId)))
+ }
+ h := hashId.New()
+
+ switch sigType {
+ case packet.SigTypeBinary:
+ return h, h, nil
+ case packet.SigTypeText:
+ return h, NewCanonicalTextHash(h), nil
+ }
+
+ return nil, nil, errors.UnsupportedError("unsupported signature type: " + strconv.Itoa(int(sigType)))
+}
+
+// checkReader wraps an io.Reader from a LiteralData packet. When it sees EOF
+// it closes the ReadCloser from any SymmetricallyEncrypted packet to trigger
+// MDC checks.
+type checkReader struct {
+ md *MessageDetails
+}
+
+func (cr checkReader) Read(buf []byte) (n int, err error) {
+ n, err = cr.md.LiteralData.Body.Read(buf)
+ if err == io.EOF {
+ mdcErr := cr.md.decrypted.Close()
+ if mdcErr != nil {
+ err = mdcErr
+ }
+ }
+ return
+}
+
+// signatureCheckReader wraps an io.Reader from a LiteralData packet and hashes
+// the data as it is read. When it sees an EOF from the underlying io.Reader
+// it parses and checks a trailing Signature packet and triggers any MDC checks.
+type signatureCheckReader struct {
+ packets *packet.Reader
+ h, wrappedHash hash.Hash
+ md *MessageDetails
+}
+
+func (scr *signatureCheckReader) Read(buf []byte) (n int, err error) {
+ n, err = scr.md.LiteralData.Body.Read(buf)
+ scr.wrappedHash.Write(buf[:n])
+ if err == io.EOF {
+ var p packet.Packet
+ p, scr.md.SignatureError = scr.packets.Next()
+ if scr.md.SignatureError != nil {
+ return
+ }
+
+ var ok bool
+ if scr.md.Signature, ok = p.(*packet.Signature); ok {
+ scr.md.SignatureError = scr.md.SignedBy.PublicKey.VerifySignature(scr.h, scr.md.Signature)
+ } else if scr.md.SignatureV3, ok = p.(*packet.SignatureV3); ok {
+ scr.md.SignatureError = scr.md.SignedBy.PublicKey.VerifySignatureV3(scr.h, scr.md.SignatureV3)
+ } else {
+ scr.md.SignatureError = errors.StructuralError("LiteralData not followed by Signature")
+ return
+ }
+
+ // The SymmetricallyEncrypted packet, if any, might have an
+ // unsigned hash of its own. In order to check this we need to
+ // close that Reader.
+ if scr.md.decrypted != nil {
+ mdcErr := scr.md.decrypted.Close()
+ if mdcErr != nil {
+ err = mdcErr
+ }
+ }
+ }
+ return
+}
+
+// CheckDetachedSignature takes a signed file and a detached signature and
+// returns the signer if the signature is valid. If the signer isn't known,
+// ErrUnknownIssuer is returned.
+func CheckDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signer *Entity, err error) {
+ var issuerKeyId uint64
+ var hashFunc crypto.Hash
+ var sigType packet.SignatureType
+ var keys []Key
+ var p packet.Packet
+
+ packets := packet.NewReader(signature)
+ for {
+ p, err = packets.Next()
+ if err == io.EOF {
+ return nil, errors.ErrUnknownIssuer
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ switch sig := p.(type) {
+ case *packet.Signature:
+ if sig.IssuerKeyId == nil {
+ return nil, errors.StructuralError("signature doesn't have an issuer")
+ }
+ issuerKeyId = *sig.IssuerKeyId
+ hashFunc = sig.Hash
+ sigType = sig.SigType
+ case *packet.SignatureV3:
+ issuerKeyId = sig.IssuerKeyId
+ hashFunc = sig.Hash
+ sigType = sig.SigType
+ default:
+ return nil, errors.StructuralError("non signature packet found")
+ }
+
+ keys = keyring.KeysByIdUsage(issuerKeyId, packet.KeyFlagSign)
+ if len(keys) > 0 {
+ break
+ }
+ }
+
+ if len(keys) == 0 {
+ panic("unreachable")
+ }
+
+ h, wrappedHash, err := hashForSignature(hashFunc, sigType)
+ if err != nil {
+ return nil, err
+ }
+
+ if _, err := io.Copy(wrappedHash, signed); err != nil && err != io.EOF {
+ return nil, err
+ }
+
+ for _, key := range keys {
+ switch sig := p.(type) {
+ case *packet.Signature:
+ err = key.PublicKey.VerifySignature(h, sig)
+ case *packet.SignatureV3:
+ err = key.PublicKey.VerifySignatureV3(h, sig)
+ default:
+ panic("unreachable")
+ }
+
+ if err == nil {
+ return key.Entity, nil
+ }
+ }
+
+ return nil, err
+}
+
+// CheckArmoredDetachedSignature performs the same actions as
+// CheckDetachedSignature but expects the signature to be armored.
+func CheckArmoredDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signer *Entity, err error) {
+ body, err := readArmored(signature, SignatureType)
+ if err != nil {
+ return
+ }
+
+ return CheckDetachedSignature(keyring, signed, body)
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package s2k implements the various OpenPGP string-to-key transforms as
+// specified in RFC 4800 section 3.7.1.
+package s2k // import "golang.org/x/crypto/openpgp/s2k"
+
+import (
+ "crypto"
+ "hash"
+ "io"
+ "strconv"
+
+ "golang.org/x/crypto/openpgp/errors"
+)
+
+// Config collects configuration parameters for s2k key-stretching
+// transformatioms. A nil *Config is valid and results in all default
+// values. Currently, Config is used only by the Serialize function in
+// this package.
+type Config struct {
+ // Hash is the default hash function to be used. If
+ // nil, SHA1 is used.
+ Hash crypto.Hash
+ // S2KCount is only used for symmetric encryption. It
+ // determines the strength of the passphrase stretching when
+ // the said passphrase is hashed to produce a key. S2KCount
+ // should be between 1024 and 65011712, inclusive. If Config
+ // is nil or S2KCount is 0, the value 65536 used. Not all
+ // values in the above range can be represented. S2KCount will
+ // be rounded up to the next representable value if it cannot
+ // be encoded exactly. When set, it is strongly encrouraged to
+ // use a value that is at least 65536. See RFC 4880 Section
+ // 3.7.1.3.
+ S2KCount int
+}
+
+func (c *Config) hash() crypto.Hash {
+ if c == nil || uint(c.Hash) == 0 {
+ // SHA1 is the historical default in this package.
+ return crypto.SHA1
+ }
+
+ return c.Hash
+}
+
+func (c *Config) encodedCount() uint8 {
+ if c == nil || c.S2KCount == 0 {
+ return 96 // The common case. Correspoding to 65536
+ }
+
+ i := c.S2KCount
+ switch {
+ // Behave like GPG. Should we make 65536 the lowest value used?
+ case i < 1024:
+ i = 1024
+ case i > 65011712:
+ i = 65011712
+ }
+
+ return encodeCount(i)
+}
+
+// encodeCount converts an iterative "count" in the range 1024 to
+// 65011712, inclusive, to an encoded count. The return value is the
+// octet that is actually stored in the GPG file. encodeCount panics
+// if i is not in the above range (encodedCount above takes care to
+// pass i in the correct range). See RFC 4880 Section 3.7.7.1.
+func encodeCount(i int) uint8 {
+ if i < 1024 || i > 65011712 {
+ panic("count arg i outside the required range")
+ }
+
+ for encoded := 0; encoded < 256; encoded++ {
+ count := decodeCount(uint8(encoded))
+ if count >= i {
+ return uint8(encoded)
+ }
+ }
+
+ return 255
+}
+
+// decodeCount returns the s2k mode 3 iterative "count" corresponding to
+// the encoded octet c.
+func decodeCount(c uint8) int {
+ return (16 + int(c&15)) << (uint32(c>>4) + 6)
+}
+
+// Simple writes to out the result of computing the Simple S2K function (RFC
+// 4880, section 3.7.1.1) using the given hash and input passphrase.
+func Simple(out []byte, h hash.Hash, in []byte) {
+ Salted(out, h, in, nil)
+}
+
+var zero [1]byte
+
+// Salted writes to out the result of computing the Salted S2K function (RFC
+// 4880, section 3.7.1.2) using the given hash, input passphrase and salt.
+func Salted(out []byte, h hash.Hash, in []byte, salt []byte) {
+ done := 0
+ var digest []byte
+
+ for i := 0; done < len(out); i++ {
+ h.Reset()
+ for j := 0; j < i; j++ {
+ h.Write(zero[:])
+ }
+ h.Write(salt)
+ h.Write(in)
+ digest = h.Sum(digest[:0])
+ n := copy(out[done:], digest)
+ done += n
+ }
+}
+
+// Iterated writes to out the result of computing the Iterated and Salted S2K
+// function (RFC 4880, section 3.7.1.3) using the given hash, input passphrase,
+// salt and iteration count.
+func Iterated(out []byte, h hash.Hash, in []byte, salt []byte, count int) {
+ combined := make([]byte, len(in)+len(salt))
+ copy(combined, salt)
+ copy(combined[len(salt):], in)
+
+ if count < len(combined) {
+ count = len(combined)
+ }
+
+ done := 0
+ var digest []byte
+ for i := 0; done < len(out); i++ {
+ h.Reset()
+ for j := 0; j < i; j++ {
+ h.Write(zero[:])
+ }
+ written := 0
+ for written < count {
+ if written+len(combined) > count {
+ todo := count - written
+ h.Write(combined[:todo])
+ written = count
+ } else {
+ h.Write(combined)
+ written += len(combined)
+ }
+ }
+ digest = h.Sum(digest[:0])
+ n := copy(out[done:], digest)
+ done += n
+ }
+}
+
+// Parse reads a binary specification for a string-to-key transformation from r
+// and returns a function which performs that transform.
+func Parse(r io.Reader) (f func(out, in []byte), err error) {
+ var buf [9]byte
+
+ _, err = io.ReadFull(r, buf[:2])
+ if err != nil {
+ return
+ }
+
+ hash, ok := HashIdToHash(buf[1])
+ if !ok {
+ return nil, errors.UnsupportedError("hash for S2K function: " + strconv.Itoa(int(buf[1])))
+ }
+ if !hash.Available() {
+ return nil, errors.UnsupportedError("hash not available: " + strconv.Itoa(int(hash)))
+ }
+ h := hash.New()
+
+ switch buf[0] {
+ case 0:
+ f := func(out, in []byte) {
+ Simple(out, h, in)
+ }
+ return f, nil
+ case 1:
+ _, err = io.ReadFull(r, buf[:8])
+ if err != nil {
+ return
+ }
+ f := func(out, in []byte) {
+ Salted(out, h, in, buf[:8])
+ }
+ return f, nil
+ case 3:
+ _, err = io.ReadFull(r, buf[:9])
+ if err != nil {
+ return
+ }
+ count := decodeCount(buf[8])
+ f := func(out, in []byte) {
+ Iterated(out, h, in, buf[:8], count)
+ }
+ return f, nil
+ }
+
+ return nil, errors.UnsupportedError("S2K function")
+}
+
+// Serialize salts and stretches the given passphrase and writes the
+// resulting key into key. It also serializes an S2K descriptor to
+// w. The key stretching can be configured with c, which may be
+// nil. In that case, sensible defaults will be used.
+func Serialize(w io.Writer, key []byte, rand io.Reader, passphrase []byte, c *Config) error {
+ var buf [11]byte
+ buf[0] = 3 /* iterated and salted */
+ buf[1], _ = HashToHashId(c.hash())
+ salt := buf[2:10]
+ if _, err := io.ReadFull(rand, salt); err != nil {
+ return err
+ }
+ encodedCount := c.encodedCount()
+ count := decodeCount(encodedCount)
+ buf[10] = encodedCount
+ if _, err := w.Write(buf[:]); err != nil {
+ return err
+ }
+
+ Iterated(key, c.hash().New(), passphrase, salt, count)
+ return nil
+}
+
+// hashToHashIdMapping contains pairs relating OpenPGP's hash identifier with
+// Go's crypto.Hash type. See RFC 4880, section 9.4.
+var hashToHashIdMapping = []struct {
+ id byte
+ hash crypto.Hash
+ name string
+}{
+ {1, crypto.MD5, "MD5"},
+ {2, crypto.SHA1, "SHA1"},
+ {3, crypto.RIPEMD160, "RIPEMD160"},
+ {8, crypto.SHA256, "SHA256"},
+ {9, crypto.SHA384, "SHA384"},
+ {10, crypto.SHA512, "SHA512"},
+ {11, crypto.SHA224, "SHA224"},
+}
+
+// HashIdToHash returns a crypto.Hash which corresponds to the given OpenPGP
+// hash id.
+func HashIdToHash(id byte) (h crypto.Hash, ok bool) {
+ for _, m := range hashToHashIdMapping {
+ if m.id == id {
+ return m.hash, true
+ }
+ }
+ return 0, false
+}
+
+// HashIdToString returns the name of the hash function corresponding to the
+// given OpenPGP hash id.
+func HashIdToString(id byte) (name string, ok bool) {
+ for _, m := range hashToHashIdMapping {
+ if m.id == id {
+ return m.name, true
+ }
+ }
+
+ return "", false
+}
+
+// HashIdToHash returns an OpenPGP hash id which corresponds the given Hash.
+func HashToHashId(h crypto.Hash) (id byte, ok bool) {
+ for _, m := range hashToHashIdMapping {
+ if m.hash == h {
+ return m.id, true
+ }
+ }
+ return 0, false
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package openpgp
+
+import (
+ "crypto"
+ "hash"
+ "io"
+ "strconv"
+ "time"
+
+ "golang.org/x/crypto/openpgp/armor"
+ "golang.org/x/crypto/openpgp/errors"
+ "golang.org/x/crypto/openpgp/packet"
+ "golang.org/x/crypto/openpgp/s2k"
+)
+
+// DetachSign signs message with the private key from signer (which must
+// already have been decrypted) and writes the signature to w.
+// If config is nil, sensible defaults will be used.
+func DetachSign(w io.Writer, signer *Entity, message io.Reader, config *packet.Config) error {
+ return detachSign(w, signer, message, packet.SigTypeBinary, config)
+}
+
+// ArmoredDetachSign signs message with the private key from signer (which
+// must already have been decrypted) and writes an armored signature to w.
+// If config is nil, sensible defaults will be used.
+func ArmoredDetachSign(w io.Writer, signer *Entity, message io.Reader, config *packet.Config) (err error) {
+ return armoredDetachSign(w, signer, message, packet.SigTypeBinary, config)
+}
+
+// DetachSignText signs message (after canonicalising the line endings) with
+// the private key from signer (which must already have been decrypted) and
+// writes the signature to w.
+// If config is nil, sensible defaults will be used.
+func DetachSignText(w io.Writer, signer *Entity, message io.Reader, config *packet.Config) error {
+ return detachSign(w, signer, message, packet.SigTypeText, config)
+}
+
+// ArmoredDetachSignText signs message (after canonicalising the line endings)
+// with the private key from signer (which must already have been decrypted)
+// and writes an armored signature to w.
+// If config is nil, sensible defaults will be used.
+func ArmoredDetachSignText(w io.Writer, signer *Entity, message io.Reader, config *packet.Config) error {
+ return armoredDetachSign(w, signer, message, packet.SigTypeText, config)
+}
+
+func armoredDetachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.SignatureType, config *packet.Config) (err error) {
+ out, err := armor.Encode(w, SignatureType, nil)
+ if err != nil {
+ return
+ }
+ err = detachSign(out, signer, message, sigType, config)
+ if err != nil {
+ return
+ }
+ return out.Close()
+}
+
+func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.SignatureType, config *packet.Config) (err error) {
+ if signer.PrivateKey == nil {
+ return errors.InvalidArgumentError("signing key doesn't have a private key")
+ }
+ if signer.PrivateKey.Encrypted {
+ return errors.InvalidArgumentError("signing key is encrypted")
+ }
+
+ sig := new(packet.Signature)
+ sig.SigType = sigType
+ sig.PubKeyAlgo = signer.PrivateKey.PubKeyAlgo
+ sig.Hash = config.Hash()
+ sig.CreationTime = config.Now()
+ sig.IssuerKeyId = &signer.PrivateKey.KeyId
+
+ h, wrappedHash, err := hashForSignature(sig.Hash, sig.SigType)
+ if err != nil {
+ return
+ }
+ io.Copy(wrappedHash, message)
+
+ err = sig.Sign(h, signer.PrivateKey, config)
+ if err != nil {
+ return
+ }
+
+ return sig.Serialize(w)
+}
+
+// FileHints contains metadata about encrypted files. This metadata is, itself,
+// encrypted.
+type FileHints struct {
+ // IsBinary can be set to hint that the contents are binary data.
+ IsBinary bool
+ // FileName hints at the name of the file that should be written. It's
+ // truncated to 255 bytes if longer. It may be empty to suggest that the
+ // file should not be written to disk. It may be equal to "_CONSOLE" to
+ // suggest the data should not be written to disk.
+ FileName string
+ // ModTime contains the modification time of the file, or the zero time if not applicable.
+ ModTime time.Time
+}
+
+// SymmetricallyEncrypt acts like gpg -c: it encrypts a file with a passphrase.
+// The resulting WriteCloser must be closed after the contents of the file have
+// been written.
+// If config is nil, sensible defaults will be used.
+func SymmetricallyEncrypt(ciphertext io.Writer, passphrase []byte, hints *FileHints, config *packet.Config) (plaintext io.WriteCloser, err error) {
+ if hints == nil {
+ hints = &FileHints{}
+ }
+
+ key, err := packet.SerializeSymmetricKeyEncrypted(ciphertext, passphrase, config)
+ if err != nil {
+ return
+ }
+ w, err := packet.SerializeSymmetricallyEncrypted(ciphertext, config.Cipher(), key, config)
+ if err != nil {
+ return
+ }
+
+ literaldata := w
+ if algo := config.Compression(); algo != packet.CompressionNone {
+ var compConfig *packet.CompressionConfig
+ if config != nil {
+ compConfig = config.CompressionConfig
+ }
+ literaldata, err = packet.SerializeCompressed(w, algo, compConfig)
+ if err != nil {
+ return
+ }
+ }
+
+ var epochSeconds uint32
+ if !hints.ModTime.IsZero() {
+ epochSeconds = uint32(hints.ModTime.Unix())
+ }
+ return packet.SerializeLiteral(literaldata, hints.IsBinary, hints.FileName, epochSeconds)
+}
+
+// intersectPreferences mutates and returns a prefix of a that contains only
+// the values in the intersection of a and b. The order of a is preserved.
+func intersectPreferences(a []uint8, b []uint8) (intersection []uint8) {
+ var j int
+ for _, v := range a {
+ for _, v2 := range b {
+ if v == v2 {
+ a[j] = v
+ j++
+ break
+ }
+ }
+ }
+
+ return a[:j]
+}
+
+func hashToHashId(h crypto.Hash) uint8 {
+ v, ok := s2k.HashToHashId(h)
+ if !ok {
+ panic("tried to convert unknown hash")
+ }
+ return v
+}
+
+// Encrypt encrypts a message to a number of recipients and, optionally, signs
+// it. hints contains optional information, that is also encrypted, that aids
+// the recipients in processing the message. The resulting WriteCloser must
+// be closed after the contents of the file have been written.
+// If config is nil, sensible defaults will be used.
+func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHints, config *packet.Config) (plaintext io.WriteCloser, err error) {
+ var signer *packet.PrivateKey
+ if signed != nil {
+ signKey, ok := signed.signingKey(config.Now())
+ if !ok {
+ return nil, errors.InvalidArgumentError("no valid signing keys")
+ }
+ signer = signKey.PrivateKey
+ if signer == nil {
+ return nil, errors.InvalidArgumentError("no private key in signing key")
+ }
+ if signer.Encrypted {
+ return nil, errors.InvalidArgumentError("signing key must be decrypted")
+ }
+ }
+
+ // These are the possible ciphers that we'll use for the message.
+ candidateCiphers := []uint8{
+ uint8(packet.CipherAES128),
+ uint8(packet.CipherAES256),
+ uint8(packet.CipherCAST5),
+ }
+ // These are the possible hash functions that we'll use for the signature.
+ candidateHashes := []uint8{
+ hashToHashId(crypto.SHA256),
+ hashToHashId(crypto.SHA512),
+ hashToHashId(crypto.SHA1),
+ hashToHashId(crypto.RIPEMD160),
+ }
+ // In the event that a recipient doesn't specify any supported ciphers
+ // or hash functions, these are the ones that we assume that every
+ // implementation supports.
+ defaultCiphers := candidateCiphers[len(candidateCiphers)-1:]
+ defaultHashes := candidateHashes[len(candidateHashes)-1:]
+
+ encryptKeys := make([]Key, len(to))
+ for i := range to {
+ var ok bool
+ encryptKeys[i], ok = to[i].encryptionKey(config.Now())
+ if !ok {
+ return nil, errors.InvalidArgumentError("cannot encrypt a message to key id " + strconv.FormatUint(to[i].PrimaryKey.KeyId, 16) + " because it has no encryption keys")
+ }
+
+ sig := to[i].primaryIdentity().SelfSignature
+
+ preferredSymmetric := sig.PreferredSymmetric
+ if len(preferredSymmetric) == 0 {
+ preferredSymmetric = defaultCiphers
+ }
+ preferredHashes := sig.PreferredHash
+ if len(preferredHashes) == 0 {
+ preferredHashes = defaultHashes
+ }
+ candidateCiphers = intersectPreferences(candidateCiphers, preferredSymmetric)
+ candidateHashes = intersectPreferences(candidateHashes, preferredHashes)
+ }
+
+ if len(candidateCiphers) == 0 || len(candidateHashes) == 0 {
+ return nil, errors.InvalidArgumentError("cannot encrypt because recipient set shares no common algorithms")
+ }
+
+ cipher := packet.CipherFunction(candidateCiphers[0])
+ // If the cipher specified by config is a candidate, we'll use that.
+ configuredCipher := config.Cipher()
+ for _, c := range candidateCiphers {
+ cipherFunc := packet.CipherFunction(c)
+ if cipherFunc == configuredCipher {
+ cipher = cipherFunc
+ break
+ }
+ }
+
+ var hash crypto.Hash
+ for _, hashId := range candidateHashes {
+ if h, ok := s2k.HashIdToHash(hashId); ok && h.Available() {
+ hash = h
+ break
+ }
+ }
+
+ // If the hash specified by config is a candidate, we'll use that.
+ if configuredHash := config.Hash(); configuredHash.Available() {
+ for _, hashId := range candidateHashes {
+ if h, ok := s2k.HashIdToHash(hashId); ok && h == configuredHash {
+ hash = h
+ break
+ }
+ }
+ }
+
+ if hash == 0 {
+ hashId := candidateHashes[0]
+ name, ok := s2k.HashIdToString(hashId)
+ if !ok {
+ name = "#" + strconv.Itoa(int(hashId))
+ }
+ return nil, errors.InvalidArgumentError("cannot encrypt because no candidate hash functions are compiled in. (Wanted " + name + " in this case.)")
+ }
+
+ symKey := make([]byte, cipher.KeySize())
+ if _, err := io.ReadFull(config.Random(), symKey); err != nil {
+ return nil, err
+ }
+
+ for _, key := range encryptKeys {
+ if err := packet.SerializeEncryptedKey(ciphertext, key.PublicKey, cipher, symKey, config); err != nil {
+ return nil, err
+ }
+ }
+
+ encryptedData, err := packet.SerializeSymmetricallyEncrypted(ciphertext, cipher, symKey, config)
+ if err != nil {
+ return
+ }
+
+ if signer != nil {
+ ops := &packet.OnePassSignature{
+ SigType: packet.SigTypeBinary,
+ Hash: hash,
+ PubKeyAlgo: signer.PubKeyAlgo,
+ KeyId: signer.KeyId,
+ IsLast: true,
+ }
+ if err := ops.Serialize(encryptedData); err != nil {
+ return nil, err
+ }
+ }
+
+ if hints == nil {
+ hints = &FileHints{}
+ }
+
+ w := encryptedData
+ if signer != nil {
+ // If we need to write a signature packet after the literal
+ // data then we need to stop literalData from closing
+ // encryptedData.
+ w = noOpCloser{encryptedData}
+
+ }
+ var epochSeconds uint32
+ if !hints.ModTime.IsZero() {
+ epochSeconds = uint32(hints.ModTime.Unix())
+ }
+ literalData, err := packet.SerializeLiteral(w, hints.IsBinary, hints.FileName, epochSeconds)
+ if err != nil {
+ return nil, err
+ }
+
+ if signer != nil {
+ return signatureWriter{encryptedData, literalData, hash, hash.New(), signer, config}, nil
+ }
+ return literalData, nil
+}
+
+// signatureWriter hashes the contents of a message while passing it along to
+// literalData. When closed, it closes literalData, writes a signature packet
+// to encryptedData and then also closes encryptedData.
+type signatureWriter struct {
+ encryptedData io.WriteCloser
+ literalData io.WriteCloser
+ hashType crypto.Hash
+ h hash.Hash
+ signer *packet.PrivateKey
+ config *packet.Config
+}
+
+func (s signatureWriter) Write(data []byte) (int, error) {
+ s.h.Write(data)
+ return s.literalData.Write(data)
+}
+
+func (s signatureWriter) Close() error {
+ sig := &packet.Signature{
+ SigType: packet.SigTypeBinary,
+ PubKeyAlgo: s.signer.PubKeyAlgo,
+ Hash: s.hashType,
+ CreationTime: s.config.Now(),
+ IssuerKeyId: &s.signer.KeyId,
+ }
+
+ if err := sig.Sign(s.h, s.signer, s.config); err != nil {
+ return err
+ }
+ if err := s.literalData.Close(); err != nil {
+ return err
+ }
+ if err := sig.Serialize(s.encryptedData); err != nil {
+ return err
+ }
+ return s.encryptedData.Close()
+}
+
+// noOpCloser is like an ioutil.NopCloser, but for an io.Writer.
+// TODO: we have two of these in OpenPGP packages alone. This probably needs
+// to be promoted somewhere more common.
+type noOpCloser struct {
+ w io.Writer
+}
+
+func (c noOpCloser) Write(data []byte) (n int, err error) {
+ return c.w.Write(data)
+}
+
+func (c noOpCloser) Close() error {
+ return nil
+}
--- /dev/null
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package agent implements the ssh-agent protocol, and provides both
+// a client and a server. The client can talk to a standard ssh-agent
+// that uses UNIX sockets, and one could implement an alternative
+// ssh-agent process using the sample server.
+//
+// References:
+// [PROTOCOL.agent]: https://tools.ietf.org/html/draft-miller-ssh-agent-00
+package agent // import "golang.org/x/crypto/ssh/agent"
+
+import (
+ "bytes"
+ "crypto/dsa"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rsa"
+ "encoding/base64"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "math/big"
+ "sync"
+
+ "golang.org/x/crypto/ed25519"
+ "golang.org/x/crypto/ssh"
+)
+
+// Agent represents the capabilities of an ssh-agent.
+type Agent interface {
+ // List returns the identities known to the agent.
+ List() ([]*Key, error)
+
+ // Sign has the agent sign the data using a protocol 2 key as defined
+ // in [PROTOCOL.agent] section 2.6.2.
+ Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error)
+
+ // Add adds a private key to the agent.
+ Add(key AddedKey) error
+
+ // Remove removes all identities with the given public key.
+ Remove(key ssh.PublicKey) error
+
+ // RemoveAll removes all identities.
+ RemoveAll() error
+
+ // Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
+ Lock(passphrase []byte) error
+
+ // Unlock undoes the effect of Lock
+ Unlock(passphrase []byte) error
+
+ // Signers returns signers for all the known keys.
+ Signers() ([]ssh.Signer, error)
+}
+
+// ConstraintExtension describes an optional constraint defined by users.
+type ConstraintExtension struct {
+ // ExtensionName consist of a UTF-8 string suffixed by the
+ // implementation domain following the naming scheme defined
+ // in Section 4.2 of [RFC4251], e.g. "foo@example.com".
+ ExtensionName string
+ // ExtensionDetails contains the actual content of the extended
+ // constraint.
+ ExtensionDetails []byte
+}
+
+// AddedKey describes an SSH key to be added to an Agent.
+type AddedKey struct {
+ // PrivateKey must be a *rsa.PrivateKey, *dsa.PrivateKey or
+ // *ecdsa.PrivateKey, which will be inserted into the agent.
+ PrivateKey interface{}
+ // Certificate, if not nil, is communicated to the agent and will be
+ // stored with the key.
+ Certificate *ssh.Certificate
+ // Comment is an optional, free-form string.
+ Comment string
+ // LifetimeSecs, if not zero, is the number of seconds that the
+ // agent will store the key for.
+ LifetimeSecs uint32
+ // ConfirmBeforeUse, if true, requests that the agent confirm with the
+ // user before each use of this key.
+ ConfirmBeforeUse bool
+ // ConstraintExtensions are the experimental or private-use constraints
+ // defined by users.
+ ConstraintExtensions []ConstraintExtension
+}
+
+// See [PROTOCOL.agent], section 3.
+const (
+ agentRequestV1Identities = 1
+ agentRemoveAllV1Identities = 9
+
+ // 3.2 Requests from client to agent for protocol 2 key operations
+ agentAddIdentity = 17
+ agentRemoveIdentity = 18
+ agentRemoveAllIdentities = 19
+ agentAddIDConstrained = 25
+
+ // 3.3 Key-type independent requests from client to agent
+ agentAddSmartcardKey = 20
+ agentRemoveSmartcardKey = 21
+ agentLock = 22
+ agentUnlock = 23
+ agentAddSmartcardKeyConstrained = 26
+
+ // 3.7 Key constraint identifiers
+ agentConstrainLifetime = 1
+ agentConstrainConfirm = 2
+ agentConstrainExtension = 3
+)
+
+// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
+// is a sanity check, not a limit in the spec.
+const maxAgentResponseBytes = 16 << 20
+
+// Agent messages:
+// These structures mirror the wire format of the corresponding ssh agent
+// messages found in [PROTOCOL.agent].
+
+// 3.4 Generic replies from agent to client
+const agentFailure = 5
+
+type failureAgentMsg struct{}
+
+const agentSuccess = 6
+
+type successAgentMsg struct{}
+
+// See [PROTOCOL.agent], section 2.5.2.
+const agentRequestIdentities = 11
+
+type requestIdentitiesAgentMsg struct{}
+
+// See [PROTOCOL.agent], section 2.5.2.
+const agentIdentitiesAnswer = 12
+
+type identitiesAnswerAgentMsg struct {
+ NumKeys uint32 `sshtype:"12"`
+ Keys []byte `ssh:"rest"`
+}
+
+// See [PROTOCOL.agent], section 2.6.2.
+const agentSignRequest = 13
+
+type signRequestAgentMsg struct {
+ KeyBlob []byte `sshtype:"13"`
+ Data []byte
+ Flags uint32
+}
+
+// See [PROTOCOL.agent], section 2.6.2.
+
+// 3.6 Replies from agent to client for protocol 2 key operations
+const agentSignResponse = 14
+
+type signResponseAgentMsg struct {
+ SigBlob []byte `sshtype:"14"`
+}
+
+type publicKey struct {
+ Format string
+ Rest []byte `ssh:"rest"`
+}
+
+// 3.7 Key constraint identifiers
+type constrainLifetimeAgentMsg struct {
+ LifetimeSecs uint32 `sshtype:"1"`
+}
+
+type constrainExtensionAgentMsg struct {
+ ExtensionName string `sshtype:"3"`
+ ExtensionDetails []byte
+
+ // Rest is a field used for parsing, not part of message
+ Rest []byte `ssh:"rest"`
+}
+
+// Key represents a protocol 2 public key as defined in
+// [PROTOCOL.agent], section 2.5.2.
+type Key struct {
+ Format string
+ Blob []byte
+ Comment string
+}
+
+func clientErr(err error) error {
+ return fmt.Errorf("agent: client error: %v", err)
+}
+
+// String returns the storage form of an agent key with the format, base64
+// encoded serialized key, and the comment if it is not empty.
+func (k *Key) String() string {
+ s := string(k.Format) + " " + base64.StdEncoding.EncodeToString(k.Blob)
+
+ if k.Comment != "" {
+ s += " " + k.Comment
+ }
+
+ return s
+}
+
+// Type returns the public key type.
+func (k *Key) Type() string {
+ return k.Format
+}
+
+// Marshal returns key blob to satisfy the ssh.PublicKey interface.
+func (k *Key) Marshal() []byte {
+ return k.Blob
+}
+
+// Verify satisfies the ssh.PublicKey interface.
+func (k *Key) Verify(data []byte, sig *ssh.Signature) error {
+ pubKey, err := ssh.ParsePublicKey(k.Blob)
+ if err != nil {
+ return fmt.Errorf("agent: bad public key: %v", err)
+ }
+ return pubKey.Verify(data, sig)
+}
+
+type wireKey struct {
+ Format string
+ Rest []byte `ssh:"rest"`
+}
+
+func parseKey(in []byte) (out *Key, rest []byte, err error) {
+ var record struct {
+ Blob []byte
+ Comment string
+ Rest []byte `ssh:"rest"`
+ }
+
+ if err := ssh.Unmarshal(in, &record); err != nil {
+ return nil, nil, err
+ }
+
+ var wk wireKey
+ if err := ssh.Unmarshal(record.Blob, &wk); err != nil {
+ return nil, nil, err
+ }
+
+ return &Key{
+ Format: wk.Format,
+ Blob: record.Blob,
+ Comment: record.Comment,
+ }, record.Rest, nil
+}
+
+// client is a client for an ssh-agent process.
+type client struct {
+ // conn is typically a *net.UnixConn
+ conn io.ReadWriter
+ // mu is used to prevent concurrent access to the agent
+ mu sync.Mutex
+}
+
+// NewClient returns an Agent that talks to an ssh-agent process over
+// the given connection.
+func NewClient(rw io.ReadWriter) Agent {
+ return &client{conn: rw}
+}
+
+// call sends an RPC to the agent. On success, the reply is
+// unmarshaled into reply and replyType is set to the first byte of
+// the reply, which contains the type of the message.
+func (c *client) call(req []byte) (reply interface{}, err error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ msg := make([]byte, 4+len(req))
+ binary.BigEndian.PutUint32(msg, uint32(len(req)))
+ copy(msg[4:], req)
+ if _, err = c.conn.Write(msg); err != nil {
+ return nil, clientErr(err)
+ }
+
+ var respSizeBuf [4]byte
+ if _, err = io.ReadFull(c.conn, respSizeBuf[:]); err != nil {
+ return nil, clientErr(err)
+ }
+ respSize := binary.BigEndian.Uint32(respSizeBuf[:])
+ if respSize > maxAgentResponseBytes {
+ return nil, clientErr(err)
+ }
+
+ buf := make([]byte, respSize)
+ if _, err = io.ReadFull(c.conn, buf); err != nil {
+ return nil, clientErr(err)
+ }
+ reply, err = unmarshal(buf)
+ if err != nil {
+ return nil, clientErr(err)
+ }
+ return reply, err
+}
+
+func (c *client) simpleCall(req []byte) error {
+ resp, err := c.call(req)
+ if err != nil {
+ return err
+ }
+ if _, ok := resp.(*successAgentMsg); ok {
+ return nil
+ }
+ return errors.New("agent: failure")
+}
+
+func (c *client) RemoveAll() error {
+ return c.simpleCall([]byte{agentRemoveAllIdentities})
+}
+
+func (c *client) Remove(key ssh.PublicKey) error {
+ req := ssh.Marshal(&agentRemoveIdentityMsg{
+ KeyBlob: key.Marshal(),
+ })
+ return c.simpleCall(req)
+}
+
+func (c *client) Lock(passphrase []byte) error {
+ req := ssh.Marshal(&agentLockMsg{
+ Passphrase: passphrase,
+ })
+ return c.simpleCall(req)
+}
+
+func (c *client) Unlock(passphrase []byte) error {
+ req := ssh.Marshal(&agentUnlockMsg{
+ Passphrase: passphrase,
+ })
+ return c.simpleCall(req)
+}
+
+// List returns the identities known to the agent.
+func (c *client) List() ([]*Key, error) {
+ // see [PROTOCOL.agent] section 2.5.2.
+ req := []byte{agentRequestIdentities}
+
+ msg, err := c.call(req)
+ if err != nil {
+ return nil, err
+ }
+
+ switch msg := msg.(type) {
+ case *identitiesAnswerAgentMsg:
+ if msg.NumKeys > maxAgentResponseBytes/8 {
+ return nil, errors.New("agent: too many keys in agent reply")
+ }
+ keys := make([]*Key, msg.NumKeys)
+ data := msg.Keys
+ for i := uint32(0); i < msg.NumKeys; i++ {
+ var key *Key
+ var err error
+ if key, data, err = parseKey(data); err != nil {
+ return nil, err
+ }
+ keys[i] = key
+ }
+ return keys, nil
+ case *failureAgentMsg:
+ return nil, errors.New("agent: failed to list keys")
+ }
+ panic("unreachable")
+}
+
+// Sign has the agent sign the data using a protocol 2 key as defined
+// in [PROTOCOL.agent] section 2.6.2.
+func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
+ req := ssh.Marshal(signRequestAgentMsg{
+ KeyBlob: key.Marshal(),
+ Data: data,
+ })
+
+ msg, err := c.call(req)
+ if err != nil {
+ return nil, err
+ }
+
+ switch msg := msg.(type) {
+ case *signResponseAgentMsg:
+ var sig ssh.Signature
+ if err := ssh.Unmarshal(msg.SigBlob, &sig); err != nil {
+ return nil, err
+ }
+
+ return &sig, nil
+ case *failureAgentMsg:
+ return nil, errors.New("agent: failed to sign challenge")
+ }
+ panic("unreachable")
+}
+
+// unmarshal parses an agent message in packet, returning the parsed
+// form and the message type of packet.
+func unmarshal(packet []byte) (interface{}, error) {
+ if len(packet) < 1 {
+ return nil, errors.New("agent: empty packet")
+ }
+ var msg interface{}
+ switch packet[0] {
+ case agentFailure:
+ return new(failureAgentMsg), nil
+ case agentSuccess:
+ return new(successAgentMsg), nil
+ case agentIdentitiesAnswer:
+ msg = new(identitiesAnswerAgentMsg)
+ case agentSignResponse:
+ msg = new(signResponseAgentMsg)
+ case agentV1IdentitiesAnswer:
+ msg = new(agentV1IdentityMsg)
+ default:
+ return nil, fmt.Errorf("agent: unknown type tag %d", packet[0])
+ }
+ if err := ssh.Unmarshal(packet, msg); err != nil {
+ return nil, err
+ }
+ return msg, nil
+}
+
+type rsaKeyMsg struct {
+ Type string `sshtype:"17|25"`
+ N *big.Int
+ E *big.Int
+ D *big.Int
+ Iqmp *big.Int // IQMP = Inverse Q Mod P
+ P *big.Int
+ Q *big.Int
+ Comments string
+ Constraints []byte `ssh:"rest"`
+}
+
+type dsaKeyMsg struct {
+ Type string `sshtype:"17|25"`
+ P *big.Int
+ Q *big.Int
+ G *big.Int
+ Y *big.Int
+ X *big.Int
+ Comments string
+ Constraints []byte `ssh:"rest"`
+}
+
+type ecdsaKeyMsg struct {
+ Type string `sshtype:"17|25"`
+ Curve string
+ KeyBytes []byte
+ D *big.Int
+ Comments string
+ Constraints []byte `ssh:"rest"`
+}
+
+type ed25519KeyMsg struct {
+ Type string `sshtype:"17|25"`
+ Pub []byte
+ Priv []byte
+ Comments string
+ Constraints []byte `ssh:"rest"`
+}
+
+// Insert adds a private key to the agent.
+func (c *client) insertKey(s interface{}, comment string, constraints []byte) error {
+ var req []byte
+ switch k := s.(type) {
+ case *rsa.PrivateKey:
+ if len(k.Primes) != 2 {
+ return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes))
+ }
+ k.Precompute()
+ req = ssh.Marshal(rsaKeyMsg{
+ Type: ssh.KeyAlgoRSA,
+ N: k.N,
+ E: big.NewInt(int64(k.E)),
+ D: k.D,
+ Iqmp: k.Precomputed.Qinv,
+ P: k.Primes[0],
+ Q: k.Primes[1],
+ Comments: comment,
+ Constraints: constraints,
+ })
+ case *dsa.PrivateKey:
+ req = ssh.Marshal(dsaKeyMsg{
+ Type: ssh.KeyAlgoDSA,
+ P: k.P,
+ Q: k.Q,
+ G: k.G,
+ Y: k.Y,
+ X: k.X,
+ Comments: comment,
+ Constraints: constraints,
+ })
+ case *ecdsa.PrivateKey:
+ nistID := fmt.Sprintf("nistp%d", k.Params().BitSize)
+ req = ssh.Marshal(ecdsaKeyMsg{
+ Type: "ecdsa-sha2-" + nistID,
+ Curve: nistID,
+ KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y),
+ D: k.D,
+ Comments: comment,
+ Constraints: constraints,
+ })
+ case *ed25519.PrivateKey:
+ req = ssh.Marshal(ed25519KeyMsg{
+ Type: ssh.KeyAlgoED25519,
+ Pub: []byte(*k)[32:],
+ Priv: []byte(*k),
+ Comments: comment,
+ Constraints: constraints,
+ })
+ default:
+ return fmt.Errorf("agent: unsupported key type %T", s)
+ }
+
+ // if constraints are present then the message type needs to be changed.
+ if len(constraints) != 0 {
+ req[0] = agentAddIDConstrained
+ }
+
+ resp, err := c.call(req)
+ if err != nil {
+ return err
+ }
+ if _, ok := resp.(*successAgentMsg); ok {
+ return nil
+ }
+ return errors.New("agent: failure")
+}
+
+type rsaCertMsg struct {
+ Type string `sshtype:"17|25"`
+ CertBytes []byte
+ D *big.Int
+ Iqmp *big.Int // IQMP = Inverse Q Mod P
+ P *big.Int
+ Q *big.Int
+ Comments string
+ Constraints []byte `ssh:"rest"`
+}
+
+type dsaCertMsg struct {
+ Type string `sshtype:"17|25"`
+ CertBytes []byte
+ X *big.Int
+ Comments string
+ Constraints []byte `ssh:"rest"`
+}
+
+type ecdsaCertMsg struct {
+ Type string `sshtype:"17|25"`
+ CertBytes []byte
+ D *big.Int
+ Comments string
+ Constraints []byte `ssh:"rest"`
+}
+
+type ed25519CertMsg struct {
+ Type string `sshtype:"17|25"`
+ CertBytes []byte
+ Pub []byte
+ Priv []byte
+ Comments string
+ Constraints []byte `ssh:"rest"`
+}
+
+// Add adds a private key to the agent. If a certificate is given,
+// that certificate is added instead as public key.
+func (c *client) Add(key AddedKey) error {
+ var constraints []byte
+
+ if secs := key.LifetimeSecs; secs != 0 {
+ constraints = append(constraints, ssh.Marshal(constrainLifetimeAgentMsg{secs})...)
+ }
+
+ if key.ConfirmBeforeUse {
+ constraints = append(constraints, agentConstrainConfirm)
+ }
+
+ cert := key.Certificate
+ if cert == nil {
+ return c.insertKey(key.PrivateKey, key.Comment, constraints)
+ }
+ return c.insertCert(key.PrivateKey, cert, key.Comment, constraints)
+}
+
+func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error {
+ var req []byte
+ switch k := s.(type) {
+ case *rsa.PrivateKey:
+ if len(k.Primes) != 2 {
+ return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes))
+ }
+ k.Precompute()
+ req = ssh.Marshal(rsaCertMsg{
+ Type: cert.Type(),
+ CertBytes: cert.Marshal(),
+ D: k.D,
+ Iqmp: k.Precomputed.Qinv,
+ P: k.Primes[0],
+ Q: k.Primes[1],
+ Comments: comment,
+ Constraints: constraints,
+ })
+ case *dsa.PrivateKey:
+ req = ssh.Marshal(dsaCertMsg{
+ Type: cert.Type(),
+ CertBytes: cert.Marshal(),
+ X: k.X,
+ Comments: comment,
+ Constraints: constraints,
+ })
+ case *ecdsa.PrivateKey:
+ req = ssh.Marshal(ecdsaCertMsg{
+ Type: cert.Type(),
+ CertBytes: cert.Marshal(),
+ D: k.D,
+ Comments: comment,
+ Constraints: constraints,
+ })
+ case *ed25519.PrivateKey:
+ req = ssh.Marshal(ed25519CertMsg{
+ Type: cert.Type(),
+ CertBytes: cert.Marshal(),
+ Pub: []byte(*k)[32:],
+ Priv: []byte(*k),
+ Comments: comment,
+ Constraints: constraints,
+ })
+ default:
+ return fmt.Errorf("agent: unsupported key type %T", s)
+ }
+
+ // if constraints are present then the message type needs to be changed.
+ if len(constraints) != 0 {
+ req[0] = agentAddIDConstrained
+ }
+
+ signer, err := ssh.NewSignerFromKey(s)
+ if err != nil {
+ return err
+ }
+ if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 {
+ return errors.New("agent: signer and cert have different public key")
+ }
+
+ resp, err := c.call(req)
+ if err != nil {
+ return err
+ }
+ if _, ok := resp.(*successAgentMsg); ok {
+ return nil
+ }
+ return errors.New("agent: failure")
+}
+
+// Signers provides a callback for client authentication.
+func (c *client) Signers() ([]ssh.Signer, error) {
+ keys, err := c.List()
+ if err != nil {
+ return nil, err
+ }
+
+ var result []ssh.Signer
+ for _, k := range keys {
+ result = append(result, &agentKeyringSigner{c, k})
+ }
+ return result, nil
+}
+
+type agentKeyringSigner struct {
+ agent *client
+ pub ssh.PublicKey
+}
+
+func (s *agentKeyringSigner) PublicKey() ssh.PublicKey {
+ return s.pub
+}
+
+func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
+ // The agent has its own entropy source, so the rand argument is ignored.
+ return s.agent.Sign(s.pub, data)
+}
--- /dev/null
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package agent
+
+import (
+ "errors"
+ "io"
+ "net"
+ "sync"
+
+ "golang.org/x/crypto/ssh"
+)
+
+// RequestAgentForwarding sets up agent forwarding for the session.
+// ForwardToAgent or ForwardToRemote should be called to route
+// the authentication requests.
+func RequestAgentForwarding(session *ssh.Session) error {
+ ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil)
+ if err != nil {
+ return err
+ }
+ if !ok {
+ return errors.New("forwarding request denied")
+ }
+ return nil
+}
+
+// ForwardToAgent routes authentication requests to the given keyring.
+func ForwardToAgent(client *ssh.Client, keyring Agent) error {
+ channels := client.HandleChannelOpen(channelType)
+ if channels == nil {
+ return errors.New("agent: already have handler for " + channelType)
+ }
+
+ go func() {
+ for ch := range channels {
+ channel, reqs, err := ch.Accept()
+ if err != nil {
+ continue
+ }
+ go ssh.DiscardRequests(reqs)
+ go func() {
+ ServeAgent(keyring, channel)
+ channel.Close()
+ }()
+ }
+ }()
+ return nil
+}
+
+const channelType = "auth-agent@openssh.com"
+
+// ForwardToRemote routes authentication requests to the ssh-agent
+// process serving on the given unix socket.
+func ForwardToRemote(client *ssh.Client, addr string) error {
+ channels := client.HandleChannelOpen(channelType)
+ if channels == nil {
+ return errors.New("agent: already have handler for " + channelType)
+ }
+ conn, err := net.Dial("unix", addr)
+ if err != nil {
+ return err
+ }
+ conn.Close()
+
+ go func() {
+ for ch := range channels {
+ channel, reqs, err := ch.Accept()
+ if err != nil {
+ continue
+ }
+ go ssh.DiscardRequests(reqs)
+ go forwardUnixSocket(channel, addr)
+ }
+ }()
+ return nil
+}
+
+func forwardUnixSocket(channel ssh.Channel, addr string) {
+ conn, err := net.Dial("unix", addr)
+ if err != nil {
+ return
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ io.Copy(conn, channel)
+ conn.(*net.UnixConn).CloseWrite()
+ wg.Done()
+ }()
+ go func() {
+ io.Copy(channel, conn)
+ channel.CloseWrite()
+ wg.Done()
+ }()
+
+ wg.Wait()
+ conn.Close()
+ channel.Close()
+}
--- /dev/null
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package agent
+
+import (
+ "bytes"
+ "crypto/rand"
+ "crypto/subtle"
+ "errors"
+ "fmt"
+ "sync"
+ "time"
+
+ "golang.org/x/crypto/ssh"
+)
+
+type privKey struct {
+ signer ssh.Signer
+ comment string
+ expire *time.Time
+}
+
+type keyring struct {
+ mu sync.Mutex
+ keys []privKey
+
+ locked bool
+ passphrase []byte
+}
+
+var errLocked = errors.New("agent: locked")
+
+// NewKeyring returns an Agent that holds keys in memory. It is safe
+// for concurrent use by multiple goroutines.
+func NewKeyring() Agent {
+ return &keyring{}
+}
+
+// RemoveAll removes all identities.
+func (r *keyring) RemoveAll() error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.locked {
+ return errLocked
+ }
+
+ r.keys = nil
+ return nil
+}
+
+// removeLocked does the actual key removal. The caller must already be holding the
+// keyring mutex.
+func (r *keyring) removeLocked(want []byte) error {
+ found := false
+ for i := 0; i < len(r.keys); {
+ if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) {
+ found = true
+ r.keys[i] = r.keys[len(r.keys)-1]
+ r.keys = r.keys[:len(r.keys)-1]
+ continue
+ } else {
+ i++
+ }
+ }
+
+ if !found {
+ return errors.New("agent: key not found")
+ }
+ return nil
+}
+
+// Remove removes all identities with the given public key.
+func (r *keyring) Remove(key ssh.PublicKey) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.locked {
+ return errLocked
+ }
+
+ return r.removeLocked(key.Marshal())
+}
+
+// Lock locks the agent. Sign and Remove will fail, and List will return an empty list.
+func (r *keyring) Lock(passphrase []byte) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.locked {
+ return errLocked
+ }
+
+ r.locked = true
+ r.passphrase = passphrase
+ return nil
+}
+
+// Unlock undoes the effect of Lock
+func (r *keyring) Unlock(passphrase []byte) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if !r.locked {
+ return errors.New("agent: not locked")
+ }
+ if len(passphrase) != len(r.passphrase) || 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) {
+ return fmt.Errorf("agent: incorrect passphrase")
+ }
+
+ r.locked = false
+ r.passphrase = nil
+ return nil
+}
+
+// expireKeysLocked removes expired keys from the keyring. If a key was added
+// with a lifetimesecs contraint and seconds >= lifetimesecs seconds have
+// ellapsed, it is removed. The caller *must* be holding the keyring mutex.
+func (r *keyring) expireKeysLocked() {
+ for _, k := range r.keys {
+ if k.expire != nil && time.Now().After(*k.expire) {
+ r.removeLocked(k.signer.PublicKey().Marshal())
+ }
+ }
+}
+
+// List returns the identities known to the agent.
+func (r *keyring) List() ([]*Key, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.locked {
+ // section 2.7: locked agents return empty.
+ return nil, nil
+ }
+
+ r.expireKeysLocked()
+ var ids []*Key
+ for _, k := range r.keys {
+ pub := k.signer.PublicKey()
+ ids = append(ids, &Key{
+ Format: pub.Type(),
+ Blob: pub.Marshal(),
+ Comment: k.comment})
+ }
+ return ids, nil
+}
+
+// Insert adds a private key to the keyring. If a certificate
+// is given, that certificate is added as public key. Note that
+// any constraints given are ignored.
+func (r *keyring) Add(key AddedKey) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.locked {
+ return errLocked
+ }
+ signer, err := ssh.NewSignerFromKey(key.PrivateKey)
+
+ if err != nil {
+ return err
+ }
+
+ if cert := key.Certificate; cert != nil {
+ signer, err = ssh.NewCertSigner(cert, signer)
+ if err != nil {
+ return err
+ }
+ }
+
+ p := privKey{
+ signer: signer,
+ comment: key.Comment,
+ }
+
+ if key.LifetimeSecs > 0 {
+ t := time.Now().Add(time.Duration(key.LifetimeSecs) * time.Second)
+ p.expire = &t
+ }
+
+ r.keys = append(r.keys, p)
+
+ return nil
+}
+
+// Sign returns a signature for the data.
+func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.locked {
+ return nil, errLocked
+ }
+
+ r.expireKeysLocked()
+ wanted := key.Marshal()
+ for _, k := range r.keys {
+ if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) {
+ return k.signer.Sign(rand.Reader, data)
+ }
+ }
+ return nil, errors.New("not found")
+}
+
+// Signers returns signers for all the known keys.
+func (r *keyring) Signers() ([]ssh.Signer, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.locked {
+ return nil, errLocked
+ }
+
+ r.expireKeysLocked()
+ s := make([]ssh.Signer, 0, len(r.keys))
+ for _, k := range r.keys {
+ s = append(s, k.signer)
+ }
+ return s, nil
+}
--- /dev/null
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package agent
+
+import (
+ "crypto/dsa"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rsa"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "math/big"
+
+ "golang.org/x/crypto/ed25519"
+ "golang.org/x/crypto/ssh"
+)
+
+// Server wraps an Agent and uses it to implement the agent side of
+// the SSH-agent, wire protocol.
+type server struct {
+ agent Agent
+}
+
+func (s *server) processRequestBytes(reqData []byte) []byte {
+ rep, err := s.processRequest(reqData)
+ if err != nil {
+ if err != errLocked {
+ // TODO(hanwen): provide better logging interface?
+ log.Printf("agent %d: %v", reqData[0], err)
+ }
+ return []byte{agentFailure}
+ }
+
+ if err == nil && rep == nil {
+ return []byte{agentSuccess}
+ }
+
+ return ssh.Marshal(rep)
+}
+
+func marshalKey(k *Key) []byte {
+ var record struct {
+ Blob []byte
+ Comment string
+ }
+ record.Blob = k.Marshal()
+ record.Comment = k.Comment
+
+ return ssh.Marshal(&record)
+}
+
+// See [PROTOCOL.agent], section 2.5.1.
+const agentV1IdentitiesAnswer = 2
+
+type agentV1IdentityMsg struct {
+ Numkeys uint32 `sshtype:"2"`
+}
+
+type agentRemoveIdentityMsg struct {
+ KeyBlob []byte `sshtype:"18"`
+}
+
+type agentLockMsg struct {
+ Passphrase []byte `sshtype:"22"`
+}
+
+type agentUnlockMsg struct {
+ Passphrase []byte `sshtype:"23"`
+}
+
+func (s *server) processRequest(data []byte) (interface{}, error) {
+ switch data[0] {
+ case agentRequestV1Identities:
+ return &agentV1IdentityMsg{0}, nil
+
+ case agentRemoveAllV1Identities:
+ return nil, nil
+
+ case agentRemoveIdentity:
+ var req agentRemoveIdentityMsg
+ if err := ssh.Unmarshal(data, &req); err != nil {
+ return nil, err
+ }
+
+ var wk wireKey
+ if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
+ return nil, err
+ }
+
+ return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob})
+
+ case agentRemoveAllIdentities:
+ return nil, s.agent.RemoveAll()
+
+ case agentLock:
+ var req agentLockMsg
+ if err := ssh.Unmarshal(data, &req); err != nil {
+ return nil, err
+ }
+
+ return nil, s.agent.Lock(req.Passphrase)
+
+ case agentUnlock:
+ var req agentUnlockMsg
+ if err := ssh.Unmarshal(data, &req); err != nil {
+ return nil, err
+ }
+ return nil, s.agent.Unlock(req.Passphrase)
+
+ case agentSignRequest:
+ var req signRequestAgentMsg
+ if err := ssh.Unmarshal(data, &req); err != nil {
+ return nil, err
+ }
+
+ var wk wireKey
+ if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
+ return nil, err
+ }
+
+ k := &Key{
+ Format: wk.Format,
+ Blob: req.KeyBlob,
+ }
+
+ sig, err := s.agent.Sign(k, req.Data) // TODO(hanwen): flags.
+ if err != nil {
+ return nil, err
+ }
+ return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil
+
+ case agentRequestIdentities:
+ keys, err := s.agent.List()
+ if err != nil {
+ return nil, err
+ }
+
+ rep := identitiesAnswerAgentMsg{
+ NumKeys: uint32(len(keys)),
+ }
+ for _, k := range keys {
+ rep.Keys = append(rep.Keys, marshalKey(k)...)
+ }
+ return rep, nil
+
+ case agentAddIDConstrained, agentAddIdentity:
+ return nil, s.insertIdentity(data)
+ }
+
+ return nil, fmt.Errorf("unknown opcode %d", data[0])
+}
+
+func parseConstraints(constraints []byte) (lifetimeSecs uint32, confirmBeforeUse bool, extensions []ConstraintExtension, err error) {
+ for len(constraints) != 0 {
+ switch constraints[0] {
+ case agentConstrainLifetime:
+ lifetimeSecs = binary.BigEndian.Uint32(constraints[1:5])
+ constraints = constraints[5:]
+ case agentConstrainConfirm:
+ confirmBeforeUse = true
+ constraints = constraints[1:]
+ case agentConstrainExtension:
+ var msg constrainExtensionAgentMsg
+ if err = ssh.Unmarshal(constraints, &msg); err != nil {
+ return 0, false, nil, err
+ }
+ extensions = append(extensions, ConstraintExtension{
+ ExtensionName: msg.ExtensionName,
+ ExtensionDetails: msg.ExtensionDetails,
+ })
+ constraints = msg.Rest
+ default:
+ return 0, false, nil, fmt.Errorf("unknown constraint type: %d", constraints[0])
+ }
+ }
+ return
+}
+
+func setConstraints(key *AddedKey, constraintBytes []byte) error {
+ lifetimeSecs, confirmBeforeUse, constraintExtensions, err := parseConstraints(constraintBytes)
+ if err != nil {
+ return err
+ }
+
+ key.LifetimeSecs = lifetimeSecs
+ key.ConfirmBeforeUse = confirmBeforeUse
+ key.ConstraintExtensions = constraintExtensions
+ return nil
+}
+
+func parseRSAKey(req []byte) (*AddedKey, error) {
+ var k rsaKeyMsg
+ if err := ssh.Unmarshal(req, &k); err != nil {
+ return nil, err
+ }
+ if k.E.BitLen() > 30 {
+ return nil, errors.New("agent: RSA public exponent too large")
+ }
+ priv := &rsa.PrivateKey{
+ PublicKey: rsa.PublicKey{
+ E: int(k.E.Int64()),
+ N: k.N,
+ },
+ D: k.D,
+ Primes: []*big.Int{k.P, k.Q},
+ }
+ priv.Precompute()
+
+ addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
+}
+
+func parseEd25519Key(req []byte) (*AddedKey, error) {
+ var k ed25519KeyMsg
+ if err := ssh.Unmarshal(req, &k); err != nil {
+ return nil, err
+ }
+ priv := ed25519.PrivateKey(k.Priv)
+
+ addedKey := &AddedKey{PrivateKey: &priv, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
+}
+
+func parseDSAKey(req []byte) (*AddedKey, error) {
+ var k dsaKeyMsg
+ if err := ssh.Unmarshal(req, &k); err != nil {
+ return nil, err
+ }
+ priv := &dsa.PrivateKey{
+ PublicKey: dsa.PublicKey{
+ Parameters: dsa.Parameters{
+ P: k.P,
+ Q: k.Q,
+ G: k.G,
+ },
+ Y: k.Y,
+ },
+ X: k.X,
+ }
+
+ addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
+}
+
+func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) {
+ priv = &ecdsa.PrivateKey{
+ D: privScalar,
+ }
+
+ switch curveName {
+ case "nistp256":
+ priv.Curve = elliptic.P256()
+ case "nistp384":
+ priv.Curve = elliptic.P384()
+ case "nistp521":
+ priv.Curve = elliptic.P521()
+ default:
+ return nil, fmt.Errorf("agent: unknown curve %q", curveName)
+ }
+
+ priv.X, priv.Y = elliptic.Unmarshal(priv.Curve, keyBytes)
+ if priv.X == nil || priv.Y == nil {
+ return nil, errors.New("agent: point not on curve")
+ }
+
+ return priv, nil
+}
+
+func parseEd25519Cert(req []byte) (*AddedKey, error) {
+ var k ed25519CertMsg
+ if err := ssh.Unmarshal(req, &k); err != nil {
+ return nil, err
+ }
+ pubKey, err := ssh.ParsePublicKey(k.CertBytes)
+ if err != nil {
+ return nil, err
+ }
+ priv := ed25519.PrivateKey(k.Priv)
+ cert, ok := pubKey.(*ssh.Certificate)
+ if !ok {
+ return nil, errors.New("agent: bad ED25519 certificate")
+ }
+
+ addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
+}
+
+func parseECDSAKey(req []byte) (*AddedKey, error) {
+ var k ecdsaKeyMsg
+ if err := ssh.Unmarshal(req, &k); err != nil {
+ return nil, err
+ }
+
+ priv, err := unmarshalECDSA(k.Curve, k.KeyBytes, k.D)
+ if err != nil {
+ return nil, err
+ }
+
+ addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
+}
+
+func parseRSACert(req []byte) (*AddedKey, error) {
+ var k rsaCertMsg
+ if err := ssh.Unmarshal(req, &k); err != nil {
+ return nil, err
+ }
+
+ pubKey, err := ssh.ParsePublicKey(k.CertBytes)
+ if err != nil {
+ return nil, err
+ }
+
+ cert, ok := pubKey.(*ssh.Certificate)
+ if !ok {
+ return nil, errors.New("agent: bad RSA certificate")
+ }
+
+ // An RSA publickey as marshaled by rsaPublicKey.Marshal() in keys.go
+ var rsaPub struct {
+ Name string
+ E *big.Int
+ N *big.Int
+ }
+ if err := ssh.Unmarshal(cert.Key.Marshal(), &rsaPub); err != nil {
+ return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
+ }
+
+ if rsaPub.E.BitLen() > 30 {
+ return nil, errors.New("agent: RSA public exponent too large")
+ }
+
+ priv := rsa.PrivateKey{
+ PublicKey: rsa.PublicKey{
+ E: int(rsaPub.E.Int64()),
+ N: rsaPub.N,
+ },
+ D: k.D,
+ Primes: []*big.Int{k.Q, k.P},
+ }
+ priv.Precompute()
+
+ addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
+}
+
+func parseDSACert(req []byte) (*AddedKey, error) {
+ var k dsaCertMsg
+ if err := ssh.Unmarshal(req, &k); err != nil {
+ return nil, err
+ }
+ pubKey, err := ssh.ParsePublicKey(k.CertBytes)
+ if err != nil {
+ return nil, err
+ }
+ cert, ok := pubKey.(*ssh.Certificate)
+ if !ok {
+ return nil, errors.New("agent: bad DSA certificate")
+ }
+
+ // A DSA publickey as marshaled by dsaPublicKey.Marshal() in keys.go
+ var w struct {
+ Name string
+ P, Q, G, Y *big.Int
+ }
+ if err := ssh.Unmarshal(cert.Key.Marshal(), &w); err != nil {
+ return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
+ }
+
+ priv := &dsa.PrivateKey{
+ PublicKey: dsa.PublicKey{
+ Parameters: dsa.Parameters{
+ P: w.P,
+ Q: w.Q,
+ G: w.G,
+ },
+ Y: w.Y,
+ },
+ X: k.X,
+ }
+
+ addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
+}
+
+func parseECDSACert(req []byte) (*AddedKey, error) {
+ var k ecdsaCertMsg
+ if err := ssh.Unmarshal(req, &k); err != nil {
+ return nil, err
+ }
+
+ pubKey, err := ssh.ParsePublicKey(k.CertBytes)
+ if err != nil {
+ return nil, err
+ }
+ cert, ok := pubKey.(*ssh.Certificate)
+ if !ok {
+ return nil, errors.New("agent: bad ECDSA certificate")
+ }
+
+ // An ECDSA publickey as marshaled by ecdsaPublicKey.Marshal() in keys.go
+ var ecdsaPub struct {
+ Name string
+ ID string
+ Key []byte
+ }
+ if err := ssh.Unmarshal(cert.Key.Marshal(), &ecdsaPub); err != nil {
+ return nil, err
+ }
+
+ priv, err := unmarshalECDSA(ecdsaPub.ID, ecdsaPub.Key, k.D)
+ if err != nil {
+ return nil, err
+ }
+
+ addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
+}
+
+func (s *server) insertIdentity(req []byte) error {
+ var record struct {
+ Type string `sshtype:"17|25"`
+ Rest []byte `ssh:"rest"`
+ }
+
+ if err := ssh.Unmarshal(req, &record); err != nil {
+ return err
+ }
+
+ var addedKey *AddedKey
+ var err error
+
+ switch record.Type {
+ case ssh.KeyAlgoRSA:
+ addedKey, err = parseRSAKey(req)
+ case ssh.KeyAlgoDSA:
+ addedKey, err = parseDSAKey(req)
+ case ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521:
+ addedKey, err = parseECDSAKey(req)
+ case ssh.KeyAlgoED25519:
+ addedKey, err = parseEd25519Key(req)
+ case ssh.CertAlgoRSAv01:
+ addedKey, err = parseRSACert(req)
+ case ssh.CertAlgoDSAv01:
+ addedKey, err = parseDSACert(req)
+ case ssh.CertAlgoECDSA256v01, ssh.CertAlgoECDSA384v01, ssh.CertAlgoECDSA521v01:
+ addedKey, err = parseECDSACert(req)
+ case ssh.CertAlgoED25519v01:
+ addedKey, err = parseEd25519Cert(req)
+ default:
+ return fmt.Errorf("agent: not implemented: %q", record.Type)
+ }
+
+ if err != nil {
+ return err
+ }
+ return s.agent.Add(*addedKey)
+}
+
+// ServeAgent serves the agent protocol on the given connection. It
+// returns when an I/O error occurs.
+func ServeAgent(agent Agent, c io.ReadWriter) error {
+ s := &server{agent}
+
+ var length [4]byte
+ for {
+ if _, err := io.ReadFull(c, length[:]); err != nil {
+ return err
+ }
+ l := binary.BigEndian.Uint32(length[:])
+ if l > maxAgentResponseBytes {
+ // We also cap requests.
+ return fmt.Errorf("agent: request too large: %d", l)
+ }
+
+ req := make([]byte, l)
+ if _, err := io.ReadFull(c, req); err != nil {
+ return err
+ }
+
+ repData := s.processRequestBytes(req)
+ if len(repData) > maxAgentResponseBytes {
+ return fmt.Errorf("agent: reply too large: %d bytes", len(repData))
+ }
+
+ binary.BigEndian.PutUint32(length[:], uint32(len(repData)))
+ if _, err := c.Write(length[:]); err != nil {
+ return err
+ }
+ if _, err := c.Write(repData); err != nil {
+ return err
+ }
+ }
+}
--- /dev/null
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package knownhosts implements a parser for the OpenSSH known_hosts
+// host key database, and provides utility functions for writing
+// OpenSSH compliant known_hosts files.
+package knownhosts
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/hmac"
+ "crypto/rand"
+ "crypto/sha1"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "strings"
+
+ "golang.org/x/crypto/ssh"
+)
+
+// See the sshd manpage
+// (http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT) for
+// background.
+
+type addr struct{ host, port string }
+
+func (a *addr) String() string {
+ h := a.host
+ if strings.Contains(h, ":") {
+ h = "[" + h + "]"
+ }
+ return h + ":" + a.port
+}
+
+type matcher interface {
+ match(addr) bool
+}
+
+type hostPattern struct {
+ negate bool
+ addr addr
+}
+
+func (p *hostPattern) String() string {
+ n := ""
+ if p.negate {
+ n = "!"
+ }
+
+ return n + p.addr.String()
+}
+
+type hostPatterns []hostPattern
+
+func (ps hostPatterns) match(a addr) bool {
+ matched := false
+ for _, p := range ps {
+ if !p.match(a) {
+ continue
+ }
+ if p.negate {
+ return false
+ }
+ matched = true
+ }
+ return matched
+}
+
+// See
+// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/addrmatch.c
+// The matching of * has no regard for separators, unlike filesystem globs
+func wildcardMatch(pat []byte, str []byte) bool {
+ for {
+ if len(pat) == 0 {
+ return len(str) == 0
+ }
+ if len(str) == 0 {
+ return false
+ }
+
+ if pat[0] == '*' {
+ if len(pat) == 1 {
+ return true
+ }
+
+ for j := range str {
+ if wildcardMatch(pat[1:], str[j:]) {
+ return true
+ }
+ }
+ return false
+ }
+
+ if pat[0] == '?' || pat[0] == str[0] {
+ pat = pat[1:]
+ str = str[1:]
+ } else {
+ return false
+ }
+ }
+}
+
+func (p *hostPattern) match(a addr) bool {
+ return wildcardMatch([]byte(p.addr.host), []byte(a.host)) && p.addr.port == a.port
+}
+
+type keyDBLine struct {
+ cert bool
+ matcher matcher
+ knownKey KnownKey
+}
+
+func serialize(k ssh.PublicKey) string {
+ return k.Type() + " " + base64.StdEncoding.EncodeToString(k.Marshal())
+}
+
+func (l *keyDBLine) match(a addr) bool {
+ return l.matcher.match(a)
+}
+
+type hostKeyDB struct {
+ // Serialized version of revoked keys
+ revoked map[string]*KnownKey
+ lines []keyDBLine
+}
+
+func newHostKeyDB() *hostKeyDB {
+ db := &hostKeyDB{
+ revoked: make(map[string]*KnownKey),
+ }
+
+ return db
+}
+
+func keyEq(a, b ssh.PublicKey) bool {
+ return bytes.Equal(a.Marshal(), b.Marshal())
+}
+
+// IsAuthorityForHost can be used as a callback in ssh.CertChecker
+func (db *hostKeyDB) IsHostAuthority(remote ssh.PublicKey, address string) bool {
+ h, p, err := net.SplitHostPort(address)
+ if err != nil {
+ return false
+ }
+ a := addr{host: h, port: p}
+
+ for _, l := range db.lines {
+ if l.cert && keyEq(l.knownKey.Key, remote) && l.match(a) {
+ return true
+ }
+ }
+ return false
+}
+
+// IsRevoked can be used as a callback in ssh.CertChecker
+func (db *hostKeyDB) IsRevoked(key *ssh.Certificate) bool {
+ _, ok := db.revoked[string(key.Marshal())]
+ return ok
+}
+
+const markerCert = "@cert-authority"
+const markerRevoked = "@revoked"
+
+func nextWord(line []byte) (string, []byte) {
+ i := bytes.IndexAny(line, "\t ")
+ if i == -1 {
+ return string(line), nil
+ }
+
+ return string(line[:i]), bytes.TrimSpace(line[i:])
+}
+
+func parseLine(line []byte) (marker, host string, key ssh.PublicKey, err error) {
+ if w, next := nextWord(line); w == markerCert || w == markerRevoked {
+ marker = w
+ line = next
+ }
+
+ host, line = nextWord(line)
+ if len(line) == 0 {
+ return "", "", nil, errors.New("knownhosts: missing host pattern")
+ }
+
+ // ignore the keytype as it's in the key blob anyway.
+ _, line = nextWord(line)
+ if len(line) == 0 {
+ return "", "", nil, errors.New("knownhosts: missing key type pattern")
+ }
+
+ keyBlob, _ := nextWord(line)
+
+ keyBytes, err := base64.StdEncoding.DecodeString(keyBlob)
+ if err != nil {
+ return "", "", nil, err
+ }
+ key, err = ssh.ParsePublicKey(keyBytes)
+ if err != nil {
+ return "", "", nil, err
+ }
+
+ return marker, host, key, nil
+}
+
+func (db *hostKeyDB) parseLine(line []byte, filename string, linenum int) error {
+ marker, pattern, key, err := parseLine(line)
+ if err != nil {
+ return err
+ }
+
+ if marker == markerRevoked {
+ db.revoked[string(key.Marshal())] = &KnownKey{
+ Key: key,
+ Filename: filename,
+ Line: linenum,
+ }
+
+ return nil
+ }
+
+ entry := keyDBLine{
+ cert: marker == markerCert,
+ knownKey: KnownKey{
+ Filename: filename,
+ Line: linenum,
+ Key: key,
+ },
+ }
+
+ if pattern[0] == '|' {
+ entry.matcher, err = newHashedHost(pattern)
+ } else {
+ entry.matcher, err = newHostnameMatcher(pattern)
+ }
+
+ if err != nil {
+ return err
+ }
+
+ db.lines = append(db.lines, entry)
+ return nil
+}
+
+func newHostnameMatcher(pattern string) (matcher, error) {
+ var hps hostPatterns
+ for _, p := range strings.Split(pattern, ",") {
+ if len(p) == 0 {
+ continue
+ }
+
+ var a addr
+ var negate bool
+ if p[0] == '!' {
+ negate = true
+ p = p[1:]
+ }
+
+ if len(p) == 0 {
+ return nil, errors.New("knownhosts: negation without following hostname")
+ }
+
+ var err error
+ if p[0] == '[' {
+ a.host, a.port, err = net.SplitHostPort(p)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ a.host, a.port, err = net.SplitHostPort(p)
+ if err != nil {
+ a.host = p
+ a.port = "22"
+ }
+ }
+ hps = append(hps, hostPattern{
+ negate: negate,
+ addr: a,
+ })
+ }
+ return hps, nil
+}
+
+// KnownKey represents a key declared in a known_hosts file.
+type KnownKey struct {
+ Key ssh.PublicKey
+ Filename string
+ Line int
+}
+
+func (k *KnownKey) String() string {
+ return fmt.Sprintf("%s:%d: %s", k.Filename, k.Line, serialize(k.Key))
+}
+
+// KeyError is returned if we did not find the key in the host key
+// database, or there was a mismatch. Typically, in batch
+// applications, this should be interpreted as failure. Interactive
+// applications can offer an interactive prompt to the user.
+type KeyError struct {
+ // Want holds the accepted host keys. For each key algorithm,
+ // there can be one hostkey. If Want is empty, the host is
+ // unknown. If Want is non-empty, there was a mismatch, which
+ // can signify a MITM attack.
+ Want []KnownKey
+}
+
+func (u *KeyError) Error() string {
+ if len(u.Want) == 0 {
+ return "knownhosts: key is unknown"
+ }
+ return "knownhosts: key mismatch"
+}
+
+// RevokedError is returned if we found a key that was revoked.
+type RevokedError struct {
+ Revoked KnownKey
+}
+
+func (r *RevokedError) Error() string {
+ return "knownhosts: key is revoked"
+}
+
+// check checks a key against the host database. This should not be
+// used for verifying certificates.
+func (db *hostKeyDB) check(address string, remote net.Addr, remoteKey ssh.PublicKey) error {
+ if revoked := db.revoked[string(remoteKey.Marshal())]; revoked != nil {
+ return &RevokedError{Revoked: *revoked}
+ }
+
+ host, port, err := net.SplitHostPort(remote.String())
+ if err != nil {
+ return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", remote, err)
+ }
+
+ hostToCheck := addr{host, port}
+ if address != "" {
+ // Give preference to the hostname if available.
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", address, err)
+ }
+
+ hostToCheck = addr{host, port}
+ }
+
+ return db.checkAddr(hostToCheck, remoteKey)
+}
+
+// checkAddrs checks if we can find the given public key for any of
+// the given addresses. If we only find an entry for the IP address,
+// or only the hostname, then this still succeeds.
+func (db *hostKeyDB) checkAddr(a addr, remoteKey ssh.PublicKey) error {
+ // TODO(hanwen): are these the right semantics? What if there
+ // is just a key for the IP address, but not for the
+ // hostname?
+
+ // Algorithm => key.
+ knownKeys := map[string]KnownKey{}
+ for _, l := range db.lines {
+ if l.match(a) {
+ typ := l.knownKey.Key.Type()
+ if _, ok := knownKeys[typ]; !ok {
+ knownKeys[typ] = l.knownKey
+ }
+ }
+ }
+
+ keyErr := &KeyError{}
+ for _, v := range knownKeys {
+ keyErr.Want = append(keyErr.Want, v)
+ }
+
+ // Unknown remote host.
+ if len(knownKeys) == 0 {
+ return keyErr
+ }
+
+ // If the remote host starts using a different, unknown key type, we
+ // also interpret that as a mismatch.
+ if known, ok := knownKeys[remoteKey.Type()]; !ok || !keyEq(known.Key, remoteKey) {
+ return keyErr
+ }
+
+ return nil
+}
+
+// The Read function parses file contents.
+func (db *hostKeyDB) Read(r io.Reader, filename string) error {
+ scanner := bufio.NewScanner(r)
+
+ lineNum := 0
+ for scanner.Scan() {
+ lineNum++
+ line := scanner.Bytes()
+ line = bytes.TrimSpace(line)
+ if len(line) == 0 || line[0] == '#' {
+ continue
+ }
+
+ if err := db.parseLine(line, filename, lineNum); err != nil {
+ return fmt.Errorf("knownhosts: %s:%d: %v", filename, lineNum, err)
+ }
+ }
+ return scanner.Err()
+}
+
+// New creates a host key callback from the given OpenSSH host key
+// files. The returned callback is for use in
+// ssh.ClientConfig.HostKeyCallback. By preference, the key check
+// operates on the hostname if available, i.e. if a server changes its
+// IP address, the host key check will still succeed, even though a
+// record of the new IP address is not available.
+func New(files ...string) (ssh.HostKeyCallback, error) {
+ db := newHostKeyDB()
+ for _, fn := range files {
+ f, err := os.Open(fn)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+ if err := db.Read(f, fn); err != nil {
+ return nil, err
+ }
+ }
+
+ var certChecker ssh.CertChecker
+ certChecker.IsHostAuthority = db.IsHostAuthority
+ certChecker.IsRevoked = db.IsRevoked
+ certChecker.HostKeyFallback = db.check
+
+ return certChecker.CheckHostKey, nil
+}
+
+// Normalize normalizes an address into the form used in known_hosts
+func Normalize(address string) string {
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ host = address
+ port = "22"
+ }
+ entry := host
+ if port != "22" {
+ entry = "[" + entry + "]:" + port
+ } else if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
+ entry = "[" + entry + "]"
+ }
+ return entry
+}
+
+// Line returns a line to add append to the known_hosts files.
+func Line(addresses []string, key ssh.PublicKey) string {
+ var trimmed []string
+ for _, a := range addresses {
+ trimmed = append(trimmed, Normalize(a))
+ }
+
+ return strings.Join(trimmed, ",") + " " + serialize(key)
+}
+
+// HashHostname hashes the given hostname. The hostname is not
+// normalized before hashing.
+func HashHostname(hostname string) string {
+ // TODO(hanwen): check if we can safely normalize this always.
+ salt := make([]byte, sha1.Size)
+
+ _, err := rand.Read(salt)
+ if err != nil {
+ panic(fmt.Sprintf("crypto/rand failure %v", err))
+ }
+
+ hash := hashHost(hostname, salt)
+ return encodeHash(sha1HashType, salt, hash)
+}
+
+func decodeHash(encoded string) (hashType string, salt, hash []byte, err error) {
+ if len(encoded) == 0 || encoded[0] != '|' {
+ err = errors.New("knownhosts: hashed host must start with '|'")
+ return
+ }
+ components := strings.Split(encoded, "|")
+ if len(components) != 4 {
+ err = fmt.Errorf("knownhosts: got %d components, want 3", len(components))
+ return
+ }
+
+ hashType = components[1]
+ if salt, err = base64.StdEncoding.DecodeString(components[2]); err != nil {
+ return
+ }
+ if hash, err = base64.StdEncoding.DecodeString(components[3]); err != nil {
+ return
+ }
+ return
+}
+
+func encodeHash(typ string, salt []byte, hash []byte) string {
+ return strings.Join([]string{"",
+ typ,
+ base64.StdEncoding.EncodeToString(salt),
+ base64.StdEncoding.EncodeToString(hash),
+ }, "|")
+}
+
+// See https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
+func hashHost(hostname string, salt []byte) []byte {
+ mac := hmac.New(sha1.New, salt)
+ mac.Write([]byte(hostname))
+ return mac.Sum(nil)
+}
+
+type hashedHost struct {
+ salt []byte
+ hash []byte
+}
+
+const sha1HashType = "1"
+
+func newHashedHost(encoded string) (*hashedHost, error) {
+ typ, salt, hash, err := decodeHash(encoded)
+ if err != nil {
+ return nil, err
+ }
+
+ // The type field seems for future algorithm agility, but it's
+ // actually hardcoded in openssh currently, see
+ // https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
+ if typ != sha1HashType {
+ return nil, fmt.Errorf("knownhosts: got hash type %s, must be '1'", typ)
+ }
+
+ return &hashedHost{salt: salt, hash: hash}, nil
+}
+
+func (h *hashedHost) match(a addr) bool {
+ return bytes.Equal(hashHost(Normalize(a.String()), h.salt), h.hash)
+}
--- /dev/null
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2017 Sourced Technologies S.L.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
--- /dev/null
+package billy
+
+import (
+ "errors"
+ "io"
+ "os"
+ "time"
+)
+
+var (
+ ErrReadOnly = errors.New("read-only filesystem")
+ ErrNotSupported = errors.New("feature not supported")
+ ErrCrossedBoundary = errors.New("chroot boundary crossed")
+)
+
+// Capability holds the supported features of a billy filesystem. This does
+// not mean that the capability has to be supported by the underlying storage.
+// For example, a billy filesystem may support WriteCapability but the
+// storage be mounted in read only mode.
+type Capability uint64
+
+const (
+ // WriteCapability means that the fs is writable.
+ WriteCapability Capability = 1 << iota
+ // ReadCapability means that the fs is readable.
+ ReadCapability
+ // ReadAndWriteCapability is the ability to open a file in read and write mode.
+ ReadAndWriteCapability
+ // SeekCapability means it is able to move position inside the file.
+ SeekCapability
+ // TruncateCapability means that a file can be truncated.
+ TruncateCapability
+ // LockCapability is the ability to lock a file.
+ LockCapability
+
+ // DefaultCapabilities lists all capable features supported by filesystems
+ // without Capability interface. This list should not be changed until a
+ // major version is released.
+ DefaultCapabilities Capability = WriteCapability | ReadCapability |
+ ReadAndWriteCapability | SeekCapability | TruncateCapability |
+ LockCapability
+
+ // AllCapabilities lists all capable features.
+ AllCapabilities Capability = WriteCapability | ReadCapability |
+ ReadAndWriteCapability | SeekCapability | TruncateCapability |
+ LockCapability
+)
+
+// Filesystem abstract the operations in a storage-agnostic interface.
+// Each method implementation mimics the behavior of the equivalent functions
+// at the os package from the standard library.
+type Filesystem interface {
+ Basic
+ TempFile
+ Dir
+ Symlink
+ Chroot
+}
+
+// Basic abstract the basic operations in a storage-agnostic interface as
+// an extension to the Basic interface.
+type Basic interface {
+ // Create creates the named file with mode 0666 (before umask), truncating
+ // it if it already exists. If successful, methods on the returned File can
+ // be used for I/O; the associated file descriptor has mode O_RDWR.
+ Create(filename string) (File, error)
+ // Open opens the named file for reading. If successful, methods on the
+ // returned file can be used for reading; the associated file descriptor has
+ // mode O_RDONLY.
+ Open(filename string) (File, error)
+ // OpenFile is the generalized open call; most users will use Open or Create
+ // instead. It opens the named file with specified flag (O_RDONLY etc.) and
+ // perm, (0666 etc.) if applicable. If successful, methods on the returned
+ // File can be used for I/O.
+ OpenFile(filename string, flag int, perm os.FileMode) (File, error)
+ // Stat returns a FileInfo describing the named file.
+ Stat(filename string) (os.FileInfo, error)
+ // Rename renames (moves) oldpath to newpath. If newpath already exists and
+ // is not a directory, Rename replaces it. OS-specific restrictions may
+ // apply when oldpath and newpath are in different directories.
+ Rename(oldpath, newpath string) error
+ // Remove removes the named file or directory.
+ Remove(filename string) error
+ // Join joins any number of path elements into a single path, adding a
+ // Separator if necessary. Join calls filepath.Clean on the result; in
+ // particular, all empty strings are ignored. On Windows, the result is a
+ // UNC path if and only if the first path element is a UNC path.
+ Join(elem ...string) string
+}
+
+type TempFile interface {
+ // TempFile creates a new temporary file in the directory dir with a name
+ // beginning with prefix, opens the file for reading and writing, and
+ // returns the resulting *os.File. If dir is the empty string, TempFile
+ // uses the default directory for temporary files (see os.TempDir).
+ // Multiple programs calling TempFile simultaneously will not choose the
+ // same file. The caller can use f.Name() to find the pathname of the file.
+ // It is the caller's responsibility to remove the file when no longer
+ // needed.
+ TempFile(dir, prefix string) (File, error)
+}
+
+// Dir abstract the dir related operations in a storage-agnostic interface as
+// an extension to the Basic interface.
+type Dir interface {
+ // ReadDir reads the directory named by dirname and returns a list of
+ // directory entries sorted by filename.
+ ReadDir(path string) ([]os.FileInfo, error)
+ // MkdirAll creates a directory named path, along with any necessary
+ // parents, and returns nil, or else returns an error. The permission bits
+ // perm are used for all directories that MkdirAll creates. If path is/
+ // already a directory, MkdirAll does nothing and returns nil.
+ MkdirAll(filename string, perm os.FileMode) error
+}
+
+// Symlink abstract the symlink related operations in a storage-agnostic
+// interface as an extension to the Basic interface.
+type Symlink interface {
+ // Lstat returns a FileInfo describing the named file. If the file is a
+ // symbolic link, the returned FileInfo describes the symbolic link. Lstat
+ // makes no attempt to follow the link.
+ Lstat(filename string) (os.FileInfo, error)
+ // Symlink creates a symbolic-link from link to target. target may be an
+ // absolute or relative path, and need not refer to an existing node.
+ // Parent directories of link are created as necessary.
+ Symlink(target, link string) error
+ // Readlink returns the target path of link.
+ Readlink(link string) (string, error)
+}
+
+// Change abstract the FileInfo change related operations in a storage-agnostic
+// interface as an extension to the Basic interface
+type Change interface {
+ // Chmod changes the mode of the named file to mode. If the file is a
+ // symbolic link, it changes the mode of the link's target.
+ Chmod(name string, mode os.FileMode) error
+ // Lchown changes the numeric uid and gid of the named file. If the file is
+ // a symbolic link, it changes the uid and gid of the link itself.
+ Lchown(name string, uid, gid int) error
+ // Chown changes the numeric uid and gid of the named file. If the file is a
+ // symbolic link, it changes the uid and gid of the link's target.
+ Chown(name string, uid, gid int) error
+ // Chtimes changes the access and modification times of the named file,
+ // similar to the Unix utime() or utimes() functions.
+ //
+ // The underlying filesystem may truncate or round the values to a less
+ // precise time unit.
+ Chtimes(name string, atime time.Time, mtime time.Time) error
+}
+
+// Chroot abstract the chroot related operations in a storage-agnostic interface
+// as an extension to the Basic interface.
+type Chroot interface {
+ // Chroot returns a new filesystem from the same type where the new root is
+ // the given path. Files outside of the designated directory tree cannot be
+ // accessed.
+ Chroot(path string) (Filesystem, error)
+ // Root returns the root path of the filesystem.
+ Root() string
+}
+
+// File represent a file, being a subset of the os.File
+type File interface {
+ // Name returns the name of the file as presented to Open.
+ Name() string
+ io.Writer
+ io.Reader
+ io.ReaderAt
+ io.Seeker
+ io.Closer
+ // Lock locks the file like e.g. flock. It protects against access from
+ // other processes.
+ Lock() error
+ // Unlock unlocks the file.
+ Unlock() error
+ // Truncate the file.
+ Truncate(size int64) error
+}
+
+// Capable interface can return the available features of a filesystem.
+type Capable interface {
+ // Capabilities returns the capabilities of a filesystem in bit flags.
+ Capabilities() Capability
+}
+
+// Capabilities returns the features supported by a filesystem. If the FS
+// does not implement Capable interface it returns all features.
+func Capabilities(fs Basic) Capability {
+ capable, ok := fs.(Capable)
+ if !ok {
+ return DefaultCapabilities
+ }
+
+ return capable.Capabilities()
+}
+
+// CapabilityCheck tests the filesystem for the provided capabilities and
+// returns true in case it supports all of them.
+func CapabilityCheck(fs Basic, capabilities Capability) bool {
+ fsCaps := Capabilities(fs)
+ return fsCaps&capabilities == capabilities
+}
--- /dev/null
+package chroot
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+
+ "gopkg.in/src-d/go-billy.v4"
+ "gopkg.in/src-d/go-billy.v4/helper/polyfill"
+)
+
+// ChrootHelper is a helper to implement billy.Chroot.
+type ChrootHelper struct {
+ underlying billy.Filesystem
+ base string
+}
+
+// New creates a new filesystem wrapping up the given 'fs'.
+// The created filesystem has its base in the given ChrootHelperectory of the
+// underlying filesystem.
+func New(fs billy.Basic, base string) billy.Filesystem {
+ return &ChrootHelper{
+ underlying: polyfill.New(fs),
+ base: base,
+ }
+}
+
+func (fs *ChrootHelper) underlyingPath(filename string) (string, error) {
+ if isCrossBoundaries(filename) {
+ return "", billy.ErrCrossedBoundary
+ }
+
+ return fs.Join(fs.Root(), filename), nil
+}
+
+func isCrossBoundaries(path string) bool {
+ path = filepath.ToSlash(path)
+ path = filepath.Clean(path)
+
+ return strings.HasPrefix(path, ".."+string(filepath.Separator))
+}
+
+func (fs *ChrootHelper) Create(filename string) (billy.File, error) {
+ fullpath, err := fs.underlyingPath(filename)
+ if err != nil {
+ return nil, err
+ }
+
+ f, err := fs.underlying.Create(fullpath)
+ if err != nil {
+ return nil, err
+ }
+
+ return newFile(fs, f, filename), nil
+}
+
+func (fs *ChrootHelper) Open(filename string) (billy.File, error) {
+ fullpath, err := fs.underlyingPath(filename)
+ if err != nil {
+ return nil, err
+ }
+
+ f, err := fs.underlying.Open(fullpath)
+ if err != nil {
+ return nil, err
+ }
+
+ return newFile(fs, f, filename), nil
+}
+
+func (fs *ChrootHelper) OpenFile(filename string, flag int, mode os.FileMode) (billy.File, error) {
+ fullpath, err := fs.underlyingPath(filename)
+ if err != nil {
+ return nil, err
+ }
+
+ f, err := fs.underlying.OpenFile(fullpath, flag, mode)
+ if err != nil {
+ return nil, err
+ }
+
+ return newFile(fs, f, filename), nil
+}
+
+func (fs *ChrootHelper) Stat(filename string) (os.FileInfo, error) {
+ fullpath, err := fs.underlyingPath(filename)
+ if err != nil {
+ return nil, err
+ }
+
+ return fs.underlying.Stat(fullpath)
+}
+
+func (fs *ChrootHelper) Rename(from, to string) error {
+ var err error
+ from, err = fs.underlyingPath(from)
+ if err != nil {
+ return err
+ }
+
+ to, err = fs.underlyingPath(to)
+ if err != nil {
+ return err
+ }
+
+ return fs.underlying.Rename(from, to)
+}
+
+func (fs *ChrootHelper) Remove(path string) error {
+ fullpath, err := fs.underlyingPath(path)
+ if err != nil {
+ return err
+ }
+
+ return fs.underlying.Remove(fullpath)
+}
+
+func (fs *ChrootHelper) Join(elem ...string) string {
+ return fs.underlying.Join(elem...)
+}
+
+func (fs *ChrootHelper) TempFile(dir, prefix string) (billy.File, error) {
+ fullpath, err := fs.underlyingPath(dir)
+ if err != nil {
+ return nil, err
+ }
+
+ f, err := fs.underlying.(billy.TempFile).TempFile(fullpath, prefix)
+ if err != nil {
+ return nil, err
+ }
+
+ return newFile(fs, f, fs.Join(dir, filepath.Base(f.Name()))), nil
+}
+
+func (fs *ChrootHelper) ReadDir(path string) ([]os.FileInfo, error) {
+ fullpath, err := fs.underlyingPath(path)
+ if err != nil {
+ return nil, err
+ }
+
+ return fs.underlying.(billy.Dir).ReadDir(fullpath)
+}
+
+func (fs *ChrootHelper) MkdirAll(filename string, perm os.FileMode) error {
+ fullpath, err := fs.underlyingPath(filename)
+ if err != nil {
+ return err
+ }
+
+ return fs.underlying.(billy.Dir).MkdirAll(fullpath, perm)
+}
+
+func (fs *ChrootHelper) Lstat(filename string) (os.FileInfo, error) {
+ fullpath, err := fs.underlyingPath(filename)
+ if err != nil {
+ return nil, err
+ }
+
+ return fs.underlying.(billy.Symlink).Lstat(fullpath)
+}
+
+func (fs *ChrootHelper) Symlink(target, link string) error {
+ target = filepath.FromSlash(target)
+
+ // only rewrite target if it's already absolute
+ if filepath.IsAbs(target) || strings.HasPrefix(target, string(filepath.Separator)) {
+ target = fs.Join(fs.Root(), target)
+ target = filepath.Clean(filepath.FromSlash(target))
+ }
+
+ link, err := fs.underlyingPath(link)
+ if err != nil {
+ return err
+ }
+
+ return fs.underlying.(billy.Symlink).Symlink(target, link)
+}
+
+func (fs *ChrootHelper) Readlink(link string) (string, error) {
+ fullpath, err := fs.underlyingPath(link)
+ if err != nil {
+ return "", err
+ }
+
+ target, err := fs.underlying.(billy.Symlink).Readlink(fullpath)
+ if err != nil {
+ return "", err
+ }
+
+ if !filepath.IsAbs(target) && !strings.HasPrefix(target, string(filepath.Separator)) {
+ return target, nil
+ }
+
+ target, err = filepath.Rel(fs.base, target)
+ if err != nil {
+ return "", err
+ }
+
+ return string(os.PathSeparator) + target, nil
+}
+
+func (fs *ChrootHelper) Chroot(path string) (billy.Filesystem, error) {
+ fullpath, err := fs.underlyingPath(path)
+ if err != nil {
+ return nil, err
+ }
+
+ return New(fs.underlying, fullpath), nil
+}
+
+func (fs *ChrootHelper) Root() string {
+ return fs.base
+}
+
+func (fs *ChrootHelper) Underlying() billy.Basic {
+ return fs.underlying
+}
+
+// Capabilities implements the Capable interface.
+func (fs *ChrootHelper) Capabilities() billy.Capability {
+ return billy.Capabilities(fs.underlying)
+}
+
+type file struct {
+ billy.File
+ name string
+}
+
+func newFile(fs billy.Filesystem, f billy.File, filename string) billy.File {
+ filename = fs.Join(fs.Root(), filename)
+ filename, _ = filepath.Rel(fs.Root(), filename)
+
+ return &file{
+ File: f,
+ name: filename,
+ }
+}
+
+func (f *file) Name() string {
+ return f.name
+}
--- /dev/null
+package polyfill
+
+import (
+ "os"
+ "path/filepath"
+
+ "gopkg.in/src-d/go-billy.v4"
+)
+
+// Polyfill is a helper that implements all missing method from billy.Filesystem.
+type Polyfill struct {
+ billy.Basic
+ c capabilities
+}
+
+type capabilities struct{ tempfile, dir, symlink, chroot bool }
+
+// New creates a new filesystem wrapping up 'fs' the intercepts all the calls
+// made and errors if fs doesn't implement any of the billy interfaces.
+func New(fs billy.Basic) billy.Filesystem {
+ if original, ok := fs.(billy.Filesystem); ok {
+ return original
+ }
+
+ h := &Polyfill{Basic: fs}
+
+ _, h.c.tempfile = h.Basic.(billy.TempFile)
+ _, h.c.dir = h.Basic.(billy.Dir)
+ _, h.c.symlink = h.Basic.(billy.Symlink)
+ _, h.c.chroot = h.Basic.(billy.Chroot)
+ return h
+}
+
+func (h *Polyfill) TempFile(dir, prefix string) (billy.File, error) {
+ if !h.c.tempfile {
+ return nil, billy.ErrNotSupported
+ }
+
+ return h.Basic.(billy.TempFile).TempFile(dir, prefix)
+}
+
+func (h *Polyfill) ReadDir(path string) ([]os.FileInfo, error) {
+ if !h.c.dir {
+ return nil, billy.ErrNotSupported
+ }
+
+ return h.Basic.(billy.Dir).ReadDir(path)
+}
+
+func (h *Polyfill) MkdirAll(filename string, perm os.FileMode) error {
+ if !h.c.dir {
+ return billy.ErrNotSupported
+ }
+
+ return h.Basic.(billy.Dir).MkdirAll(filename, perm)
+}
+
+func (h *Polyfill) Symlink(target, link string) error {
+ if !h.c.symlink {
+ return billy.ErrNotSupported
+ }
+
+ return h.Basic.(billy.Symlink).Symlink(target, link)
+}
+
+func (h *Polyfill) Readlink(link string) (string, error) {
+ if !h.c.symlink {
+ return "", billy.ErrNotSupported
+ }
+
+ return h.Basic.(billy.Symlink).Readlink(link)
+}
+
+func (h *Polyfill) Lstat(path string) (os.FileInfo, error) {
+ if !h.c.symlink {
+ return nil, billy.ErrNotSupported
+ }
+
+ return h.Basic.(billy.Symlink).Lstat(path)
+}
+
+func (h *Polyfill) Chroot(path string) (billy.Filesystem, error) {
+ if !h.c.chroot {
+ return nil, billy.ErrNotSupported
+ }
+
+ return h.Basic.(billy.Chroot).Chroot(path)
+}
+
+func (h *Polyfill) Root() string {
+ if !h.c.chroot {
+ return string(filepath.Separator)
+ }
+
+ return h.Basic.(billy.Chroot).Root()
+}
+
+func (h *Polyfill) Underlying() billy.Basic {
+ return h.Basic
+}
+
+// Capabilities implements the Capable interface.
+func (h *Polyfill) Capabilities() billy.Capability {
+ return billy.Capabilities(h.Basic)
+}
--- /dev/null
+// Package osfs provides a billy filesystem for the OS.
+package osfs // import "gopkg.in/src-d/go-billy.v4/osfs"
+
+import (
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "sync"
+
+ "gopkg.in/src-d/go-billy.v4"
+ "gopkg.in/src-d/go-billy.v4/helper/chroot"
+)
+
+const (
+ defaultDirectoryMode = 0755
+ defaultCreateMode = 0666
+)
+
+// OS is a filesystem based on the os filesystem.
+type OS struct{}
+
+// New returns a new OS filesystem.
+func New(baseDir string) billy.Filesystem {
+ return chroot.New(&OS{}, baseDir)
+}
+
+func (fs *OS) Create(filename string) (billy.File, error) {
+ return fs.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, defaultCreateMode)
+}
+
+func (fs *OS) OpenFile(filename string, flag int, perm os.FileMode) (billy.File, error) {
+ if flag&os.O_CREATE != 0 {
+ if err := fs.createDir(filename); err != nil {
+ return nil, err
+ }
+ }
+
+ f, err := os.OpenFile(filename, flag, perm)
+ if err != nil {
+ return nil, err
+ }
+ return &file{File: f}, err
+}
+
+func (fs *OS) createDir(fullpath string) error {
+ dir := filepath.Dir(fullpath)
+ if dir != "." {
+ if err := os.MkdirAll(dir, defaultDirectoryMode); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (fs *OS) ReadDir(path string) ([]os.FileInfo, error) {
+ l, err := ioutil.ReadDir(path)
+ if err != nil {
+ return nil, err
+ }
+
+ var s = make([]os.FileInfo, len(l))
+ for i, f := range l {
+ s[i] = f
+ }
+
+ return s, nil
+}
+
+func (fs *OS) Rename(from, to string) error {
+ if err := fs.createDir(to); err != nil {
+ return err
+ }
+
+ return os.Rename(from, to)
+}
+
+func (fs *OS) MkdirAll(path string, perm os.FileMode) error {
+ return os.MkdirAll(path, defaultDirectoryMode)
+}
+
+func (fs *OS) Open(filename string) (billy.File, error) {
+ return fs.OpenFile(filename, os.O_RDONLY, 0)
+}
+
+func (fs *OS) Stat(filename string) (os.FileInfo, error) {
+ return os.Stat(filename)
+}
+
+func (fs *OS) Remove(filename string) error {
+ return os.Remove(filename)
+}
+
+func (fs *OS) TempFile(dir, prefix string) (billy.File, error) {
+ if err := fs.createDir(dir + string(os.PathSeparator)); err != nil {
+ return nil, err
+ }
+
+ f, err := ioutil.TempFile(dir, prefix)
+ if err != nil {
+ return nil, err
+ }
+ return &file{File: f}, nil
+}
+
+func (fs *OS) Join(elem ...string) string {
+ return filepath.Join(elem...)
+}
+
+func (fs *OS) RemoveAll(path string) error {
+ return os.RemoveAll(filepath.Clean(path))
+}
+
+func (fs *OS) Lstat(filename string) (os.FileInfo, error) {
+ return os.Lstat(filepath.Clean(filename))
+}
+
+func (fs *OS) Symlink(target, link string) error {
+ if err := fs.createDir(link); err != nil {
+ return err
+ }
+
+ return os.Symlink(target, link)
+}
+
+func (fs *OS) Readlink(link string) (string, error) {
+ return os.Readlink(link)
+}
+
+// Capabilities implements the Capable interface.
+func (fs *OS) Capabilities() billy.Capability {
+ return billy.DefaultCapabilities
+}
+
+// file is a wrapper for an os.File which adds support for file locking.
+type file struct {
+ *os.File
+ m sync.Mutex
+}
--- /dev/null
+// +build !windows
+
+package osfs
+
+import (
+ "syscall"
+)
+
+func (f *file) Lock() error {
+ f.m.Lock()
+ defer f.m.Unlock()
+
+ return syscall.Flock(int(f.File.Fd()), syscall.LOCK_EX)
+}
+
+func (f *file) Unlock() error {
+ f.m.Lock()
+ defer f.m.Unlock()
+
+ return syscall.Flock(int(f.File.Fd()), syscall.LOCK_UN)
+}
--- /dev/null
+// +build windows
+
+package osfs
+
+import (
+ "os"
+ "runtime"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+type fileInfo struct {
+ os.FileInfo
+ name string
+}
+
+func (fi *fileInfo) Name() string {
+ return fi.name
+}
+
+var (
+ kernel32DLL = windows.NewLazySystemDLL("kernel32.dll")
+ lockFileExProc = kernel32DLL.NewProc("LockFileEx")
+ unlockFileProc = kernel32DLL.NewProc("UnlockFile")
+)
+
+const (
+ lockfileExclusiveLock = 0x2
+)
+
+func (f *file) Lock() error {
+ f.m.Lock()
+ defer f.m.Unlock()
+
+ var overlapped windows.Overlapped
+ // err is always non-nil as per sys/windows semantics.
+ ret, _, err := lockFileExProc.Call(f.File.Fd(), lockfileExclusiveLock, 0, 0xFFFFFFFF, 0,
+ uintptr(unsafe.Pointer(&overlapped)))
+ runtime.KeepAlive(&overlapped)
+ if ret == 0 {
+ return err
+ }
+ return nil
+}
+
+func (f *file) Unlock() error {
+ f.m.Lock()
+ defer f.m.Unlock()
+
+ // err is always non-nil as per sys/windows semantics.
+ ret, _, err := unlockFileProc.Call(f.File.Fd(), 0, 0, 0xFFFFFFFF, 0)
+ if ret == 0 {
+ return err
+ }
+ return nil
+}
--- /dev/null
+package util
+
+import (
+ "path/filepath"
+ "sort"
+ "strings"
+
+ "gopkg.in/src-d/go-billy.v4"
+)
+
+// Glob returns the names of all files matching pattern or nil
+// if there is no matching file. The syntax of patterns is the same
+// as in Match. The pattern may describe hierarchical names such as
+// /usr/*/bin/ed (assuming the Separator is '/').
+//
+// Glob ignores file system errors such as I/O errors reading directories.
+// The only possible returned error is ErrBadPattern, when pattern
+// is malformed.
+//
+// Function originally from https://golang.org/src/path/filepath/match_test.go
+func Glob(fs billy.Filesystem, pattern string) (matches []string, err error) {
+ if !hasMeta(pattern) {
+ if _, err = fs.Lstat(pattern); err != nil {
+ return nil, nil
+ }
+ return []string{pattern}, nil
+ }
+
+ dir, file := filepath.Split(pattern)
+ // Prevent infinite recursion. See issue 15879.
+ if dir == pattern {
+ return nil, filepath.ErrBadPattern
+ }
+
+ var m []string
+ m, err = Glob(fs, cleanGlobPath(dir))
+ if err != nil {
+ return
+ }
+ for _, d := range m {
+ matches, err = glob(fs, d, file, matches)
+ if err != nil {
+ return
+ }
+ }
+ return
+}
+
+// cleanGlobPath prepares path for glob matching.
+func cleanGlobPath(path string) string {
+ switch path {
+ case "":
+ return "."
+ case string(filepath.Separator):
+ // do nothing to the path
+ return path
+ default:
+ return path[0 : len(path)-1] // chop off trailing separator
+ }
+}
+
+// glob searches for files matching pattern in the directory dir
+// and appends them to matches. If the directory cannot be
+// opened, it returns the existing matches. New matches are
+// added in lexicographical order.
+func glob(fs billy.Filesystem, dir, pattern string, matches []string) (m []string, e error) {
+ m = matches
+ fi, err := fs.Stat(dir)
+ if err != nil {
+ return
+ }
+
+ if !fi.IsDir() {
+ return
+ }
+
+ names, _ := readdirnames(fs, dir)
+ sort.Strings(names)
+
+ for _, n := range names {
+ matched, err := filepath.Match(pattern, n)
+ if err != nil {
+ return m, err
+ }
+ if matched {
+ m = append(m, filepath.Join(dir, n))
+ }
+ }
+ return
+}
+
+// hasMeta reports whether path contains any of the magic characters
+// recognized by Match.
+func hasMeta(path string) bool {
+ // TODO(niemeyer): Should other magic characters be added here?
+ return strings.ContainsAny(path, "*?[")
+}
+
+func readdirnames(fs billy.Filesystem, dir string) ([]string, error) {
+ files, err := fs.ReadDir(dir)
+ if err != nil {
+ return nil, err
+ }
+
+ var names []string
+ for _, file := range files {
+ names = append(names, file.Name())
+ }
+
+ return names, nil
+}
--- /dev/null
+package util
+
+import (
+ "io"
+ "os"
+ "path/filepath"
+ "strconv"
+ "sync"
+ "time"
+
+ "gopkg.in/src-d/go-billy.v4"
+)
+
+// RemoveAll removes path and any children it contains. It removes everything it
+// can but returns the first error it encounters. If the path does not exist,
+// RemoveAll returns nil (no error).
+func RemoveAll(fs billy.Basic, path string) error {
+ fs, path = getUnderlyingAndPath(fs, path)
+
+ if r, ok := fs.(removerAll); ok {
+ return r.RemoveAll(path)
+ }
+
+ return removeAll(fs, path)
+}
+
+type removerAll interface {
+ RemoveAll(string) error
+}
+
+func removeAll(fs billy.Basic, path string) error {
+ // This implementation is adapted from os.RemoveAll.
+
+ // Simple case: if Remove works, we're done.
+ err := fs.Remove(path)
+ if err == nil || os.IsNotExist(err) {
+ return nil
+ }
+
+ // Otherwise, is this a directory we need to recurse into?
+ dir, serr := fs.Stat(path)
+ if serr != nil {
+ if os.IsNotExist(serr) {
+ return nil
+ }
+
+ return serr
+ }
+
+ if !dir.IsDir() {
+ // Not a directory; return the error from Remove.
+ return err
+ }
+
+ dirfs, ok := fs.(billy.Dir)
+ if !ok {
+ return billy.ErrNotSupported
+ }
+
+ // Directory.
+ fis, err := dirfs.ReadDir(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ // Race. It was deleted between the Lstat and Open.
+ // Return nil per RemoveAll's docs.
+ return nil
+ }
+
+ return err
+ }
+
+ // Remove contents & return first error.
+ err = nil
+ for _, fi := range fis {
+ cpath := fs.Join(path, fi.Name())
+ err1 := removeAll(fs, cpath)
+ if err == nil {
+ err = err1
+ }
+ }
+
+ // Remove directory.
+ err1 := fs.Remove(path)
+ if err1 == nil || os.IsNotExist(err1) {
+ return nil
+ }
+
+ if err == nil {
+ err = err1
+ }
+
+ return err
+
+}
+
+// WriteFile writes data to a file named by filename in the given filesystem.
+// If the file does not exist, WriteFile creates it with permissions perm;
+// otherwise WriteFile truncates it before writing.
+func WriteFile(fs billy.Basic, filename string, data []byte, perm os.FileMode) error {
+ f, err := fs.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm)
+ if err != nil {
+ return err
+ }
+
+ n, err := f.Write(data)
+ if err == nil && n < len(data) {
+ err = io.ErrShortWrite
+ }
+
+ if err1 := f.Close(); err == nil {
+ err = err1
+ }
+
+ return err
+}
+
+// Random number state.
+// We generate random temporary file names so that there's a good
+// chance the file doesn't exist yet - keeps the number of tries in
+// TempFile to a minimum.
+var rand uint32
+var randmu sync.Mutex
+
+func reseed() uint32 {
+ return uint32(time.Now().UnixNano() + int64(os.Getpid()))
+}
+
+func nextSuffix() string {
+ randmu.Lock()
+ r := rand
+ if r == 0 {
+ r = reseed()
+ }
+ r = r*1664525 + 1013904223 // constants from Numerical Recipes
+ rand = r
+ randmu.Unlock()
+ return strconv.Itoa(int(1e9 + r%1e9))[1:]
+}
+
+// TempFile creates a new temporary file in the directory dir with a name
+// beginning with prefix, opens the file for reading and writing, and returns
+// the resulting *os.File. If dir is the empty string, TempFile uses the default
+// directory for temporary files (see os.TempDir). Multiple programs calling
+// TempFile simultaneously will not choose the same file. The caller can use
+// f.Name() to find the pathname of the file. It is the caller's responsibility
+// to remove the file when no longer needed.
+func TempFile(fs billy.Basic, dir, prefix string) (f billy.File, err error) {
+ // This implementation is based on stdlib ioutil.TempFile.
+
+ if dir == "" {
+ dir = os.TempDir()
+ }
+
+ nconflict := 0
+ for i := 0; i < 10000; i++ {
+ name := filepath.Join(dir, prefix+nextSuffix())
+ f, err = fs.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0600)
+ if os.IsExist(err) {
+ if nconflict++; nconflict > 10 {
+ randmu.Lock()
+ rand = reseed()
+ randmu.Unlock()
+ }
+ continue
+ }
+ break
+ }
+ return
+}
+
+// TempDir creates a new temporary directory in the directory dir
+// with a name beginning with prefix and returns the path of the
+// new directory. If dir is the empty string, TempDir uses the
+// default directory for temporary files (see os.TempDir).
+// Multiple programs calling TempDir simultaneously
+// will not choose the same directory. It is the caller's responsibility
+// to remove the directory when no longer needed.
+func TempDir(fs billy.Dir, dir, prefix string) (name string, err error) {
+ // This implementation is based on stdlib ioutil.TempDir
+
+ if dir == "" {
+ dir = os.TempDir()
+ }
+
+ nconflict := 0
+ for i := 0; i < 10000; i++ {
+ try := filepath.Join(dir, prefix+nextSuffix())
+ err = fs.MkdirAll(try, 0700)
+ if os.IsExist(err) {
+ if nconflict++; nconflict > 10 {
+ randmu.Lock()
+ rand = reseed()
+ randmu.Unlock()
+ }
+ continue
+ }
+ if os.IsNotExist(err) {
+ if _, err := os.Stat(dir); os.IsNotExist(err) {
+ return "", err
+ }
+ }
+ if err == nil {
+ name = try
+ }
+ break
+ }
+ return
+}
+
+type underlying interface {
+ Underlying() billy.Basic
+}
+
+func getUnderlyingAndPath(fs billy.Basic, path string) (billy.Basic, string) {
+ u, ok := fs.(underlying)
+ if !ok {
+ return fs, path
+ }
+ if ch, ok := fs.(billy.Chroot); ok {
+ path = fs.Join(ch.Root(), path)
+ }
+
+ return u.Underlying(), path
+}
--- /dev/null
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2018 Sourced Technologies, S.L.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
--- /dev/null
+package git
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "strconv"
+ "strings"
+ "time"
+ "unicode/utf8"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/object"
+ "gopkg.in/src-d/go-git.v4/utils/diff"
+)
+
+// BlameResult represents the result of a Blame operation.
+type BlameResult struct {
+ // Path is the path of the File that we're blaming.
+ Path string
+ // Rev (Revision) is the hash of the specified Commit used to generate this result.
+ Rev plumbing.Hash
+ // Lines contains every line with its authorship.
+ Lines []*Line
+}
+
+// Blame returns a BlameResult with the information about the last author of
+// each line from file `path` at commit `c`.
+func Blame(c *object.Commit, path string) (*BlameResult, error) {
+ // The file to blame is identified by the input arguments:
+ // commit and path. commit is a Commit object obtained from a Repository. Path
+ // represents a path to a specific file contained into the repository.
+ //
+ // Blaming a file is a two step process:
+ //
+ // 1. Create a linear history of the commits affecting a file. We use
+ // revlist.New for that.
+ //
+ // 2. Then build a graph with a node for every line in every file in
+ // the history of the file.
+ //
+ // Each node is assigned a commit: Start by the nodes in the first
+ // commit. Assign that commit as the creator of all its lines.
+ //
+ // Then jump to the nodes in the next commit, and calculate the diff
+ // between the two files. Newly created lines get
+ // assigned the new commit as its origin. Modified lines also get
+ // this new commit. Untouched lines retain the old commit.
+ //
+ // All this work is done in the assignOrigin function which holds all
+ // the internal relevant data in a "blame" struct, that is not
+ // exported.
+ //
+ // TODO: ways to improve the efficiency of this function:
+ // 1. Improve revlist
+ // 2. Improve how to traverse the history (example a backward traversal will
+ // be much more efficient)
+ //
+ // TODO: ways to improve the function in general:
+ // 1. Add memoization between revlist and assign.
+ // 2. It is using much more memory than needed, see the TODOs below.
+
+ b := new(blame)
+ b.fRev = c
+ b.path = path
+
+ // get all the file revisions
+ if err := b.fillRevs(); err != nil {
+ return nil, err
+ }
+
+ // calculate the line tracking graph and fill in
+ // file contents in data.
+ if err := b.fillGraphAndData(); err != nil {
+ return nil, err
+ }
+
+ file, err := b.fRev.File(b.path)
+ if err != nil {
+ return nil, err
+ }
+ finalLines, err := file.Lines()
+ if err != nil {
+ return nil, err
+ }
+
+ // Each node (line) holds the commit where it was introduced or
+ // last modified. To achieve that we use the FORWARD algorithm
+ // described in Zimmermann, et al. "Mining Version Archives for
+ // Co-changed Lines", in proceedings of the Mining Software
+ // Repositories workshop, Shanghai, May 22-23, 2006.
+ lines, err := newLines(finalLines, b.sliceGraph(len(b.graph)-1))
+ if err != nil {
+ return nil, err
+ }
+
+ return &BlameResult{
+ Path: path,
+ Rev: c.Hash,
+ Lines: lines,
+ }, nil
+}
+
+// Line values represent the contents and author of a line in BlamedResult values.
+type Line struct {
+ // Author is the email address of the last author that modified the line.
+ Author string
+ // Text is the original text of the line.
+ Text string
+ // Date is when the original text of the line was introduced
+ Date time.Time
+ // Hash is the commit hash that introduced the original line
+ Hash plumbing.Hash
+}
+
+func newLine(author, text string, date time.Time, hash plumbing.Hash) *Line {
+ return &Line{
+ Author: author,
+ Text: text,
+ Hash: hash,
+ Date: date,
+ }
+}
+
+func newLines(contents []string, commits []*object.Commit) ([]*Line, error) {
+ lcontents := len(contents)
+ lcommits := len(commits)
+
+ if lcontents != lcommits {
+ if lcontents == lcommits-1 && contents[lcontents-1] != "\n" {
+ contents = append(contents, "\n")
+ } else {
+ return nil, errors.New("contents and commits have different length")
+ }
+ }
+
+ result := make([]*Line, 0, lcontents)
+ for i := range contents {
+ result = append(result, newLine(
+ commits[i].Author.Email, contents[i],
+ commits[i].Author.When, commits[i].Hash,
+ ))
+ }
+
+ return result, nil
+}
+
+// this struct is internally used by the blame function to hold its
+// inputs, outputs and state.
+type blame struct {
+ // the path of the file to blame
+ path string
+ // the commit of the final revision of the file to blame
+ fRev *object.Commit
+ // the chain of revisions affecting the the file to blame
+ revs []*object.Commit
+ // the contents of the file across all its revisions
+ data []string
+ // the graph of the lines in the file across all the revisions
+ graph [][]*object.Commit
+}
+
+// calculate the history of a file "path", starting from commit "from", sorted by commit date.
+func (b *blame) fillRevs() error {
+ var err error
+
+ b.revs, err = references(b.fRev, b.path)
+ return err
+}
+
+// build graph of a file from its revision history
+func (b *blame) fillGraphAndData() error {
+ //TODO: not all commits are needed, only the current rev and the prev
+ b.graph = make([][]*object.Commit, len(b.revs))
+ b.data = make([]string, len(b.revs)) // file contents in all the revisions
+ // for every revision of the file, starting with the first
+ // one...
+ for i, rev := range b.revs {
+ // get the contents of the file
+ file, err := rev.File(b.path)
+ if err != nil {
+ return nil
+ }
+ b.data[i], err = file.Contents()
+ if err != nil {
+ return err
+ }
+ nLines := countLines(b.data[i])
+ // create a node for each line
+ b.graph[i] = make([]*object.Commit, nLines)
+ // assign a commit to each node
+ // if this is the first revision, then the node is assigned to
+ // this first commit.
+ if i == 0 {
+ for j := 0; j < nLines; j++ {
+ b.graph[i][j] = (*object.Commit)(b.revs[i])
+ }
+ } else {
+ // if this is not the first commit, then assign to the old
+ // commit or to the new one, depending on what the diff
+ // says.
+ b.assignOrigin(i, i-1)
+ }
+ }
+ return nil
+}
+
+// sliceGraph returns a slice of commits (one per line) for a particular
+// revision of a file (0=first revision).
+func (b *blame) sliceGraph(i int) []*object.Commit {
+ fVs := b.graph[i]
+ result := make([]*object.Commit, 0, len(fVs))
+ for _, v := range fVs {
+ c := object.Commit(*v)
+ result = append(result, &c)
+ }
+ return result
+}
+
+// Assigns origin to vertexes in current (c) rev from data in its previous (p)
+// revision
+func (b *blame) assignOrigin(c, p int) {
+ // assign origin based on diff info
+ hunks := diff.Do(b.data[p], b.data[c])
+ sl := -1 // source line
+ dl := -1 // destination line
+ for h := range hunks {
+ hLines := countLines(hunks[h].Text)
+ for hl := 0; hl < hLines; hl++ {
+ switch {
+ case hunks[h].Type == 0:
+ sl++
+ dl++
+ b.graph[c][dl] = b.graph[p][sl]
+ case hunks[h].Type == 1:
+ dl++
+ b.graph[c][dl] = (*object.Commit)(b.revs[c])
+ case hunks[h].Type == -1:
+ sl++
+ default:
+ panic("unreachable")
+ }
+ }
+ }
+}
+
+// GoString prints the results of a Blame using git-blame's style.
+func (b *blame) GoString() string {
+ var buf bytes.Buffer
+
+ file, err := b.fRev.File(b.path)
+ if err != nil {
+ panic("PrettyPrint: internal error in repo.Data")
+ }
+ contents, err := file.Contents()
+ if err != nil {
+ panic("PrettyPrint: internal error in repo.Data")
+ }
+
+ lines := strings.Split(contents, "\n")
+ // max line number length
+ mlnl := len(strconv.Itoa(len(lines)))
+ // max author length
+ mal := b.maxAuthorLength()
+ format := fmt.Sprintf("%%s (%%-%ds %%%dd) %%s\n",
+ mal, mlnl)
+
+ fVs := b.graph[len(b.graph)-1]
+ for ln, v := range fVs {
+ fmt.Fprintf(&buf, format, v.Hash.String()[:8],
+ prettyPrintAuthor(fVs[ln]), ln+1, lines[ln])
+ }
+ return buf.String()
+}
+
+// utility function to pretty print the author.
+func prettyPrintAuthor(c *object.Commit) string {
+ return fmt.Sprintf("%s %s", c.Author.Name, c.Author.When.Format("2006-01-02"))
+}
+
+// utility function to calculate the number of runes needed
+// to print the longest author name in the blame of a file.
+func (b *blame) maxAuthorLength() int {
+ memo := make(map[plumbing.Hash]struct{}, len(b.graph)-1)
+ fVs := b.graph[len(b.graph)-1]
+ m := 0
+ for ln := range fVs {
+ if _, ok := memo[fVs[ln].Hash]; ok {
+ continue
+ }
+ memo[fVs[ln].Hash] = struct{}{}
+ m = max(m, utf8.RuneCountInString(prettyPrintAuthor(fVs[ln])))
+ }
+ return m
+}
+
+func max(a, b int) int {
+ if a > b {
+ return a
+ }
+ return b
+}
--- /dev/null
+package git
+
+import "strings"
+
+const defaultDotGitPath = ".git"
+
+// countLines returns the number of lines in a string à la git, this is
+// The newline character is assumed to be '\n'. The empty string
+// contains 0 lines. If the last line of the string doesn't end with a
+// newline, it will still be considered a line.
+func countLines(s string) int {
+ if s == "" {
+ return 0
+ }
+
+ nEOL := strings.Count(s, "\n")
+ if strings.HasSuffix(s, "\n") {
+ return nEOL
+ }
+
+ return nEOL + 1
+}
--- /dev/null
+package config
+
+import (
+ "errors"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ format "gopkg.in/src-d/go-git.v4/plumbing/format/config"
+)
+
+var (
+ errBranchEmptyName = errors.New("branch config: empty name")
+ errBranchInvalidMerge = errors.New("branch config: invalid merge")
+)
+
+// Branch contains information on the
+// local branches and which remote to track
+type Branch struct {
+ // Name of branch
+ Name string
+ // Remote name of remote to track
+ Remote string
+ // Merge is the local refspec for the branch
+ Merge plumbing.ReferenceName
+
+ raw *format.Subsection
+}
+
+// Validate validates fields of branch
+func (b *Branch) Validate() error {
+ if b.Name == "" {
+ return errBranchEmptyName
+ }
+
+ if b.Merge != "" && !b.Merge.IsBranch() {
+ return errBranchInvalidMerge
+ }
+
+ return nil
+}
+
+func (b *Branch) marshal() *format.Subsection {
+ if b.raw == nil {
+ b.raw = &format.Subsection{}
+ }
+
+ b.raw.Name = b.Name
+
+ if b.Remote == "" {
+ b.raw.RemoveOption(remoteSection)
+ } else {
+ b.raw.SetOption(remoteSection, b.Remote)
+ }
+
+ if b.Merge == "" {
+ b.raw.RemoveOption(mergeKey)
+ } else {
+ b.raw.SetOption(mergeKey, string(b.Merge))
+ }
+
+ return b.raw
+}
+
+func (b *Branch) unmarshal(s *format.Subsection) error {
+ b.raw = s
+
+ b.Name = b.raw.Name
+ b.Remote = b.raw.Options.Get(remoteSection)
+ b.Merge = plumbing.ReferenceName(b.raw.Options.Get(mergeKey))
+
+ return b.Validate()
+}
--- /dev/null
+// Package config contains the abstraction of multiple config files
+package config
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "sort"
+ "strconv"
+
+ format "gopkg.in/src-d/go-git.v4/plumbing/format/config"
+)
+
+const (
+ // DefaultFetchRefSpec is the default refspec used for fetch.
+ DefaultFetchRefSpec = "+refs/heads/*:refs/remotes/%s/*"
+ // DefaultPushRefSpec is the default refspec used for push.
+ DefaultPushRefSpec = "refs/heads/*:refs/heads/*"
+)
+
+// ConfigStorer generic storage of Config object
+type ConfigStorer interface {
+ Config() (*Config, error)
+ SetConfig(*Config) error
+}
+
+var (
+ ErrInvalid = errors.New("config invalid key in remote or branch")
+ ErrRemoteConfigNotFound = errors.New("remote config not found")
+ ErrRemoteConfigEmptyURL = errors.New("remote config: empty URL")
+ ErrRemoteConfigEmptyName = errors.New("remote config: empty name")
+)
+
+// Config contains the repository configuration
+// ftp://www.kernel.org/pub/software/scm/git/docs/git-config.html#FILES
+type Config struct {
+ Core struct {
+ // IsBare if true this repository is assumed to be bare and has no
+ // working directory associated with it.
+ IsBare bool
+ // Worktree is the path to the root of the working tree.
+ Worktree string
+ // CommentChar is the character indicating the start of a
+ // comment for commands like commit and tag
+ CommentChar string
+ }
+
+ Pack struct {
+ // Window controls the size of the sliding window for delta
+ // compression. The default is 10. A value of 0 turns off
+ // delta compression entirely.
+ Window uint
+ }
+
+ // Remotes list of repository remotes, the key of the map is the name
+ // of the remote, should equal to RemoteConfig.Name.
+ Remotes map[string]*RemoteConfig
+ // Submodules list of repository submodules, the key of the map is the name
+ // of the submodule, should equal to Submodule.Name.
+ Submodules map[string]*Submodule
+ // Branches list of branches, the key is the branch name and should
+ // equal Branch.Name
+ Branches map[string]*Branch
+ // Raw contains the raw information of a config file. The main goal is
+ // preserve the parsed information from the original format, to avoid
+ // dropping unsupported fields.
+ Raw *format.Config
+}
+
+// NewConfig returns a new empty Config.
+func NewConfig() *Config {
+ config := &Config{
+ Remotes: make(map[string]*RemoteConfig),
+ Submodules: make(map[string]*Submodule),
+ Branches: make(map[string]*Branch),
+ Raw: format.New(),
+ }
+
+ config.Pack.Window = DefaultPackWindow
+
+ return config
+}
+
+// Validate validates the fields and sets the default values.
+func (c *Config) Validate() error {
+ for name, r := range c.Remotes {
+ if r.Name != name {
+ return ErrInvalid
+ }
+
+ if err := r.Validate(); err != nil {
+ return err
+ }
+ }
+
+ for name, b := range c.Branches {
+ if b.Name != name {
+ return ErrInvalid
+ }
+
+ if err := b.Validate(); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+const (
+ remoteSection = "remote"
+ submoduleSection = "submodule"
+ branchSection = "branch"
+ coreSection = "core"
+ packSection = "pack"
+ fetchKey = "fetch"
+ urlKey = "url"
+ bareKey = "bare"
+ worktreeKey = "worktree"
+ commentCharKey = "commentChar"
+ windowKey = "window"
+ mergeKey = "merge"
+
+ // DefaultPackWindow holds the number of previous objects used to
+ // generate deltas. The value 10 is the same used by git command.
+ DefaultPackWindow = uint(10)
+)
+
+// Unmarshal parses a git-config file and stores it.
+func (c *Config) Unmarshal(b []byte) error {
+ r := bytes.NewBuffer(b)
+ d := format.NewDecoder(r)
+
+ c.Raw = format.New()
+ if err := d.Decode(c.Raw); err != nil {
+ return err
+ }
+
+ c.unmarshalCore()
+ if err := c.unmarshalPack(); err != nil {
+ return err
+ }
+ unmarshalSubmodules(c.Raw, c.Submodules)
+
+ if err := c.unmarshalBranches(); err != nil {
+ return err
+ }
+
+ return c.unmarshalRemotes()
+}
+
+func (c *Config) unmarshalCore() {
+ s := c.Raw.Section(coreSection)
+ if s.Options.Get(bareKey) == "true" {
+ c.Core.IsBare = true
+ }
+
+ c.Core.Worktree = s.Options.Get(worktreeKey)
+ c.Core.CommentChar = s.Options.Get(commentCharKey)
+}
+
+func (c *Config) unmarshalPack() error {
+ s := c.Raw.Section(packSection)
+ window := s.Options.Get(windowKey)
+ if window == "" {
+ c.Pack.Window = DefaultPackWindow
+ } else {
+ winUint, err := strconv.ParseUint(window, 10, 32)
+ if err != nil {
+ return err
+ }
+ c.Pack.Window = uint(winUint)
+ }
+ return nil
+}
+
+func (c *Config) unmarshalRemotes() error {
+ s := c.Raw.Section(remoteSection)
+ for _, sub := range s.Subsections {
+ r := &RemoteConfig{}
+ if err := r.unmarshal(sub); err != nil {
+ return err
+ }
+
+ c.Remotes[r.Name] = r
+ }
+
+ return nil
+}
+
+func unmarshalSubmodules(fc *format.Config, submodules map[string]*Submodule) {
+ s := fc.Section(submoduleSection)
+ for _, sub := range s.Subsections {
+ m := &Submodule{}
+ m.unmarshal(sub)
+
+ if m.Validate() == ErrModuleBadPath {
+ continue
+ }
+
+ submodules[m.Name] = m
+ }
+}
+
+func (c *Config) unmarshalBranches() error {
+ bs := c.Raw.Section(branchSection)
+ for _, sub := range bs.Subsections {
+ b := &Branch{}
+
+ if err := b.unmarshal(sub); err != nil {
+ return err
+ }
+
+ c.Branches[b.Name] = b
+ }
+ return nil
+}
+
+// Marshal returns Config encoded as a git-config file.
+func (c *Config) Marshal() ([]byte, error) {
+ c.marshalCore()
+ c.marshalPack()
+ c.marshalRemotes()
+ c.marshalSubmodules()
+ c.marshalBranches()
+
+ buf := bytes.NewBuffer(nil)
+ if err := format.NewEncoder(buf).Encode(c.Raw); err != nil {
+ return nil, err
+ }
+
+ return buf.Bytes(), nil
+}
+
+func (c *Config) marshalCore() {
+ s := c.Raw.Section(coreSection)
+ s.SetOption(bareKey, fmt.Sprintf("%t", c.Core.IsBare))
+
+ if c.Core.Worktree != "" {
+ s.SetOption(worktreeKey, c.Core.Worktree)
+ }
+}
+
+func (c *Config) marshalPack() {
+ s := c.Raw.Section(packSection)
+ if c.Pack.Window != DefaultPackWindow {
+ s.SetOption(windowKey, fmt.Sprintf("%d", c.Pack.Window))
+ }
+}
+
+func (c *Config) marshalRemotes() {
+ s := c.Raw.Section(remoteSection)
+ newSubsections := make(format.Subsections, 0, len(c.Remotes))
+ added := make(map[string]bool)
+ for _, subsection := range s.Subsections {
+ if remote, ok := c.Remotes[subsection.Name]; ok {
+ newSubsections = append(newSubsections, remote.marshal())
+ added[subsection.Name] = true
+ }
+ }
+
+ remoteNames := make([]string, 0, len(c.Remotes))
+ for name := range c.Remotes {
+ remoteNames = append(remoteNames, name)
+ }
+
+ sort.Strings(remoteNames)
+
+ for _, name := range remoteNames {
+ if !added[name] {
+ newSubsections = append(newSubsections, c.Remotes[name].marshal())
+ }
+ }
+
+ s.Subsections = newSubsections
+}
+
+func (c *Config) marshalSubmodules() {
+ s := c.Raw.Section(submoduleSection)
+ s.Subsections = make(format.Subsections, len(c.Submodules))
+
+ var i int
+ for _, r := range c.Submodules {
+ section := r.marshal()
+ // the submodule section at config is a subset of the .gitmodule file
+ // we should remove the non-valid options for the config file.
+ section.RemoveOption(pathKey)
+ s.Subsections[i] = section
+ i++
+ }
+}
+
+func (c *Config) marshalBranches() {
+ s := c.Raw.Section(branchSection)
+ newSubsections := make(format.Subsections, 0, len(c.Branches))
+ added := make(map[string]bool)
+ for _, subsection := range s.Subsections {
+ if branch, ok := c.Branches[subsection.Name]; ok {
+ newSubsections = append(newSubsections, branch.marshal())
+ added[subsection.Name] = true
+ }
+ }
+
+ branchNames := make([]string, 0, len(c.Branches))
+ for name := range c.Branches {
+ branchNames = append(branchNames, name)
+ }
+
+ sort.Strings(branchNames)
+
+ for _, name := range branchNames {
+ if !added[name] {
+ newSubsections = append(newSubsections, c.Branches[name].marshal())
+ }
+ }
+
+ s.Subsections = newSubsections
+}
+
+// RemoteConfig contains the configuration for a given remote repository.
+type RemoteConfig struct {
+ // Name of the remote
+ Name string
+ // URLs the URLs of a remote repository. It must be non-empty. Fetch will
+ // always use the first URL, while push will use all of them.
+ URLs []string
+ // Fetch the default set of "refspec" for fetch operation
+ Fetch []RefSpec
+
+ // raw representation of the subsection, filled by marshal or unmarshal are
+ // called
+ raw *format.Subsection
+}
+
+// Validate validates the fields and sets the default values.
+func (c *RemoteConfig) Validate() error {
+ if c.Name == "" {
+ return ErrRemoteConfigEmptyName
+ }
+
+ if len(c.URLs) == 0 {
+ return ErrRemoteConfigEmptyURL
+ }
+
+ for _, r := range c.Fetch {
+ if err := r.Validate(); err != nil {
+ return err
+ }
+ }
+
+ if len(c.Fetch) == 0 {
+ c.Fetch = []RefSpec{RefSpec(fmt.Sprintf(DefaultFetchRefSpec, c.Name))}
+ }
+
+ return nil
+}
+
+func (c *RemoteConfig) unmarshal(s *format.Subsection) error {
+ c.raw = s
+
+ fetch := []RefSpec{}
+ for _, f := range c.raw.Options.GetAll(fetchKey) {
+ rs := RefSpec(f)
+ if err := rs.Validate(); err != nil {
+ return err
+ }
+
+ fetch = append(fetch, rs)
+ }
+
+ c.Name = c.raw.Name
+ c.URLs = append([]string(nil), c.raw.Options.GetAll(urlKey)...)
+ c.Fetch = fetch
+
+ return nil
+}
+
+func (c *RemoteConfig) marshal() *format.Subsection {
+ if c.raw == nil {
+ c.raw = &format.Subsection{}
+ }
+
+ c.raw.Name = c.Name
+ if len(c.URLs) == 0 {
+ c.raw.RemoveOption(urlKey)
+ } else {
+ c.raw.SetOption(urlKey, c.URLs...)
+ }
+
+ if len(c.Fetch) == 0 {
+ c.raw.RemoveOption(fetchKey)
+ } else {
+ var values []string
+ for _, rs := range c.Fetch {
+ values = append(values, rs.String())
+ }
+
+ c.raw.SetOption(fetchKey, values...)
+ }
+
+ return c.raw
+}
--- /dev/null
+package config
+
+import (
+ "bytes"
+ "errors"
+ "regexp"
+
+ format "gopkg.in/src-d/go-git.v4/plumbing/format/config"
+)
+
+var (
+ ErrModuleEmptyURL = errors.New("module config: empty URL")
+ ErrModuleEmptyPath = errors.New("module config: empty path")
+ ErrModuleBadPath = errors.New("submodule has an invalid path")
+)
+
+var (
+ // Matches module paths with dotdot ".." components.
+ dotdotPath = regexp.MustCompile(`(^|[/\\])\.\.([/\\]|$)`)
+)
+
+// Modules defines the submodules properties, represents a .gitmodules file
+// https://www.kernel.org/pub/software/scm/git/docs/gitmodules.html
+type Modules struct {
+ // Submodules is a map of submodules being the key the name of the submodule.
+ Submodules map[string]*Submodule
+
+ raw *format.Config
+}
+
+// NewModules returns a new empty Modules
+func NewModules() *Modules {
+ return &Modules{
+ Submodules: make(map[string]*Submodule),
+ raw: format.New(),
+ }
+}
+
+const (
+ pathKey = "path"
+ branchKey = "branch"
+)
+
+// Unmarshal parses a git-config file and stores it.
+func (m *Modules) Unmarshal(b []byte) error {
+ r := bytes.NewBuffer(b)
+ d := format.NewDecoder(r)
+
+ m.raw = format.New()
+ if err := d.Decode(m.raw); err != nil {
+ return err
+ }
+
+ unmarshalSubmodules(m.raw, m.Submodules)
+ return nil
+}
+
+// Marshal returns Modules encoded as a git-config file.
+func (m *Modules) Marshal() ([]byte, error) {
+ s := m.raw.Section(submoduleSection)
+ s.Subsections = make(format.Subsections, len(m.Submodules))
+
+ var i int
+ for _, r := range m.Submodules {
+ s.Subsections[i] = r.marshal()
+ i++
+ }
+
+ buf := bytes.NewBuffer(nil)
+ if err := format.NewEncoder(buf).Encode(m.raw); err != nil {
+ return nil, err
+ }
+
+ return buf.Bytes(), nil
+}
+
+// Submodule defines a submodule.
+type Submodule struct {
+ // Name module name
+ Name string
+ // Path defines the path, relative to the top-level directory of the Git
+ // working tree.
+ Path string
+ // URL defines a URL from which the submodule repository can be cloned.
+ URL string
+ // Branch is a remote branch name for tracking updates in the upstream
+ // submodule. Optional value.
+ Branch string
+
+ // raw representation of the subsection, filled by marshal or unmarshal are
+ // called.
+ raw *format.Subsection
+}
+
+// Validate validates the fields and sets the default values.
+func (m *Submodule) Validate() error {
+ if m.Path == "" {
+ return ErrModuleEmptyPath
+ }
+
+ if m.URL == "" {
+ return ErrModuleEmptyURL
+ }
+
+ if dotdotPath.MatchString(m.Path) {
+ return ErrModuleBadPath
+ }
+
+ return nil
+}
+
+func (m *Submodule) unmarshal(s *format.Subsection) {
+ m.raw = s
+
+ m.Name = m.raw.Name
+ m.Path = m.raw.Option(pathKey)
+ m.URL = m.raw.Option(urlKey)
+ m.Branch = m.raw.Option(branchKey)
+}
+
+func (m *Submodule) marshal() *format.Subsection {
+ if m.raw == nil {
+ m.raw = &format.Subsection{}
+ }
+
+ m.raw.Name = m.Name
+ if m.raw.Name == "" {
+ m.raw.Name = m.Path
+ }
+
+ m.raw.SetOption(pathKey, m.Path)
+ m.raw.SetOption(urlKey, m.URL)
+
+ if m.Branch != "" {
+ m.raw.SetOption(branchKey, m.Branch)
+ }
+
+ return m.raw
+}
--- /dev/null
+package config
+
+import (
+ "errors"
+ "strings"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+)
+
+const (
+ refSpecWildcard = "*"
+ refSpecForce = "+"
+ refSpecSeparator = ":"
+)
+
+var (
+ ErrRefSpecMalformedSeparator = errors.New("malformed refspec, separators are wrong")
+ ErrRefSpecMalformedWildcard = errors.New("malformed refspec, mismatched number of wildcards")
+)
+
+// RefSpec is a mapping from local branches to remote references
+// The format of the refspec is an optional +, followed by <src>:<dst>, where
+// <src> is the pattern for references on the remote side and <dst> is where
+// those references will be written locally. The + tells Git to update the
+// reference even if it isn’t a fast-forward.
+// eg.: "+refs/heads/*:refs/remotes/origin/*"
+//
+// https://git-scm.com/book/es/v2/Git-Internals-The-Refspec
+type RefSpec string
+
+// Validate validates the RefSpec
+func (s RefSpec) Validate() error {
+ spec := string(s)
+ if strings.Count(spec, refSpecSeparator) != 1 {
+ return ErrRefSpecMalformedSeparator
+ }
+
+ sep := strings.Index(spec, refSpecSeparator)
+ if sep == len(spec)-1 {
+ return ErrRefSpecMalformedSeparator
+ }
+
+ ws := strings.Count(spec[0:sep], refSpecWildcard)
+ wd := strings.Count(spec[sep+1:], refSpecWildcard)
+ if ws == wd && ws < 2 && wd < 2 {
+ return nil
+ }
+
+ return ErrRefSpecMalformedWildcard
+}
+
+// IsForceUpdate returns if update is allowed in non fast-forward merges.
+func (s RefSpec) IsForceUpdate() bool {
+ return s[0] == refSpecForce[0]
+}
+
+// IsDelete returns true if the refspec indicates a delete (empty src).
+func (s RefSpec) IsDelete() bool {
+ return s[0] == refSpecSeparator[0]
+}
+
+// Src return the src side.
+func (s RefSpec) Src() string {
+ spec := string(s)
+
+ var start int
+ if s.IsForceUpdate() {
+ start = 1
+ } else {
+ start = 0
+ }
+ end := strings.Index(spec, refSpecSeparator)
+
+ return spec[start:end]
+}
+
+// Match match the given plumbing.ReferenceName against the source.
+func (s RefSpec) Match(n plumbing.ReferenceName) bool {
+ if !s.IsWildcard() {
+ return s.matchExact(n)
+ }
+
+ return s.matchGlob(n)
+}
+
+// IsWildcard returns true if the RefSpec contains a wildcard.
+func (s RefSpec) IsWildcard() bool {
+ return strings.Contains(string(s), refSpecWildcard)
+}
+
+func (s RefSpec) matchExact(n plumbing.ReferenceName) bool {
+ return s.Src() == n.String()
+}
+
+func (s RefSpec) matchGlob(n plumbing.ReferenceName) bool {
+ src := s.Src()
+ name := n.String()
+ wildcard := strings.Index(src, refSpecWildcard)
+
+ var prefix, suffix string
+ prefix = src[0:wildcard]
+ if len(src) < wildcard {
+ suffix = src[wildcard+1 : len(suffix)]
+ }
+
+ return len(name) > len(prefix)+len(suffix) &&
+ strings.HasPrefix(name, prefix) &&
+ strings.HasSuffix(name, suffix)
+}
+
+// Dst returns the destination for the given remote reference.
+func (s RefSpec) Dst(n plumbing.ReferenceName) plumbing.ReferenceName {
+ spec := string(s)
+ start := strings.Index(spec, refSpecSeparator) + 1
+ dst := spec[start:]
+ src := s.Src()
+
+ if !s.IsWildcard() {
+ return plumbing.ReferenceName(dst)
+ }
+
+ name := n.String()
+ ws := strings.Index(src, refSpecWildcard)
+ wd := strings.Index(dst, refSpecWildcard)
+ match := name[ws : len(name)-(len(src)-(ws+1))]
+
+ return plumbing.ReferenceName(dst[0:wd] + match + dst[wd+1:])
+}
+
+func (s RefSpec) String() string {
+ return string(s)
+}
+
+// MatchAny returns true if any of the RefSpec match with the given ReferenceName.
+func MatchAny(l []RefSpec, n plumbing.ReferenceName) bool {
+ for _, r := range l {
+ if r.Match(n) {
+ return true
+ }
+ }
+
+ return false
+}
--- /dev/null
+// A highly extensible git implementation in pure Go.
+//
+// go-git aims to reach the completeness of libgit2 or jgit, nowadays covers the
+// majority of the plumbing read operations and some of the main write
+// operations, but lacks the main porcelain operations such as merges.
+//
+// It is highly extensible, we have been following the open/close principle in
+// its design to facilitate extensions, mainly focusing the efforts on the
+// persistence of the objects.
+package git // import "gopkg.in/src-d/go-git.v4"
--- /dev/null
+// Package revision extracts git revision from string
+// More informations about revision : https://www.kernel.org/pub/software/scm/git/docs/gitrevisions.html
+package revision
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "regexp"
+ "strconv"
+ "time"
+)
+
+// ErrInvalidRevision is emitted if string doesn't match valid revision
+type ErrInvalidRevision struct {
+ s string
+}
+
+func (e *ErrInvalidRevision) Error() string {
+ return "Revision invalid : " + e.s
+}
+
+// Revisioner represents a revision component.
+// A revision is made of multiple revision components
+// obtained after parsing a revision string,
+// for instance revision "master~" will be converted in
+// two revision components Ref and TildePath
+type Revisioner interface {
+}
+
+// Ref represents a reference name : HEAD, master
+type Ref string
+
+// TildePath represents ~, ~{n}
+type TildePath struct {
+ Depth int
+}
+
+// CaretPath represents ^, ^{n}
+type CaretPath struct {
+ Depth int
+}
+
+// CaretReg represents ^{/foo bar}
+type CaretReg struct {
+ Regexp *regexp.Regexp
+ Negate bool
+}
+
+// CaretType represents ^{commit}
+type CaretType struct {
+ ObjectType string
+}
+
+// AtReflog represents @{n}
+type AtReflog struct {
+ Depth int
+}
+
+// AtCheckout represents @{-n}
+type AtCheckout struct {
+ Depth int
+}
+
+// AtUpstream represents @{upstream}, @{u}
+type AtUpstream struct {
+ BranchName string
+}
+
+// AtPush represents @{push}
+type AtPush struct {
+ BranchName string
+}
+
+// AtDate represents @{"2006-01-02T15:04:05Z"}
+type AtDate struct {
+ Date time.Time
+}
+
+// ColonReg represents :/foo bar
+type ColonReg struct {
+ Regexp *regexp.Regexp
+ Negate bool
+}
+
+// ColonPath represents :./<path> :<path>
+type ColonPath struct {
+ Path string
+}
+
+// ColonStagePath represents :<n>:/<path>
+type ColonStagePath struct {
+ Path string
+ Stage int
+}
+
+// Parser represents a parser
+// use to tokenize and transform to revisioner chunks
+// a given string
+type Parser struct {
+ s *scanner
+ currentParsedChar struct {
+ tok token
+ lit string
+ }
+ unreadLastChar bool
+}
+
+// NewParserFromString returns a new instance of parser from a string.
+func NewParserFromString(s string) *Parser {
+ return NewParser(bytes.NewBufferString(s))
+}
+
+// NewParser returns a new instance of parser.
+func NewParser(r io.Reader) *Parser {
+ return &Parser{s: newScanner(r)}
+}
+
+// scan returns the next token from the underlying scanner
+// or the last scanned token if an unscan was requested
+func (p *Parser) scan() (token, string, error) {
+ if p.unreadLastChar {
+ p.unreadLastChar = false
+ return p.currentParsedChar.tok, p.currentParsedChar.lit, nil
+ }
+
+ tok, lit, err := p.s.scan()
+
+ p.currentParsedChar.tok, p.currentParsedChar.lit = tok, lit
+
+ return tok, lit, err
+}
+
+// unscan pushes the previously read token back onto the buffer.
+func (p *Parser) unscan() { p.unreadLastChar = true }
+
+// Parse explode a revision string into revisioner chunks
+func (p *Parser) Parse() ([]Revisioner, error) {
+ var rev Revisioner
+ var revs []Revisioner
+ var tok token
+ var err error
+
+ for {
+ tok, _, err = p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ switch tok {
+ case at:
+ rev, err = p.parseAt()
+ case tilde:
+ rev, err = p.parseTilde()
+ case caret:
+ rev, err = p.parseCaret()
+ case colon:
+ rev, err = p.parseColon()
+ case eof:
+ err = p.validateFullRevision(&revs)
+
+ if err != nil {
+ return []Revisioner{}, err
+ }
+
+ return revs, nil
+ default:
+ p.unscan()
+ rev, err = p.parseRef()
+ }
+
+ if err != nil {
+ return []Revisioner{}, err
+ }
+
+ revs = append(revs, rev)
+ }
+}
+
+// validateFullRevision ensures all revisioner chunks make a valid revision
+func (p *Parser) validateFullRevision(chunks *[]Revisioner) error {
+ var hasReference bool
+
+ for i, chunk := range *chunks {
+ switch chunk.(type) {
+ case Ref:
+ if i == 0 {
+ hasReference = true
+ } else {
+ return &ErrInvalidRevision{`reference must be defined once at the beginning`}
+ }
+ case AtDate:
+ if len(*chunks) == 1 || hasReference && len(*chunks) == 2 {
+ return nil
+ }
+
+ return &ErrInvalidRevision{`"@" statement is not valid, could be : <refname>@{<ISO-8601 date>}, @{<ISO-8601 date>}`}
+ case AtReflog:
+ if len(*chunks) == 1 || hasReference && len(*chunks) == 2 {
+ return nil
+ }
+
+ return &ErrInvalidRevision{`"@" statement is not valid, could be : <refname>@{<n>}, @{<n>}`}
+ case AtCheckout:
+ if len(*chunks) == 1 {
+ return nil
+ }
+
+ return &ErrInvalidRevision{`"@" statement is not valid, could be : @{-<n>}`}
+ case AtUpstream:
+ if len(*chunks) == 1 || hasReference && len(*chunks) == 2 {
+ return nil
+ }
+
+ return &ErrInvalidRevision{`"@" statement is not valid, could be : <refname>@{upstream}, @{upstream}, <refname>@{u}, @{u}`}
+ case AtPush:
+ if len(*chunks) == 1 || hasReference && len(*chunks) == 2 {
+ return nil
+ }
+
+ return &ErrInvalidRevision{`"@" statement is not valid, could be : <refname>@{push}, @{push}`}
+ case TildePath, CaretPath, CaretReg:
+ if !hasReference {
+ return &ErrInvalidRevision{`"~" or "^" statement must have a reference defined at the beginning`}
+ }
+ case ColonReg:
+ if len(*chunks) == 1 {
+ return nil
+ }
+
+ return &ErrInvalidRevision{`":" statement is not valid, could be : :/<regexp>`}
+ case ColonPath:
+ if i == len(*chunks)-1 && hasReference || len(*chunks) == 1 {
+ return nil
+ }
+
+ return &ErrInvalidRevision{`":" statement is not valid, could be : <revision>:<path>`}
+ case ColonStagePath:
+ if len(*chunks) == 1 {
+ return nil
+ }
+
+ return &ErrInvalidRevision{`":" statement is not valid, could be : :<n>:<path>`}
+ }
+ }
+
+ return nil
+}
+
+// parseAt extract @ statements
+func (p *Parser) parseAt() (Revisioner, error) {
+ var tok, nextTok token
+ var lit, nextLit string
+ var err error
+
+ tok, _, err = p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ if tok != obrace {
+ p.unscan()
+
+ return Ref("HEAD"), nil
+ }
+
+ tok, lit, err = p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ nextTok, nextLit, err = p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ switch {
+ case tok == word && (lit == "u" || lit == "upstream") && nextTok == cbrace:
+ return AtUpstream{}, nil
+ case tok == word && lit == "push" && nextTok == cbrace:
+ return AtPush{}, nil
+ case tok == number && nextTok == cbrace:
+ n, _ := strconv.Atoi(lit)
+
+ return AtReflog{n}, nil
+ case tok == minus && nextTok == number:
+ n, _ := strconv.Atoi(nextLit)
+
+ t, _, err := p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ if t != cbrace {
+ return nil, &ErrInvalidRevision{fmt.Sprintf(`missing "}" in @{-n} structure`)}
+ }
+
+ return AtCheckout{n}, nil
+ default:
+ p.unscan()
+
+ date := lit
+
+ for {
+ tok, lit, err = p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ switch {
+ case tok == cbrace:
+ t, err := time.Parse("2006-01-02T15:04:05Z", date)
+
+ if err != nil {
+ return nil, &ErrInvalidRevision{fmt.Sprintf(`wrong date "%s" must fit ISO-8601 format : 2006-01-02T15:04:05Z`, date)}
+ }
+
+ return AtDate{t}, nil
+ default:
+ date += lit
+ }
+ }
+ }
+}
+
+// parseTilde extract ~ statements
+func (p *Parser) parseTilde() (Revisioner, error) {
+ var tok token
+ var lit string
+ var err error
+
+ tok, lit, err = p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ switch {
+ case tok == number:
+ n, _ := strconv.Atoi(lit)
+
+ return TildePath{n}, nil
+ default:
+ p.unscan()
+ return TildePath{1}, nil
+ }
+}
+
+// parseCaret extract ^ statements
+func (p *Parser) parseCaret() (Revisioner, error) {
+ var tok token
+ var lit string
+ var err error
+
+ tok, lit, err = p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ switch {
+ case tok == obrace:
+ r, err := p.parseCaretBraces()
+
+ if err != nil {
+ return nil, err
+ }
+
+ return r, nil
+ case tok == number:
+ n, _ := strconv.Atoi(lit)
+
+ if n > 2 {
+ return nil, &ErrInvalidRevision{fmt.Sprintf(`"%s" found must be 0, 1 or 2 after "^"`, lit)}
+ }
+
+ return CaretPath{n}, nil
+ default:
+ p.unscan()
+ return CaretPath{1}, nil
+ }
+}
+
+// parseCaretBraces extract ^{<data>} statements
+func (p *Parser) parseCaretBraces() (Revisioner, error) {
+ var tok, nextTok token
+ var lit, _ string
+ start := true
+ var re string
+ var negate bool
+ var err error
+
+ for {
+ tok, lit, err = p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ nextTok, _, err = p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ switch {
+ case tok == word && nextTok == cbrace && (lit == "commit" || lit == "tree" || lit == "blob" || lit == "tag" || lit == "object"):
+ return CaretType{lit}, nil
+ case re == "" && tok == cbrace:
+ return CaretType{"tag"}, nil
+ case re == "" && tok == emark && nextTok == emark:
+ re += lit
+ case re == "" && tok == emark && nextTok == minus:
+ negate = true
+ case re == "" && tok == emark:
+ return nil, &ErrInvalidRevision{fmt.Sprintf(`revision suffix brace component sequences starting with "/!" others than those defined are reserved`)}
+ case re == "" && tok == slash:
+ p.unscan()
+ case tok != slash && start:
+ return nil, &ErrInvalidRevision{fmt.Sprintf(`"%s" is not a valid revision suffix brace component`, lit)}
+ case tok != cbrace:
+ p.unscan()
+ re += lit
+ case tok == cbrace:
+ p.unscan()
+
+ reg, err := regexp.Compile(re)
+
+ if err != nil {
+ return CaretReg{}, &ErrInvalidRevision{fmt.Sprintf(`revision suffix brace component, %s`, err.Error())}
+ }
+
+ return CaretReg{reg, negate}, nil
+ }
+
+ start = false
+ }
+}
+
+// parseColon extract : statements
+func (p *Parser) parseColon() (Revisioner, error) {
+ var tok token
+ var err error
+
+ tok, _, err = p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ switch tok {
+ case slash:
+ return p.parseColonSlash()
+ default:
+ p.unscan()
+ return p.parseColonDefault()
+ }
+}
+
+// parseColonSlash extract :/<data> statements
+func (p *Parser) parseColonSlash() (Revisioner, error) {
+ var tok, nextTok token
+ var lit string
+ var re string
+ var negate bool
+ var err error
+
+ for {
+ tok, lit, err = p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ nextTok, _, err = p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ switch {
+ case tok == emark && nextTok == emark:
+ re += lit
+ case re == "" && tok == emark && nextTok == minus:
+ negate = true
+ case re == "" && tok == emark:
+ return nil, &ErrInvalidRevision{fmt.Sprintf(`revision suffix brace component sequences starting with "/!" others than those defined are reserved`)}
+ case tok == eof:
+ p.unscan()
+ reg, err := regexp.Compile(re)
+
+ if err != nil {
+ return ColonReg{}, &ErrInvalidRevision{fmt.Sprintf(`revision suffix brace component, %s`, err.Error())}
+ }
+
+ return ColonReg{reg, negate}, nil
+ default:
+ p.unscan()
+ re += lit
+ }
+ }
+}
+
+// parseColonDefault extract :<data> statements
+func (p *Parser) parseColonDefault() (Revisioner, error) {
+ var tok token
+ var lit string
+ var path string
+ var stage int
+ var err error
+ var n = -1
+
+ tok, lit, err = p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ nextTok, _, err := p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ if tok == number && nextTok == colon {
+ n, _ = strconv.Atoi(lit)
+ }
+
+ switch n {
+ case 0, 1, 2, 3:
+ stage = n
+ default:
+ path += lit
+ p.unscan()
+ }
+
+ for {
+ tok, lit, err = p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ switch {
+ case tok == eof && n == -1:
+ return ColonPath{path}, nil
+ case tok == eof:
+ return ColonStagePath{path, stage}, nil
+ default:
+ path += lit
+ }
+ }
+}
+
+// parseRef extract reference name
+func (p *Parser) parseRef() (Revisioner, error) {
+ var tok, prevTok token
+ var lit, buf string
+ var endOfRef bool
+ var err error
+
+ for {
+ tok, lit, err = p.scan()
+
+ if err != nil {
+ return nil, err
+ }
+
+ switch tok {
+ case eof, at, colon, tilde, caret:
+ endOfRef = true
+ }
+
+ err := p.checkRefFormat(tok, lit, prevTok, buf, endOfRef)
+
+ if err != nil {
+ return "", err
+ }
+
+ if endOfRef {
+ p.unscan()
+ return Ref(buf), nil
+ }
+
+ buf += lit
+ prevTok = tok
+ }
+}
+
+// checkRefFormat ensure reference name follow rules defined here :
+// https://git-scm.com/docs/git-check-ref-format
+func (p *Parser) checkRefFormat(token token, literal string, previousToken token, buffer string, endOfRef bool) error {
+ switch token {
+ case aslash, space, control, qmark, asterisk, obracket:
+ return &ErrInvalidRevision{fmt.Sprintf(`must not contains "%s"`, literal)}
+ }
+
+ switch {
+ case (token == dot || token == slash) && buffer == "":
+ return &ErrInvalidRevision{fmt.Sprintf(`must not start with "%s"`, literal)}
+ case previousToken == slash && endOfRef:
+ return &ErrInvalidRevision{`must not end with "/"`}
+ case previousToken == dot && endOfRef:
+ return &ErrInvalidRevision{`must not end with "."`}
+ case token == dot && previousToken == slash:
+ return &ErrInvalidRevision{`must not contains "/."`}
+ case previousToken == dot && token == dot:
+ return &ErrInvalidRevision{`must not contains ".."`}
+ case previousToken == slash && token == slash:
+ return &ErrInvalidRevision{`must not contains consecutively "/"`}
+ case (token == slash || endOfRef) && len(buffer) > 4 && buffer[len(buffer)-5:] == ".lock":
+ return &ErrInvalidRevision{"cannot end with .lock"}
+ }
+
+ return nil
+}
--- /dev/null
+package revision
+
+import (
+ "bufio"
+ "io"
+ "unicode"
+)
+
+// runeCategoryValidator takes a rune as input and
+// validates it belongs to a rune category
+type runeCategoryValidator func(r rune) bool
+
+// tokenizeExpression aggegates a series of runes matching check predicate into a single
+// string and provides given tokenType as token type
+func tokenizeExpression(ch rune, tokenType token, check runeCategoryValidator, r *bufio.Reader) (token, string, error) {
+ var data []rune
+ data = append(data, ch)
+
+ for {
+ c, _, err := r.ReadRune()
+
+ if c == zeroRune {
+ break
+ }
+
+ if err != nil {
+ return tokenError, "", err
+ }
+
+ if check(c) {
+ data = append(data, c)
+ } else {
+ err := r.UnreadRune()
+
+ if err != nil {
+ return tokenError, "", err
+ }
+
+ return tokenType, string(data), nil
+ }
+ }
+
+ return tokenType, string(data), nil
+}
+
+var zeroRune = rune(0)
+
+// scanner represents a lexical scanner.
+type scanner struct {
+ r *bufio.Reader
+}
+
+// newScanner returns a new instance of scanner.
+func newScanner(r io.Reader) *scanner {
+ return &scanner{r: bufio.NewReader(r)}
+}
+
+// Scan extracts tokens and their strings counterpart
+// from the reader
+func (s *scanner) scan() (token, string, error) {
+ ch, _, err := s.r.ReadRune()
+
+ if err != nil && err != io.EOF {
+ return tokenError, "", err
+ }
+
+ switch ch {
+ case zeroRune:
+ return eof, "", nil
+ case ':':
+ return colon, string(ch), nil
+ case '~':
+ return tilde, string(ch), nil
+ case '^':
+ return caret, string(ch), nil
+ case '.':
+ return dot, string(ch), nil
+ case '/':
+ return slash, string(ch), nil
+ case '{':
+ return obrace, string(ch), nil
+ case '}':
+ return cbrace, string(ch), nil
+ case '-':
+ return minus, string(ch), nil
+ case '@':
+ return at, string(ch), nil
+ case '\\':
+ return aslash, string(ch), nil
+ case '?':
+ return qmark, string(ch), nil
+ case '*':
+ return asterisk, string(ch), nil
+ case '[':
+ return obracket, string(ch), nil
+ case '!':
+ return emark, string(ch), nil
+ }
+
+ if unicode.IsSpace(ch) {
+ return space, string(ch), nil
+ }
+
+ if unicode.IsControl(ch) {
+ return control, string(ch), nil
+ }
+
+ if unicode.IsLetter(ch) {
+ return tokenizeExpression(ch, word, unicode.IsLetter, s.r)
+ }
+
+ if unicode.IsNumber(ch) {
+ return tokenizeExpression(ch, number, unicode.IsNumber, s.r)
+ }
+
+ return tokenError, string(ch), nil
+}
--- /dev/null
+package revision
+
+// token represents a entity extracted from string parsing
+type token int
+
+const (
+ eof token = iota
+
+ aslash
+ asterisk
+ at
+ caret
+ cbrace
+ colon
+ control
+ dot
+ emark
+ minus
+ number
+ obrace
+ obracket
+ qmark
+ slash
+ space
+ tilde
+ tokenError
+ word
+)
--- /dev/null
+package git
+
+import (
+ "fmt"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/filemode"
+ "gopkg.in/src-d/go-git.v4/plumbing/object"
+ "gopkg.in/src-d/go-git.v4/storage"
+)
+
+type objectWalker struct {
+ Storer storage.Storer
+ // seen is the set of objects seen in the repo.
+ // seen map can become huge if walking over large
+ // repos. Thus using struct{} as the value type.
+ seen map[plumbing.Hash]struct{}
+}
+
+func newObjectWalker(s storage.Storer) *objectWalker {
+ return &objectWalker{s, map[plumbing.Hash]struct{}{}}
+}
+
+// walkAllRefs walks all (hash) refererences from the repo.
+func (p *objectWalker) walkAllRefs() error {
+ // Walk over all the references in the repo.
+ it, err := p.Storer.IterReferences()
+ if err != nil {
+ return err
+ }
+ defer it.Close()
+ err = it.ForEach(func(ref *plumbing.Reference) error {
+ // Exit this iteration early for non-hash references.
+ if ref.Type() != plumbing.HashReference {
+ return nil
+ }
+ return p.walkObjectTree(ref.Hash())
+ })
+ return err
+}
+
+func (p *objectWalker) isSeen(hash plumbing.Hash) bool {
+ _, seen := p.seen[hash]
+ return seen
+}
+
+func (p *objectWalker) add(hash plumbing.Hash) {
+ p.seen[hash] = struct{}{}
+}
+
+// walkObjectTree walks over all objects and remembers references
+// to them in the objectWalker. This is used instead of the revlist
+// walks because memory usage is tight with huge repos.
+func (p *objectWalker) walkObjectTree(hash plumbing.Hash) error {
+ // Check if we have already seen, and mark this object
+ if p.isSeen(hash) {
+ return nil
+ }
+ p.add(hash)
+ // Fetch the object.
+ obj, err := object.GetObject(p.Storer, hash)
+ if err != nil {
+ return fmt.Errorf("Getting object %s failed: %v", hash, err)
+ }
+ // Walk all children depending on object type.
+ switch obj := obj.(type) {
+ case *object.Commit:
+ err = p.walkObjectTree(obj.TreeHash)
+ if err != nil {
+ return err
+ }
+ for _, h := range obj.ParentHashes {
+ err = p.walkObjectTree(h)
+ if err != nil {
+ return err
+ }
+ }
+ case *object.Tree:
+ for i := range obj.Entries {
+ // Shortcut for blob objects:
+ // 'or' the lower bits of a mode and check that it
+ // it matches a filemode.Executable. The type information
+ // is in the higher bits, but this is the cleanest way
+ // to handle plain files with different modes.
+ // Other non-tree objects are somewhat rare, so they
+ // are not special-cased.
+ if obj.Entries[i].Mode|0755 == filemode.Executable {
+ p.add(obj.Entries[i].Hash)
+ continue
+ }
+ // Normal walk for sub-trees (and symlinks etc).
+ err = p.walkObjectTree(obj.Entries[i].Hash)
+ if err != nil {
+ return err
+ }
+ }
+ case *object.Tag:
+ return p.walkObjectTree(obj.Target)
+ default:
+ // Error out on unhandled object types.
+ return fmt.Errorf("Unknown object %X %s %T\n", obj.ID(), obj.Type(), obj)
+ }
+ return nil
+}
--- /dev/null
+package git
+
+import (
+ "errors"
+ "regexp"
+ "strings"
+
+ "golang.org/x/crypto/openpgp"
+ "gopkg.in/src-d/go-git.v4/config"
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/object"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/sideband"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport"
+)
+
+// SubmoduleRescursivity defines how depth will affect any submodule recursive
+// operation.
+type SubmoduleRescursivity uint
+
+const (
+ // DefaultRemoteName name of the default Remote, just like git command.
+ DefaultRemoteName = "origin"
+
+ // NoRecurseSubmodules disables the recursion for a submodule operation.
+ NoRecurseSubmodules SubmoduleRescursivity = 0
+ // DefaultSubmoduleRecursionDepth allow recursion in a submodule operation.
+ DefaultSubmoduleRecursionDepth SubmoduleRescursivity = 10
+)
+
+var (
+ ErrMissingURL = errors.New("URL field is required")
+)
+
+// CloneOptions describes how a clone should be performed.
+type CloneOptions struct {
+ // The (possibly remote) repository URL to clone from.
+ URL string
+ // Auth credentials, if required, to use with the remote repository.
+ Auth transport.AuthMethod
+ // Name of the remote to be added, by default `origin`.
+ RemoteName string
+ // Remote branch to clone.
+ ReferenceName plumbing.ReferenceName
+ // Fetch only ReferenceName if true.
+ SingleBranch bool
+ // No checkout of HEAD after clone if true.
+ NoCheckout bool
+ // Limit fetching to the specified number of commits.
+ Depth int
+ // RecurseSubmodules after the clone is created, initialize all submodules
+ // within, using their default settings. This option is ignored if the
+ // cloned repository does not have a worktree.
+ RecurseSubmodules SubmoduleRescursivity
+ // Progress is where the human readable information sent by the server is
+ // stored, if nil nothing is stored and the capability (if supported)
+ // no-progress, is sent to the server to avoid send this information.
+ Progress sideband.Progress
+ // Tags describe how the tags will be fetched from the remote repository,
+ // by default is AllTags.
+ Tags TagMode
+}
+
+// Validate validates the fields and sets the default values.
+func (o *CloneOptions) Validate() error {
+ if o.URL == "" {
+ return ErrMissingURL
+ }
+
+ if o.RemoteName == "" {
+ o.RemoteName = DefaultRemoteName
+ }
+
+ if o.ReferenceName == "" {
+ o.ReferenceName = plumbing.HEAD
+ }
+
+ if o.Tags == InvalidTagMode {
+ o.Tags = AllTags
+ }
+
+ return nil
+}
+
+// PullOptions describes how a pull should be performed.
+type PullOptions struct {
+ // Name of the remote to be pulled. If empty, uses the default.
+ RemoteName string
+ // Remote branch to clone. If empty, uses HEAD.
+ ReferenceName plumbing.ReferenceName
+ // Fetch only ReferenceName if true.
+ SingleBranch bool
+ // Limit fetching to the specified number of commits.
+ Depth int
+ // Auth credentials, if required, to use with the remote repository.
+ Auth transport.AuthMethod
+ // RecurseSubmodules controls if new commits of all populated submodules
+ // should be fetched too.
+ RecurseSubmodules SubmoduleRescursivity
+ // Progress is where the human readable information sent by the server is
+ // stored, if nil nothing is stored and the capability (if supported)
+ // no-progress, is sent to the server to avoid send this information.
+ Progress sideband.Progress
+ // Force allows the pull to update a local branch even when the remote
+ // branch does not descend from it.
+ Force bool
+}
+
+// Validate validates the fields and sets the default values.
+func (o *PullOptions) Validate() error {
+ if o.RemoteName == "" {
+ o.RemoteName = DefaultRemoteName
+ }
+
+ if o.ReferenceName == "" {
+ o.ReferenceName = plumbing.HEAD
+ }
+
+ return nil
+}
+
+type TagMode int
+
+const (
+ InvalidTagMode TagMode = iota
+ // TagFollowing any tag that points into the histories being fetched is also
+ // fetched. TagFollowing requires a server with `include-tag` capability
+ // in order to fetch the annotated tags objects.
+ TagFollowing
+ // AllTags fetch all tags from the remote (i.e., fetch remote tags
+ // refs/tags/* into local tags with the same name)
+ AllTags
+ //NoTags fetch no tags from the remote at all
+ NoTags
+)
+
+// FetchOptions describes how a fetch should be performed
+type FetchOptions struct {
+ // Name of the remote to fetch from. Defaults to origin.
+ RemoteName string
+ RefSpecs []config.RefSpec
+ // Depth limit fetching to the specified number of commits from the tip of
+ // each remote branch history.
+ Depth int
+ // Auth credentials, if required, to use with the remote repository.
+ Auth transport.AuthMethod
+ // Progress is where the human readable information sent by the server is
+ // stored, if nil nothing is stored and the capability (if supported)
+ // no-progress, is sent to the server to avoid send this information.
+ Progress sideband.Progress
+ // Tags describe how the tags will be fetched from the remote repository,
+ // by default is TagFollowing.
+ Tags TagMode
+ // Force allows the fetch to update a local branch even when the remote
+ // branch does not descend from it.
+ Force bool
+}
+
+// Validate validates the fields and sets the default values.
+func (o *FetchOptions) Validate() error {
+ if o.RemoteName == "" {
+ o.RemoteName = DefaultRemoteName
+ }
+
+ if o.Tags == InvalidTagMode {
+ o.Tags = TagFollowing
+ }
+
+ for _, r := range o.RefSpecs {
+ if err := r.Validate(); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// PushOptions describes how a push should be performed.
+type PushOptions struct {
+ // RemoteName is the name of the remote to be pushed to.
+ RemoteName string
+ // RefSpecs specify what destination ref to update with what source
+ // object. A refspec with empty src can be used to delete a reference.
+ RefSpecs []config.RefSpec
+ // Auth credentials, if required, to use with the remote repository.
+ Auth transport.AuthMethod
+ // Progress is where the human readable information sent by the server is
+ // stored, if nil nothing is stored.
+ Progress sideband.Progress
+}
+
+// Validate validates the fields and sets the default values.
+func (o *PushOptions) Validate() error {
+ if o.RemoteName == "" {
+ o.RemoteName = DefaultRemoteName
+ }
+
+ if len(o.RefSpecs) == 0 {
+ o.RefSpecs = []config.RefSpec{
+ config.RefSpec(config.DefaultPushRefSpec),
+ }
+ }
+
+ for _, r := range o.RefSpecs {
+ if err := r.Validate(); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// SubmoduleUpdateOptions describes how a submodule update should be performed.
+type SubmoduleUpdateOptions struct {
+ // Init, if true initializes the submodules recorded in the index.
+ Init bool
+ // NoFetch tell to the update command to not fetch new objects from the
+ // remote site.
+ NoFetch bool
+ // RecurseSubmodules the update is performed not only in the submodules of
+ // the current repository but also in any nested submodules inside those
+ // submodules (and so on). Until the SubmoduleRescursivity is reached.
+ RecurseSubmodules SubmoduleRescursivity
+ // Auth credentials, if required, to use with the remote repository.
+ Auth transport.AuthMethod
+}
+
+var (
+ ErrBranchHashExclusive = errors.New("Branch and Hash are mutually exclusive")
+ ErrCreateRequiresBranch = errors.New("Branch is mandatory when Create is used")
+)
+
+// CheckoutOptions describes how a checkout 31operation should be performed.
+type CheckoutOptions struct {
+ // Hash is the hash of the commit to be checked out. If used, HEAD will be
+ // in detached mode. If Create is not used, Branch and Hash are mutually
+ // exclusive.
+ Hash plumbing.Hash
+ // Branch to be checked out, if Branch and Hash are empty is set to `master`.
+ Branch plumbing.ReferenceName
+ // Create a new branch named Branch and start it at Hash.
+ Create bool
+ // Force, if true when switching branches, proceed even if the index or the
+ // working tree differs from HEAD. This is used to throw away local changes
+ Force bool
+}
+
+// Validate validates the fields and sets the default values.
+func (o *CheckoutOptions) Validate() error {
+ if !o.Create && !o.Hash.IsZero() && o.Branch != "" {
+ return ErrBranchHashExclusive
+ }
+
+ if o.Create && o.Branch == "" {
+ return ErrCreateRequiresBranch
+ }
+
+ if o.Branch == "" {
+ o.Branch = plumbing.Master
+ }
+
+ return nil
+}
+
+// ResetMode defines the mode of a reset operation.
+type ResetMode int8
+
+const (
+ // MixedReset resets the index but not the working tree (i.e., the changed
+ // files are preserved but not marked for commit) and reports what has not
+ // been updated. This is the default action.
+ MixedReset ResetMode = iota
+ // HardReset resets the index and working tree. Any changes to tracked files
+ // in the working tree are discarded.
+ HardReset
+ // MergeReset resets the index and updates the files in the working tree
+ // that are different between Commit and HEAD, but keeps those which are
+ // different between the index and working tree (i.e. which have changes
+ // which have not been added).
+ //
+ // If a file that is different between Commit and the index has unstaged
+ // changes, reset is aborted.
+ MergeReset
+ // SoftReset does not touch the index file or the working tree at all (but
+ // resets the head to <commit>, just like all modes do). This leaves all
+ // your changed files "Changes to be committed", as git status would put it.
+ SoftReset
+)
+
+// ResetOptions describes how a reset operation should be performed.
+type ResetOptions struct {
+ // Commit, if commit is pressent set the current branch head (HEAD) to it.
+ Commit plumbing.Hash
+ // Mode, form resets the current branch head to Commit and possibly updates
+ // the index (resetting it to the tree of Commit) and the working tree
+ // depending on Mode. If empty MixedReset is used.
+ Mode ResetMode
+}
+
+// Validate validates the fields and sets the default values.
+func (o *ResetOptions) Validate(r *Repository) error {
+ if o.Commit == plumbing.ZeroHash {
+ ref, err := r.Head()
+ if err != nil {
+ return err
+ }
+
+ o.Commit = ref.Hash()
+ }
+
+ return nil
+}
+
+type LogOrder int8
+
+const (
+ LogOrderDefault LogOrder = iota
+ LogOrderDFS
+ LogOrderDFSPost
+ LogOrderBSF
+ LogOrderCommitterTime
+)
+
+// LogOptions describes how a log action should be performed.
+type LogOptions struct {
+ // When the From option is set the log will only contain commits
+ // reachable from it. If this option is not set, HEAD will be used as
+ // the default From.
+ From plumbing.Hash
+
+ // The default traversal algorithm is Depth-first search
+ // set Order=LogOrderCommitterTime for ordering by committer time (more compatible with `git log`)
+ // set Order=LogOrderBSF for Breadth-first search
+ Order LogOrder
+
+ // Show only those commits in which the specified file was inserted/updated.
+ // It is equivalent to running `git log -- <file-name>`.
+ FileName *string
+}
+
+var (
+ ErrMissingAuthor = errors.New("author field is required")
+)
+
+// CommitOptions describes how a commit operation should be performed.
+type CommitOptions struct {
+ // All automatically stage files that have been modified and deleted, but
+ // new files you have not told Git about are not affected.
+ All bool
+ // Author is the author's signature of the commit.
+ Author *object.Signature
+ // Committer is the committer's signature of the commit. If Committer is
+ // nil the Author signature is used.
+ Committer *object.Signature
+ // Parents are the parents commits for the new commit, by default when
+ // len(Parents) is zero, the hash of HEAD reference is used.
+ Parents []plumbing.Hash
+ // SignKey denotes a key to sign the commit with. A nil value here means the
+ // commit will not be signed. The private key must be present and already
+ // decrypted.
+ SignKey *openpgp.Entity
+}
+
+// Validate validates the fields and sets the default values.
+func (o *CommitOptions) Validate(r *Repository) error {
+ if o.Author == nil {
+ return ErrMissingAuthor
+ }
+
+ if o.Committer == nil {
+ o.Committer = o.Author
+ }
+
+ if len(o.Parents) == 0 {
+ head, err := r.Head()
+ if err != nil && err != plumbing.ErrReferenceNotFound {
+ return err
+ }
+
+ if head != nil {
+ o.Parents = []plumbing.Hash{head.Hash()}
+ }
+ }
+
+ return nil
+}
+
+var (
+ ErrMissingName = errors.New("name field is required")
+ ErrMissingTagger = errors.New("tagger field is required")
+ ErrMissingMessage = errors.New("message field is required")
+)
+
+// CreateTagOptions describes how a tag object should be created.
+type CreateTagOptions struct {
+ // Tagger defines the signature of the tag creator.
+ Tagger *object.Signature
+ // Message defines the annotation of the tag. It is canonicalized during
+ // validation into the format expected by git - no leading whitespace and
+ // ending in a newline.
+ Message string
+ // SignKey denotes a key to sign the tag with. A nil value here means the tag
+ // will not be signed. The private key must be present and already decrypted.
+ SignKey *openpgp.Entity
+}
+
+// Validate validates the fields and sets the default values.
+func (o *CreateTagOptions) Validate(r *Repository, hash plumbing.Hash) error {
+ if o.Tagger == nil {
+ return ErrMissingTagger
+ }
+
+ if o.Message == "" {
+ return ErrMissingMessage
+ }
+
+ // Canonicalize the message into the expected message format.
+ o.Message = strings.TrimSpace(o.Message) + "\n"
+
+ return nil
+}
+
+// ListOptions describes how a remote list should be performed.
+type ListOptions struct {
+ // Auth credentials, if required, to use with the remote repository.
+ Auth transport.AuthMethod
+}
+
+// CleanOptions describes how a clean should be performed.
+type CleanOptions struct {
+ Dir bool
+}
+
+// GrepOptions describes how a grep should be performed.
+type GrepOptions struct {
+ // Patterns are compiled Regexp objects to be matched.
+ Patterns []*regexp.Regexp
+ // InvertMatch selects non-matching lines.
+ InvertMatch bool
+ // CommitHash is the hash of the commit from which worktree should be derived.
+ CommitHash plumbing.Hash
+ // ReferenceName is the branch or tag name from which worktree should be derived.
+ ReferenceName plumbing.ReferenceName
+ // PathSpecs are compiled Regexp objects of pathspec to use in the matching.
+ PathSpecs []*regexp.Regexp
+}
+
+var (
+ ErrHashOrReference = errors.New("ambiguous options, only one of CommitHash or ReferenceName can be passed")
+)
+
+// Validate validates the fields and sets the default values.
+func (o *GrepOptions) Validate(w *Worktree) error {
+ if !o.CommitHash.IsZero() && o.ReferenceName != "" {
+ return ErrHashOrReference
+ }
+
+ // If none of CommitHash and ReferenceName are provided, set commit hash of
+ // the repository's head.
+ if o.CommitHash.IsZero() && o.ReferenceName == "" {
+ ref, err := w.r.Head()
+ if err != nil {
+ return err
+ }
+ o.CommitHash = ref.Hash()
+ }
+
+ return nil
+}
+
+// PlainOpenOptions describes how opening a plain repository should be
+// performed.
+type PlainOpenOptions struct {
+ // DetectDotGit defines whether parent directories should be
+ // walked until a .git directory or file is found.
+ DetectDotGit bool
+}
+
+// Validate validates the fields and sets the default values.
+func (o *PlainOpenOptions) Validate() error { return nil }
--- /dev/null
+package cache
+
+import (
+ "container/list"
+ "sync"
+)
+
+// BufferLRU implements an object cache with an LRU eviction policy and a
+// maximum size (measured in object size).
+type BufferLRU struct {
+ MaxSize FileSize
+
+ actualSize FileSize
+ ll *list.List
+ cache map[int64]*list.Element
+ mut sync.Mutex
+}
+
+// NewBufferLRU creates a new BufferLRU with the given maximum size. The maximum
+// size will never be exceeded.
+func NewBufferLRU(maxSize FileSize) *BufferLRU {
+ return &BufferLRU{MaxSize: maxSize}
+}
+
+// NewBufferLRUDefault creates a new BufferLRU with the default cache size.
+func NewBufferLRUDefault() *BufferLRU {
+ return &BufferLRU{MaxSize: DefaultMaxSize}
+}
+
+type buffer struct {
+ Key int64
+ Slice []byte
+}
+
+// Put puts a buffer into the cache. If the buffer is already in the cache, it
+// will be marked as used. Otherwise, it will be inserted. A buffers might
+// be evicted to make room for the new one.
+func (c *BufferLRU) Put(key int64, slice []byte) {
+ c.mut.Lock()
+ defer c.mut.Unlock()
+
+ if c.cache == nil {
+ c.actualSize = 0
+ c.cache = make(map[int64]*list.Element, 1000)
+ c.ll = list.New()
+ }
+
+ bufSize := FileSize(len(slice))
+ if ee, ok := c.cache[key]; ok {
+ oldBuf := ee.Value.(buffer)
+ // in this case bufSize is a delta: new size - old size
+ bufSize -= FileSize(len(oldBuf.Slice))
+ c.ll.MoveToFront(ee)
+ ee.Value = buffer{key, slice}
+ } else {
+ if bufSize > c.MaxSize {
+ return
+ }
+ ee := c.ll.PushFront(buffer{key, slice})
+ c.cache[key] = ee
+ }
+
+ c.actualSize += bufSize
+ for c.actualSize > c.MaxSize {
+ last := c.ll.Back()
+ lastObj := last.Value.(buffer)
+ lastSize := FileSize(len(lastObj.Slice))
+
+ c.ll.Remove(last)
+ delete(c.cache, lastObj.Key)
+ c.actualSize -= lastSize
+ }
+}
+
+// Get returns a buffer by its key. It marks the buffer as used. If the buffer
+// is not in the cache, (nil, false) will be returned.
+func (c *BufferLRU) Get(key int64) ([]byte, bool) {
+ c.mut.Lock()
+ defer c.mut.Unlock()
+
+ ee, ok := c.cache[key]
+ if !ok {
+ return nil, false
+ }
+
+ c.ll.MoveToFront(ee)
+ return ee.Value.(buffer).Slice, true
+}
+
+// Clear the content of this buffer cache.
+func (c *BufferLRU) Clear() {
+ c.mut.Lock()
+ defer c.mut.Unlock()
+
+ c.ll = nil
+ c.cache = nil
+ c.actualSize = 0
+}
--- /dev/null
+package cache
+
+import "gopkg.in/src-d/go-git.v4/plumbing"
+
+const (
+ Byte FileSize = 1 << (iota * 10)
+ KiByte
+ MiByte
+ GiByte
+)
+
+type FileSize int64
+
+const DefaultMaxSize FileSize = 96 * MiByte
+
+// Object is an interface to a object cache.
+type Object interface {
+ // Put puts the given object into the cache. Whether this object will
+ // actually be put into the cache or not is implementation specific.
+ Put(o plumbing.EncodedObject)
+ // Get gets an object from the cache given its hash. The second return value
+ // is true if the object was returned, and false otherwise.
+ Get(k plumbing.Hash) (plumbing.EncodedObject, bool)
+ // Clear clears every object from the cache.
+ Clear()
+}
+
+// Buffer is an interface to a buffer cache.
+type Buffer interface {
+ // Put puts a buffer into the cache. If the buffer is already in the cache,
+ // it will be marked as used. Otherwise, it will be inserted. Buffer might
+ // be evicted to make room for the new one.
+ Put(key int64, slice []byte)
+ // Get returns a buffer by its key. It marks the buffer as used. If the
+ // buffer is not in the cache, (nil, false) will be returned.
+ Get(key int64) ([]byte, bool)
+ // Clear clears every object from the cache.
+ Clear()
+}
--- /dev/null
+package cache
+
+import (
+ "container/list"
+ "sync"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+)
+
+// ObjectLRU implements an object cache with an LRU eviction policy and a
+// maximum size (measured in object size).
+type ObjectLRU struct {
+ MaxSize FileSize
+
+ actualSize FileSize
+ ll *list.List
+ cache map[interface{}]*list.Element
+ mut sync.Mutex
+}
+
+// NewObjectLRU creates a new ObjectLRU with the given maximum size. The maximum
+// size will never be exceeded.
+func NewObjectLRU(maxSize FileSize) *ObjectLRU {
+ return &ObjectLRU{MaxSize: maxSize}
+}
+
+// NewObjectLRUDefault creates a new ObjectLRU with the default cache size.
+func NewObjectLRUDefault() *ObjectLRU {
+ return &ObjectLRU{MaxSize: DefaultMaxSize}
+}
+
+// Put puts an object into the cache. If the object is already in the cache, it
+// will be marked as used. Otherwise, it will be inserted. A single object might
+// be evicted to make room for the new object.
+func (c *ObjectLRU) Put(obj plumbing.EncodedObject) {
+ c.mut.Lock()
+ defer c.mut.Unlock()
+
+ if c.cache == nil {
+ c.actualSize = 0
+ c.cache = make(map[interface{}]*list.Element, 1000)
+ c.ll = list.New()
+ }
+
+ objSize := FileSize(obj.Size())
+ key := obj.Hash()
+ if ee, ok := c.cache[key]; ok {
+ oldObj := ee.Value.(plumbing.EncodedObject)
+ // in this case objSize is a delta: new size - old size
+ objSize -= FileSize(oldObj.Size())
+ c.ll.MoveToFront(ee)
+ ee.Value = obj
+ } else {
+ if objSize > c.MaxSize {
+ return
+ }
+ ee := c.ll.PushFront(obj)
+ c.cache[key] = ee
+ }
+
+ c.actualSize += objSize
+ for c.actualSize > c.MaxSize {
+ last := c.ll.Back()
+ lastObj := last.Value.(plumbing.EncodedObject)
+ lastSize := FileSize(lastObj.Size())
+
+ c.ll.Remove(last)
+ delete(c.cache, lastObj.Hash())
+ c.actualSize -= lastSize
+ }
+}
+
+// Get returns an object by its hash. It marks the object as used. If the object
+// is not in the cache, (nil, false) will be returned.
+func (c *ObjectLRU) Get(k plumbing.Hash) (plumbing.EncodedObject, bool) {
+ c.mut.Lock()
+ defer c.mut.Unlock()
+
+ ee, ok := c.cache[k]
+ if !ok {
+ return nil, false
+ }
+
+ c.ll.MoveToFront(ee)
+ return ee.Value.(plumbing.EncodedObject), true
+}
+
+// Clear the content of this object cache.
+func (c *ObjectLRU) Clear() {
+ c.mut.Lock()
+ defer c.mut.Unlock()
+
+ c.ll = nil
+ c.cache = nil
+ c.actualSize = 0
+}
--- /dev/null
+package plumbing
+
+import "fmt"
+
+type PermanentError struct {
+ Err error
+}
+
+func NewPermanentError(err error) *PermanentError {
+ if err == nil {
+ return nil
+ }
+
+ return &PermanentError{Err: err}
+}
+
+func (e *PermanentError) Error() string {
+ return fmt.Sprintf("permanent client error: %s", e.Err.Error())
+}
+
+type UnexpectedError struct {
+ Err error
+}
+
+func NewUnexpectedError(err error) *UnexpectedError {
+ if err == nil {
+ return nil
+ }
+
+ return &UnexpectedError{Err: err}
+}
+
+func (e *UnexpectedError) Error() string {
+ return fmt.Sprintf("unexpected client error: %s", e.Err.Error())
+}
--- /dev/null
+package filemode
+
+import (
+ "encoding/binary"
+ "fmt"
+ "os"
+ "strconv"
+)
+
+// A FileMode represents the kind of tree entries used by git. It
+// resembles regular file systems modes, although FileModes are
+// considerably simpler (there are not so many), and there are some,
+// like Submodule that has no file system equivalent.
+type FileMode uint32
+
+const (
+ // Empty is used as the FileMode of tree elements when comparing
+ // trees in the following situations:
+ //
+ // - the mode of tree elements before their creation. - the mode of
+ // tree elements after their deletion. - the mode of unmerged
+ // elements when checking the index.
+ //
+ // Empty has no file system equivalent. As Empty is the zero value
+ // of FileMode, it is also returned by New and
+ // NewFromOsNewFromOSFileMode along with an error, when they fail.
+ Empty FileMode = 0
+ // Dir represent a Directory.
+ Dir FileMode = 0040000
+ // Regular represent non-executable files. Please note this is not
+ // the same as golang regular files, which include executable files.
+ Regular FileMode = 0100644
+ // Deprecated represent non-executable files with the group writable
+ // bit set. This mode was supported by the first versions of git,
+ // but it has been deprecatred nowadays. This library uses them
+ // internally, so you can read old packfiles, but will treat them as
+ // Regulars when interfacing with the outside world. This is the
+ // standard git behaviuor.
+ Deprecated FileMode = 0100664
+ // Executable represents executable files.
+ Executable FileMode = 0100755
+ // Symlink represents symbolic links to files.
+ Symlink FileMode = 0120000
+ // Submodule represents git submodules. This mode has no file system
+ // equivalent.
+ Submodule FileMode = 0160000
+)
+
+// New takes the octal string representation of a FileMode and returns
+// the FileMode and a nil error. If the string can not be parsed to a
+// 32 bit unsigned octal number, it returns Empty and the parsing error.
+//
+// Example: "40000" means Dir, "100644" means Regular.
+//
+// Please note this function does not check if the returned FileMode
+// is valid in git or if it is malformed. For instance, "1" will
+// return the malformed FileMode(1) and a nil error.
+func New(s string) (FileMode, error) {
+ n, err := strconv.ParseUint(s, 8, 32)
+ if err != nil {
+ return Empty, err
+ }
+
+ return FileMode(n), nil
+}
+
+// NewFromOSFileMode returns the FileMode used by git to represent
+// the provided file system modes and a nil error on success. If the
+// file system mode cannot be mapped to any valid git mode (as with
+// sockets or named pipes), it will return Empty and an error.
+//
+// Note that some git modes cannot be generated from os.FileModes, like
+// Deprecated and Submodule; while Empty will be returned, along with an
+// error, only when the method fails.
+func NewFromOSFileMode(m os.FileMode) (FileMode, error) {
+ if m.IsRegular() {
+ if isSetTemporary(m) {
+ return Empty, fmt.Errorf("no equivalent git mode for %s", m)
+ }
+ if isSetCharDevice(m) {
+ return Empty, fmt.Errorf("no equivalent git mode for %s", m)
+ }
+ if isSetUserExecutable(m) {
+ return Executable, nil
+ }
+ return Regular, nil
+ }
+
+ if m.IsDir() {
+ return Dir, nil
+ }
+
+ if isSetSymLink(m) {
+ return Symlink, nil
+ }
+
+ return Empty, fmt.Errorf("no equivalent git mode for %s", m)
+}
+
+func isSetCharDevice(m os.FileMode) bool {
+ return m&os.ModeCharDevice != 0
+}
+
+func isSetTemporary(m os.FileMode) bool {
+ return m&os.ModeTemporary != 0
+}
+
+func isSetUserExecutable(m os.FileMode) bool {
+ return m&0100 != 0
+}
+
+func isSetSymLink(m os.FileMode) bool {
+ return m&os.ModeSymlink != 0
+}
+
+// Bytes return a slice of 4 bytes with the mode in little endian
+// encoding.
+func (m FileMode) Bytes() []byte {
+ ret := make([]byte, 4)
+ binary.LittleEndian.PutUint32(ret, uint32(m))
+ return ret[:]
+}
+
+// IsMalformed returns if the FileMode should not appear in a git packfile,
+// this is: Empty and any other mode not mentioned as a constant in this
+// package.
+func (m FileMode) IsMalformed() bool {
+ return m != Dir &&
+ m != Regular &&
+ m != Deprecated &&
+ m != Executable &&
+ m != Symlink &&
+ m != Submodule
+}
+
+// String returns the FileMode as a string in the standatd git format,
+// this is, an octal number padded with ceros to 7 digits. Malformed
+// modes are printed in that same format, for easier debugging.
+//
+// Example: Regular is "0100644", Empty is "0000000".
+func (m FileMode) String() string {
+ return fmt.Sprintf("%07o", uint32(m))
+}
+
+// IsRegular returns if the FileMode represents that of a regular file,
+// this is, either Regular or Deprecated. Please note that Executable
+// are not regular even though in the UNIX tradition, they usually are:
+// See the IsFile method.
+func (m FileMode) IsRegular() bool {
+ return m == Regular ||
+ m == Deprecated
+}
+
+// IsFile returns if the FileMode represents that of a file, this is,
+// Regular, Deprecated, Excutable or Link.
+func (m FileMode) IsFile() bool {
+ return m == Regular ||
+ m == Deprecated ||
+ m == Executable ||
+ m == Symlink
+}
+
+// ToOSFileMode returns the os.FileMode to be used when creating file
+// system elements with the given git mode and a nil error on success.
+//
+// When the provided mode cannot be mapped to a valid file system mode
+// (e.g. Submodule) it returns os.FileMode(0) and an error.
+//
+// The returned file mode does not take into account the umask.
+func (m FileMode) ToOSFileMode() (os.FileMode, error) {
+ switch m {
+ case Dir:
+ return os.ModePerm | os.ModeDir, nil
+ case Submodule:
+ return os.ModePerm | os.ModeDir, nil
+ case Regular:
+ return os.FileMode(0644), nil
+ // Deprecated is no longer allowed: treated as a Regular instead
+ case Deprecated:
+ return os.FileMode(0644), nil
+ case Executable:
+ return os.FileMode(0755), nil
+ case Symlink:
+ return os.ModePerm | os.ModeSymlink, nil
+ }
+
+ return os.FileMode(0), fmt.Errorf("malformed mode (%s)", m)
+}
--- /dev/null
+package config
+
+// New creates a new config instance.
+func New() *Config {
+ return &Config{}
+}
+
+// Config contains all the sections, comments and includes from a config file.
+type Config struct {
+ Comment *Comment
+ Sections Sections
+ Includes Includes
+}
+
+// Includes is a list of Includes in a config file.
+type Includes []*Include
+
+// Include is a reference to an included config file.
+type Include struct {
+ Path string
+ Config *Config
+}
+
+// Comment string without the prefix '#' or ';'.
+type Comment string
+
+const (
+ // NoSubsection token is passed to Config.Section and Config.SetSection to
+ // represent the absence of a section.
+ NoSubsection = ""
+)
+
+// Section returns a existing section with the given name or creates a new one.
+func (c *Config) Section(name string) *Section {
+ for i := len(c.Sections) - 1; i >= 0; i-- {
+ s := c.Sections[i]
+ if s.IsName(name) {
+ return s
+ }
+ }
+
+ s := &Section{Name: name}
+ c.Sections = append(c.Sections, s)
+ return s
+}
+
+// AddOption adds an option to a given section and subsection. Use the
+// NoSubsection constant for the subsection argument if no subsection is wanted.
+func (c *Config) AddOption(section string, subsection string, key string, value string) *Config {
+ if subsection == "" {
+ c.Section(section).AddOption(key, value)
+ } else {
+ c.Section(section).Subsection(subsection).AddOption(key, value)
+ }
+
+ return c
+}
+
+// SetOption sets an option to a given section and subsection. Use the
+// NoSubsection constant for the subsection argument if no subsection is wanted.
+func (c *Config) SetOption(section string, subsection string, key string, value string) *Config {
+ if subsection == "" {
+ c.Section(section).SetOption(key, value)
+ } else {
+ c.Section(section).Subsection(subsection).SetOption(key, value)
+ }
+
+ return c
+}
+
+// RemoveSection removes a section from a config file.
+func (c *Config) RemoveSection(name string) *Config {
+ result := Sections{}
+ for _, s := range c.Sections {
+ if !s.IsName(name) {
+ result = append(result, s)
+ }
+ }
+
+ c.Sections = result
+ return c
+}
+
+// RemoveSubsection remove s a subsection from a config file.
+func (c *Config) RemoveSubsection(section string, subsection string) *Config {
+ for _, s := range c.Sections {
+ if s.IsName(section) {
+ result := Subsections{}
+ for _, ss := range s.Subsections {
+ if !ss.IsName(subsection) {
+ result = append(result, ss)
+ }
+ }
+ s.Subsections = result
+ }
+ }
+
+ return c
+}
--- /dev/null
+package config
+
+import (
+ "io"
+
+ "github.com/src-d/gcfg"
+)
+
+// A Decoder reads and decodes config files from an input stream.
+type Decoder struct {
+ io.Reader
+}
+
+// NewDecoder returns a new decoder that reads from r.
+func NewDecoder(r io.Reader) *Decoder {
+ return &Decoder{r}
+}
+
+// Decode reads the whole config from its input and stores it in the
+// value pointed to by config.
+func (d *Decoder) Decode(config *Config) error {
+ cb := func(s string, ss string, k string, v string, bv bool) error {
+ if ss == "" && k == "" {
+ config.Section(s)
+ return nil
+ }
+
+ if ss != "" && k == "" {
+ config.Section(s).Subsection(ss)
+ return nil
+ }
+
+ config.AddOption(s, ss, k, v)
+ return nil
+ }
+ return gcfg.ReadWithCallback(d, cb)
+}
--- /dev/null
+// Package config implements encoding and decoding of git config files.
+//
+// Configuration File
+// ------------------
+//
+// The Git configuration file contains a number of variables that affect
+// the Git commands' behavior. The `.git/config` file in each repository
+// is used to store the configuration for that repository, and
+// `$HOME/.gitconfig` is used to store a per-user configuration as
+// fallback values for the `.git/config` file. The file `/etc/gitconfig`
+// can be used to store a system-wide default configuration.
+//
+// The configuration variables are used by both the Git plumbing
+// and the porcelains. The variables are divided into sections, wherein
+// the fully qualified variable name of the variable itself is the last
+// dot-separated segment and the section name is everything before the last
+// dot. The variable names are case-insensitive, allow only alphanumeric
+// characters and `-`, and must start with an alphabetic character. Some
+// variables may appear multiple times; we say then that the variable is
+// multivalued.
+//
+// Syntax
+// ~~~~~~
+//
+// The syntax is fairly flexible and permissive; whitespaces are mostly
+// ignored. The '#' and ';' characters begin comments to the end of line,
+// blank lines are ignored.
+//
+// The file consists of sections and variables. A section begins with
+// the name of the section in square brackets and continues until the next
+// section begins. Section names are case-insensitive. Only alphanumeric
+// characters, `-` and `.` are allowed in section names. Each variable
+// must belong to some section, which means that there must be a section
+// header before the first setting of a variable.
+//
+// Sections can be further divided into subsections. To begin a subsection
+// put its name in double quotes, separated by space from the section name,
+// in the section header, like in the example below:
+//
+// --------
+// [section "subsection"]
+//
+// --------
+//
+// Subsection names are case sensitive and can contain any characters except
+// newline (doublequote `"` and backslash can be included by escaping them
+// as `\"` and `\\`, respectively). Section headers cannot span multiple
+// lines. Variables may belong directly to a section or to a given subsection.
+// You can have `[section]` if you have `[section "subsection"]`, but you
+// don't need to.
+//
+// There is also a deprecated `[section.subsection]` syntax. With this
+// syntax, the subsection name is converted to lower-case and is also
+// compared case sensitively. These subsection names follow the same
+// restrictions as section names.
+//
+// All the other lines (and the remainder of the line after the section
+// header) are recognized as setting variables, in the form
+// 'name = value' (or just 'name', which is a short-hand to say that
+// the variable is the boolean "true").
+// The variable names are case-insensitive, allow only alphanumeric characters
+// and `-`, and must start with an alphabetic character.
+//
+// A line that defines a value can be continued to the next line by
+// ending it with a `\`; the backquote and the end-of-line are
+// stripped. Leading whitespaces after 'name =', the remainder of the
+// line after the first comment character '#' or ';', and trailing
+// whitespaces of the line are discarded unless they are enclosed in
+// double quotes. Internal whitespaces within the value are retained
+// verbatim.
+//
+// Inside double quotes, double quote `"` and backslash `\` characters
+// must be escaped: use `\"` for `"` and `\\` for `\`.
+//
+// The following escape sequences (beside `\"` and `\\`) are recognized:
+// `\n` for newline character (NL), `\t` for horizontal tabulation (HT, TAB)
+// and `\b` for backspace (BS). Other char escape sequences (including octal
+// escape sequences) are invalid.
+//
+// Includes
+// ~~~~~~~~
+//
+// You can include one config file from another by setting the special
+// `include.path` variable to the name of the file to be included. The
+// variable takes a pathname as its value, and is subject to tilde
+// expansion.
+//
+// The included file is expanded immediately, as if its contents had been
+// found at the location of the include directive. If the value of the
+// `include.path` variable is a relative path, the path is considered to be
+// relative to the configuration file in which the include directive was
+// found. See below for examples.
+//
+//
+// Example
+// ~~~~~~~
+//
+// # Core variables
+// [core]
+// ; Don't trust file modes
+// filemode = false
+//
+// # Our diff algorithm
+// [diff]
+// external = /usr/local/bin/diff-wrapper
+// renames = true
+//
+// [branch "devel"]
+// remote = origin
+// merge = refs/heads/devel
+//
+// # Proxy settings
+// [core]
+// gitProxy="ssh" for "kernel.org"
+// gitProxy=default-proxy ; for the rest
+//
+// [include]
+// path = /path/to/foo.inc ; include by absolute path
+// path = foo ; expand "foo" relative to the current file
+// path = ~/foo ; expand "foo" in your `$HOME` directory
+//
+package config
--- /dev/null
+package config
+
+import (
+ "fmt"
+ "io"
+ "strings"
+)
+
+// An Encoder writes config files to an output stream.
+type Encoder struct {
+ w io.Writer
+}
+
+// NewEncoder returns a new encoder that writes to w.
+func NewEncoder(w io.Writer) *Encoder {
+ return &Encoder{w}
+}
+
+// Encode writes the config in git config format to the stream of the encoder.
+func (e *Encoder) Encode(cfg *Config) error {
+ for _, s := range cfg.Sections {
+ if err := e.encodeSection(s); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (e *Encoder) encodeSection(s *Section) error {
+ if len(s.Options) > 0 {
+ if err := e.printf("[%s]\n", s.Name); err != nil {
+ return err
+ }
+
+ if err := e.encodeOptions(s.Options); err != nil {
+ return err
+ }
+ }
+
+ for _, ss := range s.Subsections {
+ if err := e.encodeSubsection(s.Name, ss); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (e *Encoder) encodeSubsection(sectionName string, s *Subsection) error {
+ //TODO: escape
+ if err := e.printf("[%s \"%s\"]\n", sectionName, s.Name); err != nil {
+ return err
+ }
+
+ return e.encodeOptions(s.Options)
+}
+
+func (e *Encoder) encodeOptions(opts Options) error {
+ for _, o := range opts {
+ pattern := "\t%s = %s\n"
+ if strings.Contains(o.Value, "\\") {
+ pattern = "\t%s = %q\n"
+ }
+
+ if err := e.printf(pattern, o.Key, o.Value); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (e *Encoder) printf(msg string, args ...interface{}) error {
+ _, err := fmt.Fprintf(e.w, msg, args...)
+ return err
+}
--- /dev/null
+package config
+
+import (
+ "fmt"
+ "strings"
+)
+
+// Option defines a key/value entity in a config file.
+type Option struct {
+ // Key preserving original caseness.
+ // Use IsKey instead to compare key regardless of caseness.
+ Key string
+ // Original value as string, could be not normalized.
+ Value string
+}
+
+type Options []*Option
+
+// IsKey returns true if the given key matches
+// this option's key in a case-insensitive comparison.
+func (o *Option) IsKey(key string) bool {
+ return strings.ToLower(o.Key) == strings.ToLower(key)
+}
+
+func (opts Options) GoString() string {
+ var strs []string
+ for _, opt := range opts {
+ strs = append(strs, fmt.Sprintf("%#v", opt))
+ }
+
+ return strings.Join(strs, ", ")
+}
+
+// Get gets the value for the given key if set,
+// otherwise it returns the empty string.
+//
+// Note that there is no difference
+//
+// This matches git behaviour since git v1.8.1-rc1,
+// if there are multiple definitions of a key, the
+// last one wins.
+//
+// See: http://article.gmane.org/gmane.linux.kernel/1407184
+//
+// In order to get all possible values for the same key,
+// use GetAll.
+func (opts Options) Get(key string) string {
+ for i := len(opts) - 1; i >= 0; i-- {
+ o := opts[i]
+ if o.IsKey(key) {
+ return o.Value
+ }
+ }
+ return ""
+}
+
+// GetAll returns all possible values for the same key.
+func (opts Options) GetAll(key string) []string {
+ result := []string{}
+ for _, o := range opts {
+ if o.IsKey(key) {
+ result = append(result, o.Value)
+ }
+ }
+ return result
+}
+
+func (opts Options) withoutOption(key string) Options {
+ result := Options{}
+ for _, o := range opts {
+ if !o.IsKey(key) {
+ result = append(result, o)
+ }
+ }
+ return result
+}
+
+func (opts Options) withAddedOption(key string, value string) Options {
+ return append(opts, &Option{key, value})
+}
+
+func (opts Options) withSettedOption(key string, values ...string) Options {
+ var result Options
+ var added []string
+ for _, o := range opts {
+ if !o.IsKey(key) {
+ result = append(result, o)
+ continue
+ }
+
+ if contains(values, o.Value) {
+ added = append(added, o.Value)
+ result = append(result, o)
+ continue
+ }
+ }
+
+ for _, value := range values {
+ if contains(added, value) {
+ continue
+ }
+
+ result = result.withAddedOption(key, value)
+ }
+
+ return result
+}
+
+func contains(haystack []string, needle string) bool {
+ for _, s := range haystack {
+ if s == needle {
+ return true
+ }
+ }
+
+ return false
+}
--- /dev/null
+package config
+
+import (
+ "fmt"
+ "strings"
+)
+
+// Section is the representation of a section inside git configuration files.
+// Each Section contains Options that are used by both the Git plumbing
+// and the porcelains.
+// Sections can be further divided into subsections. To begin a subsection
+// put its name in double quotes, separated by space from the section name,
+// in the section header, like in the example below:
+//
+// [section "subsection"]
+//
+// All the other lines (and the remainder of the line after the section header)
+// are recognized as option variables, in the form "name = value" (or just name,
+// which is a short-hand to say that the variable is the boolean "true").
+// The variable names are case-insensitive, allow only alphanumeric characters
+// and -, and must start with an alphabetic character:
+//
+// [section "subsection1"]
+// option1 = value1
+// option2
+// [section "subsection2"]
+// option3 = value2
+//
+type Section struct {
+ Name string
+ Options Options
+ Subsections Subsections
+}
+
+type Subsection struct {
+ Name string
+ Options Options
+}
+
+type Sections []*Section
+
+func (s Sections) GoString() string {
+ var strs []string
+ for _, ss := range s {
+ strs = append(strs, fmt.Sprintf("%#v", ss))
+ }
+
+ return strings.Join(strs, ", ")
+}
+
+type Subsections []*Subsection
+
+func (s Subsections) GoString() string {
+ var strs []string
+ for _, ss := range s {
+ strs = append(strs, fmt.Sprintf("%#v", ss))
+ }
+
+ return strings.Join(strs, ", ")
+}
+
+// IsName checks if the name provided is equals to the Section name, case insensitive.
+func (s *Section) IsName(name string) bool {
+ return strings.ToLower(s.Name) == strings.ToLower(name)
+}
+
+// Option return the value for the specified key. Empty string is returned if
+// key does not exists.
+func (s *Section) Option(key string) string {
+ return s.Options.Get(key)
+}
+
+// AddOption adds a new Option to the Section. The updated Section is returned.
+func (s *Section) AddOption(key string, value string) *Section {
+ s.Options = s.Options.withAddedOption(key, value)
+ return s
+}
+
+// SetOption adds a new Option to the Section. If the option already exists, is replaced.
+// The updated Section is returned.
+func (s *Section) SetOption(key string, value string) *Section {
+ s.Options = s.Options.withSettedOption(key, value)
+ return s
+}
+
+// Remove an option with the specified key. The updated Section is returned.
+func (s *Section) RemoveOption(key string) *Section {
+ s.Options = s.Options.withoutOption(key)
+ return s
+}
+
+// Subsection returns a Subsection from the specified Section. If the
+// Subsection does not exists, new one is created and added to Section.
+func (s *Section) Subsection(name string) *Subsection {
+ for i := len(s.Subsections) - 1; i >= 0; i-- {
+ ss := s.Subsections[i]
+ if ss.IsName(name) {
+ return ss
+ }
+ }
+
+ ss := &Subsection{Name: name}
+ s.Subsections = append(s.Subsections, ss)
+ return ss
+}
+
+// HasSubsection checks if the Section has a Subsection with the specified name.
+func (s *Section) HasSubsection(name string) bool {
+ for _, ss := range s.Subsections {
+ if ss.IsName(name) {
+ return true
+ }
+ }
+
+ return false
+}
+
+// IsName checks if the name of the subsection is exactly the specified name.
+func (s *Subsection) IsName(name string) bool {
+ return s.Name == name
+}
+
+// Option returns an option with the specified key. If the option does not exists,
+// empty spring will be returned.
+func (s *Subsection) Option(key string) string {
+ return s.Options.Get(key)
+}
+
+// AddOption adds a new Option to the Subsection. The updated Subsection is returned.
+func (s *Subsection) AddOption(key string, value string) *Subsection {
+ s.Options = s.Options.withAddedOption(key, value)
+ return s
+}
+
+// SetOption adds a new Option to the Subsection. If the option already exists, is replaced.
+// The updated Subsection is returned.
+func (s *Subsection) SetOption(key string, value ...string) *Subsection {
+ s.Options = s.Options.withSettedOption(key, value...)
+ return s
+}
+
+// RemoveOption removes the option with the specified key. The updated Subsection is returned.
+func (s *Subsection) RemoveOption(key string) *Subsection {
+ s.Options = s.Options.withoutOption(key)
+ return s
+}
--- /dev/null
+package diff
+
+import (
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/filemode"
+)
+
+// Operation defines the operation of a diff item.
+type Operation int
+
+const (
+ // Equal item represents a equals diff.
+ Equal Operation = iota
+ // Add item represents an insert diff.
+ Add
+ // Delete item represents a delete diff.
+ Delete
+)
+
+// Patch represents a collection of steps to transform several files.
+type Patch interface {
+ // FilePatches returns a slice of patches per file.
+ FilePatches() []FilePatch
+ // Message returns an optional message that can be at the top of the
+ // Patch representation.
+ Message() string
+}
+
+// FilePatch represents the necessary steps to transform one file to another.
+type FilePatch interface {
+ // IsBinary returns true if this patch is representing a binary file.
+ IsBinary() bool
+ // Files returns the from and to Files, with all the necessary metadata to
+ // about them. If the patch creates a new file, "from" will be nil.
+ // If the patch deletes a file, "to" will be nil.
+ Files() (from, to File)
+ // Chunks returns a slice of ordered changes to transform "from" File to
+ // "to" File. If the file is a binary one, Chunks will be empty.
+ Chunks() []Chunk
+}
+
+// File contains all the file metadata necessary to print some patch formats.
+type File interface {
+ // Hash returns the File Hash.
+ Hash() plumbing.Hash
+ // Mode returns the FileMode.
+ Mode() filemode.FileMode
+ // Path returns the complete Path to the file, including the filename.
+ Path() string
+}
+
+// Chunk represents a portion of a file transformation to another.
+type Chunk interface {
+ // Content contains the portion of the file.
+ Content() string
+ // Type contains the Operation to do with this Chunk.
+ Type() Operation
+}
--- /dev/null
+package diff
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "strings"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+)
+
+const (
+ diffInit = "diff --git a/%s b/%s\n"
+
+ chunkStart = "@@ -"
+ chunkMiddle = " +"
+ chunkEnd = " @@%s\n"
+ chunkCount = "%d,%d"
+
+ noFilePath = "/dev/null"
+ aDir = "a/"
+ bDir = "b/"
+
+ fPath = "--- %s\n"
+ tPath = "+++ %s\n"
+ binary = "Binary files %s and %s differ\n"
+
+ addLine = "+%s\n"
+ deleteLine = "-%s\n"
+ equalLine = " %s\n"
+
+ oldMode = "old mode %o\n"
+ newMode = "new mode %o\n"
+ deletedFileMode = "deleted file mode %o\n"
+ newFileMode = "new file mode %o\n"
+
+ renameFrom = "from"
+ renameTo = "to"
+ renameFileMode = "rename %s %s\n"
+
+ indexAndMode = "index %s..%s %o\n"
+ indexNoMode = "index %s..%s\n"
+
+ DefaultContextLines = 3
+)
+
+// UnifiedEncoder encodes an unified diff into the provided Writer.
+// There are some unsupported features:
+// - Similarity index for renames
+// - Sort hash representation
+type UnifiedEncoder struct {
+ io.Writer
+
+ // ctxLines is the count of unchanged lines that will appear
+ // surrounding a change.
+ ctxLines int
+
+ buf bytes.Buffer
+}
+
+func NewUnifiedEncoder(w io.Writer, ctxLines int) *UnifiedEncoder {
+ return &UnifiedEncoder{ctxLines: ctxLines, Writer: w}
+}
+
+func (e *UnifiedEncoder) Encode(patch Patch) error {
+ e.printMessage(patch.Message())
+
+ if err := e.encodeFilePatch(patch.FilePatches()); err != nil {
+ return err
+ }
+
+ _, err := e.buf.WriteTo(e)
+
+ return err
+}
+
+func (e *UnifiedEncoder) encodeFilePatch(filePatches []FilePatch) error {
+ for _, p := range filePatches {
+ f, t := p.Files()
+ if err := e.header(f, t, p.IsBinary()); err != nil {
+ return err
+ }
+
+ g := newHunksGenerator(p.Chunks(), e.ctxLines)
+ for _, c := range g.Generate() {
+ c.WriteTo(&e.buf)
+ }
+ }
+
+ return nil
+}
+
+func (e *UnifiedEncoder) printMessage(message string) {
+ isEmpty := message == ""
+ hasSuffix := strings.HasSuffix(message, "\n")
+ if !isEmpty && !hasSuffix {
+ message = message + "\n"
+ }
+
+ e.buf.WriteString(message)
+}
+
+func (e *UnifiedEncoder) header(from, to File, isBinary bool) error {
+ switch {
+ case from == nil && to == nil:
+ return nil
+ case from != nil && to != nil:
+ hashEquals := from.Hash() == to.Hash()
+
+ fmt.Fprintf(&e.buf, diffInit, from.Path(), to.Path())
+
+ if from.Mode() != to.Mode() {
+ fmt.Fprintf(&e.buf, oldMode+newMode, from.Mode(), to.Mode())
+ }
+
+ if from.Path() != to.Path() {
+ fmt.Fprintf(&e.buf,
+ renameFileMode+renameFileMode,
+ renameFrom, from.Path(), renameTo, to.Path())
+ }
+
+ if from.Mode() != to.Mode() && !hashEquals {
+ fmt.Fprintf(&e.buf, indexNoMode, from.Hash(), to.Hash())
+ } else if !hashEquals {
+ fmt.Fprintf(&e.buf, indexAndMode, from.Hash(), to.Hash(), from.Mode())
+ }
+
+ if !hashEquals {
+ e.pathLines(isBinary, aDir+from.Path(), bDir+to.Path())
+ }
+ case from == nil:
+ fmt.Fprintf(&e.buf, diffInit, to.Path(), to.Path())
+ fmt.Fprintf(&e.buf, newFileMode, to.Mode())
+ fmt.Fprintf(&e.buf, indexNoMode, plumbing.ZeroHash, to.Hash())
+ e.pathLines(isBinary, noFilePath, bDir+to.Path())
+ case to == nil:
+ fmt.Fprintf(&e.buf, diffInit, from.Path(), from.Path())
+ fmt.Fprintf(&e.buf, deletedFileMode, from.Mode())
+ fmt.Fprintf(&e.buf, indexNoMode, from.Hash(), plumbing.ZeroHash)
+ e.pathLines(isBinary, aDir+from.Path(), noFilePath)
+ }
+
+ return nil
+}
+
+func (e *UnifiedEncoder) pathLines(isBinary bool, fromPath, toPath string) {
+ format := fPath + tPath
+ if isBinary {
+ format = binary
+ }
+
+ fmt.Fprintf(&e.buf, format, fromPath, toPath)
+}
+
+type hunksGenerator struct {
+ fromLine, toLine int
+ ctxLines int
+ chunks []Chunk
+ current *hunk
+ hunks []*hunk
+ beforeContext, afterContext []string
+}
+
+func newHunksGenerator(chunks []Chunk, ctxLines int) *hunksGenerator {
+ return &hunksGenerator{
+ chunks: chunks,
+ ctxLines: ctxLines,
+ }
+}
+
+func (c *hunksGenerator) Generate() []*hunk {
+ for i, chunk := range c.chunks {
+ ls := splitLines(chunk.Content())
+ lsLen := len(ls)
+
+ switch chunk.Type() {
+ case Equal:
+ c.fromLine += lsLen
+ c.toLine += lsLen
+ c.processEqualsLines(ls, i)
+ case Delete:
+ if lsLen != 0 {
+ c.fromLine++
+ }
+
+ c.processHunk(i, chunk.Type())
+ c.fromLine += lsLen - 1
+ c.current.AddOp(chunk.Type(), ls...)
+ case Add:
+ if lsLen != 0 {
+ c.toLine++
+ }
+ c.processHunk(i, chunk.Type())
+ c.toLine += lsLen - 1
+ c.current.AddOp(chunk.Type(), ls...)
+ }
+
+ if i == len(c.chunks)-1 && c.current != nil {
+ c.hunks = append(c.hunks, c.current)
+ }
+ }
+
+ return c.hunks
+}
+
+func (c *hunksGenerator) processHunk(i int, op Operation) {
+ if c.current != nil {
+ return
+ }
+
+ var ctxPrefix string
+ linesBefore := len(c.beforeContext)
+ if linesBefore > c.ctxLines {
+ ctxPrefix = " " + c.beforeContext[linesBefore-c.ctxLines-1]
+ c.beforeContext = c.beforeContext[linesBefore-c.ctxLines:]
+ linesBefore = c.ctxLines
+ }
+
+ c.current = &hunk{ctxPrefix: ctxPrefix}
+ c.current.AddOp(Equal, c.beforeContext...)
+
+ switch op {
+ case Delete:
+ c.current.fromLine, c.current.toLine =
+ c.addLineNumbers(c.fromLine, c.toLine, linesBefore, i, Add)
+ case Add:
+ c.current.toLine, c.current.fromLine =
+ c.addLineNumbers(c.toLine, c.fromLine, linesBefore, i, Delete)
+ }
+
+ c.beforeContext = nil
+}
+
+// addLineNumbers obtains the line numbers in a new chunk
+func (c *hunksGenerator) addLineNumbers(la, lb int, linesBefore int, i int, op Operation) (cla, clb int) {
+ cla = la - linesBefore
+ // we need to search for a reference for the next diff
+ switch {
+ case linesBefore != 0 && c.ctxLines != 0:
+ if lb > c.ctxLines {
+ clb = lb - c.ctxLines + 1
+ } else {
+ clb = 1
+ }
+ case c.ctxLines == 0:
+ clb = lb
+ case i != len(c.chunks)-1:
+ next := c.chunks[i+1]
+ if next.Type() == op || next.Type() == Equal {
+ // this diff will be into this chunk
+ clb = lb + 1
+ }
+ }
+
+ return
+}
+
+func (c *hunksGenerator) processEqualsLines(ls []string, i int) {
+ if c.current == nil {
+ c.beforeContext = append(c.beforeContext, ls...)
+ return
+ }
+
+ c.afterContext = append(c.afterContext, ls...)
+ if len(c.afterContext) <= c.ctxLines*2 && i != len(c.chunks)-1 {
+ c.current.AddOp(Equal, c.afterContext...)
+ c.afterContext = nil
+ } else {
+ ctxLines := c.ctxLines
+ if ctxLines > len(c.afterContext) {
+ ctxLines = len(c.afterContext)
+ }
+ c.current.AddOp(Equal, c.afterContext[:ctxLines]...)
+ c.hunks = append(c.hunks, c.current)
+
+ c.current = nil
+ c.beforeContext = c.afterContext[ctxLines:]
+ c.afterContext = nil
+ }
+}
+
+func splitLines(s string) []string {
+ out := strings.Split(s, "\n")
+ if out[len(out)-1] == "" {
+ out = out[:len(out)-1]
+ }
+
+ return out
+}
+
+type hunk struct {
+ fromLine int
+ toLine int
+
+ fromCount int
+ toCount int
+
+ ctxPrefix string
+ ops []*op
+}
+
+func (c *hunk) WriteTo(buf *bytes.Buffer) {
+ buf.WriteString(chunkStart)
+
+ if c.fromCount == 1 {
+ fmt.Fprintf(buf, "%d", c.fromLine)
+ } else {
+ fmt.Fprintf(buf, chunkCount, c.fromLine, c.fromCount)
+ }
+
+ buf.WriteString(chunkMiddle)
+
+ if c.toCount == 1 {
+ fmt.Fprintf(buf, "%d", c.toLine)
+ } else {
+ fmt.Fprintf(buf, chunkCount, c.toLine, c.toCount)
+ }
+
+ fmt.Fprintf(buf, chunkEnd, c.ctxPrefix)
+
+ for _, d := range c.ops {
+ buf.WriteString(d.String())
+ }
+}
+
+func (c *hunk) AddOp(t Operation, s ...string) {
+ ls := len(s)
+ switch t {
+ case Add:
+ c.toCount += ls
+ case Delete:
+ c.fromCount += ls
+ case Equal:
+ c.toCount += ls
+ c.fromCount += ls
+ }
+
+ for _, l := range s {
+ c.ops = append(c.ops, &op{l, t})
+ }
+}
+
+type op struct {
+ text string
+ t Operation
+}
+
+func (o *op) String() string {
+ var prefix string
+ switch o.t {
+ case Add:
+ prefix = addLine
+ case Delete:
+ prefix = deleteLine
+ case Equal:
+ prefix = equalLine
+ }
+
+ return fmt.Sprintf(prefix, o.text)
+}
--- /dev/null
+package gitignore
+
+import (
+ "bytes"
+ "io/ioutil"
+ "os"
+ "os/user"
+ "strings"
+
+ "gopkg.in/src-d/go-billy.v4"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/config"
+ gioutil "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+const (
+ commentPrefix = "#"
+ coreSection = "core"
+ eol = "\n"
+ excludesfile = "excludesfile"
+ gitDir = ".git"
+ gitignoreFile = ".gitignore"
+ gitconfigFile = ".gitconfig"
+ systemFile = "/etc/gitconfig"
+)
+
+// readIgnoreFile reads a specific git ignore file.
+func readIgnoreFile(fs billy.Filesystem, path []string, ignoreFile string) (ps []Pattern, err error) {
+ f, err := fs.Open(fs.Join(append(path, ignoreFile)...))
+ if err == nil {
+ defer f.Close()
+
+ if data, err := ioutil.ReadAll(f); err == nil {
+ for _, s := range strings.Split(string(data), eol) {
+ if !strings.HasPrefix(s, commentPrefix) && len(strings.TrimSpace(s)) > 0 {
+ ps = append(ps, ParsePattern(s, path))
+ }
+ }
+ }
+ } else if !os.IsNotExist(err) {
+ return nil, err
+ }
+
+ return
+}
+
+// ReadPatterns reads gitignore patterns recursively traversing through the directory
+// structure. The result is in the ascending order of priority (last higher).
+func ReadPatterns(fs billy.Filesystem, path []string) (ps []Pattern, err error) {
+ ps, _ = readIgnoreFile(fs, path, gitignoreFile)
+
+ var fis []os.FileInfo
+ fis, err = fs.ReadDir(fs.Join(path...))
+ if err != nil {
+ return
+ }
+
+ for _, fi := range fis {
+ if fi.IsDir() && fi.Name() != gitDir {
+ var subps []Pattern
+ subps, err = ReadPatterns(fs, append(path, fi.Name()))
+ if err != nil {
+ return
+ }
+
+ if len(subps) > 0 {
+ ps = append(ps, subps...)
+ }
+ }
+ }
+
+ return
+}
+
+func loadPatterns(fs billy.Filesystem, path string) (ps []Pattern, err error) {
+ f, err := fs.Open(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+ return nil, err
+ }
+
+ defer gioutil.CheckClose(f, &err)
+
+ b, err := ioutil.ReadAll(f)
+ if err != nil {
+ return
+ }
+
+ d := config.NewDecoder(bytes.NewBuffer(b))
+
+ raw := config.New()
+ if err = d.Decode(raw); err != nil {
+ return
+ }
+
+ s := raw.Section(coreSection)
+ efo := s.Options.Get(excludesfile)
+ if efo == "" {
+ return nil, nil
+ }
+
+ ps, err = readIgnoreFile(fs, nil, efo)
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+
+ return
+}
+
+// LoadGlobalPatterns loads gitignore patterns from from the gitignore file
+// declared in a user's ~/.gitconfig file. If the ~/.gitconfig file does not
+// exist the function will return nil. If the core.excludesfile property
+// is not declared, the function will return nil. If the file pointed to by
+// the core.excludesfile property does not exist, the function will return nil.
+//
+// The function assumes fs is rooted at the root filesystem.
+func LoadGlobalPatterns(fs billy.Filesystem) (ps []Pattern, err error) {
+ usr, err := user.Current()
+ if err != nil {
+ return
+ }
+
+ return loadPatterns(fs, fs.Join(usr.HomeDir, gitconfigFile))
+}
+
+// LoadSystemPatterns loads gitignore patterns from from the gitignore file
+// declared in a system's /etc/gitconfig file. If the ~/.gitconfig file does
+// not exist the function will return nil. If the core.excludesfile property
+// is not declared, the function will return nil. If the file pointed to by
+// the core.excludesfile property does not exist, the function will return nil.
+//
+// The function assumes fs is rooted at the root filesystem.
+func LoadSystemPatterns(fs billy.Filesystem) (ps []Pattern, err error) {
+ return loadPatterns(fs, systemFile)
+}
--- /dev/null
+// Package gitignore implements matching file system paths to gitignore patterns that
+// can be automatically read from a git repository tree in the order of definition
+// priorities. It support all pattern formats as specified in the original gitignore
+// documentation, copied below:
+//
+// Pattern format
+// ==============
+//
+// - A blank line matches no files, so it can serve as a separator for readability.
+//
+// - A line starting with # serves as a comment. Put a backslash ("\") in front of
+// the first hash for patterns that begin with a hash.
+//
+// - Trailing spaces are ignored unless they are quoted with backslash ("\").
+//
+// - An optional prefix "!" which negates the pattern; any matching file excluded
+// by a previous pattern will become included again. It is not possible to
+// re-include a file if a parent directory of that file is excluded.
+// Git doesn’t list excluded directories for performance reasons, so
+// any patterns on contained files have no effect, no matter where they are
+// defined. Put a backslash ("\") in front of the first "!" for patterns
+// that begin with a literal "!", for example, "\!important!.txt".
+//
+// - If the pattern ends with a slash, it is removed for the purpose of the
+// following description, but it would only find a match with a directory.
+// In other words, foo/ will match a directory foo and paths underneath it,
+// but will not match a regular file or a symbolic link foo (this is consistent
+// with the way how pathspec works in general in Git).
+//
+// - If the pattern does not contain a slash /, Git treats it as a shell glob
+// pattern and checks for a match against the pathname relative to the location
+// of the .gitignore file (relative to the toplevel of the work tree if not
+// from a .gitignore file).
+//
+// - Otherwise, Git treats the pattern as a shell glob suitable for consumption
+// by fnmatch(3) with the FNM_PATHNAME flag: wildcards in the pattern will
+// not match a / in the pathname. For example, "Documentation/*.html" matches
+// "Documentation/git.html" but not "Documentation/ppc/ppc.html" or
+// "tools/perf/Documentation/perf.html".
+//
+// - A leading slash matches the beginning of the pathname. For example,
+// "/*.c" matches "cat-file.c" but not "mozilla-sha1/sha1.c".
+//
+// Two consecutive asterisks ("**") in patterns matched against full pathname
+// may have special meaning:
+//
+// - A leading "**" followed by a slash means match in all directories.
+// For example, "**/foo" matches file or directory "foo" anywhere, the same as
+// pattern "foo". "**/foo/bar" matches file or directory "bar"
+// anywhere that is directly under directory "foo".
+//
+// - A trailing "/**" matches everything inside. For example, "abc/**" matches
+// all files inside directory "abc", relative to the location of the
+// .gitignore file, with infinite depth.
+//
+// - A slash followed by two consecutive asterisks then a slash matches
+// zero or more directories. For example, "a/**/b" matches "a/b", "a/x/b",
+// "a/x/y/b" and so on.
+//
+// - Other consecutive asterisks are considered invalid.
+//
+// Copyright and license
+// =====================
+//
+// Copyright (c) Oleg Sklyar, Silvertern and source{d}
+//
+// The package code was donated to source{d} to include, modify and develop
+// further as a part of the `go-git` project, release it on the license of
+// the whole project or delete it from the project.
+package gitignore
--- /dev/null
+package gitignore
+
+// Matcher defines a global multi-pattern matcher for gitignore patterns
+type Matcher interface {
+ // Match matches patterns in the order of priorities. As soon as an inclusion or
+ // exclusion is found, not further matching is performed.
+ Match(path []string, isDir bool) bool
+}
+
+// NewMatcher constructs a new global matcher. Patterns must be given in the order of
+// increasing priority. That is most generic settings files first, then the content of
+// the repo .gitignore, then content of .gitignore down the path or the repo and then
+// the content command line arguments.
+func NewMatcher(ps []Pattern) Matcher {
+ return &matcher{ps}
+}
+
+type matcher struct {
+ patterns []Pattern
+}
+
+func (m *matcher) Match(path []string, isDir bool) bool {
+ n := len(m.patterns)
+ for i := n - 1; i >= 0; i-- {
+ if match := m.patterns[i].Match(path, isDir); match > NoMatch {
+ return match == Exclude
+ }
+ }
+ return false
+}
--- /dev/null
+package gitignore
+
+import (
+ "path/filepath"
+ "strings"
+)
+
+// MatchResult defines outcomes of a match, no match, exclusion or inclusion.
+type MatchResult int
+
+const (
+ // NoMatch defines the no match outcome of a match check
+ NoMatch MatchResult = iota
+ // Exclude defines an exclusion of a file as a result of a match check
+ Exclude
+ // Include defines an explicit inclusion of a file as a result of a match check
+ Include
+)
+
+const (
+ inclusionPrefix = "!"
+ zeroToManyDirs = "**"
+ patternDirSep = "/"
+)
+
+// Pattern defines a single gitignore pattern.
+type Pattern interface {
+ // Match matches the given path to the pattern.
+ Match(path []string, isDir bool) MatchResult
+}
+
+type pattern struct {
+ domain []string
+ pattern []string
+ inclusion bool
+ dirOnly bool
+ isGlob bool
+}
+
+// ParsePattern parses a gitignore pattern string into the Pattern structure.
+func ParsePattern(p string, domain []string) Pattern {
+ res := pattern{domain: domain}
+
+ if strings.HasPrefix(p, inclusionPrefix) {
+ res.inclusion = true
+ p = p[1:]
+ }
+
+ if !strings.HasSuffix(p, "\\ ") {
+ p = strings.TrimRight(p, " ")
+ }
+
+ if strings.HasSuffix(p, patternDirSep) {
+ res.dirOnly = true
+ p = p[:len(p)-1]
+ }
+
+ if strings.Contains(p, patternDirSep) {
+ res.isGlob = true
+ }
+
+ res.pattern = strings.Split(p, patternDirSep)
+ return &res
+}
+
+func (p *pattern) Match(path []string, isDir bool) MatchResult {
+ if len(path) <= len(p.domain) {
+ return NoMatch
+ }
+ for i, e := range p.domain {
+ if path[i] != e {
+ return NoMatch
+ }
+ }
+
+ path = path[len(p.domain):]
+ if p.isGlob && !p.globMatch(path, isDir) {
+ return NoMatch
+ } else if !p.isGlob && !p.simpleNameMatch(path, isDir) {
+ return NoMatch
+ }
+
+ if p.inclusion {
+ return Include
+ } else {
+ return Exclude
+ }
+}
+
+func (p *pattern) simpleNameMatch(path []string, isDir bool) bool {
+ for i, name := range path {
+ if match, err := filepath.Match(p.pattern[0], name); err != nil {
+ return false
+ } else if !match {
+ continue
+ }
+ if p.dirOnly && !isDir && i == len(path)-1 {
+ return false
+ }
+ return true
+ }
+ return false
+}
+
+func (p *pattern) globMatch(path []string, isDir bool) bool {
+ matched := false
+ canTraverse := false
+ for i, pattern := range p.pattern {
+ if pattern == "" {
+ canTraverse = false
+ continue
+ }
+ if pattern == zeroToManyDirs {
+ if i == len(p.pattern)-1 {
+ break
+ }
+ canTraverse = true
+ continue
+ }
+ if strings.Contains(pattern, zeroToManyDirs) {
+ return false
+ }
+ if len(path) == 0 {
+ return false
+ }
+ if canTraverse {
+ canTraverse = false
+ for len(path) > 0 {
+ e := path[0]
+ path = path[1:]
+ if match, err := filepath.Match(pattern, e); err != nil {
+ return false
+ } else if match {
+ matched = true
+ break
+ } else if len(path) == 0 {
+ // if nothing left then fail
+ matched = false
+ }
+ }
+ } else {
+ if match, err := filepath.Match(pattern, path[0]); err != nil || !match {
+ return false
+ }
+ matched = true
+ path = path[1:]
+ }
+ }
+ if matched && p.dirOnly && !isDir && len(path) == 0 {
+ matched = false
+ }
+ return matched
+}
--- /dev/null
+package idxfile
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/utils/binary"
+)
+
+var (
+ // ErrUnsupportedVersion is returned by Decode when the idx file version
+ // is not supported.
+ ErrUnsupportedVersion = errors.New("Unsuported version")
+ // ErrMalformedIdxFile is returned by Decode when the idx file is corrupted.
+ ErrMalformedIdxFile = errors.New("Malformed IDX file")
+)
+
+const (
+ fanout = 256
+ objectIDLength = 20
+)
+
+// Decoder reads and decodes idx files from an input stream.
+type Decoder struct {
+ *bufio.Reader
+}
+
+// NewDecoder builds a new idx stream decoder, that reads from r.
+func NewDecoder(r io.Reader) *Decoder {
+ return &Decoder{bufio.NewReader(r)}
+}
+
+// Decode reads from the stream and decode the content into the MemoryIndex struct.
+func (d *Decoder) Decode(idx *MemoryIndex) error {
+ if err := validateHeader(d); err != nil {
+ return err
+ }
+
+ flow := []func(*MemoryIndex, io.Reader) error{
+ readVersion,
+ readFanout,
+ readObjectNames,
+ readCRC32,
+ readOffsets,
+ readChecksums,
+ }
+
+ for _, f := range flow {
+ if err := f(idx, d); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func validateHeader(r io.Reader) error {
+ var h = make([]byte, 4)
+ if _, err := io.ReadFull(r, h); err != nil {
+ return err
+ }
+
+ if !bytes.Equal(h, idxHeader) {
+ return ErrMalformedIdxFile
+ }
+
+ return nil
+}
+
+func readVersion(idx *MemoryIndex, r io.Reader) error {
+ v, err := binary.ReadUint32(r)
+ if err != nil {
+ return err
+ }
+
+ if v > VersionSupported {
+ return ErrUnsupportedVersion
+ }
+
+ idx.Version = v
+ return nil
+}
+
+func readFanout(idx *MemoryIndex, r io.Reader) error {
+ for k := 0; k < fanout; k++ {
+ n, err := binary.ReadUint32(r)
+ if err != nil {
+ return err
+ }
+
+ idx.Fanout[k] = n
+ idx.FanoutMapping[k] = noMapping
+ }
+
+ return nil
+}
+
+func readObjectNames(idx *MemoryIndex, r io.Reader) error {
+ for k := 0; k < fanout; k++ {
+ var buckets uint32
+ if k == 0 {
+ buckets = idx.Fanout[k]
+ } else {
+ buckets = idx.Fanout[k] - idx.Fanout[k-1]
+ }
+
+ if buckets == 0 {
+ continue
+ }
+
+ if buckets < 0 {
+ return ErrMalformedIdxFile
+ }
+
+ idx.FanoutMapping[k] = len(idx.Names)
+
+ nameLen := int(buckets * objectIDLength)
+ bin := make([]byte, nameLen)
+ if _, err := io.ReadFull(r, bin); err != nil {
+ return err
+ }
+
+ idx.Names = append(idx.Names, bin)
+ idx.Offset32 = append(idx.Offset32, make([]byte, buckets*4))
+ idx.CRC32 = append(idx.CRC32, make([]byte, buckets*4))
+ }
+
+ return nil
+}
+
+func readCRC32(idx *MemoryIndex, r io.Reader) error {
+ for k := 0; k < fanout; k++ {
+ if pos := idx.FanoutMapping[k]; pos != noMapping {
+ if _, err := io.ReadFull(r, idx.CRC32[pos]); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+func readOffsets(idx *MemoryIndex, r io.Reader) error {
+ var o64cnt int
+ for k := 0; k < fanout; k++ {
+ if pos := idx.FanoutMapping[k]; pos != noMapping {
+ if _, err := io.ReadFull(r, idx.Offset32[pos]); err != nil {
+ return err
+ }
+
+ for p := 0; p < len(idx.Offset32[pos]); p += 4 {
+ if idx.Offset32[pos][p]&(byte(1)<<7) > 0 {
+ o64cnt++
+ }
+ }
+ }
+ }
+
+ if o64cnt > 0 {
+ idx.Offset64 = make([]byte, o64cnt*8)
+ if _, err := io.ReadFull(r, idx.Offset64); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func readChecksums(idx *MemoryIndex, r io.Reader) error {
+ if _, err := io.ReadFull(r, idx.PackfileChecksum[:]); err != nil {
+ return err
+ }
+
+ if _, err := io.ReadFull(r, idx.IdxChecksum[:]); err != nil {
+ return err
+ }
+
+ return nil
+}
--- /dev/null
+// Package idxfile implements encoding and decoding of packfile idx files.
+//
+// == Original (version 1) pack-*.idx files have the following format:
+//
+// - The header consists of 256 4-byte network byte order
+// integers. N-th entry of this table records the number of
+// objects in the corresponding pack, the first byte of whose
+// object name is less than or equal to N. This is called the
+// 'first-level fan-out' table.
+//
+// - The header is followed by sorted 24-byte entries, one entry
+// per object in the pack. Each entry is:
+//
+// 4-byte network byte order integer, recording where the
+// object is stored in the packfile as the offset from the
+// beginning.
+//
+// 20-byte object name.
+//
+// - The file is concluded with a trailer:
+//
+// A copy of the 20-byte SHA1 checksum at the end of
+// corresponding packfile.
+//
+// 20-byte SHA1-checksum of all of the above.
+//
+// Pack Idx file:
+//
+// -- +--------------------------------+
+// fanout | fanout[0] = 2 (for example) |-.
+// table +--------------------------------+ |
+// | fanout[1] | |
+// +--------------------------------+ |
+// | fanout[2] | |
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
+// | fanout[255] = total objects |---.
+// -- +--------------------------------+ | |
+// main | offset | | |
+// index | object name 00XXXXXXXXXXXXXXXX | | |
+// tab +--------------------------------+ | |
+// | offset | | |
+// | object name 00XXXXXXXXXXXXXXXX | | |
+// +--------------------------------+<+ |
+// .-| offset | |
+// | | object name 01XXXXXXXXXXXXXXXX | |
+// | +--------------------------------+ |
+// | | offset | |
+// | | object name 01XXXXXXXXXXXXXXXX | |
+// | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
+// | | offset | |
+// | | object name FFXXXXXXXXXXXXXXXX | |
+// --| +--------------------------------+<--+
+// trailer | | packfile checksum |
+// | +--------------------------------+
+// | | idxfile checksum |
+// | +--------------------------------+
+// .---------.
+// |
+// Pack file entry: <+
+//
+// packed object header:
+// 1-byte size extension bit (MSB)
+// type (next 3 bit)
+// size0 (lower 4-bit)
+// n-byte sizeN (as long as MSB is set, each 7-bit)
+// size0..sizeN form 4+7+7+..+7 bit integer, size0
+// is the least significant part, and sizeN is the
+// most significant part.
+// packed object data:
+// If it is not DELTA, then deflated bytes (the size above
+// is the size before compression).
+// If it is REF_DELTA, then
+// 20-byte base object name SHA1 (the size above is the
+// size of the delta data that follows).
+// delta data, deflated.
+// If it is OFS_DELTA, then
+// n-byte offset (see below) interpreted as a negative
+// offset from the type-byte of the header of the
+// ofs-delta entry (the size above is the size of
+// the delta data that follows).
+// delta data, deflated.
+//
+// offset encoding:
+// n bytes with MSB set in all but the last one.
+// The offset is then the number constructed by
+// concatenating the lower 7 bit of each byte, and
+// for n >= 2 adding 2^7 + 2^14 + ... + 2^(7*(n-1))
+// to the result.
+//
+// == Version 2 pack-*.idx files support packs larger than 4 GiB, and
+// have some other reorganizations. They have the format:
+//
+// - A 4-byte magic number '\377tOc' which is an unreasonable
+// fanout[0] value.
+//
+// - A 4-byte version number (= 2)
+//
+// - A 256-entry fan-out table just like v1.
+//
+// - A table of sorted 20-byte SHA1 object names. These are
+// packed together without offset values to reduce the cache
+// footprint of the binary search for a specific object name.
+//
+// - A table of 4-byte CRC32 values of the packed object data.
+// This is new in v2 so compressed data can be copied directly
+// from pack to pack during repacking without undetected
+// data corruption.
+//
+// - A table of 4-byte offset values (in network byte order).
+// These are usually 31-bit pack file offsets, but large
+// offsets are encoded as an index into the next table with
+// the msbit set.
+//
+// - A table of 8-byte offset entries (empty for pack files less
+// than 2 GiB). Pack files are organized with heavily used
+// objects toward the front, so most object references should
+// not need to refer to this table.
+//
+// - The same trailer as a v1 pack file:
+//
+// A copy of the 20-byte SHA1 checksum at the end of
+// corresponding packfile.
+//
+// 20-byte SHA1-checksum of all of the above.
+//
+// Source:
+// https://www.kernel.org/pub/software/scm/git/docs/v1.7.5/technical/pack-format.txt
+package idxfile
--- /dev/null
+package idxfile
+
+import (
+ "crypto/sha1"
+ "hash"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/utils/binary"
+)
+
+// Encoder writes MemoryIndex structs to an output stream.
+type Encoder struct {
+ io.Writer
+ hash hash.Hash
+}
+
+// NewEncoder returns a new stream encoder that writes to w.
+func NewEncoder(w io.Writer) *Encoder {
+ h := sha1.New()
+ mw := io.MultiWriter(w, h)
+ return &Encoder{mw, h}
+}
+
+// Encode encodes an MemoryIndex to the encoder writer.
+func (e *Encoder) Encode(idx *MemoryIndex) (int, error) {
+ flow := []func(*MemoryIndex) (int, error){
+ e.encodeHeader,
+ e.encodeFanout,
+ e.encodeHashes,
+ e.encodeCRC32,
+ e.encodeOffsets,
+ e.encodeChecksums,
+ }
+
+ sz := 0
+ for _, f := range flow {
+ i, err := f(idx)
+ sz += i
+
+ if err != nil {
+ return sz, err
+ }
+ }
+
+ return sz, nil
+}
+
+func (e *Encoder) encodeHeader(idx *MemoryIndex) (int, error) {
+ c, err := e.Write(idxHeader)
+ if err != nil {
+ return c, err
+ }
+
+ return c + 4, binary.WriteUint32(e, idx.Version)
+}
+
+func (e *Encoder) encodeFanout(idx *MemoryIndex) (int, error) {
+ for _, c := range idx.Fanout {
+ if err := binary.WriteUint32(e, c); err != nil {
+ return 0, err
+ }
+ }
+
+ return fanout * 4, nil
+}
+
+func (e *Encoder) encodeHashes(idx *MemoryIndex) (int, error) {
+ var size int
+ for k := 0; k < fanout; k++ {
+ pos := idx.FanoutMapping[k]
+ if pos == noMapping {
+ continue
+ }
+
+ n, err := e.Write(idx.Names[pos])
+ if err != nil {
+ return size, err
+ }
+ size += n
+ }
+ return size, nil
+}
+
+func (e *Encoder) encodeCRC32(idx *MemoryIndex) (int, error) {
+ var size int
+ for k := 0; k < fanout; k++ {
+ pos := idx.FanoutMapping[k]
+ if pos == noMapping {
+ continue
+ }
+
+ n, err := e.Write(idx.CRC32[pos])
+ if err != nil {
+ return size, err
+ }
+
+ size += n
+ }
+
+ return size, nil
+}
+
+func (e *Encoder) encodeOffsets(idx *MemoryIndex) (int, error) {
+ var size int
+ for k := 0; k < fanout; k++ {
+ pos := idx.FanoutMapping[k]
+ if pos == noMapping {
+ continue
+ }
+
+ n, err := e.Write(idx.Offset32[pos])
+ if err != nil {
+ return size, err
+ }
+
+ size += n
+ }
+
+ if len(idx.Offset64) > 0 {
+ n, err := e.Write(idx.Offset64)
+ if err != nil {
+ return size, err
+ }
+
+ size += n
+ }
+
+ return size, nil
+}
+
+func (e *Encoder) encodeChecksums(idx *MemoryIndex) (int, error) {
+ if _, err := e.Write(idx.PackfileChecksum[:]); err != nil {
+ return 0, err
+ }
+
+ copy(idx.IdxChecksum[:], e.hash.Sum(nil)[:20])
+ if _, err := e.Write(idx.IdxChecksum[:]); err != nil {
+ return 0, err
+ }
+
+ return 40, nil
+}
--- /dev/null
+package idxfile
+
+import (
+ "bytes"
+ "io"
+ "sort"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/utils/binary"
+)
+
+const (
+ // VersionSupported is the only idx version supported.
+ VersionSupported = 2
+
+ noMapping = -1
+)
+
+var (
+ idxHeader = []byte{255, 't', 'O', 'c'}
+)
+
+// Index represents an index of a packfile.
+type Index interface {
+ // Contains checks whether the given hash is in the index.
+ Contains(h plumbing.Hash) (bool, error)
+ // FindOffset finds the offset in the packfile for the object with
+ // the given hash.
+ FindOffset(h plumbing.Hash) (int64, error)
+ // FindCRC32 finds the CRC32 of the object with the given hash.
+ FindCRC32(h plumbing.Hash) (uint32, error)
+ // FindHash finds the hash for the object with the given offset.
+ FindHash(o int64) (plumbing.Hash, error)
+ // Count returns the number of entries in the index.
+ Count() (int64, error)
+ // Entries returns an iterator to retrieve all index entries.
+ Entries() (EntryIter, error)
+ // EntriesByOffset returns an iterator to retrieve all index entries ordered
+ // by offset.
+ EntriesByOffset() (EntryIter, error)
+}
+
+// MemoryIndex is the in memory representation of an idx file.
+type MemoryIndex struct {
+ Version uint32
+ Fanout [256]uint32
+ // FanoutMapping maps the position in the fanout table to the position
+ // in the Names, Offset32 and CRC32 slices. This improves the memory
+ // usage by not needing an array with unnecessary empty slots.
+ FanoutMapping [256]int
+ Names [][]byte
+ Offset32 [][]byte
+ CRC32 [][]byte
+ Offset64 []byte
+ PackfileChecksum [20]byte
+ IdxChecksum [20]byte
+
+ offsetHash map[int64]plumbing.Hash
+}
+
+var _ Index = (*MemoryIndex)(nil)
+
+// NewMemoryIndex returns an instance of a new MemoryIndex.
+func NewMemoryIndex() *MemoryIndex {
+ return &MemoryIndex{}
+}
+
+func (idx *MemoryIndex) findHashIndex(h plumbing.Hash) (int, bool) {
+ k := idx.FanoutMapping[h[0]]
+ if k == noMapping {
+ return 0, false
+ }
+
+ if len(idx.Names) <= k {
+ return 0, false
+ }
+
+ data := idx.Names[k]
+ high := uint64(len(idx.Offset32[k])) >> 2
+ if high == 0 {
+ return 0, false
+ }
+
+ low := uint64(0)
+ for {
+ mid := (low + high) >> 1
+ offset := mid * objectIDLength
+
+ cmp := bytes.Compare(h[:], data[offset:offset+objectIDLength])
+ if cmp < 0 {
+ high = mid
+ } else if cmp == 0 {
+ return int(mid), true
+ } else {
+ low = mid + 1
+ }
+
+ if low >= high {
+ break
+ }
+ }
+
+ return 0, false
+}
+
+// Contains implements the Index interface.
+func (idx *MemoryIndex) Contains(h plumbing.Hash) (bool, error) {
+ _, ok := idx.findHashIndex(h)
+ return ok, nil
+}
+
+// FindOffset implements the Index interface.
+func (idx *MemoryIndex) FindOffset(h plumbing.Hash) (int64, error) {
+ if len(idx.FanoutMapping) <= int(h[0]) {
+ return 0, plumbing.ErrObjectNotFound
+ }
+
+ k := idx.FanoutMapping[h[0]]
+ i, ok := idx.findHashIndex(h)
+ if !ok {
+ return 0, plumbing.ErrObjectNotFound
+ }
+
+ return idx.getOffset(k, i)
+}
+
+const isO64Mask = uint64(1) << 31
+
+func (idx *MemoryIndex) getOffset(firstLevel, secondLevel int) (int64, error) {
+ offset := secondLevel << 2
+ buf := bytes.NewBuffer(idx.Offset32[firstLevel][offset : offset+4])
+ ofs, err := binary.ReadUint32(buf)
+ if err != nil {
+ return -1, err
+ }
+
+ if (uint64(ofs) & isO64Mask) != 0 {
+ offset := 8 * (uint64(ofs) & ^isO64Mask)
+ buf := bytes.NewBuffer(idx.Offset64[offset : offset+8])
+ n, err := binary.ReadUint64(buf)
+ if err != nil {
+ return -1, err
+ }
+
+ return int64(n), nil
+ }
+
+ return int64(ofs), nil
+}
+
+// FindCRC32 implements the Index interface.
+func (idx *MemoryIndex) FindCRC32(h plumbing.Hash) (uint32, error) {
+ k := idx.FanoutMapping[h[0]]
+ i, ok := idx.findHashIndex(h)
+ if !ok {
+ return 0, plumbing.ErrObjectNotFound
+ }
+
+ return idx.getCRC32(k, i)
+}
+
+func (idx *MemoryIndex) getCRC32(firstLevel, secondLevel int) (uint32, error) {
+ offset := secondLevel << 2
+ buf := bytes.NewBuffer(idx.CRC32[firstLevel][offset : offset+4])
+ return binary.ReadUint32(buf)
+}
+
+// FindHash implements the Index interface.
+func (idx *MemoryIndex) FindHash(o int64) (plumbing.Hash, error) {
+ // Lazily generate the reverse offset/hash map if required.
+ if idx.offsetHash == nil {
+ if err := idx.genOffsetHash(); err != nil {
+ return plumbing.ZeroHash, err
+ }
+ }
+
+ hash, ok := idx.offsetHash[o]
+ if !ok {
+ return plumbing.ZeroHash, plumbing.ErrObjectNotFound
+ }
+
+ return hash, nil
+}
+
+// genOffsetHash generates the offset/hash mapping for reverse search.
+func (idx *MemoryIndex) genOffsetHash() error {
+ count, err := idx.Count()
+ if err != nil {
+ return err
+ }
+
+ idx.offsetHash = make(map[int64]plumbing.Hash, count)
+
+ iter, err := idx.Entries()
+ if err != nil {
+ return err
+ }
+
+ for {
+ entry, err := iter.Next()
+ if err != nil {
+ if err == io.EOF {
+ return nil
+ }
+ return err
+ }
+
+ idx.offsetHash[int64(entry.Offset)] = entry.Hash
+ }
+}
+
+// Count implements the Index interface.
+func (idx *MemoryIndex) Count() (int64, error) {
+ return int64(idx.Fanout[fanout-1]), nil
+}
+
+// Entries implements the Index interface.
+func (idx *MemoryIndex) Entries() (EntryIter, error) {
+ return &idxfileEntryIter{idx, 0, 0, 0}, nil
+}
+
+// EntriesByOffset implements the Index interface.
+func (idx *MemoryIndex) EntriesByOffset() (EntryIter, error) {
+ count, err := idx.Count()
+ if err != nil {
+ return nil, err
+ }
+
+ iter := &idxfileEntryOffsetIter{
+ entries: make(entriesByOffset, count),
+ }
+
+ entries, err := idx.Entries()
+ if err != nil {
+ return nil, err
+ }
+
+ for pos := 0; int64(pos) < count; pos++ {
+ entry, err := entries.Next()
+ if err != nil {
+ return nil, err
+ }
+
+ iter.entries[pos] = entry
+ }
+
+ sort.Sort(iter.entries)
+
+ return iter, nil
+}
+
+// EntryIter is an iterator that will return the entries in a packfile index.
+type EntryIter interface {
+ // Next returns the next entry in the packfile index.
+ Next() (*Entry, error)
+ // Close closes the iterator.
+ Close() error
+}
+
+type idxfileEntryIter struct {
+ idx *MemoryIndex
+ total int
+ firstLevel, secondLevel int
+}
+
+func (i *idxfileEntryIter) Next() (*Entry, error) {
+ for {
+ if i.firstLevel >= fanout {
+ return nil, io.EOF
+ }
+
+ if i.total >= int(i.idx.Fanout[i.firstLevel]) {
+ i.firstLevel++
+ i.secondLevel = 0
+ continue
+ }
+
+ entry := new(Entry)
+ ofs := i.secondLevel * objectIDLength
+ copy(entry.Hash[:], i.idx.Names[i.idx.FanoutMapping[i.firstLevel]][ofs:])
+
+ pos := i.idx.FanoutMapping[entry.Hash[0]]
+
+ offset, err := i.idx.getOffset(pos, i.secondLevel)
+ if err != nil {
+ return nil, err
+ }
+ entry.Offset = uint64(offset)
+
+ entry.CRC32, err = i.idx.getCRC32(pos, i.secondLevel)
+ if err != nil {
+ return nil, err
+ }
+
+ i.secondLevel++
+ i.total++
+
+ return entry, nil
+ }
+}
+
+func (i *idxfileEntryIter) Close() error {
+ i.firstLevel = fanout
+ return nil
+}
+
+// Entry is the in memory representation of an object entry in the idx file.
+type Entry struct {
+ Hash plumbing.Hash
+ CRC32 uint32
+ Offset uint64
+}
+
+type idxfileEntryOffsetIter struct {
+ entries entriesByOffset
+ pos int
+}
+
+func (i *idxfileEntryOffsetIter) Next() (*Entry, error) {
+ if i.pos >= len(i.entries) {
+ return nil, io.EOF
+ }
+
+ entry := i.entries[i.pos]
+ i.pos++
+
+ return entry, nil
+}
+
+func (i *idxfileEntryOffsetIter) Close() error {
+ i.pos = len(i.entries) + 1
+ return nil
+}
+
+type entriesByOffset []*Entry
+
+func (o entriesByOffset) Len() int {
+ return len(o)
+}
+
+func (o entriesByOffset) Less(i int, j int) bool {
+ return o[i].Offset < o[j].Offset
+}
+
+func (o entriesByOffset) Swap(i int, j int) {
+ o[i], o[j] = o[j], o[i]
+}
--- /dev/null
+package idxfile
+
+import (
+ "bytes"
+ "fmt"
+ "math"
+ "sort"
+ "sync"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/utils/binary"
+)
+
+// objects implements sort.Interface and uses hash as sorting key.
+type objects []Entry
+
+// Writer implements a packfile Observer interface and is used to generate
+// indexes.
+type Writer struct {
+ m sync.Mutex
+
+ count uint32
+ checksum plumbing.Hash
+ objects objects
+ offset64 uint32
+ finished bool
+ index *MemoryIndex
+ added map[plumbing.Hash]struct{}
+}
+
+// Index returns a previously created MemoryIndex or creates a new one if
+// needed.
+func (w *Writer) Index() (*MemoryIndex, error) {
+ w.m.Lock()
+ defer w.m.Unlock()
+
+ if w.index == nil {
+ return w.createIndex()
+ }
+
+ return w.index, nil
+}
+
+// Add appends new object data.
+func (w *Writer) Add(h plumbing.Hash, pos uint64, crc uint32) {
+ w.m.Lock()
+ defer w.m.Unlock()
+
+ if w.added == nil {
+ w.added = make(map[plumbing.Hash]struct{})
+ }
+
+ if _, ok := w.added[h]; !ok {
+ w.added[h] = struct{}{}
+ w.objects = append(w.objects, Entry{h, crc, pos})
+ }
+
+}
+
+func (w *Writer) Finished() bool {
+ return w.finished
+}
+
+// OnHeader implements packfile.Observer interface.
+func (w *Writer) OnHeader(count uint32) error {
+ w.count = count
+ w.objects = make(objects, 0, count)
+ return nil
+}
+
+// OnInflatedObjectHeader implements packfile.Observer interface.
+func (w *Writer) OnInflatedObjectHeader(t plumbing.ObjectType, objSize int64, pos int64) error {
+ return nil
+}
+
+// OnInflatedObjectContent implements packfile.Observer interface.
+func (w *Writer) OnInflatedObjectContent(h plumbing.Hash, pos int64, crc uint32, _ []byte) error {
+ w.Add(h, uint64(pos), crc)
+ return nil
+}
+
+// OnFooter implements packfile.Observer interface.
+func (w *Writer) OnFooter(h plumbing.Hash) error {
+ w.checksum = h
+ w.finished = true
+ _, err := w.createIndex()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// creatIndex returns a filled MemoryIndex with the information filled by
+// the observer callbacks.
+func (w *Writer) createIndex() (*MemoryIndex, error) {
+ if !w.finished {
+ return nil, fmt.Errorf("the index still hasn't finished building")
+ }
+
+ idx := new(MemoryIndex)
+ w.index = idx
+
+ sort.Sort(w.objects)
+
+ // unmap all fans by default
+ for i := range idx.FanoutMapping {
+ idx.FanoutMapping[i] = noMapping
+ }
+
+ buf := new(bytes.Buffer)
+
+ last := -1
+ bucket := -1
+ for i, o := range w.objects {
+ fan := o.Hash[0]
+
+ // fill the gaps between fans
+ for j := last + 1; j < int(fan); j++ {
+ idx.Fanout[j] = uint32(i)
+ }
+
+ // update the number of objects for this position
+ idx.Fanout[fan] = uint32(i + 1)
+
+ // we move from one bucket to another, update counters and allocate
+ // memory
+ if last != int(fan) {
+ bucket++
+ idx.FanoutMapping[fan] = bucket
+ last = int(fan)
+
+ idx.Names = append(idx.Names, make([]byte, 0))
+ idx.Offset32 = append(idx.Offset32, make([]byte, 0))
+ idx.CRC32 = append(idx.CRC32, make([]byte, 0))
+ }
+
+ idx.Names[bucket] = append(idx.Names[bucket], o.Hash[:]...)
+
+ offset := o.Offset
+ if offset > math.MaxInt32 {
+ offset = w.addOffset64(offset)
+ }
+
+ buf.Truncate(0)
+ binary.WriteUint32(buf, uint32(offset))
+ idx.Offset32[bucket] = append(idx.Offset32[bucket], buf.Bytes()...)
+
+ buf.Truncate(0)
+ binary.WriteUint32(buf, uint32(o.CRC32))
+ idx.CRC32[bucket] = append(idx.CRC32[bucket], buf.Bytes()...)
+ }
+
+ for j := last + 1; j < 256; j++ {
+ idx.Fanout[j] = uint32(len(w.objects))
+ }
+
+ idx.Version = VersionSupported
+ idx.PackfileChecksum = w.checksum
+
+ return idx, nil
+}
+
+func (w *Writer) addOffset64(pos uint64) uint64 {
+ buf := new(bytes.Buffer)
+ binary.WriteUint64(buf, pos)
+ w.index.Offset64 = append(w.index.Offset64, buf.Bytes()...)
+
+ index := uint64(w.offset64 | (1 << 31))
+ w.offset64++
+
+ return index
+}
+
+func (o objects) Len() int {
+ return len(o)
+}
+
+func (o objects) Less(i int, j int) bool {
+ cmp := bytes.Compare(o[i].Hash[:], o[j].Hash[:])
+ return cmp < 0
+}
+
+func (o objects) Swap(i int, j int) {
+ o[i], o[j] = o[j], o[i]
+}
--- /dev/null
+package index
+
+import (
+ "bytes"
+ "crypto/sha1"
+ "errors"
+ "hash"
+ "io"
+ "io/ioutil"
+ "strconv"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/utils/binary"
+)
+
+var (
+ // DecodeVersionSupported is the range of supported index versions
+ DecodeVersionSupported = struct{ Min, Max uint32 }{Min: 2, Max: 4}
+
+ // ErrMalformedSignature is returned by Decode when the index header file is
+ // malformed
+ ErrMalformedSignature = errors.New("malformed index signature file")
+ // ErrInvalidChecksum is returned by Decode if the SHA1 hash mismatch with
+ // the read content
+ ErrInvalidChecksum = errors.New("invalid checksum")
+
+ errUnknownExtension = errors.New("unknown extension")
+)
+
+const (
+ entryHeaderLength = 62
+ entryExtended = 0x4000
+ entryValid = 0x8000
+ nameMask = 0xfff
+ intentToAddMask = 1 << 13
+ skipWorkTreeMask = 1 << 14
+)
+
+// A Decoder reads and decodes index files from an input stream.
+type Decoder struct {
+ r io.Reader
+ hash hash.Hash
+ lastEntry *Entry
+}
+
+// NewDecoder returns a new decoder that reads from r.
+func NewDecoder(r io.Reader) *Decoder {
+ h := sha1.New()
+ return &Decoder{
+ r: io.TeeReader(r, h),
+ hash: h,
+ }
+}
+
+// Decode reads the whole index object from its input and stores it in the
+// value pointed to by idx.
+func (d *Decoder) Decode(idx *Index) error {
+ var err error
+ idx.Version, err = validateHeader(d.r)
+ if err != nil {
+ return err
+ }
+
+ entryCount, err := binary.ReadUint32(d.r)
+ if err != nil {
+ return err
+ }
+
+ if err := d.readEntries(idx, int(entryCount)); err != nil {
+ return err
+ }
+
+ return d.readExtensions(idx)
+}
+
+func (d *Decoder) readEntries(idx *Index, count int) error {
+ for i := 0; i < count; i++ {
+ e, err := d.readEntry(idx)
+ if err != nil {
+ return err
+ }
+
+ d.lastEntry = e
+ idx.Entries = append(idx.Entries, e)
+ }
+
+ return nil
+}
+
+func (d *Decoder) readEntry(idx *Index) (*Entry, error) {
+ e := &Entry{}
+
+ var msec, mnsec, sec, nsec uint32
+ var flags uint16
+
+ flow := []interface{}{
+ &sec, &nsec,
+ &msec, &mnsec,
+ &e.Dev,
+ &e.Inode,
+ &e.Mode,
+ &e.UID,
+ &e.GID,
+ &e.Size,
+ &e.Hash,
+ &flags,
+ }
+
+ if err := binary.Read(d.r, flow...); err != nil {
+ return nil, err
+ }
+
+ read := entryHeaderLength
+
+ if sec != 0 || nsec != 0 {
+ e.CreatedAt = time.Unix(int64(sec), int64(nsec))
+ }
+
+ if msec != 0 || mnsec != 0 {
+ e.ModifiedAt = time.Unix(int64(msec), int64(mnsec))
+ }
+
+ e.Stage = Stage(flags>>12) & 0x3
+
+ if flags&entryExtended != 0 {
+ extended, err := binary.ReadUint16(d.r)
+ if err != nil {
+ return nil, err
+ }
+
+ read += 2
+ e.IntentToAdd = extended&intentToAddMask != 0
+ e.SkipWorktree = extended&skipWorkTreeMask != 0
+ }
+
+ if err := d.readEntryName(idx, e, flags); err != nil {
+ return nil, err
+ }
+
+ return e, d.padEntry(idx, e, read)
+}
+
+func (d *Decoder) readEntryName(idx *Index, e *Entry, flags uint16) error {
+ var name string
+ var err error
+
+ switch idx.Version {
+ case 2, 3:
+ len := flags & nameMask
+ name, err = d.doReadEntryName(len)
+ case 4:
+ name, err = d.doReadEntryNameV4()
+ default:
+ return ErrUnsupportedVersion
+ }
+
+ if err != nil {
+ return err
+ }
+
+ e.Name = name
+ return nil
+}
+
+func (d *Decoder) doReadEntryNameV4() (string, error) {
+ l, err := binary.ReadVariableWidthInt(d.r)
+ if err != nil {
+ return "", err
+ }
+
+ var base string
+ if d.lastEntry != nil {
+ base = d.lastEntry.Name[:len(d.lastEntry.Name)-int(l)]
+ }
+
+ name, err := binary.ReadUntil(d.r, '\x00')
+ if err != nil {
+ return "", err
+ }
+
+ return base + string(name), nil
+}
+
+func (d *Decoder) doReadEntryName(len uint16) (string, error) {
+ name := make([]byte, len)
+ if err := binary.Read(d.r, &name); err != nil {
+ return "", err
+ }
+
+ return string(name), nil
+}
+
+// Index entries are padded out to the next 8 byte alignment
+// for historical reasons related to how C Git read the files.
+func (d *Decoder) padEntry(idx *Index, e *Entry, read int) error {
+ if idx.Version == 4 {
+ return nil
+ }
+
+ entrySize := read + len(e.Name)
+ padLen := 8 - entrySize%8
+ _, err := io.CopyN(ioutil.Discard, d.r, int64(padLen))
+ return err
+}
+
+func (d *Decoder) readExtensions(idx *Index) error {
+ // TODO: support 'Split index' and 'Untracked cache' extensions, take in
+ // count that they are not supported by jgit or libgit
+
+ var expected []byte
+ var err error
+
+ var header [4]byte
+ for {
+ expected = d.hash.Sum(nil)
+
+ var n int
+ if n, err = io.ReadFull(d.r, header[:]); err != nil {
+ if n == 0 {
+ err = io.EOF
+ }
+
+ break
+ }
+
+ err = d.readExtension(idx, header[:])
+ if err != nil {
+ break
+ }
+ }
+
+ if err != errUnknownExtension {
+ return err
+ }
+
+ return d.readChecksum(expected, header)
+}
+
+func (d *Decoder) readExtension(idx *Index, header []byte) error {
+ switch {
+ case bytes.Equal(header, treeExtSignature):
+ r, err := d.getExtensionReader()
+ if err != nil {
+ return err
+ }
+
+ idx.Cache = &Tree{}
+ d := &treeExtensionDecoder{r}
+ if err := d.Decode(idx.Cache); err != nil {
+ return err
+ }
+ case bytes.Equal(header, resolveUndoExtSignature):
+ r, err := d.getExtensionReader()
+ if err != nil {
+ return err
+ }
+
+ idx.ResolveUndo = &ResolveUndo{}
+ d := &resolveUndoDecoder{r}
+ if err := d.Decode(idx.ResolveUndo); err != nil {
+ return err
+ }
+ case bytes.Equal(header, endOfIndexEntryExtSignature):
+ r, err := d.getExtensionReader()
+ if err != nil {
+ return err
+ }
+
+ idx.EndOfIndexEntry = &EndOfIndexEntry{}
+ d := &endOfIndexEntryDecoder{r}
+ if err := d.Decode(idx.EndOfIndexEntry); err != nil {
+ return err
+ }
+ default:
+ return errUnknownExtension
+ }
+
+ return nil
+}
+
+func (d *Decoder) getExtensionReader() (io.Reader, error) {
+ len, err := binary.ReadUint32(d.r)
+ if err != nil {
+ return nil, err
+ }
+
+ return &io.LimitedReader{R: d.r, N: int64(len)}, nil
+}
+
+func (d *Decoder) readChecksum(expected []byte, alreadyRead [4]byte) error {
+ var h plumbing.Hash
+ copy(h[:4], alreadyRead[:])
+
+ if err := binary.Read(d.r, h[4:]); err != nil {
+ return err
+ }
+
+ if !bytes.Equal(h[:], expected) {
+ return ErrInvalidChecksum
+ }
+
+ return nil
+}
+
+func validateHeader(r io.Reader) (version uint32, err error) {
+ var s = make([]byte, 4)
+ if _, err := io.ReadFull(r, s); err != nil {
+ return 0, err
+ }
+
+ if !bytes.Equal(s, indexSignature) {
+ return 0, ErrMalformedSignature
+ }
+
+ version, err = binary.ReadUint32(r)
+ if err != nil {
+ return 0, err
+ }
+
+ if version < DecodeVersionSupported.Min || version > DecodeVersionSupported.Max {
+ return 0, ErrUnsupportedVersion
+ }
+
+ return
+}
+
+type treeExtensionDecoder struct {
+ r io.Reader
+}
+
+func (d *treeExtensionDecoder) Decode(t *Tree) error {
+ for {
+ e, err := d.readEntry()
+ if err != nil {
+ if err == io.EOF {
+ return nil
+ }
+
+ return err
+ }
+
+ if e == nil {
+ continue
+ }
+
+ t.Entries = append(t.Entries, *e)
+ }
+}
+
+func (d *treeExtensionDecoder) readEntry() (*TreeEntry, error) {
+ e := &TreeEntry{}
+
+ path, err := binary.ReadUntil(d.r, '\x00')
+ if err != nil {
+ return nil, err
+ }
+
+ e.Path = string(path)
+
+ count, err := binary.ReadUntil(d.r, ' ')
+ if err != nil {
+ return nil, err
+ }
+
+ i, err := strconv.Atoi(string(count))
+ if err != nil {
+ return nil, err
+ }
+
+ // An entry can be in an invalidated state and is represented by having a
+ // negative number in the entry_count field.
+ if i == -1 {
+ return nil, nil
+ }
+
+ e.Entries = i
+ trees, err := binary.ReadUntil(d.r, '\n')
+ if err != nil {
+ return nil, err
+ }
+
+ i, err = strconv.Atoi(string(trees))
+ if err != nil {
+ return nil, err
+ }
+
+ e.Trees = i
+
+ if err := binary.Read(d.r, &e.Hash); err != nil {
+ return nil, err
+ }
+
+ return e, nil
+}
+
+type resolveUndoDecoder struct {
+ r io.Reader
+}
+
+func (d *resolveUndoDecoder) Decode(ru *ResolveUndo) error {
+ for {
+ e, err := d.readEntry()
+ if err != nil {
+ if err == io.EOF {
+ return nil
+ }
+
+ return err
+ }
+
+ ru.Entries = append(ru.Entries, *e)
+ }
+}
+
+func (d *resolveUndoDecoder) readEntry() (*ResolveUndoEntry, error) {
+ e := &ResolveUndoEntry{
+ Stages: make(map[Stage]plumbing.Hash),
+ }
+
+ path, err := binary.ReadUntil(d.r, '\x00')
+ if err != nil {
+ return nil, err
+ }
+
+ e.Path = string(path)
+
+ for i := 0; i < 3; i++ {
+ if err := d.readStage(e, Stage(i+1)); err != nil {
+ return nil, err
+ }
+ }
+
+ for s := range e.Stages {
+ var hash plumbing.Hash
+ if err := binary.Read(d.r, hash[:]); err != nil {
+ return nil, err
+ }
+
+ e.Stages[s] = hash
+ }
+
+ return e, nil
+}
+
+func (d *resolveUndoDecoder) readStage(e *ResolveUndoEntry, s Stage) error {
+ ascii, err := binary.ReadUntil(d.r, '\x00')
+ if err != nil {
+ return err
+ }
+
+ stage, err := strconv.ParseInt(string(ascii), 8, 64)
+ if err != nil {
+ return err
+ }
+
+ if stage != 0 {
+ e.Stages[s] = plumbing.ZeroHash
+ }
+
+ return nil
+}
+
+type endOfIndexEntryDecoder struct {
+ r io.Reader
+}
+
+func (d *endOfIndexEntryDecoder) Decode(e *EndOfIndexEntry) error {
+ var err error
+ e.Offset, err = binary.ReadUint32(d.r)
+ if err != nil {
+ return err
+ }
+
+ return binary.Read(d.r, &e.Hash)
+}
--- /dev/null
+// Package index implements encoding and decoding of index format files.
+//
+// Git index format
+// ================
+//
+// == The Git index file has the following format
+//
+// All binary numbers are in network byte order. Version 2 is described
+// here unless stated otherwise.
+//
+// - A 12-byte header consisting of
+//
+// 4-byte signature:
+// The signature is { 'D', 'I', 'R', 'C' } (stands for "dircache")
+//
+// 4-byte version number:
+// The current supported versions are 2, 3 and 4.
+//
+// 32-bit number of index entries.
+//
+// - A number of sorted index entries (see below).
+//
+// - Extensions
+//
+// Extensions are identified by signature. Optional extensions can
+// be ignored if Git does not understand them.
+//
+// Git currently supports cached tree and resolve undo extensions.
+//
+// 4-byte extension signature. If the first byte is 'A'..'Z' the
+// extension is optional and can be ignored.
+//
+// 32-bit size of the extension
+//
+// Extension data
+//
+// - 160-bit SHA-1 over the content of the index file before this
+// checksum.
+//
+// == Index entry
+//
+// Index entries are sorted in ascending order on the name field,
+// interpreted as a string of unsigned bytes (i.e. memcmp() order, no
+// localization, no special casing of directory separator '/'). Entries
+// with the same name are sorted by their stage field.
+//
+// 32-bit ctime seconds, the last time a file's metadata changed
+// this is stat(2) data
+//
+// 32-bit ctime nanosecond fractions
+// this is stat(2) data
+//
+// 32-bit mtime seconds, the last time a file's data changed
+// this is stat(2) data
+//
+// 32-bit mtime nanosecond fractions
+// this is stat(2) data
+//
+// 32-bit dev
+// this is stat(2) data
+//
+// 32-bit ino
+// this is stat(2) data
+//
+// 32-bit mode, split into (high to low bits)
+//
+// 4-bit object type
+// valid values in binary are 1000 (regular file), 1010 (symbolic link)
+// and 1110 (gitlink)
+//
+// 3-bit unused
+//
+// 9-bit unix permission. Only 0755 and 0644 are valid for regular files.
+// Symbolic links and gitlinks have value 0 in this field.
+//
+// 32-bit uid
+// this is stat(2) data
+//
+// 32-bit gid
+// this is stat(2) data
+//
+// 32-bit file size
+// This is the on-disk size from stat(2), truncated to 32-bit.
+//
+// 160-bit SHA-1 for the represented object
+//
+// A 16-bit 'flags' field split into (high to low bits)
+//
+// 1-bit assume-valid flag
+//
+// 1-bit extended flag (must be zero in version 2)
+//
+// 2-bit stage (during merge)
+//
+// 12-bit name length if the length is less than 0xFFF; otherwise 0xFFF
+// is stored in this field.
+//
+// (Version 3 or later) A 16-bit field, only applicable if the
+// "extended flag" above is 1, split into (high to low bits).
+//
+// 1-bit reserved for future
+//
+// 1-bit skip-worktree flag (used by sparse checkout)
+//
+// 1-bit intent-to-add flag (used by "git add -N")
+//
+// 13-bit unused, must be zero
+//
+// Entry path name (variable length) relative to top level directory
+// (without leading slash). '/' is used as path separator. The special
+// path components ".", ".." and ".git" (without quotes) are disallowed.
+// Trailing slash is also disallowed.
+//
+// The exact encoding is undefined, but the '.' and '/' characters
+// are encoded in 7-bit ASCII and the encoding cannot contain a NUL
+// byte (iow, this is a UNIX pathname).
+//
+// (Version 4) In version 4, the entry path name is prefix-compressed
+// relative to the path name for the previous entry (the very first
+// entry is encoded as if the path name for the previous entry is an
+// empty string). At the beginning of an entry, an integer N in the
+// variable width encoding (the same encoding as the offset is encoded
+// for OFS_DELTA pack entries; see pack-format.txt) is stored, followed
+// by a NUL-terminated string S. Removing N bytes from the end of the
+// path name for the previous entry, and replacing it with the string S
+// yields the path name for this entry.
+//
+// 1-8 nul bytes as necessary to pad the entry to a multiple of eight bytes
+// while keeping the name NUL-terminated.
+//
+// (Version 4) In version 4, the padding after the pathname does not
+// exist.
+//
+// Interpretation of index entries in split index mode is completely
+// different. See below for details.
+//
+// == Extensions
+//
+// === Cached tree
+//
+// Cached tree extension contains pre-computed hashes for trees that can
+// be derived from the index. It helps speed up tree object generation
+// from index for a new commit.
+//
+// When a path is updated in index, the path must be invalidated and
+// removed from tree cache.
+//
+// The signature for this extension is { 'T', 'R', 'E', 'E' }.
+//
+// A series of entries fill the entire extension; each of which
+// consists of:
+//
+// - NUL-terminated path component (relative to its parent directory);
+//
+// - ASCII decimal number of entries in the index that is covered by the
+// tree this entry represents (entry_count);
+//
+// - A space (ASCII 32);
+//
+// - ASCII decimal number that represents the number of subtrees this
+// tree has;
+//
+// - A newline (ASCII 10); and
+//
+// - 160-bit object name for the object that would result from writing
+// this span of index as a tree.
+//
+// An entry can be in an invalidated state and is represented by having
+// a negative number in the entry_count field. In this case, there is no
+// object name and the next entry starts immediately after the newline.
+// When writing an invalid entry, -1 should always be used as entry_count.
+//
+// The entries are written out in the top-down, depth-first order. The
+// first entry represents the root level of the repository, followed by the
+// first subtree--let's call this A--of the root level (with its name
+// relative to the root level), followed by the first subtree of A (with
+// its name relative to A), ...
+//
+// === Resolve undo
+//
+// A conflict is represented in the index as a set of higher stage entries.
+// When a conflict is resolved (e.g. with "git add path"), these higher
+// stage entries will be removed and a stage-0 entry with proper resolution
+// is added.
+//
+// When these higher stage entries are removed, they are saved in the
+// resolve undo extension, so that conflicts can be recreated (e.g. with
+// "git checkout -m"), in case users want to redo a conflict resolution
+// from scratch.
+//
+// The signature for this extension is { 'R', 'E', 'U', 'C' }.
+//
+// A series of entries fill the entire extension; each of which
+// consists of:
+//
+// - NUL-terminated pathname the entry describes (relative to the root of
+// the repository, i.e. full pathname);
+//
+// - Three NUL-terminated ASCII octal numbers, entry mode of entries in
+// stage 1 to 3 (a missing stage is represented by "0" in this field);
+// and
+//
+// - At most three 160-bit object names of the entry in stages from 1 to 3
+// (nothing is written for a missing stage).
+//
+// === Split index
+//
+// In split index mode, the majority of index entries could be stored
+// in a separate file. This extension records the changes to be made on
+// top of that to produce the final index.
+//
+// The signature for this extension is { 'l', 'i', 'n', 'k' }.
+//
+// The extension consists of:
+//
+// - 160-bit SHA-1 of the shared index file. The shared index file path
+// is $GIT_DIR/sharedindex.<SHA-1>. If all 160 bits are zero, the
+// index does not require a shared index file.
+//
+// - An ewah-encoded delete bitmap, each bit represents an entry in the
+// shared index. If a bit is set, its corresponding entry in the
+// shared index will be removed from the final index. Note, because
+// a delete operation changes index entry positions, but we do need
+// original positions in replace phase, it's best to just mark
+// entries for removal, then do a mass deletion after replacement.
+//
+// - An ewah-encoded replace bitmap, each bit represents an entry in
+// the shared index. If a bit is set, its corresponding entry in the
+// shared index will be replaced with an entry in this index
+// file. All replaced entries are stored in sorted order in this
+// index. The first "1" bit in the replace bitmap corresponds to the
+// first index entry, the second "1" bit to the second entry and so
+// on. Replaced entries may have empty path names to save space.
+//
+// The remaining index entries after replaced ones will be added to the
+// final index. These added entries are also sorted by entry name then
+// stage.
+//
+// == Untracked cache
+//
+// Untracked cache saves the untracked file list and necessary data to
+// verify the cache. The signature for this extension is { 'U', 'N',
+// 'T', 'R' }.
+//
+// The extension starts with
+//
+// - A sequence of NUL-terminated strings, preceded by the size of the
+// sequence in variable width encoding. Each string describes the
+// environment where the cache can be used.
+//
+// - Stat data of $GIT_DIR/info/exclude. See "Index entry" section from
+// ctime field until "file size".
+//
+// - Stat data of plumbing.excludesfile
+//
+// - 32-bit dir_flags (see struct dir_struct)
+//
+// - 160-bit SHA-1 of $GIT_DIR/info/exclude. Null SHA-1 means the file
+// does not exist.
+//
+// - 160-bit SHA-1 of plumbing.excludesfile. Null SHA-1 means the file does
+// not exist.
+//
+// - NUL-terminated string of per-dir exclude file name. This usually
+// is ".gitignore".
+//
+// - The number of following directory blocks, variable width
+// encoding. If this number is zero, the extension ends here with a
+// following NUL.
+//
+// - A number of directory blocks in depth-first-search order, each
+// consists of
+//
+// - The number of untracked entries, variable width encoding.
+//
+// - The number of sub-directory blocks, variable width encoding.
+//
+// - The directory name terminated by NUL.
+//
+// - A number of untracked file/dir names terminated by NUL.
+//
+// The remaining data of each directory block is grouped by type:
+//
+// - An ewah bitmap, the n-th bit marks whether the n-th directory has
+// valid untracked cache entries.
+//
+// - An ewah bitmap, the n-th bit records "check-only" bit of
+// read_directory_recursive() for the n-th directory.
+//
+// - An ewah bitmap, the n-th bit indicates whether SHA-1 and stat data
+// is valid for the n-th directory and exists in the next data.
+//
+// - An array of stat data. The n-th data corresponds with the n-th
+// "one" bit in the previous ewah bitmap.
+//
+// - An array of SHA-1. The n-th SHA-1 corresponds with the n-th "one" bit
+// in the previous ewah bitmap.
+//
+// - One NUL.
+//
+// == File System Monitor cache
+//
+// The file system monitor cache tracks files for which the core.fsmonitor
+// hook has told us about changes. The signature for this extension is
+// { 'F', 'S', 'M', 'N' }.
+//
+// The extension starts with
+//
+// - 32-bit version number: the current supported version is 1.
+//
+// - 64-bit time: the extension data reflects all changes through the given
+// time which is stored as the nanoseconds elapsed since midnight,
+// January 1, 1970.
+//
+// - 32-bit bitmap size: the size of the CE_FSMONITOR_VALID bitmap.
+//
+// - An ewah bitmap, the n-th bit indicates whether the n-th index entry
+// is not CE_FSMONITOR_VALID.
+//
+// == End of Index Entry
+//
+// The End of Index Entry (EOIE) is used to locate the end of the variable
+// length index entries and the begining of the extensions. Code can take
+// advantage of this to quickly locate the index extensions without having
+// to parse through all of the index entries.
+//
+// Because it must be able to be loaded before the variable length cache
+// entries and other index extensions, this extension must be written last.
+// The signature for this extension is { 'E', 'O', 'I', 'E' }.
+//
+// The extension consists of:
+//
+// - 32-bit offset to the end of the index entries
+//
+// - 160-bit SHA-1 over the extension types and their sizes (but not
+// their contents). E.g. if we have "TREE" extension that is N-bytes
+// long, "REUC" extension that is M-bytes long, followed by "EOIE",
+// then the hash would be:
+//
+// SHA-1("TREE" + <binary representation of N> +
+// "REUC" + <binary representation of M>)
+//
+// == Index Entry Offset Table
+//
+// The Index Entry Offset Table (IEOT) is used to help address the CPU
+// cost of loading the index by enabling multi-threading the process of
+// converting cache entries from the on-disk format to the in-memory format.
+// The signature for this extension is { 'I', 'E', 'O', 'T' }.
+//
+// The extension consists of:
+//
+// - 32-bit version (currently 1)
+//
+// - A number of index offset entries each consisting of:
+//
+// - 32-bit offset from the begining of the file to the first cache entry
+// in this block of entries.
+//
+// - 32-bit count of cache entries in this blockpackage index
+package index
--- /dev/null
+package index
+
+import (
+ "bytes"
+ "crypto/sha1"
+ "errors"
+ "hash"
+ "io"
+ "sort"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/utils/binary"
+)
+
+var (
+ // EncodeVersionSupported is the range of supported index versions
+ EncodeVersionSupported uint32 = 2
+
+ // ErrInvalidTimestamp is returned by Encode if a Index with a Entry with
+ // negative timestamp values
+ ErrInvalidTimestamp = errors.New("negative timestamps are not allowed")
+)
+
+// An Encoder writes an Index to an output stream.
+type Encoder struct {
+ w io.Writer
+ hash hash.Hash
+}
+
+// NewEncoder returns a new encoder that writes to w.
+func NewEncoder(w io.Writer) *Encoder {
+ h := sha1.New()
+ mw := io.MultiWriter(w, h)
+ return &Encoder{mw, h}
+}
+
+// Encode writes the Index to the stream of the encoder.
+func (e *Encoder) Encode(idx *Index) error {
+ // TODO: support versions v3 and v4
+ // TODO: support extensions
+ if idx.Version != EncodeVersionSupported {
+ return ErrUnsupportedVersion
+ }
+
+ if err := e.encodeHeader(idx); err != nil {
+ return err
+ }
+
+ if err := e.encodeEntries(idx); err != nil {
+ return err
+ }
+
+ return e.encodeFooter()
+}
+
+func (e *Encoder) encodeHeader(idx *Index) error {
+ return binary.Write(e.w,
+ indexSignature,
+ idx.Version,
+ uint32(len(idx.Entries)),
+ )
+}
+
+func (e *Encoder) encodeEntries(idx *Index) error {
+ sort.Sort(byName(idx.Entries))
+
+ for _, entry := range idx.Entries {
+ if err := e.encodeEntry(entry); err != nil {
+ return err
+ }
+
+ wrote := entryHeaderLength + len(entry.Name)
+ if err := e.padEntry(wrote); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (e *Encoder) encodeEntry(entry *Entry) error {
+ if entry.IntentToAdd || entry.SkipWorktree {
+ return ErrUnsupportedVersion
+ }
+
+ sec, nsec, err := e.timeToUint32(&entry.CreatedAt)
+ if err != nil {
+ return err
+ }
+
+ msec, mnsec, err := e.timeToUint32(&entry.ModifiedAt)
+ if err != nil {
+ return err
+ }
+
+ flags := uint16(entry.Stage&0x3) << 12
+ if l := len(entry.Name); l < nameMask {
+ flags |= uint16(l)
+ } else {
+ flags |= nameMask
+ }
+
+ flow := []interface{}{
+ sec, nsec,
+ msec, mnsec,
+ entry.Dev,
+ entry.Inode,
+ entry.Mode,
+ entry.UID,
+ entry.GID,
+ entry.Size,
+ entry.Hash[:],
+ flags,
+ }
+
+ if err := binary.Write(e.w, flow...); err != nil {
+ return err
+ }
+
+ return binary.Write(e.w, []byte(entry.Name))
+}
+
+func (e *Encoder) timeToUint32(t *time.Time) (uint32, uint32, error) {
+ if t.IsZero() {
+ return 0, 0, nil
+ }
+
+ if t.Unix() < 0 || t.UnixNano() < 0 {
+ return 0, 0, ErrInvalidTimestamp
+ }
+
+ return uint32(t.Unix()), uint32(t.Nanosecond()), nil
+}
+
+func (e *Encoder) padEntry(wrote int) error {
+ padLen := 8 - wrote%8
+
+ _, err := e.w.Write(bytes.Repeat([]byte{'\x00'}, padLen))
+ return err
+}
+
+func (e *Encoder) encodeFooter() error {
+ return binary.Write(e.w, e.hash.Sum(nil))
+}
+
+type byName []*Entry
+
+func (l byName) Len() int { return len(l) }
+func (l byName) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
+func (l byName) Less(i, j int) bool { return l[i].Name < l[j].Name }
--- /dev/null
+package index
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "path/filepath"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/filemode"
+)
+
+var (
+ // ErrUnsupportedVersion is returned by Decode when the index file version
+ // is not supported.
+ ErrUnsupportedVersion = errors.New("unsupported version")
+ // ErrEntryNotFound is returned by Index.Entry, if an entry is not found.
+ ErrEntryNotFound = errors.New("entry not found")
+
+ indexSignature = []byte{'D', 'I', 'R', 'C'}
+ treeExtSignature = []byte{'T', 'R', 'E', 'E'}
+ resolveUndoExtSignature = []byte{'R', 'E', 'U', 'C'}
+ endOfIndexEntryExtSignature = []byte{'E', 'O', 'I', 'E'}
+)
+
+// Stage during merge
+type Stage int
+
+const (
+ // Merged is the default stage, fully merged
+ Merged Stage = 1
+ // AncestorMode is the base revision
+ AncestorMode Stage = 1
+ // OurMode is the first tree revision, ours
+ OurMode Stage = 2
+ // TheirMode is the second tree revision, theirs
+ TheirMode Stage = 3
+)
+
+// Index contains the information about which objects are currently checked out
+// in the worktree, having information about the working files. Changes in
+// worktree are detected using this Index. The Index is also used during merges
+type Index struct {
+ // Version is index version
+ Version uint32
+ // Entries collection of entries represented by this Index. The order of
+ // this collection is not guaranteed
+ Entries []*Entry
+ // Cache represents the 'Cached tree' extension
+ Cache *Tree
+ // ResolveUndo represents the 'Resolve undo' extension
+ ResolveUndo *ResolveUndo
+ // EndOfIndexEntry represents the 'End of Index Entry' extension
+ EndOfIndexEntry *EndOfIndexEntry
+}
+
+// Add creates a new Entry and returns it. The caller should first check that
+// another entry with the same path does not exist.
+func (i *Index) Add(path string) *Entry {
+ e := &Entry{
+ Name: filepath.ToSlash(path),
+ }
+
+ i.Entries = append(i.Entries, e)
+ return e
+}
+
+// Entry returns the entry that match the given path, if any.
+func (i *Index) Entry(path string) (*Entry, error) {
+ path = filepath.ToSlash(path)
+ for _, e := range i.Entries {
+ if e.Name == path {
+ return e, nil
+ }
+ }
+
+ return nil, ErrEntryNotFound
+}
+
+// Remove remove the entry that match the give path and returns deleted entry.
+func (i *Index) Remove(path string) (*Entry, error) {
+ path = filepath.ToSlash(path)
+ for index, e := range i.Entries {
+ if e.Name == path {
+ i.Entries = append(i.Entries[:index], i.Entries[index+1:]...)
+ return e, nil
+ }
+ }
+
+ return nil, ErrEntryNotFound
+}
+
+// Glob returns the all entries matching pattern or nil if there is no matching
+// entry. The syntax of patterns is the same as in filepath.Glob.
+func (i *Index) Glob(pattern string) (matches []*Entry, err error) {
+ pattern = filepath.ToSlash(pattern)
+ for _, e := range i.Entries {
+ m, err := match(pattern, e.Name)
+ if err != nil {
+ return nil, err
+ }
+
+ if m {
+ matches = append(matches, e)
+ }
+ }
+
+ return
+}
+
+// String is equivalent to `git ls-files --stage --debug`
+func (i *Index) String() string {
+ buf := bytes.NewBuffer(nil)
+ for _, e := range i.Entries {
+ buf.WriteString(e.String())
+ }
+
+ return buf.String()
+}
+
+// Entry represents a single file (or stage of a file) in the cache. An entry
+// represents exactly one stage of a file. If a file path is unmerged then
+// multiple Entry instances may appear for the same path name.
+type Entry struct {
+ // Hash is the SHA1 of the represented file
+ Hash plumbing.Hash
+ // Name is the Entry path name relative to top level directory
+ Name string
+ // CreatedAt time when the tracked path was created
+ CreatedAt time.Time
+ // ModifiedAt time when the tracked path was changed
+ ModifiedAt time.Time
+ // Dev and Inode of the tracked path
+ Dev, Inode uint32
+ // Mode of the path
+ Mode filemode.FileMode
+ // UID and GID, userid and group id of the owner
+ UID, GID uint32
+ // Size is the length in bytes for regular files
+ Size uint32
+ // Stage on a merge is defines what stage is representing this entry
+ // https://git-scm.com/book/en/v2/Git-Tools-Advanced-Merging
+ Stage Stage
+ // SkipWorktree used in sparse checkouts
+ // https://git-scm.com/docs/git-read-tree#_sparse_checkout
+ SkipWorktree bool
+ // IntentToAdd record only the fact that the path will be added later
+ // https://git-scm.com/docs/git-add ("git add -N")
+ IntentToAdd bool
+}
+
+func (e Entry) String() string {
+ buf := bytes.NewBuffer(nil)
+
+ fmt.Fprintf(buf, "%06o %s %d\t%s\n", e.Mode, e.Hash, e.Stage, e.Name)
+ fmt.Fprintf(buf, " ctime: %d:%d\n", e.CreatedAt.Unix(), e.CreatedAt.Nanosecond())
+ fmt.Fprintf(buf, " mtime: %d:%d\n", e.ModifiedAt.Unix(), e.ModifiedAt.Nanosecond())
+ fmt.Fprintf(buf, " dev: %d\tino: %d\n", e.Dev, e.Inode)
+ fmt.Fprintf(buf, " uid: %d\tgid: %d\n", e.UID, e.GID)
+ fmt.Fprintf(buf, " size: %d\tflags: %x\n", e.Size, 0)
+
+ return buf.String()
+}
+
+// Tree contains pre-computed hashes for trees that can be derived from the
+// index. It helps speed up tree object generation from index for a new commit.
+type Tree struct {
+ Entries []TreeEntry
+}
+
+// TreeEntry entry of a cached Tree
+type TreeEntry struct {
+ // Path component (relative to its parent directory)
+ Path string
+ // Entries is the number of entries in the index that is covered by the tree
+ // this entry represents.
+ Entries int
+ // Trees is the number that represents the number of subtrees this tree has
+ Trees int
+ // Hash object name for the object that would result from writing this span
+ // of index as a tree.
+ Hash plumbing.Hash
+}
+
+// ResolveUndo is used when a conflict is resolved (e.g. with "git add path"),
+// these higher stage entries are removed and a stage-0 entry with proper
+// resolution is added. When these higher stage entries are removed, they are
+// saved in the resolve undo extension.
+type ResolveUndo struct {
+ Entries []ResolveUndoEntry
+}
+
+// ResolveUndoEntry contains the information about a conflict when is resolved
+type ResolveUndoEntry struct {
+ Path string
+ Stages map[Stage]plumbing.Hash
+}
+
+// EndOfIndexEntry is the End of Index Entry (EOIE) is used to locate the end of
+// the variable length index entries and the begining of the extensions. Code
+// can take advantage of this to quickly locate the index extensions without
+// having to parse through all of the index entries.
+//
+// Because it must be able to be loaded before the variable length cache
+// entries and other index extensions, this extension must be written last.
+type EndOfIndexEntry struct {
+ // Offset to the end of the index entries
+ Offset uint32
+ // Hash is a SHA-1 over the extension types and their sizes (but not
+ // their contents).
+ Hash plumbing.Hash
+}
--- /dev/null
+package index
+
+import (
+ "path/filepath"
+ "runtime"
+ "unicode/utf8"
+)
+
+// match is filepath.Match with support to match fullpath and not only filenames
+// code from:
+// https://github.com/golang/go/blob/39852bf4cce6927e01d0136c7843f65a801738cb/src/path/filepath/match.go#L44-L224
+func match(pattern, name string) (matched bool, err error) {
+Pattern:
+ for len(pattern) > 0 {
+ var star bool
+ var chunk string
+ star, chunk, pattern = scanChunk(pattern)
+
+ // Look for match at current position.
+ t, ok, err := matchChunk(chunk, name)
+ // if we're the last chunk, make sure we've exhausted the name
+ // otherwise we'll give a false result even if we could still match
+ // using the star
+ if ok && (len(t) == 0 || len(pattern) > 0) {
+ name = t
+ continue
+ }
+ if err != nil {
+ return false, err
+ }
+ if star {
+ // Look for match skipping i+1 bytes.
+ // Cannot skip /.
+ for i := 0; i < len(name); i++ {
+ t, ok, err := matchChunk(chunk, name[i+1:])
+ if ok {
+ // if we're the last chunk, make sure we exhausted the name
+ if len(pattern) == 0 && len(t) > 0 {
+ continue
+ }
+ name = t
+ continue Pattern
+ }
+ if err != nil {
+ return false, err
+ }
+ }
+ }
+ return false, nil
+ }
+ return len(name) == 0, nil
+}
+
+// scanChunk gets the next segment of pattern, which is a non-star string
+// possibly preceded by a star.
+func scanChunk(pattern string) (star bool, chunk, rest string) {
+ for len(pattern) > 0 && pattern[0] == '*' {
+ pattern = pattern[1:]
+ star = true
+ }
+ inrange := false
+ var i int
+Scan:
+ for i = 0; i < len(pattern); i++ {
+ switch pattern[i] {
+ case '\\':
+ if runtime.GOOS != "windows" {
+ // error check handled in matchChunk: bad pattern.
+ if i+1 < len(pattern) {
+ i++
+ }
+ }
+ case '[':
+ inrange = true
+ case ']':
+ inrange = false
+ case '*':
+ if !inrange {
+ break Scan
+ }
+ }
+ }
+ return star, pattern[0:i], pattern[i:]
+}
+
+// matchChunk checks whether chunk matches the beginning of s.
+// If so, it returns the remainder of s (after the match).
+// Chunk is all single-character operators: literals, char classes, and ?.
+func matchChunk(chunk, s string) (rest string, ok bool, err error) {
+ for len(chunk) > 0 {
+ if len(s) == 0 {
+ return
+ }
+ switch chunk[0] {
+ case '[':
+ // character class
+ r, n := utf8.DecodeRuneInString(s)
+ s = s[n:]
+ chunk = chunk[1:]
+ // We can't end right after '[', we're expecting at least
+ // a closing bracket and possibly a caret.
+ if len(chunk) == 0 {
+ err = filepath.ErrBadPattern
+ return
+ }
+ // possibly negated
+ negated := chunk[0] == '^'
+ if negated {
+ chunk = chunk[1:]
+ }
+ // parse all ranges
+ match := false
+ nrange := 0
+ for {
+ if len(chunk) > 0 && chunk[0] == ']' && nrange > 0 {
+ chunk = chunk[1:]
+ break
+ }
+ var lo, hi rune
+ if lo, chunk, err = getEsc(chunk); err != nil {
+ return
+ }
+ hi = lo
+ if chunk[0] == '-' {
+ if hi, chunk, err = getEsc(chunk[1:]); err != nil {
+ return
+ }
+ }
+ if lo <= r && r <= hi {
+ match = true
+ }
+ nrange++
+ }
+ if match == negated {
+ return
+ }
+
+ case '?':
+ _, n := utf8.DecodeRuneInString(s)
+ s = s[n:]
+ chunk = chunk[1:]
+
+ case '\\':
+ if runtime.GOOS != "windows" {
+ chunk = chunk[1:]
+ if len(chunk) == 0 {
+ err = filepath.ErrBadPattern
+ return
+ }
+ }
+ fallthrough
+
+ default:
+ if chunk[0] != s[0] {
+ return
+ }
+ s = s[1:]
+ chunk = chunk[1:]
+ }
+ }
+ return s, true, nil
+}
+
+// getEsc gets a possibly-escaped character from chunk, for a character class.
+func getEsc(chunk string) (r rune, nchunk string, err error) {
+ if len(chunk) == 0 || chunk[0] == '-' || chunk[0] == ']' {
+ err = filepath.ErrBadPattern
+ return
+ }
+ if chunk[0] == '\\' && runtime.GOOS != "windows" {
+ chunk = chunk[1:]
+ if len(chunk) == 0 {
+ err = filepath.ErrBadPattern
+ return
+ }
+ }
+ r, n := utf8.DecodeRuneInString(chunk)
+ if r == utf8.RuneError && n == 1 {
+ err = filepath.ErrBadPattern
+ }
+ nchunk = chunk[n:]
+ if len(nchunk) == 0 {
+ err = filepath.ErrBadPattern
+ }
+ return
+}
--- /dev/null
+// Package objfile implements encoding and decoding of object files.
+package objfile
--- /dev/null
+package objfile
+
+import (
+ "compress/zlib"
+ "errors"
+ "io"
+ "strconv"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/packfile"
+)
+
+var (
+ ErrClosed = errors.New("objfile: already closed")
+ ErrHeader = errors.New("objfile: invalid header")
+ ErrNegativeSize = errors.New("objfile: negative object size")
+)
+
+// Reader reads and decodes compressed objfile data from a provided io.Reader.
+// Reader implements io.ReadCloser. Close should be called when finished with
+// the Reader. Close will not close the underlying io.Reader.
+type Reader struct {
+ multi io.Reader
+ zlib io.ReadCloser
+ hasher plumbing.Hasher
+}
+
+// NewReader returns a new Reader reading from r.
+func NewReader(r io.Reader) (*Reader, error) {
+ zlib, err := zlib.NewReader(r)
+ if err != nil {
+ return nil, packfile.ErrZLib.AddDetails(err.Error())
+ }
+
+ return &Reader{
+ zlib: zlib,
+ }, nil
+}
+
+// Header reads the type and the size of object, and prepares the reader for read
+func (r *Reader) Header() (t plumbing.ObjectType, size int64, err error) {
+ var raw []byte
+ raw, err = r.readUntil(' ')
+ if err != nil {
+ return
+ }
+
+ t, err = plumbing.ParseObjectType(string(raw))
+ if err != nil {
+ return
+ }
+
+ raw, err = r.readUntil(0)
+ if err != nil {
+ return
+ }
+
+ size, err = strconv.ParseInt(string(raw), 10, 64)
+ if err != nil {
+ err = ErrHeader
+ return
+ }
+
+ defer r.prepareForRead(t, size)
+ return
+}
+
+// readSlice reads one byte at a time from r until it encounters delim or an
+// error.
+func (r *Reader) readUntil(delim byte) ([]byte, error) {
+ var buf [1]byte
+ value := make([]byte, 0, 16)
+ for {
+ if n, err := r.zlib.Read(buf[:]); err != nil && (err != io.EOF || n == 0) {
+ if err == io.EOF {
+ return nil, ErrHeader
+ }
+ return nil, err
+ }
+
+ if buf[0] == delim {
+ return value, nil
+ }
+
+ value = append(value, buf[0])
+ }
+}
+
+func (r *Reader) prepareForRead(t plumbing.ObjectType, size int64) {
+ r.hasher = plumbing.NewHasher(t, size)
+ r.multi = io.TeeReader(r.zlib, r.hasher)
+}
+
+// Read reads len(p) bytes into p from the object data stream. It returns
+// the number of bytes read (0 <= n <= len(p)) and any error encountered. Even
+// if Read returns n < len(p), it may use all of p as scratch space during the
+// call.
+//
+// If Read encounters the end of the data stream it will return err == io.EOF,
+// either in the current call if n > 0 or in a subsequent call.
+func (r *Reader) Read(p []byte) (n int, err error) {
+ return r.multi.Read(p)
+}
+
+// Hash returns the hash of the object data stream that has been read so far.
+func (r *Reader) Hash() plumbing.Hash {
+ return r.hasher.Sum()
+}
+
+// Close releases any resources consumed by the Reader. Calling Close does not
+// close the wrapped io.Reader originally passed to NewReader.
+func (r *Reader) Close() error {
+ return r.zlib.Close()
+}
--- /dev/null
+package objfile
+
+import (
+ "compress/zlib"
+ "errors"
+ "io"
+ "strconv"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+)
+
+var (
+ ErrOverflow = errors.New("objfile: declared data length exceeded (overflow)")
+)
+
+// Writer writes and encodes data in compressed objfile format to a provided
+// io.Writer. Close should be called when finished with the Writer. Close will
+// not close the underlying io.Writer.
+type Writer struct {
+ raw io.Writer
+ zlib io.WriteCloser
+ hasher plumbing.Hasher
+ multi io.Writer
+
+ closed bool
+ pending int64 // number of unwritten bytes
+}
+
+// NewWriter returns a new Writer writing to w.
+//
+// The returned Writer implements io.WriteCloser. Close should be called when
+// finished with the Writer. Close will not close the underlying io.Writer.
+func NewWriter(w io.Writer) *Writer {
+ return &Writer{
+ raw: w,
+ zlib: zlib.NewWriter(w),
+ }
+}
+
+// WriteHeader writes the type and the size and prepares to accept the object's
+// contents. If an invalid t is provided, plumbing.ErrInvalidType is returned. If a
+// negative size is provided, ErrNegativeSize is returned.
+func (w *Writer) WriteHeader(t plumbing.ObjectType, size int64) error {
+ if !t.Valid() {
+ return plumbing.ErrInvalidType
+ }
+ if size < 0 {
+ return ErrNegativeSize
+ }
+
+ b := t.Bytes()
+ b = append(b, ' ')
+ b = append(b, []byte(strconv.FormatInt(size, 10))...)
+ b = append(b, 0)
+
+ defer w.prepareForWrite(t, size)
+ _, err := w.zlib.Write(b)
+
+ return err
+}
+
+func (w *Writer) prepareForWrite(t plumbing.ObjectType, size int64) {
+ w.pending = size
+
+ w.hasher = plumbing.NewHasher(t, size)
+ w.multi = io.MultiWriter(w.zlib, w.hasher)
+}
+
+// Write writes the object's contents. Write returns the error ErrOverflow if
+// more than size bytes are written after WriteHeader.
+func (w *Writer) Write(p []byte) (n int, err error) {
+ if w.closed {
+ return 0, ErrClosed
+ }
+
+ overwrite := false
+ if int64(len(p)) > w.pending {
+ p = p[0:w.pending]
+ overwrite = true
+ }
+
+ n, err = w.multi.Write(p)
+ w.pending -= int64(n)
+ if err == nil && overwrite {
+ err = ErrOverflow
+ return
+ }
+
+ return
+}
+
+// Hash returns the hash of the object data stream that has been written so far.
+// It can be called before or after Close.
+func (w *Writer) Hash() plumbing.Hash {
+ return w.hasher.Sum() // Not yet closed, return hash of data written so far
+}
+
+// Close releases any resources consumed by the Writer.
+//
+// Calling Close does not close the wrapped io.Writer originally passed to
+// NewWriter.
+func (w *Writer) Close() error {
+ if err := w.zlib.Close(); err != nil {
+ return err
+ }
+
+ w.closed = true
+ return nil
+}
--- /dev/null
+package packfile
+
+import (
+ "bytes"
+ "io"
+ "sync"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+var signature = []byte{'P', 'A', 'C', 'K'}
+
+const (
+ // VersionSupported is the packfile version supported by this package
+ VersionSupported uint32 = 2
+
+ firstLengthBits = uint8(4) // the first byte into object header has 4 bits to store the length
+ lengthBits = uint8(7) // subsequent bytes has 7 bits to store the length
+ maskFirstLength = 15 // 0000 1111
+ maskContinue = 0x80 // 1000 0000
+ maskLength = uint8(127) // 0111 1111
+ maskType = uint8(112) // 0111 0000
+)
+
+// UpdateObjectStorage updates the storer with the objects in the given
+// packfile.
+func UpdateObjectStorage(s storer.Storer, packfile io.Reader) error {
+ if pw, ok := s.(storer.PackfileWriter); ok {
+ return WritePackfileToObjectStorage(pw, packfile)
+ }
+
+ p, err := NewParserWithStorage(NewScanner(packfile), s)
+ if err != nil {
+ return err
+ }
+
+ _, err = p.Parse()
+ return err
+}
+
+// WritePackfileToObjectStorage writes all the packfile objects into the given
+// object storage.
+func WritePackfileToObjectStorage(
+ sw storer.PackfileWriter,
+ packfile io.Reader,
+) (err error) {
+ w, err := sw.PackfileWriter()
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(w, &err)
+ _, err = io.Copy(w, packfile)
+ return err
+}
+
+var bufPool = sync.Pool{
+ New: func() interface{} {
+ return bytes.NewBuffer(nil)
+ },
+}
--- /dev/null
+package packfile
+
+const blksz = 16
+const maxChainLength = 64
+
+// deltaIndex is a modified version of JGit's DeltaIndex adapted to our current
+// design.
+type deltaIndex struct {
+ table []int
+ entries []int
+ mask int
+}
+
+func (idx *deltaIndex) init(buf []byte) {
+ scanner := newDeltaIndexScanner(buf, len(buf))
+ idx.mask = scanner.mask
+ idx.table = scanner.table
+ idx.entries = make([]int, countEntries(scanner)+1)
+ idx.copyEntries(scanner)
+}
+
+// findMatch returns the offset of src where the block starting at tgtOffset
+// is and the length of the match. A length of 0 means there was no match. A
+// length of -1 means the src length is lower than the blksz and whatever
+// other positive length is the length of the match in bytes.
+func (idx *deltaIndex) findMatch(src, tgt []byte, tgtOffset int) (srcOffset, l int) {
+ if len(tgt) < tgtOffset+s {
+ return 0, len(tgt) - tgtOffset
+ }
+
+ if len(src) < blksz {
+ return 0, -1
+ }
+
+ if len(tgt) >= tgtOffset+s && len(src) >= blksz {
+ h := hashBlock(tgt, tgtOffset)
+ tIdx := h & idx.mask
+ eIdx := idx.table[tIdx]
+ if eIdx != 0 {
+ srcOffset = idx.entries[eIdx]
+ } else {
+ return
+ }
+
+ l = matchLength(src, tgt, tgtOffset, srcOffset)
+ }
+
+ return
+}
+
+func matchLength(src, tgt []byte, otgt, osrc int) (l int) {
+ lensrc := len(src)
+ lentgt := len(tgt)
+ for (osrc < lensrc && otgt < lentgt) && src[osrc] == tgt[otgt] {
+ l++
+ osrc++
+ otgt++
+ }
+ return
+}
+
+func countEntries(scan *deltaIndexScanner) (cnt int) {
+ // Figure out exactly how many entries we need. As we do the
+ // enumeration truncate any delta chains longer than what we
+ // are willing to scan during encode. This keeps the encode
+ // logic linear in the size of the input rather than quadratic.
+ for i := 0; i < len(scan.table); i++ {
+ h := scan.table[i]
+ if h == 0 {
+ continue
+ }
+
+ size := 0
+ for {
+ size++
+ if size == maxChainLength {
+ scan.next[h] = 0
+ break
+ }
+ h = scan.next[h]
+
+ if h == 0 {
+ break
+ }
+ }
+ cnt += size
+ }
+
+ return
+}
+
+func (idx *deltaIndex) copyEntries(scanner *deltaIndexScanner) {
+ // Rebuild the entries list from the scanner, positioning all
+ // blocks in the same hash chain next to each other. We can
+ // then later discard the next list, along with the scanner.
+ //
+ next := 1
+ for i := 0; i < len(idx.table); i++ {
+ h := idx.table[i]
+ if h == 0 {
+ continue
+ }
+
+ idx.table[i] = next
+ for {
+ idx.entries[next] = scanner.entries[h]
+ next++
+ h = scanner.next[h]
+
+ if h == 0 {
+ break
+ }
+ }
+ }
+}
+
+type deltaIndexScanner struct {
+ table []int
+ entries []int
+ next []int
+ mask int
+ count int
+}
+
+func newDeltaIndexScanner(buf []byte, size int) *deltaIndexScanner {
+ size -= size % blksz
+ worstCaseBlockCnt := size / blksz
+ if worstCaseBlockCnt < 1 {
+ return new(deltaIndexScanner)
+ }
+
+ tableSize := tableSize(worstCaseBlockCnt)
+ scanner := &deltaIndexScanner{
+ table: make([]int, tableSize),
+ mask: tableSize - 1,
+ entries: make([]int, worstCaseBlockCnt+1),
+ next: make([]int, worstCaseBlockCnt+1),
+ }
+
+ scanner.scan(buf, size)
+ return scanner
+}
+
+// slightly modified version of JGit's DeltaIndexScanner. We store the offset on the entries
+// instead of the entries and the key, so we avoid operations to retrieve the offset later, as
+// we don't use the key.
+// See: https://github.com/eclipse/jgit/blob/005e5feb4ecd08c4e4d141a38b9e7942accb3212/org.eclipse.jgit/src/org/eclipse/jgit/internal/storage/pack/DeltaIndexScanner.java
+func (s *deltaIndexScanner) scan(buf []byte, end int) {
+ lastHash := 0
+ ptr := end - blksz
+
+ for {
+ key := hashBlock(buf, ptr)
+ tIdx := key & s.mask
+ head := s.table[tIdx]
+ if head != 0 && lastHash == key {
+ s.entries[head] = ptr
+ } else {
+ s.count++
+ eIdx := s.count
+ s.entries[eIdx] = ptr
+ s.next[eIdx] = head
+ s.table[tIdx] = eIdx
+ }
+
+ lastHash = key
+ ptr -= blksz
+
+ if 0 > ptr {
+ break
+ }
+ }
+}
+
+func tableSize(worstCaseBlockCnt int) int {
+ shift := 32 - leadingZeros(uint32(worstCaseBlockCnt))
+ sz := 1 << uint(shift-1)
+ if sz < worstCaseBlockCnt {
+ sz <<= 1
+ }
+ return sz
+}
+
+// use https://golang.org/pkg/math/bits/#LeadingZeros32 in the future
+func leadingZeros(x uint32) (n int) {
+ if x >= 1<<16 {
+ x >>= 16
+ n = 16
+ }
+ if x >= 1<<8 {
+ x >>= 8
+ n += 8
+ }
+ n += int(len8tab[x])
+ return 32 - n
+}
+
+var len8tab = [256]uint8{
+ 0x00, 0x01, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04,
+ 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05,
+ 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
+ 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
+ 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07,
+ 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07,
+ 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07,
+ 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07,
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
+}
+
+func hashBlock(raw []byte, ptr int) int {
+ // The first 4 steps collapse out into a 4 byte big-endian decode,
+ // with a larger right shift as we combined shift lefts together.
+ //
+ hash := ((uint32(raw[ptr]) & 0xff) << 24) |
+ ((uint32(raw[ptr+1]) & 0xff) << 16) |
+ ((uint32(raw[ptr+2]) & 0xff) << 8) |
+ (uint32(raw[ptr+3]) & 0xff)
+ hash ^= T[hash>>31]
+
+ hash = ((hash << 8) | (uint32(raw[ptr+4]) & 0xff)) ^ T[hash>>23]
+ hash = ((hash << 8) | (uint32(raw[ptr+5]) & 0xff)) ^ T[hash>>23]
+ hash = ((hash << 8) | (uint32(raw[ptr+6]) & 0xff)) ^ T[hash>>23]
+ hash = ((hash << 8) | (uint32(raw[ptr+7]) & 0xff)) ^ T[hash>>23]
+
+ hash = ((hash << 8) | (uint32(raw[ptr+8]) & 0xff)) ^ T[hash>>23]
+ hash = ((hash << 8) | (uint32(raw[ptr+9]) & 0xff)) ^ T[hash>>23]
+ hash = ((hash << 8) | (uint32(raw[ptr+10]) & 0xff)) ^ T[hash>>23]
+ hash = ((hash << 8) | (uint32(raw[ptr+11]) & 0xff)) ^ T[hash>>23]
+
+ hash = ((hash << 8) | (uint32(raw[ptr+12]) & 0xff)) ^ T[hash>>23]
+ hash = ((hash << 8) | (uint32(raw[ptr+13]) & 0xff)) ^ T[hash>>23]
+ hash = ((hash << 8) | (uint32(raw[ptr+14]) & 0xff)) ^ T[hash>>23]
+ hash = ((hash << 8) | (uint32(raw[ptr+15]) & 0xff)) ^ T[hash>>23]
+
+ return int(hash)
+}
+
+var T = []uint32{0x00000000, 0xd4c6b32d, 0x7d4bd577,
+ 0xa98d665a, 0x2e5119c3, 0xfa97aaee, 0x531accb4, 0x87dc7f99,
+ 0x5ca23386, 0x886480ab, 0x21e9e6f1, 0xf52f55dc, 0x72f32a45,
+ 0xa6359968, 0x0fb8ff32, 0xdb7e4c1f, 0x6d82d421, 0xb944670c,
+ 0x10c90156, 0xc40fb27b, 0x43d3cde2, 0x97157ecf, 0x3e981895,
+ 0xea5eabb8, 0x3120e7a7, 0xe5e6548a, 0x4c6b32d0, 0x98ad81fd,
+ 0x1f71fe64, 0xcbb74d49, 0x623a2b13, 0xb6fc983e, 0x0fc31b6f,
+ 0xdb05a842, 0x7288ce18, 0xa64e7d35, 0x219202ac, 0xf554b181,
+ 0x5cd9d7db, 0x881f64f6, 0x536128e9, 0x87a79bc4, 0x2e2afd9e,
+ 0xfaec4eb3, 0x7d30312a, 0xa9f68207, 0x007be45d, 0xd4bd5770,
+ 0x6241cf4e, 0xb6877c63, 0x1f0a1a39, 0xcbcca914, 0x4c10d68d,
+ 0x98d665a0, 0x315b03fa, 0xe59db0d7, 0x3ee3fcc8, 0xea254fe5,
+ 0x43a829bf, 0x976e9a92, 0x10b2e50b, 0xc4745626, 0x6df9307c,
+ 0xb93f8351, 0x1f8636de, 0xcb4085f3, 0x62cde3a9, 0xb60b5084,
+ 0x31d72f1d, 0xe5119c30, 0x4c9cfa6a, 0x985a4947, 0x43240558,
+ 0x97e2b675, 0x3e6fd02f, 0xeaa96302, 0x6d751c9b, 0xb9b3afb6,
+ 0x103ec9ec, 0xc4f87ac1, 0x7204e2ff, 0xa6c251d2, 0x0f4f3788,
+ 0xdb8984a5, 0x5c55fb3c, 0x88934811, 0x211e2e4b, 0xf5d89d66,
+ 0x2ea6d179, 0xfa606254, 0x53ed040e, 0x872bb723, 0x00f7c8ba,
+ 0xd4317b97, 0x7dbc1dcd, 0xa97aaee0, 0x10452db1, 0xc4839e9c,
+ 0x6d0ef8c6, 0xb9c84beb, 0x3e143472, 0xead2875f, 0x435fe105,
+ 0x97995228, 0x4ce71e37, 0x9821ad1a, 0x31accb40, 0xe56a786d,
+ 0x62b607f4, 0xb670b4d9, 0x1ffdd283, 0xcb3b61ae, 0x7dc7f990,
+ 0xa9014abd, 0x008c2ce7, 0xd44a9fca, 0x5396e053, 0x8750537e,
+ 0x2edd3524, 0xfa1b8609, 0x2165ca16, 0xf5a3793b, 0x5c2e1f61,
+ 0x88e8ac4c, 0x0f34d3d5, 0xdbf260f8, 0x727f06a2, 0xa6b9b58f,
+ 0x3f0c6dbc, 0xebcade91, 0x4247b8cb, 0x96810be6, 0x115d747f,
+ 0xc59bc752, 0x6c16a108, 0xb8d01225, 0x63ae5e3a, 0xb768ed17,
+ 0x1ee58b4d, 0xca233860, 0x4dff47f9, 0x9939f4d4, 0x30b4928e,
+ 0xe47221a3, 0x528eb99d, 0x86480ab0, 0x2fc56cea, 0xfb03dfc7,
+ 0x7cdfa05e, 0xa8191373, 0x01947529, 0xd552c604, 0x0e2c8a1b,
+ 0xdaea3936, 0x73675f6c, 0xa7a1ec41, 0x207d93d8, 0xf4bb20f5,
+ 0x5d3646af, 0x89f0f582, 0x30cf76d3, 0xe409c5fe, 0x4d84a3a4,
+ 0x99421089, 0x1e9e6f10, 0xca58dc3d, 0x63d5ba67, 0xb713094a,
+ 0x6c6d4555, 0xb8abf678, 0x11269022, 0xc5e0230f, 0x423c5c96,
+ 0x96faefbb, 0x3f7789e1, 0xebb13acc, 0x5d4da2f2, 0x898b11df,
+ 0x20067785, 0xf4c0c4a8, 0x731cbb31, 0xa7da081c, 0x0e576e46,
+ 0xda91dd6b, 0x01ef9174, 0xd5292259, 0x7ca44403, 0xa862f72e,
+ 0x2fbe88b7, 0xfb783b9a, 0x52f55dc0, 0x8633eeed, 0x208a5b62,
+ 0xf44ce84f, 0x5dc18e15, 0x89073d38, 0x0edb42a1, 0xda1df18c,
+ 0x739097d6, 0xa75624fb, 0x7c2868e4, 0xa8eedbc9, 0x0163bd93,
+ 0xd5a50ebe, 0x52797127, 0x86bfc20a, 0x2f32a450, 0xfbf4177d,
+ 0x4d088f43, 0x99ce3c6e, 0x30435a34, 0xe485e919, 0x63599680,
+ 0xb79f25ad, 0x1e1243f7, 0xcad4f0da, 0x11aabcc5, 0xc56c0fe8,
+ 0x6ce169b2, 0xb827da9f, 0x3ffba506, 0xeb3d162b, 0x42b07071,
+ 0x9676c35c, 0x2f49400d, 0xfb8ff320, 0x5202957a, 0x86c42657,
+ 0x011859ce, 0xd5deeae3, 0x7c538cb9, 0xa8953f94, 0x73eb738b,
+ 0xa72dc0a6, 0x0ea0a6fc, 0xda6615d1, 0x5dba6a48, 0x897cd965,
+ 0x20f1bf3f, 0xf4370c12, 0x42cb942c, 0x960d2701, 0x3f80415b,
+ 0xeb46f276, 0x6c9a8def, 0xb85c3ec2, 0x11d15898, 0xc517ebb5,
+ 0x1e69a7aa, 0xcaaf1487, 0x632272dd, 0xb7e4c1f0, 0x3038be69,
+ 0xe4fe0d44, 0x4d736b1e, 0x99b5d833,
+}
--- /dev/null
+package packfile
+
+import (
+ "sort"
+ "sync"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+)
+
+const (
+ // deltas based on deltas, how many steps we can do.
+ // 50 is the default value used in JGit
+ maxDepth = int64(50)
+)
+
+// applyDelta is the set of object types that we should apply deltas
+var applyDelta = map[plumbing.ObjectType]bool{
+ plumbing.BlobObject: true,
+ plumbing.TreeObject: true,
+}
+
+type deltaSelector struct {
+ storer storer.EncodedObjectStorer
+}
+
+func newDeltaSelector(s storer.EncodedObjectStorer) *deltaSelector {
+ return &deltaSelector{s}
+}
+
+// ObjectsToPack creates a list of ObjectToPack from the hashes
+// provided, creating deltas if it's suitable, using an specific
+// internal logic. `packWindow` specifies the size of the sliding
+// window used to compare objects for delta compression; 0 turns off
+// delta compression entirely.
+func (dw *deltaSelector) ObjectsToPack(
+ hashes []plumbing.Hash,
+ packWindow uint,
+) ([]*ObjectToPack, error) {
+ otp, err := dw.objectsToPack(hashes, packWindow)
+ if err != nil {
+ return nil, err
+ }
+
+ if packWindow == 0 {
+ return otp, nil
+ }
+
+ dw.sort(otp)
+
+ var objectGroups [][]*ObjectToPack
+ var prev *ObjectToPack
+ i := -1
+ for _, obj := range otp {
+ if prev == nil || prev.Type() != obj.Type() {
+ objectGroups = append(objectGroups, []*ObjectToPack{obj})
+ i++
+ prev = obj
+ } else {
+ objectGroups[i] = append(objectGroups[i], obj)
+ }
+ }
+
+ var wg sync.WaitGroup
+ var once sync.Once
+ for _, objs := range objectGroups {
+ objs := objs
+ wg.Add(1)
+ go func() {
+ if walkErr := dw.walk(objs, packWindow); walkErr != nil {
+ once.Do(func() {
+ err = walkErr
+ })
+ }
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+
+ if err != nil {
+ return nil, err
+ }
+
+ return otp, nil
+}
+
+func (dw *deltaSelector) objectsToPack(
+ hashes []plumbing.Hash,
+ packWindow uint,
+) ([]*ObjectToPack, error) {
+ var objectsToPack []*ObjectToPack
+ for _, h := range hashes {
+ var o plumbing.EncodedObject
+ var err error
+ if packWindow == 0 {
+ o, err = dw.encodedObject(h)
+ } else {
+ o, err = dw.encodedDeltaObject(h)
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ otp := newObjectToPack(o)
+ if _, ok := o.(plumbing.DeltaObject); ok {
+ otp.CleanOriginal()
+ }
+
+ objectsToPack = append(objectsToPack, otp)
+ }
+
+ if packWindow == 0 {
+ return objectsToPack, nil
+ }
+
+ if err := dw.fixAndBreakChains(objectsToPack); err != nil {
+ return nil, err
+ }
+
+ return objectsToPack, nil
+}
+
+func (dw *deltaSelector) encodedDeltaObject(h plumbing.Hash) (plumbing.EncodedObject, error) {
+ edos, ok := dw.storer.(storer.DeltaObjectStorer)
+ if !ok {
+ return dw.encodedObject(h)
+ }
+
+ return edos.DeltaObject(plumbing.AnyObject, h)
+}
+
+func (dw *deltaSelector) encodedObject(h plumbing.Hash) (plumbing.EncodedObject, error) {
+ return dw.storer.EncodedObject(plumbing.AnyObject, h)
+}
+
+func (dw *deltaSelector) fixAndBreakChains(objectsToPack []*ObjectToPack) error {
+ m := make(map[plumbing.Hash]*ObjectToPack, len(objectsToPack))
+ for _, otp := range objectsToPack {
+ m[otp.Hash()] = otp
+ }
+
+ for _, otp := range objectsToPack {
+ if err := dw.fixAndBreakChainsOne(m, otp); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (dw *deltaSelector) fixAndBreakChainsOne(objectsToPack map[plumbing.Hash]*ObjectToPack, otp *ObjectToPack) error {
+ if !otp.Object.Type().IsDelta() {
+ return nil
+ }
+
+ // Initial ObjectToPack instances might have a delta assigned to Object
+ // but no actual base initially. Once Base is assigned to a delta, it means
+ // we already fixed it.
+ if otp.Base != nil {
+ return nil
+ }
+
+ do, ok := otp.Object.(plumbing.DeltaObject)
+ if !ok {
+ // if this is not a DeltaObject, then we cannot retrieve its base,
+ // so we have to break the delta chain here.
+ return dw.undeltify(otp)
+ }
+
+ base, ok := objectsToPack[do.BaseHash()]
+ if !ok {
+ // The base of the delta is not in our list of objects to pack, so
+ // we break the chain.
+ return dw.undeltify(otp)
+ }
+
+ if err := dw.fixAndBreakChainsOne(objectsToPack, base); err != nil {
+ return err
+ }
+
+ otp.SetDelta(base, otp.Object)
+ return nil
+}
+
+func (dw *deltaSelector) restoreOriginal(otp *ObjectToPack) error {
+ if otp.Original != nil {
+ return nil
+ }
+
+ if !otp.Object.Type().IsDelta() {
+ return nil
+ }
+
+ obj, err := dw.encodedObject(otp.Hash())
+ if err != nil {
+ return err
+ }
+
+ otp.SetOriginal(obj)
+
+ return nil
+}
+
+// undeltify undeltifies an *ObjectToPack by retrieving the original object from
+// the storer and resetting it.
+func (dw *deltaSelector) undeltify(otp *ObjectToPack) error {
+ if err := dw.restoreOriginal(otp); err != nil {
+ return err
+ }
+
+ otp.Object = otp.Original
+ otp.Depth = 0
+ return nil
+}
+
+func (dw *deltaSelector) sort(objectsToPack []*ObjectToPack) {
+ sort.Sort(byTypeAndSize(objectsToPack))
+}
+
+func (dw *deltaSelector) walk(
+ objectsToPack []*ObjectToPack,
+ packWindow uint,
+) error {
+ indexMap := make(map[plumbing.Hash]*deltaIndex)
+ for i := 0; i < len(objectsToPack); i++ {
+ // Clean up the index map and reconstructed delta objects for anything
+ // outside our pack window, to save memory.
+ if i > int(packWindow) {
+ obj := objectsToPack[i-int(packWindow)]
+
+ delete(indexMap, obj.Hash())
+
+ if obj.IsDelta() {
+ obj.SaveOriginalMetadata()
+ obj.CleanOriginal()
+ }
+ }
+
+ target := objectsToPack[i]
+
+ // If we already have a delta, we don't try to find a new one for this
+ // object. This happens when a delta is set to be reused from an existing
+ // packfile.
+ if target.IsDelta() {
+ continue
+ }
+
+ // We only want to create deltas from specific types.
+ if !applyDelta[target.Type()] {
+ continue
+ }
+
+ for j := i - 1; j >= 0 && i-j < int(packWindow); j-- {
+ base := objectsToPack[j]
+ // Objects must use only the same type as their delta base.
+ // Since objectsToPack is sorted by type and size, once we find
+ // a different type, we know we won't find more of them.
+ if base.Type() != target.Type() {
+ break
+ }
+
+ if err := dw.tryToDeltify(indexMap, base, target); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+func (dw *deltaSelector) tryToDeltify(indexMap map[plumbing.Hash]*deltaIndex, base, target *ObjectToPack) error {
+ // Original object might not be present if we're reusing a delta, so we
+ // ensure it is restored.
+ if err := dw.restoreOriginal(target); err != nil {
+ return err
+ }
+
+ if err := dw.restoreOriginal(base); err != nil {
+ return err
+ }
+
+ // If the sizes are radically different, this is a bad pairing.
+ if target.Size() < base.Size()>>4 {
+ return nil
+ }
+
+ msz := dw.deltaSizeLimit(
+ target.Object.Size(),
+ base.Depth,
+ target.Depth,
+ target.IsDelta(),
+ )
+
+ // Nearly impossible to fit useful delta.
+ if msz <= 8 {
+ return nil
+ }
+
+ // If we have to insert a lot to make this work, find another.
+ if base.Size()-target.Size() > msz {
+ return nil
+ }
+
+ if _, ok := indexMap[base.Hash()]; !ok {
+ indexMap[base.Hash()] = new(deltaIndex)
+ }
+
+ // Now we can generate the delta using originals
+ delta, err := getDelta(indexMap[base.Hash()], base.Original, target.Original)
+ if err != nil {
+ return err
+ }
+
+ // if delta better than target
+ if delta.Size() < msz {
+ target.SetDelta(base, delta)
+ }
+
+ return nil
+}
+
+func (dw *deltaSelector) deltaSizeLimit(targetSize int64, baseDepth int,
+ targetDepth int, targetDelta bool) int64 {
+ if !targetDelta {
+ // Any delta should be no more than 50% of the original size
+ // (for text files deflate of whole form should shrink 50%).
+ n := targetSize >> 1
+
+ // Evenly distribute delta size limits over allowed depth.
+ // If src is non-delta (depth = 0), delta <= 50% of original.
+ // If src is almost at limit (9/10), delta <= 10% of original.
+ return n * (maxDepth - int64(baseDepth)) / maxDepth
+ }
+
+ // With a delta base chosen any new delta must be "better".
+ // Retain the distribution described above.
+ d := int64(targetDepth)
+ n := targetSize
+
+ // If target depth is bigger than maxDepth, this delta is not suitable to be used.
+ if d >= maxDepth {
+ return 0
+ }
+
+ // If src is whole (depth=0) and base is near limit (depth=9/10)
+ // any delta using src can be 10x larger and still be better.
+ //
+ // If src is near limit (depth=9/10) and base is whole (depth=0)
+ // a new delta dependent on src must be 1/10th the size.
+ return n * (maxDepth - int64(baseDepth)) / (maxDepth - d)
+}
+
+type byTypeAndSize []*ObjectToPack
+
+func (a byTypeAndSize) Len() int { return len(a) }
+
+func (a byTypeAndSize) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
+
+func (a byTypeAndSize) Less(i, j int) bool {
+ if a[i].Type() < a[j].Type() {
+ return false
+ }
+
+ if a[i].Type() > a[j].Type() {
+ return true
+ }
+
+ return a[i].Size() > a[j].Size()
+}
--- /dev/null
+package packfile
+
+import (
+ "bytes"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+)
+
+// See https://github.com/jelmer/dulwich/blob/master/dulwich/pack.py and
+// https://github.com/tarruda/node-git-core/blob/master/src/js/delta.js
+// for more info
+
+const (
+ // Standard chunk size used to generate fingerprints
+ s = 16
+
+ // https://github.com/git/git/blob/f7466e94375b3be27f229c78873f0acf8301c0a5/diff-delta.c#L428
+ // Max size of a copy operation (64KB)
+ maxCopySize = 64 * 1024
+)
+
+// GetDelta returns an EncodedObject of type OFSDeltaObject. Base and Target object,
+// will be loaded into memory to be able to create the delta object.
+// To generate target again, you will need the obtained object and "base" one.
+// Error will be returned if base or target object cannot be read.
+func GetDelta(base, target plumbing.EncodedObject) (plumbing.EncodedObject, error) {
+ return getDelta(new(deltaIndex), base, target)
+}
+
+func getDelta(index *deltaIndex, base, target plumbing.EncodedObject) (plumbing.EncodedObject, error) {
+ br, err := base.Reader()
+ if err != nil {
+ return nil, err
+ }
+ defer br.Close()
+ tr, err := target.Reader()
+ if err != nil {
+ return nil, err
+ }
+ defer tr.Close()
+
+ bb := bufPool.Get().(*bytes.Buffer)
+ bb.Reset()
+ defer bufPool.Put(bb)
+
+ _, err = bb.ReadFrom(br)
+ if err != nil {
+ return nil, err
+ }
+
+ tb := bufPool.Get().(*bytes.Buffer)
+ tb.Reset()
+ defer bufPool.Put(tb)
+
+ _, err = tb.ReadFrom(tr)
+ if err != nil {
+ return nil, err
+ }
+
+ db := diffDelta(index, bb.Bytes(), tb.Bytes())
+ delta := &plumbing.MemoryObject{}
+ _, err = delta.Write(db)
+ if err != nil {
+ return nil, err
+ }
+
+ delta.SetSize(int64(len(db)))
+ delta.SetType(plumbing.OFSDeltaObject)
+
+ return delta, nil
+}
+
+// DiffDelta returns the delta that transforms src into tgt.
+func DiffDelta(src, tgt []byte) []byte {
+ return diffDelta(new(deltaIndex), src, tgt)
+}
+
+func diffDelta(index *deltaIndex, src []byte, tgt []byte) []byte {
+ buf := bufPool.Get().(*bytes.Buffer)
+ buf.Reset()
+ buf.Write(deltaEncodeSize(len(src)))
+ buf.Write(deltaEncodeSize(len(tgt)))
+
+ if len(index.entries) == 0 {
+ index.init(src)
+ }
+
+ ibuf := bufPool.Get().(*bytes.Buffer)
+ ibuf.Reset()
+ for i := 0; i < len(tgt); i++ {
+ offset, l := index.findMatch(src, tgt, i)
+
+ if l == 0 {
+ // couldn't find a match, just write the current byte and continue
+ ibuf.WriteByte(tgt[i])
+ } else if l < 0 {
+ // src is less than blksz, copy the rest of the target to avoid
+ // calls to findMatch
+ for ; i < len(tgt); i++ {
+ ibuf.WriteByte(tgt[i])
+ }
+ } else if l < s {
+ // remaining target is less than blksz, copy what's left of it
+ // and avoid calls to findMatch
+ for j := i; j < i+l; j++ {
+ ibuf.WriteByte(tgt[j])
+ }
+ i += l - 1
+ } else {
+ encodeInsertOperation(ibuf, buf)
+
+ rl := l
+ aOffset := offset
+ for rl > 0 {
+ if rl < maxCopySize {
+ buf.Write(encodeCopyOperation(aOffset, rl))
+ break
+ }
+
+ buf.Write(encodeCopyOperation(aOffset, maxCopySize))
+ rl -= maxCopySize
+ aOffset += maxCopySize
+ }
+
+ i += l - 1
+ }
+ }
+
+ encodeInsertOperation(ibuf, buf)
+ bytes := buf.Bytes()
+
+ bufPool.Put(buf)
+ bufPool.Put(ibuf)
+
+ return bytes
+}
+
+func encodeInsertOperation(ibuf, buf *bytes.Buffer) {
+ if ibuf.Len() == 0 {
+ return
+ }
+
+ b := ibuf.Bytes()
+ s := ibuf.Len()
+ o := 0
+ for {
+ if s <= 127 {
+ break
+ }
+ buf.WriteByte(byte(127))
+ buf.Write(b[o : o+127])
+ s -= 127
+ o += 127
+ }
+ buf.WriteByte(byte(s))
+ buf.Write(b[o : o+s])
+
+ ibuf.Reset()
+}
+
+func deltaEncodeSize(size int) []byte {
+ var ret []byte
+ c := size & 0x7f
+ size >>= 7
+ for {
+ if size == 0 {
+ break
+ }
+
+ ret = append(ret, byte(c|0x80))
+ c = size & 0x7f
+ size >>= 7
+ }
+ ret = append(ret, byte(c))
+
+ return ret
+}
+
+func encodeCopyOperation(offset, length int) []byte {
+ code := 0x80
+ var opcodes []byte
+
+ var i uint
+ for i = 0; i < 4; i++ {
+ f := 0xff << (i * 8)
+ if offset&f != 0 {
+ opcodes = append(opcodes, byte(offset&f>>(i*8)))
+ code |= 0x01 << i
+ }
+ }
+
+ for i = 0; i < 3; i++ {
+ f := 0xff << (i * 8)
+ if length&f != 0 {
+ opcodes = append(opcodes, byte(length&f>>(i*8)))
+ code |= 0x10 << i
+ }
+ }
+
+ return append([]byte{byte(code)}, opcodes...)
+}
--- /dev/null
+// Package packfile implements encoding and decoding of packfile format.
+//
+// == pack-*.pack files have the following format:
+//
+// - A header appears at the beginning and consists of the following:
+//
+// 4-byte signature:
+// The signature is: {'P', 'A', 'C', 'K'}
+//
+// 4-byte version number (network byte order):
+// GIT currently accepts version number 2 or 3 but
+// generates version 2 only.
+//
+// 4-byte number of objects contained in the pack (network byte order)
+//
+// Observation: we cannot have more than 4G versions ;-) and
+// more than 4G objects in a pack.
+//
+// - The header is followed by number of object entries, each of
+// which looks like this:
+//
+// (undeltified representation)
+// n-byte type and length (3-bit type, (n-1)*7+4-bit length)
+// compressed data
+//
+// (deltified representation)
+// n-byte type and length (3-bit type, (n-1)*7+4-bit length)
+// 20-byte base object name
+// compressed delta data
+//
+// Observation: length of each object is encoded in a variable
+// length format and is not constrained to 32-bit or anything.
+//
+// - The trailer records 20-byte SHA1 checksum of all of the above.
+//
+//
+// Source:
+// https://www.kernel.org/pub/software/scm/git/docs/v1.7.5/technical/pack-protocol.txt
+package packfile
--- /dev/null
+package packfile
+
+import (
+ "compress/zlib"
+ "crypto/sha1"
+ "fmt"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/utils/binary"
+)
+
+// Encoder gets the data from the storage and write it into the writer in PACK
+// format
+type Encoder struct {
+ selector *deltaSelector
+ w *offsetWriter
+ zw *zlib.Writer
+ hasher plumbing.Hasher
+
+ useRefDeltas bool
+}
+
+// NewEncoder creates a new packfile encoder using a specific Writer and
+// EncodedObjectStorer. By default deltas used to generate the packfile will be
+// OFSDeltaObject. To use Reference deltas, set useRefDeltas to true.
+func NewEncoder(w io.Writer, s storer.EncodedObjectStorer, useRefDeltas bool) *Encoder {
+ h := plumbing.Hasher{
+ Hash: sha1.New(),
+ }
+ mw := io.MultiWriter(w, h)
+ ow := newOffsetWriter(mw)
+ zw := zlib.NewWriter(mw)
+ return &Encoder{
+ selector: newDeltaSelector(s),
+ w: ow,
+ zw: zw,
+ hasher: h,
+ useRefDeltas: useRefDeltas,
+ }
+}
+
+// Encode creates a packfile containing all the objects referenced in
+// hashes and writes it to the writer in the Encoder. `packWindow`
+// specifies the size of the sliding window used to compare objects
+// for delta compression; 0 turns off delta compression entirely.
+func (e *Encoder) Encode(
+ hashes []plumbing.Hash,
+ packWindow uint,
+) (plumbing.Hash, error) {
+ objects, err := e.selector.ObjectsToPack(hashes, packWindow)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ return e.encode(objects)
+}
+
+func (e *Encoder) encode(objects []*ObjectToPack) (plumbing.Hash, error) {
+ if err := e.head(len(objects)); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ for _, o := range objects {
+ if err := e.entry(o); err != nil {
+ return plumbing.ZeroHash, err
+ }
+ }
+
+ return e.footer()
+}
+
+func (e *Encoder) head(numEntries int) error {
+ return binary.Write(
+ e.w,
+ signature,
+ int32(VersionSupported),
+ int32(numEntries),
+ )
+}
+
+func (e *Encoder) entry(o *ObjectToPack) error {
+ if o.WantWrite() {
+ // A cycle exists in this delta chain. This should only occur if a
+ // selected object representation disappeared during writing
+ // (for example due to a concurrent repack) and a different base
+ // was chosen, forcing a cycle. Select something other than a
+ // delta, and write this object.
+ e.selector.restoreOriginal(o)
+ o.BackToOriginal()
+ }
+
+ if o.IsWritten() {
+ return nil
+ }
+
+ o.MarkWantWrite()
+
+ if err := e.writeBaseIfDelta(o); err != nil {
+ return err
+ }
+
+ // We need to check if we already write that object due a cyclic delta chain
+ if o.IsWritten() {
+ return nil
+ }
+
+ o.Offset = e.w.Offset()
+
+ if o.IsDelta() {
+ if err := e.writeDeltaHeader(o); err != nil {
+ return err
+ }
+ } else {
+ if err := e.entryHead(o.Type(), o.Size()); err != nil {
+ return err
+ }
+ }
+
+ e.zw.Reset(e.w)
+ or, err := o.Object.Reader()
+ if err != nil {
+ return err
+ }
+
+ _, err = io.Copy(e.zw, or)
+ if err != nil {
+ return err
+ }
+
+ return e.zw.Close()
+}
+
+func (e *Encoder) writeBaseIfDelta(o *ObjectToPack) error {
+ if o.IsDelta() && !o.Base.IsWritten() {
+ // We must write base first
+ return e.entry(o.Base)
+ }
+
+ return nil
+}
+
+func (e *Encoder) writeDeltaHeader(o *ObjectToPack) error {
+ // Write offset deltas by default
+ t := plumbing.OFSDeltaObject
+ if e.useRefDeltas {
+ t = plumbing.REFDeltaObject
+ }
+
+ if err := e.entryHead(t, o.Object.Size()); err != nil {
+ return err
+ }
+
+ if e.useRefDeltas {
+ return e.writeRefDeltaHeader(o.Base.Hash())
+ } else {
+ return e.writeOfsDeltaHeader(o)
+ }
+}
+
+func (e *Encoder) writeRefDeltaHeader(base plumbing.Hash) error {
+ return binary.Write(e.w, base)
+}
+
+func (e *Encoder) writeOfsDeltaHeader(o *ObjectToPack) error {
+ // for OFS_DELTA, offset of the base is interpreted as negative offset
+ // relative to the type-byte of the header of the ofs-delta entry.
+ relativeOffset := o.Offset - o.Base.Offset
+ if relativeOffset <= 0 {
+ return fmt.Errorf("bad offset for OFS_DELTA entry: %d", relativeOffset)
+ }
+
+ return binary.WriteVariableWidthInt(e.w, relativeOffset)
+}
+
+func (e *Encoder) entryHead(typeNum plumbing.ObjectType, size int64) error {
+ t := int64(typeNum)
+ header := []byte{}
+ c := (t << firstLengthBits) | (size & maskFirstLength)
+ size >>= firstLengthBits
+ for {
+ if size == 0 {
+ break
+ }
+ header = append(header, byte(c|maskContinue))
+ c = size & int64(maskLength)
+ size >>= lengthBits
+ }
+
+ header = append(header, byte(c))
+ _, err := e.w.Write(header)
+
+ return err
+}
+
+func (e *Encoder) footer() (plumbing.Hash, error) {
+ h := e.hasher.Sum()
+ return h, binary.Write(e.w, h)
+}
+
+type offsetWriter struct {
+ w io.Writer
+ offset int64
+}
+
+func newOffsetWriter(w io.Writer) *offsetWriter {
+ return &offsetWriter{w: w}
+}
+
+func (ow *offsetWriter) Write(p []byte) (n int, err error) {
+ n, err = ow.w.Write(p)
+ ow.offset += int64(n)
+ return n, err
+}
+
+func (ow *offsetWriter) Offset() int64 {
+ return ow.offset
+}
--- /dev/null
+package packfile
+
+import "fmt"
+
+// Error specifies errors returned during packfile parsing.
+type Error struct {
+ reason, details string
+}
+
+// NewError returns a new error.
+func NewError(reason string) *Error {
+ return &Error{reason: reason}
+}
+
+// Error returns a text representation of the error.
+func (e *Error) Error() string {
+ if e.details == "" {
+ return e.reason
+ }
+
+ return fmt.Sprintf("%s: %s", e.reason, e.details)
+}
+
+// AddDetails adds details to an error, with additional text.
+func (e *Error) AddDetails(format string, args ...interface{}) *Error {
+ return &Error{
+ reason: e.reason,
+ details: fmt.Sprintf(format, args...),
+ }
+}
--- /dev/null
+package packfile
+
+import (
+ "io"
+
+ billy "gopkg.in/src-d/go-billy.v4"
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/cache"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/idxfile"
+)
+
+// FSObject is an object from the packfile on the filesystem.
+type FSObject struct {
+ hash plumbing.Hash
+ h *ObjectHeader
+ offset int64
+ size int64
+ typ plumbing.ObjectType
+ index idxfile.Index
+ fs billy.Filesystem
+ path string
+ cache cache.Object
+}
+
+// NewFSObject creates a new filesystem object.
+func NewFSObject(
+ hash plumbing.Hash,
+ finalType plumbing.ObjectType,
+ offset int64,
+ contentSize int64,
+ index idxfile.Index,
+ fs billy.Filesystem,
+ path string,
+ cache cache.Object,
+) *FSObject {
+ return &FSObject{
+ hash: hash,
+ offset: offset,
+ size: contentSize,
+ typ: finalType,
+ index: index,
+ fs: fs,
+ path: path,
+ cache: cache,
+ }
+}
+
+// Reader implements the plumbing.EncodedObject interface.
+func (o *FSObject) Reader() (io.ReadCloser, error) {
+ obj, ok := o.cache.Get(o.hash)
+ if ok {
+ reader, err := obj.Reader()
+ if err != nil {
+ return nil, err
+ }
+
+ return reader, nil
+ }
+
+ f, err := o.fs.Open(o.path)
+ if err != nil {
+ return nil, err
+ }
+
+ p := NewPackfileWithCache(o.index, nil, f, o.cache)
+ r, err := p.getObjectContent(o.offset)
+ if err != nil {
+ _ = f.Close()
+ return nil, err
+ }
+
+ if err := f.Close(); err != nil {
+ return nil, err
+ }
+
+ return r, nil
+}
+
+// SetSize implements the plumbing.EncodedObject interface. This method
+// is a noop.
+func (o *FSObject) SetSize(int64) {}
+
+// SetType implements the plumbing.EncodedObject interface. This method is
+// a noop.
+func (o *FSObject) SetType(plumbing.ObjectType) {}
+
+// Hash implements the plumbing.EncodedObject interface.
+func (o *FSObject) Hash() plumbing.Hash { return o.hash }
+
+// Size implements the plumbing.EncodedObject interface.
+func (o *FSObject) Size() int64 { return o.size }
+
+// Type implements the plumbing.EncodedObject interface.
+func (o *FSObject) Type() plumbing.ObjectType {
+ return o.typ
+}
+
+// Writer implements the plumbing.EncodedObject interface. This method always
+// returns a nil writer.
+func (o *FSObject) Writer() (io.WriteCloser, error) {
+ return nil, nil
+}
+
+type objectReader struct {
+ io.ReadCloser
+ f billy.File
+}
+
+func (r *objectReader) Close() error {
+ if err := r.ReadCloser.Close(); err != nil {
+ _ = r.f.Close()
+ return err
+ }
+
+ return r.f.Close()
+}
--- /dev/null
+package packfile
+
+import (
+ "gopkg.in/src-d/go-git.v4/plumbing"
+)
+
+// ObjectToPack is a representation of an object that is going to be into a
+// pack file.
+type ObjectToPack struct {
+ // The main object to pack, it could be any object, including deltas
+ Object plumbing.EncodedObject
+ // Base is the object that a delta is based on (it could be also another delta).
+ // If the main object is not a delta, Base will be null
+ Base *ObjectToPack
+ // Original is the object that we can generate applying the delta to
+ // Base, or the same object as Object in the case of a non-delta
+ // object.
+ Original plumbing.EncodedObject
+ // Depth is the amount of deltas needed to resolve to obtain Original
+ // (delta based on delta based on ...)
+ Depth int
+
+ // offset in pack when object has been already written, or 0 if it
+ // has not been written yet
+ Offset int64
+
+ // Information from the original object
+ resolvedOriginal bool
+ originalType plumbing.ObjectType
+ originalSize int64
+ originalHash plumbing.Hash
+}
+
+// newObjectToPack creates a correct ObjectToPack based on a non-delta object
+func newObjectToPack(o plumbing.EncodedObject) *ObjectToPack {
+ return &ObjectToPack{
+ Object: o,
+ Original: o,
+ }
+}
+
+// newDeltaObjectToPack creates a correct ObjectToPack for a delta object, based on
+// his base (could be another delta), the delta target (in this case called original),
+// and the delta Object itself
+func newDeltaObjectToPack(base *ObjectToPack, original, delta plumbing.EncodedObject) *ObjectToPack {
+ return &ObjectToPack{
+ Object: delta,
+ Base: base,
+ Original: original,
+ Depth: base.Depth + 1,
+ }
+}
+
+// BackToOriginal converts that ObjectToPack to a non-deltified object if it was one
+func (o *ObjectToPack) BackToOriginal() {
+ if o.IsDelta() && o.Original != nil {
+ o.Object = o.Original
+ o.Base = nil
+ o.Depth = 0
+ }
+}
+
+// IsWritten returns if that ObjectToPack was
+// already written into the packfile or not
+func (o *ObjectToPack) IsWritten() bool {
+ return o.Offset > 1
+}
+
+// MarkWantWrite marks this ObjectToPack as WantWrite
+// to avoid delta chain loops
+func (o *ObjectToPack) MarkWantWrite() {
+ o.Offset = 1
+}
+
+// WantWrite checks if this ObjectToPack was marked as WantWrite before
+func (o *ObjectToPack) WantWrite() bool {
+ return o.Offset == 1
+}
+
+// SetOriginal sets both Original and saves size, type and hash. If object
+// is nil Original is set but previous resolved values are kept
+func (o *ObjectToPack) SetOriginal(obj plumbing.EncodedObject) {
+ o.Original = obj
+ o.SaveOriginalMetadata()
+}
+
+// SaveOriginalMetadata saves size, type and hash of Original object
+func (o *ObjectToPack) SaveOriginalMetadata() {
+ if o.Original != nil {
+ o.originalSize = o.Original.Size()
+ o.originalType = o.Original.Type()
+ o.originalHash = o.Original.Hash()
+ o.resolvedOriginal = true
+ }
+}
+
+// CleanOriginal sets Original to nil
+func (o *ObjectToPack) CleanOriginal() {
+ o.Original = nil
+}
+
+func (o *ObjectToPack) Type() plumbing.ObjectType {
+ if o.Original != nil {
+ return o.Original.Type()
+ }
+
+ if o.resolvedOriginal {
+ return o.originalType
+ }
+
+ if o.Base != nil {
+ return o.Base.Type()
+ }
+
+ if o.Object != nil {
+ return o.Object.Type()
+ }
+
+ panic("cannot get type")
+}
+
+func (o *ObjectToPack) Hash() plumbing.Hash {
+ if o.Original != nil {
+ return o.Original.Hash()
+ }
+
+ if o.resolvedOriginal {
+ return o.originalHash
+ }
+
+ do, ok := o.Object.(plumbing.DeltaObject)
+ if ok {
+ return do.ActualHash()
+ }
+
+ panic("cannot get hash")
+}
+
+func (o *ObjectToPack) Size() int64 {
+ if o.Original != nil {
+ return o.Original.Size()
+ }
+
+ if o.resolvedOriginal {
+ return o.originalSize
+ }
+
+ do, ok := o.Object.(plumbing.DeltaObject)
+ if ok {
+ return do.ActualSize()
+ }
+
+ panic("cannot get ObjectToPack size")
+}
+
+func (o *ObjectToPack) IsDelta() bool {
+ return o.Base != nil
+}
+
+func (o *ObjectToPack) SetDelta(base *ObjectToPack, delta plumbing.EncodedObject) {
+ o.Object = delta
+ o.Base = base
+ o.Depth = base.Depth + 1
+}
--- /dev/null
+package packfile
+
+import (
+ "bytes"
+ "io"
+ "os"
+
+ billy "gopkg.in/src-d/go-billy.v4"
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/cache"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/idxfile"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+)
+
+var (
+ // ErrInvalidObject is returned by Decode when an invalid object is
+ // found in the packfile.
+ ErrInvalidObject = NewError("invalid git object")
+ // ErrZLib is returned by Decode when there was an error unzipping
+ // the packfile contents.
+ ErrZLib = NewError("zlib reading error")
+)
+
+// Packfile allows retrieving information from inside a packfile.
+type Packfile struct {
+ idxfile.Index
+ fs billy.Filesystem
+ file billy.File
+ s *Scanner
+ deltaBaseCache cache.Object
+ offsetToType map[int64]plumbing.ObjectType
+}
+
+// NewPackfileWithCache creates a new Packfile with the given object cache.
+// If the filesystem is provided, the packfile will return FSObjects, otherwise
+// it will return MemoryObjects.
+func NewPackfileWithCache(
+ index idxfile.Index,
+ fs billy.Filesystem,
+ file billy.File,
+ cache cache.Object,
+) *Packfile {
+ s := NewScanner(file)
+ return &Packfile{
+ index,
+ fs,
+ file,
+ s,
+ cache,
+ make(map[int64]plumbing.ObjectType),
+ }
+}
+
+// NewPackfile returns a packfile representation for the given packfile file
+// and packfile idx.
+// If the filesystem is provided, the packfile will return FSObjects, otherwise
+// it will return MemoryObjects.
+func NewPackfile(index idxfile.Index, fs billy.Filesystem, file billy.File) *Packfile {
+ return NewPackfileWithCache(index, fs, file, cache.NewObjectLRUDefault())
+}
+
+// Get retrieves the encoded object in the packfile with the given hash.
+func (p *Packfile) Get(h plumbing.Hash) (plumbing.EncodedObject, error) {
+ offset, err := p.FindOffset(h)
+ if err != nil {
+ return nil, err
+ }
+
+ return p.GetByOffset(offset)
+}
+
+// GetByOffset retrieves the encoded object from the packfile with the given
+// offset.
+func (p *Packfile) GetByOffset(o int64) (plumbing.EncodedObject, error) {
+ hash, err := p.FindHash(o)
+ if err == nil {
+ if obj, ok := p.deltaBaseCache.Get(hash); ok {
+ return obj, nil
+ }
+ }
+
+ if _, err := p.s.SeekFromStart(o); err != nil {
+ if err == io.EOF || isInvalid(err) {
+ return nil, plumbing.ErrObjectNotFound
+ }
+
+ return nil, err
+ }
+
+ return p.nextObject()
+}
+
+// GetSizeByOffset retrieves the size of the encoded object from the
+// packfile with the given offset.
+func (p *Packfile) GetSizeByOffset(o int64) (size int64, err error) {
+ if _, err := p.s.SeekFromStart(o); err != nil {
+ if err == io.EOF || isInvalid(err) {
+ return 0, plumbing.ErrObjectNotFound
+ }
+
+ return 0, err
+ }
+
+ h, err := p.nextObjectHeader()
+ if err != nil {
+ return 0, err
+ }
+ return h.Length, nil
+}
+
+func (p *Packfile) nextObjectHeader() (*ObjectHeader, error) {
+ h, err := p.s.NextObjectHeader()
+ p.s.pendingObject = nil
+ return h, err
+}
+
+func (p *Packfile) getObjectSize(h *ObjectHeader) (int64, error) {
+ switch h.Type {
+ case plumbing.CommitObject, plumbing.TreeObject, plumbing.BlobObject, plumbing.TagObject:
+ return h.Length, nil
+ case plumbing.REFDeltaObject, plumbing.OFSDeltaObject:
+ buf := bufPool.Get().(*bytes.Buffer)
+ buf.Reset()
+ defer bufPool.Put(buf)
+
+ if _, _, err := p.s.NextObject(buf); err != nil {
+ return 0, err
+ }
+
+ delta := buf.Bytes()
+ _, delta = decodeLEB128(delta) // skip src size
+ sz, _ := decodeLEB128(delta)
+ return int64(sz), nil
+ default:
+ return 0, ErrInvalidObject.AddDetails("type %q", h.Type)
+ }
+}
+
+func (p *Packfile) getObjectType(h *ObjectHeader) (typ plumbing.ObjectType, err error) {
+ switch h.Type {
+ case plumbing.CommitObject, plumbing.TreeObject, plumbing.BlobObject, plumbing.TagObject:
+ return h.Type, nil
+ case plumbing.REFDeltaObject, plumbing.OFSDeltaObject:
+ var offset int64
+ if h.Type == plumbing.REFDeltaObject {
+ offset, err = p.FindOffset(h.Reference)
+ if err != nil {
+ return
+ }
+ } else {
+ offset = h.OffsetReference
+ }
+
+ if baseType, ok := p.offsetToType[offset]; ok {
+ typ = baseType
+ } else {
+ if _, err = p.s.SeekFromStart(offset); err != nil {
+ return
+ }
+
+ h, err = p.nextObjectHeader()
+ if err != nil {
+ return
+ }
+
+ typ, err = p.getObjectType(h)
+ if err != nil {
+ return
+ }
+ }
+ default:
+ err = ErrInvalidObject.AddDetails("type %q", h.Type)
+ }
+
+ return
+}
+
+func (p *Packfile) nextObject() (plumbing.EncodedObject, error) {
+ h, err := p.nextObjectHeader()
+ if err != nil {
+ if err == io.EOF || isInvalid(err) {
+ return nil, plumbing.ErrObjectNotFound
+ }
+ return nil, err
+ }
+
+ // If we have no filesystem, we will return a MemoryObject instead
+ // of an FSObject.
+ if p.fs == nil {
+ return p.getNextObject(h)
+ }
+
+ hash, err := p.FindHash(h.Offset)
+ if err != nil {
+ return nil, err
+ }
+
+ size, err := p.getObjectSize(h)
+ if err != nil {
+ return nil, err
+ }
+
+ typ, err := p.getObjectType(h)
+ if err != nil {
+ return nil, err
+ }
+
+ p.offsetToType[h.Offset] = typ
+
+ return NewFSObject(
+ hash,
+ typ,
+ h.Offset,
+ size,
+ p.Index,
+ p.fs,
+ p.file.Name(),
+ p.deltaBaseCache,
+ ), nil
+}
+
+func (p *Packfile) getObjectContent(offset int64) (io.ReadCloser, error) {
+ ref, err := p.FindHash(offset)
+ if err == nil {
+ obj, ok := p.cacheGet(ref)
+ if ok {
+ reader, err := obj.Reader()
+ if err != nil {
+ return nil, err
+ }
+
+ return reader, nil
+ }
+ }
+
+ if _, err := p.s.SeekFromStart(offset); err != nil {
+ return nil, err
+ }
+
+ h, err := p.nextObjectHeader()
+ if err != nil {
+ return nil, err
+ }
+
+ obj, err := p.getNextObject(h)
+ if err != nil {
+ return nil, err
+ }
+
+ return obj.Reader()
+}
+
+func (p *Packfile) getNextObject(h *ObjectHeader) (plumbing.EncodedObject, error) {
+ var obj = new(plumbing.MemoryObject)
+ obj.SetSize(h.Length)
+ obj.SetType(h.Type)
+
+ var err error
+ switch h.Type {
+ case plumbing.CommitObject, plumbing.TreeObject, plumbing.BlobObject, plumbing.TagObject:
+ err = p.fillRegularObjectContent(obj)
+ case plumbing.REFDeltaObject:
+ err = p.fillREFDeltaObjectContent(obj, h.Reference)
+ case plumbing.OFSDeltaObject:
+ err = p.fillOFSDeltaObjectContent(obj, h.OffsetReference)
+ default:
+ err = ErrInvalidObject.AddDetails("type %q", h.Type)
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ return obj, nil
+}
+
+func (p *Packfile) fillRegularObjectContent(obj plumbing.EncodedObject) error {
+ w, err := obj.Writer()
+ if err != nil {
+ return err
+ }
+
+ _, _, err = p.s.NextObject(w)
+ p.cachePut(obj)
+
+ return err
+}
+
+func (p *Packfile) fillREFDeltaObjectContent(obj plumbing.EncodedObject, ref plumbing.Hash) error {
+ buf := bufPool.Get().(*bytes.Buffer)
+ buf.Reset()
+ _, _, err := p.s.NextObject(buf)
+ if err != nil {
+ return err
+ }
+
+ base, ok := p.cacheGet(ref)
+ if !ok {
+ base, err = p.Get(ref)
+ if err != nil {
+ return err
+ }
+ }
+
+ obj.SetType(base.Type())
+ err = ApplyDelta(obj, base, buf.Bytes())
+ p.cachePut(obj)
+ bufPool.Put(buf)
+
+ return err
+}
+
+func (p *Packfile) fillOFSDeltaObjectContent(obj plumbing.EncodedObject, offset int64) error {
+ buf := bytes.NewBuffer(nil)
+ _, _, err := p.s.NextObject(buf)
+ if err != nil {
+ return err
+ }
+
+ var base plumbing.EncodedObject
+ var ok bool
+ hash, err := p.FindHash(offset)
+ if err == nil {
+ base, ok = p.cacheGet(hash)
+ }
+
+ if !ok {
+ base, err = p.GetByOffset(offset)
+ if err != nil {
+ return err
+ }
+
+ p.cachePut(base)
+ }
+
+ obj.SetType(base.Type())
+ err = ApplyDelta(obj, base, buf.Bytes())
+ p.cachePut(obj)
+
+ return err
+}
+
+func (p *Packfile) cacheGet(h plumbing.Hash) (plumbing.EncodedObject, bool) {
+ if p.deltaBaseCache == nil {
+ return nil, false
+ }
+
+ return p.deltaBaseCache.Get(h)
+}
+
+func (p *Packfile) cachePut(obj plumbing.EncodedObject) {
+ if p.deltaBaseCache == nil {
+ return
+ }
+
+ p.deltaBaseCache.Put(obj)
+}
+
+// GetAll returns an iterator with all encoded objects in the packfile.
+// The iterator returned is not thread-safe, it should be used in the same
+// thread as the Packfile instance.
+func (p *Packfile) GetAll() (storer.EncodedObjectIter, error) {
+ return p.GetByType(plumbing.AnyObject)
+}
+
+// GetByType returns all the objects of the given type.
+func (p *Packfile) GetByType(typ plumbing.ObjectType) (storer.EncodedObjectIter, error) {
+ switch typ {
+ case plumbing.AnyObject,
+ plumbing.BlobObject,
+ plumbing.TreeObject,
+ plumbing.CommitObject,
+ plumbing.TagObject:
+ entries, err := p.EntriesByOffset()
+ if err != nil {
+ return nil, err
+ }
+
+ return &objectIter{
+ // Easiest way to provide an object decoder is just to pass a Packfile
+ // instance. To not mess with the seeks, it's a new instance with a
+ // different scanner but the same cache and offset to hash map for
+ // reusing as much cache as possible.
+ p: p,
+ iter: entries,
+ typ: typ,
+ }, nil
+ default:
+ return nil, plumbing.ErrInvalidType
+ }
+}
+
+// ID returns the ID of the packfile, which is the checksum at the end of it.
+func (p *Packfile) ID() (plumbing.Hash, error) {
+ prev, err := p.file.Seek(-20, io.SeekEnd)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ var hash plumbing.Hash
+ if _, err := io.ReadFull(p.file, hash[:]); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ if _, err := p.file.Seek(prev, io.SeekStart); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ return hash, nil
+}
+
+// Close the packfile and its resources.
+func (p *Packfile) Close() error {
+ closer, ok := p.file.(io.Closer)
+ if !ok {
+ return nil
+ }
+
+ return closer.Close()
+}
+
+type objectIter struct {
+ p *Packfile
+ typ plumbing.ObjectType
+ iter idxfile.EntryIter
+}
+
+func (i *objectIter) Next() (plumbing.EncodedObject, error) {
+ for {
+ e, err := i.iter.Next()
+ if err != nil {
+ return nil, err
+ }
+
+ obj, err := i.p.GetByOffset(int64(e.Offset))
+ if err != nil {
+ return nil, err
+ }
+
+ if i.typ == plumbing.AnyObject || obj.Type() == i.typ {
+ return obj, nil
+ }
+ }
+}
+
+func (i *objectIter) ForEach(f func(plumbing.EncodedObject) error) error {
+ for {
+ o, err := i.Next()
+ if err != nil {
+ if err == io.EOF {
+ return nil
+ }
+ return err
+ }
+
+ if err := f(o); err != nil {
+ return err
+ }
+ }
+}
+
+func (i *objectIter) Close() {
+ i.iter.Close()
+}
+
+// isInvalid checks whether an error is an os.PathError with an os.ErrInvalid
+// error inside. It also checks for the windows error, which is different from
+// os.ErrInvalid.
+func isInvalid(err error) bool {
+ pe, ok := err.(*os.PathError)
+ if !ok {
+ return false
+ }
+
+ errstr := pe.Err.Error()
+ return errstr == errInvalidUnix || errstr == errInvalidWindows
+}
+
+// errInvalidWindows is the Windows equivalent to os.ErrInvalid
+const errInvalidWindows = "The parameter is incorrect."
+
+var errInvalidUnix = os.ErrInvalid.Error()
--- /dev/null
+package packfile
+
+import (
+ "bytes"
+ "errors"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/cache"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+)
+
+var (
+ // ErrReferenceDeltaNotFound is returned when the reference delta is not
+ // found.
+ ErrReferenceDeltaNotFound = errors.New("reference delta not found")
+
+ // ErrNotSeekableSource is returned when the source for the parser is not
+ // seekable and a storage was not provided, so it can't be parsed.
+ ErrNotSeekableSource = errors.New("parser source is not seekable and storage was not provided")
+
+ // ErrDeltaNotCached is returned when the delta could not be found in cache.
+ ErrDeltaNotCached = errors.New("delta could not be found in cache")
+)
+
+// Observer interface is implemented by index encoders.
+type Observer interface {
+ // OnHeader is called when a new packfile is opened.
+ OnHeader(count uint32) error
+ // OnInflatedObjectHeader is called for each object header read.
+ OnInflatedObjectHeader(t plumbing.ObjectType, objSize int64, pos int64) error
+ // OnInflatedObjectContent is called for each decoded object.
+ OnInflatedObjectContent(h plumbing.Hash, pos int64, crc uint32, content []byte) error
+ // OnFooter is called when decoding is done.
+ OnFooter(h plumbing.Hash) error
+}
+
+// Parser decodes a packfile and calls any observer associated to it. Is used
+// to generate indexes.
+type Parser struct {
+ storage storer.EncodedObjectStorer
+ scanner *Scanner
+ count uint32
+ oi []*objectInfo
+ oiByHash map[plumbing.Hash]*objectInfo
+ oiByOffset map[int64]*objectInfo
+ hashOffset map[plumbing.Hash]int64
+ checksum plumbing.Hash
+
+ cache *cache.BufferLRU
+ // delta content by offset, only used if source is not seekable
+ deltas map[int64][]byte
+
+ ob []Observer
+}
+
+// NewParser creates a new Parser. The Scanner source must be seekable.
+// If it's not, NewParserWithStorage should be used instead.
+func NewParser(scanner *Scanner, ob ...Observer) (*Parser, error) {
+ return NewParserWithStorage(scanner, nil, ob...)
+}
+
+// NewParserWithStorage creates a new Parser. The scanner source must either
+// be seekable or a storage must be provided.
+func NewParserWithStorage(
+ scanner *Scanner,
+ storage storer.EncodedObjectStorer,
+ ob ...Observer,
+) (*Parser, error) {
+ if !scanner.IsSeekable && storage == nil {
+ return nil, ErrNotSeekableSource
+ }
+
+ var deltas map[int64][]byte
+ if !scanner.IsSeekable {
+ deltas = make(map[int64][]byte)
+ }
+
+ return &Parser{
+ storage: storage,
+ scanner: scanner,
+ ob: ob,
+ count: 0,
+ cache: cache.NewBufferLRUDefault(),
+ deltas: deltas,
+ }, nil
+}
+
+func (p *Parser) forEachObserver(f func(o Observer) error) error {
+ for _, o := range p.ob {
+ if err := f(o); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (p *Parser) onHeader(count uint32) error {
+ return p.forEachObserver(func(o Observer) error {
+ return o.OnHeader(count)
+ })
+}
+
+func (p *Parser) onInflatedObjectHeader(
+ t plumbing.ObjectType,
+ objSize int64,
+ pos int64,
+) error {
+ return p.forEachObserver(func(o Observer) error {
+ return o.OnInflatedObjectHeader(t, objSize, pos)
+ })
+}
+
+func (p *Parser) onInflatedObjectContent(
+ h plumbing.Hash,
+ pos int64,
+ crc uint32,
+ content []byte,
+) error {
+ return p.forEachObserver(func(o Observer) error {
+ return o.OnInflatedObjectContent(h, pos, crc, content)
+ })
+}
+
+func (p *Parser) onFooter(h plumbing.Hash) error {
+ return p.forEachObserver(func(o Observer) error {
+ return o.OnFooter(h)
+ })
+}
+
+// Parse start decoding phase of the packfile.
+func (p *Parser) Parse() (plumbing.Hash, error) {
+ if err := p.init(); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ if err := p.indexObjects(); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ var err error
+ p.checksum, err = p.scanner.Checksum()
+ if err != nil && err != io.EOF {
+ return plumbing.ZeroHash, err
+ }
+
+ if err := p.resolveDeltas(); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ if err := p.onFooter(p.checksum); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ return p.checksum, nil
+}
+
+func (p *Parser) init() error {
+ _, c, err := p.scanner.Header()
+ if err != nil {
+ return err
+ }
+
+ if err := p.onHeader(c); err != nil {
+ return err
+ }
+
+ p.count = c
+ p.oiByHash = make(map[plumbing.Hash]*objectInfo, p.count)
+ p.oiByOffset = make(map[int64]*objectInfo, p.count)
+ p.oi = make([]*objectInfo, p.count)
+
+ return nil
+}
+
+func (p *Parser) indexObjects() error {
+ buf := new(bytes.Buffer)
+
+ for i := uint32(0); i < p.count; i++ {
+ buf.Reset()
+
+ oh, err := p.scanner.NextObjectHeader()
+ if err != nil {
+ return err
+ }
+
+ delta := false
+ var ota *objectInfo
+ switch t := oh.Type; t {
+ case plumbing.OFSDeltaObject:
+ delta = true
+
+ parent, ok := p.oiByOffset[oh.OffsetReference]
+ if !ok {
+ return plumbing.ErrObjectNotFound
+ }
+
+ ota = newDeltaObject(oh.Offset, oh.Length, t, parent)
+ parent.Children = append(parent.Children, ota)
+ case plumbing.REFDeltaObject:
+ delta = true
+ parent, ok := p.oiByHash[oh.Reference]
+ if !ok {
+ // can't find referenced object in this pack file
+ // this must be a "thin" pack.
+ parent = &objectInfo{ //Placeholder parent
+ SHA1: oh.Reference,
+ ExternalRef: true, // mark as an external reference that must be resolved
+ Type: plumbing.AnyObject,
+ DiskType: plumbing.AnyObject,
+ }
+ p.oiByHash[oh.Reference] = parent
+ }
+ ota = newDeltaObject(oh.Offset, oh.Length, t, parent)
+ parent.Children = append(parent.Children, ota)
+
+ default:
+ ota = newBaseObject(oh.Offset, oh.Length, t)
+ }
+
+ _, crc, err := p.scanner.NextObject(buf)
+ if err != nil {
+ return err
+ }
+
+ ota.Crc32 = crc
+ ota.Length = oh.Length
+
+ data := buf.Bytes()
+ if !delta {
+ sha1, err := getSHA1(ota.Type, data)
+ if err != nil {
+ return err
+ }
+
+ ota.SHA1 = sha1
+ p.oiByHash[ota.SHA1] = ota
+ }
+
+ if p.storage != nil && !delta {
+ obj := new(plumbing.MemoryObject)
+ obj.SetSize(oh.Length)
+ obj.SetType(oh.Type)
+ if _, err := obj.Write(data); err != nil {
+ return err
+ }
+
+ if _, err := p.storage.SetEncodedObject(obj); err != nil {
+ return err
+ }
+ }
+
+ if delta && !p.scanner.IsSeekable {
+ p.deltas[oh.Offset] = make([]byte, len(data))
+ copy(p.deltas[oh.Offset], data)
+ }
+
+ p.oiByOffset[oh.Offset] = ota
+ p.oi[i] = ota
+ }
+
+ return nil
+}
+
+func (p *Parser) resolveDeltas() error {
+ for _, obj := range p.oi {
+ content, err := p.get(obj)
+ if err != nil {
+ return err
+ }
+
+ if err := p.onInflatedObjectHeader(obj.Type, obj.Length, obj.Offset); err != nil {
+ return err
+ }
+
+ if err := p.onInflatedObjectContent(obj.SHA1, obj.Offset, obj.Crc32, content); err != nil {
+ return err
+ }
+
+ if !obj.IsDelta() && len(obj.Children) > 0 {
+ for _, child := range obj.Children {
+ if _, err := p.resolveObject(child, content); err != nil {
+ return err
+ }
+ }
+
+ // Remove the delta from the cache.
+ if obj.DiskType.IsDelta() && !p.scanner.IsSeekable {
+ delete(p.deltas, obj.Offset)
+ }
+ }
+ }
+
+ return nil
+}
+
+func (p *Parser) get(o *objectInfo) (b []byte, err error) {
+ var ok bool
+ if !o.ExternalRef { // skip cache check for placeholder parents
+ b, ok = p.cache.Get(o.Offset)
+ }
+
+ // If it's not on the cache and is not a delta we can try to find it in the
+ // storage, if there's one. External refs must enter here.
+ if !ok && p.storage != nil && !o.Type.IsDelta() {
+ e, err := p.storage.EncodedObject(plumbing.AnyObject, o.SHA1)
+ if err != nil {
+ return nil, err
+ }
+ o.Type = e.Type()
+
+ r, err := e.Reader()
+ if err != nil {
+ return nil, err
+ }
+
+ b = make([]byte, e.Size())
+ if _, err = r.Read(b); err != nil {
+ return nil, err
+ }
+ }
+
+ if b != nil {
+ return b, nil
+ }
+
+ if o.ExternalRef {
+ // we were not able to resolve a ref in a thin pack
+ return nil, ErrReferenceDeltaNotFound
+ }
+
+ var data []byte
+ if o.DiskType.IsDelta() {
+ base, err := p.get(o.Parent)
+ if err != nil {
+ return nil, err
+ }
+
+ data, err = p.resolveObject(o, base)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ data, err = p.readData(o)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if len(o.Children) > 0 {
+ p.cache.Put(o.Offset, data)
+ }
+
+ return data, nil
+}
+
+func (p *Parser) resolveObject(
+ o *objectInfo,
+ base []byte,
+) ([]byte, error) {
+ if !o.DiskType.IsDelta() {
+ return nil, nil
+ }
+
+ data, err := p.readData(o)
+ if err != nil {
+ return nil, err
+ }
+
+ data, err = applyPatchBase(o, data, base)
+ if err != nil {
+ return nil, err
+ }
+
+ if p.storage != nil {
+ obj := new(plumbing.MemoryObject)
+ obj.SetSize(o.Size())
+ obj.SetType(o.Type)
+ if _, err := obj.Write(data); err != nil {
+ return nil, err
+ }
+
+ if _, err := p.storage.SetEncodedObject(obj); err != nil {
+ return nil, err
+ }
+ }
+
+ return data, nil
+}
+
+func (p *Parser) readData(o *objectInfo) ([]byte, error) {
+ if !p.scanner.IsSeekable && o.DiskType.IsDelta() {
+ data, ok := p.deltas[o.Offset]
+ if !ok {
+ return nil, ErrDeltaNotCached
+ }
+
+ return data, nil
+ }
+
+ if _, err := p.scanner.SeekFromStart(o.Offset); err != nil {
+ return nil, err
+ }
+
+ if _, err := p.scanner.NextObjectHeader(); err != nil {
+ return nil, err
+ }
+
+ buf := new(bytes.Buffer)
+ if _, _, err := p.scanner.NextObject(buf); err != nil {
+ return nil, err
+ }
+
+ return buf.Bytes(), nil
+}
+
+func applyPatchBase(ota *objectInfo, data, base []byte) ([]byte, error) {
+ patched, err := PatchDelta(base, data)
+ if err != nil {
+ return nil, err
+ }
+
+ if ota.SHA1 == plumbing.ZeroHash {
+ ota.Type = ota.Parent.Type
+ sha1, err := getSHA1(ota.Type, patched)
+ if err != nil {
+ return nil, err
+ }
+
+ ota.SHA1 = sha1
+ ota.Length = int64(len(patched))
+ }
+
+ return patched, nil
+}
+
+func getSHA1(t plumbing.ObjectType, data []byte) (plumbing.Hash, error) {
+ hasher := plumbing.NewHasher(t, int64(len(data)))
+ if _, err := hasher.Write(data); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ return hasher.Sum(), nil
+}
+
+type objectInfo struct {
+ Offset int64
+ Length int64
+ Type plumbing.ObjectType
+ DiskType plumbing.ObjectType
+ ExternalRef bool // indicates this is an external reference in a thin pack file
+
+ Crc32 uint32
+
+ Parent *objectInfo
+ Children []*objectInfo
+ SHA1 plumbing.Hash
+}
+
+func newBaseObject(offset, length int64, t plumbing.ObjectType) *objectInfo {
+ return newDeltaObject(offset, length, t, nil)
+}
+
+func newDeltaObject(
+ offset, length int64,
+ t plumbing.ObjectType,
+ parent *objectInfo,
+) *objectInfo {
+ obj := &objectInfo{
+ Offset: offset,
+ Length: length,
+ Type: t,
+ DiskType: t,
+ Crc32: 0,
+ Parent: parent,
+ }
+
+ return obj
+}
+
+func (o *objectInfo) IsDelta() bool {
+ return o.Type.IsDelta()
+}
+
+func (o *objectInfo) Size() int64 {
+ return o.Length
+}
--- /dev/null
+package packfile
+
+import (
+ "errors"
+ "io/ioutil"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+)
+
+// See https://github.com/git/git/blob/49fa3dc76179e04b0833542fa52d0f287a4955ac/delta.h
+// https://github.com/git/git/blob/c2c5f6b1e479f2c38e0e01345350620944e3527f/patch-delta.c,
+// and https://github.com/tarruda/node-git-core/blob/master/src/js/delta.js
+// for details about the delta format.
+
+const deltaSizeMin = 4
+
+// ApplyDelta writes to target the result of applying the modification deltas in delta to base.
+func ApplyDelta(target, base plumbing.EncodedObject, delta []byte) error {
+ r, err := base.Reader()
+ if err != nil {
+ return err
+ }
+
+ w, err := target.Writer()
+ if err != nil {
+ return err
+ }
+
+ src, err := ioutil.ReadAll(r)
+ if err != nil {
+ return err
+ }
+
+ dst, err := PatchDelta(src, delta)
+ if err != nil {
+ return err
+ }
+
+ target.SetSize(int64(len(dst)))
+
+ _, err = w.Write(dst)
+ return err
+}
+
+var (
+ ErrInvalidDelta = errors.New("invalid delta")
+ ErrDeltaCmd = errors.New("wrong delta command")
+)
+
+// PatchDelta returns the result of applying the modification deltas in delta to src.
+// An error will be returned if delta is corrupted (ErrDeltaLen) or an action command
+// is not copy from source or copy from delta (ErrDeltaCmd).
+func PatchDelta(src, delta []byte) ([]byte, error) {
+ if len(delta) < deltaSizeMin {
+ return nil, ErrInvalidDelta
+ }
+
+ srcSz, delta := decodeLEB128(delta)
+ if srcSz != uint(len(src)) {
+ return nil, ErrInvalidDelta
+ }
+
+ targetSz, delta := decodeLEB128(delta)
+ remainingTargetSz := targetSz
+
+ var cmd byte
+ dest := make([]byte, 0, targetSz)
+ for {
+ if len(delta) == 0 {
+ return nil, ErrInvalidDelta
+ }
+
+ cmd = delta[0]
+ delta = delta[1:]
+ if isCopyFromSrc(cmd) {
+ var offset, sz uint
+ var err error
+ offset, delta, err = decodeOffset(cmd, delta)
+ if err != nil {
+ return nil, err
+ }
+
+ sz, delta, err = decodeSize(cmd, delta)
+ if err != nil {
+ return nil, err
+ }
+
+ if invalidSize(sz, targetSz) ||
+ invalidOffsetSize(offset, sz, srcSz) {
+ break
+ }
+ dest = append(dest, src[offset:offset+sz]...)
+ remainingTargetSz -= sz
+ } else if isCopyFromDelta(cmd) {
+ sz := uint(cmd) // cmd is the size itself
+ if invalidSize(sz, targetSz) {
+ return nil, ErrInvalidDelta
+ }
+
+ if uint(len(delta)) < sz {
+ return nil, ErrInvalidDelta
+ }
+
+ dest = append(dest, delta[0:sz]...)
+ remainingTargetSz -= sz
+ delta = delta[sz:]
+ } else {
+ return nil, ErrDeltaCmd
+ }
+
+ if remainingTargetSz <= 0 {
+ break
+ }
+ }
+
+ return dest, nil
+}
+
+// Decodes a number encoded as an unsigned LEB128 at the start of some
+// binary data and returns the decoded number and the rest of the
+// stream.
+//
+// This must be called twice on the delta data buffer, first to get the
+// expected source buffer size, and again to get the target buffer size.
+func decodeLEB128(input []byte) (uint, []byte) {
+ var num, sz uint
+ var b byte
+ for {
+ b = input[sz]
+ num |= (uint(b) & payload) << (sz * 7) // concats 7 bits chunks
+ sz++
+
+ if uint(b)&continuation == 0 || sz == uint(len(input)) {
+ break
+ }
+ }
+
+ return num, input[sz:]
+}
+
+const (
+ payload = 0x7f // 0111 1111
+ continuation = 0x80 // 1000 0000
+)
+
+func isCopyFromSrc(cmd byte) bool {
+ return (cmd & 0x80) != 0
+}
+
+func isCopyFromDelta(cmd byte) bool {
+ return (cmd&0x80) == 0 && cmd != 0
+}
+
+func decodeOffset(cmd byte, delta []byte) (uint, []byte, error) {
+ var offset uint
+ if (cmd & 0x01) != 0 {
+ if len(delta) == 0 {
+ return 0, nil, ErrInvalidDelta
+ }
+ offset = uint(delta[0])
+ delta = delta[1:]
+ }
+ if (cmd & 0x02) != 0 {
+ if len(delta) == 0 {
+ return 0, nil, ErrInvalidDelta
+ }
+ offset |= uint(delta[0]) << 8
+ delta = delta[1:]
+ }
+ if (cmd & 0x04) != 0 {
+ if len(delta) == 0 {
+ return 0, nil, ErrInvalidDelta
+ }
+ offset |= uint(delta[0]) << 16
+ delta = delta[1:]
+ }
+ if (cmd & 0x08) != 0 {
+ if len(delta) == 0 {
+ return 0, nil, ErrInvalidDelta
+ }
+ offset |= uint(delta[0]) << 24
+ delta = delta[1:]
+ }
+
+ return offset, delta, nil
+}
+
+func decodeSize(cmd byte, delta []byte) (uint, []byte, error) {
+ var sz uint
+ if (cmd & 0x10) != 0 {
+ if len(delta) == 0 {
+ return 0, nil, ErrInvalidDelta
+ }
+ sz = uint(delta[0])
+ delta = delta[1:]
+ }
+ if (cmd & 0x20) != 0 {
+ if len(delta) == 0 {
+ return 0, nil, ErrInvalidDelta
+ }
+ sz |= uint(delta[0]) << 8
+ delta = delta[1:]
+ }
+ if (cmd & 0x40) != 0 {
+ if len(delta) == 0 {
+ return 0, nil, ErrInvalidDelta
+ }
+ sz |= uint(delta[0]) << 16
+ delta = delta[1:]
+ }
+ if sz == 0 {
+ sz = 0x10000
+ }
+
+ return sz, delta, nil
+}
+
+func invalidSize(sz, targetSz uint) bool {
+ return sz > targetSz
+}
+
+func invalidOffsetSize(offset, sz, srcSz uint) bool {
+ return sumOverflows(offset, sz) ||
+ offset+sz > srcSz
+}
+
+func sumOverflows(a, b uint) bool {
+ return a+b < a
+}
--- /dev/null
+package packfile
+
+import (
+ "bufio"
+ "bytes"
+ "compress/zlib"
+ "fmt"
+ "hash"
+ "hash/crc32"
+ "io"
+ stdioutil "io/ioutil"
+ "sync"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/utils/binary"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+var (
+ // ErrEmptyPackfile is returned by ReadHeader when no data is found in the packfile
+ ErrEmptyPackfile = NewError("empty packfile")
+ // ErrBadSignature is returned by ReadHeader when the signature in the packfile is incorrect.
+ ErrBadSignature = NewError("malformed pack file signature")
+ // ErrUnsupportedVersion is returned by ReadHeader when the packfile version is
+ // different than VersionSupported.
+ ErrUnsupportedVersion = NewError("unsupported packfile version")
+ // ErrSeekNotSupported returned if seek is not support
+ ErrSeekNotSupported = NewError("not seek support")
+)
+
+// ObjectHeader contains the information related to the object, this information
+// is collected from the previous bytes to the content of the object.
+type ObjectHeader struct {
+ Type plumbing.ObjectType
+ Offset int64
+ Length int64
+ Reference plumbing.Hash
+ OffsetReference int64
+}
+
+type Scanner struct {
+ r reader
+ zr readerResetter
+ crc hash.Hash32
+
+ // pendingObject is used to detect if an object has been read, or still
+ // is waiting to be read
+ pendingObject *ObjectHeader
+ version, objects uint32
+
+ // lsSeekable says if this scanner can do Seek or not, to have a Scanner
+ // seekable a r implementing io.Seeker is required
+ IsSeekable bool
+}
+
+// NewScanner returns a new Scanner based on a reader, if the given reader
+// implements io.ReadSeeker the Scanner will be also Seekable
+func NewScanner(r io.Reader) *Scanner {
+ seeker, ok := r.(io.ReadSeeker)
+ if !ok {
+ seeker = &trackableReader{Reader: r}
+ }
+
+ crc := crc32.NewIEEE()
+ return &Scanner{
+ r: newTeeReader(newByteReadSeeker(seeker), crc),
+ crc: crc,
+ IsSeekable: ok,
+ }
+}
+
+// Header reads the whole packfile header (signature, version and object count).
+// It returns the version and the object count and performs checks on the
+// validity of the signature and the version fields.
+func (s *Scanner) Header() (version, objects uint32, err error) {
+ if s.version != 0 {
+ return s.version, s.objects, nil
+ }
+
+ sig, err := s.readSignature()
+ if err != nil {
+ if err == io.EOF {
+ err = ErrEmptyPackfile
+ }
+
+ return
+ }
+
+ if !s.isValidSignature(sig) {
+ err = ErrBadSignature
+ return
+ }
+
+ version, err = s.readVersion()
+ s.version = version
+ if err != nil {
+ return
+ }
+
+ if !s.isSupportedVersion(version) {
+ err = ErrUnsupportedVersion.AddDetails("%d", version)
+ return
+ }
+
+ objects, err = s.readCount()
+ s.objects = objects
+ return
+}
+
+// readSignature reads an returns the signature field in the packfile.
+func (s *Scanner) readSignature() ([]byte, error) {
+ var sig = make([]byte, 4)
+ if _, err := io.ReadFull(s.r, sig); err != nil {
+ return []byte{}, err
+ }
+
+ return sig, nil
+}
+
+// isValidSignature returns if sig is a valid packfile signature.
+func (s *Scanner) isValidSignature(sig []byte) bool {
+ return bytes.Equal(sig, signature)
+}
+
+// readVersion reads and returns the version field of a packfile.
+func (s *Scanner) readVersion() (uint32, error) {
+ return binary.ReadUint32(s.r)
+}
+
+// isSupportedVersion returns whether version v is supported by the parser.
+// The current supported version is VersionSupported, defined above.
+func (s *Scanner) isSupportedVersion(v uint32) bool {
+ return v == VersionSupported
+}
+
+// readCount reads and returns the count of objects field of a packfile.
+func (s *Scanner) readCount() (uint32, error) {
+ return binary.ReadUint32(s.r)
+}
+
+// NextObjectHeader returns the ObjectHeader for the next object in the reader
+func (s *Scanner) NextObjectHeader() (*ObjectHeader, error) {
+ defer s.Flush()
+
+ if err := s.doPending(); err != nil {
+ return nil, err
+ }
+
+ s.crc.Reset()
+
+ h := &ObjectHeader{}
+ s.pendingObject = h
+
+ var err error
+ h.Offset, err = s.r.Seek(0, io.SeekCurrent)
+ if err != nil {
+ return nil, err
+ }
+
+ h.Type, h.Length, err = s.readObjectTypeAndLength()
+ if err != nil {
+ return nil, err
+ }
+
+ switch h.Type {
+ case plumbing.OFSDeltaObject:
+ no, err := binary.ReadVariableWidthInt(s.r)
+ if err != nil {
+ return nil, err
+ }
+
+ h.OffsetReference = h.Offset - no
+ case plumbing.REFDeltaObject:
+ var err error
+ h.Reference, err = binary.ReadHash(s.r)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return h, nil
+}
+
+func (s *Scanner) doPending() error {
+ if s.version == 0 {
+ var err error
+ s.version, s.objects, err = s.Header()
+ if err != nil {
+ return err
+ }
+ }
+
+ return s.discardObjectIfNeeded()
+}
+
+func (s *Scanner) discardObjectIfNeeded() error {
+ if s.pendingObject == nil {
+ return nil
+ }
+
+ h := s.pendingObject
+ n, _, err := s.NextObject(stdioutil.Discard)
+ if err != nil {
+ return err
+ }
+
+ if n != h.Length {
+ return fmt.Errorf(
+ "error discarding object, discarded %d, expected %d",
+ n, h.Length,
+ )
+ }
+
+ return nil
+}
+
+// ReadObjectTypeAndLength reads and returns the object type and the
+// length field from an object entry in a packfile.
+func (s *Scanner) readObjectTypeAndLength() (plumbing.ObjectType, int64, error) {
+ t, c, err := s.readType()
+ if err != nil {
+ return t, 0, err
+ }
+
+ l, err := s.readLength(c)
+
+ return t, l, err
+}
+
+func (s *Scanner) readType() (plumbing.ObjectType, byte, error) {
+ var c byte
+ var err error
+ if c, err = s.r.ReadByte(); err != nil {
+ return plumbing.ObjectType(0), 0, err
+ }
+
+ typ := parseType(c)
+
+ return typ, c, nil
+}
+
+func parseType(b byte) plumbing.ObjectType {
+ return plumbing.ObjectType((b & maskType) >> firstLengthBits)
+}
+
+// the length is codified in the last 4 bits of the first byte and in
+// the last 7 bits of subsequent bytes. Last byte has a 0 MSB.
+func (s *Scanner) readLength(first byte) (int64, error) {
+ length := int64(first & maskFirstLength)
+
+ c := first
+ shift := firstLengthBits
+ var err error
+ for c&maskContinue > 0 {
+ if c, err = s.r.ReadByte(); err != nil {
+ return 0, err
+ }
+
+ length += int64(c&maskLength) << shift
+ shift += lengthBits
+ }
+
+ return length, nil
+}
+
+// NextObject writes the content of the next object into the reader, returns
+// the number of bytes written, the CRC32 of the content and an error, if any
+func (s *Scanner) NextObject(w io.Writer) (written int64, crc32 uint32, err error) {
+ defer s.crc.Reset()
+
+ s.pendingObject = nil
+ written, err = s.copyObject(w)
+ s.Flush()
+ crc32 = s.crc.Sum32()
+ return
+}
+
+// ReadRegularObject reads and write a non-deltified object
+// from it zlib stream in an object entry in the packfile.
+func (s *Scanner) copyObject(w io.Writer) (n int64, err error) {
+ if s.zr == nil {
+ var zr io.ReadCloser
+ zr, err = zlib.NewReader(s.r)
+ if err != nil {
+ return 0, fmt.Errorf("zlib initialization error: %s", err)
+ }
+
+ s.zr = zr.(readerResetter)
+ } else {
+ if err = s.zr.Reset(s.r, nil); err != nil {
+ return 0, fmt.Errorf("zlib reset error: %s", err)
+ }
+ }
+
+ defer ioutil.CheckClose(s.zr, &err)
+ buf := byteSlicePool.Get().([]byte)
+ n, err = io.CopyBuffer(w, s.zr, buf)
+ byteSlicePool.Put(buf)
+ return
+}
+
+var byteSlicePool = sync.Pool{
+ New: func() interface{} {
+ return make([]byte, 32*1024)
+ },
+}
+
+// SeekFromStart sets a new offset from start, returns the old position before
+// the change.
+func (s *Scanner) SeekFromStart(offset int64) (previous int64, err error) {
+ // if seeking we assume that you are not interested on the header
+ if s.version == 0 {
+ s.version = VersionSupported
+ }
+
+ previous, err = s.r.Seek(0, io.SeekCurrent)
+ if err != nil {
+ return -1, err
+ }
+
+ _, err = s.r.Seek(offset, io.SeekStart)
+ return previous, err
+}
+
+// Checksum returns the checksum of the packfile
+func (s *Scanner) Checksum() (plumbing.Hash, error) {
+ err := s.discardObjectIfNeeded()
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ return binary.ReadHash(s.r)
+}
+
+// Close reads the reader until io.EOF
+func (s *Scanner) Close() error {
+ buf := byteSlicePool.Get().([]byte)
+ _, err := io.CopyBuffer(stdioutil.Discard, s.r, buf)
+ byteSlicePool.Put(buf)
+ return err
+}
+
+// Flush finishes writing the buffer to crc hasher in case we are using
+// a teeReader. Otherwise it is a no-op.
+func (s *Scanner) Flush() error {
+ tee, ok := s.r.(*teeReader)
+ if ok {
+ return tee.Flush()
+ }
+ return nil
+}
+
+type trackableReader struct {
+ count int64
+ io.Reader
+}
+
+// Read reads up to len(p) bytes into p.
+func (r *trackableReader) Read(p []byte) (n int, err error) {
+ n, err = r.Reader.Read(p)
+ r.count += int64(n)
+
+ return
+}
+
+// Seek only supports io.SeekCurrent, any other operation fails
+func (r *trackableReader) Seek(offset int64, whence int) (int64, error) {
+ if whence != io.SeekCurrent {
+ return -1, ErrSeekNotSupported
+ }
+
+ return r.count, nil
+}
+
+func newByteReadSeeker(r io.ReadSeeker) *bufferedSeeker {
+ return &bufferedSeeker{
+ r: r,
+ Reader: *bufio.NewReader(r),
+ }
+}
+
+type bufferedSeeker struct {
+ r io.ReadSeeker
+ bufio.Reader
+}
+
+func (r *bufferedSeeker) Seek(offset int64, whence int) (int64, error) {
+ if whence == io.SeekCurrent {
+ current, err := r.r.Seek(offset, whence)
+ if err != nil {
+ return current, err
+ }
+
+ return current - int64(r.Buffered()), nil
+ }
+
+ defer r.Reader.Reset(r.r)
+ return r.r.Seek(offset, whence)
+}
+
+type readerResetter interface {
+ io.ReadCloser
+ zlib.Resetter
+}
+
+type reader interface {
+ io.Reader
+ io.ByteReader
+ io.Seeker
+}
+
+type teeReader struct {
+ reader
+ w hash.Hash32
+ bufWriter *bufio.Writer
+}
+
+func newTeeReader(r reader, h hash.Hash32) *teeReader {
+ return &teeReader{
+ reader: r,
+ w: h,
+ bufWriter: bufio.NewWriter(h),
+ }
+}
+
+func (r *teeReader) Read(p []byte) (n int, err error) {
+ r.Flush()
+
+ n, err = r.reader.Read(p)
+ if n > 0 {
+ if n, err := r.w.Write(p[:n]); err != nil {
+ return n, err
+ }
+ }
+ return
+}
+
+func (r *teeReader) ReadByte() (b byte, err error) {
+ b, err = r.reader.ReadByte()
+ if err == nil {
+ return b, r.bufWriter.WriteByte(b)
+ }
+
+ return
+}
+
+func (r *teeReader) Flush() (err error) {
+ return r.bufWriter.Flush()
+}
--- /dev/null
+// Package pktline implements reading payloads form pkt-lines and encoding
+// pkt-lines from payloads.
+package pktline
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+)
+
+// An Encoder writes pkt-lines to an output stream.
+type Encoder struct {
+ w io.Writer
+}
+
+const (
+ // MaxPayloadSize is the maximum payload size of a pkt-line in bytes.
+ MaxPayloadSize = 65516
+
+ // For compatibility with canonical Git implementation, accept longer pkt-lines
+ OversizePayloadMax = 65520
+)
+
+var (
+ // FlushPkt are the contents of a flush-pkt pkt-line.
+ FlushPkt = []byte{'0', '0', '0', '0'}
+ // Flush is the payload to use with the Encode method to encode a flush-pkt.
+ Flush = []byte{}
+ // FlushString is the payload to use with the EncodeString method to encode a flush-pkt.
+ FlushString = ""
+ // ErrPayloadTooLong is returned by the Encode methods when any of the
+ // provided payloads is bigger than MaxPayloadSize.
+ ErrPayloadTooLong = errors.New("payload is too long")
+)
+
+// NewEncoder returns a new encoder that writes to w.
+func NewEncoder(w io.Writer) *Encoder {
+ return &Encoder{
+ w: w,
+ }
+}
+
+// Flush encodes a flush-pkt to the output stream.
+func (e *Encoder) Flush() error {
+ _, err := e.w.Write(FlushPkt)
+ return err
+}
+
+// Encode encodes a pkt-line with the payload specified and write it to
+// the output stream. If several payloads are specified, each of them
+// will get streamed in their own pkt-lines.
+func (e *Encoder) Encode(payloads ...[]byte) error {
+ for _, p := range payloads {
+ if err := e.encodeLine(p); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (e *Encoder) encodeLine(p []byte) error {
+ if len(p) > MaxPayloadSize {
+ return ErrPayloadTooLong
+ }
+
+ if bytes.Equal(p, Flush) {
+ return e.Flush()
+ }
+
+ n := len(p) + 4
+ if _, err := e.w.Write(asciiHex16(n)); err != nil {
+ return err
+ }
+ _, err := e.w.Write(p)
+ return err
+}
+
+// Returns the hexadecimal ascii representation of the 16 less
+// significant bits of n. The length of the returned slice will always
+// be 4. Example: if n is 1234 (0x4d2), the return value will be
+// []byte{'0', '4', 'd', '2'}.
+func asciiHex16(n int) []byte {
+ var ret [4]byte
+ ret[0] = byteToASCIIHex(byte(n & 0xf000 >> 12))
+ ret[1] = byteToASCIIHex(byte(n & 0x0f00 >> 8))
+ ret[2] = byteToASCIIHex(byte(n & 0x00f0 >> 4))
+ ret[3] = byteToASCIIHex(byte(n & 0x000f))
+
+ return ret[:]
+}
+
+// turns a byte into its hexadecimal ascii representation. Example:
+// from 11 (0xb) to 'b'.
+func byteToASCIIHex(n byte) byte {
+ if n < 10 {
+ return '0' + n
+ }
+
+ return 'a' - 10 + n
+}
+
+// EncodeString works similarly as Encode but payloads are specified as strings.
+func (e *Encoder) EncodeString(payloads ...string) error {
+ for _, p := range payloads {
+ if err := e.Encode([]byte(p)); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Encodef encodes a single pkt-line with the payload formatted as
+// the format specifier. The rest of the arguments will be used in
+// the format string.
+func (e *Encoder) Encodef(format string, a ...interface{}) error {
+ return e.EncodeString(
+ fmt.Sprintf(format, a...),
+ )
+}
--- /dev/null
+package pktline
+
+import (
+ "errors"
+ "io"
+)
+
+const (
+ lenSize = 4
+)
+
+// ErrInvalidPktLen is returned by Err() when an invalid pkt-len is found.
+var ErrInvalidPktLen = errors.New("invalid pkt-len found")
+
+// Scanner provides a convenient interface for reading the payloads of a
+// series of pkt-lines. It takes an io.Reader providing the source,
+// which then can be tokenized through repeated calls to the Scan
+// method.
+//
+// After each Scan call, the Bytes method will return the payload of the
+// corresponding pkt-line on a shared buffer, which will be 65516 bytes
+// or smaller. Flush pkt-lines are represented by empty byte slices.
+//
+// Scanning stops at EOF or the first I/O error.
+type Scanner struct {
+ r io.Reader // The reader provided by the client
+ err error // Sticky error
+ payload []byte // Last pkt-payload
+ len [lenSize]byte // Last pkt-len
+}
+
+// NewScanner returns a new Scanner to read from r.
+func NewScanner(r io.Reader) *Scanner {
+ return &Scanner{
+ r: r,
+ }
+}
+
+// Err returns the first error encountered by the Scanner.
+func (s *Scanner) Err() error {
+ return s.err
+}
+
+// Scan advances the Scanner to the next pkt-line, whose payload will
+// then be available through the Bytes method. Scanning stops at EOF
+// or the first I/O error. After Scan returns false, the Err method
+// will return any error that occurred during scanning, except that if
+// it was io.EOF, Err will return nil.
+func (s *Scanner) Scan() bool {
+ var l int
+ l, s.err = s.readPayloadLen()
+ if s.err == io.EOF {
+ s.err = nil
+ return false
+ }
+ if s.err != nil {
+ return false
+ }
+
+ if cap(s.payload) < l {
+ s.payload = make([]byte, 0, l)
+ }
+
+ if _, s.err = io.ReadFull(s.r, s.payload[:l]); s.err != nil {
+ return false
+ }
+ s.payload = s.payload[:l]
+
+ return true
+}
+
+// Bytes returns the most recent payload generated by a call to Scan.
+// The underlying array may point to data that will be overwritten by a
+// subsequent call to Scan. It does no allocation.
+func (s *Scanner) Bytes() []byte {
+ return s.payload
+}
+
+// Method readPayloadLen returns the payload length by reading the
+// pkt-len and subtracting the pkt-len size.
+func (s *Scanner) readPayloadLen() (int, error) {
+ if _, err := io.ReadFull(s.r, s.len[:]); err != nil {
+ if err == io.ErrUnexpectedEOF {
+ return 0, ErrInvalidPktLen
+ }
+
+ return 0, err
+ }
+
+ n, err := hexDecode(s.len)
+ if err != nil {
+ return 0, err
+ }
+
+ switch {
+ case n == 0:
+ return 0, nil
+ case n <= lenSize:
+ return 0, ErrInvalidPktLen
+ case n > OversizePayloadMax+lenSize:
+ return 0, ErrInvalidPktLen
+ default:
+ return n - lenSize, nil
+ }
+}
+
+// Turns the hexadecimal representation of a number in a byte slice into
+// a number. This function substitute strconv.ParseUint(string(buf), 16,
+// 16) and/or hex.Decode, to avoid generating new strings, thus helping the
+// GC.
+func hexDecode(buf [lenSize]byte) (int, error) {
+ var ret int
+ for i := 0; i < lenSize; i++ {
+ n, err := asciiHexToByte(buf[i])
+ if err != nil {
+ return 0, ErrInvalidPktLen
+ }
+ ret = 16*ret + int(n)
+ }
+ return ret, nil
+}
+
+// turns the hexadecimal ascii representation of a byte into its
+// numerical value. Example: from 'b' to 11 (0xb).
+func asciiHexToByte(b byte) (byte, error) {
+ switch {
+ case b >= '0' && b <= '9':
+ return b - '0', nil
+ case b >= 'a' && b <= 'f':
+ return b - 'a' + 10, nil
+ default:
+ return 0, ErrInvalidPktLen
+ }
+}
--- /dev/null
+package plumbing
+
+import (
+ "bytes"
+ "crypto/sha1"
+ "encoding/hex"
+ "hash"
+ "sort"
+ "strconv"
+)
+
+// Hash SHA1 hased content
+type Hash [20]byte
+
+// ZeroHash is Hash with value zero
+var ZeroHash Hash
+
+// ComputeHash compute the hash for a given ObjectType and content
+func ComputeHash(t ObjectType, content []byte) Hash {
+ h := NewHasher(t, int64(len(content)))
+ h.Write(content)
+ return h.Sum()
+}
+
+// NewHash return a new Hash from a hexadecimal hash representation
+func NewHash(s string) Hash {
+ b, _ := hex.DecodeString(s)
+
+ var h Hash
+ copy(h[:], b)
+
+ return h
+}
+
+func (h Hash) IsZero() bool {
+ var empty Hash
+ return h == empty
+}
+
+func (h Hash) String() string {
+ return hex.EncodeToString(h[:])
+}
+
+type Hasher struct {
+ hash.Hash
+}
+
+func NewHasher(t ObjectType, size int64) Hasher {
+ h := Hasher{sha1.New()}
+ h.Write(t.Bytes())
+ h.Write([]byte(" "))
+ h.Write([]byte(strconv.FormatInt(size, 10)))
+ h.Write([]byte{0})
+ return h
+}
+
+func (h Hasher) Sum() (hash Hash) {
+ copy(hash[:], h.Hash.Sum(nil))
+ return
+}
+
+// HashesSort sorts a slice of Hashes in increasing order.
+func HashesSort(a []Hash) {
+ sort.Sort(HashSlice(a))
+}
+
+// HashSlice attaches the methods of sort.Interface to []Hash, sorting in
+// increasing order.
+type HashSlice []Hash
+
+func (p HashSlice) Len() int { return len(p) }
+func (p HashSlice) Less(i, j int) bool { return bytes.Compare(p[i][:], p[j][:]) < 0 }
+func (p HashSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
--- /dev/null
+package plumbing
+
+import (
+ "bytes"
+ "io"
+ "io/ioutil"
+)
+
+// MemoryObject on memory Object implementation
+type MemoryObject struct {
+ t ObjectType
+ h Hash
+ cont []byte
+ sz int64
+}
+
+// Hash returns the object Hash, the hash is calculated on-the-fly the first
+// time it's called, in all subsequent calls the same Hash is returned even
+// if the type or the content have changed. The Hash is only generated if the
+// size of the content is exactly the object size.
+func (o *MemoryObject) Hash() Hash {
+ if o.h == ZeroHash && int64(len(o.cont)) == o.sz {
+ o.h = ComputeHash(o.t, o.cont)
+ }
+
+ return o.h
+}
+
+// Type return the ObjectType
+func (o *MemoryObject) Type() ObjectType { return o.t }
+
+// SetType sets the ObjectType
+func (o *MemoryObject) SetType(t ObjectType) { o.t = t }
+
+// Size return the size of the object
+func (o *MemoryObject) Size() int64 { return o.sz }
+
+// SetSize set the object size, a content of the given size should be written
+// afterwards
+func (o *MemoryObject) SetSize(s int64) { o.sz = s }
+
+// Reader returns a ObjectReader used to read the object's content.
+func (o *MemoryObject) Reader() (io.ReadCloser, error) {
+ return ioutil.NopCloser(bytes.NewBuffer(o.cont)), nil
+}
+
+// Writer returns a ObjectWriter used to write the object's content.
+func (o *MemoryObject) Writer() (io.WriteCloser, error) {
+ return o, nil
+}
+
+func (o *MemoryObject) Write(p []byte) (n int, err error) {
+ o.cont = append(o.cont, p...)
+ o.sz = int64(len(o.cont))
+
+ return len(p), nil
+}
+
+// Close releases any resources consumed by the object when it is acting as a
+// ObjectWriter.
+func (o *MemoryObject) Close() error { return nil }
--- /dev/null
+// package plumbing implement the core interfaces and structs used by go-git
+package plumbing
+
+import (
+ "errors"
+ "io"
+)
+
+var (
+ ErrObjectNotFound = errors.New("object not found")
+ // ErrInvalidType is returned when an invalid object type is provided.
+ ErrInvalidType = errors.New("invalid object type")
+)
+
+// Object is a generic representation of any git object
+type EncodedObject interface {
+ Hash() Hash
+ Type() ObjectType
+ SetType(ObjectType)
+ Size() int64
+ SetSize(int64)
+ Reader() (io.ReadCloser, error)
+ Writer() (io.WriteCloser, error)
+}
+
+// DeltaObject is an EncodedObject representing a delta.
+type DeltaObject interface {
+ EncodedObject
+ // BaseHash returns the hash of the object used as base for this delta.
+ BaseHash() Hash
+ // ActualHash returns the hash of the object after applying the delta.
+ ActualHash() Hash
+ // Size returns the size of the object after applying the delta.
+ ActualSize() int64
+}
+
+// ObjectType internal object type
+// Integer values from 0 to 7 map to those exposed by git.
+// AnyObject is used to represent any from 0 to 7.
+type ObjectType int8
+
+const (
+ InvalidObject ObjectType = 0
+ CommitObject ObjectType = 1
+ TreeObject ObjectType = 2
+ BlobObject ObjectType = 3
+ TagObject ObjectType = 4
+ // 5 reserved for future expansion
+ OFSDeltaObject ObjectType = 6
+ REFDeltaObject ObjectType = 7
+
+ AnyObject ObjectType = -127
+)
+
+func (t ObjectType) String() string {
+ switch t {
+ case CommitObject:
+ return "commit"
+ case TreeObject:
+ return "tree"
+ case BlobObject:
+ return "blob"
+ case TagObject:
+ return "tag"
+ case OFSDeltaObject:
+ return "ofs-delta"
+ case REFDeltaObject:
+ return "ref-delta"
+ case AnyObject:
+ return "any"
+ default:
+ return "unknown"
+ }
+}
+
+func (t ObjectType) Bytes() []byte {
+ return []byte(t.String())
+}
+
+// Valid returns true if t is a valid ObjectType.
+func (t ObjectType) Valid() bool {
+ return t >= CommitObject && t <= REFDeltaObject
+}
+
+// IsDelta returns true for any ObjectTyoe that represents a delta (i.e.
+// REFDeltaObject or OFSDeltaObject).
+func (t ObjectType) IsDelta() bool {
+ return t == REFDeltaObject || t == OFSDeltaObject
+}
+
+// ParseObjectType parses a string representation of ObjectType. It returns an
+// error on parse failure.
+func ParseObjectType(value string) (typ ObjectType, err error) {
+ switch value {
+ case "commit":
+ typ = CommitObject
+ case "tree":
+ typ = TreeObject
+ case "blob":
+ typ = BlobObject
+ case "tag":
+ typ = TagObject
+ case "ofs-delta":
+ typ = OFSDeltaObject
+ case "ref-delta":
+ typ = REFDeltaObject
+ default:
+ err = ErrInvalidType
+ }
+ return
+}
--- /dev/null
+package object
+
+import (
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+// Blob is used to store arbitrary data - it is generally a file.
+type Blob struct {
+ // Hash of the blob.
+ Hash plumbing.Hash
+ // Size of the (uncompressed) blob.
+ Size int64
+
+ obj plumbing.EncodedObject
+}
+
+// GetBlob gets a blob from an object storer and decodes it.
+func GetBlob(s storer.EncodedObjectStorer, h plumbing.Hash) (*Blob, error) {
+ o, err := s.EncodedObject(plumbing.BlobObject, h)
+ if err != nil {
+ return nil, err
+ }
+
+ return DecodeBlob(o)
+}
+
+// DecodeObject decodes an encoded object into a *Blob.
+func DecodeBlob(o plumbing.EncodedObject) (*Blob, error) {
+ b := &Blob{}
+ if err := b.Decode(o); err != nil {
+ return nil, err
+ }
+
+ return b, nil
+}
+
+// ID returns the object ID of the blob. The returned value will always match
+// the current value of Blob.Hash.
+//
+// ID is present to fulfill the Object interface.
+func (b *Blob) ID() plumbing.Hash {
+ return b.Hash
+}
+
+// Type returns the type of object. It always returns plumbing.BlobObject.
+//
+// Type is present to fulfill the Object interface.
+func (b *Blob) Type() plumbing.ObjectType {
+ return plumbing.BlobObject
+}
+
+// Decode transforms a plumbing.EncodedObject into a Blob struct.
+func (b *Blob) Decode(o plumbing.EncodedObject) error {
+ if o.Type() != plumbing.BlobObject {
+ return ErrUnsupportedObject
+ }
+
+ b.Hash = o.Hash()
+ b.Size = o.Size()
+ b.obj = o
+
+ return nil
+}
+
+// Encode transforms a Blob into a plumbing.EncodedObject.
+func (b *Blob) Encode(o plumbing.EncodedObject) (err error) {
+ o.SetType(plumbing.BlobObject)
+
+ w, err := o.Writer()
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(w, &err)
+
+ r, err := b.Reader()
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(r, &err)
+
+ _, err = io.Copy(w, r)
+ return err
+}
+
+// Reader returns a reader allow the access to the content of the blob
+func (b *Blob) Reader() (io.ReadCloser, error) {
+ return b.obj.Reader()
+}
+
+// BlobIter provides an iterator for a set of blobs.
+type BlobIter struct {
+ storer.EncodedObjectIter
+ s storer.EncodedObjectStorer
+}
+
+// NewBlobIter takes a storer.EncodedObjectStorer and a
+// storer.EncodedObjectIter and returns a *BlobIter that iterates over all
+// blobs contained in the storer.EncodedObjectIter.
+//
+// Any non-blob object returned by the storer.EncodedObjectIter is skipped.
+func NewBlobIter(s storer.EncodedObjectStorer, iter storer.EncodedObjectIter) *BlobIter {
+ return &BlobIter{iter, s}
+}
+
+// Next moves the iterator to the next blob and returns a pointer to it. If
+// there are no more blobs, it returns io.EOF.
+func (iter *BlobIter) Next() (*Blob, error) {
+ for {
+ obj, err := iter.EncodedObjectIter.Next()
+ if err != nil {
+ return nil, err
+ }
+
+ if obj.Type() != plumbing.BlobObject {
+ continue
+ }
+
+ return DecodeBlob(obj)
+ }
+}
+
+// ForEach call the cb function for each blob contained on this iter until
+// an error happens or the end of the iter is reached. If ErrStop is sent
+// the iteration is stop but no error is returned. The iterator is closed.
+func (iter *BlobIter) ForEach(cb func(*Blob) error) error {
+ return iter.EncodedObjectIter.ForEach(func(obj plumbing.EncodedObject) error {
+ if obj.Type() != plumbing.BlobObject {
+ return nil
+ }
+
+ b, err := DecodeBlob(obj)
+ if err != nil {
+ return err
+ }
+
+ return cb(b)
+ })
+}
--- /dev/null
+package object
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "strings"
+
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie"
+)
+
+// Change values represent a detected change between two git trees. For
+// modifications, From is the original status of the node and To is its
+// final status. For insertions, From is the zero value and for
+// deletions To is the zero value.
+type Change struct {
+ From ChangeEntry
+ To ChangeEntry
+}
+
+var empty = ChangeEntry{}
+
+// Action returns the kind of action represented by the change, an
+// insertion, a deletion or a modification.
+func (c *Change) Action() (merkletrie.Action, error) {
+ if c.From == empty && c.To == empty {
+ return merkletrie.Action(0),
+ fmt.Errorf("malformed change: empty from and to")
+ }
+ if c.From == empty {
+ return merkletrie.Insert, nil
+ }
+ if c.To == empty {
+ return merkletrie.Delete, nil
+ }
+
+ return merkletrie.Modify, nil
+}
+
+// Files return the files before and after a change.
+// For insertions from will be nil. For deletions to will be nil.
+func (c *Change) Files() (from, to *File, err error) {
+ action, err := c.Action()
+ if err != nil {
+ return
+ }
+
+ if action == merkletrie.Insert || action == merkletrie.Modify {
+ to, err = c.To.Tree.TreeEntryFile(&c.To.TreeEntry)
+ if !c.To.TreeEntry.Mode.IsFile() {
+ return nil, nil, nil
+ }
+
+ if err != nil {
+ return
+ }
+ }
+
+ if action == merkletrie.Delete || action == merkletrie.Modify {
+ from, err = c.From.Tree.TreeEntryFile(&c.From.TreeEntry)
+ if !c.From.TreeEntry.Mode.IsFile() {
+ return nil, nil, nil
+ }
+
+ if err != nil {
+ return
+ }
+ }
+
+ return
+}
+
+func (c *Change) String() string {
+ action, err := c.Action()
+ if err != nil {
+ return fmt.Sprintf("malformed change")
+ }
+
+ return fmt.Sprintf("<Action: %s, Path: %s>", action, c.name())
+}
+
+// Patch returns a Patch with all the file changes in chunks. This
+// representation can be used to create several diff outputs.
+func (c *Change) Patch() (*Patch, error) {
+ return c.PatchContext(context.Background())
+}
+
+// Patch returns a Patch with all the file changes in chunks. This
+// representation can be used to create several diff outputs.
+// If context expires, an non-nil error will be returned
+// Provided context must be non-nil
+func (c *Change) PatchContext(ctx context.Context) (*Patch, error) {
+ return getPatchContext(ctx, "", c)
+}
+
+func (c *Change) name() string {
+ if c.From != empty {
+ return c.From.Name
+ }
+
+ return c.To.Name
+}
+
+// ChangeEntry values represent a node that has suffered a change.
+type ChangeEntry struct {
+ // Full path of the node using "/" as separator.
+ Name string
+ // Parent tree of the node that has changed.
+ Tree *Tree
+ // The entry of the node.
+ TreeEntry TreeEntry
+}
+
+// Changes represents a collection of changes between two git trees.
+// Implements sort.Interface lexicographically over the path of the
+// changed files.
+type Changes []*Change
+
+func (c Changes) Len() int {
+ return len(c)
+}
+
+func (c Changes) Swap(i, j int) {
+ c[i], c[j] = c[j], c[i]
+}
+
+func (c Changes) Less(i, j int) bool {
+ return strings.Compare(c[i].name(), c[j].name()) < 0
+}
+
+func (c Changes) String() string {
+ var buffer bytes.Buffer
+ buffer.WriteString("[")
+ comma := ""
+ for _, v := range c {
+ buffer.WriteString(comma)
+ buffer.WriteString(v.String())
+ comma = ", "
+ }
+ buffer.WriteString("]")
+
+ return buffer.String()
+}
+
+// Patch returns a Patch with all the changes in chunks. This
+// representation can be used to create several diff outputs.
+func (c Changes) Patch() (*Patch, error) {
+ return c.PatchContext(context.Background())
+}
+
+// Patch returns a Patch with all the changes in chunks. This
+// representation can be used to create several diff outputs.
+// If context expires, an non-nil error will be returned
+// Provided context must be non-nil
+func (c Changes) PatchContext(ctx context.Context) (*Patch, error) {
+ return getPatchContext(ctx, "", c...)
+}
--- /dev/null
+package object
+
+import (
+ "errors"
+ "fmt"
+
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie"
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie/noder"
+)
+
+// The following functions transform changes types form the merkletrie
+// package to changes types from this package.
+
+func newChange(c merkletrie.Change) (*Change, error) {
+ ret := &Change{}
+
+ var err error
+ if ret.From, err = newChangeEntry(c.From); err != nil {
+ return nil, fmt.Errorf("From field: %s", err)
+ }
+
+ if ret.To, err = newChangeEntry(c.To); err != nil {
+ return nil, fmt.Errorf("To field: %s", err)
+ }
+
+ return ret, nil
+}
+
+func newChangeEntry(p noder.Path) (ChangeEntry, error) {
+ if p == nil {
+ return empty, nil
+ }
+
+ asTreeNoder, ok := p.Last().(*treeNoder)
+ if !ok {
+ return ChangeEntry{}, errors.New("cannot transform non-TreeNoders")
+ }
+
+ return ChangeEntry{
+ Name: p.String(),
+ Tree: asTreeNoder.parent,
+ TreeEntry: TreeEntry{
+ Name: asTreeNoder.name,
+ Mode: asTreeNoder.mode,
+ Hash: asTreeNoder.hash,
+ },
+ }, nil
+}
+
+func newChanges(src merkletrie.Changes) (Changes, error) {
+ ret := make(Changes, len(src))
+ var err error
+ for i, e := range src {
+ ret[i], err = newChange(e)
+ if err != nil {
+ return nil, fmt.Errorf("change #%d: %s", i, err)
+ }
+ }
+
+ return ret, nil
+}
--- /dev/null
+package object
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "strings"
+
+ "golang.org/x/crypto/openpgp"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+const (
+ beginpgp string = "-----BEGIN PGP SIGNATURE-----"
+ endpgp string = "-----END PGP SIGNATURE-----"
+ headerpgp string = "gpgsig"
+)
+
+// Hash represents the hash of an object
+type Hash plumbing.Hash
+
+// Commit points to a single tree, marking it as what the project looked like
+// at a certain point in time. It contains meta-information about that point
+// in time, such as a timestamp, the author of the changes since the last
+// commit, a pointer to the previous commit(s), etc.
+// http://shafiulazam.com/gitbook/1_the_git_object_model.html
+type Commit struct {
+ // Hash of the commit object.
+ Hash plumbing.Hash
+ // Author is the original author of the commit.
+ Author Signature
+ // Committer is the one performing the commit, might be different from
+ // Author.
+ Committer Signature
+ // PGPSignature is the PGP signature of the commit.
+ PGPSignature string
+ // Message is the commit message, contains arbitrary text.
+ Message string
+ // TreeHash is the hash of the root tree of the commit.
+ TreeHash plumbing.Hash
+ // ParentHashes are the hashes of the parent commits of the commit.
+ ParentHashes []plumbing.Hash
+
+ s storer.EncodedObjectStorer
+}
+
+// GetCommit gets a commit from an object storer and decodes it.
+func GetCommit(s storer.EncodedObjectStorer, h plumbing.Hash) (*Commit, error) {
+ o, err := s.EncodedObject(plumbing.CommitObject, h)
+ if err != nil {
+ return nil, err
+ }
+
+ return DecodeCommit(s, o)
+}
+
+// DecodeCommit decodes an encoded object into a *Commit and associates it to
+// the given object storer.
+func DecodeCommit(s storer.EncodedObjectStorer, o plumbing.EncodedObject) (*Commit, error) {
+ c := &Commit{s: s}
+ if err := c.Decode(o); err != nil {
+ return nil, err
+ }
+
+ return c, nil
+}
+
+// Tree returns the Tree from the commit.
+func (c *Commit) Tree() (*Tree, error) {
+ return GetTree(c.s, c.TreeHash)
+}
+
+// Patch returns the Patch between the actual commit and the provided one.
+// Error will be return if context expires. Provided context must be non-nil
+func (c *Commit) PatchContext(ctx context.Context, to *Commit) (*Patch, error) {
+ fromTree, err := c.Tree()
+ if err != nil {
+ return nil, err
+ }
+
+ toTree, err := to.Tree()
+ if err != nil {
+ return nil, err
+ }
+
+ return fromTree.PatchContext(ctx, toTree)
+}
+
+// Patch returns the Patch between the actual commit and the provided one.
+func (c *Commit) Patch(to *Commit) (*Patch, error) {
+ return c.PatchContext(context.Background(), to)
+}
+
+// Parents return a CommitIter to the parent Commits.
+func (c *Commit) Parents() CommitIter {
+ return NewCommitIter(c.s,
+ storer.NewEncodedObjectLookupIter(c.s, plumbing.CommitObject, c.ParentHashes),
+ )
+}
+
+// NumParents returns the number of parents in a commit.
+func (c *Commit) NumParents() int {
+ return len(c.ParentHashes)
+}
+
+var ErrParentNotFound = errors.New("commit parent not found")
+
+// Parent returns the ith parent of a commit.
+func (c *Commit) Parent(i int) (*Commit, error) {
+ if len(c.ParentHashes) == 0 || i > len(c.ParentHashes)-1 {
+ return nil, ErrParentNotFound
+ }
+
+ return GetCommit(c.s, c.ParentHashes[i])
+}
+
+// File returns the file with the specified "path" in the commit and a
+// nil error if the file exists. If the file does not exist, it returns
+// a nil file and the ErrFileNotFound error.
+func (c *Commit) File(path string) (*File, error) {
+ tree, err := c.Tree()
+ if err != nil {
+ return nil, err
+ }
+
+ return tree.File(path)
+}
+
+// Files returns a FileIter allowing to iterate over the Tree
+func (c *Commit) Files() (*FileIter, error) {
+ tree, err := c.Tree()
+ if err != nil {
+ return nil, err
+ }
+
+ return tree.Files(), nil
+}
+
+// ID returns the object ID of the commit. The returned value will always match
+// the current value of Commit.Hash.
+//
+// ID is present to fulfill the Object interface.
+func (c *Commit) ID() plumbing.Hash {
+ return c.Hash
+}
+
+// Type returns the type of object. It always returns plumbing.CommitObject.
+//
+// Type is present to fulfill the Object interface.
+func (c *Commit) Type() plumbing.ObjectType {
+ return plumbing.CommitObject
+}
+
+// Decode transforms a plumbing.EncodedObject into a Commit struct.
+func (c *Commit) Decode(o plumbing.EncodedObject) (err error) {
+ if o.Type() != plumbing.CommitObject {
+ return ErrUnsupportedObject
+ }
+
+ c.Hash = o.Hash()
+
+ reader, err := o.Reader()
+ if err != nil {
+ return err
+ }
+ defer ioutil.CheckClose(reader, &err)
+
+ r := bufio.NewReader(reader)
+
+ var message bool
+ var pgpsig bool
+ for {
+ line, err := r.ReadBytes('\n')
+ if err != nil && err != io.EOF {
+ return err
+ }
+
+ if pgpsig {
+ if len(line) > 0 && line[0] == ' ' {
+ line = bytes.TrimLeft(line, " ")
+ c.PGPSignature += string(line)
+ continue
+ } else {
+ pgpsig = false
+ }
+ }
+
+ if !message {
+ line = bytes.TrimSpace(line)
+ if len(line) == 0 {
+ message = true
+ continue
+ }
+
+ split := bytes.SplitN(line, []byte{' '}, 2)
+
+ var data []byte
+ if len(split) == 2 {
+ data = split[1]
+ }
+
+ switch string(split[0]) {
+ case "tree":
+ c.TreeHash = plumbing.NewHash(string(data))
+ case "parent":
+ c.ParentHashes = append(c.ParentHashes, plumbing.NewHash(string(data)))
+ case "author":
+ c.Author.Decode(data)
+ case "committer":
+ c.Committer.Decode(data)
+ case headerpgp:
+ c.PGPSignature += string(data) + "\n"
+ pgpsig = true
+ }
+ } else {
+ c.Message += string(line)
+ }
+
+ if err == io.EOF {
+ return nil
+ }
+ }
+}
+
+// Encode transforms a Commit into a plumbing.EncodedObject.
+func (b *Commit) Encode(o plumbing.EncodedObject) error {
+ return b.encode(o, true)
+}
+
+func (b *Commit) encode(o plumbing.EncodedObject, includeSig bool) (err error) {
+ o.SetType(plumbing.CommitObject)
+ w, err := o.Writer()
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(w, &err)
+
+ if _, err = fmt.Fprintf(w, "tree %s\n", b.TreeHash.String()); err != nil {
+ return err
+ }
+
+ for _, parent := range b.ParentHashes {
+ if _, err = fmt.Fprintf(w, "parent %s\n", parent.String()); err != nil {
+ return err
+ }
+ }
+
+ if _, err = fmt.Fprint(w, "author "); err != nil {
+ return err
+ }
+
+ if err = b.Author.Encode(w); err != nil {
+ return err
+ }
+
+ if _, err = fmt.Fprint(w, "\ncommitter "); err != nil {
+ return err
+ }
+
+ if err = b.Committer.Encode(w); err != nil {
+ return err
+ }
+
+ if b.PGPSignature != "" && includeSig {
+ if _, err = fmt.Fprint(w, "\n"+headerpgp+" "); err != nil {
+ return err
+ }
+
+ // Split all the signature lines and re-write with a left padding and
+ // newline. Use join for this so it's clear that a newline should not be
+ // added after this section, as it will be added when the message is
+ // printed.
+ signature := strings.TrimSuffix(b.PGPSignature, "\n")
+ lines := strings.Split(signature, "\n")
+ if _, err = fmt.Fprint(w, strings.Join(lines, "\n ")); err != nil {
+ return err
+ }
+ }
+
+ if _, err = fmt.Fprintf(w, "\n\n%s", b.Message); err != nil {
+ return err
+ }
+
+ return err
+}
+
+// Stats shows the status of commit.
+func (c *Commit) Stats() (FileStats, error) {
+ // Get the previous commit.
+ ci := c.Parents()
+ parentCommit, err := ci.Next()
+ if err != nil {
+ if err == io.EOF {
+ emptyNoder := treeNoder{}
+ parentCommit = &Commit{
+ Hash: emptyNoder.hash,
+ // TreeHash: emptyNoder.parent.Hash,
+ s: c.s,
+ }
+ } else {
+ return nil, err
+ }
+ }
+
+ patch, err := parentCommit.Patch(c)
+ if err != nil {
+ return nil, err
+ }
+
+ return getFileStatsFromFilePatches(patch.FilePatches()), nil
+}
+
+func (c *Commit) String() string {
+ return fmt.Sprintf(
+ "%s %s\nAuthor: %s\nDate: %s\n\n%s\n",
+ plumbing.CommitObject, c.Hash, c.Author.String(),
+ c.Author.When.Format(DateFormat), indent(c.Message),
+ )
+}
+
+// Verify performs PGP verification of the commit with a provided armored
+// keyring and returns openpgp.Entity associated with verifying key on success.
+func (c *Commit) Verify(armoredKeyRing string) (*openpgp.Entity, error) {
+ keyRingReader := strings.NewReader(armoredKeyRing)
+ keyring, err := openpgp.ReadArmoredKeyRing(keyRingReader)
+ if err != nil {
+ return nil, err
+ }
+
+ // Extract signature.
+ signature := strings.NewReader(c.PGPSignature)
+
+ encoded := &plumbing.MemoryObject{}
+ // Encode commit components, excluding signature and get a reader object.
+ if err := c.encode(encoded, false); err != nil {
+ return nil, err
+ }
+ er, err := encoded.Reader()
+ if err != nil {
+ return nil, err
+ }
+
+ return openpgp.CheckArmoredDetachedSignature(keyring, er, signature)
+}
+
+func indent(t string) string {
+ var output []string
+ for _, line := range strings.Split(t, "\n") {
+ if len(line) != 0 {
+ line = " " + line
+ }
+
+ output = append(output, line)
+ }
+
+ return strings.Join(output, "\n")
+}
+
+// CommitIter is a generic closable interface for iterating over commits.
+type CommitIter interface {
+ Next() (*Commit, error)
+ ForEach(func(*Commit) error) error
+ Close()
+}
+
+// storerCommitIter provides an iterator from commits in an EncodedObjectStorer.
+type storerCommitIter struct {
+ storer.EncodedObjectIter
+ s storer.EncodedObjectStorer
+}
+
+// NewCommitIter takes a storer.EncodedObjectStorer and a
+// storer.EncodedObjectIter and returns a CommitIter that iterates over all
+// commits contained in the storer.EncodedObjectIter.
+//
+// Any non-commit object returned by the storer.EncodedObjectIter is skipped.
+func NewCommitIter(s storer.EncodedObjectStorer, iter storer.EncodedObjectIter) CommitIter {
+ return &storerCommitIter{iter, s}
+}
+
+// Next moves the iterator to the next commit and returns a pointer to it. If
+// there are no more commits, it returns io.EOF.
+func (iter *storerCommitIter) Next() (*Commit, error) {
+ obj, err := iter.EncodedObjectIter.Next()
+ if err != nil {
+ return nil, err
+ }
+
+ return DecodeCommit(iter.s, obj)
+}
+
+// ForEach call the cb function for each commit contained on this iter until
+// an error appends or the end of the iter is reached. If ErrStop is sent
+// the iteration is stopped but no error is returned. The iterator is closed.
+func (iter *storerCommitIter) ForEach(cb func(*Commit) error) error {
+ return iter.EncodedObjectIter.ForEach(func(obj plumbing.EncodedObject) error {
+ c, err := DecodeCommit(iter.s, obj)
+ if err != nil {
+ return err
+ }
+
+ return cb(c)
+ })
+}
+
+func (iter *storerCommitIter) Close() {
+ iter.EncodedObjectIter.Close()
+}
--- /dev/null
+package object
+
+import (
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+)
+
+type commitPreIterator struct {
+ seenExternal map[plumbing.Hash]bool
+ seen map[plumbing.Hash]bool
+ stack []CommitIter
+ start *Commit
+}
+
+// NewCommitPreorderIter returns a CommitIter that walks the commit history,
+// starting at the given commit and visiting its parents in pre-order.
+// The given callback will be called for each visited commit. Each commit will
+// be visited only once. If the callback returns an error, walking will stop
+// and will return the error. Other errors might be returned if the history
+// cannot be traversed (e.g. missing objects). Ignore allows to skip some
+// commits from being iterated.
+func NewCommitPreorderIter(
+ c *Commit,
+ seenExternal map[plumbing.Hash]bool,
+ ignore []plumbing.Hash,
+) CommitIter {
+ seen := make(map[plumbing.Hash]bool)
+ for _, h := range ignore {
+ seen[h] = true
+ }
+
+ return &commitPreIterator{
+ seenExternal: seenExternal,
+ seen: seen,
+ stack: make([]CommitIter, 0),
+ start: c,
+ }
+}
+
+func (w *commitPreIterator) Next() (*Commit, error) {
+ var c *Commit
+ for {
+ if w.start != nil {
+ c = w.start
+ w.start = nil
+ } else {
+ current := len(w.stack) - 1
+ if current < 0 {
+ return nil, io.EOF
+ }
+
+ var err error
+ c, err = w.stack[current].Next()
+ if err == io.EOF {
+ w.stack = w.stack[:current]
+ continue
+ }
+
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if w.seen[c.Hash] || w.seenExternal[c.Hash] {
+ continue
+ }
+
+ w.seen[c.Hash] = true
+
+ if c.NumParents() > 0 {
+ w.stack = append(w.stack, filteredParentIter(c, w.seen))
+ }
+
+ return c, nil
+ }
+}
+
+func filteredParentIter(c *Commit, seen map[plumbing.Hash]bool) CommitIter {
+ var hashes []plumbing.Hash
+ for _, h := range c.ParentHashes {
+ if !seen[h] {
+ hashes = append(hashes, h)
+ }
+ }
+
+ return NewCommitIter(c.s,
+ storer.NewEncodedObjectLookupIter(c.s, plumbing.CommitObject, hashes),
+ )
+}
+
+func (w *commitPreIterator) ForEach(cb func(*Commit) error) error {
+ for {
+ c, err := w.Next()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return err
+ }
+
+ err = cb(c)
+ if err == storer.ErrStop {
+ break
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (w *commitPreIterator) Close() {}
+
+type commitPostIterator struct {
+ stack []*Commit
+ seen map[plumbing.Hash]bool
+}
+
+// NewCommitPostorderIter returns a CommitIter that walks the commit
+// history like WalkCommitHistory but in post-order. This means that after
+// walking a merge commit, the merged commit will be walked before the base
+// it was merged on. This can be useful if you wish to see the history in
+// chronological order. Ignore allows to skip some commits from being iterated.
+func NewCommitPostorderIter(c *Commit, ignore []plumbing.Hash) CommitIter {
+ seen := make(map[plumbing.Hash]bool)
+ for _, h := range ignore {
+ seen[h] = true
+ }
+
+ return &commitPostIterator{
+ stack: []*Commit{c},
+ seen: seen,
+ }
+}
+
+func (w *commitPostIterator) Next() (*Commit, error) {
+ for {
+ if len(w.stack) == 0 {
+ return nil, io.EOF
+ }
+
+ c := w.stack[len(w.stack)-1]
+ w.stack = w.stack[:len(w.stack)-1]
+
+ if w.seen[c.Hash] {
+ continue
+ }
+
+ w.seen[c.Hash] = true
+
+ return c, c.Parents().ForEach(func(p *Commit) error {
+ w.stack = append(w.stack, p)
+ return nil
+ })
+ }
+}
+
+func (w *commitPostIterator) ForEach(cb func(*Commit) error) error {
+ for {
+ c, err := w.Next()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return err
+ }
+
+ err = cb(c)
+ if err == storer.ErrStop {
+ break
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (w *commitPostIterator) Close() {}
--- /dev/null
+package object
+
+import (
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+)
+
+type bfsCommitIterator struct {
+ seenExternal map[plumbing.Hash]bool
+ seen map[plumbing.Hash]bool
+ queue []*Commit
+}
+
+// NewCommitIterBSF returns a CommitIter that walks the commit history,
+// starting at the given commit and visiting its parents in pre-order.
+// The given callback will be called for each visited commit. Each commit will
+// be visited only once. If the callback returns an error, walking will stop
+// and will return the error. Other errors might be returned if the history
+// cannot be traversed (e.g. missing objects). Ignore allows to skip some
+// commits from being iterated.
+func NewCommitIterBSF(
+ c *Commit,
+ seenExternal map[plumbing.Hash]bool,
+ ignore []plumbing.Hash,
+) CommitIter {
+ seen := make(map[plumbing.Hash]bool)
+ for _, h := range ignore {
+ seen[h] = true
+ }
+
+ return &bfsCommitIterator{
+ seenExternal: seenExternal,
+ seen: seen,
+ queue: []*Commit{c},
+ }
+}
+
+func (w *bfsCommitIterator) appendHash(store storer.EncodedObjectStorer, h plumbing.Hash) error {
+ if w.seen[h] || w.seenExternal[h] {
+ return nil
+ }
+ c, err := GetCommit(store, h)
+ if err != nil {
+ return err
+ }
+ w.queue = append(w.queue, c)
+ return nil
+}
+
+func (w *bfsCommitIterator) Next() (*Commit, error) {
+ var c *Commit
+ for {
+ if len(w.queue) == 0 {
+ return nil, io.EOF
+ }
+ c = w.queue[0]
+ w.queue = w.queue[1:]
+
+ if w.seen[c.Hash] || w.seenExternal[c.Hash] {
+ continue
+ }
+
+ w.seen[c.Hash] = true
+
+ for _, h := range c.ParentHashes {
+ err := w.appendHash(c.s, h)
+ if err != nil {
+ return nil, nil
+ }
+ }
+
+ return c, nil
+ }
+}
+
+func (w *bfsCommitIterator) ForEach(cb func(*Commit) error) error {
+ for {
+ c, err := w.Next()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return err
+ }
+
+ err = cb(c)
+ if err == storer.ErrStop {
+ break
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (w *bfsCommitIterator) Close() {}
--- /dev/null
+package object
+
+import (
+ "io"
+
+ "github.com/emirpasic/gods/trees/binaryheap"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+)
+
+type commitIteratorByCTime struct {
+ seenExternal map[plumbing.Hash]bool
+ seen map[plumbing.Hash]bool
+ heap *binaryheap.Heap
+}
+
+// NewCommitIterCTime returns a CommitIter that walks the commit history,
+// starting at the given commit and visiting its parents while preserving Committer Time order.
+// this appears to be the closest order to `git log`
+// The given callback will be called for each visited commit. Each commit will
+// be visited only once. If the callback returns an error, walking will stop
+// and will return the error. Other errors might be returned if the history
+// cannot be traversed (e.g. missing objects). Ignore allows to skip some
+// commits from being iterated.
+func NewCommitIterCTime(
+ c *Commit,
+ seenExternal map[plumbing.Hash]bool,
+ ignore []plumbing.Hash,
+) CommitIter {
+ seen := make(map[plumbing.Hash]bool)
+ for _, h := range ignore {
+ seen[h] = true
+ }
+
+ heap := binaryheap.NewWith(func(a, b interface{}) int {
+ if a.(*Commit).Committer.When.Before(b.(*Commit).Committer.When) {
+ return 1
+ }
+ return -1
+ })
+ heap.Push(c)
+
+ return &commitIteratorByCTime{
+ seenExternal: seenExternal,
+ seen: seen,
+ heap: heap,
+ }
+}
+
+func (w *commitIteratorByCTime) Next() (*Commit, error) {
+ var c *Commit
+ for {
+ cIn, ok := w.heap.Pop()
+ if !ok {
+ return nil, io.EOF
+ }
+ c = cIn.(*Commit)
+
+ if w.seen[c.Hash] || w.seenExternal[c.Hash] {
+ continue
+ }
+
+ w.seen[c.Hash] = true
+
+ for _, h := range c.ParentHashes {
+ if w.seen[h] || w.seenExternal[h] {
+ continue
+ }
+ pc, err := GetCommit(c.s, h)
+ if err != nil {
+ return nil, err
+ }
+ w.heap.Push(pc)
+ }
+
+ return c, nil
+ }
+}
+
+func (w *commitIteratorByCTime) ForEach(cb func(*Commit) error) error {
+ for {
+ c, err := w.Next()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return err
+ }
+
+ err = cb(c)
+ if err == storer.ErrStop {
+ break
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (w *commitIteratorByCTime) Close() {}
--- /dev/null
+package object
+
+import (
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "io"
+)
+
+type commitFileIter struct {
+ fileName string
+ sourceIter CommitIter
+ currentCommit *Commit
+}
+
+// NewCommitFileIterFromIter returns a commit iterator which performs diffTree between
+// successive trees returned from the commit iterator from the argument. The purpose of this is
+// to find the commits that explain how the files that match the path came to be.
+func NewCommitFileIterFromIter(fileName string, commitIter CommitIter) CommitIter {
+ iterator := new(commitFileIter)
+ iterator.sourceIter = commitIter
+ iterator.fileName = fileName
+ return iterator
+}
+
+func (c *commitFileIter) Next() (*Commit, error) {
+ if c.currentCommit == nil {
+ var err error
+ c.currentCommit, err = c.sourceIter.Next()
+ if err != nil {
+ return nil, err
+ }
+ }
+ commit, commitErr := c.getNextFileCommit()
+
+ // Setting current-commit to nil to prevent unwanted states when errors are raised
+ if commitErr != nil {
+ c.currentCommit = nil
+ }
+ return commit, commitErr
+}
+
+func (c *commitFileIter) getNextFileCommit() (*Commit, error) {
+ for {
+ // Parent-commit can be nil if the current-commit is the initial commit
+ parentCommit, parentCommitErr := c.sourceIter.Next()
+ if parentCommitErr != nil {
+ // If the parent-commit is beyond the initial commit, keep it nil
+ if parentCommitErr != io.EOF {
+ return nil, parentCommitErr
+ }
+ parentCommit = nil
+ }
+
+ // Fetch the trees of the current and parent commits
+ currentTree, currTreeErr := c.currentCommit.Tree()
+ if currTreeErr != nil {
+ return nil, currTreeErr
+ }
+
+ var parentTree *Tree
+ if parentCommit != nil {
+ var parentTreeErr error
+ parentTree, parentTreeErr = parentCommit.Tree()
+ if parentTreeErr != nil {
+ return nil, parentTreeErr
+ }
+ }
+
+ // Find diff between current and parent trees
+ changes, diffErr := DiffTree(currentTree, parentTree)
+ if diffErr != nil {
+ return nil, diffErr
+ }
+
+ foundChangeForFile := false
+ for _, change := range changes {
+ if change.name() == c.fileName {
+ foundChangeForFile = true
+ break
+ }
+ }
+
+ // Storing the current-commit in-case a change is found, and
+ // Updating the current-commit for the next-iteration
+ prevCommit := c.currentCommit
+ c.currentCommit = parentCommit
+
+ if foundChangeForFile == true {
+ return prevCommit, nil
+ }
+
+ // If not matches found and if parent-commit is beyond the initial commit, then return with EOF
+ if parentCommit == nil {
+ return nil, io.EOF
+ }
+ }
+}
+
+func (c *commitFileIter) ForEach(cb func(*Commit) error) error {
+ for {
+ commit, nextErr := c.Next()
+ if nextErr != nil {
+ return nextErr
+ }
+ err := cb(commit)
+ if err == storer.ErrStop {
+ return nil
+ } else if err != nil {
+ return err
+ }
+ }
+}
+
+func (c *commitFileIter) Close() {
+ c.sourceIter.Close()
+}
--- /dev/null
+package object
+
+import (
+ "bytes"
+ "context"
+
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie"
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie/noder"
+)
+
+// DiffTree compares the content and mode of the blobs found via two
+// tree objects.
+func DiffTree(a, b *Tree) (Changes, error) {
+ return DiffTreeContext(context.Background(), a, b)
+}
+
+// DiffTree compares the content and mode of the blobs found via two
+// tree objects. Provided context must be non-nil.
+// An error will be return if context expires
+func DiffTreeContext(ctx context.Context, a, b *Tree) (Changes, error) {
+ from := NewTreeRootNode(a)
+ to := NewTreeRootNode(b)
+
+ hashEqual := func(a, b noder.Hasher) bool {
+ return bytes.Equal(a.Hash(), b.Hash())
+ }
+
+ merkletrieChanges, err := merkletrie.DiffTreeContext(ctx, from, to, hashEqual)
+ if err != nil {
+ if err == merkletrie.ErrCanceled {
+ return nil, ErrCanceled
+ }
+ return nil, err
+ }
+
+ return newChanges(merkletrieChanges)
+}
--- /dev/null
+package object
+
+import (
+ "bytes"
+ "io"
+ "strings"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/filemode"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/utils/binary"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+// File represents git file objects.
+type File struct {
+ // Name is the path of the file. It might be relative to a tree,
+ // depending of the function that generates it.
+ Name string
+ // Mode is the file mode.
+ Mode filemode.FileMode
+ // Blob with the contents of the file.
+ Blob
+}
+
+// NewFile returns a File based on the given blob object
+func NewFile(name string, m filemode.FileMode, b *Blob) *File {
+ return &File{Name: name, Mode: m, Blob: *b}
+}
+
+// Contents returns the contents of a file as a string.
+func (f *File) Contents() (content string, err error) {
+ reader, err := f.Reader()
+ if err != nil {
+ return "", err
+ }
+ defer ioutil.CheckClose(reader, &err)
+
+ buf := new(bytes.Buffer)
+ if _, err := buf.ReadFrom(reader); err != nil {
+ return "", err
+ }
+
+ return buf.String(), nil
+}
+
+// IsBinary returns if the file is binary or not
+func (f *File) IsBinary() (bin bool, err error) {
+ reader, err := f.Reader()
+ if err != nil {
+ return false, err
+ }
+ defer ioutil.CheckClose(reader, &err)
+
+ return binary.IsBinary(reader)
+}
+
+// Lines returns a slice of lines from the contents of a file, stripping
+// all end of line characters. If the last line is empty (does not end
+// in an end of line), it is also stripped.
+func (f *File) Lines() ([]string, error) {
+ content, err := f.Contents()
+ if err != nil {
+ return nil, err
+ }
+
+ splits := strings.Split(content, "\n")
+ // remove the last line if it is empty
+ if splits[len(splits)-1] == "" {
+ return splits[:len(splits)-1], nil
+ }
+
+ return splits, nil
+}
+
+// FileIter provides an iterator for the files in a tree.
+type FileIter struct {
+ s storer.EncodedObjectStorer
+ w TreeWalker
+}
+
+// NewFileIter takes a storer.EncodedObjectStorer and a Tree and returns a
+// *FileIter that iterates over all files contained in the tree, recursively.
+func NewFileIter(s storer.EncodedObjectStorer, t *Tree) *FileIter {
+ return &FileIter{s: s, w: *NewTreeWalker(t, true, nil)}
+}
+
+// Next moves the iterator to the next file and returns a pointer to it. If
+// there are no more files, it returns io.EOF.
+func (iter *FileIter) Next() (*File, error) {
+ for {
+ name, entry, err := iter.w.Next()
+ if err != nil {
+ return nil, err
+ }
+
+ if entry.Mode == filemode.Dir || entry.Mode == filemode.Submodule {
+ continue
+ }
+
+ blob, err := GetBlob(iter.s, entry.Hash)
+ if err != nil {
+ return nil, err
+ }
+
+ return NewFile(name, entry.Mode, blob), nil
+ }
+}
+
+// ForEach call the cb function for each file contained in this iter until
+// an error happens or the end of the iter is reached. If plumbing.ErrStop is sent
+// the iteration is stop but no error is returned. The iterator is closed.
+func (iter *FileIter) ForEach(cb func(*File) error) error {
+ defer iter.Close()
+
+ for {
+ f, err := iter.Next()
+ if err != nil {
+ if err == io.EOF {
+ return nil
+ }
+
+ return err
+ }
+
+ if err := cb(f); err != nil {
+ if err == storer.ErrStop {
+ return nil
+ }
+
+ return err
+ }
+ }
+}
+
+func (iter *FileIter) Close() {
+ iter.w.Close()
+}
--- /dev/null
+// Package object contains implementations of all Git objects and utility
+// functions to work with them.
+package object
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "strconv"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+)
+
+// ErrUnsupportedObject trigger when a non-supported object is being decoded.
+var ErrUnsupportedObject = errors.New("unsupported object type")
+
+// Object is a generic representation of any git object. It is implemented by
+// Commit, Tree, Blob, and Tag, and includes the functions that are common to
+// them.
+//
+// Object is returned when an object can be of any type. It is frequently used
+// with a type cast to acquire the specific type of object:
+//
+// func process(obj Object) {
+// switch o := obj.(type) {
+// case *Commit:
+// // o is a Commit
+// case *Tree:
+// // o is a Tree
+// case *Blob:
+// // o is a Blob
+// case *Tag:
+// // o is a Tag
+// }
+// }
+//
+// This interface is intentionally different from plumbing.EncodedObject, which
+// is a lower level interface used by storage implementations to read and write
+// objects in its encoded form.
+type Object interface {
+ ID() plumbing.Hash
+ Type() plumbing.ObjectType
+ Decode(plumbing.EncodedObject) error
+ Encode(plumbing.EncodedObject) error
+}
+
+// GetObject gets an object from an object storer and decodes it.
+func GetObject(s storer.EncodedObjectStorer, h plumbing.Hash) (Object, error) {
+ o, err := s.EncodedObject(plumbing.AnyObject, h)
+ if err != nil {
+ return nil, err
+ }
+
+ return DecodeObject(s, o)
+}
+
+// DecodeObject decodes an encoded object into an Object and associates it to
+// the given object storer.
+func DecodeObject(s storer.EncodedObjectStorer, o plumbing.EncodedObject) (Object, error) {
+ switch o.Type() {
+ case plumbing.CommitObject:
+ return DecodeCommit(s, o)
+ case plumbing.TreeObject:
+ return DecodeTree(s, o)
+ case plumbing.BlobObject:
+ return DecodeBlob(o)
+ case plumbing.TagObject:
+ return DecodeTag(s, o)
+ default:
+ return nil, plumbing.ErrInvalidType
+ }
+}
+
+// DateFormat is the format being used in the original git implementation
+const DateFormat = "Mon Jan 02 15:04:05 2006 -0700"
+
+// Signature is used to identify who and when created a commit or tag.
+type Signature struct {
+ // Name represents a person name. It is an arbitrary string.
+ Name string
+ // Email is an email, but it cannot be assumed to be well-formed.
+ Email string
+ // When is the timestamp of the signature.
+ When time.Time
+}
+
+// Decode decodes a byte slice into a signature
+func (s *Signature) Decode(b []byte) {
+ open := bytes.LastIndexByte(b, '<')
+ close := bytes.LastIndexByte(b, '>')
+ if open == -1 || close == -1 {
+ return
+ }
+
+ if close < open {
+ return
+ }
+
+ s.Name = string(bytes.Trim(b[:open], " "))
+ s.Email = string(b[open+1 : close])
+
+ hasTime := close+2 < len(b)
+ if hasTime {
+ s.decodeTimeAndTimeZone(b[close+2:])
+ }
+}
+
+// Encode encodes a Signature into a writer.
+func (s *Signature) Encode(w io.Writer) error {
+ if _, err := fmt.Fprintf(w, "%s <%s> ", s.Name, s.Email); err != nil {
+ return err
+ }
+ if err := s.encodeTimeAndTimeZone(w); err != nil {
+ return err
+ }
+ return nil
+}
+
+var timeZoneLength = 5
+
+func (s *Signature) decodeTimeAndTimeZone(b []byte) {
+ space := bytes.IndexByte(b, ' ')
+ if space == -1 {
+ space = len(b)
+ }
+
+ ts, err := strconv.ParseInt(string(b[:space]), 10, 64)
+ if err != nil {
+ return
+ }
+
+ s.When = time.Unix(ts, 0).In(time.UTC)
+ var tzStart = space + 1
+ if tzStart >= len(b) || tzStart+timeZoneLength > len(b) {
+ return
+ }
+
+ // Include a dummy year in this time.Parse() call to avoid a bug in Go:
+ // https://github.com/golang/go/issues/19750
+ //
+ // Parsing the timezone with no other details causes the tl.Location() call
+ // below to return time.Local instead of the parsed zone in some cases
+ tl, err := time.Parse("2006 -0700", "1970 "+string(b[tzStart:tzStart+timeZoneLength]))
+ if err != nil {
+ return
+ }
+
+ s.When = s.When.In(tl.Location())
+}
+
+func (s *Signature) encodeTimeAndTimeZone(w io.Writer) error {
+ u := s.When.Unix()
+ if u < 0 {
+ u = 0
+ }
+ _, err := fmt.Fprintf(w, "%d %s", u, s.When.Format("-0700"))
+ return err
+}
+
+func (s *Signature) String() string {
+ return fmt.Sprintf("%s <%s>", s.Name, s.Email)
+}
+
+// ObjectIter provides an iterator for a set of objects.
+type ObjectIter struct {
+ storer.EncodedObjectIter
+ s storer.EncodedObjectStorer
+}
+
+// NewObjectIter takes a storer.EncodedObjectStorer and a
+// storer.EncodedObjectIter and returns an *ObjectIter that iterates over all
+// objects contained in the storer.EncodedObjectIter.
+func NewObjectIter(s storer.EncodedObjectStorer, iter storer.EncodedObjectIter) *ObjectIter {
+ return &ObjectIter{iter, s}
+}
+
+// Next moves the iterator to the next object and returns a pointer to it. If
+// there are no more objects, it returns io.EOF.
+func (iter *ObjectIter) Next() (Object, error) {
+ for {
+ obj, err := iter.EncodedObjectIter.Next()
+ if err != nil {
+ return nil, err
+ }
+
+ o, err := iter.toObject(obj)
+ if err == plumbing.ErrInvalidType {
+ continue
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ return o, nil
+ }
+}
+
+// ForEach call the cb function for each object contained on this iter until
+// an error happens or the end of the iter is reached. If ErrStop is sent
+// the iteration is stop but no error is returned. The iterator is closed.
+func (iter *ObjectIter) ForEach(cb func(Object) error) error {
+ return iter.EncodedObjectIter.ForEach(func(obj plumbing.EncodedObject) error {
+ o, err := iter.toObject(obj)
+ if err == plumbing.ErrInvalidType {
+ return nil
+ }
+
+ if err != nil {
+ return err
+ }
+
+ return cb(o)
+ })
+}
+
+func (iter *ObjectIter) toObject(obj plumbing.EncodedObject) (Object, error) {
+ switch obj.Type() {
+ case plumbing.BlobObject:
+ blob := &Blob{}
+ return blob, blob.Decode(obj)
+ case plumbing.TreeObject:
+ tree := &Tree{s: iter.s}
+ return tree, tree.Decode(obj)
+ case plumbing.CommitObject:
+ commit := &Commit{}
+ return commit, commit.Decode(obj)
+ case plumbing.TagObject:
+ tag := &Tag{}
+ return tag, tag.Decode(obj)
+ default:
+ return nil, plumbing.ErrInvalidType
+ }
+}
--- /dev/null
+package object
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "strings"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/filemode"
+ fdiff "gopkg.in/src-d/go-git.v4/plumbing/format/diff"
+ "gopkg.in/src-d/go-git.v4/utils/diff"
+
+ dmp "github.com/sergi/go-diff/diffmatchpatch"
+)
+
+var (
+ ErrCanceled = errors.New("operation canceled")
+)
+
+func getPatch(message string, changes ...*Change) (*Patch, error) {
+ ctx := context.Background()
+ return getPatchContext(ctx, message, changes...)
+}
+
+func getPatchContext(ctx context.Context, message string, changes ...*Change) (*Patch, error) {
+ var filePatches []fdiff.FilePatch
+ for _, c := range changes {
+ select {
+ case <-ctx.Done():
+ return nil, ErrCanceled
+ default:
+ }
+
+ fp, err := filePatchWithContext(ctx, c)
+ if err != nil {
+ return nil, err
+ }
+
+ filePatches = append(filePatches, fp)
+ }
+
+ return &Patch{message, filePatches}, nil
+}
+
+func filePatchWithContext(ctx context.Context, c *Change) (fdiff.FilePatch, error) {
+ from, to, err := c.Files()
+ if err != nil {
+ return nil, err
+ }
+ fromContent, fIsBinary, err := fileContent(from)
+ if err != nil {
+ return nil, err
+ }
+
+ toContent, tIsBinary, err := fileContent(to)
+ if err != nil {
+ return nil, err
+ }
+
+ if fIsBinary || tIsBinary {
+ return &textFilePatch{from: c.From, to: c.To}, nil
+ }
+
+ diffs := diff.Do(fromContent, toContent)
+
+ var chunks []fdiff.Chunk
+ for _, d := range diffs {
+ select {
+ case <-ctx.Done():
+ return nil, ErrCanceled
+ default:
+ }
+
+ var op fdiff.Operation
+ switch d.Type {
+ case dmp.DiffEqual:
+ op = fdiff.Equal
+ case dmp.DiffDelete:
+ op = fdiff.Delete
+ case dmp.DiffInsert:
+ op = fdiff.Add
+ }
+
+ chunks = append(chunks, &textChunk{d.Text, op})
+ }
+
+ return &textFilePatch{
+ chunks: chunks,
+ from: c.From,
+ to: c.To,
+ }, nil
+
+}
+
+func filePatch(c *Change) (fdiff.FilePatch, error) {
+ return filePatchWithContext(context.Background(), c)
+}
+
+func fileContent(f *File) (content string, isBinary bool, err error) {
+ if f == nil {
+ return
+ }
+
+ isBinary, err = f.IsBinary()
+ if err != nil || isBinary {
+ return
+ }
+
+ content, err = f.Contents()
+
+ return
+}
+
+// textPatch is an implementation of fdiff.Patch interface
+type Patch struct {
+ message string
+ filePatches []fdiff.FilePatch
+}
+
+func (t *Patch) FilePatches() []fdiff.FilePatch {
+ return t.filePatches
+}
+
+func (t *Patch) Message() string {
+ return t.message
+}
+
+func (p *Patch) Encode(w io.Writer) error {
+ ue := fdiff.NewUnifiedEncoder(w, fdiff.DefaultContextLines)
+
+ return ue.Encode(p)
+}
+
+func (p *Patch) Stats() FileStats {
+ return getFileStatsFromFilePatches(p.FilePatches())
+}
+
+func (p *Patch) String() string {
+ buf := bytes.NewBuffer(nil)
+ err := p.Encode(buf)
+ if err != nil {
+ return fmt.Sprintf("malformed patch: %s", err.Error())
+ }
+
+ return buf.String()
+}
+
+// changeEntryWrapper is an implementation of fdiff.File interface
+type changeEntryWrapper struct {
+ ce ChangeEntry
+}
+
+func (f *changeEntryWrapper) Hash() plumbing.Hash {
+ if !f.ce.TreeEntry.Mode.IsFile() {
+ return plumbing.ZeroHash
+ }
+
+ return f.ce.TreeEntry.Hash
+}
+
+func (f *changeEntryWrapper) Mode() filemode.FileMode {
+ return f.ce.TreeEntry.Mode
+}
+func (f *changeEntryWrapper) Path() string {
+ if !f.ce.TreeEntry.Mode.IsFile() {
+ return ""
+ }
+
+ return f.ce.Name
+}
+
+func (f *changeEntryWrapper) Empty() bool {
+ return !f.ce.TreeEntry.Mode.IsFile()
+}
+
+// textFilePatch is an implementation of fdiff.FilePatch interface
+type textFilePatch struct {
+ chunks []fdiff.Chunk
+ from, to ChangeEntry
+}
+
+func (tf *textFilePatch) Files() (from fdiff.File, to fdiff.File) {
+ f := &changeEntryWrapper{tf.from}
+ t := &changeEntryWrapper{tf.to}
+
+ if !f.Empty() {
+ from = f
+ }
+
+ if !t.Empty() {
+ to = t
+ }
+
+ return
+}
+
+func (t *textFilePatch) IsBinary() bool {
+ return len(t.chunks) == 0
+}
+
+func (t *textFilePatch) Chunks() []fdiff.Chunk {
+ return t.chunks
+}
+
+// textChunk is an implementation of fdiff.Chunk interface
+type textChunk struct {
+ content string
+ op fdiff.Operation
+}
+
+func (t *textChunk) Content() string {
+ return t.content
+}
+
+func (t *textChunk) Type() fdiff.Operation {
+ return t.op
+}
+
+// FileStat stores the status of changes in content of a file.
+type FileStat struct {
+ Name string
+ Addition int
+ Deletion int
+}
+
+func (fs FileStat) String() string {
+ return printStat([]FileStat{fs})
+}
+
+// FileStats is a collection of FileStat.
+type FileStats []FileStat
+
+func (fileStats FileStats) String() string {
+ return printStat(fileStats)
+}
+
+func printStat(fileStats []FileStat) string {
+ padLength := float64(len(" "))
+ newlineLength := float64(len("\n"))
+ separatorLength := float64(len("|"))
+ // Soft line length limit. The text length calculation below excludes
+ // length of the change number. Adding that would take it closer to 80,
+ // but probably not more than 80, until it's a huge number.
+ lineLength := 72.0
+
+ // Get the longest filename and longest total change.
+ var longestLength float64
+ var longestTotalChange float64
+ for _, fs := range fileStats {
+ if int(longestLength) < len(fs.Name) {
+ longestLength = float64(len(fs.Name))
+ }
+ totalChange := fs.Addition + fs.Deletion
+ if int(longestTotalChange) < totalChange {
+ longestTotalChange = float64(totalChange)
+ }
+ }
+
+ // Parts of the output:
+ // <pad><filename><pad>|<pad><changeNumber><pad><+++/---><newline>
+ // example: " main.go | 10 +++++++--- "
+
+ // <pad><filename><pad>
+ leftTextLength := padLength + longestLength + padLength
+
+ // <pad><number><pad><+++++/-----><newline>
+ // Excluding number length here.
+ rightTextLength := padLength + padLength + newlineLength
+
+ totalTextArea := leftTextLength + separatorLength + rightTextLength
+ heightOfHistogram := lineLength - totalTextArea
+
+ // Scale the histogram.
+ var scaleFactor float64
+ if longestTotalChange > heightOfHistogram {
+ // Scale down to heightOfHistogram.
+ scaleFactor = float64(longestTotalChange / heightOfHistogram)
+ } else {
+ scaleFactor = 1.0
+ }
+
+ finalOutput := ""
+ for _, fs := range fileStats {
+ addn := float64(fs.Addition)
+ deln := float64(fs.Deletion)
+ adds := strings.Repeat("+", int(math.Floor(addn/scaleFactor)))
+ dels := strings.Repeat("-", int(math.Floor(deln/scaleFactor)))
+ finalOutput += fmt.Sprintf(" %s | %d %s%s\n", fs.Name, (fs.Addition + fs.Deletion), adds, dels)
+ }
+
+ return finalOutput
+}
+
+func getFileStatsFromFilePatches(filePatches []fdiff.FilePatch) FileStats {
+ var fileStats FileStats
+
+ for _, fp := range filePatches {
+ // ignore empty patches (binary files, submodule refs updates)
+ if len(fp.Chunks()) == 0 {
+ continue
+ }
+
+ cs := FileStat{}
+ from, to := fp.Files()
+ if from == nil {
+ // New File is created.
+ cs.Name = to.Path()
+ } else if to == nil {
+ // File is deleted.
+ cs.Name = from.Path()
+ } else if from.Path() != to.Path() {
+ // File is renamed. Not supported.
+ // cs.Name = fmt.Sprintf("%s => %s", from.Path(), to.Path())
+ } else {
+ cs.Name = from.Path()
+ }
+
+ for _, chunk := range fp.Chunks() {
+ switch chunk.Type() {
+ case fdiff.Add:
+ cs.Addition += strings.Count(chunk.Content(), "\n")
+ case fdiff.Delete:
+ cs.Deletion += strings.Count(chunk.Content(), "\n")
+ }
+ }
+
+ fileStats = append(fileStats, cs)
+ }
+
+ return fileStats
+}
--- /dev/null
+package object
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ stdioutil "io/ioutil"
+ "strings"
+
+ "golang.org/x/crypto/openpgp"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+// Tag represents an annotated tag object. It points to a single git object of
+// any type, but tags typically are applied to commit or blob objects. It
+// provides a reference that associates the target with a tag name. It also
+// contains meta-information about the tag, including the tagger, tag date and
+// message.
+//
+// Note that this is not used for lightweight tags.
+//
+// https://git-scm.com/book/en/v2/Git-Internals-Git-References#Tags
+type Tag struct {
+ // Hash of the tag.
+ Hash plumbing.Hash
+ // Name of the tag.
+ Name string
+ // Tagger is the one who created the tag.
+ Tagger Signature
+ // Message is an arbitrary text message.
+ Message string
+ // PGPSignature is the PGP signature of the tag.
+ PGPSignature string
+ // TargetType is the object type of the target.
+ TargetType plumbing.ObjectType
+ // Target is the hash of the target object.
+ Target plumbing.Hash
+
+ s storer.EncodedObjectStorer
+}
+
+// GetTag gets a tag from an object storer and decodes it.
+func GetTag(s storer.EncodedObjectStorer, h plumbing.Hash) (*Tag, error) {
+ o, err := s.EncodedObject(plumbing.TagObject, h)
+ if err != nil {
+ return nil, err
+ }
+
+ return DecodeTag(s, o)
+}
+
+// DecodeTag decodes an encoded object into a *Commit and associates it to the
+// given object storer.
+func DecodeTag(s storer.EncodedObjectStorer, o plumbing.EncodedObject) (*Tag, error) {
+ t := &Tag{s: s}
+ if err := t.Decode(o); err != nil {
+ return nil, err
+ }
+
+ return t, nil
+}
+
+// ID returns the object ID of the tag, not the object that the tag references.
+// The returned value will always match the current value of Tag.Hash.
+//
+// ID is present to fulfill the Object interface.
+func (t *Tag) ID() plumbing.Hash {
+ return t.Hash
+}
+
+// Type returns the type of object. It always returns plumbing.TagObject.
+//
+// Type is present to fulfill the Object interface.
+func (t *Tag) Type() plumbing.ObjectType {
+ return plumbing.TagObject
+}
+
+// Decode transforms a plumbing.EncodedObject into a Tag struct.
+func (t *Tag) Decode(o plumbing.EncodedObject) (err error) {
+ if o.Type() != plumbing.TagObject {
+ return ErrUnsupportedObject
+ }
+
+ t.Hash = o.Hash()
+
+ reader, err := o.Reader()
+ if err != nil {
+ return err
+ }
+ defer ioutil.CheckClose(reader, &err)
+
+ r := bufio.NewReader(reader)
+ for {
+ var line []byte
+ line, err = r.ReadBytes('\n')
+ if err != nil && err != io.EOF {
+ return err
+ }
+
+ line = bytes.TrimSpace(line)
+ if len(line) == 0 {
+ break // Start of message
+ }
+
+ split := bytes.SplitN(line, []byte{' '}, 2)
+ switch string(split[0]) {
+ case "object":
+ t.Target = plumbing.NewHash(string(split[1]))
+ case "type":
+ t.TargetType, err = plumbing.ParseObjectType(string(split[1]))
+ if err != nil {
+ return err
+ }
+ case "tag":
+ t.Name = string(split[1])
+ case "tagger":
+ t.Tagger.Decode(split[1])
+ }
+
+ if err == io.EOF {
+ return nil
+ }
+ }
+
+ data, err := stdioutil.ReadAll(r)
+ if err != nil {
+ return err
+ }
+
+ var pgpsig bool
+ // Check if data contains PGP signature.
+ if bytes.Contains(data, []byte(beginpgp)) {
+ // Split the lines at newline.
+ messageAndSig := bytes.Split(data, []byte("\n"))
+
+ for _, l := range messageAndSig {
+ if pgpsig {
+ if bytes.Contains(l, []byte(endpgp)) {
+ t.PGPSignature += endpgp + "\n"
+ pgpsig = false
+ } else {
+ t.PGPSignature += string(l) + "\n"
+ }
+ continue
+ }
+
+ // Check if it's the beginning of a PGP signature.
+ if bytes.Contains(l, []byte(beginpgp)) {
+ t.PGPSignature += beginpgp + "\n"
+ pgpsig = true
+ continue
+ }
+
+ t.Message += string(l) + "\n"
+ }
+ } else {
+ t.Message = string(data)
+ }
+
+ return nil
+}
+
+// Encode transforms a Tag into a plumbing.EncodedObject.
+func (t *Tag) Encode(o plumbing.EncodedObject) error {
+ return t.encode(o, true)
+}
+
+func (t *Tag) encode(o plumbing.EncodedObject, includeSig bool) (err error) {
+ o.SetType(plumbing.TagObject)
+ w, err := o.Writer()
+ if err != nil {
+ return err
+ }
+ defer ioutil.CheckClose(w, &err)
+
+ if _, err = fmt.Fprintf(w,
+ "object %s\ntype %s\ntag %s\ntagger ",
+ t.Target.String(), t.TargetType.Bytes(), t.Name); err != nil {
+ return err
+ }
+
+ if err = t.Tagger.Encode(w); err != nil {
+ return err
+ }
+
+ if _, err = fmt.Fprint(w, "\n\n"); err != nil {
+ return err
+ }
+
+ if _, err = fmt.Fprint(w, t.Message); err != nil {
+ return err
+ }
+
+ // Note that this is highly sensitive to what it sent along in the message.
+ // Message *always* needs to end with a newline, or else the message and the
+ // signature will be concatenated into a corrupt object. Since this is a
+ // lower-level method, we assume you know what you are doing and have already
+ // done the needful on the message in the caller.
+ if includeSig {
+ if _, err = fmt.Fprint(w, t.PGPSignature); err != nil {
+ return err
+ }
+ }
+
+ return err
+}
+
+// Commit returns the commit pointed to by the tag. If the tag points to a
+// different type of object ErrUnsupportedObject will be returned.
+func (t *Tag) Commit() (*Commit, error) {
+ if t.TargetType != plumbing.CommitObject {
+ return nil, ErrUnsupportedObject
+ }
+
+ o, err := t.s.EncodedObject(plumbing.CommitObject, t.Target)
+ if err != nil {
+ return nil, err
+ }
+
+ return DecodeCommit(t.s, o)
+}
+
+// Tree returns the tree pointed to by the tag. If the tag points to a commit
+// object the tree of that commit will be returned. If the tag does not point
+// to a commit or tree object ErrUnsupportedObject will be returned.
+func (t *Tag) Tree() (*Tree, error) {
+ switch t.TargetType {
+ case plumbing.CommitObject:
+ c, err := t.Commit()
+ if err != nil {
+ return nil, err
+ }
+
+ return c.Tree()
+ case plumbing.TreeObject:
+ return GetTree(t.s, t.Target)
+ default:
+ return nil, ErrUnsupportedObject
+ }
+}
+
+// Blob returns the blob pointed to by the tag. If the tag points to a
+// different type of object ErrUnsupportedObject will be returned.
+func (t *Tag) Blob() (*Blob, error) {
+ if t.TargetType != plumbing.BlobObject {
+ return nil, ErrUnsupportedObject
+ }
+
+ return GetBlob(t.s, t.Target)
+}
+
+// Object returns the object pointed to by the tag.
+func (t *Tag) Object() (Object, error) {
+ o, err := t.s.EncodedObject(t.TargetType, t.Target)
+ if err != nil {
+ return nil, err
+ }
+
+ return DecodeObject(t.s, o)
+}
+
+// String returns the meta information contained in the tag as a formatted
+// string.
+func (t *Tag) String() string {
+ obj, _ := t.Object()
+
+ return fmt.Sprintf(
+ "%s %s\nTagger: %s\nDate: %s\n\n%s\n%s",
+ plumbing.TagObject, t.Name, t.Tagger.String(), t.Tagger.When.Format(DateFormat),
+ t.Message, objectAsString(obj),
+ )
+}
+
+// Verify performs PGP verification of the tag with a provided armored
+// keyring and returns openpgp.Entity associated with verifying key on success.
+func (t *Tag) Verify(armoredKeyRing string) (*openpgp.Entity, error) {
+ keyRingReader := strings.NewReader(armoredKeyRing)
+ keyring, err := openpgp.ReadArmoredKeyRing(keyRingReader)
+ if err != nil {
+ return nil, err
+ }
+
+ // Extract signature.
+ signature := strings.NewReader(t.PGPSignature)
+
+ encoded := &plumbing.MemoryObject{}
+ // Encode tag components, excluding signature and get a reader object.
+ if err := t.encode(encoded, false); err != nil {
+ return nil, err
+ }
+ er, err := encoded.Reader()
+ if err != nil {
+ return nil, err
+ }
+
+ return openpgp.CheckArmoredDetachedSignature(keyring, er, signature)
+}
+
+// TagIter provides an iterator for a set of tags.
+type TagIter struct {
+ storer.EncodedObjectIter
+ s storer.EncodedObjectStorer
+}
+
+// NewTagIter takes a storer.EncodedObjectStorer and a
+// storer.EncodedObjectIter and returns a *TagIter that iterates over all
+// tags contained in the storer.EncodedObjectIter.
+//
+// Any non-tag object returned by the storer.EncodedObjectIter is skipped.
+func NewTagIter(s storer.EncodedObjectStorer, iter storer.EncodedObjectIter) *TagIter {
+ return &TagIter{iter, s}
+}
+
+// Next moves the iterator to the next tag and returns a pointer to it. If
+// there are no more tags, it returns io.EOF.
+func (iter *TagIter) Next() (*Tag, error) {
+ obj, err := iter.EncodedObjectIter.Next()
+ if err != nil {
+ return nil, err
+ }
+
+ return DecodeTag(iter.s, obj)
+}
+
+// ForEach call the cb function for each tag contained on this iter until
+// an error happens or the end of the iter is reached. If ErrStop is sent
+// the iteration is stop but no error is returned. The iterator is closed.
+func (iter *TagIter) ForEach(cb func(*Tag) error) error {
+ return iter.EncodedObjectIter.ForEach(func(obj plumbing.EncodedObject) error {
+ t, err := DecodeTag(iter.s, obj)
+ if err != nil {
+ return err
+ }
+
+ return cb(t)
+ })
+}
+
+func objectAsString(obj Object) string {
+ switch o := obj.(type) {
+ case *Commit:
+ return o.String()
+ default:
+ return ""
+ }
+}
--- /dev/null
+package object
+
+import (
+ "bufio"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "path"
+ "path/filepath"
+ "strings"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/filemode"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+const (
+ maxTreeDepth = 1024
+ startingStackSize = 8
+)
+
+// New errors defined by this package.
+var (
+ ErrMaxTreeDepth = errors.New("maximum tree depth exceeded")
+ ErrFileNotFound = errors.New("file not found")
+ ErrDirectoryNotFound = errors.New("directory not found")
+ ErrEntryNotFound = errors.New("entry not found")
+)
+
+// Tree is basically like a directory - it references a bunch of other trees
+// and/or blobs (i.e. files and sub-directories)
+type Tree struct {
+ Entries []TreeEntry
+ Hash plumbing.Hash
+
+ s storer.EncodedObjectStorer
+ m map[string]*TreeEntry
+ t map[string]*Tree // tree path cache
+}
+
+// GetTree gets a tree from an object storer and decodes it.
+func GetTree(s storer.EncodedObjectStorer, h plumbing.Hash) (*Tree, error) {
+ o, err := s.EncodedObject(plumbing.TreeObject, h)
+ if err != nil {
+ return nil, err
+ }
+
+ return DecodeTree(s, o)
+}
+
+// DecodeTree decodes an encoded object into a *Tree and associates it to the
+// given object storer.
+func DecodeTree(s storer.EncodedObjectStorer, o plumbing.EncodedObject) (*Tree, error) {
+ t := &Tree{s: s}
+ if err := t.Decode(o); err != nil {
+ return nil, err
+ }
+
+ return t, nil
+}
+
+// TreeEntry represents a file
+type TreeEntry struct {
+ Name string
+ Mode filemode.FileMode
+ Hash plumbing.Hash
+}
+
+// File returns the hash of the file identified by the `path` argument.
+// The path is interpreted as relative to the tree receiver.
+func (t *Tree) File(path string) (*File, error) {
+ e, err := t.FindEntry(path)
+ if err != nil {
+ return nil, ErrFileNotFound
+ }
+
+ blob, err := GetBlob(t.s, e.Hash)
+ if err != nil {
+ if err == plumbing.ErrObjectNotFound {
+ return nil, ErrFileNotFound
+ }
+ return nil, err
+ }
+
+ return NewFile(path, e.Mode, blob), nil
+}
+
+// Size returns the plaintext size of an object, without reading it
+// into memory.
+func (t *Tree) Size(path string) (int64, error) {
+ e, err := t.FindEntry(path)
+ if err != nil {
+ return 0, ErrEntryNotFound
+ }
+
+ return t.s.EncodedObjectSize(e.Hash)
+}
+
+// Tree returns the tree identified by the `path` argument.
+// The path is interpreted as relative to the tree receiver.
+func (t *Tree) Tree(path string) (*Tree, error) {
+ e, err := t.FindEntry(path)
+ if err != nil {
+ return nil, ErrDirectoryNotFound
+ }
+
+ tree, err := GetTree(t.s, e.Hash)
+ if err == plumbing.ErrObjectNotFound {
+ return nil, ErrDirectoryNotFound
+ }
+
+ return tree, err
+}
+
+// TreeEntryFile returns the *File for a given *TreeEntry.
+func (t *Tree) TreeEntryFile(e *TreeEntry) (*File, error) {
+ blob, err := GetBlob(t.s, e.Hash)
+ if err != nil {
+ return nil, err
+ }
+
+ return NewFile(e.Name, e.Mode, blob), nil
+}
+
+// FindEntry search a TreeEntry in this tree or any subtree.
+func (t *Tree) FindEntry(path string) (*TreeEntry, error) {
+ if t.t == nil {
+ t.t = make(map[string]*Tree)
+ }
+
+ pathParts := strings.Split(path, "/")
+ startingTree := t
+ pathCurrent := ""
+
+ // search for the longest path in the tree path cache
+ for i := len(pathParts); i > 1; i-- {
+ path := filepath.Join(pathParts[:i]...)
+
+ tree, ok := t.t[path]
+ if ok {
+ startingTree = tree
+ pathParts = pathParts[i:]
+ pathCurrent = path
+
+ break
+ }
+ }
+
+ var tree *Tree
+ var err error
+ for tree = startingTree; len(pathParts) > 1; pathParts = pathParts[1:] {
+ if tree, err = tree.dir(pathParts[0]); err != nil {
+ return nil, err
+ }
+
+ pathCurrent = filepath.Join(pathCurrent, pathParts[0])
+ t.t[pathCurrent] = tree
+ }
+
+ return tree.entry(pathParts[0])
+}
+
+func (t *Tree) dir(baseName string) (*Tree, error) {
+ entry, err := t.entry(baseName)
+ if err != nil {
+ return nil, ErrDirectoryNotFound
+ }
+
+ obj, err := t.s.EncodedObject(plumbing.TreeObject, entry.Hash)
+ if err != nil {
+ return nil, err
+ }
+
+ tree := &Tree{s: t.s}
+ err = tree.Decode(obj)
+
+ return tree, err
+}
+
+func (t *Tree) entry(baseName string) (*TreeEntry, error) {
+ if t.m == nil {
+ t.buildMap()
+ }
+
+ entry, ok := t.m[baseName]
+ if !ok {
+ return nil, ErrEntryNotFound
+ }
+
+ return entry, nil
+}
+
+// Files returns a FileIter allowing to iterate over the Tree
+func (t *Tree) Files() *FileIter {
+ return NewFileIter(t.s, t)
+}
+
+// ID returns the object ID of the tree. The returned value will always match
+// the current value of Tree.Hash.
+//
+// ID is present to fulfill the Object interface.
+func (t *Tree) ID() plumbing.Hash {
+ return t.Hash
+}
+
+// Type returns the type of object. It always returns plumbing.TreeObject.
+func (t *Tree) Type() plumbing.ObjectType {
+ return plumbing.TreeObject
+}
+
+// Decode transform an plumbing.EncodedObject into a Tree struct
+func (t *Tree) Decode(o plumbing.EncodedObject) (err error) {
+ if o.Type() != plumbing.TreeObject {
+ return ErrUnsupportedObject
+ }
+
+ t.Hash = o.Hash()
+ if o.Size() == 0 {
+ return nil
+ }
+
+ t.Entries = nil
+ t.m = nil
+
+ reader, err := o.Reader()
+ if err != nil {
+ return err
+ }
+ defer ioutil.CheckClose(reader, &err)
+
+ r := bufio.NewReader(reader)
+ for {
+ str, err := r.ReadString(' ')
+ if err != nil {
+ if err == io.EOF {
+ break
+ }
+
+ return err
+ }
+ str = str[:len(str)-1] // strip last byte (' ')
+
+ mode, err := filemode.New(str)
+ if err != nil {
+ return err
+ }
+
+ name, err := r.ReadString(0)
+ if err != nil && err != io.EOF {
+ return err
+ }
+
+ var hash plumbing.Hash
+ if _, err = io.ReadFull(r, hash[:]); err != nil {
+ return err
+ }
+
+ baseName := name[:len(name)-1]
+ t.Entries = append(t.Entries, TreeEntry{
+ Hash: hash,
+ Mode: mode,
+ Name: baseName,
+ })
+ }
+
+ return nil
+}
+
+// Encode transforms a Tree into a plumbing.EncodedObject.
+func (t *Tree) Encode(o plumbing.EncodedObject) (err error) {
+ o.SetType(plumbing.TreeObject)
+ w, err := o.Writer()
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(w, &err)
+ for _, entry := range t.Entries {
+ if _, err = fmt.Fprintf(w, "%o %s", entry.Mode, entry.Name); err != nil {
+ return err
+ }
+
+ if _, err = w.Write([]byte{0x00}); err != nil {
+ return err
+ }
+
+ if _, err = w.Write([]byte(entry.Hash[:])); err != nil {
+ return err
+ }
+ }
+
+ return err
+}
+
+func (t *Tree) buildMap() {
+ t.m = make(map[string]*TreeEntry)
+ for i := 0; i < len(t.Entries); i++ {
+ t.m[t.Entries[i].Name] = &t.Entries[i]
+ }
+}
+
+// Diff returns a list of changes between this tree and the provided one
+func (from *Tree) Diff(to *Tree) (Changes, error) {
+ return DiffTree(from, to)
+}
+
+// Diff returns a list of changes between this tree and the provided one
+// Error will be returned if context expires
+// Provided context must be non nil
+func (from *Tree) DiffContext(ctx context.Context, to *Tree) (Changes, error) {
+ return DiffTreeContext(ctx, from, to)
+}
+
+// Patch returns a slice of Patch objects with all the changes between trees
+// in chunks. This representation can be used to create several diff outputs.
+func (from *Tree) Patch(to *Tree) (*Patch, error) {
+ return from.PatchContext(context.Background(), to)
+}
+
+// Patch returns a slice of Patch objects with all the changes between trees
+// in chunks. This representation can be used to create several diff outputs.
+// If context expires, an error will be returned
+// Provided context must be non-nil
+func (from *Tree) PatchContext(ctx context.Context, to *Tree) (*Patch, error) {
+ changes, err := DiffTreeContext(ctx, from, to)
+ if err != nil {
+ return nil, err
+ }
+
+ return changes.PatchContext(ctx)
+}
+
+// treeEntryIter facilitates iterating through the TreeEntry objects in a Tree.
+type treeEntryIter struct {
+ t *Tree
+ pos int
+}
+
+func (iter *treeEntryIter) Next() (TreeEntry, error) {
+ if iter.pos >= len(iter.t.Entries) {
+ return TreeEntry{}, io.EOF
+ }
+ iter.pos++
+ return iter.t.Entries[iter.pos-1], nil
+}
+
+// TreeWalker provides a means of walking through all of the entries in a Tree.
+type TreeWalker struct {
+ stack []*treeEntryIter
+ base string
+ recursive bool
+ seen map[plumbing.Hash]bool
+
+ s storer.EncodedObjectStorer
+ t *Tree
+}
+
+// NewTreeWalker returns a new TreeWalker for the given tree.
+//
+// It is the caller's responsibility to call Close() when finished with the
+// tree walker.
+func NewTreeWalker(t *Tree, recursive bool, seen map[plumbing.Hash]bool) *TreeWalker {
+ stack := make([]*treeEntryIter, 0, startingStackSize)
+ stack = append(stack, &treeEntryIter{t, 0})
+
+ return &TreeWalker{
+ stack: stack,
+ recursive: recursive,
+ seen: seen,
+
+ s: t.s,
+ t: t,
+ }
+}
+
+// Next returns the next object from the tree. Objects are returned in order
+// and subtrees are included. After the last object has been returned further
+// calls to Next() will return io.EOF.
+//
+// In the current implementation any objects which cannot be found in the
+// underlying repository will be skipped automatically. It is possible that this
+// may change in future versions.
+func (w *TreeWalker) Next() (name string, entry TreeEntry, err error) {
+ var obj Object
+ for {
+ current := len(w.stack) - 1
+ if current < 0 {
+ // Nothing left on the stack so we're finished
+ err = io.EOF
+ return
+ }
+
+ if current > maxTreeDepth {
+ // We're probably following bad data or some self-referencing tree
+ err = ErrMaxTreeDepth
+ return
+ }
+
+ entry, err = w.stack[current].Next()
+ if err == io.EOF {
+ // Finished with the current tree, move back up to the parent
+ w.stack = w.stack[:current]
+ w.base, _ = path.Split(w.base)
+ w.base = path.Clean(w.base) // Remove trailing slash
+ continue
+ }
+
+ if err != nil {
+ return
+ }
+
+ if w.seen[entry.Hash] {
+ continue
+ }
+
+ if entry.Mode == filemode.Dir {
+ obj, err = GetTree(w.s, entry.Hash)
+ }
+
+ name = path.Join(w.base, entry.Name)
+
+ if err != nil {
+ err = io.EOF
+ return
+ }
+
+ break
+ }
+
+ if !w.recursive {
+ return
+ }
+
+ if t, ok := obj.(*Tree); ok {
+ w.stack = append(w.stack, &treeEntryIter{t, 0})
+ w.base = path.Join(w.base, entry.Name)
+ }
+
+ return
+}
+
+// Tree returns the tree that the tree walker most recently operated on.
+func (w *TreeWalker) Tree() *Tree {
+ current := len(w.stack) - 1
+ if w.stack[current].pos == 0 {
+ current--
+ }
+
+ if current < 0 {
+ return nil
+ }
+
+ return w.stack[current].t
+}
+
+// Close releases any resources used by the TreeWalker.
+func (w *TreeWalker) Close() {
+ w.stack = nil
+}
+
+// TreeIter provides an iterator for a set of trees.
+type TreeIter struct {
+ storer.EncodedObjectIter
+ s storer.EncodedObjectStorer
+}
+
+// NewTreeIter takes a storer.EncodedObjectStorer and a
+// storer.EncodedObjectIter and returns a *TreeIter that iterates over all
+// tree contained in the storer.EncodedObjectIter.
+//
+// Any non-tree object returned by the storer.EncodedObjectIter is skipped.
+func NewTreeIter(s storer.EncodedObjectStorer, iter storer.EncodedObjectIter) *TreeIter {
+ return &TreeIter{iter, s}
+}
+
+// Next moves the iterator to the next tree and returns a pointer to it. If
+// there are no more trees, it returns io.EOF.
+func (iter *TreeIter) Next() (*Tree, error) {
+ for {
+ obj, err := iter.EncodedObjectIter.Next()
+ if err != nil {
+ return nil, err
+ }
+
+ if obj.Type() != plumbing.TreeObject {
+ continue
+ }
+
+ return DecodeTree(iter.s, obj)
+ }
+}
+
+// ForEach call the cb function for each tree contained on this iter until
+// an error happens or the end of the iter is reached. If ErrStop is sent
+// the iteration is stop but no error is returned. The iterator is closed.
+func (iter *TreeIter) ForEach(cb func(*Tree) error) error {
+ return iter.EncodedObjectIter.ForEach(func(obj plumbing.EncodedObject) error {
+ if obj.Type() != plumbing.TreeObject {
+ return nil
+ }
+
+ t, err := DecodeTree(iter.s, obj)
+ if err != nil {
+ return err
+ }
+
+ return cb(t)
+ })
+}
--- /dev/null
+package object
+
+import (
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/filemode"
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie/noder"
+)
+
+// A treenoder is a helper type that wraps git trees into merkletrie
+// noders.
+//
+// As a merkletrie noder doesn't understand the concept of modes (e.g.
+// file permissions), the treenoder includes the mode of the git tree in
+// the hash, so changes in the modes will be detected as modifications
+// to the file contents by the merkletrie difftree algorithm. This is
+// consistent with how the "git diff-tree" command works.
+type treeNoder struct {
+ parent *Tree // the root node is its own parent
+ name string // empty string for the root node
+ mode filemode.FileMode
+ hash plumbing.Hash
+ children []noder.Noder // memoized
+}
+
+// NewTreeRootNode returns the root node of a Tree
+func NewTreeRootNode(t *Tree) noder.Noder {
+ if t == nil {
+ return &treeNoder{}
+ }
+
+ return &treeNoder{
+ parent: t,
+ name: "",
+ mode: filemode.Dir,
+ hash: t.Hash,
+ }
+}
+
+func (t *treeNoder) isRoot() bool {
+ return t.name == ""
+}
+
+func (t *treeNoder) String() string {
+ return "treeNoder <" + t.name + ">"
+}
+
+func (t *treeNoder) Hash() []byte {
+ if t.mode == filemode.Deprecated {
+ return append(t.hash[:], filemode.Regular.Bytes()...)
+ }
+ return append(t.hash[:], t.mode.Bytes()...)
+}
+
+func (t *treeNoder) Name() string {
+ return t.name
+}
+
+func (t *treeNoder) IsDir() bool {
+ return t.mode == filemode.Dir
+}
+
+// Children will return the children of a treenoder as treenoders,
+// building them from the children of the wrapped git tree.
+func (t *treeNoder) Children() ([]noder.Noder, error) {
+ if t.mode != filemode.Dir {
+ return noder.NoChildren, nil
+ }
+
+ // children are memoized for efficiency
+ if t.children != nil {
+ return t.children, nil
+ }
+
+ // the parent of the returned children will be ourself as a tree if
+ // we are a not the root treenoder. The root is special as it
+ // is is own parent.
+ parent := t.parent
+ if !t.isRoot() {
+ var err error
+ if parent, err = t.parent.Tree(t.name); err != nil {
+ return nil, err
+ }
+ }
+
+ return transformChildren(parent)
+}
+
+// Returns the children of a tree as treenoders.
+// Efficiency is key here.
+func transformChildren(t *Tree) ([]noder.Noder, error) {
+ var err error
+ var e TreeEntry
+
+ // there will be more tree entries than children in the tree,
+ // due to submodules and empty directories, but I think it is still
+ // worth it to pre-allocate the whole array now, even if sometimes
+ // is bigger than needed.
+ ret := make([]noder.Noder, 0, len(t.Entries))
+
+ walker := NewTreeWalker(t, false, nil) // don't recurse
+ // don't defer walker.Close() for efficiency reasons.
+ for {
+ _, e, err = walker.Next()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ walker.Close()
+ return nil, err
+ }
+
+ ret = append(ret, &treeNoder{
+ parent: t,
+ name: e.Name,
+ mode: e.Mode,
+ hash: e.Hash,
+ })
+ }
+ walker.Close()
+
+ return ret, nil
+}
+
+// len(t.tree.Entries) != the number of elements walked by treewalker
+// for some reason because of empty directories, submodules, etc, so we
+// have to walk here.
+func (t *treeNoder) NumChildren() (int, error) {
+ children, err := t.Children()
+ if err != nil {
+ return 0, err
+ }
+
+ return len(children), nil
+}
--- /dev/null
+package packp
+
+import (
+ "fmt"
+ "sort"
+ "strings"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/capability"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/storage/memory"
+)
+
+// AdvRefs values represent the information transmitted on an
+// advertised-refs message. Values from this type are not zero-value
+// safe, use the New function instead.
+type AdvRefs struct {
+ // Prefix stores prefix payloads.
+ //
+ // When using this message over (smart) HTTP, you have to add a pktline
+ // before the whole thing with the following payload:
+ //
+ // '# service=$servicename" LF
+ //
+ // Moreover, some (all) git HTTP smart servers will send a flush-pkt
+ // just after the first pkt-line.
+ //
+ // To accommodate both situations, the Prefix field allow you to store
+ // any data you want to send before the actual pktlines. It will also
+ // be filled up with whatever is found on the line.
+ Prefix [][]byte
+ // Head stores the resolved HEAD reference if present.
+ // This can be present with git-upload-pack, not with git-receive-pack.
+ Head *plumbing.Hash
+ // Capabilities are the capabilities.
+ Capabilities *capability.List
+ // References are the hash references.
+ References map[string]plumbing.Hash
+ // Peeled are the peeled hash references.
+ Peeled map[string]plumbing.Hash
+ // Shallows are the shallow object ids.
+ Shallows []plumbing.Hash
+}
+
+// NewAdvRefs returns a pointer to a new AdvRefs value, ready to be used.
+func NewAdvRefs() *AdvRefs {
+ return &AdvRefs{
+ Prefix: [][]byte{},
+ Capabilities: capability.NewList(),
+ References: make(map[string]plumbing.Hash),
+ Peeled: make(map[string]plumbing.Hash),
+ Shallows: []plumbing.Hash{},
+ }
+}
+
+func (a *AdvRefs) AddReference(r *plumbing.Reference) error {
+ switch r.Type() {
+ case plumbing.SymbolicReference:
+ v := fmt.Sprintf("%s:%s", r.Name().String(), r.Target().String())
+ a.Capabilities.Add(capability.SymRef, v)
+ case plumbing.HashReference:
+ a.References[r.Name().String()] = r.Hash()
+ default:
+ return plumbing.ErrInvalidType
+ }
+
+ return nil
+}
+
+func (a *AdvRefs) AllReferences() (memory.ReferenceStorage, error) {
+ s := memory.ReferenceStorage{}
+ if err := a.addRefs(s); err != nil {
+ return s, plumbing.NewUnexpectedError(err)
+ }
+
+ return s, nil
+}
+
+func (a *AdvRefs) addRefs(s storer.ReferenceStorer) error {
+ for name, hash := range a.References {
+ ref := plumbing.NewReferenceFromStrings(name, hash.String())
+ if err := s.SetReference(ref); err != nil {
+ return err
+ }
+ }
+
+ if a.supportSymrefs() {
+ return a.addSymbolicRefs(s)
+ }
+
+ return a.resolveHead(s)
+}
+
+// If the server does not support symrefs capability,
+// we need to guess the reference where HEAD is pointing to.
+//
+// Git versions prior to 1.8.4.3 has an special procedure to get
+// the reference where is pointing to HEAD:
+// - Check if a reference called master exists. If exists and it
+// has the same hash as HEAD hash, we can say that HEAD is pointing to master
+// - If master does not exists or does not have the same hash as HEAD,
+// order references and check in that order if that reference has the same
+// hash than HEAD. If yes, set HEAD pointing to that branch hash
+// - If no reference is found, throw an error
+func (a *AdvRefs) resolveHead(s storer.ReferenceStorer) error {
+ if a.Head == nil {
+ return nil
+ }
+
+ ref, err := s.Reference(plumbing.ReferenceName(plumbing.Master))
+
+ // check first if HEAD is pointing to master
+ if err == nil {
+ ok, err := a.createHeadIfCorrectReference(ref, s)
+ if err != nil {
+ return err
+ }
+
+ if ok {
+ return nil
+ }
+ }
+
+ if err != nil && err != plumbing.ErrReferenceNotFound {
+ return err
+ }
+
+ // From here we are trying to guess the branch that HEAD is pointing
+ refIter, err := s.IterReferences()
+ if err != nil {
+ return err
+ }
+
+ var refNames []string
+ err = refIter.ForEach(func(r *plumbing.Reference) error {
+ refNames = append(refNames, string(r.Name()))
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+
+ sort.Strings(refNames)
+
+ var headSet bool
+ for _, refName := range refNames {
+ ref, err := s.Reference(plumbing.ReferenceName(refName))
+ if err != nil {
+ return err
+ }
+ ok, err := a.createHeadIfCorrectReference(ref, s)
+ if err != nil {
+ return err
+ }
+ if ok {
+ headSet = true
+ break
+ }
+ }
+
+ if !headSet {
+ return plumbing.ErrReferenceNotFound
+ }
+
+ return nil
+}
+
+func (a *AdvRefs) createHeadIfCorrectReference(
+ reference *plumbing.Reference,
+ s storer.ReferenceStorer) (bool, error) {
+ if reference.Hash() == *a.Head {
+ headRef := plumbing.NewSymbolicReference(plumbing.HEAD, reference.Name())
+ if err := s.SetReference(headRef); err != nil {
+ return false, err
+ }
+
+ return true, nil
+ }
+
+ return false, nil
+}
+
+func (a *AdvRefs) addSymbolicRefs(s storer.ReferenceStorer) error {
+ for _, symref := range a.Capabilities.Get(capability.SymRef) {
+ chunks := strings.Split(symref, ":")
+ if len(chunks) != 2 {
+ err := fmt.Errorf("bad number of `:` in symref value (%q)", symref)
+ return plumbing.NewUnexpectedError(err)
+ }
+ name := plumbing.ReferenceName(chunks[0])
+ target := plumbing.ReferenceName(chunks[1])
+ ref := plumbing.NewSymbolicReference(name, target)
+ if err := s.SetReference(ref); err != nil {
+ return nil
+ }
+ }
+
+ return nil
+}
+
+func (a *AdvRefs) supportSymrefs() bool {
+ return a.Capabilities.Supports(capability.SymRef)
+}
--- /dev/null
+package packp
+
+import (
+ "bytes"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
+)
+
+// Decode reads the next advertised-refs message form its input and
+// stores it in the AdvRefs.
+func (a *AdvRefs) Decode(r io.Reader) error {
+ d := newAdvRefsDecoder(r)
+ return d.Decode(a)
+}
+
+type advRefsDecoder struct {
+ s *pktline.Scanner // a pkt-line scanner from the input stream
+ line []byte // current pkt-line contents, use parser.nextLine() to make it advance
+ nLine int // current pkt-line number for debugging, begins at 1
+ hash plumbing.Hash // last hash read
+ err error // sticky error, use the parser.error() method to fill this out
+ data *AdvRefs // parsed data is stored here
+}
+
+var (
+ // ErrEmptyAdvRefs is returned by Decode if it gets an empty advertised
+ // references message.
+ ErrEmptyAdvRefs = errors.New("empty advertised-ref message")
+ // ErrEmptyInput is returned by Decode if the input is empty.
+ ErrEmptyInput = errors.New("empty input")
+)
+
+func newAdvRefsDecoder(r io.Reader) *advRefsDecoder {
+ return &advRefsDecoder{
+ s: pktline.NewScanner(r),
+ }
+}
+
+func (d *advRefsDecoder) Decode(v *AdvRefs) error {
+ d.data = v
+
+ for state := decodePrefix; state != nil; {
+ state = state(d)
+ }
+
+ return d.err
+}
+
+type decoderStateFn func(*advRefsDecoder) decoderStateFn
+
+// fills out the parser stiky error
+func (d *advRefsDecoder) error(format string, a ...interface{}) {
+ msg := fmt.Sprintf(
+ "pkt-line %d: %s", d.nLine,
+ fmt.Sprintf(format, a...),
+ )
+
+ d.err = NewErrUnexpectedData(msg, d.line)
+}
+
+// Reads a new pkt-line from the scanner, makes its payload available as
+// p.line and increments p.nLine. A successful invocation returns true,
+// otherwise, false is returned and the sticky error is filled out
+// accordingly. Trims eols at the end of the payloads.
+func (d *advRefsDecoder) nextLine() bool {
+ d.nLine++
+
+ if !d.s.Scan() {
+ if d.err = d.s.Err(); d.err != nil {
+ return false
+ }
+
+ if d.nLine == 1 {
+ d.err = ErrEmptyInput
+ return false
+ }
+
+ d.error("EOF")
+ return false
+ }
+
+ d.line = d.s.Bytes()
+ d.line = bytes.TrimSuffix(d.line, eol)
+
+ return true
+}
+
+// The HTTP smart prefix is often followed by a flush-pkt.
+func decodePrefix(d *advRefsDecoder) decoderStateFn {
+ if ok := d.nextLine(); !ok {
+ return nil
+ }
+
+ if !isPrefix(d.line) {
+ return decodeFirstHash
+ }
+
+ tmp := make([]byte, len(d.line))
+ copy(tmp, d.line)
+ d.data.Prefix = append(d.data.Prefix, tmp)
+ if ok := d.nextLine(); !ok {
+ return nil
+ }
+
+ if !isFlush(d.line) {
+ return decodeFirstHash
+ }
+
+ d.data.Prefix = append(d.data.Prefix, pktline.Flush)
+ if ok := d.nextLine(); !ok {
+ return nil
+ }
+
+ return decodeFirstHash
+}
+
+func isPrefix(payload []byte) bool {
+ return len(payload) > 0 && payload[0] == '#'
+}
+
+// If the first hash is zero, then a no-refs is coming. Otherwise, a
+// list-of-refs is coming, and the hash will be followed by the first
+// advertised ref.
+func decodeFirstHash(p *advRefsDecoder) decoderStateFn {
+ // If the repository is empty, we receive a flush here (HTTP).
+ if isFlush(p.line) {
+ p.err = ErrEmptyAdvRefs
+ return nil
+ }
+
+ if len(p.line) < hashSize {
+ p.error("cannot read hash, pkt-line too short")
+ return nil
+ }
+
+ if _, err := hex.Decode(p.hash[:], p.line[:hashSize]); err != nil {
+ p.error("invalid hash text: %s", err)
+ return nil
+ }
+
+ p.line = p.line[hashSize:]
+
+ if p.hash.IsZero() {
+ return decodeSkipNoRefs
+ }
+
+ return decodeFirstRef
+}
+
+// Skips SP "capabilities^{}" NUL
+func decodeSkipNoRefs(p *advRefsDecoder) decoderStateFn {
+ if len(p.line) < len(noHeadMark) {
+ p.error("too short zero-id ref")
+ return nil
+ }
+
+ if !bytes.HasPrefix(p.line, noHeadMark) {
+ p.error("malformed zero-id ref")
+ return nil
+ }
+
+ p.line = p.line[len(noHeadMark):]
+
+ return decodeCaps
+}
+
+// decode the refname, expects SP refname NULL
+func decodeFirstRef(l *advRefsDecoder) decoderStateFn {
+ if len(l.line) < 3 {
+ l.error("line too short after hash")
+ return nil
+ }
+
+ if !bytes.HasPrefix(l.line, sp) {
+ l.error("no space after hash")
+ return nil
+ }
+ l.line = l.line[1:]
+
+ chunks := bytes.SplitN(l.line, null, 2)
+ if len(chunks) < 2 {
+ l.error("NULL not found")
+ return nil
+ }
+ ref := chunks[0]
+ l.line = chunks[1]
+
+ if bytes.Equal(ref, []byte(head)) {
+ l.data.Head = &l.hash
+ } else {
+ l.data.References[string(ref)] = l.hash
+ }
+
+ return decodeCaps
+}
+
+func decodeCaps(p *advRefsDecoder) decoderStateFn {
+ if err := p.data.Capabilities.Decode(p.line); err != nil {
+ p.error("invalid capabilities: %s", err)
+ return nil
+ }
+
+ return decodeOtherRefs
+}
+
+// The refs are either tips (obj-id SP refname) or a peeled (obj-id SP refname^{}).
+// If there are no refs, then there might be a shallow or flush-ptk.
+func decodeOtherRefs(p *advRefsDecoder) decoderStateFn {
+ if ok := p.nextLine(); !ok {
+ return nil
+ }
+
+ if bytes.HasPrefix(p.line, shallow) {
+ return decodeShallow
+ }
+
+ if len(p.line) == 0 {
+ return nil
+ }
+
+ saveTo := p.data.References
+ if bytes.HasSuffix(p.line, peeled) {
+ p.line = bytes.TrimSuffix(p.line, peeled)
+ saveTo = p.data.Peeled
+ }
+
+ ref, hash, err := readRef(p.line)
+ if err != nil {
+ p.error("%s", err)
+ return nil
+ }
+ saveTo[ref] = hash
+
+ return decodeOtherRefs
+}
+
+// Reads a ref-name
+func readRef(data []byte) (string, plumbing.Hash, error) {
+ chunks := bytes.Split(data, sp)
+ switch {
+ case len(chunks) == 1:
+ return "", plumbing.ZeroHash, fmt.Errorf("malformed ref data: no space was found")
+ case len(chunks) > 2:
+ return "", plumbing.ZeroHash, fmt.Errorf("malformed ref data: more than one space found")
+ default:
+ return string(chunks[1]), plumbing.NewHash(string(chunks[0])), nil
+ }
+}
+
+// Keeps reading shallows until a flush-pkt is found
+func decodeShallow(p *advRefsDecoder) decoderStateFn {
+ if !bytes.HasPrefix(p.line, shallow) {
+ p.error("malformed shallow prefix, found %q... instead", p.line[:len(shallow)])
+ return nil
+ }
+ p.line = bytes.TrimPrefix(p.line, shallow)
+
+ if len(p.line) != hashSize {
+ p.error(fmt.Sprintf(
+ "malformed shallow hash: wrong length, expected 40 bytes, read %d bytes",
+ len(p.line)))
+ return nil
+ }
+
+ text := p.line[:hashSize]
+ var h plumbing.Hash
+ if _, err := hex.Decode(h[:], text); err != nil {
+ p.error("invalid hash text: %s", err)
+ return nil
+ }
+
+ p.data.Shallows = append(p.data.Shallows, h)
+
+ if ok := p.nextLine(); !ok {
+ return nil
+ }
+
+ if len(p.line) == 0 {
+ return nil // succesfull parse of the advertised-refs message
+ }
+
+ return decodeShallow
+}
--- /dev/null
+package packp
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "sort"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/capability"
+)
+
+// Encode writes the AdvRefs encoding to a writer.
+//
+// All the payloads will end with a newline character. Capabilities,
+// references and shallows are written in alphabetical order, except for
+// peeled references that always follow their corresponding references.
+func (a *AdvRefs) Encode(w io.Writer) error {
+ e := newAdvRefsEncoder(w)
+ return e.Encode(a)
+}
+
+type advRefsEncoder struct {
+ data *AdvRefs // data to encode
+ pe *pktline.Encoder // where to write the encoded data
+ firstRefName string // reference name to encode in the first pkt-line (HEAD if present)
+ firstRefHash plumbing.Hash // hash referenced to encode in the first pkt-line (HEAD if present)
+ sortedRefs []string // hash references to encode ordered by increasing order
+ err error // sticky error
+
+}
+
+func newAdvRefsEncoder(w io.Writer) *advRefsEncoder {
+ return &advRefsEncoder{
+ pe: pktline.NewEncoder(w),
+ }
+}
+
+func (e *advRefsEncoder) Encode(v *AdvRefs) error {
+ e.data = v
+ e.sortRefs()
+ e.setFirstRef()
+
+ for state := encodePrefix; state != nil; {
+ state = state(e)
+ }
+
+ return e.err
+}
+
+func (e *advRefsEncoder) sortRefs() {
+ if len(e.data.References) > 0 {
+ refs := make([]string, 0, len(e.data.References))
+ for refName := range e.data.References {
+ refs = append(refs, refName)
+ }
+
+ sort.Strings(refs)
+ e.sortedRefs = refs
+ }
+}
+
+func (e *advRefsEncoder) setFirstRef() {
+ if e.data.Head != nil {
+ e.firstRefName = head
+ e.firstRefHash = *e.data.Head
+ return
+ }
+
+ if len(e.sortedRefs) > 0 {
+ refName := e.sortedRefs[0]
+ e.firstRefName = refName
+ e.firstRefHash = e.data.References[refName]
+ }
+}
+
+type encoderStateFn func(*advRefsEncoder) encoderStateFn
+
+func encodePrefix(e *advRefsEncoder) encoderStateFn {
+ for _, p := range e.data.Prefix {
+ if bytes.Equal(p, pktline.Flush) {
+ if e.err = e.pe.Flush(); e.err != nil {
+ return nil
+ }
+ continue
+ }
+ if e.err = e.pe.Encodef("%s\n", string(p)); e.err != nil {
+ return nil
+ }
+ }
+
+ return encodeFirstLine
+}
+
+// Adds the first pkt-line payload: head hash, head ref and capabilities.
+// If HEAD ref is not found, the first reference ordered in increasing order will be used.
+// If there aren't HEAD neither refs, the first line will be "PKT-LINE(zero-id SP "capabilities^{}" NUL capability-list)".
+// See: https://github.com/git/git/blob/master/Documentation/technical/pack-protocol.txt
+// See: https://github.com/git/git/blob/master/Documentation/technical/protocol-common.txt
+func encodeFirstLine(e *advRefsEncoder) encoderStateFn {
+ const formatFirstLine = "%s %s\x00%s\n"
+ var firstLine string
+ capabilities := formatCaps(e.data.Capabilities)
+
+ if e.firstRefName == "" {
+ firstLine = fmt.Sprintf(formatFirstLine, plumbing.ZeroHash.String(), "capabilities^{}", capabilities)
+ } else {
+ firstLine = fmt.Sprintf(formatFirstLine, e.firstRefHash.String(), e.firstRefName, capabilities)
+
+ }
+
+ if e.err = e.pe.EncodeString(firstLine); e.err != nil {
+ return nil
+ }
+
+ return encodeRefs
+}
+
+func formatCaps(c *capability.List) string {
+ if c == nil {
+ return ""
+ }
+
+ return c.String()
+}
+
+// Adds the (sorted) refs: hash SP refname EOL
+// and their peeled refs if any.
+func encodeRefs(e *advRefsEncoder) encoderStateFn {
+ for _, r := range e.sortedRefs {
+ if r == e.firstRefName {
+ continue
+ }
+
+ hash := e.data.References[r]
+ if e.err = e.pe.Encodef("%s %s\n", hash.String(), r); e.err != nil {
+ return nil
+ }
+
+ if hash, ok := e.data.Peeled[r]; ok {
+ if e.err = e.pe.Encodef("%s %s^{}\n", hash.String(), r); e.err != nil {
+ return nil
+ }
+ }
+ }
+
+ return encodeShallow
+}
+
+// Adds the (sorted) shallows: "shallow" SP hash EOL
+func encodeShallow(e *advRefsEncoder) encoderStateFn {
+ sorted := sortShallows(e.data.Shallows)
+ for _, hash := range sorted {
+ if e.err = e.pe.Encodef("shallow %s\n", hash); e.err != nil {
+ return nil
+ }
+ }
+
+ return encodeFlush
+}
+
+func sortShallows(c []plumbing.Hash) []string {
+ ret := []string{}
+ for _, h := range c {
+ ret = append(ret, h.String())
+ }
+ sort.Strings(ret)
+
+ return ret
+}
+
+func encodeFlush(e *advRefsEncoder) encoderStateFn {
+ e.err = e.pe.Flush()
+ return nil
+}
--- /dev/null
+// Package capability defines the server and client capabilities.
+package capability
+
+// Capability describes a server or client capability.
+type Capability string
+
+func (n Capability) String() string {
+ return string(n)
+}
+
+const (
+ // MultiACK capability allows the server to return "ACK obj-id continue" as
+ // soon as it finds a commit that it can use as a common base, between the
+ // client's wants and the client's have set.
+ //
+ // By sending this early, the server can potentially head off the client
+ // from walking any further down that particular branch of the client's
+ // repository history. The client may still need to walk down other
+ // branches, sending have lines for those, until the server has a
+ // complete cut across the DAG, or the client has said "done".
+ //
+ // Without multi_ack, a client sends have lines in --date-order until
+ // the server has found a common base. That means the client will send
+ // have lines that are already known by the server to be common, because
+ // they overlap in time with another branch that the server hasn't found
+ // a common base on yet.
+ //
+ // For example suppose the client has commits in caps that the server
+ // doesn't and the server has commits in lower case that the client
+ // doesn't, as in the following diagram:
+ //
+ // +---- u ---------------------- x
+ // / +----- y
+ // / /
+ // a -- b -- c -- d -- E -- F
+ // \
+ // +--- Q -- R -- S
+ //
+ // If the client wants x,y and starts out by saying have F,S, the server
+ // doesn't know what F,S is. Eventually the client says "have d" and
+ // the server sends "ACK d continue" to let the client know to stop
+ // walking down that line (so don't send c-b-a), but it's not done yet,
+ // it needs a base for x. The client keeps going with S-R-Q, until a
+ // gets reached, at which point the server has a clear base and it all
+ // ends.
+ //
+ // Without multi_ack the client would have sent that c-b-a chain anyway,
+ // interleaved with S-R-Q.
+ MultiACK Capability = "multi_ack"
+ // MultiACKDetailed is an extension of multi_ack that permits client to
+ // better understand the server's in-memory state.
+ MultiACKDetailed Capability = "multi_ack_detailed"
+ // NoDone should only be used with the smart HTTP protocol. If
+ // multi_ack_detailed and no-done are both present, then the sender is
+ // free to immediately send a pack following its first "ACK obj-id ready"
+ // message.
+ //
+ // Without no-done in the smart HTTP protocol, the server session would
+ // end and the client has to make another trip to send "done" before
+ // the server can send the pack. no-done removes the last round and
+ // thus slightly reduces latency.
+ NoDone Capability = "no-done"
+ // ThinPack is one with deltas which reference base objects not
+ // contained within the pack (but are known to exist at the receiving
+ // end). This can reduce the network traffic significantly, but it
+ // requires the receiving end to know how to "thicken" these packs by
+ // adding the missing bases to the pack.
+ //
+ // The upload-pack server advertises 'thin-pack' when it can generate
+ // and send a thin pack. A client requests the 'thin-pack' capability
+ // when it understands how to "thicken" it, notifying the server that
+ // it can receive such a pack. A client MUST NOT request the
+ // 'thin-pack' capability if it cannot turn a thin pack into a
+ // self-contained pack.
+ //
+ // Receive-pack, on the other hand, is assumed by default to be able to
+ // handle thin packs, but can ask the client not to use the feature by
+ // advertising the 'no-thin' capability. A client MUST NOT send a thin
+ // pack if the server advertises the 'no-thin' capability.
+ //
+ // The reasons for this asymmetry are historical. The receive-pack
+ // program did not exist until after the invention of thin packs, so
+ // historically the reference implementation of receive-pack always
+ // understood thin packs. Adding 'no-thin' later allowed receive-pack
+ // to disable the feature in a backwards-compatible manner.
+ ThinPack Capability = "thin-pack"
+ // Sideband means that server can send, and client understand multiplexed
+ // progress reports and error info interleaved with the packfile itself.
+ //
+ // These two options are mutually exclusive. A modern client always
+ // favors Sideband64k.
+ //
+ // Either mode indicates that the packfile data will be streamed broken
+ // up into packets of up to either 1000 bytes in the case of 'side_band',
+ // or 65520 bytes in the case of 'side_band_64k'. Each packet is made up
+ // of a leading 4-byte pkt-line length of how much data is in the packet,
+ // followed by a 1-byte stream code, followed by the actual data.
+ //
+ // The stream code can be one of:
+ //
+ // 1 - pack data
+ // 2 - progress messages
+ // 3 - fatal error message just before stream aborts
+ //
+ // The "side-band-64k" capability came about as a way for newer clients
+ // that can handle much larger packets to request packets that are
+ // actually crammed nearly full, while maintaining backward compatibility
+ // for the older clients.
+ //
+ // Further, with side-band and its up to 1000-byte messages, it's actually
+ // 999 bytes of payload and 1 byte for the stream code. With side-band-64k,
+ // same deal, you have up to 65519 bytes of data and 1 byte for the stream
+ // code.
+ //
+ // The client MUST send only maximum of one of "side-band" and "side-
+ // band-64k". Server MUST diagnose it as an error if client requests
+ // both.
+ Sideband Capability = "side-band"
+ Sideband64k Capability = "side-band-64k"
+ // OFSDelta server can send, and client understand PACKv2 with delta
+ // referring to its base by position in pack rather than by an obj-id. That
+ // is, they can send/read OBJ_OFS_DELTA (aka type 6) in a packfile.
+ OFSDelta Capability = "ofs-delta"
+ // Agent the server may optionally send this capability to notify the client
+ // that the server is running version `X`. The client may optionally return
+ // its own agent string by responding with an `agent=Y` capability (but it
+ // MUST NOT do so if the server did not mention the agent capability). The
+ // `X` and `Y` strings may contain any printable ASCII characters except
+ // space (i.e., the byte range 32 < x < 127), and are typically of the form
+ // "package/version" (e.g., "git/1.8.3.1"). The agent strings are purely
+ // informative for statistics and debugging purposes, and MUST NOT be used
+ // to programmatically assume the presence or absence of particular features.
+ Agent Capability = "agent"
+ // Shallow capability adds "deepen", "shallow" and "unshallow" commands to
+ // the fetch-pack/upload-pack protocol so clients can request shallow
+ // clones.
+ Shallow Capability = "shallow"
+ // DeepenSince adds "deepen-since" command to fetch-pack/upload-pack
+ // protocol so the client can request shallow clones that are cut at a
+ // specific time, instead of depth. Internally it's equivalent of doing
+ // "rev-list --max-age=<timestamp>" on the server side. "deepen-since"
+ // cannot be used with "deepen".
+ DeepenSince Capability = "deepen-since"
+ // DeepenNot adds "deepen-not" command to fetch-pack/upload-pack
+ // protocol so the client can request shallow clones that are cut at a
+ // specific revision, instead of depth. Internally it's equivalent of
+ // doing "rev-list --not <rev>" on the server side. "deepen-not"
+ // cannot be used with "deepen", but can be used with "deepen-since".
+ DeepenNot Capability = "deepen-not"
+ // DeepenRelative if this capability is requested by the client, the
+ // semantics of "deepen" command is changed. The "depth" argument is the
+ // depth from the current shallow boundary, instead of the depth from
+ // remote refs.
+ DeepenRelative Capability = "deepen-relative"
+ // NoProgress the client was started with "git clone -q" or something, and
+ // doesn't want that side band 2. Basically the client just says "I do not
+ // wish to receive stream 2 on sideband, so do not send it to me, and if
+ // you did, I will drop it on the floor anyway". However, the sideband
+ // channel 3 is still used for error responses.
+ NoProgress Capability = "no-progress"
+ // IncludeTag capability is about sending annotated tags if we are
+ // sending objects they point to. If we pack an object to the client, and
+ // a tag object points exactly at that object, we pack the tag object too.
+ // In general this allows a client to get all new annotated tags when it
+ // fetches a branch, in a single network connection.
+ //
+ // Clients MAY always send include-tag, hardcoding it into a request when
+ // the server advertises this capability. The decision for a client to
+ // request include-tag only has to do with the client's desires for tag
+ // data, whether or not a server had advertised objects in the
+ // refs/tags/* namespace.
+ //
+ // Servers MUST pack the tags if their referrant is packed and the client
+ // has requested include-tags.
+ //
+ // Clients MUST be prepared for the case where a server has ignored
+ // include-tag and has not actually sent tags in the pack. In such
+ // cases the client SHOULD issue a subsequent fetch to acquire the tags
+ // that include-tag would have otherwise given the client.
+ //
+ // The server SHOULD send include-tag, if it supports it, regardless
+ // of whether or not there are tags available.
+ IncludeTag Capability = "include-tag"
+ // ReportStatus the receive-pack process can receive a 'report-status'
+ // capability, which tells it that the client wants a report of what
+ // happened after a packfile upload and reference update. If the pushing
+ // client requests this capability, after unpacking and updating references
+ // the server will respond with whether the packfile unpacked successfully
+ // and if each reference was updated successfully. If any of those were not
+ // successful, it will send back an error message. See pack-protocol.txt
+ // for example messages.
+ ReportStatus Capability = "report-status"
+ // DeleteRefs If the server sends back this capability, it means that
+ // it is capable of accepting a zero-id value as the target
+ // value of a reference update. It is not sent back by the client, it
+ // simply informs the client that it can be sent zero-id values
+ // to delete references
+ DeleteRefs Capability = "delete-refs"
+ // Quiet If the receive-pack server advertises this capability, it is
+ // capable of silencing human-readable progress output which otherwise may
+ // be shown when processing the received pack. A send-pack client should
+ // respond with the 'quiet' capability to suppress server-side progress
+ // reporting if the local progress reporting is also being suppressed
+ // (e.g., via `push -q`, or if stderr does not go to a tty).
+ Quiet Capability = "quiet"
+ // Atomic If the server sends this capability it is capable of accepting
+ // atomic pushes. If the pushing client requests this capability, the server
+ // will update the refs in one atomic transaction. Either all refs are
+ // updated or none.
+ Atomic Capability = "atomic"
+ // PushOptions If the server sends this capability it is able to accept
+ // push options after the update commands have been sent, but before the
+ // packfile is streamed. If the pushing client requests this capability,
+ // the server will pass the options to the pre- and post- receive hooks
+ // that process this push request.
+ PushOptions Capability = "push-options"
+ // AllowTipSHA1InWant if the upload-pack server advertises this capability,
+ // fetch-pack may send "want" lines with SHA-1s that exist at the server but
+ // are not advertised by upload-pack.
+ AllowTipSHA1InWant Capability = "allow-tip-sha1-in-want"
+ // AllowReachableSHA1InWant if the upload-pack server advertises this
+ // capability, fetch-pack may send "want" lines with SHA-1s that exist at
+ // the server but are not advertised by upload-pack.
+ AllowReachableSHA1InWant Capability = "allow-reachable-sha1-in-want"
+ // PushCert the receive-pack server that advertises this capability is
+ // willing to accept a signed push certificate, and asks the <nonce> to be
+ // included in the push certificate. A send-pack client MUST NOT
+ // send a push-cert packet unless the receive-pack server advertises
+ // this capability.
+ PushCert Capability = "push-cert"
+ // SymRef symbolic reference support for better negotiation.
+ SymRef Capability = "symref"
+)
+
+const DefaultAgent = "go-git/4.x"
+
+var known = map[Capability]bool{
+ MultiACK: true, MultiACKDetailed: true, NoDone: true, ThinPack: true,
+ Sideband: true, Sideband64k: true, OFSDelta: true, Agent: true,
+ Shallow: true, DeepenSince: true, DeepenNot: true, DeepenRelative: true,
+ NoProgress: true, IncludeTag: true, ReportStatus: true, DeleteRefs: true,
+ Quiet: true, Atomic: true, PushOptions: true, AllowTipSHA1InWant: true,
+ AllowReachableSHA1InWant: true, PushCert: true, SymRef: true,
+}
+
+var requiresArgument = map[Capability]bool{
+ Agent: true, PushCert: true, SymRef: true,
+}
+
+var multipleArgument = map[Capability]bool{
+ SymRef: true,
+}
--- /dev/null
+package capability
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "strings"
+)
+
+var (
+ // ErrArgumentsRequired is returned if no arguments are giving with a
+ // capability that requires arguments
+ ErrArgumentsRequired = errors.New("arguments required")
+ // ErrArguments is returned if arguments are given with a capabilities that
+ // not supports arguments
+ ErrArguments = errors.New("arguments not allowed")
+ // ErrEmtpyArgument is returned when an empty value is given
+ ErrEmtpyArgument = errors.New("empty argument")
+ // ErrMultipleArguments multiple argument given to a capabilities that not
+ // support it
+ ErrMultipleArguments = errors.New("multiple arguments not allowed")
+)
+
+// List represents a list of capabilities
+type List struct {
+ m map[Capability]*entry
+ sort []string
+}
+
+type entry struct {
+ Name Capability
+ Values []string
+}
+
+// NewList returns a new List of capabilities
+func NewList() *List {
+ return &List{
+ m: make(map[Capability]*entry),
+ }
+}
+
+// IsEmpty returns true if the List is empty
+func (l *List) IsEmpty() bool {
+ return len(l.sort) == 0
+}
+
+// Decode decodes list of capabilities from raw into the list
+func (l *List) Decode(raw []byte) error {
+ // git 1.x receive pack used to send a leading space on its
+ // git-receive-pack capabilities announcement. We just trim space to be
+ // tolerant to space changes in different versions.
+ raw = bytes.TrimSpace(raw)
+
+ if len(raw) == 0 {
+ return nil
+ }
+
+ for _, data := range bytes.Split(raw, []byte{' '}) {
+ pair := bytes.SplitN(data, []byte{'='}, 2)
+
+ c := Capability(pair[0])
+ if len(pair) == 1 {
+ if err := l.Add(c); err != nil {
+ return err
+ }
+
+ continue
+ }
+
+ if err := l.Add(c, string(pair[1])); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Get returns the values for a capability
+func (l *List) Get(capability Capability) []string {
+ if _, ok := l.m[capability]; !ok {
+ return nil
+ }
+
+ return l.m[capability].Values
+}
+
+// Set sets a capability removing the previous values
+func (l *List) Set(capability Capability, values ...string) error {
+ if _, ok := l.m[capability]; ok {
+ delete(l.m, capability)
+ }
+
+ return l.Add(capability, values...)
+}
+
+// Add adds a capability, values are optional
+func (l *List) Add(c Capability, values ...string) error {
+ if err := l.validate(c, values); err != nil {
+ return err
+ }
+
+ if !l.Supports(c) {
+ l.m[c] = &entry{Name: c}
+ l.sort = append(l.sort, c.String())
+ }
+
+ if len(values) == 0 {
+ return nil
+ }
+
+ if known[c] && !multipleArgument[c] && len(l.m[c].Values) > 0 {
+ return ErrMultipleArguments
+ }
+
+ l.m[c].Values = append(l.m[c].Values, values...)
+ return nil
+}
+
+func (l *List) validateNoEmptyArgs(values []string) error {
+ for _, v := range values {
+ if v == "" {
+ return ErrEmtpyArgument
+ }
+ }
+ return nil
+}
+
+func (l *List) validate(c Capability, values []string) error {
+ if !known[c] {
+ return l.validateNoEmptyArgs(values)
+ }
+ if requiresArgument[c] && len(values) == 0 {
+ return ErrArgumentsRequired
+ }
+
+ if !requiresArgument[c] && len(values) != 0 {
+ return ErrArguments
+ }
+
+ if !multipleArgument[c] && len(values) > 1 {
+ return ErrMultipleArguments
+ }
+ return l.validateNoEmptyArgs(values)
+}
+
+// Supports returns true if capability is present
+func (l *List) Supports(capability Capability) bool {
+ _, ok := l.m[capability]
+ return ok
+}
+
+// Delete deletes a capability from the List
+func (l *List) Delete(capability Capability) {
+ if !l.Supports(capability) {
+ return
+ }
+
+ delete(l.m, capability)
+ for i, c := range l.sort {
+ if c != string(capability) {
+ continue
+ }
+
+ l.sort = append(l.sort[:i], l.sort[i+1:]...)
+ return
+ }
+}
+
+// All returns a slice with all defined capabilities.
+func (l *List) All() []Capability {
+ var cs []Capability
+ for _, key := range l.sort {
+ cs = append(cs, Capability(key))
+ }
+
+ return cs
+}
+
+// String generates the capabilities strings, the capabilities are sorted in
+// insertion order
+func (l *List) String() string {
+ var o []string
+ for _, key := range l.sort {
+ cap := l.m[Capability(key)]
+ if len(cap.Values) == 0 {
+ o = append(o, key)
+ continue
+ }
+
+ for _, value := range cap.Values {
+ o = append(o, fmt.Sprintf("%s=%s", key, value))
+ }
+ }
+
+ return strings.Join(o, " ")
+}
--- /dev/null
+package packp
+
+import (
+ "fmt"
+)
+
+type stateFn func() stateFn
+
+const (
+ // common
+ hashSize = 40
+
+ // advrefs
+ head = "HEAD"
+ noHead = "capabilities^{}"
+)
+
+var (
+ // common
+ sp = []byte(" ")
+ eol = []byte("\n")
+ eq = []byte{'='}
+
+ // advertised-refs
+ null = []byte("\x00")
+ peeled = []byte("^{}")
+ noHeadMark = []byte(" capabilities^{}\x00")
+
+ // upload-request
+ want = []byte("want ")
+ shallow = []byte("shallow ")
+ deepen = []byte("deepen")
+ deepenCommits = []byte("deepen ")
+ deepenSince = []byte("deepen-since ")
+ deepenReference = []byte("deepen-not ")
+
+ // shallow-update
+ unshallow = []byte("unshallow ")
+
+ // server-response
+ ack = []byte("ACK")
+ nak = []byte("NAK")
+
+ // updreq
+ shallowNoSp = []byte("shallow")
+)
+
+func isFlush(payload []byte) bool {
+ return len(payload) == 0
+}
+
+// ErrUnexpectedData represents an unexpected data decoding a message
+type ErrUnexpectedData struct {
+ Msg string
+ Data []byte
+}
+
+// NewErrUnexpectedData returns a new ErrUnexpectedData containing the data and
+// the message given
+func NewErrUnexpectedData(msg string, data []byte) error {
+ return &ErrUnexpectedData{Msg: msg, Data: data}
+}
+
+func (err *ErrUnexpectedData) Error() string {
+ if len(err.Data) == 0 {
+ return err.Msg
+ }
+
+ return fmt.Sprintf("%s (%s)", err.Msg, err.Data)
+}
--- /dev/null
+package packp
+
+/*
+
+A nice way to trace the real data transmitted and received by git, use:
+
+GIT_TRACE_PACKET=true git ls-remote http://github.com/src-d/go-git
+GIT_TRACE_PACKET=true git clone http://github.com/src-d/go-git
+
+Here follows a copy of the current protocol specification at the time of
+this writing.
+
+(Please notice that most http git servers will add a flush-pkt after the
+first pkt-line when using HTTP smart.)
+
+
+Documentation Common to Pack and Http Protocols
+===============================================
+
+ABNF Notation
+-------------
+
+ABNF notation as described by RFC 5234 is used within the protocol documents,
+except the following replacement core rules are used:
+----
+ HEXDIG = DIGIT / "a" / "b" / "c" / "d" / "e" / "f"
+----
+
+We also define the following common rules:
+----
+ NUL = %x00
+ zero-id = 40*"0"
+ obj-id = 40*(HEXDIGIT)
+
+ refname = "HEAD"
+ refname /= "refs/" <see discussion below>
+----
+
+A refname is a hierarchical octet string beginning with "refs/" and
+not violating the 'git-check-ref-format' command's validation rules.
+More specifically, they:
+
+. They can include slash `/` for hierarchical (directory)
+ grouping, but no slash-separated component can begin with a
+ dot `.`.
+
+. They must contain at least one `/`. This enforces the presence of a
+ category like `heads/`, `tags/` etc. but the actual names are not
+ restricted.
+
+. They cannot have two consecutive dots `..` anywhere.
+
+. They cannot have ASCII control characters (i.e. bytes whose
+ values are lower than \040, or \177 `DEL`), space, tilde `~`,
+ caret `^`, colon `:`, question-mark `?`, asterisk `*`,
+ or open bracket `[` anywhere.
+
+. They cannot end with a slash `/` or a dot `.`.
+
+. They cannot end with the sequence `.lock`.
+
+. They cannot contain a sequence `@{`.
+
+. They cannot contain a `\\`.
+
+
+pkt-line Format
+---------------
+
+Much (but not all) of the payload is described around pkt-lines.
+
+A pkt-line is a variable length binary string. The first four bytes
+of the line, the pkt-len, indicates the total length of the line,
+in hexadecimal. The pkt-len includes the 4 bytes used to contain
+the length's hexadecimal representation.
+
+A pkt-line MAY contain binary data, so implementors MUST ensure
+pkt-line parsing/formatting routines are 8-bit clean.
+
+A non-binary line SHOULD BE terminated by an LF, which if present
+MUST be included in the total length. Receivers MUST treat pkt-lines
+with non-binary data the same whether or not they contain the trailing
+LF (stripping the LF if present, and not complaining when it is
+missing).
+
+The maximum length of a pkt-line's data component is 65516 bytes.
+Implementations MUST NOT send pkt-line whose length exceeds 65520
+(65516 bytes of payload + 4 bytes of length data).
+
+Implementations SHOULD NOT send an empty pkt-line ("0004").
+
+A pkt-line with a length field of 0 ("0000"), called a flush-pkt,
+is a special case and MUST be handled differently than an empty
+pkt-line ("0004").
+
+----
+ pkt-line = data-pkt / flush-pkt
+
+ data-pkt = pkt-len pkt-payload
+ pkt-len = 4*(HEXDIG)
+ pkt-payload = (pkt-len - 4)*(OCTET)
+
+ flush-pkt = "0000"
+----
+
+Examples (as C-style strings):
+
+----
+ pkt-line actual value
+ ---------------------------------
+ "0006a\n" "a\n"
+ "0005a" "a"
+ "000bfoobar\n" "foobar\n"
+ "0004" ""
+----
+
+Packfile transfer protocols
+===========================
+
+Git supports transferring data in packfiles over the ssh://, git://, http:// and
+file:// transports. There exist two sets of protocols, one for pushing
+data from a client to a server and another for fetching data from a
+server to a client. The three transports (ssh, git, file) use the same
+protocol to transfer data. http is documented in http-protocol.txt.
+
+The processes invoked in the canonical Git implementation are 'upload-pack'
+on the server side and 'fetch-pack' on the client side for fetching data;
+then 'receive-pack' on the server and 'send-pack' on the client for pushing
+data. The protocol functions to have a server tell a client what is
+currently on the server, then for the two to negotiate the smallest amount
+of data to send in order to fully update one or the other.
+
+pkt-line Format
+---------------
+
+The descriptions below build on the pkt-line format described in
+protocol-common.txt. When the grammar indicate `PKT-LINE(...)`, unless
+otherwise noted the usual pkt-line LF rules apply: the sender SHOULD
+include a LF, but the receiver MUST NOT complain if it is not present.
+
+Transports
+----------
+There are three transports over which the packfile protocol is
+initiated. The Git transport is a simple, unauthenticated server that
+takes the command (almost always 'upload-pack', though Git
+servers can be configured to be globally writable, in which 'receive-
+pack' initiation is also allowed) with which the client wishes to
+communicate and executes it and connects it to the requesting
+process.
+
+In the SSH transport, the client just runs the 'upload-pack'
+or 'receive-pack' process on the server over the SSH protocol and then
+communicates with that invoked process over the SSH connection.
+
+The file:// transport runs the 'upload-pack' or 'receive-pack'
+process locally and communicates with it over a pipe.
+
+Git Transport
+-------------
+
+The Git transport starts off by sending the command and repository
+on the wire using the pkt-line format, followed by a NUL byte and a
+hostname parameter, terminated by a NUL byte.
+
+ 0032git-upload-pack /project.git\0host=myserver.com\0
+
+--
+ git-proto-request = request-command SP pathname NUL [ host-parameter NUL ]
+ request-command = "git-upload-pack" / "git-receive-pack" /
+ "git-upload-archive" ; case sensitive
+ pathname = *( %x01-ff ) ; exclude NUL
+ host-parameter = "host=" hostname [ ":" port ]
+--
+
+Only host-parameter is allowed in the git-proto-request. Clients
+MUST NOT attempt to send additional parameters. It is used for the
+git-daemon name based virtual hosting. See --interpolated-path
+option to git daemon, with the %H/%CH format characters.
+
+Basically what the Git client is doing to connect to an 'upload-pack'
+process on the server side over the Git protocol is this:
+
+ $ echo -e -n \
+ "0039git-upload-pack /schacon/gitbook.git\0host=example.com\0" |
+ nc -v example.com 9418
+
+If the server refuses the request for some reasons, it could abort
+gracefully with an error message.
+
+----
+ error-line = PKT-LINE("ERR" SP explanation-text)
+----
+
+
+SSH Transport
+-------------
+
+Initiating the upload-pack or receive-pack processes over SSH is
+executing the binary on the server via SSH remote execution.
+It is basically equivalent to running this:
+
+ $ ssh git.example.com "git-upload-pack '/project.git'"
+
+For a server to support Git pushing and pulling for a given user over
+SSH, that user needs to be able to execute one or both of those
+commands via the SSH shell that they are provided on login. On some
+systems, that shell access is limited to only being able to run those
+two commands, or even just one of them.
+
+In an ssh:// format URI, it's absolute in the URI, so the '/' after
+the host name (or port number) is sent as an argument, which is then
+read by the remote git-upload-pack exactly as is, so it's effectively
+an absolute path in the remote filesystem.
+
+ git clone ssh://user@example.com/project.git
+ |
+ v
+ ssh user@example.com "git-upload-pack '/project.git'"
+
+In a "user@host:path" format URI, its relative to the user's home
+directory, because the Git client will run:
+
+ git clone user@example.com:project.git
+ |
+ v
+ ssh user@example.com "git-upload-pack 'project.git'"
+
+The exception is if a '~' is used, in which case
+we execute it without the leading '/'.
+
+ ssh://user@example.com/~alice/project.git,
+ |
+ v
+ ssh user@example.com "git-upload-pack '~alice/project.git'"
+
+A few things to remember here:
+
+- The "command name" is spelled with dash (e.g. git-upload-pack), but
+ this can be overridden by the client;
+
+- The repository path is always quoted with single quotes.
+
+Fetching Data From a Server
+---------------------------
+
+When one Git repository wants to get data that a second repository
+has, the first can 'fetch' from the second. This operation determines
+what data the server has that the client does not then streams that
+data down to the client in packfile format.
+
+
+Reference Discovery
+-------------------
+
+When the client initially connects the server will immediately respond
+with a listing of each reference it has (all branches and tags) along
+with the object name that each reference currently points to.
+
+ $ echo -e -n "0039git-upload-pack /schacon/gitbook.git\0host=example.com\0" |
+ nc -v example.com 9418
+ 00887217a7c7e582c46cec22a130adf4b9d7d950fba0 HEAD\0multi_ack thin-pack
+ side-band side-band-64k ofs-delta shallow no-progress include-tag
+ 00441d3fcd5ced445d1abc402225c0b8a1299641f497 refs/heads/integration
+ 003f7217a7c7e582c46cec22a130adf4b9d7d950fba0 refs/heads/master
+ 003cb88d2441cac0977faf98efc80305012112238d9d refs/tags/v0.9
+ 003c525128480b96c89e6418b1e40909bf6c5b2d580f refs/tags/v1.0
+ 003fe92df48743b7bc7d26bcaabfddde0a1e20cae47c refs/tags/v1.0^{}
+ 0000
+
+The returned response is a pkt-line stream describing each ref and
+its current value. The stream MUST be sorted by name according to
+the C locale ordering.
+
+If HEAD is a valid ref, HEAD MUST appear as the first advertised
+ref. If HEAD is not a valid ref, HEAD MUST NOT appear in the
+advertisement list at all, but other refs may still appear.
+
+The stream MUST include capability declarations behind a NUL on the
+first ref. The peeled value of a ref (that is "ref^{}") MUST be
+immediately after the ref itself, if presented. A conforming server
+MUST peel the ref if it's an annotated tag.
+
+----
+ advertised-refs = (no-refs / list-of-refs)
+ *shallow
+ flush-pkt
+
+ no-refs = PKT-LINE(zero-id SP "capabilities^{}"
+ NUL capability-list)
+
+ list-of-refs = first-ref *other-ref
+ first-ref = PKT-LINE(obj-id SP refname
+ NUL capability-list)
+
+ other-ref = PKT-LINE(other-tip / other-peeled)
+ other-tip = obj-id SP refname
+ other-peeled = obj-id SP refname "^{}"
+
+ shallow = PKT-LINE("shallow" SP obj-id)
+
+ capability-list = capability *(SP capability)
+ capability = 1*(LC_ALPHA / DIGIT / "-" / "_")
+ LC_ALPHA = %x61-7A
+----
+
+Server and client MUST use lowercase for obj-id, both MUST treat obj-id
+as case-insensitive.
+
+See protocol-capabilities.txt for a list of allowed server capabilities
+and descriptions.
+
+Packfile Negotiation
+--------------------
+After reference and capabilities discovery, the client can decide to
+terminate the connection by sending a flush-pkt, telling the server it can
+now gracefully terminate, and disconnect, when it does not need any pack
+data. This can happen with the ls-remote command, and also can happen when
+the client already is up-to-date.
+
+Otherwise, it enters the negotiation phase, where the client and
+server determine what the minimal packfile necessary for transport is,
+by telling the server what objects it wants, its shallow objects
+(if any), and the maximum commit depth it wants (if any). The client
+will also send a list of the capabilities it wants to be in effect,
+out of what the server said it could do with the first 'want' line.
+
+----
+ upload-request = want-list
+ *shallow-line
+ *1depth-request
+ flush-pkt
+
+ want-list = first-want
+ *additional-want
+
+ shallow-line = PKT-LINE("shallow" SP obj-id)
+
+ depth-request = PKT-LINE("deepen" SP depth) /
+ PKT-LINE("deepen-since" SP timestamp) /
+ PKT-LINE("deepen-not" SP ref)
+
+ first-want = PKT-LINE("want" SP obj-id SP capability-list)
+ additional-want = PKT-LINE("want" SP obj-id)
+
+ depth = 1*DIGIT
+----
+
+Clients MUST send all the obj-ids it wants from the reference
+discovery phase as 'want' lines. Clients MUST send at least one
+'want' command in the request body. Clients MUST NOT mention an
+obj-id in a 'want' command which did not appear in the response
+obtained through ref discovery.
+
+The client MUST write all obj-ids which it only has shallow copies
+of (meaning that it does not have the parents of a commit) as
+'shallow' lines so that the server is aware of the limitations of
+the client's history.
+
+The client now sends the maximum commit history depth it wants for
+this transaction, which is the number of commits it wants from the
+tip of the history, if any, as a 'deepen' line. A depth of 0 is the
+same as not making a depth request. The client does not want to receive
+any commits beyond this depth, nor does it want objects needed only to
+complete those commits. Commits whose parents are not received as a
+result are defined as shallow and marked as such in the server. This
+information is sent back to the client in the next step.
+
+Once all the 'want's and 'shallow's (and optional 'deepen') are
+transferred, clients MUST send a flush-pkt, to tell the server side
+that it is done sending the list.
+
+Otherwise, if the client sent a positive depth request, the server
+will determine which commits will and will not be shallow and
+send this information to the client. If the client did not request
+a positive depth, this step is skipped.
+
+----
+ shallow-update = *shallow-line
+ *unshallow-line
+ flush-pkt
+
+ shallow-line = PKT-LINE("shallow" SP obj-id)
+
+ unshallow-line = PKT-LINE("unshallow" SP obj-id)
+----
+
+If the client has requested a positive depth, the server will compute
+the set of commits which are no deeper than the desired depth. The set
+of commits start at the client's wants.
+
+The server writes 'shallow' lines for each
+commit whose parents will not be sent as a result. The server writes
+an 'unshallow' line for each commit which the client has indicated is
+shallow, but is no longer shallow at the currently requested depth
+(that is, its parents will now be sent). The server MUST NOT mark
+as unshallow anything which the client has not indicated was shallow.
+
+Now the client will send a list of the obj-ids it has using 'have'
+lines, so the server can make a packfile that only contains the objects
+that the client needs. In multi_ack mode, the canonical implementation
+will send up to 32 of these at a time, then will send a flush-pkt. The
+canonical implementation will skip ahead and send the next 32 immediately,
+so that there is always a block of 32 "in-flight on the wire" at a time.
+
+----
+ upload-haves = have-list
+ compute-end
+
+ have-list = *have-line
+ have-line = PKT-LINE("have" SP obj-id)
+ compute-end = flush-pkt / PKT-LINE("done")
+----
+
+If the server reads 'have' lines, it then will respond by ACKing any
+of the obj-ids the client said it had that the server also has. The
+server will ACK obj-ids differently depending on which ack mode is
+chosen by the client.
+
+In multi_ack mode:
+
+ * the server will respond with 'ACK obj-id continue' for any common
+ commits.
+
+ * once the server has found an acceptable common base commit and is
+ ready to make a packfile, it will blindly ACK all 'have' obj-ids
+ back to the client.
+
+ * the server will then send a 'NAK' and then wait for another response
+ from the client - either a 'done' or another list of 'have' lines.
+
+In multi_ack_detailed mode:
+
+ * the server will differentiate the ACKs where it is signaling
+ that it is ready to send data with 'ACK obj-id ready' lines, and
+ signals the identified common commits with 'ACK obj-id common' lines.
+
+Without either multi_ack or multi_ack_detailed:
+
+ * upload-pack sends "ACK obj-id" on the first common object it finds.
+ After that it says nothing until the client gives it a "done".
+
+ * upload-pack sends "NAK" on a flush-pkt if no common object
+ has been found yet. If one has been found, and thus an ACK
+ was already sent, it's silent on the flush-pkt.
+
+After the client has gotten enough ACK responses that it can determine
+that the server has enough information to send an efficient packfile
+(in the canonical implementation, this is determined when it has received
+enough ACKs that it can color everything left in the --date-order queue
+as common with the server, or the --date-order queue is empty), or the
+client determines that it wants to give up (in the canonical implementation,
+this is determined when the client sends 256 'have' lines without getting
+any of them ACKed by the server - meaning there is nothing in common and
+the server should just send all of its objects), then the client will send
+a 'done' command. The 'done' command signals to the server that the client
+is ready to receive its packfile data.
+
+However, the 256 limit *only* turns on in the canonical client
+implementation if we have received at least one "ACK %s continue"
+during a prior round. This helps to ensure that at least one common
+ancestor is found before we give up entirely.
+
+Once the 'done' line is read from the client, the server will either
+send a final 'ACK obj-id' or it will send a 'NAK'. 'obj-id' is the object
+name of the last commit determined to be common. The server only sends
+ACK after 'done' if there is at least one common base and multi_ack or
+multi_ack_detailed is enabled. The server always sends NAK after 'done'
+if there is no common base found.
+
+Then the server will start sending its packfile data.
+
+----
+ server-response = *ack_multi ack / nak
+ ack_multi = PKT-LINE("ACK" SP obj-id ack_status)
+ ack_status = "continue" / "common" / "ready"
+ ack = PKT-LINE("ACK" SP obj-id)
+ nak = PKT-LINE("NAK")
+----
+
+A simple clone may look like this (with no 'have' lines):
+
+----
+ C: 0054want 74730d410fcb6603ace96f1dc55ea6196122532d multi_ack \
+ side-band-64k ofs-delta\n
+ C: 0032want 7d1665144a3a975c05f1f43902ddaf084e784dbe\n
+ C: 0032want 5a3f6be755bbb7deae50065988cbfa1ffa9ab68a\n
+ C: 0032want 7e47fe2bd8d01d481f44d7af0531bd93d3b21c01\n
+ C: 0032want 74730d410fcb6603ace96f1dc55ea6196122532d\n
+ C: 0000
+ C: 0009done\n
+
+ S: 0008NAK\n
+ S: [PACKFILE]
+----
+
+An incremental update (fetch) response might look like this:
+
+----
+ C: 0054want 74730d410fcb6603ace96f1dc55ea6196122532d multi_ack \
+ side-band-64k ofs-delta\n
+ C: 0032want 7d1665144a3a975c05f1f43902ddaf084e784dbe\n
+ C: 0032want 5a3f6be755bbb7deae50065988cbfa1ffa9ab68a\n
+ C: 0000
+ C: 0032have 7e47fe2bd8d01d481f44d7af0531bd93d3b21c01\n
+ C: [30 more have lines]
+ C: 0032have 74730d410fcb6603ace96f1dc55ea6196122532d\n
+ C: 0000
+
+ S: 003aACK 7e47fe2bd8d01d481f44d7af0531bd93d3b21c01 continue\n
+ S: 003aACK 74730d410fcb6603ace96f1dc55ea6196122532d continue\n
+ S: 0008NAK\n
+
+ C: 0009done\n
+
+ S: 0031ACK 74730d410fcb6603ace96f1dc55ea6196122532d\n
+ S: [PACKFILE]
+----
+
+
+Packfile Data
+-------------
+
+Now that the client and server have finished negotiation about what
+the minimal amount of data that needs to be sent to the client is, the server
+will construct and send the required data in packfile format.
+
+See pack-format.txt for what the packfile itself actually looks like.
+
+If 'side-band' or 'side-band-64k' capabilities have been specified by
+the client, the server will send the packfile data multiplexed.
+
+Each packet starting with the packet-line length of the amount of data
+that follows, followed by a single byte specifying the sideband the
+following data is coming in on.
+
+In 'side-band' mode, it will send up to 999 data bytes plus 1 control
+code, for a total of up to 1000 bytes in a pkt-line. In 'side-band-64k'
+mode it will send up to 65519 data bytes plus 1 control code, for a
+total of up to 65520 bytes in a pkt-line.
+
+The sideband byte will be a '1', '2' or a '3'. Sideband '1' will contain
+packfile data, sideband '2' will be used for progress information that the
+client will generally print to stderr and sideband '3' is used for error
+information.
+
+If no 'side-band' capability was specified, the server will stream the
+entire packfile without multiplexing.
+
+
+Pushing Data To a Server
+------------------------
+
+Pushing data to a server will invoke the 'receive-pack' process on the
+server, which will allow the client to tell it which references it should
+update and then send all the data the server will need for those new
+references to be complete. Once all the data is received and validated,
+the server will then update its references to what the client specified.
+
+Authentication
+--------------
+
+The protocol itself contains no authentication mechanisms. That is to be
+handled by the transport, such as SSH, before the 'receive-pack' process is
+invoked. If 'receive-pack' is configured over the Git transport, those
+repositories will be writable by anyone who can access that port (9418) as
+that transport is unauthenticated.
+
+Reference Discovery
+-------------------
+
+The reference discovery phase is done nearly the same way as it is in the
+fetching protocol. Each reference obj-id and name on the server is sent
+in packet-line format to the client, followed by a flush-pkt. The only
+real difference is that the capability listing is different - the only
+possible values are 'report-status', 'delete-refs', 'ofs-delta' and
+'push-options'.
+
+Reference Update Request and Packfile Transfer
+----------------------------------------------
+
+Once the client knows what references the server is at, it can send a
+list of reference update requests. For each reference on the server
+that it wants to update, it sends a line listing the obj-id currently on
+the server, the obj-id the client would like to update it to and the name
+of the reference.
+
+This list is followed by a flush-pkt. Then the push options are transmitted
+one per packet followed by another flush-pkt. After that the packfile that
+should contain all the objects that the server will need to complete the new
+references will be sent.
+
+----
+ update-request = *shallow ( command-list | push-cert ) [packfile]
+
+ shallow = PKT-LINE("shallow" SP obj-id)
+
+ command-list = PKT-LINE(command NUL capability-list)
+ *PKT-LINE(command)
+ flush-pkt
+
+ command = create / delete / update
+ create = zero-id SP new-id SP name
+ delete = old-id SP zero-id SP name
+ update = old-id SP new-id SP name
+
+ old-id = obj-id
+ new-id = obj-id
+
+ push-cert = PKT-LINE("push-cert" NUL capability-list LF)
+ PKT-LINE("certificate version 0.1" LF)
+ PKT-LINE("pusher" SP ident LF)
+ PKT-LINE("pushee" SP url LF)
+ PKT-LINE("nonce" SP nonce LF)
+ PKT-LINE(LF)
+ *PKT-LINE(command LF)
+ *PKT-LINE(gpg-signature-lines LF)
+ PKT-LINE("push-cert-end" LF)
+
+ packfile = "PACK" 28*(OCTET)
+----
+
+If the receiving end does not support delete-refs, the sending end MUST
+NOT ask for delete command.
+
+If the receiving end does not support push-cert, the sending end
+MUST NOT send a push-cert command. When a push-cert command is
+sent, command-list MUST NOT be sent; the commands recorded in the
+push certificate is used instead.
+
+The packfile MUST NOT be sent if the only command used is 'delete'.
+
+A packfile MUST be sent if either create or update command is used,
+even if the server already has all the necessary objects. In this
+case the client MUST send an empty packfile. The only time this
+is likely to happen is if the client is creating
+a new branch or a tag that points to an existing obj-id.
+
+The server will receive the packfile, unpack it, then validate each
+reference that is being updated that it hasn't changed while the request
+was being processed (the obj-id is still the same as the old-id), and
+it will run any update hooks to make sure that the update is acceptable.
+If all of that is fine, the server will then update the references.
+
+Push Certificate
+----------------
+
+A push certificate begins with a set of header lines. After the
+header and an empty line, the protocol commands follow, one per
+line. Note that the trailing LF in push-cert PKT-LINEs is _not_
+optional; it must be present.
+
+Currently, the following header fields are defined:
+
+`pusher` ident::
+ Identify the GPG key in "Human Readable Name <email@address>"
+ format.
+
+`pushee` url::
+ The repository URL (anonymized, if the URL contains
+ authentication material) the user who ran `git push`
+ intended to push into.
+
+`nonce` nonce::
+ The 'nonce' string the receiving repository asked the
+ pushing user to include in the certificate, to prevent
+ replay attacks.
+
+The GPG signature lines are a detached signature for the contents
+recorded in the push certificate before the signature block begins.
+The detached signature is used to certify that the commands were
+given by the pusher, who must be the signer.
+
+Report Status
+-------------
+
+After receiving the pack data from the sender, the receiver sends a
+report if 'report-status' capability is in effect.
+It is a short listing of what happened in that update. It will first
+list the status of the packfile unpacking as either 'unpack ok' or
+'unpack [error]'. Then it will list the status for each of the references
+that it tried to update. Each line is either 'ok [refname]' if the
+update was successful, or 'ng [refname] [error]' if the update was not.
+
+----
+ report-status = unpack-status
+ 1*(command-status)
+ flush-pkt
+
+ unpack-status = PKT-LINE("unpack" SP unpack-result)
+ unpack-result = "ok" / error-msg
+
+ command-status = command-ok / command-fail
+ command-ok = PKT-LINE("ok" SP refname)
+ command-fail = PKT-LINE("ng" SP refname SP error-msg)
+
+ error-msg = 1*(OCTECT) ; where not "ok"
+----
+
+Updates can be unsuccessful for a number of reasons. The reference can have
+changed since the reference discovery phase was originally sent, meaning
+someone pushed in the meantime. The reference being pushed could be a
+non-fast-forward reference and the update hooks or configuration could be
+set to not allow that, etc. Also, some references can be updated while others
+can be rejected.
+
+An example client/server communication might look like this:
+
+----
+ S: 007c74730d410fcb6603ace96f1dc55ea6196122532d refs/heads/local\0report-status delete-refs ofs-delta\n
+ S: 003e7d1665144a3a975c05f1f43902ddaf084e784dbe refs/heads/debug\n
+ S: 003f74730d410fcb6603ace96f1dc55ea6196122532d refs/heads/master\n
+ S: 003f74730d410fcb6603ace96f1dc55ea6196122532d refs/heads/team\n
+ S: 0000
+
+ C: 003e7d1665144a3a975c05f1f43902ddaf084e784dbe 74730d410fcb6603ace96f1dc55ea6196122532d refs/heads/debug\n
+ C: 003e74730d410fcb6603ace96f1dc55ea6196122532d 5a3f6be755bbb7deae50065988cbfa1ffa9ab68a refs/heads/master\n
+ C: 0000
+ C: [PACKDATA]
+
+ S: 000eunpack ok\n
+ S: 0018ok refs/heads/debug\n
+ S: 002ang refs/heads/master non-fast-forward\n
+----
+*/
--- /dev/null
+package packp
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "strings"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
+)
+
+const (
+ ok = "ok"
+)
+
+// ReportStatus is a report status message, as used in the git-receive-pack
+// process whenever the 'report-status' capability is negotiated.
+type ReportStatus struct {
+ UnpackStatus string
+ CommandStatuses []*CommandStatus
+}
+
+// NewReportStatus creates a new ReportStatus message.
+func NewReportStatus() *ReportStatus {
+ return &ReportStatus{}
+}
+
+// Error returns the first error if any.
+func (s *ReportStatus) Error() error {
+ if s.UnpackStatus != ok {
+ return fmt.Errorf("unpack error: %s", s.UnpackStatus)
+ }
+
+ for _, s := range s.CommandStatuses {
+ if err := s.Error(); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Encode writes the report status to a writer.
+func (s *ReportStatus) Encode(w io.Writer) error {
+ e := pktline.NewEncoder(w)
+ if err := e.Encodef("unpack %s\n", s.UnpackStatus); err != nil {
+ return err
+ }
+
+ for _, cs := range s.CommandStatuses {
+ if err := cs.encode(w); err != nil {
+ return err
+ }
+ }
+
+ return e.Flush()
+}
+
+// Decode reads from the given reader and decodes a report-status message. It
+// does not read more input than what is needed to fill the report status.
+func (s *ReportStatus) Decode(r io.Reader) error {
+ scan := pktline.NewScanner(r)
+ if err := s.scanFirstLine(scan); err != nil {
+ return err
+ }
+
+ if err := s.decodeReportStatus(scan.Bytes()); err != nil {
+ return err
+ }
+
+ flushed := false
+ for scan.Scan() {
+ b := scan.Bytes()
+ if isFlush(b) {
+ flushed = true
+ break
+ }
+
+ if err := s.decodeCommandStatus(b); err != nil {
+ return err
+ }
+ }
+
+ if !flushed {
+ return fmt.Errorf("missing flush")
+ }
+
+ return scan.Err()
+}
+
+func (s *ReportStatus) scanFirstLine(scan *pktline.Scanner) error {
+ if scan.Scan() {
+ return nil
+ }
+
+ if scan.Err() != nil {
+ return scan.Err()
+ }
+
+ return io.ErrUnexpectedEOF
+}
+
+func (s *ReportStatus) decodeReportStatus(b []byte) error {
+ if isFlush(b) {
+ return fmt.Errorf("premature flush")
+ }
+
+ b = bytes.TrimSuffix(b, eol)
+
+ line := string(b)
+ fields := strings.SplitN(line, " ", 2)
+ if len(fields) != 2 || fields[0] != "unpack" {
+ return fmt.Errorf("malformed unpack status: %s", line)
+ }
+
+ s.UnpackStatus = fields[1]
+ return nil
+}
+
+func (s *ReportStatus) decodeCommandStatus(b []byte) error {
+ b = bytes.TrimSuffix(b, eol)
+
+ line := string(b)
+ fields := strings.SplitN(line, " ", 3)
+ status := ok
+ if len(fields) == 3 && fields[0] == "ng" {
+ status = fields[2]
+ } else if len(fields) != 2 || fields[0] != "ok" {
+ return fmt.Errorf("malformed command status: %s", line)
+ }
+
+ cs := &CommandStatus{
+ ReferenceName: plumbing.ReferenceName(fields[1]),
+ Status: status,
+ }
+ s.CommandStatuses = append(s.CommandStatuses, cs)
+ return nil
+}
+
+// CommandStatus is the status of a reference in a report status.
+// See ReportStatus struct.
+type CommandStatus struct {
+ ReferenceName plumbing.ReferenceName
+ Status string
+}
+
+// Error returns the error, if any.
+func (s *CommandStatus) Error() error {
+ if s.Status == ok {
+ return nil
+ }
+
+ return fmt.Errorf("command error on %s: %s",
+ s.ReferenceName.String(), s.Status)
+}
+
+func (s *CommandStatus) encode(w io.Writer) error {
+ e := pktline.NewEncoder(w)
+ if s.Error() == nil {
+ return e.Encodef("ok %s\n", s.ReferenceName.String())
+ }
+
+ return e.Encodef("ng %s %s\n", s.ReferenceName.String(), s.Status)
+}
--- /dev/null
+package packp
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
+)
+
+const (
+ shallowLineLen = 48
+ unshallowLineLen = 50
+)
+
+type ShallowUpdate struct {
+ Shallows []plumbing.Hash
+ Unshallows []plumbing.Hash
+}
+
+func (r *ShallowUpdate) Decode(reader io.Reader) error {
+ s := pktline.NewScanner(reader)
+
+ for s.Scan() {
+ line := s.Bytes()
+ line = bytes.TrimSpace(line)
+
+ var err error
+ switch {
+ case bytes.HasPrefix(line, shallow):
+ err = r.decodeShallowLine(line)
+ case bytes.HasPrefix(line, unshallow):
+ err = r.decodeUnshallowLine(line)
+ case bytes.Equal(line, pktline.Flush):
+ return nil
+ }
+
+ if err != nil {
+ return err
+ }
+ }
+
+ return s.Err()
+}
+
+func (r *ShallowUpdate) decodeShallowLine(line []byte) error {
+ hash, err := r.decodeLine(line, shallow, shallowLineLen)
+ if err != nil {
+ return err
+ }
+
+ r.Shallows = append(r.Shallows, hash)
+ return nil
+}
+
+func (r *ShallowUpdate) decodeUnshallowLine(line []byte) error {
+ hash, err := r.decodeLine(line, unshallow, unshallowLineLen)
+ if err != nil {
+ return err
+ }
+
+ r.Unshallows = append(r.Unshallows, hash)
+ return nil
+}
+
+func (r *ShallowUpdate) decodeLine(line, prefix []byte, expLen int) (plumbing.Hash, error) {
+ if len(line) != expLen {
+ return plumbing.ZeroHash, fmt.Errorf("malformed %s%q", prefix, line)
+ }
+
+ raw := string(line[expLen-40 : expLen])
+ return plumbing.NewHash(raw), nil
+}
+
+func (r *ShallowUpdate) Encode(w io.Writer) error {
+ e := pktline.NewEncoder(w)
+
+ for _, h := range r.Shallows {
+ if err := e.Encodef("%s%s\n", shallow, h.String()); err != nil {
+ return err
+ }
+ }
+
+ for _, h := range r.Unshallows {
+ if err := e.Encodef("%s%s\n", unshallow, h.String()); err != nil {
+ return err
+ }
+ }
+
+ return e.Flush()
+}
--- /dev/null
+package sideband
+
+// Type sideband type "side-band" or "side-band-64k"
+type Type int8
+
+const (
+ // Sideband legacy sideband type up to 1000-byte messages
+ Sideband Type = iota
+ // Sideband64k sideband type up to 65519-byte messages
+ Sideband64k Type = iota
+
+ // MaxPackedSize for Sideband type
+ MaxPackedSize = 1000
+ // MaxPackedSize64k for Sideband64k type
+ MaxPackedSize64k = 65520
+)
+
+// Channel sideband channel
+type Channel byte
+
+// WithPayload encode the payload as a message
+func (ch Channel) WithPayload(payload []byte) []byte {
+ return append([]byte{byte(ch)}, payload...)
+}
+
+const (
+ // PackData packfile content
+ PackData Channel = 1
+ // ProgressMessage progress messages
+ ProgressMessage Channel = 2
+ // ErrorMessage fatal error message just before stream aborts
+ ErrorMessage Channel = 3
+)
--- /dev/null
+package sideband
+
+import (
+ "errors"
+ "fmt"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
+)
+
+// ErrMaxPackedExceeded returned by Read, if the maximum packed size is exceeded
+var ErrMaxPackedExceeded = errors.New("max. packed size exceeded")
+
+// Progress where the progress information is stored
+type Progress interface {
+ io.Writer
+}
+
+// Demuxer demultiplexes the progress reports and error info interleaved with the
+// packfile itself.
+//
+// A sideband has three different channels the main one, called PackData, contains
+// the packfile data; the ErrorMessage channel, that contains server errors; and
+// the last one, ProgressMessage channel, containing information about the ongoing
+// task happening in the server (optional, can be suppressed sending NoProgress
+// or Quiet capabilities to the server)
+//
+// In order to demultiplex the data stream, method `Read` should be called to
+// retrieve the PackData channel, the incoming data from the ProgressMessage is
+// written at `Progress` (if any), if any message is retrieved from the
+// ErrorMessage channel an error is returned and we can assume that the
+// connection has been closed.
+type Demuxer struct {
+ t Type
+ r io.Reader
+ s *pktline.Scanner
+
+ max int
+ pending []byte
+
+ // Progress is where the progress messages are stored
+ Progress Progress
+}
+
+// NewDemuxer returns a new Demuxer for the given t and read from r
+func NewDemuxer(t Type, r io.Reader) *Demuxer {
+ max := MaxPackedSize64k
+ if t == Sideband {
+ max = MaxPackedSize
+ }
+
+ return &Demuxer{
+ t: t,
+ r: r,
+ max: max,
+ s: pktline.NewScanner(r),
+ }
+}
+
+// Read reads up to len(p) bytes from the PackData channel into p, an error can
+// be return if an error happens when reading or if a message is sent in the
+// ErrorMessage channel.
+//
+// When a ProgressMessage is read, is not copy to b, instead of this is written
+// to the Progress
+func (d *Demuxer) Read(b []byte) (n int, err error) {
+ var read, req int
+
+ req = len(b)
+ for read < req {
+ n, err := d.doRead(b[read:req])
+ read += n
+
+ if err != nil {
+ return read, err
+ }
+ }
+
+ return read, nil
+}
+
+func (d *Demuxer) doRead(b []byte) (int, error) {
+ read, err := d.nextPackData()
+ size := len(read)
+ wanted := len(b)
+
+ if size > wanted {
+ d.pending = read[wanted:]
+ }
+
+ if wanted > size {
+ wanted = size
+ }
+
+ size = copy(b, read[:wanted])
+ return size, err
+}
+
+func (d *Demuxer) nextPackData() ([]byte, error) {
+ content := d.getPending()
+ if len(content) != 0 {
+ return content, nil
+ }
+
+ if !d.s.Scan() {
+ if err := d.s.Err(); err != nil {
+ return nil, err
+ }
+
+ return nil, io.EOF
+ }
+
+ content = d.s.Bytes()
+
+ size := len(content)
+ if size == 0 {
+ return nil, nil
+ } else if size > d.max {
+ return nil, ErrMaxPackedExceeded
+ }
+
+ switch Channel(content[0]) {
+ case PackData:
+ return content[1:], nil
+ case ProgressMessage:
+ if d.Progress != nil {
+ _, err := d.Progress.Write(content[1:])
+ return nil, err
+ }
+ case ErrorMessage:
+ return nil, fmt.Errorf("unexpected error: %s", content[1:])
+ default:
+ return nil, fmt.Errorf("unknown channel %s", content)
+ }
+
+ return nil, nil
+}
+
+func (d *Demuxer) getPending() (b []byte) {
+ if len(d.pending) == 0 {
+ return nil
+ }
+
+ content := d.pending
+ d.pending = nil
+
+ return content
+}
--- /dev/null
+// Package sideband implements a sideband mutiplex/demultiplexer
+package sideband
+
+// If 'side-band' or 'side-band-64k' capabilities have been specified by
+// the client, the server will send the packfile data multiplexed.
+//
+// Either mode indicates that the packfile data will be streamed broken
+// up into packets of up to either 1000 bytes in the case of 'side_band',
+// or 65520 bytes in the case of 'side_band_64k'. Each packet is made up
+// of a leading 4-byte pkt-line length of how much data is in the packet,
+// followed by a 1-byte stream code, followed by the actual data.
+//
+// The stream code can be one of:
+//
+// 1 - pack data
+// 2 - progress messages
+// 3 - fatal error message just before stream aborts
+//
+// The "side-band-64k" capability came about as a way for newer clients
+// that can handle much larger packets to request packets that are
+// actually crammed nearly full, while maintaining backward compatibility
+// for the older clients.
+//
+// Further, with side-band and its up to 1000-byte messages, it's actually
+// 999 bytes of payload and 1 byte for the stream code. With side-band-64k,
+// same deal, you have up to 65519 bytes of data and 1 byte for the stream
+// code.
+//
+// The client MUST send only maximum of one of "side-band" and "side-
+// band-64k". Server MUST diagnose it as an error if client requests
+// both.
--- /dev/null
+package sideband
+
+import (
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
+)
+
+// Muxer multiplex the packfile along with the progress messages and the error
+// information. The multiplex is perform using pktline format.
+type Muxer struct {
+ max int
+ e *pktline.Encoder
+}
+
+const chLen = 1
+
+// NewMuxer returns a new Muxer for the given t that writes on w.
+//
+// If t is equal to `Sideband` the max pack size is set to MaxPackedSize, in any
+// other value is given, max pack is set to MaxPackedSize64k, that is the
+// maximum length of a line in pktline format.
+func NewMuxer(t Type, w io.Writer) *Muxer {
+ max := MaxPackedSize64k
+ if t == Sideband {
+ max = MaxPackedSize
+ }
+
+ return &Muxer{
+ max: max - chLen,
+ e: pktline.NewEncoder(w),
+ }
+}
+
+// Write writes p in the PackData channel
+func (m *Muxer) Write(p []byte) (int, error) {
+ return m.WriteChannel(PackData, p)
+}
+
+// WriteChannel writes p in the given channel. This method can be used with any
+// channel, but is recommend use it only for the ProgressMessage and
+// ErrorMessage channels and use Write for the PackData channel
+func (m *Muxer) WriteChannel(t Channel, p []byte) (int, error) {
+ wrote := 0
+ size := len(p)
+ for wrote < size {
+ n, err := m.doWrite(t, p[wrote:])
+ wrote += n
+
+ if err != nil {
+ return wrote, err
+ }
+ }
+
+ return wrote, nil
+}
+
+func (m *Muxer) doWrite(ch Channel, p []byte) (int, error) {
+ sz := len(p)
+ if sz > m.max {
+ sz = m.max
+ }
+
+ return sz, m.e.Encode(ch.WithPayload(p[:sz]))
+}
--- /dev/null
+package packp
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
+)
+
+const ackLineLen = 44
+
+// ServerResponse object acknowledgement from upload-pack service
+type ServerResponse struct {
+ ACKs []plumbing.Hash
+}
+
+// Decode decodes the response into the struct, isMultiACK should be true, if
+// the request was done with multi_ack or multi_ack_detailed capabilities.
+func (r *ServerResponse) Decode(reader *bufio.Reader, isMultiACK bool) error {
+ // TODO: implement support for multi_ack or multi_ack_detailed responses
+ if isMultiACK {
+ return errors.New("multi_ack and multi_ack_detailed are not supported")
+ }
+
+ s := pktline.NewScanner(reader)
+
+ for s.Scan() {
+ line := s.Bytes()
+
+ if err := r.decodeLine(line); err != nil {
+ return err
+ }
+
+ // we need to detect when the end of a response header and the beginning
+ // of a packfile header happened, some requests to the git daemon
+ // produces a duplicate ACK header even when multi_ack is not supported.
+ stop, err := r.stopReading(reader)
+ if err != nil {
+ return err
+ }
+
+ if stop {
+ break
+ }
+ }
+
+ return s.Err()
+}
+
+// stopReading detects when a valid command such as ACK or NAK is found to be
+// read in the buffer without moving the read pointer.
+func (r *ServerResponse) stopReading(reader *bufio.Reader) (bool, error) {
+ ahead, err := reader.Peek(7)
+ if err == io.EOF {
+ return true, nil
+ }
+
+ if err != nil {
+ return false, err
+ }
+
+ if len(ahead) > 4 && r.isValidCommand(ahead[0:3]) {
+ return false, nil
+ }
+
+ if len(ahead) == 7 && r.isValidCommand(ahead[4:]) {
+ return false, nil
+ }
+
+ return true, nil
+}
+
+func (r *ServerResponse) isValidCommand(b []byte) bool {
+ commands := [][]byte{ack, nak}
+ for _, c := range commands {
+ if bytes.Equal(b, c) {
+ return true
+ }
+ }
+
+ return false
+}
+
+func (r *ServerResponse) decodeLine(line []byte) error {
+ if len(line) == 0 {
+ return fmt.Errorf("unexpected flush")
+ }
+
+ if bytes.Equal(line[0:3], ack) {
+ return r.decodeACKLine(line)
+ }
+
+ if bytes.Equal(line[0:3], nak) {
+ return nil
+ }
+
+ return fmt.Errorf("unexpected content %q", string(line))
+}
+
+func (r *ServerResponse) decodeACKLine(line []byte) error {
+ if len(line) < ackLineLen {
+ return fmt.Errorf("malformed ACK %q", line)
+ }
+
+ sp := bytes.Index(line, []byte(" "))
+ h := plumbing.NewHash(string(line[sp+1 : sp+41]))
+ r.ACKs = append(r.ACKs, h)
+ return nil
+}
+
+// Encode encodes the ServerResponse into a writer.
+func (r *ServerResponse) Encode(w io.Writer) error {
+ if len(r.ACKs) > 1 {
+ return errors.New("multi_ack and multi_ack_detailed are not supported")
+ }
+
+ e := pktline.NewEncoder(w)
+ if len(r.ACKs) == 0 {
+ return e.Encodef("%s\n", nak)
+ }
+
+ return e.Encodef("%s %s\n", ack, r.ACKs[0].String())
+}
--- /dev/null
+package packp
+
+import (
+ "fmt"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/capability"
+)
+
+// UploadRequest values represent the information transmitted on a
+// upload-request message. Values from this type are not zero-value
+// safe, use the New function instead.
+// This is a low level type, use UploadPackRequest instead.
+type UploadRequest struct {
+ Capabilities *capability.List
+ Wants []plumbing.Hash
+ Shallows []plumbing.Hash
+ Depth Depth
+}
+
+// Depth values stores the desired depth of the requested packfile: see
+// DepthCommit, DepthSince and DepthReference.
+type Depth interface {
+ isDepth()
+ IsZero() bool
+}
+
+// DepthCommits values stores the maximum number of requested commits in
+// the packfile. Zero means infinite. A negative value will have
+// undefined consequences.
+type DepthCommits int
+
+func (d DepthCommits) isDepth() {}
+
+func (d DepthCommits) IsZero() bool {
+ return d == 0
+}
+
+// DepthSince values requests only commits newer than the specified time.
+type DepthSince time.Time
+
+func (d DepthSince) isDepth() {}
+
+func (d DepthSince) IsZero() bool {
+ return time.Time(d).IsZero()
+}
+
+// DepthReference requests only commits not to found in the specified reference.
+type DepthReference string
+
+func (d DepthReference) isDepth() {}
+
+func (d DepthReference) IsZero() bool {
+ return string(d) == ""
+}
+
+// NewUploadRequest returns a pointer to a new UploadRequest value, ready to be
+// used. It has no capabilities, wants or shallows and an infinite depth. Please
+// note that to encode an upload-request it has to have at least one wanted hash.
+func NewUploadRequest() *UploadRequest {
+ return &UploadRequest{
+ Capabilities: capability.NewList(),
+ Wants: []plumbing.Hash{},
+ Shallows: []plumbing.Hash{},
+ Depth: DepthCommits(0),
+ }
+}
+
+// NewUploadRequestFromCapabilities returns a pointer to a new UploadRequest
+// value, the request capabilities are filled with the most optiomal ones, based
+// on the adv value (advertaised capabilities), the UploadRequest generated it
+// has no wants or shallows and an infinite depth.
+func NewUploadRequestFromCapabilities(adv *capability.List) *UploadRequest {
+ r := NewUploadRequest()
+
+ if adv.Supports(capability.MultiACKDetailed) {
+ r.Capabilities.Set(capability.MultiACKDetailed)
+ } else if adv.Supports(capability.MultiACK) {
+ r.Capabilities.Set(capability.MultiACK)
+ }
+
+ if adv.Supports(capability.Sideband64k) {
+ r.Capabilities.Set(capability.Sideband64k)
+ } else if adv.Supports(capability.Sideband) {
+ r.Capabilities.Set(capability.Sideband)
+ }
+
+ if adv.Supports(capability.ThinPack) {
+ r.Capabilities.Set(capability.ThinPack)
+ }
+
+ if adv.Supports(capability.OFSDelta) {
+ r.Capabilities.Set(capability.OFSDelta)
+ }
+
+ if adv.Supports(capability.Agent) {
+ r.Capabilities.Set(capability.Agent, capability.DefaultAgent)
+ }
+
+ return r
+}
+
+// Validate validates the content of UploadRequest, following the next rules:
+// - Wants MUST have at least one reference
+// - capability.Shallow MUST be present if Shallows is not empty
+// - is a non-zero DepthCommits is given capability.Shallow MUST be present
+// - is a DepthSince is given capability.Shallow MUST be present
+// - is a DepthReference is given capability.DeepenNot MUST be present
+// - MUST contain only maximum of one of capability.Sideband and capability.Sideband64k
+// - MUST contain only maximum of one of capability.MultiACK and capability.MultiACKDetailed
+func (r *UploadRequest) Validate() error {
+ if len(r.Wants) == 0 {
+ return fmt.Errorf("want can't be empty")
+ }
+
+ if err := r.validateRequiredCapabilities(); err != nil {
+ return err
+ }
+
+ if err := r.validateConflictCapabilities(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (r *UploadRequest) validateRequiredCapabilities() error {
+ msg := "missing capability %s"
+
+ if len(r.Shallows) != 0 && !r.Capabilities.Supports(capability.Shallow) {
+ return fmt.Errorf(msg, capability.Shallow)
+ }
+
+ switch r.Depth.(type) {
+ case DepthCommits:
+ if r.Depth != DepthCommits(0) {
+ if !r.Capabilities.Supports(capability.Shallow) {
+ return fmt.Errorf(msg, capability.Shallow)
+ }
+ }
+ case DepthSince:
+ if !r.Capabilities.Supports(capability.DeepenSince) {
+ return fmt.Errorf(msg, capability.DeepenSince)
+ }
+ case DepthReference:
+ if !r.Capabilities.Supports(capability.DeepenNot) {
+ return fmt.Errorf(msg, capability.DeepenNot)
+ }
+ }
+
+ return nil
+}
+
+func (r *UploadRequest) validateConflictCapabilities() error {
+ msg := "capabilities %s and %s are mutually exclusive"
+ if r.Capabilities.Supports(capability.Sideband) &&
+ r.Capabilities.Supports(capability.Sideband64k) {
+ return fmt.Errorf(msg, capability.Sideband, capability.Sideband64k)
+ }
+
+ if r.Capabilities.Supports(capability.MultiACK) &&
+ r.Capabilities.Supports(capability.MultiACKDetailed) {
+ return fmt.Errorf(msg, capability.MultiACK, capability.MultiACKDetailed)
+ }
+
+ return nil
+}
--- /dev/null
+package packp
+
+import (
+ "bytes"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "strconv"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
+)
+
+// Decode reads the next upload-request form its input and
+// stores it in the UploadRequest.
+func (u *UploadRequest) Decode(r io.Reader) error {
+ d := newUlReqDecoder(r)
+ return d.Decode(u)
+}
+
+type ulReqDecoder struct {
+ s *pktline.Scanner // a pkt-line scanner from the input stream
+ line []byte // current pkt-line contents, use parser.nextLine() to make it advance
+ nLine int // current pkt-line number for debugging, begins at 1
+ err error // sticky error, use the parser.error() method to fill this out
+ data *UploadRequest // parsed data is stored here
+}
+
+func newUlReqDecoder(r io.Reader) *ulReqDecoder {
+ return &ulReqDecoder{
+ s: pktline.NewScanner(r),
+ }
+}
+
+func (d *ulReqDecoder) Decode(v *UploadRequest) error {
+ d.data = v
+
+ for state := d.decodeFirstWant; state != nil; {
+ state = state()
+ }
+
+ return d.err
+}
+
+// fills out the parser stiky error
+func (d *ulReqDecoder) error(format string, a ...interface{}) {
+ msg := fmt.Sprintf(
+ "pkt-line %d: %s", d.nLine,
+ fmt.Sprintf(format, a...),
+ )
+
+ d.err = NewErrUnexpectedData(msg, d.line)
+}
+
+// Reads a new pkt-line from the scanner, makes its payload available as
+// p.line and increments p.nLine. A successful invocation returns true,
+// otherwise, false is returned and the sticky error is filled out
+// accordingly. Trims eols at the end of the payloads.
+func (d *ulReqDecoder) nextLine() bool {
+ d.nLine++
+
+ if !d.s.Scan() {
+ if d.err = d.s.Err(); d.err != nil {
+ return false
+ }
+
+ d.error("EOF")
+ return false
+ }
+
+ d.line = d.s.Bytes()
+ d.line = bytes.TrimSuffix(d.line, eol)
+
+ return true
+}
+
+// Expected format: want <hash>[ capabilities]
+func (d *ulReqDecoder) decodeFirstWant() stateFn {
+ if ok := d.nextLine(); !ok {
+ return nil
+ }
+
+ if !bytes.HasPrefix(d.line, want) {
+ d.error("missing 'want ' prefix")
+ return nil
+ }
+ d.line = bytes.TrimPrefix(d.line, want)
+
+ hash, ok := d.readHash()
+ if !ok {
+ return nil
+ }
+ d.data.Wants = append(d.data.Wants, hash)
+
+ return d.decodeCaps
+}
+
+func (d *ulReqDecoder) readHash() (plumbing.Hash, bool) {
+ if len(d.line) < hashSize {
+ d.err = fmt.Errorf("malformed hash: %v", d.line)
+ return plumbing.ZeroHash, false
+ }
+
+ var hash plumbing.Hash
+ if _, err := hex.Decode(hash[:], d.line[:hashSize]); err != nil {
+ d.error("invalid hash text: %s", err)
+ return plumbing.ZeroHash, false
+ }
+ d.line = d.line[hashSize:]
+
+ return hash, true
+}
+
+// Expected format: sp cap1 sp cap2 sp cap3...
+func (d *ulReqDecoder) decodeCaps() stateFn {
+ d.line = bytes.TrimPrefix(d.line, sp)
+ if err := d.data.Capabilities.Decode(d.line); err != nil {
+ d.error("invalid capabilities: %s", err)
+ }
+
+ return d.decodeOtherWants
+}
+
+// Expected format: want <hash>
+func (d *ulReqDecoder) decodeOtherWants() stateFn {
+ if ok := d.nextLine(); !ok {
+ return nil
+ }
+
+ if bytes.HasPrefix(d.line, shallow) {
+ return d.decodeShallow
+ }
+
+ if bytes.HasPrefix(d.line, deepen) {
+ return d.decodeDeepen
+ }
+
+ if len(d.line) == 0 {
+ return nil
+ }
+
+ if !bytes.HasPrefix(d.line, want) {
+ d.error("unexpected payload while expecting a want: %q", d.line)
+ return nil
+ }
+ d.line = bytes.TrimPrefix(d.line, want)
+
+ hash, ok := d.readHash()
+ if !ok {
+ return nil
+ }
+ d.data.Wants = append(d.data.Wants, hash)
+
+ return d.decodeOtherWants
+}
+
+// Expected format: shallow <hash>
+func (d *ulReqDecoder) decodeShallow() stateFn {
+ if bytes.HasPrefix(d.line, deepen) {
+ return d.decodeDeepen
+ }
+
+ if len(d.line) == 0 {
+ return nil
+ }
+
+ if !bytes.HasPrefix(d.line, shallow) {
+ d.error("unexpected payload while expecting a shallow: %q", d.line)
+ return nil
+ }
+ d.line = bytes.TrimPrefix(d.line, shallow)
+
+ hash, ok := d.readHash()
+ if !ok {
+ return nil
+ }
+ d.data.Shallows = append(d.data.Shallows, hash)
+
+ if ok := d.nextLine(); !ok {
+ return nil
+ }
+
+ return d.decodeShallow
+}
+
+// Expected format: deepen <n> / deepen-since <ul> / deepen-not <ref>
+func (d *ulReqDecoder) decodeDeepen() stateFn {
+ if bytes.HasPrefix(d.line, deepenCommits) {
+ return d.decodeDeepenCommits
+ }
+
+ if bytes.HasPrefix(d.line, deepenSince) {
+ return d.decodeDeepenSince
+ }
+
+ if bytes.HasPrefix(d.line, deepenReference) {
+ return d.decodeDeepenReference
+ }
+
+ if len(d.line) == 0 {
+ return nil
+ }
+
+ d.error("unexpected deepen specification: %q", d.line)
+ return nil
+}
+
+func (d *ulReqDecoder) decodeDeepenCommits() stateFn {
+ d.line = bytes.TrimPrefix(d.line, deepenCommits)
+
+ var n int
+ if n, d.err = strconv.Atoi(string(d.line)); d.err != nil {
+ return nil
+ }
+ if n < 0 {
+ d.err = fmt.Errorf("negative depth")
+ return nil
+ }
+ d.data.Depth = DepthCommits(n)
+
+ return d.decodeFlush
+}
+
+func (d *ulReqDecoder) decodeDeepenSince() stateFn {
+ d.line = bytes.TrimPrefix(d.line, deepenSince)
+
+ var secs int64
+ secs, d.err = strconv.ParseInt(string(d.line), 10, 64)
+ if d.err != nil {
+ return nil
+ }
+ t := time.Unix(secs, 0).UTC()
+ d.data.Depth = DepthSince(t)
+
+ return d.decodeFlush
+}
+
+func (d *ulReqDecoder) decodeDeepenReference() stateFn {
+ d.line = bytes.TrimPrefix(d.line, deepenReference)
+
+ d.data.Depth = DepthReference(string(d.line))
+
+ return d.decodeFlush
+}
+
+func (d *ulReqDecoder) decodeFlush() stateFn {
+ if ok := d.nextLine(); !ok {
+ return nil
+ }
+
+ if len(d.line) != 0 {
+ d.err = fmt.Errorf("unexpected payload while expecting a flush-pkt: %q", d.line)
+ }
+
+ return nil
+}
--- /dev/null
+package packp
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
+)
+
+// Encode writes the UlReq encoding of u to the stream.
+//
+// All the payloads will end with a newline character. Wants and
+// shallows are sorted alphabetically. A depth of 0 means no depth
+// request is sent.
+func (u *UploadRequest) Encode(w io.Writer) error {
+ e := newUlReqEncoder(w)
+ return e.Encode(u)
+}
+
+type ulReqEncoder struct {
+ pe *pktline.Encoder // where to write the encoded data
+ data *UploadRequest // the data to encode
+ err error // sticky error
+}
+
+func newUlReqEncoder(w io.Writer) *ulReqEncoder {
+ return &ulReqEncoder{
+ pe: pktline.NewEncoder(w),
+ }
+}
+
+func (e *ulReqEncoder) Encode(v *UploadRequest) error {
+ e.data = v
+
+ if len(v.Wants) == 0 {
+ return fmt.Errorf("empty wants provided")
+ }
+
+ plumbing.HashesSort(e.data.Wants)
+ for state := e.encodeFirstWant; state != nil; {
+ state = state()
+ }
+
+ return e.err
+}
+
+func (e *ulReqEncoder) encodeFirstWant() stateFn {
+ var err error
+ if e.data.Capabilities.IsEmpty() {
+ err = e.pe.Encodef("want %s\n", e.data.Wants[0])
+ } else {
+ err = e.pe.Encodef(
+ "want %s %s\n",
+ e.data.Wants[0],
+ e.data.Capabilities.String(),
+ )
+ }
+
+ if err != nil {
+ e.err = fmt.Errorf("encoding first want line: %s", err)
+ return nil
+ }
+
+ return e.encodeAditionalWants
+}
+
+func (e *ulReqEncoder) encodeAditionalWants() stateFn {
+ last := e.data.Wants[0]
+ for _, w := range e.data.Wants[1:] {
+ if bytes.Equal(last[:], w[:]) {
+ continue
+ }
+
+ if err := e.pe.Encodef("want %s\n", w); err != nil {
+ e.err = fmt.Errorf("encoding want %q: %s", w, err)
+ return nil
+ }
+
+ last = w
+ }
+
+ return e.encodeShallows
+}
+
+func (e *ulReqEncoder) encodeShallows() stateFn {
+ plumbing.HashesSort(e.data.Shallows)
+
+ var last plumbing.Hash
+ for _, s := range e.data.Shallows {
+ if bytes.Equal(last[:], s[:]) {
+ continue
+ }
+
+ if err := e.pe.Encodef("shallow %s\n", s); err != nil {
+ e.err = fmt.Errorf("encoding shallow %q: %s", s, err)
+ return nil
+ }
+
+ last = s
+ }
+
+ return e.encodeDepth
+}
+
+func (e *ulReqEncoder) encodeDepth() stateFn {
+ switch depth := e.data.Depth.(type) {
+ case DepthCommits:
+ if depth != 0 {
+ commits := int(depth)
+ if err := e.pe.Encodef("deepen %d\n", commits); err != nil {
+ e.err = fmt.Errorf("encoding depth %d: %s", depth, err)
+ return nil
+ }
+ }
+ case DepthSince:
+ when := time.Time(depth).UTC()
+ if err := e.pe.Encodef("deepen-since %d\n", when.Unix()); err != nil {
+ e.err = fmt.Errorf("encoding depth %s: %s", when, err)
+ return nil
+ }
+ case DepthReference:
+ reference := string(depth)
+ if err := e.pe.Encodef("deepen-not %s\n", reference); err != nil {
+ e.err = fmt.Errorf("encoding depth %s: %s", reference, err)
+ return nil
+ }
+ default:
+ e.err = fmt.Errorf("unsupported depth type")
+ return nil
+ }
+
+ return e.encodeFlush
+}
+
+func (e *ulReqEncoder) encodeFlush() stateFn {
+ if err := e.pe.Flush(); err != nil {
+ e.err = fmt.Errorf("encoding flush-pkt: %s", err)
+ return nil
+ }
+
+ return nil
+}
--- /dev/null
+package packp
+
+import (
+ "errors"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/capability"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/sideband"
+)
+
+var (
+ ErrEmptyCommands = errors.New("commands cannot be empty")
+ ErrMalformedCommand = errors.New("malformed command")
+)
+
+// ReferenceUpdateRequest values represent reference upload requests.
+// Values from this type are not zero-value safe, use the New function instead.
+type ReferenceUpdateRequest struct {
+ Capabilities *capability.List
+ Commands []*Command
+ Shallow *plumbing.Hash
+ // Packfile contains an optional packfile reader.
+ Packfile io.ReadCloser
+
+ // Progress receives sideband progress messages from the server
+ Progress sideband.Progress
+}
+
+// New returns a pointer to a new ReferenceUpdateRequest value.
+func NewReferenceUpdateRequest() *ReferenceUpdateRequest {
+ return &ReferenceUpdateRequest{
+ // TODO: Add support for push-cert
+ Capabilities: capability.NewList(),
+ Commands: nil,
+ }
+}
+
+// NewReferenceUpdateRequestFromCapabilities returns a pointer to a new
+// ReferenceUpdateRequest value, the request capabilities are filled with the
+// most optimal ones, based on the adv value (advertised capabilities), the
+// ReferenceUpdateRequest contains no commands
+//
+// It does set the following capabilities:
+// - agent
+// - report-status
+// - ofs-delta
+// - ref-delta
+// - delete-refs
+// It leaves up to the user to add the following capabilities later:
+// - atomic
+// - ofs-delta
+// - side-band
+// - side-band-64k
+// - quiet
+// - push-cert
+func NewReferenceUpdateRequestFromCapabilities(adv *capability.List) *ReferenceUpdateRequest {
+ r := NewReferenceUpdateRequest()
+
+ if adv.Supports(capability.Agent) {
+ r.Capabilities.Set(capability.Agent, capability.DefaultAgent)
+ }
+
+ if adv.Supports(capability.ReportStatus) {
+ r.Capabilities.Set(capability.ReportStatus)
+ }
+
+ return r
+}
+
+func (r *ReferenceUpdateRequest) validate() error {
+ if len(r.Commands) == 0 {
+ return ErrEmptyCommands
+ }
+
+ for _, c := range r.Commands {
+ if err := c.validate(); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+type Action string
+
+const (
+ Create Action = "create"
+ Update = "update"
+ Delete = "delete"
+ Invalid = "invalid"
+)
+
+type Command struct {
+ Name plumbing.ReferenceName
+ Old plumbing.Hash
+ New plumbing.Hash
+}
+
+func (c *Command) Action() Action {
+ if c.Old == plumbing.ZeroHash && c.New == plumbing.ZeroHash {
+ return Invalid
+ }
+
+ if c.Old == plumbing.ZeroHash {
+ return Create
+ }
+
+ if c.New == plumbing.ZeroHash {
+ return Delete
+ }
+
+ return Update
+}
+
+func (c *Command) validate() error {
+ if c.Action() == Invalid {
+ return ErrMalformedCommand
+ }
+
+ return nil
+}
--- /dev/null
+package packp
+
+import (
+ "bytes"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
+)
+
+var (
+ shallowLineLength = len(shallow) + hashSize
+ minCommandLength = hashSize*2 + 2 + 1
+ minCommandAndCapsLenth = minCommandLength + 1
+)
+
+var (
+ ErrEmpty = errors.New("empty update-request message")
+ errNoCommands = errors.New("unexpected EOF before any command")
+ errMissingCapabilitiesDelimiter = errors.New("capabilities delimiter not found")
+)
+
+func errMalformedRequest(reason string) error {
+ return fmt.Errorf("malformed request: %s", reason)
+}
+
+func errInvalidHashSize(got int) error {
+ return fmt.Errorf("invalid hash size: expected %d, got %d",
+ hashSize, got)
+}
+
+func errInvalidHash(err error) error {
+ return fmt.Errorf("invalid hash: %s", err.Error())
+}
+
+func errInvalidShallowLineLength(got int) error {
+ return errMalformedRequest(fmt.Sprintf(
+ "invalid shallow line length: expected %d, got %d",
+ shallowLineLength, got))
+}
+
+func errInvalidCommandCapabilitiesLineLength(got int) error {
+ return errMalformedRequest(fmt.Sprintf(
+ "invalid command and capabilities line length: expected at least %d, got %d",
+ minCommandAndCapsLenth, got))
+}
+
+func errInvalidCommandLineLength(got int) error {
+ return errMalformedRequest(fmt.Sprintf(
+ "invalid command line length: expected at least %d, got %d",
+ minCommandLength, got))
+}
+
+func errInvalidShallowObjId(err error) error {
+ return errMalformedRequest(
+ fmt.Sprintf("invalid shallow object id: %s", err.Error()))
+}
+
+func errInvalidOldObjId(err error) error {
+ return errMalformedRequest(
+ fmt.Sprintf("invalid old object id: %s", err.Error()))
+}
+
+func errInvalidNewObjId(err error) error {
+ return errMalformedRequest(
+ fmt.Sprintf("invalid new object id: %s", err.Error()))
+}
+
+func errMalformedCommand(err error) error {
+ return errMalformedRequest(fmt.Sprintf(
+ "malformed command: %s", err.Error()))
+}
+
+// Decode reads the next update-request message form the reader and wr
+func (req *ReferenceUpdateRequest) Decode(r io.Reader) error {
+ var rc io.ReadCloser
+ var ok bool
+ rc, ok = r.(io.ReadCloser)
+ if !ok {
+ rc = ioutil.NopCloser(r)
+ }
+
+ d := &updReqDecoder{r: rc, s: pktline.NewScanner(r)}
+ return d.Decode(req)
+}
+
+type updReqDecoder struct {
+ r io.ReadCloser
+ s *pktline.Scanner
+ req *ReferenceUpdateRequest
+}
+
+func (d *updReqDecoder) Decode(req *ReferenceUpdateRequest) error {
+ d.req = req
+ funcs := []func() error{
+ d.scanLine,
+ d.decodeShallow,
+ d.decodeCommandAndCapabilities,
+ d.decodeCommands,
+ d.setPackfile,
+ req.validate,
+ }
+
+ for _, f := range funcs {
+ if err := f(); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (d *updReqDecoder) scanLine() error {
+ if ok := d.s.Scan(); !ok {
+ return d.scanErrorOr(ErrEmpty)
+ }
+
+ return nil
+}
+
+func (d *updReqDecoder) decodeShallow() error {
+ b := d.s.Bytes()
+
+ if !bytes.HasPrefix(b, shallowNoSp) {
+ return nil
+ }
+
+ if len(b) != shallowLineLength {
+ return errInvalidShallowLineLength(len(b))
+ }
+
+ h, err := parseHash(string(b[len(shallow):]))
+ if err != nil {
+ return errInvalidShallowObjId(err)
+ }
+
+ if ok := d.s.Scan(); !ok {
+ return d.scanErrorOr(errNoCommands)
+ }
+
+ d.req.Shallow = &h
+
+ return nil
+}
+
+func (d *updReqDecoder) decodeCommands() error {
+ for {
+ b := d.s.Bytes()
+ if bytes.Equal(b, pktline.Flush) {
+ return nil
+ }
+
+ c, err := parseCommand(b)
+ if err != nil {
+ return err
+ }
+
+ d.req.Commands = append(d.req.Commands, c)
+
+ if ok := d.s.Scan(); !ok {
+ return d.s.Err()
+ }
+ }
+}
+
+func (d *updReqDecoder) decodeCommandAndCapabilities() error {
+ b := d.s.Bytes()
+ i := bytes.IndexByte(b, 0)
+ if i == -1 {
+ return errMissingCapabilitiesDelimiter
+ }
+
+ if len(b) < minCommandAndCapsLenth {
+ return errInvalidCommandCapabilitiesLineLength(len(b))
+ }
+
+ cmd, err := parseCommand(b[:i])
+ if err != nil {
+ return err
+ }
+
+ d.req.Commands = append(d.req.Commands, cmd)
+
+ if err := d.req.Capabilities.Decode(b[i+1:]); err != nil {
+ return err
+ }
+
+ if err := d.scanLine(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (d *updReqDecoder) setPackfile() error {
+ d.req.Packfile = d.r
+
+ return nil
+}
+
+func parseCommand(b []byte) (*Command, error) {
+ if len(b) < minCommandLength {
+ return nil, errInvalidCommandLineLength(len(b))
+ }
+
+ var (
+ os, ns string
+ n plumbing.ReferenceName
+ )
+ if _, err := fmt.Sscanf(string(b), "%s %s %s", &os, &ns, &n); err != nil {
+ return nil, errMalformedCommand(err)
+ }
+
+ oh, err := parseHash(os)
+ if err != nil {
+ return nil, errInvalidOldObjId(err)
+ }
+
+ nh, err := parseHash(ns)
+ if err != nil {
+ return nil, errInvalidNewObjId(err)
+ }
+
+ return &Command{Old: oh, New: nh, Name: plumbing.ReferenceName(n)}, nil
+}
+
+func parseHash(s string) (plumbing.Hash, error) {
+ if len(s) != hashSize {
+ return plumbing.ZeroHash, errInvalidHashSize(len(s))
+ }
+
+ if _, err := hex.DecodeString(s); err != nil {
+ return plumbing.ZeroHash, errInvalidHash(err)
+ }
+
+ h := plumbing.NewHash(s)
+ return h, nil
+}
+
+func (d *updReqDecoder) scanErrorOr(origErr error) error {
+ if err := d.s.Err(); err != nil {
+ return err
+ }
+
+ return origErr
+}
--- /dev/null
+package packp
+
+import (
+ "fmt"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/capability"
+)
+
+var (
+ zeroHashString = plumbing.ZeroHash.String()
+)
+
+// Encode writes the ReferenceUpdateRequest encoding to the stream.
+func (r *ReferenceUpdateRequest) Encode(w io.Writer) error {
+ if err := r.validate(); err != nil {
+ return err
+ }
+
+ e := pktline.NewEncoder(w)
+
+ if err := r.encodeShallow(e, r.Shallow); err != nil {
+ return err
+ }
+
+ if err := r.encodeCommands(e, r.Commands, r.Capabilities); err != nil {
+ return err
+ }
+
+ if r.Packfile != nil {
+ if _, err := io.Copy(w, r.Packfile); err != nil {
+ return err
+ }
+
+ return r.Packfile.Close()
+ }
+
+ return nil
+}
+
+func (r *ReferenceUpdateRequest) encodeShallow(e *pktline.Encoder,
+ h *plumbing.Hash) error {
+
+ if h == nil {
+ return nil
+ }
+
+ objId := []byte(h.String())
+ return e.Encodef("%s%s", shallow, objId)
+}
+
+func (r *ReferenceUpdateRequest) encodeCommands(e *pktline.Encoder,
+ cmds []*Command, cap *capability.List) error {
+
+ if err := e.Encodef("%s\x00%s",
+ formatCommand(cmds[0]), cap.String()); err != nil {
+ return err
+ }
+
+ for _, cmd := range cmds[1:] {
+ if err := e.Encodef(formatCommand(cmd)); err != nil {
+ return err
+ }
+ }
+
+ return e.Flush()
+}
+
+func formatCommand(cmd *Command) string {
+ o := cmd.Old.String()
+ n := cmd.New.String()
+ return fmt.Sprintf("%s %s %s", o, n, cmd.Name)
+}
--- /dev/null
+package packp
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/capability"
+)
+
+// UploadPackRequest represents a upload-pack request.
+// Zero-value is not safe, use NewUploadPackRequest instead.
+type UploadPackRequest struct {
+ UploadRequest
+ UploadHaves
+}
+
+// NewUploadPackRequest creates a new UploadPackRequest and returns a pointer.
+func NewUploadPackRequest() *UploadPackRequest {
+ ur := NewUploadRequest()
+ return &UploadPackRequest{
+ UploadHaves: UploadHaves{},
+ UploadRequest: *ur,
+ }
+}
+
+// NewUploadPackRequestFromCapabilities creates a new UploadPackRequest and
+// returns a pointer. The request capabilities are filled with the most optiomal
+// ones, based on the adv value (advertaised capabilities), the UploadPackRequest
+// it has no wants, haves or shallows and an infinite depth
+func NewUploadPackRequestFromCapabilities(adv *capability.List) *UploadPackRequest {
+ ur := NewUploadRequestFromCapabilities(adv)
+ return &UploadPackRequest{
+ UploadHaves: UploadHaves{},
+ UploadRequest: *ur,
+ }
+}
+
+// IsEmpty a request if empty if Haves are contained in the Wants, or if Wants
+// length is zero
+func (r *UploadPackRequest) IsEmpty() bool {
+ return isSubset(r.Wants, r.Haves)
+}
+
+func isSubset(needle []plumbing.Hash, haystack []plumbing.Hash) bool {
+ for _, h := range needle {
+ found := false
+ for _, oh := range haystack {
+ if h == oh {
+ found = true
+ break
+ }
+ }
+
+ if !found {
+ return false
+ }
+ }
+
+ return true
+}
+
+// UploadHaves is a message to signal the references that a client has in a
+// upload-pack. Do not use this directly. Use UploadPackRequest request instead.
+type UploadHaves struct {
+ Haves []plumbing.Hash
+}
+
+// Encode encodes the UploadHaves into the Writer. If flush is true, a flush
+// command will be encoded at the end of the writer content.
+func (u *UploadHaves) Encode(w io.Writer, flush bool) error {
+ e := pktline.NewEncoder(w)
+
+ plumbing.HashesSort(u.Haves)
+
+ var last plumbing.Hash
+ for _, have := range u.Haves {
+ if bytes.Equal(last[:], have[:]) {
+ continue
+ }
+
+ if err := e.Encodef("have %s\n", have); err != nil {
+ return fmt.Errorf("sending haves for %q: %s", have, err)
+ }
+
+ last = have
+ }
+
+ if flush && len(u.Haves) != 0 {
+ if err := e.Flush(); err != nil {
+ return fmt.Errorf("sending flush-pkt after haves: %s", err)
+ }
+ }
+
+ return nil
+}
--- /dev/null
+package packp
+
+import (
+ "errors"
+ "io"
+
+ "bufio"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/capability"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+// ErrUploadPackResponseNotDecoded is returned if Read is called without
+// decoding first
+var ErrUploadPackResponseNotDecoded = errors.New("upload-pack-response should be decoded")
+
+// UploadPackResponse contains all the information responded by the upload-pack
+// service, the response implements io.ReadCloser that allows to read the
+// packfile directly from it.
+type UploadPackResponse struct {
+ ShallowUpdate
+ ServerResponse
+
+ r io.ReadCloser
+ isShallow bool
+ isMultiACK bool
+ isOk bool
+}
+
+// NewUploadPackResponse create a new UploadPackResponse instance, the request
+// being responded by the response is required.
+func NewUploadPackResponse(req *UploadPackRequest) *UploadPackResponse {
+ isShallow := !req.Depth.IsZero()
+ isMultiACK := req.Capabilities.Supports(capability.MultiACK) ||
+ req.Capabilities.Supports(capability.MultiACKDetailed)
+
+ return &UploadPackResponse{
+ isShallow: isShallow,
+ isMultiACK: isMultiACK,
+ }
+}
+
+// NewUploadPackResponseWithPackfile creates a new UploadPackResponse instance,
+// and sets its packfile reader.
+func NewUploadPackResponseWithPackfile(req *UploadPackRequest,
+ pf io.ReadCloser) *UploadPackResponse {
+
+ r := NewUploadPackResponse(req)
+ r.r = pf
+ return r
+}
+
+// Decode decodes all the responses sent by upload-pack service into the struct
+// and prepares it to read the packfile using the Read method
+func (r *UploadPackResponse) Decode(reader io.ReadCloser) error {
+ buf := bufio.NewReader(reader)
+
+ if r.isShallow {
+ if err := r.ShallowUpdate.Decode(buf); err != nil {
+ return err
+ }
+ }
+
+ if err := r.ServerResponse.Decode(buf, r.isMultiACK); err != nil {
+ return err
+ }
+
+ // now the reader is ready to read the packfile content
+ r.r = ioutil.NewReadCloser(buf, reader)
+
+ return nil
+}
+
+// Encode encodes an UploadPackResponse.
+func (r *UploadPackResponse) Encode(w io.Writer) (err error) {
+ if r.isShallow {
+ if err := r.ShallowUpdate.Encode(w); err != nil {
+ return err
+ }
+ }
+
+ if err := r.ServerResponse.Encode(w); err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(r.r, &err)
+ _, err = io.Copy(w, r.r)
+ return err
+}
+
+// Read reads the packfile data, if the request was done with any Sideband
+// capability the content read should be demultiplexed. If the methods wasn't
+// called before the ErrUploadPackResponseNotDecoded will be return
+func (r *UploadPackResponse) Read(p []byte) (int, error) {
+ if r.r == nil {
+ return 0, ErrUploadPackResponseNotDecoded
+ }
+
+ return r.r.Read(p)
+}
+
+// Close the underlying reader, if any
+func (r *UploadPackResponse) Close() error {
+ if r.r == nil {
+ return nil
+ }
+
+ return r.r.Close()
+}
--- /dev/null
+package plumbing
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+)
+
+const (
+ refPrefix = "refs/"
+ refHeadPrefix = refPrefix + "heads/"
+ refTagPrefix = refPrefix + "tags/"
+ refRemotePrefix = refPrefix + "remotes/"
+ refNotePrefix = refPrefix + "notes/"
+ symrefPrefix = "ref: "
+)
+
+// RefRevParseRules are a set of rules to parse references into short names.
+// These are the same rules as used by git in shorten_unambiguous_ref.
+// See: https://github.com/git/git/blob/e0aaa1b6532cfce93d87af9bc813fb2e7a7ce9d7/refs.c#L417
+var RefRevParseRules = []string{
+ "refs/%s",
+ "refs/tags/%s",
+ "refs/heads/%s",
+ "refs/remotes/%s",
+ "refs/remotes/%s/HEAD",
+}
+
+var (
+ ErrReferenceNotFound = errors.New("reference not found")
+)
+
+// ReferenceType reference type's
+type ReferenceType int8
+
+const (
+ InvalidReference ReferenceType = 0
+ HashReference ReferenceType = 1
+ SymbolicReference ReferenceType = 2
+)
+
+func (r ReferenceType) String() string {
+ switch r {
+ case InvalidReference:
+ return "invalid-reference"
+ case HashReference:
+ return "hash-reference"
+ case SymbolicReference:
+ return "symbolic-reference"
+ }
+
+ return ""
+}
+
+// ReferenceName reference name's
+type ReferenceName string
+
+// NewBranchReferenceName returns a reference name describing a branch based on
+// his short name.
+func NewBranchReferenceName(name string) ReferenceName {
+ return ReferenceName(refHeadPrefix + name)
+}
+
+// NewNoteReferenceName returns a reference name describing a note based on his
+// short name.
+func NewNoteReferenceName(name string) ReferenceName {
+ return ReferenceName(refNotePrefix + name)
+}
+
+// NewRemoteReferenceName returns a reference name describing a remote branch
+// based on his short name and the remote name.
+func NewRemoteReferenceName(remote, name string) ReferenceName {
+ return ReferenceName(refRemotePrefix + fmt.Sprintf("%s/%s", remote, name))
+}
+
+// NewRemoteHEADReferenceName returns a reference name describing a the HEAD
+// branch of a remote.
+func NewRemoteHEADReferenceName(remote string) ReferenceName {
+ return ReferenceName(refRemotePrefix + fmt.Sprintf("%s/%s", remote, HEAD))
+}
+
+// NewTagReferenceName returns a reference name describing a tag based on short
+// his name.
+func NewTagReferenceName(name string) ReferenceName {
+ return ReferenceName(refTagPrefix + name)
+}
+
+// IsBranch check if a reference is a branch
+func (r ReferenceName) IsBranch() bool {
+ return strings.HasPrefix(string(r), refHeadPrefix)
+}
+
+// IsNote check if a reference is a note
+func (r ReferenceName) IsNote() bool {
+ return strings.HasPrefix(string(r), refNotePrefix)
+}
+
+// IsRemote check if a reference is a remote
+func (r ReferenceName) IsRemote() bool {
+ return strings.HasPrefix(string(r), refRemotePrefix)
+}
+
+// IsTag check if a reference is a tag
+func (r ReferenceName) IsTag() bool {
+ return strings.HasPrefix(string(r), refTagPrefix)
+}
+
+func (r ReferenceName) String() string {
+ return string(r)
+}
+
+// Short returns the short name of a ReferenceName
+func (r ReferenceName) Short() string {
+ s := string(r)
+ res := s
+ for _, format := range RefRevParseRules {
+ _, err := fmt.Sscanf(s, format, &res)
+ if err == nil {
+ continue
+ }
+ }
+
+ return res
+}
+
+const (
+ HEAD ReferenceName = "HEAD"
+ Master ReferenceName = "refs/heads/master"
+)
+
+// Reference is a representation of git reference
+type Reference struct {
+ t ReferenceType
+ n ReferenceName
+ h Hash
+ target ReferenceName
+}
+
+// NewReferenceFromStrings creates a reference from name and target as string,
+// the resulting reference can be a SymbolicReference or a HashReference base
+// on the target provided
+func NewReferenceFromStrings(name, target string) *Reference {
+ n := ReferenceName(name)
+
+ if strings.HasPrefix(target, symrefPrefix) {
+ target := ReferenceName(target[len(symrefPrefix):])
+ return NewSymbolicReference(n, target)
+ }
+
+ return NewHashReference(n, NewHash(target))
+}
+
+// NewSymbolicReference creates a new SymbolicReference reference
+func NewSymbolicReference(n, target ReferenceName) *Reference {
+ return &Reference{
+ t: SymbolicReference,
+ n: n,
+ target: target,
+ }
+}
+
+// NewHashReference creates a new HashReference reference
+func NewHashReference(n ReferenceName, h Hash) *Reference {
+ return &Reference{
+ t: HashReference,
+ n: n,
+ h: h,
+ }
+}
+
+// Type return the type of a reference
+func (r *Reference) Type() ReferenceType {
+ return r.t
+}
+
+// Name return the name of a reference
+func (r *Reference) Name() ReferenceName {
+ return r.n
+}
+
+// Hash return the hash of a hash reference
+func (r *Reference) Hash() Hash {
+ return r.h
+}
+
+// Target return the target of a symbolic reference
+func (r *Reference) Target() ReferenceName {
+ return r.target
+}
+
+// Strings dump a reference as a [2]string
+func (r *Reference) Strings() [2]string {
+ var o [2]string
+ o[0] = r.Name().String()
+
+ switch r.Type() {
+ case HashReference:
+ o[1] = r.Hash().String()
+ case SymbolicReference:
+ o[1] = symrefPrefix + r.Target().String()
+ }
+
+ return o
+}
+
+func (r *Reference) String() string {
+ s := r.Strings()
+ return fmt.Sprintf("%s %s", s[1], s[0])
+}
--- /dev/null
+package plumbing
+
+// Revision represents a git revision
+// to get more details about git revisions
+// please check git manual page :
+// https://www.kernel.org/pub/software/scm/git/docs/gitrevisions.html
+type Revision string
+
+func (r Revision) String() string {
+ return string(r)
+}
--- /dev/null
+// Package revlist provides support to access the ancestors of commits, in a
+// similar way as the git-rev-list command.
+package revlist
+
+import (
+ "fmt"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/filemode"
+ "gopkg.in/src-d/go-git.v4/plumbing/object"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+)
+
+// Objects applies a complementary set. It gets all the hashes from all
+// the reachable objects from the given objects. Ignore param are object hashes
+// that we want to ignore on the result. All that objects must be accessible
+// from the object storer.
+func Objects(
+ s storer.EncodedObjectStorer,
+ objs,
+ ignore []plumbing.Hash,
+) ([]plumbing.Hash, error) {
+ ignore, err := objects(s, ignore, nil, true)
+ if err != nil {
+ return nil, err
+ }
+
+ return objects(s, objs, ignore, false)
+}
+
+func objects(
+ s storer.EncodedObjectStorer,
+ objects,
+ ignore []plumbing.Hash,
+ allowMissingObjects bool,
+) ([]plumbing.Hash, error) {
+ seen := hashListToSet(ignore)
+ result := make(map[plumbing.Hash]bool)
+ visited := make(map[plumbing.Hash]bool)
+
+ walkerFunc := func(h plumbing.Hash) {
+ if !seen[h] {
+ result[h] = true
+ seen[h] = true
+ }
+ }
+
+ for _, h := range objects {
+ if err := processObject(s, h, seen, visited, ignore, walkerFunc); err != nil {
+ if allowMissingObjects && err == plumbing.ErrObjectNotFound {
+ continue
+ }
+
+ return nil, err
+ }
+ }
+
+ return hashSetToList(result), nil
+}
+
+// processObject obtains the object using the hash an process it depending of its type
+func processObject(
+ s storer.EncodedObjectStorer,
+ h plumbing.Hash,
+ seen map[plumbing.Hash]bool,
+ visited map[plumbing.Hash]bool,
+ ignore []plumbing.Hash,
+ walkerFunc func(h plumbing.Hash),
+) error {
+ if seen[h] {
+ return nil
+ }
+
+ o, err := s.EncodedObject(plumbing.AnyObject, h)
+ if err != nil {
+ return err
+ }
+
+ do, err := object.DecodeObject(s, o)
+ if err != nil {
+ return err
+ }
+
+ switch do := do.(type) {
+ case *object.Commit:
+ return reachableObjects(do, seen, visited, ignore, walkerFunc)
+ case *object.Tree:
+ return iterateCommitTrees(seen, do, walkerFunc)
+ case *object.Tag:
+ walkerFunc(do.Hash)
+ return processObject(s, do.Target, seen, visited, ignore, walkerFunc)
+ case *object.Blob:
+ walkerFunc(do.Hash)
+ default:
+ return fmt.Errorf("object type not valid: %s. "+
+ "Object reference: %s", o.Type(), o.Hash())
+ }
+
+ return nil
+}
+
+// reachableObjects returns, using the callback function, all the reachable
+// objects from the specified commit. To avoid to iterate over seen commits,
+// if a commit hash is into the 'seen' set, we will not iterate all his trees
+// and blobs objects.
+func reachableObjects(
+ commit *object.Commit,
+ seen map[plumbing.Hash]bool,
+ visited map[plumbing.Hash]bool,
+ ignore []plumbing.Hash,
+ cb func(h plumbing.Hash),
+) error {
+ i := object.NewCommitPreorderIter(commit, seen, ignore)
+ pending := make(map[plumbing.Hash]bool)
+ addPendingParents(pending, visited, commit)
+
+ for {
+ commit, err := i.Next()
+ if err == io.EOF {
+ break
+ }
+
+ if err != nil {
+ return err
+ }
+
+ if pending[commit.Hash] {
+ delete(pending, commit.Hash)
+ }
+
+ addPendingParents(pending, visited, commit)
+
+ if visited[commit.Hash] && len(pending) == 0 {
+ break
+ }
+
+ if seen[commit.Hash] {
+ continue
+ }
+
+ cb(commit.Hash)
+
+ tree, err := commit.Tree()
+ if err != nil {
+ return err
+ }
+
+ if err := iterateCommitTrees(seen, tree, cb); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func addPendingParents(pending, visited map[plumbing.Hash]bool, commit *object.Commit) {
+ for _, p := range commit.ParentHashes {
+ if !visited[p] {
+ pending[p] = true
+ }
+ }
+}
+
+// iterateCommitTrees iterate all reachable trees from the given commit
+func iterateCommitTrees(
+ seen map[plumbing.Hash]bool,
+ tree *object.Tree,
+ cb func(h plumbing.Hash),
+) error {
+ if seen[tree.Hash] {
+ return nil
+ }
+
+ cb(tree.Hash)
+
+ treeWalker := object.NewTreeWalker(tree, true, seen)
+
+ for {
+ _, e, err := treeWalker.Next()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return err
+ }
+
+ if e.Mode == filemode.Submodule {
+ continue
+ }
+
+ if seen[e.Hash] {
+ continue
+ }
+
+ cb(e.Hash)
+ }
+
+ return nil
+}
+
+func hashSetToList(hashes map[plumbing.Hash]bool) []plumbing.Hash {
+ var result []plumbing.Hash
+ for key := range hashes {
+ result = append(result, key)
+ }
+
+ return result
+}
+
+func hashListToSet(hashes []plumbing.Hash) map[plumbing.Hash]bool {
+ result := make(map[plumbing.Hash]bool)
+ for _, h := range hashes {
+ result[h] = true
+ }
+
+ return result
+}
--- /dev/null
+// Package storer defines the interfaces to store objects, references, etc.
+package storer
--- /dev/null
+package storer
+
+import "gopkg.in/src-d/go-git.v4/plumbing/format/index"
+
+// IndexStorer generic storage of index.Index
+type IndexStorer interface {
+ SetIndex(*index.Index) error
+ Index() (*index.Index, error)
+}
--- /dev/null
+package storer
+
+import (
+ "errors"
+ "io"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+)
+
+var (
+ //ErrStop is used to stop a ForEach function in an Iter
+ ErrStop = errors.New("stop iter")
+)
+
+// EncodedObjectStorer generic storage of objects
+type EncodedObjectStorer interface {
+ // NewEncodedObject returns a new plumbing.EncodedObject, the real type
+ // of the object can be a custom implementation or the default one,
+ // plumbing.MemoryObject.
+ NewEncodedObject() plumbing.EncodedObject
+ // SetEncodedObject saves an object into the storage, the object should
+ // be create with the NewEncodedObject, method, and file if the type is
+ // not supported.
+ SetEncodedObject(plumbing.EncodedObject) (plumbing.Hash, error)
+ // EncodedObject gets an object by hash with the given
+ // plumbing.ObjectType. Implementors should return
+ // (nil, plumbing.ErrObjectNotFound) if an object doesn't exist with
+ // both the given hash and object type.
+ //
+ // Valid plumbing.ObjectType values are CommitObject, BlobObject, TagObject,
+ // TreeObject and AnyObject. If plumbing.AnyObject is given, the object must
+ // be looked up regardless of its type.
+ EncodedObject(plumbing.ObjectType, plumbing.Hash) (plumbing.EncodedObject, error)
+ // IterObjects returns a custom EncodedObjectStorer over all the object
+ // on the storage.
+ //
+ // Valid plumbing.ObjectType values are CommitObject, BlobObject, TagObject,
+ IterEncodedObjects(plumbing.ObjectType) (EncodedObjectIter, error)
+ // HasEncodedObject returns ErrObjNotFound if the object doesn't
+ // exist. If the object does exist, it returns nil.
+ HasEncodedObject(plumbing.Hash) error
+ // EncodedObjectSize returns the plaintext size of the encoded object.
+ EncodedObjectSize(plumbing.Hash) (int64, error)
+}
+
+// DeltaObjectStorer is an EncodedObjectStorer that can return delta
+// objects.
+type DeltaObjectStorer interface {
+ // DeltaObject is the same as EncodedObject but without resolving deltas.
+ // Deltas will be returned as plumbing.DeltaObject instances.
+ DeltaObject(plumbing.ObjectType, plumbing.Hash) (plumbing.EncodedObject, error)
+}
+
+// Transactioner is a optional method for ObjectStorer, it enable transaction
+// base write and read operations in the storage
+type Transactioner interface {
+ // Begin starts a transaction.
+ Begin() Transaction
+}
+
+// LooseObjectStorer is an optional interface for managing "loose"
+// objects, i.e. those not in packfiles.
+type LooseObjectStorer interface {
+ // ForEachObjectHash iterates over all the (loose) object hashes
+ // in the repository without necessarily having to read those objects.
+ // Objects only inside pack files may be omitted.
+ // If ErrStop is sent the iteration is stop but no error is returned.
+ ForEachObjectHash(func(plumbing.Hash) error) error
+ // LooseObjectTime looks up the (m)time associated with the
+ // loose object (that is not in a pack file). Some
+ // implementations (e.g. without loose objects)
+ // always return an error.
+ LooseObjectTime(plumbing.Hash) (time.Time, error)
+ // DeleteLooseObject deletes a loose object if it exists.
+ DeleteLooseObject(plumbing.Hash) error
+}
+
+// PackedObjectStorer is an optional interface for managing objects in
+// packfiles.
+type PackedObjectStorer interface {
+ // ObjectPacks returns hashes of object packs if the underlying
+ // implementation has pack files.
+ ObjectPacks() ([]plumbing.Hash, error)
+ // DeleteOldObjectPackAndIndex deletes an object pack and the corresponding index file if they exist.
+ // Deletion is only performed if the pack is older than the supplied time (or the time is zero).
+ DeleteOldObjectPackAndIndex(plumbing.Hash, time.Time) error
+}
+
+// PackfileWriter is a optional method for ObjectStorer, it enable direct write
+// of packfile to the storage
+type PackfileWriter interface {
+ // PackfileWriter returns a writer for writing a packfile to the storage
+ //
+ // If the Storer not implements PackfileWriter the objects should be written
+ // using the Set method.
+ PackfileWriter() (io.WriteCloser, error)
+}
+
+// EncodedObjectIter is a generic closable interface for iterating over objects.
+type EncodedObjectIter interface {
+ Next() (plumbing.EncodedObject, error)
+ ForEach(func(plumbing.EncodedObject) error) error
+ Close()
+}
+
+// Transaction is an in-progress storage transaction. A transaction must end
+// with a call to Commit or Rollback.
+type Transaction interface {
+ SetEncodedObject(plumbing.EncodedObject) (plumbing.Hash, error)
+ EncodedObject(plumbing.ObjectType, plumbing.Hash) (plumbing.EncodedObject, error)
+ Commit() error
+ Rollback() error
+}
+
+// EncodedObjectLookupIter implements EncodedObjectIter. It iterates over a
+// series of object hashes and yields their associated objects by retrieving
+// each one from object storage. The retrievals are lazy and only occur when the
+// iterator moves forward with a call to Next().
+//
+// The EncodedObjectLookupIter must be closed with a call to Close() when it is
+// no longer needed.
+type EncodedObjectLookupIter struct {
+ storage EncodedObjectStorer
+ series []plumbing.Hash
+ t plumbing.ObjectType
+ pos int
+}
+
+// NewEncodedObjectLookupIter returns an object iterator given an object storage
+// and a slice of object hashes.
+func NewEncodedObjectLookupIter(
+ storage EncodedObjectStorer, t plumbing.ObjectType, series []plumbing.Hash) *EncodedObjectLookupIter {
+ return &EncodedObjectLookupIter{
+ storage: storage,
+ series: series,
+ t: t,
+ }
+}
+
+// Next returns the next object from the iterator. If the iterator has reached
+// the end it will return io.EOF as an error. If the object can't be found in
+// the object storage, it will return plumbing.ErrObjectNotFound as an error.
+// If the object is retreieved successfully error will be nil.
+func (iter *EncodedObjectLookupIter) Next() (plumbing.EncodedObject, error) {
+ if iter.pos >= len(iter.series) {
+ return nil, io.EOF
+ }
+
+ hash := iter.series[iter.pos]
+ obj, err := iter.storage.EncodedObject(iter.t, hash)
+ if err == nil {
+ iter.pos++
+ }
+
+ return obj, err
+}
+
+// ForEach call the cb function for each object contained on this iter until
+// an error happens or the end of the iter is reached. If ErrStop is sent
+// the iteration is stop but no error is returned. The iterator is closed.
+func (iter *EncodedObjectLookupIter) ForEach(cb func(plumbing.EncodedObject) error) error {
+ return ForEachIterator(iter, cb)
+}
+
+// Close releases any resources used by the iterator.
+func (iter *EncodedObjectLookupIter) Close() {
+ iter.pos = len(iter.series)
+}
+
+// EncodedObjectSliceIter implements EncodedObjectIter. It iterates over a
+// series of objects stored in a slice and yields each one in turn when Next()
+// is called.
+//
+// The EncodedObjectSliceIter must be closed with a call to Close() when it is
+// no longer needed.
+type EncodedObjectSliceIter struct {
+ series []plumbing.EncodedObject
+}
+
+// NewEncodedObjectSliceIter returns an object iterator for the given slice of
+// objects.
+func NewEncodedObjectSliceIter(series []plumbing.EncodedObject) *EncodedObjectSliceIter {
+ return &EncodedObjectSliceIter{
+ series: series,
+ }
+}
+
+// Next returns the next object from the iterator. If the iterator has reached
+// the end it will return io.EOF as an error. If the object is retreieved
+// successfully error will be nil.
+func (iter *EncodedObjectSliceIter) Next() (plumbing.EncodedObject, error) {
+ if len(iter.series) == 0 {
+ return nil, io.EOF
+ }
+
+ obj := iter.series[0]
+ iter.series = iter.series[1:]
+
+ return obj, nil
+}
+
+// ForEach call the cb function for each object contained on this iter until
+// an error happens or the end of the iter is reached. If ErrStop is sent
+// the iteration is stop but no error is returned. The iterator is closed.
+func (iter *EncodedObjectSliceIter) ForEach(cb func(plumbing.EncodedObject) error) error {
+ return ForEachIterator(iter, cb)
+}
+
+// Close releases any resources used by the iterator.
+func (iter *EncodedObjectSliceIter) Close() {
+ iter.series = []plumbing.EncodedObject{}
+}
+
+// MultiEncodedObjectIter implements EncodedObjectIter. It iterates over several
+// EncodedObjectIter,
+//
+// The MultiObjectIter must be closed with a call to Close() when it is no
+// longer needed.
+type MultiEncodedObjectIter struct {
+ iters []EncodedObjectIter
+}
+
+// NewMultiEncodedObjectIter returns an object iterator for the given slice of
+// objects.
+func NewMultiEncodedObjectIter(iters []EncodedObjectIter) EncodedObjectIter {
+ return &MultiEncodedObjectIter{iters: iters}
+}
+
+// Next returns the next object from the iterator, if one iterator reach io.EOF
+// is removed and the next one is used.
+func (iter *MultiEncodedObjectIter) Next() (plumbing.EncodedObject, error) {
+ if len(iter.iters) == 0 {
+ return nil, io.EOF
+ }
+
+ obj, err := iter.iters[0].Next()
+ if err == io.EOF {
+ iter.iters[0].Close()
+ iter.iters = iter.iters[1:]
+ return iter.Next()
+ }
+
+ return obj, err
+}
+
+// ForEach call the cb function for each object contained on this iter until
+// an error happens or the end of the iter is reached. If ErrStop is sent
+// the iteration is stop but no error is returned. The iterator is closed.
+func (iter *MultiEncodedObjectIter) ForEach(cb func(plumbing.EncodedObject) error) error {
+ return ForEachIterator(iter, cb)
+}
+
+// Close releases any resources used by the iterator.
+func (iter *MultiEncodedObjectIter) Close() {
+ for _, i := range iter.iters {
+ i.Close()
+ }
+}
+
+type bareIterator interface {
+ Next() (plumbing.EncodedObject, error)
+ Close()
+}
+
+// ForEachIterator is a helper function to build iterators without need to
+// rewrite the same ForEach function each time.
+func ForEachIterator(iter bareIterator, cb func(plumbing.EncodedObject) error) error {
+ defer iter.Close()
+ for {
+ obj, err := iter.Next()
+ if err != nil {
+ if err == io.EOF {
+ return nil
+ }
+
+ return err
+ }
+
+ if err := cb(obj); err != nil {
+ if err == ErrStop {
+ return nil
+ }
+
+ return err
+ }
+ }
+}
--- /dev/null
+package storer
+
+import (
+ "errors"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+)
+
+const MaxResolveRecursion = 1024
+
+// ErrMaxResolveRecursion is returned by ResolveReference is MaxResolveRecursion
+// is exceeded
+var ErrMaxResolveRecursion = errors.New("max. recursion level reached")
+
+// ReferenceStorer is a generic storage of references.
+type ReferenceStorer interface {
+ SetReference(*plumbing.Reference) error
+ // CheckAndSetReference sets the reference `new`, but if `old` is
+ // not `nil`, it first checks that the current stored value for
+ // `old.Name()` matches the given reference value in `old`. If
+ // not, it returns an error and doesn't update `new`.
+ CheckAndSetReference(new, old *plumbing.Reference) error
+ Reference(plumbing.ReferenceName) (*plumbing.Reference, error)
+ IterReferences() (ReferenceIter, error)
+ RemoveReference(plumbing.ReferenceName) error
+ CountLooseRefs() (int, error)
+ PackRefs() error
+}
+
+// ReferenceIter is a generic closable interface for iterating over references.
+type ReferenceIter interface {
+ Next() (*plumbing.Reference, error)
+ ForEach(func(*plumbing.Reference) error) error
+ Close()
+}
+
+type referenceFilteredIter struct {
+ ff func(r *plumbing.Reference) bool
+ iter ReferenceIter
+}
+
+// NewReferenceFilteredIter returns a reference iterator for the given reference
+// Iterator. This iterator will iterate only references that accomplish the
+// provided function.
+func NewReferenceFilteredIter(
+ ff func(r *plumbing.Reference) bool, iter ReferenceIter) ReferenceIter {
+ return &referenceFilteredIter{ff, iter}
+}
+
+// Next returns the next reference from the iterator. If the iterator has reached
+// the end it will return io.EOF as an error.
+func (iter *referenceFilteredIter) Next() (*plumbing.Reference, error) {
+ for {
+ r, err := iter.iter.Next()
+ if err != nil {
+ return nil, err
+ }
+
+ if iter.ff(r) {
+ return r, nil
+ }
+
+ continue
+ }
+}
+
+// ForEach call the cb function for each reference contained on this iter until
+// an error happens or the end of the iter is reached. If ErrStop is sent
+// the iteration is stopped but no error is returned. The iterator is closed.
+func (iter *referenceFilteredIter) ForEach(cb func(*plumbing.Reference) error) error {
+ defer iter.Close()
+ for {
+ r, err := iter.Next()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return err
+ }
+
+ if err := cb(r); err != nil {
+ if err == ErrStop {
+ break
+ }
+
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Close releases any resources used by the iterator.
+func (iter *referenceFilteredIter) Close() {
+ iter.iter.Close()
+}
+
+// ReferenceSliceIter implements ReferenceIter. It iterates over a series of
+// references stored in a slice and yields each one in turn when Next() is
+// called.
+//
+// The ReferenceSliceIter must be closed with a call to Close() when it is no
+// longer needed.
+type ReferenceSliceIter struct {
+ series []*plumbing.Reference
+ pos int
+}
+
+// NewReferenceSliceIter returns a reference iterator for the given slice of
+// objects.
+func NewReferenceSliceIter(series []*plumbing.Reference) ReferenceIter {
+ return &ReferenceSliceIter{
+ series: series,
+ }
+}
+
+// Next returns the next reference from the iterator. If the iterator has
+// reached the end it will return io.EOF as an error.
+func (iter *ReferenceSliceIter) Next() (*plumbing.Reference, error) {
+ if iter.pos >= len(iter.series) {
+ return nil, io.EOF
+ }
+
+ obj := iter.series[iter.pos]
+ iter.pos++
+ return obj, nil
+}
+
+// ForEach call the cb function for each reference contained on this iter until
+// an error happens or the end of the iter is reached. If ErrStop is sent
+// the iteration is stop but no error is returned. The iterator is closed.
+func (iter *ReferenceSliceIter) ForEach(cb func(*plumbing.Reference) error) error {
+ defer iter.Close()
+ for _, r := range iter.series {
+ if err := cb(r); err != nil {
+ if err == ErrStop {
+ return nil
+ }
+
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Close releases any resources used by the iterator.
+func (iter *ReferenceSliceIter) Close() {
+ iter.pos = len(iter.series)
+}
+
+// ResolveReference resolves a SymbolicReference to a HashReference.
+func ResolveReference(s ReferenceStorer, n plumbing.ReferenceName) (*plumbing.Reference, error) {
+ r, err := s.Reference(n)
+ if err != nil || r == nil {
+ return r, err
+ }
+ return resolveReference(s, r, 0)
+}
+
+func resolveReference(s ReferenceStorer, r *plumbing.Reference, recursion int) (*plumbing.Reference, error) {
+ if r.Type() != plumbing.SymbolicReference {
+ return r, nil
+ }
+
+ if recursion > MaxResolveRecursion {
+ return nil, ErrMaxResolveRecursion
+ }
+
+ t, err := s.Reference(r.Target())
+ if err != nil {
+ return nil, err
+ }
+
+ recursion++
+ return resolveReference(s, t, recursion)
+}
--- /dev/null
+package storer
+
+import "gopkg.in/src-d/go-git.v4/plumbing"
+
+// ShallowStorer is a storage of references to shallow commits by hash,
+// meaning that these commits have missing parents because of a shallow fetch.
+type ShallowStorer interface {
+ SetShallow([]plumbing.Hash) error
+ Shallow() ([]plumbing.Hash, error)
+}
--- /dev/null
+package storer
+
+// Storer is a basic storer for encoded objects and references.
+type Storer interface {
+ EncodedObjectStorer
+ ReferenceStorer
+}
+
+// Initializer should be implemented by storers that require to perform any
+// operation when creating a new repository (i.e. git init).
+type Initializer interface {
+ // Init performs initialization of the storer and returns the error, if
+ // any.
+ Init() error
+}
--- /dev/null
+// Package client contains helper function to deal with the different client
+// protocols.
+package client
+
+import (
+ "fmt"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/transport"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport/file"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport/git"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport/http"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport/ssh"
+)
+
+// Protocols are the protocols supported by default.
+var Protocols = map[string]transport.Transport{
+ "http": http.DefaultClient,
+ "https": http.DefaultClient,
+ "ssh": ssh.DefaultClient,
+ "git": git.DefaultClient,
+ "file": file.DefaultClient,
+}
+
+// InstallProtocol adds or modifies an existing protocol.
+func InstallProtocol(scheme string, c transport.Transport) {
+ if c == nil {
+ delete(Protocols, scheme)
+ return
+ }
+
+ Protocols[scheme] = c
+}
+
+// NewClient returns the appropriate client among of the set of known protocols:
+// http://, https://, ssh:// and file://.
+// See `InstallProtocol` to add or modify protocols.
+func NewClient(endpoint *transport.Endpoint) (transport.Transport, error) {
+ f, ok := Protocols[endpoint.Protocol]
+ if !ok {
+ return nil, fmt.Errorf("unsupported scheme %q", endpoint.Protocol)
+ }
+
+ if f == nil {
+ return nil, fmt.Errorf("malformed client for scheme %q, client is defined as nil", endpoint.Protocol)
+ }
+
+ return f, nil
+}
--- /dev/null
+// Package transport includes the implementation for different transport
+// protocols.
+//
+// `Client` can be used to fetch and send packfiles to a git server.
+// The `client` package provides higher level functions to instantiate the
+// appropriate `Client` based on the repository URL.
+//
+// go-git supports HTTP and SSH (see `Protocols`), but you can also install
+// your own protocols (see the `client` package).
+//
+// Each protocol has its own implementation of `Client`, but you should
+// generally not use them directly, use `client.NewClient` instead.
+package transport
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net/url"
+ "regexp"
+ "strconv"
+ "strings"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/capability"
+)
+
+var (
+ ErrRepositoryNotFound = errors.New("repository not found")
+ ErrEmptyRemoteRepository = errors.New("remote repository is empty")
+ ErrAuthenticationRequired = errors.New("authentication required")
+ ErrAuthorizationFailed = errors.New("authorization failed")
+ ErrEmptyUploadPackRequest = errors.New("empty git-upload-pack given")
+ ErrInvalidAuthMethod = errors.New("invalid auth method")
+ ErrAlreadyConnected = errors.New("session already established")
+)
+
+const (
+ UploadPackServiceName = "git-upload-pack"
+ ReceivePackServiceName = "git-receive-pack"
+)
+
+// Transport can initiate git-upload-pack and git-receive-pack processes.
+// It is implemented both by the client and the server, making this a RPC.
+type Transport interface {
+ // NewUploadPackSession starts a git-upload-pack session for an endpoint.
+ NewUploadPackSession(*Endpoint, AuthMethod) (UploadPackSession, error)
+ // NewReceivePackSession starts a git-receive-pack session for an endpoint.
+ NewReceivePackSession(*Endpoint, AuthMethod) (ReceivePackSession, error)
+}
+
+type Session interface {
+ // AdvertisedReferences retrieves the advertised references for a
+ // repository.
+ // If the repository does not exist, returns ErrRepositoryNotFound.
+ // If the repository exists, but is empty, returns ErrEmptyRemoteRepository.
+ AdvertisedReferences() (*packp.AdvRefs, error)
+ io.Closer
+}
+
+type AuthMethod interface {
+ fmt.Stringer
+ Name() string
+}
+
+// UploadPackSession represents a git-upload-pack session.
+// A git-upload-pack session has two steps: reference discovery
+// (AdvertisedReferences) and uploading pack (UploadPack).
+type UploadPackSession interface {
+ Session
+ // UploadPack takes a git-upload-pack request and returns a response,
+ // including a packfile. Don't be confused by terminology, the client
+ // side of a git-upload-pack is called git-fetch-pack, although here
+ // the same interface is used to make it RPC-like.
+ UploadPack(context.Context, *packp.UploadPackRequest) (*packp.UploadPackResponse, error)
+}
+
+// ReceivePackSession represents a git-receive-pack session.
+// A git-receive-pack session has two steps: reference discovery
+// (AdvertisedReferences) and receiving pack (ReceivePack).
+// In that order.
+type ReceivePackSession interface {
+ Session
+ // ReceivePack sends an update references request and a packfile
+ // reader and returns a ReportStatus and error. Don't be confused by
+ // terminology, the client side of a git-receive-pack is called
+ // git-send-pack, although here the same interface is used to make it
+ // RPC-like.
+ ReceivePack(context.Context, *packp.ReferenceUpdateRequest) (*packp.ReportStatus, error)
+}
+
+// Endpoint represents a Git URL in any supported protocol.
+type Endpoint struct {
+ // Protocol is the protocol of the endpoint (e.g. git, https, file).
+ Protocol string
+ // User is the user.
+ User string
+ // Password is the password.
+ Password string
+ // Host is the host.
+ Host string
+ // Port is the port to connect, if 0 the default port for the given protocol
+ // wil be used.
+ Port int
+ // Path is the repository path.
+ Path string
+}
+
+var defaultPorts = map[string]int{
+ "http": 80,
+ "https": 443,
+ "git": 9418,
+ "ssh": 22,
+}
+
+// String returns a string representation of the Git URL.
+func (u *Endpoint) String() string {
+ var buf bytes.Buffer
+ if u.Protocol != "" {
+ buf.WriteString(u.Protocol)
+ buf.WriteByte(':')
+ }
+
+ if u.Protocol != "" || u.Host != "" || u.User != "" || u.Password != "" {
+ buf.WriteString("//")
+
+ if u.User != "" || u.Password != "" {
+ buf.WriteString(url.PathEscape(u.User))
+ if u.Password != "" {
+ buf.WriteByte(':')
+ buf.WriteString(url.PathEscape(u.Password))
+ }
+
+ buf.WriteByte('@')
+ }
+
+ if u.Host != "" {
+ buf.WriteString(u.Host)
+
+ if u.Port != 0 {
+ port, ok := defaultPorts[strings.ToLower(u.Protocol)]
+ if !ok || ok && port != u.Port {
+ fmt.Fprintf(&buf, ":%d", u.Port)
+ }
+ }
+ }
+ }
+
+ if u.Path != "" && u.Path[0] != '/' && u.Host != "" {
+ buf.WriteByte('/')
+ }
+
+ buf.WriteString(u.Path)
+ return buf.String()
+}
+
+func NewEndpoint(endpoint string) (*Endpoint, error) {
+ if e, ok := parseSCPLike(endpoint); ok {
+ return e, nil
+ }
+
+ if e, ok := parseFile(endpoint); ok {
+ return e, nil
+ }
+
+ return parseURL(endpoint)
+}
+
+func parseURL(endpoint string) (*Endpoint, error) {
+ u, err := url.Parse(endpoint)
+ if err != nil {
+ return nil, err
+ }
+
+ if !u.IsAbs() {
+ return nil, plumbing.NewPermanentError(fmt.Errorf(
+ "invalid endpoint: %s", endpoint,
+ ))
+ }
+
+ var user, pass string
+ if u.User != nil {
+ user = u.User.Username()
+ pass, _ = u.User.Password()
+ }
+
+ return &Endpoint{
+ Protocol: u.Scheme,
+ User: user,
+ Password: pass,
+ Host: u.Hostname(),
+ Port: getPort(u),
+ Path: getPath(u),
+ }, nil
+}
+
+func getPort(u *url.URL) int {
+ p := u.Port()
+ if p == "" {
+ return 0
+ }
+
+ i, err := strconv.Atoi(p)
+ if err != nil {
+ return 0
+ }
+
+ return i
+}
+
+func getPath(u *url.URL) string {
+ var res string = u.Path
+ if u.RawQuery != "" {
+ res += "?" + u.RawQuery
+ }
+
+ if u.Fragment != "" {
+ res += "#" + u.Fragment
+ }
+
+ return res
+}
+
+var (
+ isSchemeRegExp = regexp.MustCompile(`^[^:]+://`)
+ scpLikeUrlRegExp = regexp.MustCompile(`^(?:(?P<user>[^@]+)@)?(?P<host>[^:\s]+):(?:(?P<port>[0-9]{1,5})/)?(?P<path>[^\\].*)$`)
+)
+
+func parseSCPLike(endpoint string) (*Endpoint, bool) {
+ if isSchemeRegExp.MatchString(endpoint) || !scpLikeUrlRegExp.MatchString(endpoint) {
+ return nil, false
+ }
+
+ m := scpLikeUrlRegExp.FindStringSubmatch(endpoint)
+
+ port, err := strconv.Atoi(m[3])
+ if err != nil {
+ port = 22
+ }
+
+ return &Endpoint{
+ Protocol: "ssh",
+ User: m[1],
+ Host: m[2],
+ Port: port,
+ Path: m[4],
+ }, true
+}
+
+func parseFile(endpoint string) (*Endpoint, bool) {
+ if isSchemeRegExp.MatchString(endpoint) {
+ return nil, false
+ }
+
+ path := endpoint
+ return &Endpoint{
+ Protocol: "file",
+ Path: path,
+ }, true
+}
+
+// UnsupportedCapabilities are the capabilities not supported by any client
+// implementation
+var UnsupportedCapabilities = []capability.Capability{
+ capability.MultiACK,
+ capability.MultiACKDetailed,
+ capability.ThinPack,
+}
+
+// FilterUnsupportedCapabilities it filter out all the UnsupportedCapabilities
+// from a capability.List, the intended usage is on the client implementation
+// to filter the capabilities from an AdvRefs message.
+func FilterUnsupportedCapabilities(list *capability.List) {
+ for _, c := range UnsupportedCapabilities {
+ list.Delete(c)
+ }
+}
--- /dev/null
+// Package file implements the file transport protocol.
+package file
+
+import (
+ "bufio"
+ "errors"
+ "io"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/transport"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport/internal/common"
+)
+
+// DefaultClient is the default local client.
+var DefaultClient = NewClient(
+ transport.UploadPackServiceName,
+ transport.ReceivePackServiceName,
+)
+
+type runner struct {
+ UploadPackBin string
+ ReceivePackBin string
+}
+
+// NewClient returns a new local client using the given git-upload-pack and
+// git-receive-pack binaries.
+func NewClient(uploadPackBin, receivePackBin string) transport.Transport {
+ return common.NewClient(&runner{
+ UploadPackBin: uploadPackBin,
+ ReceivePackBin: receivePackBin,
+ })
+}
+
+func prefixExecPath(cmd string) (string, error) {
+ // Use `git --exec-path` to find the exec path.
+ execCmd := exec.Command("git", "--exec-path")
+
+ stdout, err := execCmd.StdoutPipe()
+ if err != nil {
+ return "", err
+ }
+ stdoutBuf := bufio.NewReader(stdout)
+
+ err = execCmd.Start()
+ if err != nil {
+ return "", err
+ }
+
+ execPathBytes, isPrefix, err := stdoutBuf.ReadLine()
+ if err != nil {
+ return "", err
+ }
+ if isPrefix {
+ return "", errors.New("Couldn't read exec-path line all at once")
+ }
+
+ err = execCmd.Wait()
+ if err != nil {
+ return "", err
+ }
+ execPath := string(execPathBytes)
+ execPath = strings.TrimSpace(execPath)
+ cmd = filepath.Join(execPath, cmd)
+
+ // Make sure it actually exists.
+ _, err = exec.LookPath(cmd)
+ if err != nil {
+ return "", err
+ }
+ return cmd, nil
+}
+
+func (r *runner) Command(cmd string, ep *transport.Endpoint, auth transport.AuthMethod,
+) (common.Command, error) {
+
+ switch cmd {
+ case transport.UploadPackServiceName:
+ cmd = r.UploadPackBin
+ case transport.ReceivePackServiceName:
+ cmd = r.ReceivePackBin
+ }
+
+ _, err := exec.LookPath(cmd)
+ if err != nil {
+ if e, ok := err.(*exec.Error); ok && e.Err == exec.ErrNotFound {
+ cmd, err = prefixExecPath(cmd)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ return nil, err
+ }
+ }
+
+ return &command{cmd: exec.Command(cmd, ep.Path)}, nil
+}
+
+type command struct {
+ cmd *exec.Cmd
+ stderrCloser io.Closer
+ closed bool
+}
+
+func (c *command) Start() error {
+ return c.cmd.Start()
+}
+
+func (c *command) StderrPipe() (io.Reader, error) {
+ // Pipe returned by Command.StderrPipe has a race with Read + Command.Wait.
+ // We use an io.Pipe and close it after the command finishes.
+ r, w := io.Pipe()
+ c.cmd.Stderr = w
+ c.stderrCloser = r
+ return r, nil
+}
+
+func (c *command) StdinPipe() (io.WriteCloser, error) {
+ return c.cmd.StdinPipe()
+}
+
+func (c *command) StdoutPipe() (io.Reader, error) {
+ return c.cmd.StdoutPipe()
+}
+
+func (c *command) Kill() error {
+ c.cmd.Process.Kill()
+ return c.Close()
+}
+
+// Close waits for the command to exit.
+func (c *command) Close() error {
+ if c.closed {
+ return nil
+ }
+
+ defer func() {
+ c.closed = true
+ _ = c.stderrCloser.Close()
+
+ }()
+
+ err := c.cmd.Wait()
+ if _, ok := err.(*os.PathError); ok {
+ return nil
+ }
+
+ // When a repository does not exist, the command exits with code 128.
+ if _, ok := err.(*exec.ExitError); ok {
+ return nil
+ }
+
+ return err
+}
--- /dev/null
+package file
+
+import (
+ "fmt"
+ "os"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/transport"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport/internal/common"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport/server"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+// ServeUploadPack serves a git-upload-pack request using standard output, input
+// and error. This is meant to be used when implementing a git-upload-pack
+// command.
+func ServeUploadPack(path string) error {
+ ep, err := transport.NewEndpoint(path)
+ if err != nil {
+ return err
+ }
+
+ // TODO: define and implement a server-side AuthMethod
+ s, err := server.DefaultServer.NewUploadPackSession(ep, nil)
+ if err != nil {
+ return fmt.Errorf("error creating session: %s", err)
+ }
+
+ return common.ServeUploadPack(srvCmd, s)
+}
+
+// ServeReceivePack serves a git-receive-pack request using standard output,
+// input and error. This is meant to be used when implementing a
+// git-receive-pack command.
+func ServeReceivePack(path string) error {
+ ep, err := transport.NewEndpoint(path)
+ if err != nil {
+ return err
+ }
+
+ // TODO: define and implement a server-side AuthMethod
+ s, err := server.DefaultServer.NewReceivePackSession(ep, nil)
+ if err != nil {
+ return fmt.Errorf("error creating session: %s", err)
+ }
+
+ return common.ServeReceivePack(srvCmd, s)
+}
+
+var srvCmd = common.ServerCommand{
+ Stdin: os.Stdin,
+ Stdout: ioutil.WriteNopCloser(os.Stdout),
+ Stderr: os.Stderr,
+}
--- /dev/null
+// Package git implements the git transport protocol.
+package git
+
+import (
+ "fmt"
+ "io"
+ "net"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport/internal/common"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+// DefaultClient is the default git client.
+var DefaultClient = common.NewClient(&runner{})
+
+const DefaultPort = 9418
+
+type runner struct{}
+
+// Command returns a new Command for the given cmd in the given Endpoint
+func (r *runner) Command(cmd string, ep *transport.Endpoint, auth transport.AuthMethod) (common.Command, error) {
+ // auth not allowed since git protocol doesn't support authentication
+ if auth != nil {
+ return nil, transport.ErrInvalidAuthMethod
+ }
+ c := &command{command: cmd, endpoint: ep}
+ if err := c.connect(); err != nil {
+ return nil, err
+ }
+ return c, nil
+}
+
+type command struct {
+ conn net.Conn
+ connected bool
+ command string
+ endpoint *transport.Endpoint
+}
+
+// Start executes the command sending the required message to the TCP connection
+func (c *command) Start() error {
+ cmd := endpointToCommand(c.command, c.endpoint)
+
+ e := pktline.NewEncoder(c.conn)
+ return e.Encode([]byte(cmd))
+}
+
+func (c *command) connect() error {
+ if c.connected {
+ return transport.ErrAlreadyConnected
+ }
+
+ var err error
+ c.conn, err = net.Dial("tcp", c.getHostWithPort())
+ if err != nil {
+ return err
+ }
+
+ c.connected = true
+ return nil
+}
+
+func (c *command) getHostWithPort() string {
+ host := c.endpoint.Host
+ port := c.endpoint.Port
+ if port <= 0 {
+ port = DefaultPort
+ }
+
+ return fmt.Sprintf("%s:%d", host, port)
+}
+
+// StderrPipe git protocol doesn't have any dedicated error channel
+func (c *command) StderrPipe() (io.Reader, error) {
+ return nil, nil
+}
+
+// StdinPipe return the underlying connection as WriteCloser, wrapped to prevent
+// call to the Close function from the connection, a command execution in git
+// protocol can't be closed or killed
+func (c *command) StdinPipe() (io.WriteCloser, error) {
+ return ioutil.WriteNopCloser(c.conn), nil
+}
+
+// StdoutPipe return the underlying connection as Reader
+func (c *command) StdoutPipe() (io.Reader, error) {
+ return c.conn, nil
+}
+
+func endpointToCommand(cmd string, ep *transport.Endpoint) string {
+ host := ep.Host
+ if ep.Port != DefaultPort {
+ host = fmt.Sprintf("%s:%d", ep.Host, ep.Port)
+ }
+
+ return fmt.Sprintf("%s %s%chost=%s%c", cmd, ep.Path, 0, host, 0)
+}
+
+// Close closes the TCP connection and connection.
+func (c *command) Close() error {
+ if !c.connected {
+ return nil
+ }
+
+ c.connected = false
+ return c.conn.Close()
+}
--- /dev/null
+// Package http implements the HTTP transport protocol.
+package http
+
+import (
+ "bytes"
+ "fmt"
+ "net"
+ "net/http"
+ "strconv"
+ "strings"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+// it requires a bytes.Buffer, because we need to know the length
+func applyHeadersToRequest(req *http.Request, content *bytes.Buffer, host string, requestType string) {
+ req.Header.Add("User-Agent", "git/1.0")
+ req.Header.Add("Host", host) // host:port
+
+ if content == nil {
+ req.Header.Add("Accept", "*/*")
+ return
+ }
+
+ req.Header.Add("Accept", fmt.Sprintf("application/x-%s-result", requestType))
+ req.Header.Add("Content-Type", fmt.Sprintf("application/x-%s-request", requestType))
+ req.Header.Add("Content-Length", strconv.Itoa(content.Len()))
+}
+
+const infoRefsPath = "/info/refs"
+
+func advertisedReferences(s *session, serviceName string) (ref *packp.AdvRefs, err error) {
+ url := fmt.Sprintf(
+ "%s%s?service=%s",
+ s.endpoint.String(), infoRefsPath, serviceName,
+ )
+
+ req, err := http.NewRequest(http.MethodGet, url, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ s.ApplyAuthToRequest(req)
+ applyHeadersToRequest(req, nil, s.endpoint.Host, serviceName)
+ res, err := s.client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+
+ s.ModifyEndpointIfRedirect(res)
+ defer ioutil.CheckClose(res.Body, &err)
+
+ if err = NewErr(res); err != nil {
+ return nil, err
+ }
+
+ ar := packp.NewAdvRefs()
+ if err = ar.Decode(res.Body); err != nil {
+ if err == packp.ErrEmptyAdvRefs {
+ err = transport.ErrEmptyRemoteRepository
+ }
+
+ return nil, err
+ }
+
+ transport.FilterUnsupportedCapabilities(ar.Capabilities)
+ s.advRefs = ar
+
+ return ar, nil
+}
+
+type client struct {
+ c *http.Client
+}
+
+// DefaultClient is the default HTTP client, which uses `http.DefaultClient`.
+var DefaultClient = NewClient(nil)
+
+// NewClient creates a new client with a custom net/http client.
+// See `InstallProtocol` to install and override default http client.
+// Unless a properly initialized client is given, it will fall back into
+// `http.DefaultClient`.
+//
+// Note that for HTTP client cannot distinguist between private repositories and
+// unexistent repositories on GitHub. So it returns `ErrAuthorizationRequired`
+// for both.
+func NewClient(c *http.Client) transport.Transport {
+ if c == nil {
+ return &client{http.DefaultClient}
+ }
+
+ return &client{
+ c: c,
+ }
+}
+
+func (c *client) NewUploadPackSession(ep *transport.Endpoint, auth transport.AuthMethod) (
+ transport.UploadPackSession, error) {
+
+ return newUploadPackSession(c.c, ep, auth)
+}
+
+func (c *client) NewReceivePackSession(ep *transport.Endpoint, auth transport.AuthMethod) (
+ transport.ReceivePackSession, error) {
+
+ return newReceivePackSession(c.c, ep, auth)
+}
+
+type session struct {
+ auth AuthMethod
+ client *http.Client
+ endpoint *transport.Endpoint
+ advRefs *packp.AdvRefs
+}
+
+func newSession(c *http.Client, ep *transport.Endpoint, auth transport.AuthMethod) (*session, error) {
+ s := &session{
+ auth: basicAuthFromEndpoint(ep),
+ client: c,
+ endpoint: ep,
+ }
+ if auth != nil {
+ a, ok := auth.(AuthMethod)
+ if !ok {
+ return nil, transport.ErrInvalidAuthMethod
+ }
+
+ s.auth = a
+ }
+
+ return s, nil
+}
+
+func (s *session) ApplyAuthToRequest(req *http.Request) {
+ if s.auth == nil {
+ return
+ }
+
+ s.auth.setAuth(req)
+}
+
+func (s *session) ModifyEndpointIfRedirect(res *http.Response) {
+ if res.Request == nil {
+ return
+ }
+
+ r := res.Request
+ if !strings.HasSuffix(r.URL.Path, infoRefsPath) {
+ return
+ }
+
+ h, p, err := net.SplitHostPort(r.URL.Host)
+ if err != nil {
+ h = r.URL.Host
+ }
+ if p != "" {
+ port, err := strconv.Atoi(p)
+ if err == nil {
+ s.endpoint.Port = port
+ }
+ }
+ s.endpoint.Host = h
+
+ s.endpoint.Protocol = r.URL.Scheme
+ s.endpoint.Path = r.URL.Path[:len(r.URL.Path)-len(infoRefsPath)]
+}
+
+func (*session) Close() error {
+ return nil
+}
+
+// AuthMethod is concrete implementation of common.AuthMethod for HTTP services
+type AuthMethod interface {
+ transport.AuthMethod
+ setAuth(r *http.Request)
+}
+
+func basicAuthFromEndpoint(ep *transport.Endpoint) *BasicAuth {
+ u := ep.User
+ if u == "" {
+ return nil
+ }
+
+ return &BasicAuth{u, ep.Password}
+}
+
+// BasicAuth represent a HTTP basic auth
+type BasicAuth struct {
+ Username, Password string
+}
+
+func (a *BasicAuth) setAuth(r *http.Request) {
+ if a == nil {
+ return
+ }
+
+ r.SetBasicAuth(a.Username, a.Password)
+}
+
+// Name is name of the auth
+func (a *BasicAuth) Name() string {
+ return "http-basic-auth"
+}
+
+func (a *BasicAuth) String() string {
+ masked := "*******"
+ if a.Password == "" {
+ masked = "<empty>"
+ }
+
+ return fmt.Sprintf("%s - %s:%s", a.Name(), a.Username, masked)
+}
+
+// TokenAuth implements an http.AuthMethod that can be used with http transport
+// to authenticate with HTTP token authentication (also known as bearer
+// authentication).
+//
+// IMPORTANT: If you are looking to use OAuth tokens with popular servers (e.g.
+// GitHub, Bitbucket, GitLab) you should use BasicAuth instead. These servers
+// use basic HTTP authentication, with the OAuth token as user or password.
+// Check the documentation of your git server for details.
+type TokenAuth struct {
+ Token string
+}
+
+func (a *TokenAuth) setAuth(r *http.Request) {
+ if a == nil {
+ return
+ }
+ r.Header.Add("Authorization", fmt.Sprintf("Bearer %s", a.Token))
+}
+
+// Name is name of the auth
+func (a *TokenAuth) Name() string {
+ return "http-token-auth"
+}
+
+func (a *TokenAuth) String() string {
+ masked := "*******"
+ if a.Token == "" {
+ masked = "<empty>"
+ }
+ return fmt.Sprintf("%s - %s", a.Name(), masked)
+}
+
+// Err is a dedicated error to return errors based on status code
+type Err struct {
+ Response *http.Response
+}
+
+// NewErr returns a new Err based on a http response
+func NewErr(r *http.Response) error {
+ if r.StatusCode >= http.StatusOK && r.StatusCode < http.StatusMultipleChoices {
+ return nil
+ }
+
+ switch r.StatusCode {
+ case http.StatusUnauthorized:
+ return transport.ErrAuthenticationRequired
+ case http.StatusForbidden:
+ return transport.ErrAuthorizationFailed
+ case http.StatusNotFound:
+ return transport.ErrRepositoryNotFound
+ }
+
+ return plumbing.NewUnexpectedError(&Err{r})
+}
+
+// StatusCode returns the status code of the response
+func (e *Err) StatusCode() int {
+ return e.Response.StatusCode
+}
+
+func (e *Err) Error() string {
+ return fmt.Sprintf("unexpected requesting %q status code: %d",
+ e.Response.Request.URL, e.Response.StatusCode,
+ )
+}
--- /dev/null
+package http
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/capability"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/sideband"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+type rpSession struct {
+ *session
+}
+
+func newReceivePackSession(c *http.Client, ep *transport.Endpoint, auth transport.AuthMethod) (transport.ReceivePackSession, error) {
+ s, err := newSession(c, ep, auth)
+ return &rpSession{s}, err
+}
+
+func (s *rpSession) AdvertisedReferences() (*packp.AdvRefs, error) {
+ return advertisedReferences(s.session, transport.ReceivePackServiceName)
+}
+
+func (s *rpSession) ReceivePack(ctx context.Context, req *packp.ReferenceUpdateRequest) (
+ *packp.ReportStatus, error) {
+ url := fmt.Sprintf(
+ "%s/%s",
+ s.endpoint.String(), transport.ReceivePackServiceName,
+ )
+
+ buf := bytes.NewBuffer(nil)
+ if err := req.Encode(buf); err != nil {
+ return nil, err
+ }
+
+ res, err := s.doRequest(ctx, http.MethodPost, url, buf)
+ if err != nil {
+ return nil, err
+ }
+
+ r, err := ioutil.NonEmptyReader(res.Body)
+ if err == ioutil.ErrEmptyReader {
+ return nil, nil
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ var d *sideband.Demuxer
+ if req.Capabilities.Supports(capability.Sideband64k) {
+ d = sideband.NewDemuxer(sideband.Sideband64k, r)
+ } else if req.Capabilities.Supports(capability.Sideband) {
+ d = sideband.NewDemuxer(sideband.Sideband, r)
+ }
+ if d != nil {
+ d.Progress = req.Progress
+ r = d
+ }
+
+ rc := ioutil.NewReadCloser(r, res.Body)
+
+ report := packp.NewReportStatus()
+ if err := report.Decode(rc); err != nil {
+ return nil, err
+ }
+
+ return report, report.Error()
+}
+
+func (s *rpSession) doRequest(
+ ctx context.Context, method, url string, content *bytes.Buffer,
+) (*http.Response, error) {
+
+ var body io.Reader
+ if content != nil {
+ body = content
+ }
+
+ req, err := http.NewRequest(method, url, body)
+ if err != nil {
+ return nil, plumbing.NewPermanentError(err)
+ }
+
+ applyHeadersToRequest(req, content, s.endpoint.Host, transport.ReceivePackServiceName)
+ s.ApplyAuthToRequest(req)
+
+ res, err := s.client.Do(req.WithContext(ctx))
+ if err != nil {
+ return nil, plumbing.NewUnexpectedError(err)
+ }
+
+ if err := NewErr(res); err != nil {
+ _ = res.Body.Close()
+ return nil, err
+ }
+
+ return res, nil
+}
--- /dev/null
+package http
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport/internal/common"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+type upSession struct {
+ *session
+}
+
+func newUploadPackSession(c *http.Client, ep *transport.Endpoint, auth transport.AuthMethod) (transport.UploadPackSession, error) {
+ s, err := newSession(c, ep, auth)
+ return &upSession{s}, err
+}
+
+func (s *upSession) AdvertisedReferences() (*packp.AdvRefs, error) {
+ return advertisedReferences(s.session, transport.UploadPackServiceName)
+}
+
+func (s *upSession) UploadPack(
+ ctx context.Context, req *packp.UploadPackRequest,
+) (*packp.UploadPackResponse, error) {
+
+ if req.IsEmpty() {
+ return nil, transport.ErrEmptyUploadPackRequest
+ }
+
+ if err := req.Validate(); err != nil {
+ return nil, err
+ }
+
+ url := fmt.Sprintf(
+ "%s/%s",
+ s.endpoint.String(), transport.UploadPackServiceName,
+ )
+
+ content, err := uploadPackRequestToReader(req)
+ if err != nil {
+ return nil, err
+ }
+
+ res, err := s.doRequest(ctx, http.MethodPost, url, content)
+ if err != nil {
+ return nil, err
+ }
+
+ r, err := ioutil.NonEmptyReader(res.Body)
+ if err != nil {
+ if err == ioutil.ErrEmptyReader || err == io.ErrUnexpectedEOF {
+ return nil, transport.ErrEmptyUploadPackRequest
+ }
+
+ return nil, err
+ }
+
+ rc := ioutil.NewReadCloser(r, res.Body)
+ return common.DecodeUploadPackResponse(rc, req)
+}
+
+// Close does nothing.
+func (s *upSession) Close() error {
+ return nil
+}
+
+func (s *upSession) doRequest(
+ ctx context.Context, method, url string, content *bytes.Buffer,
+) (*http.Response, error) {
+
+ var body io.Reader
+ if content != nil {
+ body = content
+ }
+
+ req, err := http.NewRequest(method, url, body)
+ if err != nil {
+ return nil, plumbing.NewPermanentError(err)
+ }
+
+ applyHeadersToRequest(req, content, s.endpoint.Host, transport.UploadPackServiceName)
+ s.ApplyAuthToRequest(req)
+
+ res, err := s.client.Do(req.WithContext(ctx))
+ if err != nil {
+ return nil, plumbing.NewUnexpectedError(err)
+ }
+
+ if err := NewErr(res); err != nil {
+ _ = res.Body.Close()
+ return nil, err
+ }
+
+ return res, nil
+}
+
+func uploadPackRequestToReader(req *packp.UploadPackRequest) (*bytes.Buffer, error) {
+ buf := bytes.NewBuffer(nil)
+ e := pktline.NewEncoder(buf)
+
+ if err := req.UploadRequest.Encode(buf); err != nil {
+ return nil, fmt.Errorf("sending upload-req message: %s", err)
+ }
+
+ if err := req.UploadHaves.Encode(buf, false); err != nil {
+ return nil, fmt.Errorf("sending haves message: %s", err)
+ }
+
+ if err := e.EncodeString("done\n"); err != nil {
+ return nil, err
+ }
+
+ return buf, nil
+}
--- /dev/null
+// Package common implements the git pack protocol with a pluggable transport.
+// This is a low-level package to implement new transports. Use a concrete
+// implementation instead (e.g. http, file, ssh).
+//
+// A simple example of usage can be found in the file package.
+package common
+
+import (
+ "bufio"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ stdioutil "io/ioutil"
+ "strings"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/format/pktline"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/capability"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/sideband"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+const (
+ readErrorSecondsTimeout = 10
+)
+
+var (
+ ErrTimeoutExceeded = errors.New("timeout exceeded")
+)
+
+// Commander creates Command instances. This is the main entry point for
+// transport implementations.
+type Commander interface {
+ // Command creates a new Command for the given git command and
+ // endpoint. cmd can be git-upload-pack or git-receive-pack. An
+ // error should be returned if the endpoint is not supported or the
+ // command cannot be created (e.g. binary does not exist, connection
+ // cannot be established).
+ Command(cmd string, ep *transport.Endpoint, auth transport.AuthMethod) (Command, error)
+}
+
+// Command is used for a single command execution.
+// This interface is modeled after exec.Cmd and ssh.Session in the standard
+// library.
+type Command interface {
+ // StderrPipe returns a pipe that will be connected to the command's
+ // standard error when the command starts. It should not be called after
+ // Start.
+ StderrPipe() (io.Reader, error)
+ // StdinPipe returns a pipe that will be connected to the command's
+ // standard input when the command starts. It should not be called after
+ // Start. The pipe should be closed when no more input is expected.
+ StdinPipe() (io.WriteCloser, error)
+ // StdoutPipe returns a pipe that will be connected to the command's
+ // standard output when the command starts. It should not be called after
+ // Start.
+ StdoutPipe() (io.Reader, error)
+ // Start starts the specified command. It does not wait for it to
+ // complete.
+ Start() error
+ // Close closes the command and releases any resources used by it. It
+ // will block until the command exits.
+ Close() error
+}
+
+// CommandKiller expands the Command interface, enableing it for being killed.
+type CommandKiller interface {
+ // Kill and close the session whatever the state it is. It will block until
+ // the command is terminated.
+ Kill() error
+}
+
+type client struct {
+ cmdr Commander
+}
+
+// NewClient creates a new client using the given Commander.
+func NewClient(runner Commander) transport.Transport {
+ return &client{runner}
+}
+
+// NewUploadPackSession creates a new UploadPackSession.
+func (c *client) NewUploadPackSession(ep *transport.Endpoint, auth transport.AuthMethod) (
+ transport.UploadPackSession, error) {
+
+ return c.newSession(transport.UploadPackServiceName, ep, auth)
+}
+
+// NewReceivePackSession creates a new ReceivePackSession.
+func (c *client) NewReceivePackSession(ep *transport.Endpoint, auth transport.AuthMethod) (
+ transport.ReceivePackSession, error) {
+
+ return c.newSession(transport.ReceivePackServiceName, ep, auth)
+}
+
+type session struct {
+ Stdin io.WriteCloser
+ Stdout io.Reader
+ Command Command
+
+ isReceivePack bool
+ advRefs *packp.AdvRefs
+ packRun bool
+ finished bool
+ firstErrLine chan string
+}
+
+func (c *client) newSession(s string, ep *transport.Endpoint, auth transport.AuthMethod) (*session, error) {
+ cmd, err := c.cmdr.Command(s, ep, auth)
+ if err != nil {
+ return nil, err
+ }
+
+ stdin, err := cmd.StdinPipe()
+ if err != nil {
+ return nil, err
+ }
+
+ stdout, err := cmd.StdoutPipe()
+ if err != nil {
+ return nil, err
+ }
+
+ stderr, err := cmd.StderrPipe()
+ if err != nil {
+ return nil, err
+ }
+
+ if err := cmd.Start(); err != nil {
+ return nil, err
+ }
+
+ return &session{
+ Stdin: stdin,
+ Stdout: stdout,
+ Command: cmd,
+ firstErrLine: c.listenFirstError(stderr),
+ isReceivePack: s == transport.ReceivePackServiceName,
+ }, nil
+}
+
+func (c *client) listenFirstError(r io.Reader) chan string {
+ if r == nil {
+ return nil
+ }
+
+ errLine := make(chan string, 1)
+ go func() {
+ s := bufio.NewScanner(r)
+ if s.Scan() {
+ errLine <- s.Text()
+ } else {
+ close(errLine)
+ }
+
+ _, _ = io.Copy(stdioutil.Discard, r)
+ }()
+
+ return errLine
+}
+
+// AdvertisedReferences retrieves the advertised references from the server.
+func (s *session) AdvertisedReferences() (*packp.AdvRefs, error) {
+ if s.advRefs != nil {
+ return s.advRefs, nil
+ }
+
+ ar := packp.NewAdvRefs()
+ if err := ar.Decode(s.Stdout); err != nil {
+ if err := s.handleAdvRefDecodeError(err); err != nil {
+ return nil, err
+ }
+ }
+
+ transport.FilterUnsupportedCapabilities(ar.Capabilities)
+ s.advRefs = ar
+ return ar, nil
+}
+
+func (s *session) handleAdvRefDecodeError(err error) error {
+ // If repository is not found, we get empty stdout and server writes an
+ // error to stderr.
+ if err == packp.ErrEmptyInput {
+ s.finished = true
+ if err := s.checkNotFoundError(); err != nil {
+ return err
+ }
+
+ return io.ErrUnexpectedEOF
+ }
+
+ // For empty (but existing) repositories, we get empty advertised-references
+ // message. But valid. That is, it includes at least a flush.
+ if err == packp.ErrEmptyAdvRefs {
+ // Empty repositories are valid for git-receive-pack.
+ if s.isReceivePack {
+ return nil
+ }
+
+ if err := s.finish(); err != nil {
+ return err
+ }
+
+ return transport.ErrEmptyRemoteRepository
+ }
+
+ // Some server sends the errors as normal content (git protocol), so when
+ // we try to decode it fails, we need to check the content of it, to detect
+ // not found errors
+ if uerr, ok := err.(*packp.ErrUnexpectedData); ok {
+ if isRepoNotFoundError(string(uerr.Data)) {
+ return transport.ErrRepositoryNotFound
+ }
+ }
+
+ return err
+}
+
+// UploadPack performs a request to the server to fetch a packfile. A reader is
+// returned with the packfile content. The reader must be closed after reading.
+func (s *session) UploadPack(ctx context.Context, req *packp.UploadPackRequest) (*packp.UploadPackResponse, error) {
+ if req.IsEmpty() {
+ return nil, transport.ErrEmptyUploadPackRequest
+ }
+
+ if err := req.Validate(); err != nil {
+ return nil, err
+ }
+
+ if _, err := s.AdvertisedReferences(); err != nil {
+ return nil, err
+ }
+
+ s.packRun = true
+
+ in := s.StdinContext(ctx)
+ out := s.StdoutContext(ctx)
+
+ if err := uploadPack(in, out, req); err != nil {
+ return nil, err
+ }
+
+ r, err := ioutil.NonEmptyReader(out)
+ if err == ioutil.ErrEmptyReader {
+ if c, ok := s.Stdout.(io.Closer); ok {
+ _ = c.Close()
+ }
+
+ return nil, transport.ErrEmptyUploadPackRequest
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ rc := ioutil.NewReadCloser(r, s)
+ return DecodeUploadPackResponse(rc, req)
+}
+
+func (s *session) StdinContext(ctx context.Context) io.WriteCloser {
+ return ioutil.NewWriteCloserOnError(
+ ioutil.NewContextWriteCloser(ctx, s.Stdin),
+ s.onError,
+ )
+}
+
+func (s *session) StdoutContext(ctx context.Context) io.Reader {
+ return ioutil.NewReaderOnError(
+ ioutil.NewContextReader(ctx, s.Stdout),
+ s.onError,
+ )
+}
+
+func (s *session) onError(err error) {
+ if k, ok := s.Command.(CommandKiller); ok {
+ _ = k.Kill()
+ }
+
+ _ = s.Close()
+}
+
+func (s *session) ReceivePack(ctx context.Context, req *packp.ReferenceUpdateRequest) (*packp.ReportStatus, error) {
+ if _, err := s.AdvertisedReferences(); err != nil {
+ return nil, err
+ }
+
+ s.packRun = true
+
+ w := s.StdinContext(ctx)
+ if err := req.Encode(w); err != nil {
+ return nil, err
+ }
+
+ if err := w.Close(); err != nil {
+ return nil, err
+ }
+
+ if !req.Capabilities.Supports(capability.ReportStatus) {
+ // If we don't have report-status, we can only
+ // check return value error.
+ return nil, s.Command.Close()
+ }
+
+ r := s.StdoutContext(ctx)
+
+ var d *sideband.Demuxer
+ if req.Capabilities.Supports(capability.Sideband64k) {
+ d = sideband.NewDemuxer(sideband.Sideband64k, r)
+ } else if req.Capabilities.Supports(capability.Sideband) {
+ d = sideband.NewDemuxer(sideband.Sideband, r)
+ }
+ if d != nil {
+ d.Progress = req.Progress
+ r = d
+ }
+
+ report := packp.NewReportStatus()
+ if err := report.Decode(r); err != nil {
+ return nil, err
+ }
+
+ if err := report.Error(); err != nil {
+ defer s.Close()
+ return report, err
+ }
+
+ return report, s.Command.Close()
+}
+
+func (s *session) finish() error {
+ if s.finished {
+ return nil
+ }
+
+ s.finished = true
+
+ // If we did not run a upload/receive-pack, we close the connection
+ // gracefully by sending a flush packet to the server. If the server
+ // operates correctly, it will exit with status 0.
+ if !s.packRun {
+ _, err := s.Stdin.Write(pktline.FlushPkt)
+ return err
+ }
+
+ return nil
+}
+
+func (s *session) Close() (err error) {
+ err = s.finish()
+
+ defer ioutil.CheckClose(s.Command, &err)
+ return
+}
+
+func (s *session) checkNotFoundError() error {
+ t := time.NewTicker(time.Second * readErrorSecondsTimeout)
+ defer t.Stop()
+
+ select {
+ case <-t.C:
+ return ErrTimeoutExceeded
+ case line, ok := <-s.firstErrLine:
+ if !ok {
+ return nil
+ }
+
+ if isRepoNotFoundError(line) {
+ return transport.ErrRepositoryNotFound
+ }
+
+ return fmt.Errorf("unknown error: %s", line)
+ }
+}
+
+var (
+ githubRepoNotFoundErr = "ERROR: Repository not found."
+ bitbucketRepoNotFoundErr = "conq: repository does not exist."
+ localRepoNotFoundErr = "does not appear to be a git repository"
+ gitProtocolNotFoundErr = "ERR \n Repository not found."
+ gitProtocolNoSuchErr = "ERR no such repository"
+ gitProtocolAccessDeniedErr = "ERR access denied"
+ gogsAccessDeniedErr = "Gogs: Repository does not exist or you do not have access"
+)
+
+func isRepoNotFoundError(s string) bool {
+ if strings.HasPrefix(s, githubRepoNotFoundErr) {
+ return true
+ }
+
+ if strings.HasPrefix(s, bitbucketRepoNotFoundErr) {
+ return true
+ }
+
+ if strings.HasSuffix(s, localRepoNotFoundErr) {
+ return true
+ }
+
+ if strings.HasPrefix(s, gitProtocolNotFoundErr) {
+ return true
+ }
+
+ if strings.HasPrefix(s, gitProtocolNoSuchErr) {
+ return true
+ }
+
+ if strings.HasPrefix(s, gitProtocolAccessDeniedErr) {
+ return true
+ }
+
+ if strings.HasPrefix(s, gogsAccessDeniedErr) {
+ return true
+ }
+
+ return false
+}
+
+var (
+ nak = []byte("NAK")
+ eol = []byte("\n")
+)
+
+// uploadPack implements the git-upload-pack protocol.
+func uploadPack(w io.WriteCloser, r io.Reader, req *packp.UploadPackRequest) error {
+ // TODO support multi_ack mode
+ // TODO support multi_ack_detailed mode
+ // TODO support acks for common objects
+ // TODO build a proper state machine for all these processing options
+
+ if err := req.UploadRequest.Encode(w); err != nil {
+ return fmt.Errorf("sending upload-req message: %s", err)
+ }
+
+ if err := req.UploadHaves.Encode(w, true); err != nil {
+ return fmt.Errorf("sending haves message: %s", err)
+ }
+
+ if err := sendDone(w); err != nil {
+ return fmt.Errorf("sending done message: %s", err)
+ }
+
+ if err := w.Close(); err != nil {
+ return fmt.Errorf("closing input: %s", err)
+ }
+
+ return nil
+}
+
+func sendDone(w io.Writer) error {
+ e := pktline.NewEncoder(w)
+
+ return e.Encodef("done\n")
+}
+
+// DecodeUploadPackResponse decodes r into a new packp.UploadPackResponse
+func DecodeUploadPackResponse(r io.ReadCloser, req *packp.UploadPackRequest) (
+ *packp.UploadPackResponse, error,
+) {
+ res := packp.NewUploadPackResponse(req)
+ if err := res.Decode(r); err != nil {
+ return nil, fmt.Errorf("error decoding upload-pack response: %s", err)
+ }
+
+ return res, nil
+}
--- /dev/null
+package common
+
+import (
+ "context"
+ "fmt"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+// ServerCommand is used for a single server command execution.
+type ServerCommand struct {
+ Stderr io.Writer
+ Stdout io.WriteCloser
+ Stdin io.Reader
+}
+
+func ServeUploadPack(cmd ServerCommand, s transport.UploadPackSession) (err error) {
+ ioutil.CheckClose(cmd.Stdout, &err)
+
+ ar, err := s.AdvertisedReferences()
+ if err != nil {
+ return err
+ }
+
+ if err := ar.Encode(cmd.Stdout); err != nil {
+ return err
+ }
+
+ req := packp.NewUploadPackRequest()
+ if err := req.Decode(cmd.Stdin); err != nil {
+ return err
+ }
+
+ var resp *packp.UploadPackResponse
+ resp, err = s.UploadPack(context.TODO(), req)
+ if err != nil {
+ return err
+ }
+
+ return resp.Encode(cmd.Stdout)
+}
+
+func ServeReceivePack(cmd ServerCommand, s transport.ReceivePackSession) error {
+ ar, err := s.AdvertisedReferences()
+ if err != nil {
+ return fmt.Errorf("internal error in advertised references: %s", err)
+ }
+
+ if err := ar.Encode(cmd.Stdout); err != nil {
+ return fmt.Errorf("error in advertised references encoding: %s", err)
+ }
+
+ req := packp.NewReferenceUpdateRequest()
+ if err := req.Decode(cmd.Stdin); err != nil {
+ return fmt.Errorf("error decoding: %s", err)
+ }
+
+ rs, err := s.ReceivePack(context.TODO(), req)
+ if rs != nil {
+ if err := rs.Encode(cmd.Stdout); err != nil {
+ return fmt.Errorf("error in encoding report status %s", err)
+ }
+ }
+
+ if err != nil {
+ return fmt.Errorf("error in receive pack: %s", err)
+ }
+
+ return nil
+}
--- /dev/null
+package server
+
+import (
+ "gopkg.in/src-d/go-git.v4/plumbing/cache"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport"
+ "gopkg.in/src-d/go-git.v4/storage/filesystem"
+
+ "gopkg.in/src-d/go-billy.v4"
+ "gopkg.in/src-d/go-billy.v4/osfs"
+)
+
+// DefaultLoader is a filesystem loader ignoring host and resolving paths to /.
+var DefaultLoader = NewFilesystemLoader(osfs.New(""))
+
+// Loader loads repository's storer.Storer based on an optional host and a path.
+type Loader interface {
+ // Load loads a storer.Storer given a transport.Endpoint.
+ // Returns transport.ErrRepositoryNotFound if the repository does not
+ // exist.
+ Load(ep *transport.Endpoint) (storer.Storer, error)
+}
+
+type fsLoader struct {
+ base billy.Filesystem
+}
+
+// NewFilesystemLoader creates a Loader that ignores host and resolves paths
+// with a given base filesystem.
+func NewFilesystemLoader(base billy.Filesystem) Loader {
+ return &fsLoader{base}
+}
+
+// Load looks up the endpoint's path in the base file system and returns a
+// storer for it. Returns transport.ErrRepositoryNotFound if a repository does
+// not exist in the given path.
+func (l *fsLoader) Load(ep *transport.Endpoint) (storer.Storer, error) {
+ fs, err := l.base.Chroot(ep.Path)
+ if err != nil {
+ return nil, err
+ }
+
+ if _, err := fs.Stat("config"); err != nil {
+ return nil, transport.ErrRepositoryNotFound
+ }
+
+ return filesystem.NewStorage(fs, cache.NewObjectLRUDefault()), nil
+}
+
+// MapLoader is a Loader that uses a lookup map of storer.Storer by
+// transport.Endpoint.
+type MapLoader map[string]storer.Storer
+
+// Load returns a storer.Storer for given a transport.Endpoint by looking it up
+// in the map. Returns transport.ErrRepositoryNotFound if the endpoint does not
+// exist.
+func (l MapLoader) Load(ep *transport.Endpoint) (storer.Storer, error) {
+ s, ok := l[ep.String()]
+ if !ok {
+ return nil, transport.ErrRepositoryNotFound
+ }
+
+ return s, nil
+}
--- /dev/null
+// Package server implements the git server protocol. For most use cases, the
+// transport-specific implementations should be used.
+package server
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/packfile"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/capability"
+ "gopkg.in/src-d/go-git.v4/plumbing/revlist"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+var DefaultServer = NewServer(DefaultLoader)
+
+type server struct {
+ loader Loader
+ handler *handler
+}
+
+// NewServer returns a transport.Transport implementing a git server,
+// independent of transport. Each transport must wrap this.
+func NewServer(loader Loader) transport.Transport {
+ return &server{
+ loader,
+ &handler{asClient: false},
+ }
+}
+
+// NewClient returns a transport.Transport implementing a client with an
+// embedded server.
+func NewClient(loader Loader) transport.Transport {
+ return &server{
+ loader,
+ &handler{asClient: true},
+ }
+}
+
+func (s *server) NewUploadPackSession(ep *transport.Endpoint, auth transport.AuthMethod) (transport.UploadPackSession, error) {
+ sto, err := s.loader.Load(ep)
+ if err != nil {
+ return nil, err
+ }
+
+ return s.handler.NewUploadPackSession(sto)
+}
+
+func (s *server) NewReceivePackSession(ep *transport.Endpoint, auth transport.AuthMethod) (transport.ReceivePackSession, error) {
+ sto, err := s.loader.Load(ep)
+ if err != nil {
+ return nil, err
+ }
+
+ return s.handler.NewReceivePackSession(sto)
+}
+
+type handler struct {
+ asClient bool
+}
+
+func (h *handler) NewUploadPackSession(s storer.Storer) (transport.UploadPackSession, error) {
+ return &upSession{
+ session: session{storer: s, asClient: h.asClient},
+ }, nil
+}
+
+func (h *handler) NewReceivePackSession(s storer.Storer) (transport.ReceivePackSession, error) {
+ return &rpSession{
+ session: session{storer: s, asClient: h.asClient},
+ cmdStatus: map[plumbing.ReferenceName]error{},
+ }, nil
+}
+
+type session struct {
+ storer storer.Storer
+ caps *capability.List
+ asClient bool
+}
+
+func (s *session) Close() error {
+ return nil
+}
+
+func (s *session) SetAuth(transport.AuthMethod) error {
+ //TODO: deprecate
+ return nil
+}
+
+func (s *session) checkSupportedCapabilities(cl *capability.List) error {
+ for _, c := range cl.All() {
+ if !s.caps.Supports(c) {
+ return fmt.Errorf("unsupported capability: %s", c)
+ }
+ }
+
+ return nil
+}
+
+type upSession struct {
+ session
+}
+
+func (s *upSession) AdvertisedReferences() (*packp.AdvRefs, error) {
+ ar := packp.NewAdvRefs()
+
+ if err := s.setSupportedCapabilities(ar.Capabilities); err != nil {
+ return nil, err
+ }
+
+ s.caps = ar.Capabilities
+
+ if err := setReferences(s.storer, ar); err != nil {
+ return nil, err
+ }
+
+ if err := setHEAD(s.storer, ar); err != nil {
+ return nil, err
+ }
+
+ if s.asClient && len(ar.References) == 0 {
+ return nil, transport.ErrEmptyRemoteRepository
+ }
+
+ return ar, nil
+}
+
+func (s *upSession) UploadPack(ctx context.Context, req *packp.UploadPackRequest) (*packp.UploadPackResponse, error) {
+ if req.IsEmpty() {
+ return nil, transport.ErrEmptyUploadPackRequest
+ }
+
+ if err := req.Validate(); err != nil {
+ return nil, err
+ }
+
+ if s.caps == nil {
+ s.caps = capability.NewList()
+ if err := s.setSupportedCapabilities(s.caps); err != nil {
+ return nil, err
+ }
+ }
+
+ if err := s.checkSupportedCapabilities(req.Capabilities); err != nil {
+ return nil, err
+ }
+
+ s.caps = req.Capabilities
+
+ if len(req.Shallows) > 0 {
+ return nil, fmt.Errorf("shallow not supported")
+ }
+
+ objs, err := s.objectsToUpload(req)
+ if err != nil {
+ return nil, err
+ }
+
+ pr, pw := io.Pipe()
+ e := packfile.NewEncoder(pw, s.storer, false)
+ go func() {
+ // TODO: plumb through a pack window.
+ _, err := e.Encode(objs, 10)
+ pw.CloseWithError(err)
+ }()
+
+ return packp.NewUploadPackResponseWithPackfile(req,
+ ioutil.NewContextReadCloser(ctx, pr),
+ ), nil
+}
+
+func (s *upSession) objectsToUpload(req *packp.UploadPackRequest) ([]plumbing.Hash, error) {
+ haves, err := revlist.Objects(s.storer, req.Haves, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ return revlist.Objects(s.storer, req.Wants, haves)
+}
+
+func (*upSession) setSupportedCapabilities(c *capability.List) error {
+ if err := c.Set(capability.Agent, capability.DefaultAgent); err != nil {
+ return err
+ }
+
+ if err := c.Set(capability.OFSDelta); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+type rpSession struct {
+ session
+ cmdStatus map[plumbing.ReferenceName]error
+ firstErr error
+ unpackErr error
+}
+
+func (s *rpSession) AdvertisedReferences() (*packp.AdvRefs, error) {
+ ar := packp.NewAdvRefs()
+
+ if err := s.setSupportedCapabilities(ar.Capabilities); err != nil {
+ return nil, err
+ }
+
+ s.caps = ar.Capabilities
+
+ if err := setReferences(s.storer, ar); err != nil {
+ return nil, err
+ }
+
+ if err := setHEAD(s.storer, ar); err != nil {
+ return nil, err
+ }
+
+ return ar, nil
+}
+
+var (
+ ErrUpdateReference = errors.New("failed to update ref")
+)
+
+func (s *rpSession) ReceivePack(ctx context.Context, req *packp.ReferenceUpdateRequest) (*packp.ReportStatus, error) {
+ if s.caps == nil {
+ s.caps = capability.NewList()
+ if err := s.setSupportedCapabilities(s.caps); err != nil {
+ return nil, err
+ }
+ }
+
+ if err := s.checkSupportedCapabilities(req.Capabilities); err != nil {
+ return nil, err
+ }
+
+ s.caps = req.Capabilities
+
+ //TODO: Implement 'atomic' update of references.
+
+ r := ioutil.NewContextReadCloser(ctx, req.Packfile)
+ if err := s.writePackfile(r); err != nil {
+ s.unpackErr = err
+ s.firstErr = err
+ return s.reportStatus(), err
+ }
+
+ s.updateReferences(req)
+ return s.reportStatus(), s.firstErr
+}
+
+func (s *rpSession) updateReferences(req *packp.ReferenceUpdateRequest) {
+ for _, cmd := range req.Commands {
+ exists, err := referenceExists(s.storer, cmd.Name)
+ if err != nil {
+ s.setStatus(cmd.Name, err)
+ continue
+ }
+
+ switch cmd.Action() {
+ case packp.Create:
+ if exists {
+ s.setStatus(cmd.Name, ErrUpdateReference)
+ continue
+ }
+
+ ref := plumbing.NewHashReference(cmd.Name, cmd.New)
+ err := s.storer.SetReference(ref)
+ s.setStatus(cmd.Name, err)
+ case packp.Delete:
+ if !exists {
+ s.setStatus(cmd.Name, ErrUpdateReference)
+ continue
+ }
+
+ err := s.storer.RemoveReference(cmd.Name)
+ s.setStatus(cmd.Name, err)
+ case packp.Update:
+ if !exists {
+ s.setStatus(cmd.Name, ErrUpdateReference)
+ continue
+ }
+
+ if err != nil {
+ s.setStatus(cmd.Name, err)
+ continue
+ }
+
+ ref := plumbing.NewHashReference(cmd.Name, cmd.New)
+ err := s.storer.SetReference(ref)
+ s.setStatus(cmd.Name, err)
+ }
+ }
+}
+
+func (s *rpSession) writePackfile(r io.ReadCloser) error {
+ if r == nil {
+ return nil
+ }
+
+ if err := packfile.UpdateObjectStorage(s.storer, r); err != nil {
+ _ = r.Close()
+ return err
+ }
+
+ return r.Close()
+}
+
+func (s *rpSession) setStatus(ref plumbing.ReferenceName, err error) {
+ s.cmdStatus[ref] = err
+ if s.firstErr == nil && err != nil {
+ s.firstErr = err
+ }
+}
+
+func (s *rpSession) reportStatus() *packp.ReportStatus {
+ if !s.caps.Supports(capability.ReportStatus) {
+ return nil
+ }
+
+ rs := packp.NewReportStatus()
+ rs.UnpackStatus = "ok"
+
+ if s.unpackErr != nil {
+ rs.UnpackStatus = s.unpackErr.Error()
+ }
+
+ if s.cmdStatus == nil {
+ return rs
+ }
+
+ for ref, err := range s.cmdStatus {
+ msg := "ok"
+ if err != nil {
+ msg = err.Error()
+ }
+ status := &packp.CommandStatus{
+ ReferenceName: ref,
+ Status: msg,
+ }
+ rs.CommandStatuses = append(rs.CommandStatuses, status)
+ }
+
+ return rs
+}
+
+func (*rpSession) setSupportedCapabilities(c *capability.List) error {
+ if err := c.Set(capability.Agent, capability.DefaultAgent); err != nil {
+ return err
+ }
+
+ if err := c.Set(capability.OFSDelta); err != nil {
+ return err
+ }
+
+ if err := c.Set(capability.DeleteRefs); err != nil {
+ return err
+ }
+
+ return c.Set(capability.ReportStatus)
+}
+
+func setHEAD(s storer.Storer, ar *packp.AdvRefs) error {
+ ref, err := s.Reference(plumbing.HEAD)
+ if err == plumbing.ErrReferenceNotFound {
+ return nil
+ }
+
+ if err != nil {
+ return err
+ }
+
+ if ref.Type() == plumbing.SymbolicReference {
+ if err := ar.AddReference(ref); err != nil {
+ return nil
+ }
+
+ ref, err = storer.ResolveReference(s, ref.Target())
+ if err == plumbing.ErrReferenceNotFound {
+ return nil
+ }
+
+ if err != nil {
+ return err
+ }
+ }
+
+ if ref.Type() != plumbing.HashReference {
+ return plumbing.ErrInvalidType
+ }
+
+ h := ref.Hash()
+ ar.Head = &h
+
+ return nil
+}
+
+func setReferences(s storer.Storer, ar *packp.AdvRefs) error {
+ //TODO: add peeled references.
+ iter, err := s.IterReferences()
+ if err != nil {
+ return err
+ }
+
+ return iter.ForEach(func(ref *plumbing.Reference) error {
+ if ref.Type() != plumbing.HashReference {
+ return nil
+ }
+
+ ar.References[ref.Name().String()] = ref.Hash()
+ return nil
+ })
+}
+
+func referenceExists(s storer.ReferenceStorer, n plumbing.ReferenceName) (bool, error) {
+ _, err := s.Reference(n)
+ if err == plumbing.ErrReferenceNotFound {
+ return false, nil
+ }
+
+ return err == nil, err
+}
--- /dev/null
+package ssh
+
+import (
+ "crypto/x509"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "os/user"
+ "path/filepath"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/transport"
+
+ "github.com/mitchellh/go-homedir"
+ "github.com/xanzy/ssh-agent"
+ "golang.org/x/crypto/ssh"
+ "golang.org/x/crypto/ssh/knownhosts"
+)
+
+const DefaultUsername = "git"
+
+// AuthMethod is the interface all auth methods for the ssh client
+// must implement. The clientConfig method returns the ssh client
+// configuration needed to establish an ssh connection.
+type AuthMethod interface {
+ transport.AuthMethod
+ // ClientConfig should return a valid ssh.ClientConfig to be used to create
+ // a connection to the SSH server.
+ ClientConfig() (*ssh.ClientConfig, error)
+}
+
+// The names of the AuthMethod implementations. To be returned by the
+// Name() method. Most git servers only allow PublicKeysName and
+// PublicKeysCallbackName.
+const (
+ KeyboardInteractiveName = "ssh-keyboard-interactive"
+ PasswordName = "ssh-password"
+ PasswordCallbackName = "ssh-password-callback"
+ PublicKeysName = "ssh-public-keys"
+ PublicKeysCallbackName = "ssh-public-key-callback"
+)
+
+// KeyboardInteractive implements AuthMethod by using a
+// prompt/response sequence controlled by the server.
+type KeyboardInteractive struct {
+ User string
+ Challenge ssh.KeyboardInteractiveChallenge
+ HostKeyCallbackHelper
+}
+
+func (a *KeyboardInteractive) Name() string {
+ return KeyboardInteractiveName
+}
+
+func (a *KeyboardInteractive) String() string {
+ return fmt.Sprintf("user: %s, name: %s", a.User, a.Name())
+}
+
+func (a *KeyboardInteractive) ClientConfig() (*ssh.ClientConfig, error) {
+ return a.SetHostKeyCallback(&ssh.ClientConfig{
+ User: a.User,
+ Auth: []ssh.AuthMethod{
+ ssh.KeyboardInteractiveChallenge(a.Challenge),
+ },
+ })
+}
+
+// Password implements AuthMethod by using the given password.
+type Password struct {
+ User string
+ Password string
+ HostKeyCallbackHelper
+}
+
+func (a *Password) Name() string {
+ return PasswordName
+}
+
+func (a *Password) String() string {
+ return fmt.Sprintf("user: %s, name: %s", a.User, a.Name())
+}
+
+func (a *Password) ClientConfig() (*ssh.ClientConfig, error) {
+ return a.SetHostKeyCallback(&ssh.ClientConfig{
+ User: a.User,
+ Auth: []ssh.AuthMethod{ssh.Password(a.Password)},
+ })
+}
+
+// PasswordCallback implements AuthMethod by using a callback
+// to fetch the password.
+type PasswordCallback struct {
+ User string
+ Callback func() (pass string, err error)
+ HostKeyCallbackHelper
+}
+
+func (a *PasswordCallback) Name() string {
+ return PasswordCallbackName
+}
+
+func (a *PasswordCallback) String() string {
+ return fmt.Sprintf("user: %s, name: %s", a.User, a.Name())
+}
+
+func (a *PasswordCallback) ClientConfig() (*ssh.ClientConfig, error) {
+ return a.SetHostKeyCallback(&ssh.ClientConfig{
+ User: a.User,
+ Auth: []ssh.AuthMethod{ssh.PasswordCallback(a.Callback)},
+ })
+}
+
+// PublicKeys implements AuthMethod by using the given key pairs.
+type PublicKeys struct {
+ User string
+ Signer ssh.Signer
+ HostKeyCallbackHelper
+}
+
+// NewPublicKeys returns a PublicKeys from a PEM encoded private key. An
+// encryption password should be given if the pemBytes contains a password
+// encrypted PEM block otherwise password should be empty. It supports RSA
+// (PKCS#1), DSA (OpenSSL), and ECDSA private keys.
+func NewPublicKeys(user string, pemBytes []byte, password string) (*PublicKeys, error) {
+ block, _ := pem.Decode(pemBytes)
+ if block == nil {
+ return nil, errors.New("invalid PEM data")
+ }
+ if x509.IsEncryptedPEMBlock(block) {
+ key, err := x509.DecryptPEMBlock(block, []byte(password))
+ if err != nil {
+ return nil, err
+ }
+
+ block = &pem.Block{Type: block.Type, Bytes: key}
+ pemBytes = pem.EncodeToMemory(block)
+ }
+
+ signer, err := ssh.ParsePrivateKey(pemBytes)
+ if err != nil {
+ return nil, err
+ }
+
+ return &PublicKeys{User: user, Signer: signer}, nil
+}
+
+// NewPublicKeysFromFile returns a PublicKeys from a file containing a PEM
+// encoded private key. An encryption password should be given if the pemBytes
+// contains a password encrypted PEM block otherwise password should be empty.
+func NewPublicKeysFromFile(user, pemFile, password string) (*PublicKeys, error) {
+ bytes, err := ioutil.ReadFile(pemFile)
+ if err != nil {
+ return nil, err
+ }
+
+ return NewPublicKeys(user, bytes, password)
+}
+
+func (a *PublicKeys) Name() string {
+ return PublicKeysName
+}
+
+func (a *PublicKeys) String() string {
+ return fmt.Sprintf("user: %s, name: %s", a.User, a.Name())
+}
+
+func (a *PublicKeys) ClientConfig() (*ssh.ClientConfig, error) {
+ return a.SetHostKeyCallback(&ssh.ClientConfig{
+ User: a.User,
+ Auth: []ssh.AuthMethod{ssh.PublicKeys(a.Signer)},
+ })
+}
+
+func username() (string, error) {
+ var username string
+ if user, err := user.Current(); err == nil {
+ username = user.Username
+ } else {
+ username = os.Getenv("USER")
+ }
+
+ if username == "" {
+ return "", errors.New("failed to get username")
+ }
+
+ return username, nil
+}
+
+// PublicKeysCallback implements AuthMethod by asking a
+// ssh.agent.Agent to act as a signer.
+type PublicKeysCallback struct {
+ User string
+ Callback func() (signers []ssh.Signer, err error)
+ HostKeyCallbackHelper
+}
+
+// NewSSHAgentAuth returns a PublicKeysCallback based on a SSH agent, it opens
+// a pipe with the SSH agent and uses the pipe as the implementer of the public
+// key callback function.
+func NewSSHAgentAuth(u string) (*PublicKeysCallback, error) {
+ var err error
+ if u == "" {
+ u, err = username()
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ a, _, err := sshagent.New()
+ if err != nil {
+ return nil, fmt.Errorf("error creating SSH agent: %q", err)
+ }
+
+ return &PublicKeysCallback{
+ User: u,
+ Callback: a.Signers,
+ }, nil
+}
+
+func (a *PublicKeysCallback) Name() string {
+ return PublicKeysCallbackName
+}
+
+func (a *PublicKeysCallback) String() string {
+ return fmt.Sprintf("user: %s, name: %s", a.User, a.Name())
+}
+
+func (a *PublicKeysCallback) ClientConfig() (*ssh.ClientConfig, error) {
+ return a.SetHostKeyCallback(&ssh.ClientConfig{
+ User: a.User,
+ Auth: []ssh.AuthMethod{ssh.PublicKeysCallback(a.Callback)},
+ })
+}
+
+// NewKnownHostsCallback returns ssh.HostKeyCallback based on a file based on a
+// known_hosts file. http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT
+//
+// If list of files is empty, then it will be read from the SSH_KNOWN_HOSTS
+// environment variable, example:
+// /home/foo/custom_known_hosts_file:/etc/custom_known/hosts_file
+//
+// If SSH_KNOWN_HOSTS is not set the following file locations will be used:
+// ~/.ssh/known_hosts
+// /etc/ssh/ssh_known_hosts
+func NewKnownHostsCallback(files ...string) (ssh.HostKeyCallback, error) {
+ var err error
+
+ if len(files) == 0 {
+ if files, err = getDefaultKnownHostsFiles(); err != nil {
+ return nil, err
+ }
+ }
+
+ if files, err = filterKnownHostsFiles(files...); err != nil {
+ return nil, err
+ }
+
+ return knownhosts.New(files...)
+}
+
+func getDefaultKnownHostsFiles() ([]string, error) {
+ files := filepath.SplitList(os.Getenv("SSH_KNOWN_HOSTS"))
+ if len(files) != 0 {
+ return files, nil
+ }
+
+ homeDirPath, err := homedir.Dir()
+ if err != nil {
+ return nil, err
+ }
+
+ return []string{
+ filepath.Join(homeDirPath, "/.ssh/known_hosts"),
+ "/etc/ssh/ssh_known_hosts",
+ }, nil
+}
+
+func filterKnownHostsFiles(files ...string) ([]string, error) {
+ var out []string
+ for _, file := range files {
+ _, err := os.Stat(file)
+ if err == nil {
+ out = append(out, file)
+ continue
+ }
+
+ if !os.IsNotExist(err) {
+ return nil, err
+ }
+ }
+
+ if len(out) == 0 {
+ return nil, fmt.Errorf("unable to find any valid known_hosts file, set SSH_KNOWN_HOSTS env variable")
+ }
+
+ return out, nil
+}
+
+// HostKeyCallbackHelper is a helper that provides common functionality to
+// configure HostKeyCallback into a ssh.ClientConfig.
+type HostKeyCallbackHelper struct {
+ // HostKeyCallback is the function type used for verifying server keys.
+ // If nil default callback will be create using NewKnownHostsCallback
+ // without argument.
+ HostKeyCallback ssh.HostKeyCallback
+}
+
+// SetHostKeyCallback sets the field HostKeyCallback in the given cfg. If
+// HostKeyCallback is empty a default callback is created using
+// NewKnownHostsCallback.
+func (m *HostKeyCallbackHelper) SetHostKeyCallback(cfg *ssh.ClientConfig) (*ssh.ClientConfig, error) {
+ var err error
+ if m.HostKeyCallback == nil {
+ if m.HostKeyCallback, err = NewKnownHostsCallback(); err != nil {
+ return cfg, err
+ }
+ }
+
+ cfg.HostKeyCallback = m.HostKeyCallback
+ return cfg, nil
+}
--- /dev/null
+// Package ssh implements the SSH transport protocol.
+package ssh
+
+import (
+ "fmt"
+ "reflect"
+ "strconv"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/transport"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport/internal/common"
+
+ "github.com/kevinburke/ssh_config"
+ "golang.org/x/crypto/ssh"
+)
+
+// DefaultClient is the default SSH client.
+var DefaultClient = NewClient(nil)
+
+// DefaultSSHConfig is the reader used to access parameters stored in the
+// system's ssh_config files. If nil all the ssh_config are ignored.
+var DefaultSSHConfig sshConfig = ssh_config.DefaultUserSettings
+
+type sshConfig interface {
+ Get(alias, key string) string
+}
+
+// NewClient creates a new SSH client with an optional *ssh.ClientConfig.
+func NewClient(config *ssh.ClientConfig) transport.Transport {
+ return common.NewClient(&runner{config: config})
+}
+
+// DefaultAuthBuilder is the function used to create a default AuthMethod, when
+// the user doesn't provide any.
+var DefaultAuthBuilder = func(user string) (AuthMethod, error) {
+ return NewSSHAgentAuth(user)
+}
+
+const DefaultPort = 22
+
+type runner struct {
+ config *ssh.ClientConfig
+}
+
+func (r *runner) Command(cmd string, ep *transport.Endpoint, auth transport.AuthMethod) (common.Command, error) {
+ c := &command{command: cmd, endpoint: ep, config: r.config}
+ if auth != nil {
+ c.setAuth(auth)
+ }
+
+ if err := c.connect(); err != nil {
+ return nil, err
+ }
+ return c, nil
+}
+
+type command struct {
+ *ssh.Session
+ connected bool
+ command string
+ endpoint *transport.Endpoint
+ client *ssh.Client
+ auth AuthMethod
+ config *ssh.ClientConfig
+}
+
+func (c *command) setAuth(auth transport.AuthMethod) error {
+ a, ok := auth.(AuthMethod)
+ if !ok {
+ return transport.ErrInvalidAuthMethod
+ }
+
+ c.auth = a
+ return nil
+}
+
+func (c *command) Start() error {
+ return c.Session.Start(endpointToCommand(c.command, c.endpoint))
+}
+
+// Close closes the SSH session and connection.
+func (c *command) Close() error {
+ if !c.connected {
+ return nil
+ }
+
+ c.connected = false
+
+ //XXX: If did read the full packfile, then the session might be already
+ // closed.
+ _ = c.Session.Close()
+
+ return c.client.Close()
+}
+
+// connect connects to the SSH server, unless a AuthMethod was set with
+// SetAuth method, by default uses an auth method based on PublicKeysCallback,
+// it connects to a SSH agent, using the address stored in the SSH_AUTH_SOCK
+// environment var.
+func (c *command) connect() error {
+ if c.connected {
+ return transport.ErrAlreadyConnected
+ }
+
+ if c.auth == nil {
+ if err := c.setAuthFromEndpoint(); err != nil {
+ return err
+ }
+ }
+
+ var err error
+ config, err := c.auth.ClientConfig()
+ if err != nil {
+ return err
+ }
+
+ overrideConfig(c.config, config)
+
+ c.client, err = ssh.Dial("tcp", c.getHostWithPort(), config)
+ if err != nil {
+ return err
+ }
+
+ c.Session, err = c.client.NewSession()
+ if err != nil {
+ _ = c.client.Close()
+ return err
+ }
+
+ c.connected = true
+ return nil
+}
+
+func (c *command) getHostWithPort() string {
+ if addr, found := c.doGetHostWithPortFromSSHConfig(); found {
+ return addr
+ }
+
+ host := c.endpoint.Host
+ port := c.endpoint.Port
+ if port <= 0 {
+ port = DefaultPort
+ }
+
+ return fmt.Sprintf("%s:%d", host, port)
+}
+
+func (c *command) doGetHostWithPortFromSSHConfig() (addr string, found bool) {
+ if DefaultSSHConfig == nil {
+ return
+ }
+
+ host := c.endpoint.Host
+ port := c.endpoint.Port
+
+ configHost := DefaultSSHConfig.Get(c.endpoint.Host, "Hostname")
+ if configHost != "" {
+ host = configHost
+ found = true
+ }
+
+ if !found {
+ return
+ }
+
+ configPort := DefaultSSHConfig.Get(c.endpoint.Host, "Port")
+ if configPort != "" {
+ if i, err := strconv.Atoi(configPort); err == nil {
+ port = i
+ }
+ }
+
+ addr = fmt.Sprintf("%s:%d", host, port)
+ return
+}
+
+func (c *command) setAuthFromEndpoint() error {
+ var err error
+ c.auth, err = DefaultAuthBuilder(c.endpoint.User)
+ return err
+}
+
+func endpointToCommand(cmd string, ep *transport.Endpoint) string {
+ return fmt.Sprintf("%s '%s'", cmd, ep.Path)
+}
+
+func overrideConfig(overrides *ssh.ClientConfig, c *ssh.ClientConfig) {
+ if overrides == nil {
+ return
+ }
+
+ t := reflect.TypeOf(*c)
+ vc := reflect.ValueOf(c).Elem()
+ vo := reflect.ValueOf(overrides).Elem()
+
+ for i := 0; i < t.NumField(); i++ {
+ f := t.Field(i)
+ vcf := vc.FieldByName(f.Name)
+ vof := vo.FieldByName(f.Name)
+ vcf.Set(vof)
+ }
+
+ *c = vc.Interface().(ssh.ClientConfig)
+}
--- /dev/null
+package git
+
+import (
+ "errors"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+)
+
+type PruneHandler func(unreferencedObjectHash plumbing.Hash) error
+type PruneOptions struct {
+ // OnlyObjectsOlderThan if set to non-zero value
+ // selects only objects older than the time provided.
+ OnlyObjectsOlderThan time.Time
+ // Handler is called on matching objects
+ Handler PruneHandler
+}
+
+var ErrLooseObjectsNotSupported = errors.New("Loose objects not supported")
+
+// DeleteObject deletes an object from a repository.
+// The type conveniently matches PruneHandler.
+func (r *Repository) DeleteObject(hash plumbing.Hash) error {
+ los, ok := r.Storer.(storer.LooseObjectStorer)
+ if !ok {
+ return ErrLooseObjectsNotSupported
+ }
+
+ return los.DeleteLooseObject(hash)
+}
+
+func (r *Repository) Prune(opt PruneOptions) error {
+ los, ok := r.Storer.(storer.LooseObjectStorer)
+ if !ok {
+ return ErrLooseObjectsNotSupported
+ }
+
+ pw := newObjectWalker(r.Storer)
+ err := pw.walkAllRefs()
+ if err != nil {
+ return err
+ }
+ // Now walk all (loose) objects in storage.
+ return los.ForEachObjectHash(func(hash plumbing.Hash) error {
+ // Get out if we have seen this object.
+ if pw.isSeen(hash) {
+ return nil
+ }
+ // Otherwise it is a candidate for pruning.
+ // Check out for too new objects next.
+ if !opt.OnlyObjectsOlderThan.IsZero() {
+ // Errors here are non-fatal. The object may be e.g. packed.
+ // Or concurrently deleted. Skip such objects.
+ t, err := los.LooseObjectTime(hash)
+ if err != nil {
+ return nil
+ }
+ // Skip too new objects.
+ if !t.Before(opt.OnlyObjectsOlderThan) {
+ return nil
+ }
+ }
+ return opt.Handler(hash)
+ })
+}
--- /dev/null
+package git
+
+import (
+ "io"
+ "sort"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/object"
+ "gopkg.in/src-d/go-git.v4/utils/diff"
+
+ "github.com/sergi/go-diff/diffmatchpatch"
+)
+
+// References returns a slice of Commits for the file at "path", starting from
+// the commit provided that contains the file from the provided path. The last
+// commit into the returned slice is the commit where the file was created.
+// If the provided commit does not contains the specified path, a nil slice is
+// returned. The commits are sorted in commit order, newer to older.
+//
+// Caveats:
+//
+// - Moves and copies are not currently supported.
+//
+// - Cherry-picks are not detected unless there are no commits between them and
+// therefore can appear repeated in the list. (see git path-id for hints on how
+// to fix this).
+func references(c *object.Commit, path string) ([]*object.Commit, error) {
+ var result []*object.Commit
+ seen := make(map[plumbing.Hash]struct{})
+ if err := walkGraph(&result, &seen, c, path); err != nil {
+ return nil, err
+ }
+
+ // TODO result should be returned without ordering
+ sortCommits(result)
+
+ // for merges of identical cherry-picks
+ return removeComp(path, result, equivalent)
+}
+
+type commitSorterer struct {
+ l []*object.Commit
+}
+
+func (s commitSorterer) Len() int {
+ return len(s.l)
+}
+
+func (s commitSorterer) Less(i, j int) bool {
+ return s.l[i].Committer.When.Before(s.l[j].Committer.When) ||
+ s.l[i].Committer.When.Equal(s.l[j].Committer.When) &&
+ s.l[i].Author.When.Before(s.l[j].Author.When)
+}
+
+func (s commitSorterer) Swap(i, j int) {
+ s.l[i], s.l[j] = s.l[j], s.l[i]
+}
+
+// SortCommits sorts a commit list by commit date, from older to newer.
+func sortCommits(l []*object.Commit) {
+ s := &commitSorterer{l}
+ sort.Sort(s)
+}
+
+// Recursive traversal of the commit graph, generating a linear history of the
+// path.
+func walkGraph(result *[]*object.Commit, seen *map[plumbing.Hash]struct{}, current *object.Commit, path string) error {
+ // check and update seen
+ if _, ok := (*seen)[current.Hash]; ok {
+ return nil
+ }
+ (*seen)[current.Hash] = struct{}{}
+
+ // if the path is not in the current commit, stop searching.
+ if _, err := current.File(path); err != nil {
+ return nil
+ }
+
+ // optimization: don't traverse branches that does not
+ // contain the path.
+ parents, err := parentsContainingPath(path, current)
+ if err != nil {
+ return err
+ }
+ switch len(parents) {
+ // if the path is not found in any of its parents, the path was
+ // created by this commit; we must add it to the revisions list and
+ // stop searching. This includes the case when current is the
+ // initial commit.
+ case 0:
+ *result = append(*result, current)
+ return nil
+ case 1: // only one parent contains the path
+ // if the file contents has change, add the current commit
+ different, err := differentContents(path, current, parents)
+ if err != nil {
+ return err
+ }
+ if len(different) == 1 {
+ *result = append(*result, current)
+ }
+ // in any case, walk the parent
+ return walkGraph(result, seen, parents[0], path)
+ default: // more than one parent contains the path
+ // TODO: detect merges that had a conflict, because they must be
+ // included in the result here.
+ for _, p := range parents {
+ err := walkGraph(result, seen, p, path)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func parentsContainingPath(path string, c *object.Commit) ([]*object.Commit, error) {
+ // TODO: benchmark this method making git.object.Commit.parent public instead of using
+ // an iterator
+ var result []*object.Commit
+ iter := c.Parents()
+ for {
+ parent, err := iter.Next()
+ if err == io.EOF {
+ return result, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ if _, err := parent.File(path); err == nil {
+ result = append(result, parent)
+ }
+ }
+}
+
+// Returns an slice of the commits in "cs" that has the file "path", but with different
+// contents than what can be found in "c".
+func differentContents(path string, c *object.Commit, cs []*object.Commit) ([]*object.Commit, error) {
+ result := make([]*object.Commit, 0, len(cs))
+ h, found := blobHash(path, c)
+ if !found {
+ return nil, object.ErrFileNotFound
+ }
+ for _, cx := range cs {
+ if hx, found := blobHash(path, cx); found && h != hx {
+ result = append(result, cx)
+ }
+ }
+ return result, nil
+}
+
+// blobHash returns the hash of a path in a commit
+func blobHash(path string, commit *object.Commit) (hash plumbing.Hash, found bool) {
+ file, err := commit.File(path)
+ if err != nil {
+ var empty plumbing.Hash
+ return empty, found
+ }
+ return file.Hash, true
+}
+
+type contentsComparatorFn func(path string, a, b *object.Commit) (bool, error)
+
+// Returns a new slice of commits, with duplicates removed. Expects a
+// sorted commit list. Duplication is defined according to "comp". It
+// will always keep the first commit of a series of duplicated commits.
+func removeComp(path string, cs []*object.Commit, comp contentsComparatorFn) ([]*object.Commit, error) {
+ result := make([]*object.Commit, 0, len(cs))
+ if len(cs) == 0 {
+ return result, nil
+ }
+ result = append(result, cs[0])
+ for i := 1; i < len(cs); i++ {
+ equals, err := comp(path, cs[i], cs[i-1])
+ if err != nil {
+ return nil, err
+ }
+ if !equals {
+ result = append(result, cs[i])
+ }
+ }
+ return result, nil
+}
+
+// Equivalent commits are commits whose patch is the same.
+func equivalent(path string, a, b *object.Commit) (bool, error) {
+ numParentsA := a.NumParents()
+ numParentsB := b.NumParents()
+
+ // the first commit is not equivalent to anyone
+ // and "I think" merges can not be equivalent to anything
+ if numParentsA != 1 || numParentsB != 1 {
+ return false, nil
+ }
+
+ diffsA, err := patch(a, path)
+ if err != nil {
+ return false, err
+ }
+ diffsB, err := patch(b, path)
+ if err != nil {
+ return false, err
+ }
+
+ return sameDiffs(diffsA, diffsB), nil
+}
+
+func patch(c *object.Commit, path string) ([]diffmatchpatch.Diff, error) {
+ // get contents of the file in the commit
+ file, err := c.File(path)
+ if err != nil {
+ return nil, err
+ }
+ content, err := file.Contents()
+ if err != nil {
+ return nil, err
+ }
+
+ // get contents of the file in the first parent of the commit
+ var contentParent string
+ iter := c.Parents()
+ parent, err := iter.Next()
+ if err != nil {
+ return nil, err
+ }
+ file, err = parent.File(path)
+ if err != nil {
+ contentParent = ""
+ } else {
+ contentParent, err = file.Contents()
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // compare the contents of parent and child
+ return diff.Do(content, contentParent), nil
+}
+
+func sameDiffs(a, b []diffmatchpatch.Diff) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i := range a {
+ if !sameDiff(a[i], b[i]) {
+ return false
+ }
+ }
+ return true
+}
+
+func sameDiff(a, b diffmatchpatch.Diff) bool {
+ if a.Type != b.Type {
+ return false
+ }
+ switch a.Type {
+ case 0:
+ return countLines(a.Text) == countLines(b.Text)
+ case 1, -1:
+ return a.Text == b.Text
+ default:
+ panic("unreachable")
+ }
+}
--- /dev/null
+package git
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/config"
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/packfile"
+ "gopkg.in/src-d/go-git.v4/plumbing/object"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/capability"
+ "gopkg.in/src-d/go-git.v4/plumbing/protocol/packp/sideband"
+ "gopkg.in/src-d/go-git.v4/plumbing/revlist"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport"
+ "gopkg.in/src-d/go-git.v4/plumbing/transport/client"
+ "gopkg.in/src-d/go-git.v4/storage"
+ "gopkg.in/src-d/go-git.v4/storage/memory"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+var (
+ NoErrAlreadyUpToDate = errors.New("already up-to-date")
+ ErrDeleteRefNotSupported = errors.New("server does not support delete-refs")
+ ErrForceNeeded = errors.New("some refs were not updated")
+)
+
+const (
+ // This describes the maximum number of commits to walk when
+ // computing the haves to send to a server, for each ref in the
+ // repo containing this remote, when not using the multi-ack
+ // protocol. Setting this to 0 means there is no limit.
+ maxHavesToVisitPerRef = 100
+)
+
+// Remote represents a connection to a remote repository.
+type Remote struct {
+ c *config.RemoteConfig
+ s storage.Storer
+}
+
+func newRemote(s storage.Storer, c *config.RemoteConfig) *Remote {
+ return &Remote{s: s, c: c}
+}
+
+// Config returns the RemoteConfig object used to instantiate this Remote.
+func (r *Remote) Config() *config.RemoteConfig {
+ return r.c
+}
+
+func (r *Remote) String() string {
+ var fetch, push string
+ if len(r.c.URLs) > 0 {
+ fetch = r.c.URLs[0]
+ push = r.c.URLs[0]
+ }
+
+ return fmt.Sprintf("%s\t%s (fetch)\n%[1]s\t%[3]s (push)", r.c.Name, fetch, push)
+}
+
+// Push performs a push to the remote. Returns NoErrAlreadyUpToDate if the
+// remote was already up-to-date.
+func (r *Remote) Push(o *PushOptions) error {
+ return r.PushContext(context.Background(), o)
+}
+
+// PushContext performs a push to the remote. Returns NoErrAlreadyUpToDate if
+// the remote was already up-to-date.
+//
+// The provided Context must be non-nil. If the context expires before the
+// operation is complete, an error is returned. The context only affects to the
+// transport operations.
+func (r *Remote) PushContext(ctx context.Context, o *PushOptions) (err error) {
+ if err := o.Validate(); err != nil {
+ return err
+ }
+
+ if o.RemoteName != r.c.Name {
+ return fmt.Errorf("remote names don't match: %s != %s", o.RemoteName, r.c.Name)
+ }
+
+ s, err := newSendPackSession(r.c.URLs[0], o.Auth)
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(s, &err)
+
+ ar, err := s.AdvertisedReferences()
+ if err != nil {
+ return err
+ }
+
+ remoteRefs, err := ar.AllReferences()
+ if err != nil {
+ return err
+ }
+
+ isDelete := false
+ allDelete := true
+ for _, rs := range o.RefSpecs {
+ if rs.IsDelete() {
+ isDelete = true
+ } else {
+ allDelete = false
+ }
+ if isDelete && !allDelete {
+ break
+ }
+ }
+
+ if isDelete && !ar.Capabilities.Supports(capability.DeleteRefs) {
+ return ErrDeleteRefNotSupported
+ }
+
+ localRefs, err := r.references()
+ if err != nil {
+ return err
+ }
+
+ req, err := r.newReferenceUpdateRequest(o, localRefs, remoteRefs, ar)
+ if err != nil {
+ return err
+ }
+
+ if len(req.Commands) == 0 {
+ return NoErrAlreadyUpToDate
+ }
+
+ objects := objectsToPush(req.Commands)
+
+ haves, err := referencesToHashes(remoteRefs)
+ if err != nil {
+ return err
+ }
+
+ stop, err := r.s.Shallow()
+ if err != nil {
+ return err
+ }
+
+ // if we have shallow we should include this as part of the objects that
+ // we are aware.
+ haves = append(haves, stop...)
+
+ var hashesToPush []plumbing.Hash
+ // Avoid the expensive revlist operation if we're only doing deletes.
+ if !allDelete {
+ hashesToPush, err = revlist.Objects(r.s, objects, haves)
+ if err != nil {
+ return err
+ }
+ }
+
+ rs, err := pushHashes(ctx, s, r.s, req, hashesToPush, r.useRefDeltas(ar))
+ if err != nil {
+ return err
+ }
+
+ if err = rs.Error(); err != nil {
+ return err
+ }
+
+ return r.updateRemoteReferenceStorage(req, rs)
+}
+
+func (r *Remote) useRefDeltas(ar *packp.AdvRefs) bool {
+ return !ar.Capabilities.Supports(capability.OFSDelta)
+}
+
+func (r *Remote) newReferenceUpdateRequest(
+ o *PushOptions,
+ localRefs []*plumbing.Reference,
+ remoteRefs storer.ReferenceStorer,
+ ar *packp.AdvRefs,
+) (*packp.ReferenceUpdateRequest, error) {
+ req := packp.NewReferenceUpdateRequestFromCapabilities(ar.Capabilities)
+
+ if o.Progress != nil {
+ req.Progress = o.Progress
+ if ar.Capabilities.Supports(capability.Sideband64k) {
+ req.Capabilities.Set(capability.Sideband64k)
+ } else if ar.Capabilities.Supports(capability.Sideband) {
+ req.Capabilities.Set(capability.Sideband)
+ }
+ }
+
+ if err := r.addReferencesToUpdate(o.RefSpecs, localRefs, remoteRefs, req); err != nil {
+ return nil, err
+ }
+
+ return req, nil
+}
+
+func (r *Remote) updateRemoteReferenceStorage(
+ req *packp.ReferenceUpdateRequest,
+ result *packp.ReportStatus,
+) error {
+
+ for _, spec := range r.c.Fetch {
+ for _, c := range req.Commands {
+ if !spec.Match(c.Name) {
+ continue
+ }
+
+ local := spec.Dst(c.Name)
+ ref := plumbing.NewHashReference(local, c.New)
+ switch c.Action() {
+ case packp.Create, packp.Update:
+ if err := r.s.SetReference(ref); err != nil {
+ return err
+ }
+ case packp.Delete:
+ if err := r.s.RemoveReference(local); err != nil {
+ return err
+ }
+ }
+ }
+ }
+
+ return nil
+}
+
+// FetchContext fetches references along with the objects necessary to complete
+// their histories.
+//
+// Returns nil if the operation is successful, NoErrAlreadyUpToDate if there are
+// no changes to be fetched, or an error.
+//
+// The provided Context must be non-nil. If the context expires before the
+// operation is complete, an error is returned. The context only affects to the
+// transport operations.
+func (r *Remote) FetchContext(ctx context.Context, o *FetchOptions) error {
+ _, err := r.fetch(ctx, o)
+ return err
+}
+
+// Fetch fetches references along with the objects necessary to complete their
+// histories.
+//
+// Returns nil if the operation is successful, NoErrAlreadyUpToDate if there are
+// no changes to be fetched, or an error.
+func (r *Remote) Fetch(o *FetchOptions) error {
+ return r.FetchContext(context.Background(), o)
+}
+
+func (r *Remote) fetch(ctx context.Context, o *FetchOptions) (sto storer.ReferenceStorer, err error) {
+ if o.RemoteName == "" {
+ o.RemoteName = r.c.Name
+ }
+
+ if err = o.Validate(); err != nil {
+ return nil, err
+ }
+
+ if len(o.RefSpecs) == 0 {
+ o.RefSpecs = r.c.Fetch
+ }
+
+ s, err := newUploadPackSession(r.c.URLs[0], o.Auth)
+ if err != nil {
+ return nil, err
+ }
+
+ defer ioutil.CheckClose(s, &err)
+
+ ar, err := s.AdvertisedReferences()
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := r.newUploadPackRequest(o, ar)
+ if err != nil {
+ return nil, err
+ }
+
+ remoteRefs, err := ar.AllReferences()
+ if err != nil {
+ return nil, err
+ }
+
+ localRefs, err := r.references()
+ if err != nil {
+ return nil, err
+ }
+
+ refs, err := calculateRefs(o.RefSpecs, remoteRefs, o.Tags)
+ if err != nil {
+ return nil, err
+ }
+
+ req.Wants, err = getWants(r.s, refs)
+ if len(req.Wants) > 0 {
+ req.Haves, err = getHaves(localRefs, remoteRefs, r.s)
+ if err != nil {
+ return nil, err
+ }
+
+ if err = r.fetchPack(ctx, o, s, req); err != nil {
+ return nil, err
+ }
+ }
+
+ updated, err := r.updateLocalReferenceStorage(o.RefSpecs, refs, remoteRefs, o.Tags, o.Force)
+ if err != nil {
+ return nil, err
+ }
+
+ if !updated {
+ return remoteRefs, NoErrAlreadyUpToDate
+ }
+
+ return remoteRefs, nil
+}
+
+func newUploadPackSession(url string, auth transport.AuthMethod) (transport.UploadPackSession, error) {
+ c, ep, err := newClient(url)
+ if err != nil {
+ return nil, err
+ }
+
+ return c.NewUploadPackSession(ep, auth)
+}
+
+func newSendPackSession(url string, auth transport.AuthMethod) (transport.ReceivePackSession, error) {
+ c, ep, err := newClient(url)
+ if err != nil {
+ return nil, err
+ }
+
+ return c.NewReceivePackSession(ep, auth)
+}
+
+func newClient(url string) (transport.Transport, *transport.Endpoint, error) {
+ ep, err := transport.NewEndpoint(url)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ c, err := client.NewClient(ep)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return c, ep, err
+}
+
+func (r *Remote) fetchPack(ctx context.Context, o *FetchOptions, s transport.UploadPackSession,
+ req *packp.UploadPackRequest) (err error) {
+
+ reader, err := s.UploadPack(ctx, req)
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(reader, &err)
+
+ if err = r.updateShallow(o, reader); err != nil {
+ return err
+ }
+
+ if err = packfile.UpdateObjectStorage(r.s,
+ buildSidebandIfSupported(req.Capabilities, reader, o.Progress),
+ ); err != nil {
+ return err
+ }
+
+ return err
+}
+
+func (r *Remote) addReferencesToUpdate(
+ refspecs []config.RefSpec,
+ localRefs []*plumbing.Reference,
+ remoteRefs storer.ReferenceStorer,
+ req *packp.ReferenceUpdateRequest,
+) error {
+ // This references dictionary will be used to search references by name.
+ refsDict := make(map[string]*plumbing.Reference)
+ for _, ref := range localRefs {
+ refsDict[ref.Name().String()] = ref
+ }
+
+ for _, rs := range refspecs {
+ if rs.IsDelete() {
+ if err := r.deleteReferences(rs, remoteRefs, req); err != nil {
+ return err
+ }
+ } else {
+ err := r.addOrUpdateReferences(rs, localRefs, refsDict, remoteRefs, req)
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+func (r *Remote) addOrUpdateReferences(
+ rs config.RefSpec,
+ localRefs []*plumbing.Reference,
+ refsDict map[string]*plumbing.Reference,
+ remoteRefs storer.ReferenceStorer,
+ req *packp.ReferenceUpdateRequest,
+) error {
+ // If it is not a wilcard refspec we can directly search for the reference
+ // in the references dictionary.
+ if !rs.IsWildcard() {
+ ref, ok := refsDict[rs.Src()]
+ if !ok {
+ return nil
+ }
+
+ return r.addReferenceIfRefSpecMatches(rs, remoteRefs, ref, req)
+ }
+
+ for _, ref := range localRefs {
+ err := r.addReferenceIfRefSpecMatches(rs, remoteRefs, ref, req)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (r *Remote) deleteReferences(rs config.RefSpec,
+ remoteRefs storer.ReferenceStorer, req *packp.ReferenceUpdateRequest) error {
+ iter, err := remoteRefs.IterReferences()
+ if err != nil {
+ return err
+ }
+
+ return iter.ForEach(func(ref *plumbing.Reference) error {
+ if ref.Type() != plumbing.HashReference {
+ return nil
+ }
+
+ if rs.Dst("") != ref.Name() {
+ return nil
+ }
+
+ cmd := &packp.Command{
+ Name: ref.Name(),
+ Old: ref.Hash(),
+ New: plumbing.ZeroHash,
+ }
+ req.Commands = append(req.Commands, cmd)
+ return nil
+ })
+}
+
+func (r *Remote) addReferenceIfRefSpecMatches(rs config.RefSpec,
+ remoteRefs storer.ReferenceStorer, localRef *plumbing.Reference,
+ req *packp.ReferenceUpdateRequest) error {
+
+ if localRef.Type() != plumbing.HashReference {
+ return nil
+ }
+
+ if !rs.Match(localRef.Name()) {
+ return nil
+ }
+
+ cmd := &packp.Command{
+ Name: rs.Dst(localRef.Name()),
+ Old: plumbing.ZeroHash,
+ New: localRef.Hash(),
+ }
+
+ remoteRef, err := remoteRefs.Reference(cmd.Name)
+ if err == nil {
+ if remoteRef.Type() != plumbing.HashReference {
+ //TODO: check actual git behavior here
+ return nil
+ }
+
+ cmd.Old = remoteRef.Hash()
+ } else if err != plumbing.ErrReferenceNotFound {
+ return err
+ }
+
+ if cmd.Old == cmd.New {
+ return nil
+ }
+
+ if !rs.IsForceUpdate() {
+ if err := checkFastForwardUpdate(r.s, remoteRefs, cmd); err != nil {
+ return err
+ }
+ }
+
+ req.Commands = append(req.Commands, cmd)
+ return nil
+}
+
+func (r *Remote) references() ([]*plumbing.Reference, error) {
+ var localRefs []*plumbing.Reference
+ iter, err := r.s.IterReferences()
+ if err != nil {
+ return nil, err
+ }
+
+ for {
+ ref, err := iter.Next()
+ if err == io.EOF {
+ break
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ localRefs = append(localRefs, ref)
+ }
+
+ return localRefs, nil
+}
+
+func getRemoteRefsFromStorer(remoteRefStorer storer.ReferenceStorer) (
+ map[plumbing.Hash]bool, error) {
+ remoteRefs := map[plumbing.Hash]bool{}
+ iter, err := remoteRefStorer.IterReferences()
+ if err != nil {
+ return nil, err
+ }
+ err = iter.ForEach(func(ref *plumbing.Reference) error {
+ if ref.Type() != plumbing.HashReference {
+ return nil
+ }
+ remoteRefs[ref.Hash()] = true
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+ return remoteRefs, nil
+}
+
+// getHavesFromRef populates the given `haves` map with the given
+// reference, and up to `maxHavesToVisitPerRef` ancestor commits.
+func getHavesFromRef(
+ ref *plumbing.Reference,
+ remoteRefs map[plumbing.Hash]bool,
+ s storage.Storer,
+ haves map[plumbing.Hash]bool,
+) error {
+ h := ref.Hash()
+ if haves[h] {
+ return nil
+ }
+
+ // No need to load the commit if we know the remote already
+ // has this hash.
+ if remoteRefs[h] {
+ haves[h] = true
+ return nil
+ }
+
+ commit, err := object.GetCommit(s, h)
+ if err != nil {
+ // Ignore the error if this isn't a commit.
+ haves[ref.Hash()] = true
+ return nil
+ }
+
+ // Until go-git supports proper commit negotiation during an
+ // upload pack request, include up to `maxHavesToVisitPerRef`
+ // commits from the history of each ref.
+ walker := object.NewCommitPreorderIter(commit, haves, nil)
+ toVisit := maxHavesToVisitPerRef
+ return walker.ForEach(func(c *object.Commit) error {
+ haves[c.Hash] = true
+ toVisit--
+ // If toVisit starts out at 0 (indicating there is no
+ // max), then it will be negative here and we won't stop
+ // early.
+ if toVisit == 0 || remoteRefs[c.Hash] {
+ return storer.ErrStop
+ }
+ return nil
+ })
+}
+
+func getHaves(
+ localRefs []*plumbing.Reference,
+ remoteRefStorer storer.ReferenceStorer,
+ s storage.Storer,
+) ([]plumbing.Hash, error) {
+ haves := map[plumbing.Hash]bool{}
+
+ // Build a map of all the remote references, to avoid loading too
+ // many parent commits for references we know don't need to be
+ // transferred.
+ remoteRefs, err := getRemoteRefsFromStorer(remoteRefStorer)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, ref := range localRefs {
+ if haves[ref.Hash()] {
+ continue
+ }
+
+ if ref.Type() != plumbing.HashReference {
+ continue
+ }
+
+ err = getHavesFromRef(ref, remoteRefs, s, haves)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ var result []plumbing.Hash
+ for h := range haves {
+ result = append(result, h)
+ }
+
+ return result, nil
+}
+
+const refspecAllTags = "+refs/tags/*:refs/tags/*"
+
+func calculateRefs(
+ spec []config.RefSpec,
+ remoteRefs storer.ReferenceStorer,
+ tagMode TagMode,
+) (memory.ReferenceStorage, error) {
+ if tagMode == AllTags {
+ spec = append(spec, refspecAllTags)
+ }
+
+ refs := make(memory.ReferenceStorage)
+ for _, s := range spec {
+ if err := doCalculateRefs(s, remoteRefs, refs); err != nil {
+ return nil, err
+ }
+ }
+
+ return refs, nil
+}
+
+func doCalculateRefs(
+ s config.RefSpec,
+ remoteRefs storer.ReferenceStorer,
+ refs memory.ReferenceStorage,
+) error {
+ iter, err := remoteRefs.IterReferences()
+ if err != nil {
+ return err
+ }
+
+ var matched bool
+ err = iter.ForEach(func(ref *plumbing.Reference) error {
+ if !s.Match(ref.Name()) {
+ return nil
+ }
+
+ if ref.Type() == plumbing.SymbolicReference {
+ target, err := storer.ResolveReference(remoteRefs, ref.Name())
+ if err != nil {
+ return err
+ }
+
+ ref = plumbing.NewHashReference(ref.Name(), target.Hash())
+ }
+
+ if ref.Type() != plumbing.HashReference {
+ return nil
+ }
+
+ matched = true
+ if err := refs.SetReference(ref); err != nil {
+ return err
+ }
+
+ if !s.IsWildcard() {
+ return storer.ErrStop
+ }
+
+ return nil
+ })
+
+ if !matched && !s.IsWildcard() {
+ return fmt.Errorf("couldn't find remote ref %q", s.Src())
+ }
+
+ return err
+}
+
+func getWants(localStorer storage.Storer, refs memory.ReferenceStorage) ([]plumbing.Hash, error) {
+ wants := map[plumbing.Hash]bool{}
+ for _, ref := range refs {
+ hash := ref.Hash()
+ exists, err := objectExists(localStorer, ref.Hash())
+ if err != nil {
+ return nil, err
+ }
+
+ if !exists {
+ wants[hash] = true
+ }
+ }
+
+ var result []plumbing.Hash
+ for h := range wants {
+ result = append(result, h)
+ }
+
+ return result, nil
+}
+
+func objectExists(s storer.EncodedObjectStorer, h plumbing.Hash) (bool, error) {
+ _, err := s.EncodedObject(plumbing.AnyObject, h)
+ if err == plumbing.ErrObjectNotFound {
+ return false, nil
+ }
+
+ return true, err
+}
+
+func checkFastForwardUpdate(s storer.EncodedObjectStorer, remoteRefs storer.ReferenceStorer, cmd *packp.Command) error {
+ if cmd.Old == plumbing.ZeroHash {
+ _, err := remoteRefs.Reference(cmd.Name)
+ if err == plumbing.ErrReferenceNotFound {
+ return nil
+ }
+
+ if err != nil {
+ return err
+ }
+
+ return fmt.Errorf("non-fast-forward update: %s", cmd.Name.String())
+ }
+
+ ff, err := isFastForward(s, cmd.Old, cmd.New)
+ if err != nil {
+ return err
+ }
+
+ if !ff {
+ return fmt.Errorf("non-fast-forward update: %s", cmd.Name.String())
+ }
+
+ return nil
+}
+
+func isFastForward(s storer.EncodedObjectStorer, old, new plumbing.Hash) (bool, error) {
+ c, err := object.GetCommit(s, new)
+ if err != nil {
+ return false, err
+ }
+
+ found := false
+ iter := object.NewCommitPreorderIter(c, nil, nil)
+ err = iter.ForEach(func(c *object.Commit) error {
+ if c.Hash != old {
+ return nil
+ }
+
+ found = true
+ return storer.ErrStop
+ })
+ return found, err
+}
+
+func (r *Remote) newUploadPackRequest(o *FetchOptions,
+ ar *packp.AdvRefs) (*packp.UploadPackRequest, error) {
+
+ req := packp.NewUploadPackRequestFromCapabilities(ar.Capabilities)
+
+ if o.Depth != 0 {
+ req.Depth = packp.DepthCommits(o.Depth)
+ if err := req.Capabilities.Set(capability.Shallow); err != nil {
+ return nil, err
+ }
+ }
+
+ if o.Progress == nil && ar.Capabilities.Supports(capability.NoProgress) {
+ if err := req.Capabilities.Set(capability.NoProgress); err != nil {
+ return nil, err
+ }
+ }
+
+ isWildcard := true
+ for _, s := range o.RefSpecs {
+ if !s.IsWildcard() {
+ isWildcard = false
+ break
+ }
+ }
+
+ if isWildcard && o.Tags == TagFollowing && ar.Capabilities.Supports(capability.IncludeTag) {
+ if err := req.Capabilities.Set(capability.IncludeTag); err != nil {
+ return nil, err
+ }
+ }
+
+ return req, nil
+}
+
+func buildSidebandIfSupported(l *capability.List, reader io.Reader, p sideband.Progress) io.Reader {
+ var t sideband.Type
+
+ switch {
+ case l.Supports(capability.Sideband):
+ t = sideband.Sideband
+ case l.Supports(capability.Sideband64k):
+ t = sideband.Sideband64k
+ default:
+ return reader
+ }
+
+ d := sideband.NewDemuxer(t, reader)
+ d.Progress = p
+
+ return d
+}
+
+func (r *Remote) updateLocalReferenceStorage(
+ specs []config.RefSpec,
+ fetchedRefs, remoteRefs memory.ReferenceStorage,
+ tagMode TagMode,
+ force bool,
+) (updated bool, err error) {
+ isWildcard := true
+ forceNeeded := false
+
+ for _, spec := range specs {
+ if !spec.IsWildcard() {
+ isWildcard = false
+ }
+
+ for _, ref := range fetchedRefs {
+ if !spec.Match(ref.Name()) {
+ continue
+ }
+
+ if ref.Type() != plumbing.HashReference {
+ continue
+ }
+
+ localName := spec.Dst(ref.Name())
+ old, _ := storer.ResolveReference(r.s, localName)
+ new := plumbing.NewHashReference(localName, ref.Hash())
+
+ // If the ref exists locally as a branch and force is not specified,
+ // only update if the new ref is an ancestor of the old
+ if old != nil && old.Name().IsBranch() && !force && !spec.IsForceUpdate() {
+ ff, err := isFastForward(r.s, old.Hash(), new.Hash())
+ if err != nil {
+ return updated, err
+ }
+
+ if !ff {
+ forceNeeded = true
+ continue
+ }
+ }
+
+ refUpdated, err := checkAndUpdateReferenceStorerIfNeeded(r.s, new, old)
+ if err != nil {
+ return updated, err
+ }
+
+ if refUpdated {
+ updated = true
+ }
+ }
+ }
+
+ if tagMode == NoTags {
+ return updated, nil
+ }
+
+ tags := fetchedRefs
+ if isWildcard {
+ tags = remoteRefs
+ }
+ tagUpdated, err := r.buildFetchedTags(tags)
+ if err != nil {
+ return updated, err
+ }
+
+ if tagUpdated {
+ updated = true
+ }
+
+ if err == nil && forceNeeded {
+ err = ErrForceNeeded
+ }
+
+ return
+}
+
+func (r *Remote) buildFetchedTags(refs memory.ReferenceStorage) (updated bool, err error) {
+ for _, ref := range refs {
+ if !ref.Name().IsTag() {
+ continue
+ }
+
+ _, err := r.s.EncodedObject(plumbing.AnyObject, ref.Hash())
+ if err == plumbing.ErrObjectNotFound {
+ continue
+ }
+
+ if err != nil {
+ return false, err
+ }
+
+ refUpdated, err := updateReferenceStorerIfNeeded(r.s, ref)
+ if err != nil {
+ return updated, err
+ }
+
+ if refUpdated {
+ updated = true
+ }
+ }
+
+ return
+}
+
+// List the references on the remote repository.
+func (r *Remote) List(o *ListOptions) (rfs []*plumbing.Reference, err error) {
+ s, err := newUploadPackSession(r.c.URLs[0], o.Auth)
+ if err != nil {
+ return nil, err
+ }
+
+ defer ioutil.CheckClose(s, &err)
+
+ ar, err := s.AdvertisedReferences()
+ if err != nil {
+ return nil, err
+ }
+
+ allRefs, err := ar.AllReferences()
+ if err != nil {
+ return nil, err
+ }
+
+ refs, err := allRefs.IterReferences()
+ if err != nil {
+ return nil, err
+ }
+
+ var resultRefs []*plumbing.Reference
+ refs.ForEach(func(ref *plumbing.Reference) error {
+ resultRefs = append(resultRefs, ref)
+ return nil
+ })
+
+ return resultRefs, nil
+}
+
+func objectsToPush(commands []*packp.Command) []plumbing.Hash {
+ var objects []plumbing.Hash
+ for _, cmd := range commands {
+ if cmd.New == plumbing.ZeroHash {
+ continue
+ }
+
+ objects = append(objects, cmd.New)
+ }
+ return objects
+}
+
+func referencesToHashes(refs storer.ReferenceStorer) ([]plumbing.Hash, error) {
+ iter, err := refs.IterReferences()
+ if err != nil {
+ return nil, err
+ }
+
+ var hs []plumbing.Hash
+ err = iter.ForEach(func(ref *plumbing.Reference) error {
+ if ref.Type() != plumbing.HashReference {
+ return nil
+ }
+
+ hs = append(hs, ref.Hash())
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return hs, nil
+}
+
+func pushHashes(
+ ctx context.Context,
+ sess transport.ReceivePackSession,
+ s storage.Storer,
+ req *packp.ReferenceUpdateRequest,
+ hs []plumbing.Hash,
+ useRefDeltas bool,
+) (*packp.ReportStatus, error) {
+
+ rd, wr := io.Pipe()
+ req.Packfile = rd
+ config, err := s.Config()
+ if err != nil {
+ return nil, err
+ }
+ done := make(chan error)
+ go func() {
+ e := packfile.NewEncoder(wr, s, useRefDeltas)
+ if _, err := e.Encode(hs, config.Pack.Window); err != nil {
+ done <- wr.CloseWithError(err)
+ return
+ }
+
+ done <- wr.Close()
+ }()
+
+ rs, err := sess.ReceivePack(ctx, req)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := <-done; err != nil {
+ return nil, err
+ }
+
+ return rs, nil
+}
+
+func (r *Remote) updateShallow(o *FetchOptions, resp *packp.UploadPackResponse) error {
+ if o.Depth == 0 || len(resp.Shallows) == 0 {
+ return nil
+ }
+
+ shallows, err := r.s.Shallow()
+ if err != nil {
+ return err
+ }
+
+outer:
+ for _, s := range resp.Shallows {
+ for _, oldS := range shallows {
+ if s == oldS {
+ continue outer
+ }
+ }
+ shallows = append(shallows, s)
+ }
+
+ return r.s.SetShallow(shallows)
+}
--- /dev/null
+package git
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ stdioutil "io/ioutil"
+ "os"
+ "path"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "golang.org/x/crypto/openpgp"
+ "gopkg.in/src-d/go-git.v4/config"
+ "gopkg.in/src-d/go-git.v4/internal/revision"
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/cache"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/packfile"
+ "gopkg.in/src-d/go-git.v4/plumbing/object"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/storage"
+ "gopkg.in/src-d/go-git.v4/storage/filesystem"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+
+ "gopkg.in/src-d/go-billy.v4"
+ "gopkg.in/src-d/go-billy.v4/osfs"
+)
+
+// GitDirName this is a special folder where all the git stuff is.
+const GitDirName = ".git"
+
+var (
+ // ErrBranchExists an error stating the specified branch already exists
+ ErrBranchExists = errors.New("branch already exists")
+ // ErrBranchNotFound an error stating the specified branch does not exist
+ ErrBranchNotFound = errors.New("branch not found")
+ // ErrTagExists an error stating the specified tag already exists
+ ErrTagExists = errors.New("tag already exists")
+ // ErrTagNotFound an error stating the specified tag does not exist
+ ErrTagNotFound = errors.New("tag not found")
+
+ ErrInvalidReference = errors.New("invalid reference, should be a tag or a branch")
+ ErrRepositoryNotExists = errors.New("repository does not exist")
+ ErrRepositoryAlreadyExists = errors.New("repository already exists")
+ ErrRemoteNotFound = errors.New("remote not found")
+ ErrRemoteExists = errors.New("remote already exists")
+ ErrWorktreeNotProvided = errors.New("worktree should be provided")
+ ErrIsBareRepository = errors.New("worktree not available in a bare repository")
+ ErrUnableToResolveCommit = errors.New("unable to resolve commit")
+ ErrPackedObjectsNotSupported = errors.New("Packed objects not supported")
+)
+
+// Repository represents a git repository
+type Repository struct {
+ Storer storage.Storer
+
+ r map[string]*Remote
+ wt billy.Filesystem
+}
+
+// Init creates an empty git repository, based on the given Storer and worktree.
+// The worktree Filesystem is optional, if nil a bare repository is created. If
+// the given storer is not empty ErrRepositoryAlreadyExists is returned
+func Init(s storage.Storer, worktree billy.Filesystem) (*Repository, error) {
+ if err := initStorer(s); err != nil {
+ return nil, err
+ }
+
+ r := newRepository(s, worktree)
+ _, err := r.Reference(plumbing.HEAD, false)
+ switch err {
+ case plumbing.ErrReferenceNotFound:
+ case nil:
+ return nil, ErrRepositoryAlreadyExists
+ default:
+ return nil, err
+ }
+
+ h := plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.Master)
+ if err := s.SetReference(h); err != nil {
+ return nil, err
+ }
+
+ if worktree == nil {
+ r.setIsBare(true)
+ return r, nil
+ }
+
+ return r, setWorktreeAndStoragePaths(r, worktree)
+}
+
+func initStorer(s storer.Storer) error {
+ i, ok := s.(storer.Initializer)
+ if !ok {
+ return nil
+ }
+
+ return i.Init()
+}
+
+func setWorktreeAndStoragePaths(r *Repository, worktree billy.Filesystem) error {
+ type fsBased interface {
+ Filesystem() billy.Filesystem
+ }
+
+ // .git file is only created if the storage is file based and the file
+ // system is osfs.OS
+ fs, isFSBased := r.Storer.(fsBased)
+ if !isFSBased {
+ return nil
+ }
+
+ if err := createDotGitFile(worktree, fs.Filesystem()); err != nil {
+ return err
+ }
+
+ return setConfigWorktree(r, worktree, fs.Filesystem())
+}
+
+func createDotGitFile(worktree, storage billy.Filesystem) error {
+ path, err := filepath.Rel(worktree.Root(), storage.Root())
+ if err != nil {
+ path = storage.Root()
+ }
+
+ if path == GitDirName {
+ // not needed, since the folder is the default place
+ return nil
+ }
+
+ f, err := worktree.Create(GitDirName)
+ if err != nil {
+ return err
+ }
+
+ defer f.Close()
+ _, err = fmt.Fprintf(f, "gitdir: %s\n", path)
+ return err
+}
+
+func setConfigWorktree(r *Repository, worktree, storage billy.Filesystem) error {
+ path, err := filepath.Rel(storage.Root(), worktree.Root())
+ if err != nil {
+ path = worktree.Root()
+ }
+
+ if path == ".." {
+ // not needed, since the folder is the default place
+ return nil
+ }
+
+ cfg, err := r.Storer.Config()
+ if err != nil {
+ return err
+ }
+
+ cfg.Core.Worktree = path
+ return r.Storer.SetConfig(cfg)
+}
+
+// Open opens a git repository using the given Storer and worktree filesystem,
+// if the given storer is complete empty ErrRepositoryNotExists is returned.
+// The worktree can be nil when the repository being opened is bare, if the
+// repository is a normal one (not bare) and worktree is nil the err
+// ErrWorktreeNotProvided is returned
+func Open(s storage.Storer, worktree billy.Filesystem) (*Repository, error) {
+ _, err := s.Reference(plumbing.HEAD)
+ if err == plumbing.ErrReferenceNotFound {
+ return nil, ErrRepositoryNotExists
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ return newRepository(s, worktree), nil
+}
+
+// Clone a repository into the given Storer and worktree Filesystem with the
+// given options, if worktree is nil a bare repository is created. If the given
+// storer is not empty ErrRepositoryAlreadyExists is returned.
+//
+// The provided Context must be non-nil. If the context expires before the
+// operation is complete, an error is returned. The context only affects to the
+// transport operations.
+func Clone(s storage.Storer, worktree billy.Filesystem, o *CloneOptions) (*Repository, error) {
+ return CloneContext(context.Background(), s, worktree, o)
+}
+
+// CloneContext a repository into the given Storer and worktree Filesystem with
+// the given options, if worktree is nil a bare repository is created. If the
+// given storer is not empty ErrRepositoryAlreadyExists is returned.
+//
+// The provided Context must be non-nil. If the context expires before the
+// operation is complete, an error is returned. The context only affects to the
+// transport operations.
+func CloneContext(
+ ctx context.Context, s storage.Storer, worktree billy.Filesystem, o *CloneOptions,
+) (*Repository, error) {
+ r, err := Init(s, worktree)
+ if err != nil {
+ return nil, err
+ }
+
+ return r, r.clone(ctx, o)
+}
+
+// PlainInit create an empty git repository at the given path. isBare defines
+// if the repository will have worktree (non-bare) or not (bare), if the path
+// is not empty ErrRepositoryAlreadyExists is returned.
+func PlainInit(path string, isBare bool) (*Repository, error) {
+ var wt, dot billy.Filesystem
+
+ if isBare {
+ dot = osfs.New(path)
+ } else {
+ wt = osfs.New(path)
+ dot, _ = wt.Chroot(GitDirName)
+ }
+
+ s := filesystem.NewStorage(dot, cache.NewObjectLRUDefault())
+
+ return Init(s, wt)
+}
+
+// PlainOpen opens a git repository from the given path. It detects if the
+// repository is bare or a normal one. If the path doesn't contain a valid
+// repository ErrRepositoryNotExists is returned
+func PlainOpen(path string) (*Repository, error) {
+ return PlainOpenWithOptions(path, &PlainOpenOptions{})
+}
+
+// PlainOpenWithOptions opens a git repository from the given path with specific
+// options. See PlainOpen for more info.
+func PlainOpenWithOptions(path string, o *PlainOpenOptions) (*Repository, error) {
+ dot, wt, err := dotGitToOSFilesystems(path, o.DetectDotGit)
+ if err != nil {
+ return nil, err
+ }
+
+ if _, err := dot.Stat(""); err != nil {
+ if os.IsNotExist(err) {
+ return nil, ErrRepositoryNotExists
+ }
+
+ return nil, err
+ }
+
+ s := filesystem.NewStorage(dot, cache.NewObjectLRUDefault())
+
+ return Open(s, wt)
+}
+
+func dotGitToOSFilesystems(path string, detect bool) (dot, wt billy.Filesystem, err error) {
+ if path, err = filepath.Abs(path); err != nil {
+ return nil, nil, err
+ }
+ var fs billy.Filesystem
+ var fi os.FileInfo
+ for {
+ fs = osfs.New(path)
+ fi, err = fs.Stat(GitDirName)
+ if err == nil {
+ // no error; stop
+ break
+ }
+ if !os.IsNotExist(err) {
+ // unknown error; stop
+ return nil, nil, err
+ }
+ if detect {
+ // try its parent as long as we haven't reached
+ // the root dir
+ if dir := filepath.Dir(path); dir != path {
+ path = dir
+ continue
+ }
+ }
+ // not detecting via parent dirs and the dir does not exist;
+ // stop
+ return fs, nil, nil
+ }
+
+ if fi.IsDir() {
+ dot, err = fs.Chroot(GitDirName)
+ return dot, fs, err
+ }
+
+ dot, err = dotGitFileToOSFilesystem(path, fs)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return dot, fs, nil
+}
+
+func dotGitFileToOSFilesystem(path string, fs billy.Filesystem) (bfs billy.Filesystem, err error) {
+ f, err := fs.Open(GitDirName)
+ if err != nil {
+ return nil, err
+ }
+ defer ioutil.CheckClose(f, &err)
+
+ b, err := stdioutil.ReadAll(f)
+ if err != nil {
+ return nil, err
+ }
+
+ line := string(b)
+ const prefix = "gitdir: "
+ if !strings.HasPrefix(line, prefix) {
+ return nil, fmt.Errorf(".git file has no %s prefix", prefix)
+ }
+
+ gitdir := strings.Split(line[len(prefix):], "\n")[0]
+ gitdir = strings.TrimSpace(gitdir)
+ if filepath.IsAbs(gitdir) {
+ return osfs.New(gitdir), nil
+ }
+
+ return osfs.New(fs.Join(path, gitdir)), nil
+}
+
+// PlainClone a repository into the path with the given options, isBare defines
+// if the new repository will be bare or normal. If the path is not empty
+// ErrRepositoryAlreadyExists is returned.
+//
+// TODO(mcuadros): move isBare to CloneOptions in v5
+func PlainClone(path string, isBare bool, o *CloneOptions) (*Repository, error) {
+ return PlainCloneContext(context.Background(), path, isBare, o)
+}
+
+// PlainCloneContext a repository into the path with the given options, isBare
+// defines if the new repository will be bare or normal. If the path is not empty
+// ErrRepositoryAlreadyExists is returned.
+//
+// The provided Context must be non-nil. If the context expires before the
+// operation is complete, an error is returned. The context only affects to the
+// transport operations.
+//
+// TODO(mcuadros): move isBare to CloneOptions in v5
+func PlainCloneContext(ctx context.Context, path string, isBare bool, o *CloneOptions) (*Repository, error) {
+ dirExists, err := checkExistsAndIsEmptyDir(path)
+ if err != nil {
+ return nil, err
+ }
+
+ r, err := PlainInit(path, isBare)
+ if err != nil {
+ return nil, err
+ }
+
+ err = r.clone(ctx, o)
+ if err != nil && err != ErrRepositoryAlreadyExists {
+ cleanUpDir(path, !dirExists)
+ }
+
+ return r, err
+}
+
+func newRepository(s storage.Storer, worktree billy.Filesystem) *Repository {
+ return &Repository{
+ Storer: s,
+ wt: worktree,
+ r: make(map[string]*Remote),
+ }
+}
+
+func checkExistsAndIsEmptyDir(path string) (exists bool, err error) {
+ fi, err := os.Stat(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return false, nil
+ }
+
+ return false, err
+ }
+
+ if !fi.IsDir() {
+ return false, fmt.Errorf("path is not a directory: %s", path)
+ }
+
+ f, err := os.Open(path)
+ if err != nil {
+ return false, err
+ }
+
+ defer ioutil.CheckClose(f, &err)
+
+ _, err = f.Readdirnames(1)
+ if err == io.EOF {
+ return true, nil
+ }
+
+ if err != nil {
+ return true, err
+ }
+
+ return true, fmt.Errorf("directory is not empty: %s", path)
+}
+
+func cleanUpDir(path string, all bool) error {
+ if all {
+ return os.RemoveAll(path)
+ }
+
+ f, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(f, &err)
+
+ names, err := f.Readdirnames(-1)
+ if err != nil {
+ return err
+ }
+
+ for _, name := range names {
+ if err := os.RemoveAll(filepath.Join(path, name)); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Config return the repository config
+func (r *Repository) Config() (*config.Config, error) {
+ return r.Storer.Config()
+}
+
+// Remote return a remote if exists
+func (r *Repository) Remote(name string) (*Remote, error) {
+ cfg, err := r.Storer.Config()
+ if err != nil {
+ return nil, err
+ }
+
+ c, ok := cfg.Remotes[name]
+ if !ok {
+ return nil, ErrRemoteNotFound
+ }
+
+ return newRemote(r.Storer, c), nil
+}
+
+// Remotes returns a list with all the remotes
+func (r *Repository) Remotes() ([]*Remote, error) {
+ cfg, err := r.Storer.Config()
+ if err != nil {
+ return nil, err
+ }
+
+ remotes := make([]*Remote, len(cfg.Remotes))
+
+ var i int
+ for _, c := range cfg.Remotes {
+ remotes[i] = newRemote(r.Storer, c)
+ i++
+ }
+
+ return remotes, nil
+}
+
+// CreateRemote creates a new remote
+func (r *Repository) CreateRemote(c *config.RemoteConfig) (*Remote, error) {
+ if err := c.Validate(); err != nil {
+ return nil, err
+ }
+
+ remote := newRemote(r.Storer, c)
+
+ cfg, err := r.Storer.Config()
+ if err != nil {
+ return nil, err
+ }
+
+ if _, ok := cfg.Remotes[c.Name]; ok {
+ return nil, ErrRemoteExists
+ }
+
+ cfg.Remotes[c.Name] = c
+ return remote, r.Storer.SetConfig(cfg)
+}
+
+// DeleteRemote delete a remote from the repository and delete the config
+func (r *Repository) DeleteRemote(name string) error {
+ cfg, err := r.Storer.Config()
+ if err != nil {
+ return err
+ }
+
+ if _, ok := cfg.Remotes[name]; !ok {
+ return ErrRemoteNotFound
+ }
+
+ delete(cfg.Remotes, name)
+ return r.Storer.SetConfig(cfg)
+}
+
+// Branch return a Branch if exists
+func (r *Repository) Branch(name string) (*config.Branch, error) {
+ cfg, err := r.Storer.Config()
+ if err != nil {
+ return nil, err
+ }
+
+ b, ok := cfg.Branches[name]
+ if !ok {
+ return nil, ErrBranchNotFound
+ }
+
+ return b, nil
+}
+
+// CreateBranch creates a new Branch
+func (r *Repository) CreateBranch(c *config.Branch) error {
+ if err := c.Validate(); err != nil {
+ return err
+ }
+
+ cfg, err := r.Storer.Config()
+ if err != nil {
+ return err
+ }
+
+ if _, ok := cfg.Branches[c.Name]; ok {
+ return ErrBranchExists
+ }
+
+ cfg.Branches[c.Name] = c
+ return r.Storer.SetConfig(cfg)
+}
+
+// DeleteBranch delete a Branch from the repository and delete the config
+func (r *Repository) DeleteBranch(name string) error {
+ cfg, err := r.Storer.Config()
+ if err != nil {
+ return err
+ }
+
+ if _, ok := cfg.Branches[name]; !ok {
+ return ErrBranchNotFound
+ }
+
+ delete(cfg.Branches, name)
+ return r.Storer.SetConfig(cfg)
+}
+
+// CreateTag creates a tag. If opts is included, the tag is an annotated tag,
+// otherwise a lightweight tag is created.
+func (r *Repository) CreateTag(name string, hash plumbing.Hash, opts *CreateTagOptions) (*plumbing.Reference, error) {
+ rname := plumbing.ReferenceName(path.Join("refs", "tags", name))
+
+ _, err := r.Storer.Reference(rname)
+ switch err {
+ case nil:
+ // Tag exists, this is an error
+ return nil, ErrTagExists
+ case plumbing.ErrReferenceNotFound:
+ // Tag missing, available for creation, pass this
+ default:
+ // Some other error
+ return nil, err
+ }
+
+ var target plumbing.Hash
+ if opts != nil {
+ target, err = r.createTagObject(name, hash, opts)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ target = hash
+ }
+
+ ref := plumbing.NewHashReference(rname, target)
+ if err = r.Storer.SetReference(ref); err != nil {
+ return nil, err
+ }
+
+ return ref, nil
+}
+
+func (r *Repository) createTagObject(name string, hash plumbing.Hash, opts *CreateTagOptions) (plumbing.Hash, error) {
+ if err := opts.Validate(r, hash); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ rawobj, err := object.GetObject(r.Storer, hash)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ tag := &object.Tag{
+ Name: name,
+ Tagger: *opts.Tagger,
+ Message: opts.Message,
+ TargetType: rawobj.Type(),
+ Target: hash,
+ }
+
+ if opts.SignKey != nil {
+ sig, err := r.buildTagSignature(tag, opts.SignKey)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ tag.PGPSignature = sig
+ }
+
+ obj := r.Storer.NewEncodedObject()
+ if err := tag.Encode(obj); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ return r.Storer.SetEncodedObject(obj)
+}
+
+func (r *Repository) buildTagSignature(tag *object.Tag, signKey *openpgp.Entity) (string, error) {
+ encoded := &plumbing.MemoryObject{}
+ if err := tag.Encode(encoded); err != nil {
+ return "", err
+ }
+
+ rdr, err := encoded.Reader()
+ if err != nil {
+ return "", err
+ }
+
+ var b bytes.Buffer
+ if err := openpgp.ArmoredDetachSign(&b, signKey, rdr, nil); err != nil {
+ return "", err
+ }
+
+ return b.String(), nil
+}
+
+// Tag returns a tag from the repository.
+//
+// If you want to check to see if the tag is an annotated tag, you can call
+// TagObject on the hash of the reference in ForEach:
+//
+// ref, err := r.Tag("v0.1.0")
+// if err != nil {
+// // Handle error
+// }
+//
+// obj, err := r.TagObject(ref.Hash())
+// switch err {
+// case nil:
+// // Tag object present
+// case plumbing.ErrObjectNotFound:
+// // Not a tag object
+// default:
+// // Some other error
+// }
+//
+func (r *Repository) Tag(name string) (*plumbing.Reference, error) {
+ ref, err := r.Reference(plumbing.ReferenceName(path.Join("refs", "tags", name)), false)
+ if err != nil {
+ if err == plumbing.ErrReferenceNotFound {
+ // Return a friendly error for this one, versus just ReferenceNotFound.
+ return nil, ErrTagNotFound
+ }
+
+ return nil, err
+ }
+
+ return ref, nil
+}
+
+// DeleteTag deletes a tag from the repository.
+func (r *Repository) DeleteTag(name string) error {
+ _, err := r.Tag(name)
+ if err != nil {
+ return err
+ }
+
+ return r.Storer.RemoveReference(plumbing.ReferenceName(path.Join("refs", "tags", name)))
+}
+
+func (r *Repository) resolveToCommitHash(h plumbing.Hash) (plumbing.Hash, error) {
+ obj, err := r.Storer.EncodedObject(plumbing.AnyObject, h)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+ switch obj.Type() {
+ case plumbing.TagObject:
+ t, err := object.DecodeTag(r.Storer, obj)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+ return r.resolveToCommitHash(t.Target)
+ case plumbing.CommitObject:
+ return h, nil
+ default:
+ return plumbing.ZeroHash, ErrUnableToResolveCommit
+ }
+}
+
+// Clone clones a remote repository
+func (r *Repository) clone(ctx context.Context, o *CloneOptions) error {
+ if err := o.Validate(); err != nil {
+ return err
+ }
+
+ c := &config.RemoteConfig{
+ Name: o.RemoteName,
+ URLs: []string{o.URL},
+ Fetch: r.cloneRefSpec(o),
+ }
+
+ if _, err := r.CreateRemote(c); err != nil {
+ return err
+ }
+
+ ref, err := r.fetchAndUpdateReferences(ctx, &FetchOptions{
+ RefSpecs: c.Fetch,
+ Depth: o.Depth,
+ Auth: o.Auth,
+ Progress: o.Progress,
+ Tags: o.Tags,
+ RemoteName: o.RemoteName,
+ }, o.ReferenceName)
+ if err != nil {
+ return err
+ }
+
+ if r.wt != nil && !o.NoCheckout {
+ w, err := r.Worktree()
+ if err != nil {
+ return err
+ }
+
+ head, err := r.Head()
+ if err != nil {
+ return err
+ }
+
+ if err := w.Reset(&ResetOptions{
+ Mode: MergeReset,
+ Commit: head.Hash(),
+ }); err != nil {
+ return err
+ }
+
+ if o.RecurseSubmodules != NoRecurseSubmodules {
+ if err := w.updateSubmodules(&SubmoduleUpdateOptions{
+ RecurseSubmodules: o.RecurseSubmodules,
+ Auth: o.Auth,
+ }); err != nil {
+ return err
+ }
+ }
+ }
+
+ if err := r.updateRemoteConfigIfNeeded(o, c, ref); err != nil {
+ return err
+ }
+
+ if ref.Name().IsBranch() {
+ branchRef := ref.Name()
+ branchName := strings.Split(string(branchRef), "refs/heads/")[1]
+
+ b := &config.Branch{
+ Name: branchName,
+ Merge: branchRef,
+ }
+ if o.RemoteName == "" {
+ b.Remote = "origin"
+ } else {
+ b.Remote = o.RemoteName
+ }
+ if err := r.CreateBranch(b); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+const (
+ refspecTag = "+refs/tags/%s:refs/tags/%[1]s"
+ refspecSingleBranch = "+refs/heads/%s:refs/remotes/%s/%[1]s"
+ refspecSingleBranchHEAD = "+HEAD:refs/remotes/%s/HEAD"
+)
+
+func (r *Repository) cloneRefSpec(o *CloneOptions) []config.RefSpec {
+ switch {
+ case o.ReferenceName.IsTag():
+ return []config.RefSpec{
+ config.RefSpec(fmt.Sprintf(refspecTag, o.ReferenceName.Short())),
+ }
+ case o.SingleBranch && o.ReferenceName == plumbing.HEAD:
+ return []config.RefSpec{
+ config.RefSpec(fmt.Sprintf(refspecSingleBranchHEAD, o.RemoteName)),
+ config.RefSpec(fmt.Sprintf(refspecSingleBranch, plumbing.Master.Short(), o.RemoteName)),
+ }
+ case o.SingleBranch:
+ return []config.RefSpec{
+ config.RefSpec(fmt.Sprintf(refspecSingleBranch, o.ReferenceName.Short(), o.RemoteName)),
+ }
+ default:
+ return []config.RefSpec{
+ config.RefSpec(fmt.Sprintf(config.DefaultFetchRefSpec, o.RemoteName)),
+ }
+ }
+}
+
+func (r *Repository) setIsBare(isBare bool) error {
+ cfg, err := r.Storer.Config()
+ if err != nil {
+ return err
+ }
+
+ cfg.Core.IsBare = isBare
+ return r.Storer.SetConfig(cfg)
+}
+
+func (r *Repository) updateRemoteConfigIfNeeded(o *CloneOptions, c *config.RemoteConfig, head *plumbing.Reference) error {
+ if !o.SingleBranch {
+ return nil
+ }
+
+ c.Fetch = r.cloneRefSpec(o)
+
+ cfg, err := r.Storer.Config()
+ if err != nil {
+ return err
+ }
+
+ cfg.Remotes[c.Name] = c
+ return r.Storer.SetConfig(cfg)
+}
+
+func (r *Repository) fetchAndUpdateReferences(
+ ctx context.Context, o *FetchOptions, ref plumbing.ReferenceName,
+) (*plumbing.Reference, error) {
+
+ if err := o.Validate(); err != nil {
+ return nil, err
+ }
+
+ remote, err := r.Remote(o.RemoteName)
+ if err != nil {
+ return nil, err
+ }
+
+ objsUpdated := true
+ remoteRefs, err := remote.fetch(ctx, o)
+ if err == NoErrAlreadyUpToDate {
+ objsUpdated = false
+ } else if err != nil {
+ return nil, err
+ }
+
+ resolvedRef, err := storer.ResolveReference(remoteRefs, ref)
+ if err != nil {
+ return nil, err
+ }
+
+ refsUpdated, err := r.updateReferences(remote.c.Fetch, resolvedRef)
+ if err != nil {
+ return nil, err
+ }
+
+ if !objsUpdated && !refsUpdated {
+ return nil, NoErrAlreadyUpToDate
+ }
+
+ return resolvedRef, nil
+}
+
+func (r *Repository) updateReferences(spec []config.RefSpec,
+ resolvedRef *plumbing.Reference) (updated bool, err error) {
+
+ if !resolvedRef.Name().IsBranch() {
+ // Detached HEAD mode
+ h, err := r.resolveToCommitHash(resolvedRef.Hash())
+ if err != nil {
+ return false, err
+ }
+ head := plumbing.NewHashReference(plumbing.HEAD, h)
+ return updateReferenceStorerIfNeeded(r.Storer, head)
+ }
+
+ refs := []*plumbing.Reference{
+ // Create local reference for the resolved ref
+ resolvedRef,
+ // Create local symbolic HEAD
+ plumbing.NewSymbolicReference(plumbing.HEAD, resolvedRef.Name()),
+ }
+
+ refs = append(refs, r.calculateRemoteHeadReference(spec, resolvedRef)...)
+
+ for _, ref := range refs {
+ u, err := updateReferenceStorerIfNeeded(r.Storer, ref)
+ if err != nil {
+ return updated, err
+ }
+
+ if u {
+ updated = true
+ }
+ }
+
+ return
+}
+
+func (r *Repository) calculateRemoteHeadReference(spec []config.RefSpec,
+ resolvedHead *plumbing.Reference) []*plumbing.Reference {
+
+ var refs []*plumbing.Reference
+
+ // Create resolved HEAD reference with remote prefix if it does not
+ // exist. This is needed when using single branch and HEAD.
+ for _, rs := range spec {
+ name := resolvedHead.Name()
+ if !rs.Match(name) {
+ continue
+ }
+
+ name = rs.Dst(name)
+ _, err := r.Storer.Reference(name)
+ if err == plumbing.ErrReferenceNotFound {
+ refs = append(refs, plumbing.NewHashReference(name, resolvedHead.Hash()))
+ }
+ }
+
+ return refs
+}
+
+func checkAndUpdateReferenceStorerIfNeeded(
+ s storer.ReferenceStorer, r, old *plumbing.Reference) (
+ updated bool, err error) {
+ p, err := s.Reference(r.Name())
+ if err != nil && err != plumbing.ErrReferenceNotFound {
+ return false, err
+ }
+
+ // we use the string method to compare references, is the easiest way
+ if err == plumbing.ErrReferenceNotFound || r.String() != p.String() {
+ if err := s.CheckAndSetReference(r, old); err != nil {
+ return false, err
+ }
+
+ return true, nil
+ }
+
+ return false, nil
+}
+
+func updateReferenceStorerIfNeeded(
+ s storer.ReferenceStorer, r *plumbing.Reference) (updated bool, err error) {
+ return checkAndUpdateReferenceStorerIfNeeded(s, r, nil)
+}
+
+// Fetch fetches references along with the objects necessary to complete
+// their histories, from the remote named as FetchOptions.RemoteName.
+//
+// Returns nil if the operation is successful, NoErrAlreadyUpToDate if there are
+// no changes to be fetched, or an error.
+func (r *Repository) Fetch(o *FetchOptions) error {
+ return r.FetchContext(context.Background(), o)
+}
+
+// FetchContext fetches references along with the objects necessary to complete
+// their histories, from the remote named as FetchOptions.RemoteName.
+//
+// Returns nil if the operation is successful, NoErrAlreadyUpToDate if there are
+// no changes to be fetched, or an error.
+//
+// The provided Context must be non-nil. If the context expires before the
+// operation is complete, an error is returned. The context only affects to the
+// transport operations.
+func (r *Repository) FetchContext(ctx context.Context, o *FetchOptions) error {
+ if err := o.Validate(); err != nil {
+ return err
+ }
+
+ remote, err := r.Remote(o.RemoteName)
+ if err != nil {
+ return err
+ }
+
+ return remote.FetchContext(ctx, o)
+}
+
+// Push performs a push to the remote. Returns NoErrAlreadyUpToDate if
+// the remote was already up-to-date, from the remote named as
+// FetchOptions.RemoteName.
+func (r *Repository) Push(o *PushOptions) error {
+ return r.PushContext(context.Background(), o)
+}
+
+// PushContext performs a push to the remote. Returns NoErrAlreadyUpToDate if
+// the remote was already up-to-date, from the remote named as
+// FetchOptions.RemoteName.
+//
+// The provided Context must be non-nil. If the context expires before the
+// operation is complete, an error is returned. The context only affects to the
+// transport operations.
+func (r *Repository) PushContext(ctx context.Context, o *PushOptions) error {
+ if err := o.Validate(); err != nil {
+ return err
+ }
+
+ remote, err := r.Remote(o.RemoteName)
+ if err != nil {
+ return err
+ }
+
+ return remote.PushContext(ctx, o)
+}
+
+// Log returns the commit history from the given LogOptions.
+func (r *Repository) Log(o *LogOptions) (object.CommitIter, error) {
+ h := o.From
+ if o.From == plumbing.ZeroHash {
+ head, err := r.Head()
+ if err != nil {
+ return nil, err
+ }
+
+ h = head.Hash()
+ }
+
+ commit, err := r.CommitObject(h)
+ if err != nil {
+ return nil, err
+ }
+
+ var commitIter object.CommitIter
+ switch o.Order {
+ case LogOrderDefault:
+ commitIter = object.NewCommitPreorderIter(commit, nil, nil)
+ case LogOrderDFS:
+ commitIter = object.NewCommitPreorderIter(commit, nil, nil)
+ case LogOrderDFSPost:
+ commitIter = object.NewCommitPostorderIter(commit, nil)
+ case LogOrderBSF:
+ commitIter = object.NewCommitIterBSF(commit, nil, nil)
+ case LogOrderCommitterTime:
+ commitIter = object.NewCommitIterCTime(commit, nil, nil)
+ default:
+ return nil, fmt.Errorf("invalid Order=%v", o.Order)
+ }
+
+ if o.FileName == nil {
+ return commitIter, nil
+ }
+ return object.NewCommitFileIterFromIter(*o.FileName, commitIter), nil
+}
+
+// Tags returns all the tag References in a repository.
+//
+// If you want to check to see if the tag is an annotated tag, you can call
+// TagObject on the hash Reference passed in through ForEach:
+//
+// iter, err := r.Tags()
+// if err != nil {
+// // Handle error
+// }
+//
+// if err := iter.ForEach(func (ref *plumbing.Reference) error {
+// obj, err := r.TagObject(ref.Hash())
+// switch err {
+// case nil:
+// // Tag object present
+// case plumbing.ErrObjectNotFound:
+// // Not a tag object
+// default:
+// // Some other error
+// return err
+// }
+// }); err != nil {
+// // Handle outer iterator error
+// }
+//
+func (r *Repository) Tags() (storer.ReferenceIter, error) {
+ refIter, err := r.Storer.IterReferences()
+ if err != nil {
+ return nil, err
+ }
+
+ return storer.NewReferenceFilteredIter(
+ func(r *plumbing.Reference) bool {
+ return r.Name().IsTag()
+ }, refIter), nil
+}
+
+// Branches returns all the References that are Branches.
+func (r *Repository) Branches() (storer.ReferenceIter, error) {
+ refIter, err := r.Storer.IterReferences()
+ if err != nil {
+ return nil, err
+ }
+
+ return storer.NewReferenceFilteredIter(
+ func(r *plumbing.Reference) bool {
+ return r.Name().IsBranch()
+ }, refIter), nil
+}
+
+// Notes returns all the References that are notes. For more information:
+// https://git-scm.com/docs/git-notes
+func (r *Repository) Notes() (storer.ReferenceIter, error) {
+ refIter, err := r.Storer.IterReferences()
+ if err != nil {
+ return nil, err
+ }
+
+ return storer.NewReferenceFilteredIter(
+ func(r *plumbing.Reference) bool {
+ return r.Name().IsNote()
+ }, refIter), nil
+}
+
+// TreeObject return a Tree with the given hash. If not found
+// plumbing.ErrObjectNotFound is returned
+func (r *Repository) TreeObject(h plumbing.Hash) (*object.Tree, error) {
+ return object.GetTree(r.Storer, h)
+}
+
+// TreeObjects returns an unsorted TreeIter with all the trees in the repository
+func (r *Repository) TreeObjects() (*object.TreeIter, error) {
+ iter, err := r.Storer.IterEncodedObjects(plumbing.TreeObject)
+ if err != nil {
+ return nil, err
+ }
+
+ return object.NewTreeIter(r.Storer, iter), nil
+}
+
+// CommitObject return a Commit with the given hash. If not found
+// plumbing.ErrObjectNotFound is returned.
+func (r *Repository) CommitObject(h plumbing.Hash) (*object.Commit, error) {
+ return object.GetCommit(r.Storer, h)
+}
+
+// CommitObjects returns an unsorted CommitIter with all the commits in the repository.
+func (r *Repository) CommitObjects() (object.CommitIter, error) {
+ iter, err := r.Storer.IterEncodedObjects(plumbing.CommitObject)
+ if err != nil {
+ return nil, err
+ }
+
+ return object.NewCommitIter(r.Storer, iter), nil
+}
+
+// BlobObject returns a Blob with the given hash. If not found
+// plumbing.ErrObjectNotFound is returned.
+func (r *Repository) BlobObject(h plumbing.Hash) (*object.Blob, error) {
+ return object.GetBlob(r.Storer, h)
+}
+
+// BlobObjects returns an unsorted BlobIter with all the blobs in the repository.
+func (r *Repository) BlobObjects() (*object.BlobIter, error) {
+ iter, err := r.Storer.IterEncodedObjects(plumbing.BlobObject)
+ if err != nil {
+ return nil, err
+ }
+
+ return object.NewBlobIter(r.Storer, iter), nil
+}
+
+// TagObject returns a Tag with the given hash. If not found
+// plumbing.ErrObjectNotFound is returned. This method only returns
+// annotated Tags, no lightweight Tags.
+func (r *Repository) TagObject(h plumbing.Hash) (*object.Tag, error) {
+ return object.GetTag(r.Storer, h)
+}
+
+// TagObjects returns a unsorted TagIter that can step through all of the annotated
+// tags in the repository.
+func (r *Repository) TagObjects() (*object.TagIter, error) {
+ iter, err := r.Storer.IterEncodedObjects(plumbing.TagObject)
+ if err != nil {
+ return nil, err
+ }
+
+ return object.NewTagIter(r.Storer, iter), nil
+}
+
+// Object returns an Object with the given hash. If not found
+// plumbing.ErrObjectNotFound is returned.
+func (r *Repository) Object(t plumbing.ObjectType, h plumbing.Hash) (object.Object, error) {
+ obj, err := r.Storer.EncodedObject(t, h)
+ if err != nil {
+ return nil, err
+ }
+
+ return object.DecodeObject(r.Storer, obj)
+}
+
+// Objects returns an unsorted ObjectIter with all the objects in the repository.
+func (r *Repository) Objects() (*object.ObjectIter, error) {
+ iter, err := r.Storer.IterEncodedObjects(plumbing.AnyObject)
+ if err != nil {
+ return nil, err
+ }
+
+ return object.NewObjectIter(r.Storer, iter), nil
+}
+
+// Head returns the reference where HEAD is pointing to.
+func (r *Repository) Head() (*plumbing.Reference, error) {
+ return storer.ResolveReference(r.Storer, plumbing.HEAD)
+}
+
+// Reference returns the reference for a given reference name. If resolved is
+// true, any symbolic reference will be resolved.
+func (r *Repository) Reference(name plumbing.ReferenceName, resolved bool) (
+ *plumbing.Reference, error) {
+
+ if resolved {
+ return storer.ResolveReference(r.Storer, name)
+ }
+
+ return r.Storer.Reference(name)
+}
+
+// References returns an unsorted ReferenceIter for all references.
+func (r *Repository) References() (storer.ReferenceIter, error) {
+ return r.Storer.IterReferences()
+}
+
+// Worktree returns a worktree based on the given fs, if nil the default
+// worktree will be used.
+func (r *Repository) Worktree() (*Worktree, error) {
+ if r.wt == nil {
+ return nil, ErrIsBareRepository
+ }
+
+ return &Worktree{r: r, Filesystem: r.wt}, nil
+}
+
+func countTrue(vals ...bool) int {
+ sum := 0
+ for _, v := range vals {
+ if v {
+ sum++
+ }
+ }
+ return sum
+}
+
+// ResolveRevision resolves revision to corresponding hash. It will always
+// resolve to a commit hash, not a tree or annotated tag.
+//
+// Implemented resolvers : HEAD, branch, tag, heads/branch, refs/heads/branch,
+// refs/tags/tag, refs/remotes/origin/branch, refs/remotes/origin/HEAD, tilde and caret (HEAD~1, master~^, tag~2, ref/heads/master~1, ...), selection by text (HEAD^{/fix nasty bug})
+func (r *Repository) ResolveRevision(rev plumbing.Revision) (*plumbing.Hash, error) {
+ p := revision.NewParserFromString(string(rev))
+
+ items, err := p.Parse()
+
+ if err != nil {
+ return nil, err
+ }
+
+ var commit *object.Commit
+
+ for _, item := range items {
+ switch item.(type) {
+ case revision.Ref:
+ revisionRef := item.(revision.Ref)
+ var ref *plumbing.Reference
+ var hashCommit, refCommit, tagCommit *object.Commit
+ var rErr, hErr, tErr error
+
+ for _, rule := range append([]string{"%s"}, plumbing.RefRevParseRules...) {
+ ref, err = storer.ResolveReference(r.Storer, plumbing.ReferenceName(fmt.Sprintf(rule, revisionRef)))
+
+ if err == nil {
+ break
+ }
+ }
+
+ if ref != nil {
+ tag, tObjErr := r.TagObject(ref.Hash())
+ if tObjErr != nil {
+ tErr = tObjErr
+ } else {
+ tagCommit, tErr = tag.Commit()
+ }
+ refCommit, rErr = r.CommitObject(ref.Hash())
+ } else {
+ rErr = plumbing.ErrReferenceNotFound
+ tErr = plumbing.ErrReferenceNotFound
+ }
+
+ maybeHash := plumbing.NewHash(string(revisionRef)).String() == string(revisionRef)
+ if maybeHash {
+ hashCommit, hErr = r.CommitObject(plumbing.NewHash(string(revisionRef)))
+ } else {
+ hErr = plumbing.ErrReferenceNotFound
+ }
+
+ isTag := tErr == nil
+ isCommit := rErr == nil
+ isHash := hErr == nil
+
+ switch {
+ case countTrue(isTag, isCommit, isHash) > 1:
+ return &plumbing.ZeroHash, fmt.Errorf(`refname "%s" is ambiguous`, revisionRef)
+ case isTag:
+ commit = tagCommit
+ case isCommit:
+ commit = refCommit
+ case isHash:
+ commit = hashCommit
+ default:
+ return &plumbing.ZeroHash, plumbing.ErrReferenceNotFound
+ }
+ case revision.CaretPath:
+ depth := item.(revision.CaretPath).Depth
+
+ if depth == 0 {
+ break
+ }
+
+ iter := commit.Parents()
+
+ c, err := iter.Next()
+
+ if err != nil {
+ return &plumbing.ZeroHash, err
+ }
+
+ if depth == 1 {
+ commit = c
+
+ break
+ }
+
+ c, err = iter.Next()
+
+ if err != nil {
+ return &plumbing.ZeroHash, err
+ }
+
+ commit = c
+ case revision.TildePath:
+ for i := 0; i < item.(revision.TildePath).Depth; i++ {
+ c, err := commit.Parents().Next()
+
+ if err != nil {
+ return &plumbing.ZeroHash, err
+ }
+
+ commit = c
+ }
+ case revision.CaretReg:
+ history := object.NewCommitPreorderIter(commit, nil, nil)
+
+ re := item.(revision.CaretReg).Regexp
+ negate := item.(revision.CaretReg).Negate
+
+ var c *object.Commit
+
+ err := history.ForEach(func(hc *object.Commit) error {
+ if !negate && re.MatchString(hc.Message) {
+ c = hc
+ return storer.ErrStop
+ }
+
+ if negate && !re.MatchString(hc.Message) {
+ c = hc
+ return storer.ErrStop
+ }
+
+ return nil
+ })
+ if err != nil {
+ return &plumbing.ZeroHash, err
+ }
+
+ if c == nil {
+ return &plumbing.ZeroHash, fmt.Errorf(`No commit message match regexp : "%s"`, re.String())
+ }
+
+ commit = c
+ }
+ }
+
+ return &commit.Hash, nil
+}
+
+type RepackConfig struct {
+ // UseRefDeltas configures whether packfile encoder will use reference deltas.
+ // By default OFSDeltaObject is used.
+ UseRefDeltas bool
+ // OnlyDeletePacksOlderThan if set to non-zero value
+ // selects only objects older than the time provided.
+ OnlyDeletePacksOlderThan time.Time
+}
+
+func (r *Repository) RepackObjects(cfg *RepackConfig) (err error) {
+ pos, ok := r.Storer.(storer.PackedObjectStorer)
+ if !ok {
+ return ErrPackedObjectsNotSupported
+ }
+
+ // Get the existing object packs.
+ hs, err := pos.ObjectPacks()
+ if err != nil {
+ return err
+ }
+
+ // Create a new pack.
+ nh, err := r.createNewObjectPack(cfg)
+ if err != nil {
+ return err
+ }
+
+ // Delete old packs.
+ for _, h := range hs {
+ // Skip if new hash is the same as an old one.
+ if h == nh {
+ continue
+ }
+ err = pos.DeleteOldObjectPackAndIndex(h, cfg.OnlyDeletePacksOlderThan)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// createNewObjectPack is a helper for RepackObjects taking care
+// of creating a new pack. It is used so the the PackfileWriter
+// deferred close has the right scope.
+func (r *Repository) createNewObjectPack(cfg *RepackConfig) (h plumbing.Hash, err error) {
+ ow := newObjectWalker(r.Storer)
+ err = ow.walkAllRefs()
+ if err != nil {
+ return h, err
+ }
+ objs := make([]plumbing.Hash, 0, len(ow.seen))
+ for h := range ow.seen {
+ objs = append(objs, h)
+ }
+ pfw, ok := r.Storer.(storer.PackfileWriter)
+ if !ok {
+ return h, fmt.Errorf("Repository storer is not a storer.PackfileWriter")
+ }
+ wc, err := pfw.PackfileWriter()
+ if err != nil {
+ return h, err
+ }
+ defer ioutil.CheckClose(wc, &err)
+ scfg, err := r.Storer.Config()
+ if err != nil {
+ return h, err
+ }
+ enc := packfile.NewEncoder(wc, r.Storer, cfg.UseRefDeltas)
+ h, err = enc.Encode(objs, scfg.Pack.Window)
+ if err != nil {
+ return h, err
+ }
+
+ // Delete the packed, loose objects.
+ if los, ok := r.Storer.(storer.LooseObjectStorer); ok {
+ err = los.ForEachObjectHash(func(hash plumbing.Hash) error {
+ if ow.isSeen(hash) {
+ err = los.DeleteLooseObject(hash)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+ if err != nil {
+ return h, err
+ }
+ }
+
+ return h, err
+}
--- /dev/null
+package git
+
+import (
+ "bytes"
+ "fmt"
+ "path/filepath"
+)
+
+// Status represents the current status of a Worktree.
+// The key of the map is the path of the file.
+type Status map[string]*FileStatus
+
+// File returns the FileStatus for a given path, if the FileStatus doesn't
+// exists a new FileStatus is added to the map using the path as key.
+func (s Status) File(path string) *FileStatus {
+ if _, ok := (s)[path]; !ok {
+ s[path] = &FileStatus{Worktree: Untracked, Staging: Untracked}
+ }
+
+ return s[path]
+}
+
+// IsUntracked checks if file for given path is 'Untracked'
+func (s Status) IsUntracked(path string) bool {
+ stat, ok := (s)[filepath.ToSlash(path)]
+ return ok && stat.Worktree == Untracked
+}
+
+// IsClean returns true if all the files are in Unmodified status.
+func (s Status) IsClean() bool {
+ for _, status := range s {
+ if status.Worktree != Unmodified || status.Staging != Unmodified {
+ return false
+ }
+ }
+
+ return true
+}
+
+func (s Status) String() string {
+ buf := bytes.NewBuffer(nil)
+ for path, status := range s {
+ if status.Staging == Unmodified && status.Worktree == Unmodified {
+ continue
+ }
+
+ if status.Staging == Renamed {
+ path = fmt.Sprintf("%s -> %s", path, status.Extra)
+ }
+
+ fmt.Fprintf(buf, "%c%c %s\n", status.Staging, status.Worktree, path)
+ }
+
+ return buf.String()
+}
+
+// FileStatus contains the status of a file in the worktree
+type FileStatus struct {
+ // Staging is the status of a file in the staging area
+ Staging StatusCode
+ // Worktree is the status of a file in the worktree
+ Worktree StatusCode
+ // Extra contains extra information, such as the previous name in a rename
+ Extra string
+}
+
+// StatusCode status code of a file in the Worktree
+type StatusCode byte
+
+const (
+ Unmodified StatusCode = ' '
+ Untracked StatusCode = '?'
+ Modified StatusCode = 'M'
+ Added StatusCode = 'A'
+ Deleted StatusCode = 'D'
+ Renamed StatusCode = 'R'
+ Copied StatusCode = 'C'
+ UpdatedButUnmerged StatusCode = 'U'
+)
--- /dev/null
+package filesystem
+
+import (
+ stdioutil "io/ioutil"
+ "os"
+
+ "gopkg.in/src-d/go-git.v4/config"
+ "gopkg.in/src-d/go-git.v4/storage/filesystem/dotgit"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+type ConfigStorage struct {
+ dir *dotgit.DotGit
+}
+
+func (c *ConfigStorage) Config() (conf *config.Config, err error) {
+ cfg := config.NewConfig()
+
+ f, err := c.dir.Config()
+ if err != nil {
+ if os.IsNotExist(err) {
+ return cfg, nil
+ }
+
+ return nil, err
+ }
+
+ defer ioutil.CheckClose(f, &err)
+
+ b, err := stdioutil.ReadAll(f)
+ if err != nil {
+ return nil, err
+ }
+
+ if err = cfg.Unmarshal(b); err != nil {
+ return nil, err
+ }
+
+ return cfg, err
+}
+
+func (c *ConfigStorage) SetConfig(cfg *config.Config) (err error) {
+ if err = cfg.Validate(); err != nil {
+ return err
+ }
+
+ f, err := c.dir.ConfigWriter()
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(f, &err)
+
+ b, err := cfg.Marshal()
+ if err != nil {
+ return err
+ }
+
+ _, err = f.Write(b)
+ return err
+}
--- /dev/null
+package filesystem
+
+import (
+ "gopkg.in/src-d/go-git.v4/plumbing"
+)
+
+type deltaObject struct {
+ plumbing.EncodedObject
+ base plumbing.Hash
+ hash plumbing.Hash
+ size int64
+}
+
+func newDeltaObject(
+ obj plumbing.EncodedObject,
+ hash plumbing.Hash,
+ base plumbing.Hash,
+ size int64) plumbing.DeltaObject {
+ return &deltaObject{
+ EncodedObject: obj,
+ hash: hash,
+ base: base,
+ size: size,
+ }
+}
+
+func (o *deltaObject) BaseHash() plumbing.Hash {
+ return o.base
+}
+
+func (o *deltaObject) ActualSize() int64 {
+ return o.size
+}
+
+func (o *deltaObject) ActualHash() plumbing.Hash {
+ return o.hash
+}
--- /dev/null
+// https://github.com/git/git/blob/master/Documentation/gitrepository-layout.txt
+package dotgit
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "io"
+ stdioutil "io/ioutil"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "gopkg.in/src-d/go-billy.v4/osfs"
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+
+ "gopkg.in/src-d/go-billy.v4"
+)
+
+const (
+ suffix = ".git"
+ packedRefsPath = "packed-refs"
+ configPath = "config"
+ indexPath = "index"
+ shallowPath = "shallow"
+ modulePath = "modules"
+ objectsPath = "objects"
+ packPath = "pack"
+ refsPath = "refs"
+
+ tmpPackedRefsPrefix = "._packed-refs"
+
+ packExt = ".pack"
+ idxExt = ".idx"
+)
+
+var (
+ // ErrNotFound is returned by New when the path is not found.
+ ErrNotFound = errors.New("path not found")
+ // ErrIdxNotFound is returned by Idxfile when the idx file is not found
+ ErrIdxNotFound = errors.New("idx file not found")
+ // ErrPackfileNotFound is returned by Packfile when the packfile is not found
+ ErrPackfileNotFound = errors.New("packfile not found")
+ // ErrConfigNotFound is returned by Config when the config is not found
+ ErrConfigNotFound = errors.New("config file not found")
+ // ErrPackedRefsDuplicatedRef is returned when a duplicated reference is
+ // found in the packed-ref file. This is usually the case for corrupted git
+ // repositories.
+ ErrPackedRefsDuplicatedRef = errors.New("duplicated ref found in packed-ref file")
+ // ErrPackedRefsBadFormat is returned when the packed-ref file corrupt.
+ ErrPackedRefsBadFormat = errors.New("malformed packed-ref")
+ // ErrSymRefTargetNotFound is returned when a symbolic reference is
+ // targeting a non-existing object. This usually means the repository
+ // is corrupt.
+ ErrSymRefTargetNotFound = errors.New("symbolic reference target not found")
+)
+
+// Options holds configuration for the storage.
+type Options struct {
+ // ExclusiveAccess means that the filesystem is not modified externally
+ // while the repo is open.
+ ExclusiveAccess bool
+ // KeepDescriptors makes the file descriptors to be reused but they will
+ // need to be manually closed calling Close().
+ KeepDescriptors bool
+}
+
+// The DotGit type represents a local git repository on disk. This
+// type is not zero-value-safe, use the New function to initialize it.
+type DotGit struct {
+ options Options
+ fs billy.Filesystem
+
+ // incoming object directory information
+ incomingChecked bool
+ incomingDirName string
+
+ objectList []plumbing.Hash
+ objectMap map[plumbing.Hash]struct{}
+ packList []plumbing.Hash
+ packMap map[plumbing.Hash]struct{}
+
+ files map[string]billy.File
+}
+
+// New returns a DotGit value ready to be used. The path argument must
+// be the absolute path of a git repository directory (e.g.
+// "/foo/bar/.git").
+func New(fs billy.Filesystem) *DotGit {
+ return NewWithOptions(fs, Options{})
+}
+
+// NewWithOptions sets non default configuration options.
+// See New for complete help.
+func NewWithOptions(fs billy.Filesystem, o Options) *DotGit {
+ return &DotGit{
+ options: o,
+ fs: fs,
+ }
+}
+
+// Initialize creates all the folder scaffolding.
+func (d *DotGit) Initialize() error {
+ mustExists := []string{
+ d.fs.Join("objects", "info"),
+ d.fs.Join("objects", "pack"),
+ d.fs.Join("refs", "heads"),
+ d.fs.Join("refs", "tags"),
+ }
+
+ for _, path := range mustExists {
+ _, err := d.fs.Stat(path)
+ if err == nil {
+ continue
+ }
+
+ if !os.IsNotExist(err) {
+ return err
+ }
+
+ if err := d.fs.MkdirAll(path, os.ModeDir|os.ModePerm); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Close closes all opened files.
+func (d *DotGit) Close() error {
+ var firstError error
+ if d.files != nil {
+ for _, f := range d.files {
+ err := f.Close()
+ if err != nil && firstError == nil {
+ firstError = err
+ continue
+ }
+ }
+
+ d.files = nil
+ }
+
+ if firstError != nil {
+ return firstError
+ }
+
+ return nil
+}
+
+// ConfigWriter returns a file pointer for write to the config file
+func (d *DotGit) ConfigWriter() (billy.File, error) {
+ return d.fs.Create(configPath)
+}
+
+// Config returns a file pointer for read to the config file
+func (d *DotGit) Config() (billy.File, error) {
+ return d.fs.Open(configPath)
+}
+
+// IndexWriter returns a file pointer for write to the index file
+func (d *DotGit) IndexWriter() (billy.File, error) {
+ return d.fs.Create(indexPath)
+}
+
+// Index returns a file pointer for read to the index file
+func (d *DotGit) Index() (billy.File, error) {
+ return d.fs.Open(indexPath)
+}
+
+// ShallowWriter returns a file pointer for write to the shallow file
+func (d *DotGit) ShallowWriter() (billy.File, error) {
+ return d.fs.Create(shallowPath)
+}
+
+// Shallow returns a file pointer for read to the shallow file
+func (d *DotGit) Shallow() (billy.File, error) {
+ f, err := d.fs.Open(shallowPath)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+
+ return nil, err
+ }
+
+ return f, nil
+}
+
+// NewObjectPack return a writer for a new packfile, it saves the packfile to
+// disk and also generates and save the index for the given packfile.
+func (d *DotGit) NewObjectPack() (*PackWriter, error) {
+ d.cleanPackList()
+ return newPackWrite(d.fs)
+}
+
+// ObjectPacks returns the list of availables packfiles
+func (d *DotGit) ObjectPacks() ([]plumbing.Hash, error) {
+ if !d.options.ExclusiveAccess {
+ return d.objectPacks()
+ }
+
+ err := d.genPackList()
+ if err != nil {
+ return nil, err
+ }
+
+ return d.packList, nil
+}
+
+func (d *DotGit) objectPacks() ([]plumbing.Hash, error) {
+ packDir := d.fs.Join(objectsPath, packPath)
+ files, err := d.fs.ReadDir(packDir)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+
+ return nil, err
+ }
+
+ var packs []plumbing.Hash
+ for _, f := range files {
+ if !strings.HasSuffix(f.Name(), packExt) {
+ continue
+ }
+
+ n := f.Name()
+ h := plumbing.NewHash(n[5 : len(n)-5]) //pack-(hash).pack
+ if h.IsZero() {
+ // Ignore files with badly-formatted names.
+ continue
+ }
+ packs = append(packs, h)
+ }
+
+ return packs, nil
+}
+
+func (d *DotGit) objectPackPath(hash plumbing.Hash, extension string) string {
+ return d.fs.Join(objectsPath, packPath, fmt.Sprintf("pack-%s.%s", hash.String(), extension))
+}
+
+func (d *DotGit) objectPackOpen(hash plumbing.Hash, extension string) (billy.File, error) {
+ if d.files == nil {
+ d.files = make(map[string]billy.File)
+ }
+
+ err := d.hasPack(hash)
+ if err != nil {
+ return nil, err
+ }
+
+ path := d.objectPackPath(hash, extension)
+ f, ok := d.files[path]
+ if ok {
+ return f, nil
+ }
+
+ pack, err := d.fs.Open(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, ErrPackfileNotFound
+ }
+
+ return nil, err
+ }
+
+ if d.options.KeepDescriptors && extension == "pack" {
+ d.files[path] = pack
+ }
+
+ return pack, nil
+}
+
+// ObjectPack returns a fs.File of the given packfile
+func (d *DotGit) ObjectPack(hash plumbing.Hash) (billy.File, error) {
+ err := d.hasPack(hash)
+ if err != nil {
+ return nil, err
+ }
+
+ return d.objectPackOpen(hash, `pack`)
+}
+
+// ObjectPackIdx returns a fs.File of the index file for a given packfile
+func (d *DotGit) ObjectPackIdx(hash plumbing.Hash) (billy.File, error) {
+ err := d.hasPack(hash)
+ if err != nil {
+ return nil, err
+ }
+
+ return d.objectPackOpen(hash, `idx`)
+}
+
+func (d *DotGit) DeleteOldObjectPackAndIndex(hash plumbing.Hash, t time.Time) error {
+ d.cleanPackList()
+
+ path := d.objectPackPath(hash, `pack`)
+ if !t.IsZero() {
+ fi, err := d.fs.Stat(path)
+ if err != nil {
+ return err
+ }
+ // too new, skip deletion.
+ if !fi.ModTime().Before(t) {
+ return nil
+ }
+ }
+ err := d.fs.Remove(path)
+ if err != nil {
+ return err
+ }
+ return d.fs.Remove(d.objectPackPath(hash, `idx`))
+}
+
+// NewObject return a writer for a new object file.
+func (d *DotGit) NewObject() (*ObjectWriter, error) {
+ d.cleanObjectList()
+
+ return newObjectWriter(d.fs)
+}
+
+// Objects returns a slice with the hashes of objects found under the
+// .git/objects/ directory.
+func (d *DotGit) Objects() ([]plumbing.Hash, error) {
+ if d.options.ExclusiveAccess {
+ err := d.genObjectList()
+ if err != nil {
+ return nil, err
+ }
+
+ return d.objectList, nil
+ }
+
+ var objects []plumbing.Hash
+ err := d.ForEachObjectHash(func(hash plumbing.Hash) error {
+ objects = append(objects, hash)
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+ return objects, nil
+}
+
+// ForEachObjectHash iterates over the hashes of objects found under the
+// .git/objects/ directory and executes the provided function.
+func (d *DotGit) ForEachObjectHash(fun func(plumbing.Hash) error) error {
+ if !d.options.ExclusiveAccess {
+ return d.forEachObjectHash(fun)
+ }
+
+ err := d.genObjectList()
+ if err != nil {
+ return err
+ }
+
+ for _, h := range d.objectList {
+ err := fun(h)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (d *DotGit) forEachObjectHash(fun func(plumbing.Hash) error) error {
+ files, err := d.fs.ReadDir(objectsPath)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil
+ }
+
+ return err
+ }
+
+ for _, f := range files {
+ if f.IsDir() && len(f.Name()) == 2 && isHex(f.Name()) {
+ base := f.Name()
+ d, err := d.fs.ReadDir(d.fs.Join(objectsPath, base))
+ if err != nil {
+ return err
+ }
+
+ for _, o := range d {
+ h := plumbing.NewHash(base + o.Name())
+ if h.IsZero() {
+ // Ignore files with badly-formatted names.
+ continue
+ }
+ err = fun(h)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+
+ return nil
+}
+
+func (d *DotGit) cleanObjectList() {
+ d.objectMap = nil
+ d.objectList = nil
+}
+
+func (d *DotGit) genObjectList() error {
+ if d.objectMap != nil {
+ return nil
+ }
+
+ d.objectMap = make(map[plumbing.Hash]struct{})
+ return d.forEachObjectHash(func(h plumbing.Hash) error {
+ d.objectList = append(d.objectList, h)
+ d.objectMap[h] = struct{}{}
+
+ return nil
+ })
+}
+
+func (d *DotGit) hasObject(h plumbing.Hash) error {
+ if !d.options.ExclusiveAccess {
+ return nil
+ }
+
+ err := d.genObjectList()
+ if err != nil {
+ return err
+ }
+
+ _, ok := d.objectMap[h]
+ if !ok {
+ return plumbing.ErrObjectNotFound
+ }
+
+ return nil
+}
+
+func (d *DotGit) cleanPackList() {
+ d.packMap = nil
+ d.packList = nil
+}
+
+func (d *DotGit) genPackList() error {
+ if d.packMap != nil {
+ return nil
+ }
+
+ op, err := d.objectPacks()
+ if err != nil {
+ return err
+ }
+
+ d.packMap = make(map[plumbing.Hash]struct{})
+ d.packList = nil
+
+ for _, h := range op {
+ d.packList = append(d.packList, h)
+ d.packMap[h] = struct{}{}
+ }
+
+ return nil
+}
+
+func (d *DotGit) hasPack(h plumbing.Hash) error {
+ if !d.options.ExclusiveAccess {
+ return nil
+ }
+
+ err := d.genPackList()
+ if err != nil {
+ return err
+ }
+
+ _, ok := d.packMap[h]
+ if !ok {
+ return ErrPackfileNotFound
+ }
+
+ return nil
+}
+
+func (d *DotGit) objectPath(h plumbing.Hash) string {
+ hash := h.String()
+ return d.fs.Join(objectsPath, hash[0:2], hash[2:40])
+}
+
+// incomingObjectPath is intended to add support for a git pre-receive hook
+// to be written it adds support for go-git to find objects in an "incoming"
+// directory, so that the library can be used to write a pre-receive hook
+// that deals with the incoming objects.
+//
+// More on git hooks found here : https://git-scm.com/docs/githooks
+// More on 'quarantine'/incoming directory here:
+// https://git-scm.com/docs/git-receive-pack
+func (d *DotGit) incomingObjectPath(h plumbing.Hash) string {
+ hString := h.String()
+
+ if d.incomingDirName == "" {
+ return d.fs.Join(objectsPath, hString[0:2], hString[2:40])
+ }
+
+ return d.fs.Join(objectsPath, d.incomingDirName, hString[0:2], hString[2:40])
+}
+
+// hasIncomingObjects searches for an incoming directory and keeps its name
+// so it doesn't have to be found each time an object is accessed.
+func (d *DotGit) hasIncomingObjects() bool {
+ if !d.incomingChecked {
+ directoryContents, err := d.fs.ReadDir(objectsPath)
+ if err == nil {
+ for _, file := range directoryContents {
+ if strings.HasPrefix(file.Name(), "incoming-") && file.IsDir() {
+ d.incomingDirName = file.Name()
+ }
+ }
+ }
+
+ d.incomingChecked = true
+ }
+
+ return d.incomingDirName != ""
+}
+
+// Object returns a fs.File pointing the object file, if exists
+func (d *DotGit) Object(h plumbing.Hash) (billy.File, error) {
+ err := d.hasObject(h)
+ if err != nil {
+ return nil, err
+ }
+
+ obj1, err1 := d.fs.Open(d.objectPath(h))
+ if os.IsNotExist(err1) && d.hasIncomingObjects() {
+ obj2, err2 := d.fs.Open(d.incomingObjectPath(h))
+ if err2 != nil {
+ return obj1, err1
+ }
+ return obj2, err2
+ }
+ return obj1, err1
+}
+
+// ObjectStat returns a os.FileInfo pointing the object file, if exists
+func (d *DotGit) ObjectStat(h plumbing.Hash) (os.FileInfo, error) {
+ err := d.hasObject(h)
+ if err != nil {
+ return nil, err
+ }
+
+ obj1, err1 := d.fs.Stat(d.objectPath(h))
+ if os.IsNotExist(err1) && d.hasIncomingObjects() {
+ obj2, err2 := d.fs.Stat(d.incomingObjectPath(h))
+ if err2 != nil {
+ return obj1, err1
+ }
+ return obj2, err2
+ }
+ return obj1, err1
+}
+
+// ObjectDelete removes the object file, if exists
+func (d *DotGit) ObjectDelete(h plumbing.Hash) error {
+ d.cleanObjectList()
+
+ err1 := d.fs.Remove(d.objectPath(h))
+ if os.IsNotExist(err1) && d.hasIncomingObjects() {
+ err2 := d.fs.Remove(d.incomingObjectPath(h))
+ if err2 != nil {
+ return err1
+ }
+ return err2
+ }
+ return err1
+}
+
+func (d *DotGit) readReferenceFrom(rd io.Reader, name string) (ref *plumbing.Reference, err error) {
+ b, err := stdioutil.ReadAll(rd)
+ if err != nil {
+ return nil, err
+ }
+
+ line := strings.TrimSpace(string(b))
+ return plumbing.NewReferenceFromStrings(name, line), nil
+}
+
+func (d *DotGit) checkReferenceAndTruncate(f billy.File, old *plumbing.Reference) error {
+ if old == nil {
+ return nil
+ }
+ ref, err := d.readReferenceFrom(f, old.Name().String())
+ if err != nil {
+ return err
+ }
+ if ref.Hash() != old.Hash() {
+ return fmt.Errorf("reference has changed concurrently")
+ }
+ _, err = f.Seek(0, io.SeekStart)
+ if err != nil {
+ return err
+ }
+ return f.Truncate(0)
+}
+
+func (d *DotGit) SetRef(r, old *plumbing.Reference) error {
+ var content string
+ switch r.Type() {
+ case plumbing.SymbolicReference:
+ content = fmt.Sprintf("ref: %s\n", r.Target())
+ case plumbing.HashReference:
+ content = fmt.Sprintln(r.Hash().String())
+ }
+
+ fileName := r.Name().String()
+
+ return d.setRef(fileName, content, old)
+}
+
+// Refs scans the git directory collecting references, which it returns.
+// Symbolic references are resolved and included in the output.
+func (d *DotGit) Refs() ([]*plumbing.Reference, error) {
+ var refs []*plumbing.Reference
+ var seen = make(map[plumbing.ReferenceName]bool)
+ if err := d.addRefsFromRefDir(&refs, seen); err != nil {
+ return nil, err
+ }
+
+ if err := d.addRefsFromPackedRefs(&refs, seen); err != nil {
+ return nil, err
+ }
+
+ if err := d.addRefFromHEAD(&refs); err != nil {
+ return nil, err
+ }
+
+ return refs, nil
+}
+
+// Ref returns the reference for a given reference name.
+func (d *DotGit) Ref(name plumbing.ReferenceName) (*plumbing.Reference, error) {
+ ref, err := d.readReferenceFile(".", name.String())
+ if err == nil {
+ return ref, nil
+ }
+
+ return d.packedRef(name)
+}
+
+func (d *DotGit) findPackedRefsInFile(f billy.File) ([]*plumbing.Reference, error) {
+ s := bufio.NewScanner(f)
+ var refs []*plumbing.Reference
+ for s.Scan() {
+ ref, err := d.processLine(s.Text())
+ if err != nil {
+ return nil, err
+ }
+
+ if ref != nil {
+ refs = append(refs, ref)
+ }
+ }
+
+ return refs, s.Err()
+}
+
+func (d *DotGit) findPackedRefs() (r []*plumbing.Reference, err error) {
+ f, err := d.fs.Open(packedRefsPath)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+ return nil, err
+ }
+
+ defer ioutil.CheckClose(f, &err)
+ return d.findPackedRefsInFile(f)
+}
+
+func (d *DotGit) packedRef(name plumbing.ReferenceName) (*plumbing.Reference, error) {
+ refs, err := d.findPackedRefs()
+ if err != nil {
+ return nil, err
+ }
+
+ for _, ref := range refs {
+ if ref.Name() == name {
+ return ref, nil
+ }
+ }
+
+ return nil, plumbing.ErrReferenceNotFound
+}
+
+// RemoveRef removes a reference by name.
+func (d *DotGit) RemoveRef(name plumbing.ReferenceName) error {
+ path := d.fs.Join(".", name.String())
+ _, err := d.fs.Stat(path)
+ if err == nil {
+ err = d.fs.Remove(path)
+ // Drop down to remove it from the packed refs file, too.
+ }
+
+ if err != nil && !os.IsNotExist(err) {
+ return err
+ }
+
+ return d.rewritePackedRefsWithoutRef(name)
+}
+
+func (d *DotGit) addRefsFromPackedRefs(refs *[]*plumbing.Reference, seen map[plumbing.ReferenceName]bool) (err error) {
+ packedRefs, err := d.findPackedRefs()
+ if err != nil {
+ return err
+ }
+
+ for _, ref := range packedRefs {
+ if !seen[ref.Name()] {
+ *refs = append(*refs, ref)
+ seen[ref.Name()] = true
+ }
+ }
+ return nil
+}
+
+func (d *DotGit) addRefsFromPackedRefsFile(refs *[]*plumbing.Reference, f billy.File, seen map[plumbing.ReferenceName]bool) (err error) {
+ packedRefs, err := d.findPackedRefsInFile(f)
+ if err != nil {
+ return err
+ }
+
+ for _, ref := range packedRefs {
+ if !seen[ref.Name()] {
+ *refs = append(*refs, ref)
+ seen[ref.Name()] = true
+ }
+ }
+ return nil
+}
+
+func (d *DotGit) openAndLockPackedRefs(doCreate bool) (
+ pr billy.File, err error) {
+ var f billy.File
+ defer func() {
+ if err != nil && f != nil {
+ ioutil.CheckClose(f, &err)
+ }
+ }()
+
+ // File mode is retrieved from a constant defined in the target specific
+ // files (dotgit_rewrite_packed_refs_*). Some modes are not available
+ // in all filesystems.
+ openFlags := d.openAndLockPackedRefsMode()
+ if doCreate {
+ openFlags |= os.O_CREATE
+ }
+
+ // Keep trying to open and lock the file until we're sure the file
+ // didn't change between the open and the lock.
+ for {
+ f, err = d.fs.OpenFile(packedRefsPath, openFlags, 0600)
+ if err != nil {
+ if os.IsNotExist(err) && !doCreate {
+ return nil, nil
+ }
+
+ return nil, err
+ }
+ fi, err := d.fs.Stat(packedRefsPath)
+ if err != nil {
+ return nil, err
+ }
+ mtime := fi.ModTime()
+
+ err = f.Lock()
+ if err != nil {
+ return nil, err
+ }
+
+ fi, err = d.fs.Stat(packedRefsPath)
+ if err != nil {
+ return nil, err
+ }
+ if mtime.Equal(fi.ModTime()) {
+ break
+ }
+ // The file has changed since we opened it. Close and retry.
+ err = f.Close()
+ if err != nil {
+ return nil, err
+ }
+ }
+ return f, nil
+}
+
+func (d *DotGit) rewritePackedRefsWithoutRef(name plumbing.ReferenceName) (err error) {
+ pr, err := d.openAndLockPackedRefs(false)
+ if err != nil {
+ return err
+ }
+ if pr == nil {
+ return nil
+ }
+ defer ioutil.CheckClose(pr, &err)
+
+ // Creating the temp file in the same directory as the target file
+ // improves our chances for rename operation to be atomic.
+ tmp, err := d.fs.TempFile("", tmpPackedRefsPrefix)
+ if err != nil {
+ return err
+ }
+ tmpName := tmp.Name()
+ defer func() {
+ ioutil.CheckClose(tmp, &err)
+ _ = d.fs.Remove(tmpName) // don't check err, we might have renamed it
+ }()
+
+ s := bufio.NewScanner(pr)
+ found := false
+ for s.Scan() {
+ line := s.Text()
+ ref, err := d.processLine(line)
+ if err != nil {
+ return err
+ }
+
+ if ref != nil && ref.Name() == name {
+ found = true
+ continue
+ }
+
+ if _, err := fmt.Fprintln(tmp, line); err != nil {
+ return err
+ }
+ }
+
+ if err := s.Err(); err != nil {
+ return err
+ }
+
+ if !found {
+ return nil
+ }
+
+ return d.rewritePackedRefsWhileLocked(tmp, pr)
+}
+
+// process lines from a packed-refs file
+func (d *DotGit) processLine(line string) (*plumbing.Reference, error) {
+ if len(line) == 0 {
+ return nil, nil
+ }
+
+ switch line[0] {
+ case '#': // comment - ignore
+ return nil, nil
+ case '^': // annotated tag commit of the previous line - ignore
+ return nil, nil
+ default:
+ ws := strings.Split(line, " ") // hash then ref
+ if len(ws) != 2 {
+ return nil, ErrPackedRefsBadFormat
+ }
+
+ return plumbing.NewReferenceFromStrings(ws[1], ws[0]), nil
+ }
+}
+
+func (d *DotGit) addRefsFromRefDir(refs *[]*plumbing.Reference, seen map[plumbing.ReferenceName]bool) error {
+ return d.walkReferencesTree(refs, []string{refsPath}, seen)
+}
+
+func (d *DotGit) walkReferencesTree(refs *[]*plumbing.Reference, relPath []string, seen map[plumbing.ReferenceName]bool) error {
+ files, err := d.fs.ReadDir(d.fs.Join(relPath...))
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil
+ }
+
+ return err
+ }
+
+ for _, f := range files {
+ newRelPath := append(append([]string(nil), relPath...), f.Name())
+ if f.IsDir() {
+ if err = d.walkReferencesTree(refs, newRelPath, seen); err != nil {
+ return err
+ }
+
+ continue
+ }
+
+ ref, err := d.readReferenceFile(".", strings.Join(newRelPath, "/"))
+ if err != nil {
+ return err
+ }
+
+ if ref != nil && !seen[ref.Name()] {
+ *refs = append(*refs, ref)
+ seen[ref.Name()] = true
+ }
+ }
+
+ return nil
+}
+
+func (d *DotGit) addRefFromHEAD(refs *[]*plumbing.Reference) error {
+ ref, err := d.readReferenceFile(".", "HEAD")
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil
+ }
+
+ return err
+ }
+
+ *refs = append(*refs, ref)
+ return nil
+}
+
+func (d *DotGit) readReferenceFile(path, name string) (ref *plumbing.Reference, err error) {
+ path = d.fs.Join(path, d.fs.Join(strings.Split(name, "/")...))
+ f, err := d.fs.Open(path)
+ if err != nil {
+ return nil, err
+ }
+ defer ioutil.CheckClose(f, &err)
+
+ return d.readReferenceFrom(f, name)
+}
+
+func (d *DotGit) CountLooseRefs() (int, error) {
+ var refs []*plumbing.Reference
+ var seen = make(map[plumbing.ReferenceName]bool)
+ if err := d.addRefsFromRefDir(&refs, seen); err != nil {
+ return 0, err
+ }
+
+ return len(refs), nil
+}
+
+// PackRefs packs all loose refs into the packed-refs file.
+//
+// This implementation only works under the assumption that the view
+// of the file system won't be updated during this operation. This
+// strategy would not work on a general file system though, without
+// locking each loose reference and checking it again before deleting
+// the file, because otherwise an updated reference could sneak in and
+// then be deleted by the packed-refs process. Alternatively, every
+// ref update could also lock packed-refs, so only one lock is
+// required during ref-packing. But that would worsen performance in
+// the common case.
+//
+// TODO: add an "all" boolean like the `git pack-refs --all` flag.
+// When `all` is false, it would only pack refs that have already been
+// packed, plus all tags.
+func (d *DotGit) PackRefs() (err error) {
+ // Lock packed-refs, and create it if it doesn't exist yet.
+ f, err := d.openAndLockPackedRefs(true)
+ if err != nil {
+ return err
+ }
+ defer ioutil.CheckClose(f, &err)
+
+ // Gather all refs using addRefsFromRefDir and addRefsFromPackedRefs.
+ var refs []*plumbing.Reference
+ seen := make(map[plumbing.ReferenceName]bool)
+ if err = d.addRefsFromRefDir(&refs, seen); err != nil {
+ return err
+ }
+ if len(refs) == 0 {
+ // Nothing to do!
+ return nil
+ }
+ numLooseRefs := len(refs)
+ if err = d.addRefsFromPackedRefsFile(&refs, f, seen); err != nil {
+ return err
+ }
+
+ // Write them all to a new temp packed-refs file.
+ tmp, err := d.fs.TempFile("", tmpPackedRefsPrefix)
+ if err != nil {
+ return err
+ }
+ tmpName := tmp.Name()
+ defer func() {
+ ioutil.CheckClose(tmp, &err)
+ _ = d.fs.Remove(tmpName) // don't check err, we might have renamed it
+ }()
+
+ w := bufio.NewWriter(tmp)
+ for _, ref := range refs {
+ _, err = w.WriteString(ref.String() + "\n")
+ if err != nil {
+ return err
+ }
+ }
+ err = w.Flush()
+ if err != nil {
+ return err
+ }
+
+ // Rename the temp packed-refs file.
+ err = d.rewritePackedRefsWhileLocked(tmp, f)
+ if err != nil {
+ return err
+ }
+
+ // Delete all the loose refs, while still holding the packed-refs
+ // lock.
+ for _, ref := range refs[:numLooseRefs] {
+ path := d.fs.Join(".", ref.Name().String())
+ err = d.fs.Remove(path)
+ if err != nil && !os.IsNotExist(err) {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Module return a billy.Filesystem pointing to the module folder
+func (d *DotGit) Module(name string) (billy.Filesystem, error) {
+ return d.fs.Chroot(d.fs.Join(modulePath, name))
+}
+
+// Alternates returns DotGit(s) based off paths in objects/info/alternates if
+// available. This can be used to checks if it's a shared repository.
+func (d *DotGit) Alternates() ([]*DotGit, error) {
+ altpath := d.fs.Join("objects", "info", "alternates")
+ f, err := d.fs.Open(altpath)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ var alternates []*DotGit
+
+ // Read alternate paths line-by-line and create DotGit objects.
+ scanner := bufio.NewScanner(f)
+ for scanner.Scan() {
+ path := scanner.Text()
+ if !filepath.IsAbs(path) {
+ // For relative paths, we can perform an internal conversion to
+ // slash so that they work cross-platform.
+ slashPath := filepath.ToSlash(path)
+ // If the path is not absolute, it must be relative to object
+ // database (.git/objects/info).
+ // https://www.kernel.org/pub/software/scm/git/docs/gitrepository-layout.html
+ // Hence, derive a path relative to DotGit's root.
+ // "../../../reponame/.git/" -> "../../reponame/.git"
+ // Remove the first ../
+ relpath := filepath.Join(strings.Split(slashPath, "/")[1:]...)
+ normalPath := filepath.FromSlash(relpath)
+ path = filepath.Join(d.fs.Root(), normalPath)
+ }
+ fs := osfs.New(filepath.Dir(path))
+ alternates = append(alternates, New(fs))
+ }
+
+ if err = scanner.Err(); err != nil {
+ return nil, err
+ }
+
+ return alternates, nil
+}
+
+// Fs returns the underlying filesystem of the DotGit folder.
+func (d *DotGit) Fs() billy.Filesystem {
+ return d.fs
+}
+
+func isHex(s string) bool {
+ for _, b := range []byte(s) {
+ if isNum(b) {
+ continue
+ }
+ if isHexAlpha(b) {
+ continue
+ }
+
+ return false
+ }
+
+ return true
+}
+
+func isNum(b byte) bool {
+ return b >= '0' && b <= '9'
+}
+
+func isHexAlpha(b byte) bool {
+ return b >= 'a' && b <= 'f' || b >= 'A' && b <= 'F'
+}
--- /dev/null
+package dotgit
+
+import (
+ "io"
+ "os"
+ "runtime"
+
+ "gopkg.in/src-d/go-billy.v4"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+func (d *DotGit) openAndLockPackedRefsMode() int {
+ if billy.CapabilityCheck(d.fs, billy.ReadAndWriteCapability) {
+ return os.O_RDWR
+ }
+
+ return os.O_RDONLY
+}
+
+func (d *DotGit) rewritePackedRefsWhileLocked(
+ tmp billy.File, pr billy.File) error {
+ // Try plain rename. If we aren't using the bare Windows filesystem as the
+ // storage layer, we might be able to get away with a rename over a locked
+ // file.
+ err := d.fs.Rename(tmp.Name(), pr.Name())
+ if err == nil {
+ return nil
+ }
+
+ // If we are in a filesystem that does not support rename (e.g. sivafs)
+ // a full copy is done.
+ if err == billy.ErrNotSupported {
+ return d.copyNewFile(tmp, pr)
+ }
+
+ if runtime.GOOS != "windows" {
+ return err
+ }
+
+ // Otherwise, Windows doesn't let us rename over a locked file, so
+ // we have to do a straight copy. Unfortunately this could result
+ // in a partially-written file if the process fails before the
+ // copy completes.
+ return d.copyToExistingFile(tmp, pr)
+}
+
+func (d *DotGit) copyToExistingFile(tmp, pr billy.File) error {
+ _, err := pr.Seek(0, io.SeekStart)
+ if err != nil {
+ return err
+ }
+ err = pr.Truncate(0)
+ if err != nil {
+ return err
+ }
+ _, err = tmp.Seek(0, io.SeekStart)
+ if err != nil {
+ return err
+ }
+ _, err = io.Copy(pr, tmp)
+
+ return err
+}
+
+func (d *DotGit) copyNewFile(tmp billy.File, pr billy.File) (err error) {
+ prWrite, err := d.fs.Create(pr.Name())
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(prWrite, &err)
+
+ _, err = tmp.Seek(0, io.SeekStart)
+ if err != nil {
+ return err
+ }
+
+ _, err = io.Copy(prWrite, tmp)
+
+ return err
+}
--- /dev/null
+// +build !norwfs
+
+package dotgit
+
+import (
+ "os"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+func (d *DotGit) setRef(fileName, content string, old *plumbing.Reference) (err error) {
+ // If we are not checking an old ref, just truncate the file.
+ mode := os.O_RDWR | os.O_CREATE
+ if old == nil {
+ mode |= os.O_TRUNC
+ }
+
+ f, err := d.fs.OpenFile(fileName, mode, 0666)
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(f, &err)
+
+ // Lock is unlocked by the deferred Close above. This is because Unlock
+ // does not imply a fsync and thus there would be a race between
+ // Unlock+Close and other concurrent writers. Adding Sync to go-billy
+ // could work, but this is better (and avoids superfluous syncs).
+ err = f.Lock()
+ if err != nil {
+ return err
+ }
+
+ // this is a no-op to call even when old is nil.
+ err = d.checkReferenceAndTruncate(f, old)
+ if err != nil {
+ return err
+ }
+
+ _, err = f.Write([]byte(content))
+ return err
+}
--- /dev/null
+// +build norwfs
+
+package dotgit
+
+import (
+ "fmt"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+)
+
+// There are some filesystems that don't support opening files in RDWD mode.
+// In these filesystems the standard SetRef function can not be used as i
+// reads the reference file to check that it's not modified before updating it.
+//
+// This version of the function writes the reference without extra checks
+// making it compatible with these simple filesystems. This is usually not
+// a problem as they should be accessed by only one process at a time.
+func (d *DotGit) setRef(fileName, content string, old *plumbing.Reference) error {
+ _, err := d.fs.Stat(fileName)
+ if err == nil && old != nil {
+ fRead, err := d.fs.Open(fileName)
+ if err != nil {
+ return err
+ }
+
+ ref, err := d.readReferenceFrom(fRead, old.Name().String())
+ fRead.Close()
+
+ if err != nil {
+ return err
+ }
+
+ if ref.Hash() != old.Hash() {
+ return fmt.Errorf("reference has changed concurrently")
+ }
+ }
+
+ f, err := d.fs.Create(fileName)
+ if err != nil {
+ return err
+ }
+
+ defer f.Close()
+
+ _, err = f.Write([]byte(content))
+ return err
+}
--- /dev/null
+package dotgit
+
+import (
+ "fmt"
+ "io"
+ "sync/atomic"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/idxfile"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/objfile"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/packfile"
+
+ "gopkg.in/src-d/go-billy.v4"
+)
+
+// PackWriter is a io.Writer that generates the packfile index simultaneously,
+// a packfile.Decoder is used with a file reader to read the file being written
+// this operation is synchronized with the write operations.
+// The packfile is written in a temp file, when Close is called this file
+// is renamed/moved (depends on the Filesystem implementation) to the final
+// location, if the PackWriter is not used, nothing is written
+type PackWriter struct {
+ Notify func(plumbing.Hash, *idxfile.Writer)
+
+ fs billy.Filesystem
+ fr, fw billy.File
+ synced *syncedReader
+ checksum plumbing.Hash
+ parser *packfile.Parser
+ writer *idxfile.Writer
+ result chan error
+}
+
+func newPackWrite(fs billy.Filesystem) (*PackWriter, error) {
+ fw, err := fs.TempFile(fs.Join(objectsPath, packPath), "tmp_pack_")
+ if err != nil {
+ return nil, err
+ }
+
+ fr, err := fs.Open(fw.Name())
+ if err != nil {
+ return nil, err
+ }
+
+ writer := &PackWriter{
+ fs: fs,
+ fw: fw,
+ fr: fr,
+ synced: newSyncedReader(fw, fr),
+ result: make(chan error),
+ }
+
+ go writer.buildIndex()
+ return writer, nil
+}
+
+func (w *PackWriter) buildIndex() {
+ s := packfile.NewScanner(w.synced)
+ w.writer = new(idxfile.Writer)
+ var err error
+ w.parser, err = packfile.NewParser(s, w.writer)
+ if err != nil {
+ w.result <- err
+ return
+ }
+
+ checksum, err := w.parser.Parse()
+ if err != nil {
+ w.result <- err
+ return
+ }
+
+ w.checksum = checksum
+ w.result <- err
+}
+
+// waitBuildIndex waits until buildIndex function finishes, this can terminate
+// with a packfile.ErrEmptyPackfile, this means that nothing was written so we
+// ignore the error
+func (w *PackWriter) waitBuildIndex() error {
+ err := <-w.result
+ if err == packfile.ErrEmptyPackfile {
+ return nil
+ }
+
+ return err
+}
+
+func (w *PackWriter) Write(p []byte) (int, error) {
+ return w.synced.Write(p)
+}
+
+// Close closes all the file descriptors and save the final packfile, if nothing
+// was written, the tempfiles are deleted without writing a packfile.
+func (w *PackWriter) Close() error {
+ defer func() {
+ if w.Notify != nil && w.writer != nil && w.writer.Finished() {
+ w.Notify(w.checksum, w.writer)
+ }
+
+ close(w.result)
+ }()
+
+ if err := w.synced.Close(); err != nil {
+ return err
+ }
+
+ if err := w.waitBuildIndex(); err != nil {
+ return err
+ }
+
+ if err := w.fr.Close(); err != nil {
+ return err
+ }
+
+ if err := w.fw.Close(); err != nil {
+ return err
+ }
+
+ if w.writer == nil || !w.writer.Finished() {
+ return w.clean()
+ }
+
+ return w.save()
+}
+
+func (w *PackWriter) clean() error {
+ return w.fs.Remove(w.fw.Name())
+}
+
+func (w *PackWriter) save() error {
+ base := w.fs.Join(objectsPath, packPath, fmt.Sprintf("pack-%s", w.checksum))
+ idx, err := w.fs.Create(fmt.Sprintf("%s.idx", base))
+ if err != nil {
+ return err
+ }
+
+ if err := w.encodeIdx(idx); err != nil {
+ return err
+ }
+
+ if err := idx.Close(); err != nil {
+ return err
+ }
+
+ return w.fs.Rename(w.fw.Name(), fmt.Sprintf("%s.pack", base))
+}
+
+func (w *PackWriter) encodeIdx(writer io.Writer) error {
+ idx, err := w.writer.Index()
+ if err != nil {
+ return err
+ }
+
+ e := idxfile.NewEncoder(writer)
+ _, err = e.Encode(idx)
+ return err
+}
+
+type syncedReader struct {
+ w io.Writer
+ r io.ReadSeeker
+
+ blocked, done uint32
+ written, read uint64
+ news chan bool
+}
+
+func newSyncedReader(w io.Writer, r io.ReadSeeker) *syncedReader {
+ return &syncedReader{
+ w: w,
+ r: r,
+ news: make(chan bool),
+ }
+}
+
+func (s *syncedReader) Write(p []byte) (n int, err error) {
+ defer func() {
+ written := atomic.AddUint64(&s.written, uint64(n))
+ read := atomic.LoadUint64(&s.read)
+ if written > read {
+ s.wake()
+ }
+ }()
+
+ n, err = s.w.Write(p)
+ return
+}
+
+func (s *syncedReader) Read(p []byte) (n int, err error) {
+ defer func() { atomic.AddUint64(&s.read, uint64(n)) }()
+
+ for {
+ s.sleep()
+ n, err = s.r.Read(p)
+ if err == io.EOF && !s.isDone() && n == 0 {
+ continue
+ }
+
+ break
+ }
+
+ return
+}
+
+func (s *syncedReader) isDone() bool {
+ return atomic.LoadUint32(&s.done) == 1
+}
+
+func (s *syncedReader) isBlocked() bool {
+ return atomic.LoadUint32(&s.blocked) == 1
+}
+
+func (s *syncedReader) wake() {
+ if s.isBlocked() {
+ atomic.StoreUint32(&s.blocked, 0)
+ s.news <- true
+ }
+}
+
+func (s *syncedReader) sleep() {
+ read := atomic.LoadUint64(&s.read)
+ written := atomic.LoadUint64(&s.written)
+ if read >= written {
+ atomic.StoreUint32(&s.blocked, 1)
+ <-s.news
+ }
+
+}
+
+func (s *syncedReader) Seek(offset int64, whence int) (int64, error) {
+ if whence == io.SeekCurrent {
+ return s.r.Seek(offset, whence)
+ }
+
+ p, err := s.r.Seek(offset, whence)
+ atomic.StoreUint64(&s.read, uint64(p))
+
+ return p, err
+}
+
+func (s *syncedReader) Close() error {
+ atomic.StoreUint32(&s.done, 1)
+ close(s.news)
+ return nil
+}
+
+type ObjectWriter struct {
+ objfile.Writer
+ fs billy.Filesystem
+ f billy.File
+}
+
+func newObjectWriter(fs billy.Filesystem) (*ObjectWriter, error) {
+ f, err := fs.TempFile(fs.Join(objectsPath, packPath), "tmp_obj_")
+ if err != nil {
+ return nil, err
+ }
+
+ return &ObjectWriter{
+ Writer: (*objfile.NewWriter(f)),
+ fs: fs,
+ f: f,
+ }, nil
+}
+
+func (w *ObjectWriter) Close() error {
+ if err := w.Writer.Close(); err != nil {
+ return err
+ }
+
+ if err := w.f.Close(); err != nil {
+ return err
+ }
+
+ return w.save()
+}
+
+func (w *ObjectWriter) save() error {
+ hash := w.Hash().String()
+ file := w.fs.Join(objectsPath, hash[0:2], hash[2:40])
+
+ return w.fs.Rename(w.f.Name(), file)
+}
--- /dev/null
+package filesystem
+
+import (
+ "os"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/format/index"
+ "gopkg.in/src-d/go-git.v4/storage/filesystem/dotgit"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+type IndexStorage struct {
+ dir *dotgit.DotGit
+}
+
+func (s *IndexStorage) SetIndex(idx *index.Index) (err error) {
+ f, err := s.dir.IndexWriter()
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(f, &err)
+
+ e := index.NewEncoder(f)
+ err = e.Encode(idx)
+ return err
+}
+
+func (s *IndexStorage) Index() (i *index.Index, err error) {
+ idx := &index.Index{
+ Version: 2,
+ }
+
+ f, err := s.dir.Index()
+ if err != nil {
+ if os.IsNotExist(err) {
+ return idx, nil
+ }
+
+ return nil, err
+ }
+
+ defer ioutil.CheckClose(f, &err)
+
+ d := index.NewDecoder(f)
+ err = d.Decode(idx)
+ return idx, err
+}
--- /dev/null
+package filesystem
+
+import (
+ "gopkg.in/src-d/go-git.v4/plumbing/cache"
+ "gopkg.in/src-d/go-git.v4/storage"
+ "gopkg.in/src-d/go-git.v4/storage/filesystem/dotgit"
+)
+
+type ModuleStorage struct {
+ dir *dotgit.DotGit
+}
+
+func (s *ModuleStorage) Module(name string) (storage.Storer, error) {
+ fs, err := s.dir.Module(name)
+ if err != nil {
+ return nil, err
+ }
+
+ return NewStorage(fs, cache.NewObjectLRUDefault()), nil
+}
--- /dev/null
+package filesystem
+
+import (
+ "io"
+ "os"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/cache"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/idxfile"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/objfile"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/packfile"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/storage/filesystem/dotgit"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+
+ "gopkg.in/src-d/go-billy.v4"
+)
+
+type ObjectStorage struct {
+ options Options
+
+ // deltaBaseCache is an object cache uses to cache delta's bases when
+ deltaBaseCache cache.Object
+
+ dir *dotgit.DotGit
+ index map[plumbing.Hash]idxfile.Index
+}
+
+// NewObjectStorage creates a new ObjectStorage with the given .git directory and cache.
+func NewObjectStorage(dir *dotgit.DotGit, cache cache.Object) *ObjectStorage {
+ return NewObjectStorageWithOptions(dir, cache, Options{})
+}
+
+// NewObjectStorageWithOptions creates a new ObjectStorage with the given .git directory, cache and extra options
+func NewObjectStorageWithOptions(dir *dotgit.DotGit, cache cache.Object, ops Options) *ObjectStorage {
+ return &ObjectStorage{
+ options: ops,
+ deltaBaseCache: cache,
+ dir: dir,
+ }
+}
+
+func (s *ObjectStorage) requireIndex() error {
+ if s.index != nil {
+ return nil
+ }
+
+ s.index = make(map[plumbing.Hash]idxfile.Index)
+ packs, err := s.dir.ObjectPacks()
+ if err != nil {
+ return err
+ }
+
+ for _, h := range packs {
+ if err := s.loadIdxFile(h); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Reindex indexes again all packfiles. Useful if git changed packfiles externally
+func (s *ObjectStorage) Reindex() {
+ s.index = nil
+}
+
+func (s *ObjectStorage) loadIdxFile(h plumbing.Hash) (err error) {
+ f, err := s.dir.ObjectPackIdx(h)
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(f, &err)
+
+ idxf := idxfile.NewMemoryIndex()
+ d := idxfile.NewDecoder(f)
+ if err = d.Decode(idxf); err != nil {
+ return err
+ }
+
+ s.index[h] = idxf
+ return err
+}
+
+func (s *ObjectStorage) NewEncodedObject() plumbing.EncodedObject {
+ return &plumbing.MemoryObject{}
+}
+
+func (s *ObjectStorage) PackfileWriter() (io.WriteCloser, error) {
+ if err := s.requireIndex(); err != nil {
+ return nil, err
+ }
+
+ w, err := s.dir.NewObjectPack()
+ if err != nil {
+ return nil, err
+ }
+
+ w.Notify = func(h plumbing.Hash, writer *idxfile.Writer) {
+ index, err := writer.Index()
+ if err == nil {
+ s.index[h] = index
+ }
+ }
+
+ return w, nil
+}
+
+// SetEncodedObject adds a new object to the storage.
+func (s *ObjectStorage) SetEncodedObject(o plumbing.EncodedObject) (h plumbing.Hash, err error) {
+ if o.Type() == plumbing.OFSDeltaObject || o.Type() == plumbing.REFDeltaObject {
+ return plumbing.ZeroHash, plumbing.ErrInvalidType
+ }
+
+ ow, err := s.dir.NewObject()
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ defer ioutil.CheckClose(ow, &err)
+
+ or, err := o.Reader()
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ defer ioutil.CheckClose(or, &err)
+
+ if err = ow.WriteHeader(o.Type(), o.Size()); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ if _, err = io.Copy(ow, or); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ return o.Hash(), err
+}
+
+// HasEncodedObject returns nil if the object exists, without actually
+// reading the object data from storage.
+func (s *ObjectStorage) HasEncodedObject(h plumbing.Hash) (err error) {
+ // Check unpacked objects
+ f, err := s.dir.Object(h)
+ if err != nil {
+ if !os.IsNotExist(err) {
+ return err
+ }
+ // Fall through to check packed objects.
+ } else {
+ defer ioutil.CheckClose(f, &err)
+ return nil
+ }
+
+ // Check packed objects.
+ if err := s.requireIndex(); err != nil {
+ return err
+ }
+ _, _, offset := s.findObjectInPackfile(h)
+ if offset == -1 {
+ return plumbing.ErrObjectNotFound
+ }
+ return nil
+}
+
+func (s *ObjectStorage) encodedObjectSizeFromUnpacked(h plumbing.Hash) (
+ size int64, err error) {
+ f, err := s.dir.Object(h)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return 0, plumbing.ErrObjectNotFound
+ }
+
+ return 0, err
+ }
+
+ r, err := objfile.NewReader(f)
+ if err != nil {
+ return 0, err
+ }
+ defer ioutil.CheckClose(r, &err)
+
+ _, size, err = r.Header()
+ return size, err
+}
+
+func (s *ObjectStorage) encodedObjectSizeFromPackfile(h plumbing.Hash) (
+ size int64, err error) {
+ if err := s.requireIndex(); err != nil {
+ return 0, err
+ }
+
+ pack, _, offset := s.findObjectInPackfile(h)
+ if offset == -1 {
+ return 0, plumbing.ErrObjectNotFound
+ }
+
+ f, err := s.dir.ObjectPack(pack)
+ if err != nil {
+ return 0, err
+ }
+ defer ioutil.CheckClose(f, &err)
+
+ idx := s.index[pack]
+ hash, err := idx.FindHash(offset)
+ if err == nil {
+ obj, ok := s.deltaBaseCache.Get(hash)
+ if ok {
+ return obj.Size(), nil
+ }
+ } else if err != nil && err != plumbing.ErrObjectNotFound {
+ return 0, err
+ }
+
+ var p *packfile.Packfile
+ if s.deltaBaseCache != nil {
+ p = packfile.NewPackfileWithCache(idx, s.dir.Fs(), f, s.deltaBaseCache)
+ } else {
+ p = packfile.NewPackfile(idx, s.dir.Fs(), f)
+ }
+
+ return p.GetSizeByOffset(offset)
+}
+
+// EncodedObjectSize returns the plaintext size of the given object,
+// without actually reading the full object data from storage.
+func (s *ObjectStorage) EncodedObjectSize(h plumbing.Hash) (
+ size int64, err error) {
+ size, err = s.encodedObjectSizeFromUnpacked(h)
+ if err != nil && err != plumbing.ErrObjectNotFound {
+ return 0, err
+ } else if err == nil {
+ return size, nil
+ }
+
+ return s.encodedObjectSizeFromPackfile(h)
+}
+
+// EncodedObject returns the object with the given hash, by searching for it in
+// the packfile and the git object directories.
+func (s *ObjectStorage) EncodedObject(t plumbing.ObjectType, h plumbing.Hash) (plumbing.EncodedObject, error) {
+ obj, err := s.getFromUnpacked(h)
+ if err == plumbing.ErrObjectNotFound {
+ obj, err = s.getFromPackfile(h, false)
+ }
+
+ // If the error is still object not found, check if it's a shared object
+ // repository.
+ if err == plumbing.ErrObjectNotFound {
+ dotgits, e := s.dir.Alternates()
+ if e == nil {
+ // Create a new object storage with the DotGit(s) and check for the
+ // required hash object. Skip when not found.
+ for _, dg := range dotgits {
+ o := NewObjectStorage(dg, s.deltaBaseCache)
+ enobj, enerr := o.EncodedObject(t, h)
+ if enerr != nil {
+ continue
+ }
+ return enobj, nil
+ }
+ }
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ if plumbing.AnyObject != t && obj.Type() != t {
+ return nil, plumbing.ErrObjectNotFound
+ }
+
+ return obj, nil
+}
+
+// DeltaObject returns the object with the given hash, by searching for
+// it in the packfile and the git object directories.
+func (s *ObjectStorage) DeltaObject(t plumbing.ObjectType,
+ h plumbing.Hash) (plumbing.EncodedObject, error) {
+ obj, err := s.getFromUnpacked(h)
+ if err == plumbing.ErrObjectNotFound {
+ obj, err = s.getFromPackfile(h, true)
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ if plumbing.AnyObject != t && obj.Type() != t {
+ return nil, plumbing.ErrObjectNotFound
+ }
+
+ return obj, nil
+}
+
+func (s *ObjectStorage) getFromUnpacked(h plumbing.Hash) (obj plumbing.EncodedObject, err error) {
+ f, err := s.dir.Object(h)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, plumbing.ErrObjectNotFound
+ }
+
+ return nil, err
+ }
+
+ defer ioutil.CheckClose(f, &err)
+
+ obj = s.NewEncodedObject()
+ r, err := objfile.NewReader(f)
+ if err != nil {
+ return nil, err
+ }
+
+ defer ioutil.CheckClose(r, &err)
+
+ t, size, err := r.Header()
+ if err != nil {
+ return nil, err
+ }
+
+ obj.SetType(t)
+ obj.SetSize(size)
+ w, err := obj.Writer()
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = io.Copy(w, r)
+ return obj, err
+}
+
+// Get returns the object with the given hash, by searching for it in
+// the packfile.
+func (s *ObjectStorage) getFromPackfile(h plumbing.Hash, canBeDelta bool) (
+ plumbing.EncodedObject, error) {
+
+ if err := s.requireIndex(); err != nil {
+ return nil, err
+ }
+
+ pack, hash, offset := s.findObjectInPackfile(h)
+ if offset == -1 {
+ return nil, plumbing.ErrObjectNotFound
+ }
+
+ f, err := s.dir.ObjectPack(pack)
+ if err != nil {
+ return nil, err
+ }
+
+ if !s.options.KeepDescriptors {
+ defer ioutil.CheckClose(f, &err)
+ }
+
+ idx := s.index[pack]
+ if canBeDelta {
+ return s.decodeDeltaObjectAt(f, idx, offset, hash)
+ }
+
+ return s.decodeObjectAt(f, idx, offset)
+}
+
+func (s *ObjectStorage) decodeObjectAt(
+ f billy.File,
+ idx idxfile.Index,
+ offset int64,
+) (plumbing.EncodedObject, error) {
+ hash, err := idx.FindHash(offset)
+ if err == nil {
+ obj, ok := s.deltaBaseCache.Get(hash)
+ if ok {
+ return obj, nil
+ }
+ }
+
+ if err != nil && err != plumbing.ErrObjectNotFound {
+ return nil, err
+ }
+
+ var p *packfile.Packfile
+ if s.deltaBaseCache != nil {
+ p = packfile.NewPackfileWithCache(idx, s.dir.Fs(), f, s.deltaBaseCache)
+ } else {
+ p = packfile.NewPackfile(idx, s.dir.Fs(), f)
+ }
+
+ return p.GetByOffset(offset)
+}
+
+func (s *ObjectStorage) decodeDeltaObjectAt(
+ f billy.File,
+ idx idxfile.Index,
+ offset int64,
+ hash plumbing.Hash,
+) (plumbing.EncodedObject, error) {
+ if _, err := f.Seek(0, io.SeekStart); err != nil {
+ return nil, err
+ }
+
+ p := packfile.NewScanner(f)
+ if _, err := p.SeekFromStart(offset); err != nil {
+ return nil, err
+ }
+
+ header, err := p.NextObjectHeader()
+ if err != nil {
+ return nil, err
+ }
+
+ var (
+ base plumbing.Hash
+ )
+
+ switch header.Type {
+ case plumbing.REFDeltaObject:
+ base = header.Reference
+ case plumbing.OFSDeltaObject:
+ base, err = idx.FindHash(header.OffsetReference)
+ if err != nil {
+ return nil, err
+ }
+ default:
+ return s.decodeObjectAt(f, idx, offset)
+ }
+
+ obj := &plumbing.MemoryObject{}
+ obj.SetType(header.Type)
+ w, err := obj.Writer()
+ if err != nil {
+ return nil, err
+ }
+
+ if _, _, err := p.NextObject(w); err != nil {
+ return nil, err
+ }
+
+ return newDeltaObject(obj, hash, base, header.Length), nil
+}
+
+func (s *ObjectStorage) findObjectInPackfile(h plumbing.Hash) (plumbing.Hash, plumbing.Hash, int64) {
+ for packfile, index := range s.index {
+ offset, err := index.FindOffset(h)
+ if err == nil {
+ return packfile, h, offset
+ }
+ }
+
+ return plumbing.ZeroHash, plumbing.ZeroHash, -1
+}
+
+// IterEncodedObjects returns an iterator for all the objects in the packfile
+// with the given type.
+func (s *ObjectStorage) IterEncodedObjects(t plumbing.ObjectType) (storer.EncodedObjectIter, error) {
+ objects, err := s.dir.Objects()
+ if err != nil {
+ return nil, err
+ }
+
+ seen := make(map[plumbing.Hash]struct{})
+ var iters []storer.EncodedObjectIter
+ if len(objects) != 0 {
+ iters = append(iters, &objectsIter{s: s, t: t, h: objects})
+ seen = hashListAsMap(objects)
+ }
+
+ packi, err := s.buildPackfileIters(t, seen)
+ if err != nil {
+ return nil, err
+ }
+
+ iters = append(iters, packi)
+ return storer.NewMultiEncodedObjectIter(iters), nil
+}
+
+func (s *ObjectStorage) buildPackfileIters(
+ t plumbing.ObjectType,
+ seen map[plumbing.Hash]struct{},
+) (storer.EncodedObjectIter, error) {
+ if err := s.requireIndex(); err != nil {
+ return nil, err
+ }
+
+ packs, err := s.dir.ObjectPacks()
+ if err != nil {
+ return nil, err
+ }
+ return &lazyPackfilesIter{
+ hashes: packs,
+ open: func(h plumbing.Hash) (storer.EncodedObjectIter, error) {
+ pack, err := s.dir.ObjectPack(h)
+ if err != nil {
+ return nil, err
+ }
+ return newPackfileIter(
+ s.dir.Fs(), pack, t, seen, s.index[h],
+ s.deltaBaseCache, s.options.KeepDescriptors,
+ )
+ },
+ }, nil
+}
+
+// Close closes all opened files.
+func (s *ObjectStorage) Close() error {
+ return s.dir.Close()
+}
+
+type lazyPackfilesIter struct {
+ hashes []plumbing.Hash
+ open func(h plumbing.Hash) (storer.EncodedObjectIter, error)
+ cur storer.EncodedObjectIter
+}
+
+func (it *lazyPackfilesIter) Next() (plumbing.EncodedObject, error) {
+ for {
+ if it.cur == nil {
+ if len(it.hashes) == 0 {
+ return nil, io.EOF
+ }
+ h := it.hashes[0]
+ it.hashes = it.hashes[1:]
+
+ sub, err := it.open(h)
+ if err == io.EOF {
+ continue
+ } else if err != nil {
+ return nil, err
+ }
+ it.cur = sub
+ }
+ ob, err := it.cur.Next()
+ if err == io.EOF {
+ it.cur.Close()
+ it.cur = nil
+ continue
+ } else if err != nil {
+ return nil, err
+ }
+ return ob, nil
+ }
+}
+
+func (it *lazyPackfilesIter) ForEach(cb func(plumbing.EncodedObject) error) error {
+ return storer.ForEachIterator(it, cb)
+}
+
+func (it *lazyPackfilesIter) Close() {
+ if it.cur != nil {
+ it.cur.Close()
+ it.cur = nil
+ }
+ it.hashes = nil
+}
+
+type packfileIter struct {
+ pack billy.File
+ iter storer.EncodedObjectIter
+ seen map[plumbing.Hash]struct{}
+
+ // tells whether the pack file should be left open after iteration or not
+ keepPack bool
+}
+
+// NewPackfileIter returns a new EncodedObjectIter for the provided packfile
+// and object type. Packfile and index file will be closed after they're
+// used. If keepPack is true the packfile won't be closed after the iteration
+// finished.
+func NewPackfileIter(
+ fs billy.Filesystem,
+ f billy.File,
+ idxFile billy.File,
+ t plumbing.ObjectType,
+ keepPack bool,
+) (storer.EncodedObjectIter, error) {
+ idx := idxfile.NewMemoryIndex()
+ if err := idxfile.NewDecoder(idxFile).Decode(idx); err != nil {
+ return nil, err
+ }
+
+ if err := idxFile.Close(); err != nil {
+ return nil, err
+ }
+
+ seen := make(map[plumbing.Hash]struct{})
+ return newPackfileIter(fs, f, t, seen, idx, nil, keepPack)
+}
+
+func newPackfileIter(
+ fs billy.Filesystem,
+ f billy.File,
+ t plumbing.ObjectType,
+ seen map[plumbing.Hash]struct{},
+ index idxfile.Index,
+ cache cache.Object,
+ keepPack bool,
+) (storer.EncodedObjectIter, error) {
+ var p *packfile.Packfile
+ if cache != nil {
+ p = packfile.NewPackfileWithCache(index, fs, f, cache)
+ } else {
+ p = packfile.NewPackfile(index, fs, f)
+ }
+
+ iter, err := p.GetByType(t)
+ if err != nil {
+ return nil, err
+ }
+
+ return &packfileIter{
+ pack: f,
+ iter: iter,
+ seen: seen,
+ keepPack: keepPack,
+ }, nil
+}
+
+func (iter *packfileIter) Next() (plumbing.EncodedObject, error) {
+ for {
+ obj, err := iter.iter.Next()
+ if err != nil {
+ return nil, err
+ }
+
+ if _, ok := iter.seen[obj.Hash()]; ok {
+ continue
+ }
+
+ return obj, nil
+ }
+}
+
+func (iter *packfileIter) ForEach(cb func(plumbing.EncodedObject) error) error {
+ for {
+ o, err := iter.Next()
+ if err != nil {
+ if err == io.EOF {
+ iter.Close()
+ return nil
+ }
+ return err
+ }
+
+ if err := cb(o); err != nil {
+ return err
+ }
+ }
+}
+
+func (iter *packfileIter) Close() {
+ iter.iter.Close()
+ if !iter.keepPack {
+ _ = iter.pack.Close()
+ }
+}
+
+type objectsIter struct {
+ s *ObjectStorage
+ t plumbing.ObjectType
+ h []plumbing.Hash
+}
+
+func (iter *objectsIter) Next() (plumbing.EncodedObject, error) {
+ if len(iter.h) == 0 {
+ return nil, io.EOF
+ }
+
+ obj, err := iter.s.getFromUnpacked(iter.h[0])
+ iter.h = iter.h[1:]
+
+ if err != nil {
+ return nil, err
+ }
+
+ if iter.t != plumbing.AnyObject && iter.t != obj.Type() {
+ return iter.Next()
+ }
+
+ return obj, err
+}
+
+func (iter *objectsIter) ForEach(cb func(plumbing.EncodedObject) error) error {
+ for {
+ o, err := iter.Next()
+ if err != nil {
+ if err == io.EOF {
+ return nil
+ }
+ return err
+ }
+
+ if err := cb(o); err != nil {
+ return err
+ }
+ }
+}
+
+func (iter *objectsIter) Close() {
+ iter.h = []plumbing.Hash{}
+}
+
+func hashListAsMap(l []plumbing.Hash) map[plumbing.Hash]struct{} {
+ m := make(map[plumbing.Hash]struct{}, len(l))
+ for _, h := range l {
+ m[h] = struct{}{}
+ }
+ return m
+}
+
+func (s *ObjectStorage) ForEachObjectHash(fun func(plumbing.Hash) error) error {
+ err := s.dir.ForEachObjectHash(fun)
+ if err == storer.ErrStop {
+ return nil
+ }
+ return err
+}
+
+func (s *ObjectStorage) LooseObjectTime(hash plumbing.Hash) (time.Time, error) {
+ fi, err := s.dir.ObjectStat(hash)
+ if err != nil {
+ return time.Time{}, err
+ }
+ return fi.ModTime(), nil
+}
+
+func (s *ObjectStorage) DeleteLooseObject(hash plumbing.Hash) error {
+ return s.dir.ObjectDelete(hash)
+}
+
+func (s *ObjectStorage) ObjectPacks() ([]plumbing.Hash, error) {
+ return s.dir.ObjectPacks()
+}
+
+func (s *ObjectStorage) DeleteOldObjectPackAndIndex(h plumbing.Hash, t time.Time) error {
+ return s.dir.DeleteOldObjectPackAndIndex(h, t)
+}
--- /dev/null
+package filesystem
+
+import (
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/storage/filesystem/dotgit"
+)
+
+type ReferenceStorage struct {
+ dir *dotgit.DotGit
+}
+
+func (r *ReferenceStorage) SetReference(ref *plumbing.Reference) error {
+ return r.dir.SetRef(ref, nil)
+}
+
+func (r *ReferenceStorage) CheckAndSetReference(ref, old *plumbing.Reference) error {
+ return r.dir.SetRef(ref, old)
+}
+
+func (r *ReferenceStorage) Reference(n plumbing.ReferenceName) (*plumbing.Reference, error) {
+ return r.dir.Ref(n)
+}
+
+func (r *ReferenceStorage) IterReferences() (storer.ReferenceIter, error) {
+ refs, err := r.dir.Refs()
+ if err != nil {
+ return nil, err
+ }
+
+ return storer.NewReferenceSliceIter(refs), nil
+}
+
+func (r *ReferenceStorage) RemoveReference(n plumbing.ReferenceName) error {
+ return r.dir.RemoveRef(n)
+}
+
+func (r *ReferenceStorage) CountLooseRefs() (int, error) {
+ return r.dir.CountLooseRefs()
+}
+
+func (r *ReferenceStorage) PackRefs() error {
+ return r.dir.PackRefs()
+}
--- /dev/null
+package filesystem
+
+import (
+ "bufio"
+ "fmt"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/storage/filesystem/dotgit"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+)
+
+// ShallowStorage where the shallow commits are stored, an internal to
+// manipulate the shallow file
+type ShallowStorage struct {
+ dir *dotgit.DotGit
+}
+
+// SetShallow save the shallows in the shallow file in the .git folder as one
+// commit per line represented by 40-byte hexadecimal object terminated by a
+// newline.
+func (s *ShallowStorage) SetShallow(commits []plumbing.Hash) error {
+ f, err := s.dir.ShallowWriter()
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(f, &err)
+ for _, h := range commits {
+ if _, err := fmt.Fprintf(f, "%s\n", h); err != nil {
+ return err
+ }
+ }
+
+ return err
+}
+
+// Shallow return the shallow commits reading from shallo file from .git
+func (s *ShallowStorage) Shallow() ([]plumbing.Hash, error) {
+ f, err := s.dir.Shallow()
+ if f == nil || err != nil {
+ return nil, err
+ }
+
+ defer ioutil.CheckClose(f, &err)
+
+ var hash []plumbing.Hash
+
+ scn := bufio.NewScanner(f)
+ for scn.Scan() {
+ hash = append(hash, plumbing.NewHash(scn.Text()))
+ }
+
+ return hash, scn.Err()
+}
--- /dev/null
+// Package filesystem is a storage backend base on filesystems
+package filesystem
+
+import (
+ "gopkg.in/src-d/go-git.v4/plumbing/cache"
+ "gopkg.in/src-d/go-git.v4/storage/filesystem/dotgit"
+
+ "gopkg.in/src-d/go-billy.v4"
+)
+
+// Storage is an implementation of git.Storer that stores data on disk in the
+// standard git format (this is, the .git directory). Zero values of this type
+// are not safe to use, see the NewStorage function below.
+type Storage struct {
+ fs billy.Filesystem
+ dir *dotgit.DotGit
+
+ ObjectStorage
+ ReferenceStorage
+ IndexStorage
+ ShallowStorage
+ ConfigStorage
+ ModuleStorage
+}
+
+// Options holds configuration for the storage.
+type Options struct {
+ // ExclusiveAccess means that the filesystem is not modified externally
+ // while the repo is open.
+ ExclusiveAccess bool
+ // KeepDescriptors makes the file descriptors to be reused but they will
+ // need to be manually closed calling Close().
+ KeepDescriptors bool
+}
+
+// NewStorage returns a new Storage backed by a given `fs.Filesystem` and cache.
+func NewStorage(fs billy.Filesystem, cache cache.Object) *Storage {
+ return NewStorageWithOptions(fs, cache, Options{})
+}
+
+// NewStorageWithOptions returns a new Storage with extra options,
+// backed by a given `fs.Filesystem` and cache.
+func NewStorageWithOptions(fs billy.Filesystem, cache cache.Object, ops Options) *Storage {
+ dirOps := dotgit.Options{
+ ExclusiveAccess: ops.ExclusiveAccess,
+ KeepDescriptors: ops.KeepDescriptors,
+ }
+ dir := dotgit.NewWithOptions(fs, dirOps)
+
+ return &Storage{
+ fs: fs,
+ dir: dir,
+
+ ObjectStorage: ObjectStorage{
+ options: ops,
+ deltaBaseCache: cache,
+ dir: dir,
+ },
+ ReferenceStorage: ReferenceStorage{dir: dir},
+ IndexStorage: IndexStorage{dir: dir},
+ ShallowStorage: ShallowStorage{dir: dir},
+ ConfigStorage: ConfigStorage{dir: dir},
+ ModuleStorage: ModuleStorage{dir: dir},
+ }
+}
+
+// Filesystem returns the underlying filesystem
+func (s *Storage) Filesystem() billy.Filesystem {
+ return s.fs
+}
+
+// Init initializes .git directory
+func (s *Storage) Init() error {
+ return s.dir.Initialize()
+}
--- /dev/null
+// Package memory is a storage backend base on memory
+package memory
+
+import (
+ "fmt"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/config"
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/index"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/storage"
+)
+
+var ErrUnsupportedObjectType = fmt.Errorf("unsupported object type")
+var ErrRefHasChanged = fmt.Errorf("reference has changed concurrently")
+
+// Storage is an implementation of git.Storer that stores data on memory, being
+// ephemeral. The use of this storage should be done in controlled envoriments,
+// since the representation in memory of some repository can fill the machine
+// memory. in the other hand this storage has the best performance.
+type Storage struct {
+ ConfigStorage
+ ObjectStorage
+ ShallowStorage
+ IndexStorage
+ ReferenceStorage
+ ModuleStorage
+}
+
+// NewStorage returns a new Storage base on memory
+func NewStorage() *Storage {
+ return &Storage{
+ ReferenceStorage: make(ReferenceStorage),
+ ConfigStorage: ConfigStorage{},
+ ShallowStorage: ShallowStorage{},
+ ObjectStorage: ObjectStorage{
+ Objects: make(map[plumbing.Hash]plumbing.EncodedObject),
+ Commits: make(map[plumbing.Hash]plumbing.EncodedObject),
+ Trees: make(map[plumbing.Hash]plumbing.EncodedObject),
+ Blobs: make(map[plumbing.Hash]plumbing.EncodedObject),
+ Tags: make(map[plumbing.Hash]plumbing.EncodedObject),
+ },
+ ModuleStorage: make(ModuleStorage),
+ }
+}
+
+type ConfigStorage struct {
+ config *config.Config
+}
+
+func (c *ConfigStorage) SetConfig(cfg *config.Config) error {
+ if err := cfg.Validate(); err != nil {
+ return err
+ }
+
+ c.config = cfg
+ return nil
+}
+
+func (c *ConfigStorage) Config() (*config.Config, error) {
+ if c.config == nil {
+ c.config = config.NewConfig()
+ }
+
+ return c.config, nil
+}
+
+type IndexStorage struct {
+ index *index.Index
+}
+
+func (c *IndexStorage) SetIndex(idx *index.Index) error {
+ c.index = idx
+ return nil
+}
+
+func (c *IndexStorage) Index() (*index.Index, error) {
+ if c.index == nil {
+ c.index = &index.Index{Version: 2}
+ }
+
+ return c.index, nil
+}
+
+type ObjectStorage struct {
+ Objects map[plumbing.Hash]plumbing.EncodedObject
+ Commits map[plumbing.Hash]plumbing.EncodedObject
+ Trees map[plumbing.Hash]plumbing.EncodedObject
+ Blobs map[plumbing.Hash]plumbing.EncodedObject
+ Tags map[plumbing.Hash]plumbing.EncodedObject
+}
+
+func (o *ObjectStorage) NewEncodedObject() plumbing.EncodedObject {
+ return &plumbing.MemoryObject{}
+}
+
+func (o *ObjectStorage) SetEncodedObject(obj plumbing.EncodedObject) (plumbing.Hash, error) {
+ h := obj.Hash()
+ o.Objects[h] = obj
+
+ switch obj.Type() {
+ case plumbing.CommitObject:
+ o.Commits[h] = o.Objects[h]
+ case plumbing.TreeObject:
+ o.Trees[h] = o.Objects[h]
+ case plumbing.BlobObject:
+ o.Blobs[h] = o.Objects[h]
+ case plumbing.TagObject:
+ o.Tags[h] = o.Objects[h]
+ default:
+ return h, ErrUnsupportedObjectType
+ }
+
+ return h, nil
+}
+
+func (o *ObjectStorage) HasEncodedObject(h plumbing.Hash) (err error) {
+ if _, ok := o.Objects[h]; !ok {
+ return plumbing.ErrObjectNotFound
+ }
+ return nil
+}
+
+func (o *ObjectStorage) EncodedObjectSize(h plumbing.Hash) (
+ size int64, err error) {
+ obj, ok := o.Objects[h]
+ if !ok {
+ return 0, plumbing.ErrObjectNotFound
+ }
+
+ return obj.Size(), nil
+}
+
+func (o *ObjectStorage) EncodedObject(t plumbing.ObjectType, h plumbing.Hash) (plumbing.EncodedObject, error) {
+ obj, ok := o.Objects[h]
+ if !ok || (plumbing.AnyObject != t && obj.Type() != t) {
+ return nil, plumbing.ErrObjectNotFound
+ }
+
+ return obj, nil
+}
+
+func (o *ObjectStorage) IterEncodedObjects(t plumbing.ObjectType) (storer.EncodedObjectIter, error) {
+ var series []plumbing.EncodedObject
+ switch t {
+ case plumbing.AnyObject:
+ series = flattenObjectMap(o.Objects)
+ case plumbing.CommitObject:
+ series = flattenObjectMap(o.Commits)
+ case plumbing.TreeObject:
+ series = flattenObjectMap(o.Trees)
+ case plumbing.BlobObject:
+ series = flattenObjectMap(o.Blobs)
+ case plumbing.TagObject:
+ series = flattenObjectMap(o.Tags)
+ }
+
+ return storer.NewEncodedObjectSliceIter(series), nil
+}
+
+func flattenObjectMap(m map[plumbing.Hash]plumbing.EncodedObject) []plumbing.EncodedObject {
+ objects := make([]plumbing.EncodedObject, 0, len(m))
+ for _, obj := range m {
+ objects = append(objects, obj)
+ }
+ return objects
+}
+
+func (o *ObjectStorage) Begin() storer.Transaction {
+ return &TxObjectStorage{
+ Storage: o,
+ Objects: make(map[plumbing.Hash]plumbing.EncodedObject),
+ }
+}
+
+func (o *ObjectStorage) ForEachObjectHash(fun func(plumbing.Hash) error) error {
+ for h := range o.Objects {
+ err := fun(h)
+ if err != nil {
+ if err == storer.ErrStop {
+ return nil
+ }
+ return err
+ }
+ }
+ return nil
+}
+
+func (o *ObjectStorage) ObjectPacks() ([]plumbing.Hash, error) {
+ return nil, nil
+}
+func (o *ObjectStorage) DeleteOldObjectPackAndIndex(plumbing.Hash, time.Time) error {
+ return nil
+}
+
+var errNotSupported = fmt.Errorf("Not supported")
+
+func (s *ObjectStorage) LooseObjectTime(hash plumbing.Hash) (time.Time, error) {
+ return time.Time{}, errNotSupported
+}
+func (s *ObjectStorage) DeleteLooseObject(plumbing.Hash) error {
+ return errNotSupported
+}
+
+type TxObjectStorage struct {
+ Storage *ObjectStorage
+ Objects map[plumbing.Hash]plumbing.EncodedObject
+}
+
+func (tx *TxObjectStorage) SetEncodedObject(obj plumbing.EncodedObject) (plumbing.Hash, error) {
+ h := obj.Hash()
+ tx.Objects[h] = obj
+
+ return h, nil
+}
+
+func (tx *TxObjectStorage) EncodedObject(t plumbing.ObjectType, h plumbing.Hash) (plumbing.EncodedObject, error) {
+ obj, ok := tx.Objects[h]
+ if !ok || (plumbing.AnyObject != t && obj.Type() != t) {
+ return nil, plumbing.ErrObjectNotFound
+ }
+
+ return obj, nil
+}
+
+func (tx *TxObjectStorage) Commit() error {
+ for h, obj := range tx.Objects {
+ delete(tx.Objects, h)
+ if _, err := tx.Storage.SetEncodedObject(obj); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (tx *TxObjectStorage) Rollback() error {
+ tx.Objects = make(map[plumbing.Hash]plumbing.EncodedObject)
+ return nil
+}
+
+type ReferenceStorage map[plumbing.ReferenceName]*plumbing.Reference
+
+func (r ReferenceStorage) SetReference(ref *plumbing.Reference) error {
+ if ref != nil {
+ r[ref.Name()] = ref
+ }
+
+ return nil
+}
+
+func (r ReferenceStorage) CheckAndSetReference(ref, old *plumbing.Reference) error {
+ if ref == nil {
+ return nil
+ }
+
+ if old != nil {
+ tmp := r[ref.Name()]
+ if tmp != nil && tmp.Hash() != old.Hash() {
+ return ErrRefHasChanged
+ }
+ }
+ r[ref.Name()] = ref
+ return nil
+}
+
+func (r ReferenceStorage) Reference(n plumbing.ReferenceName) (*plumbing.Reference, error) {
+ ref, ok := r[n]
+ if !ok {
+ return nil, plumbing.ErrReferenceNotFound
+ }
+
+ return ref, nil
+}
+
+func (r ReferenceStorage) IterReferences() (storer.ReferenceIter, error) {
+ var refs []*plumbing.Reference
+ for _, ref := range r {
+ refs = append(refs, ref)
+ }
+
+ return storer.NewReferenceSliceIter(refs), nil
+}
+
+func (r ReferenceStorage) CountLooseRefs() (int, error) {
+ return len(r), nil
+}
+
+func (r ReferenceStorage) PackRefs() error {
+ return nil
+}
+
+func (r ReferenceStorage) RemoveReference(n plumbing.ReferenceName) error {
+ delete(r, n)
+ return nil
+}
+
+type ShallowStorage []plumbing.Hash
+
+func (s *ShallowStorage) SetShallow(commits []plumbing.Hash) error {
+ *s = commits
+ return nil
+}
+
+func (s ShallowStorage) Shallow() ([]plumbing.Hash, error) {
+ return s, nil
+}
+
+type ModuleStorage map[string]*Storage
+
+func (s ModuleStorage) Module(name string) (storage.Storer, error) {
+ if m, ok := s[name]; ok {
+ return m, nil
+ }
+
+ m := NewStorage()
+ s[name] = m
+
+ return m, nil
+}
--- /dev/null
+package storage
+
+import (
+ "gopkg.in/src-d/go-git.v4/config"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+)
+
+// Storer is a generic storage of objects, references and any information
+// related to a particular repository. The package gopkg.in/src-d/go-git.v4/storage
+// contains two implementation a filesystem base implementation (such as `.git`)
+// and a memory implementations being ephemeral
+type Storer interface {
+ storer.EncodedObjectStorer
+ storer.ReferenceStorer
+ storer.ShallowStorer
+ storer.IndexStorer
+ config.ConfigStorer
+ ModuleStorer
+}
+
+// ModuleStorer allows interact with the modules' Storers
+type ModuleStorer interface {
+ // Module returns a Storer representing a submodule, if not exists returns a
+ // new empty Storer is returned
+ Module(name string) (Storer, error)
+}
--- /dev/null
+package git
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+
+ "gopkg.in/src-d/go-billy.v4"
+ "gopkg.in/src-d/go-git.v4/config"
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/index"
+)
+
+var (
+ ErrSubmoduleAlreadyInitialized = errors.New("submodule already initialized")
+ ErrSubmoduleNotInitialized = errors.New("submodule not initialized")
+)
+
+// Submodule a submodule allows you to keep another Git repository in a
+// subdirectory of your repository.
+type Submodule struct {
+ // initialized defines if a submodule was already initialized.
+ initialized bool
+
+ c *config.Submodule
+ w *Worktree
+}
+
+// Config returns the submodule config
+func (s *Submodule) Config() *config.Submodule {
+ return s.c
+}
+
+// Init initialize the submodule reading the recorded Entry in the index for
+// the given submodule
+func (s *Submodule) Init() error {
+ cfg, err := s.w.r.Storer.Config()
+ if err != nil {
+ return err
+ }
+
+ _, ok := cfg.Submodules[s.c.Name]
+ if ok {
+ return ErrSubmoduleAlreadyInitialized
+ }
+
+ s.initialized = true
+
+ cfg.Submodules[s.c.Name] = s.c
+ return s.w.r.Storer.SetConfig(cfg)
+}
+
+// Status returns the status of the submodule.
+func (s *Submodule) Status() (*SubmoduleStatus, error) {
+ idx, err := s.w.r.Storer.Index()
+ if err != nil {
+ return nil, err
+ }
+
+ return s.status(idx)
+}
+
+func (s *Submodule) status(idx *index.Index) (*SubmoduleStatus, error) {
+ status := &SubmoduleStatus{
+ Path: s.c.Path,
+ }
+
+ e, err := idx.Entry(s.c.Path)
+ if err != nil && err != index.ErrEntryNotFound {
+ return nil, err
+ }
+
+ if e != nil {
+ status.Expected = e.Hash
+ }
+
+ if !s.initialized {
+ return status, nil
+ }
+
+ r, err := s.Repository()
+ if err != nil {
+ return nil, err
+ }
+
+ head, err := r.Head()
+ if err == nil {
+ status.Current = head.Hash()
+ }
+
+ if err != nil && err == plumbing.ErrReferenceNotFound {
+ err = nil
+ }
+
+ return status, err
+}
+
+// Repository returns the Repository represented by this submodule
+func (s *Submodule) Repository() (*Repository, error) {
+ if !s.initialized {
+ return nil, ErrSubmoduleNotInitialized
+ }
+
+ storer, err := s.w.r.Storer.Module(s.c.Name)
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = storer.Reference(plumbing.HEAD)
+ if err != nil && err != plumbing.ErrReferenceNotFound {
+ return nil, err
+ }
+
+ var exists bool
+ if err == nil {
+ exists = true
+ }
+
+ var worktree billy.Filesystem
+ if worktree, err = s.w.Filesystem.Chroot(s.c.Path); err != nil {
+ return nil, err
+ }
+
+ if exists {
+ return Open(storer, worktree)
+ }
+
+ r, err := Init(storer, worktree)
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = r.CreateRemote(&config.RemoteConfig{
+ Name: DefaultRemoteName,
+ URLs: []string{s.c.URL},
+ })
+
+ return r, err
+}
+
+// Update the registered submodule to match what the superproject expects, the
+// submodule should be initialized first calling the Init method or setting in
+// the options SubmoduleUpdateOptions.Init equals true
+func (s *Submodule) Update(o *SubmoduleUpdateOptions) error {
+ return s.UpdateContext(context.Background(), o)
+}
+
+// UpdateContext the registered submodule to match what the superproject
+// expects, the submodule should be initialized first calling the Init method or
+// setting in the options SubmoduleUpdateOptions.Init equals true.
+//
+// The provided Context must be non-nil. If the context expires before the
+// operation is complete, an error is returned. The context only affects to the
+// transport operations.
+func (s *Submodule) UpdateContext(ctx context.Context, o *SubmoduleUpdateOptions) error {
+ return s.update(ctx, o, plumbing.ZeroHash)
+}
+
+func (s *Submodule) update(ctx context.Context, o *SubmoduleUpdateOptions, forceHash plumbing.Hash) error {
+ if !s.initialized && !o.Init {
+ return ErrSubmoduleNotInitialized
+ }
+
+ if !s.initialized && o.Init {
+ if err := s.Init(); err != nil {
+ return err
+ }
+ }
+
+ idx, err := s.w.r.Storer.Index()
+ if err != nil {
+ return err
+ }
+
+ hash := forceHash
+ if hash.IsZero() {
+ e, err := idx.Entry(s.c.Path)
+ if err != nil {
+ return err
+ }
+
+ hash = e.Hash
+ }
+
+ r, err := s.Repository()
+ if err != nil {
+ return err
+ }
+
+ if err := s.fetchAndCheckout(ctx, r, o, hash); err != nil {
+ return err
+ }
+
+ return s.doRecursiveUpdate(r, o)
+}
+
+func (s *Submodule) doRecursiveUpdate(r *Repository, o *SubmoduleUpdateOptions) error {
+ if o.RecurseSubmodules == NoRecurseSubmodules {
+ return nil
+ }
+
+ w, err := r.Worktree()
+ if err != nil {
+ return err
+ }
+
+ l, err := w.Submodules()
+ if err != nil {
+ return err
+ }
+
+ new := &SubmoduleUpdateOptions{}
+ *new = *o
+
+ new.RecurseSubmodules--
+ return l.Update(new)
+}
+
+func (s *Submodule) fetchAndCheckout(
+ ctx context.Context, r *Repository, o *SubmoduleUpdateOptions, hash plumbing.Hash,
+) error {
+ if !o.NoFetch {
+ err := r.FetchContext(ctx, &FetchOptions{Auth: o.Auth})
+ if err != nil && err != NoErrAlreadyUpToDate {
+ return err
+ }
+ }
+
+ w, err := r.Worktree()
+ if err != nil {
+ return err
+ }
+
+ if err := w.Checkout(&CheckoutOptions{Hash: hash}); err != nil {
+ return err
+ }
+
+ head := plumbing.NewHashReference(plumbing.HEAD, hash)
+ return r.Storer.SetReference(head)
+}
+
+// Submodules list of several submodules from the same repository.
+type Submodules []*Submodule
+
+// Init initializes the submodules in this list.
+func (s Submodules) Init() error {
+ for _, sub := range s {
+ if err := sub.Init(); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Update updates all the submodules in this list.
+func (s Submodules) Update(o *SubmoduleUpdateOptions) error {
+ return s.UpdateContext(context.Background(), o)
+}
+
+// UpdateContext updates all the submodules in this list.
+//
+// The provided Context must be non-nil. If the context expires before the
+// operation is complete, an error is returned. The context only affects to the
+// transport operations.
+func (s Submodules) UpdateContext(ctx context.Context, o *SubmoduleUpdateOptions) error {
+ for _, sub := range s {
+ if err := sub.UpdateContext(ctx, o); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Status returns the status of the submodules.
+func (s Submodules) Status() (SubmodulesStatus, error) {
+ var list SubmodulesStatus
+
+ var r *Repository
+ for _, sub := range s {
+ if r == nil {
+ r = sub.w.r
+ }
+
+ idx, err := r.Storer.Index()
+ if err != nil {
+ return nil, err
+ }
+
+ status, err := sub.status(idx)
+ if err != nil {
+ return nil, err
+ }
+
+ list = append(list, status)
+ }
+
+ return list, nil
+}
+
+// SubmodulesStatus contains the status for all submodiles in the worktree
+type SubmodulesStatus []*SubmoduleStatus
+
+// String is equivalent to `git submodule status`
+func (s SubmodulesStatus) String() string {
+ buf := bytes.NewBuffer(nil)
+ for _, sub := range s {
+ fmt.Fprintln(buf, sub)
+ }
+
+ return buf.String()
+}
+
+// SubmoduleStatus contains the status for a submodule in the worktree
+type SubmoduleStatus struct {
+ Path string
+ Current plumbing.Hash
+ Expected plumbing.Hash
+ Branch plumbing.ReferenceName
+}
+
+// IsClean is the HEAD of the submodule is equals to the expected commit
+func (s *SubmoduleStatus) IsClean() bool {
+ return s.Current == s.Expected
+}
+
+// String is equivalent to `git submodule status <submodule>`
+//
+// This will print the SHA-1 of the currently checked out commit for a
+// submodule, along with the submodule path and the output of git describe fo
+// the SHA-1. Each SHA-1 will be prefixed with - if the submodule is not
+// initialized, + if the currently checked out submodule commit does not match
+// the SHA-1 found in the index of the containing repository.
+func (s *SubmoduleStatus) String() string {
+ var extra string
+ var status = ' '
+
+ if s.Current.IsZero() {
+ status = '-'
+ } else if !s.IsClean() {
+ status = '+'
+ }
+
+ if len(s.Branch) != 0 {
+ extra = string(s.Branch[5:])
+ } else if !s.Current.IsZero() {
+ extra = s.Current.String()[:7]
+ }
+
+ if extra != "" {
+ extra = fmt.Sprintf(" (%s)", extra)
+ }
+
+ return fmt.Sprintf("%c%s %s%s", status, s.Expected, s.Path, extra)
+}
--- /dev/null
+// Package binary implements sintax-sugar functions on top of the standard
+// library binary package
+package binary
+
+import (
+ "bufio"
+ "encoding/binary"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+)
+
+// Read reads structured binary data from r into data. Bytes are read and
+// decoded in BigEndian order
+// https://golang.org/pkg/encoding/binary/#Read
+func Read(r io.Reader, data ...interface{}) error {
+ for _, v := range data {
+ if err := binary.Read(r, binary.BigEndian, v); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// ReadUntil reads from r untin delim is found
+func ReadUntil(r io.Reader, delim byte) ([]byte, error) {
+ var buf [1]byte
+ value := make([]byte, 0, 16)
+ for {
+ if _, err := io.ReadFull(r, buf[:]); err != nil {
+ if err == io.EOF {
+ return nil, err
+ }
+
+ return nil, err
+ }
+
+ if buf[0] == delim {
+ return value, nil
+ }
+
+ value = append(value, buf[0])
+ }
+}
+
+// ReadVariableWidthInt reads and returns an int in Git VLQ special format:
+//
+// Ordinary VLQ has some redundancies, example: the number 358 can be
+// encoded as the 2-octet VLQ 0x8166 or the 3-octet VLQ 0x808166 or the
+// 4-octet VLQ 0x80808166 and so forth.
+//
+// To avoid these redundancies, the VLQ format used in Git removes this
+// prepending redundancy and extends the representable range of shorter
+// VLQs by adding an offset to VLQs of 2 or more octets in such a way
+// that the lowest possible value for such an (N+1)-octet VLQ becomes
+// exactly one more than the maximum possible value for an N-octet VLQ.
+// In particular, since a 1-octet VLQ can store a maximum value of 127,
+// the minimum 2-octet VLQ (0x8000) is assigned the value 128 instead of
+// 0. Conversely, the maximum value of such a 2-octet VLQ (0xff7f) is
+// 16511 instead of just 16383. Similarly, the minimum 3-octet VLQ
+// (0x808000) has a value of 16512 instead of zero, which means
+// that the maximum 3-octet VLQ (0xffff7f) is 2113663 instead of
+// just 2097151. And so forth.
+//
+// This is how the offset is saved in C:
+//
+// dheader[pos] = ofs & 127;
+// while (ofs >>= 7)
+// dheader[--pos] = 128 | (--ofs & 127);
+//
+func ReadVariableWidthInt(r io.Reader) (int64, error) {
+ var c byte
+ if err := Read(r, &c); err != nil {
+ return 0, err
+ }
+
+ var v = int64(c & maskLength)
+ for c&maskContinue > 0 {
+ v++
+ if err := Read(r, &c); err != nil {
+ return 0, err
+ }
+
+ v = (v << lengthBits) + int64(c&maskLength)
+ }
+
+ return v, nil
+}
+
+const (
+ maskContinue = uint8(128) // 1000 000
+ maskLength = uint8(127) // 0111 1111
+ lengthBits = uint8(7) // subsequent bytes has 7 bits to store the length
+)
+
+// ReadUint64 reads 8 bytes and returns them as a BigEndian uint32
+func ReadUint64(r io.Reader) (uint64, error) {
+ var v uint64
+ if err := binary.Read(r, binary.BigEndian, &v); err != nil {
+ return 0, err
+ }
+
+ return v, nil
+}
+
+// ReadUint32 reads 4 bytes and returns them as a BigEndian uint32
+func ReadUint32(r io.Reader) (uint32, error) {
+ var v uint32
+ if err := binary.Read(r, binary.BigEndian, &v); err != nil {
+ return 0, err
+ }
+
+ return v, nil
+}
+
+// ReadUint16 reads 2 bytes and returns them as a BigEndian uint16
+func ReadUint16(r io.Reader) (uint16, error) {
+ var v uint16
+ if err := binary.Read(r, binary.BigEndian, &v); err != nil {
+ return 0, err
+ }
+
+ return v, nil
+}
+
+// ReadHash reads a plumbing.Hash from r
+func ReadHash(r io.Reader) (plumbing.Hash, error) {
+ var h plumbing.Hash
+ if err := binary.Read(r, binary.BigEndian, h[:]); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ return h, nil
+}
+
+const sniffLen = 8000
+
+// IsBinary detects if data is a binary value based on:
+// http://git.kernel.org/cgit/git/git.git/tree/xdiff-interface.c?id=HEAD#n198
+func IsBinary(r io.Reader) (bool, error) {
+ reader := bufio.NewReader(r)
+ c := 0
+ for {
+ if c == sniffLen {
+ break
+ }
+
+ b, err := reader.ReadByte()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return false, err
+ }
+
+ if b == byte(0) {
+ return true, nil
+ }
+
+ c++
+ }
+
+ return false, nil
+}
--- /dev/null
+package binary
+
+import (
+ "encoding/binary"
+ "io"
+)
+
+// Write writes the binary representation of data into w, using BigEndian order
+// https://golang.org/pkg/encoding/binary/#Write
+func Write(w io.Writer, data ...interface{}) error {
+ for _, v := range data {
+ if err := binary.Write(w, binary.BigEndian, v); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func WriteVariableWidthInt(w io.Writer, n int64) error {
+ buf := []byte{byte(n & 0x7f)}
+ n >>= 7
+ for n != 0 {
+ n--
+ buf = append([]byte{0x80 | (byte(n & 0x7f))}, buf...)
+ n >>= 7
+ }
+
+ _, err := w.Write(buf)
+
+ return err
+}
+
+// WriteUint64 writes the binary representation of a uint64 into w, in BigEndian
+// order
+func WriteUint64(w io.Writer, value uint64) error {
+ return binary.Write(w, binary.BigEndian, value)
+}
+
+// WriteUint32 writes the binary representation of a uint32 into w, in BigEndian
+// order
+func WriteUint32(w io.Writer, value uint32) error {
+ return binary.Write(w, binary.BigEndian, value)
+}
+
+// WriteUint16 writes the binary representation of a uint16 into w, in BigEndian
+// order
+func WriteUint16(w io.Writer, value uint16) error {
+ return binary.Write(w, binary.BigEndian, value)
+}
--- /dev/null
+// Package diff implements line oriented diffs, similar to the ancient
+// Unix diff command.
+//
+// The current implementation is just a wrapper around Sergi's
+// go-diff/diffmatchpatch library, which is a go port of Neil
+// Fraser's google-diff-match-patch code
+package diff
+
+import (
+ "bytes"
+
+ "github.com/sergi/go-diff/diffmatchpatch"
+)
+
+// Do computes the (line oriented) modifications needed to turn the src
+// string into the dst string.
+func Do(src, dst string) (diffs []diffmatchpatch.Diff) {
+ dmp := diffmatchpatch.New()
+ wSrc, wDst, warray := dmp.DiffLinesToRunes(src, dst)
+ diffs = dmp.DiffMainRunes(wSrc, wDst, false)
+ diffs = dmp.DiffCharsToLines(diffs, warray)
+ return diffs
+}
+
+// Dst computes and returns the destination text.
+func Dst(diffs []diffmatchpatch.Diff) string {
+ var text bytes.Buffer
+ for _, d := range diffs {
+ if d.Type != diffmatchpatch.DiffDelete {
+ text.WriteString(d.Text)
+ }
+ }
+ return text.String()
+}
+
+// Src computes and returns the source text
+func Src(diffs []diffmatchpatch.Diff) string {
+ var text bytes.Buffer
+ for _, d := range diffs {
+ if d.Type != diffmatchpatch.DiffInsert {
+ text.WriteString(d.Text)
+ }
+ }
+ return text.String()
+}
--- /dev/null
+// Package ioutil implements some I/O utility functions.
+package ioutil
+
+import (
+ "bufio"
+ "context"
+ "errors"
+ "io"
+
+ "github.com/jbenet/go-context/io"
+)
+
+type readPeeker interface {
+ io.Reader
+ Peek(int) ([]byte, error)
+}
+
+var (
+ ErrEmptyReader = errors.New("reader is empty")
+)
+
+// NonEmptyReader takes a reader and returns it if it is not empty, or
+// `ErrEmptyReader` if it is empty. If there is an error when reading the first
+// byte of the given reader, it will be propagated.
+func NonEmptyReader(r io.Reader) (io.Reader, error) {
+ pr, ok := r.(readPeeker)
+ if !ok {
+ pr = bufio.NewReader(r)
+ }
+
+ _, err := pr.Peek(1)
+ if err == io.EOF {
+ return nil, ErrEmptyReader
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ return pr, nil
+}
+
+type readCloser struct {
+ io.Reader
+ closer io.Closer
+}
+
+func (r *readCloser) Close() error {
+ return r.closer.Close()
+}
+
+// NewReadCloser creates an `io.ReadCloser` with the given `io.Reader` and
+// `io.Closer`.
+func NewReadCloser(r io.Reader, c io.Closer) io.ReadCloser {
+ return &readCloser{Reader: r, closer: c}
+}
+
+type writeCloser struct {
+ io.Writer
+ closer io.Closer
+}
+
+func (r *writeCloser) Close() error {
+ return r.closer.Close()
+}
+
+// NewWriteCloser creates an `io.WriteCloser` with the given `io.Writer` and
+// `io.Closer`.
+func NewWriteCloser(w io.Writer, c io.Closer) io.WriteCloser {
+ return &writeCloser{Writer: w, closer: c}
+}
+
+type writeNopCloser struct {
+ io.Writer
+}
+
+func (writeNopCloser) Close() error { return nil }
+
+// WriteNopCloser returns a WriteCloser with a no-op Close method wrapping
+// the provided Writer w.
+func WriteNopCloser(w io.Writer) io.WriteCloser {
+ return writeNopCloser{w}
+}
+
+// CheckClose calls Close on the given io.Closer. If the given *error points to
+// nil, it will be assigned the error returned by Close. Otherwise, any error
+// returned by Close will be ignored. CheckClose is usually called with defer.
+func CheckClose(c io.Closer, err *error) {
+ if cerr := c.Close(); cerr != nil && *err == nil {
+ *err = cerr
+ }
+}
+
+// NewContextWriter wraps a writer to make it respect given Context.
+// If there is a blocking write, the returned Writer will return whenever the
+// context is cancelled (the return values are n=0 and err=ctx.Err()).
+func NewContextWriter(ctx context.Context, w io.Writer) io.Writer {
+ return ctxio.NewWriter(ctx, w)
+}
+
+// NewContextReader wraps a reader to make it respect given Context.
+// If there is a blocking read, the returned Reader will return whenever the
+// context is cancelled (the return values are n=0 and err=ctx.Err()).
+func NewContextReader(ctx context.Context, r io.Reader) io.Reader {
+ return ctxio.NewReader(ctx, r)
+}
+
+// NewContextWriteCloser as NewContextWriter but with io.Closer interface.
+func NewContextWriteCloser(ctx context.Context, w io.WriteCloser) io.WriteCloser {
+ ctxw := ctxio.NewWriter(ctx, w)
+ return NewWriteCloser(ctxw, w)
+}
+
+// NewContextReadCloser as NewContextReader but with io.Closer interface.
+func NewContextReadCloser(ctx context.Context, r io.ReadCloser) io.ReadCloser {
+ ctxr := ctxio.NewReader(ctx, r)
+ return NewReadCloser(ctxr, r)
+}
+
+type readerOnError struct {
+ io.Reader
+ notify func(error)
+}
+
+// NewReaderOnError returns a io.Reader that call the notify function when an
+// unexpected (!io.EOF) error happens, after call Read function.
+func NewReaderOnError(r io.Reader, notify func(error)) io.Reader {
+ return &readerOnError{r, notify}
+}
+
+// NewReadCloserOnError returns a io.ReadCloser that call the notify function
+// when an unexpected (!io.EOF) error happens, after call Read function.
+func NewReadCloserOnError(r io.ReadCloser, notify func(error)) io.ReadCloser {
+ return NewReadCloser(NewReaderOnError(r, notify), r)
+}
+
+func (r *readerOnError) Read(buf []byte) (n int, err error) {
+ n, err = r.Reader.Read(buf)
+ if err != nil && err != io.EOF {
+ r.notify(err)
+ }
+
+ return
+}
+
+type writerOnError struct {
+ io.Writer
+ notify func(error)
+}
+
+// NewWriterOnError returns a io.Writer that call the notify function when an
+// unexpected (!io.EOF) error happens, after call Write function.
+func NewWriterOnError(w io.Writer, notify func(error)) io.Writer {
+ return &writerOnError{w, notify}
+}
+
+// NewWriteCloserOnError returns a io.WriteCloser that call the notify function
+//when an unexpected (!io.EOF) error happens, after call Write function.
+func NewWriteCloserOnError(w io.WriteCloser, notify func(error)) io.WriteCloser {
+ return NewWriteCloser(NewWriterOnError(w, notify), w)
+}
+
+func (r *writerOnError) Write(p []byte) (n int, err error) {
+ n, err = r.Writer.Write(p)
+ if err != nil && err != io.EOF {
+ r.notify(err)
+ }
+
+ return
+}
--- /dev/null
+package merkletrie
+
+import (
+ "fmt"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie/noder"
+)
+
+// Action values represent the kind of things a Change can represent:
+// insertion, deletions or modifications of files.
+type Action int
+
+// The set of possible actions in a change.
+const (
+ _ Action = iota
+ Insert
+ Delete
+ Modify
+)
+
+// String returns the action as a human readable text.
+func (a Action) String() string {
+ switch a {
+ case Insert:
+ return "Insert"
+ case Delete:
+ return "Delete"
+ case Modify:
+ return "Modify"
+ default:
+ panic(fmt.Sprintf("unsupported action: %d", a))
+ }
+}
+
+// A Change value represent how a noder has change between to merkletries.
+type Change struct {
+ // The noder before the change or nil if it was inserted.
+ From noder.Path
+ // The noder after the change or nil if it was deleted.
+ To noder.Path
+}
+
+// Action is convenience method that returns what Action c represents.
+func (c *Change) Action() (Action, error) {
+ if c.From == nil && c.To == nil {
+ return Action(0), fmt.Errorf("malformed change: nil from and to")
+ }
+ if c.From == nil {
+ return Insert, nil
+ }
+ if c.To == nil {
+ return Delete, nil
+ }
+
+ return Modify, nil
+}
+
+// NewInsert returns a new Change representing the insertion of n.
+func NewInsert(n noder.Path) Change { return Change{To: n} }
+
+// NewDelete returns a new Change representing the deletion of n.
+func NewDelete(n noder.Path) Change { return Change{From: n} }
+
+// NewModify returns a new Change representing that a has been modified and
+// it is now b.
+func NewModify(a, b noder.Path) Change {
+ return Change{
+ From: a,
+ To: b,
+ }
+}
+
+// String returns a single change in human readable form, using the
+// format: '<' + action + space + path + '>'. The contents of the file
+// before or after the change are not included in this format.
+//
+// Example: inserting a file at the path a/b/c.txt will return "<Insert
+// a/b/c.txt>".
+func (c Change) String() string {
+ action, err := c.Action()
+ if err != nil {
+ panic(err)
+ }
+
+ var path string
+ if action == Delete {
+ path = c.From.String()
+ } else {
+ path = c.To.String()
+ }
+
+ return fmt.Sprintf("<%s %s>", action, path)
+}
+
+// Changes is a list of changes between to merkletries.
+type Changes []Change
+
+// NewChanges returns an empty list of changes.
+func NewChanges() Changes {
+ return Changes{}
+}
+
+// Add adds the change c to the list of changes.
+func (l *Changes) Add(c Change) {
+ *l = append(*l, c)
+}
+
+// AddRecursiveInsert adds the required changes to insert all the
+// file-like noders found in root, recursively.
+func (l *Changes) AddRecursiveInsert(root noder.Path) error {
+ return l.addRecursive(root, NewInsert)
+}
+
+// AddRecursiveDelete adds the required changes to delete all the
+// file-like noders found in root, recursively.
+func (l *Changes) AddRecursiveDelete(root noder.Path) error {
+ return l.addRecursive(root, NewDelete)
+}
+
+type noderToChangeFn func(noder.Path) Change // NewInsert or NewDelete
+
+func (l *Changes) addRecursive(root noder.Path, ctor noderToChangeFn) error {
+ if !root.IsDir() {
+ l.Add(ctor(root))
+ return nil
+ }
+
+ i, err := NewIterFromPath(root)
+ if err != nil {
+ return err
+ }
+
+ var current noder.Path
+ for {
+ if current, err = i.Step(); err != nil {
+ if err == io.EOF {
+ break
+ }
+ return err
+ }
+ if current.IsDir() {
+ continue
+ }
+ l.Add(ctor(current))
+ }
+
+ return nil
+}
--- /dev/null
+package merkletrie
+
+// The focus of this difftree implementation is to save time by
+// skipping whole directories if their hash is the same in both
+// trees.
+//
+// The diff algorithm implemented here is based on the doubleiter
+// type defined in this same package; we will iterate over both
+// trees at the same time, while comparing the current noders in
+// each iterator. Depending on how they differ we will output the
+// corresponding chages and move the iterators further over both
+// trees.
+//
+// The table bellow show all the possible comparison results, along
+// with what changes should we produce and how to advance the
+// iterators.
+//
+// The table is implemented by the switches in this function,
+// diffTwoNodes, diffTwoNodesSameName and diffTwoDirs.
+//
+// Many Bothans died to bring us this information, make sure you
+// understand the table before modifying this code.
+
+// # Cases
+//
+// When comparing noders in both trees you will found yourself in
+// one of 169 possible cases, but if we ignore moves, we can
+// simplify a lot the search space into the following table:
+//
+// - "-": nothing, no file or directory
+// - a<>: an empty file named "a".
+// - a<1>: a file named "a", with "1" as its contents.
+// - a<2>: a file named "a", with "2" as its contents.
+// - a(): an empty dir named "a".
+// - a(...): a dir named "a", with some files and/or dirs inside (possibly
+// empty).
+// - a(;;;): a dir named "a", with some other files and/or dirs inside
+// (possibly empty), which different from the ones in "a(...)".
+//
+// \ to - a<> a<1> a<2> a() a(...) a(;;;)
+// from \
+// - 00 01 02 03 04 05 06
+// a<> 10 11 12 13 14 15 16
+// a<1> 20 21 22 23 24 25 26
+// a<2> 30 31 32 33 34 35 36
+// a() 40 41 42 43 44 45 46
+// a(...) 50 51 52 53 54 55 56
+// a(;;;) 60 61 62 63 64 65 66
+//
+// Every (from, to) combination in the table is a special case, but
+// some of them can be merged into some more general cases, for
+// instance 11 and 22 can be merged into the general case: both
+// noders are equal.
+//
+// Here is a full list of all the cases that are similar and how to
+// merge them together into more general cases. Each general case
+// is labeled with an uppercase letter for further reference, and it
+// is followed by the pseudocode of the checks you have to perfrom
+// on both noders to see if you are in such a case, the actions to
+// perform (i.e. what changes to output) and how to advance the
+// iterators of each tree to continue the comparison process.
+//
+// ## A. Impossible: 00
+//
+// ## B. Same thing on both sides: 11, 22, 33, 44, 55, 66
+// - check: `SameName() && SameHash()`
+// - action: do nothing.
+// - advance: `FromNext(); ToNext()`
+//
+// ### C. To was created: 01, 02, 03, 04, 05, 06
+// - check: `DifferentName() && ToBeforeFrom()`
+// - action: inserRecursively(to)
+// - advance: `ToNext()`
+//
+// ### D. From was deleted: 10, 20, 30, 40, 50, 60
+// - check: `DifferentName() && FromBeforeTo()`
+// - action: `DeleteRecursively(from)`
+// - advance: `FromNext()`
+//
+// ### E. Empty file to file with contents: 12, 13
+// - check: `SameName() && DifferentHash() && FromIsFile() &&
+// ToIsFile() && FromIsEmpty()`
+// - action: `modifyFile(from, to)`
+// - advance: `FromNext()` or `FromStep()`
+//
+// ### E'. file with contents to empty file: 21, 31
+// - check: `SameName() && DifferentHash() && FromIsFile() &&
+// ToIsFile() && ToIsEmpty()`
+// - action: `modifyFile(from, to)`
+// - advance: `FromNext()` or `FromStep()`
+//
+// ### F. empty file to empty dir with the same name: 14
+// - check: `SameName() && FromIsFile() && FromIsEmpty() &&
+// ToIsDir() && ToIsEmpty()`
+// - action: `DeleteFile(from); InsertEmptyDir(to)`
+// - advance: `FromNext(); ToNext()`
+//
+// ### F'. empty dir to empty file of the same name: 41
+// - check: `SameName() && FromIsDir() && FromIsEmpty &&
+// ToIsFile() && ToIsEmpty()`
+// - action: `DeleteEmptyDir(from); InsertFile(to)`
+// - advance: `FromNext(); ToNext()` or step for any of them.
+//
+// ### G. empty file to non-empty dir of the same name: 15, 16
+// - check: `SameName() && FromIsFile() && ToIsDir() &&
+// FromIsEmpty() && ToIsNotEmpty()`
+// - action: `DeleteFile(from); InsertDirRecursively(to)`
+// - advance: `FromNext(); ToNext()`
+//
+// ### G'. non-empty dir to empty file of the same name: 51, 61
+// - check: `SameName() && FromIsDir() && FromIsNotEmpty() &&
+// ToIsFile() && FromIsEmpty()`
+// - action: `DeleteDirRecursively(from); InsertFile(to)`
+// - advance: `FromNext(); ToNext()`
+//
+// ### H. modify file contents: 23, 32
+// - check: `SameName() && FromIsFile() && ToIsFile() &&
+// FromIsNotEmpty() && ToIsNotEmpty()`
+// - action: `ModifyFile(from, to)`
+// - advance: `FromNext(); ToNext()`
+//
+// ### I. file with contents to empty dir: 24, 34
+// - check: `SameName() && DifferentHash() && FromIsFile() &&
+// FromIsNotEmpty() && ToIsDir() && ToIsEmpty()`
+// - action: `DeleteFile(from); InsertEmptyDir(to)`
+// - advance: `FromNext(); ToNext()`
+//
+// ### I'. empty dir to file with contents: 42, 43
+// - check: `SameName() && DifferentHash() && FromIsDir() &&
+// FromIsEmpty() && ToIsFile() && ToIsEmpty()`
+// - action: `DeleteDir(from); InsertFile(to)`
+// - advance: `FromNext(); ToNext()`
+//
+// ### J. file with contents to dir with contetns: 25, 26, 35, 36
+// - check: `SameName() && DifferentHash() && FromIsFile() &&
+// FromIsNotEmpty() && ToIsDir() && ToIsNotEmpty()`
+// - action: `DeleteFile(from); InsertDirRecursively(to)`
+// - advance: `FromNext(); ToNext()`
+//
+// ### J'. dir with contetns to file with contents: 52, 62, 53, 63
+// - check: `SameName() && DifferentHash() && FromIsDir() &&
+// FromIsNotEmpty() && ToIsFile() && ToIsNotEmpty()`
+// - action: `DeleteDirRecursively(from); InsertFile(to)`
+// - advance: `FromNext(); ToNext()`
+//
+// ### K. empty dir to dir with contents: 45, 46
+// - check: `SameName() && DifferentHash() && FromIsDir() &&
+// FromIsEmpty() && ToIsDir() && ToIsNotEmpty()`
+// - action: `InsertChildrenRecursively(to)`
+// - advance: `FromNext(); ToNext()`
+//
+// ### K'. dir with contents to empty dir: 54, 64
+// - check: `SameName() && DifferentHash() && FromIsDir() &&
+// FromIsEmpty() && ToIsDir() && ToIsNotEmpty()`
+// - action: `DeleteChildrenRecursively(from)`
+// - advance: `FromNext(); ToNext()`
+//
+// ### L. dir with contents to dir with different contents: 56, 65
+// - check: `SameName() && DifferentHash() && FromIsDir() &&
+// FromIsNotEmpty() && ToIsDir() && ToIsNotEmpty()`
+// - action: nothing
+// - advance: `FromStep(); ToStep()`
+//
+//
+
+// All these cases can be further simplified by a truth table
+// reduction process, in which we gather similar checks together to
+// make the final code easier to read and understand.
+//
+// The first 6 columns are the outputs of the checks to perform on
+// both noders. I have labeled them 1 to 6, this is what they mean:
+//
+// 1: SameName()
+// 2: SameHash()
+// 3: FromIsDir()
+// 4: ToIsDir()
+// 5: FromIsEmpty()
+// 6: ToIsEmpty()
+//
+// The from and to columns are a fsnoder example of the elements
+// that you will find on each tree under the specified comparison
+// results (columns 1 to 6).
+//
+// The type column identifies the case we are into, from the list above.
+//
+// The type' column identifies the new set of reduced cases, using
+// lowercase letters, and they are explained after the table.
+//
+// The last column is the set of actions and advances for each case.
+//
+// "---" means impossible except in case of hash collision.
+//
+// advance meaning:
+// - NN: from.Next(); to.Next()
+// - SS: from.Step(); to.Step()
+//
+// 1 2 3 4 5 6 | from | to |type|type'|action ; advance
+// ------------+--------+--------+----+------------------------------------
+// 0 0 0 0 0 0 | | | | | if !SameName() {
+// . | | | | | if FromBeforeTo() {
+// . | | | D | d | delete(from); from.Next()
+// . | | | | | } else {
+// . | | | C | c | insert(to); to.Next()
+// . | | | | | }
+// 0 1 1 1 1 1 | | | | | }
+// 1 0 0 0 0 0 | a<1> | a<2> | H | e | modify(from, to); NN
+// 1 0 0 0 0 1 | a<1> | a<> | E' | e | modify(from, to); NN
+// 1 0 0 0 1 0 | a<> | a<1> | E | e | modify(from, to); NN
+// 1 0 0 0 1 1 | ---- | ---- | | e |
+// 1 0 0 1 0 0 | a<1> | a(...) | J | f | delete(from); insert(to); NN
+// 1 0 0 1 0 1 | a<1> | a() | I | f | delete(from); insert(to); NN
+// 1 0 0 1 1 0 | a<> | a(...) | G | f | delete(from); insert(to); NN
+// 1 0 0 1 1 1 | a<> | a() | F | f | delete(from); insert(to); NN
+// 1 0 1 0 0 0 | a(...) | a<1> | J' | f | delete(from); insert(to); NN
+// 1 0 1 0 0 1 | a(...) | a<> | G' | f | delete(from); insert(to); NN
+// 1 0 1 0 1 0 | a() | a<1> | I' | f | delete(from); insert(to); NN
+// 1 0 1 0 1 1 | a() | a<> | F' | f | delete(from); insert(to); NN
+// 1 0 1 1 0 0 | a(...) | a(;;;) | L | g | nothing; SS
+// 1 0 1 1 0 1 | a(...) | a() | K' | h | deleteChidren(from); NN
+// 1 0 1 1 1 0 | a() | a(...) | K | i | insertChildren(to); NN
+// 1 0 1 1 1 1 | ---- | ---- | | |
+// 1 1 0 0 0 0 | a<1> | a<1> | B | b | nothing; NN
+// 1 1 0 0 0 1 | ---- | ---- | | b |
+// 1 1 0 0 1 0 | ---- | ---- | | b |
+// 1 1 0 0 1 1 | a<> | a<> | B | b | nothing; NN
+// 1 1 0 1 0 0 | ---- | ---- | | b |
+// 1 1 0 1 0 1 | ---- | ---- | | b |
+// 1 1 0 1 1 0 | ---- | ---- | | b |
+// 1 1 0 1 1 1 | ---- | ---- | | b |
+// 1 1 1 0 0 0 | ---- | ---- | | b |
+// 1 1 1 0 0 1 | ---- | ---- | | b |
+// 1 1 1 0 1 0 | ---- | ---- | | b |
+// 1 1 1 0 1 1 | ---- | ---- | | b |
+// 1 1 1 1 0 0 | a(...) | a(...) | B | b | nothing; NN
+// 1 1 1 1 0 1 | ---- | ---- | | b |
+// 1 1 1 1 1 0 | ---- | ---- | | b |
+// 1 1 1 1 1 1 | a() | a() | B | b | nothing; NN
+//
+// c and d:
+// if !SameName()
+// d if FromBeforeTo()
+// c else
+// b: SameName) && sameHash()
+// e: SameName() && !sameHash() && BothAreFiles()
+// f: SameName() && !sameHash() && FileAndDir()
+// g: SameName() && !sameHash() && BothAreDirs() && NoneIsEmpty
+// i: SameName() && !sameHash() && BothAreDirs() && FromIsEmpty
+// h: else of i
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie/noder"
+)
+
+var (
+ ErrCanceled = errors.New("operation canceled")
+)
+
+// DiffTree calculates the list of changes between two merkletries. It
+// uses the provided hashEqual callback to compare noders.
+func DiffTree(fromTree, toTree noder.Noder,
+ hashEqual noder.Equal) (Changes, error) {
+ return DiffTreeContext(context.Background(), fromTree, toTree, hashEqual)
+}
+
+// DiffTree calculates the list of changes between two merkletries. It
+// uses the provided hashEqual callback to compare noders.
+// Error will be returned if context expires
+// Provided context must be non nil
+func DiffTreeContext(ctx context.Context, fromTree, toTree noder.Noder,
+ hashEqual noder.Equal) (Changes, error) {
+ ret := NewChanges()
+
+ ii, err := newDoubleIter(fromTree, toTree, hashEqual)
+ if err != nil {
+ return nil, err
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ return nil, ErrCanceled
+ default:
+ }
+
+ from := ii.from.current
+ to := ii.to.current
+
+ switch r := ii.remaining(); r {
+ case noMoreNoders:
+ return ret, nil
+ case onlyFromRemains:
+ if err = ret.AddRecursiveDelete(from); err != nil {
+ return nil, err
+ }
+ if err = ii.nextFrom(); err != nil {
+ return nil, err
+ }
+ case onlyToRemains:
+ if err = ret.AddRecursiveInsert(to); err != nil {
+ return nil, err
+ }
+ if err = ii.nextTo(); err != nil {
+ return nil, err
+ }
+ case bothHaveNodes:
+ if err = diffNodes(&ret, ii); err != nil {
+ return nil, err
+ }
+ default:
+ panic(fmt.Sprintf("unknown remaining value: %d", r))
+ }
+ }
+}
+
+func diffNodes(changes *Changes, ii *doubleIter) error {
+ from := ii.from.current
+ to := ii.to.current
+ var err error
+
+ // compare their full paths as strings
+ switch from.Compare(to) {
+ case -1:
+ if err = changes.AddRecursiveDelete(from); err != nil {
+ return err
+ }
+ if err = ii.nextFrom(); err != nil {
+ return err
+ }
+ case 1:
+ if err = changes.AddRecursiveInsert(to); err != nil {
+ return err
+ }
+ if err = ii.nextTo(); err != nil {
+ return err
+ }
+ default:
+ if err := diffNodesSameName(changes, ii); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func diffNodesSameName(changes *Changes, ii *doubleIter) error {
+ from := ii.from.current
+ to := ii.to.current
+
+ status, err := ii.compare()
+ if err != nil {
+ return err
+ }
+
+ switch {
+ case status.sameHash:
+ // do nothing
+ if err = ii.nextBoth(); err != nil {
+ return err
+ }
+ case status.bothAreFiles:
+ changes.Add(NewModify(from, to))
+ if err = ii.nextBoth(); err != nil {
+ return err
+ }
+ case status.fileAndDir:
+ if err = changes.AddRecursiveDelete(from); err != nil {
+ return err
+ }
+ if err = changes.AddRecursiveInsert(to); err != nil {
+ return err
+ }
+ if err = ii.nextBoth(); err != nil {
+ return err
+ }
+ case status.bothAreDirs:
+ if err = diffDirs(changes, ii); err != nil {
+ return err
+ }
+ default:
+ return fmt.Errorf("bad status from double iterator")
+ }
+
+ return nil
+}
+
+func diffDirs(changes *Changes, ii *doubleIter) error {
+ from := ii.from.current
+ to := ii.to.current
+
+ status, err := ii.compare()
+ if err != nil {
+ return err
+ }
+
+ switch {
+ case status.fromIsEmptyDir:
+ if err = changes.AddRecursiveInsert(to); err != nil {
+ return err
+ }
+ if err = ii.nextBoth(); err != nil {
+ return err
+ }
+ case status.toIsEmptyDir:
+ if err = changes.AddRecursiveDelete(from); err != nil {
+ return err
+ }
+ if err = ii.nextBoth(); err != nil {
+ return err
+ }
+ case !status.fromIsEmptyDir && !status.toIsEmptyDir:
+ // do nothing
+ if err = ii.stepBoth(); err != nil {
+ return err
+ }
+ default:
+ return fmt.Errorf("both dirs are empty but has different hash")
+ }
+
+ return nil
+}
--- /dev/null
+/*
+Package merkletrie provides support for n-ary trees that are at the same
+time Merkle trees and Radix trees (tries).
+
+Git trees are Radix n-ary trees in virtue of the names of their
+tree entries. At the same time, git trees are Merkle trees thanks to
+their hashes.
+
+This package defines Merkle tries as nodes that should have:
+
+- a hash: the Merkle part of the Merkle trie
+
+- a key: the Radix part of the Merkle trie
+
+The Merkle hash condition is not enforced by this package though. This
+means that the hash of a node doesn't have to take into account the hashes of
+their children, which is good for testing purposes.
+
+Nodes in the Merkle trie are abstracted by the Noder interface. The
+intended use is that git trees implements this interface, either
+directly or using a simple wrapper.
+
+This package provides an iterator for merkletries that can skip whole
+directory-like noders and an efficient merkletrie comparison algorithm.
+
+When comparing git trees, the simple approach of alphabetically sorting
+their elements and comparing the resulting lists is too slow as it
+depends linearly on the number of files in the trees: When a directory
+has lots of files but none of them has been modified, this approach is
+very expensive. We can do better by prunning whole directories that
+have not change, just by looking at their hashes. This package provides
+the tools to do exactly that.
+*/
+package merkletrie
--- /dev/null
+package merkletrie
+
+import (
+ "fmt"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie/noder"
+)
+
+// A doubleIter is a convenience type to keep track of the current
+// noders in two merkletries that are going to be iterated in parallel.
+// It has methods for:
+//
+// - iterating over the merkletries, both at the same time or
+// individually: nextFrom, nextTo, nextBoth, stepBoth
+//
+// - checking if there are noders left in one or both of them with the
+// remaining method and its associated returned type.
+//
+// - comparing the current noders of both merkletries in several ways,
+// with the compare method and its associated returned type.
+type doubleIter struct {
+ from struct {
+ iter *Iter
+ current noder.Path // nil if no more nodes
+ }
+ to struct {
+ iter *Iter
+ current noder.Path // nil if no more nodes
+ }
+ hashEqual noder.Equal
+}
+
+// NewdoubleIter returns a new doubleIter for the merkletries "from" and
+// "to". The hashEqual callback function will be used by the doubleIter
+// to compare the hash of the noders in the merkletries. The doubleIter
+// will be initialized to the first elements in each merkletrie if any.
+func newDoubleIter(from, to noder.Noder, hashEqual noder.Equal) (
+ *doubleIter, error) {
+ var ii doubleIter
+ var err error
+
+ if ii.from.iter, err = NewIter(from); err != nil {
+ return nil, fmt.Errorf("from: %s", err)
+ }
+ if ii.from.current, err = ii.from.iter.Next(); turnEOFIntoNil(err) != nil {
+ return nil, fmt.Errorf("from: %s", err)
+ }
+
+ if ii.to.iter, err = NewIter(to); err != nil {
+ return nil, fmt.Errorf("to: %s", err)
+ }
+ if ii.to.current, err = ii.to.iter.Next(); turnEOFIntoNil(err) != nil {
+ return nil, fmt.Errorf("to: %s", err)
+ }
+
+ ii.hashEqual = hashEqual
+
+ return &ii, nil
+}
+
+func turnEOFIntoNil(e error) error {
+ if e != nil && e != io.EOF {
+ return e
+ }
+ return nil
+}
+
+// NextBoth makes d advance to the next noder in both merkletries. If
+// any of them is a directory, it skips its contents.
+func (d *doubleIter) nextBoth() error {
+ if err := d.nextFrom(); err != nil {
+ return err
+ }
+ if err := d.nextTo(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// NextFrom makes d advance to the next noder in the "from" merkletrie,
+// skipping its contents if it is a directory.
+func (d *doubleIter) nextFrom() (err error) {
+ d.from.current, err = d.from.iter.Next()
+ return turnEOFIntoNil(err)
+}
+
+// NextTo makes d advance to the next noder in the "to" merkletrie,
+// skipping its contents if it is a directory.
+func (d *doubleIter) nextTo() (err error) {
+ d.to.current, err = d.to.iter.Next()
+ return turnEOFIntoNil(err)
+}
+
+// StepBoth makes d advance to the next noder in both merkletries,
+// getting deeper into directories if that is the case.
+func (d *doubleIter) stepBoth() (err error) {
+ if d.from.current, err = d.from.iter.Step(); turnEOFIntoNil(err) != nil {
+ return err
+ }
+ if d.to.current, err = d.to.iter.Step(); turnEOFIntoNil(err) != nil {
+ return err
+ }
+ return nil
+}
+
+// Remaining returns if there are no more noders in the tree, if both
+// have noders or if one of them doesn't.
+func (d *doubleIter) remaining() remaining {
+ if d.from.current == nil && d.to.current == nil {
+ return noMoreNoders
+ }
+
+ if d.from.current == nil && d.to.current != nil {
+ return onlyToRemains
+ }
+
+ if d.from.current != nil && d.to.current == nil {
+ return onlyFromRemains
+ }
+
+ return bothHaveNodes
+}
+
+// Remaining values tells you whether both trees still have noders, or
+// only one of them or none of them.
+type remaining int
+
+const (
+ noMoreNoders remaining = iota
+ onlyToRemains
+ onlyFromRemains
+ bothHaveNodes
+)
+
+// Compare returns the comparison between the current elements in the
+// merkletries.
+func (d *doubleIter) compare() (s comparison, err error) {
+ s.sameHash = d.hashEqual(d.from.current, d.to.current)
+
+ fromIsDir := d.from.current.IsDir()
+ toIsDir := d.to.current.IsDir()
+
+ s.bothAreDirs = fromIsDir && toIsDir
+ s.bothAreFiles = !fromIsDir && !toIsDir
+ s.fileAndDir = !s.bothAreDirs && !s.bothAreFiles
+
+ fromNumChildren, err := d.from.current.NumChildren()
+ if err != nil {
+ return comparison{}, fmt.Errorf("from: %s", err)
+ }
+
+ toNumChildren, err := d.to.current.NumChildren()
+ if err != nil {
+ return comparison{}, fmt.Errorf("to: %s", err)
+ }
+
+ s.fromIsEmptyDir = fromIsDir && fromNumChildren == 0
+ s.toIsEmptyDir = toIsDir && toNumChildren == 0
+
+ return
+}
+
+// Answers to a lot of questions you can ask about how to noders are
+// equal or different.
+type comparison struct {
+ // the following are only valid if both nodes have the same name
+ // (i.e. nameComparison == 0)
+
+ // Do both nodes have the same hash?
+ sameHash bool
+ // Are both nodes files?
+ bothAreFiles bool
+
+ // the following are only valid if any of the noders are dirs,
+ // this is, if !bothAreFiles
+
+ // Is one a file and the other a dir?
+ fileAndDir bool
+ // Are both nodes dirs?
+ bothAreDirs bool
+ // Is the from node an empty dir?
+ fromIsEmptyDir bool
+ // Is the to Node an empty dir?
+ toIsEmptyDir bool
+}
--- /dev/null
+package filesystem
+
+import (
+ "io"
+ "os"
+ "path"
+
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/filemode"
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie/noder"
+
+ "gopkg.in/src-d/go-billy.v4"
+)
+
+var ignore = map[string]bool{
+ ".git": true,
+}
+
+// The node represents a file or a directory in a billy.Filesystem. It
+// implements the interface noder.Noder of merkletrie package.
+//
+// This implementation implements a "standard" hash method being able to be
+// compared with any other noder.Noder implementation inside of go-git.
+type node struct {
+ fs billy.Filesystem
+ submodules map[string]plumbing.Hash
+
+ path string
+ hash []byte
+ children []noder.Noder
+ isDir bool
+}
+
+// NewRootNode returns the root node based on a given billy.Filesystem.
+//
+// In order to provide the submodule hash status, a map[string]plumbing.Hash
+// should be provided where the key is the path of the submodule and the commit
+// of the submodule HEAD
+func NewRootNode(
+ fs billy.Filesystem,
+ submodules map[string]plumbing.Hash,
+) noder.Noder {
+ return &node{fs: fs, submodules: submodules, isDir: true}
+}
+
+// Hash the hash of a filesystem is the result of concatenating the computed
+// plumbing.Hash of the file as a Blob and its plumbing.FileMode; that way the
+// difftree algorithm will detect changes in the contents of files and also in
+// their mode.
+//
+// The hash of a directory is always a 24-bytes slice of zero values
+func (n *node) Hash() []byte {
+ return n.hash
+}
+
+func (n *node) Name() string {
+ return path.Base(n.path)
+}
+
+func (n *node) IsDir() bool {
+ return n.isDir
+}
+
+func (n *node) Children() ([]noder.Noder, error) {
+ if err := n.calculateChildren(); err != nil {
+ return nil, err
+ }
+
+ return n.children, nil
+}
+
+func (n *node) NumChildren() (int, error) {
+ if err := n.calculateChildren(); err != nil {
+ return -1, err
+ }
+
+ return len(n.children), nil
+}
+
+func (n *node) calculateChildren() error {
+ if !n.IsDir() {
+ return nil
+ }
+
+ if len(n.children) != 0 {
+ return nil
+ }
+
+ files, err := n.fs.ReadDir(n.path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil
+ }
+
+ return nil
+ }
+
+ for _, file := range files {
+ if _, ok := ignore[file.Name()]; ok {
+ continue
+ }
+
+ c, err := n.newChildNode(file)
+ if err != nil {
+ return err
+ }
+
+ n.children = append(n.children, c)
+ }
+
+ return nil
+}
+
+func (n *node) newChildNode(file os.FileInfo) (*node, error) {
+ path := path.Join(n.path, file.Name())
+
+ hash, err := n.calculateHash(path, file)
+ if err != nil {
+ return nil, err
+ }
+
+ node := &node{
+ fs: n.fs,
+ submodules: n.submodules,
+
+ path: path,
+ hash: hash,
+ isDir: file.IsDir(),
+ }
+
+ if hash, isSubmodule := n.submodules[path]; isSubmodule {
+ node.hash = append(hash[:], filemode.Submodule.Bytes()...)
+ node.isDir = false
+ }
+
+ return node, nil
+}
+
+func (n *node) calculateHash(path string, file os.FileInfo) ([]byte, error) {
+ if file.IsDir() {
+ return make([]byte, 24), nil
+ }
+
+ var hash plumbing.Hash
+ var err error
+ if file.Mode()&os.ModeSymlink != 0 {
+ hash, err = n.doCalculateHashForSymlink(path, file)
+ } else {
+ hash, err = n.doCalculateHashForRegular(path, file)
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ mode, err := filemode.NewFromOSFileMode(file.Mode())
+ if err != nil {
+ return nil, err
+ }
+
+ return append(hash[:], mode.Bytes()...), nil
+}
+
+func (n *node) doCalculateHashForRegular(path string, file os.FileInfo) (plumbing.Hash, error) {
+ f, err := n.fs.Open(path)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ defer f.Close()
+
+ h := plumbing.NewHasher(plumbing.BlobObject, file.Size())
+ if _, err := io.Copy(h, f); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ return h.Sum(), nil
+}
+
+func (n *node) doCalculateHashForSymlink(path string, file os.FileInfo) (plumbing.Hash, error) {
+ target, err := n.fs.Readlink(path)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ h := plumbing.NewHasher(plumbing.BlobObject, file.Size())
+ if _, err := h.Write([]byte(target)); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ return h.Sum(), nil
+}
+
+func (n *node) String() string {
+ return n.path
+}
--- /dev/null
+package index
+
+import (
+ "path"
+ "strings"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/format/index"
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie/noder"
+)
+
+// The node represents a index.Entry or a directory inferred from the path
+// of all entries. It implements the interface noder.Noder of merkletrie
+// package.
+//
+// This implementation implements a "standard" hash method being able to be
+// compared with any other noder.Noder implementation inside of go-git
+type node struct {
+ path string
+ entry *index.Entry
+ children []noder.Noder
+ isDir bool
+}
+
+// NewRootNode returns the root node of a computed tree from a index.Index,
+func NewRootNode(idx *index.Index) noder.Noder {
+ const rootNode = ""
+
+ m := map[string]*node{rootNode: {isDir: true}}
+
+ for _, e := range idx.Entries {
+ parts := strings.Split(e.Name, string("/"))
+
+ var fullpath string
+ for _, part := range parts {
+ parent := fullpath
+ fullpath = path.Join(fullpath, part)
+
+ if _, ok := m[fullpath]; ok {
+ continue
+ }
+
+ n := &node{path: fullpath}
+ if fullpath == e.Name {
+ n.entry = e
+ } else {
+ n.isDir = true
+ }
+
+ m[n.path] = n
+ m[parent].children = append(m[parent].children, n)
+ }
+ }
+
+ return m[rootNode]
+}
+
+func (n *node) String() string {
+ return n.path
+}
+
+// Hash the hash of a filesystem is a 24-byte slice, is the result of
+// concatenating the computed plumbing.Hash of the file as a Blob and its
+// plumbing.FileMode; that way the difftree algorithm will detect changes in the
+// contents of files and also in their mode.
+//
+// If the node is computed and not based on a index.Entry the hash is equals
+// to a 24-bytes slices of zero values.
+func (n *node) Hash() []byte {
+ if n.entry == nil {
+ return make([]byte, 24)
+ }
+
+ return append(n.entry.Hash[:], n.entry.Mode.Bytes()...)
+}
+
+func (n *node) Name() string {
+ return path.Base(n.path)
+}
+
+func (n *node) IsDir() bool {
+ return n.isDir
+}
+
+func (n *node) Children() ([]noder.Noder, error) {
+ return n.children, nil
+}
+
+func (n *node) NumChildren() (int, error) {
+ return len(n.children), nil
+}
--- /dev/null
+package frame
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "strings"
+
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie/noder"
+)
+
+// A Frame is a collection of siblings in a trie, sorted alphabetically
+// by name.
+type Frame struct {
+ // siblings, sorted in reverse alphabetical order by name
+ stack []noder.Noder
+}
+
+type byName []noder.Noder
+
+func (a byName) Len() int { return len(a) }
+func (a byName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
+func (a byName) Less(i, j int) bool {
+ return strings.Compare(a[i].Name(), a[j].Name()) < 0
+}
+
+// New returns a frame with the children of the provided node.
+func New(n noder.Noder) (*Frame, error) {
+ children, err := n.Children()
+ if err != nil {
+ return nil, err
+ }
+
+ sort.Sort(sort.Reverse(byName(children)))
+ return &Frame{
+ stack: children,
+ }, nil
+}
+
+// String returns the quoted names of the noders in the frame sorted in
+// alphabeticall order by name, surrounded by square brackets and
+// separated by comas.
+//
+// Examples:
+// []
+// ["a", "b"]
+func (f *Frame) String() string {
+ var buf bytes.Buffer
+ _ = buf.WriteByte('[')
+
+ sep := ""
+ for i := f.Len() - 1; i >= 0; i-- {
+ _, _ = buf.WriteString(sep)
+ sep = ", "
+ _, _ = buf.WriteString(fmt.Sprintf("%q", f.stack[i].Name()))
+ }
+
+ _ = buf.WriteByte(']')
+
+ return buf.String()
+}
+
+// First returns, but dont extract, the noder with the alphabetically
+// smaller name in the frame and true if the frame was not empy.
+// Otherwise it returns nil and false.
+func (f *Frame) First() (noder.Noder, bool) {
+ if f.Len() == 0 {
+ return nil, false
+ }
+
+ top := f.Len() - 1
+
+ return f.stack[top], true
+}
+
+// Drop extracts the noder with the alphabetically smaller name in the
+// frame or does nothing if the frame was empty.
+func (f *Frame) Drop() {
+ if f.Len() == 0 {
+ return
+ }
+
+ top := f.Len() - 1
+ f.stack[top] = nil
+ f.stack = f.stack[:top]
+}
+
+// Len returns the number of noders in the frame.
+func (f *Frame) Len() int {
+ return len(f.stack)
+}
--- /dev/null
+package merkletrie
+
+import (
+ "fmt"
+ "io"
+
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie/internal/frame"
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie/noder"
+)
+
+// Iter is an iterator for merkletries (only the trie part of the
+// merkletrie is relevant here, it does not use the Hasher interface).
+//
+// The iteration is performed in depth-first pre-order. Entries at each
+// depth are traversed in (case-sensitive) alphabetical order.
+//
+// This is the kind of traversal you will expect when listing ordinary
+// files and directories recursively, for example:
+//
+// Trie Traversal order
+// ---- ---------------
+// .
+// / | \ c
+// / | \ d/
+// d c z ===> d/a
+// / \ d/b
+// b a z
+//
+//
+// This iterator is somewhat especial as you can chose to skip whole
+// "directories" when iterating:
+//
+// - The Step method will iterate normally.
+//
+// - the Next method will not descend deeper into the tree.
+//
+// For example, if the iterator is at `d/`, the Step method will return
+// `d/a` while the Next would have returned `z` instead (skipping `d/`
+// and its descendants). The name of the these two methods are based on
+// the well known "next" and "step" operations, quite common in
+// debuggers, like gdb.
+//
+// The paths returned by the iterator will be relative, if the iterator
+// was created from a single node, or absolute, if the iterator was
+// created from the path to the node (the path will be prefixed to all
+// returned paths).
+type Iter struct {
+ // Tells if the iteration has started.
+ hasStarted bool
+ // The top of this stack has the current node and its siblings. The
+ // rest of the stack keeps the ancestors of the current node and
+ // their corresponding siblings. The current element is always the
+ // top element of the top frame.
+ //
+ // When "step"ping into a node, its children are pushed as a new
+ // frame.
+ //
+ // When "next"ing pass a node, the current element is dropped by
+ // popping the top frame.
+ frameStack []*frame.Frame
+ // The base path used to turn the relative paths used internally by
+ // the iterator into absolute paths used by external applications.
+ // For relative iterator this will be nil.
+ base noder.Path
+}
+
+// NewIter returns a new relative iterator using the provider noder as
+// its unnamed root. When iterating, all returned paths will be
+// relative to node.
+func NewIter(n noder.Noder) (*Iter, error) {
+ return newIter(n, nil)
+}
+
+// NewIterFromPath returns a new absolute iterator from the noder at the
+// end of the path p. When iterating, all returned paths will be
+// absolute, using the root of the path p as their root.
+func NewIterFromPath(p noder.Path) (*Iter, error) {
+ return newIter(p, p) // Path implements Noder
+}
+
+func newIter(root noder.Noder, base noder.Path) (*Iter, error) {
+ ret := &Iter{
+ base: base,
+ }
+
+ if root == nil {
+ return ret, nil
+ }
+
+ frame, err := frame.New(root)
+ if err != nil {
+ return nil, err
+ }
+ ret.push(frame)
+
+ return ret, nil
+}
+
+func (iter *Iter) top() (*frame.Frame, bool) {
+ if len(iter.frameStack) == 0 {
+ return nil, false
+ }
+ top := len(iter.frameStack) - 1
+
+ return iter.frameStack[top], true
+}
+
+func (iter *Iter) push(f *frame.Frame) {
+ iter.frameStack = append(iter.frameStack, f)
+}
+
+const (
+ doDescend = true
+ dontDescend = false
+)
+
+// Next returns the path of the next node without descending deeper into
+// the trie and nil. If there are no more entries in the trie it
+// returns nil and io.EOF. In case of error, it will return nil and the
+// error.
+func (iter *Iter) Next() (noder.Path, error) {
+ return iter.advance(dontDescend)
+}
+
+// Step returns the path to the next node in the trie, descending deeper
+// into it if needed, and nil. If there are no more nodes in the trie,
+// it returns nil and io.EOF. In case of error, it will return nil and
+// the error.
+func (iter *Iter) Step() (noder.Path, error) {
+ return iter.advance(doDescend)
+}
+
+// Advances the iterator in the desired direction: descend or
+// dontDescend.
+//
+// Returns the new current element and a nil error on success. If there
+// are no more elements in the trie below the base, it returns nil, and
+// io.EOF. Returns nil and an error in case of errors.
+func (iter *Iter) advance(wantDescend bool) (noder.Path, error) {
+ current, err := iter.current()
+ if err != nil {
+ return nil, err
+ }
+
+ // The first time we just return the current node.
+ if !iter.hasStarted {
+ iter.hasStarted = true
+ return current, nil
+ }
+
+ // Advances means getting a next current node, either its first child or
+ // its next sibling, depending if we must descend or not.
+ numChildren, err := current.NumChildren()
+ if err != nil {
+ return nil, err
+ }
+
+ mustDescend := numChildren != 0 && wantDescend
+ if mustDescend {
+ // descend: add a new frame with the current's children.
+ frame, err := frame.New(current)
+ if err != nil {
+ return nil, err
+ }
+ iter.push(frame)
+ } else {
+ // don't descend: just drop the current node
+ iter.drop()
+ }
+
+ return iter.current()
+}
+
+// Returns the path to the current node, adding the base if there was
+// one, and a nil error. If there were no noders left, it returns nil
+// and io.EOF. If an error occurred, it returns nil and the error.
+func (iter *Iter) current() (noder.Path, error) {
+ if topFrame, ok := iter.top(); !ok {
+ return nil, io.EOF
+ } else if _, ok := topFrame.First(); !ok {
+ return nil, io.EOF
+ }
+
+ ret := make(noder.Path, 0, len(iter.base)+len(iter.frameStack))
+
+ // concat the base...
+ ret = append(ret, iter.base...)
+ // ... and the current node and all its ancestors
+ for i, f := range iter.frameStack {
+ t, ok := f.First()
+ if !ok {
+ panic(fmt.Sprintf("frame %d is empty", i))
+ }
+ ret = append(ret, t)
+ }
+
+ return ret, nil
+}
+
+// removes the current node if any, and all the frames that become empty as a
+// consequence of this action.
+func (iter *Iter) drop() {
+ frame, ok := iter.top()
+ if !ok {
+ return
+ }
+
+ frame.Drop()
+ // if the frame is empty, remove it and its parent, recursively
+ if frame.Len() == 0 {
+ top := len(iter.frameStack) - 1
+ iter.frameStack[top] = nil
+ iter.frameStack = iter.frameStack[:top]
+ iter.drop()
+ }
+}
--- /dev/null
+// Package noder provide an interface for defining nodes in a
+// merkletrie, their hashes and their paths (a noders and its
+// ancestors).
+//
+// The hasher interface is easy to implement naively by elements that
+// already have a hash, like git blobs and trees. More sophisticated
+// implementations can implement the Equal function in exotic ways
+// though: for instance, comparing the modification time of directories
+// in a filesystem.
+package noder
+
+import "fmt"
+
+// Hasher interface is implemented by types that can tell you
+// their hash.
+type Hasher interface {
+ Hash() []byte
+}
+
+// Equal functions take two hashers and return if they are equal.
+//
+// These functions are expected to be faster than reflect.Equal or
+// reflect.DeepEqual because they can compare just the hash of the
+// objects, instead of their contents, so they are expected to be O(1).
+type Equal func(a, b Hasher) bool
+
+// The Noder interface is implemented by the elements of a Merkle Trie.
+//
+// There are two types of elements in a Merkle Trie:
+//
+// - file-like nodes: they cannot have children.
+//
+// - directory-like nodes: they can have 0 or more children and their
+// hash is calculated by combining their children hashes.
+type Noder interface {
+ Hasher
+ fmt.Stringer // for testing purposes
+ // Name returns the name of an element (relative, not its full
+ // path).
+ Name() string
+ // IsDir returns true if the element is a directory-like node or
+ // false if it is a file-like node.
+ IsDir() bool
+ // Children returns the children of the element. Note that empty
+ // directory-like noders and file-like noders will both return
+ // NoChildren.
+ Children() ([]Noder, error)
+ // NumChildren returns the number of children this element has.
+ //
+ // This method is an optimization: the number of children is easily
+ // calculated as the length of the value returned by the Children
+ // method (above); yet, some implementations will be able to
+ // implement NumChildren in O(1) while Children is usually more
+ // complex.
+ NumChildren() (int, error)
+}
+
+// NoChildren represents the children of a noder without children.
+var NoChildren = []Noder{}
--- /dev/null
+package noder
+
+import (
+ "bytes"
+ "strings"
+
+ "golang.org/x/text/unicode/norm"
+)
+
+// Path values represent a noder and its ancestors. The root goes first
+// and the actual final noder the path is referring to will be the last.
+//
+// A path implements the Noder interface, redirecting all the interface
+// calls to its final noder.
+//
+// Paths build from an empty Noder slice are not valid paths and should
+// not be used.
+type Path []Noder
+
+// String returns the full path of the final noder as a string, using
+// "/" as the separator.
+func (p Path) String() string {
+ var buf bytes.Buffer
+ sep := ""
+ for _, e := range p {
+ _, _ = buf.WriteString(sep)
+ sep = "/"
+ _, _ = buf.WriteString(e.Name())
+ }
+
+ return buf.String()
+}
+
+// Last returns the final noder in the path.
+func (p Path) Last() Noder {
+ return p[len(p)-1]
+}
+
+// Hash returns the hash of the final noder of the path.
+func (p Path) Hash() []byte {
+ return p.Last().Hash()
+}
+
+// Name returns the name of the final noder of the path.
+func (p Path) Name() string {
+ return p.Last().Name()
+}
+
+// IsDir returns if the final noder of the path is a directory-like
+// noder.
+func (p Path) IsDir() bool {
+ return p.Last().IsDir()
+}
+
+// Children returns the children of the final noder in the path.
+func (p Path) Children() ([]Noder, error) {
+ return p.Last().Children()
+}
+
+// NumChildren returns the number of children the final noder of the
+// path has.
+func (p Path) NumChildren() (int, error) {
+ return p.Last().NumChildren()
+}
+
+// Compare returns -1, 0 or 1 if the path p is smaller, equal or bigger
+// than other, in "directory order"; for example:
+//
+// "a" < "b"
+// "a/b/c/d/z" < "b"
+// "a/b/a" > "a/b"
+func (p Path) Compare(other Path) int {
+ i := 0
+ for {
+ switch {
+ case len(other) == len(p) && i == len(p):
+ return 0
+ case i == len(other):
+ return 1
+ case i == len(p):
+ return -1
+ default:
+ form := norm.Form(norm.NFC)
+ this := form.String(p[i].Name())
+ that := form.String(other[i].Name())
+
+ cmp := strings.Compare(this, that)
+ if cmp != 0 {
+ return cmp
+ }
+ }
+ i++
+ }
+}
--- /dev/null
+package git
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ stdioutil "io/ioutil"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "gopkg.in/src-d/go-git.v4/config"
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/filemode"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/gitignore"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/index"
+ "gopkg.in/src-d/go-git.v4/plumbing/object"
+ "gopkg.in/src-d/go-git.v4/plumbing/storer"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie"
+
+ "gopkg.in/src-d/go-billy.v4"
+ "gopkg.in/src-d/go-billy.v4/util"
+)
+
+var (
+ ErrWorktreeNotClean = errors.New("worktree is not clean")
+ ErrSubmoduleNotFound = errors.New("submodule not found")
+ ErrUnstagedChanges = errors.New("worktree contains unstaged changes")
+ ErrGitModulesSymlink = errors.New(gitmodulesFile + " is a symlink")
+)
+
+// Worktree represents a git worktree.
+type Worktree struct {
+ // Filesystem underlying filesystem.
+ Filesystem billy.Filesystem
+ // External excludes not found in the repository .gitignore
+ Excludes []gitignore.Pattern
+
+ r *Repository
+}
+
+// Pull incorporates changes from a remote repository into the current branch.
+// Returns nil if the operation is successful, NoErrAlreadyUpToDate if there are
+// no changes to be fetched, or an error.
+//
+// Pull only supports merges where the can be resolved as a fast-forward.
+func (w *Worktree) Pull(o *PullOptions) error {
+ return w.PullContext(context.Background(), o)
+}
+
+// PullContext incorporates changes from a remote repository into the current
+// branch. Returns nil if the operation is successful, NoErrAlreadyUpToDate if
+// there are no changes to be fetched, or an error.
+//
+// Pull only supports merges where the can be resolved as a fast-forward.
+//
+// The provided Context must be non-nil. If the context expires before the
+// operation is complete, an error is returned. The context only affects to the
+// transport operations.
+func (w *Worktree) PullContext(ctx context.Context, o *PullOptions) error {
+ if err := o.Validate(); err != nil {
+ return err
+ }
+
+ remote, err := w.r.Remote(o.RemoteName)
+ if err != nil {
+ return err
+ }
+
+ fetchHead, err := remote.fetch(ctx, &FetchOptions{
+ RemoteName: o.RemoteName,
+ Depth: o.Depth,
+ Auth: o.Auth,
+ Progress: o.Progress,
+ Force: o.Force,
+ })
+
+ updated := true
+ if err == NoErrAlreadyUpToDate {
+ updated = false
+ } else if err != nil {
+ return err
+ }
+
+ ref, err := storer.ResolveReference(fetchHead, o.ReferenceName)
+ if err != nil {
+ return err
+ }
+
+ head, err := w.r.Head()
+ if err == nil {
+ if !updated && head.Hash() == ref.Hash() {
+ return NoErrAlreadyUpToDate
+ }
+
+ ff, err := isFastForward(w.r.Storer, head.Hash(), ref.Hash())
+ if err != nil {
+ return err
+ }
+
+ if !ff {
+ return fmt.Errorf("non-fast-forward update")
+ }
+ }
+
+ if err != nil && err != plumbing.ErrReferenceNotFound {
+ return err
+ }
+
+ if err := w.updateHEAD(ref.Hash()); err != nil {
+ return err
+ }
+
+ if err := w.Reset(&ResetOptions{
+ Mode: MergeReset,
+ Commit: ref.Hash(),
+ }); err != nil {
+ return err
+ }
+
+ if o.RecurseSubmodules != NoRecurseSubmodules {
+ return w.updateSubmodules(&SubmoduleUpdateOptions{
+ RecurseSubmodules: o.RecurseSubmodules,
+ Auth: o.Auth,
+ })
+ }
+
+ return nil
+}
+
+func (w *Worktree) updateSubmodules(o *SubmoduleUpdateOptions) error {
+ s, err := w.Submodules()
+ if err != nil {
+ return err
+ }
+ o.Init = true
+ return s.Update(o)
+}
+
+// Checkout switch branches or restore working tree files.
+func (w *Worktree) Checkout(opts *CheckoutOptions) error {
+ if err := opts.Validate(); err != nil {
+ return err
+ }
+
+ if opts.Create {
+ if err := w.createBranch(opts); err != nil {
+ return err
+ }
+ }
+
+ if !opts.Force {
+ unstaged, err := w.containsUnstagedChanges()
+ if err != nil {
+ return err
+ }
+
+ if unstaged {
+ return ErrUnstagedChanges
+ }
+ }
+
+ c, err := w.getCommitFromCheckoutOptions(opts)
+ if err != nil {
+ return err
+ }
+
+ ro := &ResetOptions{Commit: c, Mode: MergeReset}
+ if opts.Force {
+ ro.Mode = HardReset
+ }
+
+ if !opts.Hash.IsZero() && !opts.Create {
+ err = w.setHEADToCommit(opts.Hash)
+ } else {
+ err = w.setHEADToBranch(opts.Branch, c)
+ }
+
+ if err != nil {
+ return err
+ }
+
+ return w.Reset(ro)
+}
+func (w *Worktree) createBranch(opts *CheckoutOptions) error {
+ _, err := w.r.Storer.Reference(opts.Branch)
+ if err == nil {
+ return fmt.Errorf("a branch named %q already exists", opts.Branch)
+ }
+
+ if err != plumbing.ErrReferenceNotFound {
+ return err
+ }
+
+ if opts.Hash.IsZero() {
+ ref, err := w.r.Head()
+ if err != nil {
+ return err
+ }
+
+ opts.Hash = ref.Hash()
+ }
+
+ return w.r.Storer.SetReference(
+ plumbing.NewHashReference(opts.Branch, opts.Hash),
+ )
+}
+
+func (w *Worktree) getCommitFromCheckoutOptions(opts *CheckoutOptions) (plumbing.Hash, error) {
+ if !opts.Hash.IsZero() {
+ return opts.Hash, nil
+ }
+
+ b, err := w.r.Reference(opts.Branch, true)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ if !b.Name().IsTag() {
+ return b.Hash(), nil
+ }
+
+ o, err := w.r.Object(plumbing.AnyObject, b.Hash())
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ switch o := o.(type) {
+ case *object.Tag:
+ if o.TargetType != plumbing.CommitObject {
+ return plumbing.ZeroHash, fmt.Errorf("unsupported tag object target %q", o.TargetType)
+ }
+
+ return o.Target, nil
+ case *object.Commit:
+ return o.Hash, nil
+ }
+
+ return plumbing.ZeroHash, fmt.Errorf("unsupported tag target %q", o.Type())
+}
+
+func (w *Worktree) setHEADToCommit(commit plumbing.Hash) error {
+ head := plumbing.NewHashReference(plumbing.HEAD, commit)
+ return w.r.Storer.SetReference(head)
+}
+
+func (w *Worktree) setHEADToBranch(branch plumbing.ReferenceName, commit plumbing.Hash) error {
+ target, err := w.r.Storer.Reference(branch)
+ if err != nil {
+ return err
+ }
+
+ var head *plumbing.Reference
+ if target.Name().IsBranch() {
+ head = plumbing.NewSymbolicReference(plumbing.HEAD, target.Name())
+ } else {
+ head = plumbing.NewHashReference(plumbing.HEAD, commit)
+ }
+
+ return w.r.Storer.SetReference(head)
+}
+
+// Reset the worktree to a specified state.
+func (w *Worktree) Reset(opts *ResetOptions) error {
+ if err := opts.Validate(w.r); err != nil {
+ return err
+ }
+
+ if opts.Mode == MergeReset {
+ unstaged, err := w.containsUnstagedChanges()
+ if err != nil {
+ return err
+ }
+
+ if unstaged {
+ return ErrUnstagedChanges
+ }
+ }
+
+ if err := w.setHEADCommit(opts.Commit); err != nil {
+ return err
+ }
+
+ if opts.Mode == SoftReset {
+ return nil
+ }
+
+ t, err := w.getTreeFromCommitHash(opts.Commit)
+ if err != nil {
+ return err
+ }
+
+ if opts.Mode == MixedReset || opts.Mode == MergeReset || opts.Mode == HardReset {
+ if err := w.resetIndex(t); err != nil {
+ return err
+ }
+ }
+
+ if opts.Mode == MergeReset || opts.Mode == HardReset {
+ if err := w.resetWorktree(t); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (w *Worktree) resetIndex(t *object.Tree) error {
+ idx, err := w.r.Storer.Index()
+ if err != nil {
+ return err
+ }
+
+ changes, err := w.diffTreeWithStaging(t, true)
+ if err != nil {
+ return err
+ }
+
+ for _, ch := range changes {
+ a, err := ch.Action()
+ if err != nil {
+ return err
+ }
+
+ var name string
+ var e *object.TreeEntry
+
+ switch a {
+ case merkletrie.Modify, merkletrie.Insert:
+ name = ch.To.String()
+ e, err = t.FindEntry(name)
+ if err != nil {
+ return err
+ }
+ case merkletrie.Delete:
+ name = ch.From.String()
+ }
+
+ _, _ = idx.Remove(name)
+ if e == nil {
+ continue
+ }
+
+ idx.Entries = append(idx.Entries, &index.Entry{
+ Name: name,
+ Hash: e.Hash,
+ Mode: e.Mode,
+ })
+
+ }
+
+ return w.r.Storer.SetIndex(idx)
+}
+
+func (w *Worktree) resetWorktree(t *object.Tree) error {
+ changes, err := w.diffStagingWithWorktree(true)
+ if err != nil {
+ return err
+ }
+
+ idx, err := w.r.Storer.Index()
+ if err != nil {
+ return err
+ }
+
+ for _, ch := range changes {
+ if err := w.checkoutChange(ch, t, idx); err != nil {
+ return err
+ }
+ }
+
+ return w.r.Storer.SetIndex(idx)
+}
+
+func (w *Worktree) checkoutChange(ch merkletrie.Change, t *object.Tree, idx *index.Index) error {
+ a, err := ch.Action()
+ if err != nil {
+ return err
+ }
+
+ var e *object.TreeEntry
+ var name string
+ var isSubmodule bool
+
+ switch a {
+ case merkletrie.Modify, merkletrie.Insert:
+ name = ch.To.String()
+ e, err = t.FindEntry(name)
+ if err != nil {
+ return err
+ }
+
+ isSubmodule = e.Mode == filemode.Submodule
+ case merkletrie.Delete:
+ return rmFileAndDirIfEmpty(w.Filesystem, ch.From.String())
+ }
+
+ if isSubmodule {
+ return w.checkoutChangeSubmodule(name, a, e, idx)
+ }
+
+ return w.checkoutChangeRegularFile(name, a, t, e, idx)
+}
+
+func (w *Worktree) containsUnstagedChanges() (bool, error) {
+ ch, err := w.diffStagingWithWorktree(false)
+ if err != nil {
+ return false, err
+ }
+
+ for _, c := range ch {
+ a, err := c.Action()
+ if err != nil {
+ return false, err
+ }
+
+ if a == merkletrie.Insert {
+ continue
+ }
+
+ return true, nil
+ }
+
+ return false, nil
+}
+
+func (w *Worktree) setHEADCommit(commit plumbing.Hash) error {
+ head, err := w.r.Reference(plumbing.HEAD, false)
+ if err != nil {
+ return err
+ }
+
+ if head.Type() == plumbing.HashReference {
+ head = plumbing.NewHashReference(plumbing.HEAD, commit)
+ return w.r.Storer.SetReference(head)
+ }
+
+ branch, err := w.r.Reference(head.Target(), false)
+ if err != nil {
+ return err
+ }
+
+ if !branch.Name().IsBranch() {
+ return fmt.Errorf("invalid HEAD target should be a branch, found %s", branch.Type())
+ }
+
+ branch = plumbing.NewHashReference(branch.Name(), commit)
+ return w.r.Storer.SetReference(branch)
+}
+
+func (w *Worktree) checkoutChangeSubmodule(name string,
+ a merkletrie.Action,
+ e *object.TreeEntry,
+ idx *index.Index,
+) error {
+ switch a {
+ case merkletrie.Modify:
+ sub, err := w.Submodule(name)
+ if err != nil {
+ return err
+ }
+
+ if !sub.initialized {
+ return nil
+ }
+
+ return w.addIndexFromTreeEntry(name, e, idx)
+ case merkletrie.Insert:
+ mode, err := e.Mode.ToOSFileMode()
+ if err != nil {
+ return err
+ }
+
+ if err := w.Filesystem.MkdirAll(name, mode); err != nil {
+ return err
+ }
+
+ return w.addIndexFromTreeEntry(name, e, idx)
+ }
+
+ return nil
+}
+
+func (w *Worktree) checkoutChangeRegularFile(name string,
+ a merkletrie.Action,
+ t *object.Tree,
+ e *object.TreeEntry,
+ idx *index.Index,
+) error {
+ switch a {
+ case merkletrie.Modify:
+ _, _ = idx.Remove(name)
+
+ // to apply perm changes the file is deleted, billy doesn't implement
+ // chmod
+ if err := w.Filesystem.Remove(name); err != nil {
+ return err
+ }
+
+ fallthrough
+ case merkletrie.Insert:
+ f, err := t.File(name)
+ if err != nil {
+ return err
+ }
+
+ if err := w.checkoutFile(f); err != nil {
+ return err
+ }
+
+ return w.addIndexFromFile(name, e.Hash, idx)
+ }
+
+ return nil
+}
+
+func (w *Worktree) checkoutFile(f *object.File) (err error) {
+ mode, err := f.Mode.ToOSFileMode()
+ if err != nil {
+ return
+ }
+
+ if mode&os.ModeSymlink != 0 {
+ return w.checkoutFileSymlink(f)
+ }
+
+ from, err := f.Reader()
+ if err != nil {
+ return
+ }
+
+ defer ioutil.CheckClose(from, &err)
+
+ to, err := w.Filesystem.OpenFile(f.Name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode.Perm())
+ if err != nil {
+ return
+ }
+
+ defer ioutil.CheckClose(to, &err)
+
+ _, err = io.Copy(to, from)
+ return
+}
+
+func (w *Worktree) checkoutFileSymlink(f *object.File) (err error) {
+ from, err := f.Reader()
+ if err != nil {
+ return
+ }
+
+ defer ioutil.CheckClose(from, &err)
+
+ bytes, err := stdioutil.ReadAll(from)
+ if err != nil {
+ return
+ }
+
+ err = w.Filesystem.Symlink(string(bytes), f.Name)
+
+ // On windows, this might fail.
+ // Follow Git on Windows behavior by writing the link as it is.
+ if err != nil && isSymlinkWindowsNonAdmin(err) {
+ mode, _ := f.Mode.ToOSFileMode()
+
+ to, err := w.Filesystem.OpenFile(f.Name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode.Perm())
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(to, &err)
+
+ _, err = to.Write(bytes)
+ return err
+ }
+ return
+}
+
+func (w *Worktree) addIndexFromTreeEntry(name string, f *object.TreeEntry, idx *index.Index) error {
+ _, _ = idx.Remove(name)
+ idx.Entries = append(idx.Entries, &index.Entry{
+ Hash: f.Hash,
+ Name: name,
+ Mode: filemode.Submodule,
+ })
+
+ return nil
+}
+
+func (w *Worktree) addIndexFromFile(name string, h plumbing.Hash, idx *index.Index) error {
+ _, _ = idx.Remove(name)
+ fi, err := w.Filesystem.Lstat(name)
+ if err != nil {
+ return err
+ }
+
+ mode, err := filemode.NewFromOSFileMode(fi.Mode())
+ if err != nil {
+ return err
+ }
+
+ e := &index.Entry{
+ Hash: h,
+ Name: name,
+ Mode: mode,
+ ModifiedAt: fi.ModTime(),
+ Size: uint32(fi.Size()),
+ }
+
+ // if the FileInfo.Sys() comes from os the ctime, dev, inode, uid and gid
+ // can be retrieved, otherwise this doesn't apply
+ if fillSystemInfo != nil {
+ fillSystemInfo(e, fi.Sys())
+ }
+
+ idx.Entries = append(idx.Entries, e)
+ return nil
+}
+
+func (w *Worktree) getTreeFromCommitHash(commit plumbing.Hash) (*object.Tree, error) {
+ c, err := w.r.CommitObject(commit)
+ if err != nil {
+ return nil, err
+ }
+
+ return c.Tree()
+}
+
+var fillSystemInfo func(e *index.Entry, sys interface{})
+
+const gitmodulesFile = ".gitmodules"
+
+// Submodule returns the submodule with the given name
+func (w *Worktree) Submodule(name string) (*Submodule, error) {
+ l, err := w.Submodules()
+ if err != nil {
+ return nil, err
+ }
+
+ for _, m := range l {
+ if m.Config().Name == name {
+ return m, nil
+ }
+ }
+
+ return nil, ErrSubmoduleNotFound
+}
+
+// Submodules returns all the available submodules
+func (w *Worktree) Submodules() (Submodules, error) {
+ l := make(Submodules, 0)
+ m, err := w.readGitmodulesFile()
+ if err != nil || m == nil {
+ return l, err
+ }
+
+ c, err := w.r.Config()
+ if err != nil {
+ return nil, err
+ }
+
+ for _, s := range m.Submodules {
+ l = append(l, w.newSubmodule(s, c.Submodules[s.Name]))
+ }
+
+ return l, nil
+}
+
+func (w *Worktree) newSubmodule(fromModules, fromConfig *config.Submodule) *Submodule {
+ m := &Submodule{w: w}
+ m.initialized = fromConfig != nil
+
+ if !m.initialized {
+ m.c = fromModules
+ return m
+ }
+
+ m.c = fromConfig
+ m.c.Path = fromModules.Path
+ return m
+}
+
+func (w *Worktree) isSymlink(path string) bool {
+ if s, err := w.Filesystem.Lstat(path); err == nil {
+ return s.Mode()&os.ModeSymlink != 0
+ }
+ return false
+}
+
+func (w *Worktree) readGitmodulesFile() (*config.Modules, error) {
+ if w.isSymlink(gitmodulesFile) {
+ return nil, ErrGitModulesSymlink
+ }
+
+ f, err := w.Filesystem.Open(gitmodulesFile)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+
+ return nil, err
+ }
+
+ defer f.Close()
+ input, err := stdioutil.ReadAll(f)
+ if err != nil {
+ return nil, err
+ }
+
+ m := config.NewModules()
+ return m, m.Unmarshal(input)
+}
+
+// Clean the worktree by removing untracked files.
+// An empty dir could be removed - this is what `git clean -f -d .` does.
+func (w *Worktree) Clean(opts *CleanOptions) error {
+ s, err := w.Status()
+ if err != nil {
+ return err
+ }
+
+ root := ""
+ files, err := w.Filesystem.ReadDir(root)
+ if err != nil {
+ return err
+ }
+ return w.doClean(s, opts, root, files)
+}
+
+func (w *Worktree) doClean(status Status, opts *CleanOptions, dir string, files []os.FileInfo) error {
+ for _, fi := range files {
+ if fi.Name() == ".git" {
+ continue
+ }
+
+ // relative path under the root
+ path := filepath.Join(dir, fi.Name())
+ if fi.IsDir() {
+ if !opts.Dir {
+ continue
+ }
+
+ subfiles, err := w.Filesystem.ReadDir(path)
+ if err != nil {
+ return err
+ }
+ err = w.doClean(status, opts, path, subfiles)
+ if err != nil {
+ return err
+ }
+ } else {
+ if status.IsUntracked(path) {
+ if err := w.Filesystem.Remove(path); err != nil {
+ return err
+ }
+ }
+ }
+ }
+
+ if opts.Dir {
+ return doCleanDirectories(w.Filesystem, dir)
+ }
+ return nil
+}
+
+// GrepResult is structure of a grep result.
+type GrepResult struct {
+ // FileName is the name of file which contains match.
+ FileName string
+ // LineNumber is the line number of a file at which a match was found.
+ LineNumber int
+ // Content is the content of the file at the matching line.
+ Content string
+ // TreeName is the name of the tree (reference name/commit hash) at
+ // which the match was performed.
+ TreeName string
+}
+
+func (gr GrepResult) String() string {
+ return fmt.Sprintf("%s:%s:%d:%s", gr.TreeName, gr.FileName, gr.LineNumber, gr.Content)
+}
+
+// Grep performs grep on a worktree.
+func (w *Worktree) Grep(opts *GrepOptions) ([]GrepResult, error) {
+ if err := opts.Validate(w); err != nil {
+ return nil, err
+ }
+
+ // Obtain commit hash from options (CommitHash or ReferenceName).
+ var commitHash plumbing.Hash
+ // treeName contains the value of TreeName in GrepResult.
+ var treeName string
+
+ if opts.ReferenceName != "" {
+ ref, err := w.r.Reference(opts.ReferenceName, true)
+ if err != nil {
+ return nil, err
+ }
+ commitHash = ref.Hash()
+ treeName = opts.ReferenceName.String()
+ } else if !opts.CommitHash.IsZero() {
+ commitHash = opts.CommitHash
+ treeName = opts.CommitHash.String()
+ }
+
+ // Obtain a tree from the commit hash and get a tracked files iterator from
+ // the tree.
+ tree, err := w.getTreeFromCommitHash(commitHash)
+ if err != nil {
+ return nil, err
+ }
+ fileiter := tree.Files()
+
+ return findMatchInFiles(fileiter, treeName, opts)
+}
+
+// findMatchInFiles takes a FileIter, worktree name and GrepOptions, and
+// returns a slice of GrepResult containing the result of regex pattern matching
+// in content of all the files.
+func findMatchInFiles(fileiter *object.FileIter, treeName string, opts *GrepOptions) ([]GrepResult, error) {
+ var results []GrepResult
+
+ err := fileiter.ForEach(func(file *object.File) error {
+ var fileInPathSpec bool
+
+ // When no pathspecs are provided, search all the files.
+ if len(opts.PathSpecs) == 0 {
+ fileInPathSpec = true
+ }
+
+ // Check if the file name matches with the pathspec. Break out of the
+ // loop once a match is found.
+ for _, pathSpec := range opts.PathSpecs {
+ if pathSpec != nil && pathSpec.MatchString(file.Name) {
+ fileInPathSpec = true
+ break
+ }
+ }
+
+ // If the file does not match with any of the pathspec, skip it.
+ if !fileInPathSpec {
+ return nil
+ }
+
+ grepResults, err := findMatchInFile(file, treeName, opts)
+ if err != nil {
+ return err
+ }
+ results = append(results, grepResults...)
+
+ return nil
+ })
+
+ return results, err
+}
+
+// findMatchInFile takes a single File, worktree name and GrepOptions,
+// and returns a slice of GrepResult containing the result of regex pattern
+// matching in the given file.
+func findMatchInFile(file *object.File, treeName string, opts *GrepOptions) ([]GrepResult, error) {
+ var grepResults []GrepResult
+
+ content, err := file.Contents()
+ if err != nil {
+ return grepResults, err
+ }
+
+ // Split the file content and parse line-by-line.
+ contentByLine := strings.Split(content, "\n")
+ for lineNum, cnt := range contentByLine {
+ addToResult := false
+
+ // Match the patterns and content. Break out of the loop once a
+ // match is found.
+ for _, pattern := range opts.Patterns {
+ if pattern != nil && pattern.MatchString(cnt) {
+ // Add to result only if invert match is not enabled.
+ if !opts.InvertMatch {
+ addToResult = true
+ break
+ }
+ } else if opts.InvertMatch {
+ // If matching fails, and invert match is enabled, add to
+ // results.
+ addToResult = true
+ break
+ }
+ }
+
+ if addToResult {
+ grepResults = append(grepResults, GrepResult{
+ FileName: file.Name,
+ LineNumber: lineNum + 1,
+ Content: cnt,
+ TreeName: treeName,
+ })
+ }
+ }
+
+ return grepResults, nil
+}
+
+func rmFileAndDirIfEmpty(fs billy.Filesystem, name string) error {
+ if err := util.RemoveAll(fs, name); err != nil {
+ return err
+ }
+
+ dir := filepath.Dir(name)
+ return doCleanDirectories(fs, dir)
+}
+
+// doCleanDirectories removes empty subdirs (without files)
+func doCleanDirectories(fs billy.Filesystem, dir string) error {
+ files, err := fs.ReadDir(dir)
+ if err != nil {
+ return err
+ }
+ if len(files) == 0 {
+ return fs.Remove(dir)
+ }
+ return nil
+}
--- /dev/null
+// +build darwin freebsd netbsd
+
+package git
+
+import (
+ "syscall"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/format/index"
+)
+
+func init() {
+ fillSystemInfo = func(e *index.Entry, sys interface{}) {
+ if os, ok := sys.(*syscall.Stat_t); ok {
+ e.CreatedAt = time.Unix(int64(os.Atimespec.Sec), int64(os.Atimespec.Nsec))
+ e.Dev = uint32(os.Dev)
+ e.Inode = uint32(os.Ino)
+ e.GID = os.Gid
+ e.UID = os.Uid
+ }
+ }
+}
+
+func isSymlinkWindowsNonAdmin(err error) bool {
+ return false
+}
--- /dev/null
+package git
+
+import (
+ "bytes"
+ "path"
+ "sort"
+ "strings"
+
+ "golang.org/x/crypto/openpgp"
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/filemode"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/index"
+ "gopkg.in/src-d/go-git.v4/plumbing/object"
+ "gopkg.in/src-d/go-git.v4/storage"
+
+ "gopkg.in/src-d/go-billy.v4"
+)
+
+// Commit stores the current contents of the index in a new commit along with
+// a log message from the user describing the changes.
+func (w *Worktree) Commit(msg string, opts *CommitOptions) (plumbing.Hash, error) {
+ if err := opts.Validate(w.r); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ if opts.All {
+ if err := w.autoAddModifiedAndDeleted(); err != nil {
+ return plumbing.ZeroHash, err
+ }
+ }
+
+ idx, err := w.r.Storer.Index()
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ h := &buildTreeHelper{
+ fs: w.Filesystem,
+ s: w.r.Storer,
+ }
+
+ tree, err := h.BuildTree(idx)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ commit, err := w.buildCommitObject(msg, opts, tree)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ return commit, w.updateHEAD(commit)
+}
+
+func (w *Worktree) autoAddModifiedAndDeleted() error {
+ s, err := w.Status()
+ if err != nil {
+ return err
+ }
+
+ for path, fs := range s {
+ if fs.Worktree != Modified && fs.Worktree != Deleted {
+ continue
+ }
+
+ if _, err := w.Add(path); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (w *Worktree) updateHEAD(commit plumbing.Hash) error {
+ head, err := w.r.Storer.Reference(plumbing.HEAD)
+ if err != nil {
+ return err
+ }
+
+ name := plumbing.HEAD
+ if head.Type() != plumbing.HashReference {
+ name = head.Target()
+ }
+
+ ref := plumbing.NewHashReference(name, commit)
+ return w.r.Storer.SetReference(ref)
+}
+
+func (w *Worktree) buildCommitObject(msg string, opts *CommitOptions, tree plumbing.Hash) (plumbing.Hash, error) {
+ commit := &object.Commit{
+ Author: *opts.Author,
+ Committer: *opts.Committer,
+ Message: msg,
+ TreeHash: tree,
+ ParentHashes: opts.Parents,
+ }
+
+ if opts.SignKey != nil {
+ sig, err := w.buildCommitSignature(commit, opts.SignKey)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+ commit.PGPSignature = sig
+ }
+
+ obj := w.r.Storer.NewEncodedObject()
+ if err := commit.Encode(obj); err != nil {
+ return plumbing.ZeroHash, err
+ }
+ return w.r.Storer.SetEncodedObject(obj)
+}
+
+func (w *Worktree) buildCommitSignature(commit *object.Commit, signKey *openpgp.Entity) (string, error) {
+ encoded := &plumbing.MemoryObject{}
+ if err := commit.Encode(encoded); err != nil {
+ return "", err
+ }
+ r, err := encoded.Reader()
+ if err != nil {
+ return "", err
+ }
+ var b bytes.Buffer
+ if err := openpgp.ArmoredDetachSign(&b, signKey, r, nil); err != nil {
+ return "", err
+ }
+ return b.String(), nil
+}
+
+// buildTreeHelper converts a given index.Index file into multiple git objects
+// reading the blobs from the given filesystem and creating the trees from the
+// index structure. The created objects are pushed to a given Storer.
+type buildTreeHelper struct {
+ fs billy.Filesystem
+ s storage.Storer
+
+ trees map[string]*object.Tree
+ entries map[string]*object.TreeEntry
+}
+
+// BuildTree builds the tree objects and push its to the storer, the hash
+// of the root tree is returned.
+func (h *buildTreeHelper) BuildTree(idx *index.Index) (plumbing.Hash, error) {
+ const rootNode = ""
+ h.trees = map[string]*object.Tree{rootNode: {}}
+ h.entries = map[string]*object.TreeEntry{}
+
+ for _, e := range idx.Entries {
+ if err := h.commitIndexEntry(e); err != nil {
+ return plumbing.ZeroHash, err
+ }
+ }
+
+ return h.copyTreeToStorageRecursive(rootNode, h.trees[rootNode])
+}
+
+func (h *buildTreeHelper) commitIndexEntry(e *index.Entry) error {
+ parts := strings.Split(e.Name, "/")
+
+ var fullpath string
+ for _, part := range parts {
+ parent := fullpath
+ fullpath = path.Join(fullpath, part)
+
+ h.doBuildTree(e, parent, fullpath)
+ }
+
+ return nil
+}
+
+func (h *buildTreeHelper) doBuildTree(e *index.Entry, parent, fullpath string) {
+ if _, ok := h.trees[fullpath]; ok {
+ return
+ }
+
+ if _, ok := h.entries[fullpath]; ok {
+ return
+ }
+
+ te := object.TreeEntry{Name: path.Base(fullpath)}
+
+ if fullpath == e.Name {
+ te.Mode = e.Mode
+ te.Hash = e.Hash
+ } else {
+ te.Mode = filemode.Dir
+ h.trees[fullpath] = &object.Tree{}
+ }
+
+ h.trees[parent].Entries = append(h.trees[parent].Entries, te)
+}
+
+type sortableEntries []object.TreeEntry
+
+func (sortableEntries) sortName(te object.TreeEntry) string {
+ if te.Mode == filemode.Dir {
+ return te.Name + "/"
+ }
+ return te.Name
+}
+func (se sortableEntries) Len() int { return len(se) }
+func (se sortableEntries) Less(i int, j int) bool { return se.sortName(se[i]) < se.sortName(se[j]) }
+func (se sortableEntries) Swap(i int, j int) { se[i], se[j] = se[j], se[i] }
+
+func (h *buildTreeHelper) copyTreeToStorageRecursive(parent string, t *object.Tree) (plumbing.Hash, error) {
+ sort.Sort(sortableEntries(t.Entries))
+ for i, e := range t.Entries {
+ if e.Mode != filemode.Dir && !e.Hash.IsZero() {
+ continue
+ }
+
+ path := path.Join(parent, e.Name)
+
+ var err error
+ e.Hash, err = h.copyTreeToStorageRecursive(path, h.trees[path])
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ t.Entries[i] = e
+ }
+
+ o := h.s.NewEncodedObject()
+ if err := t.Encode(o); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ return h.s.SetEncodedObject(o)
+}
--- /dev/null
+// +build linux
+
+package git
+
+import (
+ "syscall"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/format/index"
+)
+
+func init() {
+ fillSystemInfo = func(e *index.Entry, sys interface{}) {
+ if os, ok := sys.(*syscall.Stat_t); ok {
+ e.CreatedAt = time.Unix(int64(os.Ctim.Sec), int64(os.Ctim.Nsec))
+ e.Dev = uint32(os.Dev)
+ e.Inode = uint32(os.Ino)
+ e.GID = os.Gid
+ e.UID = os.Uid
+ }
+ }
+}
+
+func isSymlinkWindowsNonAdmin(err error) bool {
+ return false
+}
--- /dev/null
+package git
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "os"
+ "path"
+ "path/filepath"
+
+ "gopkg.in/src-d/go-billy.v4/util"
+ "gopkg.in/src-d/go-git.v4/plumbing"
+ "gopkg.in/src-d/go-git.v4/plumbing/filemode"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/gitignore"
+ "gopkg.in/src-d/go-git.v4/plumbing/format/index"
+ "gopkg.in/src-d/go-git.v4/plumbing/object"
+ "gopkg.in/src-d/go-git.v4/utils/ioutil"
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie"
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie/filesystem"
+ mindex "gopkg.in/src-d/go-git.v4/utils/merkletrie/index"
+ "gopkg.in/src-d/go-git.v4/utils/merkletrie/noder"
+)
+
+var (
+ // ErrDestinationExists in an Move operation means that the target exists on
+ // the worktree.
+ ErrDestinationExists = errors.New("destination exists")
+ // ErrGlobNoMatches in an AddGlob if the glob pattern does not match any
+ // files in the worktree.
+ ErrGlobNoMatches = errors.New("glob pattern did not match any files")
+)
+
+// Status returns the working tree status.
+func (w *Worktree) Status() (Status, error) {
+ var hash plumbing.Hash
+
+ ref, err := w.r.Head()
+ if err != nil && err != plumbing.ErrReferenceNotFound {
+ return nil, err
+ }
+
+ if err == nil {
+ hash = ref.Hash()
+ }
+
+ return w.status(hash)
+}
+
+func (w *Worktree) status(commit plumbing.Hash) (Status, error) {
+ s := make(Status)
+
+ left, err := w.diffCommitWithStaging(commit, false)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, ch := range left {
+ a, err := ch.Action()
+ if err != nil {
+ return nil, err
+ }
+
+ fs := s.File(nameFromAction(&ch))
+ fs.Worktree = Unmodified
+
+ switch a {
+ case merkletrie.Delete:
+ s.File(ch.From.String()).Staging = Deleted
+ case merkletrie.Insert:
+ s.File(ch.To.String()).Staging = Added
+ case merkletrie.Modify:
+ s.File(ch.To.String()).Staging = Modified
+ }
+ }
+
+ right, err := w.diffStagingWithWorktree(false)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, ch := range right {
+ a, err := ch.Action()
+ if err != nil {
+ return nil, err
+ }
+
+ fs := s.File(nameFromAction(&ch))
+ if fs.Staging == Untracked {
+ fs.Staging = Unmodified
+ }
+
+ switch a {
+ case merkletrie.Delete:
+ fs.Worktree = Deleted
+ case merkletrie.Insert:
+ fs.Worktree = Untracked
+ fs.Staging = Untracked
+ case merkletrie.Modify:
+ fs.Worktree = Modified
+ }
+ }
+
+ return s, nil
+}
+
+func nameFromAction(ch *merkletrie.Change) string {
+ name := ch.To.String()
+ if name == "" {
+ return ch.From.String()
+ }
+
+ return name
+}
+
+func (w *Worktree) diffStagingWithWorktree(reverse bool) (merkletrie.Changes, error) {
+ idx, err := w.r.Storer.Index()
+ if err != nil {
+ return nil, err
+ }
+
+ from := mindex.NewRootNode(idx)
+ submodules, err := w.getSubmodulesStatus()
+ if err != nil {
+ return nil, err
+ }
+
+ to := filesystem.NewRootNode(w.Filesystem, submodules)
+
+ var c merkletrie.Changes
+ if reverse {
+ c, err = merkletrie.DiffTree(to, from, diffTreeIsEquals)
+ } else {
+ c, err = merkletrie.DiffTree(from, to, diffTreeIsEquals)
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ return w.excludeIgnoredChanges(c), nil
+}
+
+func (w *Worktree) excludeIgnoredChanges(changes merkletrie.Changes) merkletrie.Changes {
+ patterns, err := gitignore.ReadPatterns(w.Filesystem, nil)
+ if err != nil || len(patterns) == 0 {
+ return changes
+ }
+
+ patterns = append(patterns, w.Excludes...)
+
+ m := gitignore.NewMatcher(patterns)
+
+ var res merkletrie.Changes
+ for _, ch := range changes {
+ var path []string
+ for _, n := range ch.To {
+ path = append(path, n.Name())
+ }
+ if len(path) == 0 {
+ for _, n := range ch.From {
+ path = append(path, n.Name())
+ }
+ }
+ if len(path) != 0 {
+ isDir := (len(ch.To) > 0 && ch.To.IsDir()) || (len(ch.From) > 0 && ch.From.IsDir())
+ if m.Match(path, isDir) {
+ continue
+ }
+ }
+ res = append(res, ch)
+ }
+ return res
+}
+
+func (w *Worktree) getSubmodulesStatus() (map[string]plumbing.Hash, error) {
+ o := map[string]plumbing.Hash{}
+
+ sub, err := w.Submodules()
+ if err != nil {
+ return nil, err
+ }
+
+ status, err := sub.Status()
+ if err != nil {
+ return nil, err
+ }
+
+ for _, s := range status {
+ if s.Current.IsZero() {
+ o[s.Path] = s.Expected
+ continue
+ }
+
+ o[s.Path] = s.Current
+ }
+
+ return o, nil
+}
+
+func (w *Worktree) diffCommitWithStaging(commit plumbing.Hash, reverse bool) (merkletrie.Changes, error) {
+ var t *object.Tree
+ if !commit.IsZero() {
+ c, err := w.r.CommitObject(commit)
+ if err != nil {
+ return nil, err
+ }
+
+ t, err = c.Tree()
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return w.diffTreeWithStaging(t, reverse)
+}
+
+func (w *Worktree) diffTreeWithStaging(t *object.Tree, reverse bool) (merkletrie.Changes, error) {
+ var from noder.Noder
+ if t != nil {
+ from = object.NewTreeRootNode(t)
+ }
+
+ idx, err := w.r.Storer.Index()
+ if err != nil {
+ return nil, err
+ }
+
+ to := mindex.NewRootNode(idx)
+
+ if reverse {
+ return merkletrie.DiffTree(to, from, diffTreeIsEquals)
+ }
+
+ return merkletrie.DiffTree(from, to, diffTreeIsEquals)
+}
+
+var emptyNoderHash = make([]byte, 24)
+
+// diffTreeIsEquals is a implementation of noder.Equals, used to compare
+// noder.Noder, it compare the content and the length of the hashes.
+//
+// Since some of the noder.Noder implementations doesn't compute a hash for
+// some directories, if any of the hashes is a 24-byte slice of zero values
+// the comparison is not done and the hashes are take as different.
+func diffTreeIsEquals(a, b noder.Hasher) bool {
+ hashA := a.Hash()
+ hashB := b.Hash()
+
+ if bytes.Equal(hashA, emptyNoderHash) || bytes.Equal(hashB, emptyNoderHash) {
+ return false
+ }
+
+ return bytes.Equal(hashA, hashB)
+}
+
+// Add adds the file contents of a file in the worktree to the index. if the
+// file is already staged in the index no error is returned. If a file deleted
+// from the Workspace is given, the file is removed from the index. If a
+// directory given, adds the files and all his sub-directories recursively in
+// the worktree to the index. If any of the files is already staged in the index
+// no error is returned. When path is a file, the blob.Hash is returned.
+func (w *Worktree) Add(path string) (plumbing.Hash, error) {
+ // TODO(mcuadros): remove plumbing.Hash from signature at v5.
+ s, err := w.Status()
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ idx, err := w.r.Storer.Index()
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ var h plumbing.Hash
+ var added bool
+
+ fi, err := w.Filesystem.Lstat(path)
+ if err != nil || !fi.IsDir() {
+ added, h, err = w.doAddFile(idx, s, path)
+ } else {
+ added, err = w.doAddDirectory(idx, s, path)
+ }
+
+ if err != nil {
+ return h, err
+ }
+
+ if !added {
+ return h, nil
+ }
+
+ return h, w.r.Storer.SetIndex(idx)
+}
+
+func (w *Worktree) doAddDirectory(idx *index.Index, s Status, directory string) (added bool, err error) {
+ files, err := w.Filesystem.ReadDir(directory)
+ if err != nil {
+ return false, err
+ }
+
+ for _, file := range files {
+ name := path.Join(directory, file.Name())
+
+ var a bool
+ if file.IsDir() {
+ if file.Name() == GitDirName {
+ // ignore special git directory
+ continue
+ }
+ a, err = w.doAddDirectory(idx, s, name)
+ } else {
+ a, _, err = w.doAddFile(idx, s, name)
+ }
+
+ if err != nil {
+ return
+ }
+
+ if !added && a {
+ added = true
+ }
+ }
+
+ return
+}
+
+// AddGlob adds all paths, matching pattern, to the index. If pattern matches a
+// directory path, all directory contents are added to the index recursively. No
+// error is returned if all matching paths are already staged in index.
+func (w *Worktree) AddGlob(pattern string) error {
+ files, err := util.Glob(w.Filesystem, pattern)
+ if err != nil {
+ return err
+ }
+
+ if len(files) == 0 {
+ return ErrGlobNoMatches
+ }
+
+ s, err := w.Status()
+ if err != nil {
+ return err
+ }
+
+ idx, err := w.r.Storer.Index()
+ if err != nil {
+ return err
+ }
+
+ var saveIndex bool
+ for _, file := range files {
+ fi, err := w.Filesystem.Lstat(file)
+ if err != nil {
+ return err
+ }
+
+ var added bool
+ if fi.IsDir() {
+ added, err = w.doAddDirectory(idx, s, file)
+ } else {
+ added, _, err = w.doAddFile(idx, s, file)
+ }
+
+ if err != nil {
+ return err
+ }
+
+ if !saveIndex && added {
+ saveIndex = true
+ }
+ }
+
+ if saveIndex {
+ return w.r.Storer.SetIndex(idx)
+ }
+
+ return nil
+}
+
+// doAddFile create a new blob from path and update the index, added is true if
+// the file added is different from the index.
+func (w *Worktree) doAddFile(idx *index.Index, s Status, path string) (added bool, h plumbing.Hash, err error) {
+ if s.File(path).Worktree == Unmodified {
+ return false, h, nil
+ }
+
+ h, err = w.copyFileToStorage(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ added = true
+ h, err = w.deleteFromIndex(idx, path)
+ }
+
+ return
+ }
+
+ if err := w.addOrUpdateFileToIndex(idx, path, h); err != nil {
+ return false, h, err
+ }
+
+ return true, h, err
+}
+
+func (w *Worktree) copyFileToStorage(path string) (hash plumbing.Hash, err error) {
+ fi, err := w.Filesystem.Lstat(path)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ obj := w.r.Storer.NewEncodedObject()
+ obj.SetType(plumbing.BlobObject)
+ obj.SetSize(fi.Size())
+
+ writer, err := obj.Writer()
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ defer ioutil.CheckClose(writer, &err)
+
+ if fi.Mode()&os.ModeSymlink != 0 {
+ err = w.fillEncodedObjectFromSymlink(writer, path, fi)
+ } else {
+ err = w.fillEncodedObjectFromFile(writer, path, fi)
+ }
+
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ return w.r.Storer.SetEncodedObject(obj)
+}
+
+func (w *Worktree) fillEncodedObjectFromFile(dst io.Writer, path string, fi os.FileInfo) (err error) {
+ src, err := w.Filesystem.Open(path)
+ if err != nil {
+ return err
+ }
+
+ defer ioutil.CheckClose(src, &err)
+
+ if _, err := io.Copy(dst, src); err != nil {
+ return err
+ }
+
+ return err
+}
+
+func (w *Worktree) fillEncodedObjectFromSymlink(dst io.Writer, path string, fi os.FileInfo) error {
+ target, err := w.Filesystem.Readlink(path)
+ if err != nil {
+ return err
+ }
+
+ _, err = dst.Write([]byte(target))
+ return err
+}
+
+func (w *Worktree) addOrUpdateFileToIndex(idx *index.Index, filename string, h plumbing.Hash) error {
+ e, err := idx.Entry(filename)
+ if err != nil && err != index.ErrEntryNotFound {
+ return err
+ }
+
+ if err == index.ErrEntryNotFound {
+ return w.doAddFileToIndex(idx, filename, h)
+ }
+
+ return w.doUpdateFileToIndex(e, filename, h)
+}
+
+func (w *Worktree) doAddFileToIndex(idx *index.Index, filename string, h plumbing.Hash) error {
+ return w.doUpdateFileToIndex(idx.Add(filename), filename, h)
+}
+
+func (w *Worktree) doUpdateFileToIndex(e *index.Entry, filename string, h plumbing.Hash) error {
+ info, err := w.Filesystem.Lstat(filename)
+ if err != nil {
+ return err
+ }
+
+ e.Hash = h
+ e.ModifiedAt = info.ModTime()
+ e.Mode, err = filemode.NewFromOSFileMode(info.Mode())
+ if err != nil {
+ return err
+ }
+
+ if e.Mode.IsRegular() {
+ e.Size = uint32(info.Size())
+ }
+
+ fillSystemInfo(e, info.Sys())
+ return nil
+}
+
+// Remove removes files from the working tree and from the index.
+func (w *Worktree) Remove(path string) (plumbing.Hash, error) {
+ // TODO(mcuadros): remove plumbing.Hash from signature at v5.
+ idx, err := w.r.Storer.Index()
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ var h plumbing.Hash
+
+ fi, err := w.Filesystem.Lstat(path)
+ if err != nil || !fi.IsDir() {
+ h, err = w.doRemoveFile(idx, path)
+ } else {
+ _, err = w.doRemoveDirectory(idx, path)
+ }
+ if err != nil {
+ return h, err
+ }
+
+ return h, w.r.Storer.SetIndex(idx)
+}
+
+func (w *Worktree) doRemoveDirectory(idx *index.Index, directory string) (removed bool, err error) {
+ files, err := w.Filesystem.ReadDir(directory)
+ if err != nil {
+ return false, err
+ }
+
+ for _, file := range files {
+ name := path.Join(directory, file.Name())
+
+ var r bool
+ if file.IsDir() {
+ r, err = w.doRemoveDirectory(idx, name)
+ } else {
+ _, err = w.doRemoveFile(idx, name)
+ if err == index.ErrEntryNotFound {
+ err = nil
+ }
+ }
+
+ if err != nil {
+ return
+ }
+
+ if !removed && r {
+ removed = true
+ }
+ }
+
+ err = w.removeEmptyDirectory(directory)
+ return
+}
+
+func (w *Worktree) removeEmptyDirectory(path string) error {
+ files, err := w.Filesystem.ReadDir(path)
+ if err != nil {
+ return err
+ }
+
+ if len(files) != 0 {
+ return nil
+ }
+
+ return w.Filesystem.Remove(path)
+}
+
+func (w *Worktree) doRemoveFile(idx *index.Index, path string) (plumbing.Hash, error) {
+ hash, err := w.deleteFromIndex(idx, path)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ return hash, w.deleteFromFilesystem(path)
+}
+
+func (w *Worktree) deleteFromIndex(idx *index.Index, path string) (plumbing.Hash, error) {
+ e, err := idx.Remove(path)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ return e.Hash, nil
+}
+
+func (w *Worktree) deleteFromFilesystem(path string) error {
+ err := w.Filesystem.Remove(path)
+ if os.IsNotExist(err) {
+ return nil
+ }
+
+ return err
+}
+
+// RemoveGlob removes all paths, matching pattern, from the index. If pattern
+// matches a directory path, all directory contents are removed from the index
+// recursively.
+func (w *Worktree) RemoveGlob(pattern string) error {
+ idx, err := w.r.Storer.Index()
+ if err != nil {
+ return err
+ }
+
+ entries, err := idx.Glob(pattern)
+ if err != nil {
+ return err
+ }
+
+ for _, e := range entries {
+ file := filepath.FromSlash(e.Name)
+ if _, err := w.Filesystem.Lstat(file); err != nil && !os.IsNotExist(err) {
+ return err
+ }
+
+ if _, err := w.doRemoveFile(idx, file); err != nil {
+ return err
+ }
+
+ dir, _ := filepath.Split(file)
+ if err := w.removeEmptyDirectory(dir); err != nil {
+ return err
+ }
+ }
+
+ return w.r.Storer.SetIndex(idx)
+}
+
+// Move moves or rename a file in the worktree and the index, directories are
+// not supported.
+func (w *Worktree) Move(from, to string) (plumbing.Hash, error) {
+ // TODO(mcuadros): support directories and/or implement support for glob
+ if _, err := w.Filesystem.Lstat(from); err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ if _, err := w.Filesystem.Lstat(to); err == nil {
+ return plumbing.ZeroHash, ErrDestinationExists
+ }
+
+ idx, err := w.r.Storer.Index()
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ hash, err := w.deleteFromIndex(idx, from)
+ if err != nil {
+ return plumbing.ZeroHash, err
+ }
+
+ if err := w.Filesystem.Rename(from, to); err != nil {
+ return hash, err
+ }
+
+ if err := w.addOrUpdateFileToIndex(idx, to, hash); err != nil {
+ return hash, err
+ }
+
+ return hash, w.r.Storer.SetIndex(idx)
+}
--- /dev/null
+// +build openbsd dragonfly solaris
+
+package git
+
+import (
+ "syscall"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/format/index"
+)
+
+func init() {
+ fillSystemInfo = func(e *index.Entry, sys interface{}) {
+ if os, ok := sys.(*syscall.Stat_t); ok {
+ e.CreatedAt = time.Unix(int64(os.Atim.Sec), int64(os.Atim.Nsec))
+ e.Dev = uint32(os.Dev)
+ e.Inode = uint32(os.Ino)
+ e.GID = os.Gid
+ e.UID = os.Uid
+ }
+ }
+}
+
+func isSymlinkWindowsNonAdmin(err error) bool {
+ return false
+}
--- /dev/null
+// +build windows
+
+package git
+
+import (
+ "os"
+ "syscall"
+ "time"
+
+ "gopkg.in/src-d/go-git.v4/plumbing/format/index"
+)
+
+func init() {
+ fillSystemInfo = func(e *index.Entry, sys interface{}) {
+ if os, ok := sys.(*syscall.Win32FileAttributeData); ok {
+ seconds := os.CreationTime.Nanoseconds() / 1000000000
+ nanoseconds := os.CreationTime.Nanoseconds() - seconds*1000000000
+ e.CreatedAt = time.Unix(seconds, nanoseconds)
+ }
+ }
+}
+
+func isSymlinkWindowsNonAdmin(err error) bool {
+ const ERROR_PRIVILEGE_NOT_HELD syscall.Errno = 1314
+
+ if err != nil {
+ if errLink, ok := err.(*os.LinkError); ok {
+ if errNo, ok := errLink.Err.(syscall.Errno); ok {
+ return errNo == ERROR_PRIVILEGE_NOT_HELD
+ }
+ }
+ }
+
+ return false
+}
--- /dev/null
+Copyright (c) 2016 Péter Surányi.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--- /dev/null
+// Package warnings implements error handling with non-fatal errors (warnings).
+//
+// A recurring pattern in Go programming is the following:
+//
+// func myfunc(params) error {
+// if err := doSomething(...); err != nil {
+// return err
+// }
+// if err := doSomethingElse(...); err != nil {
+// return err
+// }
+// if ok := doAnotherThing(...); !ok {
+// return errors.New("my error")
+// }
+// ...
+// return nil
+// }
+//
+// This pattern allows interrupting the flow on any received error. But what if
+// there are errors that should be noted but still not fatal, for which the flow
+// should not be interrupted? Implementing such logic at each if statement would
+// make the code complex and the flow much harder to follow.
+//
+// Package warnings provides the Collector type and a clean and simple pattern
+// for achieving such logic. The Collector takes care of deciding when to break
+// the flow and when to continue, collecting any non-fatal errors (warnings)
+// along the way. The only requirement is that fatal and non-fatal errors can be
+// distinguished programmatically; that is a function such as
+//
+// IsFatal(error) bool
+//
+// must be implemented. The following is an example of what the above snippet
+// could look like using the warnings package:
+//
+// import "gopkg.in/warnings.v0"
+//
+// func isFatal(err error) bool {
+// _, ok := err.(WarningType)
+// return !ok
+// }
+//
+// func myfunc(params) error {
+// c := warnings.NewCollector(isFatal)
+// c.FatalWithWarnings = true
+// if err := c.Collect(doSomething()); err != nil {
+// return err
+// }
+// if err := c.Collect(doSomethingElse(...)); err != nil {
+// return err
+// }
+// if ok := doAnotherThing(...); !ok {
+// if err := c.Collect(errors.New("my error")); err != nil {
+// return err
+// }
+// }
+// ...
+// return c.Done()
+// }
+//
+// For an example of a non-trivial code base using this library, see
+// gopkg.in/gcfg.v1
+//
+// Rules for using warnings
+//
+// - ensure that warnings are programmatically distinguishable from fatal
+// errors (i.e. implement an isFatal function and any necessary error types)
+// - ensure that there is a single Collector instance for a call of each
+// exported function
+// - ensure that all errors (fatal or warning) are fed through Collect
+// - ensure that every time an error is returned, it is one returned by a
+// Collector (from Collect or Done)
+// - ensure that Collect is never called after Done
+//
+// TODO
+//
+// - optionally limit the number of warnings (e.g. stop after 20 warnings) (?)
+// - consider interaction with contexts
+// - go vet-style invocations verifier
+// - semi-automatic code converter
+//
+package warnings // import "gopkg.in/warnings.v0"
+
+import (
+ "bytes"
+ "fmt"
+)
+
+// List holds a collection of warnings and optionally one fatal error.
+type List struct {
+ Warnings []error
+ Fatal error
+}
+
+// Error implements the error interface.
+func (l List) Error() string {
+ b := bytes.NewBuffer(nil)
+ if l.Fatal != nil {
+ fmt.Fprintln(b, "fatal:")
+ fmt.Fprintln(b, l.Fatal)
+ }
+ switch len(l.Warnings) {
+ case 0:
+ // nop
+ case 1:
+ fmt.Fprintln(b, "warning:")
+ default:
+ fmt.Fprintln(b, "warnings:")
+ }
+ for _, err := range l.Warnings {
+ fmt.Fprintln(b, err)
+ }
+ return b.String()
+}
+
+// A Collector collects errors up to the first fatal error.
+type Collector struct {
+ // IsFatal distinguishes between warnings and fatal errors.
+ IsFatal func(error) bool
+ // FatalWithWarnings set to true means that a fatal error is returned as
+ // a List together with all warnings so far. The default behavior is to
+ // only return the fatal error and discard any warnings that have been
+ // collected.
+ FatalWithWarnings bool
+
+ l List
+ done bool
+}
+
+// NewCollector returns a new Collector; it uses isFatal to distinguish between
+// warnings and fatal errors.
+func NewCollector(isFatal func(error) bool) *Collector {
+ return &Collector{IsFatal: isFatal}
+}
+
+// Collect collects a single error (warning or fatal). It returns nil if
+// collection can continue (only warnings so far), or otherwise the errors
+// collected. Collect mustn't be called after the first fatal error or after
+// Done has been called.
+func (c *Collector) Collect(err error) error {
+ if c.done {
+ panic("warnings.Collector already done")
+ }
+ if err == nil {
+ return nil
+ }
+ if c.IsFatal(err) {
+ c.done = true
+ c.l.Fatal = err
+ } else {
+ c.l.Warnings = append(c.l.Warnings, err)
+ }
+ if c.l.Fatal != nil {
+ return c.erorr()
+ }
+ return nil
+}
+
+// Done ends collection and returns the collected error(s).
+func (c *Collector) Done() error {
+ c.done = true
+ return c.erorr()
+}
+
+func (c *Collector) erorr() error {
+ if !c.FatalWithWarnings && c.l.Fatal != nil {
+ return c.l.Fatal
+ }
+ if c.l.Fatal == nil && len(c.l.Warnings) == 0 {
+ return nil
+ }
+ // Note that a single warning is also returned as a List. This is to make it
+ // easier to determine fatal-ness of the returned error.
+ return c.l
+}
+
+// FatalOnly returns the fatal error, if any, **in an error returned by a
+// Collector**. It returns nil if and only if err is nil or err is a List
+// with err.Fatal == nil.
+func FatalOnly(err error) error {
+ l, ok := err.(List)
+ if !ok {
+ return err
+ }
+ return l.Fatal
+}
+
+// WarningsOnly returns the warnings **in an error returned by a Collector**.
+func WarningsOnly(err error) []error {
+ l, ok := err.(List)
+ if !ok {
+ return nil
+ }
+ return l.Warnings
+}