diff --git a/dataloader.go b/dataloader.go index 2b6f97c..8a7fda7 100644 --- a/dataloader.go +++ b/dataloader.go @@ -153,21 +153,33 @@ func (d *dataloader) Load(ogCtx context.Context, key Key) Thunk { func (d *dataloader) LoadMany(ogCtx context.Context, keyArr ...Key) ThunkMany { ctx, finish := d.tracer.LoadMany(ogCtx, keyArr) - if r, ok := d.cache.GetResultMap(ctx, keyArr...); ok { - d.logger.Logf("cache hit for: %d", keyArr) - d.strategy.LoadNoOp(ctx) - return func() ResultMap { - finish(r) + var cached, missed = ResultMap{}, []Key{} + for _, key := range keyArr { + if r, ok := d.cache.GetResult(ctx, key); ok { + d.logger.Logf("cache hit for: %d", key) + d.strategy.LoadNoOp(ctx) + cached[key.String()] = r + } else { + missed = append(missed, key) + } + } - return r + if len(missed) == 0 { + return func() ResultMap { + finish(cached) + return cached } } - thunkMany := d.strategy.LoadMany(ctx, keyArr...) + thunkMany := d.strategy.LoadMany(ctx, missed...) return func() ResultMap { + cached := cached result := thunkMany() d.cache.SetResultMap(ctx, result) + for k, v := range cached { + result[k] = v + } finish(result) return result diff --git a/dataloader_test.go b/dataloader_test.go index 6ced464..9f491fb 100644 --- a/dataloader_test.go +++ b/dataloader_test.go @@ -32,7 +32,7 @@ func getBatchFunction(cb func(), result dataloader.Result) dataloader.BatchFunct return func(ctx context.Context, keys dataloader.Keys) *dataloader.ResultMap { cb() m := dataloader.NewResultMap(1) - m.Set(keys.Keys()[0].(PrimaryKey).String(), result) + m.Set(keys.Keys()[0].(PrimaryKey), result) return &m } } @@ -51,8 +51,8 @@ func (c *mockCache) SetResult(ctx context.Context, key dataloader.Key, result da } func (c *mockCache) SetResultMap(ctx context.Context, resultMap dataloader.ResultMap) { - for _, k := range resultMap.Keys() { - c.r[k] = resultMap.GetValueForString(k) + for k, v := range resultMap { + c.r[k] = v } } @@ -62,21 +62,25 @@ func (c *mockCache) GetResult(ctx context.Context, key dataloader.Key) (dataload } func (c *mockCache) GetResultMap(ctx context.Context, keys ...dataloader.Key) (dataloader.ResultMap, bool) { + var nok bool result := dataloader.NewResultMap(len(keys)) - for _, k := range keys { - r, ok := c.r[k.String()] + for _, key := range keys { + var k = key.String() + r, ok := c.r[k] if !ok { - return dataloader.NewResultMap(len(keys)), false + nok = true + continue } - result.Set(k.String(), r) + result[k] = r } - return result, true + return result, !nok } func (c *mockCache) Delete(ctx context.Context, key dataloader.Key) bool { - _, ok := c.r[key.String()] + var k = key.String() + _, ok := c.r[k] if ok { - delete(c.r, key.String()) + delete(c.r, k) return true } return false diff --git a/key.go b/key.go index d59e0ed..6c4678f 100644 --- a/key.go +++ b/key.go @@ -12,6 +12,16 @@ type Key interface { Raw() interface{} } +type StringKey string + +func (k StringKey) String() string { + return string(k) +} + +func (k StringKey) Raw() interface{} { + return k +} + // Keys wraps an array of keys and contains accessor methods type Keys interface { Append(...Key) @@ -20,17 +30,18 @@ type Keys interface { ClearAll() // Keys returns a an array of unique results after calling Raw on each key Keys() []interface{} + StringKeys() []string IsEmpty() bool } type keys struct { - k []Key + keys []Key } // NewKeys returns a new instance of the Keys array with the provided capacity. func NewKeys(capacity int) Keys { return &keys{ - k: make([]Key, 0, capacity), + keys: make([]Key, 0, capacity), } } @@ -38,7 +49,7 @@ func NewKeys(capacity int) Keys { // the provided keys func NewKeysWith(key ...Key) Keys { return &keys{ - k: key, + keys: key, } } @@ -47,28 +58,28 @@ func NewKeysWith(key ...Key) Keys { func (k *keys) Append(keys ...Key) { for _, key := range keys { if key != nil && key.Raw() != nil { // don't track nil keys - k.k = append(k.k, key) + k.keys = append(k.keys, key) } } } func (k *keys) Capacity() int { - return cap(k.k) + return cap(k.keys) } func (k *keys) Length() int { - return len(k.k) + return len(k.keys) } func (k *keys) ClearAll() { - k.k = make([]Key, 0, len(k.k)) + k.keys = make([]Key, 0, len(k.keys)) } func (k *keys) Keys() []interface{} { result := make([]interface{}, 0, k.Length()) temp := make(map[Key]bool, k.Length()) - for _, val := range k.k { + for _, val := range k.keys { if _, ok := temp[val]; !ok { temp[val] = true result = append(result, val.Raw()) @@ -78,6 +89,20 @@ func (k *keys) Keys() []interface{} { return result } +func (k *keys) StringKeys() []string { + result := make([]string, 0, k.Length()) + temp := make(map[Key]bool, k.Length()) + + for _, val := range k.keys { + if _, ok := temp[val]; !ok { + temp[val] = true + result = append(result, val.String()) + } + } + + return result +} + func (k *keys) IsEmpty() bool { - return len(k.k) == 0 + return len(k.keys) == 0 } diff --git a/result.go b/result.go index 252bd85..702d0dc 100644 --- a/result.go +++ b/result.go @@ -7,58 +7,46 @@ type Result struct { } // ResultMap maps each loaded elements Result against the elements unique identifier (Key) -type ResultMap interface { - Set(string, Result) - GetValue(Key) (Result, bool) - Length() int - // Keys returns a slice of all unique identifiers used in the containing map (keys) - Keys() []string - GetValueForString(string) Result -} - -type resultMap struct { - r map[string]Result -} +type ResultMap map[string]Result // NewResultMap returns a new instance of the result map with the provided capacity. // Each value defaults to nil func NewResultMap(capacity int) ResultMap { r := make(map[string]Result, capacity) - - return &resultMap{r: r} + return r } // ===================================== public methods ===================================== // Set adds the value to the to the result set. -func (r *resultMap) Set(identifier string, value Result) { - r.r[identifier] = value +func (r ResultMap) Set(identifier Key, value Result) { + r[identifier.String()] = value } // GetValue returns the value from the results for the provided key and true // if the value was found, otherwise false. -func (r *resultMap) GetValue(key Key) (Result, bool) { +func (r ResultMap) GetValue(key Key) (Result, bool) { if key == nil { return Result{}, false } - result, ok := r.r[key.String()] + result, ok := r[key.String()] return result, ok } -func (r *resultMap) GetValueForString(key string) Result { +func (r ResultMap) GetValueForString(key string) Result { // No need to check ok, missing value from map[Any]interface{} is nil by default. - return r.r[key] + return r[key] } -func (r *resultMap) Keys() []string { - res := make([]string, 0, len(r.r)) - for k := range r.r { +func (r ResultMap) Keys() []string { + res := make([]string, 0, len(r)) + for k, _ := range r { res = append(res, k) } return res } -func (r *resultMap) Length() int { - return len(r.r) +func (r ResultMap) Length() int { + return len(r) } diff --git a/result_test.go b/result_test.go index ccd92c1..490b5af 100644 --- a/result_test.go +++ b/result_test.go @@ -14,7 +14,7 @@ func TestEnsureOKForResult(t *testing.T) { rmap := dataloader.NewResultMap(2) key := PrimaryKey(1) value := dataloader.Result{Result: 1, Err: nil} - rmap.Set(key.String(), value) + rmap.Set(key, value) // invoke/assert result, ok := rmap.GetValue(key) @@ -27,7 +27,7 @@ func TestEnsureNotOKForResult(t *testing.T) { key := PrimaryKey(1) key2 := PrimaryKey(2) value := dataloader.Result{Result: 1, Err: nil} - rmap.Set(key.String(), value) + rmap.Set(key, value) // invoke/assert result, ok := rmap.GetValue(key2) diff --git a/strategies/once/once_test.go b/strategies/once/once_test.go index 230a354..6f8dd5f 100644 --- a/strategies/once/once_test.go +++ b/strategies/once/once_test.go @@ -35,7 +35,7 @@ func getBatchFunction(cb func(), result dataloader.Result) dataloader.BatchFunct return func(ctx context.Context, keys dataloader.Keys) *dataloader.ResultMap { cb() m := dataloader.NewResultMap(1) - m.Set(keys.Keys()[0].(PrimaryKey).String(), result) + m.Set(keys.Keys()[0].(PrimaryKey), result) return &m } } @@ -302,7 +302,7 @@ func TestKeyHandling(t *testing.T) { for i := 0; i < keys.Length(); i++ { key := keys.Keys()[i].(PrimaryKey) if expectedResult[key] != "__skip__" { - m.Set(key.String(), dataloader.Result{Result: expectedResult[key], Err: nil}) + m.Set(key, dataloader.Result{Result: expectedResult[key], Err: nil}) } } return &m diff --git a/strategies/sozu/sozu.go b/strategies/sozu/sozu.go index 2111d14..5075928 100644 --- a/strategies/sozu/sozu.go +++ b/strategies/sozu/sozu.go @@ -317,7 +317,7 @@ func buildResultMap(keyArr []dataloader.Key, r dataloader.ResultMap) dataloader. for _, k := range keyArr { if val, ok := r.GetValue(k); ok { - results.Set(k.String(), val) + results.Set(k, val) } } diff --git a/strategies/sozu/sozu_test.go b/strategies/sozu/sozu_test.go index 183973c..d5e23bb 100644 --- a/strategies/sozu/sozu_test.go +++ b/strategies/sozu/sozu_test.go @@ -36,8 +36,8 @@ func getBatchFunction(cb func(dataloader.Keys), result string) dataloader.BatchF return func(ctx context.Context, keys dataloader.Keys) *dataloader.ResultMap { cb(keys) m := dataloader.NewResultMap(1) - for _, k := range keys.Keys() { - key := k.(PrimaryKey).String() + for _, k := range keys.RawKeys() { + key := k.(PrimaryKey) m.Set( key, dataloader.Result{ @@ -123,7 +123,7 @@ func TestLoadTimeoutTriggered(t *testing.T) { cb := func(keys dataloader.Keys) { blockWG.Wait() callCount += 1 - k = keys.Keys() + k = keys.RawKeys() close(closeChan) wg.Done() } @@ -206,7 +206,7 @@ func TestLoadManyTimeoutTriggered(t *testing.T) { cb := func(keys dataloader.Keys) { blockWG.Wait() callCount += 1 - k = keys.Keys() + k = keys.RawKeys() close(closeChan) wg.Done() } @@ -300,7 +300,7 @@ func TestLoadTriggered(t *testing.T) { cb := func(keys dataloader.Keys) { blockWG.Wait() callCount += 1 - k = keys.Keys() + k = keys.RawKeys() close(closeChan) wg.Done() } @@ -381,7 +381,7 @@ func TestLoadManyTriggered(t *testing.T) { cb := func(keys dataloader.Keys) { blockWG.Wait() callCount += 1 - k = keys.Keys() + k = keys.RawKeys() close(closeChan) wg.Done() } @@ -458,7 +458,7 @@ func TestLoadBlocked(t *testing.T) { expectedResult := "batch_on_timeout_load_many" cb := func(keys dataloader.Keys) { callCount += 1 - k = keys.Keys() + k = keys.RawKeys() close(closeChan) } @@ -521,7 +521,7 @@ func TestLoadManyBlocked(t *testing.T) { expectedResult := "batch_on_timeout_load_many" cb := func(keys dataloader.Keys) { callCount += 1 - k = keys.Keys() + k = keys.RawKeys() close(closeChan) } @@ -666,9 +666,9 @@ func TestKeyHandling(t *testing.T) { batch := func(ctx context.Context, keys dataloader.Keys) *dataloader.ResultMap { m := dataloader.NewResultMap(2) for i := 0; i < keys.Length(); i++ { - key := keys.Keys()[i].(PrimaryKey) + key := keys.RawKeys()[i].(PrimaryKey) if expectedResult[key] != "__skip__" { - m.Set(key.String(), dataloader.Result{Result: expectedResult[key], Err: nil}) + m.Set(key, dataloader.Result{Result: expectedResult[key], Err: nil}) } } return &m diff --git a/strategies/standard/standard.go b/strategies/standard/standard.go index 08ecf4e..9410f42 100644 --- a/strategies/standard/standard.go +++ b/strategies/standard/standard.go @@ -2,6 +2,7 @@ package standard import ( "context" + "fmt" "sync" "time" @@ -277,9 +278,11 @@ func buildResultMap(keyArr []dataloader.Key, r dataloader.ResultMap) dataloader. for _, k := range keyArr { if val, ok := r.GetValue(k); ok { - results.Set(k.String(), val) + results.Set(k, val) } } + fmt.Printf("r: %+v\nresults: %+v\n", r, results) + return results } diff --git a/strategies/standard/standard_test.go b/strategies/standard/standard_test.go index 10b6971..8eb33ab 100644 --- a/strategies/standard/standard_test.go +++ b/strategies/standard/standard_test.go @@ -37,7 +37,7 @@ func getBatchFunction(cb func(dataloader.Keys), result string) dataloader.BatchF cb(keys) m := dataloader.NewResultMap(1) for _, k := range keys.Keys() { - key := k.(PrimaryKey).String() + key := k.(PrimaryKey) m.Set( key, dataloader.Result{ @@ -123,7 +123,7 @@ func TestLoadNoTimeout(t *testing.T) { cb := func(keys dataloader.Keys) { blockWG.Wait() callCount += 1 - k = keys.Keys() + k = keys.RawKeys() close(closeChan) wg.Done() } @@ -204,7 +204,7 @@ func TestLoadManyNoTimeout(t *testing.T) { cb := func(keys dataloader.Keys) { blockWG.Wait() callCount += 1 - k = keys.Keys() + k = keys.RawKeys() close(closeChan) wg.Done() } @@ -291,7 +291,7 @@ func TestLoadTimeout(t *testing.T) { cb := func(keys dataloader.Keys) { blockWG.Wait() callCount += 1 - k = keys.Keys() + k = keys.RawKeys() if callCount == 2 { close(closeChan) } @@ -389,7 +389,7 @@ func TestLoadManyTimeout(t *testing.T) { cb := func(keys dataloader.Keys) { blockWG.Wait() callCount += 1 - k = keys.Keys() + k = keys.RawKeys() if callCount == 2 { close(closeChan) } @@ -556,7 +556,7 @@ func TestKeyHandling(t *testing.T) { for i := 0; i < keys.Length(); i++ { key := keys.Keys()[i].(PrimaryKey) if expectedResult[key] != "__skip__" { - m.Set(key.String(), dataloader.Result{Result: expectedResult[key], Err: nil}) + m.Set(key, dataloader.Result{Result: expectedResult[key], Err: nil}) } } return &m