aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--all_test.go64
-rw-r--r--sqlite.go121
2 files changed, 178 insertions, 7 deletions
diff --git a/all_test.go b/all_test.go
index 08f565b..676fd8c 100644
--- a/all_test.go
+++ b/all_test.go
@@ -826,3 +826,67 @@ func TestIssue28(t *testing.T) {
t.Fatalf("got %T(%[1]v), expected %T(%[2]v)", err, sql.ErrNoRows)
}
}
+
+// https://gitlab.com/cznic/sqlite/-/issues/30
+func TestIssue30(t *testing.T) {
+ tempDir, err := ioutil.TempDir("", "")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ defer os.RemoveAll(tempDir)
+
+ db, err := sql.Open("sqlite", filepath.Join(tempDir, "test.db"))
+ if err != nil {
+ t.Fatalf("test.db open fail: %v", err)
+ }
+
+ defer db.Close()
+
+ _, err = db.Query("CREATE TABLE IF NOT EXISTS `userinfo` (`uid` INTEGER PRIMARY KEY AUTOINCREMENT,`username` VARCHAR(64) NULL, `departname` VARCHAR(64) NULL, `created` DATE NULL);")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ insertStatement := `INSERT INTO userinfo(username, departname, created) values("astaxie", "研发部门", "2012-12-09")`
+ _, err = db.Query(insertStatement)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rows2, err := db.Query("SELECT * FROM userinfo")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ columnTypes, _ := rows2.ColumnTypes()
+ var b strings.Builder
+ for rows2.Next() {
+ for index, value := range columnTypes {
+ precision, scale, precisionOk := value.DecimalSize()
+ length, lengthOk := value.Length()
+ nullable, nullableOk := value.Nullable()
+ fmt.Fprintf(&b, "Col %d: DatabaseTypeName %q, DecimalSize %v %v %v, Length %v %v, Name %q, Nullable %v %v, ScanType %q\n",
+ index,
+ value.DatabaseTypeName(),
+ precision, scale, precisionOk,
+ length, lengthOk,
+ value.Name(),
+ nullable, nullableOk,
+ value.ScanType(),
+ )
+ }
+ }
+ if err := rows2.Err(); err != nil {
+ t.Fatal(err)
+ }
+
+ if g, e := b.String(), `Col 0: DatabaseTypeName "INTEGER", DecimalSize 0 0 false, Length 0 false, Name "uid", Nullable true true, ScanType "int64"
+Col 1: DatabaseTypeName "VARCHAR(64)", DecimalSize 0 0 false, Length 9223372036854775807 true, Name "username", Nullable true true, ScanType "string"
+Col 2: DatabaseTypeName "VARCHAR(64)", DecimalSize 0 0 false, Length 9223372036854775807 true, Name "departname", Nullable true true, ScanType "string"
+Col 3: DatabaseTypeName "DATE", DecimalSize 0 0 false, Length 9223372036854775807 true, Name "created", Nullable true true, ScanType "string"
+`; g != e {
+ t.Fatalf("---- got\n%s\n----expected\n%s", g, e)
+ }
+ t.Log(b.String())
+}
diff --git a/sqlite.go b/sqlite.go
index 0679018..0b2ea41 100644
--- a/sqlite.go
+++ b/sqlite.go
@@ -13,6 +13,9 @@ import (
"database/sql/driver"
"fmt"
"io"
+ "math"
+ "reflect"
+ "strings"
"time"
"unsafe"
@@ -27,13 +30,18 @@ var (
//lint:ignore SA1019 TODO implement ExecerContext
_ driver.Execer = (*conn)(nil)
//lint:ignore SA1019 TODO implement QueryerContext
- _ driver.Queryer = (*conn)(nil)
- _ driver.Result = (*result)(nil)
- _ driver.Rows = noRows{}
- _ driver.Rows = (*rows)(nil)
- _ driver.Stmt = (*stmt)(nil)
- _ driver.Tx = (*tx)(nil)
- _ error = (*Error)(nil)
+ _ driver.Queryer = (*conn)(nil)
+ _ driver.Result = (*result)(nil)
+ _ driver.Rows = (*rows)(nil)
+ _ driver.Rows = noRows{}
+ _ driver.RowsColumnTypeDatabaseTypeName = (*rows)(nil)
+ _ driver.RowsColumnTypeLength = (*rows)(nil)
+ _ driver.RowsColumnTypeNullable = (*rows)(nil)
+ _ driver.RowsColumnTypePrecisionScale = (*rows)(nil)
+ _ driver.RowsColumnTypeScanType = (*rows)(nil)
+ _ driver.Stmt = (*stmt)(nil)
+ _ driver.Tx = (*tx)(nil)
+ _ error = (*Error)(nil)
)
const (
@@ -291,6 +299,100 @@ func (r *rows) Next(dest []driver.Value) (err error) {
}
}
+// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return
+// the database system type name without the length. Type names should be
+// uppercase. Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2",
+// "CHAR", "TEXT", "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT",
+// "JSONB", "XML", "TIMESTAMP".
+func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
+ return strings.ToUpper(r.c.columnDeclType(r.pstmt, index))
+}
+
+// RowsColumnTypeLength may be implemented by Rows. It should return the length
+// of the column type if the column is a variable length type. If the column is
+// not a variable length type ok should return false. If length is not limited
+// other than system limits, it should return math.MaxInt64. The following are
+// examples of returned values for various types:
+//
+// TEXT (math.MaxInt64, true)
+// varchar(10) (10, true)
+// nvarchar(10) (10, true)
+// decimal (0, false)
+// int (0, false)
+// bytea(30) (30, true)
+func (r *rows) ColumnTypeLength(index int) (length int64, ok bool) {
+ t, err := r.c.columnType(r.pstmt, index)
+ if err != nil {
+ return 0, false
+ }
+
+ switch t {
+ case sqlite3.SQLITE_INTEGER:
+ return 0, false
+ case sqlite3.SQLITE_FLOAT:
+ return 0, false
+ case sqlite3.SQLITE_TEXT:
+ return math.MaxInt64, true
+ case sqlite3.SQLITE_BLOB:
+ return math.MaxInt64, true
+ case sqlite3.SQLITE_NULL:
+ return 0, false
+ default:
+ return 0, false
+ }
+}
+
+// RowsColumnTypeNullable may be implemented by Rows. The nullable value should
+// be true if it is known the column may be null, or false if the column is
+// known to be not nullable. If the column nullability is unknown, ok should be
+// false.
+func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
+ return true, true
+}
+
+// RowsColumnTypePrecisionScale may be implemented by Rows. It should return
+// the precision and scale for decimal types. If not applicable, ok should be
+// false. The following are examples of returned values for various types:
+//
+// decimal(38, 4) (38, 4, true)
+// int (0, 0, false)
+// decimal (math.MaxInt64, math.MaxInt64, true)
+func (r *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
+ return 0, 0, false
+}
+
+// RowsColumnTypeScanType may be implemented by Rows. It should return the
+// value type that can be used to scan types into. For example, the database
+// column type "bigint" this should return "reflect.TypeOf(int64(0))".
+func (r *rows) ColumnTypeScanType(index int) reflect.Type {
+ t, err := r.c.columnType(r.pstmt, index)
+ if err != nil {
+ return reflect.TypeOf("")
+ }
+
+ switch t {
+ case sqlite3.SQLITE_INTEGER:
+ switch strings.ToLower(r.c.columnDeclType(r.pstmt, index)) {
+ case "boolean":
+ return reflect.TypeOf(false)
+ case "date", "datetime", "time", "timestamp":
+ return reflect.TypeOf(time.Time{})
+ default:
+ return reflect.TypeOf(int64(0))
+ }
+ case sqlite3.SQLITE_FLOAT:
+ return reflect.TypeOf(float64(0))
+ case sqlite3.SQLITE_TEXT:
+ return reflect.TypeOf("")
+ case sqlite3.SQLITE_BLOB:
+ return reflect.SliceOf(reflect.TypeOf([]byte{}))
+ case sqlite3.SQLITE_NULL:
+ return reflect.TypeOf(nil)
+ default:
+ return reflect.TypeOf("")
+ }
+}
+
type stmt struct {
c *conn
psql uintptr
@@ -664,6 +766,11 @@ func (c *conn) columnType(pstmt uintptr, iCol int) (_ int, err error) {
return int(v), nil
}
+// const char *sqlite3_column_decltype(sqlite3_stmt*,int);
+func (c *conn) columnDeclType(pstmt uintptr, iCol int) string {
+ return libc.GoString(sqlite3.Xsqlite3_column_decltype(c.tls, pstmt, int32(iCol)))
+}
+
// const char *sqlite3_column_name(sqlite3_stmt*, int N);
func (c *conn) columnName(pstmt uintptr, n int) (string, error) {
p := sqlite3.Xsqlite3_column_name(c.tls, pstmt, int32(n))