revision = "3a771d992973f24aa725d07868b467d1ddfceafb"
[[projects]]
- digest = "1:67351095005f164e748a5a21899d1403b03878cb2d40a7b0f742376e6eeda974"
+ digest = "1:c10f35be6200b09e26da267ca80f837315093ecaba27e7a223071380efb9dd32"
name = "github.com/blevesearch/bleve"
packages = [
".",
"search/searcher",
]
pruneopts = "NUT"
- revision = "ff210fbc6d348ad67aa5754eaea11a463fcddafd"
+ revision = "c74e08f039e56cef576e4336382b2a2d12d9e026"
[[projects]]
branch = "master"
revision = "e3534c89ef969912856dfa39e56b09e58c5f5daf"
[[projects]]
- digest = "1:23f75ae90fcc38dac6fad6881006ea7d0f2c78db5f9f81f3df558dc91460e61f"
+ digest = "1:4b992ec853d0ea9bac3dcf09a64af61de1a392e6cb0eef2204c0c92f4ae6b911"
name = "github.com/markbates/goth"
packages = [
".",
"providers/twitter",
]
pruneopts = "NUT"
- revision = "f9c6649ab984d6ea71ef1e13b7b1cdffcf4592d3"
- version = "v1.46.1"
+ revision = "bc6d8ddf751a745f37ca5567dbbfc4157bbf5da9"
+ version = "v1.47.2"
[[projects]]
digest = "1:c9724c929d27a14475a45b17a267dbc60671c0bc2c5c05ed21f011f7b5bc9fb5"
[[projects]]
branch = "master"
- digest = "1:6d5ed712653ea5321fe3e3475ab2188cf362a4e0d31e9fd3acbd4dfbbca0d680"
+ digest = "1:d0a0bdd2b64d981aa4e6a1ade90431d042cd7fa31b584e33d45e62cbfec43380"
name = "golang.org/x/net"
packages = [
"context",
+ "context/ctxhttp",
"html",
"html/atom",
"html/charset",
revision = "9b4f9f5ad5197c79fd623a3638e70d8b26cef344"
[[projects]]
- digest = "1:8159a9cda4b8810aaaeb0d60e2fa68e2fd86d8af4ec8f5059830839e3c8d93d5"
+ branch = "master"
+ digest = "1:274a6321a5a9f185eeb3fab5d7d8397e0e9f57737490d749f562c7e205ffbc2e"
name = "golang.org/x/oauth2"
packages = [
".",
"internal",
]
pruneopts = "NUT"
- revision = "c10ba270aa0bf8b8c1c986e103859c67a9103061"
+ revision = "c453e0c757598fd055e170a3a359263c91e13153"
[[projects]]
digest = "1:9f303486d623f840492bfeb48eb906a94e9d3fe638a761639b72ce64bf7bfcc3"
branch = "master"
name = "code.gitea.io/sdk"
+[[constraint]]
+# branch = "master"
+ revision = "c74e08f039e56cef576e4336382b2a2d12d9e026"
+ name = "github.com/blevesearch/bleve"
+#Not targetting v0.7.0 since standard where use only just after this tag
+
[[constraint]]
revision = "12dd70caea0268ac0d6c2707d0611ef601e7c64e"
name = "golang.org/x/crypto"
[[constraint]]
name = "github.com/markbates/goth"
- version = "1.46.1"
+ version = "1.47.2"
[[constraint]]
branch = "master"
source = "github.com/go-gitea/bolt"
[[override]]
- revision = "c10ba270aa0bf8b8c1c986e103859c67a9103061"
+ branch = "master"
name = "golang.org/x/oauth2"
[[constraint]]
package bleve
import (
+ "context"
+
"github.com/blevesearch/bleve/document"
"github.com/blevesearch/bleve/index"
"github.com/blevesearch/bleve/index/store"
"github.com/blevesearch/bleve/mapping"
- "golang.org/x/net/context"
)
// A Batch groups together multiple Index and Delete
// prepare new index snapshot
newSnapshot := &IndexSnapshot{
parent: s,
- segment: make([]*SegmentSnapshot, nsegs, nsegs+1),
- offsets: make([]uint64, nsegs, nsegs+1),
+ segment: make([]*SegmentSnapshot, 0, nsegs+1),
+ offsets: make([]uint64, 0, nsegs+1),
internal: make(map[string][]byte, len(s.root.internal)),
epoch: s.nextSnapshotEpoch,
refs: 1,
return err
}
}
- newSnapshot.segment[i] = &SegmentSnapshot{
+
+ newss := &SegmentSnapshot{
id: s.root.segment[i].id,
segment: s.root.segment[i].segment,
cachedDocs: s.root.segment[i].cachedDocs,
}
- s.root.segment[i].segment.AddRef()
// apply new obsoletions
if s.root.segment[i].deleted == nil {
- newSnapshot.segment[i].deleted = delta
+ newss.deleted = delta
} else {
- newSnapshot.segment[i].deleted = roaring.Or(s.root.segment[i].deleted, delta)
+ newss.deleted = roaring.Or(s.root.segment[i].deleted, delta)
}
- newSnapshot.offsets[i] = running
- running += s.root.segment[i].Count()
-
+ // check for live size before copying
+ if newss.LiveSize() > 0 {
+ newSnapshot.segment = append(newSnapshot.segment, newss)
+ s.root.segment[i].segment.AddRef()
+ newSnapshot.offsets = append(newSnapshot.offsets, running)
+ running += s.root.segment[i].Count()
+ }
}
+
// append new segment, if any, to end of the new index snapshot
if next.data != nil {
newSegmentSnapshot := &SegmentSnapshot{
// prepare new index snapshot
currSize := len(s.root.segment)
newSize := currSize + 1 - len(nextMerge.old)
+
+ // empty segments deletion
+ if nextMerge.new == nil {
+ newSize--
+ }
+
newSnapshot := &IndexSnapshot{
parent: s,
segment: make([]*SegmentSnapshot, 0, newSize),
segmentID := s.root.segment[i].id
if segSnapAtMerge, ok := nextMerge.old[segmentID]; ok {
// this segment is going away, see if anything else was deleted since we started the merge
- if s.root.segment[i].deleted != nil {
+ if segSnapAtMerge != nil && s.root.segment[i].deleted != nil {
// assume all these deletes are new
deletedSince := s.root.segment[i].deleted
// if we already knew about some of them, remove
newSegmentDeleted.Add(uint32(newDocNum))
}
}
- } else {
+ // clean up the old segment map to figure out the
+ // obsolete segments wrt root in meantime, whatever
+ // segments left behind in old map after processing
+ // the root segments would be the obsolete segment set
+ delete(nextMerge.old, segmentID)
+
+ } else if s.root.segment[i].LiveSize() > 0 {
// this segment is staying
newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{
id: s.root.segment[i].id,
}
}
- // put new segment at end
- newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{
- id: nextMerge.id,
- segment: nextMerge.new, // take ownership for nextMerge.new's ref-count
- deleted: newSegmentDeleted,
- cachedDocs: &cachedDocs{cache: nil},
- })
- newSnapshot.offsets = append(newSnapshot.offsets, running)
+ // before the newMerge introduction, need to clean the newly
+ // merged segment wrt the current root segments, hence
+ // applying the obsolete segment contents to newly merged segment
+ for segID, ss := range nextMerge.old {
+ obsoleted := ss.DocNumbersLive()
+ if obsoleted != nil {
+ obsoletedIter := obsoleted.Iterator()
+ for obsoletedIter.HasNext() {
+ oldDocNum := obsoletedIter.Next()
+ newDocNum := nextMerge.oldNewDocNums[segID][oldDocNum]
+ newSegmentDeleted.Add(uint32(newDocNum))
+ }
+ }
+ }
+ // In case where all the docs in the newly merged segment getting
+ // deleted by the time we reach here, can skip the introduction.
+ if nextMerge.new != nil &&
+ nextMerge.new.Count() > newSegmentDeleted.GetCardinality() {
+ // put new segment at end
+ newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{
+ id: nextMerge.id,
+ segment: nextMerge.new, // take ownership for nextMerge.new's ref-count
+ deleted: newSegmentDeleted,
+ cachedDocs: &cachedDocs{cache: nil},
+ })
+ newSnapshot.offsets = append(newSnapshot.offsets, running)
+ }
+
+ newSnapshot.AddRef() // 1 ref for the nextMerge.notify response
// swap in new segment
rootPrev := s.root
_ = rootPrev.DecRef()
}
- // notify merger we incorporated this
+ // notify requester that we incorporated this
+ nextMerge.notify <- newSnapshot
close(nextMerge.notify)
}
package scorch
import (
+ "bytes"
+ "encoding/json"
+
"fmt"
"os"
"sync/atomic"
func (s *Scorch) mergerLoop() {
var lastEpochMergePlanned uint64
+ mergePlannerOptions, err := s.parseMergePlannerOptions()
+ if err != nil {
+ s.fireAsyncError(fmt.Errorf("mergePlannerOption json parsing err: %v", err))
+ s.asyncTasks.Done()
+ return
+ }
+
OUTER:
for {
select {
startTime := time.Now()
// lets get started
- err := s.planMergeAtSnapshot(ourSnapshot)
+ err := s.planMergeAtSnapshot(ourSnapshot, mergePlannerOptions)
if err != nil {
s.fireAsyncError(fmt.Errorf("merging err: %v", err))
_ = ourSnapshot.DecRef()
_ = ourSnapshot.DecRef()
// tell the persister we're waiting for changes
- // first make a notification chan
- notifyUs := make(notificationChan)
+ // first make a epochWatcher chan
+ ew := &epochWatcher{
+ epoch: lastEpochMergePlanned,
+ notifyCh: make(notificationChan, 1),
+ }
// give it to the persister
select {
case <-s.closeCh:
break OUTER
- case s.persisterNotifier <- notifyUs:
- }
-
- // check again
- s.rootLock.RLock()
- ourSnapshot = s.root
- ourSnapshot.AddRef()
- s.rootLock.RUnlock()
-
- if ourSnapshot.epoch != lastEpochMergePlanned {
- startTime := time.Now()
-
- // lets get started
- err := s.planMergeAtSnapshot(ourSnapshot)
- if err != nil {
- s.fireAsyncError(fmt.Errorf("merging err: %v", err))
- _ = ourSnapshot.DecRef()
- continue OUTER
- }
- lastEpochMergePlanned = ourSnapshot.epoch
-
- s.fireEvent(EventKindMergerProgress, time.Since(startTime))
+ case s.persisterNotifier <- ew:
}
- _ = ourSnapshot.DecRef()
- // now wait for it (but also detect close)
+ // now wait for persister (but also detect close)
select {
case <-s.closeCh:
break OUTER
- case <-notifyUs:
- // woken up, next loop should pick up work
+ case <-ew.notifyCh:
}
}
}
s.asyncTasks.Done()
}
-func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error {
+func (s *Scorch) parseMergePlannerOptions() (*mergeplan.MergePlanOptions,
+ error) {
+ mergePlannerOptions := mergeplan.DefaultMergePlanOptions
+ if v, ok := s.config["scorchMergePlanOptions"]; ok {
+ b, err := json.Marshal(v)
+ if err != nil {
+ return &mergePlannerOptions, err
+ }
+
+ err = json.Unmarshal(b, &mergePlannerOptions)
+ if err != nil {
+ return &mergePlannerOptions, err
+ }
+ }
+ return &mergePlannerOptions, nil
+}
+
+func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot,
+ options *mergeplan.MergePlanOptions) error {
// build list of zap segments in this snapshot
var onlyZapSnapshots []mergeplan.Segment
for _, segmentSnapshot := range ourSnapshot.segment {
}
// give this list to the planner
- resultMergePlan, err := mergeplan.Plan(onlyZapSnapshots, nil)
+ resultMergePlan, err := mergeplan.Plan(onlyZapSnapshots, options)
if err != nil {
return fmt.Errorf("merge planning err: %v", err)
}
}
// process tasks in serial for now
- var notifications []notificationChan
+ var notifications []chan *IndexSnapshot
for _, task := range resultMergePlan.Tasks {
+ if len(task.Segments) == 0 {
+ continue
+ }
+
oldMap := make(map[uint64]*SegmentSnapshot)
newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1)
segmentsToMerge := make([]*zap.Segment, 0, len(task.Segments))
if segSnapshot, ok := planSegment.(*SegmentSnapshot); ok {
oldMap[segSnapshot.id] = segSnapshot
if zapSeg, ok := segSnapshot.segment.(*zap.Segment); ok {
- segmentsToMerge = append(segmentsToMerge, zapSeg)
- docsToDrop = append(docsToDrop, segSnapshot.deleted)
+ if segSnapshot.LiveSize() == 0 {
+ oldMap[segSnapshot.id] = nil
+ } else {
+ segmentsToMerge = append(segmentsToMerge, zapSeg)
+ docsToDrop = append(docsToDrop, segSnapshot.deleted)
+ }
}
}
}
- filename := zapFileName(newSegmentID)
- s.markIneligibleForRemoval(filename)
- path := s.path + string(os.PathSeparator) + filename
- newDocNums, err := zap.Merge(segmentsToMerge, docsToDrop, path, DefaultChunkFactor)
- if err != nil {
- s.unmarkIneligibleForRemoval(filename)
- return fmt.Errorf("merging failed: %v", err)
- }
- segment, err := zap.Open(path)
- if err != nil {
- s.unmarkIneligibleForRemoval(filename)
- return err
+ var oldNewDocNums map[uint64][]uint64
+ var segment segment.Segment
+ if len(segmentsToMerge) > 0 {
+ filename := zapFileName(newSegmentID)
+ s.markIneligibleForRemoval(filename)
+ path := s.path + string(os.PathSeparator) + filename
+ newDocNums, err := zap.Merge(segmentsToMerge, docsToDrop, path, 1024)
+ if err != nil {
+ s.unmarkIneligibleForRemoval(filename)
+ return fmt.Errorf("merging failed: %v", err)
+ }
+ segment, err = zap.Open(path)
+ if err != nil {
+ s.unmarkIneligibleForRemoval(filename)
+ return err
+ }
+ oldNewDocNums = make(map[uint64][]uint64)
+ for i, segNewDocNums := range newDocNums {
+ oldNewDocNums[task.Segments[i].Id()] = segNewDocNums
+ }
}
+
sm := &segmentMerge{
id: newSegmentID,
old: oldMap,
- oldNewDocNums: make(map[uint64][]uint64),
+ oldNewDocNums: oldNewDocNums,
new: segment,
- notify: make(notificationChan),
+ notify: make(chan *IndexSnapshot, 1),
}
notifications = append(notifications, sm.notify)
- for i, segNewDocNums := range newDocNums {
- sm.oldNewDocNums[task.Segments[i].Id()] = segNewDocNums
- }
// give it to the introducer
select {
case <-s.closeCh:
+ _ = segment.Close()
return nil
case s.merges <- sm:
}
select {
case <-s.closeCh:
return nil
- case <-notification:
+ case newSnapshot := <-notification:
+ if newSnapshot != nil {
+ _ = newSnapshot.DecRef()
+ }
}
}
return nil
old map[uint64]*SegmentSnapshot
oldNewDocNums map[uint64][]uint64
new segment.Segment
- notify notificationChan
+ notify chan *IndexSnapshot
+}
+
+// perform a merging of the given SegmentBase instances into a new,
+// persisted segment, and synchronously introduce that new segment
+// into the root
+func (s *Scorch) mergeSegmentBases(snapshot *IndexSnapshot,
+ sbs []*zap.SegmentBase, sbsDrops []*roaring.Bitmap, sbsIndexes []int,
+ chunkFactor uint32) (uint64, *IndexSnapshot, uint64, error) {
+ var br bytes.Buffer
+
+ cr := zap.NewCountHashWriter(&br)
+
+ newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset,
+ docValueOffset, dictLocs, fieldsInv, fieldsMap, err :=
+ zap.MergeToWriter(sbs, sbsDrops, chunkFactor, cr)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+
+ sb, err := zap.InitSegmentBase(br.Bytes(), cr.Sum32(), chunkFactor,
+ fieldsMap, fieldsInv, numDocs, storedIndexOffset, fieldsIndexOffset,
+ docValueOffset, dictLocs)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+
+ newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1)
+
+ filename := zapFileName(newSegmentID)
+ path := s.path + string(os.PathSeparator) + filename
+ err = zap.PersistSegmentBase(sb, path)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+
+ segment, err := zap.Open(path)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+
+ sm := &segmentMerge{
+ id: newSegmentID,
+ old: make(map[uint64]*SegmentSnapshot),
+ oldNewDocNums: make(map[uint64][]uint64),
+ new: segment,
+ notify: make(chan *IndexSnapshot, 1),
+ }
+
+ for i, idx := range sbsIndexes {
+ ss := snapshot.segment[idx]
+ sm.old[ss.id] = ss
+ sm.oldNewDocNums[ss.id] = newDocNums[i]
+ }
+
+ select { // send to introducer
+ case <-s.closeCh:
+ _ = segment.DecRef()
+ return 0, nil, 0, nil // TODO: return ErrInterruptedClosed?
+ case s.merges <- sm:
+ }
+
+ select { // wait for introduction to complete
+ case <-s.closeCh:
+ return 0, nil, 0, nil // TODO: return ErrInterruptedClosed?
+ case newSnapshot := <-sm.notify:
+ return numDocs, newSnapshot, newSegmentID, nil
+ }
}
// While we’re over budget, keep looping, which might produce
// another MergeTask.
- for len(eligibles) > budgetNumSegments {
+ for len(eligibles) > 0 && (len(eligibles)+len(rv.Tasks)) > budgetNumSegments {
// Track a current best roster as we examine and score
// potential rosters of merges.
var bestRoster []Segment
var bestRosterScore float64 // Lower score is better.
- for startIdx := 0; startIdx < len(eligibles)-o.SegmentsPerMergeTask; startIdx++ {
+ for startIdx := 0; startIdx < len(eligibles); startIdx++ {
var roster []Segment
var rosterLiveSize int64
var DefaultChunkFactor uint32 = 1024
+// Arbitrary number, need to make it configurable.
+// Lower values like 10/making persister really slow
+// doesn't work well as it is creating more files to
+// persist for in next persist iteration and spikes the # FDs.
+// Ideal value should let persister also proceed at
+// an optimum pace so that the merger can skip
+// many intermediate snapshots.
+// This needs to be based on empirical data.
+// TODO - may need to revisit this approach/value.
+var epochDistance = uint64(5)
+
type notificationChan chan struct{}
func (s *Scorch) persisterLoop() {
defer s.asyncTasks.Done()
- var notifyChs []notificationChan
- var lastPersistedEpoch uint64
+ var persistWatchers []*epochWatcher
+ var lastPersistedEpoch, lastMergedEpoch uint64
+ var ew *epochWatcher
OUTER:
for {
select {
case <-s.closeCh:
break OUTER
- case notifyCh := <-s.persisterNotifier:
- notifyChs = append(notifyChs, notifyCh)
+ case ew = <-s.persisterNotifier:
+ persistWatchers = append(persistWatchers, ew)
default:
}
+ if ew != nil && ew.epoch > lastMergedEpoch {
+ lastMergedEpoch = ew.epoch
+ }
+ persistWatchers = s.pausePersisterForMergerCatchUp(lastPersistedEpoch,
+ &lastMergedEpoch, persistWatchers)
var ourSnapshot *IndexSnapshot
var ourPersisted []chan error
}
lastPersistedEpoch = ourSnapshot.epoch
- for _, notifyCh := range notifyChs {
- close(notifyCh)
+ for _, ew := range persistWatchers {
+ close(ew.notifyCh)
}
- notifyChs = nil
+
+ persistWatchers = nil
_ = ourSnapshot.DecRef()
changed := false
break OUTER
case <-w.notifyCh:
// woken up, next loop should pick up work
+ continue OUTER
+ case ew = <-s.persisterNotifier:
+ // if the watchers are already caught up then let them wait,
+ // else let them continue to do the catch up
+ persistWatchers = append(persistWatchers, ew)
+ }
+ }
+}
+
+func notifyMergeWatchers(lastPersistedEpoch uint64,
+ persistWatchers []*epochWatcher) []*epochWatcher {
+ var watchersNext []*epochWatcher
+ for _, w := range persistWatchers {
+ if w.epoch < lastPersistedEpoch {
+ close(w.notifyCh)
+ } else {
+ watchersNext = append(watchersNext, w)
}
}
+ return watchersNext
+}
+
+func (s *Scorch) pausePersisterForMergerCatchUp(lastPersistedEpoch uint64, lastMergedEpoch *uint64,
+ persistWatchers []*epochWatcher) []*epochWatcher {
+
+ // first, let the watchers proceed if they lag behind
+ persistWatchers = notifyMergeWatchers(lastPersistedEpoch, persistWatchers)
+
+OUTER:
+ // check for slow merger and await until the merger catch up
+ for lastPersistedEpoch > *lastMergedEpoch+epochDistance {
+
+ select {
+ case <-s.closeCh:
+ break OUTER
+ case ew := <-s.persisterNotifier:
+ persistWatchers = append(persistWatchers, ew)
+ *lastMergedEpoch = ew.epoch
+ }
+
+ // let the watchers proceed if they lag behind
+ persistWatchers = notifyMergeWatchers(lastPersistedEpoch, persistWatchers)
+ }
+
+ return persistWatchers
}
func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
- // start a write transaction
- tx, err := s.rootBolt.Begin(true)
+ persisted, err := s.persistSnapshotMaybeMerge(snapshot)
if err != nil {
return err
}
- // defer fsync of the rootbolt
- defer func() {
- if err == nil {
- err = s.rootBolt.Sync()
+ if persisted {
+ return nil
+ }
+
+ return s.persistSnapshotDirect(snapshot)
+}
+
+// DefaultMinSegmentsForInMemoryMerge represents the default number of
+// in-memory zap segments that persistSnapshotMaybeMerge() needs to
+// see in an IndexSnapshot before it decides to merge and persist
+// those segments
+var DefaultMinSegmentsForInMemoryMerge = 2
+
+// persistSnapshotMaybeMerge examines the snapshot and might merge and
+// persist the in-memory zap segments if there are enough of them
+func (s *Scorch) persistSnapshotMaybeMerge(snapshot *IndexSnapshot) (
+ bool, error) {
+ // collect the in-memory zap segments (SegmentBase instances)
+ var sbs []*zap.SegmentBase
+ var sbsDrops []*roaring.Bitmap
+ var sbsIndexes []int
+
+ for i, segmentSnapshot := range snapshot.segment {
+ if sb, ok := segmentSnapshot.segment.(*zap.SegmentBase); ok {
+ sbs = append(sbs, sb)
+ sbsDrops = append(sbsDrops, segmentSnapshot.deleted)
+ sbsIndexes = append(sbsIndexes, i)
}
+ }
+
+ if len(sbs) < DefaultMinSegmentsForInMemoryMerge {
+ return false, nil
+ }
+
+ _, newSnapshot, newSegmentID, err := s.mergeSegmentBases(
+ snapshot, sbs, sbsDrops, sbsIndexes, DefaultChunkFactor)
+ if err != nil {
+ return false, err
+ }
+ if newSnapshot == nil {
+ return false, nil
+ }
+
+ defer func() {
+ _ = newSnapshot.DecRef()
}()
- // defer commit/rollback transaction
+
+ mergedSegmentIDs := map[uint64]struct{}{}
+ for _, idx := range sbsIndexes {
+ mergedSegmentIDs[snapshot.segment[idx].id] = struct{}{}
+ }
+
+ // construct a snapshot that's logically equivalent to the input
+ // snapshot, but with merged segments replaced by the new segment
+ equiv := &IndexSnapshot{
+ parent: snapshot.parent,
+ segment: make([]*SegmentSnapshot, 0, len(snapshot.segment)),
+ internal: snapshot.internal,
+ epoch: snapshot.epoch,
+ }
+
+ // copy to the equiv the segments that weren't replaced
+ for _, segment := range snapshot.segment {
+ if _, wasMerged := mergedSegmentIDs[segment.id]; !wasMerged {
+ equiv.segment = append(equiv.segment, segment)
+ }
+ }
+
+ // append to the equiv the new segment
+ for _, segment := range newSnapshot.segment {
+ if segment.id == newSegmentID {
+ equiv.segment = append(equiv.segment, &SegmentSnapshot{
+ id: newSegmentID,
+ segment: segment.segment,
+ deleted: nil, // nil since merging handled deletions
+ })
+ break
+ }
+ }
+
+ err = s.persistSnapshotDirect(equiv)
+ if err != nil {
+ return false, err
+ }
+
+ return true, nil
+}
+
+func (s *Scorch) persistSnapshotDirect(snapshot *IndexSnapshot) (err error) {
+ // start a write transaction
+ tx, err := s.rootBolt.Begin(true)
+ if err != nil {
+ return err
+ }
+ // defer rollback on error
defer func() {
- if err == nil {
- err = tx.Commit()
- } else {
+ if err != nil {
_ = tx.Rollback()
}
}()
newSegmentPaths := make(map[uint64]string)
// first ensure that each segment in this snapshot has been persisted
- for i, segmentSnapshot := range snapshot.segment {
- snapshotSegmentKey := segment.EncodeUvarintAscending(nil, uint64(i))
- snapshotSegmentBucket, err2 := snapshotBucket.CreateBucketIfNotExists(snapshotSegmentKey)
- if err2 != nil {
- return err2
+ for _, segmentSnapshot := range snapshot.segment {
+ snapshotSegmentKey := segment.EncodeUvarintAscending(nil, segmentSnapshot.id)
+ snapshotSegmentBucket, err := snapshotBucket.CreateBucketIfNotExists(snapshotSegmentKey)
+ if err != nil {
+ return err
}
switch seg := segmentSnapshot.segment.(type) {
case *zap.SegmentBase:
// need to persist this to disk
filename := zapFileName(segmentSnapshot.id)
path := s.path + string(os.PathSeparator) + filename
- err2 := zap.PersistSegmentBase(seg, path)
- if err2 != nil {
- return fmt.Errorf("error persisting segment: %v", err2)
+ err = zap.PersistSegmentBase(seg, path)
+ if err != nil {
+ return fmt.Errorf("error persisting segment: %v", err)
}
newSegmentPaths[segmentSnapshot.id] = path
err = snapshotSegmentBucket.Put(boltPathKey, []byte(filename))
}
}
- // only alter the root if we actually persisted a segment
- // (sometimes its just a new snapshot, possibly with new internal values)
+ // we need to swap in a new root only when we've persisted 1 or
+ // more segments -- whereby the new root would have 1-for-1
+ // replacements of in-memory segments with file-based segments
+ //
+ // other cases like updates to internal values only, and/or when
+ // there are only deletions, are already covered and persisted by
+ // the newly populated boltdb snapshotBucket above
if len(newSegmentPaths) > 0 {
// now try to open all the new snapshots
newSegments := make(map[uint64]segment.Segment)
+ defer func() {
+ for _, s := range newSegments {
+ if s != nil {
+ // cleanup segments that were opened but not
+ // swapped into the new root
+ _ = s.Close()
+ }
+ }
+ }()
for segmentID, path := range newSegmentPaths {
newSegments[segmentID], err = zap.Open(path)
if err != nil {
- for _, s := range newSegments {
- if s != nil {
- _ = s.Close() // cleanup segments that were successfully opened
- }
- }
return fmt.Errorf("error opening new segment at %s, %v", path, err)
}
}
cachedDocs: segmentSnapshot.cachedDocs,
}
newIndexSnapshot.segment[i] = newSegmentSnapshot
+ delete(newSegments, segmentSnapshot.id)
// update items persisted incase of a new segment snapshot
atomic.AddUint64(&s.stats.numItemsPersisted, newSegmentSnapshot.Count())
} else {
for k, v := range s.root.internal {
newIndexSnapshot.internal[k] = v
}
- for _, filename := range filenames {
- delete(s.ineligibleForRemoval, filename)
- }
+
rootPrev := s.root
s.root = newIndexSnapshot
s.rootLock.Unlock()
}
}
+ err = tx.Commit()
+ if err != nil {
+ return err
+ }
+
+ err = s.rootBolt.Sync()
+ if err != nil {
+ return err
+ }
+
+ // allow files to become eligible for removal after commit, such
+ // as file segments from snapshots that came from the merger
+ s.rootLock.Lock()
+ for _, filename := range filenames {
+ delete(s.ineligibleForRemoval, filename)
+ }
+ s.rootLock.Unlock()
+
return nil
}
merges chan *segmentMerge
introducerNotifier chan *epochWatcher
revertToSnapshots chan *snapshotReversion
- persisterNotifier chan notificationChan
+ persisterNotifier chan *epochWatcher
rootBolt *bolt.DB
asyncTasks sync.WaitGroup
}
func (s *Scorch) Open() error {
+ err := s.openBolt()
+ if err != nil {
+ return err
+ }
+
+ s.asyncTasks.Add(1)
+ go s.mainLoop()
+
+ if !s.readOnly && s.path != "" {
+ s.asyncTasks.Add(1)
+ go s.persisterLoop()
+ s.asyncTasks.Add(1)
+ go s.mergerLoop()
+ }
+
+ return nil
+}
+
+func (s *Scorch) openBolt() error {
var ok bool
s.path, ok = s.config["path"].(string)
if !ok {
}
}
}
+
rootBoltPath := s.path + string(os.PathSeparator) + "root.bolt"
var err error
if s.path != "" {
s.merges = make(chan *segmentMerge)
s.introducerNotifier = make(chan *epochWatcher, 1)
s.revertToSnapshots = make(chan *snapshotReversion)
- s.persisterNotifier = make(chan notificationChan)
+ s.persisterNotifier = make(chan *epochWatcher, 1)
if !s.readOnly && s.path != "" {
err := s.removeOldZapFiles() // Before persister or merger create any new files.
}
}
- s.asyncTasks.Add(1)
- go s.mainLoop()
-
- if !s.readOnly && s.path != "" {
- s.asyncTasks.Add(1)
- go s.persisterLoop()
- s.asyncTasks.Add(1)
- go s.mergerLoop()
- }
-
return nil
}
introduction.persisted = make(chan error, 1)
}
- // get read lock, to optimistically prepare obsoleted info
+ // optimistically prepare obsoletes outside of rootLock
s.rootLock.RLock()
- for _, seg := range s.root.segment {
+ root := s.root
+ root.AddRef()
+ s.rootLock.RUnlock()
+
+ for _, seg := range root.segment {
delta, err := seg.segment.DocNumbers(ids)
if err != nil {
- s.rootLock.RUnlock()
return err
}
introduction.obsoletes[seg.id] = delta
}
- s.rootLock.RUnlock()
+
+ _ = root.DecRef()
s.introductions <- introduction
var numTokenFrequencies int
var totLocs int
+ // initial scan for all fieldID's to sort them
+ for _, result := range results {
+ for _, field := range result.Document.CompositeFields {
+ s.getOrDefineField(field.Name())
+ }
+ for _, field := range result.Document.Fields {
+ s.getOrDefineField(field.Name())
+ }
+ }
+ sort.Strings(s.FieldsInv[1:]) // keep _id as first field
+ s.FieldsMap = make(map[string]uint16, len(s.FieldsInv))
+ for fieldID, fieldName := range s.FieldsInv {
+ s.FieldsMap[fieldName] = uint16(fieldID + 1)
+ }
+
processField := func(fieldID uint16, tfs analysis.TokenFrequencies) {
for term, tf := range tfs {
pidPlus1, exists := s.Dicts[fieldID][term]
prefix string
end string
offset int
+
+ dictEntry index.DictEntry // reused across Next()'s
}
// Next returns the next entry in the dictionary
d.offset++
postingID := d.d.segment.Dicts[d.d.fieldID][next]
- return &index.DictEntry{
- Term: next,
- Count: d.d.segment.Postings[postingID-1].GetCardinality(),
- }, nil
+ d.dictEntry.Term = next
+ d.dictEntry.Count = d.d.segment.Postings[postingID-1].GetCardinality()
+ return &d.dictEntry, nil
}
"github.com/golang/snappy"
)
-const version uint32 = 2
+const version uint32 = 3
const fieldNotUninverted = math.MaxUint64
}
func persistStored(memSegment *mem.Segment, w *CountHashWriter) (uint64, error) {
-
var curr int
var metaBuf bytes.Buffer
var data, compressed []byte
+ metaEncoder := govarint.NewU64Base128Encoder(&metaBuf)
+
docNumOffsets := make(map[int]uint64, len(memSegment.Stored))
for docNum, storedValues := range memSegment.Stored {
if docNum != 0 {
// reset buffer if necessary
+ curr = 0
metaBuf.Reset()
data = data[:0]
compressed = compressed[:0]
- curr = 0
}
- metaEncoder := govarint.NewU64Base128Encoder(&metaBuf)
-
st := memSegment.StoredTypes[docNum]
sp := memSegment.StoredPos[docNum]
// encode fields in order
for fieldID := range memSegment.FieldsInv {
if storedFieldValues, ok := storedValues[uint16(fieldID)]; ok {
- // has stored values for this field
- num := len(storedFieldValues)
-
stf := st[uint16(fieldID)]
spf := sp[uint16(fieldID)]
- // process each value
- for i := 0; i < num; i++ {
- // encode field
- _, err2 := metaEncoder.PutU64(uint64(fieldID))
- if err2 != nil {
- return 0, err2
- }
- // encode type
- _, err2 = metaEncoder.PutU64(uint64(stf[i]))
- if err2 != nil {
- return 0, err2
- }
- // encode start offset
- _, err2 = metaEncoder.PutU64(uint64(curr))
- if err2 != nil {
- return 0, err2
- }
- // end len
- _, err2 = metaEncoder.PutU64(uint64(len(storedFieldValues[i])))
- if err2 != nil {
- return 0, err2
- }
- // encode number of array pos
- _, err2 = metaEncoder.PutU64(uint64(len(spf[i])))
- if err2 != nil {
- return 0, err2
- }
- // encode all array positions
- for _, pos := range spf[i] {
- _, err2 = metaEncoder.PutU64(pos)
- if err2 != nil {
- return 0, err2
- }
- }
- // append data
- data = append(data, storedFieldValues[i]...)
- // update curr
- curr += len(storedFieldValues[i])
+ var err2 error
+ curr, data, err2 = persistStoredFieldValues(fieldID,
+ storedFieldValues, stf, spf, curr, metaEncoder, data)
+ if err2 != nil {
+ return 0, err2
}
}
}
- metaEncoder.Close()
+ metaEncoder.Close()
metaBytes := metaBuf.Bytes()
// compress the data
return rv, nil
}
+func persistStoredFieldValues(fieldID int,
+ storedFieldValues [][]byte, stf []byte, spf [][]uint64,
+ curr int, metaEncoder *govarint.Base128Encoder, data []byte) (
+ int, []byte, error) {
+ for i := 0; i < len(storedFieldValues); i++ {
+ // encode field
+ _, err := metaEncoder.PutU64(uint64(fieldID))
+ if err != nil {
+ return 0, nil, err
+ }
+ // encode type
+ _, err = metaEncoder.PutU64(uint64(stf[i]))
+ if err != nil {
+ return 0, nil, err
+ }
+ // encode start offset
+ _, err = metaEncoder.PutU64(uint64(curr))
+ if err != nil {
+ return 0, nil, err
+ }
+ // end len
+ _, err = metaEncoder.PutU64(uint64(len(storedFieldValues[i])))
+ if err != nil {
+ return 0, nil, err
+ }
+ // encode number of array pos
+ _, err = metaEncoder.PutU64(uint64(len(spf[i])))
+ if err != nil {
+ return 0, nil, err
+ }
+ // encode all array positions
+ for _, pos := range spf[i] {
+ _, err = metaEncoder.PutU64(pos)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ data = append(data, storedFieldValues[i]...)
+ curr += len(storedFieldValues[i])
+ }
+
+ return curr, data, nil
+}
+
func persistPostingDetails(memSegment *mem.Segment, w *CountHashWriter, chunkFactor uint32) ([]uint64, []uint64, error) {
var freqOffsets, locOfffsets []uint64
tfEncoder := newChunkedIntCoder(uint64(chunkFactor), uint64(len(memSegment.Stored)-1))
if err != nil {
return nil, err
}
- // resetting encoder for the next field
+ // reseting encoder for the next field
fdvEncoder.Reset()
}
return nil, err
}
+ return InitSegmentBase(br.Bytes(), cr.Sum32(), chunkFactor,
+ memSegment.FieldsMap, memSegment.FieldsInv, numDocs,
+ storedIndexOffset, fieldsIndexOffset, docValueOffset, dictLocs)
+}
+
+func InitSegmentBase(mem []byte, memCRC uint32, chunkFactor uint32,
+ fieldsMap map[string]uint16, fieldsInv []string, numDocs uint64,
+ storedIndexOffset uint64, fieldsIndexOffset uint64, docValueOffset uint64,
+ dictLocs []uint64) (*SegmentBase, error) {
sb := &SegmentBase{
- mem: br.Bytes(),
- memCRC: cr.Sum32(),
+ mem: mem,
+ memCRC: memCRC,
chunkFactor: chunkFactor,
- fieldsMap: memSegment.FieldsMap,
- fieldsInv: memSegment.FieldsInv,
+ fieldsMap: fieldsMap,
+ fieldsInv: fieldsInv,
numDocs: numDocs,
storedIndexOffset: storedIndexOffset,
fieldsIndexOffset: fieldsIndexOffset,
fieldDvIterMap: make(map[uint16]*docValueIterator),
}
- err = sb.loadDvIterators()
+ err := sb.loadDvIterators()
if err != nil {
return nil, err
}
// MetaData represents the data information inside a
// chunk.
type MetaData struct {
- DocID uint64 // docid of the data inside the chunk
+ DocNum uint64 // docNum of the data inside the chunk
DocDvLoc uint64 // starting offset for a given docid
DocDvLen uint64 // length of data inside the chunk for the given docid
}
rv := &chunkedContentCoder{
chunkSize: chunkSize,
chunkLens: make([]uint64, total),
- chunkMeta: []MetaData{},
+ chunkMeta: make([]MetaData, 0, total),
}
return rv
for i := range c.chunkLens {
c.chunkLens[i] = 0
}
- c.chunkMeta = []MetaData{}
+ c.chunkMeta = c.chunkMeta[:0]
}
// Close indicates you are done calling Add() this allows
// write out the metaData slice
for _, meta := range c.chunkMeta {
- _, err := writeUvarints(&c.chunkMetaBuf, meta.DocID, meta.DocDvLoc, meta.DocDvLen)
+ _, err := writeUvarints(&c.chunkMetaBuf, meta.DocNum, meta.DocDvLoc, meta.DocDvLen)
if err != nil {
return err
}
// clearing the chunk specific meta for next chunk
c.chunkBuf.Reset()
c.chunkMetaBuf.Reset()
- c.chunkMeta = []MetaData{}
+ c.chunkMeta = c.chunkMeta[:0]
c.currChunk = chunk
}
}
c.chunkMeta = append(c.chunkMeta, MetaData{
- DocID: docNum,
+ DocNum: docNum,
DocDvLoc: uint64(dvOffset),
DocDvLen: uint64(dvSize),
})
// PostingsList returns the postings list for the specified term
func (d *Dictionary) PostingsList(term string, except *roaring.Bitmap) (segment.PostingsList, error) {
- return d.postingsList([]byte(term), except)
+ return d.postingsList([]byte(term), except, nil)
}
-func (d *Dictionary) postingsList(term []byte, except *roaring.Bitmap) (*PostingsList, error) {
- rv := &PostingsList{
- sb: d.sb,
- term: term,
- except: except,
+func (d *Dictionary) postingsList(term []byte, except *roaring.Bitmap, rv *PostingsList) (*PostingsList, error) {
+ if d.fst == nil {
+ return d.postingsListInit(rv, except), nil
}
- if d.fst != nil {
- postingsOffset, exists, err := d.fst.Get(term)
- if err != nil {
- return nil, fmt.Errorf("vellum err: %v", err)
- }
- if exists {
- err = rv.read(postingsOffset, d)
- if err != nil {
- return nil, err
- }
- }
+ postingsOffset, exists, err := d.fst.Get(term)
+ if err != nil {
+ return nil, fmt.Errorf("vellum err: %v", err)
+ }
+ if !exists {
+ return d.postingsListInit(rv, except), nil
+ }
+
+ return d.postingsListFromOffset(postingsOffset, except, rv)
+}
+
+func (d *Dictionary) postingsListFromOffset(postingsOffset uint64, except *roaring.Bitmap, rv *PostingsList) (*PostingsList, error) {
+ rv = d.postingsListInit(rv, except)
+
+ err := rv.read(postingsOffset, d)
+ if err != nil {
+ return nil, err
}
return rv, nil
}
+func (d *Dictionary) postingsListInit(rv *PostingsList, except *roaring.Bitmap) *PostingsList {
+ if rv == nil {
+ rv = &PostingsList{}
+ } else {
+ *rv = PostingsList{} // clear the struct
+ }
+ rv.sb = d.sb
+ rv.except = except
+ return rv
+}
+
// Iterator returns an iterator for this dictionary
func (d *Dictionary) Iterator() segment.DictionaryIterator {
rv := &DictionaryIterator{
func (di *docValueIterator) loadDvChunk(chunkNumber,
localDocNum uint64, s *SegmentBase) error {
// advance to the chunk where the docValues
- // reside for the given docID
+ // reside for the given docNum
destChunkDataLoc := di.dvDataLoc
for i := 0; i < int(chunkNumber); i++ {
destChunkDataLoc += di.chunkLens[i]
offset := uint64(0)
di.curChunkHeader = make([]MetaData, int(numDocs))
for i := 0; i < int(numDocs); i++ {
- di.curChunkHeader[i].DocID, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64])
+ di.curChunkHeader[i].DocNum, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64])
offset += uint64(read)
di.curChunkHeader[i].DocDvLoc, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64])
offset += uint64(read)
return nil
}
-func (di *docValueIterator) visitDocValues(docID uint64,
+func (di *docValueIterator) visitDocValues(docNum uint64,
visitor index.DocumentFieldTermVisitor) error {
- // binary search the term locations for the docID
- start, length := di.getDocValueLocs(docID)
+ // binary search the term locations for the docNum
+ start, length := di.getDocValueLocs(docNum)
if start == math.MaxUint64 || length == math.MaxUint64 {
return nil
}
return err
}
- // pick the terms for the given docID
+ // pick the terms for the given docNum
uncompressed = uncompressed[start : start+length]
for {
i := bytes.Index(uncompressed, termSeparatorSplitSlice)
return nil
}
-func (di *docValueIterator) getDocValueLocs(docID uint64) (uint64, uint64) {
+func (di *docValueIterator) getDocValueLocs(docNum uint64) (uint64, uint64) {
i := sort.Search(len(di.curChunkHeader), func(i int) bool {
- return di.curChunkHeader[i].DocID >= docID
+ return di.curChunkHeader[i].DocNum >= docNum
})
- if i < len(di.curChunkHeader) && di.curChunkHeader[i].DocID == docID {
+ if i < len(di.curChunkHeader) && di.curChunkHeader[i].DocNum == docNum {
return di.curChunkHeader[i].DocDvLoc, di.curChunkHeader[i].DocDvLen
}
return math.MaxUint64, math.MaxUint64
--- /dev/null
+// Copyright (c) 2018 Couchbase, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package zap
+
+import (
+ "bytes"
+
+ "github.com/couchbase/vellum"
+)
+
+// enumerator provides an ordered traversal of multiple vellum
+// iterators. Like JOIN of iterators, the enumerator produces a
+// sequence of (key, iteratorIndex, value) tuples, sorted by key ASC,
+// then iteratorIndex ASC, where the same key might be seen or
+// repeated across multiple child iterators.
+type enumerator struct {
+ itrs []vellum.Iterator
+ currKs [][]byte
+ currVs []uint64
+
+ lowK []byte
+ lowIdxs []int
+ lowCurr int
+}
+
+// newEnumerator returns a new enumerator over the vellum Iterators
+func newEnumerator(itrs []vellum.Iterator) (*enumerator, error) {
+ rv := &enumerator{
+ itrs: itrs,
+ currKs: make([][]byte, len(itrs)),
+ currVs: make([]uint64, len(itrs)),
+ lowIdxs: make([]int, 0, len(itrs)),
+ }
+ for i, itr := range rv.itrs {
+ rv.currKs[i], rv.currVs[i] = itr.Current()
+ }
+ rv.updateMatches()
+ if rv.lowK == nil {
+ return rv, vellum.ErrIteratorDone
+ }
+ return rv, nil
+}
+
+// updateMatches maintains the low key matches based on the currKs
+func (m *enumerator) updateMatches() {
+ m.lowK = nil
+ m.lowIdxs = m.lowIdxs[:0]
+ m.lowCurr = 0
+
+ for i, key := range m.currKs {
+ if key == nil {
+ continue
+ }
+
+ cmp := bytes.Compare(key, m.lowK)
+ if cmp < 0 || m.lowK == nil {
+ // reached a new low
+ m.lowK = key
+ m.lowIdxs = m.lowIdxs[:0]
+ m.lowIdxs = append(m.lowIdxs, i)
+ } else if cmp == 0 {
+ m.lowIdxs = append(m.lowIdxs, i)
+ }
+ }
+}
+
+// Current returns the enumerator's current key, iterator-index, and
+// value. If the enumerator is not pointing at a valid value (because
+// Next returned an error previously), Current will return nil,0,0.
+func (m *enumerator) Current() ([]byte, int, uint64) {
+ var i int
+ var v uint64
+ if m.lowCurr < len(m.lowIdxs) {
+ i = m.lowIdxs[m.lowCurr]
+ v = m.currVs[i]
+ }
+ return m.lowK, i, v
+}
+
+// Next advances the enumerator to the next key/iterator/value result,
+// else vellum.ErrIteratorDone is returned.
+func (m *enumerator) Next() error {
+ m.lowCurr += 1
+ if m.lowCurr >= len(m.lowIdxs) {
+ // move all the current low iterators forwards
+ for _, vi := range m.lowIdxs {
+ err := m.itrs[vi].Next()
+ if err != nil && err != vellum.ErrIteratorDone {
+ return err
+ }
+ m.currKs[vi], m.currVs[vi] = m.itrs[vi].Current()
+ }
+ m.updateMatches()
+ }
+ if m.lowK == nil {
+ return vellum.ErrIteratorDone
+ }
+ return nil
+}
+
+// Close all the underlying Iterators. The first error, if any, will
+// be returned.
+func (m *enumerator) Close() error {
+ var rv error
+ for _, itr := range m.itrs {
+ err := itr.Close()
+ if rv == nil {
+ rv = err
+ }
+ }
+ return rv
+}
encoder *govarint.Base128Encoder
chunkLens []uint64
currChunk uint64
+
+ buf []byte
}
// newChunkedIntCoder returns a new chunk int coder which packs data into
// starting a new chunk
if c.encoder != nil {
// close out last
- c.encoder.Close()
- encodingBytes := c.chunkBuf.Bytes()
- c.chunkLens[c.currChunk] = uint64(len(encodingBytes))
- c.final = append(c.final, encodingBytes...)
+ c.Close()
c.chunkBuf.Reset()
- c.encoder = govarint.NewU64Base128Encoder(&c.chunkBuf)
}
c.currChunk = chunk
}
// Write commits all the encoded chunked integers to the provided writer.
func (c *chunkedIntCoder) Write(w io.Writer) (int, error) {
- var tw int
- buf := make([]byte, binary.MaxVarintLen64)
- // write out the number of chunks
+ bufNeeded := binary.MaxVarintLen64 * (1 + len(c.chunkLens))
+ if len(c.buf) < bufNeeded {
+ c.buf = make([]byte, bufNeeded)
+ }
+ buf := c.buf
+
+ // write out the number of chunks & each chunkLen
n := binary.PutUvarint(buf, uint64(len(c.chunkLens)))
- nw, err := w.Write(buf[:n])
- tw += nw
+ for _, chunkLen := range c.chunkLens {
+ n += binary.PutUvarint(buf[n:], uint64(chunkLen))
+ }
+
+ tw, err := w.Write(buf[:n])
if err != nil {
return tw, err
}
- // write out the chunk lens
- for _, chunkLen := range c.chunkLens {
- n := binary.PutUvarint(buf, uint64(chunkLen))
- nw, err = w.Write(buf[:n])
- tw += nw
- if err != nil {
- return tw, err
- }
- }
+
// write out the data
- nw, err = w.Write(c.final)
+ nw, err := w.Write(c.final)
tw += nw
if err != nil {
return tw, err
"fmt"
"math"
"os"
+ "sort"
"github.com/RoaringBitmap/roaring"
"github.com/Smerity/govarint"
"github.com/golang/snappy"
)
+const docDropped = math.MaxUint64 // sentinel docNum to represent a deleted doc
+
// Merge takes a slice of zap segments and bit masks describing which
// documents may be dropped, and creates a new segment containing the
// remaining data. This new segment is built at the specified path,
_ = os.Remove(path)
}
+ segmentBases := make([]*SegmentBase, len(segments))
+ for segmenti, segment := range segments {
+ segmentBases[segmenti] = &segment.SegmentBase
+ }
+
// buffer the output
br := bufio.NewWriter(f)
// wrap it for counting (tracking offsets)
cr := NewCountHashWriter(br)
- fieldsInv := mergeFields(segments)
- fieldsMap := mapFields(fieldsInv)
-
- var newDocNums [][]uint64
- var storedIndexOffset uint64
- fieldDvLocsOffset := uint64(fieldNotUninverted)
- var dictLocs []uint64
-
- newSegDocCount := computeNewDocCount(segments, drops)
- if newSegDocCount > 0 {
- storedIndexOffset, newDocNums, err = mergeStoredAndRemap(segments, drops,
- fieldsMap, fieldsInv, newSegDocCount, cr)
- if err != nil {
- cleanup()
- return nil, err
- }
-
- dictLocs, fieldDvLocsOffset, err = persistMergedRest(segments, drops, fieldsInv, fieldsMap,
- newDocNums, newSegDocCount, chunkFactor, cr)
- if err != nil {
- cleanup()
- return nil, err
- }
- } else {
- dictLocs = make([]uint64, len(fieldsInv))
- }
-
- fieldsIndexOffset, err := persistFields(fieldsInv, cr, dictLocs)
+ newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset, _, _, _, err :=
+ MergeToWriter(segmentBases, drops, chunkFactor, cr)
if err != nil {
cleanup()
return nil, err
}
- err = persistFooter(newSegDocCount, storedIndexOffset,
- fieldsIndexOffset, fieldDvLocsOffset, chunkFactor, cr.Sum32(), cr)
+ err = persistFooter(numDocs, storedIndexOffset, fieldsIndexOffset,
+ docValueOffset, chunkFactor, cr.Sum32(), cr)
if err != nil {
cleanup()
return nil, err
return newDocNums, nil
}
-// mapFields takes the fieldsInv list and builds the map
+func MergeToWriter(segments []*SegmentBase, drops []*roaring.Bitmap,
+ chunkFactor uint32, cr *CountHashWriter) (
+ newDocNums [][]uint64,
+ numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset uint64,
+ dictLocs []uint64, fieldsInv []string, fieldsMap map[string]uint16,
+ err error) {
+ docValueOffset = uint64(fieldNotUninverted)
+
+ var fieldsSame bool
+ fieldsSame, fieldsInv = mergeFields(segments)
+ fieldsMap = mapFields(fieldsInv)
+
+ numDocs = computeNewDocCount(segments, drops)
+ if numDocs > 0 {
+ storedIndexOffset, newDocNums, err = mergeStoredAndRemap(segments, drops,
+ fieldsMap, fieldsInv, fieldsSame, numDocs, cr)
+ if err != nil {
+ return nil, 0, 0, 0, 0, nil, nil, nil, err
+ }
+
+ dictLocs, docValueOffset, err = persistMergedRest(segments, drops, fieldsInv, fieldsMap,
+ newDocNums, numDocs, chunkFactor, cr)
+ if err != nil {
+ return nil, 0, 0, 0, 0, nil, nil, nil, err
+ }
+ } else {
+ dictLocs = make([]uint64, len(fieldsInv))
+ }
+
+ fieldsIndexOffset, err = persistFields(fieldsInv, cr, dictLocs)
+ if err != nil {
+ return nil, 0, 0, 0, 0, nil, nil, nil, err
+ }
+
+ return newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset, dictLocs, fieldsInv, fieldsMap, nil
+}
+
+// mapFields takes the fieldsInv list and returns a map of fieldName
+// to fieldID+1
func mapFields(fields []string) map[string]uint16 {
rv := make(map[string]uint16, len(fields))
for i, fieldName := range fields {
- rv[fieldName] = uint16(i)
+ rv[fieldName] = uint16(i) + 1
}
return rv
}
// computeNewDocCount determines how many documents will be in the newly
// merged segment when obsoleted docs are dropped
-func computeNewDocCount(segments []*Segment, drops []*roaring.Bitmap) uint64 {
+func computeNewDocCount(segments []*SegmentBase, drops []*roaring.Bitmap) uint64 {
var newDocCount uint64
for segI, segment := range segments {
- newDocCount += segment.NumDocs()
+ newDocCount += segment.numDocs
if drops[segI] != nil {
newDocCount -= drops[segI].GetCardinality()
}
return newDocCount
}
-func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
- fieldsInv []string, fieldsMap map[string]uint16, newDocNums [][]uint64,
+func persistMergedRest(segments []*SegmentBase, dropsIn []*roaring.Bitmap,
+ fieldsInv []string, fieldsMap map[string]uint16, newDocNumsIn [][]uint64,
newSegDocCount uint64, chunkFactor uint32,
w *CountHashWriter) ([]uint64, uint64, error) {
var bufMaxVarintLen64 []byte = make([]byte, binary.MaxVarintLen64)
var bufLoc []uint64
+ var postings *PostingsList
+ var postItr *PostingsIterator
+
rv := make([]uint64, len(fieldsInv))
fieldDvLocs := make([]uint64, len(fieldsInv))
- fieldDvLocsOffset := uint64(fieldNotUninverted)
+
+ tfEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
+ locEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
// docTermMap is keyed by docNum, where the array impl provides
// better memory usage behavior than a sparse-friendlier hashmap
return nil, 0, err
}
- // collect FST iterators from all segments for this field
+ // collect FST iterators from all active segments for this field
+ var newDocNums [][]uint64
+ var drops []*roaring.Bitmap
var dicts []*Dictionary
var itrs []vellum.Iterator
- for _, segment := range segments {
+
+ for segmentI, segment := range segments {
dict, err2 := segment.dictionary(fieldName)
if err2 != nil {
return nil, 0, err2
}
- dicts = append(dicts, dict)
-
if dict != nil && dict.fst != nil {
itr, err2 := dict.fst.Iterator(nil, nil)
if err2 != nil && err2 != vellum.ErrIteratorDone {
return nil, 0, err2
}
if itr != nil {
+ newDocNums = append(newDocNums, newDocNumsIn[segmentI])
+ drops = append(drops, dropsIn[segmentI])
+ dicts = append(dicts, dict)
itrs = append(itrs, itr)
}
}
}
- // create merging iterator
- mergeItr, err := vellum.NewMergeIterator(itrs, func(postingOffsets []uint64) uint64 {
- // we don't actually use the merged value
- return 0
- })
-
- tfEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
- locEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
-
if uint64(cap(docTermMap)) < newSegDocCount {
docTermMap = make([][]byte, newSegDocCount)
} else {
}
}
- for err == nil {
- term, _ := mergeItr.Current()
-
- newRoaring := roaring.NewBitmap()
- newRoaringLocs := roaring.NewBitmap()
-
- tfEncoder.Reset()
- locEncoder.Reset()
-
- // now go back and get posting list for this term
- // but pass in the deleted docs for that segment
- for dictI, dict := range dicts {
- if dict == nil {
- continue
- }
- postings, err2 := dict.postingsList(term, drops[dictI])
- if err2 != nil {
- return nil, 0, err2
- }
-
- postItr := postings.Iterator()
- next, err2 := postItr.Next()
- for next != nil && err2 == nil {
- hitNewDocNum := newDocNums[dictI][next.Number()]
- if hitNewDocNum == docDropped {
- return nil, 0, fmt.Errorf("see hit with dropped doc num")
- }
- newRoaring.Add(uint32(hitNewDocNum))
- // encode norm bits
- norm := next.Norm()
- normBits := math.Float32bits(float32(norm))
- err = tfEncoder.Add(hitNewDocNum, next.Frequency(), uint64(normBits))
- if err != nil {
- return nil, 0, err
- }
- locs := next.Locations()
- if len(locs) > 0 {
- newRoaringLocs.Add(uint32(hitNewDocNum))
- for _, loc := range locs {
- if cap(bufLoc) < 5+len(loc.ArrayPositions()) {
- bufLoc = make([]uint64, 0, 5+len(loc.ArrayPositions()))
- }
- args := bufLoc[0:5]
- args[0] = uint64(fieldsMap[loc.Field()])
- args[1] = loc.Pos()
- args[2] = loc.Start()
- args[3] = loc.End()
- args[4] = uint64(len(loc.ArrayPositions()))
- args = append(args, loc.ArrayPositions()...)
- err = locEncoder.Add(hitNewDocNum, args...)
- if err != nil {
- return nil, 0, err
- }
- }
- }
+ var prevTerm []byte
- docTermMap[hitNewDocNum] =
- append(append(docTermMap[hitNewDocNum], term...), termSeparator)
+ newRoaring := roaring.NewBitmap()
+ newRoaringLocs := roaring.NewBitmap()
- next, err2 = postItr.Next()
- }
- if err2 != nil {
- return nil, 0, err2
- }
+ finishTerm := func(term []byte) error {
+ if term == nil {
+ return nil
}
tfEncoder.Close()
if newRoaring.GetCardinality() > 0 {
// this field/term actually has hits in the new segment, lets write it down
freqOffset := uint64(w.Count())
- _, err = tfEncoder.Write(w)
+ _, err := tfEncoder.Write(w)
if err != nil {
- return nil, 0, err
+ return err
}
locOffset := uint64(w.Count())
_, err = locEncoder.Write(w)
if err != nil {
- return nil, 0, err
+ return err
}
postingLocOffset := uint64(w.Count())
_, err = writeRoaringWithLen(newRoaringLocs, w, &bufReuse, bufMaxVarintLen64)
if err != nil {
- return nil, 0, err
+ return err
}
postingOffset := uint64(w.Count())
+
// write out the start of the term info
- buf := bufMaxVarintLen64
- n := binary.PutUvarint(buf, freqOffset)
- _, err = w.Write(buf[:n])
+ n := binary.PutUvarint(bufMaxVarintLen64, freqOffset)
+ _, err = w.Write(bufMaxVarintLen64[:n])
if err != nil {
- return nil, 0, err
+ return err
}
-
// write out the start of the loc info
- n = binary.PutUvarint(buf, locOffset)
- _, err = w.Write(buf[:n])
+ n = binary.PutUvarint(bufMaxVarintLen64, locOffset)
+ _, err = w.Write(bufMaxVarintLen64[:n])
if err != nil {
- return nil, 0, err
+ return err
}
-
- // write out the start of the loc posting list
- n = binary.PutUvarint(buf, postingLocOffset)
- _, err = w.Write(buf[:n])
+ // write out the start of the posting locs
+ n = binary.PutUvarint(bufMaxVarintLen64, postingLocOffset)
+ _, err = w.Write(bufMaxVarintLen64[:n])
if err != nil {
- return nil, 0, err
+ return err
}
_, err = writeRoaringWithLen(newRoaring, w, &bufReuse, bufMaxVarintLen64)
if err != nil {
- return nil, 0, err
+ return err
}
err = newVellum.Insert(term, postingOffset)
+ if err != nil {
+ return err
+ }
+ }
+
+ newRoaring = roaring.NewBitmap()
+ newRoaringLocs = roaring.NewBitmap()
+
+ tfEncoder.Reset()
+ locEncoder.Reset()
+
+ return nil
+ }
+
+ enumerator, err := newEnumerator(itrs)
+
+ for err == nil {
+ term, itrI, postingsOffset := enumerator.Current()
+
+ if !bytes.Equal(prevTerm, term) {
+ // if the term changed, write out the info collected
+ // for the previous term
+ err2 := finishTerm(prevTerm)
+ if err2 != nil {
+ return nil, 0, err2
+ }
+ }
+
+ var err2 error
+ postings, err2 = dicts[itrI].postingsListFromOffset(
+ postingsOffset, drops[itrI], postings)
+ if err2 != nil {
+ return nil, 0, err2
+ }
+
+ newDocNumsI := newDocNums[itrI]
+
+ postItr = postings.iterator(postItr)
+ next, err2 := postItr.Next()
+ for next != nil && err2 == nil {
+ hitNewDocNum := newDocNumsI[next.Number()]
+ if hitNewDocNum == docDropped {
+ return nil, 0, fmt.Errorf("see hit with dropped doc num")
+ }
+ newRoaring.Add(uint32(hitNewDocNum))
+ // encode norm bits
+ norm := next.Norm()
+ normBits := math.Float32bits(float32(norm))
+ err = tfEncoder.Add(hitNewDocNum, next.Frequency(), uint64(normBits))
if err != nil {
return nil, 0, err
}
+ locs := next.Locations()
+ if len(locs) > 0 {
+ newRoaringLocs.Add(uint32(hitNewDocNum))
+ for _, loc := range locs {
+ if cap(bufLoc) < 5+len(loc.ArrayPositions()) {
+ bufLoc = make([]uint64, 0, 5+len(loc.ArrayPositions()))
+ }
+ args := bufLoc[0:5]
+ args[0] = uint64(fieldsMap[loc.Field()] - 1)
+ args[1] = loc.Pos()
+ args[2] = loc.Start()
+ args[3] = loc.End()
+ args[4] = uint64(len(loc.ArrayPositions()))
+ args = append(args, loc.ArrayPositions()...)
+ err = locEncoder.Add(hitNewDocNum, args...)
+ if err != nil {
+ return nil, 0, err
+ }
+ }
+ }
+
+ docTermMap[hitNewDocNum] =
+ append(append(docTermMap[hitNewDocNum], term...), termSeparator)
+
+ next, err2 = postItr.Next()
+ }
+ if err2 != nil {
+ return nil, 0, err2
}
- err = mergeItr.Next()
+ prevTerm = prevTerm[:0] // copy to prevTerm in case Next() reuses term mem
+ prevTerm = append(prevTerm, term...)
+
+ err = enumerator.Next()
}
if err != nil && err != vellum.ErrIteratorDone {
return nil, 0, err
}
+ err = finishTerm(prevTerm)
+ if err != nil {
+ return nil, 0, err
+ }
+
dictOffset := uint64(w.Count())
err = newVellum.Close()
}
}
- fieldDvLocsOffset = uint64(w.Count())
+ fieldDvLocsOffset := uint64(w.Count())
buf := bufMaxVarintLen64
for _, offset := range fieldDvLocs {
return rv, fieldDvLocsOffset, nil
}
-const docDropped = math.MaxUint64
-
-func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
- fieldsMap map[string]uint16, fieldsInv []string, newSegDocCount uint64,
+func mergeStoredAndRemap(segments []*SegmentBase, drops []*roaring.Bitmap,
+ fieldsMap map[string]uint16, fieldsInv []string, fieldsSame bool, newSegDocCount uint64,
w *CountHashWriter) (uint64, [][]uint64, error) {
var rv [][]uint64 // The remapped or newDocNums for each segment.
for segI, segment := range segments {
segNewDocNums := make([]uint64, segment.numDocs)
+ dropsI := drops[segI]
+
+ // optimize when the field mapping is the same across all
+ // segments and there are no deletions, via byte-copying
+ // of stored docs bytes directly to the writer
+ if fieldsSame && (dropsI == nil || dropsI.GetCardinality() == 0) {
+ err := segment.copyStoredDocs(newDocNum, docNumOffsets, w)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ for i := uint64(0); i < segment.numDocs; i++ {
+ segNewDocNums[i] = newDocNum
+ newDocNum++
+ }
+ rv = append(rv, segNewDocNums)
+
+ continue
+ }
+
// for each doc num
for docNum := uint64(0); docNum < segment.numDocs; docNum++ {
// TODO: roaring's API limits docNums to 32-bits?
- if drops[segI] != nil && drops[segI].Contains(uint32(docNum)) {
+ if dropsI != nil && dropsI.Contains(uint32(docNum)) {
segNewDocNums[docNum] = docDropped
continue
}
poss[i] = poss[i][:0]
}
err := segment.VisitDocument(docNum, func(field string, typ byte, value []byte, pos []uint64) bool {
- fieldID := int(fieldsMap[field])
+ fieldID := int(fieldsMap[field]) - 1
vals[fieldID] = append(vals[fieldID], value)
typs[fieldID] = append(typs[fieldID], typ)
poss[fieldID] = append(poss[fieldID], pos)
for fieldID := range fieldsInv {
storedFieldValues := vals[int(fieldID)]
- // has stored values for this field
- num := len(storedFieldValues)
+ stf := typs[int(fieldID)]
+ spf := poss[int(fieldID)]
- // process each value
- for i := 0; i < num; i++ {
- // encode field
- _, err2 := metaEncoder.PutU64(uint64(fieldID))
- if err2 != nil {
- return 0, nil, err2
- }
- // encode type
- _, err2 = metaEncoder.PutU64(uint64(typs[int(fieldID)][i]))
- if err2 != nil {
- return 0, nil, err2
- }
- // encode start offset
- _, err2 = metaEncoder.PutU64(uint64(curr))
- if err2 != nil {
- return 0, nil, err2
- }
- // end len
- _, err2 = metaEncoder.PutU64(uint64(len(storedFieldValues[i])))
- if err2 != nil {
- return 0, nil, err2
- }
- // encode number of array pos
- _, err2 = metaEncoder.PutU64(uint64(len(poss[int(fieldID)][i])))
- if err2 != nil {
- return 0, nil, err2
- }
- // encode all array positions
- for j := 0; j < len(poss[int(fieldID)][i]); j++ {
- _, err2 = metaEncoder.PutU64(poss[int(fieldID)][i][j])
- if err2 != nil {
- return 0, nil, err2
- }
- }
- // append data
- data = append(data, storedFieldValues[i]...)
- // update curr
- curr += len(storedFieldValues[i])
+ var err2 error
+ curr, data, err2 = persistStoredFieldValues(fieldID,
+ storedFieldValues, stf, spf, curr, metaEncoder, data)
+ if err2 != nil {
+ return 0, nil, err2
}
}
}
// return value is the start of the stored index
- offset := uint64(w.Count())
+ storedIndexOffset := uint64(w.Count())
// now write out the stored doc index
- for docNum := range docNumOffsets {
- err := binary.Write(w, binary.BigEndian, docNumOffsets[docNum])
+ for _, docNumOffset := range docNumOffsets {
+ err := binary.Write(w, binary.BigEndian, docNumOffset)
if err != nil {
return 0, nil, err
}
}
- return offset, rv, nil
+ return storedIndexOffset, rv, nil
}
-// mergeFields builds a unified list of fields used across all the input segments
-func mergeFields(segments []*Segment) []string {
- fieldsMap := map[string]struct{}{}
+// copyStoredDocs writes out a segment's stored doc info, optimized by
+// using a single Write() call for the entire set of bytes. The
+// newDocNumOffsets is filled with the new offsets for each doc.
+func (s *SegmentBase) copyStoredDocs(newDocNum uint64, newDocNumOffsets []uint64,
+ w *CountHashWriter) error {
+ if s.numDocs <= 0 {
+ return nil
+ }
+
+ indexOffset0, storedOffset0, _, _, _ :=
+ s.getDocStoredOffsets(0) // the segment's first doc
+
+ indexOffsetN, storedOffsetN, readN, metaLenN, dataLenN :=
+ s.getDocStoredOffsets(s.numDocs - 1) // the segment's last doc
+
+ storedOffset0New := uint64(w.Count())
+
+ storedBytes := s.mem[storedOffset0 : storedOffsetN+readN+metaLenN+dataLenN]
+ _, err := w.Write(storedBytes)
+ if err != nil {
+ return err
+ }
+
+ // remap the storedOffset's for the docs into new offsets relative
+ // to storedOffset0New, filling the given docNumOffsetsOut array
+ for indexOffset := indexOffset0; indexOffset <= indexOffsetN; indexOffset += 8 {
+ storedOffset := binary.BigEndian.Uint64(s.mem[indexOffset : indexOffset+8])
+ storedOffsetNew := storedOffset - storedOffset0 + storedOffset0New
+ newDocNumOffsets[newDocNum] = storedOffsetNew
+ newDocNum += 1
+ }
+
+ return nil
+}
+
+// mergeFields builds a unified list of fields used across all the
+// input segments, and computes whether the fields are the same across
+// segments (which depends on fields to be sorted in the same way
+// across segments)
+func mergeFields(segments []*SegmentBase) (bool, []string) {
+ fieldsSame := true
+
+ var segment0Fields []string
+ if len(segments) > 0 {
+ segment0Fields = segments[0].Fields()
+ }
+
+ fieldsExist := map[string]struct{}{}
for _, segment := range segments {
fields := segment.Fields()
- for _, field := range fields {
- fieldsMap[field] = struct{}{}
+ for fieldi, field := range fields {
+ fieldsExist[field] = struct{}{}
+ if len(segment0Fields) != len(fields) || segment0Fields[fieldi] != field {
+ fieldsSame = false
+ }
}
}
- rv := make([]string, 0, len(fieldsMap))
+ rv := make([]string, 0, len(fieldsExist))
// ensure _id stays first
rv = append(rv, "_id")
- for k := range fieldsMap {
+ for k := range fieldsExist {
if k != "_id" {
rv = append(rv, k)
}
}
- return rv
+
+ sort.Strings(rv[1:]) // leave _id as first
+
+ return fieldsSame, rv
}
// PostingsList is an in-memory represenation of a postings list
type PostingsList struct {
sb *SegmentBase
- term []byte
postingsOffset uint64
freqOffset uint64
locOffset uint64
locBitmap *roaring.Bitmap
postings *roaring.Bitmap
except *roaring.Bitmap
- postingKey []byte
}
// Iterator returns an iterator for this postings list
func (p *PostingsList) Iterator() segment.PostingsIterator {
- rv := &PostingsIterator{
- postings: p,
+ return p.iterator(nil)
+}
+
+func (p *PostingsList) iterator(rv *PostingsIterator) *PostingsIterator {
+ if rv == nil {
+ rv = &PostingsIterator{}
+ } else {
+ *rv = PostingsIterator{} // clear the struct
}
+ rv.postings = p
+
if p.postings != nil {
// prepare the freq chunk details
var n uint64
import "encoding/binary"
func (s *SegmentBase) getDocStoredMetaAndCompressed(docNum uint64) ([]byte, []byte) {
- docStoredStartAddr := s.storedIndexOffset + (8 * docNum)
- docStoredStart := binary.BigEndian.Uint64(s.mem[docStoredStartAddr : docStoredStartAddr+8])
+ _, storedOffset, n, metaLen, dataLen := s.getDocStoredOffsets(docNum)
+
+ meta := s.mem[storedOffset+n : storedOffset+n+metaLen]
+ data := s.mem[storedOffset+n+metaLen : storedOffset+n+metaLen+dataLen]
+
+ return meta, data
+}
+
+func (s *SegmentBase) getDocStoredOffsets(docNum uint64) (
+ uint64, uint64, uint64, uint64, uint64) {
+ indexOffset := s.storedIndexOffset + (8 * docNum)
+
+ storedOffset := binary.BigEndian.Uint64(s.mem[indexOffset : indexOffset+8])
+
var n uint64
- metaLen, read := binary.Uvarint(s.mem[docStoredStart : docStoredStart+binary.MaxVarintLen64])
+
+ metaLen, read := binary.Uvarint(s.mem[storedOffset : storedOffset+binary.MaxVarintLen64])
n += uint64(read)
- var dataLen uint64
- dataLen, read = binary.Uvarint(s.mem[docStoredStart+n : docStoredStart+n+binary.MaxVarintLen64])
+
+ dataLen, read := binary.Uvarint(s.mem[storedOffset+n : storedOffset+n+binary.MaxVarintLen64])
n += uint64(read)
- meta := s.mem[docStoredStart+n : docStoredStart+n+metaLen]
- data := s.mem[docStoredStart+n+metaLen : docStoredStart+n+metaLen+dataLen]
- return meta, data
+
+ return indexOffset, storedOffset, n, metaLen, dataLen
}
return nil, err
}
+ var postings *PostingsList
for _, id := range ids {
- postings, err := idDict.postingsList([]byte(id), nil)
+ postings, err = idDict.postingsList([]byte(id), nil, postings)
if err != nil {
return nil, err
}
return r.meta[string(key)]
}
-// RollbackPoints returns an array of rollback points available
-// for the application to make a decision on where to rollback
-// to. A nil return value indicates that there are no available
-// rollback points.
+// RollbackPoints returns an array of rollback points available for
+// the application to rollback to, with more recent rollback points
+// (higher epochs) coming first.
func (s *Scorch) RollbackPoints() ([]*RollbackPoint, error) {
if s.rootBolt == nil {
return nil, fmt.Errorf("RollbackPoints: root is nil")
snapshots := tx.Bucket(boltSnapshotsBucket)
if snapshots == nil {
- return nil, fmt.Errorf("RollbackPoints: no snapshots available")
+ return nil, nil
}
rollbackPoints := []*RollbackPoint{}
revert.snapshot = indexSnapshot
revert.applied = make(chan error)
-
- if !s.unsafeBatch {
- revert.persisted = make(chan error)
- }
+ revert.persisted = make(chan error)
return nil
})
return fmt.Errorf("Rollback: failed with err: %v", err)
}
- if revert.persisted != nil {
- err = <-revert.persisted
- }
-
- return err
+ return <-revert.persisted
}
docBackIndexRowErr = err
return
}
+ defer func() {
+ if cerr := kvreader.Close(); err == nil && cerr != nil {
+ docBackIndexRowErr = cerr
+ }
+ }()
for docID, doc := range batch.IndexOps {
backIndexRow, err := backIndexRowForDoc(kvreader, index.IndexInternalID(docID))
docBackIndexRowCh <- &docBackIndexRow{docID, doc, backIndexRow}
}
-
- err = kvreader.Close()
- if err != nil {
- docBackIndexRowErr = err
- return
- }
}()
// wait for analysis result
package bleve
import (
+ "context"
"sort"
"sync"
"time"
- "golang.org/x/net/context"
-
"github.com/blevesearch/bleve/document"
"github.com/blevesearch/bleve/index"
"github.com/blevesearch/bleve/index/store"
package bleve
import (
+ "context"
"encoding/json"
"fmt"
"os"
"sync/atomic"
"time"
- "golang.org/x/net/context"
-
"github.com/blevesearch/bleve/document"
"github.com/blevesearch/bleve/index"
"github.com/blevesearch/bleve/index/store"
package search
import (
+ "context"
"time"
"github.com/blevesearch/bleve/index"
-
- "golang.org/x/net/context"
)
type Collector interface {
package collector
import (
+ "context"
"time"
"github.com/blevesearch/bleve/index"
"github.com/blevesearch/bleve/search"
- "golang.org/x/net/context"
)
type collectorStore interface {
package goth
import (
+ "context"
"fmt"
"net/http"
- "golang.org/x/net/context"
"golang.org/x/oauth2"
)
import (
"bytes"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/hex"
"encoding/json"
"errors"
+ "fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
+ "strings"
- "crypto/hmac"
- "crypto/sha256"
- "encoding/hex"
- "fmt"
"github.com/markbates/goth"
"golang.org/x/oauth2"
)
const (
authURL string = "https://www.facebook.com/dialog/oauth"
tokenURL string = "https://graph.facebook.com/oauth/access_token"
- endpointProfile string = "https://graph.facebook.com/me?fields=email,first_name,last_name,link,about,id,name,picture,location"
+ endpointProfile string = "https://graph.facebook.com/me?fields="
)
// New creates a new Facebook provider, and sets up important connection details.
// BeginAuth asks Facebook for an authentication end-point.
func (p *Provider) BeginAuth(state string) (goth.Session, error) {
- url := p.config.AuthCodeURL(state)
+ authUrl := p.config.AuthCodeURL(state)
session := &Session{
- AuthURL: url,
+ AuthURL: authUrl,
}
return session, nil
}
hash.Write([]byte(sess.AccessToken))
appsecretProof := hex.EncodeToString(hash.Sum(nil))
- response, err := p.Client().Get(endpointProfile + "&access_token=" + url.QueryEscape(sess.AccessToken) + "&appsecret_proof=" + appsecretProof)
+ reqUrl := fmt.Sprint(
+ endpointProfile,
+ strings.Join(p.config.Scopes, ","),
+ "&access_token=",
+ url.QueryEscape(sess.AccessToken),
+ "&appsecret_proof=",
+ appsecretProof,
+ )
+ response, err := p.Client().Get(reqUrl)
if err != nil {
return user, err
}
},
Scopes: []string{
"email",
+ "first_name",
+ "last_name",
+ "link",
+ "about",
+ "id",
+ "name",
+ "picture",
+ "location",
},
}
- defaultScopes := map[string]struct{}{
- "email": {},
- }
-
- for _, scope := range scopes {
- if _, exists := defaultScopes[scope]; !exists {
- c.Scopes = append(c.Scopes, scope)
+ // creates possibility to invoke field method like 'picture.type(large)'
+ var found bool
+ for _, sc := range scopes {
+ sc := sc
+ for i, defScope := range c.Scopes {
+ if defScope == strings.Split(sc, ".")[0] {
+ c.Scopes[i] = sc
+ found = true
+ }
+ }
+ if !found {
+ c.Scopes = append(c.Scopes, sc)
}
+ found = false
}
return c
--- /dev/null
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.7
+
+// Package ctxhttp provides helper functions for performing context-aware HTTP requests.
+package ctxhttp // import "golang.org/x/net/context/ctxhttp"
+
+import (
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+
+ "golang.org/x/net/context"
+)
+
+// Do sends an HTTP request with the provided http.Client and returns
+// an HTTP response.
+//
+// If the client is nil, http.DefaultClient is used.
+//
+// The provided ctx must be non-nil. If it is canceled or times out,
+// ctx.Err() will be returned.
+func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) {
+ if client == nil {
+ client = http.DefaultClient
+ }
+ resp, err := client.Do(req.WithContext(ctx))
+ // If we got an error, and the context has been canceled,
+ // the context's error is probably more useful.
+ if err != nil {
+ select {
+ case <-ctx.Done():
+ err = ctx.Err()
+ default:
+ }
+ }
+ return resp, err
+}
+
+// Get issues a GET request via the Do function.
+func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ return nil, err
+ }
+ return Do(ctx, client, req)
+}
+
+// Head issues a HEAD request via the Do function.
+func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
+ req, err := http.NewRequest("HEAD", url, nil)
+ if err != nil {
+ return nil, err
+ }
+ return Do(ctx, client, req)
+}
+
+// Post issues a POST request via the Do function.
+func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) {
+ req, err := http.NewRequest("POST", url, body)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", bodyType)
+ return Do(ctx, client, req)
+}
+
+// PostForm issues a POST request via the Do function.
+func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) {
+ return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
+}
--- /dev/null
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !go1.7
+
+package ctxhttp // import "golang.org/x/net/context/ctxhttp"
+
+import (
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+
+ "golang.org/x/net/context"
+)
+
+func nop() {}
+
+var (
+ testHookContextDoneBeforeHeaders = nop
+ testHookDoReturned = nop
+ testHookDidBodyClose = nop
+)
+
+// Do sends an HTTP request with the provided http.Client and returns an HTTP response.
+// If the client is nil, http.DefaultClient is used.
+// If the context is canceled or times out, ctx.Err() will be returned.
+func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) {
+ if client == nil {
+ client = http.DefaultClient
+ }
+
+ // TODO(djd): Respect any existing value of req.Cancel.
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+
+ type responseAndError struct {
+ resp *http.Response
+ err error
+ }
+ result := make(chan responseAndError, 1)
+
+ // Make local copies of test hooks closed over by goroutines below.
+ // Prevents data races in tests.
+ testHookDoReturned := testHookDoReturned
+ testHookDidBodyClose := testHookDidBodyClose
+
+ go func() {
+ resp, err := client.Do(req)
+ testHookDoReturned()
+ result <- responseAndError{resp, err}
+ }()
+
+ var resp *http.Response
+
+ select {
+ case <-ctx.Done():
+ testHookContextDoneBeforeHeaders()
+ close(cancel)
+ // Clean up after the goroutine calling client.Do:
+ go func() {
+ if r := <-result; r.resp != nil {
+ testHookDidBodyClose()
+ r.resp.Body.Close()
+ }
+ }()
+ return nil, ctx.Err()
+ case r := <-result:
+ var err error
+ resp, err = r.resp, r.err
+ if err != nil {
+ return resp, err
+ }
+ }
+
+ c := make(chan struct{})
+ go func() {
+ select {
+ case <-ctx.Done():
+ close(cancel)
+ case <-c:
+ // The response's Body is closed.
+ }
+ }()
+ resp.Body = ¬ifyingReader{resp.Body, c}
+
+ return resp, nil
+}
+
+// Get issues a GET request via the Do function.
+func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ return nil, err
+ }
+ return Do(ctx, client, req)
+}
+
+// Head issues a HEAD request via the Do function.
+func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
+ req, err := http.NewRequest("HEAD", url, nil)
+ if err != nil {
+ return nil, err
+ }
+ return Do(ctx, client, req)
+}
+
+// Post issues a POST request via the Do function.
+func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) {
+ req, err := http.NewRequest("POST", url, body)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", bodyType)
+ return Do(ctx, client, req)
+}
+
+// PostForm issues a POST request via the Do function.
+func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) {
+ return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
+}
+
+// notifyingReader is an io.ReadCloser that closes the notify channel after
+// Close is called or a Read fails on the underlying ReadCloser.
+type notifyingReader struct {
+ io.ReadCloser
+ notify chan<- struct{}
+}
+
+func (r *notifyingReader) Read(p []byte) (int, error) {
+ n, err := r.ReadCloser.Read(p)
+ if err != nil && r.notify != nil {
+ close(r.notify)
+ r.notify = nil
+ }
+ return n, err
+}
+
+func (r *notifyingReader) Close() error {
+ err := r.ReadCloser.Close()
+ if r.notify != nil {
+ close(r.notify)
+ r.notify = nil
+ }
+ return err
+}
-Copyright (c) 2009 The oauth2 Authors. All rights reserved.
+Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
+++ /dev/null
-// Copyright 2014 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build appengine
-
-// App Engine hooks.
-
-package oauth2
-
-import (
- "net/http"
-
- "golang.org/x/net/context"
- "golang.org/x/oauth2/internal"
- "google.golang.org/appengine/urlfetch"
-)
-
-func init() {
- internal.RegisterContextClientFunc(contextClientAppEngine)
-}
-
-func contextClientAppEngine(ctx context.Context) (*http.Client, error) {
- return urlfetch.Client(ctx), nil
-}
--- /dev/null
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build appengine
+
+package internal
+
+import "google.golang.org/appengine/urlfetch"
+
+func init() {
+ appengineClientHook = urlfetch.Client
+}
--- /dev/null
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package internal contains support packages for oauth2 package.
+package internal
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package internal contains support packages for oauth2 package.
package internal
import (
- "bufio"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
- "io"
- "strings"
)
// ParseKey converts the binary contents of a private key file
if err != nil {
parsedKey, err = x509.ParsePKCS1PrivateKey(key)
if err != nil {
- return nil, fmt.Errorf("private key should be a PEM or plain PKSC1 or PKCS8; parse error: %v", err)
+ return nil, fmt.Errorf("private key should be a PEM or plain PKCS1 or PKCS8; parse error: %v", err)
}
}
parsed, ok := parsedKey.(*rsa.PrivateKey)
}
return parsed, nil
}
-
-func ParseINI(ini io.Reader) (map[string]map[string]string, error) {
- result := map[string]map[string]string{
- "": map[string]string{}, // root section
- }
- scanner := bufio.NewScanner(ini)
- currentSection := ""
- for scanner.Scan() {
- line := strings.TrimSpace(scanner.Text())
- if strings.HasPrefix(line, ";") {
- // comment.
- continue
- }
- if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") {
- currentSection = strings.TrimSpace(line[1 : len(line)-1])
- result[currentSection] = map[string]string{}
- continue
- }
- parts := strings.SplitN(line, "=", 2)
- if len(parts) == 2 && parts[0] != "" {
- result[currentSection][strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
- }
- }
- if err := scanner.Err(); err != nil {
- return nil, fmt.Errorf("error scanning ini: %v", err)
- }
- return result, nil
-}
-
-func CondVal(v string) []string {
- if v == "" {
- return nil
- }
- return []string{v}
-}
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package internal contains support packages for oauth2 package.
package internal
import (
+ "context"
"encoding/json"
+ "errors"
"fmt"
"io"
"io/ioutil"
"strings"
"time"
- "golang.org/x/net/context"
+ "golang.org/x/net/context/ctxhttp"
)
-// Token represents the crendentials used to authorize
+// Token represents the credentials used to authorize
// the requests to access protected resources on the OAuth 2.0
// provider's backend.
//
var brokenAuthHeaderProviders = []string{
"https://accounts.google.com/",
+ "https://api.codeswholesale.com/oauth/token",
"https://api.dropbox.com/",
"https://api.dropboxapi.com/",
"https://api.instagram.com/",
"https://api.pushbullet.com/",
"https://api.soundcloud.com/",
"https://api.twitch.tv/",
+ "https://id.twitch.tv/",
"https://app.box.com/",
+ "https://api.box.com/",
"https://connect.stripe.com/",
+ "https://login.mailchimp.com/",
"https://login.microsoftonline.com/",
"https://login.salesforce.com/",
+ "https://login.windows.net",
+ "https://login.live.com/",
+ "https://login.live-int.com/",
"https://oauth.sandbox.trainingpeaks.com/",
"https://oauth.trainingpeaks.com/",
"https://oauth.vk.com/",
"https://www.strava.com/oauth/",
"https://www.wunderlist.com/oauth/",
"https://api.patreon.com/",
+ "https://sandbox.codeswholesale.com/oauth/token",
+ "https://api.sipgate.com/v1/authorization/oauth",
+ "https://api.medium.com/v1/tokens",
+ "https://log.finalsurge.com/oauth/token",
+ "https://multisport.todaysplan.com.au/rest/oauth/access_token",
+ "https://whats.todaysplan.com.au/rest/oauth/access_token",
+ "https://stackoverflow.com/oauth/access_token",
+ "https://account.health.nokia.com",
+ "https://accounts.zoho.com",
+}
+
+// brokenAuthHeaderDomains lists broken providers that issue dynamic endpoints.
+var brokenAuthHeaderDomains = []string{
+ ".auth0.com",
+ ".force.com",
+ ".myshopify.com",
+ ".okta.com",
+ ".oktapreview.com",
}
func RegisterBrokenAuthHeaderProvider(tokenURL string) {
}
}
+ if u, err := url.Parse(tokenURL); err == nil {
+ for _, s := range brokenAuthHeaderDomains {
+ if strings.HasSuffix(u.Host, s) {
+ return false
+ }
+ }
+ }
+
// Assume the provider implements the spec properly
// otherwise. We can add more exceptions as they're
// discovered. We will _not_ be adding configurable hooks
}
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) {
- hc, err := ContextClient(ctx)
- if err != nil {
- return nil, err
- }
- v.Set("client_id", clientID)
bustedAuth := !providerAuthHeaderWorks(tokenURL)
- if bustedAuth && clientSecret != "" {
- v.Set("client_secret", clientSecret)
+ if bustedAuth {
+ if clientID != "" {
+ v.Set("client_id", clientID)
+ }
+ if clientSecret != "" {
+ v.Set("client_secret", clientSecret)
+ }
}
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode()))
if err != nil {
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if !bustedAuth {
- req.SetBasicAuth(clientID, clientSecret)
+ req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
}
- r, err := hc.Do(req)
+ r, err := ctxhttp.Do(ctx, ContextClient(ctx), req)
if err != nil {
return nil, err
}
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
if code := r.StatusCode; code < 200 || code > 299 {
- return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body)
+ return nil, &RetrieveError{
+ Response: r,
+ Body: body,
+ }
}
var token *Token
if token.RefreshToken == "" {
token.RefreshToken = v.Get("refresh_token")
}
+ if token.AccessToken == "" {
+ return token, errors.New("oauth2: server response missing access_token")
+ }
return token, nil
}
+
+type RetrieveError struct {
+ Response *http.Response
+ Body []byte
+}
+
+func (r *RetrieveError) Error() string {
+ return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
+}
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package internal contains support packages for oauth2 package.
package internal
import (
+ "context"
"net/http"
-
- "golang.org/x/net/context"
)
// HTTPClient is the context key to use with golang.org/x/net/context's
// because nobody else can create a ContextKey, being unexported.
type ContextKey struct{}
-// ContextClientFunc is a func which tries to return an *http.Client
-// given a Context value. If it returns an error, the search stops
-// with that error. If it returns (nil, nil), the search continues
-// down the list of registered funcs.
-type ContextClientFunc func(context.Context) (*http.Client, error)
-
-var contextClientFuncs []ContextClientFunc
-
-func RegisterContextClientFunc(fn ContextClientFunc) {
- contextClientFuncs = append(contextClientFuncs, fn)
-}
+var appengineClientHook func(context.Context) *http.Client
-func ContextClient(ctx context.Context) (*http.Client, error) {
+func ContextClient(ctx context.Context) *http.Client {
if ctx != nil {
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok {
- return hc, nil
- }
- }
- for _, fn := range contextClientFuncs {
- c, err := fn(ctx)
- if err != nil {
- return nil, err
- }
- if c != nil {
- return c, nil
+ return hc
}
}
- return http.DefaultClient, nil
-}
-
-func ContextTransport(ctx context.Context) http.RoundTripper {
- hc, err := ContextClient(ctx)
- // This is a rare error case (somebody using nil on App Engine).
- if err != nil {
- return ErrorTransport{err}
+ if appengineClientHook != nil {
+ return appengineClientHook(ctx)
}
- return hc.Transport
-}
-
-// ErrorTransport returns the specified error on RoundTrip.
-// This RoundTripper should be used in rare error cases where
-// error handling can be postponed to response handling time.
-type ErrorTransport struct{ Err error }
-
-func (t ErrorTransport) RoundTrip(*http.Request) (*http.Response, error) {
- return nil, t.Err
+ return http.DefaultClient
}
// license that can be found in the LICENSE file.
// Package oauth2 provides support for making
-// OAuth2 authorized and authenticated HTTP requests.
+// OAuth2 authorized and authenticated HTTP requests,
+// as specified in RFC 6749.
// It can additionally grant authorization with Bearer JWT.
package oauth2 // import "golang.org/x/oauth2"
import (
"bytes"
+ "context"
"errors"
"net/http"
"net/url"
"strings"
"sync"
- "golang.org/x/net/context"
"golang.org/x/oauth2/internal"
)
// that asks for permissions for the required scopes explicitly.
//
// State is a token to protect the user from CSRF attacks. You must
-// always provide a non-zero string and validate that it matches the
+// always provide a non-empty string and validate that it matches the
// the state query parameter on your redirect callback.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
//
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
// as ApprovalForce.
+// It can also be used to pass the PKCE challange.
+// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
var buf bytes.Buffer
buf.WriteString(c.Endpoint.AuthURL)
v := url.Values{
"response_type": {"code"},
"client_id": {c.ClientID},
- "redirect_uri": internal.CondVal(c.RedirectURL),
- "scope": internal.CondVal(strings.Join(c.Scopes, " ")),
- "state": internal.CondVal(state),
+ }
+ if c.RedirectURL != "" {
+ v.Set("redirect_uri", c.RedirectURL)
+ }
+ if len(c.Scopes) > 0 {
+ v.Set("scope", strings.Join(c.Scopes, " "))
+ }
+ if state != "" {
+ // TODO(light): Docs say never to omit state; don't allow empty.
+ v.Set("state", state)
}
for _, opt := range opts {
opt.setValue(v)
// The HTTP client to use is derived from the context.
// If nil, http.DefaultClient is used.
func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) {
- return retrieveToken(ctx, c, url.Values{
+ v := url.Values{
"grant_type": {"password"},
"username": {username},
"password": {password},
- "scope": internal.CondVal(strings.Join(c.Scopes, " ")),
- })
+ }
+ if len(c.Scopes) > 0 {
+ v.Set("scope", strings.Join(c.Scopes, " "))
+ }
+ return retrieveToken(ctx, c, v)
}
// Exchange converts an authorization code into a token.
//
// The code will be in the *http.Request.FormValue("code"). Before
// calling Exchange, be sure to validate FormValue("state").
-func (c *Config) Exchange(ctx context.Context, code string) (*Token, error) {
- return retrieveToken(ctx, c, url.Values{
- "grant_type": {"authorization_code"},
- "code": {code},
- "redirect_uri": internal.CondVal(c.RedirectURL),
- "scope": internal.CondVal(strings.Join(c.Scopes, " ")),
- })
+//
+// Opts may include the PKCE verifier code if previously used in AuthCodeURL.
+// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
+func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) {
+ v := url.Values{
+ "grant_type": {"authorization_code"},
+ "code": {code},
+ }
+ if c.RedirectURL != "" {
+ v.Set("redirect_uri", c.RedirectURL)
+ }
+ for _, opt := range opts {
+ opt.setValue(v)
+ }
+ return retrieveToken(ctx, c, v)
}
// Client returns an HTTP client using the provided token.
// NewClient creates an *http.Client from a Context and TokenSource.
// The returned client is not valid beyond the lifetime of the context.
//
+// Note that if a custom *http.Client is provided via the Context it
+// is used only for token acquisition and is not used to configure the
+// *http.Client returned from NewClient.
+//
// As a special case, if src is nil, a non-OAuth2 client is returned
// using the provided context. This exists to support related OAuth2
// packages.
func NewClient(ctx context.Context, src TokenSource) *http.Client {
if src == nil {
- c, err := internal.ContextClient(ctx)
- if err != nil {
- return &http.Client{Transport: internal.ErrorTransport{Err: err}}
- }
- return c
+ return internal.ContextClient(ctx)
}
return &http.Client{
Transport: &Transport{
- Base: internal.ContextTransport(ctx),
+ Base: internal.ContextClient(ctx).Transport,
Source: ReuseTokenSource(nil, src),
},
}
package oauth2
import (
+ "context"
+ "fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
- "golang.org/x/net/context"
"golang.org/x/oauth2/internal"
)
// expirations due to client-server time mismatches.
const expiryDelta = 10 * time.Second
-// Token represents the crendentials used to authorize
+// Token represents the credentials used to authorize
// the requests to access protected resources on the OAuth 2.0
// provider's backend.
//
if t.Expiry.IsZero() {
return false
}
- return t.Expiry.Add(-expiryDelta).Before(time.Now())
+ return t.Expiry.Round(0).Add(-expiryDelta).Before(time.Now())
}
// Valid reports whether t is non-nil, has an AccessToken, and is not expired.
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v)
if err != nil {
+ if rErr, ok := err.(*internal.RetrieveError); ok {
+ return nil, (*RetrieveError)(rErr)
+ }
return nil, err
}
return tokenFromInternal(tk), nil
}
+
+// RetrieveError is the error returned when the token endpoint returns a
+// non-2XX HTTP status code.
+type RetrieveError struct {
+ Response *http.Response
+ // Body is the body that was consumed by reading Response.Body.
+ // It may be truncated.
+ Body []byte
+}
+
+func (r *RetrieveError) Error() string {
+ return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
+}
}
// RoundTrip authorizes and authenticates the request with an
-// access token. If no token exists or token is expired,
-// tries to refresh/fetch a new token.
+// access token from Transport's Source.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
+ reqBodyClosed := false
+ if req.Body != nil {
+ defer func() {
+ if !reqBodyClosed {
+ req.Body.Close()
+ }
+ }()
+ }
+
if t.Source == nil {
return nil, errors.New("oauth2: Transport's Source is nil")
}
token.SetAuthHeader(req2)
t.setModReq(req, req2)
res, err := t.base().RoundTrip(req2)
+
+ // req.Body is assumed to have been closed by the base RoundTripper.
+ reqBodyClosed = true
+
if err != nil {
t.setModReq(req, nil)
return nil, err