* Update dep github.com/markbates/goth
* Update dep github.com/blevesearch/bleve
* Update dep golang.org/x/oauth2
* Fix github.com/blevesearch/bleve to c74e08f039
* Update dep golang.org/x/oauth2
tags/v1.7.0-dev
revision = "3a771d992973f24aa725d07868b467d1ddfceafb" | revision = "3a771d992973f24aa725d07868b467d1ddfceafb" | ||||
[[projects]] | [[projects]] | ||||
digest = "1:67351095005f164e748a5a21899d1403b03878cb2d40a7b0f742376e6eeda974" | |||||
digest = "1:c10f35be6200b09e26da267ca80f837315093ecaba27e7a223071380efb9dd32" | |||||
name = "github.com/blevesearch/bleve" | name = "github.com/blevesearch/bleve" | ||||
packages = [ | packages = [ | ||||
".", | ".", | ||||
"search/searcher", | "search/searcher", | ||||
] | ] | ||||
pruneopts = "NUT" | pruneopts = "NUT" | ||||
revision = "ff210fbc6d348ad67aa5754eaea11a463fcddafd" | |||||
revision = "c74e08f039e56cef576e4336382b2a2d12d9e026" | |||||
[[projects]] | [[projects]] | ||||
branch = "master" | branch = "master" | ||||
revision = "e3534c89ef969912856dfa39e56b09e58c5f5daf" | revision = "e3534c89ef969912856dfa39e56b09e58c5f5daf" | ||||
[[projects]] | [[projects]] | ||||
digest = "1:23f75ae90fcc38dac6fad6881006ea7d0f2c78db5f9f81f3df558dc91460e61f" | |||||
digest = "1:4b992ec853d0ea9bac3dcf09a64af61de1a392e6cb0eef2204c0c92f4ae6b911" | |||||
name = "github.com/markbates/goth" | name = "github.com/markbates/goth" | ||||
packages = [ | packages = [ | ||||
".", | ".", | ||||
"providers/twitter", | "providers/twitter", | ||||
] | ] | ||||
pruneopts = "NUT" | pruneopts = "NUT" | ||||
revision = "f9c6649ab984d6ea71ef1e13b7b1cdffcf4592d3" | |||||
version = "v1.46.1" | |||||
revision = "bc6d8ddf751a745f37ca5567dbbfc4157bbf5da9" | |||||
version = "v1.47.2" | |||||
[[projects]] | [[projects]] | ||||
digest = "1:c9724c929d27a14475a45b17a267dbc60671c0bc2c5c05ed21f011f7b5bc9fb5" | digest = "1:c9724c929d27a14475a45b17a267dbc60671c0bc2c5c05ed21f011f7b5bc9fb5" | ||||
[[projects]] | [[projects]] | ||||
branch = "master" | branch = "master" | ||||
digest = "1:6d5ed712653ea5321fe3e3475ab2188cf362a4e0d31e9fd3acbd4dfbbca0d680" | |||||
digest = "1:d0a0bdd2b64d981aa4e6a1ade90431d042cd7fa31b584e33d45e62cbfec43380" | |||||
name = "golang.org/x/net" | name = "golang.org/x/net" | ||||
packages = [ | packages = [ | ||||
"context", | "context", | ||||
"context/ctxhttp", | |||||
"html", | "html", | ||||
"html/atom", | "html/atom", | ||||
"html/charset", | "html/charset", | ||||
revision = "9b4f9f5ad5197c79fd623a3638e70d8b26cef344" | revision = "9b4f9f5ad5197c79fd623a3638e70d8b26cef344" | ||||
[[projects]] | [[projects]] | ||||
digest = "1:8159a9cda4b8810aaaeb0d60e2fa68e2fd86d8af4ec8f5059830839e3c8d93d5" | |||||
branch = "master" | |||||
digest = "1:274a6321a5a9f185eeb3fab5d7d8397e0e9f57737490d749f562c7e205ffbc2e" | |||||
name = "golang.org/x/oauth2" | name = "golang.org/x/oauth2" | ||||
packages = [ | packages = [ | ||||
".", | ".", | ||||
"internal", | "internal", | ||||
] | ] | ||||
pruneopts = "NUT" | pruneopts = "NUT" | ||||
revision = "c10ba270aa0bf8b8c1c986e103859c67a9103061" | |||||
revision = "c453e0c757598fd055e170a3a359263c91e13153" | |||||
[[projects]] | [[projects]] | ||||
digest = "1:9f303486d623f840492bfeb48eb906a94e9d3fe638a761639b72ce64bf7bfcc3" | digest = "1:9f303486d623f840492bfeb48eb906a94e9d3fe638a761639b72ce64bf7bfcc3" |
branch = "master" | branch = "master" | ||||
name = "code.gitea.io/sdk" | name = "code.gitea.io/sdk" | ||||
[[constraint]] | |||||
# branch = "master" | |||||
revision = "c74e08f039e56cef576e4336382b2a2d12d9e026" | |||||
name = "github.com/blevesearch/bleve" | |||||
#Not targetting v0.7.0 since standard where use only just after this tag | |||||
[[constraint]] | [[constraint]] | ||||
revision = "12dd70caea0268ac0d6c2707d0611ef601e7c64e" | revision = "12dd70caea0268ac0d6c2707d0611ef601e7c64e" | ||||
name = "golang.org/x/crypto" | name = "golang.org/x/crypto" | ||||
[[constraint]] | [[constraint]] | ||||
name = "github.com/markbates/goth" | name = "github.com/markbates/goth" | ||||
version = "1.46.1" | |||||
version = "1.47.2" | |||||
[[constraint]] | [[constraint]] | ||||
branch = "master" | branch = "master" | ||||
source = "github.com/go-gitea/bolt" | source = "github.com/go-gitea/bolt" | ||||
[[override]] | [[override]] | ||||
revision = "c10ba270aa0bf8b8c1c986e103859c67a9103061" | |||||
branch = "master" | |||||
name = "golang.org/x/oauth2" | name = "golang.org/x/oauth2" | ||||
[[constraint]] | [[constraint]] |
package bleve | package bleve | ||||
import ( | import ( | ||||
"context" | |||||
"github.com/blevesearch/bleve/document" | "github.com/blevesearch/bleve/document" | ||||
"github.com/blevesearch/bleve/index" | "github.com/blevesearch/bleve/index" | ||||
"github.com/blevesearch/bleve/index/store" | "github.com/blevesearch/bleve/index/store" | ||||
"github.com/blevesearch/bleve/mapping" | "github.com/blevesearch/bleve/mapping" | ||||
"golang.org/x/net/context" | |||||
) | ) | ||||
// A Batch groups together multiple Index and Delete | // A Batch groups together multiple Index and Delete |
// prepare new index snapshot | // prepare new index snapshot | ||||
newSnapshot := &IndexSnapshot{ | newSnapshot := &IndexSnapshot{ | ||||
parent: s, | parent: s, | ||||
segment: make([]*SegmentSnapshot, nsegs, nsegs+1), | |||||
offsets: make([]uint64, nsegs, nsegs+1), | |||||
segment: make([]*SegmentSnapshot, 0, nsegs+1), | |||||
offsets: make([]uint64, 0, nsegs+1), | |||||
internal: make(map[string][]byte, len(s.root.internal)), | internal: make(map[string][]byte, len(s.root.internal)), | ||||
epoch: s.nextSnapshotEpoch, | epoch: s.nextSnapshotEpoch, | ||||
refs: 1, | refs: 1, | ||||
return err | return err | ||||
} | } | ||||
} | } | ||||
newSnapshot.segment[i] = &SegmentSnapshot{ | |||||
newss := &SegmentSnapshot{ | |||||
id: s.root.segment[i].id, | id: s.root.segment[i].id, | ||||
segment: s.root.segment[i].segment, | segment: s.root.segment[i].segment, | ||||
cachedDocs: s.root.segment[i].cachedDocs, | cachedDocs: s.root.segment[i].cachedDocs, | ||||
} | } | ||||
s.root.segment[i].segment.AddRef() | |||||
// apply new obsoletions | // apply new obsoletions | ||||
if s.root.segment[i].deleted == nil { | if s.root.segment[i].deleted == nil { | ||||
newSnapshot.segment[i].deleted = delta | |||||
newss.deleted = delta | |||||
} else { | } else { | ||||
newSnapshot.segment[i].deleted = roaring.Or(s.root.segment[i].deleted, delta) | |||||
newss.deleted = roaring.Or(s.root.segment[i].deleted, delta) | |||||
} | } | ||||
newSnapshot.offsets[i] = running | |||||
running += s.root.segment[i].Count() | |||||
// check for live size before copying | |||||
if newss.LiveSize() > 0 { | |||||
newSnapshot.segment = append(newSnapshot.segment, newss) | |||||
s.root.segment[i].segment.AddRef() | |||||
newSnapshot.offsets = append(newSnapshot.offsets, running) | |||||
running += s.root.segment[i].Count() | |||||
} | |||||
} | } | ||||
// append new segment, if any, to end of the new index snapshot | // append new segment, if any, to end of the new index snapshot | ||||
if next.data != nil { | if next.data != nil { | ||||
newSegmentSnapshot := &SegmentSnapshot{ | newSegmentSnapshot := &SegmentSnapshot{ | ||||
// prepare new index snapshot | // prepare new index snapshot | ||||
currSize := len(s.root.segment) | currSize := len(s.root.segment) | ||||
newSize := currSize + 1 - len(nextMerge.old) | newSize := currSize + 1 - len(nextMerge.old) | ||||
// empty segments deletion | |||||
if nextMerge.new == nil { | |||||
newSize-- | |||||
} | |||||
newSnapshot := &IndexSnapshot{ | newSnapshot := &IndexSnapshot{ | ||||
parent: s, | parent: s, | ||||
segment: make([]*SegmentSnapshot, 0, newSize), | segment: make([]*SegmentSnapshot, 0, newSize), | ||||
segmentID := s.root.segment[i].id | segmentID := s.root.segment[i].id | ||||
if segSnapAtMerge, ok := nextMerge.old[segmentID]; ok { | if segSnapAtMerge, ok := nextMerge.old[segmentID]; ok { | ||||
// this segment is going away, see if anything else was deleted since we started the merge | // this segment is going away, see if anything else was deleted since we started the merge | ||||
if s.root.segment[i].deleted != nil { | |||||
if segSnapAtMerge != nil && s.root.segment[i].deleted != nil { | |||||
// assume all these deletes are new | // assume all these deletes are new | ||||
deletedSince := s.root.segment[i].deleted | deletedSince := s.root.segment[i].deleted | ||||
// if we already knew about some of them, remove | // if we already knew about some of them, remove | ||||
newSegmentDeleted.Add(uint32(newDocNum)) | newSegmentDeleted.Add(uint32(newDocNum)) | ||||
} | } | ||||
} | } | ||||
} else { | |||||
// clean up the old segment map to figure out the | |||||
// obsolete segments wrt root in meantime, whatever | |||||
// segments left behind in old map after processing | |||||
// the root segments would be the obsolete segment set | |||||
delete(nextMerge.old, segmentID) | |||||
} else if s.root.segment[i].LiveSize() > 0 { | |||||
// this segment is staying | // this segment is staying | ||||
newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{ | newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{ | ||||
id: s.root.segment[i].id, | id: s.root.segment[i].id, | ||||
} | } | ||||
} | } | ||||
// put new segment at end | |||||
newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{ | |||||
id: nextMerge.id, | |||||
segment: nextMerge.new, // take ownership for nextMerge.new's ref-count | |||||
deleted: newSegmentDeleted, | |||||
cachedDocs: &cachedDocs{cache: nil}, | |||||
}) | |||||
newSnapshot.offsets = append(newSnapshot.offsets, running) | |||||
// before the newMerge introduction, need to clean the newly | |||||
// merged segment wrt the current root segments, hence | |||||
// applying the obsolete segment contents to newly merged segment | |||||
for segID, ss := range nextMerge.old { | |||||
obsoleted := ss.DocNumbersLive() | |||||
if obsoleted != nil { | |||||
obsoletedIter := obsoleted.Iterator() | |||||
for obsoletedIter.HasNext() { | |||||
oldDocNum := obsoletedIter.Next() | |||||
newDocNum := nextMerge.oldNewDocNums[segID][oldDocNum] | |||||
newSegmentDeleted.Add(uint32(newDocNum)) | |||||
} | |||||
} | |||||
} | |||||
// In case where all the docs in the newly merged segment getting | |||||
// deleted by the time we reach here, can skip the introduction. | |||||
if nextMerge.new != nil && | |||||
nextMerge.new.Count() > newSegmentDeleted.GetCardinality() { | |||||
// put new segment at end | |||||
newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{ | |||||
id: nextMerge.id, | |||||
segment: nextMerge.new, // take ownership for nextMerge.new's ref-count | |||||
deleted: newSegmentDeleted, | |||||
cachedDocs: &cachedDocs{cache: nil}, | |||||
}) | |||||
newSnapshot.offsets = append(newSnapshot.offsets, running) | |||||
} | |||||
newSnapshot.AddRef() // 1 ref for the nextMerge.notify response | |||||
// swap in new segment | // swap in new segment | ||||
rootPrev := s.root | rootPrev := s.root | ||||
_ = rootPrev.DecRef() | _ = rootPrev.DecRef() | ||||
} | } | ||||
// notify merger we incorporated this | |||||
// notify requester that we incorporated this | |||||
nextMerge.notify <- newSnapshot | |||||
close(nextMerge.notify) | close(nextMerge.notify) | ||||
} | } | ||||
package scorch | package scorch | ||||
import ( | import ( | ||||
"bytes" | |||||
"encoding/json" | |||||
"fmt" | "fmt" | ||||
"os" | "os" | ||||
"sync/atomic" | "sync/atomic" | ||||
func (s *Scorch) mergerLoop() { | func (s *Scorch) mergerLoop() { | ||||
var lastEpochMergePlanned uint64 | var lastEpochMergePlanned uint64 | ||||
mergePlannerOptions, err := s.parseMergePlannerOptions() | |||||
if err != nil { | |||||
s.fireAsyncError(fmt.Errorf("mergePlannerOption json parsing err: %v", err)) | |||||
s.asyncTasks.Done() | |||||
return | |||||
} | |||||
OUTER: | OUTER: | ||||
for { | for { | ||||
select { | select { | ||||
startTime := time.Now() | startTime := time.Now() | ||||
// lets get started | // lets get started | ||||
err := s.planMergeAtSnapshot(ourSnapshot) | |||||
err := s.planMergeAtSnapshot(ourSnapshot, mergePlannerOptions) | |||||
if err != nil { | if err != nil { | ||||
s.fireAsyncError(fmt.Errorf("merging err: %v", err)) | s.fireAsyncError(fmt.Errorf("merging err: %v", err)) | ||||
_ = ourSnapshot.DecRef() | _ = ourSnapshot.DecRef() | ||||
_ = ourSnapshot.DecRef() | _ = ourSnapshot.DecRef() | ||||
// tell the persister we're waiting for changes | // tell the persister we're waiting for changes | ||||
// first make a notification chan | |||||
notifyUs := make(notificationChan) | |||||
// first make a epochWatcher chan | |||||
ew := &epochWatcher{ | |||||
epoch: lastEpochMergePlanned, | |||||
notifyCh: make(notificationChan, 1), | |||||
} | |||||
// give it to the persister | // give it to the persister | ||||
select { | select { | ||||
case <-s.closeCh: | case <-s.closeCh: | ||||
break OUTER | break OUTER | ||||
case s.persisterNotifier <- notifyUs: | |||||
} | |||||
// check again | |||||
s.rootLock.RLock() | |||||
ourSnapshot = s.root | |||||
ourSnapshot.AddRef() | |||||
s.rootLock.RUnlock() | |||||
if ourSnapshot.epoch != lastEpochMergePlanned { | |||||
startTime := time.Now() | |||||
// lets get started | |||||
err := s.planMergeAtSnapshot(ourSnapshot) | |||||
if err != nil { | |||||
s.fireAsyncError(fmt.Errorf("merging err: %v", err)) | |||||
_ = ourSnapshot.DecRef() | |||||
continue OUTER | |||||
} | |||||
lastEpochMergePlanned = ourSnapshot.epoch | |||||
s.fireEvent(EventKindMergerProgress, time.Since(startTime)) | |||||
case s.persisterNotifier <- ew: | |||||
} | } | ||||
_ = ourSnapshot.DecRef() | |||||
// now wait for it (but also detect close) | |||||
// now wait for persister (but also detect close) | |||||
select { | select { | ||||
case <-s.closeCh: | case <-s.closeCh: | ||||
break OUTER | break OUTER | ||||
case <-notifyUs: | |||||
// woken up, next loop should pick up work | |||||
case <-ew.notifyCh: | |||||
} | } | ||||
} | } | ||||
} | } | ||||
s.asyncTasks.Done() | s.asyncTasks.Done() | ||||
} | } | ||||
func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error { | |||||
func (s *Scorch) parseMergePlannerOptions() (*mergeplan.MergePlanOptions, | |||||
error) { | |||||
mergePlannerOptions := mergeplan.DefaultMergePlanOptions | |||||
if v, ok := s.config["scorchMergePlanOptions"]; ok { | |||||
b, err := json.Marshal(v) | |||||
if err != nil { | |||||
return &mergePlannerOptions, err | |||||
} | |||||
err = json.Unmarshal(b, &mergePlannerOptions) | |||||
if err != nil { | |||||
return &mergePlannerOptions, err | |||||
} | |||||
} | |||||
return &mergePlannerOptions, nil | |||||
} | |||||
func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot, | |||||
options *mergeplan.MergePlanOptions) error { | |||||
// build list of zap segments in this snapshot | // build list of zap segments in this snapshot | ||||
var onlyZapSnapshots []mergeplan.Segment | var onlyZapSnapshots []mergeplan.Segment | ||||
for _, segmentSnapshot := range ourSnapshot.segment { | for _, segmentSnapshot := range ourSnapshot.segment { | ||||
} | } | ||||
// give this list to the planner | // give this list to the planner | ||||
resultMergePlan, err := mergeplan.Plan(onlyZapSnapshots, nil) | |||||
resultMergePlan, err := mergeplan.Plan(onlyZapSnapshots, options) | |||||
if err != nil { | if err != nil { | ||||
return fmt.Errorf("merge planning err: %v", err) | return fmt.Errorf("merge planning err: %v", err) | ||||
} | } | ||||
} | } | ||||
// process tasks in serial for now | // process tasks in serial for now | ||||
var notifications []notificationChan | |||||
var notifications []chan *IndexSnapshot | |||||
for _, task := range resultMergePlan.Tasks { | for _, task := range resultMergePlan.Tasks { | ||||
if len(task.Segments) == 0 { | |||||
continue | |||||
} | |||||
oldMap := make(map[uint64]*SegmentSnapshot) | oldMap := make(map[uint64]*SegmentSnapshot) | ||||
newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1) | newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1) | ||||
segmentsToMerge := make([]*zap.Segment, 0, len(task.Segments)) | segmentsToMerge := make([]*zap.Segment, 0, len(task.Segments)) | ||||
if segSnapshot, ok := planSegment.(*SegmentSnapshot); ok { | if segSnapshot, ok := planSegment.(*SegmentSnapshot); ok { | ||||
oldMap[segSnapshot.id] = segSnapshot | oldMap[segSnapshot.id] = segSnapshot | ||||
if zapSeg, ok := segSnapshot.segment.(*zap.Segment); ok { | if zapSeg, ok := segSnapshot.segment.(*zap.Segment); ok { | ||||
segmentsToMerge = append(segmentsToMerge, zapSeg) | |||||
docsToDrop = append(docsToDrop, segSnapshot.deleted) | |||||
if segSnapshot.LiveSize() == 0 { | |||||
oldMap[segSnapshot.id] = nil | |||||
} else { | |||||
segmentsToMerge = append(segmentsToMerge, zapSeg) | |||||
docsToDrop = append(docsToDrop, segSnapshot.deleted) | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
filename := zapFileName(newSegmentID) | |||||
s.markIneligibleForRemoval(filename) | |||||
path := s.path + string(os.PathSeparator) + filename | |||||
newDocNums, err := zap.Merge(segmentsToMerge, docsToDrop, path, DefaultChunkFactor) | |||||
if err != nil { | |||||
s.unmarkIneligibleForRemoval(filename) | |||||
return fmt.Errorf("merging failed: %v", err) | |||||
} | |||||
segment, err := zap.Open(path) | |||||
if err != nil { | |||||
s.unmarkIneligibleForRemoval(filename) | |||||
return err | |||||
var oldNewDocNums map[uint64][]uint64 | |||||
var segment segment.Segment | |||||
if len(segmentsToMerge) > 0 { | |||||
filename := zapFileName(newSegmentID) | |||||
s.markIneligibleForRemoval(filename) | |||||
path := s.path + string(os.PathSeparator) + filename | |||||
newDocNums, err := zap.Merge(segmentsToMerge, docsToDrop, path, 1024) | |||||
if err != nil { | |||||
s.unmarkIneligibleForRemoval(filename) | |||||
return fmt.Errorf("merging failed: %v", err) | |||||
} | |||||
segment, err = zap.Open(path) | |||||
if err != nil { | |||||
s.unmarkIneligibleForRemoval(filename) | |||||
return err | |||||
} | |||||
oldNewDocNums = make(map[uint64][]uint64) | |||||
for i, segNewDocNums := range newDocNums { | |||||
oldNewDocNums[task.Segments[i].Id()] = segNewDocNums | |||||
} | |||||
} | } | ||||
sm := &segmentMerge{ | sm := &segmentMerge{ | ||||
id: newSegmentID, | id: newSegmentID, | ||||
old: oldMap, | old: oldMap, | ||||
oldNewDocNums: make(map[uint64][]uint64), | |||||
oldNewDocNums: oldNewDocNums, | |||||
new: segment, | new: segment, | ||||
notify: make(notificationChan), | |||||
notify: make(chan *IndexSnapshot, 1), | |||||
} | } | ||||
notifications = append(notifications, sm.notify) | notifications = append(notifications, sm.notify) | ||||
for i, segNewDocNums := range newDocNums { | |||||
sm.oldNewDocNums[task.Segments[i].Id()] = segNewDocNums | |||||
} | |||||
// give it to the introducer | // give it to the introducer | ||||
select { | select { | ||||
case <-s.closeCh: | case <-s.closeCh: | ||||
_ = segment.Close() | |||||
return nil | return nil | ||||
case s.merges <- sm: | case s.merges <- sm: | ||||
} | } | ||||
select { | select { | ||||
case <-s.closeCh: | case <-s.closeCh: | ||||
return nil | return nil | ||||
case <-notification: | |||||
case newSnapshot := <-notification: | |||||
if newSnapshot != nil { | |||||
_ = newSnapshot.DecRef() | |||||
} | |||||
} | } | ||||
} | } | ||||
return nil | return nil | ||||
old map[uint64]*SegmentSnapshot | old map[uint64]*SegmentSnapshot | ||||
oldNewDocNums map[uint64][]uint64 | oldNewDocNums map[uint64][]uint64 | ||||
new segment.Segment | new segment.Segment | ||||
notify notificationChan | |||||
notify chan *IndexSnapshot | |||||
} | |||||
// perform a merging of the given SegmentBase instances into a new, | |||||
// persisted segment, and synchronously introduce that new segment | |||||
// into the root | |||||
func (s *Scorch) mergeSegmentBases(snapshot *IndexSnapshot, | |||||
sbs []*zap.SegmentBase, sbsDrops []*roaring.Bitmap, sbsIndexes []int, | |||||
chunkFactor uint32) (uint64, *IndexSnapshot, uint64, error) { | |||||
var br bytes.Buffer | |||||
cr := zap.NewCountHashWriter(&br) | |||||
newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset, | |||||
docValueOffset, dictLocs, fieldsInv, fieldsMap, err := | |||||
zap.MergeToWriter(sbs, sbsDrops, chunkFactor, cr) | |||||
if err != nil { | |||||
return 0, nil, 0, err | |||||
} | |||||
sb, err := zap.InitSegmentBase(br.Bytes(), cr.Sum32(), chunkFactor, | |||||
fieldsMap, fieldsInv, numDocs, storedIndexOffset, fieldsIndexOffset, | |||||
docValueOffset, dictLocs) | |||||
if err != nil { | |||||
return 0, nil, 0, err | |||||
} | |||||
newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1) | |||||
filename := zapFileName(newSegmentID) | |||||
path := s.path + string(os.PathSeparator) + filename | |||||
err = zap.PersistSegmentBase(sb, path) | |||||
if err != nil { | |||||
return 0, nil, 0, err | |||||
} | |||||
segment, err := zap.Open(path) | |||||
if err != nil { | |||||
return 0, nil, 0, err | |||||
} | |||||
sm := &segmentMerge{ | |||||
id: newSegmentID, | |||||
old: make(map[uint64]*SegmentSnapshot), | |||||
oldNewDocNums: make(map[uint64][]uint64), | |||||
new: segment, | |||||
notify: make(chan *IndexSnapshot, 1), | |||||
} | |||||
for i, idx := range sbsIndexes { | |||||
ss := snapshot.segment[idx] | |||||
sm.old[ss.id] = ss | |||||
sm.oldNewDocNums[ss.id] = newDocNums[i] | |||||
} | |||||
select { // send to introducer | |||||
case <-s.closeCh: | |||||
_ = segment.DecRef() | |||||
return 0, nil, 0, nil // TODO: return ErrInterruptedClosed? | |||||
case s.merges <- sm: | |||||
} | |||||
select { // wait for introduction to complete | |||||
case <-s.closeCh: | |||||
return 0, nil, 0, nil // TODO: return ErrInterruptedClosed? | |||||
case newSnapshot := <-sm.notify: | |||||
return numDocs, newSnapshot, newSegmentID, nil | |||||
} | |||||
} | } |
// While we’re over budget, keep looping, which might produce | // While we’re over budget, keep looping, which might produce | ||||
// another MergeTask. | // another MergeTask. | ||||
for len(eligibles) > budgetNumSegments { | |||||
for len(eligibles) > 0 && (len(eligibles)+len(rv.Tasks)) > budgetNumSegments { | |||||
// Track a current best roster as we examine and score | // Track a current best roster as we examine and score | ||||
// potential rosters of merges. | // potential rosters of merges. | ||||
var bestRoster []Segment | var bestRoster []Segment | ||||
var bestRosterScore float64 // Lower score is better. | var bestRosterScore float64 // Lower score is better. | ||||
for startIdx := 0; startIdx < len(eligibles)-o.SegmentsPerMergeTask; startIdx++ { | |||||
for startIdx := 0; startIdx < len(eligibles); startIdx++ { | |||||
var roster []Segment | var roster []Segment | ||||
var rosterLiveSize int64 | var rosterLiveSize int64 | ||||
var DefaultChunkFactor uint32 = 1024 | var DefaultChunkFactor uint32 = 1024 | ||||
// Arbitrary number, need to make it configurable. | |||||
// Lower values like 10/making persister really slow | |||||
// doesn't work well as it is creating more files to | |||||
// persist for in next persist iteration and spikes the # FDs. | |||||
// Ideal value should let persister also proceed at | |||||
// an optimum pace so that the merger can skip | |||||
// many intermediate snapshots. | |||||
// This needs to be based on empirical data. | |||||
// TODO - may need to revisit this approach/value. | |||||
var epochDistance = uint64(5) | |||||
type notificationChan chan struct{} | type notificationChan chan struct{} | ||||
func (s *Scorch) persisterLoop() { | func (s *Scorch) persisterLoop() { | ||||
defer s.asyncTasks.Done() | defer s.asyncTasks.Done() | ||||
var notifyChs []notificationChan | |||||
var lastPersistedEpoch uint64 | |||||
var persistWatchers []*epochWatcher | |||||
var lastPersistedEpoch, lastMergedEpoch uint64 | |||||
var ew *epochWatcher | |||||
OUTER: | OUTER: | ||||
for { | for { | ||||
select { | select { | ||||
case <-s.closeCh: | case <-s.closeCh: | ||||
break OUTER | break OUTER | ||||
case notifyCh := <-s.persisterNotifier: | |||||
notifyChs = append(notifyChs, notifyCh) | |||||
case ew = <-s.persisterNotifier: | |||||
persistWatchers = append(persistWatchers, ew) | |||||
default: | default: | ||||
} | } | ||||
if ew != nil && ew.epoch > lastMergedEpoch { | |||||
lastMergedEpoch = ew.epoch | |||||
} | |||||
persistWatchers = s.pausePersisterForMergerCatchUp(lastPersistedEpoch, | |||||
&lastMergedEpoch, persistWatchers) | |||||
var ourSnapshot *IndexSnapshot | var ourSnapshot *IndexSnapshot | ||||
var ourPersisted []chan error | var ourPersisted []chan error | ||||
} | } | ||||
lastPersistedEpoch = ourSnapshot.epoch | lastPersistedEpoch = ourSnapshot.epoch | ||||
for _, notifyCh := range notifyChs { | |||||
close(notifyCh) | |||||
for _, ew := range persistWatchers { | |||||
close(ew.notifyCh) | |||||
} | } | ||||
notifyChs = nil | |||||
persistWatchers = nil | |||||
_ = ourSnapshot.DecRef() | _ = ourSnapshot.DecRef() | ||||
changed := false | changed := false | ||||
break OUTER | break OUTER | ||||
case <-w.notifyCh: | case <-w.notifyCh: | ||||
// woken up, next loop should pick up work | // woken up, next loop should pick up work | ||||
continue OUTER | |||||
case ew = <-s.persisterNotifier: | |||||
// if the watchers are already caught up then let them wait, | |||||
// else let them continue to do the catch up | |||||
persistWatchers = append(persistWatchers, ew) | |||||
} | |||||
} | |||||
} | |||||
func notifyMergeWatchers(lastPersistedEpoch uint64, | |||||
persistWatchers []*epochWatcher) []*epochWatcher { | |||||
var watchersNext []*epochWatcher | |||||
for _, w := range persistWatchers { | |||||
if w.epoch < lastPersistedEpoch { | |||||
close(w.notifyCh) | |||||
} else { | |||||
watchersNext = append(watchersNext, w) | |||||
} | } | ||||
} | } | ||||
return watchersNext | |||||
} | |||||
func (s *Scorch) pausePersisterForMergerCatchUp(lastPersistedEpoch uint64, lastMergedEpoch *uint64, | |||||
persistWatchers []*epochWatcher) []*epochWatcher { | |||||
// first, let the watchers proceed if they lag behind | |||||
persistWatchers = notifyMergeWatchers(lastPersistedEpoch, persistWatchers) | |||||
OUTER: | |||||
// check for slow merger and await until the merger catch up | |||||
for lastPersistedEpoch > *lastMergedEpoch+epochDistance { | |||||
select { | |||||
case <-s.closeCh: | |||||
break OUTER | |||||
case ew := <-s.persisterNotifier: | |||||
persistWatchers = append(persistWatchers, ew) | |||||
*lastMergedEpoch = ew.epoch | |||||
} | |||||
// let the watchers proceed if they lag behind | |||||
persistWatchers = notifyMergeWatchers(lastPersistedEpoch, persistWatchers) | |||||
} | |||||
return persistWatchers | |||||
} | } | ||||
func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error { | func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error { | ||||
// start a write transaction | |||||
tx, err := s.rootBolt.Begin(true) | |||||
persisted, err := s.persistSnapshotMaybeMerge(snapshot) | |||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
} | } | ||||
// defer fsync of the rootbolt | |||||
defer func() { | |||||
if err == nil { | |||||
err = s.rootBolt.Sync() | |||||
if persisted { | |||||
return nil | |||||
} | |||||
return s.persistSnapshotDirect(snapshot) | |||||
} | |||||
// DefaultMinSegmentsForInMemoryMerge represents the default number of | |||||
// in-memory zap segments that persistSnapshotMaybeMerge() needs to | |||||
// see in an IndexSnapshot before it decides to merge and persist | |||||
// those segments | |||||
var DefaultMinSegmentsForInMemoryMerge = 2 | |||||
// persistSnapshotMaybeMerge examines the snapshot and might merge and | |||||
// persist the in-memory zap segments if there are enough of them | |||||
func (s *Scorch) persistSnapshotMaybeMerge(snapshot *IndexSnapshot) ( | |||||
bool, error) { | |||||
// collect the in-memory zap segments (SegmentBase instances) | |||||
var sbs []*zap.SegmentBase | |||||
var sbsDrops []*roaring.Bitmap | |||||
var sbsIndexes []int | |||||
for i, segmentSnapshot := range snapshot.segment { | |||||
if sb, ok := segmentSnapshot.segment.(*zap.SegmentBase); ok { | |||||
sbs = append(sbs, sb) | |||||
sbsDrops = append(sbsDrops, segmentSnapshot.deleted) | |||||
sbsIndexes = append(sbsIndexes, i) | |||||
} | } | ||||
} | |||||
if len(sbs) < DefaultMinSegmentsForInMemoryMerge { | |||||
return false, nil | |||||
} | |||||
_, newSnapshot, newSegmentID, err := s.mergeSegmentBases( | |||||
snapshot, sbs, sbsDrops, sbsIndexes, DefaultChunkFactor) | |||||
if err != nil { | |||||
return false, err | |||||
} | |||||
if newSnapshot == nil { | |||||
return false, nil | |||||
} | |||||
defer func() { | |||||
_ = newSnapshot.DecRef() | |||||
}() | }() | ||||
// defer commit/rollback transaction | |||||
mergedSegmentIDs := map[uint64]struct{}{} | |||||
for _, idx := range sbsIndexes { | |||||
mergedSegmentIDs[snapshot.segment[idx].id] = struct{}{} | |||||
} | |||||
// construct a snapshot that's logically equivalent to the input | |||||
// snapshot, but with merged segments replaced by the new segment | |||||
equiv := &IndexSnapshot{ | |||||
parent: snapshot.parent, | |||||
segment: make([]*SegmentSnapshot, 0, len(snapshot.segment)), | |||||
internal: snapshot.internal, | |||||
epoch: snapshot.epoch, | |||||
} | |||||
// copy to the equiv the segments that weren't replaced | |||||
for _, segment := range snapshot.segment { | |||||
if _, wasMerged := mergedSegmentIDs[segment.id]; !wasMerged { | |||||
equiv.segment = append(equiv.segment, segment) | |||||
} | |||||
} | |||||
// append to the equiv the new segment | |||||
for _, segment := range newSnapshot.segment { | |||||
if segment.id == newSegmentID { | |||||
equiv.segment = append(equiv.segment, &SegmentSnapshot{ | |||||
id: newSegmentID, | |||||
segment: segment.segment, | |||||
deleted: nil, // nil since merging handled deletions | |||||
}) | |||||
break | |||||
} | |||||
} | |||||
err = s.persistSnapshotDirect(equiv) | |||||
if err != nil { | |||||
return false, err | |||||
} | |||||
return true, nil | |||||
} | |||||
func (s *Scorch) persistSnapshotDirect(snapshot *IndexSnapshot) (err error) { | |||||
// start a write transaction | |||||
tx, err := s.rootBolt.Begin(true) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
// defer rollback on error | |||||
defer func() { | defer func() { | ||||
if err == nil { | |||||
err = tx.Commit() | |||||
} else { | |||||
if err != nil { | |||||
_ = tx.Rollback() | _ = tx.Rollback() | ||||
} | } | ||||
}() | }() | ||||
newSegmentPaths := make(map[uint64]string) | newSegmentPaths := make(map[uint64]string) | ||||
// first ensure that each segment in this snapshot has been persisted | // first ensure that each segment in this snapshot has been persisted | ||||
for i, segmentSnapshot := range snapshot.segment { | |||||
snapshotSegmentKey := segment.EncodeUvarintAscending(nil, uint64(i)) | |||||
snapshotSegmentBucket, err2 := snapshotBucket.CreateBucketIfNotExists(snapshotSegmentKey) | |||||
if err2 != nil { | |||||
return err2 | |||||
for _, segmentSnapshot := range snapshot.segment { | |||||
snapshotSegmentKey := segment.EncodeUvarintAscending(nil, segmentSnapshot.id) | |||||
snapshotSegmentBucket, err := snapshotBucket.CreateBucketIfNotExists(snapshotSegmentKey) | |||||
if err != nil { | |||||
return err | |||||
} | } | ||||
switch seg := segmentSnapshot.segment.(type) { | switch seg := segmentSnapshot.segment.(type) { | ||||
case *zap.SegmentBase: | case *zap.SegmentBase: | ||||
// need to persist this to disk | // need to persist this to disk | ||||
filename := zapFileName(segmentSnapshot.id) | filename := zapFileName(segmentSnapshot.id) | ||||
path := s.path + string(os.PathSeparator) + filename | path := s.path + string(os.PathSeparator) + filename | ||||
err2 := zap.PersistSegmentBase(seg, path) | |||||
if err2 != nil { | |||||
return fmt.Errorf("error persisting segment: %v", err2) | |||||
err = zap.PersistSegmentBase(seg, path) | |||||
if err != nil { | |||||
return fmt.Errorf("error persisting segment: %v", err) | |||||
} | } | ||||
newSegmentPaths[segmentSnapshot.id] = path | newSegmentPaths[segmentSnapshot.id] = path | ||||
err = snapshotSegmentBucket.Put(boltPathKey, []byte(filename)) | err = snapshotSegmentBucket.Put(boltPathKey, []byte(filename)) | ||||
} | } | ||||
} | } | ||||
// only alter the root if we actually persisted a segment | |||||
// (sometimes its just a new snapshot, possibly with new internal values) | |||||
// we need to swap in a new root only when we've persisted 1 or | |||||
// more segments -- whereby the new root would have 1-for-1 | |||||
// replacements of in-memory segments with file-based segments | |||||
// | |||||
// other cases like updates to internal values only, and/or when | |||||
// there are only deletions, are already covered and persisted by | |||||
// the newly populated boltdb snapshotBucket above | |||||
if len(newSegmentPaths) > 0 { | if len(newSegmentPaths) > 0 { | ||||
// now try to open all the new snapshots | // now try to open all the new snapshots | ||||
newSegments := make(map[uint64]segment.Segment) | newSegments := make(map[uint64]segment.Segment) | ||||
defer func() { | |||||
for _, s := range newSegments { | |||||
if s != nil { | |||||
// cleanup segments that were opened but not | |||||
// swapped into the new root | |||||
_ = s.Close() | |||||
} | |||||
} | |||||
}() | |||||
for segmentID, path := range newSegmentPaths { | for segmentID, path := range newSegmentPaths { | ||||
newSegments[segmentID], err = zap.Open(path) | newSegments[segmentID], err = zap.Open(path) | ||||
if err != nil { | if err != nil { | ||||
for _, s := range newSegments { | |||||
if s != nil { | |||||
_ = s.Close() // cleanup segments that were successfully opened | |||||
} | |||||
} | |||||
return fmt.Errorf("error opening new segment at %s, %v", path, err) | return fmt.Errorf("error opening new segment at %s, %v", path, err) | ||||
} | } | ||||
} | } | ||||
cachedDocs: segmentSnapshot.cachedDocs, | cachedDocs: segmentSnapshot.cachedDocs, | ||||
} | } | ||||
newIndexSnapshot.segment[i] = newSegmentSnapshot | newIndexSnapshot.segment[i] = newSegmentSnapshot | ||||
delete(newSegments, segmentSnapshot.id) | |||||
// update items persisted incase of a new segment snapshot | // update items persisted incase of a new segment snapshot | ||||
atomic.AddUint64(&s.stats.numItemsPersisted, newSegmentSnapshot.Count()) | atomic.AddUint64(&s.stats.numItemsPersisted, newSegmentSnapshot.Count()) | ||||
} else { | } else { | ||||
for k, v := range s.root.internal { | for k, v := range s.root.internal { | ||||
newIndexSnapshot.internal[k] = v | newIndexSnapshot.internal[k] = v | ||||
} | } | ||||
for _, filename := range filenames { | |||||
delete(s.ineligibleForRemoval, filename) | |||||
} | |||||
rootPrev := s.root | rootPrev := s.root | ||||
s.root = newIndexSnapshot | s.root = newIndexSnapshot | ||||
s.rootLock.Unlock() | s.rootLock.Unlock() | ||||
} | } | ||||
} | } | ||||
err = tx.Commit() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
err = s.rootBolt.Sync() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
// allow files to become eligible for removal after commit, such | |||||
// as file segments from snapshots that came from the merger | |||||
s.rootLock.Lock() | |||||
for _, filename := range filenames { | |||||
delete(s.ineligibleForRemoval, filename) | |||||
} | |||||
s.rootLock.Unlock() | |||||
return nil | return nil | ||||
} | } | ||||
merges chan *segmentMerge | merges chan *segmentMerge | ||||
introducerNotifier chan *epochWatcher | introducerNotifier chan *epochWatcher | ||||
revertToSnapshots chan *snapshotReversion | revertToSnapshots chan *snapshotReversion | ||||
persisterNotifier chan notificationChan | |||||
persisterNotifier chan *epochWatcher | |||||
rootBolt *bolt.DB | rootBolt *bolt.DB | ||||
asyncTasks sync.WaitGroup | asyncTasks sync.WaitGroup | ||||
} | } | ||||
func (s *Scorch) Open() error { | func (s *Scorch) Open() error { | ||||
err := s.openBolt() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
s.asyncTasks.Add(1) | |||||
go s.mainLoop() | |||||
if !s.readOnly && s.path != "" { | |||||
s.asyncTasks.Add(1) | |||||
go s.persisterLoop() | |||||
s.asyncTasks.Add(1) | |||||
go s.mergerLoop() | |||||
} | |||||
return nil | |||||
} | |||||
func (s *Scorch) openBolt() error { | |||||
var ok bool | var ok bool | ||||
s.path, ok = s.config["path"].(string) | s.path, ok = s.config["path"].(string) | ||||
if !ok { | if !ok { | ||||
} | } | ||||
} | } | ||||
} | } | ||||
rootBoltPath := s.path + string(os.PathSeparator) + "root.bolt" | rootBoltPath := s.path + string(os.PathSeparator) + "root.bolt" | ||||
var err error | var err error | ||||
if s.path != "" { | if s.path != "" { | ||||
s.merges = make(chan *segmentMerge) | s.merges = make(chan *segmentMerge) | ||||
s.introducerNotifier = make(chan *epochWatcher, 1) | s.introducerNotifier = make(chan *epochWatcher, 1) | ||||
s.revertToSnapshots = make(chan *snapshotReversion) | s.revertToSnapshots = make(chan *snapshotReversion) | ||||
s.persisterNotifier = make(chan notificationChan) | |||||
s.persisterNotifier = make(chan *epochWatcher, 1) | |||||
if !s.readOnly && s.path != "" { | if !s.readOnly && s.path != "" { | ||||
err := s.removeOldZapFiles() // Before persister or merger create any new files. | err := s.removeOldZapFiles() // Before persister or merger create any new files. | ||||
} | } | ||||
} | } | ||||
s.asyncTasks.Add(1) | |||||
go s.mainLoop() | |||||
if !s.readOnly && s.path != "" { | |||||
s.asyncTasks.Add(1) | |||||
go s.persisterLoop() | |||||
s.asyncTasks.Add(1) | |||||
go s.mergerLoop() | |||||
} | |||||
return nil | return nil | ||||
} | } | ||||
introduction.persisted = make(chan error, 1) | introduction.persisted = make(chan error, 1) | ||||
} | } | ||||
// get read lock, to optimistically prepare obsoleted info | |||||
// optimistically prepare obsoletes outside of rootLock | |||||
s.rootLock.RLock() | s.rootLock.RLock() | ||||
for _, seg := range s.root.segment { | |||||
root := s.root | |||||
root.AddRef() | |||||
s.rootLock.RUnlock() | |||||
for _, seg := range root.segment { | |||||
delta, err := seg.segment.DocNumbers(ids) | delta, err := seg.segment.DocNumbers(ids) | ||||
if err != nil { | if err != nil { | ||||
s.rootLock.RUnlock() | |||||
return err | return err | ||||
} | } | ||||
introduction.obsoletes[seg.id] = delta | introduction.obsoletes[seg.id] = delta | ||||
} | } | ||||
s.rootLock.RUnlock() | |||||
_ = root.DecRef() | |||||
s.introductions <- introduction | s.introductions <- introduction | ||||
var numTokenFrequencies int | var numTokenFrequencies int | ||||
var totLocs int | var totLocs int | ||||
// initial scan for all fieldID's to sort them | |||||
for _, result := range results { | |||||
for _, field := range result.Document.CompositeFields { | |||||
s.getOrDefineField(field.Name()) | |||||
} | |||||
for _, field := range result.Document.Fields { | |||||
s.getOrDefineField(field.Name()) | |||||
} | |||||
} | |||||
sort.Strings(s.FieldsInv[1:]) // keep _id as first field | |||||
s.FieldsMap = make(map[string]uint16, len(s.FieldsInv)) | |||||
for fieldID, fieldName := range s.FieldsInv { | |||||
s.FieldsMap[fieldName] = uint16(fieldID + 1) | |||||
} | |||||
processField := func(fieldID uint16, tfs analysis.TokenFrequencies) { | processField := func(fieldID uint16, tfs analysis.TokenFrequencies) { | ||||
for term, tf := range tfs { | for term, tf := range tfs { | ||||
pidPlus1, exists := s.Dicts[fieldID][term] | pidPlus1, exists := s.Dicts[fieldID][term] |
prefix string | prefix string | ||||
end string | end string | ||||
offset int | offset int | ||||
dictEntry index.DictEntry // reused across Next()'s | |||||
} | } | ||||
// Next returns the next entry in the dictionary | // Next returns the next entry in the dictionary | ||||
d.offset++ | d.offset++ | ||||
postingID := d.d.segment.Dicts[d.d.fieldID][next] | postingID := d.d.segment.Dicts[d.d.fieldID][next] | ||||
return &index.DictEntry{ | |||||
Term: next, | |||||
Count: d.d.segment.Postings[postingID-1].GetCardinality(), | |||||
}, nil | |||||
d.dictEntry.Term = next | |||||
d.dictEntry.Count = d.d.segment.Postings[postingID-1].GetCardinality() | |||||
return &d.dictEntry, nil | |||||
} | } |
"github.com/golang/snappy" | "github.com/golang/snappy" | ||||
) | ) | ||||
const version uint32 = 2 | |||||
const version uint32 = 3 | |||||
const fieldNotUninverted = math.MaxUint64 | const fieldNotUninverted = math.MaxUint64 | ||||
} | } | ||||
func persistStored(memSegment *mem.Segment, w *CountHashWriter) (uint64, error) { | func persistStored(memSegment *mem.Segment, w *CountHashWriter) (uint64, error) { | ||||
var curr int | var curr int | ||||
var metaBuf bytes.Buffer | var metaBuf bytes.Buffer | ||||
var data, compressed []byte | var data, compressed []byte | ||||
metaEncoder := govarint.NewU64Base128Encoder(&metaBuf) | |||||
docNumOffsets := make(map[int]uint64, len(memSegment.Stored)) | docNumOffsets := make(map[int]uint64, len(memSegment.Stored)) | ||||
for docNum, storedValues := range memSegment.Stored { | for docNum, storedValues := range memSegment.Stored { | ||||
if docNum != 0 { | if docNum != 0 { | ||||
// reset buffer if necessary | // reset buffer if necessary | ||||
curr = 0 | |||||
metaBuf.Reset() | metaBuf.Reset() | ||||
data = data[:0] | data = data[:0] | ||||
compressed = compressed[:0] | compressed = compressed[:0] | ||||
curr = 0 | |||||
} | } | ||||
metaEncoder := govarint.NewU64Base128Encoder(&metaBuf) | |||||
st := memSegment.StoredTypes[docNum] | st := memSegment.StoredTypes[docNum] | ||||
sp := memSegment.StoredPos[docNum] | sp := memSegment.StoredPos[docNum] | ||||
// encode fields in order | // encode fields in order | ||||
for fieldID := range memSegment.FieldsInv { | for fieldID := range memSegment.FieldsInv { | ||||
if storedFieldValues, ok := storedValues[uint16(fieldID)]; ok { | if storedFieldValues, ok := storedValues[uint16(fieldID)]; ok { | ||||
// has stored values for this field | |||||
num := len(storedFieldValues) | |||||
stf := st[uint16(fieldID)] | stf := st[uint16(fieldID)] | ||||
spf := sp[uint16(fieldID)] | spf := sp[uint16(fieldID)] | ||||
// process each value | |||||
for i := 0; i < num; i++ { | |||||
// encode field | |||||
_, err2 := metaEncoder.PutU64(uint64(fieldID)) | |||||
if err2 != nil { | |||||
return 0, err2 | |||||
} | |||||
// encode type | |||||
_, err2 = metaEncoder.PutU64(uint64(stf[i])) | |||||
if err2 != nil { | |||||
return 0, err2 | |||||
} | |||||
// encode start offset | |||||
_, err2 = metaEncoder.PutU64(uint64(curr)) | |||||
if err2 != nil { | |||||
return 0, err2 | |||||
} | |||||
// end len | |||||
_, err2 = metaEncoder.PutU64(uint64(len(storedFieldValues[i]))) | |||||
if err2 != nil { | |||||
return 0, err2 | |||||
} | |||||
// encode number of array pos | |||||
_, err2 = metaEncoder.PutU64(uint64(len(spf[i]))) | |||||
if err2 != nil { | |||||
return 0, err2 | |||||
} | |||||
// encode all array positions | |||||
for _, pos := range spf[i] { | |||||
_, err2 = metaEncoder.PutU64(pos) | |||||
if err2 != nil { | |||||
return 0, err2 | |||||
} | |||||
} | |||||
// append data | |||||
data = append(data, storedFieldValues[i]...) | |||||
// update curr | |||||
curr += len(storedFieldValues[i]) | |||||
var err2 error | |||||
curr, data, err2 = persistStoredFieldValues(fieldID, | |||||
storedFieldValues, stf, spf, curr, metaEncoder, data) | |||||
if err2 != nil { | |||||
return 0, err2 | |||||
} | } | ||||
} | } | ||||
} | } | ||||
metaEncoder.Close() | |||||
metaEncoder.Close() | |||||
metaBytes := metaBuf.Bytes() | metaBytes := metaBuf.Bytes() | ||||
// compress the data | // compress the data | ||||
return rv, nil | return rv, nil | ||||
} | } | ||||
func persistStoredFieldValues(fieldID int, | |||||
storedFieldValues [][]byte, stf []byte, spf [][]uint64, | |||||
curr int, metaEncoder *govarint.Base128Encoder, data []byte) ( | |||||
int, []byte, error) { | |||||
for i := 0; i < len(storedFieldValues); i++ { | |||||
// encode field | |||||
_, err := metaEncoder.PutU64(uint64(fieldID)) | |||||
if err != nil { | |||||
return 0, nil, err | |||||
} | |||||
// encode type | |||||
_, err = metaEncoder.PutU64(uint64(stf[i])) | |||||
if err != nil { | |||||
return 0, nil, err | |||||
} | |||||
// encode start offset | |||||
_, err = metaEncoder.PutU64(uint64(curr)) | |||||
if err != nil { | |||||
return 0, nil, err | |||||
} | |||||
// end len | |||||
_, err = metaEncoder.PutU64(uint64(len(storedFieldValues[i]))) | |||||
if err != nil { | |||||
return 0, nil, err | |||||
} | |||||
// encode number of array pos | |||||
_, err = metaEncoder.PutU64(uint64(len(spf[i]))) | |||||
if err != nil { | |||||
return 0, nil, err | |||||
} | |||||
// encode all array positions | |||||
for _, pos := range spf[i] { | |||||
_, err = metaEncoder.PutU64(pos) | |||||
if err != nil { | |||||
return 0, nil, err | |||||
} | |||||
} | |||||
data = append(data, storedFieldValues[i]...) | |||||
curr += len(storedFieldValues[i]) | |||||
} | |||||
return curr, data, nil | |||||
} | |||||
func persistPostingDetails(memSegment *mem.Segment, w *CountHashWriter, chunkFactor uint32) ([]uint64, []uint64, error) { | func persistPostingDetails(memSegment *mem.Segment, w *CountHashWriter, chunkFactor uint32) ([]uint64, []uint64, error) { | ||||
var freqOffsets, locOfffsets []uint64 | var freqOffsets, locOfffsets []uint64 | ||||
tfEncoder := newChunkedIntCoder(uint64(chunkFactor), uint64(len(memSegment.Stored)-1)) | tfEncoder := newChunkedIntCoder(uint64(chunkFactor), uint64(len(memSegment.Stored)-1)) | ||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
// resetting encoder for the next field | |||||
// reseting encoder for the next field | |||||
fdvEncoder.Reset() | fdvEncoder.Reset() | ||||
} | } | ||||
return nil, err | return nil, err | ||||
} | } | ||||
return InitSegmentBase(br.Bytes(), cr.Sum32(), chunkFactor, | |||||
memSegment.FieldsMap, memSegment.FieldsInv, numDocs, | |||||
storedIndexOffset, fieldsIndexOffset, docValueOffset, dictLocs) | |||||
} | |||||
func InitSegmentBase(mem []byte, memCRC uint32, chunkFactor uint32, | |||||
fieldsMap map[string]uint16, fieldsInv []string, numDocs uint64, | |||||
storedIndexOffset uint64, fieldsIndexOffset uint64, docValueOffset uint64, | |||||
dictLocs []uint64) (*SegmentBase, error) { | |||||
sb := &SegmentBase{ | sb := &SegmentBase{ | ||||
mem: br.Bytes(), | |||||
memCRC: cr.Sum32(), | |||||
mem: mem, | |||||
memCRC: memCRC, | |||||
chunkFactor: chunkFactor, | chunkFactor: chunkFactor, | ||||
fieldsMap: memSegment.FieldsMap, | |||||
fieldsInv: memSegment.FieldsInv, | |||||
fieldsMap: fieldsMap, | |||||
fieldsInv: fieldsInv, | |||||
numDocs: numDocs, | numDocs: numDocs, | ||||
storedIndexOffset: storedIndexOffset, | storedIndexOffset: storedIndexOffset, | ||||
fieldsIndexOffset: fieldsIndexOffset, | fieldsIndexOffset: fieldsIndexOffset, | ||||
fieldDvIterMap: make(map[uint16]*docValueIterator), | fieldDvIterMap: make(map[uint16]*docValueIterator), | ||||
} | } | ||||
err = sb.loadDvIterators() | |||||
err := sb.loadDvIterators() | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } |
// MetaData represents the data information inside a | // MetaData represents the data information inside a | ||||
// chunk. | // chunk. | ||||
type MetaData struct { | type MetaData struct { | ||||
DocID uint64 // docid of the data inside the chunk | |||||
DocNum uint64 // docNum of the data inside the chunk | |||||
DocDvLoc uint64 // starting offset for a given docid | DocDvLoc uint64 // starting offset for a given docid | ||||
DocDvLen uint64 // length of data inside the chunk for the given docid | DocDvLen uint64 // length of data inside the chunk for the given docid | ||||
} | } | ||||
rv := &chunkedContentCoder{ | rv := &chunkedContentCoder{ | ||||
chunkSize: chunkSize, | chunkSize: chunkSize, | ||||
chunkLens: make([]uint64, total), | chunkLens: make([]uint64, total), | ||||
chunkMeta: []MetaData{}, | |||||
chunkMeta: make([]MetaData, 0, total), | |||||
} | } | ||||
return rv | return rv | ||||
for i := range c.chunkLens { | for i := range c.chunkLens { | ||||
c.chunkLens[i] = 0 | c.chunkLens[i] = 0 | ||||
} | } | ||||
c.chunkMeta = []MetaData{} | |||||
c.chunkMeta = c.chunkMeta[:0] | |||||
} | } | ||||
// Close indicates you are done calling Add() this allows | // Close indicates you are done calling Add() this allows | ||||
// write out the metaData slice | // write out the metaData slice | ||||
for _, meta := range c.chunkMeta { | for _, meta := range c.chunkMeta { | ||||
_, err := writeUvarints(&c.chunkMetaBuf, meta.DocID, meta.DocDvLoc, meta.DocDvLen) | |||||
_, err := writeUvarints(&c.chunkMetaBuf, meta.DocNum, meta.DocDvLoc, meta.DocDvLen) | |||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
} | } | ||||
// clearing the chunk specific meta for next chunk | // clearing the chunk specific meta for next chunk | ||||
c.chunkBuf.Reset() | c.chunkBuf.Reset() | ||||
c.chunkMetaBuf.Reset() | c.chunkMetaBuf.Reset() | ||||
c.chunkMeta = []MetaData{} | |||||
c.chunkMeta = c.chunkMeta[:0] | |||||
c.currChunk = chunk | c.currChunk = chunk | ||||
} | } | ||||
} | } | ||||
c.chunkMeta = append(c.chunkMeta, MetaData{ | c.chunkMeta = append(c.chunkMeta, MetaData{ | ||||
DocID: docNum, | |||||
DocNum: docNum, | |||||
DocDvLoc: uint64(dvOffset), | DocDvLoc: uint64(dvOffset), | ||||
DocDvLen: uint64(dvSize), | DocDvLen: uint64(dvSize), | ||||
}) | }) |
// PostingsList returns the postings list for the specified term | // PostingsList returns the postings list for the specified term | ||||
func (d *Dictionary) PostingsList(term string, except *roaring.Bitmap) (segment.PostingsList, error) { | func (d *Dictionary) PostingsList(term string, except *roaring.Bitmap) (segment.PostingsList, error) { | ||||
return d.postingsList([]byte(term), except) | |||||
return d.postingsList([]byte(term), except, nil) | |||||
} | } | ||||
func (d *Dictionary) postingsList(term []byte, except *roaring.Bitmap) (*PostingsList, error) { | |||||
rv := &PostingsList{ | |||||
sb: d.sb, | |||||
term: term, | |||||
except: except, | |||||
func (d *Dictionary) postingsList(term []byte, except *roaring.Bitmap, rv *PostingsList) (*PostingsList, error) { | |||||
if d.fst == nil { | |||||
return d.postingsListInit(rv, except), nil | |||||
} | } | ||||
if d.fst != nil { | |||||
postingsOffset, exists, err := d.fst.Get(term) | |||||
if err != nil { | |||||
return nil, fmt.Errorf("vellum err: %v", err) | |||||
} | |||||
if exists { | |||||
err = rv.read(postingsOffset, d) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
} | |||||
postingsOffset, exists, err := d.fst.Get(term) | |||||
if err != nil { | |||||
return nil, fmt.Errorf("vellum err: %v", err) | |||||
} | |||||
if !exists { | |||||
return d.postingsListInit(rv, except), nil | |||||
} | |||||
return d.postingsListFromOffset(postingsOffset, except, rv) | |||||
} | |||||
func (d *Dictionary) postingsListFromOffset(postingsOffset uint64, except *roaring.Bitmap, rv *PostingsList) (*PostingsList, error) { | |||||
rv = d.postingsListInit(rv, except) | |||||
err := rv.read(postingsOffset, d) | |||||
if err != nil { | |||||
return nil, err | |||||
} | } | ||||
return rv, nil | return rv, nil | ||||
} | } | ||||
func (d *Dictionary) postingsListInit(rv *PostingsList, except *roaring.Bitmap) *PostingsList { | |||||
if rv == nil { | |||||
rv = &PostingsList{} | |||||
} else { | |||||
*rv = PostingsList{} // clear the struct | |||||
} | |||||
rv.sb = d.sb | |||||
rv.except = except | |||||
return rv | |||||
} | |||||
// Iterator returns an iterator for this dictionary | // Iterator returns an iterator for this dictionary | ||||
func (d *Dictionary) Iterator() segment.DictionaryIterator { | func (d *Dictionary) Iterator() segment.DictionaryIterator { | ||||
rv := &DictionaryIterator{ | rv := &DictionaryIterator{ |
func (di *docValueIterator) loadDvChunk(chunkNumber, | func (di *docValueIterator) loadDvChunk(chunkNumber, | ||||
localDocNum uint64, s *SegmentBase) error { | localDocNum uint64, s *SegmentBase) error { | ||||
// advance to the chunk where the docValues | // advance to the chunk where the docValues | ||||
// reside for the given docID | |||||
// reside for the given docNum | |||||
destChunkDataLoc := di.dvDataLoc | destChunkDataLoc := di.dvDataLoc | ||||
for i := 0; i < int(chunkNumber); i++ { | for i := 0; i < int(chunkNumber); i++ { | ||||
destChunkDataLoc += di.chunkLens[i] | destChunkDataLoc += di.chunkLens[i] | ||||
offset := uint64(0) | offset := uint64(0) | ||||
di.curChunkHeader = make([]MetaData, int(numDocs)) | di.curChunkHeader = make([]MetaData, int(numDocs)) | ||||
for i := 0; i < int(numDocs); i++ { | for i := 0; i < int(numDocs); i++ { | ||||
di.curChunkHeader[i].DocID, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64]) | |||||
di.curChunkHeader[i].DocNum, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64]) | |||||
offset += uint64(read) | offset += uint64(read) | ||||
di.curChunkHeader[i].DocDvLoc, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64]) | di.curChunkHeader[i].DocDvLoc, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64]) | ||||
offset += uint64(read) | offset += uint64(read) | ||||
return nil | return nil | ||||
} | } | ||||
func (di *docValueIterator) visitDocValues(docID uint64, | |||||
func (di *docValueIterator) visitDocValues(docNum uint64, | |||||
visitor index.DocumentFieldTermVisitor) error { | visitor index.DocumentFieldTermVisitor) error { | ||||
// binary search the term locations for the docID | |||||
start, length := di.getDocValueLocs(docID) | |||||
// binary search the term locations for the docNum | |||||
start, length := di.getDocValueLocs(docNum) | |||||
if start == math.MaxUint64 || length == math.MaxUint64 { | if start == math.MaxUint64 || length == math.MaxUint64 { | ||||
return nil | return nil | ||||
} | } | ||||
return err | return err | ||||
} | } | ||||
// pick the terms for the given docID | |||||
// pick the terms for the given docNum | |||||
uncompressed = uncompressed[start : start+length] | uncompressed = uncompressed[start : start+length] | ||||
for { | for { | ||||
i := bytes.Index(uncompressed, termSeparatorSplitSlice) | i := bytes.Index(uncompressed, termSeparatorSplitSlice) | ||||
return nil | return nil | ||||
} | } | ||||
func (di *docValueIterator) getDocValueLocs(docID uint64) (uint64, uint64) { | |||||
func (di *docValueIterator) getDocValueLocs(docNum uint64) (uint64, uint64) { | |||||
i := sort.Search(len(di.curChunkHeader), func(i int) bool { | i := sort.Search(len(di.curChunkHeader), func(i int) bool { | ||||
return di.curChunkHeader[i].DocID >= docID | |||||
return di.curChunkHeader[i].DocNum >= docNum | |||||
}) | }) | ||||
if i < len(di.curChunkHeader) && di.curChunkHeader[i].DocID == docID { | |||||
if i < len(di.curChunkHeader) && di.curChunkHeader[i].DocNum == docNum { | |||||
return di.curChunkHeader[i].DocDvLoc, di.curChunkHeader[i].DocDvLen | return di.curChunkHeader[i].DocDvLoc, di.curChunkHeader[i].DocDvLen | ||||
} | } | ||||
return math.MaxUint64, math.MaxUint64 | return math.MaxUint64, math.MaxUint64 |
// Copyright (c) 2018 Couchbase, Inc. | |||||
// | |||||
// Licensed under the Apache License, Version 2.0 (the "License"); | |||||
// you may not use this file except in compliance with the License. | |||||
// You may obtain a copy of the License at | |||||
// | |||||
// http://www.apache.org/licenses/LICENSE-2.0 | |||||
// | |||||
// Unless required by applicable law or agreed to in writing, software | |||||
// distributed under the License is distributed on an "AS IS" BASIS, | |||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
// See the License for the specific language governing permissions and | |||||
// limitations under the License. | |||||
package zap | |||||
import ( | |||||
"bytes" | |||||
"github.com/couchbase/vellum" | |||||
) | |||||
// enumerator provides an ordered traversal of multiple vellum | |||||
// iterators. Like JOIN of iterators, the enumerator produces a | |||||
// sequence of (key, iteratorIndex, value) tuples, sorted by key ASC, | |||||
// then iteratorIndex ASC, where the same key might be seen or | |||||
// repeated across multiple child iterators. | |||||
type enumerator struct { | |||||
itrs []vellum.Iterator | |||||
currKs [][]byte | |||||
currVs []uint64 | |||||
lowK []byte | |||||
lowIdxs []int | |||||
lowCurr int | |||||
} | |||||
// newEnumerator returns a new enumerator over the vellum Iterators | |||||
func newEnumerator(itrs []vellum.Iterator) (*enumerator, error) { | |||||
rv := &enumerator{ | |||||
itrs: itrs, | |||||
currKs: make([][]byte, len(itrs)), | |||||
currVs: make([]uint64, len(itrs)), | |||||
lowIdxs: make([]int, 0, len(itrs)), | |||||
} | |||||
for i, itr := range rv.itrs { | |||||
rv.currKs[i], rv.currVs[i] = itr.Current() | |||||
} | |||||
rv.updateMatches() | |||||
if rv.lowK == nil { | |||||
return rv, vellum.ErrIteratorDone | |||||
} | |||||
return rv, nil | |||||
} | |||||
// updateMatches maintains the low key matches based on the currKs | |||||
func (m *enumerator) updateMatches() { | |||||
m.lowK = nil | |||||
m.lowIdxs = m.lowIdxs[:0] | |||||
m.lowCurr = 0 | |||||
for i, key := range m.currKs { | |||||
if key == nil { | |||||
continue | |||||
} | |||||
cmp := bytes.Compare(key, m.lowK) | |||||
if cmp < 0 || m.lowK == nil { | |||||
// reached a new low | |||||
m.lowK = key | |||||
m.lowIdxs = m.lowIdxs[:0] | |||||
m.lowIdxs = append(m.lowIdxs, i) | |||||
} else if cmp == 0 { | |||||
m.lowIdxs = append(m.lowIdxs, i) | |||||
} | |||||
} | |||||
} | |||||
// Current returns the enumerator's current key, iterator-index, and | |||||
// value. If the enumerator is not pointing at a valid value (because | |||||
// Next returned an error previously), Current will return nil,0,0. | |||||
func (m *enumerator) Current() ([]byte, int, uint64) { | |||||
var i int | |||||
var v uint64 | |||||
if m.lowCurr < len(m.lowIdxs) { | |||||
i = m.lowIdxs[m.lowCurr] | |||||
v = m.currVs[i] | |||||
} | |||||
return m.lowK, i, v | |||||
} | |||||
// Next advances the enumerator to the next key/iterator/value result, | |||||
// else vellum.ErrIteratorDone is returned. | |||||
func (m *enumerator) Next() error { | |||||
m.lowCurr += 1 | |||||
if m.lowCurr >= len(m.lowIdxs) { | |||||
// move all the current low iterators forwards | |||||
for _, vi := range m.lowIdxs { | |||||
err := m.itrs[vi].Next() | |||||
if err != nil && err != vellum.ErrIteratorDone { | |||||
return err | |||||
} | |||||
m.currKs[vi], m.currVs[vi] = m.itrs[vi].Current() | |||||
} | |||||
m.updateMatches() | |||||
} | |||||
if m.lowK == nil { | |||||
return vellum.ErrIteratorDone | |||||
} | |||||
return nil | |||||
} | |||||
// Close all the underlying Iterators. The first error, if any, will | |||||
// be returned. | |||||
func (m *enumerator) Close() error { | |||||
var rv error | |||||
for _, itr := range m.itrs { | |||||
err := itr.Close() | |||||
if rv == nil { | |||||
rv = err | |||||
} | |||||
} | |||||
return rv | |||||
} |
encoder *govarint.Base128Encoder | encoder *govarint.Base128Encoder | ||||
chunkLens []uint64 | chunkLens []uint64 | ||||
currChunk uint64 | currChunk uint64 | ||||
buf []byte | |||||
} | } | ||||
// newChunkedIntCoder returns a new chunk int coder which packs data into | // newChunkedIntCoder returns a new chunk int coder which packs data into | ||||
// starting a new chunk | // starting a new chunk | ||||
if c.encoder != nil { | if c.encoder != nil { | ||||
// close out last | // close out last | ||||
c.encoder.Close() | |||||
encodingBytes := c.chunkBuf.Bytes() | |||||
c.chunkLens[c.currChunk] = uint64(len(encodingBytes)) | |||||
c.final = append(c.final, encodingBytes...) | |||||
c.Close() | |||||
c.chunkBuf.Reset() | c.chunkBuf.Reset() | ||||
c.encoder = govarint.NewU64Base128Encoder(&c.chunkBuf) | |||||
} | } | ||||
c.currChunk = chunk | c.currChunk = chunk | ||||
} | } | ||||
// Write commits all the encoded chunked integers to the provided writer. | // Write commits all the encoded chunked integers to the provided writer. | ||||
func (c *chunkedIntCoder) Write(w io.Writer) (int, error) { | func (c *chunkedIntCoder) Write(w io.Writer) (int, error) { | ||||
var tw int | |||||
buf := make([]byte, binary.MaxVarintLen64) | |||||
// write out the number of chunks | |||||
bufNeeded := binary.MaxVarintLen64 * (1 + len(c.chunkLens)) | |||||
if len(c.buf) < bufNeeded { | |||||
c.buf = make([]byte, bufNeeded) | |||||
} | |||||
buf := c.buf | |||||
// write out the number of chunks & each chunkLen | |||||
n := binary.PutUvarint(buf, uint64(len(c.chunkLens))) | n := binary.PutUvarint(buf, uint64(len(c.chunkLens))) | ||||
nw, err := w.Write(buf[:n]) | |||||
tw += nw | |||||
for _, chunkLen := range c.chunkLens { | |||||
n += binary.PutUvarint(buf[n:], uint64(chunkLen)) | |||||
} | |||||
tw, err := w.Write(buf[:n]) | |||||
if err != nil { | if err != nil { | ||||
return tw, err | return tw, err | ||||
} | } | ||||
// write out the chunk lens | |||||
for _, chunkLen := range c.chunkLens { | |||||
n := binary.PutUvarint(buf, uint64(chunkLen)) | |||||
nw, err = w.Write(buf[:n]) | |||||
tw += nw | |||||
if err != nil { | |||||
return tw, err | |||||
} | |||||
} | |||||
// write out the data | // write out the data | ||||
nw, err = w.Write(c.final) | |||||
nw, err := w.Write(c.final) | |||||
tw += nw | tw += nw | ||||
if err != nil { | if err != nil { | ||||
return tw, err | return tw, err |
"fmt" | "fmt" | ||||
"math" | "math" | ||||
"os" | "os" | ||||
"sort" | |||||
"github.com/RoaringBitmap/roaring" | "github.com/RoaringBitmap/roaring" | ||||
"github.com/Smerity/govarint" | "github.com/Smerity/govarint" | ||||
"github.com/golang/snappy" | "github.com/golang/snappy" | ||||
) | ) | ||||
const docDropped = math.MaxUint64 // sentinel docNum to represent a deleted doc | |||||
// Merge takes a slice of zap segments and bit masks describing which | // Merge takes a slice of zap segments and bit masks describing which | ||||
// documents may be dropped, and creates a new segment containing the | // documents may be dropped, and creates a new segment containing the | ||||
// remaining data. This new segment is built at the specified path, | // remaining data. This new segment is built at the specified path, | ||||
_ = os.Remove(path) | _ = os.Remove(path) | ||||
} | } | ||||
segmentBases := make([]*SegmentBase, len(segments)) | |||||
for segmenti, segment := range segments { | |||||
segmentBases[segmenti] = &segment.SegmentBase | |||||
} | |||||
// buffer the output | // buffer the output | ||||
br := bufio.NewWriter(f) | br := bufio.NewWriter(f) | ||||
// wrap it for counting (tracking offsets) | // wrap it for counting (tracking offsets) | ||||
cr := NewCountHashWriter(br) | cr := NewCountHashWriter(br) | ||||
fieldsInv := mergeFields(segments) | |||||
fieldsMap := mapFields(fieldsInv) | |||||
var newDocNums [][]uint64 | |||||
var storedIndexOffset uint64 | |||||
fieldDvLocsOffset := uint64(fieldNotUninverted) | |||||
var dictLocs []uint64 | |||||
newSegDocCount := computeNewDocCount(segments, drops) | |||||
if newSegDocCount > 0 { | |||||
storedIndexOffset, newDocNums, err = mergeStoredAndRemap(segments, drops, | |||||
fieldsMap, fieldsInv, newSegDocCount, cr) | |||||
if err != nil { | |||||
cleanup() | |||||
return nil, err | |||||
} | |||||
dictLocs, fieldDvLocsOffset, err = persistMergedRest(segments, drops, fieldsInv, fieldsMap, | |||||
newDocNums, newSegDocCount, chunkFactor, cr) | |||||
if err != nil { | |||||
cleanup() | |||||
return nil, err | |||||
} | |||||
} else { | |||||
dictLocs = make([]uint64, len(fieldsInv)) | |||||
} | |||||
fieldsIndexOffset, err := persistFields(fieldsInv, cr, dictLocs) | |||||
newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset, _, _, _, err := | |||||
MergeToWriter(segmentBases, drops, chunkFactor, cr) | |||||
if err != nil { | if err != nil { | ||||
cleanup() | cleanup() | ||||
return nil, err | return nil, err | ||||
} | } | ||||
err = persistFooter(newSegDocCount, storedIndexOffset, | |||||
fieldsIndexOffset, fieldDvLocsOffset, chunkFactor, cr.Sum32(), cr) | |||||
err = persistFooter(numDocs, storedIndexOffset, fieldsIndexOffset, | |||||
docValueOffset, chunkFactor, cr.Sum32(), cr) | |||||
if err != nil { | if err != nil { | ||||
cleanup() | cleanup() | ||||
return nil, err | return nil, err | ||||
return newDocNums, nil | return newDocNums, nil | ||||
} | } | ||||
// mapFields takes the fieldsInv list and builds the map | |||||
func MergeToWriter(segments []*SegmentBase, drops []*roaring.Bitmap, | |||||
chunkFactor uint32, cr *CountHashWriter) ( | |||||
newDocNums [][]uint64, | |||||
numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset uint64, | |||||
dictLocs []uint64, fieldsInv []string, fieldsMap map[string]uint16, | |||||
err error) { | |||||
docValueOffset = uint64(fieldNotUninverted) | |||||
var fieldsSame bool | |||||
fieldsSame, fieldsInv = mergeFields(segments) | |||||
fieldsMap = mapFields(fieldsInv) | |||||
numDocs = computeNewDocCount(segments, drops) | |||||
if numDocs > 0 { | |||||
storedIndexOffset, newDocNums, err = mergeStoredAndRemap(segments, drops, | |||||
fieldsMap, fieldsInv, fieldsSame, numDocs, cr) | |||||
if err != nil { | |||||
return nil, 0, 0, 0, 0, nil, nil, nil, err | |||||
} | |||||
dictLocs, docValueOffset, err = persistMergedRest(segments, drops, fieldsInv, fieldsMap, | |||||
newDocNums, numDocs, chunkFactor, cr) | |||||
if err != nil { | |||||
return nil, 0, 0, 0, 0, nil, nil, nil, err | |||||
} | |||||
} else { | |||||
dictLocs = make([]uint64, len(fieldsInv)) | |||||
} | |||||
fieldsIndexOffset, err = persistFields(fieldsInv, cr, dictLocs) | |||||
if err != nil { | |||||
return nil, 0, 0, 0, 0, nil, nil, nil, err | |||||
} | |||||
return newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset, dictLocs, fieldsInv, fieldsMap, nil | |||||
} | |||||
// mapFields takes the fieldsInv list and returns a map of fieldName | |||||
// to fieldID+1 | |||||
func mapFields(fields []string) map[string]uint16 { | func mapFields(fields []string) map[string]uint16 { | ||||
rv := make(map[string]uint16, len(fields)) | rv := make(map[string]uint16, len(fields)) | ||||
for i, fieldName := range fields { | for i, fieldName := range fields { | ||||
rv[fieldName] = uint16(i) | |||||
rv[fieldName] = uint16(i) + 1 | |||||
} | } | ||||
return rv | return rv | ||||
} | } | ||||
// computeNewDocCount determines how many documents will be in the newly | // computeNewDocCount determines how many documents will be in the newly | ||||
// merged segment when obsoleted docs are dropped | // merged segment when obsoleted docs are dropped | ||||
func computeNewDocCount(segments []*Segment, drops []*roaring.Bitmap) uint64 { | |||||
func computeNewDocCount(segments []*SegmentBase, drops []*roaring.Bitmap) uint64 { | |||||
var newDocCount uint64 | var newDocCount uint64 | ||||
for segI, segment := range segments { | for segI, segment := range segments { | ||||
newDocCount += segment.NumDocs() | |||||
newDocCount += segment.numDocs | |||||
if drops[segI] != nil { | if drops[segI] != nil { | ||||
newDocCount -= drops[segI].GetCardinality() | newDocCount -= drops[segI].GetCardinality() | ||||
} | } | ||||
return newDocCount | return newDocCount | ||||
} | } | ||||
func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap, | |||||
fieldsInv []string, fieldsMap map[string]uint16, newDocNums [][]uint64, | |||||
func persistMergedRest(segments []*SegmentBase, dropsIn []*roaring.Bitmap, | |||||
fieldsInv []string, fieldsMap map[string]uint16, newDocNumsIn [][]uint64, | |||||
newSegDocCount uint64, chunkFactor uint32, | newSegDocCount uint64, chunkFactor uint32, | ||||
w *CountHashWriter) ([]uint64, uint64, error) { | w *CountHashWriter) ([]uint64, uint64, error) { | ||||
var bufMaxVarintLen64 []byte = make([]byte, binary.MaxVarintLen64) | var bufMaxVarintLen64 []byte = make([]byte, binary.MaxVarintLen64) | ||||
var bufLoc []uint64 | var bufLoc []uint64 | ||||
var postings *PostingsList | |||||
var postItr *PostingsIterator | |||||
rv := make([]uint64, len(fieldsInv)) | rv := make([]uint64, len(fieldsInv)) | ||||
fieldDvLocs := make([]uint64, len(fieldsInv)) | fieldDvLocs := make([]uint64, len(fieldsInv)) | ||||
fieldDvLocsOffset := uint64(fieldNotUninverted) | |||||
tfEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1) | |||||
locEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1) | |||||
// docTermMap is keyed by docNum, where the array impl provides | // docTermMap is keyed by docNum, where the array impl provides | ||||
// better memory usage behavior than a sparse-friendlier hashmap | // better memory usage behavior than a sparse-friendlier hashmap | ||||
return nil, 0, err | return nil, 0, err | ||||
} | } | ||||
// collect FST iterators from all segments for this field | |||||
// collect FST iterators from all active segments for this field | |||||
var newDocNums [][]uint64 | |||||
var drops []*roaring.Bitmap | |||||
var dicts []*Dictionary | var dicts []*Dictionary | ||||
var itrs []vellum.Iterator | var itrs []vellum.Iterator | ||||
for _, segment := range segments { | |||||
for segmentI, segment := range segments { | |||||
dict, err2 := segment.dictionary(fieldName) | dict, err2 := segment.dictionary(fieldName) | ||||
if err2 != nil { | if err2 != nil { | ||||
return nil, 0, err2 | return nil, 0, err2 | ||||
} | } | ||||
dicts = append(dicts, dict) | |||||
if dict != nil && dict.fst != nil { | if dict != nil && dict.fst != nil { | ||||
itr, err2 := dict.fst.Iterator(nil, nil) | itr, err2 := dict.fst.Iterator(nil, nil) | ||||
if err2 != nil && err2 != vellum.ErrIteratorDone { | if err2 != nil && err2 != vellum.ErrIteratorDone { | ||||
return nil, 0, err2 | return nil, 0, err2 | ||||
} | } | ||||
if itr != nil { | if itr != nil { | ||||
newDocNums = append(newDocNums, newDocNumsIn[segmentI]) | |||||
drops = append(drops, dropsIn[segmentI]) | |||||
dicts = append(dicts, dict) | |||||
itrs = append(itrs, itr) | itrs = append(itrs, itr) | ||||
} | } | ||||
} | } | ||||
} | } | ||||
// create merging iterator | |||||
mergeItr, err := vellum.NewMergeIterator(itrs, func(postingOffsets []uint64) uint64 { | |||||
// we don't actually use the merged value | |||||
return 0 | |||||
}) | |||||
tfEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1) | |||||
locEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1) | |||||
if uint64(cap(docTermMap)) < newSegDocCount { | if uint64(cap(docTermMap)) < newSegDocCount { | ||||
docTermMap = make([][]byte, newSegDocCount) | docTermMap = make([][]byte, newSegDocCount) | ||||
} else { | } else { | ||||
} | } | ||||
} | } | ||||
for err == nil { | |||||
term, _ := mergeItr.Current() | |||||
newRoaring := roaring.NewBitmap() | |||||
newRoaringLocs := roaring.NewBitmap() | |||||
tfEncoder.Reset() | |||||
locEncoder.Reset() | |||||
// now go back and get posting list for this term | |||||
// but pass in the deleted docs for that segment | |||||
for dictI, dict := range dicts { | |||||
if dict == nil { | |||||
continue | |||||
} | |||||
postings, err2 := dict.postingsList(term, drops[dictI]) | |||||
if err2 != nil { | |||||
return nil, 0, err2 | |||||
} | |||||
postItr := postings.Iterator() | |||||
next, err2 := postItr.Next() | |||||
for next != nil && err2 == nil { | |||||
hitNewDocNum := newDocNums[dictI][next.Number()] | |||||
if hitNewDocNum == docDropped { | |||||
return nil, 0, fmt.Errorf("see hit with dropped doc num") | |||||
} | |||||
newRoaring.Add(uint32(hitNewDocNum)) | |||||
// encode norm bits | |||||
norm := next.Norm() | |||||
normBits := math.Float32bits(float32(norm)) | |||||
err = tfEncoder.Add(hitNewDocNum, next.Frequency(), uint64(normBits)) | |||||
if err != nil { | |||||
return nil, 0, err | |||||
} | |||||
locs := next.Locations() | |||||
if len(locs) > 0 { | |||||
newRoaringLocs.Add(uint32(hitNewDocNum)) | |||||
for _, loc := range locs { | |||||
if cap(bufLoc) < 5+len(loc.ArrayPositions()) { | |||||
bufLoc = make([]uint64, 0, 5+len(loc.ArrayPositions())) | |||||
} | |||||
args := bufLoc[0:5] | |||||
args[0] = uint64(fieldsMap[loc.Field()]) | |||||
args[1] = loc.Pos() | |||||
args[2] = loc.Start() | |||||
args[3] = loc.End() | |||||
args[4] = uint64(len(loc.ArrayPositions())) | |||||
args = append(args, loc.ArrayPositions()...) | |||||
err = locEncoder.Add(hitNewDocNum, args...) | |||||
if err != nil { | |||||
return nil, 0, err | |||||
} | |||||
} | |||||
} | |||||
var prevTerm []byte | |||||
docTermMap[hitNewDocNum] = | |||||
append(append(docTermMap[hitNewDocNum], term...), termSeparator) | |||||
newRoaring := roaring.NewBitmap() | |||||
newRoaringLocs := roaring.NewBitmap() | |||||
next, err2 = postItr.Next() | |||||
} | |||||
if err2 != nil { | |||||
return nil, 0, err2 | |||||
} | |||||
finishTerm := func(term []byte) error { | |||||
if term == nil { | |||||
return nil | |||||
} | } | ||||
tfEncoder.Close() | tfEncoder.Close() | ||||
if newRoaring.GetCardinality() > 0 { | if newRoaring.GetCardinality() > 0 { | ||||
// this field/term actually has hits in the new segment, lets write it down | // this field/term actually has hits in the new segment, lets write it down | ||||
freqOffset := uint64(w.Count()) | freqOffset := uint64(w.Count()) | ||||
_, err = tfEncoder.Write(w) | |||||
_, err := tfEncoder.Write(w) | |||||
if err != nil { | if err != nil { | ||||
return nil, 0, err | |||||
return err | |||||
} | } | ||||
locOffset := uint64(w.Count()) | locOffset := uint64(w.Count()) | ||||
_, err = locEncoder.Write(w) | _, err = locEncoder.Write(w) | ||||
if err != nil { | if err != nil { | ||||
return nil, 0, err | |||||
return err | |||||
} | } | ||||
postingLocOffset := uint64(w.Count()) | postingLocOffset := uint64(w.Count()) | ||||
_, err = writeRoaringWithLen(newRoaringLocs, w, &bufReuse, bufMaxVarintLen64) | _, err = writeRoaringWithLen(newRoaringLocs, w, &bufReuse, bufMaxVarintLen64) | ||||
if err != nil { | if err != nil { | ||||
return nil, 0, err | |||||
return err | |||||
} | } | ||||
postingOffset := uint64(w.Count()) | postingOffset := uint64(w.Count()) | ||||
// write out the start of the term info | // write out the start of the term info | ||||
buf := bufMaxVarintLen64 | |||||
n := binary.PutUvarint(buf, freqOffset) | |||||
_, err = w.Write(buf[:n]) | |||||
n := binary.PutUvarint(bufMaxVarintLen64, freqOffset) | |||||
_, err = w.Write(bufMaxVarintLen64[:n]) | |||||
if err != nil { | if err != nil { | ||||
return nil, 0, err | |||||
return err | |||||
} | } | ||||
// write out the start of the loc info | // write out the start of the loc info | ||||
n = binary.PutUvarint(buf, locOffset) | |||||
_, err = w.Write(buf[:n]) | |||||
n = binary.PutUvarint(bufMaxVarintLen64, locOffset) | |||||
_, err = w.Write(bufMaxVarintLen64[:n]) | |||||
if err != nil { | if err != nil { | ||||
return nil, 0, err | |||||
return err | |||||
} | } | ||||
// write out the start of the loc posting list | |||||
n = binary.PutUvarint(buf, postingLocOffset) | |||||
_, err = w.Write(buf[:n]) | |||||
// write out the start of the posting locs | |||||
n = binary.PutUvarint(bufMaxVarintLen64, postingLocOffset) | |||||
_, err = w.Write(bufMaxVarintLen64[:n]) | |||||
if err != nil { | if err != nil { | ||||
return nil, 0, err | |||||
return err | |||||
} | } | ||||
_, err = writeRoaringWithLen(newRoaring, w, &bufReuse, bufMaxVarintLen64) | _, err = writeRoaringWithLen(newRoaring, w, &bufReuse, bufMaxVarintLen64) | ||||
if err != nil { | if err != nil { | ||||
return nil, 0, err | |||||
return err | |||||
} | } | ||||
err = newVellum.Insert(term, postingOffset) | err = newVellum.Insert(term, postingOffset) | ||||
if err != nil { | |||||
return err | |||||
} | |||||
} | |||||
newRoaring = roaring.NewBitmap() | |||||
newRoaringLocs = roaring.NewBitmap() | |||||
tfEncoder.Reset() | |||||
locEncoder.Reset() | |||||
return nil | |||||
} | |||||
enumerator, err := newEnumerator(itrs) | |||||
for err == nil { | |||||
term, itrI, postingsOffset := enumerator.Current() | |||||
if !bytes.Equal(prevTerm, term) { | |||||
// if the term changed, write out the info collected | |||||
// for the previous term | |||||
err2 := finishTerm(prevTerm) | |||||
if err2 != nil { | |||||
return nil, 0, err2 | |||||
} | |||||
} | |||||
var err2 error | |||||
postings, err2 = dicts[itrI].postingsListFromOffset( | |||||
postingsOffset, drops[itrI], postings) | |||||
if err2 != nil { | |||||
return nil, 0, err2 | |||||
} | |||||
newDocNumsI := newDocNums[itrI] | |||||
postItr = postings.iterator(postItr) | |||||
next, err2 := postItr.Next() | |||||
for next != nil && err2 == nil { | |||||
hitNewDocNum := newDocNumsI[next.Number()] | |||||
if hitNewDocNum == docDropped { | |||||
return nil, 0, fmt.Errorf("see hit with dropped doc num") | |||||
} | |||||
newRoaring.Add(uint32(hitNewDocNum)) | |||||
// encode norm bits | |||||
norm := next.Norm() | |||||
normBits := math.Float32bits(float32(norm)) | |||||
err = tfEncoder.Add(hitNewDocNum, next.Frequency(), uint64(normBits)) | |||||
if err != nil { | if err != nil { | ||||
return nil, 0, err | return nil, 0, err | ||||
} | } | ||||
locs := next.Locations() | |||||
if len(locs) > 0 { | |||||
newRoaringLocs.Add(uint32(hitNewDocNum)) | |||||
for _, loc := range locs { | |||||
if cap(bufLoc) < 5+len(loc.ArrayPositions()) { | |||||
bufLoc = make([]uint64, 0, 5+len(loc.ArrayPositions())) | |||||
} | |||||
args := bufLoc[0:5] | |||||
args[0] = uint64(fieldsMap[loc.Field()] - 1) | |||||
args[1] = loc.Pos() | |||||
args[2] = loc.Start() | |||||
args[3] = loc.End() | |||||
args[4] = uint64(len(loc.ArrayPositions())) | |||||
args = append(args, loc.ArrayPositions()...) | |||||
err = locEncoder.Add(hitNewDocNum, args...) | |||||
if err != nil { | |||||
return nil, 0, err | |||||
} | |||||
} | |||||
} | |||||
docTermMap[hitNewDocNum] = | |||||
append(append(docTermMap[hitNewDocNum], term...), termSeparator) | |||||
next, err2 = postItr.Next() | |||||
} | |||||
if err2 != nil { | |||||
return nil, 0, err2 | |||||
} | } | ||||
err = mergeItr.Next() | |||||
prevTerm = prevTerm[:0] // copy to prevTerm in case Next() reuses term mem | |||||
prevTerm = append(prevTerm, term...) | |||||
err = enumerator.Next() | |||||
} | } | ||||
if err != nil && err != vellum.ErrIteratorDone { | if err != nil && err != vellum.ErrIteratorDone { | ||||
return nil, 0, err | return nil, 0, err | ||||
} | } | ||||
err = finishTerm(prevTerm) | |||||
if err != nil { | |||||
return nil, 0, err | |||||
} | |||||
dictOffset := uint64(w.Count()) | dictOffset := uint64(w.Count()) | ||||
err = newVellum.Close() | err = newVellum.Close() | ||||
} | } | ||||
} | } | ||||
fieldDvLocsOffset = uint64(w.Count()) | |||||
fieldDvLocsOffset := uint64(w.Count()) | |||||
buf := bufMaxVarintLen64 | buf := bufMaxVarintLen64 | ||||
for _, offset := range fieldDvLocs { | for _, offset := range fieldDvLocs { | ||||
return rv, fieldDvLocsOffset, nil | return rv, fieldDvLocsOffset, nil | ||||
} | } | ||||
const docDropped = math.MaxUint64 | |||||
func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap, | |||||
fieldsMap map[string]uint16, fieldsInv []string, newSegDocCount uint64, | |||||
func mergeStoredAndRemap(segments []*SegmentBase, drops []*roaring.Bitmap, | |||||
fieldsMap map[string]uint16, fieldsInv []string, fieldsSame bool, newSegDocCount uint64, | |||||
w *CountHashWriter) (uint64, [][]uint64, error) { | w *CountHashWriter) (uint64, [][]uint64, error) { | ||||
var rv [][]uint64 // The remapped or newDocNums for each segment. | var rv [][]uint64 // The remapped or newDocNums for each segment. | ||||
for segI, segment := range segments { | for segI, segment := range segments { | ||||
segNewDocNums := make([]uint64, segment.numDocs) | segNewDocNums := make([]uint64, segment.numDocs) | ||||
dropsI := drops[segI] | |||||
// optimize when the field mapping is the same across all | |||||
// segments and there are no deletions, via byte-copying | |||||
// of stored docs bytes directly to the writer | |||||
if fieldsSame && (dropsI == nil || dropsI.GetCardinality() == 0) { | |||||
err := segment.copyStoredDocs(newDocNum, docNumOffsets, w) | |||||
if err != nil { | |||||
return 0, nil, err | |||||
} | |||||
for i := uint64(0); i < segment.numDocs; i++ { | |||||
segNewDocNums[i] = newDocNum | |||||
newDocNum++ | |||||
} | |||||
rv = append(rv, segNewDocNums) | |||||
continue | |||||
} | |||||
// for each doc num | // for each doc num | ||||
for docNum := uint64(0); docNum < segment.numDocs; docNum++ { | for docNum := uint64(0); docNum < segment.numDocs; docNum++ { | ||||
// TODO: roaring's API limits docNums to 32-bits? | // TODO: roaring's API limits docNums to 32-bits? | ||||
if drops[segI] != nil && drops[segI].Contains(uint32(docNum)) { | |||||
if dropsI != nil && dropsI.Contains(uint32(docNum)) { | |||||
segNewDocNums[docNum] = docDropped | segNewDocNums[docNum] = docDropped | ||||
continue | continue | ||||
} | } | ||||
poss[i] = poss[i][:0] | poss[i] = poss[i][:0] | ||||
} | } | ||||
err := segment.VisitDocument(docNum, func(field string, typ byte, value []byte, pos []uint64) bool { | err := segment.VisitDocument(docNum, func(field string, typ byte, value []byte, pos []uint64) bool { | ||||
fieldID := int(fieldsMap[field]) | |||||
fieldID := int(fieldsMap[field]) - 1 | |||||
vals[fieldID] = append(vals[fieldID], value) | vals[fieldID] = append(vals[fieldID], value) | ||||
typs[fieldID] = append(typs[fieldID], typ) | typs[fieldID] = append(typs[fieldID], typ) | ||||
poss[fieldID] = append(poss[fieldID], pos) | poss[fieldID] = append(poss[fieldID], pos) | ||||
for fieldID := range fieldsInv { | for fieldID := range fieldsInv { | ||||
storedFieldValues := vals[int(fieldID)] | storedFieldValues := vals[int(fieldID)] | ||||
// has stored values for this field | |||||
num := len(storedFieldValues) | |||||
stf := typs[int(fieldID)] | |||||
spf := poss[int(fieldID)] | |||||
// process each value | |||||
for i := 0; i < num; i++ { | |||||
// encode field | |||||
_, err2 := metaEncoder.PutU64(uint64(fieldID)) | |||||
if err2 != nil { | |||||
return 0, nil, err2 | |||||
} | |||||
// encode type | |||||
_, err2 = metaEncoder.PutU64(uint64(typs[int(fieldID)][i])) | |||||
if err2 != nil { | |||||
return 0, nil, err2 | |||||
} | |||||
// encode start offset | |||||
_, err2 = metaEncoder.PutU64(uint64(curr)) | |||||
if err2 != nil { | |||||
return 0, nil, err2 | |||||
} | |||||
// end len | |||||
_, err2 = metaEncoder.PutU64(uint64(len(storedFieldValues[i]))) | |||||
if err2 != nil { | |||||
return 0, nil, err2 | |||||
} | |||||
// encode number of array pos | |||||
_, err2 = metaEncoder.PutU64(uint64(len(poss[int(fieldID)][i]))) | |||||
if err2 != nil { | |||||
return 0, nil, err2 | |||||
} | |||||
// encode all array positions | |||||
for j := 0; j < len(poss[int(fieldID)][i]); j++ { | |||||
_, err2 = metaEncoder.PutU64(poss[int(fieldID)][i][j]) | |||||
if err2 != nil { | |||||
return 0, nil, err2 | |||||
} | |||||
} | |||||
// append data | |||||
data = append(data, storedFieldValues[i]...) | |||||
// update curr | |||||
curr += len(storedFieldValues[i]) | |||||
var err2 error | |||||
curr, data, err2 = persistStoredFieldValues(fieldID, | |||||
storedFieldValues, stf, spf, curr, metaEncoder, data) | |||||
if err2 != nil { | |||||
return 0, nil, err2 | |||||
} | } | ||||
} | } | ||||
} | } | ||||
// return value is the start of the stored index | // return value is the start of the stored index | ||||
offset := uint64(w.Count()) | |||||
storedIndexOffset := uint64(w.Count()) | |||||
// now write out the stored doc index | // now write out the stored doc index | ||||
for docNum := range docNumOffsets { | |||||
err := binary.Write(w, binary.BigEndian, docNumOffsets[docNum]) | |||||
for _, docNumOffset := range docNumOffsets { | |||||
err := binary.Write(w, binary.BigEndian, docNumOffset) | |||||
if err != nil { | if err != nil { | ||||
return 0, nil, err | return 0, nil, err | ||||
} | } | ||||
} | } | ||||
return offset, rv, nil | |||||
return storedIndexOffset, rv, nil | |||||
} | } | ||||
// mergeFields builds a unified list of fields used across all the input segments | |||||
func mergeFields(segments []*Segment) []string { | |||||
fieldsMap := map[string]struct{}{} | |||||
// copyStoredDocs writes out a segment's stored doc info, optimized by | |||||
// using a single Write() call for the entire set of bytes. The | |||||
// newDocNumOffsets is filled with the new offsets for each doc. | |||||
func (s *SegmentBase) copyStoredDocs(newDocNum uint64, newDocNumOffsets []uint64, | |||||
w *CountHashWriter) error { | |||||
if s.numDocs <= 0 { | |||||
return nil | |||||
} | |||||
indexOffset0, storedOffset0, _, _, _ := | |||||
s.getDocStoredOffsets(0) // the segment's first doc | |||||
indexOffsetN, storedOffsetN, readN, metaLenN, dataLenN := | |||||
s.getDocStoredOffsets(s.numDocs - 1) // the segment's last doc | |||||
storedOffset0New := uint64(w.Count()) | |||||
storedBytes := s.mem[storedOffset0 : storedOffsetN+readN+metaLenN+dataLenN] | |||||
_, err := w.Write(storedBytes) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
// remap the storedOffset's for the docs into new offsets relative | |||||
// to storedOffset0New, filling the given docNumOffsetsOut array | |||||
for indexOffset := indexOffset0; indexOffset <= indexOffsetN; indexOffset += 8 { | |||||
storedOffset := binary.BigEndian.Uint64(s.mem[indexOffset : indexOffset+8]) | |||||
storedOffsetNew := storedOffset - storedOffset0 + storedOffset0New | |||||
newDocNumOffsets[newDocNum] = storedOffsetNew | |||||
newDocNum += 1 | |||||
} | |||||
return nil | |||||
} | |||||
// mergeFields builds a unified list of fields used across all the | |||||
// input segments, and computes whether the fields are the same across | |||||
// segments (which depends on fields to be sorted in the same way | |||||
// across segments) | |||||
func mergeFields(segments []*SegmentBase) (bool, []string) { | |||||
fieldsSame := true | |||||
var segment0Fields []string | |||||
if len(segments) > 0 { | |||||
segment0Fields = segments[0].Fields() | |||||
} | |||||
fieldsExist := map[string]struct{}{} | |||||
for _, segment := range segments { | for _, segment := range segments { | ||||
fields := segment.Fields() | fields := segment.Fields() | ||||
for _, field := range fields { | |||||
fieldsMap[field] = struct{}{} | |||||
for fieldi, field := range fields { | |||||
fieldsExist[field] = struct{}{} | |||||
if len(segment0Fields) != len(fields) || segment0Fields[fieldi] != field { | |||||
fieldsSame = false | |||||
} | |||||
} | } | ||||
} | } | ||||
rv := make([]string, 0, len(fieldsMap)) | |||||
rv := make([]string, 0, len(fieldsExist)) | |||||
// ensure _id stays first | // ensure _id stays first | ||||
rv = append(rv, "_id") | rv = append(rv, "_id") | ||||
for k := range fieldsMap { | |||||
for k := range fieldsExist { | |||||
if k != "_id" { | if k != "_id" { | ||||
rv = append(rv, k) | rv = append(rv, k) | ||||
} | } | ||||
} | } | ||||
return rv | |||||
sort.Strings(rv[1:]) // leave _id as first | |||||
return fieldsSame, rv | |||||
} | } |
// PostingsList is an in-memory represenation of a postings list | // PostingsList is an in-memory represenation of a postings list | ||||
type PostingsList struct { | type PostingsList struct { | ||||
sb *SegmentBase | sb *SegmentBase | ||||
term []byte | |||||
postingsOffset uint64 | postingsOffset uint64 | ||||
freqOffset uint64 | freqOffset uint64 | ||||
locOffset uint64 | locOffset uint64 | ||||
locBitmap *roaring.Bitmap | locBitmap *roaring.Bitmap | ||||
postings *roaring.Bitmap | postings *roaring.Bitmap | ||||
except *roaring.Bitmap | except *roaring.Bitmap | ||||
postingKey []byte | |||||
} | } | ||||
// Iterator returns an iterator for this postings list | // Iterator returns an iterator for this postings list | ||||
func (p *PostingsList) Iterator() segment.PostingsIterator { | func (p *PostingsList) Iterator() segment.PostingsIterator { | ||||
rv := &PostingsIterator{ | |||||
postings: p, | |||||
return p.iterator(nil) | |||||
} | |||||
func (p *PostingsList) iterator(rv *PostingsIterator) *PostingsIterator { | |||||
if rv == nil { | |||||
rv = &PostingsIterator{} | |||||
} else { | |||||
*rv = PostingsIterator{} // clear the struct | |||||
} | } | ||||
rv.postings = p | |||||
if p.postings != nil { | if p.postings != nil { | ||||
// prepare the freq chunk details | // prepare the freq chunk details | ||||
var n uint64 | var n uint64 |
import "encoding/binary" | import "encoding/binary" | ||||
func (s *SegmentBase) getDocStoredMetaAndCompressed(docNum uint64) ([]byte, []byte) { | func (s *SegmentBase) getDocStoredMetaAndCompressed(docNum uint64) ([]byte, []byte) { | ||||
docStoredStartAddr := s.storedIndexOffset + (8 * docNum) | |||||
docStoredStart := binary.BigEndian.Uint64(s.mem[docStoredStartAddr : docStoredStartAddr+8]) | |||||
_, storedOffset, n, metaLen, dataLen := s.getDocStoredOffsets(docNum) | |||||
meta := s.mem[storedOffset+n : storedOffset+n+metaLen] | |||||
data := s.mem[storedOffset+n+metaLen : storedOffset+n+metaLen+dataLen] | |||||
return meta, data | |||||
} | |||||
func (s *SegmentBase) getDocStoredOffsets(docNum uint64) ( | |||||
uint64, uint64, uint64, uint64, uint64) { | |||||
indexOffset := s.storedIndexOffset + (8 * docNum) | |||||
storedOffset := binary.BigEndian.Uint64(s.mem[indexOffset : indexOffset+8]) | |||||
var n uint64 | var n uint64 | ||||
metaLen, read := binary.Uvarint(s.mem[docStoredStart : docStoredStart+binary.MaxVarintLen64]) | |||||
metaLen, read := binary.Uvarint(s.mem[storedOffset : storedOffset+binary.MaxVarintLen64]) | |||||
n += uint64(read) | n += uint64(read) | ||||
var dataLen uint64 | |||||
dataLen, read = binary.Uvarint(s.mem[docStoredStart+n : docStoredStart+n+binary.MaxVarintLen64]) | |||||
dataLen, read := binary.Uvarint(s.mem[storedOffset+n : storedOffset+n+binary.MaxVarintLen64]) | |||||
n += uint64(read) | n += uint64(read) | ||||
meta := s.mem[docStoredStart+n : docStoredStart+n+metaLen] | |||||
data := s.mem[docStoredStart+n+metaLen : docStoredStart+n+metaLen+dataLen] | |||||
return meta, data | |||||
return indexOffset, storedOffset, n, metaLen, dataLen | |||||
} | } |
return nil, err | return nil, err | ||||
} | } | ||||
var postings *PostingsList | |||||
for _, id := range ids { | for _, id := range ids { | ||||
postings, err := idDict.postingsList([]byte(id), nil) | |||||
postings, err = idDict.postingsList([]byte(id), nil, postings) | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } |
return r.meta[string(key)] | return r.meta[string(key)] | ||||
} | } | ||||
// RollbackPoints returns an array of rollback points available | |||||
// for the application to make a decision on where to rollback | |||||
// to. A nil return value indicates that there are no available | |||||
// rollback points. | |||||
// RollbackPoints returns an array of rollback points available for | |||||
// the application to rollback to, with more recent rollback points | |||||
// (higher epochs) coming first. | |||||
func (s *Scorch) RollbackPoints() ([]*RollbackPoint, error) { | func (s *Scorch) RollbackPoints() ([]*RollbackPoint, error) { | ||||
if s.rootBolt == nil { | if s.rootBolt == nil { | ||||
return nil, fmt.Errorf("RollbackPoints: root is nil") | return nil, fmt.Errorf("RollbackPoints: root is nil") | ||||
snapshots := tx.Bucket(boltSnapshotsBucket) | snapshots := tx.Bucket(boltSnapshotsBucket) | ||||
if snapshots == nil { | if snapshots == nil { | ||||
return nil, fmt.Errorf("RollbackPoints: no snapshots available") | |||||
return nil, nil | |||||
} | } | ||||
rollbackPoints := []*RollbackPoint{} | rollbackPoints := []*RollbackPoint{} | ||||
revert.snapshot = indexSnapshot | revert.snapshot = indexSnapshot | ||||
revert.applied = make(chan error) | revert.applied = make(chan error) | ||||
if !s.unsafeBatch { | |||||
revert.persisted = make(chan error) | |||||
} | |||||
revert.persisted = make(chan error) | |||||
return nil | return nil | ||||
}) | }) | ||||
return fmt.Errorf("Rollback: failed with err: %v", err) | return fmt.Errorf("Rollback: failed with err: %v", err) | ||||
} | } | ||||
if revert.persisted != nil { | |||||
err = <-revert.persisted | |||||
} | |||||
return err | |||||
return <-revert.persisted | |||||
} | } |
docBackIndexRowErr = err | docBackIndexRowErr = err | ||||
return | return | ||||
} | } | ||||
defer func() { | |||||
if cerr := kvreader.Close(); err == nil && cerr != nil { | |||||
docBackIndexRowErr = cerr | |||||
} | |||||
}() | |||||
for docID, doc := range batch.IndexOps { | for docID, doc := range batch.IndexOps { | ||||
backIndexRow, err := backIndexRowForDoc(kvreader, index.IndexInternalID(docID)) | backIndexRow, err := backIndexRowForDoc(kvreader, index.IndexInternalID(docID)) | ||||
docBackIndexRowCh <- &docBackIndexRow{docID, doc, backIndexRow} | docBackIndexRowCh <- &docBackIndexRow{docID, doc, backIndexRow} | ||||
} | } | ||||
err = kvreader.Close() | |||||
if err != nil { | |||||
docBackIndexRowErr = err | |||||
return | |||||
} | |||||
}() | }() | ||||
// wait for analysis result | // wait for analysis result |
package bleve | package bleve | ||||
import ( | import ( | ||||
"context" | |||||
"sort" | "sort" | ||||
"sync" | "sync" | ||||
"time" | "time" | ||||
"golang.org/x/net/context" | |||||
"github.com/blevesearch/bleve/document" | "github.com/blevesearch/bleve/document" | ||||
"github.com/blevesearch/bleve/index" | "github.com/blevesearch/bleve/index" | ||||
"github.com/blevesearch/bleve/index/store" | "github.com/blevesearch/bleve/index/store" |
package bleve | package bleve | ||||
import ( | import ( | ||||
"context" | |||||
"encoding/json" | "encoding/json" | ||||
"fmt" | "fmt" | ||||
"os" | "os" | ||||
"sync/atomic" | "sync/atomic" | ||||
"time" | "time" | ||||
"golang.org/x/net/context" | |||||
"github.com/blevesearch/bleve/document" | "github.com/blevesearch/bleve/document" | ||||
"github.com/blevesearch/bleve/index" | "github.com/blevesearch/bleve/index" | ||||
"github.com/blevesearch/bleve/index/store" | "github.com/blevesearch/bleve/index/store" |
package search | package search | ||||
import ( | import ( | ||||
"context" | |||||
"time" | "time" | ||||
"github.com/blevesearch/bleve/index" | "github.com/blevesearch/bleve/index" | ||||
"golang.org/x/net/context" | |||||
) | ) | ||||
type Collector interface { | type Collector interface { |
package collector | package collector | ||||
import ( | import ( | ||||
"context" | |||||
"time" | "time" | ||||
"github.com/blevesearch/bleve/index" | "github.com/blevesearch/bleve/index" | ||||
"github.com/blevesearch/bleve/search" | "github.com/blevesearch/bleve/search" | ||||
"golang.org/x/net/context" | |||||
) | ) | ||||
type collectorStore interface { | type collectorStore interface { |
package goth | package goth | ||||
import ( | import ( | ||||
"context" | |||||
"fmt" | "fmt" | ||||
"net/http" | "net/http" | ||||
"golang.org/x/net/context" | |||||
"golang.org/x/oauth2" | "golang.org/x/oauth2" | ||||
) | ) | ||||
import ( | import ( | ||||
"bytes" | "bytes" | ||||
"crypto/hmac" | |||||
"crypto/sha256" | |||||
"encoding/hex" | |||||
"encoding/json" | "encoding/json" | ||||
"errors" | "errors" | ||||
"fmt" | |||||
"io" | "io" | ||||
"io/ioutil" | "io/ioutil" | ||||
"net/http" | "net/http" | ||||
"net/url" | "net/url" | ||||
"strings" | |||||
"crypto/hmac" | |||||
"crypto/sha256" | |||||
"encoding/hex" | |||||
"fmt" | |||||
"github.com/markbates/goth" | "github.com/markbates/goth" | ||||
"golang.org/x/oauth2" | "golang.org/x/oauth2" | ||||
) | ) | ||||
const ( | const ( | ||||
authURL string = "https://www.facebook.com/dialog/oauth" | authURL string = "https://www.facebook.com/dialog/oauth" | ||||
tokenURL string = "https://graph.facebook.com/oauth/access_token" | tokenURL string = "https://graph.facebook.com/oauth/access_token" | ||||
endpointProfile string = "https://graph.facebook.com/me?fields=email,first_name,last_name,link,about,id,name,picture,location" | |||||
endpointProfile string = "https://graph.facebook.com/me?fields=" | |||||
) | ) | ||||
// New creates a new Facebook provider, and sets up important connection details. | // New creates a new Facebook provider, and sets up important connection details. | ||||
// BeginAuth asks Facebook for an authentication end-point. | // BeginAuth asks Facebook for an authentication end-point. | ||||
func (p *Provider) BeginAuth(state string) (goth.Session, error) { | func (p *Provider) BeginAuth(state string) (goth.Session, error) { | ||||
url := p.config.AuthCodeURL(state) | |||||
authUrl := p.config.AuthCodeURL(state) | |||||
session := &Session{ | session := &Session{ | ||||
AuthURL: url, | |||||
AuthURL: authUrl, | |||||
} | } | ||||
return session, nil | return session, nil | ||||
} | } | ||||
hash.Write([]byte(sess.AccessToken)) | hash.Write([]byte(sess.AccessToken)) | ||||
appsecretProof := hex.EncodeToString(hash.Sum(nil)) | appsecretProof := hex.EncodeToString(hash.Sum(nil)) | ||||
response, err := p.Client().Get(endpointProfile + "&access_token=" + url.QueryEscape(sess.AccessToken) + "&appsecret_proof=" + appsecretProof) | |||||
reqUrl := fmt.Sprint( | |||||
endpointProfile, | |||||
strings.Join(p.config.Scopes, ","), | |||||
"&access_token=", | |||||
url.QueryEscape(sess.AccessToken), | |||||
"&appsecret_proof=", | |||||
appsecretProof, | |||||
) | |||||
response, err := p.Client().Get(reqUrl) | |||||
if err != nil { | if err != nil { | ||||
return user, err | return user, err | ||||
} | } | ||||
}, | }, | ||||
Scopes: []string{ | Scopes: []string{ | ||||
"email", | "email", | ||||
"first_name", | |||||
"last_name", | |||||
"link", | |||||
"about", | |||||
"id", | |||||
"name", | |||||
"picture", | |||||
"location", | |||||
}, | }, | ||||
} | } | ||||
defaultScopes := map[string]struct{}{ | |||||
"email": {}, | |||||
} | |||||
for _, scope := range scopes { | |||||
if _, exists := defaultScopes[scope]; !exists { | |||||
c.Scopes = append(c.Scopes, scope) | |||||
// creates possibility to invoke field method like 'picture.type(large)' | |||||
var found bool | |||||
for _, sc := range scopes { | |||||
sc := sc | |||||
for i, defScope := range c.Scopes { | |||||
if defScope == strings.Split(sc, ".")[0] { | |||||
c.Scopes[i] = sc | |||||
found = true | |||||
} | |||||
} | |||||
if !found { | |||||
c.Scopes = append(c.Scopes, sc) | |||||
} | } | ||||
found = false | |||||
} | } | ||||
return c | return c |
// Copyright 2016 The Go Authors. All rights reserved. | |||||
// Use of this source code is governed by a BSD-style | |||||
// license that can be found in the LICENSE file. | |||||
// +build go1.7 | |||||
// Package ctxhttp provides helper functions for performing context-aware HTTP requests. | |||||
package ctxhttp // import "golang.org/x/net/context/ctxhttp" | |||||
import ( | |||||
"io" | |||||
"net/http" | |||||
"net/url" | |||||
"strings" | |||||
"golang.org/x/net/context" | |||||
) | |||||
// Do sends an HTTP request with the provided http.Client and returns | |||||
// an HTTP response. | |||||
// | |||||
// If the client is nil, http.DefaultClient is used. | |||||
// | |||||
// The provided ctx must be non-nil. If it is canceled or times out, | |||||
// ctx.Err() will be returned. | |||||
func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { | |||||
if client == nil { | |||||
client = http.DefaultClient | |||||
} | |||||
resp, err := client.Do(req.WithContext(ctx)) | |||||
// If we got an error, and the context has been canceled, | |||||
// the context's error is probably more useful. | |||||
if err != nil { | |||||
select { | |||||
case <-ctx.Done(): | |||||
err = ctx.Err() | |||||
default: | |||||
} | |||||
} | |||||
return resp, err | |||||
} | |||||
// Get issues a GET request via the Do function. | |||||
func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) { | |||||
req, err := http.NewRequest("GET", url, nil) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return Do(ctx, client, req) | |||||
} | |||||
// Head issues a HEAD request via the Do function. | |||||
func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) { | |||||
req, err := http.NewRequest("HEAD", url, nil) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return Do(ctx, client, req) | |||||
} | |||||
// Post issues a POST request via the Do function. | |||||
func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) { | |||||
req, err := http.NewRequest("POST", url, body) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
req.Header.Set("Content-Type", bodyType) | |||||
return Do(ctx, client, req) | |||||
} | |||||
// PostForm issues a POST request via the Do function. | |||||
func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) { | |||||
return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) | |||||
} |
// Copyright 2015 The Go Authors. All rights reserved. | |||||
// Use of this source code is governed by a BSD-style | |||||
// license that can be found in the LICENSE file. | |||||
// +build !go1.7 | |||||
package ctxhttp // import "golang.org/x/net/context/ctxhttp" | |||||
import ( | |||||
"io" | |||||
"net/http" | |||||
"net/url" | |||||
"strings" | |||||
"golang.org/x/net/context" | |||||
) | |||||
func nop() {} | |||||
var ( | |||||
testHookContextDoneBeforeHeaders = nop | |||||
testHookDoReturned = nop | |||||
testHookDidBodyClose = nop | |||||
) | |||||
// Do sends an HTTP request with the provided http.Client and returns an HTTP response. | |||||
// If the client is nil, http.DefaultClient is used. | |||||
// If the context is canceled or times out, ctx.Err() will be returned. | |||||
func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { | |||||
if client == nil { | |||||
client = http.DefaultClient | |||||
} | |||||
// TODO(djd): Respect any existing value of req.Cancel. | |||||
cancel := make(chan struct{}) | |||||
req.Cancel = cancel | |||||
type responseAndError struct { | |||||
resp *http.Response | |||||
err error | |||||
} | |||||
result := make(chan responseAndError, 1) | |||||
// Make local copies of test hooks closed over by goroutines below. | |||||
// Prevents data races in tests. | |||||
testHookDoReturned := testHookDoReturned | |||||
testHookDidBodyClose := testHookDidBodyClose | |||||
go func() { | |||||
resp, err := client.Do(req) | |||||
testHookDoReturned() | |||||
result <- responseAndError{resp, err} | |||||
}() | |||||
var resp *http.Response | |||||
select { | |||||
case <-ctx.Done(): | |||||
testHookContextDoneBeforeHeaders() | |||||
close(cancel) | |||||
// Clean up after the goroutine calling client.Do: | |||||
go func() { | |||||
if r := <-result; r.resp != nil { | |||||
testHookDidBodyClose() | |||||
r.resp.Body.Close() | |||||
} | |||||
}() | |||||
return nil, ctx.Err() | |||||
case r := <-result: | |||||
var err error | |||||
resp, err = r.resp, r.err | |||||
if err != nil { | |||||
return resp, err | |||||
} | |||||
} | |||||
c := make(chan struct{}) | |||||
go func() { | |||||
select { | |||||
case <-ctx.Done(): | |||||
close(cancel) | |||||
case <-c: | |||||
// The response's Body is closed. | |||||
} | |||||
}() | |||||
resp.Body = ¬ifyingReader{resp.Body, c} | |||||
return resp, nil | |||||
} | |||||
// Get issues a GET request via the Do function. | |||||
func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) { | |||||
req, err := http.NewRequest("GET", url, nil) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return Do(ctx, client, req) | |||||
} | |||||
// Head issues a HEAD request via the Do function. | |||||
func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) { | |||||
req, err := http.NewRequest("HEAD", url, nil) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return Do(ctx, client, req) | |||||
} | |||||
// Post issues a POST request via the Do function. | |||||
func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) { | |||||
req, err := http.NewRequest("POST", url, body) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
req.Header.Set("Content-Type", bodyType) | |||||
return Do(ctx, client, req) | |||||
} | |||||
// PostForm issues a POST request via the Do function. | |||||
func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) { | |||||
return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) | |||||
} | |||||
// notifyingReader is an io.ReadCloser that closes the notify channel after | |||||
// Close is called or a Read fails on the underlying ReadCloser. | |||||
type notifyingReader struct { | |||||
io.ReadCloser | |||||
notify chan<- struct{} | |||||
} | |||||
func (r *notifyingReader) Read(p []byte) (int, error) { | |||||
n, err := r.ReadCloser.Read(p) | |||||
if err != nil && r.notify != nil { | |||||
close(r.notify) | |||||
r.notify = nil | |||||
} | |||||
return n, err | |||||
} | |||||
func (r *notifyingReader) Close() error { | |||||
err := r.ReadCloser.Close() | |||||
if r.notify != nil { | |||||
close(r.notify) | |||||
r.notify = nil | |||||
} | |||||
return err | |||||
} |
Copyright (c) 2009 The oauth2 Authors. All rights reserved. | |||||
Copyright (c) 2009 The Go Authors. All rights reserved. | |||||
Redistribution and use in source and binary forms, with or without | Redistribution and use in source and binary forms, with or without | ||||
modification, are permitted provided that the following conditions are | modification, are permitted provided that the following conditions are |
// Copyright 2014 The Go Authors. All rights reserved. | |||||
// Use of this source code is governed by a BSD-style | |||||
// license that can be found in the LICENSE file. | |||||
// +build appengine | |||||
// App Engine hooks. | |||||
package oauth2 | |||||
import ( | |||||
"net/http" | |||||
"golang.org/x/net/context" | |||||
"golang.org/x/oauth2/internal" | |||||
"google.golang.org/appengine/urlfetch" | |||||
) | |||||
func init() { | |||||
internal.RegisterContextClientFunc(contextClientAppEngine) | |||||
} | |||||
func contextClientAppEngine(ctx context.Context) (*http.Client, error) { | |||||
return urlfetch.Client(ctx), nil | |||||
} |
// Copyright 2018 The Go Authors. All rights reserved. | |||||
// Use of this source code is governed by a BSD-style | |||||
// license that can be found in the LICENSE file. | |||||
// +build appengine | |||||
package internal | |||||
import "google.golang.org/appengine/urlfetch" | |||||
func init() { | |||||
appengineClientHook = urlfetch.Client | |||||
} |
// Copyright 2017 The Go Authors. All rights reserved. | |||||
// Use of this source code is governed by a BSD-style | |||||
// license that can be found in the LICENSE file. | |||||
// Package internal contains support packages for oauth2 package. | |||||
package internal |
// Use of this source code is governed by a BSD-style | // Use of this source code is governed by a BSD-style | ||||
// license that can be found in the LICENSE file. | // license that can be found in the LICENSE file. | ||||
// Package internal contains support packages for oauth2 package. | |||||
package internal | package internal | ||||
import ( | import ( | ||||
"bufio" | |||||
"crypto/rsa" | "crypto/rsa" | ||||
"crypto/x509" | "crypto/x509" | ||||
"encoding/pem" | "encoding/pem" | ||||
"errors" | "errors" | ||||
"fmt" | "fmt" | ||||
"io" | |||||
"strings" | |||||
) | ) | ||||
// ParseKey converts the binary contents of a private key file | // ParseKey converts the binary contents of a private key file | ||||
if err != nil { | if err != nil { | ||||
parsedKey, err = x509.ParsePKCS1PrivateKey(key) | parsedKey, err = x509.ParsePKCS1PrivateKey(key) | ||||
if err != nil { | if err != nil { | ||||
return nil, fmt.Errorf("private key should be a PEM or plain PKSC1 or PKCS8; parse error: %v", err) | |||||
return nil, fmt.Errorf("private key should be a PEM or plain PKCS1 or PKCS8; parse error: %v", err) | |||||
} | } | ||||
} | } | ||||
parsed, ok := parsedKey.(*rsa.PrivateKey) | parsed, ok := parsedKey.(*rsa.PrivateKey) | ||||
} | } | ||||
return parsed, nil | return parsed, nil | ||||
} | } | ||||
func ParseINI(ini io.Reader) (map[string]map[string]string, error) { | |||||
result := map[string]map[string]string{ | |||||
"": map[string]string{}, // root section | |||||
} | |||||
scanner := bufio.NewScanner(ini) | |||||
currentSection := "" | |||||
for scanner.Scan() { | |||||
line := strings.TrimSpace(scanner.Text()) | |||||
if strings.HasPrefix(line, ";") { | |||||
// comment. | |||||
continue | |||||
} | |||||
if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { | |||||
currentSection = strings.TrimSpace(line[1 : len(line)-1]) | |||||
result[currentSection] = map[string]string{} | |||||
continue | |||||
} | |||||
parts := strings.SplitN(line, "=", 2) | |||||
if len(parts) == 2 && parts[0] != "" { | |||||
result[currentSection][strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) | |||||
} | |||||
} | |||||
if err := scanner.Err(); err != nil { | |||||
return nil, fmt.Errorf("error scanning ini: %v", err) | |||||
} | |||||
return result, nil | |||||
} | |||||
func CondVal(v string) []string { | |||||
if v == "" { | |||||
return nil | |||||
} | |||||
return []string{v} | |||||
} |
// Use of this source code is governed by a BSD-style | // Use of this source code is governed by a BSD-style | ||||
// license that can be found in the LICENSE file. | // license that can be found in the LICENSE file. | ||||
// Package internal contains support packages for oauth2 package. | |||||
package internal | package internal | ||||
import ( | import ( | ||||
"context" | |||||
"encoding/json" | "encoding/json" | ||||
"errors" | |||||
"fmt" | "fmt" | ||||
"io" | "io" | ||||
"io/ioutil" | "io/ioutil" | ||||
"strings" | "strings" | ||||
"time" | "time" | ||||
"golang.org/x/net/context" | |||||
"golang.org/x/net/context/ctxhttp" | |||||
) | ) | ||||
// Token represents the crendentials used to authorize | |||||
// Token represents the credentials used to authorize | |||||
// the requests to access protected resources on the OAuth 2.0 | // the requests to access protected resources on the OAuth 2.0 | ||||
// provider's backend. | // provider's backend. | ||||
// | // | ||||
var brokenAuthHeaderProviders = []string{ | var brokenAuthHeaderProviders = []string{ | ||||
"https://accounts.google.com/", | "https://accounts.google.com/", | ||||
"https://api.codeswholesale.com/oauth/token", | |||||
"https://api.dropbox.com/", | "https://api.dropbox.com/", | ||||
"https://api.dropboxapi.com/", | "https://api.dropboxapi.com/", | ||||
"https://api.instagram.com/", | "https://api.instagram.com/", | ||||
"https://api.pushbullet.com/", | "https://api.pushbullet.com/", | ||||
"https://api.soundcloud.com/", | "https://api.soundcloud.com/", | ||||
"https://api.twitch.tv/", | "https://api.twitch.tv/", | ||||
"https://id.twitch.tv/", | |||||
"https://app.box.com/", | "https://app.box.com/", | ||||
"https://api.box.com/", | |||||
"https://connect.stripe.com/", | "https://connect.stripe.com/", | ||||
"https://login.mailchimp.com/", | |||||
"https://login.microsoftonline.com/", | "https://login.microsoftonline.com/", | ||||
"https://login.salesforce.com/", | "https://login.salesforce.com/", | ||||
"https://login.windows.net", | |||||
"https://login.live.com/", | |||||
"https://login.live-int.com/", | |||||
"https://oauth.sandbox.trainingpeaks.com/", | "https://oauth.sandbox.trainingpeaks.com/", | ||||
"https://oauth.trainingpeaks.com/", | "https://oauth.trainingpeaks.com/", | ||||
"https://oauth.vk.com/", | "https://oauth.vk.com/", | ||||
"https://www.strava.com/oauth/", | "https://www.strava.com/oauth/", | ||||
"https://www.wunderlist.com/oauth/", | "https://www.wunderlist.com/oauth/", | ||||
"https://api.patreon.com/", | "https://api.patreon.com/", | ||||
"https://sandbox.codeswholesale.com/oauth/token", | |||||
"https://api.sipgate.com/v1/authorization/oauth", | |||||
"https://api.medium.com/v1/tokens", | |||||
"https://log.finalsurge.com/oauth/token", | |||||
"https://multisport.todaysplan.com.au/rest/oauth/access_token", | |||||
"https://whats.todaysplan.com.au/rest/oauth/access_token", | |||||
"https://stackoverflow.com/oauth/access_token", | |||||
"https://account.health.nokia.com", | |||||
"https://accounts.zoho.com", | |||||
} | |||||
// brokenAuthHeaderDomains lists broken providers that issue dynamic endpoints. | |||||
var brokenAuthHeaderDomains = []string{ | |||||
".auth0.com", | |||||
".force.com", | |||||
".myshopify.com", | |||||
".okta.com", | |||||
".oktapreview.com", | |||||
} | } | ||||
func RegisterBrokenAuthHeaderProvider(tokenURL string) { | func RegisterBrokenAuthHeaderProvider(tokenURL string) { | ||||
} | } | ||||
} | } | ||||
if u, err := url.Parse(tokenURL); err == nil { | |||||
for _, s := range brokenAuthHeaderDomains { | |||||
if strings.HasSuffix(u.Host, s) { | |||||
return false | |||||
} | |||||
} | |||||
} | |||||
// Assume the provider implements the spec properly | // Assume the provider implements the spec properly | ||||
// otherwise. We can add more exceptions as they're | // otherwise. We can add more exceptions as they're | ||||
// discovered. We will _not_ be adding configurable hooks | // discovered. We will _not_ be adding configurable hooks | ||||
} | } | ||||
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) { | func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) { | ||||
hc, err := ContextClient(ctx) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
v.Set("client_id", clientID) | |||||
bustedAuth := !providerAuthHeaderWorks(tokenURL) | bustedAuth := !providerAuthHeaderWorks(tokenURL) | ||||
if bustedAuth && clientSecret != "" { | |||||
v.Set("client_secret", clientSecret) | |||||
if bustedAuth { | |||||
if clientID != "" { | |||||
v.Set("client_id", clientID) | |||||
} | |||||
if clientSecret != "" { | |||||
v.Set("client_secret", clientSecret) | |||||
} | |||||
} | } | ||||
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode())) | req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode())) | ||||
if err != nil { | if err != nil { | ||||
} | } | ||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||||
if !bustedAuth { | if !bustedAuth { | ||||
req.SetBasicAuth(clientID, clientSecret) | |||||
req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret)) | |||||
} | } | ||||
r, err := hc.Do(req) | |||||
r, err := ctxhttp.Do(ctx, ContextClient(ctx), req) | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) | return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) | ||||
} | } | ||||
if code := r.StatusCode; code < 200 || code > 299 { | if code := r.StatusCode; code < 200 || code > 299 { | ||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body) | |||||
return nil, &RetrieveError{ | |||||
Response: r, | |||||
Body: body, | |||||
} | |||||
} | } | ||||
var token *Token | var token *Token | ||||
if token.RefreshToken == "" { | if token.RefreshToken == "" { | ||||
token.RefreshToken = v.Get("refresh_token") | token.RefreshToken = v.Get("refresh_token") | ||||
} | } | ||||
if token.AccessToken == "" { | |||||
return token, errors.New("oauth2: server response missing access_token") | |||||
} | |||||
return token, nil | return token, nil | ||||
} | } | ||||
type RetrieveError struct { | |||||
Response *http.Response | |||||
Body []byte | |||||
} | |||||
func (r *RetrieveError) Error() string { | |||||
return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body) | |||||
} |
// Use of this source code is governed by a BSD-style | // Use of this source code is governed by a BSD-style | ||||
// license that can be found in the LICENSE file. | // license that can be found in the LICENSE file. | ||||
// Package internal contains support packages for oauth2 package. | |||||
package internal | package internal | ||||
import ( | import ( | ||||
"context" | |||||
"net/http" | "net/http" | ||||
"golang.org/x/net/context" | |||||
) | ) | ||||
// HTTPClient is the context key to use with golang.org/x/net/context's | // HTTPClient is the context key to use with golang.org/x/net/context's | ||||
// because nobody else can create a ContextKey, being unexported. | // because nobody else can create a ContextKey, being unexported. | ||||
type ContextKey struct{} | type ContextKey struct{} | ||||
// ContextClientFunc is a func which tries to return an *http.Client | |||||
// given a Context value. If it returns an error, the search stops | |||||
// with that error. If it returns (nil, nil), the search continues | |||||
// down the list of registered funcs. | |||||
type ContextClientFunc func(context.Context) (*http.Client, error) | |||||
var contextClientFuncs []ContextClientFunc | |||||
func RegisterContextClientFunc(fn ContextClientFunc) { | |||||
contextClientFuncs = append(contextClientFuncs, fn) | |||||
} | |||||
var appengineClientHook func(context.Context) *http.Client | |||||
func ContextClient(ctx context.Context) (*http.Client, error) { | |||||
func ContextClient(ctx context.Context) *http.Client { | |||||
if ctx != nil { | if ctx != nil { | ||||
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok { | if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok { | ||||
return hc, nil | |||||
} | |||||
} | |||||
for _, fn := range contextClientFuncs { | |||||
c, err := fn(ctx) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if c != nil { | |||||
return c, nil | |||||
return hc | |||||
} | } | ||||
} | } | ||||
return http.DefaultClient, nil | |||||
} | |||||
func ContextTransport(ctx context.Context) http.RoundTripper { | |||||
hc, err := ContextClient(ctx) | |||||
// This is a rare error case (somebody using nil on App Engine). | |||||
if err != nil { | |||||
return ErrorTransport{err} | |||||
if appengineClientHook != nil { | |||||
return appengineClientHook(ctx) | |||||
} | } | ||||
return hc.Transport | |||||
} | |||||
// ErrorTransport returns the specified error on RoundTrip. | |||||
// This RoundTripper should be used in rare error cases where | |||||
// error handling can be postponed to response handling time. | |||||
type ErrorTransport struct{ Err error } | |||||
func (t ErrorTransport) RoundTrip(*http.Request) (*http.Response, error) { | |||||
return nil, t.Err | |||||
return http.DefaultClient | |||||
} | } |
// license that can be found in the LICENSE file. | // license that can be found in the LICENSE file. | ||||
// Package oauth2 provides support for making | // Package oauth2 provides support for making | ||||
// OAuth2 authorized and authenticated HTTP requests. | |||||
// OAuth2 authorized and authenticated HTTP requests, | |||||
// as specified in RFC 6749. | |||||
// It can additionally grant authorization with Bearer JWT. | // It can additionally grant authorization with Bearer JWT. | ||||
package oauth2 // import "golang.org/x/oauth2" | package oauth2 // import "golang.org/x/oauth2" | ||||
import ( | import ( | ||||
"bytes" | "bytes" | ||||
"context" | |||||
"errors" | "errors" | ||||
"net/http" | "net/http" | ||||
"net/url" | "net/url" | ||||
"strings" | "strings" | ||||
"sync" | "sync" | ||||
"golang.org/x/net/context" | |||||
"golang.org/x/oauth2/internal" | "golang.org/x/oauth2/internal" | ||||
) | ) | ||||
// that asks for permissions for the required scopes explicitly. | // that asks for permissions for the required scopes explicitly. | ||||
// | // | ||||
// State is a token to protect the user from CSRF attacks. You must | // State is a token to protect the user from CSRF attacks. You must | ||||
// always provide a non-zero string and validate that it matches the | |||||
// always provide a non-empty string and validate that it matches the | |||||
// the state query parameter on your redirect callback. | // the state query parameter on your redirect callback. | ||||
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info. | // See http://tools.ietf.org/html/rfc6749#section-10.12 for more info. | ||||
// | // | ||||
// Opts may include AccessTypeOnline or AccessTypeOffline, as well | // Opts may include AccessTypeOnline or AccessTypeOffline, as well | ||||
// as ApprovalForce. | // as ApprovalForce. | ||||
// It can also be used to pass the PKCE challange. | |||||
// See https://www.oauth.com/oauth2-servers/pkce/ for more info. | |||||
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { | func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { | ||||
var buf bytes.Buffer | var buf bytes.Buffer | ||||
buf.WriteString(c.Endpoint.AuthURL) | buf.WriteString(c.Endpoint.AuthURL) | ||||
v := url.Values{ | v := url.Values{ | ||||
"response_type": {"code"}, | "response_type": {"code"}, | ||||
"client_id": {c.ClientID}, | "client_id": {c.ClientID}, | ||||
"redirect_uri": internal.CondVal(c.RedirectURL), | |||||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), | |||||
"state": internal.CondVal(state), | |||||
} | |||||
if c.RedirectURL != "" { | |||||
v.Set("redirect_uri", c.RedirectURL) | |||||
} | |||||
if len(c.Scopes) > 0 { | |||||
v.Set("scope", strings.Join(c.Scopes, " ")) | |||||
} | |||||
if state != "" { | |||||
// TODO(light): Docs say never to omit state; don't allow empty. | |||||
v.Set("state", state) | |||||
} | } | ||||
for _, opt := range opts { | for _, opt := range opts { | ||||
opt.setValue(v) | opt.setValue(v) | ||||
// The HTTP client to use is derived from the context. | // The HTTP client to use is derived from the context. | ||||
// If nil, http.DefaultClient is used. | // If nil, http.DefaultClient is used. | ||||
func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) { | func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) { | ||||
return retrieveToken(ctx, c, url.Values{ | |||||
v := url.Values{ | |||||
"grant_type": {"password"}, | "grant_type": {"password"}, | ||||
"username": {username}, | "username": {username}, | ||||
"password": {password}, | "password": {password}, | ||||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), | |||||
}) | |||||
} | |||||
if len(c.Scopes) > 0 { | |||||
v.Set("scope", strings.Join(c.Scopes, " ")) | |||||
} | |||||
return retrieveToken(ctx, c, v) | |||||
} | } | ||||
// Exchange converts an authorization code into a token. | // Exchange converts an authorization code into a token. | ||||
// | // | ||||
// The code will be in the *http.Request.FormValue("code"). Before | // The code will be in the *http.Request.FormValue("code"). Before | ||||
// calling Exchange, be sure to validate FormValue("state"). | // calling Exchange, be sure to validate FormValue("state"). | ||||
func (c *Config) Exchange(ctx context.Context, code string) (*Token, error) { | |||||
return retrieveToken(ctx, c, url.Values{ | |||||
"grant_type": {"authorization_code"}, | |||||
"code": {code}, | |||||
"redirect_uri": internal.CondVal(c.RedirectURL), | |||||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), | |||||
}) | |||||
// | |||||
// Opts may include the PKCE verifier code if previously used in AuthCodeURL. | |||||
// See https://www.oauth.com/oauth2-servers/pkce/ for more info. | |||||
func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) { | |||||
v := url.Values{ | |||||
"grant_type": {"authorization_code"}, | |||||
"code": {code}, | |||||
} | |||||
if c.RedirectURL != "" { | |||||
v.Set("redirect_uri", c.RedirectURL) | |||||
} | |||||
for _, opt := range opts { | |||||
opt.setValue(v) | |||||
} | |||||
return retrieveToken(ctx, c, v) | |||||
} | } | ||||
// Client returns an HTTP client using the provided token. | // Client returns an HTTP client using the provided token. | ||||
// NewClient creates an *http.Client from a Context and TokenSource. | // NewClient creates an *http.Client from a Context and TokenSource. | ||||
// The returned client is not valid beyond the lifetime of the context. | // The returned client is not valid beyond the lifetime of the context. | ||||
// | // | ||||
// Note that if a custom *http.Client is provided via the Context it | |||||
// is used only for token acquisition and is not used to configure the | |||||
// *http.Client returned from NewClient. | |||||
// | |||||
// As a special case, if src is nil, a non-OAuth2 client is returned | // As a special case, if src is nil, a non-OAuth2 client is returned | ||||
// using the provided context. This exists to support related OAuth2 | // using the provided context. This exists to support related OAuth2 | ||||
// packages. | // packages. | ||||
func NewClient(ctx context.Context, src TokenSource) *http.Client { | func NewClient(ctx context.Context, src TokenSource) *http.Client { | ||||
if src == nil { | if src == nil { | ||||
c, err := internal.ContextClient(ctx) | |||||
if err != nil { | |||||
return &http.Client{Transport: internal.ErrorTransport{Err: err}} | |||||
} | |||||
return c | |||||
return internal.ContextClient(ctx) | |||||
} | } | ||||
return &http.Client{ | return &http.Client{ | ||||
Transport: &Transport{ | Transport: &Transport{ | ||||
Base: internal.ContextTransport(ctx), | |||||
Base: internal.ContextClient(ctx).Transport, | |||||
Source: ReuseTokenSource(nil, src), | Source: ReuseTokenSource(nil, src), | ||||
}, | }, | ||||
} | } |
package oauth2 | package oauth2 | ||||
import ( | import ( | ||||
"context" | |||||
"fmt" | |||||
"net/http" | "net/http" | ||||
"net/url" | "net/url" | ||||
"strconv" | "strconv" | ||||
"strings" | "strings" | ||||
"time" | "time" | ||||
"golang.org/x/net/context" | |||||
"golang.org/x/oauth2/internal" | "golang.org/x/oauth2/internal" | ||||
) | ) | ||||
// expirations due to client-server time mismatches. | // expirations due to client-server time mismatches. | ||||
const expiryDelta = 10 * time.Second | const expiryDelta = 10 * time.Second | ||||
// Token represents the crendentials used to authorize | |||||
// Token represents the credentials used to authorize | |||||
// the requests to access protected resources on the OAuth 2.0 | // the requests to access protected resources on the OAuth 2.0 | ||||
// provider's backend. | // provider's backend. | ||||
// | // | ||||
if t.Expiry.IsZero() { | if t.Expiry.IsZero() { | ||||
return false | return false | ||||
} | } | ||||
return t.Expiry.Add(-expiryDelta).Before(time.Now()) | |||||
return t.Expiry.Round(0).Add(-expiryDelta).Before(time.Now()) | |||||
} | } | ||||
// Valid reports whether t is non-nil, has an AccessToken, and is not expired. | // Valid reports whether t is non-nil, has an AccessToken, and is not expired. | ||||
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { | func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { | ||||
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v) | tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v) | ||||
if err != nil { | if err != nil { | ||||
if rErr, ok := err.(*internal.RetrieveError); ok { | |||||
return nil, (*RetrieveError)(rErr) | |||||
} | |||||
return nil, err | return nil, err | ||||
} | } | ||||
return tokenFromInternal(tk), nil | return tokenFromInternal(tk), nil | ||||
} | } | ||||
// RetrieveError is the error returned when the token endpoint returns a | |||||
// non-2XX HTTP status code. | |||||
type RetrieveError struct { | |||||
Response *http.Response | |||||
// Body is the body that was consumed by reading Response.Body. | |||||
// It may be truncated. | |||||
Body []byte | |||||
} | |||||
func (r *RetrieveError) Error() string { | |||||
return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body) | |||||
} |
} | } | ||||
// RoundTrip authorizes and authenticates the request with an | // RoundTrip authorizes and authenticates the request with an | ||||
// access token. If no token exists or token is expired, | |||||
// tries to refresh/fetch a new token. | |||||
// access token from Transport's Source. | |||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { | func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { | ||||
reqBodyClosed := false | |||||
if req.Body != nil { | |||||
defer func() { | |||||
if !reqBodyClosed { | |||||
req.Body.Close() | |||||
} | |||||
}() | |||||
} | |||||
if t.Source == nil { | if t.Source == nil { | ||||
return nil, errors.New("oauth2: Transport's Source is nil") | return nil, errors.New("oauth2: Transport's Source is nil") | ||||
} | } | ||||
token.SetAuthHeader(req2) | token.SetAuthHeader(req2) | ||||
t.setModReq(req, req2) | t.setModReq(req, req2) | ||||
res, err := t.base().RoundTrip(req2) | res, err := t.base().RoundTrip(req2) | ||||
// req.Body is assumed to have been closed by the base RoundTripper. | |||||
reqBodyClosed = true | |||||
if err != nil { | if err != nil { | ||||
t.setModReq(req, nil) | t.setModReq(req, nil) | ||||
return nil, err | return nil, err |