diff options
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/dsn.go')
-rw-r--r-- | vendor/github.com/go-sql-driver/mysql/dsn.go | 159 |
1 files changed, 111 insertions, 48 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/dsn.go b/vendor/github.com/go-sql-driver/mysql/dsn.go index ac00dcedd5..be014babe3 100644 --- a/vendor/github.com/go-sql-driver/mysql/dsn.go +++ b/vendor/github.com/go-sql-driver/mysql/dsn.go @@ -10,11 +10,13 @@ package mysql import ( "bytes" + "crypto/rsa" "crypto/tls" "errors" "fmt" "net" "net/url" + "sort" "strconv" "strings" "time" @@ -27,7 +29,9 @@ var ( errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") ) -// Config is a configuration parsed from a DSN string +// Config is a configuration parsed from a DSN string. +// If a new Config is created instead of being parsed from a DSN string, +// the NewConfig function should be used, which sets default values. type Config struct { User string // Username Passwd string // Password (requires User) @@ -38,6 +42,8 @@ type Config struct { Collation string // Connection collation Loc *time.Location // Location for time.Time values MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key TLSConfig string // TLS configuration name tls *tls.Config // TLS configuration Timeout time.Duration // Dial timeout @@ -53,7 +59,54 @@ type Config struct { InterpolateParams bool // Interpolate placeholders into query string MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time - Strict bool // Return warnings as errors + RejectReadOnly bool // Reject read-only connections +} + +// NewConfig creates a new Config and sets default values. +func NewConfig() *Config { + return &Config{ + Collation: defaultCollation, + Loc: time.UTC, + MaxAllowedPacket: defaultMaxAllowedPacket, + AllowNativePasswords: true, + } +} + +func (cfg *Config) normalize() error { + if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { + return errInvalidDSNUnsafeCollation + } + + // Set default network if empty + if cfg.Net == "" { + cfg.Net = "tcp" + } + + // Set default address if empty + if cfg.Addr == "" { + switch cfg.Net { + case "tcp": + cfg.Addr = "127.0.0.1:3306" + case "unix": + cfg.Addr = "/tmp/mysql.sock" + default: + return errors.New("default addr for network '" + cfg.Net + "' unknown") + } + + } else if cfg.Net == "tcp" { + cfg.Addr = ensureHavePort(cfg.Addr) + } + + if cfg.tls != nil { + if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.Addr) + if err == nil { + cfg.tls.ServerName = host + } + } + } + + return nil } // FormatDSN formats the given Config into a DSN string which can be passed to @@ -102,12 +155,12 @@ func (cfg *Config) FormatDSN() string { } } - if cfg.AllowNativePasswords { + if !cfg.AllowNativePasswords { if hasParam { - buf.WriteString("&allowNativePasswords=true") + buf.WriteString("&allowNativePasswords=false") } else { hasParam = true - buf.WriteString("?allowNativePasswords=true") + buf.WriteString("?allowNativePasswords=false") } } @@ -195,15 +248,25 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(cfg.ReadTimeout.String()) } - if cfg.Strict { + if cfg.RejectReadOnly { if hasParam { - buf.WriteString("&strict=true") + buf.WriteString("&rejectReadOnly=true") } else { hasParam = true - buf.WriteString("?strict=true") + buf.WriteString("?rejectReadOnly=true") } } + if len(cfg.ServerPubKey) > 0 { + if hasParam { + buf.WriteString("&serverPubKey=") + } else { + hasParam = true + buf.WriteString("?serverPubKey=") + } + buf.WriteString(url.QueryEscape(cfg.ServerPubKey)) + } + if cfg.Timeout > 0 { if hasParam { buf.WriteString("&timeout=") @@ -234,7 +297,7 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(cfg.WriteTimeout.String()) } - if cfg.MaxAllowedPacket > 0 { + if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { if hasParam { buf.WriteString("&maxAllowedPacket=") } else { @@ -247,7 +310,12 @@ func (cfg *Config) FormatDSN() string { // other params if cfg.Params != nil { - for param, value := range cfg.Params { + var params []string + for param := range cfg.Params { + params = append(params, param) + } + sort.Strings(params) + for _, param := range params { if hasParam { buf.WriteByte('&') } else { @@ -257,7 +325,7 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(param) buf.WriteByte('=') - buf.WriteString(url.QueryEscape(value)) + buf.WriteString(url.QueryEscape(cfg.Params[param])) } } @@ -267,10 +335,7 @@ func (cfg *Config) FormatDSN() string { // ParseDSN parses the DSN string to a Config func ParseDSN(dsn string) (cfg *Config, err error) { // New config with some default values - cfg = &Config{ - Loc: time.UTC, - Collation: defaultCollation, - } + cfg = NewConfig() // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find the last '/' (since the password or the net addr might contain a '/') @@ -338,28 +403,9 @@ func ParseDSN(dsn string) (cfg *Config, err error) { return nil, errInvalidDSNNoSlash } - if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { - return nil, errInvalidDSNUnsafeCollation - } - - // Set default network if empty - if cfg.Net == "" { - cfg.Net = "tcp" + if err = cfg.normalize(); err != nil { + return nil, err } - - // Set default address if empty - if cfg.Addr == "" { - switch cfg.Net { - case "tcp": - cfg.Addr = "127.0.0.1:3306" - case "unix": - cfg.Addr = "/tmp/mysql.sock" - default: - return nil, errors.New("default addr for network '" + cfg.Net + "' unknown") - } - - } - return } @@ -374,7 +420,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { // cfg params switch value := param[1]; param[0] { - // Disable INFILE whitelist / enable all files case "allowAllFiles": var isBool bool @@ -472,14 +517,32 @@ func parseDSNParams(cfg *Config, params string) (err error) { return } - // Strict mode - case "strict": + // Reject read-only connections + case "rejectReadOnly": var isBool bool - cfg.Strict, isBool = readBool(value) + cfg.RejectReadOnly, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } + // Server public key + case "serverPubKey": + name, err := url.QueryUnescape(value) + if err != nil { + return fmt.Errorf("invalid value for server pub key name: %v", err) + } + + if pubKey := getServerPubKey(name); pubKey != nil { + cfg.ServerPubKey = name + cfg.pubKey = pubKey + } else { + return errors.New("invalid value / unknown server pub key name: " + name) + } + + // Strict mode + case "strict": + panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") + // Dial Timeout case "timeout": cfg.Timeout, err = time.ParseDuration(value) @@ -506,14 +569,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { return fmt.Errorf("invalid value for TLS config name: %v", err) } - if tlsConfig, ok := tlsConfigRegister[name]; ok { - if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { - host, _, err := net.SplitHostPort(cfg.Addr) - if err == nil { - tlsConfig.ServerName = host - } - } - + if tlsConfig := getTLSConfigClone(name); tlsConfig != nil { cfg.TLSConfig = name cfg.tls = tlsConfig } else { @@ -546,3 +602,10 @@ func parseDSNParams(cfg *Config, params string) (err error) { return } + +func ensureHavePort(addr string) string { + if _, _, err := net.SplitHostPort(addr); err != nil { + return net.JoinHostPort(addr, "3306") + } + return addr +} |