diff options
author | Ian Lance Taylor <iant@golang.org> | 2017-09-14 17:11:35 +0000 |
---|---|---|
committer | Ian Lance Taylor <ian@gcc.gnu.org> | 2017-09-14 17:11:35 +0000 |
commit | bc998d034f45d1828a8663b2eed928faf22a7d01 (patch) | |
tree | 8d262a22ca7318f4bcd64269fe8fe9e45bcf8d0f /libgo/go/database | |
parent | a41a6142df74219f596e612d3a7775f68ca6e96f (diff) | |
download | gcc-bc998d034f45d1828a8663b2eed928faf22a7d01.zip gcc-bc998d034f45d1828a8663b2eed928faf22a7d01.tar.gz gcc-bc998d034f45d1828a8663b2eed928faf22a7d01.tar.bz2 |
libgo: update to go1.9
Reviewed-on: https://go-review.googlesource.com/63753
From-SVN: r252767
Diffstat (limited to 'libgo/go/database')
-rw-r--r-- | libgo/go/database/sql/convert.go | 215 | ||||
-rw-r--r-- | libgo/go/database/sql/convert_test.go | 18 | ||||
-rw-r--r-- | libgo/go/database/sql/driver/driver.go | 30 | ||||
-rw-r--r-- | libgo/go/database/sql/fakedb_test.go | 94 | ||||
-rw-r--r-- | libgo/go/database/sql/sql.go | 684 | ||||
-rw-r--r-- | libgo/go/database/sql/sql_test.go | 650 |
6 files changed, 1405 insertions, 286 deletions
diff --git a/libgo/go/database/sql/convert.go b/libgo/go/database/sql/convert.go index ea2f377..4983181 100644 --- a/libgo/go/database/sql/convert.go +++ b/libgo/go/database/sql/convert.go @@ -12,6 +12,7 @@ import ( "fmt" "reflect" "strconv" + "sync" "time" "unicode" "unicode/utf8" @@ -37,86 +38,180 @@ func validateNamedValueName(name string) error { return fmt.Errorf("name %q does not begin with a letter", name) } +func driverNumInput(ds *driverStmt) int { + ds.Lock() + defer ds.Unlock() // in case NumInput panics + return ds.si.NumInput() +} + +// ccChecker wraps the driver.ColumnConverter and allows it to be used +// as if it were a NamedValueChecker. If the driver ColumnConverter +// is not present then the NamedValueChecker will return driver.ErrSkip. +type ccChecker struct { + sync.Locker + cci driver.ColumnConverter + want int +} + +func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error { + if c.cci == nil { + return driver.ErrSkip + } + // The column converter shouldn't be called on any index + // it isn't expecting. The final error will be thrown + // in the argument converter loop. + index := nv.Ordinal - 1 + if c.want <= index { + return nil + } + + // First, see if the value itself knows how to convert + // itself to a driver type. For example, a NullString + // struct changing into a string or nil. + if vr, ok := nv.Value.(driver.Valuer); ok { + sv, err := callValuerValue(vr) + if err != nil { + return err + } + if !driver.IsValue(sv) { + return fmt.Errorf("non-subset type %T returned from Value", sv) + } + nv.Value = sv + } + + // Second, ask the column to sanity check itself. For + // example, drivers might use this to make sure that + // an int64 values being inserted into a 16-bit + // integer field is in range (before getting + // truncated), or that a nil can't go into a NOT NULL + // column before going across the network to get the + // same error. + var err error + arg := nv.Value + c.Lock() + nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg) + c.Unlock() + if err != nil { + return err + } + if !driver.IsValue(nv.Value) { + return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value) + } + return nil +} + +// defaultCheckNamedValue wraps the default ColumnConverter to have the same +// function signature as the CheckNamedValue in the driver.NamedValueChecker +// interface. +func defaultCheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value) + return err +} + // driverArgs converts arguments from callers of Stmt.Exec and // Stmt.Query into driver Values. // // The statement ds may be nil, if no statement is available. -func driverArgs(ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) { +func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) { nvargs := make([]driver.NamedValue, len(args)) + + // -1 means the driver doesn't know how to count the number of + // placeholders, so we won't sanity check input here and instead let the + // driver deal with errors. + want := -1 + var si driver.Stmt + var cc ccChecker if ds != nil { si = ds.si + want = driverNumInput(ds) + cc.Locker = ds.Locker + cc.want = want } - cc, ok := si.(driver.ColumnConverter) - // Normal path, for a driver.Stmt that is not a ColumnConverter. + // Check all types of interfaces from the start. + // Drivers may opt to use the NamedValueChecker for special + // argument types, then return driver.ErrSkip to pass it along + // to the column converter. + nvc, ok := si.(driver.NamedValueChecker) if !ok { - for n, arg := range args { - var err error - nv := &nvargs[n] - nv.Ordinal = n + 1 - if np, ok := arg.(NamedArg); ok { - if err := validateNamedValueName(np.Name); err != nil { - return nil, err - } - arg = np.Value - nvargs[n].Name = np.Name - } - nv.Value, err = driver.DefaultParameterConverter.ConvertValue(arg) - - if err != nil { - return nil, fmt.Errorf("sql: converting Exec argument %s type: %v", describeNamedValue(nv), err) - } - } - return nvargs, nil + nvc, ok = ci.(driver.NamedValueChecker) + } + cci, ok := si.(driver.ColumnConverter) + if ok { + cc.cci = cci } - // Let the Stmt convert its own arguments. - for n, arg := range args { + // Loop through all the arguments, checking each one. + // If no error is returned simply increment the index + // and continue. However if driver.ErrRemoveArgument + // is returned the argument is not included in the query + // argument list. + var err error + var n int + for _, arg := range args { nv := &nvargs[n] - nv.Ordinal = n + 1 if np, ok := arg.(NamedArg); ok { - if err := validateNamedValueName(np.Name); err != nil { + if err = validateNamedValueName(np.Name); err != nil { return nil, err } arg = np.Value nv.Name = np.Name } - // First, see if the value itself knows how to convert - // itself to a driver type. For example, a NullString - // struct changing into a string or nil. - if vr, ok := arg.(driver.Valuer); ok { - sv, err := callValuerValue(vr) - if err != nil { - return nil, fmt.Errorf("sql: argument %s from Value: %v", describeNamedValue(nv), err) - } - if !driver.IsValue(sv) { - return nil, fmt.Errorf("sql: argument %s: non-subset type %T returned from Value", describeNamedValue(nv), sv) - } - arg = sv + nv.Ordinal = n + 1 + nv.Value = arg + + // Checking sequence has four routes: + // A: 1. Default + // B: 1. NamedValueChecker 2. Column Converter 3. Default + // C: 1. NamedValueChecker 3. Default + // D: 1. Column Converter 2. Default + // + // The only time a Column Converter is called is first + // or after NamedValueConverter. If first it is handled before + // the nextCheck label. Thus for repeats tries only when the + // NamedValueConverter is selected should the Column Converter + // be used in the retry. + checker := defaultCheckNamedValue + nextCC := false + switch { + case nvc != nil: + nextCC = cci != nil + checker = nvc.CheckNamedValue + case cci != nil: + checker = cc.CheckNamedValue } - // Second, ask the column to sanity check itself. For - // example, drivers might use this to make sure that - // an int64 values being inserted into a 16-bit - // integer field is in range (before getting - // truncated), or that a nil can't go into a NOT NULL - // column before going across the network to get the - // same error. - var err error - ds.Lock() - nv.Value, err = cc.ColumnConverter(n).ConvertValue(arg) - ds.Unlock() - if err != nil { + nextCheck: + err = checker(nv) + switch err { + case nil: + n++ + continue + case driver.ErrRemoveArgument: + nvargs = nvargs[:len(nvargs)-1] + continue + case driver.ErrSkip: + if nextCC { + nextCC = false + checker = cc.CheckNamedValue + } else { + checker = defaultCheckNamedValue + } + goto nextCheck + default: return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err) } - if !driver.IsValue(nv.Value) { - return nil, fmt.Errorf("sql: for argument %s, driver ColumnConverter error converted %T to unsupported type %T", - describeNamedValue(nv), arg, nv.Value) - } + } + + // Check the length of arguments after convertion to allow for omitted + // arguments. + if want != -1 && len(nvargs) != want { + return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs)) } return nvargs, nil + } // convertAssign copies to dest the value in src, converting it if possible. @@ -270,6 +365,11 @@ func convertAssign(dest, src interface{}) error { return nil } + // The following conversions use a string value as an intermediate representation + // to convert between various numeric types. + // + // This also allows scanning into user defined types such as "type Int int64". + // For symmetry, also check for string destination types. switch dv.Kind() { case reflect.Ptr: if src == nil { @@ -306,6 +406,15 @@ func convertAssign(dest, src interface{}) error { } dv.SetFloat(f64) return nil + case reflect.String: + switch v := src.(type) { + case string: + dv.SetString(v) + return nil + case []byte: + dv.SetString(string(v)) + return nil + } } return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) diff --git a/libgo/go/database/sql/convert_test.go b/libgo/go/database/sql/convert_test.go index 4dfab1f..cfe52d7 100644 --- a/libgo/go/database/sql/convert_test.go +++ b/libgo/go/database/sql/convert_test.go @@ -10,6 +10,7 @@ import ( "reflect" "runtime" "strings" + "sync" "testing" "time" ) @@ -17,9 +18,11 @@ import ( var someTime = time.Unix(123, 0) var answer int64 = 42 -type userDefined float64 - -type userDefinedSlice []int +type ( + userDefined float64 + userDefinedSlice []int + userDefinedString string +) type conversionTest struct { s, d interface{} // source and destination @@ -39,6 +42,7 @@ type conversionTest struct { wantptr *int64 // if non-nil, *d's pointed value must be equal to *wantptr wantnil bool // if true, *d must be *int64(nil) wantusrdef userDefined + wantusrstr userDefinedString } // Target variables for scanning into. @@ -171,6 +175,7 @@ var conversionTests = []conversionTest{ {s: int64(123), d: new(userDefined), wantusrdef: 123}, {s: "1.5", d: new(userDefined), wantusrdef: 1.5}, {s: []byte{1, 2, 3}, d: new(userDefinedSlice), wanterr: `unsupported Scan, storing driver.Value type []uint8 into type *sql.userDefinedSlice`}, + {s: "str", d: new(userDefinedString), wantusrstr: "str"}, // Other errors {s: complex(1, 2), d: &scanstr, wanterr: `unsupported Scan, storing driver.Value type complex128 into type *string`}, @@ -260,6 +265,9 @@ func TestConversions(t *testing.T) { if ct.wantusrdef != 0 && ct.wantusrdef != *ct.d.(*userDefined) { errf("want userDefined %f, got %f", ct.wantusrdef, *ct.d.(*userDefined)) } + if len(ct.wantusrstr) != 0 && ct.wantusrstr != *ct.d.(*userDefinedString) { + errf("want userDefined %q, got %q", ct.wantusrstr, *ct.d.(*userDefinedString)) + } } } @@ -461,8 +469,8 @@ func TestDriverArgs(t *testing.T) { }, } for i, tt := range tests { - ds := new(driverStmt) - got, err := driverArgs(ds, tt.args) + ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}} + got, err := driverArgs(nil, ds, tt.args) if err != nil { t.Errorf("test[%d]: %v", i, err) continue diff --git a/libgo/go/database/sql/driver/driver.go b/libgo/go/database/sql/driver/driver.go index d66196f..0262ca2 100644 --- a/libgo/go/database/sql/driver/driver.go +++ b/libgo/go/database/sql/driver/driver.go @@ -262,9 +262,39 @@ type StmtQueryContext interface { QueryContext(ctx context.Context, args []NamedValue) (Rows, error) } +// ErrRemoveArgument may be returned from NamedValueChecker to instruct the +// sql package to not pass the argument to the driver query interface. +// Return when accepting query specific options or structures that aren't +// SQL query arguments. +var ErrRemoveArgument = errors.New("driver: remove argument from query") + +// NamedValueChecker may be optionally implemented by Conn or Stmt. It provides +// the driver more control to handle Go and database types beyond the default +// Values types allowed. +// +// The sql package checks for value checkers in the following order, +// stopping at the first found match: Stmt.NamedValueChecker, Conn.NamedValueChecker, +// Stmt.ColumnConverter, DefaultParameterConverter. +// +// If CheckNamedValue returns ErrRemoveArgument, the NamedValue will not be included in +// the final query arguments. This may be used to pass special options to +// the query itself. +// +// If ErrSkip is returned the column converter error checking +// path is used for the argument. Drivers may wish to return ErrSkip after +// they have exhausted their own special cases. +type NamedValueChecker interface { + // CheckNamedValue is called before passing arguments to the driver + // and is called in place of any ColumnConverter. CheckNamedValue must do type + // validation and conversion as appropriate for the driver. + CheckNamedValue(*NamedValue) error +} + // ColumnConverter may be optionally implemented by Stmt if the // statement is aware of its own columns' types and can convert from // any type to a driver Value. +// +// Deprecated: Drivers should implement NamedValueChecker. type ColumnConverter interface { // ColumnConverter returns a ValueConverter for the provided // column index. If the type of a specific column isn't known diff --git a/libgo/go/database/sql/fakedb_test.go b/libgo/go/database/sql/fakedb_test.go index 4b15f5b..4dcd096 100644 --- a/libgo/go/database/sql/fakedb_test.go +++ b/libgo/go/database/sql/fakedb_test.go @@ -58,9 +58,10 @@ type fakeDriver struct { type fakeDB struct { name string - mu sync.Mutex - tables map[string]*table - badConn bool + mu sync.Mutex + tables map[string]*table + badConn bool + allowAny bool } type table struct { @@ -83,11 +84,20 @@ type row struct { cols []interface{} // must be same size as its table colname + coltype } +type memToucher interface { + // touchMem reads & writes some memory, to help find data races. + touchMem() +} + type fakeConn struct { db *fakeDB // where to return ourselves to currTx *fakeTx + // Every operation writes to line to enable the race detector + // check for data races. + line int64 + // Stats for tests: mu sync.Mutex stmtsMade int @@ -99,6 +109,10 @@ type fakeConn struct { stickyBad bool } +func (c *fakeConn) touchMem() { + c.line++ +} + func (c *fakeConn) incrStat(v *int) { c.mu.Lock() *v++ @@ -116,6 +130,7 @@ type boundCol struct { } type fakeStmt struct { + memToucher c *fakeConn q string // just for debugging @@ -298,6 +313,7 @@ func (c *fakeConn) Begin() (driver.Tx, error) { if c.currTx != nil { return nil, errors.New("already in a transaction") } + c.touchMem() c.currTx = &fakeTx{c: c} return c.currTx, nil } @@ -339,6 +355,7 @@ func (c *fakeConn) Close() (err error) { drv.mu.Unlock() } }() + c.touchMem() if c.currTx != nil { return errors.New("can't close fakeConn; in a Transaction") } @@ -352,12 +369,14 @@ func (c *fakeConn) Close() (err error) { return nil } -func checkSubsetTypes(args []driver.NamedValue) error { +func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { for _, arg := range args { switch arg.Value.(type) { case int64, float64, bool, nil, []byte, string, time.Time: default: - return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) + if !allowAny { + return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) + } } } return nil @@ -373,7 +392,7 @@ func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver. // just to check that all the args are of the proper types. // ErrSkip is returned so the caller acts as if we didn't // implement this at all. - err := checkSubsetTypes(args) + err := checkSubsetTypes(c.db.allowAny, args) if err != nil { return nil, err } @@ -390,7 +409,7 @@ func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver // just to check that all the args are of the proper types. // ErrSkip is returned so the caller acts as if we didn't // implement this at all. - err := checkSubsetTypes(args) + err := checkSubsetTypes(c.db.allowAny, args) if err != nil { return nil, err } @@ -524,13 +543,14 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm return nil, driver.ErrBadConn } + c.touchMem() var firstStmt, prev *fakeStmt for _, query := range strings.Split(query, ";") { parts := strings.Split(query, "|") if len(parts) < 1 { return nil, errf("empty query") } - stmt := &fakeStmt{q: query, c: c} + stmt := &fakeStmt{q: query, c: c, memToucher: c} if firstStmt == nil { firstStmt = stmt } @@ -612,6 +632,7 @@ func (s *fakeStmt) Close() error { if s.c.db == nil { panic("in fakeStmt.Close, conn's db is nil (already closed)") } + s.touchMem() if !s.closed { s.c.incrStat(&s.c.stmtsClosed) s.closed = true @@ -642,10 +663,11 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d return nil, driver.ErrBadConn } - err := checkSubsetTypes(args) + err := checkSubsetTypes(s.c.db.allowAny, args) if err != nil { return nil, err } + s.touchMem() if s.wait > 0 { time.Sleep(s.wait) @@ -753,11 +775,12 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( return nil, driver.ErrBadConn } - err := checkSubsetTypes(args) + err := checkSubsetTypes(s.c.db.allowAny, args) if err != nil { return nil, err } + s.touchMem() db := s.c.db if len(args) != s.placeholders { panic("error in pkg db; should only get here if size is correct") @@ -853,11 +876,12 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( } cursor := &rowsCursor{ - posRow: -1, - rows: setMRows, - cols: setColumns, - colType: setColType, - errPos: -1, + parentMem: s.c, + posRow: -1, + rows: setMRows, + cols: setColumns, + colType: setColType, + errPos: -1, } return cursor, nil } @@ -877,6 +901,7 @@ func (tx *fakeTx) Commit() error { if hookCommitBadConn != nil && hookCommitBadConn() { return driver.ErrBadConn } + tx.c.touchMem() return nil } @@ -888,16 +913,18 @@ func (tx *fakeTx) Rollback() error { if hookRollbackBadConn != nil && hookRollbackBadConn() { return driver.ErrBadConn } + tx.c.touchMem() return nil } type rowsCursor struct { - cols [][]string - colType [][]string - posSet int - posRow int - rows [][]*row - closed bool + parentMem memToucher + cols [][]string + colType [][]string + posSet int + posRow int + rows [][]*row + closed bool // errPos and err are for making Next return early with error. errPos int @@ -907,6 +934,16 @@ type rowsCursor struct { // the original slice's first byte address. we clone them // just so we're able to corrupt them on close. bytesClone map[*byte][]byte + + // Every operation writes to line to enable the race detector + // check for data races. + // This is separate from the fakeConn.line to allow for drivers that + // can start multiple queries on the same transaction at the same time. + line int64 +} + +func (rc *rowsCursor) touchMem() { + rc.line++ } func (rc *rowsCursor) Close() error { @@ -915,6 +952,8 @@ func (rc *rowsCursor) Close() error { bs[0] = 255 // first byte corrupted } } + rc.touchMem() + rc.parentMem.touchMem() rc.closed = true return nil } @@ -937,6 +976,7 @@ func (rc *rowsCursor) Next(dest []driver.Value) error { if rc.closed { return errors.New("fakedb: cursor is closed") } + rc.touchMem() rc.posRow++ if rc.posRow == rc.errPos { return rc.err @@ -970,10 +1010,12 @@ func (rc *rowsCursor) Next(dest []driver.Value) error { } func (rc *rowsCursor) HasNextResultSet() bool { + rc.touchMem() return rc.posSet < len(rc.rows)-1 } func (rc *rowsCursor) NextResultSet() error { + rc.touchMem() if rc.HasNextResultSet() { rc.posSet++ rc.posRow = -1 @@ -1004,6 +1046,12 @@ func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { return fmt.Sprintf("%v", v), nil } +type anyTypeConverter struct{} + +func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) { + return v, nil +} + func converterForType(typ string) driver.ValueConverter { switch typ { case "bool": @@ -1030,6 +1078,8 @@ func converterForType(typ string) driver.ValueConverter { return driver.Null{Converter: driver.DefaultParameterConverter} case "datetime": return driver.DefaultParameterConverter + case "any": + return anyTypeConverter{} } panic("invalid fakedb column type of " + typ) } @@ -1056,6 +1106,8 @@ func colTypeToReflectType(typ string) reflect.Type { return reflect.TypeOf(NullFloat64{}) case "datetime": return reflect.TypeOf(time.Time{}) + case "any": + return reflect.TypeOf(new(interface{})).Elem() } panic("invalid fakedb column type of " + typ) } diff --git a/libgo/go/database/sql/sql.go b/libgo/go/database/sql/sql.go index f8a8844..c609fe4 100644 --- a/libgo/go/database/sql/sql.go +++ b/libgo/go/database/sql/sql.go @@ -278,6 +278,27 @@ type Scanner interface { Scan(src interface{}) error } +// Out may be used to retrieve OUTPUT value parameters from stored procedures. +// +// Not all drivers and databases support OUTPUT value parameters. +// +// Example usage: +// +// var outArg string +// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", Out{Dest: &outArg})) +type Out struct { + _Named_Fields_Required struct{} + + // Dest is a pointer to the value that will be set to the result of the + // stored procedure's OUTPUT parameter. + Dest interface{} + + // In is whether the parameter is an INOUT parameter. If so, the input value to the stored + // procedure is the dereferenced value of Dest's pointer, which is then replaced with + // the output value. + In bool +} + // ErrNoRows is returned by Scan when QueryRow doesn't return a // row. In such a case, QueryRow returns a placeholder *Row value that // defers this error until a Scan. @@ -372,11 +393,19 @@ func (dc *driverConn) expired(timeout time.Duration) bool { return dc.createdAt.Add(timeout).Before(nowFunc()) } -func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) { +// prepareLocked prepares the query on dc. When cg == nil the dc must keep track of +// the prepared statements in a pool. +func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) { si, err := ctxDriverPrepare(ctx, dc.ci, query) if err != nil { return nil, err } + ds := &driverStmt{Locker: dc, si: si} + + // No need to manage open statements if there is a single connection grabber. + if cg != nil { + return ds, nil + } // Track each driverConn's open statements, so we can close them // before closing the conn. @@ -385,9 +414,7 @@ func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverS if dc.openStmt == nil { dc.openStmt = make(map[*driverStmt]bool) } - ds := &driverStmt{Locker: dc, si: si} dc.openStmt[ds] = true - return ds, nil } @@ -583,6 +610,17 @@ func Open(driverName, dataSourceName string) (*DB, error) { return db, nil } +func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error { + var err error + if pinger, ok := dc.ci.(driver.Pinger); ok { + withLock(dc, func() { + err = pinger.Ping(ctx) + }) + } + release(err) + return err +} + // PingContext verifies a connection to the database is still alive, // establishing a connection if necessary. func (db *DB) PingContext(ctx context.Context) error { @@ -602,11 +640,7 @@ func (db *DB) PingContext(ctx context.Context) error { return err } - if pinger, ok := dc.ci.(driver.Pinger); ok { - err = pinger.Ping(ctx) - } - db.putConn(dc, err) - return err + return db.pingDC(ctx, dc, dc.releaseConn) } // Ping verifies a connection to the database is still alive, @@ -975,9 +1009,9 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn db: db, createdAt: nowFunc(), ci: ci, + inUse: true, } db.addDepLocked(dc, dc) - dc.inUse = true db.mu.Unlock() return dc, nil } @@ -1137,22 +1171,39 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat if err != nil { return nil, err } + return db.prepareDC(ctx, dc, dc.releaseConn, nil, query) +} + +// prepareDC prepares a query on the driverConn and calls release before +// returning. When cg == nil it implies that a connection pool is used, and +// when cg != nil only a single driver connection is used. +func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) { var ds *driverStmt + var err error + defer func() { + release(err) + }() withLock(dc, func() { - ds, err = dc.prepareLocked(ctx, query) + ds, err = dc.prepareLocked(ctx, cg, query) }) if err != nil { - db.putConn(dc, err) return nil, err } stmt := &Stmt{ - db: db, - query: query, - css: []connStmt{{dc, ds}}, - lastNumClosed: atomic.LoadUint64(&db.numClosed), + db: db, + query: query, + cg: cg, + cgds: ds, + } + + // When cg == nil this statement will need to keep track of various + // connections they are prepared on and record the stmt dependency on + // the DB. + if cg == nil { + stmt.css = []connStmt{{dc, ds}} + stmt.lastNumClosed = atomic.LoadUint64(&db.numClosed) + db.addDep(stmt, stmt) } - db.addDep(stmt, stmt) - db.putConn(dc, nil) return stmt, nil } @@ -1179,18 +1230,21 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) { return db.ExecContext(context.Background(), query, args...) } -func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) { +func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (Result, error) { dc, err := db.conn(ctx, strategy) if err != nil { return nil, err } + return db.execDC(ctx, dc, dc.releaseConn, query, args) +} + +func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), query string, args []interface{}) (res Result, err error) { defer func() { - db.putConn(dc, err) + release(err) }() - if execer, ok := dc.ci.(driver.Execer); ok { var dargs []driver.NamedValue - dargs, err = driverArgs(nil, args) + dargs, err = driverArgs(dc.ci, nil, args) if err != nil { return nil, err } @@ -1215,7 +1269,7 @@ func (db *DB) exec(ctx context.Context, query string, args []interface{}, strate } ds := &driverStmt{Locker: dc, si: si} defer ds.Close() - return resultFromStatement(ctx, ds, args...) + return resultFromStatement(ctx, dc.ci, ds, args...) } // QueryContext executes a query that returns rows, typically a SELECT. @@ -1242,19 +1296,21 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { } func (db *DB) query(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) { - ci, err := db.conn(ctx, strategy) + dc, err := db.conn(ctx, strategy) if err != nil { return nil, err } - return db.queryConn(ctx, ci, ci.releaseConn, query, args) + return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args) } -// queryConn executes a query on the given connection. +// queryDC executes a query on the given connection. // The connection gets released by the releaseConn function. -func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { +// The ctx context is from a query method and the txctx context is from an +// optional transaction context. +func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { if queryer, ok := dc.ci.(driver.Queryer); ok { - dargs, err := driverArgs(nil, args) + dargs, err := driverArgs(dc.ci, nil, args) if err != nil { releaseConn(err) return nil, err @@ -1275,7 +1331,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er releaseConn: releaseConn, rowsi: rowsi, } - rows.initContextClose(ctx) + rows.initContextClose(ctx, txctx) return rows, nil } } @@ -1291,7 +1347,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er } ds := &driverStmt{Locker: dc, si: si} - rowsi, err := rowsiFromStatement(ctx, ds, args...) + rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...) if err != nil { ds.Close() releaseConn(err) @@ -1306,13 +1362,16 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er rowsi: rowsi, closeStmt: ds, } - rows.initContextClose(ctx) + rows.initContextClose(ctx, txctx) return rows, nil } // QueryRowContext executes a query that is expected to return at most one row. // QueryRowContext always returns a non-nil value. Errors are deferred until // Row's Scan method is called. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { rows, err := db.QueryContext(ctx, query, args...) return &Row{rows: rows, err: err} @@ -1321,6 +1380,9 @@ func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interfa // QueryRow executes a query that is expected to return at most one row. // QueryRow always returns a non-nil value. Errors are deferred until // Row's Scan method is called. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. func (db *DB) QueryRow(query string, args ...interface{}) *Row { return db.QueryRowContext(context.Background(), query, args...) } @@ -1361,12 +1423,17 @@ func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStra if err != nil { return nil, err } + return db.beginDC(ctx, dc, dc.releaseConn, opts) +} + +// beginDC starts a transaction. The provided dc must be valid and ready to use. +func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) { var txi driver.Tx withLock(dc, func() { txi, err = ctxDriverBegin(ctx, opts, dc.ci) }) if err != nil { - db.putConn(dc, err) + release(err) return nil, err } @@ -1374,11 +1441,12 @@ func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStra // The cancel function in Tx will be called after done is set to true. ctx, cancel := context.WithCancel(ctx) tx = &Tx{ - db: db, - dc: dc, - txi: txi, - cancel: cancel, - ctx: ctx, + db: db, + dc: dc, + releaseConn: release, + txi: txi, + cancel: cancel, + ctx: ctx, } go tx.awaitDone() return tx, nil @@ -1389,6 +1457,189 @@ func (db *DB) Driver() driver.Driver { return db.driver } +// ErrConnDone is returned by any operation that is performed on a connection +// that has already been committed or rolled back. +var ErrConnDone = errors.New("database/sql: connection is already closed") + +// Conn returns a single connection by either opening a new connection +// or returning an existing connection from the connection pool. Conn will +// block until either a connection is returned or ctx is canceled. +// Queries run on the same Conn will be run in the same database session. +// +// Every Conn must be returned to the database pool after use by +// calling Conn.Close. +func (db *DB) Conn(ctx context.Context) (*Conn, error) { + var dc *driverConn + var err error + for i := 0; i < maxBadConnRetries; i++ { + dc, err = db.conn(ctx, cachedOrNewConn) + if err != driver.ErrBadConn { + break + } + } + if err == driver.ErrBadConn { + dc, err = db.conn(ctx, cachedOrNewConn) + } + if err != nil { + return nil, err + } + + conn := &Conn{ + db: db, + dc: dc, + } + return conn, nil +} + +type releaseConn func(error) + +// Conn represents a single database session rather a pool of database +// sessions. Prefer running queries from DB unless there is a specific +// need for a continuous single database session. +// +// A Conn must call Close to return the connection to the database pool +// and may do so concurrently with a running query. +// +// After a call to Close, all operations on the +// connection fail with ErrConnDone. +type Conn struct { + db *DB + + // closemu prevents the connection from closing while there + // is an active query. It is held for read during queries + // and exclusively during close. + closemu sync.RWMutex + + // dc is owned until close, at which point + // it's returned to the connection pool. + dc *driverConn + + // done transitions from 0 to 1 exactly once, on close. + // Once done, all operations fail with ErrConnDone. + // Use atomic operations on value when checking value. + done int32 +} + +func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) { + if atomic.LoadInt32(&c.done) != 0 { + return nil, nil, ErrConnDone + } + c.closemu.RLock() + return c.dc, c.closemuRUnlockCondReleaseConn, nil +} + +// PingContext verifies the connection to the database is still alive. +func (c *Conn) PingContext(ctx context.Context) error { + dc, release, err := c.grabConn(ctx) + if err != nil { + return err + } + return c.db.pingDC(ctx, dc, release) +} + +// ExecContext executes a query without returning any rows. +// The args are for any placeholder parameters in the query. +func (c *Conn) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { + dc, release, err := c.grabConn(ctx) + if err != nil { + return nil, err + } + return c.db.execDC(ctx, dc, release, query, args) +} + +// QueryContext executes a query that returns rows, typically a SELECT. +// The args are for any placeholder parameters in the query. +func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + dc, release, err := c.grabConn(ctx) + if err != nil { + return nil, err + } + return c.db.queryDC(ctx, nil, dc, release, query, args) +} + +// QueryRowContext executes a query that is expected to return at most one row. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := c.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err} +} + +// PrepareContext creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. +// The caller must call the statement's Close method +// when the statement is no longer needed. +// +// The provided context is used for the preparation of the statement, not for the +// execution of the statement. +func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + dc, release, err := c.grabConn(ctx) + if err != nil { + return nil, err + } + return c.db.prepareDC(ctx, dc, release, c, query) +} + +// BeginTx starts a transaction. +// +// The provided context is used until the transaction is committed or rolled back. +// If the context is canceled, the sql package will roll back +// the transaction. Tx.Commit will return an error if the context provided to +// BeginTx is canceled. +// +// The provided TxOptions is optional and may be nil if defaults should be used. +// If a non-default isolation level is used that the driver doesn't support, +// an error will be returned. +func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) { + dc, release, err := c.grabConn(ctx) + if err != nil { + return nil, err + } + return c.db.beginDC(ctx, dc, release, opts) +} + +// closemuRUnlockCondReleaseConn read unlocks closemu +// as the sql operation is done with the dc. +func (c *Conn) closemuRUnlockCondReleaseConn(err error) { + c.closemu.RUnlock() + if err == driver.ErrBadConn { + c.close(err) + } +} + +func (c *Conn) txCtx() context.Context { + return nil +} + +func (c *Conn) close(err error) error { + if !atomic.CompareAndSwapInt32(&c.done, 0, 1) { + return ErrConnDone + } + + // Lock around releasing the driver connection + // to ensure all queries have been stopped before doing so. + c.closemu.Lock() + defer c.closemu.Unlock() + + c.dc.releaseConn(err) + c.dc = nil + c.db = nil + return err +} + +// Close returns the connection to the connection pool. +// All operations after a Close will return with ErrConnDone. +// Close is safe to call concurrently with other operations and will +// block until all other operations finish. It may be useful to first +// cancel any used context and then call close directly after. +func (c *Conn) Close() error { + return c.close(nil) +} + // Tx is an in-progress database transaction. // // A transaction must end with a call to Commit or Rollback. @@ -1412,6 +1663,10 @@ type Tx struct { dc *driverConn txi driver.Tx + // releaseConn is called once the Tx is closed to release + // any held driverConn back to the pool. + releaseConn func(error) + // done transitions from 0 to 1 exactly once, on Commit // or Rollback. once done, all operations fail with // ErrTxDone. @@ -1425,7 +1680,7 @@ type Tx struct { v []*Stmt } - // cancel is called after done transitions from false to true. + // cancel is called after done transitions from 0 to 1. cancel func() // ctx lives for the life of the transaction. @@ -1457,11 +1712,12 @@ var ErrTxDone = errors.New("sql: Transaction has already been committed or rolle // close returns the connection to the pool and // must only be called by Tx.rollback or Tx.Commit. func (tx *Tx) close(err error) { + tx.cancel() + tx.closemu.Lock() defer tx.closemu.Unlock() - tx.db.putConn(tx.dc, err) - tx.cancel() + tx.releaseConn(err) tx.dc = nil tx.txi = nil } @@ -1470,19 +1726,36 @@ func (tx *Tx) close(err error) { // a successful call to (*Tx).grabConn. For tests. var hookTxGrabConn func() -func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) { +func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) { select { default: case <-ctx.Done(): - return nil, ctx.Err() + return nil, nil, ctx.Err() } + + // closeme.RLock must come before the check for isDone to prevent the Tx from + // closing while a query is executing. + tx.closemu.RLock() if tx.isDone() { - return nil, ErrTxDone + tx.closemu.RUnlock() + return nil, nil, ErrTxDone } if hookTxGrabConn != nil { // test hook hookTxGrabConn() } - return tx.dc, nil + return tx.dc, tx.closemuRUnlockRelease, nil +} + +func (tx *Tx) txCtx() context.Context { + return tx.ctx +} + +// closemuRUnlockRelease is used as a func(error) method value in +// ExecContext and QueryContext. Unlocking in the releaseConn keeps +// the driver conn from being returned to the connection pool until +// the Rows has been closed. +func (tx *Tx) closemuRUnlockRelease(error) { + tx.closemu.RUnlock() } // Closes all Stmts prepared for this transaction. @@ -1540,7 +1813,7 @@ func (tx *Tx) Rollback() error { return tx.rollback(false) } -// Prepare creates a prepared statement for use within a transaction. +// PrepareContext creates a prepared statement for use within a transaction. // // The returned statement operates within the transaction and will be closed // when the transaction has been committed or rolled back. @@ -1551,44 +1824,15 @@ func (tx *Tx) Rollback() error { // for the execution of the returned statement. The returned statement // will run in the transaction context. func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { - tx.closemu.RLock() - defer tx.closemu.RUnlock() - - // TODO(bradfitz): We could be more efficient here and either - // provide a method to take an existing Stmt (created on - // perhaps a different Conn), and re-create it on this Conn if - // necessary. Or, better: keep a map in DB of query string to - // Stmts, and have Stmt.Execute do the right thing and - // re-prepare if the Conn in use doesn't have that prepared - // statement. But we'll want to avoid caching the statement - // in the case where we only call conn.Prepare implicitly - // (such as in db.Exec or tx.Exec), but the caller package - // can't be holding a reference to the returned statement. - // Perhaps just looking at the reference count (by noting - // Stmt.Close) would be enough. We might also want a finalizer - // on Stmt to drop the reference count. - dc, err := tx.grabConn(ctx) + dc, release, err := tx.grabConn(ctx) if err != nil { return nil, err } - var si driver.Stmt - withLock(dc, func() { - si, err = ctxDriverPrepare(ctx, dc.ci, query) - }) + stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query) if err != nil { return nil, err } - - stmt := &Stmt{ - db: tx.db, - tx: tx, - txds: &driverStmt{ - Locker: dc, - si: si, - }, - query: query, - } tx.stmts.Lock() tx.stmts.v = append(tx.stmts.v, stmt) tx.stmts.Unlock() @@ -1618,34 +1862,67 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { // The returned statement operates within the transaction and will be closed // when the transaction has been committed or rolled back. func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { - tx.closemu.RLock() - defer tx.closemu.RUnlock() - - // TODO(bradfitz): optimize this. Currently this re-prepares - // each time. This is fine for now to illustrate the API but - // we should really cache already-prepared statements - // per-Conn. See also the big comment in Tx.Prepare. + dc, release, err := tx.grabConn(ctx) + if err != nil { + return &Stmt{stickyErr: err} + } + defer release(nil) if tx.db != stmt.db { return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} } - dc, err := tx.grabConn(ctx) - if err != nil { - return &Stmt{stickyErr: err} - } var si driver.Stmt - withLock(dc, func() { - si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query) - }) + var parentStmt *Stmt + stmt.mu.Lock() + if stmt.closed || stmt.cg != nil { + // If the statement has been closed or already belongs to a + // transaction, we can't reuse it in this connection. + // Since tx.StmtContext should never need to be called with a + // Stmt already belonging to tx, we ignore this edge case and + // re-prepare the statement in this case. No need to add + // code-complexity for this. + stmt.mu.Unlock() + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query) + }) + if err != nil { + return &Stmt{stickyErr: err} + } + } else { + stmt.removeClosedStmtLocked() + // See if the statement has already been prepared on this connection, + // and reuse it if possible. + for _, v := range stmt.css { + if v.dc == dc { + si = v.ds.si + break + } + } + + stmt.mu.Unlock() + + if si == nil { + cs, err := stmt.prepareOnConnLocked(ctx, dc) + if err != nil { + return &Stmt{stickyErr: err} + } + si = cs.si + } + parentStmt = stmt + } + txs := &Stmt{ db: tx.db, - tx: tx, - txds: &driverStmt{ + cg: tx, + cgds: &driverStmt{ Locker: dc, si: si, }, - query: stmt.query, - stickyErr: err, + parentStmt: parentStmt, + query: stmt.query, + } + if parentStmt != nil { + tx.db.addDep(parentStmt, txs) } tx.stmts.Lock() tx.stmts.v = append(tx.stmts.v, txs) @@ -1672,42 +1949,11 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { // ExecContext executes a query that doesn't return rows. // For example: an INSERT and UPDATE. func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { - tx.closemu.RLock() - defer tx.closemu.RUnlock() - - dc, err := tx.grabConn(ctx) + dc, release, err := tx.grabConn(ctx) if err != nil { return nil, err } - - if execer, ok := dc.ci.(driver.Execer); ok { - dargs, err := driverArgs(nil, args) - if err != nil { - return nil, err - } - var resi driver.Result - withLock(dc, func() { - resi, err = ctxDriverExec(ctx, execer, query, dargs) - }) - if err == nil { - return driverResult{dc, resi}, nil - } - if err != driver.ErrSkip { - return nil, err - } - } - - var si driver.Stmt - withLock(dc, func() { - si, err = ctxDriverPrepare(ctx, dc.ci, query) - }) - if err != nil { - return nil, err - } - ds := &driverStmt{Locker: dc, si: si} - defer ds.Close() - - return resultFromStatement(ctx, ds, args...) + return tx.db.execDC(ctx, dc, release, query, args) } // Exec executes a query that doesn't return rows. @@ -1718,15 +1964,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { // QueryContext executes a query that returns rows, typically a SELECT. func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { - tx.closemu.RLock() - defer tx.closemu.RUnlock() - - dc, err := tx.grabConn(ctx) + dc, release, err := tx.grabConn(ctx) if err != nil { return nil, err } - releaseConn := func(error) {} - return tx.db.queryConn(ctx, dc, releaseConn, query, args) + + return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args) } // Query executes a query that returns rows, typically a SELECT. @@ -1737,6 +1980,9 @@ func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { // QueryRowContext executes a query that is expected to return at most one row. // QueryRowContext always returns a non-nil value. Errors are deferred until // Row's Scan method is called. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { rows, err := tx.QueryContext(ctx, query, args...) return &Row{rows: rows, err: err} @@ -1745,6 +1991,9 @@ func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interfa // QueryRow executes a query that is expected to return at most one row. // QueryRow always returns a non-nil value. Errors are deferred until // Row's Scan method is called. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { return tx.QueryRowContext(context.Background(), query, args...) } @@ -1755,6 +2004,24 @@ type connStmt struct { ds *driverStmt } +// stmtConnGrabber represents a Tx or Conn that will return the underlying +// driverConn and release function. +type stmtConnGrabber interface { + // grabConn returns the driverConn and the associated release function + // that must be called when the operation completes. + grabConn(context.Context) (*driverConn, releaseConn, error) + + // txCtx returns the transaction context if available. + // The returned context should be selected on along with + // any query context when awaiting a cancel. + txCtx() context.Context +} + +var ( + _ stmtConnGrabber = &Tx{} + _ stmtConnGrabber = &Conn{} +) + // Stmt is a prepared statement. // A Stmt is safe for concurrent use by multiple goroutines. type Stmt struct { @@ -1765,17 +2032,29 @@ type Stmt struct { closemu sync.RWMutex // held exclusively during close, for read otherwise. - // If in a transaction, else both nil: - tx *Tx - txds *driverStmt + // If Stmt is prepared on a Tx or Conn then cg is present and will + // only ever grab a connection from cg. + // If cg is nil then the Stmt must grab an arbitrary connection + // from db and determine if it must prepare the stmt again by + // inspecting css. + cg stmtConnGrabber + cgds *driverStmt + + // parentStmt is set when a transaction-specific statement + // is requested from an identical statement prepared on the same + // conn. parentStmt is used to track the dependency of this statement + // on its originating ("parent") statement so that parentStmt may + // be closed by the user without them having to know whether or not + // any transactions are still using it. + parentStmt *Stmt mu sync.Mutex // protects the rest of the fields closed bool // css is a list of underlying driver statement interfaces // that are valid on particular connections. This is only - // used if tx == nil and one is found that has idle - // connections. If tx != nil, txsi is always used. + // used if cg == nil and one is found that has idle + // connections. If cg != nil, cgds is always used. css []connStmt // lastNumClosed is copied from db.numClosed when Stmt is created @@ -1790,8 +2069,12 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er defer s.closemu.RUnlock() var res Result - for i := 0; i < maxBadConnRetries; i++ { - _, releaseConn, ds, err := s.connStmt(ctx) + strategy := cachedOrNewConn + for i := 0; i < maxBadConnRetries+1; i++ { + if i == maxBadConnRetries { + strategy = alwaysNewConn + } + dc, releaseConn, ds, err := s.connStmt(ctx, strategy) if err != nil { if err == driver.ErrBadConn { continue @@ -1799,7 +2082,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er return nil, err } - res, err = resultFromStatement(ctx, ds, args...) + res, err = resultFromStatement(ctx, dc.ci, ds, args...) releaseConn(err) if err != driver.ErrBadConn { return res, err @@ -1814,23 +2097,8 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { return s.ExecContext(context.Background(), args...) } -func driverNumInput(ds *driverStmt) int { - ds.Lock() - defer ds.Unlock() // in case NumInput panics - return ds.si.NumInput() -} - -func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) { - want := driverNumInput(ds) - - // -1 means the driver doesn't know how to count the number of - // placeholders, so we won't sanity check input here and instead let the - // driver deal with errors. - if want != -1 && len(args) != want { - return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args)) - } - - dargs, err := driverArgs(ds, args) +func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) { + dargs, err := driverArgs(ci, ds, args) if err != nil { return nil, err } @@ -1874,7 +2142,7 @@ func (s *Stmt) removeClosedStmtLocked() { // connStmt returns a free driver connection on which to execute the // statement, a function to call to release the connection, and a // statement bound to that connection. -func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) { +func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) { if err = s.stickyErr; err != nil { return } @@ -1885,22 +2153,21 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e return } - // In a transaction, we always use the connection that the - // transaction was created on. - if s.tx != nil { + // In a transaction or connection, we always use the connection that the + // the stmt was created on. + if s.cg != nil { s.mu.Unlock() - ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection. + dc, releaseConn, err = s.cg.grabConn(ctx) // blocks, waiting for the connection. if err != nil { return } - releaseConn = func(error) {} - return ci, releaseConn, s.txds, nil + return dc, releaseConn, s.cgds, nil } s.removeClosedStmtLocked() s.mu.Unlock() - dc, err := s.db.conn(ctx, cachedOrNewConn) + dc, err = s.db.conn(ctx, strategy) if err != nil { return nil, nil, nil, err } @@ -1916,18 +2183,28 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e // No luck; we need to prepare the statement on this connection withLock(dc, func() { - ds, err = dc.prepareLocked(ctx, s.query) + ds, err = s.prepareOnConnLocked(ctx, dc) }) if err != nil { - s.db.putConn(dc, err) + dc.releaseConn(err) return nil, nil, nil, err } + + return dc, dc.releaseConn, ds, nil +} + +// prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of +// open connStmt on the statement. It assumes the caller is holding the lock on dc. +func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) { + si, err := dc.prepareLocked(ctx, s.cg, s.query) + if err != nil { + return nil, err + } + cs := connStmt{dc, si} s.mu.Lock() - cs := connStmt{dc, ds} s.css = append(s.css, cs) s.mu.Unlock() - - return dc, dc.releaseConn, ds, nil + return cs.ds, nil } // QueryContext executes a prepared query statement with the given arguments @@ -1937,8 +2214,12 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er defer s.closemu.RUnlock() var rowsi driver.Rows - for i := 0; i < maxBadConnRetries; i++ { - dc, releaseConn, ds, err := s.connStmt(ctx) + strategy := cachedOrNewConn + for i := 0; i < maxBadConnRetries+1; i++ { + if i == maxBadConnRetries { + strategy = alwaysNewConn + } + dc, releaseConn, ds, err := s.connStmt(ctx, strategy) if err != nil { if err == driver.ErrBadConn { continue @@ -1946,7 +2227,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er return nil, err } - rowsi, err = rowsiFromStatement(ctx, ds, args...) + rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...) if err == nil { // Note: ownership of ci passes to the *Rows, to be freed // with releaseConn. @@ -1955,12 +2236,21 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er rowsi: rowsi, // releaseConn set below } + // addDep must be added before initContextClose or it could attempt + // to removeDep before it has been added. s.db.addDep(s, rows) + + // releaseConn must be set before initContextClose or it could + // release the connection before it is set. rows.releaseConn = func(err error) { releaseConn(err) s.db.removeDep(s, rows) } - rows.initContextClose(ctx) + var txctx context.Context + if s.cg != nil { + txctx = s.cg.txCtx() + } + rows.initContextClose(ctx, txctx) return rows, nil } @@ -1978,7 +2268,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { return s.QueryContext(context.Background(), args...) } -func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) { +func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) { var want int withLock(ds, func() { want = ds.si.NumInput() @@ -1991,7 +2281,7 @@ func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{} return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) } - dargs, err := driverArgs(ds, args) + dargs, err := driverArgs(ci, ds, args) if err != nil { return nil, err } @@ -2054,13 +2344,21 @@ func (s *Stmt) Close() error { return nil } s.closed = true + txds := s.cgds + s.cgds = nil + s.mu.Unlock() - if s.tx != nil { - return s.txds.Close() + if s.cg == nil { + return s.db.removeDep(s, s) } - return s.db.removeDep(s, s) + if s.parentStmt != nil { + // If parentStmt is set, we must not close s.txds since it's stored + // in the css array of the parentStmt. + return s.db.removeDep(s.parentStmt, s) + } + return txds.Close() } func (s *Stmt) finalClose() error { @@ -2107,18 +2405,28 @@ type Rows struct { lasterr error // non-nil only if closed is true // lastcols is only used in Scan, Next, and NextResultSet which are expected - // not not be called concurrently. + // not to be called concurrently. lastcols []driver.Value } -func (rs *Rows) initContextClose(ctx context.Context) { +func (rs *Rows) initContextClose(ctx, txctx context.Context) { ctx, rs.cancel = context.WithCancel(ctx) - go rs.awaitDone(ctx) + go rs.awaitDone(ctx, txctx) } -// awaitDone blocks until the rows are closed or the context canceled. -func (rs *Rows) awaitDone(ctx context.Context) { - <-ctx.Done() +// awaitDone blocks until either ctx or txctx is canceled. The ctx is provided +// from the query context and is canceled when the query Rows is closed. +// If the query was issued in a transaction, the transaction's context +// is also provided in txctx to ensure Rows is closed if the Tx is closed. +func (rs *Rows) awaitDone(ctx, txctx context.Context) { + var txctxDone <-chan struct{} + if txctx != nil { + txctxDone = txctx.Done() + } + select { + case <-ctx.Done(): + case <-txctxDone: + } rs.close(ctx.Err()) } @@ -2407,7 +2715,7 @@ func (rs *Rows) Scan(dest ...interface{}) error { } // rowsCloseHook returns a function so tests may install the -// hook throug a test only mutex. +// hook through a test only mutex. var rowsCloseHook = func() func(*Rows, *error) { return nil } // Close closes the Rows, preventing further enumeration. If Next is called @@ -2431,7 +2739,9 @@ func (rs *Rows) close(err error) error { rs.lasterr = err } - err = rs.rowsi.Close() + withLock(rs.dc, func() { + err = rs.rowsi.Close() + }) if fn := rowsCloseHook(); fn != nil { fn(rs, &err) } diff --git a/libgo/go/database/sql/sql_test.go b/libgo/go/database/sql/sql_test.go index 381aafc..c935eb4 100644 --- a/libgo/go/database/sql/sql_test.go +++ b/libgo/go/database/sql/sql_test.go @@ -139,6 +139,7 @@ func closeDB(t testing.TB, db *DB) { t.Errorf("Error closing fakeConn: %v", err) } }) + db.mu.Lock() for i, dc := range db.freeConn { if n := len(dc.openStmt); n > 0 { // Just a sanity check. This is legal in @@ -149,6 +150,8 @@ func closeDB(t testing.TB, db *DB) { t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, len(db.freeConn), n) } } + db.mu.Unlock() + err := db.Close() if err != nil { t.Fatalf("error closing DB: %v", err) @@ -322,7 +325,7 @@ func TestQueryContext(t *testing.T) { select { case <-ctx.Done(): if err := ctx.Err(); err != context.Canceled { - t.Fatalf("context err = %v; want context.Canceled", ctx.Err()) + t.Fatalf("context err = %v; want context.Canceled", err) } default: t.Fatalf("context err = nil; want context.Canceled") @@ -413,7 +416,7 @@ func TestTxContextWait(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*15) + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Millisecond) defer cancel() tx, err := db.BeginTx(ctx, nil) @@ -590,13 +593,13 @@ func TestPoolExhaustOnCancel(t *testing.T) { saturate.Wait() // Now cancel the request while it is waiting. - ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() for i := 0; i < max; i++ { ctxReq, cancelReq := context.WithCancel(ctx) go func() { - time.Sleep(time.Millisecond * 100) + time.Sleep(100 * time.Millisecond) cancelReq() }() err := db.PingContext(ctxReq) @@ -874,7 +877,7 @@ func TestStatementClose(t *testing.T) { msg string }{ {&Stmt{stickyErr: want}, "stickyErr not propagated"}, - {&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"}, + {&Stmt{cg: &Tx{}, cgds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"}, } for _, test := range tests { if err := test.stmt.Close(); err != want { @@ -1024,6 +1027,196 @@ func TestTxStmt(t *testing.T) { } } +func TestTxStmtPreparedOnce(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32") + + prepares0 := numPrepares(t, db) + + // db.Prepare increments numPrepares. + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + + txs1 := tx.Stmt(stmt) + txs2 := tx.Stmt(stmt) + + _, err = txs1.Exec("Go", 7) + if err != nil { + t.Fatalf("Exec = %v", err) + } + txs1.Close() + + _, err = txs2.Exec("Gopher", 8) + if err != nil { + t.Fatalf("Exec = %v", err) + } + txs2.Close() + + err = tx.Commit() + if err != nil { + t.Fatalf("Commit = %v", err) + } + + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +func TestTxStmtClosedRePrepares(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32") + + prepares0 := numPrepares(t, db) + + // db.Prepare increments numPrepares. + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + err = stmt.Close() + if err != nil { + t.Fatalf("stmt.Close() = %v", err) + } + // tx.Stmt increments numPrepares because stmt is closed. + txs := tx.Stmt(stmt) + if txs.stickyErr != nil { + t.Fatal(txs.stickyErr) + } + if txs.parentStmt != nil { + t.Fatal("expected nil parentStmt") + } + _, err = txs.Exec(`Eric`, 82) + if err != nil { + t.Fatalf("txs.Exec = %v", err) + } + + err = txs.Close() + if err != nil { + t.Fatalf("txs.Close = %v", err) + } + + tx.Rollback() + + if prepares := numPrepares(t, db) - prepares0; prepares != 2 { + t.Errorf("executed %d Prepare statements; want 2", prepares) + } +} + +func TestParentStmtOutlivesTxStmt(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32") + + // Make sure everything happens on the same connection. + db.SetMaxOpenConns(1) + + prepares0 := numPrepares(t, db) + + // db.Prepare increments numPrepares. + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + txs := tx.Stmt(stmt) + if len(stmt.css) != 1 { + t.Fatalf("len(stmt.css) = %v; want 1", len(stmt.css)) + } + err = txs.Close() + if err != nil { + t.Fatalf("txs.Close() = %v", err) + } + err = tx.Rollback() + if err != nil { + t.Fatalf("tx.Rollback() = %v", err) + } + // txs must not be valid. + _, err = txs.Exec("Suzan", 30) + if err == nil { + t.Fatalf("txs.Exec(), expected err") + } + // Stmt must still be valid. + _, err = stmt.Exec("Janina", 25) + if err != nil { + t.Fatalf("stmt.Exec() = %v", err) + } + + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +// Test that tx.Stmt called with a statement already +// associated with tx as argument re-prepares the same +// statement again. +func TestTxStmtFromTxStmtRePrepares(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32") + prepares0 := numPrepares(t, db) + // db.Prepare increments numPrepares. + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + txs1 := tx.Stmt(stmt) + + // tx.Stmt(txs1) increments numPrepares because txs1 already + // belongs to a transaction (albeit the same transaction). + txs2 := tx.Stmt(txs1) + if txs2.stickyErr != nil { + t.Fatal(txs2.stickyErr) + } + if txs2.parentStmt != nil { + t.Fatal("expected nil parentStmt") + } + _, err = txs2.Exec(`Eric`, 82) + if err != nil { + t.Fatal(err) + } + + err = txs1.Close() + if err != nil { + t.Fatalf("txs1.Close = %v", err) + } + err = txs2.Close() + if err != nil { + t.Fatalf("txs1.Close = %v", err) + } + err = tx.Rollback() + if err != nil { + t.Fatalf("tx.Rollback = %v", err) + } + + if prepares := numPrepares(t, db) - prepares0; prepares != 2 { + t.Errorf("executed %d Prepare statements; want 2", prepares) + } +} + // Issue: https://golang.org/issue/2784 // This test didn't fail before because we got lucky with the fakedb driver. // It was failing, and now not, in github.com/bradfitz/go-sql-test @@ -1108,6 +1301,69 @@ func TestTxErrBadConn(t *testing.T) { } } +func TestConnQuery(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + var name string + err = conn.QueryRowContext(ctx, "SELECT|people|name|age=?", 3).Scan(&name) + if err != nil { + t.Fatal(err) + } + if name != "Chris" { + t.Fatalf("unexpected result, got %q want Chris", name) + } + + err = conn.PingContext(ctx) + if err != nil { + t.Fatal(err) + } +} + +func TestConnTx(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + insertName, insertAge := "Nancy", 33 + _, err = tx.ExecContext(ctx, "INSERT|people|name=?,age=?,photo=APHOTO", insertName, insertAge) + if err != nil { + t.Fatal(err) + } + err = tx.Commit() + if err != nil { + t.Fatal(err) + } + + var selectName string + err = conn.QueryRowContext(ctx, "SELECT|people|name|age=?", insertAge).Scan(&selectName) + if err != nil { + t.Fatal(err) + } + if selectName != insertName { + t.Fatalf("got %q want %q", selectName, insertName) + } +} + // Tests fix for issue 2542, that we release a lock when querying on // a closed connection. func TestIssue2542Deadlock(t *testing.T) { @@ -1831,8 +2087,8 @@ func TestConnMaxLifetime(t *testing.T) { } // Expire first conn - offset = time.Second * 11 - db.SetConnMaxLifetime(time.Second * 10) + offset = 11 * time.Second + db.SetConnMaxLifetime(10 * time.Second) if err != nil { t.Fatal(err) } @@ -2078,9 +2334,13 @@ func TestStmtCloseOrder(t *testing.T) { // Test cases where there's more than maxBadConnRetries bad connections in the // pool (issue 8834) func TestManyErrBadConn(t *testing.T) { - manyErrBadConnSetup := func() *DB { + manyErrBadConnSetup := func(first ...func(db *DB)) *DB { db := newTestDB(t, "people") + for _, f := range first { + f(db) + } + nconn := maxBadConnRetries + 1 db.SetMaxIdleConns(nconn) db.SetMaxOpenConns(nconn) @@ -2148,6 +2408,128 @@ func TestManyErrBadConn(t *testing.T) { if err = stmt.Close(); err != nil { t.Fatal(err) } + + // Stmt.Exec + db = manyErrBadConnSetup(func(db *DB) { + stmt, err = db.Prepare("INSERT|people|name=Julia,age=19") + if err != nil { + t.Fatal(err) + } + }) + defer closeDB(t, db) + _, err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + if err = stmt.Close(); err != nil { + t.Fatal(err) + } + + // Stmt.Query + db = manyErrBadConnSetup(func(db *DB) { + stmt, err = db.Prepare("SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + }) + defer closeDB(t, db) + rows, err = stmt.Query() + if err != nil { + t.Fatal(err) + } + if err = rows.Close(); err != nil { + t.Fatal(err) + } + if err = stmt.Close(); err != nil { + t.Fatal(err) + } + + // Conn + db = manyErrBadConnSetup() + defer closeDB(t, db) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + err = conn.Close() + if err != nil { + t.Fatal(err) + } + + // Ping + db = manyErrBadConnSetup() + defer closeDB(t, db) + err = db.PingContext(ctx) + if err != nil { + t.Fatal(err) + } +} + +// TestIssue20575 ensures the Rows from query does not block +// closing a transaction. Ensure Rows is closed while closing a trasaction. +func TestIssue20575(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + _, err = tx.QueryContext(ctx, "SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + // Do not close Rows from QueryContext. + err = tx.Rollback() + if err != nil { + t.Fatal(err) + } + select { + default: + case <-ctx.Done(): + t.Fatal("timeout: failed to rollback query without closing rows:", ctx.Err()) + } +} + +// TestIssue20622 tests closing the transaction before rows is closed, requires +// the race detector to fail. +func TestIssue20622(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + + rows, err := tx.Query("SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + + count := 0 + for rows.Next() { + count++ + var age int + var name string + if err := rows.Scan(&age, &name); err != nil { + t.Fatal("scan failed", err) + } + + if count == 1 { + cancel() + } + time.Sleep(100 * time.Millisecond) + } + rows.Close() + tx.Commit() } // golang.org/issue/5718 @@ -2751,7 +3133,7 @@ func TestIssue18429(t *testing.T) { if err != nil { return } - // This is expected to give a cancel error many, but not all the time. + // This is expected to give a cancel error most, but not all the time. // Test failure will happen with a panic or other race condition being // reported. rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|") @@ -2766,6 +3148,46 @@ func TestIssue18429(t *testing.T) { wg.Wait() } +// TestIssue20160 attempts to test a short context life on a stmt Query. +func TestIssue20160(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx := context.Background() + sem := make(chan bool, 20) + var wg sync.WaitGroup + + const milliWait = 30 + + stmt, err := db.PrepareContext(ctx, "SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + for i := 0; i < 100; i++ { + sem <- true + wg.Add(1) + go func() { + defer func() { + <-sem + wg.Done() + }() + ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond) + defer cancel() + + // This is expected to give a cancel error most, but not all the time. + // Test failure will happen with a panic or other race condition being + // reported. + rows, _ := stmt.QueryContext(ctx) + if rows != nil { + rows.Close() + } + }() + } + wg.Wait() +} + // TestIssue18719 closes the context right before use. The sql.driverConn // will nil out the ci on close in a lock, but if another process uses it right after // it will panic with on the nil ref. @@ -2788,7 +3210,7 @@ func TestIssue18719(t *testing.T) { // Wait for the context to cancel and tx to rollback. for tx.isDone() == false { - time.Sleep(time.Millisecond * 3) + time.Sleep(3 * time.Millisecond) } } defer func() { hookTxGrabConn = nil }() @@ -2807,19 +3229,64 @@ func TestIssue18719(t *testing.T) { // canceled context. cancel() - waitForRowsClose(t, rows, 5*time.Second) +} + +func TestIssue20647(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + rows1, err := stmt.QueryContext(ctx) + if err != nil { + t.Fatal("rows1", err) + } + defer rows1.Close() + + rows2, err := stmt.QueryContext(ctx) + if err != nil { + t.Fatal("rows2", err) + } + defer rows2.Close() + + if rows1.dc != rows2.dc { + t.Fatal("stmt prepared on Conn does not use same connection") + } } func TestConcurrency(t *testing.T) { - doConcurrentTest(t, new(concurrentDBQueryTest)) - doConcurrentTest(t, new(concurrentDBExecTest)) - doConcurrentTest(t, new(concurrentStmtQueryTest)) - doConcurrentTest(t, new(concurrentStmtExecTest)) - doConcurrentTest(t, new(concurrentTxQueryTest)) - doConcurrentTest(t, new(concurrentTxExecTest)) - doConcurrentTest(t, new(concurrentTxStmtQueryTest)) - doConcurrentTest(t, new(concurrentTxStmtExecTest)) - doConcurrentTest(t, new(concurrentRandomTest)) + list := []struct { + name string + ct concurrentTest + }{ + {"Query", new(concurrentDBQueryTest)}, + {"Exec", new(concurrentDBExecTest)}, + {"StmtQuery", new(concurrentStmtQueryTest)}, + {"StmtExec", new(concurrentStmtExecTest)}, + {"TxQuery", new(concurrentTxQueryTest)}, + {"TxExec", new(concurrentTxExecTest)}, + {"TxStmtQuery", new(concurrentTxStmtQueryTest)}, + {"TxStmtExec", new(concurrentTxStmtExecTest)}, + {"Random", new(concurrentRandomTest)}, + } + for _, item := range list { + t.Run(item.name, func(t *testing.T) { + doConcurrentTest(t, item.ct) + }) + } } func TestConnectionLeak(t *testing.T) { @@ -2874,6 +3341,131 @@ func TestConnectionLeak(t *testing.T) { wg.Wait() } +type nvcDriver struct { + fakeDriver + skipNamedValueCheck bool +} + +func (d *nvcDriver) Open(dsn string) (driver.Conn, error) { + c, err := d.fakeDriver.Open(dsn) + fc := c.(*fakeConn) + fc.db.allowAny = true + return &nvcConn{fc, d.skipNamedValueCheck}, err +} + +type nvcConn struct { + *fakeConn + skipNamedValueCheck bool +} + +type decimal struct { + value int +} + +type doNotInclude struct{} + +var _ driver.NamedValueChecker = &nvcConn{} + +func (c *nvcConn) CheckNamedValue(nv *driver.NamedValue) error { + if c.skipNamedValueCheck { + return driver.ErrSkip + } + switch v := nv.Value.(type) { + default: + return driver.ErrSkip + case Out: + switch ov := v.Dest.(type) { + default: + return errors.New("unkown NameValueCheck OUTPUT type") + case *string: + *ov = "from-server" + nv.Value = "OUT:*string" + } + return nil + case decimal, []int64: + return nil + case doNotInclude: + return driver.ErrRemoveArgument + } +} + +func TestNamedValueChecker(t *testing.T) { + Register("NamedValueCheck", &nvcDriver{}) + db, err := Open("NamedValueCheck", "") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err = db.ExecContext(ctx, "WIPE") + if err != nil { + t.Fatal("exec wipe", err) + } + + _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any,str1=string,out1=string,array1=any") + if err != nil { + t.Fatal("exec create", err) + } + + o1 := "" + _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A,str1=?,out1=?O1,array1=?", Named("A", decimal{123}), "hello", Named("O1", Out{Dest: &o1}), []int64{42, 128, 707}, doNotInclude{}) + if err != nil { + t.Fatal("exec insert", err) + } + var ( + str1 string + dec1 decimal + arr1 []int64 + ) + err = db.QueryRowContext(ctx, "SELECT|keys|dec1,str1,array1|").Scan(&dec1, &str1, &arr1) + if err != nil { + t.Fatal("select", err) + } + + list := []struct{ got, want interface{} }{ + {o1, "from-server"}, + {dec1, decimal{123}}, + {str1, "hello"}, + {arr1, []int64{42, 128, 707}}, + } + + for index, item := range list { + if !reflect.DeepEqual(item.got, item.want) { + t.Errorf("got %#v wanted %#v for index %d", item.got, item.want, index) + } + } +} + +func TestNamedValueCheckerSkip(t *testing.T) { + Register("NamedValueCheckSkip", &nvcDriver{skipNamedValueCheck: true}) + db, err := Open("NamedValueCheckSkip", "") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err = db.ExecContext(ctx, "WIPE") + if err != nil { + t.Fatal("exec wipe", err) + } + + _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any") + if err != nil { + t.Fatal("exec create", err) + } + + _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A", Named("A", decimal{123})) + if err == nil { + t.Fatalf("expected error with bad argument, got %v", err) + } +} + // badConn implements a bad driver.Conn, for TestBadDriver. // The Exec method panics. type badConn struct{} @@ -2965,6 +3557,24 @@ func TestPing(t *testing.T) { } } +// Issue 18101. +func TestTypedString(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + type Str string + var scanned Str + + err := db.QueryRow("SELECT|people|name|name=?", "Alice").Scan(&scanned) + if err != nil { + t.Fatal(err) + } + expected := Str("Alice") + if scanned != expected { + t.Errorf("expected %+v, got %+v", expected, scanned) + } +} + func BenchmarkConcurrentDBExec(b *testing.B) { b.ReportAllocs() ct := new(concurrentDBExecTest) |