]> source.dussan.org Git - gitea.git/commitdiff
upgrade go-sql-driver/mysql to fix invalid connection error (#5748)
authorLunny Xiao <xiaolunwen@gmail.com>
Thu, 17 Jan 2019 06:07:23 +0000 (14:07 +0800)
committerzeripath <art27@cantab.net>
Thu, 17 Jan 2019 06:07:23 +0000 (06:07 +0000)
should fix #5736

13 files changed:
Gopkg.lock
Gopkg.toml
vendor/github.com/go-sql-driver/mysql/AUTHORS
vendor/github.com/go-sql-driver/mysql/auth.go
vendor/github.com/go-sql-driver/mysql/buffer.go
vendor/github.com/go-sql-driver/mysql/connection.go
vendor/github.com/go-sql-driver/mysql/connection_go18.go [deleted file]
vendor/github.com/go-sql-driver/mysql/driver.go
vendor/github.com/go-sql-driver/mysql/dsn.go
vendor/github.com/go-sql-driver/mysql/packets.go
vendor/github.com/go-sql-driver/mysql/utils.go
vendor/github.com/go-sql-driver/mysql/utils_go17.go [deleted file]
vendor/github.com/go-sql-driver/mysql/utils_go18.go [deleted file]

index 17e4397b15f68e064fb872e8faa8fe355f7f0677..5c2b54e3f9bf56ce905da93ac521a3800246f348 100644 (file)
   revision = "a77f45a7ce909c0ff14b28279fa1a2b674acb70f"
 
 [[projects]]
-  digest = "1:747c1fcb10f8f6734551465ab73c6ed9c551aa6e66250fb6683d1624f554546a"
+  digest = "1:dce58f88343bd78f4d32dd9601aab4fa5d9994fd2cafa185c51bbd858851cdf9"
   name = "github.com/go-sql-driver/mysql"
   packages = ["."]
   pruneopts = "NUT"
-  revision = "d523deb1b23d913de5bdada721a6071e71283618"
+  revision = "c45f530f8e7fe40f4687eaa50d0c8c5f1b66f9e0"
 
 [[projects]]
   digest = "1:06d21295033f211588d0ad7ff391cc1b27e72b60cb6d4b7db0d70cffae4cf228"
index 2eb81803a17eae8913fd1a97d78aa85c37ad8648..51f2b2cabe74c116533c77824d997deace568e4b 100644 (file)
@@ -46,7 +46,7 @@ ignored = ["google.golang.org/appengine*"]
 
 [[override]]
   name = "github.com/go-sql-driver/mysql"
-  revision = "d523deb1b23d913de5bdada721a6071e71283618"
+  revision = "c45f530f8e7fe40f4687eaa50d0c8c5f1b66f9e0"
 
 [[override]]
   name = "github.com/mattn/go-sqlite3"
index 73ff68fbcf2233d7142e76981fc3d1621c33f7f7..5ce4f7eca1d1baa232ec91f16b6073b1a2465496 100644 (file)
@@ -35,6 +35,7 @@ Hanno Braun <mail at hannobraun.com>
 Henri Yandell <flamefew at gmail.com>
 Hirotaka Yamamoto <ymmt2005 at gmail.com>
 ICHINOSE Shogo <shogo82148 at gmail.com>
+Ilia Cimpoes <ichimpoesh at gmail.com>
 INADA Naoki <songofacandy at gmail.com>
 Jacek Szwec <szwec.jacek at gmail.com>
 James Harr <james.harr at gmail.com>
@@ -72,6 +73,9 @@ Shuode Li <elemount at qq.com>
 Soroush Pour <me at soroushjp.com>
 Stan Putrya <root.vagner at gmail.com>
 Stanley Gunawan <gunawan.stanley at gmail.com>
+Steven Hartland <steven.hartland at multiplay.co.uk>
+Thomas Wodarek <wodarekwebpage at gmail.com>
+Tom Jenkinson <tom at tjenkinson.me>
 Xiangyu Hu <xiangyu.hu at outlook.com>
 Xiaobing Jiang <s7v7nislands at gmail.com>
 Xiuming Chen <cc at cxm.cc>
@@ -87,3 +91,4 @@ Keybase Inc.
 Percona LLC
 Pivotal Inc.
 Stripe Inc.
+Multiplay Ltd.
index 0b59f52ee7e41f1299bd02bc5f0cd7215b247d54..fec7040d4a2d2fb7d0f91555763f07ebc8c384ae 100644 (file)
@@ -234,64 +234,64 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro
        if err != nil {
                return err
        }
-       return mc.writeAuthSwitchPacket(enc, false)
+       return mc.writeAuthSwitchPacket(enc)
 }
 
-func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) {
+func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
        switch plugin {
        case "caching_sha2_password":
                authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
-               return authResp, (authResp == nil), nil
+               return authResp, nil
 
        case "mysql_old_password":
                if !mc.cfg.AllowOldPasswords {
-                       return nil, false, ErrOldPassword
+                       return nil, ErrOldPassword
                }
                // Note: there are edge cases where this should work but doesn't;
                // this is currently "wontfix":
                // https://github.com/go-sql-driver/mysql/issues/184
-               authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd)
-               return authResp, true, nil
+               authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0)
+               return authResp, nil
 
        case "mysql_clear_password":
                if !mc.cfg.AllowCleartextPasswords {
-                       return nil, false, ErrCleartextPassword
+                       return nil, ErrCleartextPassword
                }
                // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
                // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
-               return []byte(mc.cfg.Passwd), true, nil
+               return append([]byte(mc.cfg.Passwd), 0), nil
 
        case "mysql_native_password":
                if !mc.cfg.AllowNativePasswords {
-                       return nil, false, ErrNativePassword
+                       return nil, ErrNativePassword
                }
                // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
                // Native password authentication only need and will need 20-byte challenge.
                authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
-               return authResp, false, nil
+               return authResp, nil
 
        case "sha256_password":
                if len(mc.cfg.Passwd) == 0 {
-                       return nil, true, nil
+                       return []byte{0}, nil
                }
                if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
                        // write cleartext auth packet
-                       return []byte(mc.cfg.Passwd), true, nil
+                       return append([]byte(mc.cfg.Passwd), 0), nil
                }
 
                pubKey := mc.cfg.pubKey
                if pubKey == nil {
                        // request public key from server
-                       return []byte{1}, false, nil
+                       return []byte{1}, nil
                }
 
                // encrypted password
                enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
-               return enc, false, err
+               return enc, err
 
        default:
                errLog.Print("unknown auth plugin:", plugin)
-               return nil, false, ErrUnknownPlugin
+               return nil, ErrUnknownPlugin
        }
 }
 
@@ -315,11 +315,11 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
 
                plugin = newPlugin
 
-               authResp, addNUL, err := mc.auth(authData, plugin)
+               authResp, err := mc.auth(authData, plugin)
                if err != nil {
                        return err
                }
-               if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil {
+               if err = mc.writeAuthSwitchPacket(authResp); err != nil {
                        return err
                }
 
@@ -352,7 +352,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
                        case cachingSha2PasswordPerformFullAuthentication:
                                if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
                                        // write cleartext auth packet
-                                       err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true)
+                                       err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0))
                                        if err != nil {
                                                return err
                                        }
@@ -360,13 +360,15 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
                                        pubKey := mc.cfg.pubKey
                                        if pubKey == nil {
                                                // request public key from server
-                                               data := mc.buf.takeSmallBuffer(4 + 1)
+                                               data, err := mc.buf.takeSmallBuffer(4 + 1)
+                                               if err != nil {
+                                                       return err
+                                               }
                                                data[4] = cachingSha2PasswordRequestPublicKey
                                                mc.writePacket(data)
 
                                                // parse public key
-                                               data, err := mc.readPacket()
-                                               if err != nil {
+                                               if data, err = mc.readPacket(); err != nil {
                                                        return err
                                                }
 
index eb4748bf448d65ee672198b651a4cc149ba8701d..19486bd6f6d59add73db202580efd618e2e90f0b 100644 (file)
@@ -22,17 +22,17 @@ const defaultBufSize = 4096
 // The buffer is similar to bufio.Reader / Writer but zero-copy-ish
 // Also highly optimized for this particular use case.
 type buffer struct {
-       buf     []byte
+       buf     []byte // buf is a byte buffer who's length and capacity are equal.
        nc      net.Conn
        idx     int
        length  int
        timeout time.Duration
 }
 
+// newBuffer allocates and returns a new buffer.
 func newBuffer(nc net.Conn) buffer {
-       var b [defaultBufSize]byte
        return buffer{
-               buf: b[:],
+               buf: make([]byte, defaultBufSize),
                nc:  nc,
        }
 }
@@ -105,43 +105,56 @@ func (b *buffer) readNext(need int) ([]byte, error) {
        return b.buf[offset:b.idx], nil
 }
 
-// returns a buffer with the requested size.
+// takeBuffer returns a buffer with the requested size.
 // If possible, a slice from the existing buffer is returned.
 // Otherwise a bigger buffer is made.
 // Only one buffer (total) can be used at a time.
-func (b *buffer) takeBuffer(length int) []byte {
+func (b *buffer) takeBuffer(length int) ([]byte, error) {
        if b.length > 0 {
-               return nil
+               return nil, ErrBusyBuffer
        }
 
        // test (cheap) general case first
-       if length <= defaultBufSize || length <= cap(b.buf) {
-               return b.buf[:length]
+       if length <= cap(b.buf) {
+               return b.buf[:length], nil
        }
 
        if length < maxPacketSize {
                b.buf = make([]byte, length)
-               return b.buf
+               return b.buf, nil
        }
-       return make([]byte, length)
+
+       // buffer is larger than we want to store.
+       return make([]byte, length), nil
 }
 
-// shortcut which can be used if the requested buffer is guaranteed to be
-// smaller than defaultBufSize
+// takeSmallBuffer is shortcut which can be used if length is
+// known to be smaller than defaultBufSize.
 // Only one buffer (total) can be used at a time.
-func (b *buffer) takeSmallBuffer(length int) []byte {
+func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
        if b.length > 0 {
-               return nil
+               return nil, ErrBusyBuffer
        }
-       return b.buf[:length]
+       return b.buf[:length], nil
 }
 
 // takeCompleteBuffer returns the complete existing buffer.
 // This can be used if the necessary buffer size is unknown.
+// cap and len of the returned buffer will be equal.
 // Only one buffer (total) can be used at a time.
-func (b *buffer) takeCompleteBuffer() []byte {
+func (b *buffer) takeCompleteBuffer() ([]byte, error) {
+       if b.length > 0 {
+               return nil, ErrBusyBuffer
+       }
+       return b.buf, nil
+}
+
+// store stores buf, an updated buffer, if its suitable to do so.
+func (b *buffer) store(buf []byte) error {
        if b.length > 0 {
-               return nil
+               return ErrBusyBuffer
+       } else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
+               b.buf = buf[:cap(buf)]
        }
-       return b.buf
+       return nil
 }
index e57061412bc2d12afb4fc073e89fab41158cdb52..fc4ec7597d90c4ba9174b2612db62db2f7b223c0 100644 (file)
@@ -9,6 +9,8 @@
 package mysql
 
 import (
+       "context"
+       "database/sql"
        "database/sql/driver"
        "io"
        "net"
@@ -17,16 +19,6 @@ import (
        "time"
 )
 
-// a copy of context.Context for Go 1.7 and earlier
-type mysqlContext interface {
-       Done() <-chan struct{}
-       Err() error
-
-       // defined in context.Context, but not used in this driver:
-       // Deadline() (deadline time.Time, ok bool)
-       // Value(key interface{}) interface{}
-}
-
 type mysqlConn struct {
        buf              buffer
        netConn          net.Conn
@@ -43,7 +35,7 @@ type mysqlConn struct {
 
        // for context support (Go 1.8+)
        watching bool
-       watcher  chan<- mysqlContext
+       watcher  chan<- context.Context
        closech  chan struct{}
        finished chan<- struct{}
        canceled atomicError // set non-nil if conn is canceled
@@ -190,10 +182,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
                return "", driver.ErrSkip
        }
 
-       buf := mc.buf.takeCompleteBuffer()
-       if buf == nil {
+       buf, err := mc.buf.takeCompleteBuffer()
+       if err != nil {
                // can not take the buffer. Something must be wrong with the connection
-               errLog.Print(ErrBusyBuffer)
+               errLog.Print(err)
                return "", ErrInvalidConn
        }
        buf = buf[:0]
@@ -459,3 +451,193 @@ func (mc *mysqlConn) finish() {
        case <-mc.closech:
        }
 }
+
+// Ping implements driver.Pinger interface
+func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
+       if mc.closed.IsSet() {
+               errLog.Print(ErrInvalidConn)
+               return driver.ErrBadConn
+       }
+
+       if err = mc.watchCancel(ctx); err != nil {
+               return
+       }
+       defer mc.finish()
+
+       if err = mc.writeCommandPacket(comPing); err != nil {
+               return mc.markBadConn(err)
+       }
+
+       return mc.readResultOK()
+}
+
+// BeginTx implements driver.ConnBeginTx interface
+func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
+       if err := mc.watchCancel(ctx); err != nil {
+               return nil, err
+       }
+       defer mc.finish()
+
+       if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
+               level, err := mapIsolationLevel(opts.Isolation)
+               if err != nil {
+                       return nil, err
+               }
+               err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
+               if err != nil {
+                       return nil, err
+               }
+       }
+
+       return mc.begin(opts.ReadOnly)
+}
+
+func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
+       dargs, err := namedValueToValue(args)
+       if err != nil {
+               return nil, err
+       }
+
+       if err := mc.watchCancel(ctx); err != nil {
+               return nil, err
+       }
+
+       rows, err := mc.query(query, dargs)
+       if err != nil {
+               mc.finish()
+               return nil, err
+       }
+       rows.finish = mc.finish
+       return rows, err
+}
+
+func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
+       dargs, err := namedValueToValue(args)
+       if err != nil {
+               return nil, err
+       }
+
+       if err := mc.watchCancel(ctx); err != nil {
+               return nil, err
+       }
+       defer mc.finish()
+
+       return mc.Exec(query, dargs)
+}
+
+func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
+       if err := mc.watchCancel(ctx); err != nil {
+               return nil, err
+       }
+
+       stmt, err := mc.Prepare(query)
+       mc.finish()
+       if err != nil {
+               return nil, err
+       }
+
+       select {
+       default:
+       case <-ctx.Done():
+               stmt.Close()
+               return nil, ctx.Err()
+       }
+       return stmt, nil
+}
+
+func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
+       dargs, err := namedValueToValue(args)
+       if err != nil {
+               return nil, err
+       }
+
+       if err := stmt.mc.watchCancel(ctx); err != nil {
+               return nil, err
+       }
+
+       rows, err := stmt.query(dargs)
+       if err != nil {
+               stmt.mc.finish()
+               return nil, err
+       }
+       rows.finish = stmt.mc.finish
+       return rows, err
+}
+
+func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
+       dargs, err := namedValueToValue(args)
+       if err != nil {
+               return nil, err
+       }
+
+       if err := stmt.mc.watchCancel(ctx); err != nil {
+               return nil, err
+       }
+       defer stmt.mc.finish()
+
+       return stmt.Exec(dargs)
+}
+
+func (mc *mysqlConn) watchCancel(ctx context.Context) error {
+       if mc.watching {
+               // Reach here if canceled,
+               // so the connection is already invalid
+               mc.cleanup()
+               return nil
+       }
+       // When ctx is already cancelled, don't watch it.
+       if err := ctx.Err(); err != nil {
+               return err
+       }
+       // When ctx is not cancellable, don't watch it.
+       if ctx.Done() == nil {
+               return nil
+       }
+       // When watcher is not alive, can't watch it.
+       if mc.watcher == nil {
+               return nil
+       }
+
+       mc.watching = true
+       mc.watcher <- ctx
+       return nil
+}
+
+func (mc *mysqlConn) startWatcher() {
+       watcher := make(chan context.Context, 1)
+       mc.watcher = watcher
+       finished := make(chan struct{})
+       mc.finished = finished
+       go func() {
+               for {
+                       var ctx context.Context
+                       select {
+                       case ctx = <-watcher:
+                       case <-mc.closech:
+                               return
+                       }
+
+                       select {
+                       case <-ctx.Done():
+                               mc.cancel(ctx.Err())
+                       case <-finished:
+                       case <-mc.closech:
+                               return
+                       }
+               }
+       }()
+}
+
+func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
+       nv.Value, err = converter{}.ConvertValue(nv.Value)
+       return
+}
+
+// ResetSession implements driver.SessionResetter.
+// (From Go 1.10)
+func (mc *mysqlConn) ResetSession(ctx context.Context) error {
+       if mc.closed.IsSet() {
+               return driver.ErrBadConn
+       }
+       return nil
+}
diff --git a/vendor/github.com/go-sql-driver/mysql/connection_go18.go b/vendor/github.com/go-sql-driver/mysql/connection_go18.go
deleted file mode 100644 (file)
index 62796bf..0000000
+++ /dev/null
@@ -1,208 +0,0 @@
-// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
-//
-// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this file,
-// You can obtain one at http://mozilla.org/MPL/2.0/.
-
-// +build go1.8
-
-package mysql
-
-import (
-       "context"
-       "database/sql"
-       "database/sql/driver"
-)
-
-// Ping implements driver.Pinger interface
-func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
-       if mc.closed.IsSet() {
-               errLog.Print(ErrInvalidConn)
-               return driver.ErrBadConn
-       }
-
-       if err = mc.watchCancel(ctx); err != nil {
-               return
-       }
-       defer mc.finish()
-
-       if err = mc.writeCommandPacket(comPing); err != nil {
-               return
-       }
-
-       return mc.readResultOK()
-}
-
-// BeginTx implements driver.ConnBeginTx interface
-func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
-       if err := mc.watchCancel(ctx); err != nil {
-               return nil, err
-       }
-       defer mc.finish()
-
-       if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
-               level, err := mapIsolationLevel(opts.Isolation)
-               if err != nil {
-                       return nil, err
-               }
-               err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
-               if err != nil {
-                       return nil, err
-               }
-       }
-
-       return mc.begin(opts.ReadOnly)
-}
-
-func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
-       dargs, err := namedValueToValue(args)
-       if err != nil {
-               return nil, err
-       }
-
-       if err := mc.watchCancel(ctx); err != nil {
-               return nil, err
-       }
-
-       rows, err := mc.query(query, dargs)
-       if err != nil {
-               mc.finish()
-               return nil, err
-       }
-       rows.finish = mc.finish
-       return rows, err
-}
-
-func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
-       dargs, err := namedValueToValue(args)
-       if err != nil {
-               return nil, err
-       }
-
-       if err := mc.watchCancel(ctx); err != nil {
-               return nil, err
-       }
-       defer mc.finish()
-
-       return mc.Exec(query, dargs)
-}
-
-func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
-       if err := mc.watchCancel(ctx); err != nil {
-               return nil, err
-       }
-
-       stmt, err := mc.Prepare(query)
-       mc.finish()
-       if err != nil {
-               return nil, err
-       }
-
-       select {
-       default:
-       case <-ctx.Done():
-               stmt.Close()
-               return nil, ctx.Err()
-       }
-       return stmt, nil
-}
-
-func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
-       dargs, err := namedValueToValue(args)
-       if err != nil {
-               return nil, err
-       }
-
-       if err := stmt.mc.watchCancel(ctx); err != nil {
-               return nil, err
-       }
-
-       rows, err := stmt.query(dargs)
-       if err != nil {
-               stmt.mc.finish()
-               return nil, err
-       }
-       rows.finish = stmt.mc.finish
-       return rows, err
-}
-
-func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
-       dargs, err := namedValueToValue(args)
-       if err != nil {
-               return nil, err
-       }
-
-       if err := stmt.mc.watchCancel(ctx); err != nil {
-               return nil, err
-       }
-       defer stmt.mc.finish()
-
-       return stmt.Exec(dargs)
-}
-
-func (mc *mysqlConn) watchCancel(ctx context.Context) error {
-       if mc.watching {
-               // Reach here if canceled,
-               // so the connection is already invalid
-               mc.cleanup()
-               return nil
-       }
-       if ctx.Done() == nil {
-               return nil
-       }
-
-       mc.watching = true
-       select {
-       default:
-       case <-ctx.Done():
-               return ctx.Err()
-       }
-       if mc.watcher == nil {
-               return nil
-       }
-
-       mc.watcher <- ctx
-
-       return nil
-}
-
-func (mc *mysqlConn) startWatcher() {
-       watcher := make(chan mysqlContext, 1)
-       mc.watcher = watcher
-       finished := make(chan struct{})
-       mc.finished = finished
-       go func() {
-               for {
-                       var ctx mysqlContext
-                       select {
-                       case ctx = <-watcher:
-                       case <-mc.closech:
-                               return
-                       }
-
-                       select {
-                       case <-ctx.Done():
-                               mc.cancel(ctx.Err())
-                       case <-finished:
-                       case <-mc.closech:
-                               return
-                       }
-               }
-       }()
-}
-
-func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
-       nv.Value, err = converter{}.ConvertValue(nv.Value)
-       return
-}
-
-// ResetSession implements driver.SessionResetter.
-// (From Go 1.10)
-func (mc *mysqlConn) ResetSession(ctx context.Context) error {
-       if mc.closed.IsSet() {
-               return driver.ErrBadConn
-       }
-       return nil
-}
index 1a75a16ecf0a204529987713aeddff8249073b08..9f4967087f5185e4456be6a63f37c993689cc54a 100644 (file)
@@ -23,11 +23,6 @@ import (
        "sync"
 )
 
-// watcher interface is used for context support (From Go 1.8)
-type watcher interface {
-       startWatcher()
-}
-
 // MySQLDriver is exported to make the driver directly accessible.
 // In general the driver is used via the database/sql package.
 type MySQLDriver struct{}
@@ -55,7 +50,7 @@ func RegisterDial(net string, dial DialFunc) {
 
 // Open new Connection.
 // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
-// the DSN string is formated
+// the DSN string is formatted
 func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
        var err error
 
@@ -82,6 +77,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
                mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
        }
        if err != nil {
+               if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
+                       errLog.Print("net.Error from Dial()': ", nerr.Error())
+                       return nil, driver.ErrBadConn
+               }
                return nil, err
        }
 
@@ -96,9 +95,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
        }
 
        // Call startWatcher for context support (From Go 1.8)
-       if s, ok := interface{}(mc).(watcher); ok {
-               s.startWatcher()
-       }
+       mc.startWatcher()
 
        mc.buf = newBuffer(mc.netConn)
 
@@ -112,20 +109,23 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
                mc.cleanup()
                return nil, err
        }
+       if plugin == "" {
+               plugin = defaultAuthPlugin
+       }
 
        // Send Client Authentication Packet
-       authResp, addNUL, err := mc.auth(authData, plugin)
+       authResp, err := mc.auth(authData, plugin)
        if err != nil {
                // try the default auth plugin, if using the requested plugin failed
                errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
                plugin = defaultAuthPlugin
-               authResp, addNUL, err = mc.auth(authData, plugin)
+               authResp, err = mc.auth(authData, plugin)
                if err != nil {
                        mc.cleanup()
                        return nil, err
                }
        }
-       if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
+       if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
                mc.cleanup()
                return nil, err
        }
index be014babe3335703c74b93c4f0d3e74f6534bb5e..b9134722eb0fcc227956a98ccf63598fba2ee1c8 100644 (file)
@@ -560,7 +560,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
                                } else {
                                        cfg.TLSConfig = "false"
                                }
-                       } else if vl := strings.ToLower(value); vl == "skip-verify" {
+                       } else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
                                cfg.TLSConfig = vl
                                cfg.tls = &tls.Config{InsecureSkipVerify: true}
                        } else {
index d873a97b2feac3a89bdb824b3ecf7f218bb48bf8..5e0853767d5a5afbf5ee4230ea9c64de52d0aa5e 100644 (file)
@@ -51,7 +51,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
                mc.sequence++
 
                // packets with length 0 terminate a previous packet which is a
-               // multiple of (2^24)−1 bytes long
+               // multiple of (2^24)-1 bytes long
                if pktLen == 0 {
                        // there was no previous packet
                        if prevData == nil {
@@ -154,15 +154,15 @@ func (mc *mysqlConn) writePacket(data []byte) error {
 
 // Handshake Initialization Packet
 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
-func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) {
-       data, err := mc.readPacket()
+func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
+       data, err = mc.readPacket()
        if err != nil {
                // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
                // in connection initialization we don't risk retrying non-idempotent actions.
                if err == ErrInvalidConn {
                        return nil, "", driver.ErrBadConn
                }
-               return nil, "", err
+               return
        }
 
        if data[0] == iERR {
@@ -194,11 +194,14 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) {
                return nil, "", ErrOldProtocol
        }
        if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
-               return nil, "", ErrNoTLS
+               if mc.cfg.TLSConfig == "preferred" {
+                       mc.cfg.tls = nil
+               } else {
+                       return nil, "", ErrNoTLS
+               }
        }
        pos += 2
 
-       plugin := ""
        if len(data) > pos {
                // character set [1 byte]
                // status flags [2 bytes]
@@ -236,8 +239,6 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) {
                return b[:], plugin, nil
        }
 
-       plugin = defaultAuthPlugin
-
        // make a memory safe copy of the cipher slice
        var b [8]byte
        copy(b[:], authData)
@@ -246,7 +247,7 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) {
 
 // Client Authentication Packet
 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
-func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error {
+func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
        // Adjust client flags based on server support
        clientFlags := clientProtocol41 |
                clientSecureConn |
@@ -272,7 +273,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
 
        // encode length of the auth plugin data
        var authRespLEIBuf [9]byte
-       authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp)))
+       authRespLen := len(authResp)
+       authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
        if len(authRespLEI) > 1 {
                // if the length can not be written in 1 byte, it must be written as a
                // length encoded integer
@@ -280,9 +282,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
        }
 
        pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
-       if addNUL {
-               pktLen++
-       }
 
        // To specify a db name
        if n := len(mc.cfg.DBName); n > 0 {
@@ -291,10 +290,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
        }
 
        // Calculate packet length and get buffer with that size
-       data := mc.buf.takeSmallBuffer(pktLen + 4)
-       if data == nil {
+       data, err := mc.buf.takeSmallBuffer(pktLen + 4)
+       if err != nil {
                // cannot take the buffer. Something must be wrong with the connection
-               errLog.Print(ErrBusyBuffer)
+               errLog.Print(err)
                return errBadConnNoWrite
        }
 
@@ -353,10 +352,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
        // Auth Data [length encoded integer]
        pos += copy(data[pos:], authRespLEI)
        pos += copy(data[pos:], authResp)
-       if addNUL {
-               data[pos] = 0x00
-               pos++
-       }
 
        // Databasename [null terminated string]
        if len(mc.cfg.DBName) > 0 {
@@ -367,30 +362,24 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
 
        pos += copy(data[pos:], plugin)
        data[pos] = 0x00
+       pos++
 
        // Send Auth packet
-       return mc.writePacket(data)
+       return mc.writePacket(data[:pos])
 }
 
 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
-func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error {
+func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
        pktLen := 4 + len(authData)
-       if addNUL {
-               pktLen++
-       }
-       data := mc.buf.takeSmallBuffer(pktLen)
-       if data == nil {
+       data, err := mc.buf.takeSmallBuffer(pktLen)
+       if err != nil {
                // cannot take the buffer. Something must be wrong with the connection
-               errLog.Print(ErrBusyBuffer)
+               errLog.Print(err)
                return errBadConnNoWrite
        }
 
        // Add the auth data [EOF]
        copy(data[4:], authData)
-       if addNUL {
-               data[pktLen-1] = 0x00
-       }
-
        return mc.writePacket(data)
 }
 
@@ -402,10 +391,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
        // Reset Packet Sequence
        mc.sequence = 0
 
-       data := mc.buf.takeSmallBuffer(4 + 1)
-       if data == nil {
+       data, err := mc.buf.takeSmallBuffer(4 + 1)
+       if err != nil {
                // cannot take the buffer. Something must be wrong with the connection
-               errLog.Print(ErrBusyBuffer)
+               errLog.Print(err)
                return errBadConnNoWrite
        }
 
@@ -421,10 +410,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
        mc.sequence = 0
 
        pktLen := 1 + len(arg)
-       data := mc.buf.takeBuffer(pktLen + 4)
-       if data == nil {
+       data, err := mc.buf.takeBuffer(pktLen + 4)
+       if err != nil {
                // cannot take the buffer. Something must be wrong with the connection
-               errLog.Print(ErrBusyBuffer)
+               errLog.Print(err)
                return errBadConnNoWrite
        }
 
@@ -442,10 +431,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
        // Reset Packet Sequence
        mc.sequence = 0
 
-       data := mc.buf.takeSmallBuffer(4 + 1 + 4)
-       if data == nil {
+       data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
+       if err != nil {
                // cannot take the buffer. Something must be wrong with the connection
-               errLog.Print(ErrBusyBuffer)
+               errLog.Print(err)
                return errBadConnNoWrite
        }
 
@@ -482,7 +471,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
                return data[1:], "", err
 
        case iEOF:
-               if len(data) < 1 {
+               if len(data) == 1 {
                        // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
                        return nil, "mysql_old_password", nil
                }
@@ -898,7 +887,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
        const minPktLen = 4 + 1 + 4 + 1 + 4
        mc := stmt.mc
 
-       // Determine threshould dynamically to avoid packet size shortage.
+       // Determine threshold dynamically to avoid packet size shortage.
        longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
        if longDataSize < 64 {
                longDataSize = 64
@@ -908,15 +897,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
        mc.sequence = 0
 
        var data []byte
+       var err error
 
        if len(args) == 0 {
-               data = mc.buf.takeBuffer(minPktLen)
+               data, err = mc.buf.takeBuffer(minPktLen)
        } else {
-               data = mc.buf.takeCompleteBuffer()
+               data, err = mc.buf.takeCompleteBuffer()
+               // In this case the len(data) == cap(data) which is used to optimise the flow below.
        }
-       if data == nil {
+       if err != nil {
                // cannot take the buffer. Something must be wrong with the connection
-               errLog.Print(ErrBusyBuffer)
+               errLog.Print(err)
                return errBadConnNoWrite
        }
 
@@ -942,7 +933,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
                pos := minPktLen
 
                var nullMask []byte
-               if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
+               if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
                        // buffer has to be extended but we don't know by how much so
                        // we depend on append after all data with known sizes fit.
                        // We stop at that because we deal with a lot of columns here
@@ -951,10 +942,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
                        copy(tmp[:pos], data[:pos])
                        data = tmp
                        nullMask = data[pos : pos+maskLen]
+                       // No need to clean nullMask as make ensures that.
                        pos += maskLen
                } else {
                        nullMask = data[pos : pos+maskLen]
-                       for i := 0; i < maskLen; i++ {
+                       for i := range nullMask {
                                nullMask[i] = 0
                        }
                        pos += maskLen
@@ -1091,7 +1083,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
                // In that case we must build the data packet with the new values buffer
                if valuesCap != cap(paramValues) {
                        data = append(data[:pos], paramValues...)
-                       mc.buf.buf = data
+                       if err = mc.buf.store(data); err != nil {
+                               errLog.Print(err)
+                               return errBadConnNoWrite
+                       }
                }
 
                pos += len(paramValues)
@@ -1261,7 +1256,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
                                                rows.rs.columns[i].decimals,
                                        )
                                }
-                               dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
+                               dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen)
                        case rows.mc.parseTime:
                                dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
                        default:
@@ -1281,7 +1276,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
                                                )
                                        }
                                }
-                               dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false)
+                               dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen)
                        }
 
                        if err == nil {
index 84d595b6ba6e71d99762af1ce02c5da82da4d783..cb3650bb9b8db8dd51f360fd70aa1dc136d741cd 100644 (file)
@@ -10,10 +10,13 @@ package mysql
 
 import (
        "crypto/tls"
+       "database/sql"
        "database/sql/driver"
        "encoding/binary"
+       "errors"
        "fmt"
        "io"
+       "strconv"
        "strings"
        "sync"
        "sync/atomic"
@@ -79,7 +82,7 @@ func DeregisterTLSConfig(key string) {
 func getTLSConfigClone(key string) (config *tls.Config) {
        tlsConfigLock.RLock()
        if v, ok := tlsConfigRegistry[key]; ok {
-               config = cloneTLSConfig(v)
+               config = v.Clone()
        }
        tlsConfigLock.RUnlock()
        return
@@ -227,141 +230,156 @@ var zeroDateTime = []byte("0000-00-00 00:00:00.000000")
 const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
 const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
 
-func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) {
-       // length expects the deterministic length of the zero value,
-       // negative time and 100+ hours are automatically added if needed
-       if len(src) == 0 {
-               if justTime {
-                       return zeroDateTime[11 : 11+length], nil
-               }
-               return zeroDateTime[:length], nil
+func appendMicrosecs(dst, src []byte, decimals int) []byte {
+       if decimals <= 0 {
+               return dst
        }
-       var dst []byte          // return value
-       var pt, p1, p2, p3 byte // current digit pair
-       var zOffs byte          // offset of value in zeroDateTime
-       if justTime {
-               switch length {
-               case
-                       8,                      // time (can be up to 10 when negative and 100+ hours)
-                       10, 11, 12, 13, 14, 15: // time with fractional seconds
-               default:
-                       return nil, fmt.Errorf("illegal TIME length %d", length)
-               }
-               switch len(src) {
-               case 8, 12:
-               default:
-                       return nil, fmt.Errorf("invalid TIME packet length %d", len(src))
-               }
-               // +2 to enable negative time and 100+ hours
-               dst = make([]byte, 0, length+2)
-               if src[0] == 1 {
-                       dst = append(dst, '-')
-               }
-               if src[1] != 0 {
-                       hour := uint16(src[1])*24 + uint16(src[5])
-                       pt = byte(hour / 100)
-                       p1 = byte(hour - 100*uint16(pt))
-                       dst = append(dst, digits01[pt])
-               } else {
-                       p1 = src[5]
-               }
-               zOffs = 11
-               src = src[6:]
-       } else {
-               switch length {
-               case 10, 19, 21, 22, 23, 24, 25, 26:
-               default:
-                       t := "DATE"
-                       if length > 10 {
-                               t += "TIME"
-                       }
-                       return nil, fmt.Errorf("illegal %s length %d", t, length)
-               }
-               switch len(src) {
-               case 4, 7, 11:
-               default:
-                       t := "DATE"
-                       if length > 10 {
-                               t += "TIME"
-                       }
-                       return nil, fmt.Errorf("illegal %s packet length %d", t, len(src))
-               }
-               dst = make([]byte, 0, length)
-               // start with the date
-               year := binary.LittleEndian.Uint16(src[:2])
-               pt = byte(year / 100)
-               p1 = byte(year - 100*uint16(pt))
-               p2, p3 = src[2], src[3]
-               dst = append(dst,
-                       digits10[pt], digits01[pt],
-                       digits10[p1], digits01[p1], '-',
-                       digits10[p2], digits01[p2], '-',
-                       digits10[p3], digits01[p3],
-               )
-               if length == 10 {
-                       return dst, nil
-               }
-               if len(src) == 4 {
-                       return append(dst, zeroDateTime[10:length]...), nil
-               }
-               dst = append(dst, ' ')
-               p1 = src[4] // hour
-               src = src[5:]
-       }
-       // p1 is 2-digit hour, src is after hour
-       p2, p3 = src[0], src[1]
-       dst = append(dst,
-               digits10[p1], digits01[p1], ':',
-               digits10[p2], digits01[p2], ':',
-               digits10[p3], digits01[p3],
-       )
-       if length <= byte(len(dst)) {
-               return dst, nil
-       }
-       src = src[2:]
        if len(src) == 0 {
-               return append(dst, zeroDateTime[19:zOffs+length]...), nil
+               return append(dst, ".000000"[:decimals+1]...)
        }
+
        microsecs := binary.LittleEndian.Uint32(src[:4])
-       p1 = byte(microsecs / 10000)
+       p1 := byte(microsecs / 10000)
        microsecs -= 10000 * uint32(p1)
-       p2 = byte(microsecs / 100)
+       p2 := byte(microsecs / 100)
        microsecs -= 100 * uint32(p2)
-       p3 = byte(microsecs)
-       switch decimals := zOffs + length - 20; decimals {
+       p3 := byte(microsecs)
+
+       switch decimals {
        default:
                return append(dst, '.',
                        digits10[p1], digits01[p1],
                        digits10[p2], digits01[p2],
                        digits10[p3], digits01[p3],
-               ), nil
+               )
        case 1:
                return append(dst, '.',
                        digits10[p1],
-               ), nil
+               )
        case 2:
                return append(dst, '.',
                        digits10[p1], digits01[p1],
-               ), nil
+               )
        case 3:
                return append(dst, '.',
                        digits10[p1], digits01[p1],
                        digits10[p2],
-               ), nil
+               )
        case 4:
                return append(dst, '.',
                        digits10[p1], digits01[p1],
                        digits10[p2], digits01[p2],
-               ), nil
+               )
        case 5:
                return append(dst, '.',
                        digits10[p1], digits01[p1],
                        digits10[p2], digits01[p2],
                        digits10[p3],
-               ), nil
+               )
        }
 }
 
+func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) {
+       // length expects the deterministic length of the zero value,
+       // negative time and 100+ hours are automatically added if needed
+       if len(src) == 0 {
+               return zeroDateTime[:length], nil
+       }
+       var dst []byte      // return value
+       var p1, p2, p3 byte // current digit pair
+
+       switch length {
+       case 10, 19, 21, 22, 23, 24, 25, 26:
+       default:
+               t := "DATE"
+               if length > 10 {
+                       t += "TIME"
+               }
+               return nil, fmt.Errorf("illegal %s length %d", t, length)
+       }
+       switch len(src) {
+       case 4, 7, 11:
+       default:
+               t := "DATE"
+               if length > 10 {
+                       t += "TIME"
+               }
+               return nil, fmt.Errorf("illegal %s packet length %d", t, len(src))
+       }
+       dst = make([]byte, 0, length)
+       // start with the date
+       year := binary.LittleEndian.Uint16(src[:2])
+       pt := year / 100
+       p1 = byte(year - 100*uint16(pt))
+       p2, p3 = src[2], src[3]
+       dst = append(dst,
+               digits10[pt], digits01[pt],
+               digits10[p1], digits01[p1], '-',
+               digits10[p2], digits01[p2], '-',
+               digits10[p3], digits01[p3],
+       )
+       if length == 10 {
+               return dst, nil
+       }
+       if len(src) == 4 {
+               return append(dst, zeroDateTime[10:length]...), nil
+       }
+       dst = append(dst, ' ')
+       p1 = src[4] // hour
+       src = src[5:]
+
+       // p1 is 2-digit hour, src is after hour
+       p2, p3 = src[0], src[1]
+       dst = append(dst,
+               digits10[p1], digits01[p1], ':',
+               digits10[p2], digits01[p2], ':',
+               digits10[p3], digits01[p3],
+       )
+       return appendMicrosecs(dst, src[2:], int(length)-20), nil
+}
+
+func formatBinaryTime(src []byte, length uint8) (driver.Value, error) {
+       // length expects the deterministic length of the zero value,
+       // negative time and 100+ hours are automatically added if needed
+       if len(src) == 0 {
+               return zeroDateTime[11 : 11+length], nil
+       }
+       var dst []byte // return value
+
+       switch length {
+       case
+               8,                      // time (can be up to 10 when negative and 100+ hours)
+               10, 11, 12, 13, 14, 15: // time with fractional seconds
+       default:
+               return nil, fmt.Errorf("illegal TIME length %d", length)
+       }
+       switch len(src) {
+       case 8, 12:
+       default:
+               return nil, fmt.Errorf("invalid TIME packet length %d", len(src))
+       }
+       // +2 to enable negative time and 100+ hours
+       dst = make([]byte, 0, length+2)
+       if src[0] == 1 {
+               dst = append(dst, '-')
+       }
+       days := binary.LittleEndian.Uint32(src[1:5])
+       hours := int64(days)*24 + int64(src[5])
+
+       if hours >= 100 {
+               dst = strconv.AppendInt(dst, hours, 10)
+       } else {
+               dst = append(dst, digits10[hours], digits01[hours])
+       }
+
+       min, sec := src[6], src[7]
+       dst = append(dst, ':',
+               digits10[min], digits01[min], ':',
+               digits10[sec], digits01[sec],
+       )
+       return appendMicrosecs(dst, src[8:], int(length)-9), nil
+}
+
 /******************************************************************************
 *                       Convert from and to bytes                             *
 ******************************************************************************/
@@ -708,3 +726,30 @@ func (ae *atomicError) Value() error {
        }
        return nil
 }
+
+func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
+       dargs := make([]driver.Value, len(named))
+       for n, param := range named {
+               if len(param.Name) > 0 {
+                       // TODO: support the use of Named Parameters #561
+                       return nil, errors.New("mysql: driver does not support the use of Named Parameters")
+               }
+               dargs[n] = param.Value
+       }
+       return dargs, nil
+}
+
+func mapIsolationLevel(level driver.IsolationLevel) (string, error) {
+       switch sql.IsolationLevel(level) {
+       case sql.LevelRepeatableRead:
+               return "REPEATABLE READ", nil
+       case sql.LevelReadCommitted:
+               return "READ COMMITTED", nil
+       case sql.LevelReadUncommitted:
+               return "READ UNCOMMITTED", nil
+       case sql.LevelSerializable:
+               return "SERIALIZABLE", nil
+       default:
+               return "", fmt.Errorf("mysql: unsupported isolation level: %v", level)
+       }
+}
diff --git a/vendor/github.com/go-sql-driver/mysql/utils_go17.go b/vendor/github.com/go-sql-driver/mysql/utils_go17.go
deleted file mode 100644 (file)
index f595634..0000000
+++ /dev/null
@@ -1,40 +0,0 @@
-// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
-//
-// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this file,
-// You can obtain one at http://mozilla.org/MPL/2.0/.
-
-// +build go1.7
-// +build !go1.8
-
-package mysql
-
-import "crypto/tls"
-
-func cloneTLSConfig(c *tls.Config) *tls.Config {
-       return &tls.Config{
-               Rand:                        c.Rand,
-               Time:                        c.Time,
-               Certificates:                c.Certificates,
-               NameToCertificate:           c.NameToCertificate,
-               GetCertificate:              c.GetCertificate,
-               RootCAs:                     c.RootCAs,
-               NextProtos:                  c.NextProtos,
-               ServerName:                  c.ServerName,
-               ClientAuth:                  c.ClientAuth,
-               ClientCAs:                   c.ClientCAs,
-               InsecureSkipVerify:          c.InsecureSkipVerify,
-               CipherSuites:                c.CipherSuites,
-               PreferServerCipherSuites:    c.PreferServerCipherSuites,
-               SessionTicketsDisabled:      c.SessionTicketsDisabled,
-               SessionTicketKey:            c.SessionTicketKey,
-               ClientSessionCache:          c.ClientSessionCache,
-               MinVersion:                  c.MinVersion,
-               MaxVersion:                  c.MaxVersion,
-               CurvePreferences:            c.CurvePreferences,
-               DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
-               Renegotiation:               c.Renegotiation,
-       }
-}
diff --git a/vendor/github.com/go-sql-driver/mysql/utils_go18.go b/vendor/github.com/go-sql-driver/mysql/utils_go18.go
deleted file mode 100644 (file)
index c35c2a6..0000000
+++ /dev/null
@@ -1,50 +0,0 @@
-// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
-//
-// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this file,
-// You can obtain one at http://mozilla.org/MPL/2.0/.
-
-// +build go1.8
-
-package mysql
-
-import (
-       "crypto/tls"
-       "database/sql"
-       "database/sql/driver"
-       "errors"
-       "fmt"
-)
-
-func cloneTLSConfig(c *tls.Config) *tls.Config {
-       return c.Clone()
-}
-
-func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
-       dargs := make([]driver.Value, len(named))
-       for n, param := range named {
-               if len(param.Name) > 0 {
-                       // TODO: support the use of Named Parameters #561
-                       return nil, errors.New("mysql: driver does not support the use of Named Parameters")
-               }
-               dargs[n] = param.Value
-       }
-       return dargs, nil
-}
-
-func mapIsolationLevel(level driver.IsolationLevel) (string, error) {
-       switch sql.IsolationLevel(level) {
-       case sql.LevelRepeatableRead:
-               return "REPEATABLE READ", nil
-       case sql.LevelReadCommitted:
-               return "READ COMMITTED", nil
-       case sql.LevelReadUncommitted:
-               return "READ UNCOMMITTED", nil
-       case sql.LevelSerializable:
-               return "SERIALIZABLE", nil
-       default:
-               return "", fmt.Errorf("mysql: unsupported isolation level: %v", level)
-       }
-}