From a0ae409cb1d614ceeefe7c7fbed576e332fba262 Mon Sep 17 00:00:00 2001 From: Alexander Menzhinsky Date: Mon, 1 May 2017 23:49:25 +0300 Subject: Add named parameters support --- sqlite.go | 228 +++++++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 175 insertions(+), 53 deletions(-) (limited to 'sqlite.go') diff --git a/sqlite.go b/sqlite.go index ed318ef..656be0c 100644 --- a/sqlite.go +++ b/sqlite.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//TODO Pinger (tags go1.8) - //go:generate go run generator.go // Package sqlite is an in-process implementation of a self-contained, @@ -66,6 +64,7 @@ import ( "github.com/cznic/sqlite/internal/bin" "github.com/cznic/virtual" "github.com/cznic/xc" + "golang.org/x/net/context" ) var ( @@ -101,6 +100,7 @@ var ( bindInt int bindInt64 int bindParameterCount int + bindParameterName int bindText int changes int closeV2 int @@ -118,6 +118,7 @@ var ( extendedResultCodes int finalize int free int + interrupt int lastInsertRowID int maloc int openV2 int @@ -159,6 +160,7 @@ func init() { {&bindInt, "sqlite3_bind_int"}, {&bindInt64, "sqlite3_bind_int64"}, {&bindParameterCount, "sqlite3_bind_parameter_count"}, + {&bindParameterName, "sqlite3_bind_parameter_name"}, {&bindText, "sqlite3_bind_text"}, {&changes, "sqlite3_changes"}, {&closeV2, "sqlite3_close_v2"}, @@ -176,6 +178,7 @@ func init() { {&extendedResultCodes, "sqlite3_extended_result_codes"}, {&finalize, "sqlite3_finalize"}, {&free, "sqlite3_free"}, + {&interrupt, "sqlite3_interrupt"}, {&lastInsertRowID, "sqlite3_last_insert_rowid"}, {&maloc, "sqlite3_malloc"}, {&openV2, "sqlite3_open_v2"}, @@ -599,21 +602,37 @@ func (s *stmt) NumInput() (n int) { } // Exec executes a query that doesn't return rows, such as an INSERT or UPDATE. -// -// Deprecated: Drivers should implement StmtExecContext instead (or -// additionally). -func (s *stmt) Exec(args []driver.Value) (r driver.Result, err error) { +func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { + return s.exec(context.Background(), toNamedValues(args)) +} + +func (s *stmt) exec(ctx context.Context, args []namedValue) (r driver.Result, err error) { if trace { - defer func(args []driver.Value) { + defer func(args []namedValue) { tracer(s, "Exec(%v): (%v, %v)", args, r, err) }(args) } + + var pstmt uintptr + + donech := make(chan struct{}) + defer close(donech) + go func() { + select { + case <-ctx.Done(): + if pstmt != 0 { + s.interrupt(s.pdb()) + } + case <-donech: + } + }() + for psql := s.psql; readI8(psql) != 0; psql = readPtr(s.pzTail) { if err := s.prepareV2(psql); err != nil { return nil, err } - pstmt := readPtr(s.ppstmt) + pstmt = readPtr(s.ppstmt) if pstmt == 0 { continue } @@ -627,9 +646,8 @@ func (s *stmt) Exec(args []driver.Value) (r driver.Result, err error) { if err = s.bind(pstmt, n, args); err != nil { return nil, err } - - args = args[n:] } + rc, err := s.step(pstmt) if err != nil { s.finalize(pstmt) @@ -650,24 +668,39 @@ func (s *stmt) Exec(args []driver.Value) (r driver.Result, err error) { return newResult(s) } -// Query executes a query that may return rows, such as a SELECT. -// -// Deprecated: Drivers should implement StmtQueryContext instead (or -// additionally). -func (s *stmt) Query(args []driver.Value) (r driver.Rows, err error) { +func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { + return s.query(context.Background(), toNamedValues(args)) +} + +func (s *stmt) query(ctx context.Context, args []namedValue) (r driver.Rows, err error) { if trace { - defer func(args []driver.Value) { + defer func(args []namedValue) { tracer(s, "Query(%v): (%v, %v)", args, r, err) }(args) } + + var pstmt uintptr var rowStmt uintptr var rc0 int + + donech := make(chan struct{}) + defer close(donech) + go func() { + select { + case <-ctx.Done(): + if pstmt != 0 { + s.interrupt(s.pdb()) + } + case <-donech: + } + }() + for psql := s.psql; readI8(psql) != 0; psql = readPtr(s.pzTail) { if err := s.prepareV2(psql); err != nil { return nil, err } - pstmt := readPtr(s.ppstmt) + pstmt = readPtr(s.ppstmt) if pstmt == 0 { continue } @@ -681,9 +714,8 @@ func (s *stmt) Query(args []driver.Value) (r driver.Rows, err error) { if err = s.bind(pstmt, n, args); err != nil { return nil, err } - - args = args[n:] } + rc, err := s.step(pstmt) if err != nil { s.finalize(pstmt) @@ -818,19 +850,46 @@ func (s *stmt) bindText(pstmt uintptr, idx1 int, value string) (err error) { return nil } -func (s *stmt) bind(pstmt uintptr, n int, args []driver.Value) error { - if len(args) < n { - return fmt.Errorf("missing arguments: got %v, expected %v", len(args), n) - } +func (s *stmt) bind(pstmt uintptr, n int, args []namedValue) error { + for i := 1; i <= n; i++ { + name, err := s.bindParameterName(pstmt, i) + if err != nil { + return err + } + + var v namedValue + for _, v = range args { + if name != "" { + // sqlite supports '$', '@' and ':' prefixes for string + // identifiers and '?' for numeric, so we cannot + // combine different prefixes with the same name + // because `database/sql` requires variable names + // to start with a letter + if name[1:] == v.Name[:] { + break + } + } else { + if v.Ordinal == i { + break + } + } + } + + if v.Ordinal == 0 { + if name != "" { + return fmt.Errorf("missing named argument %q", name[1:]) + } else { + return fmt.Errorf("missing argument with %d index", i) + } + } - for i, v := range args[:n] { - switch x := v.(type) { + switch x := v.Value.(type) { case int64: - if err := s.bindInt64(pstmt, i+1, x); err != nil { + if err := s.bindInt64(pstmt, i, x); err != nil { return err } case float64: - if err := s.bindDouble(pstmt, i+1, x); err != nil { + if err := s.bindDouble(pstmt, i, x); err != nil { return err } case bool: @@ -838,19 +897,19 @@ func (s *stmt) bind(pstmt uintptr, n int, args []driver.Value) error { if x { v = 1 } - if err := s.bindInt(pstmt, i+1, v); err != nil { + if err := s.bindInt(pstmt, i, v); err != nil { return err } case []byte: - if err := s.bindBlob(pstmt, i+1, x); err != nil { + if err := s.bindBlob(pstmt, i, x); err != nil { return err } case string: - if err := s.bindText(pstmt, i+1, x); err != nil { + if err := s.bindText(pstmt, i, x); err != nil { return err } case time.Time: - if err := s.bindText(pstmt, i+1, x.String()); err != nil { + if err := s.bindText(pstmt, i, x.String()); err != nil { return err } default: @@ -871,6 +930,18 @@ func (s *stmt) bindParameterCount(pstmt uintptr) (_ int, err error) { return int(r), err } +// const char *sqlite3_bind_parameter_name(sqlite3_stmt*, int); +func (s *stmt) bindParameterName(pstmt uintptr, i int) (string, error) { + var p uintptr + _, err := s.FFI1( + bindParameterName, + virtual.PtrResult{&p}, + virtual.Ptr(pstmt), + virtual.Int32(i), + ) + return virtual.GoString(p), err +} + // int sqlite3_finalize(sqlite3_stmt *pStmt); func (s *stmt) finalize(pstmt uintptr) error { var rc int32 @@ -932,7 +1003,7 @@ func (t *tx) String() string { return fmt.Sprintf("&%T@%p{conn: %p}", *t, t, t.c func newTx(c *conn) (*tx, error) { t := &tx{conn: c} - if err := t.exec("begin"); err != nil { + if err := t.exec(context.Background(), "begin"); err != nil { return nil, err } @@ -946,7 +1017,7 @@ func (t *tx) Commit() (err error) { tracer(t, "Commit(): %v", err) }() } - return t.exec("commit") + return t.exec(context.Background(), "commit") } // Rollback implements driver.Tx. @@ -956,7 +1027,7 @@ func (t *tx) Rollback() (err error) { tracer(t, "Rollback(): %v", err) }() } - return t.exec("rollback") + return t.exec(context.Background(), "rollback") } // int sqlite3_exec( @@ -966,7 +1037,7 @@ func (t *tx) Rollback() (err error) { // void *, /* 1st argument to callback */ // char **errmsg /* Error msg written here */ // ); -func (t *tx) exec(sql string) (err error) { +func (t *tx) exec(ctx context.Context, sql string) (err error) { psql, err := t.cString(sql) if err != nil { return err @@ -974,6 +1045,17 @@ func (t *tx) exec(sql string) (err error) { defer t.free(psql) + // TODO: use t.conn.ExecContext() instead + donech := make(chan struct{}) + defer close(donech) + go func() { + select { + case <-ctx.Done(): + t.interrupt(t.pdb()) + case <-donech: + } + }() + var rc int32 if _, err = t.FFI1( exec, @@ -1049,6 +1131,10 @@ func newConn(s *Driver, name string) (_ *conn, err error) { // Prepare returns a prepared statement, bound to this connection. func (c *conn) Prepare(query string) (s driver.Stmt, err error) { + return c.prepare(context.Background(), query) +} + +func (c *conn) prepare(ctx context.Context, query string) (s driver.Stmt, err error) { if trace { defer func() { tracer(c, "Prepare(%s): (%v, %v)", query, s, err) @@ -1072,13 +1158,21 @@ func (c *conn) Close() (err error) { return c.close() } -// Begin starts and returns a new transaction. -// -// Deprecated: Drivers should implement ConnBeginTx instead (or additionally). -func (c *conn) Begin() (t driver.Tx, err error) { +// Begin starts a transaction. +func (c *conn) Begin() (driver.Tx, error) { + return c.begin(context.Background(), txOptions{}) +} + +// copy of driver.TxOptions +type txOptions struct { + Isolation int // driver.IsolationLevel + ReadOnly bool +} + +func (c *conn) begin(ctx context.Context, opts txOptions) (t driver.Tx, err error) { if trace { defer func() { - tracer(c, "Begin(): (%v, %v)", t, err) + tracer(c, "BeginTx(): (%v, %v)", t, err) }() } return newTx(c) @@ -1090,16 +1184,18 @@ func (c *conn) Begin() (t driver.Tx, err error) { // prepare a query, execute the statement, and then close the statement. // // Exec may return ErrSkip. -// -// Deprecated: Drivers should implement ExecerContext instead (or -// additionally). -func (c *conn) Exec(query string, args []driver.Value) (r driver.Result, err error) { +func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { + return c.exec(context.Background(), query, toNamedValues(args)) +} + +func (c *conn) exec(ctx context.Context, query string, args []namedValue) (r driver.Result, err error) { if trace { defer func() { - tracer(c, "Exec(%s, %v): (%v, %v)", query, args, r, err) + tracer(c, "ExecContext(%s, %v): (%v, %v)", query, args, r, err) }() } - s, err := c.Prepare(query) + + s, err := c.PrepareContext(ctx, query) if err != nil { return nil, err } @@ -1110,7 +1206,23 @@ func (c *conn) Exec(query string, args []driver.Value) (r driver.Result, err err } }() - return s.Exec(args) + return s.(*stmt).exec(ctx, args) +} + +// copy of driver.NameValue +type namedValue struct { + Name string + Ordinal int + Value driver.Value +} + +// toNamedValues converts []driver.Value to []namedValue +func toNamedValues(vals []driver.Value) []namedValue { + args := make([]namedValue, 0, len(vals)) + for i, val := range vals { + args = append(args, namedValue{Value: val, Ordinal: i + 1}) + } + return args } // Queryer is an optional interface that may be implemented by a Conn. @@ -1119,16 +1231,17 @@ func (c *conn) Exec(query string, args []driver.Value) (r driver.Result, err err // prepare a query, execute the statement, and then close the statement. // // Query may return ErrSkip. -// -// Deprecated: Drivers should implement QueryerContext instead (or -// additionally). -func (c *conn) Query(query string, args []driver.Value) (r driver.Rows, err error) { +func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { + return c.query(context.Background(), query, toNamedValues(args)) +} + +func (c *conn) query(ctx context.Context, query string, args []namedValue) (r driver.Rows, err error) { if trace { defer func() { tracer(c, "Query(%s, %v): (%v, %v)", query, args, r, err) }() } - s, err := c.Prepare(query) + s, err := c.PrepareContext(ctx, query) if err != nil { return nil, err } @@ -1139,7 +1252,7 @@ func (c *conn) Query(query string, args []driver.Value) (r driver.Rows, err erro } }() - return s.Query(args) + return s.(*stmt).query(ctx, args) } func (c *conn) pdb() uintptr { return readPtr(c.ppdb) } @@ -1278,6 +1391,15 @@ func (c *conn) free(p uintptr) (err error) { return err } +// void sqlite3_interrupt(sqlite3*); +func (c *conn) interrupt(pdb uintptr) (err error) { + _, err = c.FFI0( + interrupt, + virtual.Ptr(pdb), + ) + return err +} + func (c *conn) close() (err error) { c.Lock() -- cgit v1.2.3-70-g09d2