From d4ceb264dd95ebe74b50f1db792caeb670e5c4e3 Mon Sep 17 00:00:00 2001 From: Louis des Landes Date: Fri, 26 Feb 2021 10:49:44 +1030 Subject: [PATCH] Add Hinter interface to allow types to set TypeHint --- conn.go | 15 +++++++++++++++ executor.go | 9 +++++++-- go.mod | 5 ++++- stmt.go | 53 +++++++++++++++++++++++++++++++++++++++++++---------- 4 files changed, 69 insertions(+), 13 deletions(-) diff --git a/conn.go b/conn.go index 2db07fd..28dfd6f 100644 --- a/conn.go +++ b/conn.go @@ -21,6 +21,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/rdsdataservice" + "github.com/gofrs/uuid" ) type Conn struct { @@ -79,3 +80,17 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { return executeStatement(ctx, c.config, query, c.transactionID, args...) } + +// CheckNamedValue allows some types to get passed down to the +// executor so the TypeHint can be set +func (c *Conn) CheckNamedValue(v *driver.NamedValue) error { + switch v.Value.(type) { + case uuid.UUID: + return nil + } + + if _, ok := v.Value.(Hinter); ok { + return nil + } + return driver.ErrSkip +} diff --git a/executor.go b/executor.go index 3313afc..150dc20 100644 --- a/executor.go +++ b/executor.go @@ -122,9 +122,14 @@ func executeStatement(ctx context.Context, config *config, query, transactionID name = prefix + strconv.Itoa(arg.Ordinal) } + val, hint, err := asField(arg.Value) + if err != nil { + return nil, err + } param := rdsdataservice.SqlParameter{ - Name: aws.String(name), - Value: asField(arg.Value), + Name: aws.String(name), + Value: val, + TypeHint: hint, } input.Parameters = append(input.Parameters, ¶m) diff --git a/go.mod b/go.mod index 1200628..aa687ac 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/savaki/dapi go 1.14 -require github.com/aws/aws-sdk-go v1.33.13 +require ( + github.com/aws/aws-sdk-go v1.33.13 + github.com/gofrs/uuid v4.0.0+incompatible // indirect +) diff --git a/stmt.go b/stmt.go index 2485b60..d6ec667 100644 --- a/stmt.go +++ b/stmt.go @@ -17,11 +17,16 @@ package dapi import ( "context" "database/sql/driver" + "errors" + "time" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/rdsdataservice" - "time" + "github.com/gofrs/uuid" ) +var ErrInvalidField = errors.New("invalid field") + type Stmt struct { ctx context.Context config *config @@ -58,7 +63,7 @@ func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { } func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - panic("implement me: QueryContext (Stmt)") + return executeStatement(ctx, s.config, s.query, "", args...) } func newStmt(ctx context.Context, config *config, query string) *Stmt { @@ -69,23 +74,51 @@ func newStmt(ctx context.Context, config *config, query string) *Stmt { } } -func asField(value driver.Value) *rdsdataservice.Field { +// If a type implements Hinter it can provide a type hint to the data-api directly +type Hinter interface { + TypeHint() string + driver.Valuer +} + +func asField(value driver.Value) (*rdsdataservice.Field, *string, error) { + var hint *string + if v, ok := value.(Hinter); ok { + hint = aws.String(v.TypeHint()) + } else { + switch value.(type) { + case time.Time: + hint = aws.String("TIMESTAMP") + case uuid.UUID: + hint = aws.String("UUID") + } + } + if v, ok := value.(driver.Valuer); ok { + var err error + value, err = v.Value() + if err != nil { + return nil, hint, err + } + } + switch v := value.(type) { case int64: - return &rdsdataservice.Field{LongValue: aws.Int64(v)} + return &rdsdataservice.Field{LongValue: aws.Int64(v)}, hint, nil case float64: - return &rdsdataservice.Field{DoubleValue: aws.Float64(v)} + return &rdsdataservice.Field{DoubleValue: aws.Float64(v)}, hint, nil case bool: - return &rdsdataservice.Field{BooleanValue: aws.Bool(v)} + return &rdsdataservice.Field{BooleanValue: aws.Bool(v)}, hint, nil case []byte: - return &rdsdataservice.Field{BlobValue: v} + return &rdsdataservice.Field{BlobValue: v}, hint, nil case string: - return &rdsdataservice.Field{StringValue: aws.String(v)} + return &rdsdataservice.Field{StringValue: aws.String(v)}, hint, nil case time.Time: s := v.Format("2006-01-02 15:04:05") - return &rdsdataservice.Field{StringValue: aws.String(s)} + return &rdsdataservice.Field{StringValue: aws.String(s)}, hint, nil default: - return &rdsdataservice.Field{IsNull: aws.Bool(true)} + if v == nil { + return &rdsdataservice.Field{IsNull: aws.Bool(true)}, hint, nil + } + return nil, hint, ErrInvalidField } }