]> source.dussan.org Git - gitea.git/commitdiff
Remove x/net/context vendor by using std package (#5202)
authorAntoine GIRARD <sapk@users.noreply.github.com>
Sat, 10 Nov 2018 23:55:36 +0000 (00:55 +0100)
committertechknowlogick <hello@techknowlogick.com>
Sat, 10 Nov 2018 23:55:36 +0000 (18:55 -0500)
* Update dep github.com/markbates/goth

* Update dep github.com/blevesearch/bleve

* Update dep golang.org/x/oauth2

* Fix github.com/blevesearch/bleve to c74e08f039e56cef576e4336382b2a2d12d9e026

* Update dep golang.org/x/oauth2

40 files changed:
Gopkg.lock
Gopkg.toml
vendor/github.com/blevesearch/bleve/index.go
vendor/github.com/blevesearch/bleve/index/scorch/introducer.go
vendor/github.com/blevesearch/bleve/index/scorch/merge.go
vendor/github.com/blevesearch/bleve/index/scorch/mergeplan/merge_plan.go
vendor/github.com/blevesearch/bleve/index/scorch/persister.go
vendor/github.com/blevesearch/bleve/index/scorch/scorch.go
vendor/github.com/blevesearch/bleve/index/scorch/segment/mem/build.go
vendor/github.com/blevesearch/bleve/index/scorch/segment/mem/dict.go
vendor/github.com/blevesearch/bleve/index/scorch/segment/zap/build.go
vendor/github.com/blevesearch/bleve/index/scorch/segment/zap/contentcoder.go
vendor/github.com/blevesearch/bleve/index/scorch/segment/zap/dict.go
vendor/github.com/blevesearch/bleve/index/scorch/segment/zap/docvalues.go
vendor/github.com/blevesearch/bleve/index/scorch/segment/zap/enumerator.go [new file with mode: 0644]
vendor/github.com/blevesearch/bleve/index/scorch/segment/zap/intcoder.go
vendor/github.com/blevesearch/bleve/index/scorch/segment/zap/merge.go
vendor/github.com/blevesearch/bleve/index/scorch/segment/zap/posting.go
vendor/github.com/blevesearch/bleve/index/scorch/segment/zap/read.go
vendor/github.com/blevesearch/bleve/index/scorch/segment/zap/segment.go
vendor/github.com/blevesearch/bleve/index/scorch/snapshot_rollback.go
vendor/github.com/blevesearch/bleve/index/upsidedown/upsidedown.go
vendor/github.com/blevesearch/bleve/index_alias_impl.go
vendor/github.com/blevesearch/bleve/index_impl.go
vendor/github.com/blevesearch/bleve/search/collector.go
vendor/github.com/blevesearch/bleve/search/collector/topn.go
vendor/github.com/markbates/goth/provider.go
vendor/github.com/markbates/goth/providers/facebook/facebook.go
vendor/golang.org/x/net/context/ctxhttp/ctxhttp.go [new file with mode: 0644]
vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17.go [new file with mode: 0644]
vendor/golang.org/x/oauth2/LICENSE
vendor/golang.org/x/oauth2/client_appengine.go [deleted file]
vendor/golang.org/x/oauth2/internal/client_appengine.go [new file with mode: 0644]
vendor/golang.org/x/oauth2/internal/doc.go [new file with mode: 0644]
vendor/golang.org/x/oauth2/internal/oauth2.go
vendor/golang.org/x/oauth2/internal/token.go
vendor/golang.org/x/oauth2/internal/transport.go
vendor/golang.org/x/oauth2/oauth2.go
vendor/golang.org/x/oauth2/token.go
vendor/golang.org/x/oauth2/transport.go

index 1a2b4b5f5bc605d47999a08b4c8faddb320e58c2..cbc089fead46e9cb49bba92f33690d3139eb5db8 100644 (file)
@@ -90,7 +90,7 @@
   revision = "3a771d992973f24aa725d07868b467d1ddfceafb"
 
 [[projects]]
-  digest = "1:67351095005f164e748a5a21899d1403b03878cb2d40a7b0f742376e6eeda974"
+  digest = "1:c10f35be6200b09e26da267ca80f837315093ecaba27e7a223071380efb9dd32"
   name = "github.com/blevesearch/bleve"
   packages = [
     ".",
     "search/searcher",
   ]
   pruneopts = "NUT"
-  revision = "ff210fbc6d348ad67aa5754eaea11a463fcddafd"
+  revision = "c74e08f039e56cef576e4336382b2a2d12d9e026"
 
 [[projects]]
   branch = "master"
   revision = "e3534c89ef969912856dfa39e56b09e58c5f5daf"
 
 [[projects]]
-  digest = "1:23f75ae90fcc38dac6fad6881006ea7d0f2c78db5f9f81f3df558dc91460e61f"
+  digest = "1:4b992ec853d0ea9bac3dcf09a64af61de1a392e6cb0eef2204c0c92f4ae6b911"
   name = "github.com/markbates/goth"
   packages = [
     ".",
     "providers/twitter",
   ]
   pruneopts = "NUT"
-  revision = "f9c6649ab984d6ea71ef1e13b7b1cdffcf4592d3"
-  version = "v1.46.1"
+  revision = "bc6d8ddf751a745f37ca5567dbbfc4157bbf5da9"
+  version = "v1.47.2"
 
 [[projects]]
   digest = "1:c9724c929d27a14475a45b17a267dbc60671c0bc2c5c05ed21f011f7b5bc9fb5"
 
 [[projects]]
   branch = "master"
-  digest = "1:6d5ed712653ea5321fe3e3475ab2188cf362a4e0d31e9fd3acbd4dfbbca0d680"
+  digest = "1:d0a0bdd2b64d981aa4e6a1ade90431d042cd7fa31b584e33d45e62cbfec43380"
   name = "golang.org/x/net"
   packages = [
     "context",
+    "context/ctxhttp",
     "html",
     "html/atom",
     "html/charset",
   revision = "9b4f9f5ad5197c79fd623a3638e70d8b26cef344"
 
 [[projects]]
-  digest = "1:8159a9cda4b8810aaaeb0d60e2fa68e2fd86d8af4ec8f5059830839e3c8d93d5"
+  branch = "master"
+  digest = "1:274a6321a5a9f185eeb3fab5d7d8397e0e9f57737490d749f562c7e205ffbc2e"
   name = "golang.org/x/oauth2"
   packages = [
     ".",
     "internal",
   ]
   pruneopts = "NUT"
-  revision = "c10ba270aa0bf8b8c1c986e103859c67a9103061"
+  revision = "c453e0c757598fd055e170a3a359263c91e13153"
 
 [[projects]]
   digest = "1:9f303486d623f840492bfeb48eb906a94e9d3fe638a761639b72ce64bf7bfcc3"
index 6338263bcc5a935c613622295c2d451f9ffc4eb2..2633d8b1dd1c226e2e4318189d5f392b1b57f0a6 100644 (file)
@@ -14,6 +14,12 @@ ignored = ["google.golang.org/appengine*"]
   branch = "master"
   name = "code.gitea.io/sdk"
 
+[[constraint]]
+#  branch = "master"
+  revision = "c74e08f039e56cef576e4336382b2a2d12d9e026"
+  name = "github.com/blevesearch/bleve"
+#Not targetting v0.7.0 since standard where use only just after this tag
+
 [[constraint]]
   revision = "12dd70caea0268ac0d6c2707d0611ef601e7c64e"
   name = "golang.org/x/crypto"
@@ -61,7 +67,7 @@ ignored = ["google.golang.org/appengine*"]
 
 [[constraint]]
   name = "github.com/markbates/goth"
-  version = "1.46.1"
+  version = "1.47.2"
 
 [[constraint]]
   branch = "master"
@@ -105,7 +111,7 @@ ignored = ["google.golang.org/appengine*"]
   source = "github.com/go-gitea/bolt"
 
 [[override]]
-  revision = "c10ba270aa0bf8b8c1c986e103859c67a9103061"
+  branch = "master"
   name = "golang.org/x/oauth2"
 
 [[constraint]]
index e85652d967e2fad2775eddf377abb6d6b2b1219f..ea7b3832ac78dfa62b24405b837ee64447c347a2 100644 (file)
 package bleve
 
 import (
+       "context"
+
        "github.com/blevesearch/bleve/document"
        "github.com/blevesearch/bleve/index"
        "github.com/blevesearch/bleve/index/store"
        "github.com/blevesearch/bleve/mapping"
-       "golang.org/x/net/context"
 )
 
 // A Batch groups together multiple Index and Delete
index 4499fa41bd42ba47b4fd935df15fe5df83da3c47..1a7d656ca7b24126340919e0dba8e3af9c57d00e 100644 (file)
@@ -100,8 +100,8 @@ func (s *Scorch) introduceSegment(next *segmentIntroduction) error {
        // prepare new index snapshot
        newSnapshot := &IndexSnapshot{
                parent:   s,
-               segment:  make([]*SegmentSnapshot, nsegs, nsegs+1),
-               offsets:  make([]uint64, nsegs, nsegs+1),
+               segment:  make([]*SegmentSnapshot, 0, nsegs+1),
+               offsets:  make([]uint64, 0, nsegs+1),
                internal: make(map[string][]byte, len(s.root.internal)),
                epoch:    s.nextSnapshotEpoch,
                refs:     1,
@@ -124,24 +124,29 @@ func (s *Scorch) introduceSegment(next *segmentIntroduction) error {
                                return err
                        }
                }
-               newSnapshot.segment[i] = &SegmentSnapshot{
+
+               newss := &SegmentSnapshot{
                        id:         s.root.segment[i].id,
                        segment:    s.root.segment[i].segment,
                        cachedDocs: s.root.segment[i].cachedDocs,
                }
-               s.root.segment[i].segment.AddRef()
 
                // apply new obsoletions
                if s.root.segment[i].deleted == nil {
-                       newSnapshot.segment[i].deleted = delta
+                       newss.deleted = delta
                } else {
-                       newSnapshot.segment[i].deleted = roaring.Or(s.root.segment[i].deleted, delta)
+                       newss.deleted = roaring.Or(s.root.segment[i].deleted, delta)
                }
 
-               newSnapshot.offsets[i] = running
-               running += s.root.segment[i].Count()
-
+               // check for live size before copying
+               if newss.LiveSize() > 0 {
+                       newSnapshot.segment = append(newSnapshot.segment, newss)
+                       s.root.segment[i].segment.AddRef()
+                       newSnapshot.offsets = append(newSnapshot.offsets, running)
+                       running += s.root.segment[i].Count()
+               }
        }
+
        // append new segment, if any, to end of the new index snapshot
        if next.data != nil {
                newSegmentSnapshot := &SegmentSnapshot{
@@ -193,6 +198,12 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
        // prepare new index snapshot
        currSize := len(s.root.segment)
        newSize := currSize + 1 - len(nextMerge.old)
+
+       // empty segments deletion
+       if nextMerge.new == nil {
+               newSize--
+       }
+
        newSnapshot := &IndexSnapshot{
                parent:   s,
                segment:  make([]*SegmentSnapshot, 0, newSize),
@@ -210,7 +221,7 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
                segmentID := s.root.segment[i].id
                if segSnapAtMerge, ok := nextMerge.old[segmentID]; ok {
                        // this segment is going away, see if anything else was deleted since we started the merge
-                       if s.root.segment[i].deleted != nil {
+                       if segSnapAtMerge != nil && s.root.segment[i].deleted != nil {
                                // assume all these deletes are new
                                deletedSince := s.root.segment[i].deleted
                                // if we already knew about some of them, remove
@@ -224,7 +235,13 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
                                        newSegmentDeleted.Add(uint32(newDocNum))
                                }
                        }
-               } else {
+                       // clean up the old segment map to figure out the
+                       // obsolete segments wrt root in meantime, whatever
+                       // segments left behind in old map after processing
+                       // the root segments would be the obsolete segment set
+                       delete(nextMerge.old, segmentID)
+
+               } else if s.root.segment[i].LiveSize() > 0 {
                        // this segment is staying
                        newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{
                                id:         s.root.segment[i].id,
@@ -238,14 +255,35 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
                }
        }
 
-       // put new segment at end
-       newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{
-               id:         nextMerge.id,
-               segment:    nextMerge.new, // take ownership for nextMerge.new's ref-count
-               deleted:    newSegmentDeleted,
-               cachedDocs: &cachedDocs{cache: nil},
-       })
-       newSnapshot.offsets = append(newSnapshot.offsets, running)
+       // before the newMerge introduction, need to clean the newly
+       // merged segment wrt the current root segments, hence
+       // applying the obsolete segment contents to newly merged segment
+       for segID, ss := range nextMerge.old {
+               obsoleted := ss.DocNumbersLive()
+               if obsoleted != nil {
+                       obsoletedIter := obsoleted.Iterator()
+                       for obsoletedIter.HasNext() {
+                               oldDocNum := obsoletedIter.Next()
+                               newDocNum := nextMerge.oldNewDocNums[segID][oldDocNum]
+                               newSegmentDeleted.Add(uint32(newDocNum))
+                       }
+               }
+       }
+       // In case where all the docs in the newly merged segment getting
+       // deleted by the time we reach here, can skip the introduction.
+       if nextMerge.new != nil &&
+               nextMerge.new.Count() > newSegmentDeleted.GetCardinality() {
+               // put new segment at end
+               newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{
+                       id:         nextMerge.id,
+                       segment:    nextMerge.new, // take ownership for nextMerge.new's ref-count
+                       deleted:    newSegmentDeleted,
+                       cachedDocs: &cachedDocs{cache: nil},
+               })
+               newSnapshot.offsets = append(newSnapshot.offsets, running)
+       }
+
+       newSnapshot.AddRef() // 1 ref for the nextMerge.notify response
 
        // swap in new segment
        rootPrev := s.root
@@ -257,7 +295,8 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
                _ = rootPrev.DecRef()
        }
 
-       // notify merger we incorporated this
+       // notify requester that we incorporated this
+       nextMerge.notify <- newSnapshot
        close(nextMerge.notify)
 }
 
index 5ded29b5a367fc0c355bb8c9f13aa7b94f7dad2d..ad756588a62054bf294e9b646be315a782dfee3c 100644 (file)
@@ -15,6 +15,9 @@
 package scorch
 
 import (
+       "bytes"
+       "encoding/json"
+
        "fmt"
        "os"
        "sync/atomic"
@@ -28,6 +31,13 @@ import (
 
 func (s *Scorch) mergerLoop() {
        var lastEpochMergePlanned uint64
+       mergePlannerOptions, err := s.parseMergePlannerOptions()
+       if err != nil {
+               s.fireAsyncError(fmt.Errorf("mergePlannerOption json parsing err: %v", err))
+               s.asyncTasks.Done()
+               return
+       }
+
 OUTER:
        for {
                select {
@@ -45,7 +55,7 @@ OUTER:
                                startTime := time.Now()
 
                                // lets get started
-                               err := s.planMergeAtSnapshot(ourSnapshot)
+                               err := s.planMergeAtSnapshot(ourSnapshot, mergePlannerOptions)
                                if err != nil {
                                        s.fireAsyncError(fmt.Errorf("merging err: %v", err))
                                        _ = ourSnapshot.DecRef()
@@ -58,51 +68,49 @@ OUTER:
                        _ = ourSnapshot.DecRef()
 
                        // tell the persister we're waiting for changes
-                       // first make a notification chan
-                       notifyUs := make(notificationChan)
+                       // first make a epochWatcher chan
+                       ew := &epochWatcher{
+                               epoch:    lastEpochMergePlanned,
+                               notifyCh: make(notificationChan, 1),
+                       }
 
                        // give it to the persister
                        select {
                        case <-s.closeCh:
                                break OUTER
-                       case s.persisterNotifier <- notifyUs:
-                       }
-
-                       // check again
-                       s.rootLock.RLock()
-                       ourSnapshot = s.root
-                       ourSnapshot.AddRef()
-                       s.rootLock.RUnlock()
-
-                       if ourSnapshot.epoch != lastEpochMergePlanned {
-                               startTime := time.Now()
-
-                               // lets get started
-                               err := s.planMergeAtSnapshot(ourSnapshot)
-                               if err != nil {
-                                       s.fireAsyncError(fmt.Errorf("merging err: %v", err))
-                                       _ = ourSnapshot.DecRef()
-                                       continue OUTER
-                               }
-                               lastEpochMergePlanned = ourSnapshot.epoch
-
-                               s.fireEvent(EventKindMergerProgress, time.Since(startTime))
+                       case s.persisterNotifier <- ew:
                        }
-                       _ = ourSnapshot.DecRef()
 
-                       // now wait for it (but also detect close)
+                       // now wait for persister (but also detect close)
                        select {
                        case <-s.closeCh:
                                break OUTER
-                       case <-notifyUs:
-                               // woken up, next loop should pick up work
+                       case <-ew.notifyCh:
                        }
                }
        }
        s.asyncTasks.Done()
 }
 
-func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error {
+func (s *Scorch) parseMergePlannerOptions() (*mergeplan.MergePlanOptions,
+       error) {
+       mergePlannerOptions := mergeplan.DefaultMergePlanOptions
+       if v, ok := s.config["scorchMergePlanOptions"]; ok {
+               b, err := json.Marshal(v)
+               if err != nil {
+                       return &mergePlannerOptions, err
+               }
+
+               err = json.Unmarshal(b, &mergePlannerOptions)
+               if err != nil {
+                       return &mergePlannerOptions, err
+               }
+       }
+       return &mergePlannerOptions, nil
+}
+
+func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot,
+       options *mergeplan.MergePlanOptions) error {
        // build list of zap segments in this snapshot
        var onlyZapSnapshots []mergeplan.Segment
        for _, segmentSnapshot := range ourSnapshot.segment {
@@ -112,7 +120,7 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error {
        }
 
        // give this list to the planner
-       resultMergePlan, err := mergeplan.Plan(onlyZapSnapshots, nil)
+       resultMergePlan, err := mergeplan.Plan(onlyZapSnapshots, options)
        if err != nil {
                return fmt.Errorf("merge planning err: %v", err)
        }
@@ -122,8 +130,12 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error {
        }
 
        // process tasks in serial for now
-       var notifications []notificationChan
+       var notifications []chan *IndexSnapshot
        for _, task := range resultMergePlan.Tasks {
+               if len(task.Segments) == 0 {
+                       continue
+               }
+
                oldMap := make(map[uint64]*SegmentSnapshot)
                newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1)
                segmentsToMerge := make([]*zap.Segment, 0, len(task.Segments))
@@ -132,40 +144,51 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error {
                        if segSnapshot, ok := planSegment.(*SegmentSnapshot); ok {
                                oldMap[segSnapshot.id] = segSnapshot
                                if zapSeg, ok := segSnapshot.segment.(*zap.Segment); ok {
-                                       segmentsToMerge = append(segmentsToMerge, zapSeg)
-                                       docsToDrop = append(docsToDrop, segSnapshot.deleted)
+                                       if segSnapshot.LiveSize() == 0 {
+                                               oldMap[segSnapshot.id] = nil
+                                       } else {
+                                               segmentsToMerge = append(segmentsToMerge, zapSeg)
+                                               docsToDrop = append(docsToDrop, segSnapshot.deleted)
+                                       }
                                }
                        }
                }
 
-               filename := zapFileName(newSegmentID)
-               s.markIneligibleForRemoval(filename)
-               path := s.path + string(os.PathSeparator) + filename
-               newDocNums, err := zap.Merge(segmentsToMerge, docsToDrop, path, DefaultChunkFactor)
-               if err != nil {
-                       s.unmarkIneligibleForRemoval(filename)
-                       return fmt.Errorf("merging failed: %v", err)
-               }
-               segment, err := zap.Open(path)
-               if err != nil {
-                       s.unmarkIneligibleForRemoval(filename)
-                       return err
+               var oldNewDocNums map[uint64][]uint64
+               var segment segment.Segment
+               if len(segmentsToMerge) > 0 {
+                       filename := zapFileName(newSegmentID)
+                       s.markIneligibleForRemoval(filename)
+                       path := s.path + string(os.PathSeparator) + filename
+                       newDocNums, err := zap.Merge(segmentsToMerge, docsToDrop, path, 1024)
+                       if err != nil {
+                               s.unmarkIneligibleForRemoval(filename)
+                               return fmt.Errorf("merging failed: %v", err)
+                       }
+                       segment, err = zap.Open(path)
+                       if err != nil {
+                               s.unmarkIneligibleForRemoval(filename)
+                               return err
+                       }
+                       oldNewDocNums = make(map[uint64][]uint64)
+                       for i, segNewDocNums := range newDocNums {
+                               oldNewDocNums[task.Segments[i].Id()] = segNewDocNums
+                       }
                }
+
                sm := &segmentMerge{
                        id:            newSegmentID,
                        old:           oldMap,
-                       oldNewDocNums: make(map[uint64][]uint64),
+                       oldNewDocNums: oldNewDocNums,
                        new:           segment,
-                       notify:        make(notificationChan),
+                       notify:        make(chan *IndexSnapshot, 1),
                }
                notifications = append(notifications, sm.notify)
-               for i, segNewDocNums := range newDocNums {
-                       sm.oldNewDocNums[task.Segments[i].Id()] = segNewDocNums
-               }
 
                // give it to the introducer
                select {
                case <-s.closeCh:
+                       _ = segment.Close()
                        return nil
                case s.merges <- sm:
                }
@@ -174,7 +197,10 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error {
                select {
                case <-s.closeCh:
                        return nil
-               case <-notification:
+               case newSnapshot := <-notification:
+                       if newSnapshot != nil {
+                               _ = newSnapshot.DecRef()
+                       }
                }
        }
        return nil
@@ -185,5 +211,72 @@ type segmentMerge struct {
        old           map[uint64]*SegmentSnapshot
        oldNewDocNums map[uint64][]uint64
        new           segment.Segment
-       notify        notificationChan
+       notify        chan *IndexSnapshot
+}
+
+// perform a merging of the given SegmentBase instances into a new,
+// persisted segment, and synchronously introduce that new segment
+// into the root
+func (s *Scorch) mergeSegmentBases(snapshot *IndexSnapshot,
+       sbs []*zap.SegmentBase, sbsDrops []*roaring.Bitmap, sbsIndexes []int,
+       chunkFactor uint32) (uint64, *IndexSnapshot, uint64, error) {
+       var br bytes.Buffer
+
+       cr := zap.NewCountHashWriter(&br)
+
+       newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset,
+               docValueOffset, dictLocs, fieldsInv, fieldsMap, err :=
+               zap.MergeToWriter(sbs, sbsDrops, chunkFactor, cr)
+       if err != nil {
+               return 0, nil, 0, err
+       }
+
+       sb, err := zap.InitSegmentBase(br.Bytes(), cr.Sum32(), chunkFactor,
+               fieldsMap, fieldsInv, numDocs, storedIndexOffset, fieldsIndexOffset,
+               docValueOffset, dictLocs)
+       if err != nil {
+               return 0, nil, 0, err
+       }
+
+       newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1)
+
+       filename := zapFileName(newSegmentID)
+       path := s.path + string(os.PathSeparator) + filename
+       err = zap.PersistSegmentBase(sb, path)
+       if err != nil {
+               return 0, nil, 0, err
+       }
+
+       segment, err := zap.Open(path)
+       if err != nil {
+               return 0, nil, 0, err
+       }
+
+       sm := &segmentMerge{
+               id:            newSegmentID,
+               old:           make(map[uint64]*SegmentSnapshot),
+               oldNewDocNums: make(map[uint64][]uint64),
+               new:           segment,
+               notify:        make(chan *IndexSnapshot, 1),
+       }
+
+       for i, idx := range sbsIndexes {
+               ss := snapshot.segment[idx]
+               sm.old[ss.id] = ss
+               sm.oldNewDocNums[ss.id] = newDocNums[i]
+       }
+
+       select { // send to introducer
+       case <-s.closeCh:
+               _ = segment.DecRef()
+               return 0, nil, 0, nil // TODO: return ErrInterruptedClosed?
+       case s.merges <- sm:
+       }
+
+       select { // wait for introduction to complete
+       case <-s.closeCh:
+               return 0, nil, 0, nil // TODO: return ErrInterruptedClosed?
+       case newSnapshot := <-sm.notify:
+               return numDocs, newSnapshot, newSegmentID, nil
+       }
 }
index 0afc3ce5c673a41645b4903b0413c567c46c0d98..62f643f431f8b706a15e58dd0e8030fbcd7eccbe 100644 (file)
@@ -186,13 +186,13 @@ func plan(segmentsIn []Segment, o *MergePlanOptions) (*MergePlan, error) {
 
        // While we’re over budget, keep looping, which might produce
        // another MergeTask.
-       for len(eligibles) > budgetNumSegments {
+       for len(eligibles) > 0 && (len(eligibles)+len(rv.Tasks)) > budgetNumSegments {
                // Track a current best roster as we examine and score
                // potential rosters of merges.
                var bestRoster []Segment
                var bestRosterScore float64 // Lower score is better.
 
-               for startIdx := 0; startIdx < len(eligibles)-o.SegmentsPerMergeTask; startIdx++ {
+               for startIdx := 0; startIdx < len(eligibles); startIdx++ {
                        var roster []Segment
                        var rosterLiveSize int64
 
index cdcee37c2ef630ce4b130de4660373f1a7092893..c21bb1439450f1c51fd351518ea1e90627dbefc0 100644 (file)
@@ -34,22 +34,39 @@ import (
 
 var DefaultChunkFactor uint32 = 1024
 
+// Arbitrary number, need to make it configurable.
+// Lower values like 10/making persister really slow
+// doesn't work well as it is creating more files to
+// persist for in next persist iteration and spikes the # FDs.
+// Ideal value should let persister also proceed at
+// an optimum pace so that the merger can skip
+// many intermediate snapshots.
+// This needs to be based on empirical data.
+// TODO - may need to revisit this approach/value.
+var epochDistance = uint64(5)
+
 type notificationChan chan struct{}
 
 func (s *Scorch) persisterLoop() {
        defer s.asyncTasks.Done()
 
-       var notifyChs []notificationChan
-       var lastPersistedEpoch uint64
+       var persistWatchers []*epochWatcher
+       var lastPersistedEpoch, lastMergedEpoch uint64
+       var ew *epochWatcher
 OUTER:
        for {
                select {
                case <-s.closeCh:
                        break OUTER
-               case notifyCh := <-s.persisterNotifier:
-                       notifyChs = append(notifyChs, notifyCh)
+               case ew = <-s.persisterNotifier:
+                       persistWatchers = append(persistWatchers, ew)
                default:
                }
+               if ew != nil && ew.epoch > lastMergedEpoch {
+                       lastMergedEpoch = ew.epoch
+               }
+               persistWatchers = s.pausePersisterForMergerCatchUp(lastPersistedEpoch,
+                       &lastMergedEpoch, persistWatchers)
 
                var ourSnapshot *IndexSnapshot
                var ourPersisted []chan error
@@ -81,10 +98,11 @@ OUTER:
                        }
 
                        lastPersistedEpoch = ourSnapshot.epoch
-                       for _, notifyCh := range notifyChs {
-                               close(notifyCh)
+                       for _, ew := range persistWatchers {
+                               close(ew.notifyCh)
                        }
-                       notifyChs = nil
+
+                       persistWatchers = nil
                        _ = ourSnapshot.DecRef()
 
                        changed := false
@@ -120,27 +138,155 @@ OUTER:
                        break OUTER
                case <-w.notifyCh:
                        // woken up, next loop should pick up work
+                       continue OUTER
+               case ew = <-s.persisterNotifier:
+                       // if the watchers are already caught up then let them wait,
+                       // else let them continue to do the catch up
+                       persistWatchers = append(persistWatchers, ew)
+               }
+       }
+}
+
+func notifyMergeWatchers(lastPersistedEpoch uint64,
+       persistWatchers []*epochWatcher) []*epochWatcher {
+       var watchersNext []*epochWatcher
+       for _, w := range persistWatchers {
+               if w.epoch < lastPersistedEpoch {
+                       close(w.notifyCh)
+               } else {
+                       watchersNext = append(watchersNext, w)
                }
        }
+       return watchersNext
+}
+
+func (s *Scorch) pausePersisterForMergerCatchUp(lastPersistedEpoch uint64, lastMergedEpoch *uint64,
+       persistWatchers []*epochWatcher) []*epochWatcher {
+
+       // first, let the watchers proceed if they lag behind
+       persistWatchers = notifyMergeWatchers(lastPersistedEpoch, persistWatchers)
+
+OUTER:
+       // check for slow merger and await until the merger catch up
+       for lastPersistedEpoch > *lastMergedEpoch+epochDistance {
+
+               select {
+               case <-s.closeCh:
+                       break OUTER
+               case ew := <-s.persisterNotifier:
+                       persistWatchers = append(persistWatchers, ew)
+                       *lastMergedEpoch = ew.epoch
+               }
+
+               // let the watchers proceed if they lag behind
+               persistWatchers = notifyMergeWatchers(lastPersistedEpoch, persistWatchers)
+       }
+
+       return persistWatchers
 }
 
 func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
-       // start a write transaction
-       tx, err := s.rootBolt.Begin(true)
+       persisted, err := s.persistSnapshotMaybeMerge(snapshot)
        if err != nil {
                return err
        }
-       // defer fsync of the rootbolt
-       defer func() {
-               if err == nil {
-                       err = s.rootBolt.Sync()
+       if persisted {
+               return nil
+       }
+
+       return s.persistSnapshotDirect(snapshot)
+}
+
+// DefaultMinSegmentsForInMemoryMerge represents the default number of
+// in-memory zap segments that persistSnapshotMaybeMerge() needs to
+// see in an IndexSnapshot before it decides to merge and persist
+// those segments
+var DefaultMinSegmentsForInMemoryMerge = 2
+
+// persistSnapshotMaybeMerge examines the snapshot and might merge and
+// persist the in-memory zap segments if there are enough of them
+func (s *Scorch) persistSnapshotMaybeMerge(snapshot *IndexSnapshot) (
+       bool, error) {
+       // collect the in-memory zap segments (SegmentBase instances)
+       var sbs []*zap.SegmentBase
+       var sbsDrops []*roaring.Bitmap
+       var sbsIndexes []int
+
+       for i, segmentSnapshot := range snapshot.segment {
+               if sb, ok := segmentSnapshot.segment.(*zap.SegmentBase); ok {
+                       sbs = append(sbs, sb)
+                       sbsDrops = append(sbsDrops, segmentSnapshot.deleted)
+                       sbsIndexes = append(sbsIndexes, i)
                }
+       }
+
+       if len(sbs) < DefaultMinSegmentsForInMemoryMerge {
+               return false, nil
+       }
+
+       _, newSnapshot, newSegmentID, err := s.mergeSegmentBases(
+               snapshot, sbs, sbsDrops, sbsIndexes, DefaultChunkFactor)
+       if err != nil {
+               return false, err
+       }
+       if newSnapshot == nil {
+               return false, nil
+       }
+
+       defer func() {
+               _ = newSnapshot.DecRef()
        }()
-       // defer commit/rollback transaction
+
+       mergedSegmentIDs := map[uint64]struct{}{}
+       for _, idx := range sbsIndexes {
+               mergedSegmentIDs[snapshot.segment[idx].id] = struct{}{}
+       }
+
+       // construct a snapshot that's logically equivalent to the input
+       // snapshot, but with merged segments replaced by the new segment
+       equiv := &IndexSnapshot{
+               parent:   snapshot.parent,
+               segment:  make([]*SegmentSnapshot, 0, len(snapshot.segment)),
+               internal: snapshot.internal,
+               epoch:    snapshot.epoch,
+       }
+
+       // copy to the equiv the segments that weren't replaced
+       for _, segment := range snapshot.segment {
+               if _, wasMerged := mergedSegmentIDs[segment.id]; !wasMerged {
+                       equiv.segment = append(equiv.segment, segment)
+               }
+       }
+
+       // append to the equiv the new segment
+       for _, segment := range newSnapshot.segment {
+               if segment.id == newSegmentID {
+                       equiv.segment = append(equiv.segment, &SegmentSnapshot{
+                               id:      newSegmentID,
+                               segment: segment.segment,
+                               deleted: nil, // nil since merging handled deletions
+                       })
+                       break
+               }
+       }
+
+       err = s.persistSnapshotDirect(equiv)
+       if err != nil {
+               return false, err
+       }
+
+       return true, nil
+}
+
+func (s *Scorch) persistSnapshotDirect(snapshot *IndexSnapshot) (err error) {
+       // start a write transaction
+       tx, err := s.rootBolt.Begin(true)
+       if err != nil {
+               return err
+       }
+       // defer rollback on error
        defer func() {
-               if err == nil {
-                       err = tx.Commit()
-               } else {
+               if err != nil {
                        _ = tx.Rollback()
                }
        }()
@@ -172,20 +318,20 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
        newSegmentPaths := make(map[uint64]string)
 
        // first ensure that each segment in this snapshot has been persisted
-       for i, segmentSnapshot := range snapshot.segment {
-               snapshotSegmentKey := segment.EncodeUvarintAscending(nil, uint64(i))
-               snapshotSegmentBucket, err2 := snapshotBucket.CreateBucketIfNotExists(snapshotSegmentKey)
-               if err2 != nil {
-                       return err2
+       for _, segmentSnapshot := range snapshot.segment {
+               snapshotSegmentKey := segment.EncodeUvarintAscending(nil, segmentSnapshot.id)
+               snapshotSegmentBucket, err := snapshotBucket.CreateBucketIfNotExists(snapshotSegmentKey)
+               if err != nil {
+                       return err
                }
                switch seg := segmentSnapshot.segment.(type) {
                case *zap.SegmentBase:
                        // need to persist this to disk
                        filename := zapFileName(segmentSnapshot.id)
                        path := s.path + string(os.PathSeparator) + filename
-                       err2 := zap.PersistSegmentBase(seg, path)
-                       if err2 != nil {
-                               return fmt.Errorf("error persisting segment: %v", err2)
+                       err = zap.PersistSegmentBase(seg, path)
+                       if err != nil {
+                               return fmt.Errorf("error persisting segment: %v", err)
                        }
                        newSegmentPaths[segmentSnapshot.id] = path
                        err = snapshotSegmentBucket.Put(boltPathKey, []byte(filename))
@@ -218,19 +364,28 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
                }
        }
 
-       // only alter the root if we actually persisted a segment
-       // (sometimes its just a new snapshot, possibly with new internal values)
+       // we need to swap in a new root only when we've persisted 1 or
+       // more segments -- whereby the new root would have 1-for-1
+       // replacements of in-memory segments with file-based segments
+       //
+       // other cases like updates to internal values only, and/or when
+       // there are only deletions, are already covered and persisted by
+       // the newly populated boltdb snapshotBucket above
        if len(newSegmentPaths) > 0 {
                // now try to open all the new snapshots
                newSegments := make(map[uint64]segment.Segment)
+               defer func() {
+                       for _, s := range newSegments {
+                               if s != nil {
+                                       // cleanup segments that were opened but not
+                                       // swapped into the new root
+                                       _ = s.Close()
+                               }
+                       }
+               }()
                for segmentID, path := range newSegmentPaths {
                        newSegments[segmentID], err = zap.Open(path)
                        if err != nil {
-                               for _, s := range newSegments {
-                                       if s != nil {
-                                               _ = s.Close() // cleanup segments that were successfully opened
-                                       }
-                               }
                                return fmt.Errorf("error opening new segment at %s, %v", path, err)
                        }
                }
@@ -255,6 +410,7 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
                                        cachedDocs: segmentSnapshot.cachedDocs,
                                }
                                newIndexSnapshot.segment[i] = newSegmentSnapshot
+                               delete(newSegments, segmentSnapshot.id)
                                // update items persisted incase of a new segment snapshot
                                atomic.AddUint64(&s.stats.numItemsPersisted, newSegmentSnapshot.Count())
                        } else {
@@ -266,9 +422,7 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
                for k, v := range s.root.internal {
                        newIndexSnapshot.internal[k] = v
                }
-               for _, filename := range filenames {
-                       delete(s.ineligibleForRemoval, filename)
-               }
+
                rootPrev := s.root
                s.root = newIndexSnapshot
                s.rootLock.Unlock()
@@ -277,6 +431,24 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
                }
        }
 
+       err = tx.Commit()
+       if err != nil {
+               return err
+       }
+
+       err = s.rootBolt.Sync()
+       if err != nil {
+               return err
+       }
+
+       // allow files to become eligible for removal after commit, such
+       // as file segments from snapshots that came from the merger
+       s.rootLock.Lock()
+       for _, filename := range filenames {
+               delete(s.ineligibleForRemoval, filename)
+       }
+       s.rootLock.Unlock()
+
        return nil
 }
 
index 311077653aa6698b8f6026a5c1bad6381bbcda0e..f539313d1c15bfcc388fed9b11033adfd890bfe8 100644 (file)
@@ -61,7 +61,7 @@ type Scorch struct {
        merges             chan *segmentMerge
        introducerNotifier chan *epochWatcher
        revertToSnapshots  chan *snapshotReversion
-       persisterNotifier  chan notificationChan
+       persisterNotifier  chan *epochWatcher
        rootBolt           *bolt.DB
        asyncTasks         sync.WaitGroup
 
@@ -114,6 +114,25 @@ func (s *Scorch) fireAsyncError(err error) {
 }
 
 func (s *Scorch) Open() error {
+       err := s.openBolt()
+       if err != nil {
+               return err
+       }
+
+       s.asyncTasks.Add(1)
+       go s.mainLoop()
+
+       if !s.readOnly && s.path != "" {
+               s.asyncTasks.Add(1)
+               go s.persisterLoop()
+               s.asyncTasks.Add(1)
+               go s.mergerLoop()
+       }
+
+       return nil
+}
+
+func (s *Scorch) openBolt() error {
        var ok bool
        s.path, ok = s.config["path"].(string)
        if !ok {
@@ -136,6 +155,7 @@ func (s *Scorch) Open() error {
                        }
                }
        }
+
        rootBoltPath := s.path + string(os.PathSeparator) + "root.bolt"
        var err error
        if s.path != "" {
@@ -156,7 +176,7 @@ func (s *Scorch) Open() error {
        s.merges = make(chan *segmentMerge)
        s.introducerNotifier = make(chan *epochWatcher, 1)
        s.revertToSnapshots = make(chan *snapshotReversion)
-       s.persisterNotifier = make(chan notificationChan)
+       s.persisterNotifier = make(chan *epochWatcher, 1)
 
        if !s.readOnly && s.path != "" {
                err := s.removeOldZapFiles() // Before persister or merger create any new files.
@@ -166,16 +186,6 @@ func (s *Scorch) Open() error {
                }
        }
 
-       s.asyncTasks.Add(1)
-       go s.mainLoop()
-
-       if !s.readOnly && s.path != "" {
-               s.asyncTasks.Add(1)
-               go s.persisterLoop()
-               s.asyncTasks.Add(1)
-               go s.mergerLoop()
-       }
-
        return nil
 }
 
@@ -310,17 +320,21 @@ func (s *Scorch) prepareSegment(newSegment segment.Segment, ids []string,
                introduction.persisted = make(chan error, 1)
        }
 
-       // get read lock, to optimistically prepare obsoleted info
+       // optimistically prepare obsoletes outside of rootLock
        s.rootLock.RLock()
-       for _, seg := range s.root.segment {
+       root := s.root
+       root.AddRef()
+       s.rootLock.RUnlock()
+
+       for _, seg := range root.segment {
                delta, err := seg.segment.DocNumbers(ids)
                if err != nil {
-                       s.rootLock.RUnlock()
                        return err
                }
                introduction.obsoletes[seg.id] = delta
        }
-       s.rootLock.RUnlock()
+
+       _ = root.DecRef()
 
        s.introductions <- introduction
 
index d3344ce301f41178785a31c2f4876d03357b663d..57d60dc8908fe70c8bea76ab6549fdf11906a1c4 100644 (file)
@@ -95,6 +95,21 @@ func (s *Segment) initializeDict(results []*index.AnalysisResult) {
        var numTokenFrequencies int
        var totLocs int
 
+       // initial scan for all fieldID's to sort them
+       for _, result := range results {
+               for _, field := range result.Document.CompositeFields {
+                       s.getOrDefineField(field.Name())
+               }
+               for _, field := range result.Document.Fields {
+                       s.getOrDefineField(field.Name())
+               }
+       }
+       sort.Strings(s.FieldsInv[1:]) // keep _id as first field
+       s.FieldsMap = make(map[string]uint16, len(s.FieldsInv))
+       for fieldID, fieldName := range s.FieldsInv {
+               s.FieldsMap[fieldName] = uint16(fieldID + 1)
+       }
+
        processField := func(fieldID uint16, tfs analysis.TokenFrequencies) {
                for term, tf := range tfs {
                        pidPlus1, exists := s.Dicts[fieldID][term]
index 939c287e9849594665def1be9d83149700e98853..cf92ef71f6e99b1da00e4bbff8f31a78ff85894c 100644 (file)
@@ -76,6 +76,8 @@ type DictionaryIterator struct {
        prefix string
        end    string
        offset int
+
+       dictEntry index.DictEntry // reused across Next()'s
 }
 
 // Next returns the next entry in the dictionary
@@ -95,8 +97,7 @@ func (d *DictionaryIterator) Next() (*index.DictEntry, error) {
 
        d.offset++
        postingID := d.d.segment.Dicts[d.d.fieldID][next]
-       return &index.DictEntry{
-               Term:  next,
-               Count: d.d.segment.Postings[postingID-1].GetCardinality(),
-       }, nil
+       d.dictEntry.Term = next
+       d.dictEntry.Count = d.d.segment.Postings[postingID-1].GetCardinality()
+       return &d.dictEntry, nil
 }
index 58f9faeaf6b394a5bf9176f92f2c34d96ae2c3c4..72357ae7d7e9eafdf82bec6fd5fbf2ced3b54810 100644 (file)
@@ -28,7 +28,7 @@ import (
        "github.com/golang/snappy"
 )
 
-const version uint32 = 2
+const version uint32 = 3
 
 const fieldNotUninverted = math.MaxUint64
 
@@ -187,79 +187,42 @@ func persistBase(memSegment *mem.Segment, cr *CountHashWriter, chunkFactor uint3
 }
 
 func persistStored(memSegment *mem.Segment, w *CountHashWriter) (uint64, error) {
-
        var curr int
        var metaBuf bytes.Buffer
        var data, compressed []byte
 
+       metaEncoder := govarint.NewU64Base128Encoder(&metaBuf)
+
        docNumOffsets := make(map[int]uint64, len(memSegment.Stored))
 
        for docNum, storedValues := range memSegment.Stored {
                if docNum != 0 {
                        // reset buffer if necessary
+                       curr = 0
                        metaBuf.Reset()
                        data = data[:0]
                        compressed = compressed[:0]
-                       curr = 0
                }
 
-               metaEncoder := govarint.NewU64Base128Encoder(&metaBuf)
-
                st := memSegment.StoredTypes[docNum]
                sp := memSegment.StoredPos[docNum]
 
                // encode fields in order
                for fieldID := range memSegment.FieldsInv {
                        if storedFieldValues, ok := storedValues[uint16(fieldID)]; ok {
-                               // has stored values for this field
-                               num := len(storedFieldValues)
-
                                stf := st[uint16(fieldID)]
                                spf := sp[uint16(fieldID)]
 
-                               // process each value
-                               for i := 0; i < num; i++ {
-                                       // encode field
-                                       _, err2 := metaEncoder.PutU64(uint64(fieldID))
-                                       if err2 != nil {
-                                               return 0, err2
-                                       }
-                                       // encode type
-                                       _, err2 = metaEncoder.PutU64(uint64(stf[i]))
-                                       if err2 != nil {
-                                               return 0, err2
-                                       }
-                                       // encode start offset
-                                       _, err2 = metaEncoder.PutU64(uint64(curr))
-                                       if err2 != nil {
-                                               return 0, err2
-                                       }
-                                       // end len
-                                       _, err2 = metaEncoder.PutU64(uint64(len(storedFieldValues[i])))
-                                       if err2 != nil {
-                                               return 0, err2
-                                       }
-                                       // encode number of array pos
-                                       _, err2 = metaEncoder.PutU64(uint64(len(spf[i])))
-                                       if err2 != nil {
-                                               return 0, err2
-                                       }
-                                       // encode all array positions
-                                       for _, pos := range spf[i] {
-                                               _, err2 = metaEncoder.PutU64(pos)
-                                               if err2 != nil {
-                                                       return 0, err2
-                                               }
-                                       }
-                                       // append data
-                                       data = append(data, storedFieldValues[i]...)
-                                       // update curr
-                                       curr += len(storedFieldValues[i])
+                               var err2 error
+                               curr, data, err2 = persistStoredFieldValues(fieldID,
+                                       storedFieldValues, stf, spf, curr, metaEncoder, data)
+                               if err2 != nil {
+                                       return 0, err2
                                }
                        }
                }
-               metaEncoder.Close()
 
+               metaEncoder.Close()
                metaBytes := metaBuf.Bytes()
 
                // compress the data
@@ -299,6 +262,51 @@ func persistStored(memSegment *mem.Segment, w *CountHashWriter) (uint64, error)
        return rv, nil
 }
 
+func persistStoredFieldValues(fieldID int,
+       storedFieldValues [][]byte, stf []byte, spf [][]uint64,
+       curr int, metaEncoder *govarint.Base128Encoder, data []byte) (
+       int, []byte, error) {
+       for i := 0; i < len(storedFieldValues); i++ {
+               // encode field
+               _, err := metaEncoder.PutU64(uint64(fieldID))
+               if err != nil {
+                       return 0, nil, err
+               }
+               // encode type
+               _, err = metaEncoder.PutU64(uint64(stf[i]))
+               if err != nil {
+                       return 0, nil, err
+               }
+               // encode start offset
+               _, err = metaEncoder.PutU64(uint64(curr))
+               if err != nil {
+                       return 0, nil, err
+               }
+               // end len
+               _, err = metaEncoder.PutU64(uint64(len(storedFieldValues[i])))
+               if err != nil {
+                       return 0, nil, err
+               }
+               // encode number of array pos
+               _, err = metaEncoder.PutU64(uint64(len(spf[i])))
+               if err != nil {
+                       return 0, nil, err
+               }
+               // encode all array positions
+               for _, pos := range spf[i] {
+                       _, err = metaEncoder.PutU64(pos)
+                       if err != nil {
+                               return 0, nil, err
+                       }
+               }
+
+               data = append(data, storedFieldValues[i]...)
+               curr += len(storedFieldValues[i])
+       }
+
+       return curr, data, nil
+}
+
 func persistPostingDetails(memSegment *mem.Segment, w *CountHashWriter, chunkFactor uint32) ([]uint64, []uint64, error) {
        var freqOffsets, locOfffsets []uint64
        tfEncoder := newChunkedIntCoder(uint64(chunkFactor), uint64(len(memSegment.Stored)-1))
@@ -580,7 +588,7 @@ func persistDocValues(memSegment *mem.Segment, w *CountHashWriter,
                if err != nil {
                        return nil, err
                }
-               // resetting encoder for the next field
+               // reseting encoder for the next field
                fdvEncoder.Reset()
        }
 
@@ -625,12 +633,21 @@ func NewSegmentBase(memSegment *mem.Segment, chunkFactor uint32) (*SegmentBase,
                return nil, err
        }
 
+       return InitSegmentBase(br.Bytes(), cr.Sum32(), chunkFactor,
+               memSegment.FieldsMap, memSegment.FieldsInv, numDocs,
+               storedIndexOffset, fieldsIndexOffset, docValueOffset, dictLocs)
+}
+
+func InitSegmentBase(mem []byte, memCRC uint32, chunkFactor uint32,
+       fieldsMap map[string]uint16, fieldsInv []string, numDocs uint64,
+       storedIndexOffset uint64, fieldsIndexOffset uint64, docValueOffset uint64,
+       dictLocs []uint64) (*SegmentBase, error) {
        sb := &SegmentBase{
-               mem:               br.Bytes(),
-               memCRC:            cr.Sum32(),
+               mem:               mem,
+               memCRC:            memCRC,
                chunkFactor:       chunkFactor,
-               fieldsMap:         memSegment.FieldsMap,
-               fieldsInv:         memSegment.FieldsInv,
+               fieldsMap:         fieldsMap,
+               fieldsInv:         fieldsInv,
                numDocs:           numDocs,
                storedIndexOffset: storedIndexOffset,
                fieldsIndexOffset: fieldsIndexOffset,
@@ -639,7 +656,7 @@ func NewSegmentBase(memSegment *mem.Segment, chunkFactor uint32) (*SegmentBase,
                fieldDvIterMap:    make(map[uint16]*docValueIterator),
        }
 
-       err = sb.loadDvIterators()
+       err := sb.loadDvIterators()
        if err != nil {
                return nil, err
        }
index b03940497fbf9fb753fada6656a5ccbcccbe6672..83457146ecaef765aa627dbf88823a1d875c7be0 100644 (file)
@@ -39,7 +39,7 @@ type chunkedContentCoder struct {
 // MetaData represents the data information inside a
 // chunk.
 type MetaData struct {
-       DocID    uint64 // docid of the data inside the chunk
+       DocNum   uint64 // docNum of the data inside the chunk
        DocDvLoc uint64 // starting offset for a given docid
        DocDvLen uint64 // length of data inside the chunk for the given docid
 }
@@ -52,7 +52,7 @@ func newChunkedContentCoder(chunkSize uint64,
        rv := &chunkedContentCoder{
                chunkSize: chunkSize,
                chunkLens: make([]uint64, total),
-               chunkMeta: []MetaData{},
+               chunkMeta: make([]MetaData, 0, total),
        }
 
        return rv
@@ -68,7 +68,7 @@ func (c *chunkedContentCoder) Reset() {
        for i := range c.chunkLens {
                c.chunkLens[i] = 0
        }
-       c.chunkMeta = []MetaData{}
+       c.chunkMeta = c.chunkMeta[:0]
 }
 
 // Close indicates you are done calling Add() this allows
@@ -88,7 +88,7 @@ func (c *chunkedContentCoder) flushContents() error {
 
        // write out the metaData slice
        for _, meta := range c.chunkMeta {
-               _, err := writeUvarints(&c.chunkMetaBuf, meta.DocID, meta.DocDvLoc, meta.DocDvLen)
+               _, err := writeUvarints(&c.chunkMetaBuf, meta.DocNum, meta.DocDvLoc, meta.DocDvLen)
                if err != nil {
                        return err
                }
@@ -118,7 +118,7 @@ func (c *chunkedContentCoder) Add(docNum uint64, vals []byte) error {
                // clearing the chunk specific meta for next chunk
                c.chunkBuf.Reset()
                c.chunkMetaBuf.Reset()
-               c.chunkMeta = []MetaData{}
+               c.chunkMeta = c.chunkMeta[:0]
                c.currChunk = chunk
        }
 
@@ -130,7 +130,7 @@ func (c *chunkedContentCoder) Add(docNum uint64, vals []byte) error {
        }
 
        c.chunkMeta = append(c.chunkMeta, MetaData{
-               DocID:    docNum,
+               DocNum:   docNum,
                DocDvLoc: uint64(dvOffset),
                DocDvLen: uint64(dvSize),
        })
index 0f5145fba87c42dca3cc28e3fd41d8590ad3012d..e5d7126866db6feced1cfb8107246aa91d3a5c70 100644 (file)
@@ -34,32 +34,47 @@ type Dictionary struct {
 
 // PostingsList returns the postings list for the specified term
 func (d *Dictionary) PostingsList(term string, except *roaring.Bitmap) (segment.PostingsList, error) {
-       return d.postingsList([]byte(term), except)
+       return d.postingsList([]byte(term), except, nil)
 }
 
-func (d *Dictionary) postingsList(term []byte, except *roaring.Bitmap) (*PostingsList, error) {
-       rv := &PostingsList{
-               sb:     d.sb,
-               term:   term,
-               except: except,
+func (d *Dictionary) postingsList(term []byte, except *roaring.Bitmap, rv *PostingsList) (*PostingsList, error) {
+       if d.fst == nil {
+               return d.postingsListInit(rv, except), nil
        }
 
-       if d.fst != nil {
-               postingsOffset, exists, err := d.fst.Get(term)
-               if err != nil {
-                       return nil, fmt.Errorf("vellum err: %v", err)
-               }
-               if exists {
-                       err = rv.read(postingsOffset, d)
-                       if err != nil {
-                               return nil, err
-                       }
-               }
+       postingsOffset, exists, err := d.fst.Get(term)
+       if err != nil {
+               return nil, fmt.Errorf("vellum err: %v", err)
+       }
+       if !exists {
+               return d.postingsListInit(rv, except), nil
+       }
+
+       return d.postingsListFromOffset(postingsOffset, except, rv)
+}
+
+func (d *Dictionary) postingsListFromOffset(postingsOffset uint64, except *roaring.Bitmap, rv *PostingsList) (*PostingsList, error) {
+       rv = d.postingsListInit(rv, except)
+
+       err := rv.read(postingsOffset, d)
+       if err != nil {
+               return nil, err
        }
 
        return rv, nil
 }
 
+func (d *Dictionary) postingsListInit(rv *PostingsList, except *roaring.Bitmap) *PostingsList {
+       if rv == nil {
+               rv = &PostingsList{}
+       } else {
+               *rv = PostingsList{} // clear the struct
+       }
+       rv.sb = d.sb
+       rv.except = except
+       return rv
+}
+
 // Iterator returns an iterator for this dictionary
 func (d *Dictionary) Iterator() segment.DictionaryIterator {
        rv := &DictionaryIterator{
index fb5b348a5b6709b8a164f075fba65a313fb726fb..0514bd307c3bd91b4aa82653e2cc630d9916fe2e 100644 (file)
@@ -99,7 +99,7 @@ func (s *SegmentBase) loadFieldDocValueIterator(field string,
 func (di *docValueIterator) loadDvChunk(chunkNumber,
        localDocNum uint64, s *SegmentBase) error {
        // advance to the chunk where the docValues
-       // reside for the given docID
+       // reside for the given docNum
        destChunkDataLoc := di.dvDataLoc
        for i := 0; i < int(chunkNumber); i++ {
                destChunkDataLoc += di.chunkLens[i]
@@ -116,7 +116,7 @@ func (di *docValueIterator) loadDvChunk(chunkNumber,
        offset := uint64(0)
        di.curChunkHeader = make([]MetaData, int(numDocs))
        for i := 0; i < int(numDocs); i++ {
-               di.curChunkHeader[i].DocID, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64])
+               di.curChunkHeader[i].DocNum, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64])
                offset += uint64(read)
                di.curChunkHeader[i].DocDvLoc, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64])
                offset += uint64(read)
@@ -131,10 +131,10 @@ func (di *docValueIterator) loadDvChunk(chunkNumber,
        return nil
 }
 
-func (di *docValueIterator) visitDocValues(docID uint64,
+func (di *docValueIterator) visitDocValues(docNum uint64,
        visitor index.DocumentFieldTermVisitor) error {
-       // binary search the term locations for the docID
-       start, length := di.getDocValueLocs(docID)
+       // binary search the term locations for the docNum
+       start, length := di.getDocValueLocs(docNum)
        if start == math.MaxUint64 || length == math.MaxUint64 {
                return nil
        }
@@ -144,7 +144,7 @@ func (di *docValueIterator) visitDocValues(docID uint64,
                return err
        }
 
-       // pick the terms for the given docID
+       // pick the terms for the given docNum
        uncompressed = uncompressed[start : start+length]
        for {
                i := bytes.Index(uncompressed, termSeparatorSplitSlice)
@@ -159,11 +159,11 @@ func (di *docValueIterator) visitDocValues(docID uint64,
        return nil
 }
 
-func (di *docValueIterator) getDocValueLocs(docID uint64) (uint64, uint64) {
+func (di *docValueIterator) getDocValueLocs(docNum uint64) (uint64, uint64) {
        i := sort.Search(len(di.curChunkHeader), func(i int) bool {
-               return di.curChunkHeader[i].DocID >= docID
+               return di.curChunkHeader[i].DocNum >= docNum
        })
-       if i < len(di.curChunkHeader) && di.curChunkHeader[i].DocID == docID {
+       if i < len(di.curChunkHeader) && di.curChunkHeader[i].DocNum == docNum {
                return di.curChunkHeader[i].DocDvLoc, di.curChunkHeader[i].DocDvLen
        }
        return math.MaxUint64, math.MaxUint64
diff --git a/vendor/github.com/blevesearch/bleve/index/scorch/segment/zap/enumerator.go b/vendor/github.com/blevesearch/bleve/index/scorch/segment/zap/enumerator.go
new file mode 100644 (file)
index 0000000..3c708dd
--- /dev/null
@@ -0,0 +1,124 @@
+//  Copyright (c) 2018 Couchbase, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//             http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package zap
+
+import (
+       "bytes"
+
+       "github.com/couchbase/vellum"
+)
+
+// enumerator provides an ordered traversal of multiple vellum
+// iterators.  Like JOIN of iterators, the enumerator produces a
+// sequence of (key, iteratorIndex, value) tuples, sorted by key ASC,
+// then iteratorIndex ASC, where the same key might be seen or
+// repeated across multiple child iterators.
+type enumerator struct {
+       itrs   []vellum.Iterator
+       currKs [][]byte
+       currVs []uint64
+
+       lowK    []byte
+       lowIdxs []int
+       lowCurr int
+}
+
+// newEnumerator returns a new enumerator over the vellum Iterators
+func newEnumerator(itrs []vellum.Iterator) (*enumerator, error) {
+       rv := &enumerator{
+               itrs:    itrs,
+               currKs:  make([][]byte, len(itrs)),
+               currVs:  make([]uint64, len(itrs)),
+               lowIdxs: make([]int, 0, len(itrs)),
+       }
+       for i, itr := range rv.itrs {
+               rv.currKs[i], rv.currVs[i] = itr.Current()
+       }
+       rv.updateMatches()
+       if rv.lowK == nil {
+               return rv, vellum.ErrIteratorDone
+       }
+       return rv, nil
+}
+
+// updateMatches maintains the low key matches based on the currKs
+func (m *enumerator) updateMatches() {
+       m.lowK = nil
+       m.lowIdxs = m.lowIdxs[:0]
+       m.lowCurr = 0
+
+       for i, key := range m.currKs {
+               if key == nil {
+                       continue
+               }
+
+               cmp := bytes.Compare(key, m.lowK)
+               if cmp < 0 || m.lowK == nil {
+                       // reached a new low
+                       m.lowK = key
+                       m.lowIdxs = m.lowIdxs[:0]
+                       m.lowIdxs = append(m.lowIdxs, i)
+               } else if cmp == 0 {
+                       m.lowIdxs = append(m.lowIdxs, i)
+               }
+       }
+}
+
+// Current returns the enumerator's current key, iterator-index, and
+// value.  If the enumerator is not pointing at a valid value (because
+// Next returned an error previously), Current will return nil,0,0.
+func (m *enumerator) Current() ([]byte, int, uint64) {
+       var i int
+       var v uint64
+       if m.lowCurr < len(m.lowIdxs) {
+               i = m.lowIdxs[m.lowCurr]
+               v = m.currVs[i]
+       }
+       return m.lowK, i, v
+}
+
+// Next advances the enumerator to the next key/iterator/value result,
+// else vellum.ErrIteratorDone is returned.
+func (m *enumerator) Next() error {
+       m.lowCurr += 1
+       if m.lowCurr >= len(m.lowIdxs) {
+               // move all the current low iterators forwards
+               for _, vi := range m.lowIdxs {
+                       err := m.itrs[vi].Next()
+                       if err != nil && err != vellum.ErrIteratorDone {
+                               return err
+                       }
+                       m.currKs[vi], m.currVs[vi] = m.itrs[vi].Current()
+               }
+               m.updateMatches()
+       }
+       if m.lowK == nil {
+               return vellum.ErrIteratorDone
+       }
+       return nil
+}
+
+// Close all the underlying Iterators.  The first error, if any, will
+// be returned.
+func (m *enumerator) Close() error {
+       var rv error
+       for _, itr := range m.itrs {
+               err := itr.Close()
+               if rv == nil {
+                       rv = err
+               }
+       }
+       return rv
+}
index e9f295023bc6565d006db43e411ed98f83d4efa5..b505fec94e9b5d5fae19ed05835cd27c2fe98eb7 100644 (file)
@@ -30,6 +30,8 @@ type chunkedIntCoder struct {
        encoder   *govarint.Base128Encoder
        chunkLens []uint64
        currChunk uint64
+
+       buf []byte
 }
 
 // newChunkedIntCoder returns a new chunk int coder which packs data into
@@ -67,12 +69,8 @@ func (c *chunkedIntCoder) Add(docNum uint64, vals ...uint64) error {
                // starting a new chunk
                if c.encoder != nil {
                        // close out last
-                       c.encoder.Close()
-                       encodingBytes := c.chunkBuf.Bytes()
-                       c.chunkLens[c.currChunk] = uint64(len(encodingBytes))
-                       c.final = append(c.final, encodingBytes...)
+                       c.Close()
                        c.chunkBuf.Reset()
-                       c.encoder = govarint.NewU64Base128Encoder(&c.chunkBuf)
                }
                c.currChunk = chunk
        }
@@ -98,26 +96,25 @@ func (c *chunkedIntCoder) Close() {
 
 // Write commits all the encoded chunked integers to the provided writer.
 func (c *chunkedIntCoder) Write(w io.Writer) (int, error) {
-       var tw int
-       buf := make([]byte, binary.MaxVarintLen64)
-       // write out the number of chunks
+       bufNeeded := binary.MaxVarintLen64 * (1 + len(c.chunkLens))
+       if len(c.buf) < bufNeeded {
+               c.buf = make([]byte, bufNeeded)
+       }
+       buf := c.buf
+
+       // write out the number of chunks & each chunkLen
        n := binary.PutUvarint(buf, uint64(len(c.chunkLens)))
-       nw, err := w.Write(buf[:n])
-       tw += nw
+       for _, chunkLen := range c.chunkLens {
+               n += binary.PutUvarint(buf[n:], uint64(chunkLen))
+       }
+
+       tw, err := w.Write(buf[:n])
        if err != nil {
                return tw, err
        }
-       // write out the chunk lens
-       for _, chunkLen := range c.chunkLens {
-               n := binary.PutUvarint(buf, uint64(chunkLen))
-               nw, err = w.Write(buf[:n])
-               tw += nw
-               if err != nil {
-                       return tw, err
-               }
-       }
+
        // write out the data
-       nw, err = w.Write(c.final)
+       nw, err := w.Write(c.final)
        tw += nw
        if err != nil {
                return tw, err
index cc348d72072d1d6fd1e55448e5ce866da8ddf161..ae8c5b197b0f8b180a85c0cf60724ab7cf1210ef 100644 (file)
@@ -21,6 +21,7 @@ import (
        "fmt"
        "math"
        "os"
+       "sort"
 
        "github.com/RoaringBitmap/roaring"
        "github.com/Smerity/govarint"
@@ -28,6 +29,8 @@ import (
        "github.com/golang/snappy"
 )
 
+const docDropped = math.MaxUint64 // sentinel docNum to represent a deleted doc
+
 // Merge takes a slice of zap segments and bit masks describing which
 // documents may be dropped, and creates a new segment containing the
 // remaining data.  This new segment is built at the specified path,
@@ -46,47 +49,26 @@ func Merge(segments []*Segment, drops []*roaring.Bitmap, path string,
                _ = os.Remove(path)
        }
 
+       segmentBases := make([]*SegmentBase, len(segments))
+       for segmenti, segment := range segments {
+               segmentBases[segmenti] = &segment.SegmentBase
+       }
+
        // buffer the output
        br := bufio.NewWriter(f)
 
        // wrap it for counting (tracking offsets)
        cr := NewCountHashWriter(br)
 
-       fieldsInv := mergeFields(segments)
-       fieldsMap := mapFields(fieldsInv)
-
-       var newDocNums [][]uint64
-       var storedIndexOffset uint64
-       fieldDvLocsOffset := uint64(fieldNotUninverted)
-       var dictLocs []uint64
-
-       newSegDocCount := computeNewDocCount(segments, drops)
-       if newSegDocCount > 0 {
-               storedIndexOffset, newDocNums, err = mergeStoredAndRemap(segments, drops,
-                       fieldsMap, fieldsInv, newSegDocCount, cr)
-               if err != nil {
-                       cleanup()
-                       return nil, err
-               }
-
-               dictLocs, fieldDvLocsOffset, err = persistMergedRest(segments, drops, fieldsInv, fieldsMap,
-                       newDocNums, newSegDocCount, chunkFactor, cr)
-               if err != nil {
-                       cleanup()
-                       return nil, err
-               }
-       } else {
-               dictLocs = make([]uint64, len(fieldsInv))
-       }
-
-       fieldsIndexOffset, err := persistFields(fieldsInv, cr, dictLocs)
+       newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset, _, _, _, err :=
+               MergeToWriter(segmentBases, drops, chunkFactor, cr)
        if err != nil {
                cleanup()
                return nil, err
        }
 
-       err = persistFooter(newSegDocCount, storedIndexOffset,
-               fieldsIndexOffset, fieldDvLocsOffset, chunkFactor, cr.Sum32(), cr)
+       err = persistFooter(numDocs, storedIndexOffset, fieldsIndexOffset,
+               docValueOffset, chunkFactor, cr.Sum32(), cr)
        if err != nil {
                cleanup()
                return nil, err
@@ -113,21 +95,59 @@ func Merge(segments []*Segment, drops []*roaring.Bitmap, path string,
        return newDocNums, nil
 }
 
-// mapFields takes the fieldsInv list and builds the map
+func MergeToWriter(segments []*SegmentBase, drops []*roaring.Bitmap,
+       chunkFactor uint32, cr *CountHashWriter) (
+       newDocNums [][]uint64,
+       numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset uint64,
+       dictLocs []uint64, fieldsInv []string, fieldsMap map[string]uint16,
+       err error) {
+       docValueOffset = uint64(fieldNotUninverted)
+
+       var fieldsSame bool
+       fieldsSame, fieldsInv = mergeFields(segments)
+       fieldsMap = mapFields(fieldsInv)
+
+       numDocs = computeNewDocCount(segments, drops)
+       if numDocs > 0 {
+               storedIndexOffset, newDocNums, err = mergeStoredAndRemap(segments, drops,
+                       fieldsMap, fieldsInv, fieldsSame, numDocs, cr)
+               if err != nil {
+                       return nil, 0, 0, 0, 0, nil, nil, nil, err
+               }
+
+               dictLocs, docValueOffset, err = persistMergedRest(segments, drops, fieldsInv, fieldsMap,
+                       newDocNums, numDocs, chunkFactor, cr)
+               if err != nil {
+                       return nil, 0, 0, 0, 0, nil, nil, nil, err
+               }
+       } else {
+               dictLocs = make([]uint64, len(fieldsInv))
+       }
+
+       fieldsIndexOffset, err = persistFields(fieldsInv, cr, dictLocs)
+       if err != nil {
+               return nil, 0, 0, 0, 0, nil, nil, nil, err
+       }
+
+       return newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset, dictLocs, fieldsInv, fieldsMap, nil
+}
+
+// mapFields takes the fieldsInv list and returns a map of fieldName
+// to fieldID+1
 func mapFields(fields []string) map[string]uint16 {
        rv := make(map[string]uint16, len(fields))
        for i, fieldName := range fields {
-               rv[fieldName] = uint16(i)
+               rv[fieldName] = uint16(i) + 1
        }
        return rv
 }
 
 // computeNewDocCount determines how many documents will be in the newly
 // merged segment when obsoleted docs are dropped
-func computeNewDocCount(segments []*Segment, drops []*roaring.Bitmap) uint64 {
+func computeNewDocCount(segments []*SegmentBase, drops []*roaring.Bitmap) uint64 {
        var newDocCount uint64
        for segI, segment := range segments {
-               newDocCount += segment.NumDocs()
+               newDocCount += segment.numDocs
                if drops[segI] != nil {
                        newDocCount -= drops[segI].GetCardinality()
                }
@@ -135,8 +155,8 @@ func computeNewDocCount(segments []*Segment, drops []*roaring.Bitmap) uint64 {
        return newDocCount
 }
 
-func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
-       fieldsInv []string, fieldsMap map[string]uint16, newDocNums [][]uint64,
+func persistMergedRest(segments []*SegmentBase, dropsIn []*roaring.Bitmap,
+       fieldsInv []string, fieldsMap map[string]uint16, newDocNumsIn [][]uint64,
        newSegDocCount uint64, chunkFactor uint32,
        w *CountHashWriter) ([]uint64, uint64, error) {
 
@@ -144,9 +164,14 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
        var bufMaxVarintLen64 []byte = make([]byte, binary.MaxVarintLen64)
        var bufLoc []uint64
 
+       var postings *PostingsList
+       var postItr *PostingsIterator
+
        rv := make([]uint64, len(fieldsInv))
        fieldDvLocs := make([]uint64, len(fieldsInv))
-       fieldDvLocsOffset := uint64(fieldNotUninverted)
+
+       tfEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
+       locEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
 
        // docTermMap is keyed by docNum, where the array impl provides
        // better memory usage behavior than a sparse-friendlier hashmap
@@ -166,36 +191,31 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
                        return nil, 0, err
                }
 
-               // collect FST iterators from all segments for this field
+               // collect FST iterators from all active segments for this field
+               var newDocNums [][]uint64
+               var drops []*roaring.Bitmap
                var dicts []*Dictionary
                var itrs []vellum.Iterator
-               for _, segment := range segments {
+
+               for segmentI, segment := range segments {
                        dict, err2 := segment.dictionary(fieldName)
                        if err2 != nil {
                                return nil, 0, err2
                        }
-                       dicts = append(dicts, dict)
-
                        if dict != nil && dict.fst != nil {
                                itr, err2 := dict.fst.Iterator(nil, nil)
                                if err2 != nil && err2 != vellum.ErrIteratorDone {
                                        return nil, 0, err2
                                }
                                if itr != nil {
+                                       newDocNums = append(newDocNums, newDocNumsIn[segmentI])
+                                       drops = append(drops, dropsIn[segmentI])
+                                       dicts = append(dicts, dict)
                                        itrs = append(itrs, itr)
                                }
                        }
                }
 
-               // create merging iterator
-               mergeItr, err := vellum.NewMergeIterator(itrs, func(postingOffsets []uint64) uint64 {
-                       // we don't actually use the merged value
-                       return 0
-               })
-
-               tfEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
-               locEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
-
                if uint64(cap(docTermMap)) < newSegDocCount {
                        docTermMap = make([][]byte, newSegDocCount)
                } else {
@@ -205,70 +225,14 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
                        }
                }
 
-               for err == nil {
-                       term, _ := mergeItr.Current()
-
-                       newRoaring := roaring.NewBitmap()
-                       newRoaringLocs := roaring.NewBitmap()
-
-                       tfEncoder.Reset()
-                       locEncoder.Reset()
-
-                       // now go back and get posting list for this term
-                       // but pass in the deleted docs for that segment
-                       for dictI, dict := range dicts {
-                               if dict == nil {
-                                       continue
-                               }
-                               postings, err2 := dict.postingsList(term, drops[dictI])
-                               if err2 != nil {
-                                       return nil, 0, err2
-                               }
-
-                               postItr := postings.Iterator()
-                               next, err2 := postItr.Next()
-                               for next != nil && err2 == nil {
-                                       hitNewDocNum := newDocNums[dictI][next.Number()]
-                                       if hitNewDocNum == docDropped {
-                                               return nil, 0, fmt.Errorf("see hit with dropped doc num")
-                                       }
-                                       newRoaring.Add(uint32(hitNewDocNum))
-                                       // encode norm bits
-                                       norm := next.Norm()
-                                       normBits := math.Float32bits(float32(norm))
-                                       err = tfEncoder.Add(hitNewDocNum, next.Frequency(), uint64(normBits))
-                                       if err != nil {
-                                               return nil, 0, err
-                                       }
-                                       locs := next.Locations()
-                                       if len(locs) > 0 {
-                                               newRoaringLocs.Add(uint32(hitNewDocNum))
-                                               for _, loc := range locs {
-                                                       if cap(bufLoc) < 5+len(loc.ArrayPositions()) {
-                                                               bufLoc = make([]uint64, 0, 5+len(loc.ArrayPositions()))
-                                                       }
-                                                       args := bufLoc[0:5]
-                                                       args[0] = uint64(fieldsMap[loc.Field()])
-                                                       args[1] = loc.Pos()
-                                                       args[2] = loc.Start()
-                                                       args[3] = loc.End()
-                                                       args[4] = uint64(len(loc.ArrayPositions()))
-                                                       args = append(args, loc.ArrayPositions()...)
-                                                       err = locEncoder.Add(hitNewDocNum, args...)
-                                                       if err != nil {
-                                                               return nil, 0, err
-                                                       }
-                                               }
-                                       }
+               var prevTerm []byte
 
-                                       docTermMap[hitNewDocNum] =
-                                               append(append(docTermMap[hitNewDocNum], term...), termSeparator)
+               newRoaring := roaring.NewBitmap()
+               newRoaringLocs := roaring.NewBitmap()
 
-                                       next, err2 = postItr.Next()
-                               }
-                               if err2 != nil {
-                                       return nil, 0, err2
-                               }
+               finishTerm := func(term []byte) error {
+                       if term == nil {
+                               return nil
                        }
 
                        tfEncoder.Close()
@@ -277,59 +241,142 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
                        if newRoaring.GetCardinality() > 0 {
                                // this field/term actually has hits in the new segment, lets write it down
                                freqOffset := uint64(w.Count())
-                               _, err = tfEncoder.Write(w)
+                               _, err := tfEncoder.Write(w)
                                if err != nil {
-                                       return nil, 0, err
+                                       return err
                                }
                                locOffset := uint64(w.Count())
                                _, err = locEncoder.Write(w)
                                if err != nil {
-                                       return nil, 0, err
+                                       return err
                                }
                                postingLocOffset := uint64(w.Count())
                                _, err = writeRoaringWithLen(newRoaringLocs, w, &bufReuse, bufMaxVarintLen64)
                                if err != nil {
-                                       return nil, 0, err
+                                       return err
                                }
                                postingOffset := uint64(w.Count())
+
                                // write out the start of the term info
-                               buf := bufMaxVarintLen64
-                               n := binary.PutUvarint(buf, freqOffset)
-                               _, err = w.Write(buf[:n])
+                               n := binary.PutUvarint(bufMaxVarintLen64, freqOffset)
+                               _, err = w.Write(bufMaxVarintLen64[:n])
                                if err != nil {
-                                       return nil, 0, err
+                                       return err
                                }
-
                                // write out the start of the loc info
-                               n = binary.PutUvarint(buf, locOffset)
-                               _, err = w.Write(buf[:n])
+                               n = binary.PutUvarint(bufMaxVarintLen64, locOffset)
+                               _, err = w.Write(bufMaxVarintLen64[:n])
                                if err != nil {
-                                       return nil, 0, err
+                                       return err
                                }
-
-                               // write out the start of the loc posting list
-                               n = binary.PutUvarint(buf, postingLocOffset)
-                               _, err = w.Write(buf[:n])
+                               // write out the start of the posting locs
+                               n = binary.PutUvarint(bufMaxVarintLen64, postingLocOffset)
+                               _, err = w.Write(bufMaxVarintLen64[:n])
                                if err != nil {
-                                       return nil, 0, err
+                                       return err
                                }
                                _, err = writeRoaringWithLen(newRoaring, w, &bufReuse, bufMaxVarintLen64)
                                if err != nil {
-                                       return nil, 0, err
+                                       return err
                                }
 
                                err = newVellum.Insert(term, postingOffset)
+                               if err != nil {
+                                       return err
+                               }
+                       }
+
+                       newRoaring = roaring.NewBitmap()
+                       newRoaringLocs = roaring.NewBitmap()
+
+                       tfEncoder.Reset()
+                       locEncoder.Reset()
+
+                       return nil
+               }
+
+               enumerator, err := newEnumerator(itrs)
+
+               for err == nil {
+                       term, itrI, postingsOffset := enumerator.Current()
+
+                       if !bytes.Equal(prevTerm, term) {
+                               // if the term changed, write out the info collected
+                               // for the previous term
+                               err2 := finishTerm(prevTerm)
+                               if err2 != nil {
+                                       return nil, 0, err2
+                               }
+                       }
+
+                       var err2 error
+                       postings, err2 = dicts[itrI].postingsListFromOffset(
+                               postingsOffset, drops[itrI], postings)
+                       if err2 != nil {
+                               return nil, 0, err2
+                       }
+
+                       newDocNumsI := newDocNums[itrI]
+
+                       postItr = postings.iterator(postItr)
+                       next, err2 := postItr.Next()
+                       for next != nil && err2 == nil {
+                               hitNewDocNum := newDocNumsI[next.Number()]
+                               if hitNewDocNum == docDropped {
+                                       return nil, 0, fmt.Errorf("see hit with dropped doc num")
+                               }
+                               newRoaring.Add(uint32(hitNewDocNum))
+                               // encode norm bits
+                               norm := next.Norm()
+                               normBits := math.Float32bits(float32(norm))
+                               err = tfEncoder.Add(hitNewDocNum, next.Frequency(), uint64(normBits))
                                if err != nil {
                                        return nil, 0, err
                                }
+                               locs := next.Locations()
+                               if len(locs) > 0 {
+                                       newRoaringLocs.Add(uint32(hitNewDocNum))
+                                       for _, loc := range locs {
+                                               if cap(bufLoc) < 5+len(loc.ArrayPositions()) {
+                                                       bufLoc = make([]uint64, 0, 5+len(loc.ArrayPositions()))
+                                               }
+                                               args := bufLoc[0:5]
+                                               args[0] = uint64(fieldsMap[loc.Field()] - 1)
+                                               args[1] = loc.Pos()
+                                               args[2] = loc.Start()
+                                               args[3] = loc.End()
+                                               args[4] = uint64(len(loc.ArrayPositions()))
+                                               args = append(args, loc.ArrayPositions()...)
+                                               err = locEncoder.Add(hitNewDocNum, args...)
+                                               if err != nil {
+                                                       return nil, 0, err
+                                               }
+                                       }
+                               }
+
+                               docTermMap[hitNewDocNum] =
+                                       append(append(docTermMap[hitNewDocNum], term...), termSeparator)
+
+                               next, err2 = postItr.Next()
+                       }
+                       if err2 != nil {
+                               return nil, 0, err2
                        }
 
-                       err = mergeItr.Next()
+                       prevTerm = prevTerm[:0] // copy to prevTerm in case Next() reuses term mem
+                       prevTerm = append(prevTerm, term...)
+
+                       err = enumerator.Next()
                }
                if err != nil && err != vellum.ErrIteratorDone {
                        return nil, 0, err
                }
 
+               err = finishTerm(prevTerm)
+               if err != nil {
+                       return nil, 0, err
+               }
+
                dictOffset := uint64(w.Count())
 
                err = newVellum.Close()
@@ -378,7 +425,7 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
                }
        }
 
-       fieldDvLocsOffset = uint64(w.Count())
+       fieldDvLocsOffset := uint64(w.Count())
 
        buf := bufMaxVarintLen64
        for _, offset := range fieldDvLocs {
@@ -392,10 +439,8 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
        return rv, fieldDvLocsOffset, nil
 }
 
-const docDropped = math.MaxUint64
-
-func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
-       fieldsMap map[string]uint16, fieldsInv []string, newSegDocCount uint64,
+func mergeStoredAndRemap(segments []*SegmentBase, drops []*roaring.Bitmap,
+       fieldsMap map[string]uint16, fieldsInv []string, fieldsSame bool, newSegDocCount uint64,
        w *CountHashWriter) (uint64, [][]uint64, error) {
        var rv [][]uint64 // The remapped or newDocNums for each segment.
 
@@ -417,10 +462,30 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
        for segI, segment := range segments {
                segNewDocNums := make([]uint64, segment.numDocs)
 
+               dropsI := drops[segI]
+
+               // optimize when the field mapping is the same across all
+               // segments and there are no deletions, via byte-copying
+               // of stored docs bytes directly to the writer
+               if fieldsSame && (dropsI == nil || dropsI.GetCardinality() == 0) {
+                       err := segment.copyStoredDocs(newDocNum, docNumOffsets, w)
+                       if err != nil {
+                               return 0, nil, err
+                       }
+
+                       for i := uint64(0); i < segment.numDocs; i++ {
+                               segNewDocNums[i] = newDocNum
+                               newDocNum++
+                       }
+                       rv = append(rv, segNewDocNums)
+
+                       continue
+               }
+
                // for each doc num
                for docNum := uint64(0); docNum < segment.numDocs; docNum++ {
                        // TODO: roaring's API limits docNums to 32-bits?
-                       if drops[segI] != nil && drops[segI].Contains(uint32(docNum)) {
+                       if dropsI != nil && dropsI.Contains(uint32(docNum)) {
                                segNewDocNums[docNum] = docDropped
                                continue
                        }
@@ -439,7 +504,7 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
                                poss[i] = poss[i][:0]
                        }
                        err := segment.VisitDocument(docNum, func(field string, typ byte, value []byte, pos []uint64) bool {
-                               fieldID := int(fieldsMap[field])
+                               fieldID := int(fieldsMap[field]) - 1
                                vals[fieldID] = append(vals[fieldID], value)
                                typs[fieldID] = append(typs[fieldID], typ)
                                poss[fieldID] = append(poss[fieldID], pos)
@@ -453,47 +518,14 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
                        for fieldID := range fieldsInv {
                                storedFieldValues := vals[int(fieldID)]
 
-                               // has stored values for this field
-                               num := len(storedFieldValues)
+                               stf := typs[int(fieldID)]
+                               spf := poss[int(fieldID)]
 
-                               // process each value
-                               for i := 0; i < num; i++ {
-                                       // encode field
-                                       _, err2 := metaEncoder.PutU64(uint64(fieldID))
-                                       if err2 != nil {
-                                               return 0, nil, err2
-                                       }
-                                       // encode type
-                                       _, err2 = metaEncoder.PutU64(uint64(typs[int(fieldID)][i]))
-                                       if err2 != nil {
-                                               return 0, nil, err2
-                                       }
-                                       // encode start offset
-                                       _, err2 = metaEncoder.PutU64(uint64(curr))
-                                       if err2 != nil {
-                                               return 0, nil, err2
-                                       }
-                                       // end len
-                                       _, err2 = metaEncoder.PutU64(uint64(len(storedFieldValues[i])))
-                                       if err2 != nil {
-                                               return 0, nil, err2
-                                       }
-                                       // encode number of array pos
-                                       _, err2 = metaEncoder.PutU64(uint64(len(poss[int(fieldID)][i])))
-                                       if err2 != nil {
-                                               return 0, nil, err2
-                                       }
-                                       // encode all array positions
-                                       for j := 0; j < len(poss[int(fieldID)][i]); j++ {
-                                               _, err2 = metaEncoder.PutU64(poss[int(fieldID)][i][j])
-                                               if err2 != nil {
-                                                       return 0, nil, err2
-                                               }
-                                       }
-                                       // append data
-                                       data = append(data, storedFieldValues[i]...)
-                                       // update curr
-                                       curr += len(storedFieldValues[i])
+                               var err2 error
+                               curr, data, err2 = persistStoredFieldValues(fieldID,
+                                       storedFieldValues, stf, spf, curr, metaEncoder, data)
+                               if err2 != nil {
+                                       return 0, nil, err2
                                }
                        }
 
@@ -528,36 +560,87 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
        }
 
        // return value is the start of the stored index
-       offset := uint64(w.Count())
+       storedIndexOffset := uint64(w.Count())
 
        // now write out the stored doc index
-       for docNum := range docNumOffsets {
-               err := binary.Write(w, binary.BigEndian, docNumOffsets[docNum])
+       for _, docNumOffset := range docNumOffsets {
+               err := binary.Write(w, binary.BigEndian, docNumOffset)
                if err != nil {
                        return 0, nil, err
                }
        }
 
-       return offset, rv, nil
+       return storedIndexOffset, rv, nil
 }
 
-// mergeFields builds a unified list of fields used across all the input segments
-func mergeFields(segments []*Segment) []string {
-       fieldsMap := map[string]struct{}{}
+// copyStoredDocs writes out a segment's stored doc info, optimized by
+// using a single Write() call for the entire set of bytes.  The
+// newDocNumOffsets is filled with the new offsets for each doc.
+func (s *SegmentBase) copyStoredDocs(newDocNum uint64, newDocNumOffsets []uint64,
+       w *CountHashWriter) error {
+       if s.numDocs <= 0 {
+               return nil
+       }
+
+       indexOffset0, storedOffset0, _, _, _ :=
+               s.getDocStoredOffsets(0) // the segment's first doc
+
+       indexOffsetN, storedOffsetN, readN, metaLenN, dataLenN :=
+               s.getDocStoredOffsets(s.numDocs - 1) // the segment's last doc
+
+       storedOffset0New := uint64(w.Count())
+
+       storedBytes := s.mem[storedOffset0 : storedOffsetN+readN+metaLenN+dataLenN]
+       _, err := w.Write(storedBytes)
+       if err != nil {
+               return err
+       }
+
+       // remap the storedOffset's for the docs into new offsets relative
+       // to storedOffset0New, filling the given docNumOffsetsOut array
+       for indexOffset := indexOffset0; indexOffset <= indexOffsetN; indexOffset += 8 {
+               storedOffset := binary.BigEndian.Uint64(s.mem[indexOffset : indexOffset+8])
+               storedOffsetNew := storedOffset - storedOffset0 + storedOffset0New
+               newDocNumOffsets[newDocNum] = storedOffsetNew
+               newDocNum += 1
+       }
+
+       return nil
+}
+
+// mergeFields builds a unified list of fields used across all the
+// input segments, and computes whether the fields are the same across
+// segments (which depends on fields to be sorted in the same way
+// across segments)
+func mergeFields(segments []*SegmentBase) (bool, []string) {
+       fieldsSame := true
+
+       var segment0Fields []string
+       if len(segments) > 0 {
+               segment0Fields = segments[0].Fields()
+       }
+
+       fieldsExist := map[string]struct{}{}
        for _, segment := range segments {
                fields := segment.Fields()
-               for _, field := range fields {
-                       fieldsMap[field] = struct{}{}
+               for fieldi, field := range fields {
+                       fieldsExist[field] = struct{}{}
+                       if len(segment0Fields) != len(fields) || segment0Fields[fieldi] != field {
+                               fieldsSame = false
+                       }
                }
        }
 
-       rv := make([]string, 0, len(fieldsMap))
+       rv := make([]string, 0, len(fieldsExist))
        // ensure _id stays first
        rv = append(rv, "_id")
-       for k := range fieldsMap {
+       for k := range fieldsExist {
                if k != "_id" {
                        rv = append(rv, k)
                }
        }
-       return rv
+
+       sort.Strings(rv[1:]) // leave _id as first
+
+       return fieldsSame, rv
 }
index 67e08d1ae3ba23e0eea72820b94fa00c88cd0c90..d504885d05c7e7aba9866694def2e0f286748109 100644 (file)
@@ -28,21 +28,27 @@ import (
 // PostingsList is an in-memory represenation of a postings list
 type PostingsList struct {
        sb             *SegmentBase
-       term           []byte
        postingsOffset uint64
        freqOffset     uint64
        locOffset      uint64
        locBitmap      *roaring.Bitmap
        postings       *roaring.Bitmap
        except         *roaring.Bitmap
-       postingKey     []byte
 }
 
 // Iterator returns an iterator for this postings list
 func (p *PostingsList) Iterator() segment.PostingsIterator {
-       rv := &PostingsIterator{
-               postings: p,
+       return p.iterator(nil)
+}
+
+func (p *PostingsList) iterator(rv *PostingsIterator) *PostingsIterator {
+       if rv == nil {
+               rv = &PostingsIterator{}
+       } else {
+               *rv = PostingsIterator{} // clear the struct
        }
+       rv.postings = p
+
        if p.postings != nil {
                // prepare the freq chunk details
                var n uint64
index 0c5b9e17fae0d44f4ffc4cb43d4508e3d2bec90f..e47d4c6abdcd14dc8742cf177fac12f7e14dfdd5 100644 (file)
@@ -17,15 +17,27 @@ package zap
 import "encoding/binary"
 
 func (s *SegmentBase) getDocStoredMetaAndCompressed(docNum uint64) ([]byte, []byte) {
-       docStoredStartAddr := s.storedIndexOffset + (8 * docNum)
-       docStoredStart := binary.BigEndian.Uint64(s.mem[docStoredStartAddr : docStoredStartAddr+8])
+       _, storedOffset, n, metaLen, dataLen := s.getDocStoredOffsets(docNum)
+
+       meta := s.mem[storedOffset+n : storedOffset+n+metaLen]
+       data := s.mem[storedOffset+n+metaLen : storedOffset+n+metaLen+dataLen]
+
+       return meta, data
+}
+
+func (s *SegmentBase) getDocStoredOffsets(docNum uint64) (
+       uint64, uint64, uint64, uint64, uint64) {
+       indexOffset := s.storedIndexOffset + (8 * docNum)
+
+       storedOffset := binary.BigEndian.Uint64(s.mem[indexOffset : indexOffset+8])
+
        var n uint64
-       metaLen, read := binary.Uvarint(s.mem[docStoredStart : docStoredStart+binary.MaxVarintLen64])
+
+       metaLen, read := binary.Uvarint(s.mem[storedOffset : storedOffset+binary.MaxVarintLen64])
        n += uint64(read)
-       var dataLen uint64
-       dataLen, read = binary.Uvarint(s.mem[docStoredStart+n : docStoredStart+n+binary.MaxVarintLen64])
+
+       dataLen, read := binary.Uvarint(s.mem[storedOffset+n : storedOffset+n+binary.MaxVarintLen64])
        n += uint64(read)
-       meta := s.mem[docStoredStart+n : docStoredStart+n+metaLen]
-       data := s.mem[docStoredStart+n+metaLen : docStoredStart+n+metaLen+dataLen]
-       return meta, data
+
+       return indexOffset, storedOffset, n, metaLen, dataLen
 }
index 94268cacebaf6df13b1b9372610f48b4bd194249..40c0af2741b3dfc8d585f58087d41085ffe8f515 100644 (file)
@@ -343,8 +343,9 @@ func (s *SegmentBase) DocNumbers(ids []string) (*roaring.Bitmap, error) {
                        return nil, err
                }
 
+               var postings *PostingsList
                for _, id := range ids {
-                       postings, err := idDict.postingsList([]byte(id), nil)
+                       postings, err = idDict.postingsList([]byte(id), nil, postings)
                        if err != nil {
                                return nil, err
                        }
index 43c3ba9f1ebde83ab7892978c9a71b2f2eb5b548..247003311e750ea98a16518bfe172a5d0ae42a0f 100644 (file)
@@ -31,10 +31,9 @@ func (r *RollbackPoint) GetInternal(key []byte) []byte {
        return r.meta[string(key)]
 }
 
-// RollbackPoints returns an array of rollback points available
-// for the application to make a decision on where to rollback
-// to. A nil return value indicates that there are no available
-// rollback points.
+// RollbackPoints returns an array of rollback points available for
+// the application to rollback to, with more recent rollback points
+// (higher epochs) coming first.
 func (s *Scorch) RollbackPoints() ([]*RollbackPoint, error) {
        if s.rootBolt == nil {
                return nil, fmt.Errorf("RollbackPoints: root is nil")
@@ -54,7 +53,7 @@ func (s *Scorch) RollbackPoints() ([]*RollbackPoint, error) {
 
        snapshots := tx.Bucket(boltSnapshotsBucket)
        if snapshots == nil {
-               return nil, fmt.Errorf("RollbackPoints: no snapshots available")
+               return nil, nil
        }
 
        rollbackPoints := []*RollbackPoint{}
@@ -150,10 +149,7 @@ func (s *Scorch) Rollback(to *RollbackPoint) error {
 
                revert.snapshot = indexSnapshot
                revert.applied = make(chan error)
-
-               if !s.unsafeBatch {
-                       revert.persisted = make(chan error)
-               }
+               revert.persisted = make(chan error)
 
                return nil
        })
@@ -173,9 +169,5 @@ func (s *Scorch) Rollback(to *RollbackPoint) error {
                return fmt.Errorf("Rollback: failed with err: %v", err)
        }
 
-       if revert.persisted != nil {
-               err = <-revert.persisted
-       }
-
-       return err
+       return <-revert.persisted
 }
index 1243375b769c5f446400621e49df93d91d449b14..70e6e457f6df26a59a6c36d0b7ff3fdb329fe560 100644 (file)
@@ -837,6 +837,11 @@ func (udc *UpsideDownCouch) Batch(batch *index.Batch) (err error) {
                        docBackIndexRowErr = err
                        return
                }
+               defer func() {
+                       if cerr := kvreader.Close(); err == nil && cerr != nil {
+                               docBackIndexRowErr = cerr
+                       }
+               }()
 
                for docID, doc := range batch.IndexOps {
                        backIndexRow, err := backIndexRowForDoc(kvreader, index.IndexInternalID(docID))
@@ -847,12 +852,6 @@ func (udc *UpsideDownCouch) Batch(batch *index.Batch) (err error) {
 
                        docBackIndexRowCh <- &docBackIndexRow{docID, doc, backIndexRow}
                }
-
-               err = kvreader.Close()
-               if err != nil {
-                       docBackIndexRowErr = err
-                       return
-               }
        }()
 
        // wait for analysis result
index 9e9a3594ff06328809566c0e918fcc26859ada9b..f678a059b7f70190567f50331c76fdbb59a73195 100644 (file)
 package bleve
 
 import (
+       "context"
        "sort"
        "sync"
        "time"
 
-       "golang.org/x/net/context"
-
        "github.com/blevesearch/bleve/document"
        "github.com/blevesearch/bleve/index"
        "github.com/blevesearch/bleve/index/store"
index 799b582a0600aeedddf199d47747e6b9777c6fb2..caea1b8e04e2e73173b0d4285612a5a9e2cf724d 100644 (file)
@@ -15,6 +15,7 @@
 package bleve
 
 import (
+       "context"
        "encoding/json"
        "fmt"
        "os"
@@ -22,8 +23,6 @@ import (
        "sync/atomic"
        "time"
 
-       "golang.org/x/net/context"
-
        "github.com/blevesearch/bleve/document"
        "github.com/blevesearch/bleve/index"
        "github.com/blevesearch/bleve/index/store"
index cba4829d4642087541b640f28b74f1ef0ffd0af8..0d163a9d9d5e4792ec26bfa6040f2ec263d789ed 100644 (file)
 package search
 
 import (
+       "context"
        "time"
 
        "github.com/blevesearch/bleve/index"
-
-       "golang.org/x/net/context"
 )
 
 type Collector interface {
index 2c7c6752df514671c4ed722aff23f60f93970846..388370e7e7041cec77ccb273e506718014e1c4e9 100644 (file)
 package collector
 
 import (
+       "context"
        "time"
 
        "github.com/blevesearch/bleve/index"
        "github.com/blevesearch/bleve/search"
-       "golang.org/x/net/context"
 )
 
 type collectorStore interface {
index 58d0d60bbf7bb06665bae399d376390ae0047b9f..294679d2aabc278c88cbfd20f9ad6e60d7edae3e 100644 (file)
@@ -1,10 +1,10 @@
 package goth
 
 import (
+       "context"
        "fmt"
        "net/http"
 
-       "golang.org/x/net/context"
        "golang.org/x/oauth2"
 )
 
index 266bbe22081650d0a063ebff75bc72c1fcc75a44..5c80ca747b57046355d4245703836e32c4cc6bce 100644 (file)
@@ -4,17 +4,18 @@ package facebook
 
 import (
        "bytes"
+       "crypto/hmac"
+       "crypto/sha256"
+       "encoding/hex"
        "encoding/json"
        "errors"
+       "fmt"
        "io"
        "io/ioutil"
        "net/http"
        "net/url"
+       "strings"
 
-       "crypto/hmac"
-       "crypto/sha256"
-       "encoding/hex"
-       "fmt"
        "github.com/markbates/goth"
        "golang.org/x/oauth2"
 )
@@ -22,7 +23,7 @@ import (
 const (
        authURL         string = "https://www.facebook.com/dialog/oauth"
        tokenURL        string = "https://graph.facebook.com/oauth/access_token"
-       endpointProfile string = "https://graph.facebook.com/me?fields=email,first_name,last_name,link,about,id,name,picture,location"
+       endpointProfile string = "https://graph.facebook.com/me?fields="
 )
 
 // New creates a new Facebook provider, and sets up important connection details.
@@ -68,9 +69,9 @@ func (p *Provider) Debug(debug bool) {}
 
 // BeginAuth asks Facebook for an authentication end-point.
 func (p *Provider) BeginAuth(state string) (goth.Session, error) {
-       url := p.config.AuthCodeURL(state)
+       authUrl := p.config.AuthCodeURL(state)
        session := &Session{
-               AuthURL: url,
+               AuthURL: authUrl,
        }
        return session, nil
 }
@@ -96,7 +97,15 @@ func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
        hash.Write([]byte(sess.AccessToken))
        appsecretProof := hex.EncodeToString(hash.Sum(nil))
 
-       response, err := p.Client().Get(endpointProfile + "&access_token=" + url.QueryEscape(sess.AccessToken) + "&appsecret_proof=" + appsecretProof)
+       reqUrl := fmt.Sprint(
+               endpointProfile,
+               strings.Join(p.config.Scopes, ","),
+               "&access_token=",
+               url.QueryEscape(sess.AccessToken),
+               "&appsecret_proof=",
+               appsecretProof,
+       )
+       response, err := p.Client().Get(reqUrl)
        if err != nil {
                return user, err
        }
@@ -168,17 +177,31 @@ func newConfig(provider *Provider, scopes []string) *oauth2.Config {
                },
                Scopes: []string{
                        "email",
+                       "first_name",
+                       "last_name",
+                       "link",
+                       "about",
+                       "id",
+                       "name",
+                       "picture",
+                       "location",
                },
        }
 
-       defaultScopes := map[string]struct{}{
-               "email": {},
-       }
-
-       for _, scope := range scopes {
-               if _, exists := defaultScopes[scope]; !exists {
-                       c.Scopes = append(c.Scopes, scope)
+       // creates possibility to invoke field method like 'picture.type(large)'
+       var found bool
+       for _, sc := range scopes {
+               sc := sc
+               for i, defScope := range c.Scopes {
+                       if defScope == strings.Split(sc, ".")[0] {
+                               c.Scopes[i] = sc
+                               found = true
+                       }
+               }
+               if !found {
+                       c.Scopes = append(c.Scopes, sc)
                }
+               found = false
        }
 
        return c
diff --git a/vendor/golang.org/x/net/context/ctxhttp/ctxhttp.go b/vendor/golang.org/x/net/context/ctxhttp/ctxhttp.go
new file mode 100644 (file)
index 0000000..606cf1f
--- /dev/null
@@ -0,0 +1,74 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.7
+
+// Package ctxhttp provides helper functions for performing context-aware HTTP requests.
+package ctxhttp // import "golang.org/x/net/context/ctxhttp"
+
+import (
+       "io"
+       "net/http"
+       "net/url"
+       "strings"
+
+       "golang.org/x/net/context"
+)
+
+// Do sends an HTTP request with the provided http.Client and returns
+// an HTTP response.
+//
+// If the client is nil, http.DefaultClient is used.
+//
+// The provided ctx must be non-nil. If it is canceled or times out,
+// ctx.Err() will be returned.
+func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) {
+       if client == nil {
+               client = http.DefaultClient
+       }
+       resp, err := client.Do(req.WithContext(ctx))
+       // If we got an error, and the context has been canceled,
+       // the context's error is probably more useful.
+       if err != nil {
+               select {
+               case <-ctx.Done():
+                       err = ctx.Err()
+               default:
+               }
+       }
+       return resp, err
+}
+
+// Get issues a GET request via the Do function.
+func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
+       req, err := http.NewRequest("GET", url, nil)
+       if err != nil {
+               return nil, err
+       }
+       return Do(ctx, client, req)
+}
+
+// Head issues a HEAD request via the Do function.
+func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
+       req, err := http.NewRequest("HEAD", url, nil)
+       if err != nil {
+               return nil, err
+       }
+       return Do(ctx, client, req)
+}
+
+// Post issues a POST request via the Do function.
+func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) {
+       req, err := http.NewRequest("POST", url, body)
+       if err != nil {
+               return nil, err
+       }
+       req.Header.Set("Content-Type", bodyType)
+       return Do(ctx, client, req)
+}
+
+// PostForm issues a POST request via the Do function.
+func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) {
+       return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
+}
diff --git a/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17.go b/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17.go
new file mode 100644 (file)
index 0000000..926870c
--- /dev/null
@@ -0,0 +1,147 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !go1.7
+
+package ctxhttp // import "golang.org/x/net/context/ctxhttp"
+
+import (
+       "io"
+       "net/http"
+       "net/url"
+       "strings"
+
+       "golang.org/x/net/context"
+)
+
+func nop() {}
+
+var (
+       testHookContextDoneBeforeHeaders = nop
+       testHookDoReturned               = nop
+       testHookDidBodyClose             = nop
+)
+
+// Do sends an HTTP request with the provided http.Client and returns an HTTP response.
+// If the client is nil, http.DefaultClient is used.
+// If the context is canceled or times out, ctx.Err() will be returned.
+func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) {
+       if client == nil {
+               client = http.DefaultClient
+       }
+
+       // TODO(djd): Respect any existing value of req.Cancel.
+       cancel := make(chan struct{})
+       req.Cancel = cancel
+
+       type responseAndError struct {
+               resp *http.Response
+               err  error
+       }
+       result := make(chan responseAndError, 1)
+
+       // Make local copies of test hooks closed over by goroutines below.
+       // Prevents data races in tests.
+       testHookDoReturned := testHookDoReturned
+       testHookDidBodyClose := testHookDidBodyClose
+
+       go func() {
+               resp, err := client.Do(req)
+               testHookDoReturned()
+               result <- responseAndError{resp, err}
+       }()
+
+       var resp *http.Response
+
+       select {
+       case <-ctx.Done():
+               testHookContextDoneBeforeHeaders()
+               close(cancel)
+               // Clean up after the goroutine calling client.Do:
+               go func() {
+                       if r := <-result; r.resp != nil {
+                               testHookDidBodyClose()
+                               r.resp.Body.Close()
+                       }
+               }()
+               return nil, ctx.Err()
+       case r := <-result:
+               var err error
+               resp, err = r.resp, r.err
+               if err != nil {
+                       return resp, err
+               }
+       }
+
+       c := make(chan struct{})
+       go func() {
+               select {
+               case <-ctx.Done():
+                       close(cancel)
+               case <-c:
+                       // The response's Body is closed.
+               }
+       }()
+       resp.Body = &notifyingReader{resp.Body, c}
+
+       return resp, nil
+}
+
+// Get issues a GET request via the Do function.
+func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
+       req, err := http.NewRequest("GET", url, nil)
+       if err != nil {
+               return nil, err
+       }
+       return Do(ctx, client, req)
+}
+
+// Head issues a HEAD request via the Do function.
+func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
+       req, err := http.NewRequest("HEAD", url, nil)
+       if err != nil {
+               return nil, err
+       }
+       return Do(ctx, client, req)
+}
+
+// Post issues a POST request via the Do function.
+func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) {
+       req, err := http.NewRequest("POST", url, body)
+       if err != nil {
+               return nil, err
+       }
+       req.Header.Set("Content-Type", bodyType)
+       return Do(ctx, client, req)
+}
+
+// PostForm issues a POST request via the Do function.
+func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) {
+       return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
+}
+
+// notifyingReader is an io.ReadCloser that closes the notify channel after
+// Close is called or a Read fails on the underlying ReadCloser.
+type notifyingReader struct {
+       io.ReadCloser
+       notify chan<- struct{}
+}
+
+func (r *notifyingReader) Read(p []byte) (int, error) {
+       n, err := r.ReadCloser.Read(p)
+       if err != nil && r.notify != nil {
+               close(r.notify)
+               r.notify = nil
+       }
+       return n, err
+}
+
+func (r *notifyingReader) Close() error {
+       err := r.ReadCloser.Close()
+       if r.notify != nil {
+               close(r.notify)
+               r.notify = nil
+       }
+       return err
+}
index d02f24fd5288303238e195e2520ff64a1f3ea597..6a66aea5eafe0ca6a688840c47219556c552488e 100644 (file)
@@ -1,4 +1,4 @@
-Copyright (c) 2009 The oauth2 Authors. All rights reserved.
+Copyright (c) 2009 The Go Authors. All rights reserved.
 
 Redistribution and use in source and binary forms, with or without
 modification, are permitted provided that the following conditions are
diff --git a/vendor/golang.org/x/oauth2/client_appengine.go b/vendor/golang.org/x/oauth2/client_appengine.go
deleted file mode 100644 (file)
index 8962c49..0000000
+++ /dev/null
@@ -1,25 +0,0 @@
-// Copyright 2014 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build appengine
-
-// App Engine hooks.
-
-package oauth2
-
-import (
-       "net/http"
-
-       "golang.org/x/net/context"
-       "golang.org/x/oauth2/internal"
-       "google.golang.org/appengine/urlfetch"
-)
-
-func init() {
-       internal.RegisterContextClientFunc(contextClientAppEngine)
-}
-
-func contextClientAppEngine(ctx context.Context) (*http.Client, error) {
-       return urlfetch.Client(ctx), nil
-}
diff --git a/vendor/golang.org/x/oauth2/internal/client_appengine.go b/vendor/golang.org/x/oauth2/internal/client_appengine.go
new file mode 100644 (file)
index 0000000..7434871
--- /dev/null
@@ -0,0 +1,13 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build appengine
+
+package internal
+
+import "google.golang.org/appengine/urlfetch"
+
+func init() {
+       appengineClientHook = urlfetch.Client
+}
diff --git a/vendor/golang.org/x/oauth2/internal/doc.go b/vendor/golang.org/x/oauth2/internal/doc.go
new file mode 100644 (file)
index 0000000..03265e8
--- /dev/null
@@ -0,0 +1,6 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package internal contains support packages for oauth2 package.
+package internal
index fbe1028d64e52ba0cdec9007944932df79a2e8d7..c0ab196cf461cf3eb159d732e29f41a880d5a125 100644 (file)
@@ -2,18 +2,14 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// Package internal contains support packages for oauth2 package.
 package internal
 
 import (
-       "bufio"
        "crypto/rsa"
        "crypto/x509"
        "encoding/pem"
        "errors"
        "fmt"
-       "io"
-       "strings"
 )
 
 // ParseKey converts the binary contents of a private key file
@@ -30,7 +26,7 @@ func ParseKey(key []byte) (*rsa.PrivateKey, error) {
        if err != nil {
                parsedKey, err = x509.ParsePKCS1PrivateKey(key)
                if err != nil {
-                       return nil, fmt.Errorf("private key should be a PEM or plain PKSC1 or PKCS8; parse error: %v", err)
+                       return nil, fmt.Errorf("private key should be a PEM or plain PKCS1 or PKCS8; parse error: %v", err)
                }
        }
        parsed, ok := parsedKey.(*rsa.PrivateKey)
@@ -39,38 +35,3 @@ func ParseKey(key []byte) (*rsa.PrivateKey, error) {
        }
        return parsed, nil
 }
-
-func ParseINI(ini io.Reader) (map[string]map[string]string, error) {
-       result := map[string]map[string]string{
-               "": map[string]string{}, // root section
-       }
-       scanner := bufio.NewScanner(ini)
-       currentSection := ""
-       for scanner.Scan() {
-               line := strings.TrimSpace(scanner.Text())
-               if strings.HasPrefix(line, ";") {
-                       // comment.
-                       continue
-               }
-               if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") {
-                       currentSection = strings.TrimSpace(line[1 : len(line)-1])
-                       result[currentSection] = map[string]string{}
-                       continue
-               }
-               parts := strings.SplitN(line, "=", 2)
-               if len(parts) == 2 && parts[0] != "" {
-                       result[currentSection][strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
-               }
-       }
-       if err := scanner.Err(); err != nil {
-               return nil, fmt.Errorf("error scanning ini: %v", err)
-       }
-       return result, nil
-}
-
-func CondVal(v string) []string {
-       if v == "" {
-               return nil
-       }
-       return []string{v}
-}
index 18328a0dcf2edeee851a8ad7a79a83408a54da97..5ab17b9a5f7417d78f9c9ee7e8c742695dbc20eb 100644 (file)
@@ -2,11 +2,12 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// Package internal contains support packages for oauth2 package.
 package internal
 
 import (
+       "context"
        "encoding/json"
+       "errors"
        "fmt"
        "io"
        "io/ioutil"
@@ -17,10 +18,10 @@ import (
        "strings"
        "time"
 
-       "golang.org/x/net/context"
+       "golang.org/x/net/context/ctxhttp"
 )
 
-// Token represents the crendentials used to authorize
+// Token represents the credentials used to authorize
 // the requests to access protected resources on the OAuth 2.0
 // provider's backend.
 //
@@ -91,6 +92,7 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error {
 
 var brokenAuthHeaderProviders = []string{
        "https://accounts.google.com/",
+       "https://api.codeswholesale.com/oauth/token",
        "https://api.dropbox.com/",
        "https://api.dropboxapi.com/",
        "https://api.instagram.com/",
@@ -99,10 +101,16 @@ var brokenAuthHeaderProviders = []string{
        "https://api.pushbullet.com/",
        "https://api.soundcloud.com/",
        "https://api.twitch.tv/",
+       "https://id.twitch.tv/",
        "https://app.box.com/",
+       "https://api.box.com/",
        "https://connect.stripe.com/",
+       "https://login.mailchimp.com/",
        "https://login.microsoftonline.com/",
        "https://login.salesforce.com/",
+       "https://login.windows.net",
+       "https://login.live.com/",
+       "https://login.live-int.com/",
        "https://oauth.sandbox.trainingpeaks.com/",
        "https://oauth.trainingpeaks.com/",
        "https://oauth.vk.com/",
@@ -117,6 +125,24 @@ var brokenAuthHeaderProviders = []string{
        "https://www.strava.com/oauth/",
        "https://www.wunderlist.com/oauth/",
        "https://api.patreon.com/",
+       "https://sandbox.codeswholesale.com/oauth/token",
+       "https://api.sipgate.com/v1/authorization/oauth",
+       "https://api.medium.com/v1/tokens",
+       "https://log.finalsurge.com/oauth/token",
+       "https://multisport.todaysplan.com.au/rest/oauth/access_token",
+       "https://whats.todaysplan.com.au/rest/oauth/access_token",
+       "https://stackoverflow.com/oauth/access_token",
+       "https://account.health.nokia.com",
+       "https://accounts.zoho.com",
+}
+
+// brokenAuthHeaderDomains lists broken providers that issue dynamic endpoints.
+var brokenAuthHeaderDomains = []string{
+       ".auth0.com",
+       ".force.com",
+       ".myshopify.com",
+       ".okta.com",
+       ".oktapreview.com",
 }
 
 func RegisterBrokenAuthHeaderProvider(tokenURL string) {
@@ -139,6 +165,14 @@ func providerAuthHeaderWorks(tokenURL string) bool {
                }
        }
 
+       if u, err := url.Parse(tokenURL); err == nil {
+               for _, s := range brokenAuthHeaderDomains {
+                       if strings.HasSuffix(u.Host, s) {
+                               return false
+                       }
+               }
+       }
+
        // Assume the provider implements the spec properly
        // otherwise. We can add more exceptions as they're
        // discovered. We will _not_ be adding configurable hooks
@@ -147,14 +181,14 @@ func providerAuthHeaderWorks(tokenURL string) bool {
 }
 
 func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) {
-       hc, err := ContextClient(ctx)
-       if err != nil {
-               return nil, err
-       }
-       v.Set("client_id", clientID)
        bustedAuth := !providerAuthHeaderWorks(tokenURL)
-       if bustedAuth && clientSecret != "" {
-               v.Set("client_secret", clientSecret)
+       if bustedAuth {
+               if clientID != "" {
+                       v.Set("client_id", clientID)
+               }
+               if clientSecret != "" {
+                       v.Set("client_secret", clientSecret)
+               }
        }
        req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode()))
        if err != nil {
@@ -162,9 +196,9 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
        }
        req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
        if !bustedAuth {
-               req.SetBasicAuth(clientID, clientSecret)
+               req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
        }
-       r, err := hc.Do(req)
+       r, err := ctxhttp.Do(ctx, ContextClient(ctx), req)
        if err != nil {
                return nil, err
        }
@@ -174,7 +208,10 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
                return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
        }
        if code := r.StatusCode; code < 200 || code > 299 {
-               return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body)
+               return nil, &RetrieveError{
+                       Response: r,
+                       Body:     body,
+               }
        }
 
        var token *Token
@@ -221,5 +258,17 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
        if token.RefreshToken == "" {
                token.RefreshToken = v.Get("refresh_token")
        }
+       if token.AccessToken == "" {
+               return token, errors.New("oauth2: server response missing access_token")
+       }
        return token, nil
 }
+
+type RetrieveError struct {
+       Response *http.Response
+       Body     []byte
+}
+
+func (r *RetrieveError) Error() string {
+       return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
+}
index f1f173e345db0c2582b7491d9639aa95ceb3cb4c..572074a637dd6fbf13571900bac00289871d9dcc 100644 (file)
@@ -2,13 +2,11 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// Package internal contains support packages for oauth2 package.
 package internal
 
 import (
+       "context"
        "net/http"
-
-       "golang.org/x/net/context"
 )
 
 // HTTPClient is the context key to use with golang.org/x/net/context's
@@ -20,50 +18,16 @@ var HTTPClient ContextKey
 // because nobody else can create a ContextKey, being unexported.
 type ContextKey struct{}
 
-// ContextClientFunc is a func which tries to return an *http.Client
-// given a Context value. If it returns an error, the search stops
-// with that error.  If it returns (nil, nil), the search continues
-// down the list of registered funcs.
-type ContextClientFunc func(context.Context) (*http.Client, error)
-
-var contextClientFuncs []ContextClientFunc
-
-func RegisterContextClientFunc(fn ContextClientFunc) {
-       contextClientFuncs = append(contextClientFuncs, fn)
-}
+var appengineClientHook func(context.Context) *http.Client
 
-func ContextClient(ctx context.Context) (*http.Client, error) {
+func ContextClient(ctx context.Context) *http.Client {
        if ctx != nil {
                if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok {
-                       return hc, nil
-               }
-       }
-       for _, fn := range contextClientFuncs {
-               c, err := fn(ctx)
-               if err != nil {
-                       return nil, err
-               }
-               if c != nil {
-                       return c, nil
+                       return hc
                }
        }
-       return http.DefaultClient, nil
-}
-
-func ContextTransport(ctx context.Context) http.RoundTripper {
-       hc, err := ContextClient(ctx)
-       // This is a rare error case (somebody using nil on App Engine).
-       if err != nil {
-               return ErrorTransport{err}
+       if appengineClientHook != nil {
+               return appengineClientHook(ctx)
        }
-       return hc.Transport
-}
-
-// ErrorTransport returns the specified error on RoundTrip.
-// This RoundTripper should be used in rare error cases where
-// error handling can be postponed to response handling time.
-type ErrorTransport struct{ Err error }
-
-func (t ErrorTransport) RoundTrip(*http.Request) (*http.Response, error) {
-       return nil, t.Err
+       return http.DefaultClient
 }
index 7b06bfe1ef148cee7ccf727b8c277f5f3b5c00e7..0a3c1e16325277b1d61c1de660c687946d786b27 100644 (file)
@@ -3,19 +3,20 @@
 // license that can be found in the LICENSE file.
 
 // Package oauth2 provides support for making
-// OAuth2 authorized and authenticated HTTP requests.
+// OAuth2 authorized and authenticated HTTP requests,
+// as specified in RFC 6749.
 // It can additionally grant authorization with Bearer JWT.
 package oauth2 // import "golang.org/x/oauth2"
 
 import (
        "bytes"
+       "context"
        "errors"
        "net/http"
        "net/url"
        "strings"
        "sync"
 
-       "golang.org/x/net/context"
        "golang.org/x/oauth2/internal"
 )
 
@@ -117,21 +118,30 @@ func SetAuthURLParam(key, value string) AuthCodeOption {
 // that asks for permissions for the required scopes explicitly.
 //
 // State is a token to protect the user from CSRF attacks. You must
-// always provide a non-zero string and validate that it matches the
+// always provide a non-empty string and validate that it matches the
 // the state query parameter on your redirect callback.
 // See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
 //
 // Opts may include AccessTypeOnline or AccessTypeOffline, as well
 // as ApprovalForce.
+// It can also be used to pass the PKCE challange.
+// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
 func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
        var buf bytes.Buffer
        buf.WriteString(c.Endpoint.AuthURL)
        v := url.Values{
                "response_type": {"code"},
                "client_id":     {c.ClientID},
-               "redirect_uri":  internal.CondVal(c.RedirectURL),
-               "scope":         internal.CondVal(strings.Join(c.Scopes, " ")),
-               "state":         internal.CondVal(state),
+       }
+       if c.RedirectURL != "" {
+               v.Set("redirect_uri", c.RedirectURL)
+       }
+       if len(c.Scopes) > 0 {
+               v.Set("scope", strings.Join(c.Scopes, " "))
+       }
+       if state != "" {
+               // TODO(light): Docs say never to omit state; don't allow empty.
+               v.Set("state", state)
        }
        for _, opt := range opts {
                opt.setValue(v)
@@ -157,12 +167,15 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
 // The HTTP client to use is derived from the context.
 // If nil, http.DefaultClient is used.
 func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) {
-       return retrieveToken(ctx, c, url.Values{
+       v := url.Values{
                "grant_type": {"password"},
                "username":   {username},
                "password":   {password},
-               "scope":      internal.CondVal(strings.Join(c.Scopes, " ")),
-       })
+       }
+       if len(c.Scopes) > 0 {
+               v.Set("scope", strings.Join(c.Scopes, " "))
+       }
+       return retrieveToken(ctx, c, v)
 }
 
 // Exchange converts an authorization code into a token.
@@ -175,13 +188,21 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor
 //
 // The code will be in the *http.Request.FormValue("code"). Before
 // calling Exchange, be sure to validate FormValue("state").
-func (c *Config) Exchange(ctx context.Context, code string) (*Token, error) {
-       return retrieveToken(ctx, c, url.Values{
-               "grant_type":   {"authorization_code"},
-               "code":         {code},
-               "redirect_uri": internal.CondVal(c.RedirectURL),
-               "scope":        internal.CondVal(strings.Join(c.Scopes, " ")),
-       })
+//
+// Opts may include the PKCE verifier code if previously used in AuthCodeURL.
+// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
+func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) {
+       v := url.Values{
+               "grant_type": {"authorization_code"},
+               "code":       {code},
+       }
+       if c.RedirectURL != "" {
+               v.Set("redirect_uri", c.RedirectURL)
+       }
+       for _, opt := range opts {
+               opt.setValue(v)
+       }
+       return retrieveToken(ctx, c, v)
 }
 
 // Client returns an HTTP client using the provided token.
@@ -292,20 +313,20 @@ var HTTPClient internal.ContextKey
 // NewClient creates an *http.Client from a Context and TokenSource.
 // The returned client is not valid beyond the lifetime of the context.
 //
+// Note that if a custom *http.Client is provided via the Context it
+// is used only for token acquisition and is not used to configure the
+// *http.Client returned from NewClient.
+//
 // As a special case, if src is nil, a non-OAuth2 client is returned
 // using the provided context. This exists to support related OAuth2
 // packages.
 func NewClient(ctx context.Context, src TokenSource) *http.Client {
        if src == nil {
-               c, err := internal.ContextClient(ctx)
-               if err != nil {
-                       return &http.Client{Transport: internal.ErrorTransport{Err: err}}
-               }
-               return c
+               return internal.ContextClient(ctx)
        }
        return &http.Client{
                Transport: &Transport{
-                       Base:   internal.ContextTransport(ctx),
+                       Base:   internal.ContextClient(ctx).Transport,
                        Source: ReuseTokenSource(nil, src),
                },
        }
index 7a3167f15b04d4305bcfa8d9aba3d67039df2daf..9be1ae537376ed3c6333781070864ac6164ff64e 100644 (file)
@@ -5,13 +5,14 @@
 package oauth2
 
 import (
+       "context"
+       "fmt"
        "net/http"
        "net/url"
        "strconv"
        "strings"
        "time"
 
-       "golang.org/x/net/context"
        "golang.org/x/oauth2/internal"
 )
 
@@ -20,7 +21,7 @@ import (
 // expirations due to client-server time mismatches.
 const expiryDelta = 10 * time.Second
 
-// Token represents the crendentials used to authorize
+// Token represents the credentials used to authorize
 // the requests to access protected resources on the OAuth 2.0
 // provider's backend.
 //
@@ -123,7 +124,7 @@ func (t *Token) expired() bool {
        if t.Expiry.IsZero() {
                return false
        }
-       return t.Expiry.Add(-expiryDelta).Before(time.Now())
+       return t.Expiry.Round(0).Add(-expiryDelta).Before(time.Now())
 }
 
 // Valid reports whether t is non-nil, has an AccessToken, and is not expired.
@@ -152,7 +153,23 @@ func tokenFromInternal(t *internal.Token) *Token {
 func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
        tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v)
        if err != nil {
+               if rErr, ok := err.(*internal.RetrieveError); ok {
+                       return nil, (*RetrieveError)(rErr)
+               }
                return nil, err
        }
        return tokenFromInternal(tk), nil
 }
+
+// RetrieveError is the error returned when the token endpoint returns a
+// non-2XX HTTP status code.
+type RetrieveError struct {
+       Response *http.Response
+       // Body is the body that was consumed by reading Response.Body.
+       // It may be truncated.
+       Body []byte
+}
+
+func (r *RetrieveError) Error() string {
+       return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
+}
index 92ac7e2531f452bd146ae1dc6a2ed0d90040b8d1..aa0d34f1e0eaf5b4ab9f35244ebe3503bd34e689 100644 (file)
@@ -31,9 +31,17 @@ type Transport struct {
 }
 
 // RoundTrip authorizes and authenticates the request with an
-// access token. If no token exists or token is expired,
-// tries to refresh/fetch a new token.
+// access token from Transport's Source.
 func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
+       reqBodyClosed := false
+       if req.Body != nil {
+               defer func() {
+                       if !reqBodyClosed {
+                               req.Body.Close()
+                       }
+               }()
+       }
+
        if t.Source == nil {
                return nil, errors.New("oauth2: Transport's Source is nil")
        }
@@ -46,6 +54,10 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
        token.SetAuthHeader(req2)
        t.setModReq(req, req2)
        res, err := t.base().RoundTrip(req2)
+
+       // req.Body is assumed to have been closed by the base RoundTripper.
+       reqBodyClosed = true
+
        if err != nil {
                t.setModReq(req, nil)
                return nil, err