diff --git a/cel/folding_test.go b/cel/folding_test.go index aa671badc..08d5f5177 100644 --- a/cel/folding_test.go +++ b/cel/folding_test.go @@ -366,7 +366,10 @@ func TestConstantFoldingOptimizer(t *testing.T) { if err != nil { t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) } - opt := NewStaticOptimizer(folder) + opt, err := NewStaticOptimizer(folder) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss := opt.Optimize(e, checked) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -441,7 +444,10 @@ func TestConstantFoldingCallsWithSideEffects(t *testing.T) { if err != nil { t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) } - opt := NewStaticOptimizer(folder) + opt, err := NewStaticOptimizer(folder) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss := opt.Optimize(e, checked) if tc.error != "" { if iss.Err() == nil { @@ -508,7 +514,10 @@ func TestConstantFoldingOptimizerMacroElimination(t *testing.T) { if err != nil { t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) } - opt := NewStaticOptimizer(folder) + opt, err := NewStaticOptimizer(folder) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss := opt.Optimize(e, checked) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -570,7 +579,10 @@ func TestConstantFoldingOptimizerWithLimit(t *testing.T) { if err != nil { t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) } - opt := NewStaticOptimizer(folder) + opt, err := NewStaticOptimizer(folder) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss := opt.Optimize(e, checked) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -828,7 +840,10 @@ func TestConstantFoldingNormalizeIDs(t *testing.T) { if err != nil { t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) } - opt := NewStaticOptimizer(folder) + opt, err := NewStaticOptimizer(folder) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss := opt.Optimize(e, checked) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) diff --git a/cel/inlining_test.go b/cel/inlining_test.go index 856eaafcb..4dec1074a 100644 --- a/cel/inlining_test.go +++ b/cel/inlining_test.go @@ -220,7 +220,10 @@ func TestInliningOptimizer(t *testing.T) { t.Fatalf("Compile() failed: %v", iss.Err()) } - opt := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...)) + opt, err := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...)) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss := opt.Optimize(e, checked) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -236,7 +239,10 @@ func TestInliningOptimizer(t *testing.T) { if err != nil { t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) } - opt = cel.NewStaticOptimizer(folder) + opt, err = cel.NewStaticOptimizer(folder) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss = opt.Optimize(e, optimized) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -727,7 +733,10 @@ func TestInliningOptimizerMultiStage(t *testing.T) { t.Fatalf("Compile() failed: %v", iss.Err()) } - opt := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...)) + opt, err := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...)) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss := opt.Optimize(e, checked) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -743,7 +752,10 @@ func TestInliningOptimizerMultiStage(t *testing.T) { if err != nil { t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) } - opt = cel.NewStaticOptimizer(folder) + opt, err = cel.NewStaticOptimizer(folder) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss = opt.Optimize(e, optimized) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) diff --git a/cel/optimizer.go b/cel/optimizer.go index 9a2a97a64..3ffeb4281 100644 --- a/cel/optimizer.go +++ b/cel/optimizer.go @@ -15,6 +15,7 @@ package cel import ( + "fmt" "sort" "github.com/google/cel-go/common" @@ -33,13 +34,37 @@ import ( // should be suitable for calls to parser.Unparse. type StaticOptimizer struct { optimizers []ASTOptimizer + source *Source } +type OptimizerOption func(*StaticOptimizer) (*StaticOptimizer, error) + // NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied // to a checked expression. -func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer { - return &StaticOptimizer{ - optimizers: optimizers, +func NewStaticOptimizer(options ...any) (*StaticOptimizer, error) { + so := &StaticOptimizer{} + var err error + for _, opt := range options { + switch v := opt.(type) { + case ASTOptimizer: + so.optimizers = append(so.optimizers, v) + case OptimizerOption: + so, err = v(so) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unsupported option: %v", v) + } + } + return so, nil +} + +// OptimizeWithSource overrides the source used by the optimizer. +func OptimizeWithSource(source Source) OptimizerOption { + return func(so *StaticOptimizer) (*StaticOptimizer, error) { + so.source = &source + return so, nil } } @@ -49,15 +74,23 @@ func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer { func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { // Make a copy of the AST to be optimized. optimized := ast.Copy(a.NativeRep()) + if opt.source != nil { + optimized = ast.NewAST(optimized.Expr(), ast.NewSourceInfo(*opt.source)) + } ids := newIDGenerator(ast.MaxID(a.NativeRep())) + source := a.Source() + if opt.source != nil { + source = *opt.source + } // Create the optimizer context, could be pooled in the future. - issues := NewIssues(common.NewErrors(a.Source())) + issues := NewIssues(common.NewErrors(source)) baseFac := ast.NewExprFactory() exprFac := &optimizerExprFactory{ - idGenerator: ids, - fac: baseFac, - sourceInfo: optimized.SourceInfo(), + idGenerator: ids, + fac: baseFac, + sourceInfo: optimized.SourceInfo(), + mergeSourceInfo: opt.source != nil, } ctx := &OptimizerContext{ optimizerExprFactory: exprFac, @@ -75,12 +108,12 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { freshIDGen := newIDGenerator(0) info := optimized.SourceInfo() expr := optimized.Expr() - normalizeIDs(freshIDGen.renumberStable, expr, info) + normalizeIDs(freshIDGen.renumberStable, expr, info, exprFac.mergeSourceInfo) cleanupMacroRefs(expr, info) // Recheck the updated expression for any possible type-agreement or validation errors. parsed := &Ast{ - source: a.Source(), + source: source, impl: ast.NewAST(expr, info)} checked, iss := ctx.Check(parsed) if iss.Err() != nil { @@ -91,15 +124,19 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { // Return the optimized result. return &Ast{ - source: a.Source(), + source: source, impl: optimized, }, nil } // normalizeIDs ensures that the metadata present with an AST is reset in a manner such // that the ids within the expression correspond to the ids within macros. -func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo) { +func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo, mergeSourceInfo bool) { optimized.RenumberIDs(idGen) + if mergeSourceInfo { + info.RenumberIDs(idGen) + } + if len(info.MacroCalls()) == 0 { return } @@ -229,8 +266,9 @@ type ASTOptimizer interface { type optimizerExprFactory struct { *idGenerator - fac ast.ExprFactory - sourceInfo *ast.SourceInfo + fac ast.ExprFactory + sourceInfo *ast.SourceInfo + mergeSourceInfo bool } // NewAST creates an AST from the current expression using the tracked source info which @@ -249,7 +287,7 @@ func (opt *optimizerExprFactory) CopyAST(a *ast.AST) (ast.Expr, *ast.SourceInfo) defer func() { opt.seed = idGen.nextID() }() copyExpr := opt.fac.CopyExpr(a.Expr()) copyInfo := ast.CopySourceInfo(a.SourceInfo()) - normalizeIDs(idGen.renumberStable, copyExpr, copyInfo) + normalizeIDs(idGen.renumberStable, copyExpr, copyInfo, opt.mergeSourceInfo) return copyExpr, copyInfo } @@ -260,6 +298,11 @@ func (opt *optimizerExprFactory) CopyASTAndMetadata(a *ast.AST) ast.Expr { for macroID, call := range copyInfo.MacroCalls() { opt.SetMacroCall(macroID, call) } + if opt.mergeSourceInfo { + for id, offset := range copyInfo.OffsetRanges() { + opt.sourceInfo.SetOffsetRange(id, offset) + } + } return copyExpr } diff --git a/cel/optimizer_test.go b/cel/optimizer_test.go index f406e7c75..07a007c72 100644 --- a/cel/optimizer_test.go +++ b/cel/optimizer_test.go @@ -45,7 +45,10 @@ func TestStaticOptimizerUpdateExpr(t *testing.T) { if iss.Err() != nil { t.Fatalf("Compile() failed: %v", iss.Err()) } - opt := cel.NewStaticOptimizer(&testOptimizer{t: t, inlineExpr: inlinedAST.NativeRep()}) + opt, err := cel.NewStaticOptimizer(&testOptimizer{t: t, inlineExpr: inlinedAST.NativeRep()}) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optAST, iss := opt.Optimize(e, exprAST) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -186,7 +189,10 @@ func TestStaticOptimizerNewAST(t *testing.T) { if iss.Err() != nil { t.Fatalf("Compile(%q) failed: %v", tc, iss.Err()) } - opt := cel.NewStaticOptimizer(&identityOptimizer{t: t}) + opt, err := cel.NewStaticOptimizer(&identityOptimizer{t: t}) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optAST, iss := opt.Optimize(e, exprAST) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -204,7 +210,10 @@ func TestStaticOptimizerNewAST(t *testing.T) { func TestStaticOptimizerNilAST(t *testing.T) { env := optimizerEnv(t) - opt := cel.NewStaticOptimizer(&identityOptimizer{t: t}) + opt, err := cel.NewStaticOptimizer(&identityOptimizer{t: t}) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optAST, iss := opt.Optimize(env, nil) if iss.Err() == nil || !strings.Contains(iss.Err().Error(), "unexpected unspecified type") { t.Errorf("opt.Optimize(env, nil) got (%v, %v), wanted unexpected unspecified type", optAST, iss) diff --git a/checker/checker.go b/checker/checker.go index 3bb61c19c..280150537 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -67,6 +67,15 @@ func Check(parsed *ast.AST, source common.Source, env *Env) (*ast.AST, *common.E for id, t := range c.TypeMap() { c.SetType(id, substitute(c.mappings, t, true)) } + // Remove source info for IDs without a corresponding AST node. This can happen because + // check() deletes some nodes while rewriting the AST. For example the Select operand is + // deleted when a variable reference is replaced with a Ident expression. + ids := c.AST.IDs() + for id := range c.AST.SourceInfo().OffsetRanges() { + if !ids[id] { + c.AST.SourceInfo().ClearOffsetRange(id) + } + } return c.AST, errs } diff --git a/checker/checker_test.go b/checker/checker_test.go index 025de9615..c6133db39 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -23,6 +23,7 @@ import ( "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/containers" + "github.com/google/cel-go/common/debug" "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/stdlib" "github.com/google/cel-go/common/types" @@ -2553,6 +2554,18 @@ func TestCheck(t *testing.T) { t.Errorf("Expected error not thrown: %s", tc.err) } + astIDs := cAst.IDs() + missingIDs := []int64{} + for id := range cAst.SourceInfo().OffsetRanges() { + if !astIDs[id] { + missingIDs = append(missingIDs, id) + } + } + if len(missingIDs) > 0 { + t.Errorf("SourceInfo has offset range for IDs %v, but no such nodes exists in AST: %s", + missingIDs, debug.ToDebugStringWithIDs(cAst.Expr())) + } + actual := cAst.GetType(pAst.Expr().ID()) if tc.err == "" { if actual == nil || !actual.IsEquivalentType(tc.outType) { diff --git a/common/ast/ast.go b/common/ast/ast.go index 62c09cfc6..d10d4eca7 100644 --- a/common/ast/ast.go +++ b/common/ast/ast.go @@ -16,6 +16,8 @@ package ast import ( + "slices" + "github.com/google/cel-go/common" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" @@ -160,6 +162,16 @@ func MaxID(a *AST) int64 { return visitor.maxID + 1 } +// IDs returns the set of AST node IDs, including macro calls. +func (a *AST) IDs() map[int64]bool { + visitor := make(idVisitor) + PostOrderVisit(a.Expr(), visitor) + for _, call := range a.SourceInfo().MacroCalls() { + PostOrderVisit(call, visitor) + } + return visitor +} + // Heights computes the heights of all AST expressions and returns a map from expression id to height. func Heights(a *AST) map[int64]int { visitor := make(heightVisitor) @@ -232,6 +244,23 @@ type SourceInfo struct { macroCalls map[int64]Expr } +// RenumberIDs performs an in-place update of the expression IDs within the SourceInfo. +func (s *SourceInfo) RenumberIDs(idGen IDGenerator) { + if s == nil { + return + } + oldIDs := []int64{} + for id := range s.offsetRanges { + oldIDs = append(oldIDs, id) + } + slices.Sort(oldIDs) + newRanges := make(map[int64]OffsetRange) + for _, id := range oldIDs { + newRanges[idGen(id)] = s.offsetRanges[id] + } + s.offsetRanges = newRanges +} + // SyntaxVersion returns the syntax version associated with the text expression. func (s *SourceInfo) SyntaxVersion() string { if s == nil { @@ -533,3 +562,13 @@ func (hv heightVisitor) maxEntryHeight(entries ...EntryExpr) int { } return max } + +type idVisitor map[int64]bool + +func (v idVisitor) VisitExpr(e Expr) { + v[e.ID()] = true +} + +func (v idVisitor) VisitEntryExpr(e EntryExpr) { + v[e.ID()] = true +} diff --git a/common/ast/ast_test.go b/common/ast/ast_test.go index 7a1c6a141..0a86c3936 100644 --- a/common/ast/ast_test.go +++ b/common/ast/ast_test.go @@ -16,6 +16,7 @@ package ast_test import ( "fmt" + "maps" "reflect" "testing" @@ -388,3 +389,40 @@ func (src *mockSource) OffsetLocation(offset int32) (common.Location, bool) { } return src.Source.OffsetLocation(offset) } + +func TestSourceInfoRenumberIDs(t *testing.T) { + info := ast.NewSourceInfo(nil) + for old := int64(1); old <= 5; old++ { + info.SetOffsetRange(old, ast.OffsetRange{Start: int32(old), Stop: int32(old) + 1}) + } + original := make(map[int64]ast.OffsetRange) + maps.Copy(original, info.OffsetRanges()) + + // Verify the renumbering is stable. + var next int64 = 101 + idMap := make(map[int64]int64) + idGen := func(old int64) int64 { + if _, found := idMap[old]; !found { + idMap[old] = next + next = next + 1 + } + return idMap[old] + } + info.RenumberIDs(idGen) + + if len(info.OffsetRanges()) != 5 { + t.Errorf("got %d offset ranges, wanted 5", len(info.OffsetRanges())) + } + + for old := int64(1); old <= 5; old++ { + want := original[old] + new := old + 100 + got, found := info.GetOffsetRange(new) + if !found { + t.Errorf("offset range for ID %d not found", new) + } + if got != want { + t.Errorf("offset range for ID %d incorrect; got %v, want %v", new, got, want) + } + } +} diff --git a/common/debug/debug.go b/common/debug/debug.go index 75f5f0d63..fbc847f0c 100644 --- a/common/debug/debug.go +++ b/common/debug/debug.go @@ -312,3 +312,18 @@ func (w *debugWriter) removeIndent() { func (w *debugWriter) String() string { return w.buffer.String() } + +type idAdorner struct{} + +func (a *idAdorner) GetMetadata(elem any) string { + e, isExpr := elem.(ast.Expr) + if !isExpr { + return "" + } + return fmt.Sprintf("@id:%d ", e.ID()) +} + +// ToDebugStringWithIDs returns a string representation with AST node IDs. +func ToDebugStringWithIDs(e ast.Expr) string { + return ToAdornedDebugString(e, &idAdorner{}) +} diff --git a/ext/sets_test.go b/ext/sets_test.go index 74163db62..4ea409b67 100644 --- a/ext/sets_test.go +++ b/ext/sets_test.go @@ -430,7 +430,10 @@ func TestSetsMembershipRewriter(t *testing.T) { if err != nil { t.Fatalf("NewSetMembershipOptimizer() failed with error: %v", err) } - opt := cel.NewStaticOptimizer(setsOpt) + opt, err := cel.NewStaticOptimizer(setsOpt) + if err != nil { + t.Fatalf("cel.NewStaticOptimizer() failed with error: %v", err) + } optAST, iss := opt.Optimize(env, a) if iss.Err() != nil { t.Fatalf("opt.Optimize() failed: %v", iss.Err()) diff --git a/policy/BUILD.bazel b/policy/BUILD.bazel index 97650ff92..6b6ba5b93 100644 --- a/policy/BUILD.bazel +++ b/policy/BUILD.bazel @@ -57,6 +57,7 @@ go_test( size = "small", srcs = [ "compiler_test.go", + "composer_test.go", "config_test.go", "helper_test.go", "parser_test.go", @@ -72,6 +73,8 @@ go_test( "//test/proto3pb:go_default_library", "@in_yaml_go_yaml_v3//:go_default_library", "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "@org_golang_google_protobuf//encoding/prototext:go_default_library", ], ) diff --git a/policy/compiler_test.go b/policy/compiler_test.go index 4df6928ee..5ad4d801d 100644 --- a/policy/compiler_test.go +++ b/policy/compiler_test.go @@ -20,6 +20,7 @@ import ( "strings" "testing" + "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" "github.com/google/cel-go/cel" @@ -73,6 +74,8 @@ func TestRuleComposerUnnest(t *testing.T) { if iss.Err() != nil { t.Fatalf("Compose(rule) failed: %v", iss.Err()) } + policy := parsePolicy(t, tc.name, []ParserOption{}) + verifySourceInfoCoverage(t, policy, ast) unparsed, err := cel.AstToString(ast) if err != nil { t.Fatalf("cel.AstToString() failed: %v", err) @@ -507,6 +510,7 @@ func compile(t testing.TB, name string, policy *Policy, envOpts []cel.EnvOption, t.Fatalf("env.Extend() with config options %v, failed: %v", config, err) } ast, iss := Compile(env, policy, compilerOpts...) + verifySourceInfoCoverage(t, policy, ast) return env, ast, iss } @@ -516,3 +520,81 @@ func normalize(s string) string { strings.ReplaceAll(s, " ", ""), "\n", ""), "\t", "") } + +func verifySourceInfoCoverage(t testing.TB, policy *Policy, ast *cel.Ast) { + t.Helper() + info := ast.SourceInfo() + if info == nil { + return + } + + exprLines := exprLinesFromPolicy(policy) + coveredLines := make(map[int]bool) + ids := ast.NativeRep().IDs() + for id, offset := range info.GetPositions() { + // Check that each position in the SourceInfo corresponds to a valid AST node. + if !ids[id] { + t.Errorf("id %d not found in AST", id) + } + loc, found := ast.Source().OffsetLocation(offset) + if found { + coveredLines[loc.Line()] = true + } else { + t.Errorf("invalid source location for offset %d", offset) + } + } + // Verify that each source line inside an expression is covered by the at least one node in the + // AST. + for line := range exprLines { + if !coveredLines[line] { + t.Errorf("Line %d expected to be covered by SourceInfo, but was not", line) + } + } + + if t.Failed() { + checked, err := cel.AstToCheckedExpr(ast) + if err != nil { + t.Logf("cel.AstToCheckedExpr() failed: %v", err) + } else { + t.Logf("AST:\n%s", prototext.Format(checked.GetExpr())) + } + } +} + +// exprLinesFromPolicy returns a set of line numbers within a policy where expressions (variables, +// conditions, etc.) are defined. +func exprLinesFromPolicy(policy *Policy) map[int]bool { + lines := make(map[int]bool) + addExpectedLines := func(vs ValueString) { + if offset, found := policy.SourceInfo().GetOffsetRange(vs.ID); found { + startLoc, foundStart := policy.Source().OffsetLocation(offset.Start) + // Multiline strings can span multiple lines, but the SourceInfo will only contain the start + // position of the expression. So just skip the check if the expression contains a multiline + // string literal. + hasMultiline := strings.Contains(vs.Value, "'''") || strings.Contains(vs.Value, "\"\"\"") + if foundStart && !hasMultiline { + numLines := strings.Count(vs.Value, "\n") + for i := 0; i <= numLines; i++ { + lines[startLoc.Line()+i] = true + } + } + } + } + var traverseRule func(r *Rule) + traverseRule = func(r *Rule) { + for _, v := range r.Variables() { + addExpectedLines(v.Expression()) + } + for _, m := range r.Matches() { + addExpectedLines(m.Condition()) + if m.HasOutput() { + addExpectedLines(m.Output()) + } + if m.HasRule() { + traverseRule(m.Rule()) + } + } + } + traverseRule(policy.Rule()) + return lines +} diff --git a/policy/composer.go b/policy/composer.go index 762472487..300de8210 100644 --- a/policy/composer.go +++ b/policy/composer.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/google/cel-go/cel" + "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" @@ -72,11 +73,21 @@ type RuleComposer struct { // Compose stitches together a set of expressions within a CompiledRule into a single CEL ast. func (c *RuleComposer) Compose(r *CompiledRule) (*cel.Ast, *cel.Issues) { ruleRoot, _ := c.env.Compile("true") + composer := &ruleComposerImpl{ rule: r, varIndices: []varIndex{}, } - opt := cel.NewStaticOptimizer(composer) + // ruleRoot is a placeholder expression used as the root of the AST before optimization. Because + // StaticOptimizer would normally copy the source info from it, we must use OptimizeWithSource + // to override the correct source. + source := recoverOriginalSource(r) + opt, err := cel.NewStaticOptimizer(composer, cel.OptimizeWithSource(source)) + if err != nil { + errs := common.NewErrors(source) + errs.ReportErrorString(common.NoLocation, err.Error()) + return nil, cel.NewIssues(errs) + } ast, iss := opt.Optimize(c.env, ruleRoot) if iss.Err() != nil { return nil, iss @@ -85,10 +96,29 @@ func (c *RuleComposer) Compose(r *CompiledRule) (*cel.Ast, *cel.Issues) { varIndices: []varIndex{}, exprUnnestHeight: c.exprUnnestHeight, } - opt = cel.NewStaticOptimizer(unnester) + opt, err = cel.NewStaticOptimizer(unnester) + if err != nil { + errs := common.NewErrors(source) + errs.ReportErrorString(common.NoLocation, err.Error()) + return nil, cel.NewIssues(errs) + } return opt.Optimize(c.env, ast) } +func recoverOriginalSource(r *CompiledRule) cel.Source { + match := r.Matches()[0] + var source cel.Source + if match.Output() != nil { + source = match.Output().Expr().Source() + } else { + source = match.Condition().Source() + } + if rel, ok := source.(*RelativeSource); ok { + source = rel.Source + } + return source +} + type varIndex struct { index int indexVar string diff --git a/policy/composer_test.go b/policy/composer_test.go new file mode 100644 index 000000000..14cc3c772 --- /dev/null +++ b/policy/composer_test.go @@ -0,0 +1,77 @@ +package policy + +import ( + "testing" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/ext" + "google.golang.org/protobuf/encoding/prototext" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +type positionChecker struct { + t *testing.T + si *exprpb.SourceInfo +} + +func (v *positionChecker) VisitExpr(e ast.Expr) { + if e.AsCall() != nil && e.AsCall().FunctionName() == "_?_:_" { + // Ignore ternary operator as it's a synthetic node. + return + } + if _, found := v.si.Positions[e.ID()]; !found { + pbExpr, _ := ast.ExprToProto(e) + v.t.Errorf("No position found for expression ID: %d, %s", e.ID(), prototext.Format(pbExpr)) + } +} + +func (v *positionChecker) VisitEntryExpr(e ast.EntryExpr) { + if _, found := v.si.Positions[e.ID()]; !found { + v.t.Errorf("No position found for expression ID: %d", e.ID()) + } +} + +func TestCompose_SourceInfo(t *testing.T) { + policyYAML := `name: test_policy +rule: + match: + - condition: "2 == 1" + output: "'hi'" + - output: "'hello' + ' world'" +` + src := StringSource(policyYAML, "test_policy.yaml") + parser, err := NewParser() + if err != nil { + t.Fatalf("NewParser() failed: %v", err) + } + policy, iss := parser.Parse(src) + if iss.Err() != nil { + t.Fatalf("parser.Parse() failed: %v", iss.Err()) + } + + env, err := cel.NewEnv(cel.OptionalTypes(), ext.Bindings()) + if err != nil { + t.Fatalf("cel.NewEnv() failed: %v", err) + } + compiledRule, iss := CompileRule(env, policy) + if iss.Err() != nil { + t.Fatalf("CompileRule() failed: %v", iss.Err()) + } + composer, err := NewRuleComposer(env) + if err != nil { + t.Fatalf("NewRuleComposer() failed: %v", err) + } + compAST, iss := composer.Compose(compiledRule) + if iss.Err() != nil { + t.Fatalf("composer.Compose() failed: %v", iss.Err()) + } + + si := compAST.SourceInfo() + if si.Location != "test_policy.yaml" { + t.Errorf("SourceInfo.Location got %q, wanted test_policy.yaml", si.Location) + } + checker := &positionChecker{t: t, si: si} + ast.PostOrderVisit(compAST.NativeRep().Expr(), checker) +}