diff options
author | Ian Lance Taylor <iant@golang.org> | 2017-01-14 00:05:42 +0000 |
---|---|---|
committer | Ian Lance Taylor <ian@gcc.gnu.org> | 2017-01-14 00:05:42 +0000 |
commit | c2047754c300b68c05d65faa8dc2925fe67b71b4 (patch) | |
tree | e183ae81a1f48a02945cb6de463a70c5be1b06f6 /libgo/go/database/sql/sql.go | |
parent | 829afb8f05602bb31c9c597b24df7377fed4f059 (diff) | |
download | gcc-c2047754c300b68c05d65faa8dc2925fe67b71b4.zip gcc-c2047754c300b68c05d65faa8dc2925fe67b71b4.tar.gz gcc-c2047754c300b68c05d65faa8dc2925fe67b71b4.tar.bz2 |
libgo: update to Go 1.8 release candidate 1
Compiler changes:
* Change map assignment to use mapassign and assign value directly.
* Change string iteration to use decoderune, faster for ASCII strings.
* Change makeslice to take int, and use makeslice64 for larger values.
* Add new noverflow field to hmap struct used for maps.
Unresolved problems, to be fixed later:
* Commented out test in go/types/sizes_test.go that doesn't compile.
* Commented out reflect.TestStructOf test for padding after zero-sized field.
Reviewed-on: https://go-review.googlesource.com/35231
gotools/:
Updates for Go 1.8rc1.
* Makefile.am (go_cmd_go_files): Add bug.go.
(s-zdefaultcc): Write defaultPkgConfig.
* Makefile.in: Rebuild.
From-SVN: r244456
Diffstat (limited to 'libgo/go/database/sql/sql.go')
-rw-r--r-- | libgo/go/database/sql/sql.go | 916 |
1 files changed, 691 insertions, 225 deletions
diff --git a/libgo/go/database/sql/sql.go b/libgo/go/database/sql/sql.go index 09de1c3..0fa7c34 100644 --- a/libgo/go/database/sql/sql.go +++ b/libgo/go/database/sql/sql.go @@ -8,15 +8,20 @@ // The sql package must be used in conjunction with a database driver. // See https://golang.org/s/sqldrivers for a list of drivers. // -// For more usage examples, see the wiki page at +// Drivers that do not support context cancelation will not return until +// after the query is completed. +// +// For usage examples, see the wiki page at // https://golang.org/s/sqlwiki. package sql import ( + "context" "database/sql/driver" "errors" "fmt" "io" + "reflect" "runtime" "sort" "sync" @@ -66,6 +71,75 @@ func Drivers() []string { return list } +// A NamedArg is a named argument. NamedArg values may be used as +// arguments to Query or Exec and bind to the corresponding named +// parameter in the SQL statement. +// +// For a more concise way to create NamedArg values, see +// the Named function. +type NamedArg struct { + _Named_Fields_Required struct{} + + // Name is the name of the parameter placeholder. + // + // If empty, the ordinal position in the argument list will be + // used. + // + // Name must omit any symbol prefix. + Name string + + // Value is the value of the parameter. + // It may be assigned the same value types as the query + // arguments. + Value interface{} +} + +// Named provides a more concise way to create NamedArg values. +// +// Example usage: +// +// db.ExecContext(ctx, ` +// delete from Invoice +// where +// TimeCreated < @end +// and TimeCreated >= @start;`, +// sql.Named("start", startTime), +// sql.Named("end", endTime), +// ) +func Named(name string, value interface{}) NamedArg { + // This method exists because the go1compat promise + // doesn't guarantee that structs don't grow more fields, + // so unkeyed struct literals are a vet error. Thus, we don't + // want to allow sql.NamedArg{name, value}. + return NamedArg{Name: name, Value: value} +} + +// IsolationLevel is the transaction isolation level used in TxOptions. +type IsolationLevel int + +// Various isolation levels that drivers may support in BeginTx. +// If a driver does not support a given isolation level an error may be returned. +// +// See https://en.wikipedia.org/wiki/Isolation_(database_systems)#Isolation_levels. +const ( + LevelDefault IsolationLevel = iota + LevelReadUncommitted + LevelReadCommitted + LevelWriteCommitted + LevelRepeatableRead + LevelSnapshot + LevelSerializable + LevelLinearizable +) + +// TxOptions holds the transaction options to be used in DB.BeginTx. +type TxOptions struct { + // Isolation is the transaction isolation level. + // If zero, the driver or database's default level is used. + Isolation IsolationLevel + ReadOnly bool +} + // RawBytes is a byte slice that holds a reference to memory owned by // the database itself. After a Scan into a RawBytes, the slice is only // valid until the next call to Next, Scan, or Close. @@ -272,7 +346,7 @@ type driverConn struct { ci driver.Conn closed bool finalClosed bool // ci.Close has been called - openStmt map[driver.Stmt]bool + openStmt map[*driverStmt]bool // guarded by db.mu inUse bool @@ -284,10 +358,10 @@ func (dc *driverConn) releaseConn(err error) { dc.db.putConn(dc, err) } -func (dc *driverConn) removeOpenStmt(si driver.Stmt) { +func (dc *driverConn) removeOpenStmt(ds *driverStmt) { dc.Lock() defer dc.Unlock() - delete(dc.openStmt, si) + delete(dc.openStmt, ds) } func (dc *driverConn) expired(timeout time.Duration) bool { @@ -297,28 +371,23 @@ func (dc *driverConn) expired(timeout time.Duration) bool { return dc.createdAt.Add(timeout).Before(nowFunc()) } -func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) { - si, err := dc.ci.Prepare(query) - if err == nil { - // Track each driverConn's open statements, so we can close them - // before closing the conn. - // - // TODO(bradfitz): let drivers opt out of caring about - // stmt closes if the conn is about to close anyway? For now - // do the safe thing, in case stmts need to be closed. - // - // TODO(bradfitz): after Go 1.2, closing driver.Stmts - // should be moved to driverStmt, using unique - // *driverStmts everywhere (including from - // *Stmt.connStmt, instead of returning a - // driver.Stmt), using driverStmt as a pointer - // everywhere, and making it a finalCloser. - if dc.openStmt == nil { - dc.openStmt = make(map[driver.Stmt]bool) - } - dc.openStmt[si] = true +func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) { + si, err := ctxDriverPrepare(ctx, dc.ci, query) + if err != nil { + return nil, err } - return si, err + + // Track each driverConn's open statements, so we can close them + // before closing the conn. + // + // Wrap all driver.Stmt is *driverStmt to ensure they are only closed once. + if dc.openStmt == nil { + dc.openStmt = make(map[*driverStmt]bool) + } + ds := &driverStmt{Locker: dc, si: si} + dc.openStmt[ds] = true + + return ds, nil } // the dc.db's Mutex is held. @@ -350,17 +419,26 @@ func (dc *driverConn) Close() error { } func (dc *driverConn) finalClose() error { - dc.Lock() + var err error - for si := range dc.openStmt { - si.Close() + // Each *driverStmt has a lock to the dc. Copy the list out of the dc + // before calling close on each stmt. + var openStmt []*driverStmt + withLock(dc, func() { + openStmt = make([]*driverStmt, 0, len(dc.openStmt)) + for ds := range dc.openStmt { + openStmt = append(openStmt, ds) + } + dc.openStmt = nil + }) + for _, ds := range openStmt { + ds.Close() } - dc.openStmt = nil - - err := dc.ci.Close() - dc.ci = nil - dc.finalClosed = true - dc.Unlock() + withLock(dc, func() { + dc.finalClosed = true + err = dc.ci.Close() + dc.ci = nil + }) dc.db.mu.Lock() dc.db.numOpen-- @@ -377,12 +455,21 @@ func (dc *driverConn) finalClose() error { type driverStmt struct { sync.Locker // the *driverConn si driver.Stmt + closed bool + closeErr error // return value of previous Close call } +// Close ensures dirver.Stmt is only closed once any always returns the same +// result. func (ds *driverStmt) Close() error { ds.Lock() defer ds.Unlock() - return ds.si.Close() + if ds.closed { + return ds.closeErr + } + ds.closed = true + ds.closeErr = ds.si.Close() + return ds.closeErr } // depSet is a finalCloser's outstanding dependencies @@ -494,18 +581,36 @@ func Open(driverName, dataSourceName string) (*DB, error) { return db, nil } -// Ping verifies a connection to the database is still alive, +// PingContext verifies a connection to the database is still alive, // establishing a connection if necessary. -func (db *DB) Ping() error { - // TODO(bradfitz): give drivers an optional hook to implement - // this in a more efficient or more reliable way, if they - // have one. - dc, err := db.conn(cachedOrNewConn) +func (db *DB) PingContext(ctx context.Context) error { + var dc *driverConn + var err error + + for i := 0; i < maxBadConnRetries; i++ { + dc, err = db.conn(ctx, cachedOrNewConn) + if err != driver.ErrBadConn { + break + } + } + if err == driver.ErrBadConn { + dc, err = db.conn(ctx, alwaysNewConn) + } if err != nil { return err } - db.putConn(dc, nil) - return nil + + if pinger, ok := dc.ci.(driver.Pinger); ok { + err = pinger.Ping(ctx) + } + db.putConn(dc, err) + return err +} + +// Ping verifies a connection to the database is still alive, +// establishing a connection if necessary. +func (db *DB) Ping() error { + return db.PingContext(context.Background()) } // Close closes the database, releasing any open resources. @@ -777,12 +882,19 @@ type connRequest struct { var errDBClosed = errors.New("sql: database is closed") // conn returns a newly-opened or cached *driverConn. -func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) { +func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) { db.mu.Lock() if db.closed { db.mu.Unlock() return nil, errDBClosed } + // Check if the context is expired. + select { + default: + case <-ctx.Done(): + db.mu.Unlock() + return nil, ctx.Err() + } lifetime := db.maxLifetime // Prefer a free connection, if possible. @@ -808,15 +920,21 @@ func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) { req := make(chan connRequest, 1) db.connRequests = append(db.connRequests, req) db.mu.Unlock() - ret, ok := <-req - if !ok { - return nil, errDBClosed - } - if ret.err == nil && ret.conn.expired(lifetime) { - ret.conn.Close() - return nil, driver.ErrBadConn + + // Timeout the connection request with the context. + select { + case <-ctx.Done(): + return nil, ctx.Err() + case ret, ok := <-req: + if !ok { + return nil, errDBClosed + } + if ret.err == nil && ret.conn.expired(lifetime) { + ret.conn.Close() + return nil, driver.ErrBadConn + } + return ret.conn, ret.err } - return ret.conn, ret.err } db.numOpen++ // optimistically @@ -844,21 +962,22 @@ func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) { // putConnHook is a hook for testing. var putConnHook func(*DB, *driverConn) -// noteUnusedDriverStatement notes that si is no longer used and should +// noteUnusedDriverStatement notes that ds is no longer used and should // be closed whenever possible (when c is next not in use), unless c is // already closed. -func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) { +func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) { db.mu.Lock() defer db.mu.Unlock() if c.inUse { c.onPut = append(c.onPut, func() { - si.Close() + ds.Close() }) } else { c.Lock() - defer c.Unlock() - if !c.finalClosed { - si.Close() + fc := c.finalClosed + c.Unlock() + if !fc { + ds.Close() } } } @@ -952,40 +1071,53 @@ func (db *DB) putConnDBLocked(dc *driverConn, err error) bool { // connection to be opened. const maxBadConnRetries = 2 -// Prepare creates a prepared statement for later queries or executions. +// PrepareContext creates a prepared statement for later queries or executions. // Multiple queries or executions may be run concurrently from the // returned statement. // The caller must call the statement's Close method // when the statement is no longer needed. -func (db *DB) Prepare(query string) (*Stmt, error) { +// +// The provided context is used for the preparation of the statement, not for the +// execution of the statement. +func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { var stmt *Stmt var err error for i := 0; i < maxBadConnRetries; i++ { - stmt, err = db.prepare(query, cachedOrNewConn) + stmt, err = db.prepare(ctx, query, cachedOrNewConn) if err != driver.ErrBadConn { break } } if err == driver.ErrBadConn { - return db.prepare(query, alwaysNewConn) + return db.prepare(ctx, query, alwaysNewConn) } return stmt, err } -func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) { +// Prepare creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. +// The caller must call the statement's Close method +// when the statement is no longer needed. +func (db *DB) Prepare(query string) (*Stmt, error) { + return db.PrepareContext(context.Background(), query) +} + +func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) { // TODO: check if db.driver supports an optional // driver.Preparer interface and call that instead, if so, // otherwise we make a prepared statement that's bound // to a connection, and to execute this prepared statement // we either need to use this connection (if it's free), else // get a new connection + re-prepare + execute on that one. - dc, err := db.conn(strategy) + dc, err := db.conn(ctx, strategy) if err != nil { return nil, err } - dc.Lock() - si, err := dc.prepareLocked(query) - dc.Unlock() + var ds *driverStmt + withLock(dc, func() { + ds, err = dc.prepareLocked(ctx, query) + }) if err != nil { db.putConn(dc, err) return nil, err @@ -993,7 +1125,7 @@ func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) { stmt := &Stmt{ db: db, query: query, - css: []connStmt{{dc, si}}, + css: []connStmt{{dc, ds}}, lastNumClosed: atomic.LoadUint64(&db.numClosed), } db.addDep(stmt, stmt) @@ -1001,25 +1133,31 @@ func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) { return stmt, nil } -// Exec executes a query without returning any rows. +// ExecContext executes a query without returning any rows. // The args are for any placeholder parameters in the query. -func (db *DB) Exec(query string, args ...interface{}) (Result, error) { +func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { var res Result var err error for i := 0; i < maxBadConnRetries; i++ { - res, err = db.exec(query, args, cachedOrNewConn) + res, err = db.exec(ctx, query, args, cachedOrNewConn) if err != driver.ErrBadConn { break } } if err == driver.ErrBadConn { - return db.exec(query, args, alwaysNewConn) + return db.exec(ctx, query, args, alwaysNewConn) } return res, err } -func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) { - dc, err := db.conn(strategy) +// Exec executes a query without returning any rows. +// The args are for any placeholder parameters in the query. +func (db *DB) Exec(query string, args ...interface{}) (Result, error) { + return db.ExecContext(context.Background(), query, args...) +} + +func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) { + dc, err := db.conn(ctx, strategy) if err != nil { return nil, err } @@ -1028,13 +1166,15 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) }() if execer, ok := dc.ci.(driver.Execer); ok { - dargs, err := driverArgs(nil, args) + var dargs []driver.NamedValue + dargs, err = driverArgs(nil, args) if err != nil { return nil, err } - dc.Lock() - resi, err := execer.Exec(query, dargs) - dc.Unlock() + var resi driver.Result + withLock(dc, func() { + resi, err = ctxDriverExec(ctx, execer, query, dargs) + }) if err != driver.ErrSkip { if err != nil { return nil, err @@ -1043,54 +1183,63 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) } } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + var si driver.Stmt + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, query) + }) if err != nil { return nil, err } - defer withLock(dc, func() { si.Close() }) - return resultFromStatement(driverStmt{dc, si}, args...) + ds := &driverStmt{Locker: dc, si: si} + defer ds.Close() + return resultFromStatement(ctx, ds, args...) } -// Query executes a query that returns rows, typically a SELECT. +// QueryContext executes a query that returns rows, typically a SELECT. // The args are for any placeholder parameters in the query. -func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { +func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { var rows *Rows var err error for i := 0; i < maxBadConnRetries; i++ { - rows, err = db.query(query, args, cachedOrNewConn) + rows, err = db.query(ctx, query, args, cachedOrNewConn) if err != driver.ErrBadConn { break } } if err == driver.ErrBadConn { - return db.query(query, args, alwaysNewConn) + return db.query(ctx, query, args, alwaysNewConn) } return rows, err } -func (db *DB) query(query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) { - ci, err := db.conn(strategy) +// Query executes a query that returns rows, typically a SELECT. +// The args are for any placeholder parameters in the query. +func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { + return db.QueryContext(context.Background(), query, args...) +} + +func (db *DB) query(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) { + ci, err := db.conn(ctx, strategy) if err != nil { return nil, err } - return db.queryConn(ci, ci.releaseConn, query, args) + return db.queryConn(ctx, ci, ci.releaseConn, query, args) } // queryConn executes a query on the given connection. // The connection gets released by the releaseConn function. -func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { +func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { if queryer, ok := dc.ci.(driver.Queryer); ok { dargs, err := driverArgs(nil, args) if err != nil { releaseConn(err) return nil, err } - dc.Lock() - rowsi, err := queryer.Query(query, dargs) - dc.Unlock() + var rowsi driver.Rows + withLock(dc, func() { + rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs) + }) if err != driver.ErrSkip { if err != nil { releaseConn(err) @@ -1103,24 +1252,25 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a releaseConn: releaseConn, rowsi: rowsi, } + rows.initContextClose(ctx) return rows, nil } } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + var si driver.Stmt + var err error + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, query) + }) if err != nil { releaseConn(err) return nil, err } - ds := driverStmt{dc, si} - rowsi, err := rowsiFromStatement(ds, args...) + ds := &driverStmt{Locker: dc, si: si} + rowsi, err := rowsiFromStatement(ctx, ds, args...) if err != nil { - dc.Lock() - si.Close() - dc.Unlock() + ds.Close() releaseConn(err) return nil, err } @@ -1131,53 +1281,93 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a dc: dc, releaseConn: releaseConn, rowsi: rowsi, - closeStmt: si, + closeStmt: ds, } + rows.initContextClose(ctx) return rows, nil } +// QueryRowContext executes a query that is expected to return at most one row. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := db.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err} +} + // QueryRow executes a query that is expected to return at most one row. // QueryRow always returns a non-nil value. Errors are deferred until // Row's Scan method is called. func (db *DB) QueryRow(query string, args ...interface{}) *Row { - rows, err := db.Query(query, args...) - return &Row{rows: rows, err: err} + return db.QueryRowContext(context.Background(), query, args...) } -// Begin starts a transaction. The isolation level is dependent on -// the driver. -func (db *DB) Begin() (*Tx, error) { +// BeginTx starts a transaction. +// +// The provided context is used until the transaction is committed or rolled back. +// If the context is canceled, the sql package will roll back +// the transaction. Tx.Commit will return an error if the context provided to +// BeginTx is canceled. +// +// The provided TxOptions is optional and may be nil if defaults should be used. +// If a non-default isolation level is used that the driver doesn't support, +// an error will be returned. +func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) { var tx *Tx var err error for i := 0; i < maxBadConnRetries; i++ { - tx, err = db.begin(cachedOrNewConn) + tx, err = db.begin(ctx, opts, cachedOrNewConn) if err != driver.ErrBadConn { break } } if err == driver.ErrBadConn { - return db.begin(alwaysNewConn) + return db.begin(ctx, opts, alwaysNewConn) } return tx, err } -func (db *DB) begin(strategy connReuseStrategy) (tx *Tx, err error) { - dc, err := db.conn(strategy) +// Begin starts a transaction. The default isolation level is dependent on +// the driver. +func (db *DB) Begin() (*Tx, error) { + return db.BeginTx(context.Background(), nil) +} + +func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) { + dc, err := db.conn(ctx, strategy) if err != nil { return nil, err } - dc.Lock() - txi, err := dc.ci.Begin() - dc.Unlock() + var txi driver.Tx + withLock(dc, func() { + txi, err = ctxDriverBegin(ctx, opts, dc.ci) + }) if err != nil { db.putConn(dc, err) return nil, err } - return &Tx{ - db: db, - dc: dc, - txi: txi, - }, nil + + // Schedule the transaction to rollback when the context is cancelled. + // The cancel function in Tx will be called after done is set to true. + ctx, cancel := context.WithCancel(ctx) + tx = &Tx{ + db: db, + dc: dc, + txi: txi, + cancel: cancel, + ctx: ctx, + } + go func(tx *Tx) { + select { + case <-tx.ctx.Done(): + if !tx.isDone() { + // Discard and close the connection used to ensure the transaction + // is closed and the resources are released. + tx.rollback(true) + } + } + }(tx) + return tx, nil } // Driver returns the database's underlying driver. @@ -1203,10 +1393,11 @@ type Tx struct { dc *driverConn txi driver.Tx - // done transitions from false to true exactly once, on Commit + // done transitions from 0 to 1 exactly once, on Commit // or Rollback. once done, all operations fail with // ErrTxDone. - done bool + // Use atomic operations on value when checking value. + done int32 // All Stmts prepared for this transaction. These will be closed after the // transaction has been committed or rolled back. @@ -1214,22 +1405,33 @@ type Tx struct { sync.Mutex v []*Stmt } + + // cancel is called after done transitions from false to true. + cancel func() + + // ctx lives for the life of the transaction. + ctx context.Context +} + +func (tx *Tx) isDone() bool { + return atomic.LoadInt32(&tx.done) != 0 } +// ErrTxDone is returned by any operation that is performed on a transaction +// that has already been committed or rolled back. var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back") +// close returns the connection to the pool and +// must only be called by Tx.rollback or Tx.Commit. func (tx *Tx) close(err error) { - if tx.done { - panic("double close") // internal error - } - tx.done = true tx.db.putConn(tx.dc, err) + tx.cancel() tx.dc = nil tx.txi = nil } -func (tx *Tx) grabConn() (*driverConn, error) { - if tx.done { +func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) { + if tx.isDone() { return nil, ErrTxDone } return tx.dc, nil @@ -1238,20 +1440,26 @@ func (tx *Tx) grabConn() (*driverConn, error) { // Closes all Stmts prepared for this transaction. func (tx *Tx) closePrepared() { tx.stmts.Lock() + defer tx.stmts.Unlock() for _, stmt := range tx.stmts.v { stmt.Close() } - tx.stmts.Unlock() } // Commit commits the transaction. func (tx *Tx) Commit() error { - if tx.done { + if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { return ErrTxDone } - tx.dc.Lock() - err := tx.txi.Commit() - tx.dc.Unlock() + select { + default: + case <-tx.ctx.Done(): + return tx.ctx.Err() + } + var err error + withLock(tx.dc, func() { + err = tx.txi.Commit() + }) if err != driver.ErrBadConn { tx.closePrepared() } @@ -1259,28 +1467,42 @@ func (tx *Tx) Commit() error { return err } -// Rollback aborts the transaction. -func (tx *Tx) Rollback() error { - if tx.done { +// rollback aborts the transaction and optionally forces the pool to discard +// the connection. +func (tx *Tx) rollback(discardConn bool) error { + if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { return ErrTxDone } - tx.dc.Lock() - err := tx.txi.Rollback() - tx.dc.Unlock() + var err error + withLock(tx.dc, func() { + err = tx.txi.Rollback() + }) if err != driver.ErrBadConn { tx.closePrepared() } + if discardConn { + err = driver.ErrBadConn + } tx.close(err) return err } +// Rollback aborts the transaction. +func (tx *Tx) Rollback() error { + return tx.rollback(false) +} + // Prepare creates a prepared statement for use within a transaction. // -// The returned statement operates within the transaction and can no longer -// be used once the transaction has been committed or rolled back. +// The returned statement operates within the transaction and will be closed +// when the transaction has been committed or rolled back. // // To use an existing prepared statement on this transaction, see Tx.Stmt. -func (tx *Tx) Prepare(query string) (*Stmt, error) { +// +// The provided context will be used for the preparation of the context, not +// for the execution of the returned statement. The returned statement +// will run in the transaction context. +func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { // TODO(bradfitz): We could be more efficient here and either // provide a method to take an existing Stmt (created on // perhaps a different Conn), and re-create it on this Conn if @@ -1294,14 +1516,15 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { // Perhaps just looking at the reference count (by noting // Stmt.Close) would be enough. We might also want a finalizer // on Stmt to drop the reference count. - dc, err := tx.grabConn() + dc, err := tx.grabConn(ctx) if err != nil { return nil, err } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + var si driver.Stmt + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, query) + }) if err != nil { return nil, err } @@ -1309,7 +1532,7 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { stmt := &Stmt{ db: tx.db, tx: tx, - txsi: &driverStmt{ + txds: &driverStmt{ Locker: dc, si: si, }, @@ -1321,7 +1544,17 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { return stmt, nil } -// Stmt returns a transaction-specific prepared statement from +// Prepare creates a prepared statement for use within a transaction. +// +// The returned statement operates within the transaction and can no longer +// be used once the transaction has been committed or rolled back. +// +// To use an existing prepared statement on this transaction, see Tx.Stmt. +func (tx *Tx) Prepare(query string) (*Stmt, error) { + return tx.PrepareContext(context.Background(), query) +} + +// StmtContext returns a transaction-specific prepared statement from // an existing statement. // // Example: @@ -1329,11 +1562,11 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { // ... // tx, err := db.Begin() // ... -// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203) +// res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203) // -// The returned statement operates within the transaction and can no longer -// be used once the transaction has been committed or rolled back. -func (tx *Tx) Stmt(stmt *Stmt) *Stmt { +// The returned statement operates within the transaction and will be closed +// when the transaction has been committed or rolled back. +func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { // TODO(bradfitz): optimize this. Currently this re-prepares // each time. This is fine for now to illustrate the API but // we should really cache already-prepared statements @@ -1342,17 +1575,18 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { if tx.db != stmt.db { return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} } - dc, err := tx.grabConn() + dc, err := tx.grabConn(ctx) if err != nil { return &Stmt{stickyErr: err} } - dc.Lock() - si, err := dc.ci.Prepare(stmt.query) - dc.Unlock() + var si driver.Stmt + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query) + }) txs := &Stmt{ db: tx.db, tx: tx, - txsi: &driverStmt{ + txds: &driverStmt{ Locker: dc, si: si, }, @@ -1365,10 +1599,26 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { return txs } -// Exec executes a query that doesn't return rows. +// Stmt returns a transaction-specific prepared statement from +// an existing statement. +// +// Example: +// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?") +// ... +// tx, err := db.Begin() +// ... +// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203) +// +// The returned statement operates within the transaction and will be closed +// when the transaction has been committed or rolled back. +func (tx *Tx) Stmt(stmt *Stmt) *Stmt { + return tx.StmtContext(context.Background(), stmt) +} + +// ExecContext executes a query that doesn't return rows. // For example: an INSERT and UPDATE. -func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { - dc, err := tx.grabConn() +func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { + dc, err := tx.grabConn(ctx) if err != nil { return nil, err } @@ -1378,9 +1628,10 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { if err != nil { return nil, err } - dc.Lock() - resi, err := execer.Exec(query, dargs) - dc.Unlock() + var resi driver.Result + withLock(dc, func() { + resi, err = ctxDriverExec(ctx, execer, query, dargs) + }) if err == nil { return driverResult{dc, resi}, nil } @@ -1389,39 +1640,59 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { } } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + var si driver.Stmt + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, query) + }) if err != nil { return nil, err } - defer withLock(dc, func() { si.Close() }) + ds := &driverStmt{Locker: dc, si: si} + defer ds.Close() - return resultFromStatement(driverStmt{dc, si}, args...) + return resultFromStatement(ctx, ds, args...) } -// Query executes a query that returns rows, typically a SELECT. -func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { - dc, err := tx.grabConn() +// Exec executes a query that doesn't return rows. +// For example: an INSERT and UPDATE. +func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { + return tx.ExecContext(context.Background(), query, args...) +} + +// QueryContext executes a query that returns rows, typically a SELECT. +func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + dc, err := tx.grabConn(ctx) if err != nil { return nil, err } releaseConn := func(error) {} - return tx.db.queryConn(dc, releaseConn, query, args) + return tx.db.queryConn(ctx, dc, releaseConn, query, args) +} + +// Query executes a query that returns rows, typically a SELECT. +func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { + return tx.QueryContext(context.Background(), query, args...) +} + +// QueryRowContext executes a query that is expected to return at most one row. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := tx.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err} } // QueryRow executes a query that is expected to return at most one row. // QueryRow always returns a non-nil value. Errors are deferred until // Row's Scan method is called. func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { - rows, err := tx.Query(query, args...) - return &Row{rows: rows, err: err} + return tx.QueryRowContext(context.Background(), query, args...) } // connStmt is a prepared statement on a particular connection. type connStmt struct { dc *driverConn - si driver.Stmt + ds *driverStmt } // Stmt is a prepared statement. @@ -1436,7 +1707,7 @@ type Stmt struct { // If in a transaction, else both nil: tx *Tx - txsi *driverStmt + txds *driverStmt mu sync.Mutex // protects the rest of the fields closed bool @@ -1452,15 +1723,15 @@ type Stmt struct { lastNumClosed uint64 } -// Exec executes a prepared statement with the given arguments and +// ExecContext executes a prepared statement with the given arguments and // returns a Result summarizing the effect of the statement. -func (s *Stmt) Exec(args ...interface{}) (Result, error) { +func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, error) { s.closemu.RLock() defer s.closemu.RUnlock() var res Result for i := 0; i < maxBadConnRetries; i++ { - dc, releaseConn, si, err := s.connStmt() + _, releaseConn, ds, err := s.connStmt(ctx) if err != nil { if err == driver.ErrBadConn { continue @@ -1468,7 +1739,7 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { return nil, err } - res, err = resultFromStatement(driverStmt{dc, si}, args...) + res, err = resultFromStatement(ctx, ds, args...) releaseConn(err) if err != driver.ErrBadConn { return res, err @@ -1477,13 +1748,19 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { return nil, driver.ErrBadConn } -func driverNumInput(ds driverStmt) int { +// Exec executes a prepared statement with the given arguments and +// returns a Result summarizing the effect of the statement. +func (s *Stmt) Exec(args ...interface{}) (Result, error) { + return s.ExecContext(context.Background(), args...) +} + +func driverNumInput(ds *driverStmt) int { ds.Lock() defer ds.Unlock() // in case NumInput panics return ds.si.NumInput() } -func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) { +func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) { want := driverNumInput(ds) // -1 means the driver doesn't know how to count the number of @@ -1493,14 +1770,15 @@ func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) { return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args)) } - dargs, err := driverArgs(&ds, args) + dargs, err := driverArgs(ds, args) if err != nil { return nil, err } ds.Lock() defer ds.Unlock() - resi, err := ds.si.Exec(dargs) + + resi, err := ctxDriverStmtExec(ctx, ds.si, dargs) if err != nil { return nil, err } @@ -1536,7 +1814,7 @@ func (s *Stmt) removeClosedStmtLocked() { // connStmt returns a free driver connection on which to execute the // statement, a function to call to release the connection, and a // statement bound to that connection. -func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) { +func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) { if err = s.stickyErr; err != nil { return } @@ -1551,19 +1829,18 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St // transaction was created on. if s.tx != nil { s.mu.Unlock() - ci, err = s.tx.grabConn() // blocks, waiting for the connection. + ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection. if err != nil { return } releaseConn = func(error) {} - return ci, releaseConn, s.txsi.si, nil + return ci, releaseConn, s.txds, nil } s.removeClosedStmtLocked() s.mu.Unlock() - // TODO(bradfitz): or always wait for one? make configurable later? - dc, err := s.db.conn(cachedOrNewConn) + dc, err := s.db.conn(ctx, cachedOrNewConn) if err != nil { return nil, nil, nil, err } @@ -1572,36 +1849,36 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St for _, v := range s.css { if v.dc == dc { s.mu.Unlock() - return dc, dc.releaseConn, v.si, nil + return dc, dc.releaseConn, v.ds, nil } } s.mu.Unlock() // No luck; we need to prepare the statement on this connection - dc.Lock() - si, err = dc.prepareLocked(s.query) - dc.Unlock() + withLock(dc, func() { + ds, err = dc.prepareLocked(ctx, s.query) + }) if err != nil { s.db.putConn(dc, err) return nil, nil, nil, err } s.mu.Lock() - cs := connStmt{dc, si} + cs := connStmt{dc, ds} s.css = append(s.css, cs) s.mu.Unlock() - return dc, dc.releaseConn, si, nil + return dc, dc.releaseConn, ds, nil } -// Query executes a prepared query statement with the given arguments +// QueryContext executes a prepared query statement with the given arguments // and returns the query results as a *Rows. -func (s *Stmt) Query(args ...interface{}) (*Rows, error) { +func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { s.closemu.RLock() defer s.closemu.RUnlock() var rowsi driver.Rows for i := 0; i < maxBadConnRetries; i++ { - dc, releaseConn, si, err := s.connStmt() + dc, releaseConn, ds, err := s.connStmt(ctx) if err != nil { if err == driver.ErrBadConn { continue @@ -1609,7 +1886,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { return nil, err } - rowsi, err = rowsiFromStatement(driverStmt{dc, si}, args...) + rowsi, err = rowsiFromStatement(ctx, ds, args...) if err == nil { // Note: ownership of ci passes to the *Rows, to be freed // with releaseConn. @@ -1618,6 +1895,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { rowsi: rowsi, // releaseConn set below } + rows.initContextClose(ctx) s.db.addDep(s, rows) rows.releaseConn = func(err error) { releaseConn(err) @@ -1634,10 +1912,17 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { return nil, driver.ErrBadConn } -func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) { - ds.Lock() - want := ds.si.NumInput() - ds.Unlock() +// Query executes a prepared query statement with the given arguments +// and returns the query results as a *Rows. +func (s *Stmt) Query(args ...interface{}) (*Rows, error) { + return s.QueryContext(context.Background(), args...) +} + +func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) { + var want int + withLock(ds, func() { + want = ds.si.NumInput() + }) // -1 means the driver doesn't know how to count the number of // placeholders, so we won't sanity check input here and instead let the @@ -1646,21 +1931,22 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) } - dargs, err := driverArgs(&ds, args) + dargs, err := driverArgs(ds, args) if err != nil { return nil, err } ds.Lock() - rowsi, err := ds.si.Query(dargs) - ds.Unlock() + defer ds.Unlock() + + rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs) if err != nil { return nil, err } return rowsi, nil } -// QueryRow executes a prepared query statement with the given arguments. +// QueryRowContext executes a prepared query statement with the given arguments. // If an error occurs during the execution of the statement, that error will // be returned by a call to Scan on the returned *Row, which is always non-nil. // If the query selects no rows, the *Row's Scan will return ErrNoRows. @@ -1670,15 +1956,30 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) // Example usage: // // var name string -// err := nameByUseridStmt.QueryRow(id).Scan(&name) -func (s *Stmt) QueryRow(args ...interface{}) *Row { - rows, err := s.Query(args...) +// err := nameByUseridStmt.QueryRowContext(ctx, id).Scan(&name) +func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row { + rows, err := s.QueryContext(ctx, args...) if err != nil { return &Row{err: err} } return &Row{rows: rows} } +// QueryRow executes a prepared query statement with the given arguments. +// If an error occurs during the execution of the statement, that error will +// be returned by a call to Scan on the returned *Row, which is always non-nil. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +// +// Example usage: +// +// var name string +// err := nameByUseridStmt.QueryRow(id).Scan(&name) +func (s *Stmt) QueryRow(args ...interface{}) *Row { + return s.QueryRowContext(context.Background(), args...) +} + // Close closes the statement. func (s *Stmt) Close() error { s.closemu.Lock() @@ -1693,13 +1994,11 @@ func (s *Stmt) Close() error { return nil } s.closed = true + s.mu.Unlock() if s.tx != nil { - err := s.txsi.Close() - s.mu.Unlock() - return err + return s.txds.Close() } - s.mu.Unlock() return s.db.removeDep(s, s) } @@ -1709,8 +2008,8 @@ func (s *Stmt) finalClose() error { defer s.mu.Unlock() if s.css != nil { for _, v := range s.css { - s.db.noteUnusedDriverStatement(v.dc, v.si) - v.dc.removeOpenStmt(v.si) + s.db.noteUnusedDriverStatement(v.dc, v.ds) + v.dc.removeOpenStmt(v.ds) } s.css = nil } @@ -1736,10 +2035,28 @@ type Rows struct { releaseConn func(error) rowsi driver.Rows - closed bool + // closed value is 1 when the Rows is closed. + // Use atomic operations on value when checking value. + closed int32 + ctxClose chan struct{} // closed when Rows is closed, may be null. lastcols []driver.Value lasterr error // non-nil only if closed is true - closeStmt driver.Stmt // if non-nil, statement to Close on close + closeStmt *driverStmt // if non-nil, statement to Close on close +} + +func (rs *Rows) initContextClose(ctx context.Context) { + if ctx.Done() == context.Background().Done() { + return + } + + rs.ctxClose = make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + rs.Close() + case <-rs.ctxClose: + } + }() } // Next prepares the next result row for reading with the Scan method. It @@ -1749,7 +2066,7 @@ type Rows struct { // // Every call to Scan, even the first one, must be preceded by a call to Next. func (rs *Rows) Next() bool { - if rs.closed { + if rs.isClosed() { return false } if rs.lastcols == nil { @@ -1757,6 +2074,47 @@ func (rs *Rows) Next() bool { } rs.lasterr = rs.rowsi.Next(rs.lastcols) if rs.lasterr != nil { + // Close the connection if there is a driver error. + if rs.lasterr != io.EOF { + rs.Close() + return false + } + nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) + if !ok { + rs.Close() + return false + } + // The driver is at the end of the current result set. + // Test to see if there is another result set after the current one. + // Only close Rows if there is no further result sets to read. + if !nextResultSet.HasNextResultSet() { + rs.Close() + } + return false + } + return true +} + +// NextResultSet prepares the next result set for reading. It returns true if +// there is further result sets, or false if there is no further result set +// or if there is an error advancing to it. The Err method should be consulted +// to distinguish between the two cases. +// +// After calling NextResultSet, the Next method should always be called before +// scanning. If there are further result sets they may not have rows in the result +// set. +func (rs *Rows) NextResultSet() bool { + if rs.isClosed() { + return false + } + rs.lastcols = nil + nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) + if !ok { + rs.Close() + return false + } + rs.lasterr = nextResultSet.NextResultSet() + if rs.lasterr != nil { rs.Close() return false } @@ -1776,7 +2134,7 @@ func (rs *Rows) Err() error { // Columns returns an error if the rows are closed, or if the rows // are from QueryRow and there was a deferred error. func (rs *Rows) Columns() ([]string, error) { - if rs.closed { + if rs.isClosed() { return nil, errors.New("sql: Rows are closed") } if rs.rowsi == nil { @@ -1785,6 +2143,107 @@ func (rs *Rows) Columns() ([]string, error) { return rs.rowsi.Columns(), nil } +// ColumnTypes returns column information such as column type, length, +// and nullable. Some information may not be available from some drivers. +func (rs *Rows) ColumnTypes() ([]*ColumnType, error) { + if rs.isClosed() { + return nil, errors.New("sql: Rows are closed") + } + if rs.rowsi == nil { + return nil, errors.New("sql: no Rows available") + } + return rowsColumnInfoSetup(rs.rowsi), nil +} + +// ColumnType contains the name and type of a column. +type ColumnType struct { + name string + + hasNullable bool + hasLength bool + hasPrecisionScale bool + + nullable bool + length int64 + databaseType string + precision int64 + scale int64 + scanType reflect.Type +} + +// Name returns the name or alias of the column. +func (ci *ColumnType) Name() string { + return ci.name +} + +// Length returns the column type length for variable length column types such +// as text and binary field types. If the type length is unbounded the value will +// be math.MaxInt64 (any database limits will still apply). +// If the column type is not variable length, such as an int, or if not supported +// by the driver ok is false. +func (ci *ColumnType) Length() (length int64, ok bool) { + return ci.length, ci.hasLength +} + +// DecimalSize returns the scale and precision of a decimal type. +// If not applicable or if not supported ok is false. +func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) { + return ci.precision, ci.scale, ci.hasPrecisionScale +} + +// ScanType returns a Go type suitable for scanning into using Rows.Scan. +// If a driver does not support this property ScanType will return +// the type of an empty interface. +func (ci *ColumnType) ScanType() reflect.Type { + return ci.scanType +} + +// Nullable returns whether the column may be null. +// If a driver does not support this property ok will be false. +func (ci *ColumnType) Nullable() (nullable, ok bool) { + return ci.nullable, ci.hasNullable +} + +// DatabaseTypeName returns the database system name of the column type. If an empty +// string is returned the driver type name is not supported. +// Consult your driver documentation for a list of driver data types. Length specifiers +// are not included. +// Common type include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", "INT", "BIGINT". +func (ci *ColumnType) DatabaseTypeName() string { + return ci.databaseType +} + +func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType { + names := rowsi.Columns() + + list := make([]*ColumnType, len(names)) + for i := range list { + ci := &ColumnType{ + name: names[i], + } + list[i] = ci + + if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok { + ci.scanType = prop.ColumnTypeScanType(i) + } else { + ci.scanType = reflect.TypeOf(new(interface{})).Elem() + } + if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok { + ci.databaseType = prop.ColumnTypeDatabaseTypeName(i) + } + if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok { + ci.length, ci.hasLength = prop.ColumnTypeLength(i) + } + if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok { + ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i) + } + if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok { + ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i) + } + } + return list +} + // Scan copies the columns in the current row into the values pointed // at by dest. The number of values in dest must be the same as the // number of columns in Rows. @@ -1837,7 +2296,7 @@ func (rs *Rows) Columns() ([]string, error) { // For scanning into *bool, the source may be true, false, 1, 0, or // string inputs parseable by strconv.ParseBool. func (rs *Rows) Scan(dest ...interface{}) error { - if rs.closed { + if rs.isClosed() { return errors.New("sql: Rows are closed") } if rs.lastcols == nil { @@ -1857,14 +2316,21 @@ func (rs *Rows) Scan(dest ...interface{}) error { var rowsCloseHook func(*Rows, *error) -// Close closes the Rows, preventing further enumeration. If Next returns -// false, the Rows are closed automatically and it will suffice to check the +func (rs *Rows) isClosed() bool { + return atomic.LoadInt32(&rs.closed) != 0 +} + +// Close closes the Rows, preventing further enumeration. If Next is called +// and returns false and there are no further result sets, +// the Rows are closed automatically and it will suffice to check the // result of Err. Close is idempotent and does not affect the result of Err. func (rs *Rows) Close() error { - if rs.closed { + if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) { return nil } - rs.closed = true + if rs.ctxClose != nil { + close(rs.ctxClose) + } err := rs.rowsi.Close() if fn := rowsCloseHook; fn != nil { fn(rs, &err) |