diff options
Diffstat (limited to 'vendor/github.com/lib/pq/conn.go')
-rw-r--r-- | vendor/github.com/lib/pq/conn.go | 444 |
1 files changed, 177 insertions, 267 deletions
diff --git a/vendor/github.com/lib/pq/conn.go b/vendor/github.com/lib/pq/conn.go index ca88dc8c67..df6b565f73 100644 --- a/vendor/github.com/lib/pq/conn.go +++ b/vendor/github.com/lib/pq/conn.go @@ -3,15 +3,12 @@ package pq import ( "bufio" "crypto/md5" - "crypto/tls" - "crypto/x509" "database/sql" "database/sql/driver" "encoding/binary" "errors" "fmt" "io" - "io/ioutil" "net" "os" "os/user" @@ -30,22 +27,22 @@ var ( ErrNotSupported = errors.New("pq: Unsupported command") ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") - ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.") - ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.") + ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less") + ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") errUnexpectedReady = errors.New("unexpected ReadyForQuery") errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") - errNoLastInsertId = errors.New("no LastInsertId available after the empty statement") + errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") ) -type drv struct{} +type Driver struct{} -func (d *drv) Open(name string) (driver.Conn, error) { +func (d *Driver) Open(name string) (driver.Conn, error) { return Open(name) } func init() { - sql.Register("postgres", &drv{}) + sql.Register("postgres", &Driver{}) } type parameterStatus struct { @@ -101,6 +98,15 @@ type conn struct { namei int scratch [512]byte txnStatus transactionStatus + txnFinish func() + + // Save connection arguments to use during CancelRequest. + dialer Dialer + opts values + + // Cancellation key data for use with CancelRequest messages. + processID int + secretKey int parameterStatus parameterStatus @@ -125,9 +131,9 @@ type conn struct { } // Handle driver-side settings in parsed connection string. -func (c *conn) handleDriverSettings(o values) (err error) { +func (cn *conn) handleDriverSettings(o values) (err error) { boolSetting := func(key string, val *bool) error { - if value := o.Get(key); value != "" { + if value, ok := o[key]; ok { if value == "yes" { *val = true } else if value == "no" { @@ -139,32 +145,32 @@ func (c *conn) handleDriverSettings(o values) (err error) { return nil } - err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult) + err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult) if err != nil { return err } - err = boolSetting("binary_parameters", &c.binaryParameters) - if err != nil { - return err - } - return nil + return boolSetting("binary_parameters", &cn.binaryParameters) } -func (c *conn) handlePgpass(o values) { +func (cn *conn) handlePgpass(o values) { // if a password was supplied, do not process .pgpass - _, ok := o["password"] - if ok { + if _, ok := o["password"]; ok { return } filename := os.Getenv("PGPASSFILE") if filename == "" { // XXX this code doesn't work on Windows where the default filename is // XXX %APPDATA%\postgresql\pgpass.conf - user, err := user.Current() - if err != nil { - return + // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470 + userHome := os.Getenv("HOME") + if userHome == "" { + user, err := user.Current() + if err != nil { + return + } + userHome = user.HomeDir } - filename = filepath.Join(user.HomeDir, ".pgpass") + filename = filepath.Join(userHome, ".pgpass") } fileinfo, err := os.Stat(filename) if err != nil { @@ -181,11 +187,11 @@ func (c *conn) handlePgpass(o values) { } defer file.Close() scanner := bufio.NewScanner(io.Reader(file)) - hostname := o.Get("host") + hostname := o["host"] ntw, _ := network(o) - port := o.Get("port") - db := o.Get("dbname") - username := o.Get("user") + port := o["port"] + db := o["dbname"] + username := o["user"] // From: https://github.com/tg/pgpass/blob/master/reader.go getFields := func(s string) []string { fs := make([]string, 0, 5) @@ -224,10 +230,10 @@ func (c *conn) handlePgpass(o values) { } } -func (c *conn) writeBuf(b byte) *writeBuf { - c.scratch[0] = b +func (cn *conn) writeBuf(b byte) *writeBuf { + cn.scratch[0] = b return &writeBuf{ - buf: c.scratch[:5], + buf: cn.scratch[:5], pos: 1, } } @@ -250,13 +256,13 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { // * Very low precedence defaults applied in every situation // * Environment variables // * Explicitly passed connection information - o.Set("host", "localhost") - o.Set("port", "5432") + o["host"] = "localhost" + o["port"] = "5432" // N.B.: Extra float digits should be set to 3, but that breaks // Postgres 8.4 and older, where the max is 2. - o.Set("extra_float_digits", "2") + o["extra_float_digits"] = "2" for k, v := range parseEnviron(os.Environ()) { - o.Set(k, v) + o[k] = v } if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") { @@ -271,9 +277,9 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { } // Use the "fallback" application name if necessary - if fallback := o.Get("fallback_application_name"); fallback != "" { - if !o.Isset("application_name") { - o.Set("application_name", fallback) + if fallback, ok := o["fallback_application_name"]; ok { + if _, ok := o["application_name"]; !ok { + o["application_name"] = fallback } } @@ -284,33 +290,35 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { // parsing its value is not worth it. Instead, we always explicitly send // client_encoding as a separate run-time parameter, which should override // anything set in options. - if enc := o.Get("client_encoding"); enc != "" && !isUTF8(enc) { + if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { return nil, errors.New("client_encoding must be absent or 'UTF8'") } - o.Set("client_encoding", "UTF8") + o["client_encoding"] = "UTF8" // DateStyle needs a similar treatment. - if datestyle := o.Get("datestyle"); datestyle != "" { + if datestyle, ok := o["datestyle"]; ok { if datestyle != "ISO, MDY" { panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle)) } } else { - o.Set("datestyle", "ISO, MDY") + o["datestyle"] = "ISO, MDY" } // If a user is not provided by any other means, the last // resort is to use the current operating system provided user // name. - if o.Get("user") == "" { + if _, ok := o["user"]; !ok { u, err := userCurrent() if err != nil { return nil, err - } else { - o.Set("user", u) } + o["user"] = u } - cn := &conn{} + cn := &conn{ + opts: o, + dialer: d, + } err = cn.handleDriverSettings(o) if err != nil { return nil, err @@ -326,7 +334,7 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { cn.startup(o) // reset the deadline, in case one was set (see dial) - if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" { + if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { err = cn.c.SetDeadline(time.Time{}) } return cn, err @@ -340,7 +348,7 @@ func dial(d Dialer, o values) (net.Conn, error) { } // Zero or not specified means wait indefinitely. - if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" { + if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { seconds, err := strconv.ParseInt(timeout, 10, 0) if err != nil { return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) @@ -362,31 +370,18 @@ func dial(d Dialer, o values) (net.Conn, error) { } func network(o values) (string, string) { - host := o.Get("host") + host := o["host"] if strings.HasPrefix(host, "/") { - sockPath := path.Join(host, ".s.PGSQL."+o.Get("port")) + sockPath := path.Join(host, ".s.PGSQL."+o["port"]) return "unix", sockPath } - return "tcp", net.JoinHostPort(host, o.Get("port")) + return "tcp", net.JoinHostPort(host, o["port"]) } type values map[string]string -func (vs values) Set(k, v string) { - vs[k] = v -} - -func (vs values) Get(k string) (v string) { - return vs[k] -} - -func (vs values) Isset(k string) bool { - _, ok := vs[k] - return ok -} - // scanner implements a tokenizer for libpq-style option strings. type scanner struct { s []rune @@ -457,7 +452,7 @@ func parseOpts(name string, o values) error { // Skip any whitespace after the = if r, ok = s.SkipSpaces(); !ok { // If we reach the end here, the last value is just an empty string as per libpq. - o.Set(string(keyRunes), "") + o[string(keyRunes)] = "" break } @@ -492,7 +487,7 @@ func parseOpts(name string, o values) error { } } - o.Set(string(keyRunes), string(valRunes)) + o[string(keyRunes)] = string(valRunes) } return nil @@ -511,13 +506,17 @@ func (cn *conn) checkIsInTransaction(intxn bool) { } func (cn *conn) Begin() (_ driver.Tx, err error) { + return cn.begin("") +} + +func (cn *conn) begin(mode string) (_ driver.Tx, err error) { if cn.bad { return nil, driver.ErrBadConn } defer cn.errRecover(&err) cn.checkIsInTransaction(false) - _, commandTag, err := cn.simpleExec("BEGIN") + _, commandTag, err := cn.simpleExec("BEGIN" + mode) if err != nil { return nil, err } @@ -532,7 +531,14 @@ func (cn *conn) Begin() (_ driver.Tx, err error) { return cn, nil } +func (cn *conn) closeTxn() { + if finish := cn.txnFinish; finish != nil { + finish() + } +} + func (cn *conn) Commit() (err error) { + defer cn.closeTxn() if cn.bad { return driver.ErrBadConn } @@ -568,6 +574,7 @@ func (cn *conn) Commit() (err error) { } func (cn *conn) Rollback() (err error) { + defer cn.closeTxn() if cn.bad { return driver.ErrBadConn } @@ -647,6 +654,12 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { cn: cn, } } + // Set the result and tag to the last command complete if there wasn't a + // query already run. Although queries usually return from here and cede + // control to Next, a query with zero results does not. + if t == 'C' && res.colNames == nil { + res.result, res.tag = cn.parseComplete(r.string()) + } res.done = true case 'Z': cn.processReadyForQuery(r) @@ -685,7 +698,7 @@ var emptyRows noRows var _ driver.Result = noRows{} func (noRows) LastInsertId() (int64, error) { - return 0, errNoLastInsertId + return 0, errNoLastInsertID } func (noRows) RowsAffected() (int64, error) { @@ -694,7 +707,7 @@ func (noRows) RowsAffected() (int64, error) { // Decides which column formats to use for a prepared statement. The input is // an array of type oids, one element per result column. -func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, colFmtData []byte) { +func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) { if len(colTyps) == 0 { return nil, colFmtDataAllText } @@ -706,8 +719,8 @@ func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, c allBinary := true allText := true - for i, o := range colTyps { - switch o { + for i, t := range colTyps { + switch t.OID { // This is the list of types to use binary mode for when receiving them // through a prepared statement. If a type appears in this list, it // must also be implemented in binaryDecode in encode.go. @@ -718,6 +731,8 @@ func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, c case oid.T_int4: fallthrough case oid.T_int2: + fallthrough + case oid.T_uuid: colFmts[i] = formatBinary allText = false @@ -797,7 +812,11 @@ func (cn *conn) Close() (err error) { } // Implement the "Queryer" interface -func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err error) { +func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { + return cn.query(query, args) +} + +func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { if cn.bad { return nil, driver.ErrBadConn } @@ -821,16 +840,15 @@ func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err err rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse() cn.postExecuteWorkaround() return rows, nil - } else { - st := cn.prepareTo(query, "") - st.exec(args) - return &rows{ - cn: cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, - }, nil } + st := cn.prepareTo(query, "") + st.exec(args) + return &rows{ + cn: cn, + colNames: st.colNames, + colTyps: st.colTyps, + colFmts: st.colFmts, + }, nil } // Implement the optional "Execer" interface for one-shot queries @@ -857,17 +875,16 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err cn.postExecuteWorkaround() res, _, err = cn.readExecuteResponse("Execute") return res, err - } else { - // Use the unnamed statement to defer planning until bind - // time, or else value-based selectivity estimates cannot be - // used. - st := cn.prepareTo(query, "") - r, err := st.Exec(args) - if err != nil { - panic(err) - } - return r, err } + // Use the unnamed statement to defer planning until bind + // time, or else value-based selectivity estimates cannot be + // used. + st := cn.prepareTo(query, "") + r, err := st.Exec(args) + if err != nil { + panic(err) + } + return r, err } func (cn *conn) send(m *writeBuf) { @@ -877,16 +894,9 @@ func (cn *conn) send(m *writeBuf) { } } -func (cn *conn) sendStartupPacket(m *writeBuf) { - // sanity check - if m.buf[0] != 0 { - panic("oops") - } - +func (cn *conn) sendStartupPacket(m *writeBuf) error { _, err := cn.c.Write((m.wrap())[1:]) - if err != nil { - panic(err) - } + return err } // Send a message of type typ to the server on the other end of cn. The @@ -1000,45 +1010,17 @@ func (cn *conn) recv1() (t byte, r *readBuf) { } func (cn *conn) ssl(o values) { - verifyCaOnly := false - tlsConf := tls.Config{} - switch mode := o.Get("sslmode"); mode { - // "require" is the default. - case "", "require": - // We must skip TLS's own verification since it requires full - // verification since Go 1.3. - tlsConf.InsecureSkipVerify = true - - // From http://www.postgresql.org/docs/current/static/libpq-ssl.html: - // Note: For backwards compatibility with earlier versions of PostgreSQL, if a - // root CA file exists, the behavior of sslmode=require will be the same as - // that of verify-ca, meaning the server certificate is validated against the - // CA. Relying on this behavior is discouraged, and applications that need - // certificate validation should always use verify-ca or verify-full. - if _, err := os.Stat(o.Get("sslrootcert")); err == nil { - verifyCaOnly = true - } else { - o.Set("sslrootcert", "") - } - case "verify-ca": - // We must skip TLS's own verification since it requires full - // verification since Go 1.3. - tlsConf.InsecureSkipVerify = true - verifyCaOnly = true - case "verify-full": - tlsConf.ServerName = o.Get("host") - case "disable": + upgrade := ssl(o) + if upgrade == nil { + // Nothing to do return - default: - errorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) } - cn.setupSSLClientCertificates(&tlsConf, o) - cn.setupSSLCA(&tlsConf, o) - w := cn.writeBuf(0) w.int32(80877103) - cn.sendStartupPacket(w) + if err := cn.sendStartupPacket(w); err != nil { + panic(err) + } b := cn.scratch[:1] _, err := io.ReadFull(cn.c, b) @@ -1050,114 +1032,7 @@ func (cn *conn) ssl(o values) { panic(ErrSSLNotSupported) } - client := tls.Client(cn.c, &tlsConf) - if verifyCaOnly { - cn.verifyCA(client, &tlsConf) - } - cn.c = client -} - -// verifyCA carries out a TLS handshake to the server and verifies the -// presented certificate against the effective CA, i.e. the one specified in -// sslrootcert or the system CA if sslrootcert was not specified. -func (cn *conn) verifyCA(client *tls.Conn, tlsConf *tls.Config) { - err := client.Handshake() - if err != nil { - panic(err) - } - certs := client.ConnectionState().PeerCertificates - opts := x509.VerifyOptions{ - DNSName: client.ConnectionState().ServerName, - Intermediates: x509.NewCertPool(), - Roots: tlsConf.RootCAs, - } - for i, cert := range certs { - if i == 0 { - continue - } - opts.Intermediates.AddCert(cert) - } - _, err = certs[0].Verify(opts) - if err != nil { - panic(err) - } -} - -// This function sets up SSL client certificates based on either the "sslkey" -// and "sslcert" settings (possibly set via the environment variables PGSSLKEY -// and PGSSLCERT, respectively), or if they aren't set, from the .postgresql -// directory in the user's home directory. If the file paths are set -// explicitly, the files must exist. The key file must also not be -// world-readable, or this function will panic with -// ErrSSLKeyHasWorldPermissions. -func (cn *conn) setupSSLClientCertificates(tlsConf *tls.Config, o values) { - var missingOk bool - - sslkey := o.Get("sslkey") - sslcert := o.Get("sslcert") - if sslkey != "" && sslcert != "" { - // If the user has set an sslkey and sslcert, they *must* exist. - missingOk = false - } else { - // Automatically load certificates from ~/.postgresql. - user, err := user.Current() - if err != nil { - // user.Current() might fail when cross-compiling. We have to - // ignore the error and continue without client certificates, since - // we wouldn't know where to load them from. - return - } - - sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") - sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") - missingOk = true - } - - // Check that both files exist, and report the error or stop, depending on - // which behaviour we want. Note that we don't do any more extensive - // checks than this (such as checking that the paths aren't directories); - // LoadX509KeyPair() will take care of the rest. - keyfinfo, err := os.Stat(sslkey) - if err != nil && missingOk { - return - } else if err != nil { - panic(err) - } - _, err = os.Stat(sslcert) - if err != nil && missingOk { - return - } else if err != nil { - panic(err) - } - - // If we got this far, the key file must also have the correct permissions - kmode := keyfinfo.Mode() - if kmode != kmode&0600 { - panic(ErrSSLKeyHasWorldPermissions) - } - - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) - if err != nil { - panic(err) - } - tlsConf.Certificates = []tls.Certificate{cert} -} - -// Sets up RootCAs in the TLS configuration if sslrootcert is set. -func (cn *conn) setupSSLCA(tlsConf *tls.Config, o values) { - if sslrootcert := o.Get("sslrootcert"); sslrootcert != "" { - tlsConf.RootCAs = x509.NewCertPool() - - cert, err := ioutil.ReadFile(sslrootcert) - if err != nil { - panic(err) - } - - ok := tlsConf.RootCAs.AppendCertsFromPEM(cert) - if !ok { - errorf("couldn't parse pem in sslrootcert") - } - } + cn.c = upgrade(cn.c) } // isDriverSetting returns true iff a setting is purely for configuring the @@ -1206,12 +1081,15 @@ func (cn *conn) startup(o values) { w.string(v) } w.string("") - cn.sendStartupPacket(w) + if err := cn.sendStartupPacket(w); err != nil { + panic(err) + } for { t, r := cn.recv() switch t { case 'K': + cn.processBackendKeyData(r) case 'S': cn.processParameterStatus(r) case 'R': @@ -1231,7 +1109,7 @@ func (cn *conn) auth(r *readBuf, o values) { // OK case 3: w := cn.writeBuf('p') - w.string(o.Get("password")) + w.string(o["password"]) cn.send(w) t, r := cn.recv() @@ -1245,7 +1123,7 @@ func (cn *conn) auth(r *readBuf, o values) { case 5: s := string(r.next(4)) w := cn.writeBuf('p') - w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s)) + w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) cn.send(w) t, r := cn.recv() @@ -1267,10 +1145,10 @@ const formatText format = 0 const formatBinary format = 1 // One result-column format code with the value 1 (i.e. all binary). -var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1} +var colFmtDataAllBinary = []byte{0, 1, 0, 1} // No result-column format codes (i.e. all text). -var colFmtDataAllText []byte = []byte{0, 0} +var colFmtDataAllText = []byte{0, 0} type stmt struct { cn *conn @@ -1278,7 +1156,7 @@ type stmt struct { colNames []string colFmts []format colFmtData []byte - colTyps []oid.Oid + colTyps []fieldDesc paramTyps []oid.Oid closed bool } @@ -1439,21 +1317,32 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { type rows struct { cn *conn + finish func() colNames []string - colTyps []oid.Oid + colTyps []fieldDesc colFmts []format done bool rb readBuf + result driver.Result + tag string } func (rs *rows) Close() error { + if finish := rs.finish; finish != nil { + defer finish() + } // no need to look at cn.bad as Next() will for { err := rs.Next(nil) switch err { case nil: case io.EOF: - return nil + // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row + // description, used with HasNextResultSet). We need to fetch messages until + // we hit a 'Z', which is done by waiting for done to be set. + if rs.done { + return nil + } default: return err } @@ -1464,6 +1353,17 @@ func (rs *rows) Columns() []string { return rs.colNames } +func (rs *rows) Result() driver.Result { + if rs.result == nil { + return emptyRows + } + return rs.result +} + +func (rs *rows) Tag() string { + return rs.tag +} + func (rs *rows) Next(dest []driver.Value) (err error) { if rs.done { return io.EOF @@ -1481,6 +1381,9 @@ func (rs *rows) Next(dest []driver.Value) (err error) { case 'E': err = parseError(&rs.rb) case 'C', 'I': + if t == 'C' { + rs.result, rs.tag = conn.parseComplete(rs.rb.string()) + } continue case 'Z': conn.processReadyForQuery(&rs.rb) @@ -1504,7 +1407,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { dest[i] = nil continue } - dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i], rs.colFmts[i]) + dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) } return case 'T': @@ -1610,7 +1513,7 @@ func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { cn.send(b) } -func (c *conn) processParameterStatus(r *readBuf) { +func (cn *conn) processParameterStatus(r *readBuf) { var err error param := r.string() @@ -1621,13 +1524,13 @@ func (c *conn) processParameterStatus(r *readBuf) { var minor int _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor) if err == nil { - c.parameterStatus.serverVersion = major1*10000 + major2*100 + minor + cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor } case "TimeZone": - c.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) + cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) if err != nil { - c.parameterStatus.currentLocation = nil + cn.parameterStatus.currentLocation = nil } default: @@ -1635,8 +1538,8 @@ func (c *conn) processParameterStatus(r *readBuf) { } } -func (c *conn) processReadyForQuery(r *readBuf) { - c.txnStatus = transactionStatus(r.byte()) +func (cn *conn) processReadyForQuery(r *readBuf) { + cn.txnStatus = transactionStatus(r.byte()) } func (cn *conn) readReadyForQuery() { @@ -1651,6 +1554,11 @@ func (cn *conn) readReadyForQuery() { } } +func (cn *conn) processBackendKeyData(r *readBuf) { + cn.processID = r.int32() + cn.secretKey = r.int32() +} + func (cn *conn) readParseResponse() { t, r := cn.recv1() switch t { @@ -1666,7 +1574,7 @@ func (cn *conn) readParseResponse() { } } -func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []oid.Oid) { +func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) { for { t, r := cn.recv1() switch t { @@ -1692,7 +1600,7 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [ } } -func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []oid.Oid) { +func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) { t, r := cn.recv1() switch t { case 'T': @@ -1788,31 +1696,33 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co } } -func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []oid.Oid) { +func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) { n := r.int16() colNames = make([]string, n) - colTyps = make([]oid.Oid, n) + colTyps = make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) - colTyps[i] = r.oid() - r.next(6) + colTyps[i].OID = r.oid() + colTyps[i].Len = r.int16() + colTyps[i].Mod = r.int32() // format code not known when describing a statement; always 0 r.next(2) } return } -func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []oid.Oid) { +func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) { n := r.int16() colNames = make([]string, n) colFmts = make([]format, n) - colTyps = make([]oid.Oid, n) + colTyps = make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) - colTyps[i] = r.oid() - r.next(6) + colTyps[i].OID = r.oid() + colTyps[i].Len = r.int16() + colTyps[i].Mod = r.int32() colFmts[i] = format(r.int16()) } return |