aboutsummaryrefslogtreecommitdiff
path: root/libgo/go/exp/sql
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/exp/sql')
-rw-r--r--libgo/go/exp/sql/convert_test.go13
-rw-r--r--libgo/go/exp/sql/driver/driver.go1
-rw-r--r--libgo/go/exp/sql/driver/types.go9
-rw-r--r--libgo/go/exp/sql/driver/types_test.go4
-rw-r--r--libgo/go/exp/sql/fakedb_test.go72
-rw-r--r--libgo/go/exp/sql/sql.go58
-rw-r--r--libgo/go/exp/sql/sql_test.go55
7 files changed, 182 insertions, 30 deletions
diff --git a/libgo/go/exp/sql/convert_test.go b/libgo/go/exp/sql/convert_test.go
index bed09ff..702ba43 100644
--- a/libgo/go/exp/sql/convert_test.go
+++ b/libgo/go/exp/sql/convert_test.go
@@ -8,8 +8,11 @@ import (
"fmt"
"reflect"
"testing"
+ "time"
)
+var someTime = time.Unix(123, 0)
+
type conversionTest struct {
s, d interface{} // source and destination
@@ -19,6 +22,7 @@ type conversionTest struct {
wantstr string
wantf32 float32
wantf64 float64
+ wanttime time.Time
wantbool bool // used if d is of type *bool
wanterr string
}
@@ -35,12 +39,14 @@ var (
scanbool bool
scanf32 float32
scanf64 float64
+ scantime time.Time
)
var conversionTests = []conversionTest{
// Exact conversions (destination pointer type matches source type)
{s: "foo", d: &scanstr, wantstr: "foo"},
{s: 123, d: &scanint, wantint: 123},
+ {s: someTime, d: &scantime, wanttime: someTime},
// To strings
{s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"},
@@ -106,6 +112,10 @@ func float32Value(ptr interface{}) float32 {
return *(ptr.(*float32))
}
+func timeValue(ptr interface{}) time.Time {
+ return *(ptr.(*time.Time))
+}
+
func TestConversions(t *testing.T) {
for n, ct := range conversionTests {
err := convertAssign(ct.d, ct.s)
@@ -138,6 +148,9 @@ func TestConversions(t *testing.T) {
if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" {
errf("want bool %v, got %v", ct.wantbool, *bp)
}
+ if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) {
+ errf("want time %v, got %v", ct.wanttime, timeValue(ct.d))
+ }
}
}
diff --git a/libgo/go/exp/sql/driver/driver.go b/libgo/go/exp/sql/driver/driver.go
index f0bcca2..0cd2562 100644
--- a/libgo/go/exp/sql/driver/driver.go
+++ b/libgo/go/exp/sql/driver/driver.go
@@ -16,6 +16,7 @@
// nil
// []byte
// string [*] everywhere except from Rows.Next.
+// time.Time
//
package driver
diff --git a/libgo/go/exp/sql/driver/types.go b/libgo/go/exp/sql/driver/types.go
index 086b529..d6ba641 100644
--- a/libgo/go/exp/sql/driver/types.go
+++ b/libgo/go/exp/sql/driver/types.go
@@ -8,6 +8,7 @@ import (
"fmt"
"reflect"
"strconv"
+ "time"
)
// ValueConverter is the interface providing the ConvertValue method.
@@ -39,7 +40,7 @@ type ValueConverter interface {
// 1 is true
// 0 is false,
// other integers are an error
-// - for strings and []byte, same rules as strconv.Atob
+// - for strings and []byte, same rules as strconv.ParseBool
// - all other types are an error
var Bool boolType
@@ -143,9 +144,10 @@ func (stringType) ConvertValue(v interface{}) (interface{}, error) {
// bool
// nil
// []byte
+// time.Time
// string
//
-// This is the ame list as IsScanSubsetType, with the addition of
+// This is the same list as IsScanSubsetType, with the addition of
// string.
func IsParameterSubsetType(v interface{}) bool {
if IsScanSubsetType(v) {
@@ -165,6 +167,7 @@ func IsParameterSubsetType(v interface{}) bool {
// bool
// nil
// []byte
+// time.Time
//
// This is the same list as IsParameterSubsetType, without string.
func IsScanSubsetType(v interface{}) bool {
@@ -172,7 +175,7 @@ func IsScanSubsetType(v interface{}) bool {
return true
}
switch v.(type) {
- case int64, float64, []byte, bool:
+ case int64, float64, []byte, bool, time.Time:
return true
}
return false
diff --git a/libgo/go/exp/sql/driver/types_test.go b/libgo/go/exp/sql/driver/types_test.go
index 4b049e2..966bc6b 100644
--- a/libgo/go/exp/sql/driver/types_test.go
+++ b/libgo/go/exp/sql/driver/types_test.go
@@ -7,6 +7,7 @@ package driver
import (
"reflect"
"testing"
+ "time"
)
type valueConverterTest struct {
@@ -16,6 +17,8 @@ type valueConverterTest struct {
err string
}
+var now = time.Now()
+
var valueConverterTests = []valueConverterTest{
{Bool, "true", true, ""},
{Bool, "True", true, ""},
@@ -33,6 +36,7 @@ var valueConverterTests = []valueConverterTest{
{Bool, uint16(0), false, ""},
{c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"},
{c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"},
+ {DefaultParameterConverter, now, now, ""},
}
func TestValueConverters(t *testing.T) {
diff --git a/libgo/go/exp/sql/fakedb_test.go b/libgo/go/exp/sql/fakedb_test.go
index 2474a86..70aa68c 100644
--- a/libgo/go/exp/sql/fakedb_test.go
+++ b/libgo/go/exp/sql/fakedb_test.go
@@ -12,6 +12,7 @@ import (
"strconv"
"strings"
"sync"
+ "time"
"exp/sql/driver"
)
@@ -77,6 +78,17 @@ type fakeConn struct {
db *fakeDB // where to return ourselves to
currTx *fakeTx
+
+ // Stats for tests:
+ mu sync.Mutex
+ stmtsMade int
+ stmtsClosed int
+}
+
+func (c *fakeConn) incrStat(v *int) {
+ c.mu.Lock()
+ *v++
+ c.mu.Unlock()
}
type fakeTx struct {
@@ -110,25 +122,34 @@ func init() {
// Supports dsn forms:
// <dbname>
-// <dbname>;wipe
+// <dbname>;<opts> (no currently supported options)
func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
- d.mu.Lock()
- defer d.mu.Unlock()
- d.openCount++
- if d.dbs == nil {
- d.dbs = make(map[string]*fakeDB)
- }
parts := strings.Split(dsn, ";")
if len(parts) < 1 {
return nil, errors.New("fakedb: no database name")
}
name := parts[0]
+
+ db := d.getDB(name)
+
+ d.mu.Lock()
+ d.openCount++
+ d.mu.Unlock()
+ return &fakeConn{db: db}, nil
+}
+
+func (d *fakeDriver) getDB(name string) *fakeDB {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ if d.dbs == nil {
+ d.dbs = make(map[string]*fakeDB)
+ }
db, ok := d.dbs[name]
if !ok {
db = &fakeDB{name: name}
d.dbs[name] = db
}
- return &fakeConn{db: db}, nil
+ return db
}
func (db *fakeDB) wipe() {
@@ -200,7 +221,7 @@ func (c *fakeConn) Close() error {
func checkSubsetTypes(args []interface{}) error {
for n, arg := range args {
switch arg.(type) {
- case int64, float64, bool, nil, []byte, string:
+ case int64, float64, bool, nil, []byte, string, time.Time:
default:
return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
}
@@ -297,6 +318,8 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
switch ctype {
case "string":
subsetVal = []byte(value)
+ case "blob":
+ subsetVal = []byte(value)
case "int32":
i, err := strconv.Atoi(value)
if err != nil {
@@ -327,6 +350,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
cmd := parts[0]
parts = parts[1:]
stmt := &fakeStmt{q: query, c: c, cmd: cmd}
+ c.incrStat(&c.stmtsMade)
switch cmd {
case "WIPE":
// Nothing
@@ -347,7 +371,10 @@ func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
}
func (s *fakeStmt) Close() error {
- s.closed = true
+ if !s.closed {
+ s.c.incrStat(&s.c.stmtsClosed)
+ s.closed = true
+ }
return nil
}
@@ -501,9 +528,19 @@ type rowsCursor struct {
pos int
rows []*row
closed bool
+
+ // a clone of slices to give out to clients, indexed by the
+ // the original slice's first byte address. we clone them
+ // just so we're able to corrupt them on close.
+ bytesClone map[*byte][]byte
}
func (rc *rowsCursor) Close() error {
+ if !rc.closed {
+ for _, bs := range rc.bytesClone {
+ bs[0] = 255 // first byte corrupted
+ }
+ }
rc.closed = true
return nil
}
@@ -528,6 +565,19 @@ func (rc *rowsCursor) Next(dest []interface{}) error {
// for ease of drivers, and to prevent drivers from
// messing up conversions or doing them differently.
dest[i] = v
+
+ if bs, ok := v.([]byte); ok {
+ if rc.bytesClone == nil {
+ rc.bytesClone = make(map[*byte][]byte)
+ }
+ clone, ok := rc.bytesClone[&bs[0]]
+ if !ok {
+ clone = make([]byte, len(bs))
+ copy(clone, bs)
+ rc.bytesClone[&bs[0]] = clone
+ }
+ dest[i] = clone
+ }
}
return nil
}
@@ -540,6 +590,8 @@ func converterForType(typ string) driver.ValueConverter {
return driver.Int32
case "string":
return driver.String
+ case "datetime":
+ return driver.DefaultParameterConverter
}
panic("invalid fakedb column type of " + typ)
}
diff --git a/libgo/go/exp/sql/sql.go b/libgo/go/exp/sql/sql.go
index 937982c..4e68c3e 100644
--- a/libgo/go/exp/sql/sql.go
+++ b/libgo/go/exp/sql/sql.go
@@ -243,8 +243,13 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
if err != nil {
return nil, err
}
- defer stmt.Close()
- return stmt.Query(args...)
+ rows, err := stmt.Query(args...)
+ if err != nil {
+ stmt.Close()
+ return nil, err
+ }
+ rows.closeStmt = stmt
+ return rows, nil
}
// QueryRow executes a query that is expected to return at most one row.
@@ -549,8 +554,8 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
// statement, a function to call to release the connection, and a
// statement bound to that connection.
func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) {
- if s.stickyErr != nil {
- return nil, nil, nil, s.stickyErr
+ if err = s.stickyErr; err != nil {
+ return
}
s.mu.Lock()
if s.closed {
@@ -706,9 +711,10 @@ type Rows struct {
releaseConn func()
rowsi driver.Rows
- closed bool
- lastcols []interface{}
- lasterr error
+ closed bool
+ lastcols []interface{}
+ lasterr error
+ closeStmt *Stmt // if non-nil, statement to Close on close
}
// Next prepares the next result row for reading with the Scan method.
@@ -726,6 +732,9 @@ func (rs *Rows) Next() bool {
rs.lastcols = make([]interface{}, len(rs.rowsi.Columns()))
}
rs.lasterr = rs.rowsi.Next(rs.lastcols)
+ if rs.lasterr == io.EOF {
+ rs.Close()
+ }
return rs.lasterr == nil
}
@@ -786,6 +795,9 @@ func (rs *Rows) Close() error {
rs.closed = true
err := rs.rowsi.Close()
rs.releaseConn()
+ if rs.closeStmt != nil {
+ rs.closeStmt.Close()
+ }
return err
}
@@ -800,10 +812,6 @@ type Row struct {
// pointed at by dest. If more than one row matches the query,
// Scan uses the first row and discards the rest. If no row matches
// the query, Scan returns ErrNoRows.
-//
-// If dest contains pointers to []byte, the slices should not be
-// modified and should only be considered valid until the next call to
-// Next or Scan.
func (r *Row) Scan(dest ...interface{}) error {
if r.err != nil {
return r.err
@@ -812,7 +820,33 @@ func (r *Row) Scan(dest ...interface{}) error {
if !r.rows.Next() {
return ErrNoRows
}
- return r.rows.Scan(dest...)
+ err := r.rows.Scan(dest...)
+ if err != nil {
+ return err
+ }
+
+ // TODO(bradfitz): for now we need to defensively clone all
+ // []byte that the driver returned, since we're about to close
+ // the Rows in our defer, when we return from this function.
+ // the contract with the driver.Next(...) interface is that it
+ // can return slices into read-only temporary memory that's
+ // only valid until the next Scan/Close. But the TODO is that
+ // for a lot of drivers, this copy will be unnecessary. We
+ // should provide an optional interface for drivers to
+ // implement to say, "don't worry, the []bytes that I return
+ // from Next will not be modified again." (for instance, if
+ // they were obtained from the network anyway) But for now we
+ // don't care.
+ for _, dp := range dest {
+ b, ok := dp.(*[]byte)
+ if !ok {
+ continue
+ }
+ clone := make([]byte, len(*b))
+ copy(clone, *b)
+ *b = clone
+ }
+ return nil
}
// A Result summarizes an executed SQL command.
diff --git a/libgo/go/exp/sql/sql_test.go b/libgo/go/exp/sql/sql_test.go
index 5307a23..3f98a8c 100644
--- a/libgo/go/exp/sql/sql_test.go
+++ b/libgo/go/exp/sql/sql_test.go
@@ -8,10 +8,15 @@ import (
"reflect"
"strings"
"testing"
+ "time"
)
+const fakeDBName = "foo"
+
+var chrisBirthday = time.Unix(123456789, 0)
+
func newTestDB(t *testing.T, name string) *DB {
- db, err := Open("test", "foo")
+ db, err := Open("test", fakeDBName)
if err != nil {
t.Fatalf("Open: %v", err)
}
@@ -19,10 +24,10 @@ func newTestDB(t *testing.T, name string) *DB {
t.Fatalf("exec wipe: %v", err)
}
if name == "people" {
- exec(t, db, "CREATE|people|name=string,age=int32,dead=bool")
- exec(t, db, "INSERT|people|name=Alice,age=?", 1)
- exec(t, db, "INSERT|people|name=Bob,age=?", 2)
- exec(t, db, "INSERT|people|name=Chris,age=?", 3)
+ exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime")
+ exec(t, db, "INSERT|people|name=Alice,age=?,photo=APHOTO", 1)
+ exec(t, db, "INSERT|people|name=Bob,age=?,photo=BPHOTO", 2)
+ exec(t, db, "INSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
}
return db
}
@@ -73,6 +78,12 @@ func TestQuery(t *testing.T) {
if !reflect.DeepEqual(got, want) {
t.Logf(" got: %#v\nwant: %#v", got, want)
}
+
+ // And verify that the final rows.Next() call, which hit EOF,
+ // also closed the rows connection.
+ if n := len(db.freeConn); n != 1 {
+ t.Errorf("free conns after query hitting EOF = %d; want 1", n)
+ }
}
func TestRowsColumns(t *testing.T) {
@@ -97,12 +108,18 @@ func TestQueryRow(t *testing.T) {
defer closeDB(t, db)
var name string
var age int
+ var birthday time.Time
err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age)
if err == nil || !strings.Contains(err.Error(), "expected 2 destination arguments") {
t.Errorf("expected error from wrong number of arguments; actually got: %v", err)
}
+ err = db.QueryRow("SELECT|people|bdate|age=?", 3).Scan(&birthday)
+ if err != nil || !birthday.Equal(chrisBirthday) {
+ t.Errorf("chris birthday = %v, err = %v; want %v", birthday, err, chrisBirthday)
+ }
+
err = db.QueryRow("SELECT|people|age,name|age=?", 2).Scan(&age, &name)
if err != nil {
t.Fatalf("age QueryRow+Scan: %v", err)
@@ -124,6 +141,16 @@ func TestQueryRow(t *testing.T) {
if age != 1 {
t.Errorf("expected age 1, got %d", age)
}
+
+ var photo []byte
+ err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo)
+ if err != nil {
+ t.Fatalf("photo QueryRow+Scan: %v", err)
+ }
+ want := []byte("APHOTO")
+ if !reflect.DeepEqual(photo, want) {
+ t.Errorf("photo = %q; want %q", photo, want)
+ }
}
func TestStatementErrorAfterClose(t *testing.T) {
@@ -258,3 +285,21 @@ func TestIssue2542Deadlock(t *testing.T) {
}
}
}
+
+func TestQueryRowClosingStmt(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ var name string
+ var age int
+ err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age, &name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(db.freeConn) != 1 {
+ t.Fatalf("expected 1 free conn")
+ }
+ fakeConn := db.freeConn[0].(*fakeConn)
+ if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed {
+ t.Logf("statement close mismatch: made %d, closed %d", made, closed)
+ }
+}