From 7e84c86aa2f50e66c5be10f10895d70d5af510aa Mon Sep 17 00:00:00 2001 From: AJ Roetker Date: Wed, 5 Mar 2025 13:16:43 -0800 Subject: [PATCH 1/3] Random cleanups --- expr/builtins/aggregations.go | 15 ++++---- expr/builtins/filter.go | 20 +++++------ expr/builtins/json.go | 38 +++++++++++--------- expr/builtins/list_map.go | 67 ++++++++++++++++------------------- 4 files changed, 67 insertions(+), 73 deletions(-) diff --git a/expr/builtins/aggregations.go b/expr/builtins/aggregations.go index 828f07d6..07110bcd 100644 --- a/expr/builtins/aggregations.go +++ b/expr/builtins/aggregations.go @@ -11,9 +11,8 @@ import ( // Avg average of values. Note, this function DOES NOT persist state doesn't aggregate // across multiple calls. That would be responsibility of write context. // -// avg(1,2,3) => 2.0, true -// avg("hello") => math.NaN, false -// +// avg(1,2,3) => 2.0, true +// avg("hello") => math.NaN, false type Avg struct{} // Type is NumberType @@ -68,9 +67,8 @@ func avgEval(ctx expr.EvalContext, vals []value.Value) (value.Value, bool) { // Sum function to add values. Note, this function DOES NOT persist state doesn't aggregate // across multiple calls. That would be responsibility of write context. // -// sum(1, 2, 3) => 6 -// sum(1, "horse", 3) => nan, false -// +// sum(1, 2, 3) => 6 +// sum(1, "horse", 3) => nan, false type Sum struct{} // Type is number @@ -133,9 +131,8 @@ func sumEval(ctx expr.EvalContext, vals []value.Value) (value.Value, bool) { // and in general is a horrible, horrible function that needs to be replaced // with occurrences of value, ignores the value and ensures it is non null // -// count(anyvalue) => 1, true -// count(not_number) => -- 0, false -// +// count(anyvalue) => 1, true +// count(not_number) => -- 0, false type Count struct{} // Type is Integer diff --git a/expr/builtins/filter.go b/expr/builtins/filter.go index 32a317e5..8761fef9 100644 --- a/expr/builtins/filter.go +++ b/expr/builtins/filter.go @@ -12,8 +12,7 @@ import ( // OneOf choose the first non-nil, non-zero, non-false fields // -// oneof(nil, 0, "hello") => 'hello' -// +// oneof(nil, 0, "hello") => 'hello' type OneOf struct{} // Type unknown @@ -65,16 +64,15 @@ func FiltersFromArgs(filterVals []value.Value) []string { // // Filter a map of values by key to remove certain keys // -// filter(match("topic_"),key_to_filter, key2_to_filter) => {"goodkey": 22}, true +// filter(match("topic_"),key_to_filter, key2_to_filter) => {"goodkey": 22}, true // // Filter out VALUES (not keys) from a list of []string{} for a specific value // -// filter(split("apples,oranges",","),"ora*") => ["apples"], true +// filter(split("apples,oranges",","),"ora*") => ["apples"], true // // Filter out values for single strings // -// filter("apples","app*") => []string{}, true -// +// filter("apples","app*") => []string{}, true type Filter struct{} // Type unknown @@ -187,16 +185,15 @@ func FilterEval(ctx expr.EvalContext, vals []value.Value) (value.Value, bool) { // // Filter a map of values by key to only keep certain keys // -// filtermatch(match("topic_"),key_to_filter, key2_to_filter) => {"goodkey": 22}, true +// filtermatch(match("topic_"),key_to_filter, key2_to_filter) => {"goodkey": 22}, true // // Filter in VALUES (not keys) from a list of []string{} for a specific value // -// filtermatch(split("apples,oranges",","),"ora*") => ["oranges"], true +// filtermatch(split("apples,oranges",","),"ora*") => ["oranges"], true // // Filter in values for single strings // -// filtermatch("apples","app*") => []string{"apple"}, true -// +// filtermatch("apples","app*") => []string{"apple"}, true type FilterMatch struct{} // Type Unknown @@ -272,6 +269,9 @@ func FilterMatchEval(ctx expr.EvalContext, vals []value.Value) (value.Value, boo lv := make([]string, 0, val.Len()) for _, slv := range val.SliceValue() { + switch slv.Type() { + case value.StringType: + } sv := slv.ToString() filteredIn := false for _, filter := range filters { diff --git a/expr/builtins/json.go b/expr/builtins/json.go index 76766ea9..1808ed14 100644 --- a/expr/builtins/json.go +++ b/expr/builtins/json.go @@ -15,10 +15,9 @@ var _ = u.EMPTY // JsonPath jmespath json parser http://jmespath.org/ // -// json_field = `[{"name":"n1","ct":8,"b":true, "tags":["a","b"]},{"name":"n2","ct":10,"b": false, "tags":["a","b"]}]` -// -// json.jmespath(json_field, "[?name == 'n1'].name | [0]") => "n1" +// json_field = `[{"name":"n1","ct":8,"b":true, "tags":["a","b"]},{"name":"n2","ct":10,"b": false, "tags":["a","b"]}]` // +// json.jmespath(json_field, "[?name == 'n1'].name | [0]") => "n1" type JsonPath struct{} func (m *JsonPath) Type() value.ValueType { return value.UnknownType } @@ -34,33 +33,38 @@ func (m *JsonPath) Validate(n *expr.FuncNode) (expr.EvaluatorFunc, error) { default: return nil, fmt.Errorf("expected a string expression for jmespath got %T", jn) } - - parser := jmespath.NewParser() - _, err := parser.Parse(jsonPathExpr) + jmes, err := jmespath.Compile(jsonPathExpr) if err != nil { - // if syntaxError, ok := err.(jmespath.SyntaxError); ok { - // u.Warnf("%s\n%s\n", syntaxError, syntaxError.HighlightLocation()) - // } return nil, err } - return jsonPathEval(jsonPathExpr), nil + return jsonPathEval(jmes), nil } -func jsonPathEval(expression string) expr.EvaluatorFunc { +func jsonPathEval(jmes *jmespath.JMESPath) expr.EvaluatorFunc { return func(ctx expr.EvalContext, args []value.Value) (value.Value, bool) { if args[0] == nil || args[0].Err() || args[0].Nil() { return nil, false } - - val := args[0].ToString() - - // Validate that this is valid json? + a := args[0] + var val []byte + var err error + switch { + case a.Type().IsMap() || a.Type().IsSlice(): + // TODO (ajr): need to recursively change value.Value to interface{} and extract the values + // this is a bit of a hack to do that + val, err = json.Marshal(a.Value()) + if err != nil { + return nil, false + } + default: + val = []byte(args[0].ToString()) + } var data interface{} - if err := json.Unmarshal([]byte(val), &data); err != nil { + if err := json.Unmarshal(val, &data); err != nil { return nil, false } - result, err := jmespath.Search(expression, data) + result, err := jmes.Search(data) if err != nil { return nil, false } diff --git a/expr/builtins/list_map.go b/expr/builtins/list_map.go index 2b1a0940..7f975921 100644 --- a/expr/builtins/list_map.go +++ b/expr/builtins/list_map.go @@ -15,9 +15,8 @@ var _ = u.EMPTY // len length of array types // -// len([1,2,3]) => 3, true -// len(not_a_field) => -- NilInt, false -// +// len([1,2,3]) => 3, true +// len(not_a_field) => -- NilInt, false type Length struct{} // Type is IntType @@ -63,13 +62,12 @@ func lenEval(ctx expr.EvalContext, args []value.Value) (value.Value, bool) { // ArrayIndex array.index choose the nth element of an array // -// // given context input of -// "items" = [1,2,3] -// -// array.index(items, 1) => 1, true -// array.index(items, 5) => nil, false -// array.index(items, -1) => 3, true +// // given context input of +// "items" = [1,2,3] // +// array.index(items, 1) => 1, true +// array.index(items, 5) => nil, false +// array.index(items, -1) => 3, true type ArrayIndex struct{} // Type unknown - returns single value from SliceValue array @@ -108,12 +106,12 @@ func arrayIndexEval(ctx expr.EvalContext, args []value.Value) (value.Value, bool // array.slice slice element m -> n of a slice. First arg must be a slice. // -// // given context of -// "items" = [1,2,3,4,5] +// // given context of +// "items" = [1,2,3,4,5] // -// array.slice(items, 1, 3) => [2,3], true -// array.slice(items, 2) => [3,4,5], true -// array.slice(items, -2) => [4,5], true +// array.slice(items, 1, 3) => [2,3], true +// array.slice(items, 2) => [3,4,5], true +// array.slice(items, -2) => [4,5], true type ArraySlice struct{} // Type Unknown for Array Slice @@ -208,8 +206,7 @@ func arraySliceEval(ctx expr.EvalContext, args []value.Value) (value.Value, bool // Map Create a map from two values. If the right side value is nil // then does not evaluate. // -// map(left, right) => map[string]value{left:right} -// +// map(left, right) => map[string]value{left:right} type MapFunc struct{} // Type is MapValueType @@ -236,9 +233,8 @@ func mapEval(ctx expr.EvalContext, args []value.Value) (value.Value, bool) { // MapTime() Create a map[string]time of each key // -// maptime(field) => map[string]time{field_value:message_timestamp} -// maptime(field, timestamp) => map[string]time{field_value:timestamp} -// +// maptime(field) => map[string]time{field_value:message_timestamp} +// maptime(field, timestamp) => map[string]time{field_value:timestamp} type MapTime struct{} // Type MapTime @@ -282,13 +278,13 @@ func mapTimeEval(ctx expr.EvalContext, args []value.Value) (value.Value, bool) { // - May pass as many match strings as you want. // - Must match on Prefix of key. // -// given input context of: -// {"score_value":24,"event_click":true, "tag_apple": "apple", "label_orange": "orange"} +// given input context of: +// {"score_value":24,"event_click":true, "tag_apple": "apple", "label_orange": "orange"} // -// match("score_") => {"value":24} -// match("amount_") => false -// match("event_") => {"click":true} -// match("label_","tag_") => {"apple":"apple","orange":"orange"} +// match("score_") => {"value":24} +// match("amount_") => false +// match("event_") => {"click":true} +// match("label_","tag_") => {"apple":"apple","orange":"orange"} type Match struct{} // Type is MapValueType @@ -328,11 +324,10 @@ func matchEval(ctx expr.EvalContext, args []value.Value) (value.Value, bool) { // MapKeys: Take a map and extract array of keys // -// //given input: -// {"tag.1":"news","tag.2":"sports"} -// -// mapkeys(match("tag.")) => []string{"news","sports"} +// //given input: +// {"tag.1":"news","tag.2":"sports"} // +// mapkeys(match("tag.")) => []string{"news","sports"} type MapKeys struct{} // Type []string aka strings @@ -368,11 +363,10 @@ func mapKeysEval(ctx expr.EvalContext, args []value.Value) (value.Value, bool) { // MapValues: Take a map and extract array of values // -// // given input: -// {"tag.1":"news","tag.2":"sports"} -// -// mapvalue(match("tag.")) => []string{"1","2"} +// // given input: +// {"tag.1":"news","tag.2":"sports"} // +// mapvalue(match("tag.")) => []string{"1","2"} type MapValues struct{} // Type strings aka []string @@ -409,11 +403,10 @@ func mapValuesEval(ctx expr.EvalContext, args []value.Value) (value.Value, bool) // MapInvert: Take a map and invert key/values // -// // given input: -// tags = {"1":"news","2":"sports"} -// -// mapinvert(tags) => map[string]string{"news":"1","sports":"2"} +// // given input: +// tags = {"1":"news","2":"sports"} // +// mapinvert(tags) => map[string]string{"news":"1","sports":"2"} type MapInvert struct{} // Type MapValue From 3016bfd37e6df33b556b9c6c0a7db38649f028f4 Mon Sep 17 00:00:00 2001 From: AJ Roetker Date: Wed, 5 Mar 2025 15:26:50 -0800 Subject: [PATCH 2/3] Add filterql vm optimized --- filterqlvm/compiler/compiler.go | 965 ++++++++++++++++++++++++++++++++ filterqlvm/filterqlvm.go | 140 +++++ filterqlvm/filterqlvm_test.go | 251 +++++++++ 3 files changed, 1356 insertions(+) create mode 100644 filterqlvm/compiler/compiler.go create mode 100644 filterqlvm/filterqlvm.go create mode 100644 filterqlvm/filterqlvm_test.go diff --git a/filterqlvm/compiler/compiler.go b/filterqlvm/compiler/compiler.go new file mode 100644 index 00000000..e138d0ed --- /dev/null +++ b/filterqlvm/compiler/compiler.go @@ -0,0 +1,965 @@ +package compiler + +import ( + "fmt" + "hash/fnv" + "strings" + "sync" + + "github.com/mb0/glob" + + "github.com/lytics/qlbridge/expr" + "github.com/lytics/qlbridge/lex" + "github.com/lytics/qlbridge/value" + "github.com/lytics/qlbridge/vm" +) + +// CompiledExpr represents a compiled expression that can be evaluated directly +type CompiledExpr struct { + // Original AST node that was compiled + Node expr.Node + // Direct evaluation function + EvalFunc func(ctx expr.EvalContext) (value.Value, bool) +} + +// DirectCompiler generates optimized Go functions directly in memory +type DirectCompiler struct { + cache map[uint64]*CompiledExpr + cacheLock sync.RWMutex +} + +// NewDirectCompiler creates a new direct compiler +func NewDirectCompiler() *DirectCompiler { + return &DirectCompiler{ + cache: make(map[uint64]*CompiledExpr), + } +} + +// Compile compiles an expression node into a direct evaluation function +func (c *DirectCompiler) Compile(node expr.Node) (*CompiledExpr, error) { + // Generate a hash for the node to use as cache key + hash := hashNode(node) + + // Check cache first + c.cacheLock.RLock() + if compiled, ok := c.cache[hash]; ok { + c.cacheLock.RUnlock() + return compiled, nil + } + c.cacheLock.RUnlock() + + // Create the compiled expression + compiled, err := c.compileToFunc(node) + if err != nil { + return nil, err + } + + // Cache the result + c.cacheLock.Lock() + c.cache[hash] = compiled + c.cacheLock.Unlock() + + return compiled, nil +} + +// compileToFunc creates a direct evaluation function for a node +func (c *DirectCompiler) compileToFunc(node expr.Node) (*CompiledExpr, error) { + switch n := node.(type) { + case *expr.BinaryNode: + return c.compileBinary(n) + case *expr.BooleanNode: + return c.compileBoolean(n) + case *expr.UnaryNode: + return c.compileUnary(n) + case *expr.IdentityNode: + return c.compileIdentity(n) + case *expr.NumberNode: + return c.compileNumber(n) + case *expr.StringNode: + return c.compileString(n) + case *expr.FuncNode: + return c.compileFunc(n) + case *expr.TriNode: + return c.compileTernary(n) + case *expr.ArrayNode: + return c.compileArray(n) + case *expr.IncludeNode: + return c.compileInclude(n) + case *expr.NullNode: + return &CompiledExpr{ + Node: n, + EvalFunc: func(ctx expr.EvalContext) (value.Value, bool) { + return value.NewNilValue(), true + }, + }, nil + case *expr.ValueNode: + return c.compileValue(n) + default: + return nil, fmt.Errorf("unsupported node type: %T", node) + } +} + +// compileBinary creates a direct function for binary operations +func (c *DirectCompiler) compileBinary(node *expr.BinaryNode) (*CompiledExpr, error) { + // Compile left and right operands + leftExpr, err := c.compileToFunc(node.Args[0]) + if err != nil { + return nil, err + } + + rightExpr, err := c.compileToFunc(node.Args[1]) + if err != nil { + return nil, err + } + + // Create a direct function based on the operator + evalFunc := func(ctx expr.EvalContext) (value.Value, bool) { + // Get left and right values + left, leftOk := leftExpr.EvalFunc(ctx) + if !leftOk { + return nil, false + } + + // Short-circuit for logical operators + switch node.Operator.T { + case lex.TokenLogicAnd, lex.TokenAnd: + // If left is false, no need to evaluate right + if leftBool, ok := left.(value.BoolValue); ok && !leftBool.Val() { + return value.NewBoolValue(false), true + } + case lex.TokenLogicOr, lex.TokenOr: + // If left is true, no need to evaluate right + if leftBool, ok := left.(value.BoolValue); ok && leftBool.Val() { + return value.NewBoolValue(true), true + } + } + + right, rightOk := rightExpr.EvalFunc(ctx) + if !rightOk { + return nil, false + } + + // Handle different operators + switch node.Operator.T { + case lex.TokenEqual, lex.TokenEqualEqual: + eq, err := value.Equal(left, right) + if err != nil { + return value.NewBoolValue(false), true + } + return value.NewBoolValue(eq), true + + case lex.TokenNE: + eq, err := value.Equal(left, right) + if err != nil { + return value.NewBoolValue(true), true + } + return value.NewBoolValue(!eq), true + + case lex.TokenGT: + // Handle different value types + switch lv := left.(type) { + case value.NumberValue: + if rv, ok := right.(value.NumberValue); ok { + return value.NewBoolValue(lv.Val() > rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + return value.NewBoolValue(lv.Val() > float64(rv.Val())), true + } + case value.IntValue: + if rv, ok := right.(value.NumberValue); ok { + return value.NewBoolValue(float64(lv.Val()) > rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + return value.NewBoolValue(lv.Val() > rv.Val()), true + } + case value.StringValue: + if rv, ok := right.(value.StringValue); ok { + return value.NewBoolValue(lv.Val() > rv.Val()), true + } + case value.TimeValue: + if rv, ok := right.(value.TimeValue); ok { + return value.NewBoolValue(lv.Val().Unix() > rv.Val().Unix()), true + } + } + // Try converting to strings + return value.NewBoolValue(left.ToString() > right.ToString()), true + + case lex.TokenGE: + // Similar to GT but with >= comparison + switch lv := left.(type) { + case value.NumberValue: + if rv, ok := right.(value.NumberValue); ok { + return value.NewBoolValue(lv.Val() >= rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + return value.NewBoolValue(lv.Val() >= float64(rv.Val())), true + } + case value.IntValue: + if rv, ok := right.(value.NumberValue); ok { + return value.NewBoolValue(float64(lv.Val()) >= rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + return value.NewBoolValue(lv.Val() >= rv.Val()), true + } + case value.StringValue: + if rv, ok := right.(value.StringValue); ok { + return value.NewBoolValue(lv.Val() >= rv.Val()), true + } + case value.TimeValue: + if rv, ok := right.(value.TimeValue); ok { + return value.NewBoolValue(lv.Val().Unix() >= rv.Val().Unix()), true + } + } + return value.NewBoolValue(left.ToString() >= right.ToString()), true + + case lex.TokenLT: + // Similar to GT but with < comparison + switch lv := left.(type) { + case value.NumberValue: + if rv, ok := right.(value.NumberValue); ok { + return value.NewBoolValue(lv.Val() < rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + return value.NewBoolValue(lv.Val() < float64(rv.Val())), true + } + case value.IntValue: + if rv, ok := right.(value.NumberValue); ok { + return value.NewBoolValue(float64(lv.Val()) < rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + return value.NewBoolValue(lv.Val() < rv.Val()), true + } + case value.StringValue: + if rv, ok := right.(value.StringValue); ok { + return value.NewBoolValue(lv.Val() < rv.Val()), true + } + case value.TimeValue: + if rv, ok := right.(value.TimeValue); ok { + return value.NewBoolValue(lv.Val().Unix() < rv.Val().Unix()), true + } + } + return value.NewBoolValue(left.ToString() < right.ToString()), true + + case lex.TokenLE: + // Similar to GT but with <= comparison + switch lv := left.(type) { + case value.NumberValue: + if rv, ok := right.(value.NumberValue); ok { + return value.NewBoolValue(lv.Val() <= rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + return value.NewBoolValue(lv.Val() <= float64(rv.Val())), true + } + case value.IntValue: + if rv, ok := right.(value.NumberValue); ok { + return value.NewBoolValue(float64(lv.Val()) <= rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + return value.NewBoolValue(lv.Val() <= rv.Val()), true + } + case value.StringValue: + if rv, ok := right.(value.StringValue); ok { + return value.NewBoolValue(lv.Val() <= rv.Val()), true + } + case value.TimeValue: + if rv, ok := right.(value.TimeValue); ok { + return value.NewBoolValue(lv.Val().Unix() <= rv.Val().Unix()), true + } + } + return value.NewBoolValue(left.ToString() <= right.ToString()), true + + case lex.TokenPlus: + // Addition or string concatenation + switch lv := left.(type) { + case value.NumberValue: + if rv, ok := right.(value.NumberValue); ok { + return value.NewNumberValue(lv.Val() + rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + return value.NewNumberValue(lv.Val() + float64(rv.Val())), true + } + case value.IntValue: + if rv, ok := right.(value.NumberValue); ok { + return value.NewNumberValue(float64(lv.Val()) + rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + return value.NewIntValue(lv.Val() + rv.Val()), true + } + } + // Fallback to string concatenation + return value.NewStringValue(left.ToString() + right.ToString()), true + + case lex.TokenMinus: + // Subtraction + switch lv := left.(type) { + case value.NumberValue: + if rv, ok := right.(value.NumberValue); ok { + return value.NewNumberValue(lv.Val() - rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + return value.NewNumberValue(lv.Val() - float64(rv.Val())), true + } + case value.IntValue: + if rv, ok := right.(value.NumberValue); ok { + return value.NewNumberValue(float64(lv.Val()) - rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + return value.NewIntValue(lv.Val() - rv.Val()), true + } + } + return nil, false + + case lex.TokenMultiply, lex.TokenStar: + // Multiplication + switch lv := left.(type) { + case value.NumberValue: + if rv, ok := right.(value.NumberValue); ok { + return value.NewNumberValue(lv.Val() * rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + return value.NewNumberValue(lv.Val() * float64(rv.Val())), true + } + case value.IntValue: + if rv, ok := right.(value.NumberValue); ok { + return value.NewNumberValue(float64(lv.Val()) * rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + return value.NewIntValue(lv.Val() * rv.Val()), true + } + } + return nil, false + + case lex.TokenDivide: + // Division + switch lv := left.(type) { + case value.NumberValue: + if rv, ok := right.(value.NumberValue); ok { + if rv.Val() == 0 { + return nil, false // Division by zero + } + return value.NewNumberValue(lv.Val() / rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + if rv.Val() == 0 { + return nil, false // Division by zero + } + return value.NewNumberValue(lv.Val() / float64(rv.Val())), true + } + case value.IntValue: + if rv, ok := right.(value.NumberValue); ok { + if rv.Val() == 0 { + return nil, false // Division by zero + } + return value.NewNumberValue(float64(lv.Val()) / rv.Val()), true + } else if rv, ok := right.(value.IntValue); ok { + if rv.Val() == 0 { + return nil, false // Division by zero + } + return value.NewIntValue(lv.Val() / rv.Val()), true + } + } + return nil, false + + case lex.TokenModulus: + // Modulus (remainder) + switch lv := left.(type) { + case value.IntValue: + if rv, ok := right.(value.IntValue); ok { + if rv.Val() == 0 { + return nil, false // Modulus by zero + } + return value.NewIntValue(lv.Val() % rv.Val()), true + } + case value.NumberValue: + if rv, ok := right.(value.NumberValue); ok { + if rv.Val() == 0 { + return nil, false // Modulus by zero + } + return value.NewNumberValue(float64(int64(lv.Val()) % int64(rv.Val()))), true + } else if rv, ok := right.(value.IntValue); ok { + if rv.Val() == 0 { + return nil, false // Modulus by zero + } + return value.NewNumberValue(float64(int64(lv.Val()) % rv.Val())), true + } + } + return nil, false + + case lex.TokenLogicAnd, lex.TokenAnd: + // Logical AND + leftBool, ok := left.(value.BoolValue) + if !ok { + return value.NewBoolValue(false), true + } + + rightBool, ok := right.(value.BoolValue) + if !ok { + return value.NewBoolValue(false), true + } + + return value.NewBoolValue(leftBool.Val() && rightBool.Val()), true + + case lex.TokenLogicOr, lex.TokenOr: + // Logical OR + leftBool, ok := left.(value.BoolValue) + if !ok { + return value.NewBoolValue(false), true + } + + rightBool, ok := right.(value.BoolValue) + if !ok { + return value.NewBoolValue(false), true + } + + return value.NewBoolValue(leftBool.Val() || rightBool.Val()), true + + case lex.TokenContains: + // Contains operation (string or slice) + switch lv := left.(type) { + case value.StringValue: + if rv, ok := right.(value.StringValue); ok { + return value.NewBoolValue(strings.Contains(lv.Val(), rv.Val())), true + } + return value.NewBoolValue(strings.Contains(lv.Val(), right.ToString())), true + + case value.Slice: + // Check if right side is in left slice + for _, item := range lv.SliceValue() { + if eq, err := value.Equal(item, right); err == nil && eq { + return value.NewBoolValue(true), true + } + } + return value.NewBoolValue(false), true + } + return value.NewBoolValue(false), true + + case lex.TokenLike: + // LIKE pattern matching + leftStr, ok := value.ValueToString(left) + if !ok { + return value.NewBoolValue(false), true + } + + rightStr, ok := value.ValueToString(right) + if !ok { + return value.NewBoolValue(false), true + } + + // Convert SQL LIKE pattern to glob pattern + pattern := strings.Replace(rightStr, "%", "*", -1) + + // Use glob matching (to be imported) + match, err := glob.Match(pattern, leftStr) + if err != nil { + return value.NewBoolValue(false), true + } + + return value.NewBoolValue(match), true + + case lex.TokenIN, lex.TokenIntersects: + // IN or INTERSECTS operation + switch rv := right.(type) { + case value.Slice: + // Check if left side is in right slice + for _, item := range rv.SliceValue() { + if eq, err := value.Equal(left, item); err == nil && eq { + return value.NewBoolValue(true), true + } + } + return value.NewBoolValue(false), true + + case value.Map: + // Check if left key exists in map + leftStr, ok := value.ValueToString(left) + if !ok { + return value.NewBoolValue(false), true + } + + _, exists := rv.Get(leftStr) + return value.NewBoolValue(exists), true + } + return value.NewBoolValue(false), true + } + + return nil, false + } + + return &CompiledExpr{ + Node: node, + EvalFunc: evalFunc, + }, nil +} + +// compileBoolean creates a direct function for boolean operations (AND/OR) +func (c *DirectCompiler) compileBoolean(node *expr.BooleanNode) (*CompiledExpr, error) { + // Compile all arguments + argExprs := make([]*CompiledExpr, len(node.Args)) + for i, arg := range node.Args { + compiled, err := c.compileToFunc(arg) + if err != nil { + return nil, err + } + argExprs[i] = compiled + } + + // Create a direct function based on the operator + evalFunc := func(ctx expr.EvalContext) (value.Value, bool) { + switch node.Operator.T { + case lex.TokenLogicAnd, lex.TokenAnd: + // Short-circuit AND: if any arg is false, return false + for _, argExpr := range argExprs { + result, ok := argExpr.EvalFunc(ctx) + if !ok { + // If we can't evaluate an argument, for AND we return false + if node.Negated() { + return value.NewBoolValue(true), true + } + return value.NewBoolValue(false), true + } + + if boolVal, ok := result.(value.BoolValue); ok { + if !boolVal.Val() { + // Short-circuit: found a false value + if node.Negated() { + return value.NewBoolValue(true), true + } + return value.NewBoolValue(false), true + } + } else { + // Non-boolean value in AND, treat as false + if node.Negated() { + return value.NewBoolValue(true), true + } + return value.NewBoolValue(false), true + } + } + + // All arguments were true + if node.Negated() { + return value.NewBoolValue(false), true + } + return value.NewBoolValue(true), true + + case lex.TokenLogicOr, lex.TokenOr: + // Short-circuit OR: if any arg is true, return true + for _, argExpr := range argExprs { + result, ok := argExpr.EvalFunc(ctx) + if !ok { + continue // Try the next argument + } + + if boolVal, ok := result.(value.BoolValue); ok { + if boolVal.Val() { + // Short-circuit: found a true value + if node.Negated() { + return value.NewBoolValue(false), true + } + return value.NewBoolValue(true), true + } + } + } + + // No argument was true + if node.Negated() { + return value.NewBoolValue(true), true + } + return value.NewBoolValue(false), true + } + + // Unsupported operator + return nil, false + } + + return &CompiledExpr{ + Node: node, + EvalFunc: evalFunc, + }, nil +} + +// compileUnary creates a direct function for unary operations +func (c *DirectCompiler) compileUnary(node *expr.UnaryNode) (*CompiledExpr, error) { + // Compile the argument + argExpr, err := c.compileToFunc(node.Arg) + if err != nil { + return nil, err + } + + // Create a direct function based on the operator + evalFunc := func(ctx expr.EvalContext) (value.Value, bool) { + switch node.Operator.T { + case lex.TokenNegate: + result, ok := argExpr.EvalFunc(ctx) + if !ok { + return value.NewBoolValue(false), true + } + + if boolVal, ok := result.(value.BoolValue); ok { + return value.NewBoolValue(!boolVal.Val()), true + } + + return value.NewBoolValue(false), true + + case lex.TokenMinus: + result, ok := argExpr.EvalFunc(ctx) + if !ok { + return nil, false + } + + if numVal, ok := result.(value.NumberValue); ok { + return value.NewNumberValue(-numVal.Val()), true + } else if intVal, ok := result.(value.IntValue); ok { + return value.NewIntValue(-intVal.Val()), true + } + + return nil, false + + case lex.TokenExists: + result, ok := argExpr.EvalFunc(ctx) + if !ok { + return value.NewBoolValue(false), true + } + + if result == nil || result.Nil() { + return value.NewBoolValue(false), true + } + + return value.NewBoolValue(true), true + } + + return nil, false + } + + return &CompiledExpr{ + Node: node, + EvalFunc: evalFunc, + }, nil +} + +// compileIdentity creates a direct function for identity (variable) references +func (c *DirectCompiler) compileIdentity(node *expr.IdentityNode) (*CompiledExpr, error) { + // Boolean identities (true/false) + if node.IsBooleanIdentity() { + boolValue := node.Bool() + return &CompiledExpr{ + Node: node, + EvalFunc: func(ctx expr.EvalContext) (value.Value, bool) { + return value.NewBoolValue(boolValue), true + }, + }, nil + } + + // Create a direct lookup function + evalFunc := func(ctx expr.EvalContext) (value.Value, bool) { + if node.HasLeftRight() { + return ctx.Get(node.OriginalText()) + } + return ctx.Get(node.Text) + } + + return &CompiledExpr{ + Node: node, + EvalFunc: evalFunc, + }, nil +} + +// compileNumber creates a direct function for numeric literals +func (c *DirectCompiler) compileNumber(node *expr.NumberNode) (*CompiledExpr, error) { + if node.IsInt { + intValue := node.Int64 + return &CompiledExpr{ + Node: node, + EvalFunc: func(ctx expr.EvalContext) (value.Value, bool) { + return value.NewIntValue(intValue), true + }, + }, nil + } + + floatValue := node.Float64 + return &CompiledExpr{ + Node: node, + EvalFunc: func(ctx expr.EvalContext) (value.Value, bool) { + return value.NewNumberValue(floatValue), true + }, + }, nil +} + +// compileString creates a direct function for string literals +func (c *DirectCompiler) compileString(node *expr.StringNode) (*CompiledExpr, error) { + strValue := node.Text + return &CompiledExpr{ + Node: node, + EvalFunc: func(ctx expr.EvalContext) (value.Value, bool) { + return value.NewStringValue(strValue), true + }, + }, nil +} + +// compileFunc creates a direct function for function calls +func (c *DirectCompiler) compileFunc(node *expr.FuncNode) (*CompiledExpr, error) { + // Compile all arguments + argExprs := make([]*CompiledExpr, len(node.Args)) + for i, arg := range node.Args { + compiled, err := c.compileToFunc(arg) + if err != nil { + return nil, err + } + argExprs[i] = compiled + } + + // Get the function implementation + // funcName := strings.ToLower(node.Name) + + // Create a direct function + evalFunc := func(ctx expr.EvalContext) (value.Value, bool) { + // Evaluate all arguments + args := make([]value.Value, len(argExprs)) + for i, argExpr := range argExprs { + arg, ok := argExpr.EvalFunc(ctx) + if !ok { + arg = value.NewNilValue() + } + args[i] = arg + } + + // Execute function + return node.Eval(ctx, args) + } + + return &CompiledExpr{ + Node: node, + EvalFunc: evalFunc, + }, nil +} + +// compileTernary creates a direct function for ternary operations (BETWEEN) +func (c *DirectCompiler) compileTernary(node *expr.TriNode) (*CompiledExpr, error) { + // Only BETWEEN is currently supported + if node.Operator.T != lex.TokenBetween { + return nil, fmt.Errorf("unsupported ternary operator: %v", node.Operator.T) + } + + // Compile all arguments + valueExpr, err := c.compileToFunc(node.Args[0]) + if err != nil { + return nil, err + } + + lowerExpr, err := c.compileToFunc(node.Args[1]) + if err != nil { + return nil, err + } + + upperExpr, err := c.compileToFunc(node.Args[2]) + if err != nil { + return nil, err + } + + // Create a direct function + evalFunc := func(ctx expr.EvalContext) (value.Value, bool) { + // Evaluate all arguments + val, valueOk := valueExpr.EvalFunc(ctx) + if !valueOk { + return nil, false + } + + lower, lowerOk := lowerExpr.EvalFunc(ctx) + if !lowerOk { + return nil, false + } + + upper, upperOk := upperExpr.EvalFunc(ctx) + if !upperOk { + return nil, false + } + + // Compare based on value types + switch v := val.(type) { + case value.NumberValue: + lowerVal, lowerOk := value.ValueToFloat64(lower) + upperVal, upperOk := value.ValueToFloat64(upper) + + if lowerOk && upperOk { + result := v.Val() > lowerVal && v.Val() < upperVal + if node.Negated() { + return value.NewBoolValue(!result), true + } + return value.NewBoolValue(result), true + } + + case value.IntValue: + lowerVal, lowerOk := value.ValueToInt64(lower) + upperVal, upperOk := value.ValueToInt64(upper) + + if lowerOk && upperOk { + result := v.Val() > lowerVal && v.Val() < upperVal + if node.Negated() { + return value.NewBoolValue(!result), true + } + return value.NewBoolValue(result), true + } + + case value.TimeValue: + lowerTime, lowerOk := value.ValueToTime(lower) + upperTime, upperOk := value.ValueToTime(upper) + + if lowerOk && upperOk { + result := v.Val().After(lowerTime) && v.Val().Before(upperTime) + if node.Negated() { + return value.NewBoolValue(!result), true + } + return value.NewBoolValue(result), true + } + } + + // Fallback to string comparison + valueStr := val.ToString() + lowerStr := lower.ToString() + upperStr := upper.ToString() + + result := valueStr > lowerStr && valueStr < upperStr + if node.Negated() { + return value.NewBoolValue(!result), true + } + return value.NewBoolValue(result), true + } + + return &CompiledExpr{ + Node: node, + EvalFunc: evalFunc, + }, nil +} + +// compileArray creates a direct function for array nodes +func (c *DirectCompiler) compileArray(node *expr.ArrayNode) (*CompiledExpr, error) { + // Compile all elements + elemExprs := make([]*CompiledExpr, len(node.Args)) + for i, arg := range node.Args { + compiled, err := c.compileToFunc(arg) + if err != nil { + return nil, err + } + elemExprs[i] = compiled + } + + // Create a direct function + evalFunc := func(ctx expr.EvalContext) (value.Value, bool) { + // Evaluate all elements + elems := make([]value.Value, len(elemExprs)) + for i, elemExpr := range elemExprs { + elem, ok := elemExpr.EvalFunc(ctx) + if !ok { + elem = value.NewNilValue() + } + elems[i] = elem + } + + // Create slice value + return value.NewSliceValues(elems), true + } + + return &CompiledExpr{ + Node: node, + EvalFunc: evalFunc, + }, nil +} + +// compileInclude creates a direct function for include nodes +func (c *DirectCompiler) compileInclude(node *expr.IncludeNode) (*CompiledExpr, error) { + includeID := node.Identity.Text + + // Create a direct function + evalFunc := func(ctx expr.EvalContext) (value.Value, bool) { + // Check for IncludeCache + if cacheCtx, hasCacheCtx := ctx.(expr.IncludeCacheContextV2); hasCacheCtx { + matches, err := cacheCtx.GetOrSet(includeID, func() (bool, error) { + return evaluateInclude(ctx, node) + }) + + if err != nil { + if node.Negated() { + return value.NewBoolValue(true), true + } + return nil, false + } + + if node.Negated() { + return value.NewBoolValue(!matches), true + } + return value.NewBoolValue(matches), true + } + + // Non-cached evaluation + matches, err := evaluateInclude(ctx, node) + if err != nil { + if node.Negated() { + return value.NewBoolValue(true), true + } + return nil, false + } + + if node.Negated() { + return value.NewBoolValue(!matches), true + } + return value.NewBoolValue(matches), true + } + + return &CompiledExpr{ + Node: node, + EvalFunc: evalFunc, + }, nil +} + +// evaluateInclude evaluates an include node +func evaluateInclude(ctx expr.EvalContext, node *expr.IncludeNode) (bool, error) { + includer, ok := ctx.(expr.EvalIncludeContext) + if !ok { + return false, fmt.Errorf("no inclusion context") + } + + // Resolve the included expression + includedExpr := node.ExprNode + if includedExpr == nil { + var err error + includedExpr, err = includer.Include(node.Identity.Text) + if err != nil || includedExpr == nil { + return false, fmt.Errorf("walking include: %w", err) + } + + // Save for future use + node.ExprNode = includedExpr + } + + // Check for wildcard includes + if idNode, ok := includedExpr.(*expr.IdentityNode); ok { + if idNode.Text == "*" || idNode.Text == "match_all" { + return true, nil + } + } + + // Evaluate the included expression + result, ok := vm.Eval(ctx, includedExpr) + if !ok { + return false, fmt.Errorf("evaluating expression") + } + + // Convert to boolean + if boolVal, ok := result.(value.BoolValue); ok { + return boolVal.Val(), nil + } + + return false, nil +} + +// compileValue creates a direct function for value nodes +func (c *DirectCompiler) compileValue(node *expr.ValueNode) (*CompiledExpr, error) { + if node.Value == nil { + return &CompiledExpr{ + Node: node, + EvalFunc: func(ctx expr.EvalContext) (value.Value, bool) { + return nil, false + }, + }, nil + } + + // Store value for direct return + val := node.Value + + return &CompiledExpr{ + Node: node, + EvalFunc: func(ctx expr.EvalContext) (value.Value, bool) { + return val, true + }, + }, nil +} + +// hashNode creates a hash of an expression node +func hashNode(node expr.Node) uint64 { + h := fnv.New64() + h.Write([]byte(node.String())) + return h.Sum64() +} diff --git a/filterqlvm/filterqlvm.go b/filterqlvm/filterqlvm.go new file mode 100644 index 00000000..d9ef78f7 --- /dev/null +++ b/filterqlvm/filterqlvm.go @@ -0,0 +1,140 @@ +package filterqlvm + +import ( + "github.com/lytics/qlbridge/expr" + "github.com/lytics/qlbridge/filterqlvm/compiler" + "github.com/lytics/qlbridge/rel" + "github.com/lytics/qlbridge/value" + "github.com/lytics/qlbridge/vm" +) + +// OptimizedVM is a high-performance virtual machine for evaluating FilterQL expressions +// It compiles expressions directly to optimized Go functions without any AST walking +type OptimizedVM struct { + compiler *compiler.DirectCompiler +} + +// NewOptimizedVM creates a new optimized VM +func NewOptimizedVM() *OptimizedVM { + return &OptimizedVM{ + compiler: compiler.NewDirectCompiler(), + } +} + +// CompileFilter compiles a FilterQL statement +func (vm *OptimizedVM) CompileFilter(filter *rel.FilterStatement) (*compiler.CompiledExpr, error) { + return vm.compiler.Compile(filter.Filter) +} + +// CompileNode compiles any expression node +func (vm *OptimizedVM) CompileNode(node expr.Node) (*compiler.CompiledExpr, error) { + return vm.compiler.Compile(node) +} + +// EvalFilter evaluates a FilterQL statement against a context +func (vm *OptimizedVM) EvalFilter(filter *rel.FilterStatement, ctx expr.EvalContext) (bool, bool) { + // Special case for match_all + if filter.Filter == nil { + return true, true + } + + switch n := filter.Filter.(type) { + case *expr.IdentityNode: + if n.Text == "*" || n.Text == "match_all" { + return true, true + } + } + + // Compile the filter (cached internally) + compiled, err := vm.CompileFilter(filter) + if err != nil { + // Fall back to the standard evaluator if compilation fails + return vm.Matches(ctx, filter) + } + + // Run the compiled version + result, ok := compiled.EvalFunc(ctx) + if !ok { + return false, false + } + + // Convert to bool + if bv, isBool := result.(value.BoolValue); isBool { + return bv.Val(), true + } + + return false, false +} + +// EvalNode evaluates any expression node against a context +func (ovm *OptimizedVM) EvalNode(node expr.Node, ctx expr.EvalContext) (value.Value, bool) { + // Compile the node (cached internally) + compiled, err := ovm.CompileNode(node) + if err != nil { + // Fall back to the standard evaluator if compilation fails + return vm.Eval(ctx, node) + } + + // Run the compiled version + return compiled.EvalFunc(ctx) +} + +// Matches evaluates a FilterQL statement to determine if the context matches +func (ovm *OptimizedVM) Matches(ctx expr.EvalContext, stmt *rel.FilterStatement) (bool, bool) { + return ovm.EvalFilter(stmt, ctx) +} + +// MatchesInc evaluates a FilterQL statement with an includer +func (ovm *OptimizedVM) MatchesInc(inc expr.Includer, cr expr.EvalContext, stmt *rel.FilterStatement) (bool, bool) { + evalCtx, ok := cr.(expr.EvalIncludeContext) + if !ok { + evalCtx = &filterql{EvalContext: cr, Includer: inc} + } + + return ovm.EvalFilter(stmt, evalCtx) +} + +// EvalFilterSelect evaluates a FilterSelect statement +func (ovm *OptimizedVM) EvalFilterSelect(sel *rel.FilterSelect, writeContext expr.ContextWriter, readContext expr.EvalContext) (bool, bool) { + ctx, ok := readContext.(expr.EvalIncludeContext) + if !ok { + ctx = &expr.IncludeContext{ContextReader: readContext} + } + + // Check filter condition + if sel.FilterStatement != nil { + matches, ok := ovm.Matches(ctx, sel.FilterStatement) + if !ok || !matches { + return false, ok + } + } + + // Process columns + for _, col := range sel.Columns { + // Check column guard + if col.Guard != nil { + guardResult, ok := ovm.EvalNode(col.Guard, readContext) + if !ok { + continue + } + + if guardBool, ok := guardResult.(value.BoolValue); ok && !guardBool.Val() { + continue // Skip this column + } + } + + // Evaluate column expression + v, ok := ovm.EvalNode(col.Expr, readContext) + if ok { + writeContext.Put(col, readContext, v) + } + } + + return true, true +} + +// filterql implementation for handling includes +type filterql struct { + expr.EvalContext + expr.Includer +} diff --git a/filterqlvm/filterqlvm_test.go b/filterqlvm/filterqlvm_test.go new file mode 100644 index 00000000..a1c0193b --- /dev/null +++ b/filterqlvm/filterqlvm_test.go @@ -0,0 +1,251 @@ +package filterqlvm + +import ( + "fmt" + "testing" + "time" + + "github.com/lytics/qlbridge/expr" + "github.com/lytics/qlbridge/rel" + "github.com/lytics/qlbridge/value" + "github.com/lytics/qlbridge/vm" +) + +// testContext implements expr.EvalContext for testing +type testContext struct { + data map[string]value.Value +} + +func (tc *testContext) Get(key string) (value.Value, bool) { + v, ok := tc.data[key] + left, right, hasNamespacing := expr.LeftRight(key) + + //u.Debugf("left:%q right:%q key=%v", left, right, key) + if hasNamespacing { + f, ok := tc.data[left] + if !ok { + return nil, false + } + mapf, ok := f.(value.Map) + if !ok { + return nil, false + } + v, ok := mapf.Get(right) + return v, ok + } + + if f, ok := tc.data[right]; ok { + return f, true + } + + return v, ok +} + +func (tc *testContext) Row() map[string]value.Value { + return tc.data +} + +func (tc *testContext) Ts() time.Time { + return time.Now() +} + +// Basic context with simple string and number fields +func newBasicContext() *testContext { + return &testContext{ + data: map[string]value.Value{ + "name": value.NewStringValue("John Doe"), + "age": value.NewIntValue(30), + "score": value.NewNumberValue(92.5), + "active": value.NewBoolValue(true), + "created": value.NewTimeValue(time.Date(2020, 1, 15, 0, 0, 0, 0, time.UTC)), + "tags": value.NewStringsValue([]string{"developer", "golang", "database"}), + "metadata": value.NewMapValue(map[string]interface{}{"level": 5, "type": "user"}), + }, + } +} + +// Complex context with nested data and multiple value types +func newComplexContext() *testContext { + tags := value.NewStringsValue([]string{ + "mobile", "web", "desktop", "backend", "frontend", + "developer", "tester", "designer", "manager", "admin", + }) + + metrics := make(map[string]interface{}) + for i := 0; i < 20; i++ { + metrics[fmt.Sprintf("metric_%d", i)] = float64(i * 10) + } + + return &testContext{ + data: map[string]value.Value{ + "user_id": value.NewIntValue(12345), + "username": value.NewStringValue("johndoe"), + "email": value.NewStringValue("john.doe@example.com"), + "first_name": value.NewStringValue("John"), + "last_name": value.NewStringValue("Doe"), + "age": value.NewIntValue(35), + "score": value.NewNumberValue(92.5), + "active": value.NewBoolValue(true), + "verified": value.NewBoolValue(true), + "premium": value.NewBoolValue(false), + "created_at": value.NewTimeValue(time.Date(2020, 1, 15, 0, 0, 0, 0, time.UTC)), + "updated_at": value.NewTimeValue(time.Date(2023, 6, 10, 0, 0, 0, 0, time.UTC)), + "last_login": value.NewTimeValue(time.Now().Add(-24 * time.Hour)), + "login_count": value.NewIntValue(150), + "tags": tags, + "roles": value.NewStringsValue([]string{"user", "editor"}), + "preferences": value.NewMapValue(map[string]interface{}{"theme": "dark", "notifications": true, "language": "en"}), + "address": value.NewMapValue(map[string]interface{}{"city": "New York", "country": "USA", "zip": "10001"}), + "metrics": value.NewMapValue(metrics), + "subscription": value.NewMapValue(map[string]interface{}{"plan": "pro", "amount": 99.99, "currency": "USD", "active": true}), + "devices": value.NewSliceValues([]value.Value{ + value.NewMapValue(map[string]interface{}{"type": "mobile", "os": "iOS", "last_used": time.Now().Add(-2 * 24 * time.Hour)}), + value.NewMapValue(map[string]interface{}{"type": "desktop", "os": "Windows", "last_used": time.Now().Add(-12 * time.Hour)}), + }), + }, + } +} + +// Define benchmark patterns with varying complexity +var benchmarkPatterns = []struct { + name string + filter string + complex bool // whether to use complex context +}{ + {"Simple equality", `name = "John Doe"`, false}, + {"Numeric comparison", `age > 25`, false}, + {"String LIKE", `name LIKE "John%"`, false}, + {"IN operator", `"golang" IN tags`, false}, + {"Simple AND", `name = "John Doe" AND age > 25`, false}, + {"Simple OR", `score > 95 OR active = true`, false}, + {"Nested logic", `(age > 20 OR score > 95) AND active = true`, false}, + + // More complex patterns for complex context + {"Complex string comparison", `first_name = "John" AND last_name = "Doe"`, true}, + {"Multiple numeric comparisons", `age >= 30 AND score > 90 AND login_count > 100`, true}, + {"Map field access", `preferences.theme = "dark"`, true}, + {"Multiple IN checks", `AND ("user" IN rolesm, "admin" NOT IN roles)`, true}, + {"Complex AND OR", `AND (OR (premium = false, score > 90), AND(verified = true, active = true))`, true}, + {"Date comparisons", `created_at < updated_at AND last_login > created_at`, true}, + {"Complex string operations", `AND (email LIKE "%@example.com" , username LIKE "john%" )`, true}, + {"Deep nested conditions", `OR (AND (age > 30, score > 90) , AND (login_count > 100 , OR (premium = true, active = true)))`, true}, + {"Field exists checks", `AND (EXISTS subscription , subscription.active = true)`, true}, + {"Mathematical operations", `AND(age * 2 > 50 , score / 10 > 9)`, true}, + + // Very complex pattern combining many operators + {"Very complex pattern", + `AND ( + AND (first_name = "John", last_name = "Doe"), + OR (age > 30, score >= 90) , + OR ("user" IN roles, "admin" IN roles), + OR (preferences.theme = "dark", preferences.language = "en"), + (created_at < updated_at), + AND (subscription.plan = "pro", subscription.amount < 100), + (email LIKE "%@example.com") + )`, + true}, +} + +// Benchmarks for Standard VM +func BenchmarkStandardVM(b *testing.B) { + for _, pattern := range benchmarkPatterns { + b.Run(pattern.name, func(b *testing.B) { + filter, err := rel.ParseFilterQL("FILTER " + pattern.filter) + if err != nil { + b.Fatalf("Failed to parse filter: %v", err) + } + + var ctx *testContext + if pattern.complex { + ctx = newComplexContext() + } else { + ctx = newBasicContext() + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + vm.Matches(ctx, filter) + } + }) + } +} + +// Benchmarks for Optimized VM +func BenchmarkOptimizedVM(b *testing.B) { + optimizedVM := NewOptimizedVM() + + for _, pattern := range benchmarkPatterns { + b.Run(pattern.name, func(b *testing.B) { + filter, err := rel.ParseFilterQL("FILTER " + pattern.filter) + if err != nil { + b.Fatalf("Failed to parse filter: %v", err) + } + + var ctx *testContext + if pattern.complex { + ctx = newComplexContext() + } else { + ctx = newBasicContext() + } + + // Pre-compile to ensure we're measuring execution time, not compilation + _, err = optimizedVM.CompileFilter(filter) + if err != nil { + b.Fatalf("Failed to compile filter: %v", err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + optimizedVM.Matches(ctx, filter) + } + }) + } +} + +// Verify that both VMs produce consistent results +func TestVMConsistency(t *testing.T) { + standardVM := &FilterQLVM{} // Wrapper for the standard VM + optimizedVM := NewOptimizedVM() + + for _, pattern := range benchmarkPatterns { + t.Run(pattern.name, func(t *testing.T) { + filter, err := rel.ParseFilterQL("FILTER " + pattern.filter) + if err != nil { + t.Fatalf("Failed to parse filter: %v", err) + } + + var ctx *testContext + if pattern.complex { + ctx = newComplexContext() + } else { + ctx = newBasicContext() + } + + // Run with standard VM + standardMatches, standardOk := standardVM.Matches(ctx, filter) + + // Run with optimized VM + optimizedMatches, optimizedOk := optimizedVM.Matches(ctx, filter) + + // Compare results + if standardOk != optimizedOk { + t.Errorf("VM ok status differs: standard=%v, optimized=%v", standardOk, optimizedOk) + } + + if standardMatches != optimizedMatches { + t.Errorf("VM match results differ: standard=%v, optimized=%v", standardMatches, optimizedMatches) + } + }) + } +} + +// FilterQLVM is a wrapper for the standard VM +type FilterQLVM struct{} + +func (ovm *FilterQLVM) Matches(ctx expr.EvalContext, stmt *rel.FilterStatement) (bool, bool) { + return vm.Matches(ctx, stmt) +} + +func (vm *FilterQLVM) MatchesInc(inc expr.Includer, cr expr.EvalContext, stmt *rel.FilterStatement) (bool, bool) { + return vm.MatchesInc(inc, cr, stmt) +} From 74ce072241d9ea8bb5345e74cd6ef594d8aac7ed Mon Sep 17 00:00:00 2001 From: AJ Roetker Date: Wed, 5 Mar 2025 21:37:26 -0800 Subject: [PATCH 3/3] Optimize the optimized vm more --- filterqlvm/compiler/compiler.go | 179 +++++++++++++++++++++++++++++--- filterqlvm/filterqlvm.go | 2 +- filterqlvm/filterqlvm_test.go | 11 +- vm/filterqlvm_test.go | 30 ++++-- 4 files changed, 188 insertions(+), 34 deletions(-) diff --git a/filterqlvm/compiler/compiler.go b/filterqlvm/compiler/compiler.go index e138d0ed..05c14497 100644 --- a/filterqlvm/compiler/compiler.go +++ b/filterqlvm/compiler/compiler.go @@ -3,6 +3,7 @@ package compiler import ( "fmt" "hash/fnv" + "strconv" "strings" "sync" @@ -10,6 +11,7 @@ import ( "github.com/lytics/qlbridge/expr" "github.com/lytics/qlbridge/lex" + "github.com/lytics/qlbridge/rel" "github.com/lytics/qlbridge/value" "github.com/lytics/qlbridge/vm" ) @@ -35,6 +37,32 @@ func NewDirectCompiler() *DirectCompiler { } } +func (c *DirectCompiler) CompileFilter(node *rel.FilterStatement) (*CompiledExpr, error) { + // Generate a hash for the node to use as cache key + hash := hashFilter(node) + + // Check cache first + c.cacheLock.RLock() + if compiled, ok := c.cache[hash]; ok { + c.cacheLock.RUnlock() + return compiled, nil + } + c.cacheLock.RUnlock() + + // Create the compiled expression + compiled, err := c.compileToFunc(node.Filter) + if err != nil { + return nil, err + } + + // Cache the result + c.cacheLock.Lock() + c.cache[hash] = compiled + c.cacheLock.Unlock() + + return compiled, nil +} + // Compile compiles an expression node into a direct evaluation function func (c *DirectCompiler) Compile(node expr.Node) (*CompiledExpr, error) { // Generate a hash for the node to use as cache key @@ -116,29 +144,69 @@ func (c *DirectCompiler) compileBinary(node *expr.BinaryNode) (*CompiledExpr, er evalFunc := func(ctx expr.EvalContext) (value.Value, bool) { // Get left and right values left, leftOk := leftExpr.EvalFunc(ctx) - if !leftOk { - return nil, false - } // Short-circuit for logical operators switch node.Operator.T { case lex.TokenLogicAnd, lex.TokenAnd: + if !leftOk { + return nil, false + } // If left is false, no need to evaluate right if leftBool, ok := left.(value.BoolValue); ok && !leftBool.Val() { return value.NewBoolValue(false), true } case lex.TokenLogicOr, lex.TokenOr: - // If left is true, no need to evaluate right - if leftBool, ok := left.(value.BoolValue); ok && leftBool.Val() { - return value.NewBoolValue(true), true + if leftOk { + // If left is true, no need to evaluate right + if leftBool, ok := left.(value.BoolValue); ok && leftBool.Val() { + return value.NewBoolValue(true), true + } } } right, rightOk := rightExpr.EvalFunc(ctx) - if !rightOk { + // If we could not evaluate either we can shortcut + if !leftOk && !rightOk { + switch node.Operator.T { + case lex.TokenLogicOr, lex.TokenOr: + return value.NewBoolValue(false), true + case lex.TokenEqualEqual, lex.TokenEqual: + // We don't alllow nil == nil here bc we have a NilValue type + // that we would use for that + return value.NewBoolValue(false), true + case lex.TokenNE: + return value.NewBoolValue(false), true + case lex.TokenGT, lex.TokenGE, lex.TokenLT, lex.TokenLE, lex.TokenLike: + return value.NewBoolValue(false), true + } return nil, false } + // Else if we can only evaluate right + if !leftOk { + switch node.Operator.T { + case lex.TokenIntersects, lex.TokenIN, lex.TokenContains, lex.TokenLike: + return value.NewBoolValue(false), true + } + } + + // Else if we can only evaluate one, we can short circuit as well + if !leftOk || !rightOk { + switch node.Operator.T { + case lex.TokenAnd, lex.TokenLogicAnd: + return value.NewBoolValue(false), true + case lex.TokenEqualEqual, lex.TokenEqual: + return value.NewBoolValue(false), true + case lex.TokenNE: + // they are technically not equal? + return value.NewBoolValue(true), true + case lex.TokenIN, lex.TokenIntersects: + return value.NewBoolValue(false), true + case lex.TokenGT, lex.TokenGE, lex.TokenLT, lex.TokenLE, lex.TokenLike: + return value.NewBoolValue(false), true + } + } + // Handle different operators switch node.Operator.T { case lex.TokenEqual, lex.TokenEqualEqual: @@ -163,6 +231,10 @@ func (c *DirectCompiler) compileBinary(node *expr.BinaryNode) (*CompiledExpr, er return value.NewBoolValue(lv.Val() > rv.Val()), true } else if rv, ok := right.(value.IntValue); ok { return value.NewBoolValue(lv.Val() > float64(rv.Val())), true + } else if rv, ok := right.(value.StringValue); ok { + if rf, err := strconv.ParseFloat(rv.Val(), 64); err == nil { + return value.NewBoolValue(lv.Val() > rf), true + } } case value.IntValue: if rv, ok := right.(value.NumberValue); ok { @@ -171,13 +243,22 @@ func (c *DirectCompiler) compileBinary(node *expr.BinaryNode) (*CompiledExpr, er return value.NewBoolValue(lv.Val() > rv.Val()), true } case value.StringValue: + if rv, ok := right.(value.TimeValue); ok { + leftTime, ok := value.ValueToTime(left) + if !ok { + return value.BoolValueFalse, false + } + return value.NewBoolValue(rv.Val().Unix() > leftTime.Unix()), true + } if rv, ok := right.(value.StringValue); ok { return value.NewBoolValue(lv.Val() > rv.Val()), true } case value.TimeValue: - if rv, ok := right.(value.TimeValue); ok { - return value.NewBoolValue(lv.Val().Unix() > rv.Val().Unix()), true + rightTime, ok := value.ValueToTime(right) + if !ok { + return value.BoolValueFalse, false } + return value.NewBoolValue(lv.Val().Unix() > rightTime.Unix()), true } // Try converting to strings return value.NewBoolValue(left.ToString() > right.ToString()), true @@ -190,6 +271,10 @@ func (c *DirectCompiler) compileBinary(node *expr.BinaryNode) (*CompiledExpr, er return value.NewBoolValue(lv.Val() >= rv.Val()), true } else if rv, ok := right.(value.IntValue); ok { return value.NewBoolValue(lv.Val() >= float64(rv.Val())), true + } else if rv, ok := right.(value.StringValue); ok { + if rf, err := strconv.ParseFloat(rv.Val(), 64); err == nil { + return value.NewBoolValue(lv.Val() >= rf), true + } } case value.IntValue: if rv, ok := right.(value.NumberValue); ok { @@ -198,13 +283,23 @@ func (c *DirectCompiler) compileBinary(node *expr.BinaryNode) (*CompiledExpr, er return value.NewBoolValue(lv.Val() >= rv.Val()), true } case value.StringValue: + // TODO (need this to work for all the operators) + if rv, ok := right.(value.TimeValue); ok { + leftTime, ok := value.ValueToTime(left) + if !ok { + return value.BoolValueFalse, false + } + return value.NewBoolValue(rv.Val().Unix() >= leftTime.Unix()), true + } if rv, ok := right.(value.StringValue); ok { return value.NewBoolValue(lv.Val() >= rv.Val()), true } case value.TimeValue: - if rv, ok := right.(value.TimeValue); ok { - return value.NewBoolValue(lv.Val().Unix() >= rv.Val().Unix()), true + rightTime, ok := value.ValueToTime(right) + if !ok { + return value.BoolValueFalse, false } + return value.NewBoolValue(lv.Val().Unix() >= rightTime.Unix()), true } return value.NewBoolValue(left.ToString() >= right.ToString()), true @@ -216,6 +311,10 @@ func (c *DirectCompiler) compileBinary(node *expr.BinaryNode) (*CompiledExpr, er return value.NewBoolValue(lv.Val() < rv.Val()), true } else if rv, ok := right.(value.IntValue); ok { return value.NewBoolValue(lv.Val() < float64(rv.Val())), true + } else if rv, ok := right.(value.StringValue); ok { + if rf, err := strconv.ParseFloat(rv.Val(), 64); err == nil { + return value.NewBoolValue(lv.Val() < rf), true + } } case value.IntValue: if rv, ok := right.(value.NumberValue); ok { @@ -224,13 +323,22 @@ func (c *DirectCompiler) compileBinary(node *expr.BinaryNode) (*CompiledExpr, er return value.NewBoolValue(lv.Val() < rv.Val()), true } case value.StringValue: + if rv, ok := right.(value.TimeValue); ok { + leftTime, ok := value.ValueToTime(left) + if !ok { + return value.BoolValueFalse, false + } + return value.NewBoolValue(rv.Val().Unix() < leftTime.Unix()), true + } if rv, ok := right.(value.StringValue); ok { return value.NewBoolValue(lv.Val() < rv.Val()), true } case value.TimeValue: - if rv, ok := right.(value.TimeValue); ok { - return value.NewBoolValue(lv.Val().Unix() < rv.Val().Unix()), true + rightTime, ok := value.ValueToTime(right) + if !ok { + return value.BoolValueFalse, false } + return value.NewBoolValue(lv.Val().Unix() < rightTime.Unix()), true } return value.NewBoolValue(left.ToString() < right.ToString()), true @@ -242,21 +350,36 @@ func (c *DirectCompiler) compileBinary(node *expr.BinaryNode) (*CompiledExpr, er return value.NewBoolValue(lv.Val() <= rv.Val()), true } else if rv, ok := right.(value.IntValue); ok { return value.NewBoolValue(lv.Val() <= float64(rv.Val())), true + } else if rv, ok := right.(value.StringValue); ok { + if rf, err := strconv.ParseFloat(rv.Val(), 64); err == nil { + return value.NewBoolValue(lv.Val() <= rf), true + } } case value.IntValue: + // TODO (Add parsing of strings to ints) + // TODO (How to handle string floats?) if rv, ok := right.(value.NumberValue); ok { return value.NewBoolValue(float64(lv.Val()) <= rv.Val()), true } else if rv, ok := right.(value.IntValue); ok { return value.NewBoolValue(lv.Val() <= rv.Val()), true } case value.StringValue: + if rv, ok := right.(value.TimeValue); ok { + leftTime, ok := value.ValueToTime(left) + if !ok { + return value.BoolValueFalse, false + } + return value.NewBoolValue(rv.Val().Unix() <= leftTime.Unix()), true + } if rv, ok := right.(value.StringValue); ok { return value.NewBoolValue(lv.Val() <= rv.Val()), true } case value.TimeValue: - if rv, ok := right.(value.TimeValue); ok { - return value.NewBoolValue(lv.Val().Unix() <= rv.Val().Unix()), true + rightTime, ok := value.ValueToTime(right) + if !ok { + return value.BoolValueFalse, false } + return value.NewBoolValue(lv.Val().Unix() <= rightTime.Unix()), true } return value.NewBoolValue(left.ToString() <= right.ToString()), true @@ -445,6 +568,26 @@ func (c *DirectCompiler) compileBinary(node *expr.BinaryNode) (*CompiledExpr, er // IN or INTERSECTS operation switch rv := right.(type) { case value.Slice: + if lv, ok := left.(value.Slice); ok { + // Check if any left item is in right slice + for _, item := range lv.SliceValue() { + for _, rightItem := range rv.SliceValue() { + if eq, err := value.Equal(item, rightItem); err == nil && eq { + return value.NewBoolValue(true), true + } + } + } + return value.NewBoolValue(false), true + } + if lv, ok := left.(value.Map); ok { + // Check if any left item is in right slice + for _, item := range rv.SliceValue() { + if _, exists := lv.Get(item.ToString()); exists { + return value.NewBoolValue(exists), true + } + } + return value.NewBoolValue(false), true + } // Check if left side is in right slice for _, item := range rv.SliceValue() { if eq, err := value.Equal(left, item); err == nil && eq { @@ -963,3 +1106,9 @@ func hashNode(node expr.Node) uint64 { h.Write([]byte(node.String())) return h.Sum64() } + +func hashFilter(filter *rel.FilterStatement) uint64 { + h := fnv.New64() + h.Write([]byte(filter.String())) + return h.Sum64() +} diff --git a/filterqlvm/filterqlvm.go b/filterqlvm/filterqlvm.go index d9ef78f7..d9aa2e3b 100644 --- a/filterqlvm/filterqlvm.go +++ b/filterqlvm/filterqlvm.go @@ -23,7 +23,7 @@ func NewOptimizedVM() *OptimizedVM { // CompileFilter compiles a FilterQL statement func (vm *OptimizedVM) CompileFilter(filter *rel.FilterStatement) (*compiler.CompiledExpr, error) { - return vm.compiler.Compile(filter.Filter) + return vm.compiler.CompileFilter(filter) } // CompileNode compiles any expression node diff --git a/filterqlvm/filterqlvm_test.go b/filterqlvm/filterqlvm_test.go index a1c0193b..5b9a0310 100644 --- a/filterqlvm/filterqlvm_test.go +++ b/filterqlvm/filterqlvm_test.go @@ -20,7 +20,6 @@ func (tc *testContext) Get(key string) (value.Value, bool) { v, ok := tc.data[key] left, right, hasNamespacing := expr.LeftRight(key) - //u.Debugf("left:%q right:%q key=%v", left, right, key) if hasNamespacing { f, ok := tc.data[left] if !ok { @@ -175,12 +174,11 @@ func BenchmarkOptimizedVM(b *testing.B) { optimizedVM := NewOptimizedVM() for _, pattern := range benchmarkPatterns { + filter, err := rel.ParseFilterQL("FILTER " + pattern.filter) + if err != nil { + b.Fatalf("Failed to parse filter: %v", err) + } b.Run(pattern.name, func(b *testing.B) { - filter, err := rel.ParseFilterQL("FILTER " + pattern.filter) - if err != nil { - b.Fatalf("Failed to parse filter: %v", err) - } - var ctx *testContext if pattern.complex { ctx = newComplexContext() @@ -193,7 +191,6 @@ func BenchmarkOptimizedVM(b *testing.B) { if err != nil { b.Fatalf("Failed to compile filter: %v", err) } - b.ResetTimer() for i := 0; i < b.N; i++ { optimizedVM.Matches(ctx, filter) diff --git a/vm/filterqlvm_test.go b/vm/filterqlvm_test.go index e2dfb18f..54c36d33 100644 --- a/vm/filterqlvm_test.go +++ b/vm/filterqlvm_test.go @@ -10,9 +10,11 @@ import ( "github.com/araddon/dateparse" u "github.com/araddon/gou" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/lytics/qlbridge/datasource" "github.com/lytics/qlbridge/expr" + "github.com/lytics/qlbridge/filterqlvm" "github.com/lytics/qlbridge/rel" "github.com/lytics/qlbridge/value" "github.com/lytics/qlbridge/vm" @@ -126,6 +128,7 @@ func TestFilterQlVm(t *testing.T) { `FILTER NOT ( Created > "now-1d") `, // Date Math (negated) `FILTER NOT ( FakeDate > "now-1d") `, // Date Math (negated, missing field) `FILTER Updated > "now-2h"`, // Date Math + `FILTER Updated < "now-30m"`, // Date Math `FILTER transactions < "now-1h"`, // Date Compare with []time.Time `FILTER FirstEvent.signedup < "now-2h"`, // Date Math on map[string]time `FILTER FirstEvent.signedup == "12/18/2015"`, // Date equality on map[string]time @@ -140,11 +143,11 @@ func TestFilterQlVm(t *testing.T) { `FILTER lastevent NOT IN ("not-gonna-happen")`, `FILTER *`, // match all `FILTER OR ( - name == "Rey" -- false + name == "Rey" -- false INCLUDE match_all_include )`, `FILTER OR ( - name == "Rey" -- false + name == "Rey" -- false INCLUDE is_yoda_true )`, `FILTER OR ( @@ -163,13 +166,14 @@ func TestFilterQlVm(t *testing.T) { // Coerce strings to numbers when appropriate `FILTER AND (zip == "5", BankAmount > "50")`, `FILTER bankamount > "9.4"`, + `FILTER bankamount < "100001"`, `FILTER AND (zip == 5, "Yoda" == name, OR ( city IN ( "Portland, OR", "New York, NY", "Peoria, IL" ) ) )`, `FILTER OR ( - EXISTS q, - AND ( - zip > 0, - OR ( zip > 10000, zip < 100 ) - ), + EXISTS q, + AND ( + zip > 0, + OR ( zip > 10000, zip < 100 ) + ), NOT ( name == "Yoda" ) )`, `FILTER hits.foo > 1.5`, `FILTER hits.foo > "1.5"`, @@ -182,12 +186,16 @@ func TestFilterQlVm(t *testing.T) { //u.Debugf("len hits: %v", len(hitsx)) //expr.Trace = true + optimizedVM := filterqlvm.NewOptimizedVM() for _, q := range hits { fs, err := rel.ParseFilterQL(q) assert.Equal(t, nil, err) match, ok := vm.Matches(incctx, fs) assert.True(t, ok, "should be ok matching on query %q: %v", q, ok) assert.True(t, match, q) + match, ok = optimizedVM.Matches(incctx, fs) + require.True(t, ok, "should be ok matching on query %q: %v", q, ok) + require.True(t, match, q) match, ok = vm.MatchesExpr(incctx, fs.Filter) assert.True(t, ok, "should be ok matching on query %q: %v", q, ok) assert.True(t, match, q) @@ -229,23 +237,23 @@ func TestFilterQlVm(t *testing.T) { SELECT name , zip IF zip > 2 - FROM mycontext + FROM mycontext FILTER name == "Yoda"`, map[string]interface{}{"name": "Yoda", "zip": 5}}, {` SELECT name , zip IF zip > 200 - FROM mycontext + FROM mycontext FILTER name == "Yoda"`, map[string]interface{}{"name": "Yoda"}}, {` SELECT name IF name < true - FROM mycontext + FROM mycontext FILTER name == "Yoda"`, nil}, {` SELECT name IF zip + 5 - FROM mycontext + FROM mycontext FILTER name == "Yoda"`, nil}, } for _, test := range filterSelects {