diff options
-rw-r--r-- | sqlite.go | 228 | ||||
-rw-r--r-- | sqlite_go18.go | 72 | ||||
-rw-r--r-- | sqlite_go18_test.go | 48 |
3 files changed, 295 insertions, 53 deletions
@@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//TODO Pinger (tags go1.8) - //go:generate go run generator.go // Package sqlite is an in-process implementation of a self-contained, @@ -66,6 +64,7 @@ import ( "github.com/cznic/sqlite/internal/bin" "github.com/cznic/virtual" "github.com/cznic/xc" + "golang.org/x/net/context" ) var ( @@ -101,6 +100,7 @@ var ( bindInt int bindInt64 int bindParameterCount int + bindParameterName int bindText int changes int closeV2 int @@ -118,6 +118,7 @@ var ( extendedResultCodes int finalize int free int + interrupt int lastInsertRowID int maloc int openV2 int @@ -159,6 +160,7 @@ func init() { {&bindInt, "sqlite3_bind_int"}, {&bindInt64, "sqlite3_bind_int64"}, {&bindParameterCount, "sqlite3_bind_parameter_count"}, + {&bindParameterName, "sqlite3_bind_parameter_name"}, {&bindText, "sqlite3_bind_text"}, {&changes, "sqlite3_changes"}, {&closeV2, "sqlite3_close_v2"}, @@ -176,6 +178,7 @@ func init() { {&extendedResultCodes, "sqlite3_extended_result_codes"}, {&finalize, "sqlite3_finalize"}, {&free, "sqlite3_free"}, + {&interrupt, "sqlite3_interrupt"}, {&lastInsertRowID, "sqlite3_last_insert_rowid"}, {&maloc, "sqlite3_malloc"}, {&openV2, "sqlite3_open_v2"}, @@ -599,21 +602,37 @@ func (s *stmt) NumInput() (n int) { } // Exec executes a query that doesn't return rows, such as an INSERT or UPDATE. -// -// Deprecated: Drivers should implement StmtExecContext instead (or -// additionally). -func (s *stmt) Exec(args []driver.Value) (r driver.Result, err error) { +func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { + return s.exec(context.Background(), toNamedValues(args)) +} + +func (s *stmt) exec(ctx context.Context, args []namedValue) (r driver.Result, err error) { if trace { - defer func(args []driver.Value) { + defer func(args []namedValue) { tracer(s, "Exec(%v): (%v, %v)", args, r, err) }(args) } + + var pstmt uintptr + + donech := make(chan struct{}) + defer close(donech) + go func() { + select { + case <-ctx.Done(): + if pstmt != 0 { + s.interrupt(s.pdb()) + } + case <-donech: + } + }() + for psql := s.psql; readI8(psql) != 0; psql = readPtr(s.pzTail) { if err := s.prepareV2(psql); err != nil { return nil, err } - pstmt := readPtr(s.ppstmt) + pstmt = readPtr(s.ppstmt) if pstmt == 0 { continue } @@ -627,9 +646,8 @@ func (s *stmt) Exec(args []driver.Value) (r driver.Result, err error) { if err = s.bind(pstmt, n, args); err != nil { return nil, err } - - args = args[n:] } + rc, err := s.step(pstmt) if err != nil { s.finalize(pstmt) @@ -650,24 +668,39 @@ func (s *stmt) Exec(args []driver.Value) (r driver.Result, err error) { return newResult(s) } -// Query executes a query that may return rows, such as a SELECT. -// -// Deprecated: Drivers should implement StmtQueryContext instead (or -// additionally). -func (s *stmt) Query(args []driver.Value) (r driver.Rows, err error) { +func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { + return s.query(context.Background(), toNamedValues(args)) +} + +func (s *stmt) query(ctx context.Context, args []namedValue) (r driver.Rows, err error) { if trace { - defer func(args []driver.Value) { + defer func(args []namedValue) { tracer(s, "Query(%v): (%v, %v)", args, r, err) }(args) } + + var pstmt uintptr var rowStmt uintptr var rc0 int + + donech := make(chan struct{}) + defer close(donech) + go func() { + select { + case <-ctx.Done(): + if pstmt != 0 { + s.interrupt(s.pdb()) + } + case <-donech: + } + }() + for psql := s.psql; readI8(psql) != 0; psql = readPtr(s.pzTail) { if err := s.prepareV2(psql); err != nil { return nil, err } - pstmt := readPtr(s.ppstmt) + pstmt = readPtr(s.ppstmt) if pstmt == 0 { continue } @@ -681,9 +714,8 @@ func (s *stmt) Query(args []driver.Value) (r driver.Rows, err error) { if err = s.bind(pstmt, n, args); err != nil { return nil, err } - - args = args[n:] } + rc, err := s.step(pstmt) if err != nil { s.finalize(pstmt) @@ -818,19 +850,46 @@ func (s *stmt) bindText(pstmt uintptr, idx1 int, value string) (err error) { return nil } -func (s *stmt) bind(pstmt uintptr, n int, args []driver.Value) error { - if len(args) < n { - return fmt.Errorf("missing arguments: got %v, expected %v", len(args), n) - } +func (s *stmt) bind(pstmt uintptr, n int, args []namedValue) error { + for i := 1; i <= n; i++ { + name, err := s.bindParameterName(pstmt, i) + if err != nil { + return err + } + + var v namedValue + for _, v = range args { + if name != "" { + // sqlite supports '$', '@' and ':' prefixes for string + // identifiers and '?' for numeric, so we cannot + // combine different prefixes with the same name + // because `database/sql` requires variable names + // to start with a letter + if name[1:] == v.Name[:] { + break + } + } else { + if v.Ordinal == i { + break + } + } + } + + if v.Ordinal == 0 { + if name != "" { + return fmt.Errorf("missing named argument %q", name[1:]) + } else { + return fmt.Errorf("missing argument with %d index", i) + } + } - for i, v := range args[:n] { - switch x := v.(type) { + switch x := v.Value.(type) { case int64: - if err := s.bindInt64(pstmt, i+1, x); err != nil { + if err := s.bindInt64(pstmt, i, x); err != nil { return err } case float64: - if err := s.bindDouble(pstmt, i+1, x); err != nil { + if err := s.bindDouble(pstmt, i, x); err != nil { return err } case bool: @@ -838,19 +897,19 @@ func (s *stmt) bind(pstmt uintptr, n int, args []driver.Value) error { if x { v = 1 } - if err := s.bindInt(pstmt, i+1, v); err != nil { + if err := s.bindInt(pstmt, i, v); err != nil { return err } case []byte: - if err := s.bindBlob(pstmt, i+1, x); err != nil { + if err := s.bindBlob(pstmt, i, x); err != nil { return err } case string: - if err := s.bindText(pstmt, i+1, x); err != nil { + if err := s.bindText(pstmt, i, x); err != nil { return err } case time.Time: - if err := s.bindText(pstmt, i+1, x.String()); err != nil { + if err := s.bindText(pstmt, i, x.String()); err != nil { return err } default: @@ -871,6 +930,18 @@ func (s *stmt) bindParameterCount(pstmt uintptr) (_ int, err error) { return int(r), err } +// const char *sqlite3_bind_parameter_name(sqlite3_stmt*, int); +func (s *stmt) bindParameterName(pstmt uintptr, i int) (string, error) { + var p uintptr + _, err := s.FFI1( + bindParameterName, + virtual.PtrResult{&p}, + virtual.Ptr(pstmt), + virtual.Int32(i), + ) + return virtual.GoString(p), err +} + // int sqlite3_finalize(sqlite3_stmt *pStmt); func (s *stmt) finalize(pstmt uintptr) error { var rc int32 @@ -932,7 +1003,7 @@ func (t *tx) String() string { return fmt.Sprintf("&%T@%p{conn: %p}", *t, t, t.c func newTx(c *conn) (*tx, error) { t := &tx{conn: c} - if err := t.exec("begin"); err != nil { + if err := t.exec(context.Background(), "begin"); err != nil { return nil, err } @@ -946,7 +1017,7 @@ func (t *tx) Commit() (err error) { tracer(t, "Commit(): %v", err) }() } - return t.exec("commit") + return t.exec(context.Background(), "commit") } // Rollback implements driver.Tx. @@ -956,7 +1027,7 @@ func (t *tx) Rollback() (err error) { tracer(t, "Rollback(): %v", err) }() } - return t.exec("rollback") + return t.exec(context.Background(), "rollback") } // int sqlite3_exec( @@ -966,7 +1037,7 @@ func (t *tx) Rollback() (err error) { // void *, /* 1st argument to callback */ // char **errmsg /* Error msg written here */ // ); -func (t *tx) exec(sql string) (err error) { +func (t *tx) exec(ctx context.Context, sql string) (err error) { psql, err := t.cString(sql) if err != nil { return err @@ -974,6 +1045,17 @@ func (t *tx) exec(sql string) (err error) { defer t.free(psql) + // TODO: use t.conn.ExecContext() instead + donech := make(chan struct{}) + defer close(donech) + go func() { + select { + case <-ctx.Done(): + t.interrupt(t.pdb()) + case <-donech: + } + }() + var rc int32 if _, err = t.FFI1( exec, @@ -1049,6 +1131,10 @@ func newConn(s *Driver, name string) (_ *conn, err error) { // Prepare returns a prepared statement, bound to this connection. func (c *conn) Prepare(query string) (s driver.Stmt, err error) { + return c.prepare(context.Background(), query) +} + +func (c *conn) prepare(ctx context.Context, query string) (s driver.Stmt, err error) { if trace { defer func() { tracer(c, "Prepare(%s): (%v, %v)", query, s, err) @@ -1072,13 +1158,21 @@ func (c *conn) Close() (err error) { return c.close() } -// Begin starts and returns a new transaction. -// -// Deprecated: Drivers should implement ConnBeginTx instead (or additionally). -func (c *conn) Begin() (t driver.Tx, err error) { +// Begin starts a transaction. +func (c *conn) Begin() (driver.Tx, error) { + return c.begin(context.Background(), txOptions{}) +} + +// copy of driver.TxOptions +type txOptions struct { + Isolation int // driver.IsolationLevel + ReadOnly bool +} + +func (c *conn) begin(ctx context.Context, opts txOptions) (t driver.Tx, err error) { if trace { defer func() { - tracer(c, "Begin(): (%v, %v)", t, err) + tracer(c, "BeginTx(): (%v, %v)", t, err) }() } return newTx(c) @@ -1090,16 +1184,18 @@ func (c *conn) Begin() (t driver.Tx, err error) { // prepare a query, execute the statement, and then close the statement. // // Exec may return ErrSkip. -// -// Deprecated: Drivers should implement ExecerContext instead (or -// additionally). -func (c *conn) Exec(query string, args []driver.Value) (r driver.Result, err error) { +func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { + return c.exec(context.Background(), query, toNamedValues(args)) +} + +func (c *conn) exec(ctx context.Context, query string, args []namedValue) (r driver.Result, err error) { if trace { defer func() { - tracer(c, "Exec(%s, %v): (%v, %v)", query, args, r, err) + tracer(c, "ExecContext(%s, %v): (%v, %v)", query, args, r, err) }() } - s, err := c.Prepare(query) + + s, err := c.PrepareContext(ctx, query) if err != nil { return nil, err } @@ -1110,7 +1206,23 @@ func (c *conn) Exec(query string, args []driver.Value) (r driver.Result, err err } }() - return s.Exec(args) + return s.(*stmt).exec(ctx, args) +} + +// copy of driver.NameValue +type namedValue struct { + Name string + Ordinal int + Value driver.Value +} + +// toNamedValues converts []driver.Value to []namedValue +func toNamedValues(vals []driver.Value) []namedValue { + args := make([]namedValue, 0, len(vals)) + for i, val := range vals { + args = append(args, namedValue{Value: val, Ordinal: i + 1}) + } + return args } // Queryer is an optional interface that may be implemented by a Conn. @@ -1119,16 +1231,17 @@ func (c *conn) Exec(query string, args []driver.Value) (r driver.Result, err err // prepare a query, execute the statement, and then close the statement. // // Query may return ErrSkip. -// -// Deprecated: Drivers should implement QueryerContext instead (or -// additionally). -func (c *conn) Query(query string, args []driver.Value) (r driver.Rows, err error) { +func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { + return c.query(context.Background(), query, toNamedValues(args)) +} + +func (c *conn) query(ctx context.Context, query string, args []namedValue) (r driver.Rows, err error) { if trace { defer func() { tracer(c, "Query(%s, %v): (%v, %v)", query, args, r, err) }() } - s, err := c.Prepare(query) + s, err := c.PrepareContext(ctx, query) if err != nil { return nil, err } @@ -1139,7 +1252,7 @@ func (c *conn) Query(query string, args []driver.Value) (r driver.Rows, err erro } }() - return s.Query(args) + return s.(*stmt).query(ctx, args) } func (c *conn) pdb() uintptr { return readPtr(c.ppdb) } @@ -1278,6 +1391,15 @@ func (c *conn) free(p uintptr) (err error) { return err } +// void sqlite3_interrupt(sqlite3*); +func (c *conn) interrupt(pdb uintptr) (err error) { + _, err = c.FFI0( + interrupt, + virtual.Ptr(pdb), + ) + return err +} + func (c *conn) close() (err error) { c.Lock() diff --git a/sqlite_go18.go b/sqlite_go18.go new file mode 100644 index 0000000..4338dcc --- /dev/null +++ b/sqlite_go18.go @@ -0,0 +1,72 @@ +// Copyright 2017 The Sqlite Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//+build go1.8 + +package sqlite + +import ( + "context" + "database/sql/driver" + "errors" +) + +// Ping implements driver.Pinger +func (c *conn) Ping(ctx context.Context) error { + c.Lock() + defer c.Unlock() + + if c.ppdb == 0 { + return errors.New("db is closed") + } + + _, err := c.ExecContext(ctx, "select 1", nil) + return err +} + +// BeginTx implements driver.ConnBeginTx +func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + return c.begin(ctx, txOptions{ + Isolation: int(opts.Isolation), + ReadOnly: opts.ReadOnly, + }) +} + +// PrepareContext implements driver.ConnPrepareContext +func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + return c.prepare(ctx, query) +} + +// ExecContext implements driver.ExecerContext +func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + return c.exec(ctx, query, toNamedValues2(args)) +} + +// QueryContext implements driver.QueryerContext +func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + return c.query(ctx, query, toNamedValues2(args)) +} + +// ExecContext implements driver.StmtExecContext +func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + return s.exec(ctx, toNamedValues2(args)) +} + +// QueryContext implements driver.StmtQueryContext +func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + return s.query(ctx, toNamedValues2(args)) +} + +// converts []driver.NamedValue to []namedValue +func toNamedValues2(vals []driver.NamedValue) []namedValue { + args := make([]namedValue, 0, len(vals)) + for _, val := range vals { + args = append(args, namedValue{ + Name: val.Name, + Ordinal: val.Ordinal, + Value: val.Value, + }) + } + return args +} diff --git a/sqlite_go18_test.go b/sqlite_go18_test.go new file mode 100644 index 0000000..d558488 --- /dev/null +++ b/sqlite_go18_test.go @@ -0,0 +1,48 @@ +// Copyright 2017 The Sqlite Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//+build go1.8 + +package sqlite + +import ( + "database/sql" + "os" + "reflect" + "testing" +) + +func TestNamedParameters(t *testing.T) { + dir, db := tempDB(t) + defer func() { + db.Close() + os.RemoveAll(dir) + }() + + _, err := db.Exec(` + create table t(s1 varchar(32), s2 varchar(32), s3 varchar(32), s4 varchar(32)); + insert into t values(?, @aa, $aa, @bb); + `, "1", sql.Named("aa", "one"), sql.Named("bb", "two")) + + if err != nil { + t.Fatal(err) + } + + rows, err := db.Query("select * from t") + if err != nil { + t.Fatal(err) + } + + rec := make([]string, 4) + for rows.Next() { + if err := rows.Scan(&rec[0], &rec[1], &rec[2], &rec[3]); err != nil { + t.Fatal(err) + } + } + + w := []string{"1", "one", "one", "two"} + if !reflect.DeepEqual(rec, w) { + t.Fatal(rec, w) + } +} |