Skip to content
26 changes: 19 additions & 7 deletions dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,33 @@ func (d *dataloader) Load(ogCtx context.Context, key Key) Thunk {
func (d *dataloader) LoadMany(ogCtx context.Context, keyArr ...Key) ThunkMany {
ctx, finish := d.tracer.LoadMany(ogCtx, keyArr)

if r, ok := d.cache.GetResultMap(ctx, keyArr...); ok {
d.logger.Logf("cache hit for: %d", keyArr)
d.strategy.LoadNoOp(ctx)
return func() ResultMap {
finish(r)
var cached, missed = ResultMap{}, []Key{}
for _, key := range keyArr {
if r, ok := d.cache.GetResult(ctx, key); ok {
d.logger.Logf("cache hit for: %d", key)
d.strategy.LoadNoOp(ctx)
cached[key.String()] = r
} else {
missed = append(missed, key)
}
}

return r
if len(missed) == 0 {
return func() ResultMap {
finish(cached)
return cached
}
}

thunkMany := d.strategy.LoadMany(ctx, keyArr...)
thunkMany := d.strategy.LoadMany(ctx, missed...)
return func() ResultMap {
cached := cached
result := thunkMany()
d.cache.SetResultMap(ctx, result)

for k, v := range cached {
result[k] = v
}
finish(result)

return result
Expand Down
24 changes: 14 additions & 10 deletions dataloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func getBatchFunction(cb func(), result dataloader.Result) dataloader.BatchFunct
return func(ctx context.Context, keys dataloader.Keys) *dataloader.ResultMap {
cb()
m := dataloader.NewResultMap(1)
m.Set(keys.Keys()[0].(PrimaryKey).String(), result)
m.Set(keys.Keys()[0].(PrimaryKey), result)
return &m
}
}
Expand All @@ -51,8 +51,8 @@ func (c *mockCache) SetResult(ctx context.Context, key dataloader.Key, result da
}

func (c *mockCache) SetResultMap(ctx context.Context, resultMap dataloader.ResultMap) {
for _, k := range resultMap.Keys() {
c.r[k] = resultMap.GetValueForString(k)
for k, v := range resultMap {
c.r[k] = v
}
}

Expand All @@ -62,21 +62,25 @@ func (c *mockCache) GetResult(ctx context.Context, key dataloader.Key) (dataload
}

func (c *mockCache) GetResultMap(ctx context.Context, keys ...dataloader.Key) (dataloader.ResultMap, bool) {
var nok bool
result := dataloader.NewResultMap(len(keys))
for _, k := range keys {
r, ok := c.r[k.String()]
for _, key := range keys {
var k = key.String()
r, ok := c.r[k]
if !ok {
return dataloader.NewResultMap(len(keys)), false
nok = true
continue
}
result.Set(k.String(), r)
result[k] = r
}
return result, true
return result, !nok
}

func (c *mockCache) Delete(ctx context.Context, key dataloader.Key) bool {
_, ok := c.r[key.String()]
var k = key.String()
_, ok := c.r[k]
if ok {
delete(c.r, key.String())
delete(c.r, k)
return true
}
return false
Expand Down
43 changes: 34 additions & 9 deletions key.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ type Key interface {
Raw() interface{}
}

type StringKey string

func (k StringKey) String() string {
return string(k)
}

func (k StringKey) Raw() interface{} {
return k
}

// Keys wraps an array of keys and contains accessor methods
type Keys interface {
Append(...Key)
Expand All @@ -20,25 +30,26 @@ type Keys interface {
ClearAll()
// Keys returns a an array of unique results after calling Raw on each key
Keys() []interface{}
StringKeys() []string
IsEmpty() bool
}

type keys struct {
k []Key
keys []Key
}

// NewKeys returns a new instance of the Keys array with the provided capacity.
func NewKeys(capacity int) Keys {
return &keys{
k: make([]Key, 0, capacity),
keys: make([]Key, 0, capacity),
}
}

// NewKeysWith is a helper method for returning a new keys array which includes the
// the provided keys
func NewKeysWith(key ...Key) Keys {
return &keys{
k: key,
keys: key,
}
}

Expand All @@ -47,28 +58,28 @@ func NewKeysWith(key ...Key) Keys {
func (k *keys) Append(keys ...Key) {
for _, key := range keys {
if key != nil && key.Raw() != nil { // don't track nil keys
k.k = append(k.k, key)
k.keys = append(k.keys, key)
}
}
}

func (k *keys) Capacity() int {
return cap(k.k)
return cap(k.keys)
}

func (k *keys) Length() int {
return len(k.k)
return len(k.keys)
}

func (k *keys) ClearAll() {
k.k = make([]Key, 0, len(k.k))
k.keys = make([]Key, 0, len(k.keys))
}

func (k *keys) Keys() []interface{} {
result := make([]interface{}, 0, k.Length())
temp := make(map[Key]bool, k.Length())

for _, val := range k.k {
for _, val := range k.keys {
if _, ok := temp[val]; !ok {
temp[val] = true
result = append(result, val.Raw())
Expand All @@ -78,6 +89,20 @@ func (k *keys) Keys() []interface{} {
return result
}

func (k *keys) StringKeys() []string {
result := make([]string, 0, k.Length())
temp := make(map[Key]bool, k.Length())

for _, val := range k.keys {
if _, ok := temp[val]; !ok {
temp[val] = true
result = append(result, val.String())
}
}

return result
}

func (k *keys) IsEmpty() bool {
return len(k.k) == 0
return len(k.keys) == 0
}
38 changes: 13 additions & 25 deletions result.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,58 +7,46 @@ type Result struct {
}

// ResultMap maps each loaded elements Result against the elements unique identifier (Key)
type ResultMap interface {
Set(string, Result)
GetValue(Key) (Result, bool)
Length() int
// Keys returns a slice of all unique identifiers used in the containing map (keys)
Keys() []string
GetValueForString(string) Result
}

type resultMap struct {
r map[string]Result
}
type ResultMap map[string]Result

// NewResultMap returns a new instance of the result map with the provided capacity.
// Each value defaults to nil
func NewResultMap(capacity int) ResultMap {
r := make(map[string]Result, capacity)

return &resultMap{r: r}
return r
}

// ===================================== public methods =====================================

// Set adds the value to the to the result set.
func (r *resultMap) Set(identifier string, value Result) {
r.r[identifier] = value
func (r ResultMap) Set(identifier Key, value Result) {
r[identifier.String()] = value
}

// GetValue returns the value from the results for the provided key and true
// if the value was found, otherwise false.
func (r *resultMap) GetValue(key Key) (Result, bool) {
func (r ResultMap) GetValue(key Key) (Result, bool) {
if key == nil {
return Result{}, false
}

result, ok := r.r[key.String()]
result, ok := r[key.String()]
return result, ok
}

func (r *resultMap) GetValueForString(key string) Result {
func (r ResultMap) GetValueForString(key string) Result {
// No need to check ok, missing value from map[Any]interface{} is nil by default.
return r.r[key]
return r[key]
}

func (r *resultMap) Keys() []string {
res := make([]string, 0, len(r.r))
for k := range r.r {
func (r ResultMap) Keys() []string {
res := make([]string, 0, len(r))
for k, _ := range r {
res = append(res, k)
}
return res
}

func (r *resultMap) Length() int {
return len(r.r)
func (r ResultMap) Length() int {
return len(r)
}
4 changes: 2 additions & 2 deletions result_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func TestEnsureOKForResult(t *testing.T) {
rmap := dataloader.NewResultMap(2)
key := PrimaryKey(1)
value := dataloader.Result{Result: 1, Err: nil}
rmap.Set(key.String(), value)
rmap.Set(key, value)

// invoke/assert
result, ok := rmap.GetValue(key)
Expand All @@ -27,7 +27,7 @@ func TestEnsureNotOKForResult(t *testing.T) {
key := PrimaryKey(1)
key2 := PrimaryKey(2)
value := dataloader.Result{Result: 1, Err: nil}
rmap.Set(key.String(), value)
rmap.Set(key, value)

// invoke/assert
result, ok := rmap.GetValue(key2)
Expand Down
4 changes: 2 additions & 2 deletions strategies/once/once_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func getBatchFunction(cb func(), result dataloader.Result) dataloader.BatchFunct
return func(ctx context.Context, keys dataloader.Keys) *dataloader.ResultMap {
cb()
m := dataloader.NewResultMap(1)
m.Set(keys.Keys()[0].(PrimaryKey).String(), result)
m.Set(keys.Keys()[0].(PrimaryKey), result)
return &m
}
}
Expand Down Expand Up @@ -302,7 +302,7 @@ func TestKeyHandling(t *testing.T) {
for i := 0; i < keys.Length(); i++ {
key := keys.Keys()[i].(PrimaryKey)
if expectedResult[key] != "__skip__" {
m.Set(key.String(), dataloader.Result{Result: expectedResult[key], Err: nil})
m.Set(key, dataloader.Result{Result: expectedResult[key], Err: nil})
}
}
return &m
Expand Down
2 changes: 1 addition & 1 deletion strategies/sozu/sozu.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ func buildResultMap(keyArr []dataloader.Key, r dataloader.ResultMap) dataloader.

for _, k := range keyArr {
if val, ok := r.GetValue(k); ok {
results.Set(k.String(), val)
results.Set(k, val)
}
}

Expand Down
20 changes: 10 additions & 10 deletions strategies/sozu/sozu_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ func getBatchFunction(cb func(dataloader.Keys), result string) dataloader.BatchF
return func(ctx context.Context, keys dataloader.Keys) *dataloader.ResultMap {
cb(keys)
m := dataloader.NewResultMap(1)
for _, k := range keys.Keys() {
key := k.(PrimaryKey).String()
for _, k := range keys.RawKeys() {
key := k.(PrimaryKey)
m.Set(
key,
dataloader.Result{
Expand Down Expand Up @@ -123,7 +123,7 @@ func TestLoadTimeoutTriggered(t *testing.T) {
cb := func(keys dataloader.Keys) {
blockWG.Wait()
callCount += 1
k = keys.Keys()
k = keys.RawKeys()
close(closeChan)
wg.Done()
}
Expand Down Expand Up @@ -206,7 +206,7 @@ func TestLoadManyTimeoutTriggered(t *testing.T) {
cb := func(keys dataloader.Keys) {
blockWG.Wait()
callCount += 1
k = keys.Keys()
k = keys.RawKeys()
close(closeChan)
wg.Done()
}
Expand Down Expand Up @@ -300,7 +300,7 @@ func TestLoadTriggered(t *testing.T) {
cb := func(keys dataloader.Keys) {
blockWG.Wait()
callCount += 1
k = keys.Keys()
k = keys.RawKeys()
close(closeChan)
wg.Done()
}
Expand Down Expand Up @@ -381,7 +381,7 @@ func TestLoadManyTriggered(t *testing.T) {
cb := func(keys dataloader.Keys) {
blockWG.Wait()
callCount += 1
k = keys.Keys()
k = keys.RawKeys()
close(closeChan)
wg.Done()
}
Expand Down Expand Up @@ -458,7 +458,7 @@ func TestLoadBlocked(t *testing.T) {
expectedResult := "batch_on_timeout_load_many"
cb := func(keys dataloader.Keys) {
callCount += 1
k = keys.Keys()
k = keys.RawKeys()
close(closeChan)
}

Expand Down Expand Up @@ -521,7 +521,7 @@ func TestLoadManyBlocked(t *testing.T) {
expectedResult := "batch_on_timeout_load_many"
cb := func(keys dataloader.Keys) {
callCount += 1
k = keys.Keys()
k = keys.RawKeys()
close(closeChan)
}

Expand Down Expand Up @@ -666,9 +666,9 @@ func TestKeyHandling(t *testing.T) {
batch := func(ctx context.Context, keys dataloader.Keys) *dataloader.ResultMap {
m := dataloader.NewResultMap(2)
for i := 0; i < keys.Length(); i++ {
key := keys.Keys()[i].(PrimaryKey)
key := keys.RawKeys()[i].(PrimaryKey)
if expectedResult[key] != "__skip__" {
m.Set(key.String(), dataloader.Result{Result: expectedResult[key], Err: nil})
m.Set(key, dataloader.Result{Result: expectedResult[key], Err: nil})
}
}
return &m
Expand Down
Loading