* Update dep github.com/markbates/goth
* Update dep github.com/blevesearch/bleve
* Update dep golang.org/x/oauth2
* Fix github.com/blevesearch/bleve to c74e08f039
* Update dep golang.org/x/oauth2
tags/v1.7.0-dev
@@ -90,7 +90,7 @@ | |||
revision = "3a771d992973f24aa725d07868b467d1ddfceafb" | |||
[[projects]] | |||
digest = "1:67351095005f164e748a5a21899d1403b03878cb2d40a7b0f742376e6eeda974" | |||
digest = "1:c10f35be6200b09e26da267ca80f837315093ecaba27e7a223071380efb9dd32" | |||
name = "github.com/blevesearch/bleve" | |||
packages = [ | |||
".", | |||
@@ -135,7 +135,7 @@ | |||
"search/searcher", | |||
] | |||
pruneopts = "NUT" | |||
revision = "ff210fbc6d348ad67aa5754eaea11a463fcddafd" | |||
revision = "c74e08f039e56cef576e4336382b2a2d12d9e026" | |||
[[projects]] | |||
branch = "master" | |||
@@ -557,7 +557,7 @@ | |||
revision = "e3534c89ef969912856dfa39e56b09e58c5f5daf" | |||
[[projects]] | |||
digest = "1:23f75ae90fcc38dac6fad6881006ea7d0f2c78db5f9f81f3df558dc91460e61f" | |||
digest = "1:4b992ec853d0ea9bac3dcf09a64af61de1a392e6cb0eef2204c0c92f4ae6b911" | |||
name = "github.com/markbates/goth" | |||
packages = [ | |||
".", | |||
@@ -572,8 +572,8 @@ | |||
"providers/twitter", | |||
] | |||
pruneopts = "NUT" | |||
revision = "f9c6649ab984d6ea71ef1e13b7b1cdffcf4592d3" | |||
version = "v1.46.1" | |||
revision = "bc6d8ddf751a745f37ca5567dbbfc4157bbf5da9" | |||
version = "v1.47.2" | |||
[[projects]] | |||
digest = "1:c9724c929d27a14475a45b17a267dbc60671c0bc2c5c05ed21f011f7b5bc9fb5" | |||
@@ -809,10 +809,11 @@ | |||
[[projects]] | |||
branch = "master" | |||
digest = "1:6d5ed712653ea5321fe3e3475ab2188cf362a4e0d31e9fd3acbd4dfbbca0d680" | |||
digest = "1:d0a0bdd2b64d981aa4e6a1ade90431d042cd7fa31b584e33d45e62cbfec43380" | |||
name = "golang.org/x/net" | |||
packages = [ | |||
"context", | |||
"context/ctxhttp", | |||
"html", | |||
"html/atom", | |||
"html/charset", | |||
@@ -821,14 +822,15 @@ | |||
revision = "9b4f9f5ad5197c79fd623a3638e70d8b26cef344" | |||
[[projects]] | |||
digest = "1:8159a9cda4b8810aaaeb0d60e2fa68e2fd86d8af4ec8f5059830839e3c8d93d5" | |||
branch = "master" | |||
digest = "1:274a6321a5a9f185eeb3fab5d7d8397e0e9f57737490d749f562c7e205ffbc2e" | |||
name = "golang.org/x/oauth2" | |||
packages = [ | |||
".", | |||
"internal", | |||
] | |||
pruneopts = "NUT" | |||
revision = "c10ba270aa0bf8b8c1c986e103859c67a9103061" | |||
revision = "c453e0c757598fd055e170a3a359263c91e13153" | |||
[[projects]] | |||
digest = "1:9f303486d623f840492bfeb48eb906a94e9d3fe638a761639b72ce64bf7bfcc3" |
@@ -14,6 +14,12 @@ ignored = ["google.golang.org/appengine*"] | |||
branch = "master" | |||
name = "code.gitea.io/sdk" | |||
[[constraint]] | |||
# branch = "master" | |||
revision = "c74e08f039e56cef576e4336382b2a2d12d9e026" | |||
name = "github.com/blevesearch/bleve" | |||
#Not targetting v0.7.0 since standard where use only just after this tag | |||
[[constraint]] | |||
revision = "12dd70caea0268ac0d6c2707d0611ef601e7c64e" | |||
name = "golang.org/x/crypto" | |||
@@ -61,7 +67,7 @@ ignored = ["google.golang.org/appengine*"] | |||
[[constraint]] | |||
name = "github.com/markbates/goth" | |||
version = "1.46.1" | |||
version = "1.47.2" | |||
[[constraint]] | |||
branch = "master" | |||
@@ -105,7 +111,7 @@ ignored = ["google.golang.org/appengine*"] | |||
source = "github.com/go-gitea/bolt" | |||
[[override]] | |||
revision = "c10ba270aa0bf8b8c1c986e103859c67a9103061" | |||
branch = "master" | |||
name = "golang.org/x/oauth2" | |||
[[constraint]] |
@@ -15,11 +15,12 @@ | |||
package bleve | |||
import ( | |||
"context" | |||
"github.com/blevesearch/bleve/document" | |||
"github.com/blevesearch/bleve/index" | |||
"github.com/blevesearch/bleve/index/store" | |||
"github.com/blevesearch/bleve/mapping" | |||
"golang.org/x/net/context" | |||
) | |||
// A Batch groups together multiple Index and Delete |
@@ -100,8 +100,8 @@ func (s *Scorch) introduceSegment(next *segmentIntroduction) error { | |||
// prepare new index snapshot | |||
newSnapshot := &IndexSnapshot{ | |||
parent: s, | |||
segment: make([]*SegmentSnapshot, nsegs, nsegs+1), | |||
offsets: make([]uint64, nsegs, nsegs+1), | |||
segment: make([]*SegmentSnapshot, 0, nsegs+1), | |||
offsets: make([]uint64, 0, nsegs+1), | |||
internal: make(map[string][]byte, len(s.root.internal)), | |||
epoch: s.nextSnapshotEpoch, | |||
refs: 1, | |||
@@ -124,24 +124,29 @@ func (s *Scorch) introduceSegment(next *segmentIntroduction) error { | |||
return err | |||
} | |||
} | |||
newSnapshot.segment[i] = &SegmentSnapshot{ | |||
newss := &SegmentSnapshot{ | |||
id: s.root.segment[i].id, | |||
segment: s.root.segment[i].segment, | |||
cachedDocs: s.root.segment[i].cachedDocs, | |||
} | |||
s.root.segment[i].segment.AddRef() | |||
// apply new obsoletions | |||
if s.root.segment[i].deleted == nil { | |||
newSnapshot.segment[i].deleted = delta | |||
newss.deleted = delta | |||
} else { | |||
newSnapshot.segment[i].deleted = roaring.Or(s.root.segment[i].deleted, delta) | |||
newss.deleted = roaring.Or(s.root.segment[i].deleted, delta) | |||
} | |||
newSnapshot.offsets[i] = running | |||
running += s.root.segment[i].Count() | |||
// check for live size before copying | |||
if newss.LiveSize() > 0 { | |||
newSnapshot.segment = append(newSnapshot.segment, newss) | |||
s.root.segment[i].segment.AddRef() | |||
newSnapshot.offsets = append(newSnapshot.offsets, running) | |||
running += s.root.segment[i].Count() | |||
} | |||
} | |||
// append new segment, if any, to end of the new index snapshot | |||
if next.data != nil { | |||
newSegmentSnapshot := &SegmentSnapshot{ | |||
@@ -193,6 +198,12 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) { | |||
// prepare new index snapshot | |||
currSize := len(s.root.segment) | |||
newSize := currSize + 1 - len(nextMerge.old) | |||
// empty segments deletion | |||
if nextMerge.new == nil { | |||
newSize-- | |||
} | |||
newSnapshot := &IndexSnapshot{ | |||
parent: s, | |||
segment: make([]*SegmentSnapshot, 0, newSize), | |||
@@ -210,7 +221,7 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) { | |||
segmentID := s.root.segment[i].id | |||
if segSnapAtMerge, ok := nextMerge.old[segmentID]; ok { | |||
// this segment is going away, see if anything else was deleted since we started the merge | |||
if s.root.segment[i].deleted != nil { | |||
if segSnapAtMerge != nil && s.root.segment[i].deleted != nil { | |||
// assume all these deletes are new | |||
deletedSince := s.root.segment[i].deleted | |||
// if we already knew about some of them, remove | |||
@@ -224,7 +235,13 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) { | |||
newSegmentDeleted.Add(uint32(newDocNum)) | |||
} | |||
} | |||
} else { | |||
// clean up the old segment map to figure out the | |||
// obsolete segments wrt root in meantime, whatever | |||
// segments left behind in old map after processing | |||
// the root segments would be the obsolete segment set | |||
delete(nextMerge.old, segmentID) | |||
} else if s.root.segment[i].LiveSize() > 0 { | |||
// this segment is staying | |||
newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{ | |||
id: s.root.segment[i].id, | |||
@@ -238,14 +255,35 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) { | |||
} | |||
} | |||
// put new segment at end | |||
newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{ | |||
id: nextMerge.id, | |||
segment: nextMerge.new, // take ownership for nextMerge.new's ref-count | |||
deleted: newSegmentDeleted, | |||
cachedDocs: &cachedDocs{cache: nil}, | |||
}) | |||
newSnapshot.offsets = append(newSnapshot.offsets, running) | |||
// before the newMerge introduction, need to clean the newly | |||
// merged segment wrt the current root segments, hence | |||
// applying the obsolete segment contents to newly merged segment | |||
for segID, ss := range nextMerge.old { | |||
obsoleted := ss.DocNumbersLive() | |||
if obsoleted != nil { | |||
obsoletedIter := obsoleted.Iterator() | |||
for obsoletedIter.HasNext() { | |||
oldDocNum := obsoletedIter.Next() | |||
newDocNum := nextMerge.oldNewDocNums[segID][oldDocNum] | |||
newSegmentDeleted.Add(uint32(newDocNum)) | |||
} | |||
} | |||
} | |||
// In case where all the docs in the newly merged segment getting | |||
// deleted by the time we reach here, can skip the introduction. | |||
if nextMerge.new != nil && | |||
nextMerge.new.Count() > newSegmentDeleted.GetCardinality() { | |||
// put new segment at end | |||
newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{ | |||
id: nextMerge.id, | |||
segment: nextMerge.new, // take ownership for nextMerge.new's ref-count | |||
deleted: newSegmentDeleted, | |||
cachedDocs: &cachedDocs{cache: nil}, | |||
}) | |||
newSnapshot.offsets = append(newSnapshot.offsets, running) | |||
} | |||
newSnapshot.AddRef() // 1 ref for the nextMerge.notify response | |||
// swap in new segment | |||
rootPrev := s.root | |||
@@ -257,7 +295,8 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) { | |||
_ = rootPrev.DecRef() | |||
} | |||
// notify merger we incorporated this | |||
// notify requester that we incorporated this | |||
nextMerge.notify <- newSnapshot | |||
close(nextMerge.notify) | |||
} | |||
@@ -15,6 +15,9 @@ | |||
package scorch | |||
import ( | |||
"bytes" | |||
"encoding/json" | |||
"fmt" | |||
"os" | |||
"sync/atomic" | |||
@@ -28,6 +31,13 @@ import ( | |||
func (s *Scorch) mergerLoop() { | |||
var lastEpochMergePlanned uint64 | |||
mergePlannerOptions, err := s.parseMergePlannerOptions() | |||
if err != nil { | |||
s.fireAsyncError(fmt.Errorf("mergePlannerOption json parsing err: %v", err)) | |||
s.asyncTasks.Done() | |||
return | |||
} | |||
OUTER: | |||
for { | |||
select { | |||
@@ -45,7 +55,7 @@ OUTER: | |||
startTime := time.Now() | |||
// lets get started | |||
err := s.planMergeAtSnapshot(ourSnapshot) | |||
err := s.planMergeAtSnapshot(ourSnapshot, mergePlannerOptions) | |||
if err != nil { | |||
s.fireAsyncError(fmt.Errorf("merging err: %v", err)) | |||
_ = ourSnapshot.DecRef() | |||
@@ -58,51 +68,49 @@ OUTER: | |||
_ = ourSnapshot.DecRef() | |||
// tell the persister we're waiting for changes | |||
// first make a notification chan | |||
notifyUs := make(notificationChan) | |||
// first make a epochWatcher chan | |||
ew := &epochWatcher{ | |||
epoch: lastEpochMergePlanned, | |||
notifyCh: make(notificationChan, 1), | |||
} | |||
// give it to the persister | |||
select { | |||
case <-s.closeCh: | |||
break OUTER | |||
case s.persisterNotifier <- notifyUs: | |||
} | |||
// check again | |||
s.rootLock.RLock() | |||
ourSnapshot = s.root | |||
ourSnapshot.AddRef() | |||
s.rootLock.RUnlock() | |||
if ourSnapshot.epoch != lastEpochMergePlanned { | |||
startTime := time.Now() | |||
// lets get started | |||
err := s.planMergeAtSnapshot(ourSnapshot) | |||
if err != nil { | |||
s.fireAsyncError(fmt.Errorf("merging err: %v", err)) | |||
_ = ourSnapshot.DecRef() | |||
continue OUTER | |||
} | |||
lastEpochMergePlanned = ourSnapshot.epoch | |||
s.fireEvent(EventKindMergerProgress, time.Since(startTime)) | |||
case s.persisterNotifier <- ew: | |||
} | |||
_ = ourSnapshot.DecRef() | |||
// now wait for it (but also detect close) | |||
// now wait for persister (but also detect close) | |||
select { | |||
case <-s.closeCh: | |||
break OUTER | |||
case <-notifyUs: | |||
// woken up, next loop should pick up work | |||
case <-ew.notifyCh: | |||
} | |||
} | |||
} | |||
s.asyncTasks.Done() | |||
} | |||
func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error { | |||
func (s *Scorch) parseMergePlannerOptions() (*mergeplan.MergePlanOptions, | |||
error) { | |||
mergePlannerOptions := mergeplan.DefaultMergePlanOptions | |||
if v, ok := s.config["scorchMergePlanOptions"]; ok { | |||
b, err := json.Marshal(v) | |||
if err != nil { | |||
return &mergePlannerOptions, err | |||
} | |||
err = json.Unmarshal(b, &mergePlannerOptions) | |||
if err != nil { | |||
return &mergePlannerOptions, err | |||
} | |||
} | |||
return &mergePlannerOptions, nil | |||
} | |||
func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot, | |||
options *mergeplan.MergePlanOptions) error { | |||
// build list of zap segments in this snapshot | |||
var onlyZapSnapshots []mergeplan.Segment | |||
for _, segmentSnapshot := range ourSnapshot.segment { | |||
@@ -112,7 +120,7 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error { | |||
} | |||
// give this list to the planner | |||
resultMergePlan, err := mergeplan.Plan(onlyZapSnapshots, nil) | |||
resultMergePlan, err := mergeplan.Plan(onlyZapSnapshots, options) | |||
if err != nil { | |||
return fmt.Errorf("merge planning err: %v", err) | |||
} | |||
@@ -122,8 +130,12 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error { | |||
} | |||
// process tasks in serial for now | |||
var notifications []notificationChan | |||
var notifications []chan *IndexSnapshot | |||
for _, task := range resultMergePlan.Tasks { | |||
if len(task.Segments) == 0 { | |||
continue | |||
} | |||
oldMap := make(map[uint64]*SegmentSnapshot) | |||
newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1) | |||
segmentsToMerge := make([]*zap.Segment, 0, len(task.Segments)) | |||
@@ -132,40 +144,51 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error { | |||
if segSnapshot, ok := planSegment.(*SegmentSnapshot); ok { | |||
oldMap[segSnapshot.id] = segSnapshot | |||
if zapSeg, ok := segSnapshot.segment.(*zap.Segment); ok { | |||
segmentsToMerge = append(segmentsToMerge, zapSeg) | |||
docsToDrop = append(docsToDrop, segSnapshot.deleted) | |||
if segSnapshot.LiveSize() == 0 { | |||
oldMap[segSnapshot.id] = nil | |||
} else { | |||
segmentsToMerge = append(segmentsToMerge, zapSeg) | |||
docsToDrop = append(docsToDrop, segSnapshot.deleted) | |||
} | |||
} | |||
} | |||
} | |||
filename := zapFileName(newSegmentID) | |||
s.markIneligibleForRemoval(filename) | |||
path := s.path + string(os.PathSeparator) + filename | |||
newDocNums, err := zap.Merge(segmentsToMerge, docsToDrop, path, DefaultChunkFactor) | |||
if err != nil { | |||
s.unmarkIneligibleForRemoval(filename) | |||
return fmt.Errorf("merging failed: %v", err) | |||
} | |||
segment, err := zap.Open(path) | |||
if err != nil { | |||
s.unmarkIneligibleForRemoval(filename) | |||
return err | |||
var oldNewDocNums map[uint64][]uint64 | |||
var segment segment.Segment | |||
if len(segmentsToMerge) > 0 { | |||
filename := zapFileName(newSegmentID) | |||
s.markIneligibleForRemoval(filename) | |||
path := s.path + string(os.PathSeparator) + filename | |||
newDocNums, err := zap.Merge(segmentsToMerge, docsToDrop, path, 1024) | |||
if err != nil { | |||
s.unmarkIneligibleForRemoval(filename) | |||
return fmt.Errorf("merging failed: %v", err) | |||
} | |||
segment, err = zap.Open(path) | |||
if err != nil { | |||
s.unmarkIneligibleForRemoval(filename) | |||
return err | |||
} | |||
oldNewDocNums = make(map[uint64][]uint64) | |||
for i, segNewDocNums := range newDocNums { | |||
oldNewDocNums[task.Segments[i].Id()] = segNewDocNums | |||
} | |||
} | |||
sm := &segmentMerge{ | |||
id: newSegmentID, | |||
old: oldMap, | |||
oldNewDocNums: make(map[uint64][]uint64), | |||
oldNewDocNums: oldNewDocNums, | |||
new: segment, | |||
notify: make(notificationChan), | |||
notify: make(chan *IndexSnapshot, 1), | |||
} | |||
notifications = append(notifications, sm.notify) | |||
for i, segNewDocNums := range newDocNums { | |||
sm.oldNewDocNums[task.Segments[i].Id()] = segNewDocNums | |||
} | |||
// give it to the introducer | |||
select { | |||
case <-s.closeCh: | |||
_ = segment.Close() | |||
return nil | |||
case s.merges <- sm: | |||
} | |||
@@ -174,7 +197,10 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error { | |||
select { | |||
case <-s.closeCh: | |||
return nil | |||
case <-notification: | |||
case newSnapshot := <-notification: | |||
if newSnapshot != nil { | |||
_ = newSnapshot.DecRef() | |||
} | |||
} | |||
} | |||
return nil | |||
@@ -185,5 +211,72 @@ type segmentMerge struct { | |||
old map[uint64]*SegmentSnapshot | |||
oldNewDocNums map[uint64][]uint64 | |||
new segment.Segment | |||
notify notificationChan | |||
notify chan *IndexSnapshot | |||
} | |||
// perform a merging of the given SegmentBase instances into a new, | |||
// persisted segment, and synchronously introduce that new segment | |||
// into the root | |||
func (s *Scorch) mergeSegmentBases(snapshot *IndexSnapshot, | |||
sbs []*zap.SegmentBase, sbsDrops []*roaring.Bitmap, sbsIndexes []int, | |||
chunkFactor uint32) (uint64, *IndexSnapshot, uint64, error) { | |||
var br bytes.Buffer | |||
cr := zap.NewCountHashWriter(&br) | |||
newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset, | |||
docValueOffset, dictLocs, fieldsInv, fieldsMap, err := | |||
zap.MergeToWriter(sbs, sbsDrops, chunkFactor, cr) | |||
if err != nil { | |||
return 0, nil, 0, err | |||
} | |||
sb, err := zap.InitSegmentBase(br.Bytes(), cr.Sum32(), chunkFactor, | |||
fieldsMap, fieldsInv, numDocs, storedIndexOffset, fieldsIndexOffset, | |||
docValueOffset, dictLocs) | |||
if err != nil { | |||
return 0, nil, 0, err | |||
} | |||
newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1) | |||
filename := zapFileName(newSegmentID) | |||
path := s.path + string(os.PathSeparator) + filename | |||
err = zap.PersistSegmentBase(sb, path) | |||
if err != nil { | |||
return 0, nil, 0, err | |||
} | |||
segment, err := zap.Open(path) | |||
if err != nil { | |||
return 0, nil, 0, err | |||
} | |||
sm := &segmentMerge{ | |||
id: newSegmentID, | |||
old: make(map[uint64]*SegmentSnapshot), | |||
oldNewDocNums: make(map[uint64][]uint64), | |||
new: segment, | |||
notify: make(chan *IndexSnapshot, 1), | |||
} | |||
for i, idx := range sbsIndexes { | |||
ss := snapshot.segment[idx] | |||
sm.old[ss.id] = ss | |||
sm.oldNewDocNums[ss.id] = newDocNums[i] | |||
} | |||
select { // send to introducer | |||
case <-s.closeCh: | |||
_ = segment.DecRef() | |||
return 0, nil, 0, nil // TODO: return ErrInterruptedClosed? | |||
case s.merges <- sm: | |||
} | |||
select { // wait for introduction to complete | |||
case <-s.closeCh: | |||
return 0, nil, 0, nil // TODO: return ErrInterruptedClosed? | |||
case newSnapshot := <-sm.notify: | |||
return numDocs, newSnapshot, newSegmentID, nil | |||
} | |||
} |
@@ -186,13 +186,13 @@ func plan(segmentsIn []Segment, o *MergePlanOptions) (*MergePlan, error) { | |||
// While we’re over budget, keep looping, which might produce | |||
// another MergeTask. | |||
for len(eligibles) > budgetNumSegments { | |||
for len(eligibles) > 0 && (len(eligibles)+len(rv.Tasks)) > budgetNumSegments { | |||
// Track a current best roster as we examine and score | |||
// potential rosters of merges. | |||
var bestRoster []Segment | |||
var bestRosterScore float64 // Lower score is better. | |||
for startIdx := 0; startIdx < len(eligibles)-o.SegmentsPerMergeTask; startIdx++ { | |||
for startIdx := 0; startIdx < len(eligibles); startIdx++ { | |||
var roster []Segment | |||
var rosterLiveSize int64 | |||
@@ -34,22 +34,39 @@ import ( | |||
var DefaultChunkFactor uint32 = 1024 | |||
// Arbitrary number, need to make it configurable. | |||
// Lower values like 10/making persister really slow | |||
// doesn't work well as it is creating more files to | |||
// persist for in next persist iteration and spikes the # FDs. | |||
// Ideal value should let persister also proceed at | |||
// an optimum pace so that the merger can skip | |||
// many intermediate snapshots. | |||
// This needs to be based on empirical data. | |||
// TODO - may need to revisit this approach/value. | |||
var epochDistance = uint64(5) | |||
type notificationChan chan struct{} | |||
func (s *Scorch) persisterLoop() { | |||
defer s.asyncTasks.Done() | |||
var notifyChs []notificationChan | |||
var lastPersistedEpoch uint64 | |||
var persistWatchers []*epochWatcher | |||
var lastPersistedEpoch, lastMergedEpoch uint64 | |||
var ew *epochWatcher | |||
OUTER: | |||
for { | |||
select { | |||
case <-s.closeCh: | |||
break OUTER | |||
case notifyCh := <-s.persisterNotifier: | |||
notifyChs = append(notifyChs, notifyCh) | |||
case ew = <-s.persisterNotifier: | |||
persistWatchers = append(persistWatchers, ew) | |||
default: | |||
} | |||
if ew != nil && ew.epoch > lastMergedEpoch { | |||
lastMergedEpoch = ew.epoch | |||
} | |||
persistWatchers = s.pausePersisterForMergerCatchUp(lastPersistedEpoch, | |||
&lastMergedEpoch, persistWatchers) | |||
var ourSnapshot *IndexSnapshot | |||
var ourPersisted []chan error | |||
@@ -81,10 +98,11 @@ OUTER: | |||
} | |||
lastPersistedEpoch = ourSnapshot.epoch | |||
for _, notifyCh := range notifyChs { | |||
close(notifyCh) | |||
for _, ew := range persistWatchers { | |||
close(ew.notifyCh) | |||
} | |||
notifyChs = nil | |||
persistWatchers = nil | |||
_ = ourSnapshot.DecRef() | |||
changed := false | |||
@@ -120,27 +138,155 @@ OUTER: | |||
break OUTER | |||
case <-w.notifyCh: | |||
// woken up, next loop should pick up work | |||
continue OUTER | |||
case ew = <-s.persisterNotifier: | |||
// if the watchers are already caught up then let them wait, | |||
// else let them continue to do the catch up | |||
persistWatchers = append(persistWatchers, ew) | |||
} | |||
} | |||
} | |||
func notifyMergeWatchers(lastPersistedEpoch uint64, | |||
persistWatchers []*epochWatcher) []*epochWatcher { | |||
var watchersNext []*epochWatcher | |||
for _, w := range persistWatchers { | |||
if w.epoch < lastPersistedEpoch { | |||
close(w.notifyCh) | |||
} else { | |||
watchersNext = append(watchersNext, w) | |||
} | |||
} | |||
return watchersNext | |||
} | |||
func (s *Scorch) pausePersisterForMergerCatchUp(lastPersistedEpoch uint64, lastMergedEpoch *uint64, | |||
persistWatchers []*epochWatcher) []*epochWatcher { | |||
// first, let the watchers proceed if they lag behind | |||
persistWatchers = notifyMergeWatchers(lastPersistedEpoch, persistWatchers) | |||
OUTER: | |||
// check for slow merger and await until the merger catch up | |||
for lastPersistedEpoch > *lastMergedEpoch+epochDistance { | |||
select { | |||
case <-s.closeCh: | |||
break OUTER | |||
case ew := <-s.persisterNotifier: | |||
persistWatchers = append(persistWatchers, ew) | |||
*lastMergedEpoch = ew.epoch | |||
} | |||
// let the watchers proceed if they lag behind | |||
persistWatchers = notifyMergeWatchers(lastPersistedEpoch, persistWatchers) | |||
} | |||
return persistWatchers | |||
} | |||
func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error { | |||
// start a write transaction | |||
tx, err := s.rootBolt.Begin(true) | |||
persisted, err := s.persistSnapshotMaybeMerge(snapshot) | |||
if err != nil { | |||
return err | |||
} | |||
// defer fsync of the rootbolt | |||
defer func() { | |||
if err == nil { | |||
err = s.rootBolt.Sync() | |||
if persisted { | |||
return nil | |||
} | |||
return s.persistSnapshotDirect(snapshot) | |||
} | |||
// DefaultMinSegmentsForInMemoryMerge represents the default number of | |||
// in-memory zap segments that persistSnapshotMaybeMerge() needs to | |||
// see in an IndexSnapshot before it decides to merge and persist | |||
// those segments | |||
var DefaultMinSegmentsForInMemoryMerge = 2 | |||
// persistSnapshotMaybeMerge examines the snapshot and might merge and | |||
// persist the in-memory zap segments if there are enough of them | |||
func (s *Scorch) persistSnapshotMaybeMerge(snapshot *IndexSnapshot) ( | |||
bool, error) { | |||
// collect the in-memory zap segments (SegmentBase instances) | |||
var sbs []*zap.SegmentBase | |||
var sbsDrops []*roaring.Bitmap | |||
var sbsIndexes []int | |||
for i, segmentSnapshot := range snapshot.segment { | |||
if sb, ok := segmentSnapshot.segment.(*zap.SegmentBase); ok { | |||
sbs = append(sbs, sb) | |||
sbsDrops = append(sbsDrops, segmentSnapshot.deleted) | |||
sbsIndexes = append(sbsIndexes, i) | |||
} | |||
} | |||
if len(sbs) < DefaultMinSegmentsForInMemoryMerge { | |||
return false, nil | |||
} | |||
_, newSnapshot, newSegmentID, err := s.mergeSegmentBases( | |||
snapshot, sbs, sbsDrops, sbsIndexes, DefaultChunkFactor) | |||
if err != nil { | |||
return false, err | |||
} | |||
if newSnapshot == nil { | |||
return false, nil | |||
} | |||
defer func() { | |||
_ = newSnapshot.DecRef() | |||
}() | |||
// defer commit/rollback transaction | |||
mergedSegmentIDs := map[uint64]struct{}{} | |||
for _, idx := range sbsIndexes { | |||
mergedSegmentIDs[snapshot.segment[idx].id] = struct{}{} | |||
} | |||
// construct a snapshot that's logically equivalent to the input | |||
// snapshot, but with merged segments replaced by the new segment | |||
equiv := &IndexSnapshot{ | |||
parent: snapshot.parent, | |||
segment: make([]*SegmentSnapshot, 0, len(snapshot.segment)), | |||
internal: snapshot.internal, | |||
epoch: snapshot.epoch, | |||
} | |||
// copy to the equiv the segments that weren't replaced | |||
for _, segment := range snapshot.segment { | |||
if _, wasMerged := mergedSegmentIDs[segment.id]; !wasMerged { | |||
equiv.segment = append(equiv.segment, segment) | |||
} | |||
} | |||
// append to the equiv the new segment | |||
for _, segment := range newSnapshot.segment { | |||
if segment.id == newSegmentID { | |||
equiv.segment = append(equiv.segment, &SegmentSnapshot{ | |||
id: newSegmentID, | |||
segment: segment.segment, | |||
deleted: nil, // nil since merging handled deletions | |||
}) | |||
break | |||
} | |||
} | |||
err = s.persistSnapshotDirect(equiv) | |||
if err != nil { | |||
return false, err | |||
} | |||
return true, nil | |||
} | |||
func (s *Scorch) persistSnapshotDirect(snapshot *IndexSnapshot) (err error) { | |||
// start a write transaction | |||
tx, err := s.rootBolt.Begin(true) | |||
if err != nil { | |||
return err | |||
} | |||
// defer rollback on error | |||
defer func() { | |||
if err == nil { | |||
err = tx.Commit() | |||
} else { | |||
if err != nil { | |||
_ = tx.Rollback() | |||
} | |||
}() | |||
@@ -172,20 +318,20 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error { | |||
newSegmentPaths := make(map[uint64]string) | |||
// first ensure that each segment in this snapshot has been persisted | |||
for i, segmentSnapshot := range snapshot.segment { | |||
snapshotSegmentKey := segment.EncodeUvarintAscending(nil, uint64(i)) | |||
snapshotSegmentBucket, err2 := snapshotBucket.CreateBucketIfNotExists(snapshotSegmentKey) | |||
if err2 != nil { | |||
return err2 | |||
for _, segmentSnapshot := range snapshot.segment { | |||
snapshotSegmentKey := segment.EncodeUvarintAscending(nil, segmentSnapshot.id) | |||
snapshotSegmentBucket, err := snapshotBucket.CreateBucketIfNotExists(snapshotSegmentKey) | |||
if err != nil { | |||
return err | |||
} | |||
switch seg := segmentSnapshot.segment.(type) { | |||
case *zap.SegmentBase: | |||
// need to persist this to disk | |||
filename := zapFileName(segmentSnapshot.id) | |||
path := s.path + string(os.PathSeparator) + filename | |||
err2 := zap.PersistSegmentBase(seg, path) | |||
if err2 != nil { | |||
return fmt.Errorf("error persisting segment: %v", err2) | |||
err = zap.PersistSegmentBase(seg, path) | |||
if err != nil { | |||
return fmt.Errorf("error persisting segment: %v", err) | |||
} | |||
newSegmentPaths[segmentSnapshot.id] = path | |||
err = snapshotSegmentBucket.Put(boltPathKey, []byte(filename)) | |||
@@ -218,19 +364,28 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error { | |||
} | |||
} | |||
// only alter the root if we actually persisted a segment | |||
// (sometimes its just a new snapshot, possibly with new internal values) | |||
// we need to swap in a new root only when we've persisted 1 or | |||
// more segments -- whereby the new root would have 1-for-1 | |||
// replacements of in-memory segments with file-based segments | |||
// | |||
// other cases like updates to internal values only, and/or when | |||
// there are only deletions, are already covered and persisted by | |||
// the newly populated boltdb snapshotBucket above | |||
if len(newSegmentPaths) > 0 { | |||
// now try to open all the new snapshots | |||
newSegments := make(map[uint64]segment.Segment) | |||
defer func() { | |||
for _, s := range newSegments { | |||
if s != nil { | |||
// cleanup segments that were opened but not | |||
// swapped into the new root | |||
_ = s.Close() | |||
} | |||
} | |||
}() | |||
for segmentID, path := range newSegmentPaths { | |||
newSegments[segmentID], err = zap.Open(path) | |||
if err != nil { | |||
for _, s := range newSegments { | |||
if s != nil { | |||
_ = s.Close() // cleanup segments that were successfully opened | |||
} | |||
} | |||
return fmt.Errorf("error opening new segment at %s, %v", path, err) | |||
} | |||
} | |||
@@ -255,6 +410,7 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error { | |||
cachedDocs: segmentSnapshot.cachedDocs, | |||
} | |||
newIndexSnapshot.segment[i] = newSegmentSnapshot | |||
delete(newSegments, segmentSnapshot.id) | |||
// update items persisted incase of a new segment snapshot | |||
atomic.AddUint64(&s.stats.numItemsPersisted, newSegmentSnapshot.Count()) | |||
} else { | |||
@@ -266,9 +422,7 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error { | |||
for k, v := range s.root.internal { | |||
newIndexSnapshot.internal[k] = v | |||
} | |||
for _, filename := range filenames { | |||
delete(s.ineligibleForRemoval, filename) | |||
} | |||
rootPrev := s.root | |||
s.root = newIndexSnapshot | |||
s.rootLock.Unlock() | |||
@@ -277,6 +431,24 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error { | |||
} | |||
} | |||
err = tx.Commit() | |||
if err != nil { | |||
return err | |||
} | |||
err = s.rootBolt.Sync() | |||
if err != nil { | |||
return err | |||
} | |||
// allow files to become eligible for removal after commit, such | |||
// as file segments from snapshots that came from the merger | |||
s.rootLock.Lock() | |||
for _, filename := range filenames { | |||
delete(s.ineligibleForRemoval, filename) | |||
} | |||
s.rootLock.Unlock() | |||
return nil | |||
} | |||
@@ -61,7 +61,7 @@ type Scorch struct { | |||
merges chan *segmentMerge | |||
introducerNotifier chan *epochWatcher | |||
revertToSnapshots chan *snapshotReversion | |||
persisterNotifier chan notificationChan | |||
persisterNotifier chan *epochWatcher | |||
rootBolt *bolt.DB | |||
asyncTasks sync.WaitGroup | |||
@@ -114,6 +114,25 @@ func (s *Scorch) fireAsyncError(err error) { | |||
} | |||
func (s *Scorch) Open() error { | |||
err := s.openBolt() | |||
if err != nil { | |||
return err | |||
} | |||
s.asyncTasks.Add(1) | |||
go s.mainLoop() | |||
if !s.readOnly && s.path != "" { | |||
s.asyncTasks.Add(1) | |||
go s.persisterLoop() | |||
s.asyncTasks.Add(1) | |||
go s.mergerLoop() | |||
} | |||
return nil | |||
} | |||
func (s *Scorch) openBolt() error { | |||
var ok bool | |||
s.path, ok = s.config["path"].(string) | |||
if !ok { | |||
@@ -136,6 +155,7 @@ func (s *Scorch) Open() error { | |||
} | |||
} | |||
} | |||
rootBoltPath := s.path + string(os.PathSeparator) + "root.bolt" | |||
var err error | |||
if s.path != "" { | |||
@@ -156,7 +176,7 @@ func (s *Scorch) Open() error { | |||
s.merges = make(chan *segmentMerge) | |||
s.introducerNotifier = make(chan *epochWatcher, 1) | |||
s.revertToSnapshots = make(chan *snapshotReversion) | |||
s.persisterNotifier = make(chan notificationChan) | |||
s.persisterNotifier = make(chan *epochWatcher, 1) | |||
if !s.readOnly && s.path != "" { | |||
err := s.removeOldZapFiles() // Before persister or merger create any new files. | |||
@@ -166,16 +186,6 @@ func (s *Scorch) Open() error { | |||
} | |||
} | |||
s.asyncTasks.Add(1) | |||
go s.mainLoop() | |||
if !s.readOnly && s.path != "" { | |||
s.asyncTasks.Add(1) | |||
go s.persisterLoop() | |||
s.asyncTasks.Add(1) | |||
go s.mergerLoop() | |||
} | |||
return nil | |||
} | |||
@@ -310,17 +320,21 @@ func (s *Scorch) prepareSegment(newSegment segment.Segment, ids []string, | |||
introduction.persisted = make(chan error, 1) | |||
} | |||
// get read lock, to optimistically prepare obsoleted info | |||
// optimistically prepare obsoletes outside of rootLock | |||
s.rootLock.RLock() | |||
for _, seg := range s.root.segment { | |||
root := s.root | |||
root.AddRef() | |||
s.rootLock.RUnlock() | |||
for _, seg := range root.segment { | |||
delta, err := seg.segment.DocNumbers(ids) | |||
if err != nil { | |||
s.rootLock.RUnlock() | |||
return err | |||
} | |||
introduction.obsoletes[seg.id] = delta | |||
} | |||
s.rootLock.RUnlock() | |||
_ = root.DecRef() | |||
s.introductions <- introduction | |||
@@ -95,6 +95,21 @@ func (s *Segment) initializeDict(results []*index.AnalysisResult) { | |||
var numTokenFrequencies int | |||
var totLocs int | |||
// initial scan for all fieldID's to sort them | |||
for _, result := range results { | |||
for _, field := range result.Document.CompositeFields { | |||
s.getOrDefineField(field.Name()) | |||
} | |||
for _, field := range result.Document.Fields { | |||
s.getOrDefineField(field.Name()) | |||
} | |||
} | |||
sort.Strings(s.FieldsInv[1:]) // keep _id as first field | |||
s.FieldsMap = make(map[string]uint16, len(s.FieldsInv)) | |||
for fieldID, fieldName := range s.FieldsInv { | |||
s.FieldsMap[fieldName] = uint16(fieldID + 1) | |||
} | |||
processField := func(fieldID uint16, tfs analysis.TokenFrequencies) { | |||
for term, tf := range tfs { | |||
pidPlus1, exists := s.Dicts[fieldID][term] |
@@ -76,6 +76,8 @@ type DictionaryIterator struct { | |||
prefix string | |||
end string | |||
offset int | |||
dictEntry index.DictEntry // reused across Next()'s | |||
} | |||
// Next returns the next entry in the dictionary | |||
@@ -95,8 +97,7 @@ func (d *DictionaryIterator) Next() (*index.DictEntry, error) { | |||
d.offset++ | |||
postingID := d.d.segment.Dicts[d.d.fieldID][next] | |||
return &index.DictEntry{ | |||
Term: next, | |||
Count: d.d.segment.Postings[postingID-1].GetCardinality(), | |||
}, nil | |||
d.dictEntry.Term = next | |||
d.dictEntry.Count = d.d.segment.Postings[postingID-1].GetCardinality() | |||
return &d.dictEntry, nil | |||
} |
@@ -28,7 +28,7 @@ import ( | |||
"github.com/golang/snappy" | |||
) | |||
const version uint32 = 2 | |||
const version uint32 = 3 | |||
const fieldNotUninverted = math.MaxUint64 | |||
@@ -187,79 +187,42 @@ func persistBase(memSegment *mem.Segment, cr *CountHashWriter, chunkFactor uint3 | |||
} | |||
func persistStored(memSegment *mem.Segment, w *CountHashWriter) (uint64, error) { | |||
var curr int | |||
var metaBuf bytes.Buffer | |||
var data, compressed []byte | |||
metaEncoder := govarint.NewU64Base128Encoder(&metaBuf) | |||
docNumOffsets := make(map[int]uint64, len(memSegment.Stored)) | |||
for docNum, storedValues := range memSegment.Stored { | |||
if docNum != 0 { | |||
// reset buffer if necessary | |||
curr = 0 | |||
metaBuf.Reset() | |||
data = data[:0] | |||
compressed = compressed[:0] | |||
curr = 0 | |||
} | |||
metaEncoder := govarint.NewU64Base128Encoder(&metaBuf) | |||
st := memSegment.StoredTypes[docNum] | |||
sp := memSegment.StoredPos[docNum] | |||
// encode fields in order | |||
for fieldID := range memSegment.FieldsInv { | |||
if storedFieldValues, ok := storedValues[uint16(fieldID)]; ok { | |||
// has stored values for this field | |||
num := len(storedFieldValues) | |||
stf := st[uint16(fieldID)] | |||
spf := sp[uint16(fieldID)] | |||
// process each value | |||
for i := 0; i < num; i++ { | |||
// encode field | |||
_, err2 := metaEncoder.PutU64(uint64(fieldID)) | |||
if err2 != nil { | |||
return 0, err2 | |||
} | |||
// encode type | |||
_, err2 = metaEncoder.PutU64(uint64(stf[i])) | |||
if err2 != nil { | |||
return 0, err2 | |||
} | |||
// encode start offset | |||
_, err2 = metaEncoder.PutU64(uint64(curr)) | |||
if err2 != nil { | |||
return 0, err2 | |||
} | |||
// end len | |||
_, err2 = metaEncoder.PutU64(uint64(len(storedFieldValues[i]))) | |||
if err2 != nil { | |||
return 0, err2 | |||
} | |||
// encode number of array pos | |||
_, err2 = metaEncoder.PutU64(uint64(len(spf[i]))) | |||
if err2 != nil { | |||
return 0, err2 | |||
} | |||
// encode all array positions | |||
for _, pos := range spf[i] { | |||
_, err2 = metaEncoder.PutU64(pos) | |||
if err2 != nil { | |||
return 0, err2 | |||
} | |||
} | |||
// append data | |||
data = append(data, storedFieldValues[i]...) | |||
// update curr | |||
curr += len(storedFieldValues[i]) | |||
var err2 error | |||
curr, data, err2 = persistStoredFieldValues(fieldID, | |||
storedFieldValues, stf, spf, curr, metaEncoder, data) | |||
if err2 != nil { | |||
return 0, err2 | |||
} | |||
} | |||
} | |||
metaEncoder.Close() | |||
metaEncoder.Close() | |||
metaBytes := metaBuf.Bytes() | |||
// compress the data | |||
@@ -299,6 +262,51 @@ func persistStored(memSegment *mem.Segment, w *CountHashWriter) (uint64, error) | |||
return rv, nil | |||
} | |||
func persistStoredFieldValues(fieldID int, | |||
storedFieldValues [][]byte, stf []byte, spf [][]uint64, | |||
curr int, metaEncoder *govarint.Base128Encoder, data []byte) ( | |||
int, []byte, error) { | |||
for i := 0; i < len(storedFieldValues); i++ { | |||
// encode field | |||
_, err := metaEncoder.PutU64(uint64(fieldID)) | |||
if err != nil { | |||
return 0, nil, err | |||
} | |||
// encode type | |||
_, err = metaEncoder.PutU64(uint64(stf[i])) | |||
if err != nil { | |||
return 0, nil, err | |||
} | |||
// encode start offset | |||
_, err = metaEncoder.PutU64(uint64(curr)) | |||
if err != nil { | |||
return 0, nil, err | |||
} | |||
// end len | |||
_, err = metaEncoder.PutU64(uint64(len(storedFieldValues[i]))) | |||
if err != nil { | |||
return 0, nil, err | |||
} | |||
// encode number of array pos | |||
_, err = metaEncoder.PutU64(uint64(len(spf[i]))) | |||
if err != nil { | |||
return 0, nil, err | |||
} | |||
// encode all array positions | |||
for _, pos := range spf[i] { | |||
_, err = metaEncoder.PutU64(pos) | |||
if err != nil { | |||
return 0, nil, err | |||
} | |||
} | |||
data = append(data, storedFieldValues[i]...) | |||
curr += len(storedFieldValues[i]) | |||
} | |||
return curr, data, nil | |||
} | |||
func persistPostingDetails(memSegment *mem.Segment, w *CountHashWriter, chunkFactor uint32) ([]uint64, []uint64, error) { | |||
var freqOffsets, locOfffsets []uint64 | |||
tfEncoder := newChunkedIntCoder(uint64(chunkFactor), uint64(len(memSegment.Stored)-1)) | |||
@@ -580,7 +588,7 @@ func persistDocValues(memSegment *mem.Segment, w *CountHashWriter, | |||
if err != nil { | |||
return nil, err | |||
} | |||
// resetting encoder for the next field | |||
// reseting encoder for the next field | |||
fdvEncoder.Reset() | |||
} | |||
@@ -625,12 +633,21 @@ func NewSegmentBase(memSegment *mem.Segment, chunkFactor uint32) (*SegmentBase, | |||
return nil, err | |||
} | |||
return InitSegmentBase(br.Bytes(), cr.Sum32(), chunkFactor, | |||
memSegment.FieldsMap, memSegment.FieldsInv, numDocs, | |||
storedIndexOffset, fieldsIndexOffset, docValueOffset, dictLocs) | |||
} | |||
func InitSegmentBase(mem []byte, memCRC uint32, chunkFactor uint32, | |||
fieldsMap map[string]uint16, fieldsInv []string, numDocs uint64, | |||
storedIndexOffset uint64, fieldsIndexOffset uint64, docValueOffset uint64, | |||
dictLocs []uint64) (*SegmentBase, error) { | |||
sb := &SegmentBase{ | |||
mem: br.Bytes(), | |||
memCRC: cr.Sum32(), | |||
mem: mem, | |||
memCRC: memCRC, | |||
chunkFactor: chunkFactor, | |||
fieldsMap: memSegment.FieldsMap, | |||
fieldsInv: memSegment.FieldsInv, | |||
fieldsMap: fieldsMap, | |||
fieldsInv: fieldsInv, | |||
numDocs: numDocs, | |||
storedIndexOffset: storedIndexOffset, | |||
fieldsIndexOffset: fieldsIndexOffset, | |||
@@ -639,7 +656,7 @@ func NewSegmentBase(memSegment *mem.Segment, chunkFactor uint32) (*SegmentBase, | |||
fieldDvIterMap: make(map[uint16]*docValueIterator), | |||
} | |||
err = sb.loadDvIterators() | |||
err := sb.loadDvIterators() | |||
if err != nil { | |||
return nil, err | |||
} |
@@ -39,7 +39,7 @@ type chunkedContentCoder struct { | |||
// MetaData represents the data information inside a | |||
// chunk. | |||
type MetaData struct { | |||
DocID uint64 // docid of the data inside the chunk | |||
DocNum uint64 // docNum of the data inside the chunk | |||
DocDvLoc uint64 // starting offset for a given docid | |||
DocDvLen uint64 // length of data inside the chunk for the given docid | |||
} | |||
@@ -52,7 +52,7 @@ func newChunkedContentCoder(chunkSize uint64, | |||
rv := &chunkedContentCoder{ | |||
chunkSize: chunkSize, | |||
chunkLens: make([]uint64, total), | |||
chunkMeta: []MetaData{}, | |||
chunkMeta: make([]MetaData, 0, total), | |||
} | |||
return rv | |||
@@ -68,7 +68,7 @@ func (c *chunkedContentCoder) Reset() { | |||
for i := range c.chunkLens { | |||
c.chunkLens[i] = 0 | |||
} | |||
c.chunkMeta = []MetaData{} | |||
c.chunkMeta = c.chunkMeta[:0] | |||
} | |||
// Close indicates you are done calling Add() this allows | |||
@@ -88,7 +88,7 @@ func (c *chunkedContentCoder) flushContents() error { | |||
// write out the metaData slice | |||
for _, meta := range c.chunkMeta { | |||
_, err := writeUvarints(&c.chunkMetaBuf, meta.DocID, meta.DocDvLoc, meta.DocDvLen) | |||
_, err := writeUvarints(&c.chunkMetaBuf, meta.DocNum, meta.DocDvLoc, meta.DocDvLen) | |||
if err != nil { | |||
return err | |||
} | |||
@@ -118,7 +118,7 @@ func (c *chunkedContentCoder) Add(docNum uint64, vals []byte) error { | |||
// clearing the chunk specific meta for next chunk | |||
c.chunkBuf.Reset() | |||
c.chunkMetaBuf.Reset() | |||
c.chunkMeta = []MetaData{} | |||
c.chunkMeta = c.chunkMeta[:0] | |||
c.currChunk = chunk | |||
} | |||
@@ -130,7 +130,7 @@ func (c *chunkedContentCoder) Add(docNum uint64, vals []byte) error { | |||
} | |||
c.chunkMeta = append(c.chunkMeta, MetaData{ | |||
DocID: docNum, | |||
DocNum: docNum, | |||
DocDvLoc: uint64(dvOffset), | |||
DocDvLen: uint64(dvSize), | |||
}) |
@@ -34,32 +34,47 @@ type Dictionary struct { | |||
// PostingsList returns the postings list for the specified term | |||
func (d *Dictionary) PostingsList(term string, except *roaring.Bitmap) (segment.PostingsList, error) { | |||
return d.postingsList([]byte(term), except) | |||
return d.postingsList([]byte(term), except, nil) | |||
} | |||
func (d *Dictionary) postingsList(term []byte, except *roaring.Bitmap) (*PostingsList, error) { | |||
rv := &PostingsList{ | |||
sb: d.sb, | |||
term: term, | |||
except: except, | |||
func (d *Dictionary) postingsList(term []byte, except *roaring.Bitmap, rv *PostingsList) (*PostingsList, error) { | |||
if d.fst == nil { | |||
return d.postingsListInit(rv, except), nil | |||
} | |||
if d.fst != nil { | |||
postingsOffset, exists, err := d.fst.Get(term) | |||
if err != nil { | |||
return nil, fmt.Errorf("vellum err: %v", err) | |||
} | |||
if exists { | |||
err = rv.read(postingsOffset, d) | |||
if err != nil { | |||
return nil, err | |||
} | |||
} | |||
postingsOffset, exists, err := d.fst.Get(term) | |||
if err != nil { | |||
return nil, fmt.Errorf("vellum err: %v", err) | |||
} | |||
if !exists { | |||
return d.postingsListInit(rv, except), nil | |||
} | |||
return d.postingsListFromOffset(postingsOffset, except, rv) | |||
} | |||
func (d *Dictionary) postingsListFromOffset(postingsOffset uint64, except *roaring.Bitmap, rv *PostingsList) (*PostingsList, error) { | |||
rv = d.postingsListInit(rv, except) | |||
err := rv.read(postingsOffset, d) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return rv, nil | |||
} | |||
func (d *Dictionary) postingsListInit(rv *PostingsList, except *roaring.Bitmap) *PostingsList { | |||
if rv == nil { | |||
rv = &PostingsList{} | |||
} else { | |||
*rv = PostingsList{} // clear the struct | |||
} | |||
rv.sb = d.sb | |||
rv.except = except | |||
return rv | |||
} | |||
// Iterator returns an iterator for this dictionary | |||
func (d *Dictionary) Iterator() segment.DictionaryIterator { | |||
rv := &DictionaryIterator{ |
@@ -99,7 +99,7 @@ func (s *SegmentBase) loadFieldDocValueIterator(field string, | |||
func (di *docValueIterator) loadDvChunk(chunkNumber, | |||
localDocNum uint64, s *SegmentBase) error { | |||
// advance to the chunk where the docValues | |||
// reside for the given docID | |||
// reside for the given docNum | |||
destChunkDataLoc := di.dvDataLoc | |||
for i := 0; i < int(chunkNumber); i++ { | |||
destChunkDataLoc += di.chunkLens[i] | |||
@@ -116,7 +116,7 @@ func (di *docValueIterator) loadDvChunk(chunkNumber, | |||
offset := uint64(0) | |||
di.curChunkHeader = make([]MetaData, int(numDocs)) | |||
for i := 0; i < int(numDocs); i++ { | |||
di.curChunkHeader[i].DocID, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64]) | |||
di.curChunkHeader[i].DocNum, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64]) | |||
offset += uint64(read) | |||
di.curChunkHeader[i].DocDvLoc, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64]) | |||
offset += uint64(read) | |||
@@ -131,10 +131,10 @@ func (di *docValueIterator) loadDvChunk(chunkNumber, | |||
return nil | |||
} | |||
func (di *docValueIterator) visitDocValues(docID uint64, | |||
func (di *docValueIterator) visitDocValues(docNum uint64, | |||
visitor index.DocumentFieldTermVisitor) error { | |||
// binary search the term locations for the docID | |||
start, length := di.getDocValueLocs(docID) | |||
// binary search the term locations for the docNum | |||
start, length := di.getDocValueLocs(docNum) | |||
if start == math.MaxUint64 || length == math.MaxUint64 { | |||
return nil | |||
} | |||
@@ -144,7 +144,7 @@ func (di *docValueIterator) visitDocValues(docID uint64, | |||
return err | |||
} | |||
// pick the terms for the given docID | |||
// pick the terms for the given docNum | |||
uncompressed = uncompressed[start : start+length] | |||
for { | |||
i := bytes.Index(uncompressed, termSeparatorSplitSlice) | |||
@@ -159,11 +159,11 @@ func (di *docValueIterator) visitDocValues(docID uint64, | |||
return nil | |||
} | |||
func (di *docValueIterator) getDocValueLocs(docID uint64) (uint64, uint64) { | |||
func (di *docValueIterator) getDocValueLocs(docNum uint64) (uint64, uint64) { | |||
i := sort.Search(len(di.curChunkHeader), func(i int) bool { | |||
return di.curChunkHeader[i].DocID >= docID | |||
return di.curChunkHeader[i].DocNum >= docNum | |||
}) | |||
if i < len(di.curChunkHeader) && di.curChunkHeader[i].DocID == docID { | |||
if i < len(di.curChunkHeader) && di.curChunkHeader[i].DocNum == docNum { | |||
return di.curChunkHeader[i].DocDvLoc, di.curChunkHeader[i].DocDvLen | |||
} | |||
return math.MaxUint64, math.MaxUint64 |
@@ -0,0 +1,124 @@ | |||
// Copyright (c) 2018 Couchbase, Inc. | |||
// | |||
// Licensed under the Apache License, Version 2.0 (the "License"); | |||
// you may not use this file except in compliance with the License. | |||
// You may obtain a copy of the License at | |||
// | |||
// http://www.apache.org/licenses/LICENSE-2.0 | |||
// | |||
// Unless required by applicable law or agreed to in writing, software | |||
// distributed under the License is distributed on an "AS IS" BASIS, | |||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
// See the License for the specific language governing permissions and | |||
// limitations under the License. | |||
package zap | |||
import ( | |||
"bytes" | |||
"github.com/couchbase/vellum" | |||
) | |||
// enumerator provides an ordered traversal of multiple vellum | |||
// iterators. Like JOIN of iterators, the enumerator produces a | |||
// sequence of (key, iteratorIndex, value) tuples, sorted by key ASC, | |||
// then iteratorIndex ASC, where the same key might be seen or | |||
// repeated across multiple child iterators. | |||
type enumerator struct { | |||
itrs []vellum.Iterator | |||
currKs [][]byte | |||
currVs []uint64 | |||
lowK []byte | |||
lowIdxs []int | |||
lowCurr int | |||
} | |||
// newEnumerator returns a new enumerator over the vellum Iterators | |||
func newEnumerator(itrs []vellum.Iterator) (*enumerator, error) { | |||
rv := &enumerator{ | |||
itrs: itrs, | |||
currKs: make([][]byte, len(itrs)), | |||
currVs: make([]uint64, len(itrs)), | |||
lowIdxs: make([]int, 0, len(itrs)), | |||
} | |||
for i, itr := range rv.itrs { | |||
rv.currKs[i], rv.currVs[i] = itr.Current() | |||
} | |||
rv.updateMatches() | |||
if rv.lowK == nil { | |||
return rv, vellum.ErrIteratorDone | |||
} | |||
return rv, nil | |||
} | |||
// updateMatches maintains the low key matches based on the currKs | |||
func (m *enumerator) updateMatches() { | |||
m.lowK = nil | |||
m.lowIdxs = m.lowIdxs[:0] | |||
m.lowCurr = 0 | |||
for i, key := range m.currKs { | |||
if key == nil { | |||
continue | |||
} | |||
cmp := bytes.Compare(key, m.lowK) | |||
if cmp < 0 || m.lowK == nil { | |||
// reached a new low | |||
m.lowK = key | |||
m.lowIdxs = m.lowIdxs[:0] | |||
m.lowIdxs = append(m.lowIdxs, i) | |||
} else if cmp == 0 { | |||
m.lowIdxs = append(m.lowIdxs, i) | |||
} | |||
} | |||
} | |||
// Current returns the enumerator's current key, iterator-index, and | |||
// value. If the enumerator is not pointing at a valid value (because | |||
// Next returned an error previously), Current will return nil,0,0. | |||
func (m *enumerator) Current() ([]byte, int, uint64) { | |||
var i int | |||
var v uint64 | |||
if m.lowCurr < len(m.lowIdxs) { | |||
i = m.lowIdxs[m.lowCurr] | |||
v = m.currVs[i] | |||
} | |||
return m.lowK, i, v | |||
} | |||
// Next advances the enumerator to the next key/iterator/value result, | |||
// else vellum.ErrIteratorDone is returned. | |||
func (m *enumerator) Next() error { | |||
m.lowCurr += 1 | |||
if m.lowCurr >= len(m.lowIdxs) { | |||
// move all the current low iterators forwards | |||
for _, vi := range m.lowIdxs { | |||
err := m.itrs[vi].Next() | |||
if err != nil && err != vellum.ErrIteratorDone { | |||
return err | |||
} | |||
m.currKs[vi], m.currVs[vi] = m.itrs[vi].Current() | |||
} | |||
m.updateMatches() | |||
} | |||
if m.lowK == nil { | |||
return vellum.ErrIteratorDone | |||
} | |||
return nil | |||
} | |||
// Close all the underlying Iterators. The first error, if any, will | |||
// be returned. | |||
func (m *enumerator) Close() error { | |||
var rv error | |||
for _, itr := range m.itrs { | |||
err := itr.Close() | |||
if rv == nil { | |||
rv = err | |||
} | |||
} | |||
return rv | |||
} |
@@ -30,6 +30,8 @@ type chunkedIntCoder struct { | |||
encoder *govarint.Base128Encoder | |||
chunkLens []uint64 | |||
currChunk uint64 | |||
buf []byte | |||
} | |||
// newChunkedIntCoder returns a new chunk int coder which packs data into | |||
@@ -67,12 +69,8 @@ func (c *chunkedIntCoder) Add(docNum uint64, vals ...uint64) error { | |||
// starting a new chunk | |||
if c.encoder != nil { | |||
// close out last | |||
c.encoder.Close() | |||
encodingBytes := c.chunkBuf.Bytes() | |||
c.chunkLens[c.currChunk] = uint64(len(encodingBytes)) | |||
c.final = append(c.final, encodingBytes...) | |||
c.Close() | |||
c.chunkBuf.Reset() | |||
c.encoder = govarint.NewU64Base128Encoder(&c.chunkBuf) | |||
} | |||
c.currChunk = chunk | |||
} | |||
@@ -98,26 +96,25 @@ func (c *chunkedIntCoder) Close() { | |||
// Write commits all the encoded chunked integers to the provided writer. | |||
func (c *chunkedIntCoder) Write(w io.Writer) (int, error) { | |||
var tw int | |||
buf := make([]byte, binary.MaxVarintLen64) | |||
// write out the number of chunks | |||
bufNeeded := binary.MaxVarintLen64 * (1 + len(c.chunkLens)) | |||
if len(c.buf) < bufNeeded { | |||
c.buf = make([]byte, bufNeeded) | |||
} | |||
buf := c.buf | |||
// write out the number of chunks & each chunkLen | |||
n := binary.PutUvarint(buf, uint64(len(c.chunkLens))) | |||
nw, err := w.Write(buf[:n]) | |||
tw += nw | |||
for _, chunkLen := range c.chunkLens { | |||
n += binary.PutUvarint(buf[n:], uint64(chunkLen)) | |||
} | |||
tw, err := w.Write(buf[:n]) | |||
if err != nil { | |||
return tw, err | |||
} | |||
// write out the chunk lens | |||
for _, chunkLen := range c.chunkLens { | |||
n := binary.PutUvarint(buf, uint64(chunkLen)) | |||
nw, err = w.Write(buf[:n]) | |||
tw += nw | |||
if err != nil { | |||
return tw, err | |||
} | |||
} | |||
// write out the data | |||
nw, err = w.Write(c.final) | |||
nw, err := w.Write(c.final) | |||
tw += nw | |||
if err != nil { | |||
return tw, err |
@@ -21,6 +21,7 @@ import ( | |||
"fmt" | |||
"math" | |||
"os" | |||
"sort" | |||
"github.com/RoaringBitmap/roaring" | |||
"github.com/Smerity/govarint" | |||
@@ -28,6 +29,8 @@ import ( | |||
"github.com/golang/snappy" | |||
) | |||
const docDropped = math.MaxUint64 // sentinel docNum to represent a deleted doc | |||
// Merge takes a slice of zap segments and bit masks describing which | |||
// documents may be dropped, and creates a new segment containing the | |||
// remaining data. This new segment is built at the specified path, | |||
@@ -46,47 +49,26 @@ func Merge(segments []*Segment, drops []*roaring.Bitmap, path string, | |||
_ = os.Remove(path) | |||
} | |||
segmentBases := make([]*SegmentBase, len(segments)) | |||
for segmenti, segment := range segments { | |||
segmentBases[segmenti] = &segment.SegmentBase | |||
} | |||
// buffer the output | |||
br := bufio.NewWriter(f) | |||
// wrap it for counting (tracking offsets) | |||
cr := NewCountHashWriter(br) | |||
fieldsInv := mergeFields(segments) | |||
fieldsMap := mapFields(fieldsInv) | |||
var newDocNums [][]uint64 | |||
var storedIndexOffset uint64 | |||
fieldDvLocsOffset := uint64(fieldNotUninverted) | |||
var dictLocs []uint64 | |||
newSegDocCount := computeNewDocCount(segments, drops) | |||
if newSegDocCount > 0 { | |||
storedIndexOffset, newDocNums, err = mergeStoredAndRemap(segments, drops, | |||
fieldsMap, fieldsInv, newSegDocCount, cr) | |||
if err != nil { | |||
cleanup() | |||
return nil, err | |||
} | |||
dictLocs, fieldDvLocsOffset, err = persistMergedRest(segments, drops, fieldsInv, fieldsMap, | |||
newDocNums, newSegDocCount, chunkFactor, cr) | |||
if err != nil { | |||
cleanup() | |||
return nil, err | |||
} | |||
} else { | |||
dictLocs = make([]uint64, len(fieldsInv)) | |||
} | |||
fieldsIndexOffset, err := persistFields(fieldsInv, cr, dictLocs) | |||
newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset, _, _, _, err := | |||
MergeToWriter(segmentBases, drops, chunkFactor, cr) | |||
if err != nil { | |||
cleanup() | |||
return nil, err | |||
} | |||
err = persistFooter(newSegDocCount, storedIndexOffset, | |||
fieldsIndexOffset, fieldDvLocsOffset, chunkFactor, cr.Sum32(), cr) | |||
err = persistFooter(numDocs, storedIndexOffset, fieldsIndexOffset, | |||
docValueOffset, chunkFactor, cr.Sum32(), cr) | |||
if err != nil { | |||
cleanup() | |||
return nil, err | |||
@@ -113,21 +95,59 @@ func Merge(segments []*Segment, drops []*roaring.Bitmap, path string, | |||
return newDocNums, nil | |||
} | |||
// mapFields takes the fieldsInv list and builds the map | |||
func MergeToWriter(segments []*SegmentBase, drops []*roaring.Bitmap, | |||
chunkFactor uint32, cr *CountHashWriter) ( | |||
newDocNums [][]uint64, | |||
numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset uint64, | |||
dictLocs []uint64, fieldsInv []string, fieldsMap map[string]uint16, | |||
err error) { | |||
docValueOffset = uint64(fieldNotUninverted) | |||
var fieldsSame bool | |||
fieldsSame, fieldsInv = mergeFields(segments) | |||
fieldsMap = mapFields(fieldsInv) | |||
numDocs = computeNewDocCount(segments, drops) | |||
if numDocs > 0 { | |||
storedIndexOffset, newDocNums, err = mergeStoredAndRemap(segments, drops, | |||
fieldsMap, fieldsInv, fieldsSame, numDocs, cr) | |||
if err != nil { | |||
return nil, 0, 0, 0, 0, nil, nil, nil, err | |||
} | |||
dictLocs, docValueOffset, err = persistMergedRest(segments, drops, fieldsInv, fieldsMap, | |||
newDocNums, numDocs, chunkFactor, cr) | |||
if err != nil { | |||
return nil, 0, 0, 0, 0, nil, nil, nil, err | |||
} | |||
} else { | |||
dictLocs = make([]uint64, len(fieldsInv)) | |||
} | |||
fieldsIndexOffset, err = persistFields(fieldsInv, cr, dictLocs) | |||
if err != nil { | |||
return nil, 0, 0, 0, 0, nil, nil, nil, err | |||
} | |||
return newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset, dictLocs, fieldsInv, fieldsMap, nil | |||
} | |||
// mapFields takes the fieldsInv list and returns a map of fieldName | |||
// to fieldID+1 | |||
func mapFields(fields []string) map[string]uint16 { | |||
rv := make(map[string]uint16, len(fields)) | |||
for i, fieldName := range fields { | |||
rv[fieldName] = uint16(i) | |||
rv[fieldName] = uint16(i) + 1 | |||
} | |||
return rv | |||
} | |||
// computeNewDocCount determines how many documents will be in the newly | |||
// merged segment when obsoleted docs are dropped | |||
func computeNewDocCount(segments []*Segment, drops []*roaring.Bitmap) uint64 { | |||
func computeNewDocCount(segments []*SegmentBase, drops []*roaring.Bitmap) uint64 { | |||
var newDocCount uint64 | |||
for segI, segment := range segments { | |||
newDocCount += segment.NumDocs() | |||
newDocCount += segment.numDocs | |||
if drops[segI] != nil { | |||
newDocCount -= drops[segI].GetCardinality() | |||
} | |||
@@ -135,8 +155,8 @@ func computeNewDocCount(segments []*Segment, drops []*roaring.Bitmap) uint64 { | |||
return newDocCount | |||
} | |||
func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap, | |||
fieldsInv []string, fieldsMap map[string]uint16, newDocNums [][]uint64, | |||
func persistMergedRest(segments []*SegmentBase, dropsIn []*roaring.Bitmap, | |||
fieldsInv []string, fieldsMap map[string]uint16, newDocNumsIn [][]uint64, | |||
newSegDocCount uint64, chunkFactor uint32, | |||
w *CountHashWriter) ([]uint64, uint64, error) { | |||
@@ -144,9 +164,14 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap, | |||
var bufMaxVarintLen64 []byte = make([]byte, binary.MaxVarintLen64) | |||
var bufLoc []uint64 | |||
var postings *PostingsList | |||
var postItr *PostingsIterator | |||
rv := make([]uint64, len(fieldsInv)) | |||
fieldDvLocs := make([]uint64, len(fieldsInv)) | |||
fieldDvLocsOffset := uint64(fieldNotUninverted) | |||
tfEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1) | |||
locEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1) | |||
// docTermMap is keyed by docNum, where the array impl provides | |||
// better memory usage behavior than a sparse-friendlier hashmap | |||
@@ -166,36 +191,31 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap, | |||
return nil, 0, err | |||
} | |||
// collect FST iterators from all segments for this field | |||
// collect FST iterators from all active segments for this field | |||
var newDocNums [][]uint64 | |||
var drops []*roaring.Bitmap | |||
var dicts []*Dictionary | |||
var itrs []vellum.Iterator | |||
for _, segment := range segments { | |||
for segmentI, segment := range segments { | |||
dict, err2 := segment.dictionary(fieldName) | |||
if err2 != nil { | |||
return nil, 0, err2 | |||
} | |||
dicts = append(dicts, dict) | |||
if dict != nil && dict.fst != nil { | |||
itr, err2 := dict.fst.Iterator(nil, nil) | |||
if err2 != nil && err2 != vellum.ErrIteratorDone { | |||
return nil, 0, err2 | |||
} | |||
if itr != nil { | |||
newDocNums = append(newDocNums, newDocNumsIn[segmentI]) | |||
drops = append(drops, dropsIn[segmentI]) | |||
dicts = append(dicts, dict) | |||
itrs = append(itrs, itr) | |||
} | |||
} | |||
} | |||
// create merging iterator | |||
mergeItr, err := vellum.NewMergeIterator(itrs, func(postingOffsets []uint64) uint64 { | |||
// we don't actually use the merged value | |||
return 0 | |||
}) | |||
tfEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1) | |||
locEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1) | |||
if uint64(cap(docTermMap)) < newSegDocCount { | |||
docTermMap = make([][]byte, newSegDocCount) | |||
} else { | |||
@@ -205,70 +225,14 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap, | |||
} | |||
} | |||
for err == nil { | |||
term, _ := mergeItr.Current() | |||
newRoaring := roaring.NewBitmap() | |||
newRoaringLocs := roaring.NewBitmap() | |||
tfEncoder.Reset() | |||
locEncoder.Reset() | |||
// now go back and get posting list for this term | |||
// but pass in the deleted docs for that segment | |||
for dictI, dict := range dicts { | |||
if dict == nil { | |||
continue | |||
} | |||
postings, err2 := dict.postingsList(term, drops[dictI]) | |||
if err2 != nil { | |||
return nil, 0, err2 | |||
} | |||
postItr := postings.Iterator() | |||
next, err2 := postItr.Next() | |||
for next != nil && err2 == nil { | |||
hitNewDocNum := newDocNums[dictI][next.Number()] | |||
if hitNewDocNum == docDropped { | |||
return nil, 0, fmt.Errorf("see hit with dropped doc num") | |||
} | |||
newRoaring.Add(uint32(hitNewDocNum)) | |||
// encode norm bits | |||
norm := next.Norm() | |||
normBits := math.Float32bits(float32(norm)) | |||
err = tfEncoder.Add(hitNewDocNum, next.Frequency(), uint64(normBits)) | |||
if err != nil { | |||
return nil, 0, err | |||
} | |||
locs := next.Locations() | |||
if len(locs) > 0 { | |||
newRoaringLocs.Add(uint32(hitNewDocNum)) | |||
for _, loc := range locs { | |||
if cap(bufLoc) < 5+len(loc.ArrayPositions()) { | |||
bufLoc = make([]uint64, 0, 5+len(loc.ArrayPositions())) | |||
} | |||
args := bufLoc[0:5] | |||
args[0] = uint64(fieldsMap[loc.Field()]) | |||
args[1] = loc.Pos() | |||
args[2] = loc.Start() | |||
args[3] = loc.End() | |||
args[4] = uint64(len(loc.ArrayPositions())) | |||
args = append(args, loc.ArrayPositions()...) | |||
err = locEncoder.Add(hitNewDocNum, args...) | |||
if err != nil { | |||
return nil, 0, err | |||
} | |||
} | |||
} | |||
var prevTerm []byte | |||
docTermMap[hitNewDocNum] = | |||
append(append(docTermMap[hitNewDocNum], term...), termSeparator) | |||
newRoaring := roaring.NewBitmap() | |||
newRoaringLocs := roaring.NewBitmap() | |||
next, err2 = postItr.Next() | |||
} | |||
if err2 != nil { | |||
return nil, 0, err2 | |||
} | |||
finishTerm := func(term []byte) error { | |||
if term == nil { | |||
return nil | |||
} | |||
tfEncoder.Close() | |||
@@ -277,59 +241,142 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap, | |||
if newRoaring.GetCardinality() > 0 { | |||
// this field/term actually has hits in the new segment, lets write it down | |||
freqOffset := uint64(w.Count()) | |||
_, err = tfEncoder.Write(w) | |||
_, err := tfEncoder.Write(w) | |||
if err != nil { | |||
return nil, 0, err | |||
return err | |||
} | |||
locOffset := uint64(w.Count()) | |||
_, err = locEncoder.Write(w) | |||
if err != nil { | |||
return nil, 0, err | |||
return err | |||
} | |||
postingLocOffset := uint64(w.Count()) | |||
_, err = writeRoaringWithLen(newRoaringLocs, w, &bufReuse, bufMaxVarintLen64) | |||
if err != nil { | |||
return nil, 0, err | |||
return err | |||
} | |||
postingOffset := uint64(w.Count()) | |||
// write out the start of the term info | |||
buf := bufMaxVarintLen64 | |||
n := binary.PutUvarint(buf, freqOffset) | |||
_, err = w.Write(buf[:n]) | |||
n := binary.PutUvarint(bufMaxVarintLen64, freqOffset) | |||
_, err = w.Write(bufMaxVarintLen64[:n]) | |||
if err != nil { | |||
return nil, 0, err | |||
return err | |||
} | |||
// write out the start of the loc info | |||
n = binary.PutUvarint(buf, locOffset) | |||
_, err = w.Write(buf[:n]) | |||
n = binary.PutUvarint(bufMaxVarintLen64, locOffset) | |||
_, err = w.Write(bufMaxVarintLen64[:n]) | |||
if err != nil { | |||
return nil, 0, err | |||
return err | |||
} | |||
// write out the start of the loc posting list | |||
n = binary.PutUvarint(buf, postingLocOffset) | |||
_, err = w.Write(buf[:n]) | |||
// write out the start of the posting locs | |||
n = binary.PutUvarint(bufMaxVarintLen64, postingLocOffset) | |||
_, err = w.Write(bufMaxVarintLen64[:n]) | |||
if err != nil { | |||
return nil, 0, err | |||
return err | |||
} | |||
_, err = writeRoaringWithLen(newRoaring, w, &bufReuse, bufMaxVarintLen64) | |||
if err != nil { | |||
return nil, 0, err | |||
return err | |||
} | |||
err = newVellum.Insert(term, postingOffset) | |||
if err != nil { | |||
return err | |||
} | |||
} | |||
newRoaring = roaring.NewBitmap() | |||
newRoaringLocs = roaring.NewBitmap() | |||
tfEncoder.Reset() | |||
locEncoder.Reset() | |||
return nil | |||
} | |||
enumerator, err := newEnumerator(itrs) | |||
for err == nil { | |||
term, itrI, postingsOffset := enumerator.Current() | |||
if !bytes.Equal(prevTerm, term) { | |||
// if the term changed, write out the info collected | |||
// for the previous term | |||
err2 := finishTerm(prevTerm) | |||
if err2 != nil { | |||
return nil, 0, err2 | |||
} | |||
} | |||
var err2 error | |||
postings, err2 = dicts[itrI].postingsListFromOffset( | |||
postingsOffset, drops[itrI], postings) | |||
if err2 != nil { | |||
return nil, 0, err2 | |||
} | |||
newDocNumsI := newDocNums[itrI] | |||
postItr = postings.iterator(postItr) | |||
next, err2 := postItr.Next() | |||
for next != nil && err2 == nil { | |||
hitNewDocNum := newDocNumsI[next.Number()] | |||
if hitNewDocNum == docDropped { | |||
return nil, 0, fmt.Errorf("see hit with dropped doc num") | |||
} | |||
newRoaring.Add(uint32(hitNewDocNum)) | |||
// encode norm bits | |||
norm := next.Norm() | |||
normBits := math.Float32bits(float32(norm)) | |||
err = tfEncoder.Add(hitNewDocNum, next.Frequency(), uint64(normBits)) | |||
if err != nil { | |||
return nil, 0, err | |||
} | |||
locs := next.Locations() | |||
if len(locs) > 0 { | |||
newRoaringLocs.Add(uint32(hitNewDocNum)) | |||
for _, loc := range locs { | |||
if cap(bufLoc) < 5+len(loc.ArrayPositions()) { | |||
bufLoc = make([]uint64, 0, 5+len(loc.ArrayPositions())) | |||
} | |||
args := bufLoc[0:5] | |||
args[0] = uint64(fieldsMap[loc.Field()] - 1) | |||
args[1] = loc.Pos() | |||
args[2] = loc.Start() | |||
args[3] = loc.End() | |||
args[4] = uint64(len(loc.ArrayPositions())) | |||
args = append(args, loc.ArrayPositions()...) | |||
err = locEncoder.Add(hitNewDocNum, args...) | |||
if err != nil { | |||
return nil, 0, err | |||
} | |||
} | |||
} | |||
docTermMap[hitNewDocNum] = | |||
append(append(docTermMap[hitNewDocNum], term...), termSeparator) | |||
next, err2 = postItr.Next() | |||
} | |||
if err2 != nil { | |||
return nil, 0, err2 | |||
} | |||
err = mergeItr.Next() | |||
prevTerm = prevTerm[:0] // copy to prevTerm in case Next() reuses term mem | |||
prevTerm = append(prevTerm, term...) | |||
err = enumerator.Next() | |||
} | |||
if err != nil && err != vellum.ErrIteratorDone { | |||
return nil, 0, err | |||
} | |||
err = finishTerm(prevTerm) | |||
if err != nil { | |||
return nil, 0, err | |||
} | |||
dictOffset := uint64(w.Count()) | |||
err = newVellum.Close() | |||
@@ -378,7 +425,7 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap, | |||
} | |||
} | |||
fieldDvLocsOffset = uint64(w.Count()) | |||
fieldDvLocsOffset := uint64(w.Count()) | |||
buf := bufMaxVarintLen64 | |||
for _, offset := range fieldDvLocs { | |||
@@ -392,10 +439,8 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap, | |||
return rv, fieldDvLocsOffset, nil | |||
} | |||
const docDropped = math.MaxUint64 | |||
func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap, | |||
fieldsMap map[string]uint16, fieldsInv []string, newSegDocCount uint64, | |||
func mergeStoredAndRemap(segments []*SegmentBase, drops []*roaring.Bitmap, | |||
fieldsMap map[string]uint16, fieldsInv []string, fieldsSame bool, newSegDocCount uint64, | |||
w *CountHashWriter) (uint64, [][]uint64, error) { | |||
var rv [][]uint64 // The remapped or newDocNums for each segment. | |||
@@ -417,10 +462,30 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap, | |||
for segI, segment := range segments { | |||
segNewDocNums := make([]uint64, segment.numDocs) | |||
dropsI := drops[segI] | |||
// optimize when the field mapping is the same across all | |||
// segments and there are no deletions, via byte-copying | |||
// of stored docs bytes directly to the writer | |||
if fieldsSame && (dropsI == nil || dropsI.GetCardinality() == 0) { | |||
err := segment.copyStoredDocs(newDocNum, docNumOffsets, w) | |||
if err != nil { | |||
return 0, nil, err | |||
} | |||
for i := uint64(0); i < segment.numDocs; i++ { | |||
segNewDocNums[i] = newDocNum | |||
newDocNum++ | |||
} | |||
rv = append(rv, segNewDocNums) | |||
continue | |||
} | |||
// for each doc num | |||
for docNum := uint64(0); docNum < segment.numDocs; docNum++ { | |||
// TODO: roaring's API limits docNums to 32-bits? | |||
if drops[segI] != nil && drops[segI].Contains(uint32(docNum)) { | |||
if dropsI != nil && dropsI.Contains(uint32(docNum)) { | |||
segNewDocNums[docNum] = docDropped | |||
continue | |||
} | |||
@@ -439,7 +504,7 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap, | |||
poss[i] = poss[i][:0] | |||
} | |||
err := segment.VisitDocument(docNum, func(field string, typ byte, value []byte, pos []uint64) bool { | |||
fieldID := int(fieldsMap[field]) | |||
fieldID := int(fieldsMap[field]) - 1 | |||
vals[fieldID] = append(vals[fieldID], value) | |||
typs[fieldID] = append(typs[fieldID], typ) | |||
poss[fieldID] = append(poss[fieldID], pos) | |||
@@ -453,47 +518,14 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap, | |||
for fieldID := range fieldsInv { | |||
storedFieldValues := vals[int(fieldID)] | |||
// has stored values for this field | |||
num := len(storedFieldValues) | |||
stf := typs[int(fieldID)] | |||
spf := poss[int(fieldID)] | |||
// process each value | |||
for i := 0; i < num; i++ { | |||
// encode field | |||
_, err2 := metaEncoder.PutU64(uint64(fieldID)) | |||
if err2 != nil { | |||
return 0, nil, err2 | |||
} | |||
// encode type | |||
_, err2 = metaEncoder.PutU64(uint64(typs[int(fieldID)][i])) | |||
if err2 != nil { | |||
return 0, nil, err2 | |||
} | |||
// encode start offset | |||
_, err2 = metaEncoder.PutU64(uint64(curr)) | |||
if err2 != nil { | |||
return 0, nil, err2 | |||
} | |||
// end len | |||
_, err2 = metaEncoder.PutU64(uint64(len(storedFieldValues[i]))) | |||
if err2 != nil { | |||
return 0, nil, err2 | |||
} | |||
// encode number of array pos | |||
_, err2 = metaEncoder.PutU64(uint64(len(poss[int(fieldID)][i]))) | |||
if err2 != nil { | |||
return 0, nil, err2 | |||
} | |||
// encode all array positions | |||
for j := 0; j < len(poss[int(fieldID)][i]); j++ { | |||
_, err2 = metaEncoder.PutU64(poss[int(fieldID)][i][j]) | |||
if err2 != nil { | |||
return 0, nil, err2 | |||
} | |||
} | |||
// append data | |||
data = append(data, storedFieldValues[i]...) | |||
// update curr | |||
curr += len(storedFieldValues[i]) | |||
var err2 error | |||
curr, data, err2 = persistStoredFieldValues(fieldID, | |||
storedFieldValues, stf, spf, curr, metaEncoder, data) | |||
if err2 != nil { | |||
return 0, nil, err2 | |||
} | |||
} | |||
@@ -528,36 +560,87 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap, | |||
} | |||
// return value is the start of the stored index | |||
offset := uint64(w.Count()) | |||
storedIndexOffset := uint64(w.Count()) | |||
// now write out the stored doc index | |||
for docNum := range docNumOffsets { | |||
err := binary.Write(w, binary.BigEndian, docNumOffsets[docNum]) | |||
for _, docNumOffset := range docNumOffsets { | |||
err := binary.Write(w, binary.BigEndian, docNumOffset) | |||
if err != nil { | |||
return 0, nil, err | |||
} | |||
} | |||
return offset, rv, nil | |||
return storedIndexOffset, rv, nil | |||
} | |||
// mergeFields builds a unified list of fields used across all the input segments | |||
func mergeFields(segments []*Segment) []string { | |||
fieldsMap := map[string]struct{}{} | |||
// copyStoredDocs writes out a segment's stored doc info, optimized by | |||
// using a single Write() call for the entire set of bytes. The | |||
// newDocNumOffsets is filled with the new offsets for each doc. | |||
func (s *SegmentBase) copyStoredDocs(newDocNum uint64, newDocNumOffsets []uint64, | |||
w *CountHashWriter) error { | |||
if s.numDocs <= 0 { | |||
return nil | |||
} | |||
indexOffset0, storedOffset0, _, _, _ := | |||
s.getDocStoredOffsets(0) // the segment's first doc | |||
indexOffsetN, storedOffsetN, readN, metaLenN, dataLenN := | |||
s.getDocStoredOffsets(s.numDocs - 1) // the segment's last doc | |||
storedOffset0New := uint64(w.Count()) | |||
storedBytes := s.mem[storedOffset0 : storedOffsetN+readN+metaLenN+dataLenN] | |||
_, err := w.Write(storedBytes) | |||
if err != nil { | |||
return err | |||
} | |||
// remap the storedOffset's for the docs into new offsets relative | |||
// to storedOffset0New, filling the given docNumOffsetsOut array | |||
for indexOffset := indexOffset0; indexOffset <= indexOffsetN; indexOffset += 8 { | |||
storedOffset := binary.BigEndian.Uint64(s.mem[indexOffset : indexOffset+8]) | |||
storedOffsetNew := storedOffset - storedOffset0 + storedOffset0New | |||
newDocNumOffsets[newDocNum] = storedOffsetNew | |||
newDocNum += 1 | |||
} | |||
return nil | |||
} | |||
// mergeFields builds a unified list of fields used across all the | |||
// input segments, and computes whether the fields are the same across | |||
// segments (which depends on fields to be sorted in the same way | |||
// across segments) | |||
func mergeFields(segments []*SegmentBase) (bool, []string) { | |||
fieldsSame := true | |||
var segment0Fields []string | |||
if len(segments) > 0 { | |||
segment0Fields = segments[0].Fields() | |||
} | |||
fieldsExist := map[string]struct{}{} | |||
for _, segment := range segments { | |||
fields := segment.Fields() | |||
for _, field := range fields { | |||
fieldsMap[field] = struct{}{} | |||
for fieldi, field := range fields { | |||
fieldsExist[field] = struct{}{} | |||
if len(segment0Fields) != len(fields) || segment0Fields[fieldi] != field { | |||
fieldsSame = false | |||
} | |||
} | |||
} | |||
rv := make([]string, 0, len(fieldsMap)) | |||
rv := make([]string, 0, len(fieldsExist)) | |||
// ensure _id stays first | |||
rv = append(rv, "_id") | |||
for k := range fieldsMap { | |||
for k := range fieldsExist { | |||
if k != "_id" { | |||
rv = append(rv, k) | |||
} | |||
} | |||
return rv | |||
sort.Strings(rv[1:]) // leave _id as first | |||
return fieldsSame, rv | |||
} |
@@ -28,21 +28,27 @@ import ( | |||
// PostingsList is an in-memory represenation of a postings list | |||
type PostingsList struct { | |||
sb *SegmentBase | |||
term []byte | |||
postingsOffset uint64 | |||
freqOffset uint64 | |||
locOffset uint64 | |||
locBitmap *roaring.Bitmap | |||
postings *roaring.Bitmap | |||
except *roaring.Bitmap | |||
postingKey []byte | |||
} | |||
// Iterator returns an iterator for this postings list | |||
func (p *PostingsList) Iterator() segment.PostingsIterator { | |||
rv := &PostingsIterator{ | |||
postings: p, | |||
return p.iterator(nil) | |||
} | |||
func (p *PostingsList) iterator(rv *PostingsIterator) *PostingsIterator { | |||
if rv == nil { | |||
rv = &PostingsIterator{} | |||
} else { | |||
*rv = PostingsIterator{} // clear the struct | |||
} | |||
rv.postings = p | |||
if p.postings != nil { | |||
// prepare the freq chunk details | |||
var n uint64 |
@@ -17,15 +17,27 @@ package zap | |||
import "encoding/binary" | |||
func (s *SegmentBase) getDocStoredMetaAndCompressed(docNum uint64) ([]byte, []byte) { | |||
docStoredStartAddr := s.storedIndexOffset + (8 * docNum) | |||
docStoredStart := binary.BigEndian.Uint64(s.mem[docStoredStartAddr : docStoredStartAddr+8]) | |||
_, storedOffset, n, metaLen, dataLen := s.getDocStoredOffsets(docNum) | |||
meta := s.mem[storedOffset+n : storedOffset+n+metaLen] | |||
data := s.mem[storedOffset+n+metaLen : storedOffset+n+metaLen+dataLen] | |||
return meta, data | |||
} | |||
func (s *SegmentBase) getDocStoredOffsets(docNum uint64) ( | |||
uint64, uint64, uint64, uint64, uint64) { | |||
indexOffset := s.storedIndexOffset + (8 * docNum) | |||
storedOffset := binary.BigEndian.Uint64(s.mem[indexOffset : indexOffset+8]) | |||
var n uint64 | |||
metaLen, read := binary.Uvarint(s.mem[docStoredStart : docStoredStart+binary.MaxVarintLen64]) | |||
metaLen, read := binary.Uvarint(s.mem[storedOffset : storedOffset+binary.MaxVarintLen64]) | |||
n += uint64(read) | |||
var dataLen uint64 | |||
dataLen, read = binary.Uvarint(s.mem[docStoredStart+n : docStoredStart+n+binary.MaxVarintLen64]) | |||
dataLen, read := binary.Uvarint(s.mem[storedOffset+n : storedOffset+n+binary.MaxVarintLen64]) | |||
n += uint64(read) | |||
meta := s.mem[docStoredStart+n : docStoredStart+n+metaLen] | |||
data := s.mem[docStoredStart+n+metaLen : docStoredStart+n+metaLen+dataLen] | |||
return meta, data | |||
return indexOffset, storedOffset, n, metaLen, dataLen | |||
} |
@@ -343,8 +343,9 @@ func (s *SegmentBase) DocNumbers(ids []string) (*roaring.Bitmap, error) { | |||
return nil, err | |||
} | |||
var postings *PostingsList | |||
for _, id := range ids { | |||
postings, err := idDict.postingsList([]byte(id), nil) | |||
postings, err = idDict.postingsList([]byte(id), nil, postings) | |||
if err != nil { | |||
return nil, err | |||
} |
@@ -31,10 +31,9 @@ func (r *RollbackPoint) GetInternal(key []byte) []byte { | |||
return r.meta[string(key)] | |||
} | |||
// RollbackPoints returns an array of rollback points available | |||
// for the application to make a decision on where to rollback | |||
// to. A nil return value indicates that there are no available | |||
// rollback points. | |||
// RollbackPoints returns an array of rollback points available for | |||
// the application to rollback to, with more recent rollback points | |||
// (higher epochs) coming first. | |||
func (s *Scorch) RollbackPoints() ([]*RollbackPoint, error) { | |||
if s.rootBolt == nil { | |||
return nil, fmt.Errorf("RollbackPoints: root is nil") | |||
@@ -54,7 +53,7 @@ func (s *Scorch) RollbackPoints() ([]*RollbackPoint, error) { | |||
snapshots := tx.Bucket(boltSnapshotsBucket) | |||
if snapshots == nil { | |||
return nil, fmt.Errorf("RollbackPoints: no snapshots available") | |||
return nil, nil | |||
} | |||
rollbackPoints := []*RollbackPoint{} | |||
@@ -150,10 +149,7 @@ func (s *Scorch) Rollback(to *RollbackPoint) error { | |||
revert.snapshot = indexSnapshot | |||
revert.applied = make(chan error) | |||
if !s.unsafeBatch { | |||
revert.persisted = make(chan error) | |||
} | |||
revert.persisted = make(chan error) | |||
return nil | |||
}) | |||
@@ -173,9 +169,5 @@ func (s *Scorch) Rollback(to *RollbackPoint) error { | |||
return fmt.Errorf("Rollback: failed with err: %v", err) | |||
} | |||
if revert.persisted != nil { | |||
err = <-revert.persisted | |||
} | |||
return err | |||
return <-revert.persisted | |||
} |
@@ -837,6 +837,11 @@ func (udc *UpsideDownCouch) Batch(batch *index.Batch) (err error) { | |||
docBackIndexRowErr = err | |||
return | |||
} | |||
defer func() { | |||
if cerr := kvreader.Close(); err == nil && cerr != nil { | |||
docBackIndexRowErr = cerr | |||
} | |||
}() | |||
for docID, doc := range batch.IndexOps { | |||
backIndexRow, err := backIndexRowForDoc(kvreader, index.IndexInternalID(docID)) | |||
@@ -847,12 +852,6 @@ func (udc *UpsideDownCouch) Batch(batch *index.Batch) (err error) { | |||
docBackIndexRowCh <- &docBackIndexRow{docID, doc, backIndexRow} | |||
} | |||
err = kvreader.Close() | |||
if err != nil { | |||
docBackIndexRowErr = err | |||
return | |||
} | |||
}() | |||
// wait for analysis result |
@@ -15,12 +15,11 @@ | |||
package bleve | |||
import ( | |||
"context" | |||
"sort" | |||
"sync" | |||
"time" | |||
"golang.org/x/net/context" | |||
"github.com/blevesearch/bleve/document" | |||
"github.com/blevesearch/bleve/index" | |||
"github.com/blevesearch/bleve/index/store" |
@@ -15,6 +15,7 @@ | |||
package bleve | |||
import ( | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"os" | |||
@@ -22,8 +23,6 @@ import ( | |||
"sync/atomic" | |||
"time" | |||
"golang.org/x/net/context" | |||
"github.com/blevesearch/bleve/document" | |||
"github.com/blevesearch/bleve/index" | |||
"github.com/blevesearch/bleve/index/store" |
@@ -15,11 +15,10 @@ | |||
package search | |||
import ( | |||
"context" | |||
"time" | |||
"github.com/blevesearch/bleve/index" | |||
"golang.org/x/net/context" | |||
) | |||
type Collector interface { |
@@ -15,11 +15,11 @@ | |||
package collector | |||
import ( | |||
"context" | |||
"time" | |||
"github.com/blevesearch/bleve/index" | |||
"github.com/blevesearch/bleve/search" | |||
"golang.org/x/net/context" | |||
) | |||
type collectorStore interface { |
@@ -1,10 +1,10 @@ | |||
package goth | |||
import ( | |||
"context" | |||
"fmt" | |||
"net/http" | |||
"golang.org/x/net/context" | |||
"golang.org/x/oauth2" | |||
) | |||
@@ -4,17 +4,18 @@ package facebook | |||
import ( | |||
"bytes" | |||
"crypto/hmac" | |||
"crypto/sha256" | |||
"encoding/hex" | |||
"encoding/json" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"io/ioutil" | |||
"net/http" | |||
"net/url" | |||
"strings" | |||
"crypto/hmac" | |||
"crypto/sha256" | |||
"encoding/hex" | |||
"fmt" | |||
"github.com/markbates/goth" | |||
"golang.org/x/oauth2" | |||
) | |||
@@ -22,7 +23,7 @@ import ( | |||
const ( | |||
authURL string = "https://www.facebook.com/dialog/oauth" | |||
tokenURL string = "https://graph.facebook.com/oauth/access_token" | |||
endpointProfile string = "https://graph.facebook.com/me?fields=email,first_name,last_name,link,about,id,name,picture,location" | |||
endpointProfile string = "https://graph.facebook.com/me?fields=" | |||
) | |||
// New creates a new Facebook provider, and sets up important connection details. | |||
@@ -68,9 +69,9 @@ func (p *Provider) Debug(debug bool) {} | |||
// BeginAuth asks Facebook for an authentication end-point. | |||
func (p *Provider) BeginAuth(state string) (goth.Session, error) { | |||
url := p.config.AuthCodeURL(state) | |||
authUrl := p.config.AuthCodeURL(state) | |||
session := &Session{ | |||
AuthURL: url, | |||
AuthURL: authUrl, | |||
} | |||
return session, nil | |||
} | |||
@@ -96,7 +97,15 @@ func (p *Provider) FetchUser(session goth.Session) (goth.User, error) { | |||
hash.Write([]byte(sess.AccessToken)) | |||
appsecretProof := hex.EncodeToString(hash.Sum(nil)) | |||
response, err := p.Client().Get(endpointProfile + "&access_token=" + url.QueryEscape(sess.AccessToken) + "&appsecret_proof=" + appsecretProof) | |||
reqUrl := fmt.Sprint( | |||
endpointProfile, | |||
strings.Join(p.config.Scopes, ","), | |||
"&access_token=", | |||
url.QueryEscape(sess.AccessToken), | |||
"&appsecret_proof=", | |||
appsecretProof, | |||
) | |||
response, err := p.Client().Get(reqUrl) | |||
if err != nil { | |||
return user, err | |||
} | |||
@@ -168,17 +177,31 @@ func newConfig(provider *Provider, scopes []string) *oauth2.Config { | |||
}, | |||
Scopes: []string{ | |||
"email", | |||
"first_name", | |||
"last_name", | |||
"link", | |||
"about", | |||
"id", | |||
"name", | |||
"picture", | |||
"location", | |||
}, | |||
} | |||
defaultScopes := map[string]struct{}{ | |||
"email": {}, | |||
} | |||
for _, scope := range scopes { | |||
if _, exists := defaultScopes[scope]; !exists { | |||
c.Scopes = append(c.Scopes, scope) | |||
// creates possibility to invoke field method like 'picture.type(large)' | |||
var found bool | |||
for _, sc := range scopes { | |||
sc := sc | |||
for i, defScope := range c.Scopes { | |||
if defScope == strings.Split(sc, ".")[0] { | |||
c.Scopes[i] = sc | |||
found = true | |||
} | |||
} | |||
if !found { | |||
c.Scopes = append(c.Scopes, sc) | |||
} | |||
found = false | |||
} | |||
return c |
@@ -0,0 +1,74 @@ | |||
// Copyright 2016 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// +build go1.7 | |||
// Package ctxhttp provides helper functions for performing context-aware HTTP requests. | |||
package ctxhttp // import "golang.org/x/net/context/ctxhttp" | |||
import ( | |||
"io" | |||
"net/http" | |||
"net/url" | |||
"strings" | |||
"golang.org/x/net/context" | |||
) | |||
// Do sends an HTTP request with the provided http.Client and returns | |||
// an HTTP response. | |||
// | |||
// If the client is nil, http.DefaultClient is used. | |||
// | |||
// The provided ctx must be non-nil. If it is canceled or times out, | |||
// ctx.Err() will be returned. | |||
func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { | |||
if client == nil { | |||
client = http.DefaultClient | |||
} | |||
resp, err := client.Do(req.WithContext(ctx)) | |||
// If we got an error, and the context has been canceled, | |||
// the context's error is probably more useful. | |||
if err != nil { | |||
select { | |||
case <-ctx.Done(): | |||
err = ctx.Err() | |||
default: | |||
} | |||
} | |||
return resp, err | |||
} | |||
// Get issues a GET request via the Do function. | |||
func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) { | |||
req, err := http.NewRequest("GET", url, nil) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return Do(ctx, client, req) | |||
} | |||
// Head issues a HEAD request via the Do function. | |||
func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) { | |||
req, err := http.NewRequest("HEAD", url, nil) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return Do(ctx, client, req) | |||
} | |||
// Post issues a POST request via the Do function. | |||
func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) { | |||
req, err := http.NewRequest("POST", url, body) | |||
if err != nil { | |||
return nil, err | |||
} | |||
req.Header.Set("Content-Type", bodyType) | |||
return Do(ctx, client, req) | |||
} | |||
// PostForm issues a POST request via the Do function. | |||
func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) { | |||
return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) | |||
} |
@@ -0,0 +1,147 @@ | |||
// Copyright 2015 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// +build !go1.7 | |||
package ctxhttp // import "golang.org/x/net/context/ctxhttp" | |||
import ( | |||
"io" | |||
"net/http" | |||
"net/url" | |||
"strings" | |||
"golang.org/x/net/context" | |||
) | |||
func nop() {} | |||
var ( | |||
testHookContextDoneBeforeHeaders = nop | |||
testHookDoReturned = nop | |||
testHookDidBodyClose = nop | |||
) | |||
// Do sends an HTTP request with the provided http.Client and returns an HTTP response. | |||
// If the client is nil, http.DefaultClient is used. | |||
// If the context is canceled or times out, ctx.Err() will be returned. | |||
func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { | |||
if client == nil { | |||
client = http.DefaultClient | |||
} | |||
// TODO(djd): Respect any existing value of req.Cancel. | |||
cancel := make(chan struct{}) | |||
req.Cancel = cancel | |||
type responseAndError struct { | |||
resp *http.Response | |||
err error | |||
} | |||
result := make(chan responseAndError, 1) | |||
// Make local copies of test hooks closed over by goroutines below. | |||
// Prevents data races in tests. | |||
testHookDoReturned := testHookDoReturned | |||
testHookDidBodyClose := testHookDidBodyClose | |||
go func() { | |||
resp, err := client.Do(req) | |||
testHookDoReturned() | |||
result <- responseAndError{resp, err} | |||
}() | |||
var resp *http.Response | |||
select { | |||
case <-ctx.Done(): | |||
testHookContextDoneBeforeHeaders() | |||
close(cancel) | |||
// Clean up after the goroutine calling client.Do: | |||
go func() { | |||
if r := <-result; r.resp != nil { | |||
testHookDidBodyClose() | |||
r.resp.Body.Close() | |||
} | |||
}() | |||
return nil, ctx.Err() | |||
case r := <-result: | |||
var err error | |||
resp, err = r.resp, r.err | |||
if err != nil { | |||
return resp, err | |||
} | |||
} | |||
c := make(chan struct{}) | |||
go func() { | |||
select { | |||
case <-ctx.Done(): | |||
close(cancel) | |||
case <-c: | |||
// The response's Body is closed. | |||
} | |||
}() | |||
resp.Body = ¬ifyingReader{resp.Body, c} | |||
return resp, nil | |||
} | |||
// Get issues a GET request via the Do function. | |||
func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) { | |||
req, err := http.NewRequest("GET", url, nil) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return Do(ctx, client, req) | |||
} | |||
// Head issues a HEAD request via the Do function. | |||
func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) { | |||
req, err := http.NewRequest("HEAD", url, nil) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return Do(ctx, client, req) | |||
} | |||
// Post issues a POST request via the Do function. | |||
func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) { | |||
req, err := http.NewRequest("POST", url, body) | |||
if err != nil { | |||
return nil, err | |||
} | |||
req.Header.Set("Content-Type", bodyType) | |||
return Do(ctx, client, req) | |||
} | |||
// PostForm issues a POST request via the Do function. | |||
func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) { | |||
return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) | |||
} | |||
// notifyingReader is an io.ReadCloser that closes the notify channel after | |||
// Close is called or a Read fails on the underlying ReadCloser. | |||
type notifyingReader struct { | |||
io.ReadCloser | |||
notify chan<- struct{} | |||
} | |||
func (r *notifyingReader) Read(p []byte) (int, error) { | |||
n, err := r.ReadCloser.Read(p) | |||
if err != nil && r.notify != nil { | |||
close(r.notify) | |||
r.notify = nil | |||
} | |||
return n, err | |||
} | |||
func (r *notifyingReader) Close() error { | |||
err := r.ReadCloser.Close() | |||
if r.notify != nil { | |||
close(r.notify) | |||
r.notify = nil | |||
} | |||
return err | |||
} |
@@ -1,4 +1,4 @@ | |||
Copyright (c) 2009 The oauth2 Authors. All rights reserved. | |||
Copyright (c) 2009 The Go Authors. All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are |
@@ -1,25 +0,0 @@ | |||
// Copyright 2014 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// +build appengine | |||
// App Engine hooks. | |||
package oauth2 | |||
import ( | |||
"net/http" | |||
"golang.org/x/net/context" | |||
"golang.org/x/oauth2/internal" | |||
"google.golang.org/appengine/urlfetch" | |||
) | |||
func init() { | |||
internal.RegisterContextClientFunc(contextClientAppEngine) | |||
} | |||
func contextClientAppEngine(ctx context.Context) (*http.Client, error) { | |||
return urlfetch.Client(ctx), nil | |||
} |
@@ -0,0 +1,13 @@ | |||
// Copyright 2018 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// +build appengine | |||
package internal | |||
import "google.golang.org/appengine/urlfetch" | |||
func init() { | |||
appengineClientHook = urlfetch.Client | |||
} |
@@ -0,0 +1,6 @@ | |||
// Copyright 2017 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// Package internal contains support packages for oauth2 package. | |||
package internal |
@@ -2,18 +2,14 @@ | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// Package internal contains support packages for oauth2 package. | |||
package internal | |||
import ( | |||
"bufio" | |||
"crypto/rsa" | |||
"crypto/x509" | |||
"encoding/pem" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"strings" | |||
) | |||
// ParseKey converts the binary contents of a private key file | |||
@@ -30,7 +26,7 @@ func ParseKey(key []byte) (*rsa.PrivateKey, error) { | |||
if err != nil { | |||
parsedKey, err = x509.ParsePKCS1PrivateKey(key) | |||
if err != nil { | |||
return nil, fmt.Errorf("private key should be a PEM or plain PKSC1 or PKCS8; parse error: %v", err) | |||
return nil, fmt.Errorf("private key should be a PEM or plain PKCS1 or PKCS8; parse error: %v", err) | |||
} | |||
} | |||
parsed, ok := parsedKey.(*rsa.PrivateKey) | |||
@@ -39,38 +35,3 @@ func ParseKey(key []byte) (*rsa.PrivateKey, error) { | |||
} | |||
return parsed, nil | |||
} | |||
func ParseINI(ini io.Reader) (map[string]map[string]string, error) { | |||
result := map[string]map[string]string{ | |||
"": map[string]string{}, // root section | |||
} | |||
scanner := bufio.NewScanner(ini) | |||
currentSection := "" | |||
for scanner.Scan() { | |||
line := strings.TrimSpace(scanner.Text()) | |||
if strings.HasPrefix(line, ";") { | |||
// comment. | |||
continue | |||
} | |||
if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { | |||
currentSection = strings.TrimSpace(line[1 : len(line)-1]) | |||
result[currentSection] = map[string]string{} | |||
continue | |||
} | |||
parts := strings.SplitN(line, "=", 2) | |||
if len(parts) == 2 && parts[0] != "" { | |||
result[currentSection][strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) | |||
} | |||
} | |||
if err := scanner.Err(); err != nil { | |||
return nil, fmt.Errorf("error scanning ini: %v", err) | |||
} | |||
return result, nil | |||
} | |||
func CondVal(v string) []string { | |||
if v == "" { | |||
return nil | |||
} | |||
return []string{v} | |||
} |
@@ -2,11 +2,12 @@ | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// Package internal contains support packages for oauth2 package. | |||
package internal | |||
import ( | |||
"context" | |||
"encoding/json" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"io/ioutil" | |||
@@ -17,10 +18,10 @@ import ( | |||
"strings" | |||
"time" | |||
"golang.org/x/net/context" | |||
"golang.org/x/net/context/ctxhttp" | |||
) | |||
// Token represents the crendentials used to authorize | |||
// Token represents the credentials used to authorize | |||
// the requests to access protected resources on the OAuth 2.0 | |||
// provider's backend. | |||
// | |||
@@ -91,6 +92,7 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error { | |||
var brokenAuthHeaderProviders = []string{ | |||
"https://accounts.google.com/", | |||
"https://api.codeswholesale.com/oauth/token", | |||
"https://api.dropbox.com/", | |||
"https://api.dropboxapi.com/", | |||
"https://api.instagram.com/", | |||
@@ -99,10 +101,16 @@ var brokenAuthHeaderProviders = []string{ | |||
"https://api.pushbullet.com/", | |||
"https://api.soundcloud.com/", | |||
"https://api.twitch.tv/", | |||
"https://id.twitch.tv/", | |||
"https://app.box.com/", | |||
"https://api.box.com/", | |||
"https://connect.stripe.com/", | |||
"https://login.mailchimp.com/", | |||
"https://login.microsoftonline.com/", | |||
"https://login.salesforce.com/", | |||
"https://login.windows.net", | |||
"https://login.live.com/", | |||
"https://login.live-int.com/", | |||
"https://oauth.sandbox.trainingpeaks.com/", | |||
"https://oauth.trainingpeaks.com/", | |||
"https://oauth.vk.com/", | |||
@@ -117,6 +125,24 @@ var brokenAuthHeaderProviders = []string{ | |||
"https://www.strava.com/oauth/", | |||
"https://www.wunderlist.com/oauth/", | |||
"https://api.patreon.com/", | |||
"https://sandbox.codeswholesale.com/oauth/token", | |||
"https://api.sipgate.com/v1/authorization/oauth", | |||
"https://api.medium.com/v1/tokens", | |||
"https://log.finalsurge.com/oauth/token", | |||
"https://multisport.todaysplan.com.au/rest/oauth/access_token", | |||
"https://whats.todaysplan.com.au/rest/oauth/access_token", | |||
"https://stackoverflow.com/oauth/access_token", | |||
"https://account.health.nokia.com", | |||
"https://accounts.zoho.com", | |||
} | |||
// brokenAuthHeaderDomains lists broken providers that issue dynamic endpoints. | |||
var brokenAuthHeaderDomains = []string{ | |||
".auth0.com", | |||
".force.com", | |||
".myshopify.com", | |||
".okta.com", | |||
".oktapreview.com", | |||
} | |||
func RegisterBrokenAuthHeaderProvider(tokenURL string) { | |||
@@ -139,6 +165,14 @@ func providerAuthHeaderWorks(tokenURL string) bool { | |||
} | |||
} | |||
if u, err := url.Parse(tokenURL); err == nil { | |||
for _, s := range brokenAuthHeaderDomains { | |||
if strings.HasSuffix(u.Host, s) { | |||
return false | |||
} | |||
} | |||
} | |||
// Assume the provider implements the spec properly | |||
// otherwise. We can add more exceptions as they're | |||
// discovered. We will _not_ be adding configurable hooks | |||
@@ -147,14 +181,14 @@ func providerAuthHeaderWorks(tokenURL string) bool { | |||
} | |||
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) { | |||
hc, err := ContextClient(ctx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
v.Set("client_id", clientID) | |||
bustedAuth := !providerAuthHeaderWorks(tokenURL) | |||
if bustedAuth && clientSecret != "" { | |||
v.Set("client_secret", clientSecret) | |||
if bustedAuth { | |||
if clientID != "" { | |||
v.Set("client_id", clientID) | |||
} | |||
if clientSecret != "" { | |||
v.Set("client_secret", clientSecret) | |||
} | |||
} | |||
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode())) | |||
if err != nil { | |||
@@ -162,9 +196,9 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, | |||
} | |||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | |||
if !bustedAuth { | |||
req.SetBasicAuth(clientID, clientSecret) | |||
req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret)) | |||
} | |||
r, err := hc.Do(req) | |||
r, err := ctxhttp.Do(ctx, ContextClient(ctx), req) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -174,7 +208,10 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, | |||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) | |||
} | |||
if code := r.StatusCode; code < 200 || code > 299 { | |||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body) | |||
return nil, &RetrieveError{ | |||
Response: r, | |||
Body: body, | |||
} | |||
} | |||
var token *Token | |||
@@ -221,5 +258,17 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, | |||
if token.RefreshToken == "" { | |||
token.RefreshToken = v.Get("refresh_token") | |||
} | |||
if token.AccessToken == "" { | |||
return token, errors.New("oauth2: server response missing access_token") | |||
} | |||
return token, nil | |||
} | |||
type RetrieveError struct { | |||
Response *http.Response | |||
Body []byte | |||
} | |||
func (r *RetrieveError) Error() string { | |||
return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body) | |||
} |
@@ -2,13 +2,11 @@ | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// Package internal contains support packages for oauth2 package. | |||
package internal | |||
import ( | |||
"context" | |||
"net/http" | |||
"golang.org/x/net/context" | |||
) | |||
// HTTPClient is the context key to use with golang.org/x/net/context's | |||
@@ -20,50 +18,16 @@ var HTTPClient ContextKey | |||
// because nobody else can create a ContextKey, being unexported. | |||
type ContextKey struct{} | |||
// ContextClientFunc is a func which tries to return an *http.Client | |||
// given a Context value. If it returns an error, the search stops | |||
// with that error. If it returns (nil, nil), the search continues | |||
// down the list of registered funcs. | |||
type ContextClientFunc func(context.Context) (*http.Client, error) | |||
var contextClientFuncs []ContextClientFunc | |||
func RegisterContextClientFunc(fn ContextClientFunc) { | |||
contextClientFuncs = append(contextClientFuncs, fn) | |||
} | |||
var appengineClientHook func(context.Context) *http.Client | |||
func ContextClient(ctx context.Context) (*http.Client, error) { | |||
func ContextClient(ctx context.Context) *http.Client { | |||
if ctx != nil { | |||
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok { | |||
return hc, nil | |||
} | |||
} | |||
for _, fn := range contextClientFuncs { | |||
c, err := fn(ctx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if c != nil { | |||
return c, nil | |||
return hc | |||
} | |||
} | |||
return http.DefaultClient, nil | |||
} | |||
func ContextTransport(ctx context.Context) http.RoundTripper { | |||
hc, err := ContextClient(ctx) | |||
// This is a rare error case (somebody using nil on App Engine). | |||
if err != nil { | |||
return ErrorTransport{err} | |||
if appengineClientHook != nil { | |||
return appengineClientHook(ctx) | |||
} | |||
return hc.Transport | |||
} | |||
// ErrorTransport returns the specified error on RoundTrip. | |||
// This RoundTripper should be used in rare error cases where | |||
// error handling can be postponed to response handling time. | |||
type ErrorTransport struct{ Err error } | |||
func (t ErrorTransport) RoundTrip(*http.Request) (*http.Response, error) { | |||
return nil, t.Err | |||
return http.DefaultClient | |||
} |
@@ -3,19 +3,20 @@ | |||
// license that can be found in the LICENSE file. | |||
// Package oauth2 provides support for making | |||
// OAuth2 authorized and authenticated HTTP requests. | |||
// OAuth2 authorized and authenticated HTTP requests, | |||
// as specified in RFC 6749. | |||
// It can additionally grant authorization with Bearer JWT. | |||
package oauth2 // import "golang.org/x/oauth2" | |||
import ( | |||
"bytes" | |||
"context" | |||
"errors" | |||
"net/http" | |||
"net/url" | |||
"strings" | |||
"sync" | |||
"golang.org/x/net/context" | |||
"golang.org/x/oauth2/internal" | |||
) | |||
@@ -117,21 +118,30 @@ func SetAuthURLParam(key, value string) AuthCodeOption { | |||
// that asks for permissions for the required scopes explicitly. | |||
// | |||
// State is a token to protect the user from CSRF attacks. You must | |||
// always provide a non-zero string and validate that it matches the | |||
// always provide a non-empty string and validate that it matches the | |||
// the state query parameter on your redirect callback. | |||
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info. | |||
// | |||
// Opts may include AccessTypeOnline or AccessTypeOffline, as well | |||
// as ApprovalForce. | |||
// It can also be used to pass the PKCE challange. | |||
// See https://www.oauth.com/oauth2-servers/pkce/ for more info. | |||
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { | |||
var buf bytes.Buffer | |||
buf.WriteString(c.Endpoint.AuthURL) | |||
v := url.Values{ | |||
"response_type": {"code"}, | |||
"client_id": {c.ClientID}, | |||
"redirect_uri": internal.CondVal(c.RedirectURL), | |||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), | |||
"state": internal.CondVal(state), | |||
} | |||
if c.RedirectURL != "" { | |||
v.Set("redirect_uri", c.RedirectURL) | |||
} | |||
if len(c.Scopes) > 0 { | |||
v.Set("scope", strings.Join(c.Scopes, " ")) | |||
} | |||
if state != "" { | |||
// TODO(light): Docs say never to omit state; don't allow empty. | |||
v.Set("state", state) | |||
} | |||
for _, opt := range opts { | |||
opt.setValue(v) | |||
@@ -157,12 +167,15 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { | |||
// The HTTP client to use is derived from the context. | |||
// If nil, http.DefaultClient is used. | |||
func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) { | |||
return retrieveToken(ctx, c, url.Values{ | |||
v := url.Values{ | |||
"grant_type": {"password"}, | |||
"username": {username}, | |||
"password": {password}, | |||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), | |||
}) | |||
} | |||
if len(c.Scopes) > 0 { | |||
v.Set("scope", strings.Join(c.Scopes, " ")) | |||
} | |||
return retrieveToken(ctx, c, v) | |||
} | |||
// Exchange converts an authorization code into a token. | |||
@@ -175,13 +188,21 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor | |||
// | |||
// The code will be in the *http.Request.FormValue("code"). Before | |||
// calling Exchange, be sure to validate FormValue("state"). | |||
func (c *Config) Exchange(ctx context.Context, code string) (*Token, error) { | |||
return retrieveToken(ctx, c, url.Values{ | |||
"grant_type": {"authorization_code"}, | |||
"code": {code}, | |||
"redirect_uri": internal.CondVal(c.RedirectURL), | |||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), | |||
}) | |||
// | |||
// Opts may include the PKCE verifier code if previously used in AuthCodeURL. | |||
// See https://www.oauth.com/oauth2-servers/pkce/ for more info. | |||
func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) { | |||
v := url.Values{ | |||
"grant_type": {"authorization_code"}, | |||
"code": {code}, | |||
} | |||
if c.RedirectURL != "" { | |||
v.Set("redirect_uri", c.RedirectURL) | |||
} | |||
for _, opt := range opts { | |||
opt.setValue(v) | |||
} | |||
return retrieveToken(ctx, c, v) | |||
} | |||
// Client returns an HTTP client using the provided token. | |||
@@ -292,20 +313,20 @@ var HTTPClient internal.ContextKey | |||
// NewClient creates an *http.Client from a Context and TokenSource. | |||
// The returned client is not valid beyond the lifetime of the context. | |||
// | |||
// Note that if a custom *http.Client is provided via the Context it | |||
// is used only for token acquisition and is not used to configure the | |||
// *http.Client returned from NewClient. | |||
// | |||
// As a special case, if src is nil, a non-OAuth2 client is returned | |||
// using the provided context. This exists to support related OAuth2 | |||
// packages. | |||
func NewClient(ctx context.Context, src TokenSource) *http.Client { | |||
if src == nil { | |||
c, err := internal.ContextClient(ctx) | |||
if err != nil { | |||
return &http.Client{Transport: internal.ErrorTransport{Err: err}} | |||
} | |||
return c | |||
return internal.ContextClient(ctx) | |||
} | |||
return &http.Client{ | |||
Transport: &Transport{ | |||
Base: internal.ContextTransport(ctx), | |||
Base: internal.ContextClient(ctx).Transport, | |||
Source: ReuseTokenSource(nil, src), | |||
}, | |||
} |
@@ -5,13 +5,14 @@ | |||
package oauth2 | |||
import ( | |||
"context" | |||
"fmt" | |||
"net/http" | |||
"net/url" | |||
"strconv" | |||
"strings" | |||
"time" | |||
"golang.org/x/net/context" | |||
"golang.org/x/oauth2/internal" | |||
) | |||
@@ -20,7 +21,7 @@ import ( | |||
// expirations due to client-server time mismatches. | |||
const expiryDelta = 10 * time.Second | |||
// Token represents the crendentials used to authorize | |||
// Token represents the credentials used to authorize | |||
// the requests to access protected resources on the OAuth 2.0 | |||
// provider's backend. | |||
// | |||
@@ -123,7 +124,7 @@ func (t *Token) expired() bool { | |||
if t.Expiry.IsZero() { | |||
return false | |||
} | |||
return t.Expiry.Add(-expiryDelta).Before(time.Now()) | |||
return t.Expiry.Round(0).Add(-expiryDelta).Before(time.Now()) | |||
} | |||
// Valid reports whether t is non-nil, has an AccessToken, and is not expired. | |||
@@ -152,7 +153,23 @@ func tokenFromInternal(t *internal.Token) *Token { | |||
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { | |||
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v) | |||
if err != nil { | |||
if rErr, ok := err.(*internal.RetrieveError); ok { | |||
return nil, (*RetrieveError)(rErr) | |||
} | |||
return nil, err | |||
} | |||
return tokenFromInternal(tk), nil | |||
} | |||
// RetrieveError is the error returned when the token endpoint returns a | |||
// non-2XX HTTP status code. | |||
type RetrieveError struct { | |||
Response *http.Response | |||
// Body is the body that was consumed by reading Response.Body. | |||
// It may be truncated. | |||
Body []byte | |||
} | |||
func (r *RetrieveError) Error() string { | |||
return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body) | |||
} |
@@ -31,9 +31,17 @@ type Transport struct { | |||
} | |||
// RoundTrip authorizes and authenticates the request with an | |||
// access token. If no token exists or token is expired, | |||
// tries to refresh/fetch a new token. | |||
// access token from Transport's Source. | |||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { | |||
reqBodyClosed := false | |||
if req.Body != nil { | |||
defer func() { | |||
if !reqBodyClosed { | |||
req.Body.Close() | |||
} | |||
}() | |||
} | |||
if t.Source == nil { | |||
return nil, errors.New("oauth2: Transport's Source is nil") | |||
} | |||
@@ -46,6 +54,10 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { | |||
token.SetAuthHeader(req2) | |||
t.setModReq(req, req2) | |||
res, err := t.base().RoundTrip(req2) | |||
// req.Body is assumed to have been closed by the base RoundTripper. | |||
reqBodyClosed = true | |||
if err != nil { | |||
t.setModReq(req, nil) | |||
return nil, err |