aboutsummaryrefslogtreecommitdiff
path: root/libgo/go/database
diff options
context:
space:
mode:
authorIan Lance Taylor <iant@golang.org>2017-09-14 17:11:35 +0000
committerIan Lance Taylor <ian@gcc.gnu.org>2017-09-14 17:11:35 +0000
commitbc998d034f45d1828a8663b2eed928faf22a7d01 (patch)
tree8d262a22ca7318f4bcd64269fe8fe9e45bcf8d0f /libgo/go/database
parenta41a6142df74219f596e612d3a7775f68ca6e96f (diff)
downloadgcc-bc998d034f45d1828a8663b2eed928faf22a7d01.zip
gcc-bc998d034f45d1828a8663b2eed928faf22a7d01.tar.gz
gcc-bc998d034f45d1828a8663b2eed928faf22a7d01.tar.bz2
libgo: update to go1.9
Reviewed-on: https://go-review.googlesource.com/63753 From-SVN: r252767
Diffstat (limited to 'libgo/go/database')
-rw-r--r--libgo/go/database/sql/convert.go215
-rw-r--r--libgo/go/database/sql/convert_test.go18
-rw-r--r--libgo/go/database/sql/driver/driver.go30
-rw-r--r--libgo/go/database/sql/fakedb_test.go94
-rw-r--r--libgo/go/database/sql/sql.go684
-rw-r--r--libgo/go/database/sql/sql_test.go650
6 files changed, 1405 insertions, 286 deletions
diff --git a/libgo/go/database/sql/convert.go b/libgo/go/database/sql/convert.go
index ea2f377..4983181 100644
--- a/libgo/go/database/sql/convert.go
+++ b/libgo/go/database/sql/convert.go
@@ -12,6 +12,7 @@ import (
"fmt"
"reflect"
"strconv"
+ "sync"
"time"
"unicode"
"unicode/utf8"
@@ -37,86 +38,180 @@ func validateNamedValueName(name string) error {
return fmt.Errorf("name %q does not begin with a letter", name)
}
+func driverNumInput(ds *driverStmt) int {
+ ds.Lock()
+ defer ds.Unlock() // in case NumInput panics
+ return ds.si.NumInput()
+}
+
+// ccChecker wraps the driver.ColumnConverter and allows it to be used
+// as if it were a NamedValueChecker. If the driver ColumnConverter
+// is not present then the NamedValueChecker will return driver.ErrSkip.
+type ccChecker struct {
+ sync.Locker
+ cci driver.ColumnConverter
+ want int
+}
+
+func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
+ if c.cci == nil {
+ return driver.ErrSkip
+ }
+ // The column converter shouldn't be called on any index
+ // it isn't expecting. The final error will be thrown
+ // in the argument converter loop.
+ index := nv.Ordinal - 1
+ if c.want <= index {
+ return nil
+ }
+
+ // First, see if the value itself knows how to convert
+ // itself to a driver type. For example, a NullString
+ // struct changing into a string or nil.
+ if vr, ok := nv.Value.(driver.Valuer); ok {
+ sv, err := callValuerValue(vr)
+ if err != nil {
+ return err
+ }
+ if !driver.IsValue(sv) {
+ return fmt.Errorf("non-subset type %T returned from Value", sv)
+ }
+ nv.Value = sv
+ }
+
+ // Second, ask the column to sanity check itself. For
+ // example, drivers might use this to make sure that
+ // an int64 values being inserted into a 16-bit
+ // integer field is in range (before getting
+ // truncated), or that a nil can't go into a NOT NULL
+ // column before going across the network to get the
+ // same error.
+ var err error
+ arg := nv.Value
+ c.Lock()
+ nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
+ c.Unlock()
+ if err != nil {
+ return err
+ }
+ if !driver.IsValue(nv.Value) {
+ return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value)
+ }
+ return nil
+}
+
+// defaultCheckNamedValue wraps the default ColumnConverter to have the same
+// function signature as the CheckNamedValue in the driver.NamedValueChecker
+// interface.
+func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
+ nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
+ return err
+}
+
// driverArgs converts arguments from callers of Stmt.Exec and
// Stmt.Query into driver Values.
//
// The statement ds may be nil, if no statement is available.
-func driverArgs(ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
+func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
nvargs := make([]driver.NamedValue, len(args))
+
+ // -1 means the driver doesn't know how to count the number of
+ // placeholders, so we won't sanity check input here and instead let the
+ // driver deal with errors.
+ want := -1
+
var si driver.Stmt
+ var cc ccChecker
if ds != nil {
si = ds.si
+ want = driverNumInput(ds)
+ cc.Locker = ds.Locker
+ cc.want = want
}
- cc, ok := si.(driver.ColumnConverter)
- // Normal path, for a driver.Stmt that is not a ColumnConverter.
+ // Check all types of interfaces from the start.
+ // Drivers may opt to use the NamedValueChecker for special
+ // argument types, then return driver.ErrSkip to pass it along
+ // to the column converter.
+ nvc, ok := si.(driver.NamedValueChecker)
if !ok {
- for n, arg := range args {
- var err error
- nv := &nvargs[n]
- nv.Ordinal = n + 1
- if np, ok := arg.(NamedArg); ok {
- if err := validateNamedValueName(np.Name); err != nil {
- return nil, err
- }
- arg = np.Value
- nvargs[n].Name = np.Name
- }
- nv.Value, err = driver.DefaultParameterConverter.ConvertValue(arg)
-
- if err != nil {
- return nil, fmt.Errorf("sql: converting Exec argument %s type: %v", describeNamedValue(nv), err)
- }
- }
- return nvargs, nil
+ nvc, ok = ci.(driver.NamedValueChecker)
+ }
+ cci, ok := si.(driver.ColumnConverter)
+ if ok {
+ cc.cci = cci
}
- // Let the Stmt convert its own arguments.
- for n, arg := range args {
+ // Loop through all the arguments, checking each one.
+ // If no error is returned simply increment the index
+ // and continue. However if driver.ErrRemoveArgument
+ // is returned the argument is not included in the query
+ // argument list.
+ var err error
+ var n int
+ for _, arg := range args {
nv := &nvargs[n]
- nv.Ordinal = n + 1
if np, ok := arg.(NamedArg); ok {
- if err := validateNamedValueName(np.Name); err != nil {
+ if err = validateNamedValueName(np.Name); err != nil {
return nil, err
}
arg = np.Value
nv.Name = np.Name
}
- // First, see if the value itself knows how to convert
- // itself to a driver type. For example, a NullString
- // struct changing into a string or nil.
- if vr, ok := arg.(driver.Valuer); ok {
- sv, err := callValuerValue(vr)
- if err != nil {
- return nil, fmt.Errorf("sql: argument %s from Value: %v", describeNamedValue(nv), err)
- }
- if !driver.IsValue(sv) {
- return nil, fmt.Errorf("sql: argument %s: non-subset type %T returned from Value", describeNamedValue(nv), sv)
- }
- arg = sv
+ nv.Ordinal = n + 1
+ nv.Value = arg
+
+ // Checking sequence has four routes:
+ // A: 1. Default
+ // B: 1. NamedValueChecker 2. Column Converter 3. Default
+ // C: 1. NamedValueChecker 3. Default
+ // D: 1. Column Converter 2. Default
+ //
+ // The only time a Column Converter is called is first
+ // or after NamedValueConverter. If first it is handled before
+ // the nextCheck label. Thus for repeats tries only when the
+ // NamedValueConverter is selected should the Column Converter
+ // be used in the retry.
+ checker := defaultCheckNamedValue
+ nextCC := false
+ switch {
+ case nvc != nil:
+ nextCC = cci != nil
+ checker = nvc.CheckNamedValue
+ case cci != nil:
+ checker = cc.CheckNamedValue
}
- // Second, ask the column to sanity check itself. For
- // example, drivers might use this to make sure that
- // an int64 values being inserted into a 16-bit
- // integer field is in range (before getting
- // truncated), or that a nil can't go into a NOT NULL
- // column before going across the network to get the
- // same error.
- var err error
- ds.Lock()
- nv.Value, err = cc.ColumnConverter(n).ConvertValue(arg)
- ds.Unlock()
- if err != nil {
+ nextCheck:
+ err = checker(nv)
+ switch err {
+ case nil:
+ n++
+ continue
+ case driver.ErrRemoveArgument:
+ nvargs = nvargs[:len(nvargs)-1]
+ continue
+ case driver.ErrSkip:
+ if nextCC {
+ nextCC = false
+ checker = cc.CheckNamedValue
+ } else {
+ checker = defaultCheckNamedValue
+ }
+ goto nextCheck
+ default:
return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err)
}
- if !driver.IsValue(nv.Value) {
- return nil, fmt.Errorf("sql: for argument %s, driver ColumnConverter error converted %T to unsupported type %T",
- describeNamedValue(nv), arg, nv.Value)
- }
+ }
+
+ // Check the length of arguments after convertion to allow for omitted
+ // arguments.
+ if want != -1 && len(nvargs) != want {
+ return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
}
return nvargs, nil
+
}
// convertAssign copies to dest the value in src, converting it if possible.
@@ -270,6 +365,11 @@ func convertAssign(dest, src interface{}) error {
return nil
}
+ // The following conversions use a string value as an intermediate representation
+ // to convert between various numeric types.
+ //
+ // This also allows scanning into user defined types such as "type Int int64".
+ // For symmetry, also check for string destination types.
switch dv.Kind() {
case reflect.Ptr:
if src == nil {
@@ -306,6 +406,15 @@ func convertAssign(dest, src interface{}) error {
}
dv.SetFloat(f64)
return nil
+ case reflect.String:
+ switch v := src.(type) {
+ case string:
+ dv.SetString(v)
+ return nil
+ case []byte:
+ dv.SetString(string(v))
+ return nil
+ }
}
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
diff --git a/libgo/go/database/sql/convert_test.go b/libgo/go/database/sql/convert_test.go
index 4dfab1f..cfe52d7 100644
--- a/libgo/go/database/sql/convert_test.go
+++ b/libgo/go/database/sql/convert_test.go
@@ -10,6 +10,7 @@ import (
"reflect"
"runtime"
"strings"
+ "sync"
"testing"
"time"
)
@@ -17,9 +18,11 @@ import (
var someTime = time.Unix(123, 0)
var answer int64 = 42
-type userDefined float64
-
-type userDefinedSlice []int
+type (
+ userDefined float64
+ userDefinedSlice []int
+ userDefinedString string
+)
type conversionTest struct {
s, d interface{} // source and destination
@@ -39,6 +42,7 @@ type conversionTest struct {
wantptr *int64 // if non-nil, *d's pointed value must be equal to *wantptr
wantnil bool // if true, *d must be *int64(nil)
wantusrdef userDefined
+ wantusrstr userDefinedString
}
// Target variables for scanning into.
@@ -171,6 +175,7 @@ var conversionTests = []conversionTest{
{s: int64(123), d: new(userDefined), wantusrdef: 123},
{s: "1.5", d: new(userDefined), wantusrdef: 1.5},
{s: []byte{1, 2, 3}, d: new(userDefinedSlice), wanterr: `unsupported Scan, storing driver.Value type []uint8 into type *sql.userDefinedSlice`},
+ {s: "str", d: new(userDefinedString), wantusrstr: "str"},
// Other errors
{s: complex(1, 2), d: &scanstr, wanterr: `unsupported Scan, storing driver.Value type complex128 into type *string`},
@@ -260,6 +265,9 @@ func TestConversions(t *testing.T) {
if ct.wantusrdef != 0 && ct.wantusrdef != *ct.d.(*userDefined) {
errf("want userDefined %f, got %f", ct.wantusrdef, *ct.d.(*userDefined))
}
+ if len(ct.wantusrstr) != 0 && ct.wantusrstr != *ct.d.(*userDefinedString) {
+ errf("want userDefined %q, got %q", ct.wantusrstr, *ct.d.(*userDefinedString))
+ }
}
}
@@ -461,8 +469,8 @@ func TestDriverArgs(t *testing.T) {
},
}
for i, tt := range tests {
- ds := new(driverStmt)
- got, err := driverArgs(ds, tt.args)
+ ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}}
+ got, err := driverArgs(nil, ds, tt.args)
if err != nil {
t.Errorf("test[%d]: %v", i, err)
continue
diff --git a/libgo/go/database/sql/driver/driver.go b/libgo/go/database/sql/driver/driver.go
index d66196f..0262ca2 100644
--- a/libgo/go/database/sql/driver/driver.go
+++ b/libgo/go/database/sql/driver/driver.go
@@ -262,9 +262,39 @@ type StmtQueryContext interface {
QueryContext(ctx context.Context, args []NamedValue) (Rows, error)
}
+// ErrRemoveArgument may be returned from NamedValueChecker to instruct the
+// sql package to not pass the argument to the driver query interface.
+// Return when accepting query specific options or structures that aren't
+// SQL query arguments.
+var ErrRemoveArgument = errors.New("driver: remove argument from query")
+
+// NamedValueChecker may be optionally implemented by Conn or Stmt. It provides
+// the driver more control to handle Go and database types beyond the default
+// Values types allowed.
+//
+// The sql package checks for value checkers in the following order,
+// stopping at the first found match: Stmt.NamedValueChecker, Conn.NamedValueChecker,
+// Stmt.ColumnConverter, DefaultParameterConverter.
+//
+// If CheckNamedValue returns ErrRemoveArgument, the NamedValue will not be included in
+// the final query arguments. This may be used to pass special options to
+// the query itself.
+//
+// If ErrSkip is returned the column converter error checking
+// path is used for the argument. Drivers may wish to return ErrSkip after
+// they have exhausted their own special cases.
+type NamedValueChecker interface {
+ // CheckNamedValue is called before passing arguments to the driver
+ // and is called in place of any ColumnConverter. CheckNamedValue must do type
+ // validation and conversion as appropriate for the driver.
+ CheckNamedValue(*NamedValue) error
+}
+
// ColumnConverter may be optionally implemented by Stmt if the
// statement is aware of its own columns' types and can convert from
// any type to a driver Value.
+//
+// Deprecated: Drivers should implement NamedValueChecker.
type ColumnConverter interface {
// ColumnConverter returns a ValueConverter for the provided
// column index. If the type of a specific column isn't known
diff --git a/libgo/go/database/sql/fakedb_test.go b/libgo/go/database/sql/fakedb_test.go
index 4b15f5b..4dcd096 100644
--- a/libgo/go/database/sql/fakedb_test.go
+++ b/libgo/go/database/sql/fakedb_test.go
@@ -58,9 +58,10 @@ type fakeDriver struct {
type fakeDB struct {
name string
- mu sync.Mutex
- tables map[string]*table
- badConn bool
+ mu sync.Mutex
+ tables map[string]*table
+ badConn bool
+ allowAny bool
}
type table struct {
@@ -83,11 +84,20 @@ type row struct {
cols []interface{} // must be same size as its table colname + coltype
}
+type memToucher interface {
+ // touchMem reads & writes some memory, to help find data races.
+ touchMem()
+}
+
type fakeConn struct {
db *fakeDB // where to return ourselves to
currTx *fakeTx
+ // Every operation writes to line to enable the race detector
+ // check for data races.
+ line int64
+
// Stats for tests:
mu sync.Mutex
stmtsMade int
@@ -99,6 +109,10 @@ type fakeConn struct {
stickyBad bool
}
+func (c *fakeConn) touchMem() {
+ c.line++
+}
+
func (c *fakeConn) incrStat(v *int) {
c.mu.Lock()
*v++
@@ -116,6 +130,7 @@ type boundCol struct {
}
type fakeStmt struct {
+ memToucher
c *fakeConn
q string // just for debugging
@@ -298,6 +313,7 @@ func (c *fakeConn) Begin() (driver.Tx, error) {
if c.currTx != nil {
return nil, errors.New("already in a transaction")
}
+ c.touchMem()
c.currTx = &fakeTx{c: c}
return c.currTx, nil
}
@@ -339,6 +355,7 @@ func (c *fakeConn) Close() (err error) {
drv.mu.Unlock()
}
}()
+ c.touchMem()
if c.currTx != nil {
return errors.New("can't close fakeConn; in a Transaction")
}
@@ -352,12 +369,14 @@ func (c *fakeConn) Close() (err error) {
return nil
}
-func checkSubsetTypes(args []driver.NamedValue) error {
+func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
for _, arg := range args {
switch arg.Value.(type) {
case int64, float64, bool, nil, []byte, string, time.Time:
default:
- return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
+ if !allowAny {
+ return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
+ }
}
}
return nil
@@ -373,7 +392,7 @@ func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.
// just to check that all the args are of the proper types.
// ErrSkip is returned so the caller acts as if we didn't
// implement this at all.
- err := checkSubsetTypes(args)
+ err := checkSubsetTypes(c.db.allowAny, args)
if err != nil {
return nil, err
}
@@ -390,7 +409,7 @@ func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver
// just to check that all the args are of the proper types.
// ErrSkip is returned so the caller acts as if we didn't
// implement this at all.
- err := checkSubsetTypes(args)
+ err := checkSubsetTypes(c.db.allowAny, args)
if err != nil {
return nil, err
}
@@ -524,13 +543,14 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
return nil, driver.ErrBadConn
}
+ c.touchMem()
var firstStmt, prev *fakeStmt
for _, query := range strings.Split(query, ";") {
parts := strings.Split(query, "|")
if len(parts) < 1 {
return nil, errf("empty query")
}
- stmt := &fakeStmt{q: query, c: c}
+ stmt := &fakeStmt{q: query, c: c, memToucher: c}
if firstStmt == nil {
firstStmt = stmt
}
@@ -612,6 +632,7 @@ func (s *fakeStmt) Close() error {
if s.c.db == nil {
panic("in fakeStmt.Close, conn's db is nil (already closed)")
}
+ s.touchMem()
if !s.closed {
s.c.incrStat(&s.c.stmtsClosed)
s.closed = true
@@ -642,10 +663,11 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
return nil, driver.ErrBadConn
}
- err := checkSubsetTypes(args)
+ err := checkSubsetTypes(s.c.db.allowAny, args)
if err != nil {
return nil, err
}
+ s.touchMem()
if s.wait > 0 {
time.Sleep(s.wait)
@@ -753,11 +775,12 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
return nil, driver.ErrBadConn
}
- err := checkSubsetTypes(args)
+ err := checkSubsetTypes(s.c.db.allowAny, args)
if err != nil {
return nil, err
}
+ s.touchMem()
db := s.c.db
if len(args) != s.placeholders {
panic("error in pkg db; should only get here if size is correct")
@@ -853,11 +876,12 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
}
cursor := &rowsCursor{
- posRow: -1,
- rows: setMRows,
- cols: setColumns,
- colType: setColType,
- errPos: -1,
+ parentMem: s.c,
+ posRow: -1,
+ rows: setMRows,
+ cols: setColumns,
+ colType: setColType,
+ errPos: -1,
}
return cursor, nil
}
@@ -877,6 +901,7 @@ func (tx *fakeTx) Commit() error {
if hookCommitBadConn != nil && hookCommitBadConn() {
return driver.ErrBadConn
}
+ tx.c.touchMem()
return nil
}
@@ -888,16 +913,18 @@ func (tx *fakeTx) Rollback() error {
if hookRollbackBadConn != nil && hookRollbackBadConn() {
return driver.ErrBadConn
}
+ tx.c.touchMem()
return nil
}
type rowsCursor struct {
- cols [][]string
- colType [][]string
- posSet int
- posRow int
- rows [][]*row
- closed bool
+ parentMem memToucher
+ cols [][]string
+ colType [][]string
+ posSet int
+ posRow int
+ rows [][]*row
+ closed bool
// errPos and err are for making Next return early with error.
errPos int
@@ -907,6 +934,16 @@ type rowsCursor struct {
// the original slice's first byte address. we clone them
// just so we're able to corrupt them on close.
bytesClone map[*byte][]byte
+
+ // Every operation writes to line to enable the race detector
+ // check for data races.
+ // This is separate from the fakeConn.line to allow for drivers that
+ // can start multiple queries on the same transaction at the same time.
+ line int64
+}
+
+func (rc *rowsCursor) touchMem() {
+ rc.line++
}
func (rc *rowsCursor) Close() error {
@@ -915,6 +952,8 @@ func (rc *rowsCursor) Close() error {
bs[0] = 255 // first byte corrupted
}
}
+ rc.touchMem()
+ rc.parentMem.touchMem()
rc.closed = true
return nil
}
@@ -937,6 +976,7 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
if rc.closed {
return errors.New("fakedb: cursor is closed")
}
+ rc.touchMem()
rc.posRow++
if rc.posRow == rc.errPos {
return rc.err
@@ -970,10 +1010,12 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
}
func (rc *rowsCursor) HasNextResultSet() bool {
+ rc.touchMem()
return rc.posSet < len(rc.rows)-1
}
func (rc *rowsCursor) NextResultSet() error {
+ rc.touchMem()
if rc.HasNextResultSet() {
rc.posSet++
rc.posRow = -1
@@ -1004,6 +1046,12 @@ func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
return fmt.Sprintf("%v", v), nil
}
+type anyTypeConverter struct{}
+
+func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) {
+ return v, nil
+}
+
func converterForType(typ string) driver.ValueConverter {
switch typ {
case "bool":
@@ -1030,6 +1078,8 @@ func converterForType(typ string) driver.ValueConverter {
return driver.Null{Converter: driver.DefaultParameterConverter}
case "datetime":
return driver.DefaultParameterConverter
+ case "any":
+ return anyTypeConverter{}
}
panic("invalid fakedb column type of " + typ)
}
@@ -1056,6 +1106,8 @@ func colTypeToReflectType(typ string) reflect.Type {
return reflect.TypeOf(NullFloat64{})
case "datetime":
return reflect.TypeOf(time.Time{})
+ case "any":
+ return reflect.TypeOf(new(interface{})).Elem()
}
panic("invalid fakedb column type of " + typ)
}
diff --git a/libgo/go/database/sql/sql.go b/libgo/go/database/sql/sql.go
index f8a8844..c609fe4 100644
--- a/libgo/go/database/sql/sql.go
+++ b/libgo/go/database/sql/sql.go
@@ -278,6 +278,27 @@ type Scanner interface {
Scan(src interface{}) error
}
+// Out may be used to retrieve OUTPUT value parameters from stored procedures.
+//
+// Not all drivers and databases support OUTPUT value parameters.
+//
+// Example usage:
+//
+// var outArg string
+// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", Out{Dest: &outArg}))
+type Out struct {
+ _Named_Fields_Required struct{}
+
+ // Dest is a pointer to the value that will be set to the result of the
+ // stored procedure's OUTPUT parameter.
+ Dest interface{}
+
+ // In is whether the parameter is an INOUT parameter. If so, the input value to the stored
+ // procedure is the dereferenced value of Dest's pointer, which is then replaced with
+ // the output value.
+ In bool
+}
+
// ErrNoRows is returned by Scan when QueryRow doesn't return a
// row. In such a case, QueryRow returns a placeholder *Row value that
// defers this error until a Scan.
@@ -372,11 +393,19 @@ func (dc *driverConn) expired(timeout time.Duration) bool {
return dc.createdAt.Add(timeout).Before(nowFunc())
}
-func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) {
+// prepareLocked prepares the query on dc. When cg == nil the dc must keep track of
+// the prepared statements in a pool.
+func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) {
si, err := ctxDriverPrepare(ctx, dc.ci, query)
if err != nil {
return nil, err
}
+ ds := &driverStmt{Locker: dc, si: si}
+
+ // No need to manage open statements if there is a single connection grabber.
+ if cg != nil {
+ return ds, nil
+ }
// Track each driverConn's open statements, so we can close them
// before closing the conn.
@@ -385,9 +414,7 @@ func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverS
if dc.openStmt == nil {
dc.openStmt = make(map[*driverStmt]bool)
}
- ds := &driverStmt{Locker: dc, si: si}
dc.openStmt[ds] = true
-
return ds, nil
}
@@ -583,6 +610,17 @@ func Open(driverName, dataSourceName string) (*DB, error) {
return db, nil
}
+func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error {
+ var err error
+ if pinger, ok := dc.ci.(driver.Pinger); ok {
+ withLock(dc, func() {
+ err = pinger.Ping(ctx)
+ })
+ }
+ release(err)
+ return err
+}
+
// PingContext verifies a connection to the database is still alive,
// establishing a connection if necessary.
func (db *DB) PingContext(ctx context.Context) error {
@@ -602,11 +640,7 @@ func (db *DB) PingContext(ctx context.Context) error {
return err
}
- if pinger, ok := dc.ci.(driver.Pinger); ok {
- err = pinger.Ping(ctx)
- }
- db.putConn(dc, err)
- return err
+ return db.pingDC(ctx, dc, dc.releaseConn)
}
// Ping verifies a connection to the database is still alive,
@@ -975,9 +1009,9 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
db: db,
createdAt: nowFunc(),
ci: ci,
+ inUse: true,
}
db.addDepLocked(dc, dc)
- dc.inUse = true
db.mu.Unlock()
return dc, nil
}
@@ -1137,22 +1171,39 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat
if err != nil {
return nil, err
}
+ return db.prepareDC(ctx, dc, dc.releaseConn, nil, query)
+}
+
+// prepareDC prepares a query on the driverConn and calls release before
+// returning. When cg == nil it implies that a connection pool is used, and
+// when cg != nil only a single driver connection is used.
+func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) {
var ds *driverStmt
+ var err error
+ defer func() {
+ release(err)
+ }()
withLock(dc, func() {
- ds, err = dc.prepareLocked(ctx, query)
+ ds, err = dc.prepareLocked(ctx, cg, query)
})
if err != nil {
- db.putConn(dc, err)
return nil, err
}
stmt := &Stmt{
- db: db,
- query: query,
- css: []connStmt{{dc, ds}},
- lastNumClosed: atomic.LoadUint64(&db.numClosed),
+ db: db,
+ query: query,
+ cg: cg,
+ cgds: ds,
+ }
+
+ // When cg == nil this statement will need to keep track of various
+ // connections they are prepared on and record the stmt dependency on
+ // the DB.
+ if cg == nil {
+ stmt.css = []connStmt{{dc, ds}}
+ stmt.lastNumClosed = atomic.LoadUint64(&db.numClosed)
+ db.addDep(stmt, stmt)
}
- db.addDep(stmt, stmt)
- db.putConn(dc, nil)
return stmt, nil
}
@@ -1179,18 +1230,21 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
return db.ExecContext(context.Background(), query, args...)
}
-func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
+func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (Result, error) {
dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
+ return db.execDC(ctx, dc, dc.releaseConn, query, args)
+}
+
+func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), query string, args []interface{}) (res Result, err error) {
defer func() {
- db.putConn(dc, err)
+ release(err)
}()
-
if execer, ok := dc.ci.(driver.Execer); ok {
var dargs []driver.NamedValue
- dargs, err = driverArgs(nil, args)
+ dargs, err = driverArgs(dc.ci, nil, args)
if err != nil {
return nil, err
}
@@ -1215,7 +1269,7 @@ func (db *DB) exec(ctx context.Context, query string, args []interface{}, strate
}
ds := &driverStmt{Locker: dc, si: si}
defer ds.Close()
- return resultFromStatement(ctx, ds, args...)
+ return resultFromStatement(ctx, dc.ci, ds, args...)
}
// QueryContext executes a query that returns rows, typically a SELECT.
@@ -1242,19 +1296,21 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
}
func (db *DB) query(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
- ci, err := db.conn(ctx, strategy)
+ dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
- return db.queryConn(ctx, ci, ci.releaseConn, query, args)
+ return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args)
}
-// queryConn executes a query on the given connection.
+// queryDC executes a query on the given connection.
// The connection gets released by the releaseConn function.
-func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
+// The ctx context is from a query method and the txctx context is from an
+// optional transaction context.
+func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
if queryer, ok := dc.ci.(driver.Queryer); ok {
- dargs, err := driverArgs(nil, args)
+ dargs, err := driverArgs(dc.ci, nil, args)
if err != nil {
releaseConn(err)
return nil, err
@@ -1275,7 +1331,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er
releaseConn: releaseConn,
rowsi: rowsi,
}
- rows.initContextClose(ctx)
+ rows.initContextClose(ctx, txctx)
return rows, nil
}
}
@@ -1291,7 +1347,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er
}
ds := &driverStmt{Locker: dc, si: si}
- rowsi, err := rowsiFromStatement(ctx, ds, args...)
+ rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...)
if err != nil {
ds.Close()
releaseConn(err)
@@ -1306,13 +1362,16 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er
rowsi: rowsi,
closeStmt: ds,
}
- rows.initContextClose(ctx)
+ rows.initContextClose(ctx, txctx)
return rows, nil
}
// QueryRowContext executes a query that is expected to return at most one row.
// QueryRowContext always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := db.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err}
@@ -1321,6 +1380,9 @@ func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interfa
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
func (db *DB) QueryRow(query string, args ...interface{}) *Row {
return db.QueryRowContext(context.Background(), query, args...)
}
@@ -1361,12 +1423,17 @@ func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStra
if err != nil {
return nil, err
}
+ return db.beginDC(ctx, dc, dc.releaseConn, opts)
+}
+
+// beginDC starts a transaction. The provided dc must be valid and ready to use.
+func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) {
var txi driver.Tx
withLock(dc, func() {
txi, err = ctxDriverBegin(ctx, opts, dc.ci)
})
if err != nil {
- db.putConn(dc, err)
+ release(err)
return nil, err
}
@@ -1374,11 +1441,12 @@ func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStra
// The cancel function in Tx will be called after done is set to true.
ctx, cancel := context.WithCancel(ctx)
tx = &Tx{
- db: db,
- dc: dc,
- txi: txi,
- cancel: cancel,
- ctx: ctx,
+ db: db,
+ dc: dc,
+ releaseConn: release,
+ txi: txi,
+ cancel: cancel,
+ ctx: ctx,
}
go tx.awaitDone()
return tx, nil
@@ -1389,6 +1457,189 @@ func (db *DB) Driver() driver.Driver {
return db.driver
}
+// ErrConnDone is returned by any operation that is performed on a connection
+// that has already been committed or rolled back.
+var ErrConnDone = errors.New("database/sql: connection is already closed")
+
+// Conn returns a single connection by either opening a new connection
+// or returning an existing connection from the connection pool. Conn will
+// block until either a connection is returned or ctx is canceled.
+// Queries run on the same Conn will be run in the same database session.
+//
+// Every Conn must be returned to the database pool after use by
+// calling Conn.Close.
+func (db *DB) Conn(ctx context.Context) (*Conn, error) {
+ var dc *driverConn
+ var err error
+ for i := 0; i < maxBadConnRetries; i++ {
+ dc, err = db.conn(ctx, cachedOrNewConn)
+ if err != driver.ErrBadConn {
+ break
+ }
+ }
+ if err == driver.ErrBadConn {
+ dc, err = db.conn(ctx, cachedOrNewConn)
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ conn := &Conn{
+ db: db,
+ dc: dc,
+ }
+ return conn, nil
+}
+
+type releaseConn func(error)
+
+// Conn represents a single database session rather a pool of database
+// sessions. Prefer running queries from DB unless there is a specific
+// need for a continuous single database session.
+//
+// A Conn must call Close to return the connection to the database pool
+// and may do so concurrently with a running query.
+//
+// After a call to Close, all operations on the
+// connection fail with ErrConnDone.
+type Conn struct {
+ db *DB
+
+ // closemu prevents the connection from closing while there
+ // is an active query. It is held for read during queries
+ // and exclusively during close.
+ closemu sync.RWMutex
+
+ // dc is owned until close, at which point
+ // it's returned to the connection pool.
+ dc *driverConn
+
+ // done transitions from 0 to 1 exactly once, on close.
+ // Once done, all operations fail with ErrConnDone.
+ // Use atomic operations on value when checking value.
+ done int32
+}
+
+func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) {
+ if atomic.LoadInt32(&c.done) != 0 {
+ return nil, nil, ErrConnDone
+ }
+ c.closemu.RLock()
+ return c.dc, c.closemuRUnlockCondReleaseConn, nil
+}
+
+// PingContext verifies the connection to the database is still alive.
+func (c *Conn) PingContext(ctx context.Context) error {
+ dc, release, err := c.grabConn(ctx)
+ if err != nil {
+ return err
+ }
+ return c.db.pingDC(ctx, dc, release)
+}
+
+// ExecContext executes a query without returning any rows.
+// The args are for any placeholder parameters in the query.
+func (c *Conn) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
+ dc, release, err := c.grabConn(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return c.db.execDC(ctx, dc, release, query, args)
+}
+
+// QueryContext executes a query that returns rows, typically a SELECT.
+// The args are for any placeholder parameters in the query.
+func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
+ dc, release, err := c.grabConn(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return c.db.queryDC(ctx, nil, dc, release, query, args)
+}
+
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
+func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
+ rows, err := c.QueryContext(ctx, query, args...)
+ return &Row{rows: rows, err: err}
+}
+
+// PrepareContext creates a prepared statement for later queries or executions.
+// Multiple queries or executions may be run concurrently from the
+// returned statement.
+// The caller must call the statement's Close method
+// when the statement is no longer needed.
+//
+// The provided context is used for the preparation of the statement, not for the
+// execution of the statement.
+func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
+ dc, release, err := c.grabConn(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return c.db.prepareDC(ctx, dc, release, c, query)
+}
+
+// BeginTx starts a transaction.
+//
+// The provided context is used until the transaction is committed or rolled back.
+// If the context is canceled, the sql package will roll back
+// the transaction. Tx.Commit will return an error if the context provided to
+// BeginTx is canceled.
+//
+// The provided TxOptions is optional and may be nil if defaults should be used.
+// If a non-default isolation level is used that the driver doesn't support,
+// an error will be returned.
+func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
+ dc, release, err := c.grabConn(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return c.db.beginDC(ctx, dc, release, opts)
+}
+
+// closemuRUnlockCondReleaseConn read unlocks closemu
+// as the sql operation is done with the dc.
+func (c *Conn) closemuRUnlockCondReleaseConn(err error) {
+ c.closemu.RUnlock()
+ if err == driver.ErrBadConn {
+ c.close(err)
+ }
+}
+
+func (c *Conn) txCtx() context.Context {
+ return nil
+}
+
+func (c *Conn) close(err error) error {
+ if !atomic.CompareAndSwapInt32(&c.done, 0, 1) {
+ return ErrConnDone
+ }
+
+ // Lock around releasing the driver connection
+ // to ensure all queries have been stopped before doing so.
+ c.closemu.Lock()
+ defer c.closemu.Unlock()
+
+ c.dc.releaseConn(err)
+ c.dc = nil
+ c.db = nil
+ return err
+}
+
+// Close returns the connection to the connection pool.
+// All operations after a Close will return with ErrConnDone.
+// Close is safe to call concurrently with other operations and will
+// block until all other operations finish. It may be useful to first
+// cancel any used context and then call close directly after.
+func (c *Conn) Close() error {
+ return c.close(nil)
+}
+
// Tx is an in-progress database transaction.
//
// A transaction must end with a call to Commit or Rollback.
@@ -1412,6 +1663,10 @@ type Tx struct {
dc *driverConn
txi driver.Tx
+ // releaseConn is called once the Tx is closed to release
+ // any held driverConn back to the pool.
+ releaseConn func(error)
+
// done transitions from 0 to 1 exactly once, on Commit
// or Rollback. once done, all operations fail with
// ErrTxDone.
@@ -1425,7 +1680,7 @@ type Tx struct {
v []*Stmt
}
- // cancel is called after done transitions from false to true.
+ // cancel is called after done transitions from 0 to 1.
cancel func()
// ctx lives for the life of the transaction.
@@ -1457,11 +1712,12 @@ var ErrTxDone = errors.New("sql: Transaction has already been committed or rolle
// close returns the connection to the pool and
// must only be called by Tx.rollback or Tx.Commit.
func (tx *Tx) close(err error) {
+ tx.cancel()
+
tx.closemu.Lock()
defer tx.closemu.Unlock()
- tx.db.putConn(tx.dc, err)
- tx.cancel()
+ tx.releaseConn(err)
tx.dc = nil
tx.txi = nil
}
@@ -1470,19 +1726,36 @@ func (tx *Tx) close(err error) {
// a successful call to (*Tx).grabConn. For tests.
var hookTxGrabConn func()
-func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
+func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
select {
default:
case <-ctx.Done():
- return nil, ctx.Err()
+ return nil, nil, ctx.Err()
}
+
+ // closeme.RLock must come before the check for isDone to prevent the Tx from
+ // closing while a query is executing.
+ tx.closemu.RLock()
if tx.isDone() {
- return nil, ErrTxDone
+ tx.closemu.RUnlock()
+ return nil, nil, ErrTxDone
}
if hookTxGrabConn != nil { // test hook
hookTxGrabConn()
}
- return tx.dc, nil
+ return tx.dc, tx.closemuRUnlockRelease, nil
+}
+
+func (tx *Tx) txCtx() context.Context {
+ return tx.ctx
+}
+
+// closemuRUnlockRelease is used as a func(error) method value in
+// ExecContext and QueryContext. Unlocking in the releaseConn keeps
+// the driver conn from being returned to the connection pool until
+// the Rows has been closed.
+func (tx *Tx) closemuRUnlockRelease(error) {
+ tx.closemu.RUnlock()
}
// Closes all Stmts prepared for this transaction.
@@ -1540,7 +1813,7 @@ func (tx *Tx) Rollback() error {
return tx.rollback(false)
}
-// Prepare creates a prepared statement for use within a transaction.
+// PrepareContext creates a prepared statement for use within a transaction.
//
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
@@ -1551,44 +1824,15 @@ func (tx *Tx) Rollback() error {
// for the execution of the returned statement. The returned statement
// will run in the transaction context.
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
- tx.closemu.RLock()
- defer tx.closemu.RUnlock()
-
- // TODO(bradfitz): We could be more efficient here and either
- // provide a method to take an existing Stmt (created on
- // perhaps a different Conn), and re-create it on this Conn if
- // necessary. Or, better: keep a map in DB of query string to
- // Stmts, and have Stmt.Execute do the right thing and
- // re-prepare if the Conn in use doesn't have that prepared
- // statement. But we'll want to avoid caching the statement
- // in the case where we only call conn.Prepare implicitly
- // (such as in db.Exec or tx.Exec), but the caller package
- // can't be holding a reference to the returned statement.
- // Perhaps just looking at the reference count (by noting
- // Stmt.Close) would be enough. We might also want a finalizer
- // on Stmt to drop the reference count.
- dc, err := tx.grabConn(ctx)
+ dc, release, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
- var si driver.Stmt
- withLock(dc, func() {
- si, err = ctxDriverPrepare(ctx, dc.ci, query)
- })
+ stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query)
if err != nil {
return nil, err
}
-
- stmt := &Stmt{
- db: tx.db,
- tx: tx,
- txds: &driverStmt{
- Locker: dc,
- si: si,
- },
- query: query,
- }
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, stmt)
tx.stmts.Unlock()
@@ -1618,34 +1862,67 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
- tx.closemu.RLock()
- defer tx.closemu.RUnlock()
-
- // TODO(bradfitz): optimize this. Currently this re-prepares
- // each time. This is fine for now to illustrate the API but
- // we should really cache already-prepared statements
- // per-Conn. See also the big comment in Tx.Prepare.
+ dc, release, err := tx.grabConn(ctx)
+ if err != nil {
+ return &Stmt{stickyErr: err}
+ }
+ defer release(nil)
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
- dc, err := tx.grabConn(ctx)
- if err != nil {
- return &Stmt{stickyErr: err}
- }
var si driver.Stmt
- withLock(dc, func() {
- si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
- })
+ var parentStmt *Stmt
+ stmt.mu.Lock()
+ if stmt.closed || stmt.cg != nil {
+ // If the statement has been closed or already belongs to a
+ // transaction, we can't reuse it in this connection.
+ // Since tx.StmtContext should never need to be called with a
+ // Stmt already belonging to tx, we ignore this edge case and
+ // re-prepare the statement in this case. No need to add
+ // code-complexity for this.
+ stmt.mu.Unlock()
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
+ })
+ if err != nil {
+ return &Stmt{stickyErr: err}
+ }
+ } else {
+ stmt.removeClosedStmtLocked()
+ // See if the statement has already been prepared on this connection,
+ // and reuse it if possible.
+ for _, v := range stmt.css {
+ if v.dc == dc {
+ si = v.ds.si
+ break
+ }
+ }
+
+ stmt.mu.Unlock()
+
+ if si == nil {
+ cs, err := stmt.prepareOnConnLocked(ctx, dc)
+ if err != nil {
+ return &Stmt{stickyErr: err}
+ }
+ si = cs.si
+ }
+ parentStmt = stmt
+ }
+
txs := &Stmt{
db: tx.db,
- tx: tx,
- txds: &driverStmt{
+ cg: tx,
+ cgds: &driverStmt{
Locker: dc,
si: si,
},
- query: stmt.query,
- stickyErr: err,
+ parentStmt: parentStmt,
+ query: stmt.query,
+ }
+ if parentStmt != nil {
+ tx.db.addDep(parentStmt, txs)
}
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, txs)
@@ -1672,42 +1949,11 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
// ExecContext executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
- tx.closemu.RLock()
- defer tx.closemu.RUnlock()
-
- dc, err := tx.grabConn(ctx)
+ dc, release, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
-
- if execer, ok := dc.ci.(driver.Execer); ok {
- dargs, err := driverArgs(nil, args)
- if err != nil {
- return nil, err
- }
- var resi driver.Result
- withLock(dc, func() {
- resi, err = ctxDriverExec(ctx, execer, query, dargs)
- })
- if err == nil {
- return driverResult{dc, resi}, nil
- }
- if err != driver.ErrSkip {
- return nil, err
- }
- }
-
- var si driver.Stmt
- withLock(dc, func() {
- si, err = ctxDriverPrepare(ctx, dc.ci, query)
- })
- if err != nil {
- return nil, err
- }
- ds := &driverStmt{Locker: dc, si: si}
- defer ds.Close()
-
- return resultFromStatement(ctx, ds, args...)
+ return tx.db.execDC(ctx, dc, release, query, args)
}
// Exec executes a query that doesn't return rows.
@@ -1718,15 +1964,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
// QueryContext executes a query that returns rows, typically a SELECT.
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
- tx.closemu.RLock()
- defer tx.closemu.RUnlock()
-
- dc, err := tx.grabConn(ctx)
+ dc, release, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
- releaseConn := func(error) {}
- return tx.db.queryConn(ctx, dc, releaseConn, query, args)
+
+ return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args)
}
// Query executes a query that returns rows, typically a SELECT.
@@ -1737,6 +1980,9 @@ func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
// QueryRowContext executes a query that is expected to return at most one row.
// QueryRowContext always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := tx.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err}
@@ -1745,6 +1991,9 @@ func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interfa
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
return tx.QueryRowContext(context.Background(), query, args...)
}
@@ -1755,6 +2004,24 @@ type connStmt struct {
ds *driverStmt
}
+// stmtConnGrabber represents a Tx or Conn that will return the underlying
+// driverConn and release function.
+type stmtConnGrabber interface {
+ // grabConn returns the driverConn and the associated release function
+ // that must be called when the operation completes.
+ grabConn(context.Context) (*driverConn, releaseConn, error)
+
+ // txCtx returns the transaction context if available.
+ // The returned context should be selected on along with
+ // any query context when awaiting a cancel.
+ txCtx() context.Context
+}
+
+var (
+ _ stmtConnGrabber = &Tx{}
+ _ stmtConnGrabber = &Conn{}
+)
+
// Stmt is a prepared statement.
// A Stmt is safe for concurrent use by multiple goroutines.
type Stmt struct {
@@ -1765,17 +2032,29 @@ type Stmt struct {
closemu sync.RWMutex // held exclusively during close, for read otherwise.
- // If in a transaction, else both nil:
- tx *Tx
- txds *driverStmt
+ // If Stmt is prepared on a Tx or Conn then cg is present and will
+ // only ever grab a connection from cg.
+ // If cg is nil then the Stmt must grab an arbitrary connection
+ // from db and determine if it must prepare the stmt again by
+ // inspecting css.
+ cg stmtConnGrabber
+ cgds *driverStmt
+
+ // parentStmt is set when a transaction-specific statement
+ // is requested from an identical statement prepared on the same
+ // conn. parentStmt is used to track the dependency of this statement
+ // on its originating ("parent") statement so that parentStmt may
+ // be closed by the user without them having to know whether or not
+ // any transactions are still using it.
+ parentStmt *Stmt
mu sync.Mutex // protects the rest of the fields
closed bool
// css is a list of underlying driver statement interfaces
// that are valid on particular connections. This is only
- // used if tx == nil and one is found that has idle
- // connections. If tx != nil, txsi is always used.
+ // used if cg == nil and one is found that has idle
+ // connections. If cg != nil, cgds is always used.
css []connStmt
// lastNumClosed is copied from db.numClosed when Stmt is created
@@ -1790,8 +2069,12 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
defer s.closemu.RUnlock()
var res Result
- for i := 0; i < maxBadConnRetries; i++ {
- _, releaseConn, ds, err := s.connStmt(ctx)
+ strategy := cachedOrNewConn
+ for i := 0; i < maxBadConnRetries+1; i++ {
+ if i == maxBadConnRetries {
+ strategy = alwaysNewConn
+ }
+ dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
if err != nil {
if err == driver.ErrBadConn {
continue
@@ -1799,7 +2082,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
return nil, err
}
- res, err = resultFromStatement(ctx, ds, args...)
+ res, err = resultFromStatement(ctx, dc.ci, ds, args...)
releaseConn(err)
if err != driver.ErrBadConn {
return res, err
@@ -1814,23 +2097,8 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
return s.ExecContext(context.Background(), args...)
}
-func driverNumInput(ds *driverStmt) int {
- ds.Lock()
- defer ds.Unlock() // in case NumInput panics
- return ds.si.NumInput()
-}
-
-func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) {
- want := driverNumInput(ds)
-
- // -1 means the driver doesn't know how to count the number of
- // placeholders, so we won't sanity check input here and instead let the
- // driver deal with errors.
- if want != -1 && len(args) != want {
- return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args))
- }
-
- dargs, err := driverArgs(ds, args)
+func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) {
+ dargs, err := driverArgs(ci, ds, args)
if err != nil {
return nil, err
}
@@ -1874,7 +2142,7 @@ func (s *Stmt) removeClosedStmtLocked() {
// connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a
// statement bound to that connection.
-func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) {
+func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
if err = s.stickyErr; err != nil {
return
}
@@ -1885,22 +2153,21 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e
return
}
- // In a transaction, we always use the connection that the
- // transaction was created on.
- if s.tx != nil {
+ // In a transaction or connection, we always use the connection that the
+ // the stmt was created on.
+ if s.cg != nil {
s.mu.Unlock()
- ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection.
+ dc, releaseConn, err = s.cg.grabConn(ctx) // blocks, waiting for the connection.
if err != nil {
return
}
- releaseConn = func(error) {}
- return ci, releaseConn, s.txds, nil
+ return dc, releaseConn, s.cgds, nil
}
s.removeClosedStmtLocked()
s.mu.Unlock()
- dc, err := s.db.conn(ctx, cachedOrNewConn)
+ dc, err = s.db.conn(ctx, strategy)
if err != nil {
return nil, nil, nil, err
}
@@ -1916,18 +2183,28 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e
// No luck; we need to prepare the statement on this connection
withLock(dc, func() {
- ds, err = dc.prepareLocked(ctx, s.query)
+ ds, err = s.prepareOnConnLocked(ctx, dc)
})
if err != nil {
- s.db.putConn(dc, err)
+ dc.releaseConn(err)
return nil, nil, nil, err
}
+
+ return dc, dc.releaseConn, ds, nil
+}
+
+// prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of
+// open connStmt on the statement. It assumes the caller is holding the lock on dc.
+func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
+ si, err := dc.prepareLocked(ctx, s.cg, s.query)
+ if err != nil {
+ return nil, err
+ }
+ cs := connStmt{dc, si}
s.mu.Lock()
- cs := connStmt{dc, ds}
s.css = append(s.css, cs)
s.mu.Unlock()
-
- return dc, dc.releaseConn, ds, nil
+ return cs.ds, nil
}
// QueryContext executes a prepared query statement with the given arguments
@@ -1937,8 +2214,12 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
defer s.closemu.RUnlock()
var rowsi driver.Rows
- for i := 0; i < maxBadConnRetries; i++ {
- dc, releaseConn, ds, err := s.connStmt(ctx)
+ strategy := cachedOrNewConn
+ for i := 0; i < maxBadConnRetries+1; i++ {
+ if i == maxBadConnRetries {
+ strategy = alwaysNewConn
+ }
+ dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
if err != nil {
if err == driver.ErrBadConn {
continue
@@ -1946,7 +2227,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
return nil, err
}
- rowsi, err = rowsiFromStatement(ctx, ds, args...)
+ rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...)
if err == nil {
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
@@ -1955,12 +2236,21 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
rowsi: rowsi,
// releaseConn set below
}
+ // addDep must be added before initContextClose or it could attempt
+ // to removeDep before it has been added.
s.db.addDep(s, rows)
+
+ // releaseConn must be set before initContextClose or it could
+ // release the connection before it is set.
rows.releaseConn = func(err error) {
releaseConn(err)
s.db.removeDep(s, rows)
}
- rows.initContextClose(ctx)
+ var txctx context.Context
+ if s.cg != nil {
+ txctx = s.cg.txCtx()
+ }
+ rows.initContextClose(ctx, txctx)
return rows, nil
}
@@ -1978,7 +2268,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return s.QueryContext(context.Background(), args...)
}
-func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
+func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
var want int
withLock(ds, func() {
want = ds.si.NumInput()
@@ -1991,7 +2281,7 @@ func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}
return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
}
- dargs, err := driverArgs(ds, args)
+ dargs, err := driverArgs(ci, ds, args)
if err != nil {
return nil, err
}
@@ -2054,13 +2344,21 @@ func (s *Stmt) Close() error {
return nil
}
s.closed = true
+ txds := s.cgds
+ s.cgds = nil
+
s.mu.Unlock()
- if s.tx != nil {
- return s.txds.Close()
+ if s.cg == nil {
+ return s.db.removeDep(s, s)
}
- return s.db.removeDep(s, s)
+ if s.parentStmt != nil {
+ // If parentStmt is set, we must not close s.txds since it's stored
+ // in the css array of the parentStmt.
+ return s.db.removeDep(s.parentStmt, s)
+ }
+ return txds.Close()
}
func (s *Stmt) finalClose() error {
@@ -2107,18 +2405,28 @@ type Rows struct {
lasterr error // non-nil only if closed is true
// lastcols is only used in Scan, Next, and NextResultSet which are expected
- // not not be called concurrently.
+ // not to be called concurrently.
lastcols []driver.Value
}
-func (rs *Rows) initContextClose(ctx context.Context) {
+func (rs *Rows) initContextClose(ctx, txctx context.Context) {
ctx, rs.cancel = context.WithCancel(ctx)
- go rs.awaitDone(ctx)
+ go rs.awaitDone(ctx, txctx)
}
-// awaitDone blocks until the rows are closed or the context canceled.
-func (rs *Rows) awaitDone(ctx context.Context) {
- <-ctx.Done()
+// awaitDone blocks until either ctx or txctx is canceled. The ctx is provided
+// from the query context and is canceled when the query Rows is closed.
+// If the query was issued in a transaction, the transaction's context
+// is also provided in txctx to ensure Rows is closed if the Tx is closed.
+func (rs *Rows) awaitDone(ctx, txctx context.Context) {
+ var txctxDone <-chan struct{}
+ if txctx != nil {
+ txctxDone = txctx.Done()
+ }
+ select {
+ case <-ctx.Done():
+ case <-txctxDone:
+ }
rs.close(ctx.Err())
}
@@ -2407,7 +2715,7 @@ func (rs *Rows) Scan(dest ...interface{}) error {
}
// rowsCloseHook returns a function so tests may install the
-// hook throug a test only mutex.
+// hook through a test only mutex.
var rowsCloseHook = func() func(*Rows, *error) { return nil }
// Close closes the Rows, preventing further enumeration. If Next is called
@@ -2431,7 +2739,9 @@ func (rs *Rows) close(err error) error {
rs.lasterr = err
}
- err = rs.rowsi.Close()
+ withLock(rs.dc, func() {
+ err = rs.rowsi.Close()
+ })
if fn := rowsCloseHook(); fn != nil {
fn(rs, &err)
}
diff --git a/libgo/go/database/sql/sql_test.go b/libgo/go/database/sql/sql_test.go
index 381aafc..c935eb4 100644
--- a/libgo/go/database/sql/sql_test.go
+++ b/libgo/go/database/sql/sql_test.go
@@ -139,6 +139,7 @@ func closeDB(t testing.TB, db *DB) {
t.Errorf("Error closing fakeConn: %v", err)
}
})
+ db.mu.Lock()
for i, dc := range db.freeConn {
if n := len(dc.openStmt); n > 0 {
// Just a sanity check. This is legal in
@@ -149,6 +150,8 @@ func closeDB(t testing.TB, db *DB) {
t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, len(db.freeConn), n)
}
}
+ db.mu.Unlock()
+
err := db.Close()
if err != nil {
t.Fatalf("error closing DB: %v", err)
@@ -322,7 +325,7 @@ func TestQueryContext(t *testing.T) {
select {
case <-ctx.Done():
if err := ctx.Err(); err != context.Canceled {
- t.Fatalf("context err = %v; want context.Canceled", ctx.Err())
+ t.Fatalf("context err = %v; want context.Canceled", err)
}
default:
t.Fatalf("context err = nil; want context.Canceled")
@@ -413,7 +416,7 @@ func TestTxContextWait(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
- ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*15)
+ ctx, cancel := context.WithTimeout(context.Background(), 15*time.Millisecond)
defer cancel()
tx, err := db.BeginTx(ctx, nil)
@@ -590,13 +593,13 @@ func TestPoolExhaustOnCancel(t *testing.T) {
saturate.Wait()
// Now cancel the request while it is waiting.
- ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
for i := 0; i < max; i++ {
ctxReq, cancelReq := context.WithCancel(ctx)
go func() {
- time.Sleep(time.Millisecond * 100)
+ time.Sleep(100 * time.Millisecond)
cancelReq()
}()
err := db.PingContext(ctxReq)
@@ -874,7 +877,7 @@ func TestStatementClose(t *testing.T) {
msg string
}{
{&Stmt{stickyErr: want}, "stickyErr not propagated"},
- {&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
+ {&Stmt{cg: &Tx{}, cgds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
}
for _, test := range tests {
if err := test.stmt.Close(); err != want {
@@ -1024,6 +1027,196 @@ func TestTxStmt(t *testing.T) {
}
}
+func TestTxStmtPreparedOnce(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32")
+
+ prepares0 := numPrepares(t, db)
+
+ // db.Prepare increments numPrepares.
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ defer stmt.Close()
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+
+ txs1 := tx.Stmt(stmt)
+ txs2 := tx.Stmt(stmt)
+
+ _, err = txs1.Exec("Go", 7)
+ if err != nil {
+ t.Fatalf("Exec = %v", err)
+ }
+ txs1.Close()
+
+ _, err = txs2.Exec("Gopher", 8)
+ if err != nil {
+ t.Fatalf("Exec = %v", err)
+ }
+ txs2.Close()
+
+ err = tx.Commit()
+ if err != nil {
+ t.Fatalf("Commit = %v", err)
+ }
+
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+func TestTxStmtClosedRePrepares(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32")
+
+ prepares0 := numPrepares(t, db)
+
+ // db.Prepare increments numPrepares.
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+ err = stmt.Close()
+ if err != nil {
+ t.Fatalf("stmt.Close() = %v", err)
+ }
+ // tx.Stmt increments numPrepares because stmt is closed.
+ txs := tx.Stmt(stmt)
+ if txs.stickyErr != nil {
+ t.Fatal(txs.stickyErr)
+ }
+ if txs.parentStmt != nil {
+ t.Fatal("expected nil parentStmt")
+ }
+ _, err = txs.Exec(`Eric`, 82)
+ if err != nil {
+ t.Fatalf("txs.Exec = %v", err)
+ }
+
+ err = txs.Close()
+ if err != nil {
+ t.Fatalf("txs.Close = %v", err)
+ }
+
+ tx.Rollback()
+
+ if prepares := numPrepares(t, db) - prepares0; prepares != 2 {
+ t.Errorf("executed %d Prepare statements; want 2", prepares)
+ }
+}
+
+func TestParentStmtOutlivesTxStmt(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32")
+
+ // Make sure everything happens on the same connection.
+ db.SetMaxOpenConns(1)
+
+ prepares0 := numPrepares(t, db)
+
+ // db.Prepare increments numPrepares.
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ defer stmt.Close()
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+ txs := tx.Stmt(stmt)
+ if len(stmt.css) != 1 {
+ t.Fatalf("len(stmt.css) = %v; want 1", len(stmt.css))
+ }
+ err = txs.Close()
+ if err != nil {
+ t.Fatalf("txs.Close() = %v", err)
+ }
+ err = tx.Rollback()
+ if err != nil {
+ t.Fatalf("tx.Rollback() = %v", err)
+ }
+ // txs must not be valid.
+ _, err = txs.Exec("Suzan", 30)
+ if err == nil {
+ t.Fatalf("txs.Exec(), expected err")
+ }
+ // Stmt must still be valid.
+ _, err = stmt.Exec("Janina", 25)
+ if err != nil {
+ t.Fatalf("stmt.Exec() = %v", err)
+ }
+
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+// Test that tx.Stmt called with a statement already
+// associated with tx as argument re-prepares the same
+// statement again.
+func TestTxStmtFromTxStmtRePrepares(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32")
+ prepares0 := numPrepares(t, db)
+ // db.Prepare increments numPrepares.
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ defer stmt.Close()
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+ txs1 := tx.Stmt(stmt)
+
+ // tx.Stmt(txs1) increments numPrepares because txs1 already
+ // belongs to a transaction (albeit the same transaction).
+ txs2 := tx.Stmt(txs1)
+ if txs2.stickyErr != nil {
+ t.Fatal(txs2.stickyErr)
+ }
+ if txs2.parentStmt != nil {
+ t.Fatal("expected nil parentStmt")
+ }
+ _, err = txs2.Exec(`Eric`, 82)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = txs1.Close()
+ if err != nil {
+ t.Fatalf("txs1.Close = %v", err)
+ }
+ err = txs2.Close()
+ if err != nil {
+ t.Fatalf("txs1.Close = %v", err)
+ }
+ err = tx.Rollback()
+ if err != nil {
+ t.Fatalf("tx.Rollback = %v", err)
+ }
+
+ if prepares := numPrepares(t, db) - prepares0; prepares != 2 {
+ t.Errorf("executed %d Prepare statements; want 2", prepares)
+ }
+}
+
// Issue: https://golang.org/issue/2784
// This test didn't fail before because we got lucky with the fakedb driver.
// It was failing, and now not, in github.com/bradfitz/go-sql-test
@@ -1108,6 +1301,69 @@ func TestTxErrBadConn(t *testing.T) {
}
}
+func TestConnQuery(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ conn, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ var name string
+ err = conn.QueryRowContext(ctx, "SELECT|people|name|age=?", 3).Scan(&name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if name != "Chris" {
+ t.Fatalf("unexpected result, got %q want Chris", name)
+ }
+
+ err = conn.PingContext(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestConnTx(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ conn, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ tx, err := conn.BeginTx(ctx, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ insertName, insertAge := "Nancy", 33
+ _, err = tx.ExecContext(ctx, "INSERT|people|name=?,age=?,photo=APHOTO", insertName, insertAge)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = tx.Commit()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var selectName string
+ err = conn.QueryRowContext(ctx, "SELECT|people|name|age=?", insertAge).Scan(&selectName)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if selectName != insertName {
+ t.Fatalf("got %q want %q", selectName, insertName)
+ }
+}
+
// Tests fix for issue 2542, that we release a lock when querying on
// a closed connection.
func TestIssue2542Deadlock(t *testing.T) {
@@ -1831,8 +2087,8 @@ func TestConnMaxLifetime(t *testing.T) {
}
// Expire first conn
- offset = time.Second * 11
- db.SetConnMaxLifetime(time.Second * 10)
+ offset = 11 * time.Second
+ db.SetConnMaxLifetime(10 * time.Second)
if err != nil {
t.Fatal(err)
}
@@ -2078,9 +2334,13 @@ func TestStmtCloseOrder(t *testing.T) {
// Test cases where there's more than maxBadConnRetries bad connections in the
// pool (issue 8834)
func TestManyErrBadConn(t *testing.T) {
- manyErrBadConnSetup := func() *DB {
+ manyErrBadConnSetup := func(first ...func(db *DB)) *DB {
db := newTestDB(t, "people")
+ for _, f := range first {
+ f(db)
+ }
+
nconn := maxBadConnRetries + 1
db.SetMaxIdleConns(nconn)
db.SetMaxOpenConns(nconn)
@@ -2148,6 +2408,128 @@ func TestManyErrBadConn(t *testing.T) {
if err = stmt.Close(); err != nil {
t.Fatal(err)
}
+
+ // Stmt.Exec
+ db = manyErrBadConnSetup(func(db *DB) {
+ stmt, err = db.Prepare("INSERT|people|name=Julia,age=19")
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ defer closeDB(t, db)
+ _, err = stmt.Exec()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err = stmt.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ // Stmt.Query
+ db = manyErrBadConnSetup(func(db *DB) {
+ stmt, err = db.Prepare("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ defer closeDB(t, db)
+ rows, err = stmt.Query()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err = rows.Close(); err != nil {
+ t.Fatal(err)
+ }
+ if err = stmt.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ // Conn
+ db = manyErrBadConnSetup()
+ defer closeDB(t, db)
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ conn, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = conn.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Ping
+ db = manyErrBadConnSetup()
+ defer closeDB(t, db)
+ err = db.PingContext(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+// TestIssue20575 ensures the Rows from query does not block
+// closing a transaction. Ensure Rows is closed while closing a trasaction.
+func TestIssue20575(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ _, err = tx.QueryContext(ctx, "SELECT|people|age,name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Do not close Rows from QueryContext.
+ err = tx.Rollback()
+ if err != nil {
+ t.Fatal(err)
+ }
+ select {
+ default:
+ case <-ctx.Done():
+ t.Fatal("timeout: failed to rollback query without closing rows:", ctx.Err())
+ }
+}
+
+// TestIssue20622 tests closing the transaction before rows is closed, requires
+// the race detector to fail.
+func TestIssue20622(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ tx, err := db.BeginTx(ctx, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rows, err := tx.Query("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ count := 0
+ for rows.Next() {
+ count++
+ var age int
+ var name string
+ if err := rows.Scan(&age, &name); err != nil {
+ t.Fatal("scan failed", err)
+ }
+
+ if count == 1 {
+ cancel()
+ }
+ time.Sleep(100 * time.Millisecond)
+ }
+ rows.Close()
+ tx.Commit()
}
// golang.org/issue/5718
@@ -2751,7 +3133,7 @@ func TestIssue18429(t *testing.T) {
if err != nil {
return
}
- // This is expected to give a cancel error many, but not all the time.
+ // This is expected to give a cancel error most, but not all the time.
// Test failure will happen with a panic or other race condition being
// reported.
rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
@@ -2766,6 +3148,46 @@ func TestIssue18429(t *testing.T) {
wg.Wait()
}
+// TestIssue20160 attempts to test a short context life on a stmt Query.
+func TestIssue20160(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx := context.Background()
+ sem := make(chan bool, 20)
+ var wg sync.WaitGroup
+
+ const milliWait = 30
+
+ stmt, err := db.PrepareContext(ctx, "SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stmt.Close()
+
+ for i := 0; i < 100; i++ {
+ sem <- true
+ wg.Add(1)
+ go func() {
+ defer func() {
+ <-sem
+ wg.Done()
+ }()
+ ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond)
+ defer cancel()
+
+ // This is expected to give a cancel error most, but not all the time.
+ // Test failure will happen with a panic or other race condition being
+ // reported.
+ rows, _ := stmt.QueryContext(ctx)
+ if rows != nil {
+ rows.Close()
+ }
+ }()
+ }
+ wg.Wait()
+}
+
// TestIssue18719 closes the context right before use. The sql.driverConn
// will nil out the ci on close in a lock, but if another process uses it right after
// it will panic with on the nil ref.
@@ -2788,7 +3210,7 @@ func TestIssue18719(t *testing.T) {
// Wait for the context to cancel and tx to rollback.
for tx.isDone() == false {
- time.Sleep(time.Millisecond * 3)
+ time.Sleep(3 * time.Millisecond)
}
}
defer func() { hookTxGrabConn = nil }()
@@ -2807,19 +3229,64 @@ func TestIssue18719(t *testing.T) {
// canceled context.
cancel()
- waitForRowsClose(t, rows, 5*time.Second)
+}
+
+func TestIssue20647(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ conn, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stmt.Close()
+
+ rows1, err := stmt.QueryContext(ctx)
+ if err != nil {
+ t.Fatal("rows1", err)
+ }
+ defer rows1.Close()
+
+ rows2, err := stmt.QueryContext(ctx)
+ if err != nil {
+ t.Fatal("rows2", err)
+ }
+ defer rows2.Close()
+
+ if rows1.dc != rows2.dc {
+ t.Fatal("stmt prepared on Conn does not use same connection")
+ }
}
func TestConcurrency(t *testing.T) {
- doConcurrentTest(t, new(concurrentDBQueryTest))
- doConcurrentTest(t, new(concurrentDBExecTest))
- doConcurrentTest(t, new(concurrentStmtQueryTest))
- doConcurrentTest(t, new(concurrentStmtExecTest))
- doConcurrentTest(t, new(concurrentTxQueryTest))
- doConcurrentTest(t, new(concurrentTxExecTest))
- doConcurrentTest(t, new(concurrentTxStmtQueryTest))
- doConcurrentTest(t, new(concurrentTxStmtExecTest))
- doConcurrentTest(t, new(concurrentRandomTest))
+ list := []struct {
+ name string
+ ct concurrentTest
+ }{
+ {"Query", new(concurrentDBQueryTest)},
+ {"Exec", new(concurrentDBExecTest)},
+ {"StmtQuery", new(concurrentStmtQueryTest)},
+ {"StmtExec", new(concurrentStmtExecTest)},
+ {"TxQuery", new(concurrentTxQueryTest)},
+ {"TxExec", new(concurrentTxExecTest)},
+ {"TxStmtQuery", new(concurrentTxStmtQueryTest)},
+ {"TxStmtExec", new(concurrentTxStmtExecTest)},
+ {"Random", new(concurrentRandomTest)},
+ }
+ for _, item := range list {
+ t.Run(item.name, func(t *testing.T) {
+ doConcurrentTest(t, item.ct)
+ })
+ }
}
func TestConnectionLeak(t *testing.T) {
@@ -2874,6 +3341,131 @@ func TestConnectionLeak(t *testing.T) {
wg.Wait()
}
+type nvcDriver struct {
+ fakeDriver
+ skipNamedValueCheck bool
+}
+
+func (d *nvcDriver) Open(dsn string) (driver.Conn, error) {
+ c, err := d.fakeDriver.Open(dsn)
+ fc := c.(*fakeConn)
+ fc.db.allowAny = true
+ return &nvcConn{fc, d.skipNamedValueCheck}, err
+}
+
+type nvcConn struct {
+ *fakeConn
+ skipNamedValueCheck bool
+}
+
+type decimal struct {
+ value int
+}
+
+type doNotInclude struct{}
+
+var _ driver.NamedValueChecker = &nvcConn{}
+
+func (c *nvcConn) CheckNamedValue(nv *driver.NamedValue) error {
+ if c.skipNamedValueCheck {
+ return driver.ErrSkip
+ }
+ switch v := nv.Value.(type) {
+ default:
+ return driver.ErrSkip
+ case Out:
+ switch ov := v.Dest.(type) {
+ default:
+ return errors.New("unkown NameValueCheck OUTPUT type")
+ case *string:
+ *ov = "from-server"
+ nv.Value = "OUT:*string"
+ }
+ return nil
+ case decimal, []int64:
+ return nil
+ case doNotInclude:
+ return driver.ErrRemoveArgument
+ }
+}
+
+func TestNamedValueChecker(t *testing.T) {
+ Register("NamedValueCheck", &nvcDriver{})
+ db, err := Open("NamedValueCheck", "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ _, err = db.ExecContext(ctx, "WIPE")
+ if err != nil {
+ t.Fatal("exec wipe", err)
+ }
+
+ _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any,str1=string,out1=string,array1=any")
+ if err != nil {
+ t.Fatal("exec create", err)
+ }
+
+ o1 := ""
+ _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A,str1=?,out1=?O1,array1=?", Named("A", decimal{123}), "hello", Named("O1", Out{Dest: &o1}), []int64{42, 128, 707}, doNotInclude{})
+ if err != nil {
+ t.Fatal("exec insert", err)
+ }
+ var (
+ str1 string
+ dec1 decimal
+ arr1 []int64
+ )
+ err = db.QueryRowContext(ctx, "SELECT|keys|dec1,str1,array1|").Scan(&dec1, &str1, &arr1)
+ if err != nil {
+ t.Fatal("select", err)
+ }
+
+ list := []struct{ got, want interface{} }{
+ {o1, "from-server"},
+ {dec1, decimal{123}},
+ {str1, "hello"},
+ {arr1, []int64{42, 128, 707}},
+ }
+
+ for index, item := range list {
+ if !reflect.DeepEqual(item.got, item.want) {
+ t.Errorf("got %#v wanted %#v for index %d", item.got, item.want, index)
+ }
+ }
+}
+
+func TestNamedValueCheckerSkip(t *testing.T) {
+ Register("NamedValueCheckSkip", &nvcDriver{skipNamedValueCheck: true})
+ db, err := Open("NamedValueCheckSkip", "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ _, err = db.ExecContext(ctx, "WIPE")
+ if err != nil {
+ t.Fatal("exec wipe", err)
+ }
+
+ _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any")
+ if err != nil {
+ t.Fatal("exec create", err)
+ }
+
+ _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A", Named("A", decimal{123}))
+ if err == nil {
+ t.Fatalf("expected error with bad argument, got %v", err)
+ }
+}
+
// badConn implements a bad driver.Conn, for TestBadDriver.
// The Exec method panics.
type badConn struct{}
@@ -2965,6 +3557,24 @@ func TestPing(t *testing.T) {
}
}
+// Issue 18101.
+func TestTypedString(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ type Str string
+ var scanned Str
+
+ err := db.QueryRow("SELECT|people|name|name=?", "Alice").Scan(&scanned)
+ if err != nil {
+ t.Fatal(err)
+ }
+ expected := Str("Alice")
+ if scanned != expected {
+ t.Errorf("expected %+v, got %+v", expected, scanned)
+ }
+}
+
func BenchmarkConcurrentDBExec(b *testing.B) {
b.ReportAllocs()
ct := new(concurrentDBExecTest)