]> source.dussan.org Git - gitea.git/commitdiff
Request for public keys only if LDAP attribute is set (#5816)
authorLauris BH <lauris@nix.lv>
Wed, 23 Jan 2019 23:25:33 +0000 (01:25 +0200)
committerGitHub <noreply@github.com>
Wed, 23 Jan 2019 23:25:33 +0000 (01:25 +0200)
* Update go-ldap dependency

* Request for public keys only if attribute is set

13 files changed:
Gopkg.lock
modules/auth/ldap/ldap.go
vendor/gopkg.in/ldap.v2/LICENSE
vendor/gopkg.in/ldap.v2/atomic_value.go [new file with mode: 0644]
vendor/gopkg.in/ldap.v2/atomic_value_go13.go [new file with mode: 0644]
vendor/gopkg.in/ldap.v2/conn.go
vendor/gopkg.in/ldap.v2/control.go
vendor/gopkg.in/ldap.v2/debug.go
vendor/gopkg.in/ldap.v2/dn.go
vendor/gopkg.in/ldap.v2/error.go
vendor/gopkg.in/ldap.v2/filter.go
vendor/gopkg.in/ldap.v2/ldap.go
vendor/gopkg.in/ldap.v2/passwdmodify.go

index 5c2b54e3f9bf56ce905da93ac521a3800246f348..8f2a9d0ca7bdee3320debde580116792ca54bfdc 100644 (file)
   version = "v1.31.1"
 
 [[projects]]
-  digest = "1:01f4ac37c52bda6f7e1bd73680a99f88733c0408aaa159ecb1ba53a1ade9423c"
+  digest = "1:7e1c00b9959544fa1ccca7cf0407a5b29ac6d5201059c4fac6f599cb99bfd24d"
   name = "gopkg.in/ldap.v2"
   packages = ["."]
   pruneopts = "NUT"
-  revision = "d0a5ced67b4dc310b9158d63a2c6f9c5ec13f105"
-  version = "v2.4.1"
+  revision = "bb7a9ca6e4fbc2129e3db588a34bc970ffe811a9"
+  version = "v2.5.1"
 
 [[projects]]
   digest = "1:cfe1730a152ff033ad7d9c115d22e36b19eec6d5928c06146b9119be45d39dc0"
     "github.com/keybase/go-crypto/openpgp",
     "github.com/keybase/go-crypto/openpgp/armor",
     "github.com/keybase/go-crypto/openpgp/packet",
+    "github.com/klauspost/compress/gzip",
     "github.com/lafriks/xormstore",
     "github.com/lib/pq",
     "github.com/lunny/dingtalk_webhook",
index 010b4ea868dca4feb5bc87115dfbe541bea46e80..c68af25408057149cf87274a086be2777fe9a868 100644 (file)
@@ -247,11 +247,17 @@ func (ls *Source) SearchEntry(name, passwd string, directBind bool) *SearchResul
                return nil
        }
 
+       var isAttributeSSHPublicKeySet = len(strings.TrimSpace(ls.AttributeSSHPublicKey)) > 0
+
+       attribs := []string{ls.AttributeUsername, ls.AttributeName, ls.AttributeSurname, ls.AttributeMail}
+       if isAttributeSSHPublicKeySet {
+               attribs = append(attribs, ls.AttributeSSHPublicKey)
+       }
+
        log.Trace("Fetching attributes '%v', '%v', '%v', '%v', '%v' with filter %s and base %s", ls.AttributeUsername, ls.AttributeName, ls.AttributeSurname, ls.AttributeMail, ls.AttributeSSHPublicKey, userFilter, userDN)
        search := ldap.NewSearchRequest(
                userDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, userFilter,
-               []string{ls.AttributeUsername, ls.AttributeName, ls.AttributeSurname, ls.AttributeMail, ls.AttributeSSHPublicKey},
-               nil)
+               attribs, nil)
 
        sr, err := l.Search(search)
        if err != nil {
@@ -267,11 +273,15 @@ func (ls *Source) SearchEntry(name, passwd string, directBind bool) *SearchResul
                return nil
        }
 
+       var sshPublicKey []string
+
        username := sr.Entries[0].GetAttributeValue(ls.AttributeUsername)
        firstname := sr.Entries[0].GetAttributeValue(ls.AttributeName)
        surname := sr.Entries[0].GetAttributeValue(ls.AttributeSurname)
        mail := sr.Entries[0].GetAttributeValue(ls.AttributeMail)
-       sshPublicKey := sr.Entries[0].GetAttributeValues(ls.AttributeSSHPublicKey)
+       if isAttributeSSHPublicKeySet {
+               sshPublicKey = sr.Entries[0].GetAttributeValues(ls.AttributeSSHPublicKey)
+       }
        isAdmin := checkAdmin(l, ls, userDN)
 
        if !directBind && ls.AttributesInBind {
@@ -320,11 +330,17 @@ func (ls *Source) SearchEntries() []*SearchResult {
 
        userFilter := fmt.Sprintf(ls.Filter, "*")
 
+       var isAttributeSSHPublicKeySet = len(strings.TrimSpace(ls.AttributeSSHPublicKey)) > 0
+
+       attribs := []string{ls.AttributeUsername, ls.AttributeName, ls.AttributeSurname, ls.AttributeMail}
+       if isAttributeSSHPublicKeySet {
+               attribs = append(attribs, ls.AttributeSSHPublicKey)
+       }
+
        log.Trace("Fetching attributes '%v', '%v', '%v', '%v', '%v' with filter %s and base %s", ls.AttributeUsername, ls.AttributeName, ls.AttributeSurname, ls.AttributeMail, ls.AttributeSSHPublicKey, userFilter, ls.UserBase)
        search := ldap.NewSearchRequest(
                ls.UserBase, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, userFilter,
-               []string{ls.AttributeUsername, ls.AttributeName, ls.AttributeSurname, ls.AttributeMail, ls.AttributeSSHPublicKey},
-               nil)
+               attribs, nil)
 
        var sr *ldap.SearchResult
        if ls.UsePagedSearch() {
@@ -341,12 +357,14 @@ func (ls *Source) SearchEntries() []*SearchResult {
 
        for i, v := range sr.Entries {
                result[i] = &SearchResult{
-                       Username:     v.GetAttributeValue(ls.AttributeUsername),
-                       Name:         v.GetAttributeValue(ls.AttributeName),
-                       Surname:      v.GetAttributeValue(ls.AttributeSurname),
-                       Mail:         v.GetAttributeValue(ls.AttributeMail),
-                       SSHPublicKey: v.GetAttributeValues(ls.AttributeSSHPublicKey),
-                       IsAdmin:      checkAdmin(l, ls, v.DN),
+                       Username: v.GetAttributeValue(ls.AttributeUsername),
+                       Name:     v.GetAttributeValue(ls.AttributeName),
+                       Surname:  v.GetAttributeValue(ls.AttributeSurname),
+                       Mail:     v.GetAttributeValue(ls.AttributeMail),
+                       IsAdmin:  checkAdmin(l, ls, v.DN),
+               }
+               if isAttributeSSHPublicKeySet {
+                       result[i].SSHPublicKey = v.GetAttributeValues(ls.AttributeSSHPublicKey)
                }
        }
 
index 74487567632c8f137ef3971b0f5912ca50bebcda..6c0ed4b3872714c1b344f11fd5ae16896c0f8b24 100644 (file)
@@ -1,27 +1,22 @@
-Copyright (c) 2012 The Go Authors. All rights reserved.
+The MIT License (MIT)
 
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are
-met:
+Copyright (c) 2011-2015 Michael Mitton (mmitton@gmail.com)
+Portions copyright (c) 2015-2016 go-ldap Authors
 
-   * Redistributions of source code must retain the above copyright
-notice, this list of conditions and the following disclaimer.
-   * Redistributions in binary form must reproduce the above
-copyright notice, this list of conditions and the following disclaimer
-in the documentation and/or other materials provided with the
-distribution.
-   * Neither the name of Google Inc. nor the names of its
-contributors may be used to endorse or promote products derived from
-this software without specific prior written permission.
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
 
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/vendor/gopkg.in/ldap.v2/atomic_value.go b/vendor/gopkg.in/ldap.v2/atomic_value.go
new file mode 100644 (file)
index 0000000..bccf757
--- /dev/null
@@ -0,0 +1,13 @@
+// +build go1.4
+
+package ldap
+
+import (
+       "sync/atomic"
+)
+
+// For compilers that support it, we just use the underlying sync/atomic.Value
+// type.
+type atomicValue struct {
+       atomic.Value
+}
diff --git a/vendor/gopkg.in/ldap.v2/atomic_value_go13.go b/vendor/gopkg.in/ldap.v2/atomic_value_go13.go
new file mode 100644 (file)
index 0000000..04920bb
--- /dev/null
@@ -0,0 +1,28 @@
+// +build !go1.4
+
+package ldap
+
+import (
+       "sync"
+)
+
+// This is a helper type that emulates the use of the "sync/atomic.Value"
+// struct that's available in Go 1.4 and up.
+type atomicValue struct {
+       value interface{}
+       lock  sync.RWMutex
+}
+
+func (av *atomicValue) Store(val interface{}) {
+       av.lock.Lock()
+       av.value = val
+       av.lock.Unlock()
+}
+
+func (av *atomicValue) Load() interface{} {
+       av.lock.RLock()
+       ret := av.value
+       av.lock.RUnlock()
+
+       return ret
+}
index b5bd99adb5e2a0ce53f29d795d4d629a6358fd65..eb28eb4726aa56736dcc9f772ac984dfc71b6a39 100644 (file)
@@ -11,6 +11,7 @@ import (
        "log"
        "net"
        "sync"
+       "sync/atomic"
        "time"
 
        "gopkg.in/asn1-ber.v1"
@@ -82,20 +83,18 @@ const (
 type Conn struct {
        conn                net.Conn
        isTLS               bool
-       isClosing           bool
-       closeErr            error
+       closing             uint32
+       closeErr            atomicValue
        isStartingTLS       bool
        Debug               debugging
-       chanConfirm         chan bool
+       chanConfirm         chan struct{}
        messageContexts     map[int64]*messageContext
        chanMessage         chan *messagePacket
        chanMessageID       chan int64
-       wgSender            sync.WaitGroup
        wgClose             sync.WaitGroup
-       once                sync.Once
        outstandingRequests uint
        messageMutex        sync.Mutex
-       requestTimeout      time.Duration
+       requestTimeout      int64
 }
 
 var _ Client = &Conn{}
@@ -142,7 +141,7 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
 func NewConn(conn net.Conn, isTLS bool) *Conn {
        return &Conn{
                conn:            conn,
-               chanConfirm:     make(chan bool),
+               chanConfirm:     make(chan struct{}),
                chanMessageID:   make(chan int64),
                chanMessage:     make(chan *messagePacket, 10),
                messageContexts: map[int64]*messageContext{},
@@ -158,12 +157,22 @@ func (l *Conn) Start() {
        l.wgClose.Add(1)
 }
 
+// isClosing returns whether or not we're currently closing.
+func (l *Conn) isClosing() bool {
+       return atomic.LoadUint32(&l.closing) == 1
+}
+
+// setClosing sets the closing value to true
+func (l *Conn) setClosing() bool {
+       return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
+}
+
 // Close closes the connection.
 func (l *Conn) Close() {
-       l.once.Do(func() {
-               l.isClosing = true
-               l.wgSender.Wait()
+       l.messageMutex.Lock()
+       defer l.messageMutex.Unlock()
 
+       if l.setClosing() {
                l.Debug.Printf("Sending quit message and waiting for confirmation")
                l.chanMessage <- &messagePacket{Op: MessageQuit}
                <-l.chanConfirm
@@ -171,27 +180,25 @@ func (l *Conn) Close() {
 
                l.Debug.Printf("Closing network connection")
                if err := l.conn.Close(); err != nil {
-                       log.Print(err)
+                       log.Println(err)
                }
 
                l.wgClose.Done()
-       })
+       }
        l.wgClose.Wait()
 }
 
 // SetTimeout sets the time after a request is sent that a MessageTimeout triggers
 func (l *Conn) SetTimeout(timeout time.Duration) {
        if timeout > 0 {
-               l.requestTimeout = timeout
+               atomic.StoreInt64(&l.requestTimeout, int64(timeout))
        }
 }
 
 // Returns the next available messageID
 func (l *Conn) nextMessageID() int64 {
-       if l.chanMessageID != nil {
-               if messageID, ok := <-l.chanMessageID; ok {
-                       return messageID
-               }
+       if messageID, ok := <-l.chanMessageID; ok {
+               return messageID
        }
        return 0
 }
@@ -258,7 +265,7 @@ func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
 }
 
 func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
-       if l.isClosing {
+       if l.isClosing() {
                return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
        }
        l.messageMutex.Lock()
@@ -297,7 +304,7 @@ func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags)
 func (l *Conn) finishMessage(msgCtx *messageContext) {
        close(msgCtx.done)
 
-       if l.isClosing {
+       if l.isClosing() {
                return
        }
 
@@ -316,12 +323,12 @@ func (l *Conn) finishMessage(msgCtx *messageContext) {
 }
 
 func (l *Conn) sendProcessMessage(message *messagePacket) bool {
-       if l.isClosing {
+       l.messageMutex.Lock()
+       defer l.messageMutex.Unlock()
+       if l.isClosing() {
                return false
        }
-       l.wgSender.Add(1)
        l.chanMessage <- message
-       l.wgSender.Done()
        return true
 }
 
@@ -333,15 +340,14 @@ func (l *Conn) processMessages() {
                for messageID, msgCtx := range l.messageContexts {
                        // If we are closing due to an error, inform anyone who
                        // is waiting about the error.
-                       if l.isClosing && l.closeErr != nil {
-                               msgCtx.sendResponse(&PacketResponse{Error: l.closeErr})
+                       if l.isClosing() && l.closeErr.Load() != nil {
+                               msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
                        }
                        l.Debug.Printf("Closing channel for MessageID %d", messageID)
                        close(msgCtx.responses)
                        delete(l.messageContexts, messageID)
                }
                close(l.chanMessageID)
-               l.chanConfirm <- true
                close(l.chanConfirm)
        }()
 
@@ -350,11 +356,7 @@ func (l *Conn) processMessages() {
                select {
                case l.chanMessageID <- messageID:
                        messageID++
-               case message, ok := <-l.chanMessage:
-                       if !ok {
-                               l.Debug.Printf("Shutting down - message channel is closed")
-                               return
-                       }
+               case message := <-l.chanMessage:
                        switch message.Op {
                        case MessageQuit:
                                l.Debug.Printf("Shutting down - quit message received")
@@ -377,14 +379,15 @@ func (l *Conn) processMessages() {
                                l.messageContexts[message.MessageID] = message.Context
 
                                // Add timeout if defined
-                               if l.requestTimeout > 0 {
+                               requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
+                               if requestTimeout > 0 {
                                        go func() {
                                                defer func() {
                                                        if err := recover(); err != nil {
                                                                log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
                                                        }
                                                }()
-                                               time.Sleep(l.requestTimeout)
+                                               time.Sleep(requestTimeout)
                                                timeoutMessage := &messagePacket{
                                                        Op:        MessageTimeout,
                                                        MessageID: message.MessageID,
@@ -397,7 +400,7 @@ func (l *Conn) processMessages() {
                                if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
                                        msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
                                } else {
-                                       log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing)
+                                       log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing())
                                        ber.PrintPacket(message.Packet)
                                }
                        case MessageTimeout:
@@ -439,8 +442,8 @@ func (l *Conn) reader() {
                packet, err := ber.ReadPacket(l.conn)
                if err != nil {
                        // A read error is expected here if we are closing the connection...
-                       if !l.isClosing {
-                               l.closeErr = fmt.Errorf("unable to read LDAP response packet: %s", err)
+                       if !l.isClosing() {
+                               l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
                                l.Debug.Printf("reader error: %s", err.Error())
                        }
                        return
index 5c62118d46e03a2c76c1a7b700f9960ad5f63645..342f325ca61445fbc7c61b4047c6af029b27d8e5 100644 (file)
@@ -334,18 +334,18 @@ func DecodeControl(packet *ber.Packet) Control {
                for _, child := range sequence.Children {
                        if child.Tag == 0 {
                                //Warning
-                               child := child.Children[0]
-                               packet := ber.DecodePacket(child.Data.Bytes())
+                               warningPacket := child.Children[0]
+                               packet := ber.DecodePacket(warningPacket.Data.Bytes())
                                val, ok := packet.Value.(int64)
                                if ok {
-                                       if child.Tag == 0 {
+                                       if warningPacket.Tag == 0 {
                                                //timeBeforeExpiration
                                                c.Expire = val
-                                               child.Value = c.Expire
-                                       } else if child.Tag == 1 {
+                                               warningPacket.Value = c.Expire
+                                       } else if warningPacket.Tag == 1 {
                                                //graceAuthNsRemaining
                                                c.Grace = val
-                                               child.Value = c.Grace
+                                               warningPacket.Value = c.Grace
                                        }
                                }
                        } else if child.Tag == 1 {
index b8a7ecbff1f2815045ba0cc2aeec33361d1fdd67..7279fc2518229ef7c8b1b82f03adc4e0b659bf91 100644 (file)
@@ -6,7 +6,7 @@ import (
        "gopkg.in/asn1-ber.v1"
 )
 
-// debbuging type
+// debugging type
 //     - has a Printf method to write the debug output
 type debugging bool
 
index cc70c894c206551a739bf823c71f3c05946caff5..34e9023af936e86781300806cf5975677b231817 100644 (file)
@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 //
-// File contains DN parsing functionallity
+// File contains DN parsing functionality
 //
 // https://tools.ietf.org/html/rfc4514
 //
@@ -52,7 +52,7 @@ import (
        "fmt"
        "strings"
 
-       ber "gopkg.in/asn1-ber.v1"
+       "gopkg.in/asn1-ber.v1"
 )
 
 // AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514
@@ -83,9 +83,19 @@ func ParseDN(str string) (*DN, error) {
        attribute := new(AttributeTypeAndValue)
        escaping := false
 
+       unescapedTrailingSpaces := 0
+       stringFromBuffer := func() string {
+               s := buffer.String()
+               s = s[0 : len(s)-unescapedTrailingSpaces]
+               buffer.Reset()
+               unescapedTrailingSpaces = 0
+               return s
+       }
+
        for i := 0; i < len(str); i++ {
                char := str[i]
                if escaping {
+                       unescapedTrailingSpaces = 0
                        escaping = false
                        switch char {
                        case ' ', '"', '#', '+', ',', ';', '<', '=', '>', '\\':
@@ -107,10 +117,10 @@ func ParseDN(str string) (*DN, error) {
                        buffer.WriteByte(dst[0])
                        i++
                } else if char == '\\' {
+                       unescapedTrailingSpaces = 0
                        escaping = true
                } else if char == '=' {
-                       attribute.Type = buffer.String()
-                       buffer.Reset()
+                       attribute.Type = stringFromBuffer()
                        // Special case: If the first character in the value is # the
                        // following data is BER encoded so we can just fast forward
                        // and decode.
@@ -133,7 +143,10 @@ func ParseDN(str string) (*DN, error) {
                        }
                } else if char == ',' || char == '+' {
                        // We're done with this RDN or value, push it
-                       attribute.Value = buffer.String()
+                       if len(attribute.Type) == 0 {
+                               return nil, errors.New("incomplete type, value pair")
+                       }
+                       attribute.Value = stringFromBuffer()
                        rdn.Attributes = append(rdn.Attributes, attribute)
                        attribute = new(AttributeTypeAndValue)
                        if char == ',' {
@@ -141,8 +154,17 @@ func ParseDN(str string) (*DN, error) {
                                rdn = new(RelativeDN)
                                rdn.Attributes = make([]*AttributeTypeAndValue, 0)
                        }
-                       buffer.Reset()
+               } else if char == ' ' && buffer.Len() == 0 {
+                       // ignore unescaped leading spaces
+                       continue
                } else {
+                       if char == ' ' {
+                               // Track unescaped spaces in case they are trailing and we need to remove them
+                               unescapedTrailingSpaces++
+                       } else {
+                               // Reset if we see a non-space char
+                               unescapedTrailingSpaces = 0
+                       }
                        buffer.WriteByte(char)
                }
        }
@@ -150,9 +172,76 @@ func ParseDN(str string) (*DN, error) {
                if len(attribute.Type) == 0 {
                        return nil, errors.New("DN ended with incomplete type, value pair")
                }
-               attribute.Value = buffer.String()
+               attribute.Value = stringFromBuffer()
                rdn.Attributes = append(rdn.Attributes, attribute)
                dn.RDNs = append(dn.RDNs, rdn)
        }
        return dn, nil
 }
+
+// Equal returns true if the DNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
+// Returns true if they have the same number of relative distinguished names
+// and corresponding relative distinguished names (by position) are the same.
+func (d *DN) Equal(other *DN) bool {
+       if len(d.RDNs) != len(other.RDNs) {
+               return false
+       }
+       for i := range d.RDNs {
+               if !d.RDNs[i].Equal(other.RDNs[i]) {
+                       return false
+               }
+       }
+       return true
+}
+
+// AncestorOf returns true if the other DN consists of at least one RDN followed by all the RDNs of the current DN.
+// "ou=widgets,o=acme.com" is an ancestor of "ou=sprockets,ou=widgets,o=acme.com"
+// "ou=widgets,o=acme.com" is not an ancestor of "ou=sprockets,ou=widgets,o=foo.com"
+// "ou=widgets,o=acme.com" is not an ancestor of "ou=widgets,o=acme.com"
+func (d *DN) AncestorOf(other *DN) bool {
+       if len(d.RDNs) >= len(other.RDNs) {
+               return false
+       }
+       // Take the last `len(d.RDNs)` RDNs from the other DN to compare against
+       otherRDNs := other.RDNs[len(other.RDNs)-len(d.RDNs):]
+       for i := range d.RDNs {
+               if !d.RDNs[i].Equal(otherRDNs[i]) {
+                       return false
+               }
+       }
+       return true
+}
+
+// Equal returns true if the RelativeDNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
+// Relative distinguished names are the same if and only if they have the same number of AttributeTypeAndValues
+// and each attribute of the first RDN is the same as the attribute of the second RDN with the same attribute type.
+// The order of attributes is not significant.
+// Case of attribute types is not significant.
+func (r *RelativeDN) Equal(other *RelativeDN) bool {
+       if len(r.Attributes) != len(other.Attributes) {
+               return false
+       }
+       return r.hasAllAttributes(other.Attributes) && other.hasAllAttributes(r.Attributes)
+}
+
+func (r *RelativeDN) hasAllAttributes(attrs []*AttributeTypeAndValue) bool {
+       for _, attr := range attrs {
+               found := false
+               for _, myattr := range r.Attributes {
+                       if myattr.Equal(attr) {
+                               found = true
+                               break
+                       }
+               }
+               if !found {
+                       return false
+               }
+       }
+       return true
+}
+
+// Equal returns true if the AttributeTypeAndValue is equivalent to the specified AttributeTypeAndValue
+// Case of the attribute type is not significant
+func (a *AttributeTypeAndValue) Equal(other *AttributeTypeAndValue) bool {
+       return strings.EqualFold(a.Type, other.Type) && a.Value == other.Value
+}
index ff697873ddd12ead9cc59740aa2814c4cfa5e68f..4cccb537fdd87fd6919f2ac139164fc613df896a 100644 (file)
@@ -97,6 +97,13 @@ var LDAPResultCodeMap = map[uint8]string{
        LDAPResultObjectClassModsProhibited:    "Object Class Mods Prohibited",
        LDAPResultAffectsMultipleDSAs:          "Affects Multiple DSAs",
        LDAPResultOther:                        "Other",
+
+       ErrorNetwork:            "Network Error",
+       ErrorFilterCompile:      "Filter Compile Error",
+       ErrorFilterDecompile:    "Filter Decompile Error",
+       ErrorDebugging:          "Debugging Error",
+       ErrorUnexpectedMessage:  "Unexpected Message",
+       ErrorUnexpectedResponse: "Unexpected Response",
 }
 
 func getLDAPResultCode(packet *ber.Packet) (code uint8, description string) {
index 7eae310f1677808cf24380496b0e0317bbab1958..3858a2865c01634efe438cb4373cf74508b7e7a4 100644 (file)
@@ -82,7 +82,10 @@ func CompileFilter(filter string) (*ber.Packet, error) {
        if err != nil {
                return nil, err
        }
-       if pos != len(filter) {
+       switch {
+       case pos > len(filter):
+               return nil, NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
+       case pos < len(filter):
                return nil, NewError(ErrorFilterCompile, errors.New("ldap: finished compiling filter with extra at end: "+fmt.Sprint(filter[pos:])))
        }
        return packet, nil
index 90018be83f5d271aec5e17ccc7ac58a5ddffb573..496924756976ea6a2592b31b8c8f73e054a438d5 100644 (file)
@@ -9,7 +9,7 @@ import (
        "io/ioutil"
        "os"
 
-       ber "gopkg.in/asn1-ber.v1"
+       "gopkg.in/asn1-ber.v1"
 )
 
 // LDAP Application Codes
@@ -153,16 +153,47 @@ func addLDAPDescriptions(packet *ber.Packet) (err error) {
 func addControlDescriptions(packet *ber.Packet) {
        packet.Description = "Controls"
        for _, child := range packet.Children {
+               var value *ber.Packet
+               controlType := ""
                child.Description = "Control"
-               child.Children[0].Description = "Control Type (" + ControlTypeMap[child.Children[0].Value.(string)] + ")"
-               value := child.Children[1]
-               if len(child.Children) == 3 {
+               switch len(child.Children) {
+               case 0:
+                       // at least one child is required for control type
+                       continue
+
+               case 1:
+                       // just type, no criticality or value
+                       controlType = child.Children[0].Value.(string)
+                       child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")"
+
+               case 2:
+                       controlType = child.Children[0].Value.(string)
+                       child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")"
+                       // Children[1] could be criticality or value (both are optional)
+                       // duck-type on whether this is a boolean
+                       if _, ok := child.Children[1].Value.(bool); ok {
+                               child.Children[1].Description = "Criticality"
+                       } else {
+                               child.Children[1].Description = "Control Value"
+                               value = child.Children[1]
+                       }
+
+               case 3:
+                       // criticality and value present
+                       controlType = child.Children[0].Value.(string)
+                       child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")"
                        child.Children[1].Description = "Criticality"
+                       child.Children[2].Description = "Control Value"
                        value = child.Children[2]
-               }
-               value.Description = "Control Value"
 
-               switch child.Children[0].Value.(string) {
+               default:
+                       // more than 3 children is invalid
+                       continue
+               }
+               if value == nil {
+                       continue
+               }
+               switch controlType {
                case ControlTypePaging:
                        value.Description += " (Paging)"
                        if value.Value != nil {
@@ -188,18 +219,18 @@ func addControlDescriptions(packet *ber.Packet) {
                        for _, child := range sequence.Children {
                                if child.Tag == 0 {
                                        //Warning
-                                       child := child.Children[0]
-                                       packet := ber.DecodePacket(child.Data.Bytes())
+                                       warningPacket := child.Children[0]
+                                       packet := ber.DecodePacket(warningPacket.Data.Bytes())
                                        val, ok := packet.Value.(int64)
                                        if ok {
-                                               if child.Tag == 0 {
+                                               if warningPacket.Tag == 0 {
                                                        //timeBeforeExpiration
                                                        value.Description += " (TimeBeforeExpiration)"
-                                                       child.Value = val
-                                               } else if child.Tag == 1 {
+                                                       warningPacket.Value = val
+                                               } else if warningPacket.Tag == 1 {
                                                        //graceAuthNsRemaining
                                                        value.Description += " (GraceAuthNsRemaining)"
-                                                       child.Value = val
+                                                       warningPacket.Value = val
                                                }
                                        }
                                } else if child.Tag == 1 {
index 26110ccf4a5d832f8687cdf84a31f56d6585a914..7d8246fd1895f3356bf49d5744311afd754fd042 100644 (file)
@@ -135,10 +135,10 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa
        extendedResponse := packet.Children[1]
        for _, child := range extendedResponse.Children {
                if child.Tag == 11 {
-                       passwordModifyReponseValue := ber.DecodePacket(child.Data.Bytes())
-                       if len(passwordModifyReponseValue.Children) == 1 {
-                               if passwordModifyReponseValue.Children[0].Tag == 0 {
-                                       result.GeneratedPassword = ber.DecodeString(passwordModifyReponseValue.Children[0].Data.Bytes())
+                       passwordModifyResponseValue := ber.DecodePacket(child.Data.Bytes())
+                       if len(passwordModifyResponseValue.Children) == 1 {
+                               if passwordModifyResponseValue.Children[0].Tag == 0 {
+                                       result.GeneratedPassword = ber.DecodeString(passwordModifyResponseValue.Children[0].Data.Bytes())
                                }
                        }
                }