From 736c530ac70b31867445bbf0212e308467fb53b3 Mon Sep 17 00:00:00 2001 From: Jan Mercl <0xjnml@gmail.com> Date: Fri, 25 Sep 2020 15:12:39 +0200 Subject: implement sql.{RowsColumnTypeScanType,RowsColumnTypeDatabaseTypeName,RowsColumnTypeLength,RowsColumnTypeNullable,RowsColumnTypePrecisionScale}, fixes #30 --- all_test.go | 64 ++++++++++++++++++++++++++++++++ sqlite.go | 121 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 178 insertions(+), 7 deletions(-) diff --git a/all_test.go b/all_test.go index 08f565b..676fd8c 100644 --- a/all_test.go +++ b/all_test.go @@ -826,3 +826,67 @@ func TestIssue28(t *testing.T) { t.Fatalf("got %T(%[1]v), expected %T(%[2]v)", err, sql.ErrNoRows) } } + +// https://gitlab.com/cznic/sqlite/-/issues/30 +func TestIssue30(t *testing.T) { + tempDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + + defer os.RemoveAll(tempDir) + + db, err := sql.Open("sqlite", filepath.Join(tempDir, "test.db")) + if err != nil { + t.Fatalf("test.db open fail: %v", err) + } + + defer db.Close() + + _, err = db.Query("CREATE TABLE IF NOT EXISTS `userinfo` (`uid` INTEGER PRIMARY KEY AUTOINCREMENT,`username` VARCHAR(64) NULL, `departname` VARCHAR(64) NULL, `created` DATE NULL);") + if err != nil { + t.Fatal(err) + } + + insertStatement := `INSERT INTO userinfo(username, departname, created) values("astaxie", "研发部门", "2012-12-09")` + _, err = db.Query(insertStatement) + if err != nil { + t.Fatal(err) + } + + rows2, err := db.Query("SELECT * FROM userinfo") + if err != nil { + t.Fatal(err) + } + + columnTypes, _ := rows2.ColumnTypes() + var b strings.Builder + for rows2.Next() { + for index, value := range columnTypes { + precision, scale, precisionOk := value.DecimalSize() + length, lengthOk := value.Length() + nullable, nullableOk := value.Nullable() + fmt.Fprintf(&b, "Col %d: DatabaseTypeName %q, DecimalSize %v %v %v, Length %v %v, Name %q, Nullable %v %v, ScanType %q\n", + index, + value.DatabaseTypeName(), + precision, scale, precisionOk, + length, lengthOk, + value.Name(), + nullable, nullableOk, + value.ScanType(), + ) + } + } + if err := rows2.Err(); err != nil { + t.Fatal(err) + } + + if g, e := b.String(), `Col 0: DatabaseTypeName "INTEGER", DecimalSize 0 0 false, Length 0 false, Name "uid", Nullable true true, ScanType "int64" +Col 1: DatabaseTypeName "VARCHAR(64)", DecimalSize 0 0 false, Length 9223372036854775807 true, Name "username", Nullable true true, ScanType "string" +Col 2: DatabaseTypeName "VARCHAR(64)", DecimalSize 0 0 false, Length 9223372036854775807 true, Name "departname", Nullable true true, ScanType "string" +Col 3: DatabaseTypeName "DATE", DecimalSize 0 0 false, Length 9223372036854775807 true, Name "created", Nullable true true, ScanType "string" +`; g != e { + t.Fatalf("---- got\n%s\n----expected\n%s", g, e) + } + t.Log(b.String()) +} diff --git a/sqlite.go b/sqlite.go index 0679018..0b2ea41 100644 --- a/sqlite.go +++ b/sqlite.go @@ -13,6 +13,9 @@ import ( "database/sql/driver" "fmt" "io" + "math" + "reflect" + "strings" "time" "unsafe" @@ -27,13 +30,18 @@ var ( //lint:ignore SA1019 TODO implement ExecerContext _ driver.Execer = (*conn)(nil) //lint:ignore SA1019 TODO implement QueryerContext - _ driver.Queryer = (*conn)(nil) - _ driver.Result = (*result)(nil) - _ driver.Rows = noRows{} - _ driver.Rows = (*rows)(nil) - _ driver.Stmt = (*stmt)(nil) - _ driver.Tx = (*tx)(nil) - _ error = (*Error)(nil) + _ driver.Queryer = (*conn)(nil) + _ driver.Result = (*result)(nil) + _ driver.Rows = (*rows)(nil) + _ driver.Rows = noRows{} + _ driver.RowsColumnTypeDatabaseTypeName = (*rows)(nil) + _ driver.RowsColumnTypeLength = (*rows)(nil) + _ driver.RowsColumnTypeNullable = (*rows)(nil) + _ driver.RowsColumnTypePrecisionScale = (*rows)(nil) + _ driver.RowsColumnTypeScanType = (*rows)(nil) + _ driver.Stmt = (*stmt)(nil) + _ driver.Tx = (*tx)(nil) + _ error = (*Error)(nil) ) const ( @@ -291,6 +299,100 @@ func (r *rows) Next(dest []driver.Value) (err error) { } } +// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return +// the database system type name without the length. Type names should be +// uppercase. Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", +// "CHAR", "TEXT", "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", +// "JSONB", "XML", "TIMESTAMP". +func (r *rows) ColumnTypeDatabaseTypeName(index int) string { + return strings.ToUpper(r.c.columnDeclType(r.pstmt, index)) +} + +// RowsColumnTypeLength may be implemented by Rows. It should return the length +// of the column type if the column is a variable length type. If the column is +// not a variable length type ok should return false. If length is not limited +// other than system limits, it should return math.MaxInt64. The following are +// examples of returned values for various types: +// +// TEXT (math.MaxInt64, true) +// varchar(10) (10, true) +// nvarchar(10) (10, true) +// decimal (0, false) +// int (0, false) +// bytea(30) (30, true) +func (r *rows) ColumnTypeLength(index int) (length int64, ok bool) { + t, err := r.c.columnType(r.pstmt, index) + if err != nil { + return 0, false + } + + switch t { + case sqlite3.SQLITE_INTEGER: + return 0, false + case sqlite3.SQLITE_FLOAT: + return 0, false + case sqlite3.SQLITE_TEXT: + return math.MaxInt64, true + case sqlite3.SQLITE_BLOB: + return math.MaxInt64, true + case sqlite3.SQLITE_NULL: + return 0, false + default: + return 0, false + } +} + +// RowsColumnTypeNullable may be implemented by Rows. The nullable value should +// be true if it is known the column may be null, or false if the column is +// known to be not nullable. If the column nullability is unknown, ok should be +// false. +func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) { + return true, true +} + +// RowsColumnTypePrecisionScale may be implemented by Rows. It should return +// the precision and scale for decimal types. If not applicable, ok should be +// false. The following are examples of returned values for various types: +// +// decimal(38, 4) (38, 4, true) +// int (0, 0, false) +// decimal (math.MaxInt64, math.MaxInt64, true) +func (r *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { + return 0, 0, false +} + +// RowsColumnTypeScanType may be implemented by Rows. It should return the +// value type that can be used to scan types into. For example, the database +// column type "bigint" this should return "reflect.TypeOf(int64(0))". +func (r *rows) ColumnTypeScanType(index int) reflect.Type { + t, err := r.c.columnType(r.pstmt, index) + if err != nil { + return reflect.TypeOf("") + } + + switch t { + case sqlite3.SQLITE_INTEGER: + switch strings.ToLower(r.c.columnDeclType(r.pstmt, index)) { + case "boolean": + return reflect.TypeOf(false) + case "date", "datetime", "time", "timestamp": + return reflect.TypeOf(time.Time{}) + default: + return reflect.TypeOf(int64(0)) + } + case sqlite3.SQLITE_FLOAT: + return reflect.TypeOf(float64(0)) + case sqlite3.SQLITE_TEXT: + return reflect.TypeOf("") + case sqlite3.SQLITE_BLOB: + return reflect.SliceOf(reflect.TypeOf([]byte{})) + case sqlite3.SQLITE_NULL: + return reflect.TypeOf(nil) + default: + return reflect.TypeOf("") + } +} + type stmt struct { c *conn psql uintptr @@ -664,6 +766,11 @@ func (c *conn) columnType(pstmt uintptr, iCol int) (_ int, err error) { return int(v), nil } +// const char *sqlite3_column_decltype(sqlite3_stmt*,int); +func (c *conn) columnDeclType(pstmt uintptr, iCol int) string { + return libc.GoString(sqlite3.Xsqlite3_column_decltype(c.tls, pstmt, int32(iCol))) +} + // const char *sqlite3_column_name(sqlite3_stmt*, int N); func (c *conn) columnName(pstmt uintptr, n int) (string, error) { p := sqlite3.Xsqlite3_column_name(c.tls, pstmt, int32(n)) -- cgit v1.2.3-70-g09d2