Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/rdsdataservice"
"github.com/gofrs/uuid"
)

type Conn struct {
Expand Down Expand Up @@ -79,3 +80,17 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
return executeStatement(ctx, c.config, query, c.transactionID, args...)
}

// CheckNamedValue allows some types to get passed down to the
// executor so the TypeHint can be set
func (c *Conn) CheckNamedValue(v *driver.NamedValue) error {
switch v.Value.(type) {
case uuid.UUID:
return nil
}

if _, ok := v.Value.(Hinter); ok {
return nil
}
return driver.ErrSkip
}
9 changes: 7 additions & 2 deletions executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,14 @@ func executeStatement(ctx context.Context, config *config, query, transactionID
name = prefix + strconv.Itoa(arg.Ordinal)
}

val, hint, err := asField(arg.Value)
if err != nil {
return nil, err
}
param := rdsdataservice.SqlParameter{
Name: aws.String(name),
Value: asField(arg.Value),
Name: aws.String(name),
Value: val,
TypeHint: hint,
}

input.Parameters = append(input.Parameters, &param)
Expand Down
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ module github.com/savaki/dapi

go 1.14

require github.com/aws/aws-sdk-go v1.33.13
require (
github.com/aws/aws-sdk-go v1.33.13
github.com/gofrs/uuid v4.0.0+incompatible // indirect
)
53 changes: 43 additions & 10 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@ package dapi
import (
"context"
"database/sql/driver"
"errors"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/rdsdataservice"
"time"
"github.com/gofrs/uuid"
)

var ErrInvalidField = errors.New("invalid field")

type Stmt struct {
ctx context.Context
config *config
Expand Down Expand Up @@ -58,7 +63,7 @@ func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
}

func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
panic("implement me: QueryContext (Stmt)")
return executeStatement(ctx, s.config, s.query, "", args...)
}

func newStmt(ctx context.Context, config *config, query string) *Stmt {
Expand All @@ -69,23 +74,51 @@ func newStmt(ctx context.Context, config *config, query string) *Stmt {
}
}

func asField(value driver.Value) *rdsdataservice.Field {
// If a type implements Hinter it can provide a type hint to the data-api directly
type Hinter interface {
TypeHint() string
driver.Valuer
}

func asField(value driver.Value) (*rdsdataservice.Field, *string, error) {
var hint *string
if v, ok := value.(Hinter); ok {
hint = aws.String(v.TypeHint())
} else {
switch value.(type) {
case time.Time:
hint = aws.String("TIMESTAMP")
case uuid.UUID:
hint = aws.String("UUID")
}
}
if v, ok := value.(driver.Valuer); ok {
var err error
value, err = v.Value()
if err != nil {
return nil, hint, err
}
}

switch v := value.(type) {
case int64:
return &rdsdataservice.Field{LongValue: aws.Int64(v)}
return &rdsdataservice.Field{LongValue: aws.Int64(v)}, hint, nil
case float64:
return &rdsdataservice.Field{DoubleValue: aws.Float64(v)}
return &rdsdataservice.Field{DoubleValue: aws.Float64(v)}, hint, nil
case bool:
return &rdsdataservice.Field{BooleanValue: aws.Bool(v)}
return &rdsdataservice.Field{BooleanValue: aws.Bool(v)}, hint, nil
case []byte:
return &rdsdataservice.Field{BlobValue: v}
return &rdsdataservice.Field{BlobValue: v}, hint, nil
case string:
return &rdsdataservice.Field{StringValue: aws.String(v)}
return &rdsdataservice.Field{StringValue: aws.String(v)}, hint, nil
case time.Time:
s := v.Format("2006-01-02 15:04:05")
return &rdsdataservice.Field{StringValue: aws.String(s)}
return &rdsdataservice.Field{StringValue: aws.String(s)}, hint, nil
default:
return &rdsdataservice.Field{IsNull: aws.Bool(true)}
if v == nil {
return &rdsdataservice.Field{IsNull: aws.Bool(true)}, hint, nil
}
return nil, hint, ErrInvalidField
}
}

Expand Down