diff options
Diffstat (limited to 'vendor/gopkg.in/ldap.v2/conn.go')
-rw-r--r-- | vendor/gopkg.in/ldap.v2/conn.go | 73 |
1 files changed, 38 insertions, 35 deletions
diff --git a/vendor/gopkg.in/ldap.v2/conn.go b/vendor/gopkg.in/ldap.v2/conn.go index b5bd99adb5..eb28eb4726 100644 --- a/vendor/gopkg.in/ldap.v2/conn.go +++ b/vendor/gopkg.in/ldap.v2/conn.go @@ -11,6 +11,7 @@ import ( "log" "net" "sync" + "sync/atomic" "time" "gopkg.in/asn1-ber.v1" @@ -82,20 +83,18 @@ const ( type Conn struct { conn net.Conn isTLS bool - isClosing bool - closeErr error + closing uint32 + closeErr atomicValue isStartingTLS bool Debug debugging - chanConfirm chan bool + chanConfirm chan struct{} messageContexts map[int64]*messageContext chanMessage chan *messagePacket chanMessageID chan int64 - wgSender sync.WaitGroup wgClose sync.WaitGroup - once sync.Once outstandingRequests uint messageMutex sync.Mutex - requestTimeout time.Duration + requestTimeout int64 } var _ Client = &Conn{} @@ -142,7 +141,7 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) { func NewConn(conn net.Conn, isTLS bool) *Conn { return &Conn{ conn: conn, - chanConfirm: make(chan bool), + chanConfirm: make(chan struct{}), chanMessageID: make(chan int64), chanMessage: make(chan *messagePacket, 10), messageContexts: map[int64]*messageContext{}, @@ -158,12 +157,22 @@ func (l *Conn) Start() { l.wgClose.Add(1) } +// isClosing returns whether or not we're currently closing. +func (l *Conn) isClosing() bool { + return atomic.LoadUint32(&l.closing) == 1 +} + +// setClosing sets the closing value to true +func (l *Conn) setClosing() bool { + return atomic.CompareAndSwapUint32(&l.closing, 0, 1) +} + // Close closes the connection. func (l *Conn) Close() { - l.once.Do(func() { - l.isClosing = true - l.wgSender.Wait() + l.messageMutex.Lock() + defer l.messageMutex.Unlock() + if l.setClosing() { l.Debug.Printf("Sending quit message and waiting for confirmation") l.chanMessage <- &messagePacket{Op: MessageQuit} <-l.chanConfirm @@ -171,27 +180,25 @@ func (l *Conn) Close() { l.Debug.Printf("Closing network connection") if err := l.conn.Close(); err != nil { - log.Print(err) + log.Println(err) } l.wgClose.Done() - }) + } l.wgClose.Wait() } // SetTimeout sets the time after a request is sent that a MessageTimeout triggers func (l *Conn) SetTimeout(timeout time.Duration) { if timeout > 0 { - l.requestTimeout = timeout + atomic.StoreInt64(&l.requestTimeout, int64(timeout)) } } // Returns the next available messageID func (l *Conn) nextMessageID() int64 { - if l.chanMessageID != nil { - if messageID, ok := <-l.chanMessageID; ok { - return messageID - } + if messageID, ok := <-l.chanMessageID; ok { + return messageID } return 0 } @@ -258,7 +265,7 @@ func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) { } func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) { - if l.isClosing { + if l.isClosing() { return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed")) } l.messageMutex.Lock() @@ -297,7 +304,7 @@ func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) func (l *Conn) finishMessage(msgCtx *messageContext) { close(msgCtx.done) - if l.isClosing { + if l.isClosing() { return } @@ -316,12 +323,12 @@ func (l *Conn) finishMessage(msgCtx *messageContext) { } func (l *Conn) sendProcessMessage(message *messagePacket) bool { - if l.isClosing { + l.messageMutex.Lock() + defer l.messageMutex.Unlock() + if l.isClosing() { return false } - l.wgSender.Add(1) l.chanMessage <- message - l.wgSender.Done() return true } @@ -333,15 +340,14 @@ func (l *Conn) processMessages() { for messageID, msgCtx := range l.messageContexts { // If we are closing due to an error, inform anyone who // is waiting about the error. - if l.isClosing && l.closeErr != nil { - msgCtx.sendResponse(&PacketResponse{Error: l.closeErr}) + if l.isClosing() && l.closeErr.Load() != nil { + msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}) } l.Debug.Printf("Closing channel for MessageID %d", messageID) close(msgCtx.responses) delete(l.messageContexts, messageID) } close(l.chanMessageID) - l.chanConfirm <- true close(l.chanConfirm) }() @@ -350,11 +356,7 @@ func (l *Conn) processMessages() { select { case l.chanMessageID <- messageID: messageID++ - case message, ok := <-l.chanMessage: - if !ok { - l.Debug.Printf("Shutting down - message channel is closed") - return - } + case message := <-l.chanMessage: switch message.Op { case MessageQuit: l.Debug.Printf("Shutting down - quit message received") @@ -377,14 +379,15 @@ func (l *Conn) processMessages() { l.messageContexts[message.MessageID] = message.Context // Add timeout if defined - if l.requestTimeout > 0 { + requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout)) + if requestTimeout > 0 { go func() { defer func() { if err := recover(); err != nil { log.Printf("ldap: recovered panic in RequestTimeout: %v", err) } }() - time.Sleep(l.requestTimeout) + time.Sleep(requestTimeout) timeoutMessage := &messagePacket{ Op: MessageTimeout, MessageID: message.MessageID, @@ -397,7 +400,7 @@ func (l *Conn) processMessages() { if msgCtx, ok := l.messageContexts[message.MessageID]; ok { msgCtx.sendResponse(&PacketResponse{message.Packet, nil}) } else { - log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing) + log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing()) ber.PrintPacket(message.Packet) } case MessageTimeout: @@ -439,8 +442,8 @@ func (l *Conn) reader() { packet, err := ber.ReadPacket(l.conn) if err != nil { // A read error is expected here if we are closing the connection... - if !l.isClosing { - l.closeErr = fmt.Errorf("unable to read LDAP response packet: %s", err) + if !l.isClosing() { + l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err)) l.Debug.Printf("reader error: %s", err.Error()) } return |