aboutsummaryrefslogtreecommitdiff
path: root/libgo/go/database/sql
diff options
context:
space:
mode:
authorIan Lance Taylor <iant@golang.org>2017-01-14 00:05:42 +0000
committerIan Lance Taylor <ian@gcc.gnu.org>2017-01-14 00:05:42 +0000
commitc2047754c300b68c05d65faa8dc2925fe67b71b4 (patch)
treee183ae81a1f48a02945cb6de463a70c5be1b06f6 /libgo/go/database/sql
parent829afb8f05602bb31c9c597b24df7377fed4f059 (diff)
downloadgcc-c2047754c300b68c05d65faa8dc2925fe67b71b4.zip
gcc-c2047754c300b68c05d65faa8dc2925fe67b71b4.tar.gz
gcc-c2047754c300b68c05d65faa8dc2925fe67b71b4.tar.bz2
libgo: update to Go 1.8 release candidate 1
Compiler changes: * Change map assignment to use mapassign and assign value directly. * Change string iteration to use decoderune, faster for ASCII strings. * Change makeslice to take int, and use makeslice64 for larger values. * Add new noverflow field to hmap struct used for maps. Unresolved problems, to be fixed later: * Commented out test in go/types/sizes_test.go that doesn't compile. * Commented out reflect.TestStructOf test for padding after zero-sized field. Reviewed-on: https://go-review.googlesource.com/35231 gotools/: Updates for Go 1.8rc1. * Makefile.am (go_cmd_go_files): Add bug.go. (s-zdefaultcc): Write defaultPkgConfig. * Makefile.in: Rebuild. From-SVN: r244456
Diffstat (limited to 'libgo/go/database/sql')
-rw-r--r--libgo/go/database/sql/convert.go91
-rw-r--r--libgo/go/database/sql/convert_test.go83
-rw-r--r--libgo/go/database/sql/ctxutil.go163
-rw-r--r--libgo/go/database/sql/driver/driver.go195
-rw-r--r--libgo/go/database/sql/driver/types.go42
-rw-r--r--libgo/go/database/sql/driver/types_test.go16
-rw-r--r--libgo/go/database/sql/fakedb_test.go375
-rw-r--r--libgo/go/database/sql/sql.go916
-rw-r--r--libgo/go/database/sql/sql_test.go555
9 files changed, 2069 insertions, 367 deletions
diff --git a/libgo/go/database/sql/convert.go b/libgo/go/database/sql/convert.go
index 99aed23..ea2f377 100644
--- a/libgo/go/database/sql/convert.go
+++ b/libgo/go/database/sql/convert.go
@@ -13,16 +13,36 @@ import (
"reflect"
"strconv"
"time"
+ "unicode"
+ "unicode/utf8"
)
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
+func describeNamedValue(nv *driver.NamedValue) string {
+ if len(nv.Name) == 0 {
+ return fmt.Sprintf("$%d", nv.Ordinal)
+ }
+ return fmt.Sprintf("with name %q", nv.Name)
+}
+
+func validateNamedValueName(name string) error {
+ if len(name) == 0 {
+ return nil
+ }
+ r, _ := utf8.DecodeRuneInString(name)
+ if unicode.IsLetter(r) {
+ return nil
+ }
+ return fmt.Errorf("name %q does not begin with a letter", name)
+}
+
// driverArgs converts arguments from callers of Stmt.Exec and
// Stmt.Query into driver Values.
//
// The statement ds may be nil, if no statement is available.
-func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
- dargs := make([]driver.Value, len(args))
+func driverArgs(ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
+ nvargs := make([]driver.NamedValue, len(args))
var si driver.Stmt
if ds != nil {
si = ds.si
@@ -33,26 +53,45 @@ func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
if !ok {
for n, arg := range args {
var err error
- dargs[n], err = driver.DefaultParameterConverter.ConvertValue(arg)
+ nv := &nvargs[n]
+ nv.Ordinal = n + 1
+ if np, ok := arg.(NamedArg); ok {
+ if err := validateNamedValueName(np.Name); err != nil {
+ return nil, err
+ }
+ arg = np.Value
+ nvargs[n].Name = np.Name
+ }
+ nv.Value, err = driver.DefaultParameterConverter.ConvertValue(arg)
+
if err != nil {
- return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err)
+ return nil, fmt.Errorf("sql: converting Exec argument %s type: %v", describeNamedValue(nv), err)
}
}
- return dargs, nil
+ return nvargs, nil
}
// Let the Stmt convert its own arguments.
for n, arg := range args {
+ nv := &nvargs[n]
+ nv.Ordinal = n + 1
+ if np, ok := arg.(NamedArg); ok {
+ if err := validateNamedValueName(np.Name); err != nil {
+ return nil, err
+ }
+ arg = np.Value
+ nv.Name = np.Name
+ }
// First, see if the value itself knows how to convert
// itself to a driver type. For example, a NullString
// struct changing into a string or nil.
- if svi, ok := arg.(driver.Valuer); ok {
- sv, err := svi.Value()
+ if vr, ok := arg.(driver.Valuer); ok {
+ sv, err := callValuerValue(vr)
if err != nil {
- return nil, fmt.Errorf("sql: argument index %d from Value: %v", n, err)
+ return nil, fmt.Errorf("sql: argument %s from Value: %v", describeNamedValue(nv), err)
}
if !driver.IsValue(sv) {
- return nil, fmt.Errorf("sql: argument index %d: non-subset type %T returned from Value", n, sv)
+ return nil, fmt.Errorf("sql: argument %s: non-subset type %T returned from Value", describeNamedValue(nv), sv)
}
arg = sv
}
@@ -66,18 +105,18 @@ func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
// same error.
var err error
ds.Lock()
- dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg)
+ nv.Value, err = cc.ColumnConverter(n).ConvertValue(arg)
ds.Unlock()
if err != nil {
- return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err)
+ return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err)
}
- if !driver.IsValue(dargs[n]) {
- return nil, fmt.Errorf("sql: driver ColumnConverter error converted %T to unsupported type %T",
- arg, dargs[n])
+ if !driver.IsValue(nv.Value) {
+ return nil, fmt.Errorf("sql: for argument %s, driver ColumnConverter error converted %T to unsupported type %T",
+ describeNamedValue(nv), arg, nv.Value)
}
}
- return dargs, nil
+ return nvargs, nil
}
// convertAssign copies to dest the value in src, converting it if possible.
@@ -330,3 +369,25 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
}
return
}
+
+var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
+
+// callValuerValue returns vr.Value(), with one exception:
+// If vr.Value is an auto-generated method on a pointer type and the
+// pointer is nil, it would panic at runtime in the panicwrap
+// method. Treat it like nil instead.
+// Issue 8415.
+//
+// This is so people can implement driver.Value on value types and
+// still use nil pointers to those types to mean nil/NULL, just like
+// string/*string.
+//
+// This function is mirrored in the database/sql/driver package.
+func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
+ if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
+ rv.IsNil() &&
+ rv.Type().Elem().Implements(valuerReflectType) {
+ return nil, nil
+ }
+ return vr.Value()
+}
diff --git a/libgo/go/database/sql/convert_test.go b/libgo/go/database/sql/convert_test.go
index ab81f2f..4dfab1f 100644
--- a/libgo/go/database/sql/convert_test.go
+++ b/libgo/go/database/sql/convert_test.go
@@ -9,6 +9,7 @@ import (
"fmt"
"reflect"
"runtime"
+ "strings"
"testing"
"time"
)
@@ -389,3 +390,85 @@ func TestUserDefinedBytes(t *testing.T) {
t.Fatal("userDefinedBytes got potentially dirty driver memory")
}
}
+
+type Valuer_V string
+
+func (v Valuer_V) Value() (driver.Value, error) {
+ return strings.ToUpper(string(v)), nil
+}
+
+type Valuer_P string
+
+func (p *Valuer_P) Value() (driver.Value, error) {
+ if p == nil {
+ return "nil-to-str", nil
+ }
+ return strings.ToUpper(string(*p)), nil
+}
+
+func TestDriverArgs(t *testing.T) {
+ var nilValuerVPtr *Valuer_V
+ var nilValuerPPtr *Valuer_P
+ var nilStrPtr *string
+ tests := []struct {
+ args []interface{}
+ want []driver.NamedValue
+ }{
+ 0: {
+ args: []interface{}{Valuer_V("foo")},
+ want: []driver.NamedValue{
+ driver.NamedValue{
+ Ordinal: 1,
+ Value: "FOO",
+ },
+ },
+ },
+ 1: {
+ args: []interface{}{nilValuerVPtr},
+ want: []driver.NamedValue{
+ driver.NamedValue{
+ Ordinal: 1,
+ Value: nil,
+ },
+ },
+ },
+ 2: {
+ args: []interface{}{nilValuerPPtr},
+ want: []driver.NamedValue{
+ driver.NamedValue{
+ Ordinal: 1,
+ Value: "nil-to-str",
+ },
+ },
+ },
+ 3: {
+ args: []interface{}{"plain-str"},
+ want: []driver.NamedValue{
+ driver.NamedValue{
+ Ordinal: 1,
+ Value: "plain-str",
+ },
+ },
+ },
+ 4: {
+ args: []interface{}{nilStrPtr},
+ want: []driver.NamedValue{
+ driver.NamedValue{
+ Ordinal: 1,
+ Value: nil,
+ },
+ },
+ },
+ }
+ for i, tt := range tests {
+ ds := new(driverStmt)
+ got, err := driverArgs(ds, tt.args)
+ if err != nil {
+ t.Errorf("test[%d]: %v", i, err)
+ continue
+ }
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("test[%d]: got %v, want %v", i, got, tt.want)
+ }
+ }
+}
diff --git a/libgo/go/database/sql/ctxutil.go b/libgo/go/database/sql/ctxutil.go
new file mode 100644
index 0000000..1071446
--- /dev/null
+++ b/libgo/go/database/sql/ctxutil.go
@@ -0,0 +1,163 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sql
+
+import (
+ "context"
+ "database/sql/driver"
+ "errors"
+)
+
+func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) {
+ if ciCtx, is := ci.(driver.ConnPrepareContext); is {
+ return ciCtx.PrepareContext(ctx, query)
+ }
+ si, err := ci.Prepare(query)
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ si.Close()
+ return nil, ctx.Err()
+ }
+ }
+ return si, err
+}
+
+func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
+ if execerCtx, is := execer.(driver.ExecerContext); is {
+ return execerCtx.ExecContext(ctx, query, nvdargs)
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
+ }
+
+ resi, err := execer.Exec(query, dargs)
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ return resi, ctx.Err()
+ }
+ }
+ return resi, err
+}
+
+func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
+ if queryerCtx, is := queryer.(driver.QueryerContext); is {
+ ret, err := queryerCtx.QueryContext(ctx, query, nvdargs)
+ return ret, err
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
+ }
+
+ rowsi, err := queryer.Query(query, dargs)
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ rowsi.Close()
+ return nil, ctx.Err()
+ }
+ }
+ return rowsi, err
+}
+
+func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
+ if siCtx, is := si.(driver.StmtExecContext); is {
+ return siCtx.ExecContext(ctx, nvdargs)
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
+ }
+
+ resi, err := si.Exec(dargs)
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ return resi, ctx.Err()
+ }
+ }
+ return resi, err
+}
+
+func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
+ if siCtx, is := si.(driver.StmtQueryContext); is {
+ return siCtx.QueryContext(ctx, nvdargs)
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
+ }
+
+ rowsi, err := si.Query(dargs)
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ rowsi.Close()
+ return nil, ctx.Err()
+ }
+ }
+ return rowsi, err
+}
+
+var errLevelNotSupported = errors.New("sql: selected isolation level is not supported")
+
+func ctxDriverBegin(ctx context.Context, opts *TxOptions, ci driver.Conn) (driver.Tx, error) {
+ if ciCtx, is := ci.(driver.ConnBeginTx); is {
+ dopts := driver.TxOptions{}
+ if opts != nil {
+ dopts.Isolation = driver.IsolationLevel(opts.Isolation)
+ dopts.ReadOnly = opts.ReadOnly
+ }
+ return ciCtx.BeginTx(ctx, dopts)
+ }
+
+ if ctx.Done() == context.Background().Done() {
+ return ci.Begin()
+ }
+
+ if opts != nil {
+ // Check the transaction level. If the transaction level is non-default
+ // then return an error here as the BeginTx driver value is not supported.
+ if opts.Isolation != LevelDefault {
+ return nil, errors.New("sql: driver does not support non-default isolation level")
+ }
+
+ // If a read-only transaction is requested return an error as the
+ // BeginTx driver value is not supported.
+ if opts.ReadOnly {
+ return nil, errors.New("sql: driver does not support read-only transactions")
+ }
+ }
+
+ txi, err := ci.Begin()
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ txi.Rollback()
+ return nil, ctx.Err()
+ }
+ }
+ return txi, err
+}
+
+func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
+ dargs := make([]driver.Value, len(named))
+ for n, param := range named {
+ if len(param.Name) > 0 {
+ return nil, errors.New("sql: driver does not support the use of Named Parameters")
+ }
+ dargs[n] = param.Value
+ }
+ return dargs, nil
+}
diff --git a/libgo/go/database/sql/driver/driver.go b/libgo/go/database/sql/driver/driver.go
index 4dba85a..d66196f 100644
--- a/libgo/go/database/sql/driver/driver.go
+++ b/libgo/go/database/sql/driver/driver.go
@@ -8,7 +8,11 @@
// Most code should use package sql.
package driver
-import "errors"
+import (
+ "context"
+ "errors"
+ "reflect"
+)
// Value is a value that drivers must be able to handle.
// It is either nil or an instance of one of these types:
@@ -21,6 +25,21 @@ import "errors"
// time.Time
type Value interface{}
+// NamedValue holds both the value name and value.
+type NamedValue struct {
+ // If the Name is not empty it should be used for the parameter identifier and
+ // not the ordinal position.
+ //
+ // Name will not have a symbol prefix.
+ Name string
+
+ // Ordinal position of the parameter starting from one and is always set.
+ Ordinal int
+
+ // Value is the parameter value.
+ Value Value
+}
+
// Driver is the interface that must be implemented by a database
// driver.
type Driver interface {
@@ -54,6 +73,17 @@ var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented")
// you shouldn't return ErrBadConn.
var ErrBadConn = errors.New("driver: bad connection")
+// Pinger is an optional interface that may be implemented by a Conn.
+//
+// If a Conn does not implement Pinger, the sql package's DB.Ping and
+// DB.PingContext will check if there is at least one Conn available.
+//
+// If Conn.Ping returns ErrBadConn, DB.Ping and DB.PingContext will remove
+// the Conn from pool.
+type Pinger interface {
+ Ping(ctx context.Context) error
+}
+
// Execer is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Execer, the sql package's DB.Exec will
@@ -61,10 +91,25 @@ var ErrBadConn = errors.New("driver: bad connection")
// statement.
//
// Exec may return ErrSkip.
+//
+// Deprecated: Drivers should implement ExecerContext instead (or additionally).
type Execer interface {
Exec(query string, args []Value) (Result, error)
}
+// ExecerContext is an optional interface that may be implemented by a Conn.
+//
+// If a Conn does not implement ExecerContext, the sql package's DB.Exec will
+// first prepare a query, execute the statement, and then close the
+// statement.
+//
+// ExecerContext may return ErrSkip.
+//
+// ExecerContext must honor the context timeout and return when the context is canceled.
+type ExecerContext interface {
+ ExecContext(ctx context.Context, query string, args []NamedValue) (Result, error)
+}
+
// Queryer is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Queryer, the sql package's DB.Query will
@@ -72,10 +117,25 @@ type Execer interface {
// statement.
//
// Query may return ErrSkip.
+//
+// Deprecated: Drivers should implement QueryerContext instead (or additionally).
type Queryer interface {
Query(query string, args []Value) (Rows, error)
}
+// QueryerContext is an optional interface that may be implemented by a Conn.
+//
+// If a Conn does not implement QueryerContext, the sql package's DB.Query will
+// first prepare a query, execute the statement, and then close the
+// statement.
+//
+// QueryerContext may return ErrSkip.
+//
+// QueryerContext must honor the context timeout and return when the context is canceled.
+type QueryerContext interface {
+ QueryContext(ctx context.Context, query string, args []NamedValue) (Rows, error)
+}
+
// Conn is a connection to a database. It is not used concurrently
// by multiple goroutines.
//
@@ -95,9 +155,50 @@ type Conn interface {
Close() error
// Begin starts and returns a new transaction.
+ //
+ // Deprecated: Drivers should implement ConnBeginTx instead (or additionally).
Begin() (Tx, error)
}
+// ConnPrepareContext enhances the Conn interface with context.
+type ConnPrepareContext interface {
+ // PrepareContext returns a prepared statement, bound to this connection.
+ // context is for the preparation of the statement,
+ // it must not store the context within the statement itself.
+ PrepareContext(ctx context.Context, query string) (Stmt, error)
+}
+
+// IsolationLevel is the transaction isolation level stored in TxOptions.
+//
+// This type should be considered identical to sql.IsolationLevel along
+// with any values defined on it.
+type IsolationLevel int
+
+// TxOptions holds the transaction options.
+//
+// This type should be considered identical to sql.TxOptions.
+type TxOptions struct {
+ Isolation IsolationLevel
+ ReadOnly bool
+}
+
+// ConnBeginTx enhances the Conn interface with context and TxOptions.
+type ConnBeginTx interface {
+ // BeginTx starts and returns a new transaction.
+ // If the context is canceled by the user the sql package will
+ // call Tx.Rollback before discarding and closing the connection.
+ //
+ // This must check opts.Isolation to determine if there is a set
+ // isolation level. If the driver does not support a non-default
+ // level and one is set or if there is a non-default isolation level
+ // that is not supported, an error must be returned.
+ //
+ // This must also check opts.ReadOnly to determine if the read-only
+ // value is true to either set the read-only transaction property if supported
+ // or return an error if it is not supported.
+ BeginTx(ctx context.Context, opts TxOptions) (Tx, error)
+}
+
// Result is the result of a query execution.
type Result interface {
// LastInsertId returns the database's auto-generated ID
@@ -132,13 +233,35 @@ type Stmt interface {
// Exec executes a query that doesn't return rows, such
// as an INSERT or UPDATE.
+ //
+ // Deprecated: Drivers should implement StmtExecContext instead (or additionally).
Exec(args []Value) (Result, error)
// Query executes a query that may return rows, such as a
// SELECT.
+ //
+ // Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
Query(args []Value) (Rows, error)
}
+// StmtExecContext enhances the Stmt interface by providing Exec with context.
+type StmtExecContext interface {
+ // ExecContext executes a query that doesn't return rows, such
+ // as an INSERT or UPDATE.
+ //
+ // ExecContext must honor the context timeout and return when it is canceled.
+ ExecContext(ctx context.Context, args []NamedValue) (Result, error)
+}
+
+// StmtQueryContext enhances the Stmt interface by providing Query with context.
+type StmtQueryContext interface {
+ // QueryContext executes a query that may return rows, such as a
+ // SELECT.
+ //
+ // QueryContext must honor the context timeout and return when it is canceled.
+ QueryContext(ctx context.Context, args []NamedValue) (Rows, error)
+}
+
// ColumnConverter may be optionally implemented by Stmt if the
// statement is aware of its own columns' types and can convert from
// any type to a driver Value.
@@ -169,6 +292,76 @@ type Rows interface {
Next(dest []Value) error
}
+// RowsNextResultSet extends the Rows interface by providing a way to signal
+// the driver to advance to the next result set.
+type RowsNextResultSet interface {
+ Rows
+
+ // HasNextResultSet is called at the end of the current result set and
+ // reports whether there is another result set after the current one.
+ HasNextResultSet() bool
+
+ // NextResultSet advances the driver to the next result set even
+ // if there are remaining rows in the current result set.
+ //
+ // NextResultSet should return io.EOF when there are no more result sets.
+ NextResultSet() error
+}
+
+// RowsColumnTypeScanType may be implemented by Rows. It should return
+// the value type that can be used to scan types into. For example, the database
+// column type "bigint" this should return "reflect.TypeOf(int64(0))".
+type RowsColumnTypeScanType interface {
+ Rows
+ ColumnTypeScanType(index int) reflect.Type
+}
+
+// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
+// database system type name without the length. Type names should be uppercase.
+// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
+// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
+// "TIMESTAMP".
+type RowsColumnTypeDatabaseTypeName interface {
+ Rows
+ ColumnTypeDatabaseTypeName(index int) string
+}
+
+// RowsColumnTypeLength may be implemented by Rows. It should return the length
+// of the column type if the column is a variable length type. If the column is
+// not a variable length type ok should return false.
+// If length is not limited other than system limits, it should return math.MaxInt64.
+// The following are examples of returned values for various types:
+// TEXT (math.MaxInt64, true)
+// varchar(10) (10, true)
+// nvarchar(10) (10, true)
+// decimal (0, false)
+// int (0, false)
+// bytea(30) (30, true)
+type RowsColumnTypeLength interface {
+ Rows
+ ColumnTypeLength(index int) (length int64, ok bool)
+}
+
+// RowsColumnTypeNullable may be implemented by Rows. The nullable value should
+// be true if it is known the column may be null, or false if the column is known
+// to be not nullable.
+// If the column nullability is unknown, ok should be false.
+type RowsColumnTypeNullable interface {
+ Rows
+ ColumnTypeNullable(index int) (nullable, ok bool)
+}
+
+// RowsColumnTypePrecisionScale may be implemented by Rows. It should return
+// the precision and scale for decimal types. If not applicable, ok should be false.
+// The following are examples of returned values for various types:
+// decimal(38, 4) (38, 4, true)
+// int (0, 0, false)
+// decimal (math.MaxInt64, math.MaxInt64, true)
+type RowsColumnTypePrecisionScale interface {
+ Rows
+ ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool)
+}
+
// Tx is a transaction.
type Tx interface {
Commit() error
diff --git a/libgo/go/database/sql/driver/types.go b/libgo/go/database/sql/driver/types.go
index e480e70..8b3cb6c 100644
--- a/libgo/go/database/sql/driver/types.go
+++ b/libgo/go/database/sql/driver/types.go
@@ -198,9 +198,9 @@ func IsScanValue(v interface{}) bool {
// Value method is used to return a Value. As a fallback, the provided
// argument's underlying type is used to convert it to a Value:
// underlying integer types are converted to int64, floats to float64,
-// and strings to []byte. If the argument is a nil pointer,
-// ConvertValue returns a nil Value. If the argument is a non-nil
-// pointer, it is dereferenced and ConvertValue is called
+// bool, string, and []byte to themselves. If the argument is a nil
+// pointer, ConvertValue returns a nil Value. If the argument is a
+// non-nil pointer, it is dereferenced and ConvertValue is called
// recursively. Other types are an error.
var DefaultParameterConverter defaultConverter
@@ -208,13 +208,35 @@ type defaultConverter struct{}
var _ ValueConverter = defaultConverter{}
+var valuerReflectType = reflect.TypeOf((*Valuer)(nil)).Elem()
+
+// callValuerValue returns vr.Value(), with one exception:
+// If vr.Value is an auto-generated method on a pointer type and the
+// pointer is nil, it would panic at runtime in the panicwrap
+// method. Treat it like nil instead.
+// Issue 8415.
+//
+// This is so people can implement driver.Value on value types and
+// still use nil pointers to those types to mean nil/NULL, just like
+// string/*string.
+//
+// This function is mirrored in the database/sql package.
+func callValuerValue(vr Valuer) (v Value, err error) {
+ if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
+ rv.IsNil() &&
+ rv.Type().Elem().Implements(valuerReflectType) {
+ return nil, nil
+ }
+ return vr.Value()
+}
+
func (defaultConverter) ConvertValue(v interface{}) (Value, error) {
if IsValue(v) {
return v, nil
}
- if svi, ok := v.(Valuer); ok {
- sv, err := svi.Value()
+ if vr, ok := v.(Valuer); ok {
+ sv, err := callValuerValue(vr)
if err != nil {
return nil, err
}
@@ -245,6 +267,16 @@ func (defaultConverter) ConvertValue(v interface{}) (Value, error) {
return int64(u64), nil
case reflect.Float32, reflect.Float64:
return rv.Float(), nil
+ case reflect.Bool:
+ return rv.Bool(), nil
+ case reflect.Slice:
+ ek := rv.Type().Elem().Kind()
+ if ek == reflect.Uint8 {
+ return rv.Bytes(), nil
+ }
+ return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
+ case reflect.String:
+ return rv.String(), nil
}
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}
diff --git a/libgo/go/database/sql/driver/types_test.go b/libgo/go/database/sql/driver/types_test.go
index 1ce0ff0..0379bf8 100644
--- a/libgo/go/database/sql/driver/types_test.go
+++ b/libgo/go/database/sql/driver/types_test.go
@@ -20,6 +20,16 @@ type valueConverterTest struct {
var now = time.Now()
var answer int64 = 42
+type (
+ i int64
+ f float64
+ b bool
+ bs []byte
+ s string
+ t time.Time
+ is []int
+)
+
var valueConverterTests = []valueConverterTest{
{Bool, "true", true, ""},
{Bool, "True", true, ""},
@@ -41,6 +51,12 @@ var valueConverterTests = []valueConverterTest{
{DefaultParameterConverter, (*int64)(nil), nil, ""},
{DefaultParameterConverter, &answer, answer, ""},
{DefaultParameterConverter, &now, now, ""},
+ {DefaultParameterConverter, i(9), int64(9), ""},
+ {DefaultParameterConverter, f(0.1), float64(0.1), ""},
+ {DefaultParameterConverter, b(true), true, ""},
+ {DefaultParameterConverter, bs{1}, []byte{1}, ""},
+ {DefaultParameterConverter, s("a"), "a", ""},
+ {DefaultParameterConverter, is{1}, nil, "unsupported type driver.is, a slice of int"},
}
func TestValueConverters(t *testing.T) {
diff --git a/libgo/go/database/sql/fakedb_test.go b/libgo/go/database/sql/fakedb_test.go
index 5b238bf..4b15f5b 100644
--- a/libgo/go/database/sql/fakedb_test.go
+++ b/libgo/go/database/sql/fakedb_test.go
@@ -5,11 +5,13 @@
package sql
import (
+ "context"
"database/sql/driver"
"errors"
"fmt"
"io"
"log"
+ "reflect"
"sort"
"strconv"
"strings"
@@ -32,10 +34,16 @@ var _ = log.Printf
// where types are: "string", [u]int{8,16,32,64}, "bool"
// INSERT|<tablename>|col=val,col2=val2,col3=?
// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
+// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
//
// Any of these can be preceded by PANIC|<method>|, to cause the
// named method on fakeStmt to panic.
//
+// Any of these can be proceeded by WAIT|<duration>|, to cause the
+// named method on fakeStmt to sleep for the specified duration.
+//
+// Multiple of these can be combined when separated with a semicolon.
+//
// When opening a fakeDriver's database, it starts empty with no
// tables. All tables and data are stored in memory only.
type fakeDriver struct {
@@ -101,6 +109,12 @@ type fakeTx struct {
c *fakeConn
}
+type boundCol struct {
+ Column string
+ Placeholder string
+ Ordinal int
+}
+
type fakeStmt struct {
c *fakeConn
q string // just for debugging
@@ -108,6 +122,9 @@ type fakeStmt struct {
cmd string
table string
panic string
+ wait time.Duration
+
+ next *fakeStmt // used for returning multiple results.
closed bool
@@ -116,7 +133,7 @@ type fakeStmt struct {
colValue []interface{} // used by INSERT (mix of strings and "?" for bound params)
placeholders int // used by INSERT/SELECT: number of ? params
- whereCol []string // used by SELECT (all placeholders)
+ whereCol []boundCol // used by SELECT (all placeholders)
placeholderConverter []driver.ValueConverter // used by INSERT
}
@@ -335,18 +352,23 @@ func (c *fakeConn) Close() (err error) {
return nil
}
-func checkSubsetTypes(args []driver.Value) error {
- for n, arg := range args {
- switch arg.(type) {
+func checkSubsetTypes(args []driver.NamedValue) error {
+ for _, arg := range args {
+ switch arg.Value.(type) {
case int64, float64, bool, nil, []byte, string, time.Time:
default:
- return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
+ return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
}
}
return nil
}
func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+ // Ensure that ExecContext is called if available.
+ panic("ExecContext was not called.")
+}
+
+func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
// This is an optional interface, but it's implemented here
// just to check that all the args are of the proper types.
// ErrSkip is returned so the caller acts as if we didn't
@@ -359,6 +381,11 @@ func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error
}
func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
+ // Ensure that ExecContext is called if available.
+ panic("QueryContext was not called.")
+}
+
+func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
// This is an optional interface, but it's implemented here
// just to check that all the args are of the proper types.
// ErrSkip is returned so the caller acts as if we didn't
@@ -377,12 +404,13 @@ func errf(msg string, args ...interface{}) error {
// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
// (note that where columns must always contain ? marks,
// just a limitation for fakedb)
-func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 3 {
stmt.Close()
return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
}
stmt.table = parts[0]
+
stmt.colName = strings.Split(parts[1], ",")
for n, colspec := range strings.Split(parts[2], ",") {
if colspec == "" {
@@ -399,19 +427,19 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
stmt.Close()
return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
}
- if value != "?" {
+ if !strings.HasPrefix(value, "?") {
stmt.Close()
return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
stmt.table, column)
}
- stmt.whereCol = append(stmt.whereCol, column)
stmt.placeholders++
+ stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
}
return stmt, nil
}
// parts are table|col=type,col2=type2
-func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 2 {
stmt.Close()
return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
@@ -430,7 +458,7 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e
}
// parts are table|col=?,col2=val
-func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 2 {
stmt.Close()
return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
@@ -450,7 +478,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
}
stmt.colName = append(stmt.colName, column)
- if value != "?" {
+ if !strings.HasPrefix(value, "?") {
var subsetVal interface{}
// Convert to driver subset type
switch ctype {
@@ -473,7 +501,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
} else {
stmt.placeholders++
stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
- stmt.colValue = append(stmt.colValue, "?")
+ stmt.colValue = append(stmt.colValue, value)
}
}
return stmt, nil
@@ -483,6 +511,10 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
var hookPrepareBadConn func() bool
func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
+ panic("use PrepareContext")
+}
+
+func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
c.numPrepare++
if c.db == nil {
panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
@@ -492,38 +524,72 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
return nil, driver.ErrBadConn
}
- parts := strings.Split(query, "|")
- if len(parts) < 1 {
- return nil, errf("empty query")
- }
- stmt := &fakeStmt{q: query, c: c}
- if len(parts) >= 3 && parts[0] == "PANIC" {
- stmt.panic = parts[1]
- parts = parts[2:]
- }
- cmd := parts[0]
- stmt.cmd = cmd
- parts = parts[1:]
+ var firstStmt, prev *fakeStmt
+ for _, query := range strings.Split(query, ";") {
+ parts := strings.Split(query, "|")
+ if len(parts) < 1 {
+ return nil, errf("empty query")
+ }
+ stmt := &fakeStmt{q: query, c: c}
+ if firstStmt == nil {
+ firstStmt = stmt
+ }
+ if len(parts) >= 3 {
+ switch parts[0] {
+ case "PANIC":
+ stmt.panic = parts[1]
+ parts = parts[2:]
+ case "WAIT":
+ wait, err := time.ParseDuration(parts[1])
+ if err != nil {
+ return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
+ }
+ parts = parts[2:]
+ stmt.wait = wait
+ }
+ }
+ cmd := parts[0]
+ stmt.cmd = cmd
+ parts = parts[1:]
+
+ if stmt.wait > 0 {
+ wait := time.NewTimer(stmt.wait)
+ select {
+ case <-wait.C:
+ case <-ctx.Done():
+ wait.Stop()
+ return nil, ctx.Err()
+ }
+ }
- c.incrStat(&c.stmtsMade)
- switch cmd {
- case "WIPE":
- // Nothing
- case "SELECT":
- return c.prepareSelect(stmt, parts)
- case "CREATE":
- return c.prepareCreate(stmt, parts)
- case "INSERT":
- return c.prepareInsert(stmt, parts)
- case "NOSERT":
- // Do all the prep-work like for an INSERT but don't actually insert the row.
- // Used for some of the concurrent tests.
- return c.prepareInsert(stmt, parts)
- default:
- stmt.Close()
- return nil, errf("unsupported command type %q", cmd)
+ c.incrStat(&c.stmtsMade)
+ var err error
+ switch cmd {
+ case "WIPE":
+ // Nothing
+ case "SELECT":
+ stmt, err = c.prepareSelect(stmt, parts)
+ case "CREATE":
+ stmt, err = c.prepareCreate(stmt, parts)
+ case "INSERT":
+ stmt, err = c.prepareInsert(stmt, parts)
+ case "NOSERT":
+ // Do all the prep-work like for an INSERT but don't actually insert the row.
+ // Used for some of the concurrent tests.
+ stmt, err = c.prepareInsert(stmt, parts)
+ default:
+ stmt.Close()
+ return nil, errf("unsupported command type %q", cmd)
+ }
+ if err != nil {
+ return nil, err
+ }
+ if prev != nil {
+ prev.next = stmt
+ }
+ prev = stmt
}
- return stmt, nil
+ return firstStmt, nil
}
func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
@@ -550,6 +616,9 @@ func (s *fakeStmt) Close() error {
s.c.incrStat(&s.c.stmtsClosed)
s.closed = true
}
+ if s.next != nil {
+ s.next.Close()
+ }
return nil
}
@@ -559,6 +628,9 @@ var errClosed = errors.New("fakedb: statement has been closed")
var hookExecBadConn func() bool
func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
+ panic("Using ExecContext")
+}
+func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if s.panic == "Exec" {
panic(s.panic)
}
@@ -575,6 +647,16 @@ func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
return nil, err
}
+ if s.wait > 0 {
+ time.Sleep(s.wait)
+ }
+
+ select {
+ default:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+
db := s.c.db
switch s.cmd {
case "WIPE":
@@ -599,7 +681,7 @@ func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
// When doInsert is true, add the row to the table.
// When doInsert is false do prep-work and error checking, but don't
// actually add the row to the table.
-func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result, error) {
+func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
db := s.c.db
if len(args) != s.placeholders {
panic("error in pkg db; should only get here if size is correct")
@@ -625,8 +707,18 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result
return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
}
var val interface{}
- if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" {
- val = args[argPos]
+ if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
+ if strvalue == "?" {
+ val = args[argPos].Value
+ } else {
+ // Assign value from argument placeholder name.
+ for _, a := range args {
+ if a.Name == strvalue[1:] {
+ val = a.Value
+ break
+ }
+ }
+ }
argPos++
} else {
val = s.colValue[n]
@@ -646,6 +738,10 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result
var hookQueryBadConn func() bool
func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
+ panic("Use QueryContext")
+}
+
+func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if s.panic == "Query" {
panic(s.panic)
}
@@ -667,65 +763,101 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
panic("error in pkg db; should only get here if size is correct")
}
- db.mu.Lock()
- t, ok := db.table(s.table)
- db.mu.Unlock()
- if !ok {
- return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
- }
+ setMRows := make([][]*row, 0, 1)
+ setColumns := make([][]string, 0, 1)
+ setColType := make([][]string, 0, 1)
- if s.table == "magicquery" {
- if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
- if args[0] == "sleep" {
- time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
- }
+ for {
+ db.mu.Lock()
+ t, ok := db.table(s.table)
+ db.mu.Unlock()
+ if !ok {
+ return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
}
- }
-
- t.mu.Lock()
- defer t.mu.Unlock()
- colIdx := make(map[string]int) // select column name -> column index in table
- for _, name := range s.colName {
- idx := t.columnIndex(name)
- if idx == -1 {
- return nil, fmt.Errorf("fakedb: unknown column name %q", name)
+ if s.table == "magicquery" {
+ if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
+ if args[0].Value == "sleep" {
+ time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
+ }
+ }
}
- colIdx[name] = idx
- }
- mrows := []*row{}
-rows:
- for _, trow := range t.rows {
- // Process the where clause, skipping non-match rows. This is lazy
- // and just uses fmt.Sprintf("%v") to test equality. Good enough
- // for test code.
- for widx, wcol := range s.whereCol {
- idx := t.columnIndex(wcol)
+ t.mu.Lock()
+
+ colIdx := make(map[string]int) // select column name -> column index in table
+ for _, name := range s.colName {
+ idx := t.columnIndex(name)
if idx == -1 {
- return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
+ t.mu.Unlock()
+ return nil, fmt.Errorf("fakedb: unknown column name %q", name)
}
- tcol := trow.cols[idx]
- if bs, ok := tcol.([]byte); ok {
- // lazy hack to avoid sprintf %v on a []byte
- tcol = string(bs)
+ colIdx[name] = idx
+ }
+
+ mrows := []*row{}
+ rows:
+ for _, trow := range t.rows {
+ // Process the where clause, skipping non-match rows. This is lazy
+ // and just uses fmt.Sprintf("%v") to test equality. Good enough
+ // for test code.
+ for _, wcol := range s.whereCol {
+ idx := t.columnIndex(wcol.Column)
+ if idx == -1 {
+ t.mu.Unlock()
+ return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
+ }
+ tcol := trow.cols[idx]
+ if bs, ok := tcol.([]byte); ok {
+ // lazy hack to avoid sprintf %v on a []byte
+ tcol = string(bs)
+ }
+ var argValue interface{}
+ if wcol.Placeholder == "?" {
+ argValue = args[wcol.Ordinal-1].Value
+ } else {
+ // Assign arg value from placeholder name.
+ for _, a := range args {
+ if a.Name == wcol.Placeholder[1:] {
+ argValue = a.Value
+ break
+ }
+ }
+ }
+ if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
+ continue rows
+ }
}
- if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
- continue rows
+ mrow := &row{cols: make([]interface{}, len(s.colName))}
+ for seli, name := range s.colName {
+ mrow.cols[seli] = trow.cols[colIdx[name]]
}
+ mrows = append(mrows, mrow)
}
- mrow := &row{cols: make([]interface{}, len(s.colName))}
- for seli, name := range s.colName {
- mrow.cols[seli] = trow.cols[colIdx[name]]
+
+ var colType []string
+ for _, column := range s.colName {
+ colType = append(colType, t.coltype[t.columnIndex(column)])
}
- mrows = append(mrows, mrow)
+
+ t.mu.Unlock()
+
+ setMRows = append(setMRows, mrows)
+ setColumns = append(setColumns, s.colName)
+ setColType = append(setColType, colType)
+
+ if s.next == nil {
+ break
+ }
+ s = s.next
}
cursor := &rowsCursor{
- pos: -1,
- rows: mrows,
- cols: s.colName,
- errPos: -1,
+ posRow: -1,
+ rows: setMRows,
+ cols: setColumns,
+ colType: setColType,
+ errPos: -1,
}
return cursor, nil
}
@@ -760,10 +892,12 @@ func (tx *fakeTx) Rollback() error {
}
type rowsCursor struct {
- cols []string
- pos int
- rows []*row
- closed bool
+ cols [][]string
+ colType [][]string
+ posSet int
+ posRow int
+ rows [][]*row
+ closed bool
// errPos and err are for making Next return early with error.
errPos int
@@ -786,7 +920,11 @@ func (rc *rowsCursor) Close() error {
}
func (rc *rowsCursor) Columns() []string {
- return rc.cols
+ return rc.cols[rc.posSet]
+}
+
+func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
+ return colTypeToReflectType(rc.colType[rc.posSet][index])
}
var rowsCursorNextHook func(dest []driver.Value) error
@@ -799,14 +937,14 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
if rc.closed {
return errors.New("fakedb: cursor is closed")
}
- rc.pos++
- if rc.pos == rc.errPos {
+ rc.posRow++
+ if rc.posRow == rc.errPos {
return rc.err
}
- if rc.pos >= len(rc.rows) {
+ if rc.posRow >= len(rc.rows[rc.posSet]) {
return io.EOF // per interface spec
}
- for i, v := range rc.rows[rc.pos].cols {
+ for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
// TODO(bradfitz): convert to subset types? naah, I
// think the subset types should only be input to
// driver, but the sql package should be able to handle
@@ -831,6 +969,19 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
return nil
}
+func (rc *rowsCursor) HasNextResultSet() bool {
+ return rc.posSet < len(rc.rows)-1
+}
+
+func (rc *rowsCursor) NextResultSet() error {
+ if rc.HasNextResultSet() {
+ rc.posSet++
+ rc.posRow = -1
+ return nil
+ }
+ return io.EOF // Per interface spec.
+}
+
// fakeDriverString is like driver.String, but indirects pointers like
// DefaultValueConverter.
//
@@ -882,3 +1033,29 @@ func converterForType(typ string) driver.ValueConverter {
}
panic("invalid fakedb column type of " + typ)
}
+
+func colTypeToReflectType(typ string) reflect.Type {
+ switch typ {
+ case "bool":
+ return reflect.TypeOf(false)
+ case "nullbool":
+ return reflect.TypeOf(NullBool{})
+ case "int32":
+ return reflect.TypeOf(int32(0))
+ case "string":
+ return reflect.TypeOf("")
+ case "nullstring":
+ return reflect.TypeOf(NullString{})
+ case "int64":
+ return reflect.TypeOf(int64(0))
+ case "nullint64":
+ return reflect.TypeOf(NullInt64{})
+ case "float64":
+ return reflect.TypeOf(float64(0))
+ case "nullfloat64":
+ return reflect.TypeOf(NullFloat64{})
+ case "datetime":
+ return reflect.TypeOf(time.Time{})
+ }
+ panic("invalid fakedb column type of " + typ)
+}
diff --git a/libgo/go/database/sql/sql.go b/libgo/go/database/sql/sql.go
index 09de1c3..0fa7c34 100644
--- a/libgo/go/database/sql/sql.go
+++ b/libgo/go/database/sql/sql.go
@@ -8,15 +8,20 @@
// The sql package must be used in conjunction with a database driver.
// See https://golang.org/s/sqldrivers for a list of drivers.
//
-// For more usage examples, see the wiki page at
+// Drivers that do not support context cancelation will not return until
+// after the query is completed.
+//
+// For usage examples, see the wiki page at
// https://golang.org/s/sqlwiki.
package sql
import (
+ "context"
"database/sql/driver"
"errors"
"fmt"
"io"
+ "reflect"
"runtime"
"sort"
"sync"
@@ -66,6 +71,75 @@ func Drivers() []string {
return list
}
+// A NamedArg is a named argument. NamedArg values may be used as
+// arguments to Query or Exec and bind to the corresponding named
+// parameter in the SQL statement.
+//
+// For a more concise way to create NamedArg values, see
+// the Named function.
+type NamedArg struct {
+ _Named_Fields_Required struct{}
+
+ // Name is the name of the parameter placeholder.
+ //
+ // If empty, the ordinal position in the argument list will be
+ // used.
+ //
+ // Name must omit any symbol prefix.
+ Name string
+
+ // Value is the value of the parameter.
+ // It may be assigned the same value types as the query
+ // arguments.
+ Value interface{}
+}
+
+// Named provides a more concise way to create NamedArg values.
+//
+// Example usage:
+//
+// db.ExecContext(ctx, `
+// delete from Invoice
+// where
+// TimeCreated < @end
+// and TimeCreated >= @start;`,
+// sql.Named("start", startTime),
+// sql.Named("end", endTime),
+// )
+func Named(name string, value interface{}) NamedArg {
+ // This method exists because the go1compat promise
+ // doesn't guarantee that structs don't grow more fields,
+ // so unkeyed struct literals are a vet error. Thus, we don't
+ // want to allow sql.NamedArg{name, value}.
+ return NamedArg{Name: name, Value: value}
+}
+
+// IsolationLevel is the transaction isolation level used in TxOptions.
+type IsolationLevel int
+
+// Various isolation levels that drivers may support in BeginTx.
+// If a driver does not support a given isolation level an error may be returned.
+//
+// See https://en.wikipedia.org/wiki/Isolation_(database_systems)#Isolation_levels.
+const (
+ LevelDefault IsolationLevel = iota
+ LevelReadUncommitted
+ LevelReadCommitted
+ LevelWriteCommitted
+ LevelRepeatableRead
+ LevelSnapshot
+ LevelSerializable
+ LevelLinearizable
+)
+
+// TxOptions holds the transaction options to be used in DB.BeginTx.
+type TxOptions struct {
+ // Isolation is the transaction isolation level.
+ // If zero, the driver or database's default level is used.
+ Isolation IsolationLevel
+ ReadOnly bool
+}
+
// RawBytes is a byte slice that holds a reference to memory owned by
// the database itself. After a Scan into a RawBytes, the slice is only
// valid until the next call to Next, Scan, or Close.
@@ -272,7 +346,7 @@ type driverConn struct {
ci driver.Conn
closed bool
finalClosed bool // ci.Close has been called
- openStmt map[driver.Stmt]bool
+ openStmt map[*driverStmt]bool
// guarded by db.mu
inUse bool
@@ -284,10 +358,10 @@ func (dc *driverConn) releaseConn(err error) {
dc.db.putConn(dc, err)
}
-func (dc *driverConn) removeOpenStmt(si driver.Stmt) {
+func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
dc.Lock()
defer dc.Unlock()
- delete(dc.openStmt, si)
+ delete(dc.openStmt, ds)
}
func (dc *driverConn) expired(timeout time.Duration) bool {
@@ -297,28 +371,23 @@ func (dc *driverConn) expired(timeout time.Duration) bool {
return dc.createdAt.Add(timeout).Before(nowFunc())
}
-func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) {
- si, err := dc.ci.Prepare(query)
- if err == nil {
- // Track each driverConn's open statements, so we can close them
- // before closing the conn.
- //
- // TODO(bradfitz): let drivers opt out of caring about
- // stmt closes if the conn is about to close anyway? For now
- // do the safe thing, in case stmts need to be closed.
- //
- // TODO(bradfitz): after Go 1.2, closing driver.Stmts
- // should be moved to driverStmt, using unique
- // *driverStmts everywhere (including from
- // *Stmt.connStmt, instead of returning a
- // driver.Stmt), using driverStmt as a pointer
- // everywhere, and making it a finalCloser.
- if dc.openStmt == nil {
- dc.openStmt = make(map[driver.Stmt]bool)
- }
- dc.openStmt[si] = true
+func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) {
+ si, err := ctxDriverPrepare(ctx, dc.ci, query)
+ if err != nil {
+ return nil, err
}
- return si, err
+
+ // Track each driverConn's open statements, so we can close them
+ // before closing the conn.
+ //
+ // Wrap all driver.Stmt is *driverStmt to ensure they are only closed once.
+ if dc.openStmt == nil {
+ dc.openStmt = make(map[*driverStmt]bool)
+ }
+ ds := &driverStmt{Locker: dc, si: si}
+ dc.openStmt[ds] = true
+
+ return ds, nil
}
// the dc.db's Mutex is held.
@@ -350,17 +419,26 @@ func (dc *driverConn) Close() error {
}
func (dc *driverConn) finalClose() error {
- dc.Lock()
+ var err error
- for si := range dc.openStmt {
- si.Close()
+ // Each *driverStmt has a lock to the dc. Copy the list out of the dc
+ // before calling close on each stmt.
+ var openStmt []*driverStmt
+ withLock(dc, func() {
+ openStmt = make([]*driverStmt, 0, len(dc.openStmt))
+ for ds := range dc.openStmt {
+ openStmt = append(openStmt, ds)
+ }
+ dc.openStmt = nil
+ })
+ for _, ds := range openStmt {
+ ds.Close()
}
- dc.openStmt = nil
-
- err := dc.ci.Close()
- dc.ci = nil
- dc.finalClosed = true
- dc.Unlock()
+ withLock(dc, func() {
+ dc.finalClosed = true
+ err = dc.ci.Close()
+ dc.ci = nil
+ })
dc.db.mu.Lock()
dc.db.numOpen--
@@ -377,12 +455,21 @@ func (dc *driverConn) finalClose() error {
type driverStmt struct {
sync.Locker // the *driverConn
si driver.Stmt
+ closed bool
+ closeErr error // return value of previous Close call
}
+// Close ensures dirver.Stmt is only closed once any always returns the same
+// result.
func (ds *driverStmt) Close() error {
ds.Lock()
defer ds.Unlock()
- return ds.si.Close()
+ if ds.closed {
+ return ds.closeErr
+ }
+ ds.closed = true
+ ds.closeErr = ds.si.Close()
+ return ds.closeErr
}
// depSet is a finalCloser's outstanding dependencies
@@ -494,18 +581,36 @@ func Open(driverName, dataSourceName string) (*DB, error) {
return db, nil
}
-// Ping verifies a connection to the database is still alive,
+// PingContext verifies a connection to the database is still alive,
// establishing a connection if necessary.
-func (db *DB) Ping() error {
- // TODO(bradfitz): give drivers an optional hook to implement
- // this in a more efficient or more reliable way, if they
- // have one.
- dc, err := db.conn(cachedOrNewConn)
+func (db *DB) PingContext(ctx context.Context) error {
+ var dc *driverConn
+ var err error
+
+ for i := 0; i < maxBadConnRetries; i++ {
+ dc, err = db.conn(ctx, cachedOrNewConn)
+ if err != driver.ErrBadConn {
+ break
+ }
+ }
+ if err == driver.ErrBadConn {
+ dc, err = db.conn(ctx, alwaysNewConn)
+ }
if err != nil {
return err
}
- db.putConn(dc, nil)
- return nil
+
+ if pinger, ok := dc.ci.(driver.Pinger); ok {
+ err = pinger.Ping(ctx)
+ }
+ db.putConn(dc, err)
+ return err
+}
+
+// Ping verifies a connection to the database is still alive,
+// establishing a connection if necessary.
+func (db *DB) Ping() error {
+ return db.PingContext(context.Background())
}
// Close closes the database, releasing any open resources.
@@ -777,12 +882,19 @@ type connRequest struct {
var errDBClosed = errors.New("sql: database is closed")
// conn returns a newly-opened or cached *driverConn.
-func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) {
+func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
db.mu.Lock()
if db.closed {
db.mu.Unlock()
return nil, errDBClosed
}
+ // Check if the context is expired.
+ select {
+ default:
+ case <-ctx.Done():
+ db.mu.Unlock()
+ return nil, ctx.Err()
+ }
lifetime := db.maxLifetime
// Prefer a free connection, if possible.
@@ -808,15 +920,21 @@ func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) {
req := make(chan connRequest, 1)
db.connRequests = append(db.connRequests, req)
db.mu.Unlock()
- ret, ok := <-req
- if !ok {
- return nil, errDBClosed
- }
- if ret.err == nil && ret.conn.expired(lifetime) {
- ret.conn.Close()
- return nil, driver.ErrBadConn
+
+ // Timeout the connection request with the context.
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case ret, ok := <-req:
+ if !ok {
+ return nil, errDBClosed
+ }
+ if ret.err == nil && ret.conn.expired(lifetime) {
+ ret.conn.Close()
+ return nil, driver.ErrBadConn
+ }
+ return ret.conn, ret.err
}
- return ret.conn, ret.err
}
db.numOpen++ // optimistically
@@ -844,21 +962,22 @@ func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) {
// putConnHook is a hook for testing.
var putConnHook func(*DB, *driverConn)
-// noteUnusedDriverStatement notes that si is no longer used and should
+// noteUnusedDriverStatement notes that ds is no longer used and should
// be closed whenever possible (when c is next not in use), unless c is
// already closed.
-func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) {
+func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
db.mu.Lock()
defer db.mu.Unlock()
if c.inUse {
c.onPut = append(c.onPut, func() {
- si.Close()
+ ds.Close()
})
} else {
c.Lock()
- defer c.Unlock()
- if !c.finalClosed {
- si.Close()
+ fc := c.finalClosed
+ c.Unlock()
+ if !fc {
+ ds.Close()
}
}
}
@@ -952,40 +1071,53 @@ func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
// connection to be opened.
const maxBadConnRetries = 2
-// Prepare creates a prepared statement for later queries or executions.
+// PrepareContext creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
// The caller must call the statement's Close method
// when the statement is no longer needed.
-func (db *DB) Prepare(query string) (*Stmt, error) {
+//
+// The provided context is used for the preparation of the statement, not for the
+// execution of the statement.
+func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
var stmt *Stmt
var err error
for i := 0; i < maxBadConnRetries; i++ {
- stmt, err = db.prepare(query, cachedOrNewConn)
+ stmt, err = db.prepare(ctx, query, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
- return db.prepare(query, alwaysNewConn)
+ return db.prepare(ctx, query, alwaysNewConn)
}
return stmt, err
}
-func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
+// Prepare creates a prepared statement for later queries or executions.
+// Multiple queries or executions may be run concurrently from the
+// returned statement.
+// The caller must call the statement's Close method
+// when the statement is no longer needed.
+func (db *DB) Prepare(query string) (*Stmt, error) {
+ return db.PrepareContext(context.Background(), query)
+}
+
+func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
// TODO: check if db.driver supports an optional
// driver.Preparer interface and call that instead, if so,
// otherwise we make a prepared statement that's bound
// to a connection, and to execute this prepared statement
// we either need to use this connection (if it's free), else
// get a new connection + re-prepare + execute on that one.
- dc, err := db.conn(strategy)
+ dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
- dc.Lock()
- si, err := dc.prepareLocked(query)
- dc.Unlock()
+ var ds *driverStmt
+ withLock(dc, func() {
+ ds, err = dc.prepareLocked(ctx, query)
+ })
if err != nil {
db.putConn(dc, err)
return nil, err
@@ -993,7 +1125,7 @@ func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
stmt := &Stmt{
db: db,
query: query,
- css: []connStmt{{dc, si}},
+ css: []connStmt{{dc, ds}},
lastNumClosed: atomic.LoadUint64(&db.numClosed),
}
db.addDep(stmt, stmt)
@@ -1001,25 +1133,31 @@ func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
return stmt, nil
}
-// Exec executes a query without returning any rows.
+// ExecContext executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
-func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
+func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
var res Result
var err error
for i := 0; i < maxBadConnRetries; i++ {
- res, err = db.exec(query, args, cachedOrNewConn)
+ res, err = db.exec(ctx, query, args, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
- return db.exec(query, args, alwaysNewConn)
+ return db.exec(ctx, query, args, alwaysNewConn)
}
return res, err
}
-func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
- dc, err := db.conn(strategy)
+// Exec executes a query without returning any rows.
+// The args are for any placeholder parameters in the query.
+func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
+ return db.ExecContext(context.Background(), query, args...)
+}
+
+func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
+ dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
@@ -1028,13 +1166,15 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy)
}()
if execer, ok := dc.ci.(driver.Execer); ok {
- dargs, err := driverArgs(nil, args)
+ var dargs []driver.NamedValue
+ dargs, err = driverArgs(nil, args)
if err != nil {
return nil, err
}
- dc.Lock()
- resi, err := execer.Exec(query, dargs)
- dc.Unlock()
+ var resi driver.Result
+ withLock(dc, func() {
+ resi, err = ctxDriverExec(ctx, execer, query, dargs)
+ })
if err != driver.ErrSkip {
if err != nil {
return nil, err
@@ -1043,54 +1183,63 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy)
}
}
- dc.Lock()
- si, err := dc.ci.Prepare(query)
- dc.Unlock()
+ var si driver.Stmt
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
+ })
if err != nil {
return nil, err
}
- defer withLock(dc, func() { si.Close() })
- return resultFromStatement(driverStmt{dc, si}, args...)
+ ds := &driverStmt{Locker: dc, si: si}
+ defer ds.Close()
+ return resultFromStatement(ctx, ds, args...)
}
-// Query executes a query that returns rows, typically a SELECT.
+// QueryContext executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
-func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
+func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
var rows *Rows
var err error
for i := 0; i < maxBadConnRetries; i++ {
- rows, err = db.query(query, args, cachedOrNewConn)
+ rows, err = db.query(ctx, query, args, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
- return db.query(query, args, alwaysNewConn)
+ return db.query(ctx, query, args, alwaysNewConn)
}
return rows, err
}
-func (db *DB) query(query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
- ci, err := db.conn(strategy)
+// Query executes a query that returns rows, typically a SELECT.
+// The args are for any placeholder parameters in the query.
+func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
+ return db.QueryContext(context.Background(), query, args...)
+}
+
+func (db *DB) query(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
+ ci, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
- return db.queryConn(ci, ci.releaseConn, query, args)
+ return db.queryConn(ctx, ci, ci.releaseConn, query, args)
}
// queryConn executes a query on the given connection.
// The connection gets released by the releaseConn function.
-func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
+func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
if queryer, ok := dc.ci.(driver.Queryer); ok {
dargs, err := driverArgs(nil, args)
if err != nil {
releaseConn(err)
return nil, err
}
- dc.Lock()
- rowsi, err := queryer.Query(query, dargs)
- dc.Unlock()
+ var rowsi driver.Rows
+ withLock(dc, func() {
+ rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs)
+ })
if err != driver.ErrSkip {
if err != nil {
releaseConn(err)
@@ -1103,24 +1252,25 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
releaseConn: releaseConn,
rowsi: rowsi,
}
+ rows.initContextClose(ctx)
return rows, nil
}
}
- dc.Lock()
- si, err := dc.ci.Prepare(query)
- dc.Unlock()
+ var si driver.Stmt
+ var err error
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
+ })
if err != nil {
releaseConn(err)
return nil, err
}
- ds := driverStmt{dc, si}
- rowsi, err := rowsiFromStatement(ds, args...)
+ ds := &driverStmt{Locker: dc, si: si}
+ rowsi, err := rowsiFromStatement(ctx, ds, args...)
if err != nil {
- dc.Lock()
- si.Close()
- dc.Unlock()
+ ds.Close()
releaseConn(err)
return nil, err
}
@@ -1131,53 +1281,93 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
dc: dc,
releaseConn: releaseConn,
rowsi: rowsi,
- closeStmt: si,
+ closeStmt: ds,
}
+ rows.initContextClose(ctx)
return rows, nil
}
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
+ rows, err := db.QueryContext(ctx, query, args...)
+ return &Row{rows: rows, err: err}
+}
+
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (db *DB) QueryRow(query string, args ...interface{}) *Row {
- rows, err := db.Query(query, args...)
- return &Row{rows: rows, err: err}
+ return db.QueryRowContext(context.Background(), query, args...)
}
-// Begin starts a transaction. The isolation level is dependent on
-// the driver.
-func (db *DB) Begin() (*Tx, error) {
+// BeginTx starts a transaction.
+//
+// The provided context is used until the transaction is committed or rolled back.
+// If the context is canceled, the sql package will roll back
+// the transaction. Tx.Commit will return an error if the context provided to
+// BeginTx is canceled.
+//
+// The provided TxOptions is optional and may be nil if defaults should be used.
+// If a non-default isolation level is used that the driver doesn't support,
+// an error will be returned.
+func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
var tx *Tx
var err error
for i := 0; i < maxBadConnRetries; i++ {
- tx, err = db.begin(cachedOrNewConn)
+ tx, err = db.begin(ctx, opts, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
- return db.begin(alwaysNewConn)
+ return db.begin(ctx, opts, alwaysNewConn)
}
return tx, err
}
-func (db *DB) begin(strategy connReuseStrategy) (tx *Tx, err error) {
- dc, err := db.conn(strategy)
+// Begin starts a transaction. The default isolation level is dependent on
+// the driver.
+func (db *DB) Begin() (*Tx, error) {
+ return db.BeginTx(context.Background(), nil)
+}
+
+func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) {
+ dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
- dc.Lock()
- txi, err := dc.ci.Begin()
- dc.Unlock()
+ var txi driver.Tx
+ withLock(dc, func() {
+ txi, err = ctxDriverBegin(ctx, opts, dc.ci)
+ })
if err != nil {
db.putConn(dc, err)
return nil, err
}
- return &Tx{
- db: db,
- dc: dc,
- txi: txi,
- }, nil
+
+ // Schedule the transaction to rollback when the context is cancelled.
+ // The cancel function in Tx will be called after done is set to true.
+ ctx, cancel := context.WithCancel(ctx)
+ tx = &Tx{
+ db: db,
+ dc: dc,
+ txi: txi,
+ cancel: cancel,
+ ctx: ctx,
+ }
+ go func(tx *Tx) {
+ select {
+ case <-tx.ctx.Done():
+ if !tx.isDone() {
+ // Discard and close the connection used to ensure the transaction
+ // is closed and the resources are released.
+ tx.rollback(true)
+ }
+ }
+ }(tx)
+ return tx, nil
}
// Driver returns the database's underlying driver.
@@ -1203,10 +1393,11 @@ type Tx struct {
dc *driverConn
txi driver.Tx
- // done transitions from false to true exactly once, on Commit
+ // done transitions from 0 to 1 exactly once, on Commit
// or Rollback. once done, all operations fail with
// ErrTxDone.
- done bool
+ // Use atomic operations on value when checking value.
+ done int32
// All Stmts prepared for this transaction. These will be closed after the
// transaction has been committed or rolled back.
@@ -1214,22 +1405,33 @@ type Tx struct {
sync.Mutex
v []*Stmt
}
+
+ // cancel is called after done transitions from false to true.
+ cancel func()
+
+ // ctx lives for the life of the transaction.
+ ctx context.Context
+}
+
+func (tx *Tx) isDone() bool {
+ return atomic.LoadInt32(&tx.done) != 0
}
+// ErrTxDone is returned by any operation that is performed on a transaction
+// that has already been committed or rolled back.
var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
+// close returns the connection to the pool and
+// must only be called by Tx.rollback or Tx.Commit.
func (tx *Tx) close(err error) {
- if tx.done {
- panic("double close") // internal error
- }
- tx.done = true
tx.db.putConn(tx.dc, err)
+ tx.cancel()
tx.dc = nil
tx.txi = nil
}
-func (tx *Tx) grabConn() (*driverConn, error) {
- if tx.done {
+func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
+ if tx.isDone() {
return nil, ErrTxDone
}
return tx.dc, nil
@@ -1238,20 +1440,26 @@ func (tx *Tx) grabConn() (*driverConn, error) {
// Closes all Stmts prepared for this transaction.
func (tx *Tx) closePrepared() {
tx.stmts.Lock()
+ defer tx.stmts.Unlock()
for _, stmt := range tx.stmts.v {
stmt.Close()
}
- tx.stmts.Unlock()
}
// Commit commits the transaction.
func (tx *Tx) Commit() error {
- if tx.done {
+ if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
return ErrTxDone
}
- tx.dc.Lock()
- err := tx.txi.Commit()
- tx.dc.Unlock()
+ select {
+ default:
+ case <-tx.ctx.Done():
+ return tx.ctx.Err()
+ }
+ var err error
+ withLock(tx.dc, func() {
+ err = tx.txi.Commit()
+ })
if err != driver.ErrBadConn {
tx.closePrepared()
}
@@ -1259,28 +1467,42 @@ func (tx *Tx) Commit() error {
return err
}
-// Rollback aborts the transaction.
-func (tx *Tx) Rollback() error {
- if tx.done {
+// rollback aborts the transaction and optionally forces the pool to discard
+// the connection.
+func (tx *Tx) rollback(discardConn bool) error {
+ if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
return ErrTxDone
}
- tx.dc.Lock()
- err := tx.txi.Rollback()
- tx.dc.Unlock()
+ var err error
+ withLock(tx.dc, func() {
+ err = tx.txi.Rollback()
+ })
if err != driver.ErrBadConn {
tx.closePrepared()
}
+ if discardConn {
+ err = driver.ErrBadConn
+ }
tx.close(err)
return err
}
+// Rollback aborts the transaction.
+func (tx *Tx) Rollback() error {
+ return tx.rollback(false)
+}
+
// Prepare creates a prepared statement for use within a transaction.
//
-// The returned statement operates within the transaction and can no longer
-// be used once the transaction has been committed or rolled back.
+// The returned statement operates within the transaction and will be closed
+// when the transaction has been committed or rolled back.
//
// To use an existing prepared statement on this transaction, see Tx.Stmt.
-func (tx *Tx) Prepare(query string) (*Stmt, error) {
+//
+// The provided context will be used for the preparation of the context, not
+// for the execution of the returned statement. The returned statement
+// will run in the transaction context.
+func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
// TODO(bradfitz): We could be more efficient here and either
// provide a method to take an existing Stmt (created on
// perhaps a different Conn), and re-create it on this Conn if
@@ -1294,14 +1516,15 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
// Perhaps just looking at the reference count (by noting
// Stmt.Close) would be enough. We might also want a finalizer
// on Stmt to drop the reference count.
- dc, err := tx.grabConn()
+ dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
- dc.Lock()
- si, err := dc.ci.Prepare(query)
- dc.Unlock()
+ var si driver.Stmt
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
+ })
if err != nil {
return nil, err
}
@@ -1309,7 +1532,7 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
stmt := &Stmt{
db: tx.db,
tx: tx,
- txsi: &driverStmt{
+ txds: &driverStmt{
Locker: dc,
si: si,
},
@@ -1321,7 +1544,17 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
return stmt, nil
}
-// Stmt returns a transaction-specific prepared statement from
+// Prepare creates a prepared statement for use within a transaction.
+//
+// The returned statement operates within the transaction and can no longer
+// be used once the transaction has been committed or rolled back.
+//
+// To use an existing prepared statement on this transaction, see Tx.Stmt.
+func (tx *Tx) Prepare(query string) (*Stmt, error) {
+ return tx.PrepareContext(context.Background(), query)
+}
+
+// StmtContext returns a transaction-specific prepared statement from
// an existing statement.
//
// Example:
@@ -1329,11 +1562,11 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
// ...
// tx, err := db.Begin()
// ...
-// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+// res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203)
//
-// The returned statement operates within the transaction and can no longer
-// be used once the transaction has been committed or rolled back.
-func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+// The returned statement operates within the transaction and will be closed
+// when the transaction has been committed or rolled back.
+func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
// TODO(bradfitz): optimize this. Currently this re-prepares
// each time. This is fine for now to illustrate the API but
// we should really cache already-prepared statements
@@ -1342,17 +1575,18 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
- dc, err := tx.grabConn()
+ dc, err := tx.grabConn(ctx)
if err != nil {
return &Stmt{stickyErr: err}
}
- dc.Lock()
- si, err := dc.ci.Prepare(stmt.query)
- dc.Unlock()
+ var si driver.Stmt
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
+ })
txs := &Stmt{
db: tx.db,
tx: tx,
- txsi: &driverStmt{
+ txds: &driverStmt{
Locker: dc,
si: si,
},
@@ -1365,10 +1599,26 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
return txs
}
-// Exec executes a query that doesn't return rows.
+// Stmt returns a transaction-specific prepared statement from
+// an existing statement.
+//
+// Example:
+// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
+// ...
+// tx, err := db.Begin()
+// ...
+// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+//
+// The returned statement operates within the transaction and will be closed
+// when the transaction has been committed or rolled back.
+func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+ return tx.StmtContext(context.Background(), stmt)
+}
+
+// ExecContext executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
-func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
- dc, err := tx.grabConn()
+func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
+ dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
@@ -1378,9 +1628,10 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
if err != nil {
return nil, err
}
- dc.Lock()
- resi, err := execer.Exec(query, dargs)
- dc.Unlock()
+ var resi driver.Result
+ withLock(dc, func() {
+ resi, err = ctxDriverExec(ctx, execer, query, dargs)
+ })
if err == nil {
return driverResult{dc, resi}, nil
}
@@ -1389,39 +1640,59 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
}
}
- dc.Lock()
- si, err := dc.ci.Prepare(query)
- dc.Unlock()
+ var si driver.Stmt
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
+ })
if err != nil {
return nil, err
}
- defer withLock(dc, func() { si.Close() })
+ ds := &driverStmt{Locker: dc, si: si}
+ defer ds.Close()
- return resultFromStatement(driverStmt{dc, si}, args...)
+ return resultFromStatement(ctx, ds, args...)
}
-// Query executes a query that returns rows, typically a SELECT.
-func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
- dc, err := tx.grabConn()
+// Exec executes a query that doesn't return rows.
+// For example: an INSERT and UPDATE.
+func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
+ return tx.ExecContext(context.Background(), query, args...)
+}
+
+// QueryContext executes a query that returns rows, typically a SELECT.
+func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
+ dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
releaseConn := func(error) {}
- return tx.db.queryConn(dc, releaseConn, query, args)
+ return tx.db.queryConn(ctx, dc, releaseConn, query, args)
+}
+
+// Query executes a query that returns rows, typically a SELECT.
+func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
+ return tx.QueryContext(context.Background(), query, args...)
+}
+
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
+ rows, err := tx.QueryContext(ctx, query, args...)
+ return &Row{rows: rows, err: err}
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
- rows, err := tx.Query(query, args...)
- return &Row{rows: rows, err: err}
+ return tx.QueryRowContext(context.Background(), query, args...)
}
// connStmt is a prepared statement on a particular connection.
type connStmt struct {
dc *driverConn
- si driver.Stmt
+ ds *driverStmt
}
// Stmt is a prepared statement.
@@ -1436,7 +1707,7 @@ type Stmt struct {
// If in a transaction, else both nil:
tx *Tx
- txsi *driverStmt
+ txds *driverStmt
mu sync.Mutex // protects the rest of the fields
closed bool
@@ -1452,15 +1723,15 @@ type Stmt struct {
lastNumClosed uint64
}
-// Exec executes a prepared statement with the given arguments and
+// ExecContext executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
-func (s *Stmt) Exec(args ...interface{}) (Result, error) {
+func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
var res Result
for i := 0; i < maxBadConnRetries; i++ {
- dc, releaseConn, si, err := s.connStmt()
+ _, releaseConn, ds, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
@@ -1468,7 +1739,7 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
return nil, err
}
- res, err = resultFromStatement(driverStmt{dc, si}, args...)
+ res, err = resultFromStatement(ctx, ds, args...)
releaseConn(err)
if err != driver.ErrBadConn {
return res, err
@@ -1477,13 +1748,19 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
return nil, driver.ErrBadConn
}
-func driverNumInput(ds driverStmt) int {
+// Exec executes a prepared statement with the given arguments and
+// returns a Result summarizing the effect of the statement.
+func (s *Stmt) Exec(args ...interface{}) (Result, error) {
+ return s.ExecContext(context.Background(), args...)
+}
+
+func driverNumInput(ds *driverStmt) int {
ds.Lock()
defer ds.Unlock() // in case NumInput panics
return ds.si.NumInput()
}
-func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
+func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) {
want := driverNumInput(ds)
// -1 means the driver doesn't know how to count the number of
@@ -1493,14 +1770,15 @@ func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args))
}
- dargs, err := driverArgs(&ds, args)
+ dargs, err := driverArgs(ds, args)
if err != nil {
return nil, err
}
ds.Lock()
defer ds.Unlock()
- resi, err := ds.si.Exec(dargs)
+
+ resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
if err != nil {
return nil, err
}
@@ -1536,7 +1814,7 @@ func (s *Stmt) removeClosedStmtLocked() {
// connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a
// statement bound to that connection.
-func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
+func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) {
if err = s.stickyErr; err != nil {
return
}
@@ -1551,19 +1829,18 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
// transaction was created on.
if s.tx != nil {
s.mu.Unlock()
- ci, err = s.tx.grabConn() // blocks, waiting for the connection.
+ ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection.
if err != nil {
return
}
releaseConn = func(error) {}
- return ci, releaseConn, s.txsi.si, nil
+ return ci, releaseConn, s.txds, nil
}
s.removeClosedStmtLocked()
s.mu.Unlock()
- // TODO(bradfitz): or always wait for one? make configurable later?
- dc, err := s.db.conn(cachedOrNewConn)
+ dc, err := s.db.conn(ctx, cachedOrNewConn)
if err != nil {
return nil, nil, nil, err
}
@@ -1572,36 +1849,36 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
for _, v := range s.css {
if v.dc == dc {
s.mu.Unlock()
- return dc, dc.releaseConn, v.si, nil
+ return dc, dc.releaseConn, v.ds, nil
}
}
s.mu.Unlock()
// No luck; we need to prepare the statement on this connection
- dc.Lock()
- si, err = dc.prepareLocked(s.query)
- dc.Unlock()
+ withLock(dc, func() {
+ ds, err = dc.prepareLocked(ctx, s.query)
+ })
if err != nil {
s.db.putConn(dc, err)
return nil, nil, nil, err
}
s.mu.Lock()
- cs := connStmt{dc, si}
+ cs := connStmt{dc, ds}
s.css = append(s.css, cs)
s.mu.Unlock()
- return dc, dc.releaseConn, si, nil
+ return dc, dc.releaseConn, ds, nil
}
-// Query executes a prepared query statement with the given arguments
+// QueryContext executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
-func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
+func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
var rowsi driver.Rows
for i := 0; i < maxBadConnRetries; i++ {
- dc, releaseConn, si, err := s.connStmt()
+ dc, releaseConn, ds, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
@@ -1609,7 +1886,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return nil, err
}
- rowsi, err = rowsiFromStatement(driverStmt{dc, si}, args...)
+ rowsi, err = rowsiFromStatement(ctx, ds, args...)
if err == nil {
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
@@ -1618,6 +1895,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
rowsi: rowsi,
// releaseConn set below
}
+ rows.initContextClose(ctx)
s.db.addDep(s, rows)
rows.releaseConn = func(err error) {
releaseConn(err)
@@ -1634,10 +1912,17 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return nil, driver.ErrBadConn
}
-func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) {
- ds.Lock()
- want := ds.si.NumInput()
- ds.Unlock()
+// Query executes a prepared query statement with the given arguments
+// and returns the query results as a *Rows.
+func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
+ return s.QueryContext(context.Background(), args...)
+}
+
+func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
+ var want int
+ withLock(ds, func() {
+ want = ds.si.NumInput()
+ })
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
@@ -1646,21 +1931,22 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error)
return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
}
- dargs, err := driverArgs(&ds, args)
+ dargs, err := driverArgs(ds, args)
if err != nil {
return nil, err
}
ds.Lock()
- rowsi, err := ds.si.Query(dargs)
- ds.Unlock()
+ defer ds.Unlock()
+
+ rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs)
if err != nil {
return nil, err
}
return rowsi, nil
}
-// QueryRow executes a prepared query statement with the given arguments.
+// QueryRowContext executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
@@ -1670,15 +1956,30 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error)
// Example usage:
//
// var name string
-// err := nameByUseridStmt.QueryRow(id).Scan(&name)
-func (s *Stmt) QueryRow(args ...interface{}) *Row {
- rows, err := s.Query(args...)
+// err := nameByUseridStmt.QueryRowContext(ctx, id).Scan(&name)
+func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
+ rows, err := s.QueryContext(ctx, args...)
if err != nil {
return &Row{err: err}
}
return &Row{rows: rows}
}
+// QueryRow executes a prepared query statement with the given arguments.
+// If an error occurs during the execution of the statement, that error will
+// be returned by a call to Scan on the returned *Row, which is always non-nil.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
+//
+// Example usage:
+//
+// var name string
+// err := nameByUseridStmt.QueryRow(id).Scan(&name)
+func (s *Stmt) QueryRow(args ...interface{}) *Row {
+ return s.QueryRowContext(context.Background(), args...)
+}
+
// Close closes the statement.
func (s *Stmt) Close() error {
s.closemu.Lock()
@@ -1693,13 +1994,11 @@ func (s *Stmt) Close() error {
return nil
}
s.closed = true
+ s.mu.Unlock()
if s.tx != nil {
- err := s.txsi.Close()
- s.mu.Unlock()
- return err
+ return s.txds.Close()
}
- s.mu.Unlock()
return s.db.removeDep(s, s)
}
@@ -1709,8 +2008,8 @@ func (s *Stmt) finalClose() error {
defer s.mu.Unlock()
if s.css != nil {
for _, v := range s.css {
- s.db.noteUnusedDriverStatement(v.dc, v.si)
- v.dc.removeOpenStmt(v.si)
+ s.db.noteUnusedDriverStatement(v.dc, v.ds)
+ v.dc.removeOpenStmt(v.ds)
}
s.css = nil
}
@@ -1736,10 +2035,28 @@ type Rows struct {
releaseConn func(error)
rowsi driver.Rows
- closed bool
+ // closed value is 1 when the Rows is closed.
+ // Use atomic operations on value when checking value.
+ closed int32
+ ctxClose chan struct{} // closed when Rows is closed, may be null.
lastcols []driver.Value
lasterr error // non-nil only if closed is true
- closeStmt driver.Stmt // if non-nil, statement to Close on close
+ closeStmt *driverStmt // if non-nil, statement to Close on close
+}
+
+func (rs *Rows) initContextClose(ctx context.Context) {
+ if ctx.Done() == context.Background().Done() {
+ return
+ }
+
+ rs.ctxClose = make(chan struct{})
+ go func() {
+ select {
+ case <-ctx.Done():
+ rs.Close()
+ case <-rs.ctxClose:
+ }
+ }()
}
// Next prepares the next result row for reading with the Scan method. It
@@ -1749,7 +2066,7 @@ type Rows struct {
//
// Every call to Scan, even the first one, must be preceded by a call to Next.
func (rs *Rows) Next() bool {
- if rs.closed {
+ if rs.isClosed() {
return false
}
if rs.lastcols == nil {
@@ -1757,6 +2074,47 @@ func (rs *Rows) Next() bool {
}
rs.lasterr = rs.rowsi.Next(rs.lastcols)
if rs.lasterr != nil {
+ // Close the connection if there is a driver error.
+ if rs.lasterr != io.EOF {
+ rs.Close()
+ return false
+ }
+ nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
+ if !ok {
+ rs.Close()
+ return false
+ }
+ // The driver is at the end of the current result set.
+ // Test to see if there is another result set after the current one.
+ // Only close Rows if there is no further result sets to read.
+ if !nextResultSet.HasNextResultSet() {
+ rs.Close()
+ }
+ return false
+ }
+ return true
+}
+
+// NextResultSet prepares the next result set for reading. It returns true if
+// there is further result sets, or false if there is no further result set
+// or if there is an error advancing to it. The Err method should be consulted
+// to distinguish between the two cases.
+//
+// After calling NextResultSet, the Next method should always be called before
+// scanning. If there are further result sets they may not have rows in the result
+// set.
+func (rs *Rows) NextResultSet() bool {
+ if rs.isClosed() {
+ return false
+ }
+ rs.lastcols = nil
+ nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
+ if !ok {
+ rs.Close()
+ return false
+ }
+ rs.lasterr = nextResultSet.NextResultSet()
+ if rs.lasterr != nil {
rs.Close()
return false
}
@@ -1776,7 +2134,7 @@ func (rs *Rows) Err() error {
// Columns returns an error if the rows are closed, or if the rows
// are from QueryRow and there was a deferred error.
func (rs *Rows) Columns() ([]string, error) {
- if rs.closed {
+ if rs.isClosed() {
return nil, errors.New("sql: Rows are closed")
}
if rs.rowsi == nil {
@@ -1785,6 +2143,107 @@ func (rs *Rows) Columns() ([]string, error) {
return rs.rowsi.Columns(), nil
}
+// ColumnTypes returns column information such as column type, length,
+// and nullable. Some information may not be available from some drivers.
+func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
+ if rs.isClosed() {
+ return nil, errors.New("sql: Rows are closed")
+ }
+ if rs.rowsi == nil {
+ return nil, errors.New("sql: no Rows available")
+ }
+ return rowsColumnInfoSetup(rs.rowsi), nil
+}
+
+// ColumnType contains the name and type of a column.
+type ColumnType struct {
+ name string
+
+ hasNullable bool
+ hasLength bool
+ hasPrecisionScale bool
+
+ nullable bool
+ length int64
+ databaseType string
+ precision int64
+ scale int64
+ scanType reflect.Type
+}
+
+// Name returns the name or alias of the column.
+func (ci *ColumnType) Name() string {
+ return ci.name
+}
+
+// Length returns the column type length for variable length column types such
+// as text and binary field types. If the type length is unbounded the value will
+// be math.MaxInt64 (any database limits will still apply).
+// If the column type is not variable length, such as an int, or if not supported
+// by the driver ok is false.
+func (ci *ColumnType) Length() (length int64, ok bool) {
+ return ci.length, ci.hasLength
+}
+
+// DecimalSize returns the scale and precision of a decimal type.
+// If not applicable or if not supported ok is false.
+func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
+ return ci.precision, ci.scale, ci.hasPrecisionScale
+}
+
+// ScanType returns a Go type suitable for scanning into using Rows.Scan.
+// If a driver does not support this property ScanType will return
+// the type of an empty interface.
+func (ci *ColumnType) ScanType() reflect.Type {
+ return ci.scanType
+}
+
+// Nullable returns whether the column may be null.
+// If a driver does not support this property ok will be false.
+func (ci *ColumnType) Nullable() (nullable, ok bool) {
+ return ci.nullable, ci.hasNullable
+}
+
+// DatabaseTypeName returns the database system name of the column type. If an empty
+// string is returned the driver type name is not supported.
+// Consult your driver documentation for a list of driver data types. Length specifiers
+// are not included.
+// Common type include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", "INT", "BIGINT".
+func (ci *ColumnType) DatabaseTypeName() string {
+ return ci.databaseType
+}
+
+func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType {
+ names := rowsi.Columns()
+
+ list := make([]*ColumnType, len(names))
+ for i := range list {
+ ci := &ColumnType{
+ name: names[i],
+ }
+ list[i] = ci
+
+ if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok {
+ ci.scanType = prop.ColumnTypeScanType(i)
+ } else {
+ ci.scanType = reflect.TypeOf(new(interface{})).Elem()
+ }
+ if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok {
+ ci.databaseType = prop.ColumnTypeDatabaseTypeName(i)
+ }
+ if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok {
+ ci.length, ci.hasLength = prop.ColumnTypeLength(i)
+ }
+ if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok {
+ ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i)
+ }
+ if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok {
+ ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i)
+ }
+ }
+ return list
+}
+
// Scan copies the columns in the current row into the values pointed
// at by dest. The number of values in dest must be the same as the
// number of columns in Rows.
@@ -1837,7 +2296,7 @@ func (rs *Rows) Columns() ([]string, error) {
// For scanning into *bool, the source may be true, false, 1, 0, or
// string inputs parseable by strconv.ParseBool.
func (rs *Rows) Scan(dest ...interface{}) error {
- if rs.closed {
+ if rs.isClosed() {
return errors.New("sql: Rows are closed")
}
if rs.lastcols == nil {
@@ -1857,14 +2316,21 @@ func (rs *Rows) Scan(dest ...interface{}) error {
var rowsCloseHook func(*Rows, *error)
-// Close closes the Rows, preventing further enumeration. If Next returns
-// false, the Rows are closed automatically and it will suffice to check the
+func (rs *Rows) isClosed() bool {
+ return atomic.LoadInt32(&rs.closed) != 0
+}
+
+// Close closes the Rows, preventing further enumeration. If Next is called
+// and returns false and there are no further result sets,
+// the Rows are closed automatically and it will suffice to check the
// result of Err. Close is idempotent and does not affect the result of Err.
func (rs *Rows) Close() error {
- if rs.closed {
+ if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) {
return nil
}
- rs.closed = true
+ if rs.ctxClose != nil {
+ close(rs.ctxClose)
+ }
err := rs.rowsi.Close()
if fn := rowsCloseHook; fn != nil {
fn(rs, &err)
diff --git a/libgo/go/database/sql/sql_test.go b/libgo/go/database/sql/sql_test.go
index 08df0c7..63e1292 100644
--- a/libgo/go/database/sql/sql_test.go
+++ b/libgo/go/database/sql/sql_test.go
@@ -5,6 +5,7 @@
package sql
import (
+ "context"
"database/sql/driver"
"errors"
"fmt"
@@ -23,6 +24,17 @@ func init() {
c *driverConn
}
freedFrom := make(map[dbConn]string)
+ var mu sync.Mutex
+ getFreedFrom := func(c dbConn) string {
+ mu.Lock()
+ defer mu.Unlock()
+ return freedFrom[c]
+ }
+ setFreedFrom := func(c dbConn, s string) {
+ mu.Lock()
+ defer mu.Unlock()
+ freedFrom[c] = s
+ }
putConnHook = func(db *DB, c *driverConn) {
idx := -1
for i, v := range db.freeConn {
@@ -35,10 +47,10 @@ func init() {
// print before panic, as panic may get lost due to conflicting panic
// (all goroutines asleep) elsewhere, since we might not unlock
// the mutex in freeConn here.
- println("double free of conn. conflicts are:\nA) " + freedFrom[dbConn{db, c}] + "\n\nand\nB) " + stack())
+ println("double free of conn. conflicts are:\nA) " + getFreedFrom(dbConn{db, c}) + "\n\nand\nB) " + stack())
panic("double free of conn.")
}
- freedFrom[dbConn{db, c}] = stack()
+ setFreedFrom(dbConn{db, c}, stack())
}
}
@@ -140,10 +152,7 @@ func closeDB(t testing.TB, db *DB) {
if err != nil {
t.Fatalf("error closing DB: %v", err)
}
- db.mu.Lock()
- count := db.numOpen
- db.mu.Unlock()
- if count != 0 {
+ if count := db.numOpenConns(); count != 0 {
t.Fatalf("%d connections still open after closing DB", count)
}
}
@@ -182,6 +191,12 @@ func (db *DB) numFreeConns() int {
return len(db.freeConn)
}
+func (db *DB) numOpenConns() int {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ return db.numOpen
+}
+
// clearAllConns closes all connections in db.
func (db *DB) clearAllConns(t *testing.T) {
db.SetMaxIdleConns(0)
@@ -260,6 +275,257 @@ func TestQuery(t *testing.T) {
}
}
+func TestQueryContext(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ rows, err := db.QueryContext(ctx, "SELECT|people|age,name|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ type row struct {
+ age int
+ name string
+ }
+ got := []row{}
+ index := 0
+ for rows.Next() {
+ if index == 2 {
+ cancel()
+ time.Sleep(10 * time.Millisecond)
+ }
+ var r row
+ err = rows.Scan(&r.age, &r.name)
+ if err != nil {
+ if index == 2 {
+ break
+ }
+ t.Fatalf("Scan: %v", err)
+ }
+ if index == 2 && err == nil {
+ t.Fatal("expected an error on last scan")
+ }
+ got = append(got, r)
+ index++
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("Err: %v", err)
+ }
+ want := []row{
+ {age: 1, name: "Alice"},
+ {age: 2, name: "Bob"},
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
+ }
+
+ // And verify that the final rows.Next() call, which hit EOF,
+ // also closed the rows connection.
+ if n := db.numFreeConns(); n != 1 {
+ t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+ }
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool {
+ deadline := time.Now().Add(waitFor)
+ for time.Now().Before(deadline) {
+ if fn() {
+ return true
+ }
+ time.Sleep(checkEvery)
+ }
+ return false
+}
+
+func TestQueryContextWait(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+
+ ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*15)
+
+ // This will trigger the *fakeConn.Prepare method which will take time
+ // performing the query. The ctxDriverPrepare func will check the context
+ // after this and close the rows and return an error.
+ _, err := db.QueryContext(ctx, "WAIT|1s|SELECT|people|age,name|")
+ if err != context.DeadlineExceeded {
+ t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err)
+ }
+
+ // Verify closed rows connection after error condition.
+ if n := db.numFreeConns(); n != 1 {
+ t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+ }
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+func TestTxContextWait(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*15)
+
+ tx, err := db.BeginTx(ctx, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // This will trigger the *fakeConn.Prepare method which will take time
+ // performing the query. The ctxDriverPrepare func will check the context
+ // after this and close the rows and return an error.
+ _, err = tx.QueryContext(ctx, "WAIT|1s|SELECT|people|age,name|")
+ if err != context.DeadlineExceeded {
+ t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err)
+ }
+
+ var numFree int
+ if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
+ numFree = db.numFreeConns()
+ return numFree == 0
+ }) {
+ t.Fatalf("free conns after hitting EOF = %d; want 0", numFree)
+ }
+
+ // Ensure the dropped connection allows more connections to be made.
+ // Checked on DB Close.
+ waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
+ return db.numOpenConns() == 0
+ })
+}
+
+func TestMultiResultSetQuery(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+ rows, err := db.Query("SELECT|people|age,name|;SELECT|people|name|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ type row1 struct {
+ age int
+ name string
+ }
+ type row2 struct {
+ name string
+ }
+ got1 := []row1{}
+ for rows.Next() {
+ var r row1
+ err = rows.Scan(&r.age, &r.name)
+ if err != nil {
+ t.Fatalf("Scan: %v", err)
+ }
+ got1 = append(got1, r)
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("Err: %v", err)
+ }
+ want1 := []row1{
+ {age: 1, name: "Alice"},
+ {age: 2, name: "Bob"},
+ {age: 3, name: "Chris"},
+ }
+ if !reflect.DeepEqual(got1, want1) {
+ t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1)
+ }
+
+ if !rows.NextResultSet() {
+ t.Errorf("expected another result set")
+ }
+
+ got2 := []row2{}
+ for rows.Next() {
+ var r row2
+ err = rows.Scan(&r.name)
+ if err != nil {
+ t.Fatalf("Scan: %v", err)
+ }
+ got2 = append(got2, r)
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("Err: %v", err)
+ }
+ want2 := []row2{
+ {name: "Alice"},
+ {name: "Bob"},
+ {name: "Chris"},
+ }
+ if !reflect.DeepEqual(got2, want2) {
+ t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2)
+ }
+ if rows.NextResultSet() {
+ t.Errorf("expected no more result sets")
+ }
+
+ // And verify that the final rows.Next() call, which hit EOF,
+ // also closed the rows connection.
+ if n := db.numFreeConns(); n != 1 {
+ t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+ }
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+func TestQueryNamedArg(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+ rows, err := db.Query(
+ // Ensure the name and age parameters only match on placeholder name, not position.
+ "SELECT|people|age,name|name=?name,age=?age",
+ Named("age", 2),
+ Named("name", "Bob"),
+ )
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ type row struct {
+ age int
+ name string
+ }
+ got := []row{}
+ for rows.Next() {
+ var r row
+ err = rows.Scan(&r.age, &r.name)
+ if err != nil {
+ t.Fatalf("Scan: %v", err)
+ }
+ got = append(got, r)
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("Err: %v", err)
+ }
+ want := []row{
+ {age: 2, name: "Bob"},
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
+ }
+
+ // And verify that the final rows.Next() call, which hit EOF,
+ // also closed the rows connection.
+ if n := db.numFreeConns(); n != 1 {
+ t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+ }
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
func TestByteOwnership(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
@@ -317,6 +583,56 @@ func TestRowsColumns(t *testing.T) {
}
}
+func TestRowsColumnTypes(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ rows, err := db.Query("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ tt, err := rows.ColumnTypes()
+ if err != nil {
+ t.Fatalf("ColumnTypes: %v", err)
+ }
+
+ types := make([]reflect.Type, len(tt))
+ for i, tp := range tt {
+ st := tp.ScanType()
+ if st == nil {
+ t.Errorf("scantype is null for column %q", tp.Name())
+ continue
+ }
+ types[i] = st
+ }
+ values := make([]interface{}, len(tt))
+ for i := range values {
+ values[i] = reflect.New(types[i]).Interface()
+ }
+ ct := 0
+ for rows.Next() {
+ err = rows.Scan(values...)
+ if err != nil {
+ t.Fatalf("failed to scan values in %v", err)
+ }
+ ct++
+ if ct == 0 {
+ if values[0].(string) != "Bob" {
+ t.Errorf("Expected Bob, got %v", values[0])
+ }
+ if values[1].(int) != 2 {
+ t.Errorf("Expected 2, got %v", values[1])
+ }
+ }
+ }
+ if ct != 3 {
+ t.Errorf("expected 3 rows, got %d", ct)
+ }
+
+ if err := rows.Close(); err != nil {
+ t.Errorf("error closing rows: %s", err)
+ }
+}
+
func TestQueryRow(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
@@ -367,6 +683,37 @@ func TestQueryRow(t *testing.T) {
}
}
+func TestTxRollbackCommitErr(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = tx.Rollback()
+ if err != nil {
+ t.Errorf("expected nil error from Rollback; got %v", err)
+ }
+ err = tx.Commit()
+ if err != ErrTxDone {
+ t.Errorf("expected %q from Commit; got %q", ErrTxDone, err)
+ }
+
+ tx, err = db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = tx.Commit()
+ if err != nil {
+ t.Errorf("expected nil error from Commit; got %v", err)
+ }
+ err = tx.Rollback()
+ if err != ErrTxDone {
+ t.Errorf("expected %q from Rollback; got %q", ErrTxDone, err)
+ }
+}
+
func TestStatementErrorAfterClose(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
@@ -439,7 +786,7 @@ func TestStatementClose(t *testing.T) {
msg string
}{
{&Stmt{stickyErr: want}, "stickyErr not propagated"},
- {&Stmt{tx: &Tx{}, txsi: &driverStmt{&sync.Mutex{}, stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
+ {&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
}
for _, test := range tests {
if err := test.stmt.Close(); err != want {
@@ -513,8 +860,8 @@ func TestExec(t *testing.T) {
{[]interface{}{7, 9}, ""},
// Invalid conversions:
- {[]interface{}{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument #1's type: sql/driver: value 4294967295 overflows int32"},
- {[]interface{}{"Brad", "strconv fail"}, "sql: converting argument #1's type: sql/driver: value \"strconv fail\" can't be converted to int32"},
+ {[]interface{}{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument $2 type: sql/driver: value 4294967295 overflows int32"},
+ {[]interface{}{"Brad", "strconv fail"}, `sql: converting argument $2 type: sql/driver: value "strconv fail" can't be converted to int32`},
// Wrong number of args:
{[]interface{}{}, "sql: expected 2 arguments, got 0"},
@@ -1159,17 +1506,19 @@ func TestMaxOpenConnsOnBusy(t *testing.T) {
db.SetMaxOpenConns(3)
- conn0, err := db.conn(cachedOrNewConn)
+ ctx := context.Background()
+
+ conn0, err := db.conn(ctx, cachedOrNewConn)
if err != nil {
t.Fatalf("db open conn fail: %v", err)
}
- conn1, err := db.conn(cachedOrNewConn)
+ conn1, err := db.conn(ctx, cachedOrNewConn)
if err != nil {
t.Fatalf("db open conn fail: %v", err)
}
- conn2, err := db.conn(cachedOrNewConn)
+ conn2, err := db.conn(ctx, cachedOrNewConn)
if err != nil {
t.Fatalf("db open conn fail: %v", err)
}
@@ -1203,7 +1552,11 @@ func TestPendingConnsAfterErr(t *testing.T) {
tryOpen = maxOpen*2 + 2
)
- db := newTestDB(t, "people")
+ // No queries will be run.
+ db, err := Open("test", fakeDBName)
+ if err != nil {
+ t.Fatalf("Open: %v", err)
+ }
defer closeDB(t, db)
defer func() {
for k, v := range db.lastPut {
@@ -1215,29 +1568,29 @@ func TestPendingConnsAfterErr(t *testing.T) {
db.SetMaxIdleConns(0)
errOffline := errors.New("db offline")
+
defer func() { setHookOpenErr(nil) }()
errs := make(chan error, tryOpen)
- unblock := make(chan struct{})
+ var opening sync.WaitGroup
+ opening.Add(tryOpen)
+
setHookOpenErr(func() error {
- <-unblock // block until all connections are in flight
+ // Wait for all connections to enqueue.
+ opening.Wait()
return errOffline
})
- var opening sync.WaitGroup
- opening.Add(tryOpen)
for i := 0; i < tryOpen; i++ {
go func() {
opening.Done() // signal one connection is in flight
- _, err := db.Exec("INSERT|people|name=Julia,age=19")
+ _, err := db.Exec("will never run")
errs <- err
}()
}
- opening.Wait() // wait for all workers to begin running
- time.Sleep(10 * time.Millisecond) // make extra sure all workers are blocked
- close(unblock) // let all workers proceed
+ opening.Wait() // wait for all workers to begin running
const timeout = 5 * time.Second
to := time.NewTimer(timeout)
@@ -1254,6 +1607,24 @@ func TestPendingConnsAfterErr(t *testing.T) {
t.Fatalf("orphaned connection request(s), still waiting after %v", timeout)
}
}
+
+ // Wait a reasonable time for the database to close all connections.
+ tick := time.NewTicker(3 * time.Millisecond)
+ defer tick.Stop()
+ for {
+ select {
+ case <-tick.C:
+ db.mu.Lock()
+ if db.numOpen == 0 {
+ db.mu.Unlock()
+ return
+ }
+ db.mu.Unlock()
+ case <-to.C:
+ // Closing the database will check for numOpen and fail the test.
+ return
+ }
+ }
}
func TestSingleOpenConn(t *testing.T) {
@@ -2236,6 +2607,54 @@ func TestIssue6081(t *testing.T) {
}
}
+// TestIssue18429 attempts to stress rolling back the transaction from a
+// context cancel while simultaneously calling Tx.Rollback. Rolling back from a
+// context happens concurrently so tx.rollback and tx.Commit must guard against
+// double entry.
+//
+// In the test, a context is canceled while the query is in process so
+// the internal rollback will run concurrently with the explicitly called
+// Tx.Rollback.
+func TestIssue18429(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx := context.Background()
+ sem := make(chan bool, 20)
+ var wg sync.WaitGroup
+
+ const milliWait = 30
+
+ for i := 0; i < 100; i++ {
+ sem <- true
+ wg.Add(1)
+ go func() {
+ defer func() {
+ <-sem
+ wg.Done()
+ }()
+ qwait := (time.Duration(rand.Intn(milliWait)) * time.Millisecond).String()
+
+ ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond)
+ defer cancel()
+
+ tx, err := db.BeginTx(ctx, nil)
+ if err != nil {
+ return
+ }
+ rows, err := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
+ if rows != nil {
+ rows.Close()
+ }
+ // This call will race with the context cancel rollback to complete
+ // if the rollback itself isn't guarded.
+ tx.Rollback()
+ }()
+ }
+ wg.Wait()
+ time.Sleep(milliWait * 3 * time.Millisecond)
+}
+
func TestConcurrency(t *testing.T) {
doConcurrentTest(t, new(concurrentDBQueryTest))
doConcurrentTest(t, new(concurrentDBExecTest))
@@ -2279,7 +2698,8 @@ func TestConnectionLeak(t *testing.T) {
go func() {
r, err := db.Query("SELECT|people|name|")
if err != nil {
- t.Fatal(err)
+ t.Error(err)
+ return
}
r.Close()
wg.Done()
@@ -2299,6 +2719,97 @@ func TestConnectionLeak(t *testing.T) {
wg.Wait()
}
+// badConn implements a bad driver.Conn, for TestBadDriver.
+// The Exec method panics.
+type badConn struct{}
+
+func (bc badConn) Prepare(query string) (driver.Stmt, error) {
+ return nil, errors.New("badConn Prepare")
+}
+
+func (bc badConn) Close() error {
+ return nil
+}
+
+func (bc badConn) Begin() (driver.Tx, error) {
+ return nil, errors.New("badConn Begin")
+}
+
+func (bc badConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+ panic("badConn.Exec")
+}
+
+// badDriver is a driver.Driver that uses badConn.
+type badDriver struct{}
+
+func (bd badDriver) Open(name string) (driver.Conn, error) {
+ return badConn{}, nil
+}
+
+// Issue 15901.
+func TestBadDriver(t *testing.T) {
+ Register("bad", badDriver{})
+ db, err := Open("bad", "ignored")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("expected panic")
+ } else {
+ if want := "badConn.Exec"; r.(string) != want {
+ t.Errorf("panic was %v, expected %v", r, want)
+ }
+ }
+ }()
+ defer db.Close()
+ db.Exec("ignored")
+}
+
+type pingDriver struct {
+ fails bool
+}
+
+type pingConn struct {
+ badConn
+ driver *pingDriver
+}
+
+var pingError = errors.New("Ping failed")
+
+func (pc pingConn) Ping(ctx context.Context) error {
+ if pc.driver.fails {
+ return pingError
+ }
+ return nil
+}
+
+var _ driver.Pinger = pingConn{}
+
+func (pd *pingDriver) Open(name string) (driver.Conn, error) {
+ return pingConn{driver: pd}, nil
+}
+
+func TestPing(t *testing.T) {
+ driver := &pingDriver{}
+ Register("ping", driver)
+
+ db, err := Open("ping", "ignored")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if err := db.Ping(); err != nil {
+ t.Errorf("err was %#v, expected nil", err)
+ return
+ }
+
+ driver.fails = true
+ if err := db.Ping(); err != pingError {
+ t.Errorf("err was %#v, expected pingError", err)
+ }
+}
+
func BenchmarkConcurrentDBExec(b *testing.B) {
b.ReportAllocs()
ct := new(concurrentDBExecTest)