From eb6abd415743b1a864285606ed9743056e58e637 Mon Sep 17 00:00:00 2001 From: abenea Date: Fri, 16 Jan 2026 16:20:35 +0100 Subject: [PATCH 1/4] Clean up unused source info after checker rewrites the AST. --- checker/checker.go | 9 +++++++++ checker/checker_test.go | 13 +++++++++++++ common/ast/ast.go | 20 ++++++++++++++++++++ common/debug/debug.go | 15 +++++++++++++++ 4 files changed, 57 insertions(+) diff --git a/checker/checker.go b/checker/checker.go index 3bb61c19c..280150537 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -67,6 +67,15 @@ func Check(parsed *ast.AST, source common.Source, env *Env) (*ast.AST, *common.E for id, t := range c.TypeMap() { c.SetType(id, substitute(c.mappings, t, true)) } + // Remove source info for IDs without a corresponding AST node. This can happen because + // check() deletes some nodes while rewriting the AST. For example the Select operand is + // deleted when a variable reference is replaced with a Ident expression. + ids := c.AST.IDs() + for id := range c.AST.SourceInfo().OffsetRanges() { + if !ids[id] { + c.AST.SourceInfo().ClearOffsetRange(id) + } + } return c.AST, errs } diff --git a/checker/checker_test.go b/checker/checker_test.go index 025de9615..c6133db39 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -23,6 +23,7 @@ import ( "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/containers" + "github.com/google/cel-go/common/debug" "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/stdlib" "github.com/google/cel-go/common/types" @@ -2553,6 +2554,18 @@ func TestCheck(t *testing.T) { t.Errorf("Expected error not thrown: %s", tc.err) } + astIDs := cAst.IDs() + missingIDs := []int64{} + for id := range cAst.SourceInfo().OffsetRanges() { + if !astIDs[id] { + missingIDs = append(missingIDs, id) + } + } + if len(missingIDs) > 0 { + t.Errorf("SourceInfo has offset range for IDs %v, but no such nodes exists in AST: %s", + missingIDs, debug.ToDebugStringWithIDs(cAst.Expr())) + } + actual := cAst.GetType(pAst.Expr().ID()) if tc.err == "" { if actual == nil || !actual.IsEquivalentType(tc.outType) { diff --git a/common/ast/ast.go b/common/ast/ast.go index 62c09cfc6..aa2884c63 100644 --- a/common/ast/ast.go +++ b/common/ast/ast.go @@ -160,6 +160,16 @@ func MaxID(a *AST) int64 { return visitor.maxID + 1 } +// IDs returns the set of AST node IDs, including macro calls. +func (a *AST) IDs() map[int64]bool { + visitor := make(idVisitor) + PostOrderVisit(a.Expr(), visitor) + for _, call := range a.SourceInfo().MacroCalls() { + PostOrderVisit(call, visitor) + } + return visitor +} + // Heights computes the heights of all AST expressions and returns a map from expression id to height. func Heights(a *AST) map[int64]int { visitor := make(heightVisitor) @@ -533,3 +543,13 @@ func (hv heightVisitor) maxEntryHeight(entries ...EntryExpr) int { } return max } + +type idVisitor map[int64]bool + +func (v idVisitor) VisitExpr(e Expr) { + v[e.ID()] = true +} + +func (v idVisitor) VisitEntryExpr(e EntryExpr) { + v[e.ID()] = true +} diff --git a/common/debug/debug.go b/common/debug/debug.go index 75f5f0d63..fbc847f0c 100644 --- a/common/debug/debug.go +++ b/common/debug/debug.go @@ -312,3 +312,18 @@ func (w *debugWriter) removeIndent() { func (w *debugWriter) String() string { return w.buffer.String() } + +type idAdorner struct{} + +func (a *idAdorner) GetMetadata(elem any) string { + e, isExpr := elem.(ast.Expr) + if !isExpr { + return "" + } + return fmt.Sprintf("@id:%d ", e.ID()) +} + +// ToDebugStringWithIDs returns a string representation with AST node IDs. +func ToDebugStringWithIDs(e ast.Expr) string { + return ToAdornedDebugString(e, &idAdorner{}) +} From 7b7c04df74f05f73fa73158be3e35c9f80e68275 Mon Sep 17 00:00:00 2001 From: Andrei Benea Date: Tue, 6 Jan 2026 17:31:08 +0100 Subject: [PATCH 2/4] Preserve source information during policy composition. The old implementation outputs the wrong source information for policy files derived from the dummy AST created by RuleComposer.Compose. This change fixes that by preserving the correct source and merging the offset ranges of the rule match expressions inserted by the composer optimizer into the final AST. --- cel/env.go | 8 ++++ cel/optimizer.go | 60 +++++++++++++++++++++++++----- policy/BUILD.bazel | 3 ++ policy/compiler_test.go | 82 +++++++++++++++++++++++++++++++++++++++++ policy/composer.go | 24 +++++++++++- policy/composer_test.go | 77 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 242 insertions(+), 12 deletions(-) create mode 100644 policy/composer_test.go diff --git a/cel/env.go b/cel/env.go index 58819e872..8d220932c 100644 --- a/cel/env.go +++ b/cel/env.go @@ -48,6 +48,14 @@ type Ast struct { impl *celast.AST } +// NewAst creates a new Ast value from a source and its native representation. +func NewAst(source Source, impl *celast.AST) *Ast { + return &Ast{ + source: source, + impl: impl, + } +} + // NativeRep converts the AST to a Go-native representation. func (ast *Ast) NativeRep() *celast.AST { if ast == nil { diff --git a/cel/optimizer.go b/cel/optimizer.go index 9a2a97a64..61757b981 100644 --- a/cel/optimizer.go +++ b/cel/optimizer.go @@ -32,14 +32,27 @@ import ( // Note: source position information is best-effort and likely wrong, but optimized expressions // should be suitable for calls to parser.Unparse. type StaticOptimizer struct { - optimizers []ASTOptimizer + optimizers []ASTOptimizer + mergeSourceInfo bool } // NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied // to a checked expression. func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer { return &StaticOptimizer{ - optimizers: optimizers, + optimizers: optimizers, + mergeSourceInfo: false, + } +} + +// NewStaticOptimizerWithSourceInfoMerging creates a StaticOptimizer with a sequence of +// ASTOptimizer's to be applied to a checked expression. The source info of the optimized AST will +// be merged with the source info of the input AST which is useful when merging multiple expressions +// defined in the same file. Used only for policy composition. +func NewStaticOptimizerWithSourceInfoMerging(optimizers ...ASTOptimizer) *StaticOptimizer { + return &StaticOptimizer{ + optimizers: optimizers, + mergeSourceInfo: true, } } @@ -55,9 +68,10 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { issues := NewIssues(common.NewErrors(a.Source())) baseFac := ast.NewExprFactory() exprFac := &optimizerExprFactory{ - idGenerator: ids, - fac: baseFac, - sourceInfo: optimized.SourceInfo(), + idGenerator: ids, + fac: baseFac, + sourceInfo: optimized.SourceInfo(), + mergeSourceInfo: opt.mergeSourceInfo, } ctx := &OptimizerContext{ optimizerExprFactory: exprFac, @@ -75,7 +89,7 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { freshIDGen := newIDGenerator(0) info := optimized.SourceInfo() expr := optimized.Expr() - normalizeIDs(freshIDGen.renumberStable, expr, info) + normalizeIDs(freshIDGen.renumberStable, expr, info, exprFac.mergeSourceInfo) cleanupMacroRefs(expr, info) // Recheck the updated expression for any possible type-agreement or validation errors. @@ -96,10 +110,30 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { }, nil } +func updateOffsetRanges(idGen ast.IDGenerator, info *ast.SourceInfo) { + newRanges := make(map[int64]ast.OffsetRange) + sortedOldIDs := []int64{} + for oldID := range info.OffsetRanges() { + sortedOldIDs = append(sortedOldIDs, oldID) + } + sort.Slice(sortedOldIDs, func(i, j int) bool { return sortedOldIDs[i] < sortedOldIDs[j] }) + for _, oldID := range sortedOldIDs { + offsetRange, _ := info.GetOffsetRange(oldID) + newRanges[idGen(oldID)] = offsetRange + info.ClearOffsetRange(oldID) + } + for newID, offsetRange := range newRanges { + info.SetOffsetRange(newID, offsetRange) + } +} + // normalizeIDs ensures that the metadata present with an AST is reset in a manner such // that the ids within the expression correspond to the ids within macros. -func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo) { +func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo, mergeSourceInfo bool) { optimized.RenumberIDs(idGen) + if mergeSourceInfo { + updateOffsetRanges(idGen, info) + } if len(info.MacroCalls()) == 0 { return } @@ -229,8 +263,9 @@ type ASTOptimizer interface { type optimizerExprFactory struct { *idGenerator - fac ast.ExprFactory - sourceInfo *ast.SourceInfo + fac ast.ExprFactory + sourceInfo *ast.SourceInfo + mergeSourceInfo bool } // NewAST creates an AST from the current expression using the tracked source info which @@ -249,7 +284,7 @@ func (opt *optimizerExprFactory) CopyAST(a *ast.AST) (ast.Expr, *ast.SourceInfo) defer func() { opt.seed = idGen.nextID() }() copyExpr := opt.fac.CopyExpr(a.Expr()) copyInfo := ast.CopySourceInfo(a.SourceInfo()) - normalizeIDs(idGen.renumberStable, copyExpr, copyInfo) + normalizeIDs(idGen.renumberStable, copyExpr, copyInfo, opt.mergeSourceInfo) return copyExpr, copyInfo } @@ -260,6 +295,11 @@ func (opt *optimizerExprFactory) CopyASTAndMetadata(a *ast.AST) ast.Expr { for macroID, call := range copyInfo.MacroCalls() { opt.SetMacroCall(macroID, call) } + if opt.mergeSourceInfo { + for id, offset := range copyInfo.OffsetRanges() { + opt.sourceInfo.SetOffsetRange(id, offset) + } + } return copyExpr } diff --git a/policy/BUILD.bazel b/policy/BUILD.bazel index 97650ff92..6b6ba5b93 100644 --- a/policy/BUILD.bazel +++ b/policy/BUILD.bazel @@ -57,6 +57,7 @@ go_test( size = "small", srcs = [ "compiler_test.go", + "composer_test.go", "config_test.go", "helper_test.go", "parser_test.go", @@ -72,6 +73,8 @@ go_test( "//test/proto3pb:go_default_library", "@in_yaml_go_yaml_v3//:go_default_library", "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "@org_golang_google_protobuf//encoding/prototext:go_default_library", ], ) diff --git a/policy/compiler_test.go b/policy/compiler_test.go index 4df6928ee..5ad4d801d 100644 --- a/policy/compiler_test.go +++ b/policy/compiler_test.go @@ -20,6 +20,7 @@ import ( "strings" "testing" + "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" "github.com/google/cel-go/cel" @@ -73,6 +74,8 @@ func TestRuleComposerUnnest(t *testing.T) { if iss.Err() != nil { t.Fatalf("Compose(rule) failed: %v", iss.Err()) } + policy := parsePolicy(t, tc.name, []ParserOption{}) + verifySourceInfoCoverage(t, policy, ast) unparsed, err := cel.AstToString(ast) if err != nil { t.Fatalf("cel.AstToString() failed: %v", err) @@ -507,6 +510,7 @@ func compile(t testing.TB, name string, policy *Policy, envOpts []cel.EnvOption, t.Fatalf("env.Extend() with config options %v, failed: %v", config, err) } ast, iss := Compile(env, policy, compilerOpts...) + verifySourceInfoCoverage(t, policy, ast) return env, ast, iss } @@ -516,3 +520,81 @@ func normalize(s string) string { strings.ReplaceAll(s, " ", ""), "\n", ""), "\t", "") } + +func verifySourceInfoCoverage(t testing.TB, policy *Policy, ast *cel.Ast) { + t.Helper() + info := ast.SourceInfo() + if info == nil { + return + } + + exprLines := exprLinesFromPolicy(policy) + coveredLines := make(map[int]bool) + ids := ast.NativeRep().IDs() + for id, offset := range info.GetPositions() { + // Check that each position in the SourceInfo corresponds to a valid AST node. + if !ids[id] { + t.Errorf("id %d not found in AST", id) + } + loc, found := ast.Source().OffsetLocation(offset) + if found { + coveredLines[loc.Line()] = true + } else { + t.Errorf("invalid source location for offset %d", offset) + } + } + // Verify that each source line inside an expression is covered by the at least one node in the + // AST. + for line := range exprLines { + if !coveredLines[line] { + t.Errorf("Line %d expected to be covered by SourceInfo, but was not", line) + } + } + + if t.Failed() { + checked, err := cel.AstToCheckedExpr(ast) + if err != nil { + t.Logf("cel.AstToCheckedExpr() failed: %v", err) + } else { + t.Logf("AST:\n%s", prototext.Format(checked.GetExpr())) + } + } +} + +// exprLinesFromPolicy returns a set of line numbers within a policy where expressions (variables, +// conditions, etc.) are defined. +func exprLinesFromPolicy(policy *Policy) map[int]bool { + lines := make(map[int]bool) + addExpectedLines := func(vs ValueString) { + if offset, found := policy.SourceInfo().GetOffsetRange(vs.ID); found { + startLoc, foundStart := policy.Source().OffsetLocation(offset.Start) + // Multiline strings can span multiple lines, but the SourceInfo will only contain the start + // position of the expression. So just skip the check if the expression contains a multiline + // string literal. + hasMultiline := strings.Contains(vs.Value, "'''") || strings.Contains(vs.Value, "\"\"\"") + if foundStart && !hasMultiline { + numLines := strings.Count(vs.Value, "\n") + for i := 0; i <= numLines; i++ { + lines[startLoc.Line()+i] = true + } + } + } + } + var traverseRule func(r *Rule) + traverseRule = func(r *Rule) { + for _, v := range r.Variables() { + addExpectedLines(v.Expression()) + } + for _, m := range r.Matches() { + addExpectedLines(m.Condition()) + if m.HasOutput() { + addExpectedLines(m.Output()) + } + if m.HasRule() { + traverseRule(m.Rule()) + } + } + } + traverseRule(policy.Rule()) + return lines +} diff --git a/policy/composer.go b/policy/composer.go index 762472487..eedc2f65b 100644 --- a/policy/composer.go +++ b/policy/composer.go @@ -71,12 +71,18 @@ type RuleComposer struct { // Compose stitches together a set of expressions within a CompiledRule into a single CEL ast. func (c *RuleComposer) Compose(r *CompiledRule) (*cel.Ast, *cel.Issues) { - ruleRoot, _ := c.env.Compile("true") + // dummyExpr is a placeholder expression used as the root of the AST before optimization. Because + // StaticOptimizer copies the source info from it, we must set the correct source info here. + source := recoverOriginalSource(r) + sourceInfo := ast.NewSourceInfo(source) + dummyExpr := ast.NewExprFactory().NewLiteral(1, types.True) + ruleRoot := cel.NewAst(source, ast.NewAST(dummyExpr, sourceInfo)) + composer := &ruleComposerImpl{ rule: r, varIndices: []varIndex{}, } - opt := cel.NewStaticOptimizer(composer) + opt := cel.NewStaticOptimizerWithSourceInfoMerging(composer) ast, iss := opt.Optimize(c.env, ruleRoot) if iss.Err() != nil { return nil, iss @@ -89,6 +95,20 @@ func (c *RuleComposer) Compose(r *CompiledRule) (*cel.Ast, *cel.Issues) { return opt.Optimize(c.env, ast) } +func recoverOriginalSource(r *CompiledRule) cel.Source { + match := r.Matches()[0] + var source cel.Source + if match.Output() != nil { + source = match.Output().Expr().Source() + } else { + source = match.Condition().Source() + } + if rel, ok := source.(*RelativeSource); ok { + source = rel.Source + } + return source +} + type varIndex struct { index int indexVar string diff --git a/policy/composer_test.go b/policy/composer_test.go new file mode 100644 index 000000000..14cc3c772 --- /dev/null +++ b/policy/composer_test.go @@ -0,0 +1,77 @@ +package policy + +import ( + "testing" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/ext" + "google.golang.org/protobuf/encoding/prototext" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +type positionChecker struct { + t *testing.T + si *exprpb.SourceInfo +} + +func (v *positionChecker) VisitExpr(e ast.Expr) { + if e.AsCall() != nil && e.AsCall().FunctionName() == "_?_:_" { + // Ignore ternary operator as it's a synthetic node. + return + } + if _, found := v.si.Positions[e.ID()]; !found { + pbExpr, _ := ast.ExprToProto(e) + v.t.Errorf("No position found for expression ID: %d, %s", e.ID(), prototext.Format(pbExpr)) + } +} + +func (v *positionChecker) VisitEntryExpr(e ast.EntryExpr) { + if _, found := v.si.Positions[e.ID()]; !found { + v.t.Errorf("No position found for expression ID: %d", e.ID()) + } +} + +func TestCompose_SourceInfo(t *testing.T) { + policyYAML := `name: test_policy +rule: + match: + - condition: "2 == 1" + output: "'hi'" + - output: "'hello' + ' world'" +` + src := StringSource(policyYAML, "test_policy.yaml") + parser, err := NewParser() + if err != nil { + t.Fatalf("NewParser() failed: %v", err) + } + policy, iss := parser.Parse(src) + if iss.Err() != nil { + t.Fatalf("parser.Parse() failed: %v", iss.Err()) + } + + env, err := cel.NewEnv(cel.OptionalTypes(), ext.Bindings()) + if err != nil { + t.Fatalf("cel.NewEnv() failed: %v", err) + } + compiledRule, iss := CompileRule(env, policy) + if iss.Err() != nil { + t.Fatalf("CompileRule() failed: %v", iss.Err()) + } + composer, err := NewRuleComposer(env) + if err != nil { + t.Fatalf("NewRuleComposer() failed: %v", err) + } + compAST, iss := composer.Compose(compiledRule) + if iss.Err() != nil { + t.Fatalf("composer.Compose() failed: %v", iss.Err()) + } + + si := compAST.SourceInfo() + if si.Location != "test_policy.yaml" { + t.Errorf("SourceInfo.Location got %q, wanted test_policy.yaml", si.Location) + } + checker := &positionChecker{t: t, si: si} + ast.PostOrderVisit(compAST.NativeRep().Expr(), checker) +} From 1e61c526250e940937a308baafeb8ed1ce1fc08b Mon Sep 17 00:00:00 2001 From: abenea Date: Fri, 16 Jan 2026 22:17:17 +0100 Subject: [PATCH 3/4] Introduce optimizer options. --- cel/env.go | 8 ------- cel/folding_test.go | 25 ++++++++++++++++---- cel/inlining_test.go | 20 ++++++++++++---- cel/optimizer.go | 55 +++++++++++++++++++++++++++++-------------- cel/optimizer_test.go | 15 +++++++++--- ext/sets_test.go | 5 +++- policy/composer.go | 26 +++++++++++++------- 7 files changed, 107 insertions(+), 47 deletions(-) diff --git a/cel/env.go b/cel/env.go index 8d220932c..58819e872 100644 --- a/cel/env.go +++ b/cel/env.go @@ -48,14 +48,6 @@ type Ast struct { impl *celast.AST } -// NewAst creates a new Ast value from a source and its native representation. -func NewAst(source Source, impl *celast.AST) *Ast { - return &Ast{ - source: source, - impl: impl, - } -} - // NativeRep converts the AST to a Go-native representation. func (ast *Ast) NativeRep() *celast.AST { if ast == nil { diff --git a/cel/folding_test.go b/cel/folding_test.go index aa671badc..08d5f5177 100644 --- a/cel/folding_test.go +++ b/cel/folding_test.go @@ -366,7 +366,10 @@ func TestConstantFoldingOptimizer(t *testing.T) { if err != nil { t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) } - opt := NewStaticOptimizer(folder) + opt, err := NewStaticOptimizer(folder) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss := opt.Optimize(e, checked) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -441,7 +444,10 @@ func TestConstantFoldingCallsWithSideEffects(t *testing.T) { if err != nil { t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) } - opt := NewStaticOptimizer(folder) + opt, err := NewStaticOptimizer(folder) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss := opt.Optimize(e, checked) if tc.error != "" { if iss.Err() == nil { @@ -508,7 +514,10 @@ func TestConstantFoldingOptimizerMacroElimination(t *testing.T) { if err != nil { t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) } - opt := NewStaticOptimizer(folder) + opt, err := NewStaticOptimizer(folder) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss := opt.Optimize(e, checked) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -570,7 +579,10 @@ func TestConstantFoldingOptimizerWithLimit(t *testing.T) { if err != nil { t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) } - opt := NewStaticOptimizer(folder) + opt, err := NewStaticOptimizer(folder) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss := opt.Optimize(e, checked) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -828,7 +840,10 @@ func TestConstantFoldingNormalizeIDs(t *testing.T) { if err != nil { t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) } - opt := NewStaticOptimizer(folder) + opt, err := NewStaticOptimizer(folder) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss := opt.Optimize(e, checked) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) diff --git a/cel/inlining_test.go b/cel/inlining_test.go index 856eaafcb..4dec1074a 100644 --- a/cel/inlining_test.go +++ b/cel/inlining_test.go @@ -220,7 +220,10 @@ func TestInliningOptimizer(t *testing.T) { t.Fatalf("Compile() failed: %v", iss.Err()) } - opt := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...)) + opt, err := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...)) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss := opt.Optimize(e, checked) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -236,7 +239,10 @@ func TestInliningOptimizer(t *testing.T) { if err != nil { t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) } - opt = cel.NewStaticOptimizer(folder) + opt, err = cel.NewStaticOptimizer(folder) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss = opt.Optimize(e, optimized) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -727,7 +733,10 @@ func TestInliningOptimizerMultiStage(t *testing.T) { t.Fatalf("Compile() failed: %v", iss.Err()) } - opt := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...)) + opt, err := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...)) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss := opt.Optimize(e, checked) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -743,7 +752,10 @@ func TestInliningOptimizerMultiStage(t *testing.T) { if err != nil { t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) } - opt = cel.NewStaticOptimizer(folder) + opt, err = cel.NewStaticOptimizer(folder) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optimized, iss = opt.Optimize(e, optimized) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) diff --git a/cel/optimizer.go b/cel/optimizer.go index 61757b981..111d74e67 100644 --- a/cel/optimizer.go +++ b/cel/optimizer.go @@ -15,6 +15,7 @@ package cel import ( + "fmt" "sort" "github.com/google/cel-go/common" @@ -32,27 +33,38 @@ import ( // Note: source position information is best-effort and likely wrong, but optimized expressions // should be suitable for calls to parser.Unparse. type StaticOptimizer struct { - optimizers []ASTOptimizer - mergeSourceInfo bool + optimizers []ASTOptimizer + source *Source } +type OptimizerOption func(*StaticOptimizer) (*StaticOptimizer, error) + // NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied // to a checked expression. -func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer { - return &StaticOptimizer{ - optimizers: optimizers, - mergeSourceInfo: false, +func NewStaticOptimizer(options ...any) (*StaticOptimizer, error) { + so := &StaticOptimizer{} + var err error + for _, opt := range options { + switch v := opt.(type) { + case ASTOptimizer: + so.optimizers = append(so.optimizers, v) + case OptimizerOption: + so, err = v(so) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unsupported option: %v", v) + } } + return so, nil } -// NewStaticOptimizerWithSourceInfoMerging creates a StaticOptimizer with a sequence of -// ASTOptimizer's to be applied to a checked expression. The source info of the optimized AST will -// be merged with the source info of the input AST which is useful when merging multiple expressions -// defined in the same file. Used only for policy composition. -func NewStaticOptimizerWithSourceInfoMerging(optimizers ...ASTOptimizer) *StaticOptimizer { - return &StaticOptimizer{ - optimizers: optimizers, - mergeSourceInfo: true, +// OptimizeWithSource overrides the source used by the optimizer. +func OptimizeWithSource(source Source) OptimizerOption { + return func(so *StaticOptimizer) (*StaticOptimizer, error) { + so.source = &source + return so, nil } } @@ -62,16 +74,23 @@ func NewStaticOptimizerWithSourceInfoMerging(optimizers ...ASTOptimizer) *Static func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { // Make a copy of the AST to be optimized. optimized := ast.Copy(a.NativeRep()) + if opt.source != nil { + optimized = ast.NewAST(optimized.Expr(), ast.NewSourceInfo(*opt.source)) + } ids := newIDGenerator(ast.MaxID(a.NativeRep())) + source := a.Source() + if opt.source != nil { + source = *opt.source + } // Create the optimizer context, could be pooled in the future. - issues := NewIssues(common.NewErrors(a.Source())) + issues := NewIssues(common.NewErrors(source)) baseFac := ast.NewExprFactory() exprFac := &optimizerExprFactory{ idGenerator: ids, fac: baseFac, sourceInfo: optimized.SourceInfo(), - mergeSourceInfo: opt.mergeSourceInfo, + mergeSourceInfo: opt.source != nil, } ctx := &OptimizerContext{ optimizerExprFactory: exprFac, @@ -94,7 +113,7 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { // Recheck the updated expression for any possible type-agreement or validation errors. parsed := &Ast{ - source: a.Source(), + source: source, impl: ast.NewAST(expr, info)} checked, iss := ctx.Check(parsed) if iss.Err() != nil { @@ -105,7 +124,7 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { // Return the optimized result. return &Ast{ - source: a.Source(), + source: source, impl: optimized, }, nil } diff --git a/cel/optimizer_test.go b/cel/optimizer_test.go index f406e7c75..07a007c72 100644 --- a/cel/optimizer_test.go +++ b/cel/optimizer_test.go @@ -45,7 +45,10 @@ func TestStaticOptimizerUpdateExpr(t *testing.T) { if iss.Err() != nil { t.Fatalf("Compile() failed: %v", iss.Err()) } - opt := cel.NewStaticOptimizer(&testOptimizer{t: t, inlineExpr: inlinedAST.NativeRep()}) + opt, err := cel.NewStaticOptimizer(&testOptimizer{t: t, inlineExpr: inlinedAST.NativeRep()}) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optAST, iss := opt.Optimize(e, exprAST) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -186,7 +189,10 @@ func TestStaticOptimizerNewAST(t *testing.T) { if iss.Err() != nil { t.Fatalf("Compile(%q) failed: %v", tc, iss.Err()) } - opt := cel.NewStaticOptimizer(&identityOptimizer{t: t}) + opt, err := cel.NewStaticOptimizer(&identityOptimizer{t: t}) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optAST, iss := opt.Optimize(e, exprAST) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -204,7 +210,10 @@ func TestStaticOptimizerNewAST(t *testing.T) { func TestStaticOptimizerNilAST(t *testing.T) { env := optimizerEnv(t) - opt := cel.NewStaticOptimizer(&identityOptimizer{t: t}) + opt, err := cel.NewStaticOptimizer(&identityOptimizer{t: t}) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } optAST, iss := opt.Optimize(env, nil) if iss.Err() == nil || !strings.Contains(iss.Err().Error(), "unexpected unspecified type") { t.Errorf("opt.Optimize(env, nil) got (%v, %v), wanted unexpected unspecified type", optAST, iss) diff --git a/ext/sets_test.go b/ext/sets_test.go index 74163db62..4ea409b67 100644 --- a/ext/sets_test.go +++ b/ext/sets_test.go @@ -430,7 +430,10 @@ func TestSetsMembershipRewriter(t *testing.T) { if err != nil { t.Fatalf("NewSetMembershipOptimizer() failed with error: %v", err) } - opt := cel.NewStaticOptimizer(setsOpt) + opt, err := cel.NewStaticOptimizer(setsOpt) + if err != nil { + t.Fatalf("cel.NewStaticOptimizer() failed with error: %v", err) + } optAST, iss := opt.Optimize(env, a) if iss.Err() != nil { t.Fatalf("opt.Optimize() failed: %v", iss.Err()) diff --git a/policy/composer.go b/policy/composer.go index eedc2f65b..300de8210 100644 --- a/policy/composer.go +++ b/policy/composer.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/google/cel-go/cel" + "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" @@ -71,18 +72,22 @@ type RuleComposer struct { // Compose stitches together a set of expressions within a CompiledRule into a single CEL ast. func (c *RuleComposer) Compose(r *CompiledRule) (*cel.Ast, *cel.Issues) { - // dummyExpr is a placeholder expression used as the root of the AST before optimization. Because - // StaticOptimizer copies the source info from it, we must set the correct source info here. - source := recoverOriginalSource(r) - sourceInfo := ast.NewSourceInfo(source) - dummyExpr := ast.NewExprFactory().NewLiteral(1, types.True) - ruleRoot := cel.NewAst(source, ast.NewAST(dummyExpr, sourceInfo)) + ruleRoot, _ := c.env.Compile("true") composer := &ruleComposerImpl{ rule: r, varIndices: []varIndex{}, } - opt := cel.NewStaticOptimizerWithSourceInfoMerging(composer) + // ruleRoot is a placeholder expression used as the root of the AST before optimization. Because + // StaticOptimizer would normally copy the source info from it, we must use OptimizeWithSource + // to override the correct source. + source := recoverOriginalSource(r) + opt, err := cel.NewStaticOptimizer(composer, cel.OptimizeWithSource(source)) + if err != nil { + errs := common.NewErrors(source) + errs.ReportErrorString(common.NoLocation, err.Error()) + return nil, cel.NewIssues(errs) + } ast, iss := opt.Optimize(c.env, ruleRoot) if iss.Err() != nil { return nil, iss @@ -91,7 +96,12 @@ func (c *RuleComposer) Compose(r *CompiledRule) (*cel.Ast, *cel.Issues) { varIndices: []varIndex{}, exprUnnestHeight: c.exprUnnestHeight, } - opt = cel.NewStaticOptimizer(unnester) + opt, err = cel.NewStaticOptimizer(unnester) + if err != nil { + errs := common.NewErrors(source) + errs.ReportErrorString(common.NoLocation, err.Error()) + return nil, cel.NewIssues(errs) + } return opt.Optimize(c.env, ast) } From c687daf25a806f8eda8e915344883201b590092f Mon Sep 17 00:00:00 2001 From: abenea Date: Sat, 17 Jan 2026 00:36:55 +0100 Subject: [PATCH 4/4] Move SourceInfo renumbering logic to a member function. --- cel/optimizer.go | 20 ++------------------ common/ast/ast.go | 19 +++++++++++++++++++ common/ast/ast_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 18 deletions(-) diff --git a/cel/optimizer.go b/cel/optimizer.go index 111d74e67..3ffeb4281 100644 --- a/cel/optimizer.go +++ b/cel/optimizer.go @@ -129,30 +129,14 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { }, nil } -func updateOffsetRanges(idGen ast.IDGenerator, info *ast.SourceInfo) { - newRanges := make(map[int64]ast.OffsetRange) - sortedOldIDs := []int64{} - for oldID := range info.OffsetRanges() { - sortedOldIDs = append(sortedOldIDs, oldID) - } - sort.Slice(sortedOldIDs, func(i, j int) bool { return sortedOldIDs[i] < sortedOldIDs[j] }) - for _, oldID := range sortedOldIDs { - offsetRange, _ := info.GetOffsetRange(oldID) - newRanges[idGen(oldID)] = offsetRange - info.ClearOffsetRange(oldID) - } - for newID, offsetRange := range newRanges { - info.SetOffsetRange(newID, offsetRange) - } -} - // normalizeIDs ensures that the metadata present with an AST is reset in a manner such // that the ids within the expression correspond to the ids within macros. func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo, mergeSourceInfo bool) { optimized.RenumberIDs(idGen) if mergeSourceInfo { - updateOffsetRanges(idGen, info) + info.RenumberIDs(idGen) } + if len(info.MacroCalls()) == 0 { return } diff --git a/common/ast/ast.go b/common/ast/ast.go index aa2884c63..d10d4eca7 100644 --- a/common/ast/ast.go +++ b/common/ast/ast.go @@ -16,6 +16,8 @@ package ast import ( + "slices" + "github.com/google/cel-go/common" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" @@ -242,6 +244,23 @@ type SourceInfo struct { macroCalls map[int64]Expr } +// RenumberIDs performs an in-place update of the expression IDs within the SourceInfo. +func (s *SourceInfo) RenumberIDs(idGen IDGenerator) { + if s == nil { + return + } + oldIDs := []int64{} + for id := range s.offsetRanges { + oldIDs = append(oldIDs, id) + } + slices.Sort(oldIDs) + newRanges := make(map[int64]OffsetRange) + for _, id := range oldIDs { + newRanges[idGen(id)] = s.offsetRanges[id] + } + s.offsetRanges = newRanges +} + // SyntaxVersion returns the syntax version associated with the text expression. func (s *SourceInfo) SyntaxVersion() string { if s == nil { diff --git a/common/ast/ast_test.go b/common/ast/ast_test.go index 7a1c6a141..0a86c3936 100644 --- a/common/ast/ast_test.go +++ b/common/ast/ast_test.go @@ -16,6 +16,7 @@ package ast_test import ( "fmt" + "maps" "reflect" "testing" @@ -388,3 +389,40 @@ func (src *mockSource) OffsetLocation(offset int32) (common.Location, bool) { } return src.Source.OffsetLocation(offset) } + +func TestSourceInfoRenumberIDs(t *testing.T) { + info := ast.NewSourceInfo(nil) + for old := int64(1); old <= 5; old++ { + info.SetOffsetRange(old, ast.OffsetRange{Start: int32(old), Stop: int32(old) + 1}) + } + original := make(map[int64]ast.OffsetRange) + maps.Copy(original, info.OffsetRanges()) + + // Verify the renumbering is stable. + var next int64 = 101 + idMap := make(map[int64]int64) + idGen := func(old int64) int64 { + if _, found := idMap[old]; !found { + idMap[old] = next + next = next + 1 + } + return idMap[old] + } + info.RenumberIDs(idGen) + + if len(info.OffsetRanges()) != 5 { + t.Errorf("got %d offset ranges, wanted 5", len(info.OffsetRanges())) + } + + for old := int64(1); old <= 5; old++ { + want := original[old] + new := old + 100 + got, found := info.GetOffsetRange(new) + if !found { + t.Errorf("offset range for ID %d not found", new) + } + if got != want { + t.Errorf("offset range for ID %d incorrect; got %v, want %v", new, got, want) + } + } +}