diff options
Diffstat (limited to 'libgo/go/exp/sql')
-rw-r--r-- | libgo/go/exp/sql/convert_test.go | 13 | ||||
-rw-r--r-- | libgo/go/exp/sql/driver/driver.go | 1 | ||||
-rw-r--r-- | libgo/go/exp/sql/driver/types.go | 9 | ||||
-rw-r--r-- | libgo/go/exp/sql/driver/types_test.go | 4 | ||||
-rw-r--r-- | libgo/go/exp/sql/fakedb_test.go | 72 | ||||
-rw-r--r-- | libgo/go/exp/sql/sql.go | 58 | ||||
-rw-r--r-- | libgo/go/exp/sql/sql_test.go | 55 |
7 files changed, 182 insertions, 30 deletions
diff --git a/libgo/go/exp/sql/convert_test.go b/libgo/go/exp/sql/convert_test.go index bed09ff..702ba43 100644 --- a/libgo/go/exp/sql/convert_test.go +++ b/libgo/go/exp/sql/convert_test.go @@ -8,8 +8,11 @@ import ( "fmt" "reflect" "testing" + "time" ) +var someTime = time.Unix(123, 0) + type conversionTest struct { s, d interface{} // source and destination @@ -19,6 +22,7 @@ type conversionTest struct { wantstr string wantf32 float32 wantf64 float64 + wanttime time.Time wantbool bool // used if d is of type *bool wanterr string } @@ -35,12 +39,14 @@ var ( scanbool bool scanf32 float32 scanf64 float64 + scantime time.Time ) var conversionTests = []conversionTest{ // Exact conversions (destination pointer type matches source type) {s: "foo", d: &scanstr, wantstr: "foo"}, {s: 123, d: &scanint, wantint: 123}, + {s: someTime, d: &scantime, wanttime: someTime}, // To strings {s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"}, @@ -106,6 +112,10 @@ func float32Value(ptr interface{}) float32 { return *(ptr.(*float32)) } +func timeValue(ptr interface{}) time.Time { + return *(ptr.(*time.Time)) +} + func TestConversions(t *testing.T) { for n, ct := range conversionTests { err := convertAssign(ct.d, ct.s) @@ -138,6 +148,9 @@ func TestConversions(t *testing.T) { if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" { errf("want bool %v, got %v", ct.wantbool, *bp) } + if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) { + errf("want time %v, got %v", ct.wanttime, timeValue(ct.d)) + } } } diff --git a/libgo/go/exp/sql/driver/driver.go b/libgo/go/exp/sql/driver/driver.go index f0bcca2..0cd2562 100644 --- a/libgo/go/exp/sql/driver/driver.go +++ b/libgo/go/exp/sql/driver/driver.go @@ -16,6 +16,7 @@ // nil // []byte // string [*] everywhere except from Rows.Next. +// time.Time // package driver diff --git a/libgo/go/exp/sql/driver/types.go b/libgo/go/exp/sql/driver/types.go index 086b529..d6ba641 100644 --- a/libgo/go/exp/sql/driver/types.go +++ b/libgo/go/exp/sql/driver/types.go @@ -8,6 +8,7 @@ import ( "fmt" "reflect" "strconv" + "time" ) // ValueConverter is the interface providing the ConvertValue method. @@ -39,7 +40,7 @@ type ValueConverter interface { // 1 is true // 0 is false, // other integers are an error -// - for strings and []byte, same rules as strconv.Atob +// - for strings and []byte, same rules as strconv.ParseBool // - all other types are an error var Bool boolType @@ -143,9 +144,10 @@ func (stringType) ConvertValue(v interface{}) (interface{}, error) { // bool // nil // []byte +// time.Time // string // -// This is the ame list as IsScanSubsetType, with the addition of +// This is the same list as IsScanSubsetType, with the addition of // string. func IsParameterSubsetType(v interface{}) bool { if IsScanSubsetType(v) { @@ -165,6 +167,7 @@ func IsParameterSubsetType(v interface{}) bool { // bool // nil // []byte +// time.Time // // This is the same list as IsParameterSubsetType, without string. func IsScanSubsetType(v interface{}) bool { @@ -172,7 +175,7 @@ func IsScanSubsetType(v interface{}) bool { return true } switch v.(type) { - case int64, float64, []byte, bool: + case int64, float64, []byte, bool, time.Time: return true } return false diff --git a/libgo/go/exp/sql/driver/types_test.go b/libgo/go/exp/sql/driver/types_test.go index 4b049e2..966bc6b 100644 --- a/libgo/go/exp/sql/driver/types_test.go +++ b/libgo/go/exp/sql/driver/types_test.go @@ -7,6 +7,7 @@ package driver import ( "reflect" "testing" + "time" ) type valueConverterTest struct { @@ -16,6 +17,8 @@ type valueConverterTest struct { err string } +var now = time.Now() + var valueConverterTests = []valueConverterTest{ {Bool, "true", true, ""}, {Bool, "True", true, ""}, @@ -33,6 +36,7 @@ var valueConverterTests = []valueConverterTest{ {Bool, uint16(0), false, ""}, {c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"}, {c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"}, + {DefaultParameterConverter, now, now, ""}, } func TestValueConverters(t *testing.T) { diff --git a/libgo/go/exp/sql/fakedb_test.go b/libgo/go/exp/sql/fakedb_test.go index 2474a86..70aa68c 100644 --- a/libgo/go/exp/sql/fakedb_test.go +++ b/libgo/go/exp/sql/fakedb_test.go @@ -12,6 +12,7 @@ import ( "strconv" "strings" "sync" + "time" "exp/sql/driver" ) @@ -77,6 +78,17 @@ type fakeConn struct { db *fakeDB // where to return ourselves to currTx *fakeTx + + // Stats for tests: + mu sync.Mutex + stmtsMade int + stmtsClosed int +} + +func (c *fakeConn) incrStat(v *int) { + c.mu.Lock() + *v++ + c.mu.Unlock() } type fakeTx struct { @@ -110,25 +122,34 @@ func init() { // Supports dsn forms: // <dbname> -// <dbname>;wipe +// <dbname>;<opts> (no currently supported options) func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { - d.mu.Lock() - defer d.mu.Unlock() - d.openCount++ - if d.dbs == nil { - d.dbs = make(map[string]*fakeDB) - } parts := strings.Split(dsn, ";") if len(parts) < 1 { return nil, errors.New("fakedb: no database name") } name := parts[0] + + db := d.getDB(name) + + d.mu.Lock() + d.openCount++ + d.mu.Unlock() + return &fakeConn{db: db}, nil +} + +func (d *fakeDriver) getDB(name string) *fakeDB { + d.mu.Lock() + defer d.mu.Unlock() + if d.dbs == nil { + d.dbs = make(map[string]*fakeDB) + } db, ok := d.dbs[name] if !ok { db = &fakeDB{name: name} d.dbs[name] = db } - return &fakeConn{db: db}, nil + return db } func (db *fakeDB) wipe() { @@ -200,7 +221,7 @@ func (c *fakeConn) Close() error { func checkSubsetTypes(args []interface{}) error { for n, arg := range args { switch arg.(type) { - case int64, float64, bool, nil, []byte, string: + case int64, float64, bool, nil, []byte, string, time.Time: default: return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg) } @@ -297,6 +318,8 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e switch ctype { case "string": subsetVal = []byte(value) + case "blob": + subsetVal = []byte(value) case "int32": i, err := strconv.Atoi(value) if err != nil { @@ -327,6 +350,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { cmd := parts[0] parts = parts[1:] stmt := &fakeStmt{q: query, c: c, cmd: cmd} + c.incrStat(&c.stmtsMade) switch cmd { case "WIPE": // Nothing @@ -347,7 +371,10 @@ func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { } func (s *fakeStmt) Close() error { - s.closed = true + if !s.closed { + s.c.incrStat(&s.c.stmtsClosed) + s.closed = true + } return nil } @@ -501,9 +528,19 @@ type rowsCursor struct { pos int rows []*row closed bool + + // a clone of slices to give out to clients, indexed by the + // the original slice's first byte address. we clone them + // just so we're able to corrupt them on close. + bytesClone map[*byte][]byte } func (rc *rowsCursor) Close() error { + if !rc.closed { + for _, bs := range rc.bytesClone { + bs[0] = 255 // first byte corrupted + } + } rc.closed = true return nil } @@ -528,6 +565,19 @@ func (rc *rowsCursor) Next(dest []interface{}) error { // for ease of drivers, and to prevent drivers from // messing up conversions or doing them differently. dest[i] = v + + if bs, ok := v.([]byte); ok { + if rc.bytesClone == nil { + rc.bytesClone = make(map[*byte][]byte) + } + clone, ok := rc.bytesClone[&bs[0]] + if !ok { + clone = make([]byte, len(bs)) + copy(clone, bs) + rc.bytesClone[&bs[0]] = clone + } + dest[i] = clone + } } return nil } @@ -540,6 +590,8 @@ func converterForType(typ string) driver.ValueConverter { return driver.Int32 case "string": return driver.String + case "datetime": + return driver.DefaultParameterConverter } panic("invalid fakedb column type of " + typ) } diff --git a/libgo/go/exp/sql/sql.go b/libgo/go/exp/sql/sql.go index 937982c..4e68c3e 100644 --- a/libgo/go/exp/sql/sql.go +++ b/libgo/go/exp/sql/sql.go @@ -243,8 +243,13 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { if err != nil { return nil, err } - defer stmt.Close() - return stmt.Query(args...) + rows, err := stmt.Query(args...) + if err != nil { + stmt.Close() + return nil, err + } + rows.closeStmt = stmt + return rows, nil } // QueryRow executes a query that is expected to return at most one row. @@ -549,8 +554,8 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { // statement, a function to call to release the connection, and a // statement bound to that connection. func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) { - if s.stickyErr != nil { - return nil, nil, nil, s.stickyErr + if err = s.stickyErr; err != nil { + return } s.mu.Lock() if s.closed { @@ -706,9 +711,10 @@ type Rows struct { releaseConn func() rowsi driver.Rows - closed bool - lastcols []interface{} - lasterr error + closed bool + lastcols []interface{} + lasterr error + closeStmt *Stmt // if non-nil, statement to Close on close } // Next prepares the next result row for reading with the Scan method. @@ -726,6 +732,9 @@ func (rs *Rows) Next() bool { rs.lastcols = make([]interface{}, len(rs.rowsi.Columns())) } rs.lasterr = rs.rowsi.Next(rs.lastcols) + if rs.lasterr == io.EOF { + rs.Close() + } return rs.lasterr == nil } @@ -786,6 +795,9 @@ func (rs *Rows) Close() error { rs.closed = true err := rs.rowsi.Close() rs.releaseConn() + if rs.closeStmt != nil { + rs.closeStmt.Close() + } return err } @@ -800,10 +812,6 @@ type Row struct { // pointed at by dest. If more than one row matches the query, // Scan uses the first row and discards the rest. If no row matches // the query, Scan returns ErrNoRows. -// -// If dest contains pointers to []byte, the slices should not be -// modified and should only be considered valid until the next call to -// Next or Scan. func (r *Row) Scan(dest ...interface{}) error { if r.err != nil { return r.err @@ -812,7 +820,33 @@ func (r *Row) Scan(dest ...interface{}) error { if !r.rows.Next() { return ErrNoRows } - return r.rows.Scan(dest...) + err := r.rows.Scan(dest...) + if err != nil { + return err + } + + // TODO(bradfitz): for now we need to defensively clone all + // []byte that the driver returned, since we're about to close + // the Rows in our defer, when we return from this function. + // the contract with the driver.Next(...) interface is that it + // can return slices into read-only temporary memory that's + // only valid until the next Scan/Close. But the TODO is that + // for a lot of drivers, this copy will be unnecessary. We + // should provide an optional interface for drivers to + // implement to say, "don't worry, the []bytes that I return + // from Next will not be modified again." (for instance, if + // they were obtained from the network anyway) But for now we + // don't care. + for _, dp := range dest { + b, ok := dp.(*[]byte) + if !ok { + continue + } + clone := make([]byte, len(*b)) + copy(clone, *b) + *b = clone + } + return nil } // A Result summarizes an executed SQL command. diff --git a/libgo/go/exp/sql/sql_test.go b/libgo/go/exp/sql/sql_test.go index 5307a23..3f98a8c 100644 --- a/libgo/go/exp/sql/sql_test.go +++ b/libgo/go/exp/sql/sql_test.go @@ -8,10 +8,15 @@ import ( "reflect" "strings" "testing" + "time" ) +const fakeDBName = "foo" + +var chrisBirthday = time.Unix(123456789, 0) + func newTestDB(t *testing.T, name string) *DB { - db, err := Open("test", "foo") + db, err := Open("test", fakeDBName) if err != nil { t.Fatalf("Open: %v", err) } @@ -19,10 +24,10 @@ func newTestDB(t *testing.T, name string) *DB { t.Fatalf("exec wipe: %v", err) } if name == "people" { - exec(t, db, "CREATE|people|name=string,age=int32,dead=bool") - exec(t, db, "INSERT|people|name=Alice,age=?", 1) - exec(t, db, "INSERT|people|name=Bob,age=?", 2) - exec(t, db, "INSERT|people|name=Chris,age=?", 3) + exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime") + exec(t, db, "INSERT|people|name=Alice,age=?,photo=APHOTO", 1) + exec(t, db, "INSERT|people|name=Bob,age=?,photo=BPHOTO", 2) + exec(t, db, "INSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday) } return db } @@ -73,6 +78,12 @@ func TestQuery(t *testing.T) { if !reflect.DeepEqual(got, want) { t.Logf(" got: %#v\nwant: %#v", got, want) } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + if n := len(db.freeConn); n != 1 { + t.Errorf("free conns after query hitting EOF = %d; want 1", n) + } } func TestRowsColumns(t *testing.T) { @@ -97,12 +108,18 @@ func TestQueryRow(t *testing.T) { defer closeDB(t, db) var name string var age int + var birthday time.Time err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age) if err == nil || !strings.Contains(err.Error(), "expected 2 destination arguments") { t.Errorf("expected error from wrong number of arguments; actually got: %v", err) } + err = db.QueryRow("SELECT|people|bdate|age=?", 3).Scan(&birthday) + if err != nil || !birthday.Equal(chrisBirthday) { + t.Errorf("chris birthday = %v, err = %v; want %v", birthday, err, chrisBirthday) + } + err = db.QueryRow("SELECT|people|age,name|age=?", 2).Scan(&age, &name) if err != nil { t.Fatalf("age QueryRow+Scan: %v", err) @@ -124,6 +141,16 @@ func TestQueryRow(t *testing.T) { if age != 1 { t.Errorf("expected age 1, got %d", age) } + + var photo []byte + err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo) + if err != nil { + t.Fatalf("photo QueryRow+Scan: %v", err) + } + want := []byte("APHOTO") + if !reflect.DeepEqual(photo, want) { + t.Errorf("photo = %q; want %q", photo, want) + } } func TestStatementErrorAfterClose(t *testing.T) { @@ -258,3 +285,21 @@ func TestIssue2542Deadlock(t *testing.T) { } } } + +func TestQueryRowClosingStmt(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + var name string + var age int + err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age, &name) + if err != nil { + t.Fatal(err) + } + if len(db.freeConn) != 1 { + t.Fatalf("expected 1 free conn") + } + fakeConn := db.freeConn[0].(*fakeConn) + if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed { + t.Logf("statement close mismatch: made %d, closed %d", made, closed) + } +} |