Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ type Ast struct {
impl *celast.AST
}

// NewAst creates a new Ast value from a source and its native representation.
func NewAst(source Source, impl *celast.AST) *Ast {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use this method, but rather pass the desired source to the optimizer to use as the basis for source metadata merging since that seems to be the desired effect. More details below.

return &Ast{
source: source,
impl: impl,
}
}

// NativeRep converts the AST to a Go-native representation.
func (ast *Ast) NativeRep() *celast.AST {
if ast == nil {
Expand Down
60 changes: 50 additions & 10 deletions cel/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,27 @@ import (
// Note: source position information is best-effort and likely wrong, but optimized expressions
// should be suitable for calls to parser.Unparse.
type StaticOptimizer struct {
optimizers []ASTOptimizer
optimizers []ASTOptimizer
mergeSourceInfo bool
}

// NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied
// to a checked expression.
func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer {
return &StaticOptimizer{
optimizers: optimizers,
optimizers: optimizers,
mergeSourceInfo: false,
}
}

// NewStaticOptimizerWithSourceInfoMerging creates a StaticOptimizer with a sequence of
// ASTOptimizer's to be applied to a checked expression. The source info of the optimized AST will
// be merged with the source info of the input AST which is useful when merging multiple expressions
// defined in the same file. Used only for policy composition.
func NewStaticOptimizerWithSourceInfoMerging(optimizers ...ASTOptimizer) *StaticOptimizer {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer instead of making a new method to make source info merge a functional option by introducing the following changes:

type OptimizerOption func(*StaticOptimizer) (*StaticOptimizer, error)

NewStaticOptimizer(options ...any) (*StaticOptimizer, error) {
   so := &StaticOptimizer{}
   var err error
   for _, opt := range options {
     switch v := opt.(type) {
       case ASTOptimizer:
          so.optimizers = append(so.optimizers, v)
       case OptimizerOption:
          so, err = v(so)
          if err != nil {
             return nil, err
          }
        default:
          return nil, fmt.Errorf("unsupported option: %v", v)
     }
   }
   return so
}

return &StaticOptimizer{
optimizers: optimizers,
mergeSourceInfo: true,
}
}

Expand All @@ -55,9 +68,10 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
issues := NewIssues(common.NewErrors(a.Source()))
baseFac := ast.NewExprFactory()
exprFac := &optimizerExprFactory{
idGenerator: ids,
fac: baseFac,
sourceInfo: optimized.SourceInfo(),
idGenerator: ids,
fac: baseFac,
sourceInfo: optimized.SourceInfo(),
mergeSourceInfo: opt.mergeSourceInfo,
}
ctx := &OptimizerContext{
optimizerExprFactory: exprFac,
Expand All @@ -75,7 +89,7 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
freshIDGen := newIDGenerator(0)
info := optimized.SourceInfo()
expr := optimized.Expr()
normalizeIDs(freshIDGen.renumberStable, expr, info)
normalizeIDs(freshIDGen.renumberStable, expr, info, exprFac.mergeSourceInfo)
cleanupMacroRefs(expr, info)

// Recheck the updated expression for any possible type-agreement or validation errors.
Expand All @@ -96,10 +110,30 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
}, nil
}

func updateOffsetRanges(idGen ast.IDGenerator, info *ast.SourceInfo) {
newRanges := make(map[int64]ast.OffsetRange)
sortedOldIDs := []int64{}
for oldID := range info.OffsetRanges() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer copy(dst, src) to the for-loop

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://pkg.go.dev/builtin#copy it only works with slices, but in this case I am trying to copy the keys from a map.

sortedOldIDs = append(sortedOldIDs, oldID)
}
sort.Slice(sortedOldIDs, func(i, j int) bool { return sortedOldIDs[i] < sortedOldIDs[j] })
for _, oldID := range sortedOldIDs {
offsetRange, _ := info.GetOffsetRange(oldID)
newRanges[idGen(oldID)] = offsetRange
info.ClearOffsetRange(oldID)
}
for newID, offsetRange := range newRanges {
info.SetOffsetRange(newID, offsetRange)
}
}

// normalizeIDs ensures that the metadata present with an AST is reset in a manner such
// that the ids within the expression correspond to the ids within macros.
func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo) {
func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo, mergeSourceInfo bool) {
optimized.RenumberIDs(idGen)
if mergeSourceInfo {
updateOffsetRanges(idGen, info)
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider moving updateOffsetRanges into the ast.SourceInfo object as RenumerIDs(idGen) for consistency with the other renumbering code. Should be easier to unit tests that a copy of an AST can be renumbered in a way that's consistent with your expectations.

if len(info.MacroCalls()) == 0 {
return
}
Expand Down Expand Up @@ -229,8 +263,9 @@ type ASTOptimizer interface {

type optimizerExprFactory struct {
*idGenerator
fac ast.ExprFactory
sourceInfo *ast.SourceInfo
fac ast.ExprFactory
sourceInfo *ast.SourceInfo
mergeSourceInfo bool
}

// NewAST creates an AST from the current expression using the tracked source info which
Expand All @@ -249,7 +284,7 @@ func (opt *optimizerExprFactory) CopyAST(a *ast.AST) (ast.Expr, *ast.SourceInfo)
defer func() { opt.seed = idGen.nextID() }()
copyExpr := opt.fac.CopyExpr(a.Expr())
copyInfo := ast.CopySourceInfo(a.SourceInfo())
normalizeIDs(idGen.renumberStable, copyExpr, copyInfo)
normalizeIDs(idGen.renumberStable, copyExpr, copyInfo, opt.mergeSourceInfo)
return copyExpr, copyInfo
}

Expand All @@ -260,6 +295,11 @@ func (opt *optimizerExprFactory) CopyASTAndMetadata(a *ast.AST) ast.Expr {
for macroID, call := range copyInfo.MacroCalls() {
opt.SetMacroCall(macroID, call)
}
if opt.mergeSourceInfo {
for id, offset := range copyInfo.OffsetRanges() {
opt.sourceInfo.SetOffsetRange(id, offset)
}
}
return copyExpr
}

Expand Down
3 changes: 3 additions & 0 deletions policy/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ go_test(
size = "small",
srcs = [
"compiler_test.go",
"composer_test.go",
"config_test.go",
"helper_test.go",
"parser_test.go",
Expand All @@ -72,6 +73,8 @@ go_test(
"//test/proto3pb:go_default_library",
"@in_yaml_go_yaml_v3//:go_default_library",
"@com_github_google_go_cmp//cmp:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//encoding/prototext:go_default_library",
],
)

Expand Down
57 changes: 47 additions & 10 deletions policy/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,7 @@ type compiler struct {
func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*CompiledRule, *cel.Issues) {
compiledVars := make([]*CompiledVariable, len(r.Variables()))
for i, v := range r.Variables() {
exprSrc := c.relSource(v.Expression())
varAST, exprIss := ruleEnv.CompileSource(exprSrc)
varAST, exprIss := c.compileValueString(ruleEnv, v.Expression())
varName := v.Name().Value

// Determine the variable type. If the expression is an error then record the error and
Expand Down Expand Up @@ -304,8 +303,7 @@ func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*Com
// Compile the set of match conditions under the rule.
compiledMatches := []*CompiledMatch{}
for _, m := range r.Matches() {
condSrc := c.relSource(m.Condition())
condAST, condIss := ruleEnv.CompileSource(condSrc)
condAST, condIss := c.compileValueString(ruleEnv, m.Condition())
iss = iss.Append(condIss)
// This case cannot happen when the Policy object is parsed from yaml, but could happen
// with a non-YAML generation of the Policy object.
Expand All @@ -315,8 +313,7 @@ func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*Com
continue
}
if m.HasOutput() {
outSrc := c.relSource(m.Output())
outAST, outIss := ruleEnv.CompileSource(outSrc)
outAST, outIss := c.compileValueString(ruleEnv, m.Output())
iss = iss.Append(outIss)
compiledMatches = append(compiledMatches, &CompiledMatch{
exprID: m.exprID,
Expand Down Expand Up @@ -411,16 +408,56 @@ func (c *compiler) checkUnreachableCode(rule *CompiledRule, iss *cel.Issues) {
}
}

func (c *compiler) relSource(pstr ValueString) *RelativeSource {
func (c *compiler) compileValueString(env *cel.Env, vs ValueString) (*cel.Ast, *cel.Issues) {
line := 0
col := 1
if offset, found := c.info.GetOffsetRange(pstr.ID); found {
if loc, found := c.src.OffsetLocation(offset.Start); found {
offset, found := c.info.GetOffsetRange(vs.ID)
if found {
if loc, offsetFound := c.src.OffsetLocation(offset.Start); offsetFound {
line = loc.Line()
col = loc.Column()
}
}
return c.src.Relative(pstr.Value, line, col)
relSource := c.src.Relative(vs.Value, line, col)
ast, iss := env.CompileSource(relSource)
info := ast.NativeRep().SourceInfo()
var keepIDs map[int64]bool
if !found {
// Remove the newly created offset ranges if the ValueString is not associated with a source
// position. This currently happens because the parser creates a synthetic "true" condition if a
// match doesn't have one.
keepIDs = make(map[int64]bool)
} else {
// Remove offset ranges for ids without a corresponding AST node. This can happen because the
// checker deletes some nodes while rewriting the AST. For example the Select operand is deleted
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we should be able to do this cleanup within the parser and the checker rather than here. Could we explore that approach instead? Maybe as a separate PR so we can remove the cleanup here after that one is submitted?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in #1258.

// when a variable reference is replaced with a Ident expression.
keepIDs = getAstIds(ast)
}
for id := range info.OffsetRanges() {
if !keepIDs[id] {
info.ClearOffsetRange(id)
}
}
return ast, iss
}

type idVisitor struct {
ids map[int64]bool
}

func (v *idVisitor) VisitExpr(e ast.Expr) {
v.ids[e.ID()] = true
}

func (v *idVisitor) VisitEntryExpr(e ast.EntryExpr) {
v.ids[e.ID()] = true
}

// getAstIds returns a set of AST node IDs
func getAstIds(celAst *cel.Ast) map[int64]bool {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer to uppercase acronyms, e.g. getASTIDs. There's some debt around the old names, but we shouldn't propagate it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer operating on the ast.AST rather than the cel.Ast.

visitor := &idVisitor{ids: make(map[int64]bool)}
ast.PostOrderVisit(celAst.NativeRep().Expr(), visitor)
return visitor.ids
}

const (
Expand Down
82 changes: 82 additions & 0 deletions policy/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"strings"
"testing"

"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/proto"

"github.com/google/cel-go/cel"
Expand Down Expand Up @@ -73,6 +74,8 @@ func TestRuleComposerUnnest(t *testing.T) {
if iss.Err() != nil {
t.Fatalf("Compose(rule) failed: %v", iss.Err())
}
policy := parsePolicy(t, tc.name, []ParserOption{})
verifySourceInfoCoverage(t, policy, ast)
unparsed, err := cel.AstToString(ast)
if err != nil {
t.Fatalf("cel.AstToString() failed: %v", err)
Expand Down Expand Up @@ -507,6 +510,7 @@ func compile(t testing.TB, name string, policy *Policy, envOpts []cel.EnvOption,
t.Fatalf("env.Extend() with config options %v, failed: %v", config, err)
}
ast, iss := Compile(env, policy, compilerOpts...)
verifySourceInfoCoverage(t, policy, ast)
return env, ast, iss
}

Expand All @@ -516,3 +520,81 @@ func normalize(s string) string {
strings.ReplaceAll(s, " ", ""), "\n", ""),
"\t", "")
}

func verifySourceInfoCoverage(t testing.TB, policy *Policy, ast *cel.Ast) {
t.Helper()
info := ast.SourceInfo()
if info == nil {
return
}

exprLines := exprLinesFromPolicy(policy)
coveredLines := make(map[int]bool)
ids := getAstIds(ast)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should just expose this function on the ast.AST instead ... no harm in doing it that way.

for id, offset := range info.GetPositions() {
// Check that each position in the SourceInfo corresponds to a valid AST node.
if !ids[id] {
t.Errorf("id %d not found in AST", id)
}
loc, found := ast.Source().OffsetLocation(offset)
if found {
coveredLines[loc.Line()] = true
} else {
t.Errorf("invalid source location for offset %d", offset)
}
}
// Verify that each source line inside an expression is covered by the at least one node in the
// AST.
for line := range exprLines {
if !coveredLines[line] {
t.Errorf("Line %d expected to be covered by SourceInfo, but was not", line)
}
}

if t.Failed() {
checked, err := cel.AstToCheckedExpr(ast)
if err != nil {
t.Logf("cel.AstToCheckedExpr() failed: %v", err)
} else {
t.Logf("AST:\n%s", prototext.Format(checked.GetExpr()))
}
}
}

// exprLinesFromPolicy returns a set of line numbers within a policy where expressions (variables,
// conditions, etc.) are defined.
func exprLinesFromPolicy(policy *Policy) map[int]bool {
lines := make(map[int]bool)
addExpectedLines := func(vs ValueString) {
if offset, found := policy.SourceInfo().GetOffsetRange(vs.ID); found {
startLoc, foundStart := policy.Source().OffsetLocation(offset.Start)
// Multiline strings can span multiple lines, but the SourceInfo will only contain the start
// position of the expression. So just skip the check if the expression contains a multiline
// string literal.
hasMultiline := strings.Contains(vs.Value, "'''") || strings.Contains(vs.Value, "\"\"\"")
if foundStart && !hasMultiline {
numLines := strings.Count(vs.Value, "\n")
for i := 0; i <= numLines; i++ {
lines[startLoc.Line()+i] = true
}
}
}
}
var traverseRule func(r *Rule)
traverseRule = func(r *Rule) {
for _, v := range r.Variables() {
addExpectedLines(v.Expression())
}
for _, m := range r.Matches() {
addExpectedLines(m.Condition())
if m.HasOutput() {
addExpectedLines(m.Output())
}
if m.HasRule() {
traverseRule(m.Rule())
}
}
}
traverseRule(policy.Rule())
return lines
}
24 changes: 22 additions & 2 deletions policy/composer.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,18 @@ type RuleComposer struct {

// Compose stitches together a set of expressions within a CompiledRule into a single CEL ast.
func (c *RuleComposer) Compose(r *CompiledRule) (*cel.Ast, *cel.Issues) {
ruleRoot, _ := c.env.Compile("true")
// dummyExpr is a placeholder expression used as the root of the AST before optimization. Because
// StaticOptimizer copies the source info from it, we must set the correct source info here.
source := recoverOriginalSource(r)
sourceInfo := ast.NewSourceInfo(source)
dummyExpr := ast.NewExprFactory().NewLiteral(1, types.True)
ruleRoot := cel.NewAst(source, ast.NewAST(dummyExpr, sourceInfo))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than modify the existing rootRule approach, consider packaging up the original source as an option to the static optimizer since it appears there's already a need to introduce awareness of such capabilities from this change.

source := recoverOriginalSource(r)
opt, err := cel.NewStaticOptimizer(composer, cel.OptimizeWithSource(source))
if err != nil {
    // add the error to the issues list and return
}
opt.Optimize(c.env, ruleRoot)


composer := &ruleComposerImpl{
rule: r,
varIndices: []varIndex{},
}
opt := cel.NewStaticOptimizer(composer)
opt := cel.NewStaticOptimizerWithSourceInfoMerging(composer)
ast, iss := opt.Optimize(c.env, ruleRoot)
if iss.Err() != nil {
return nil, iss
Expand All @@ -89,6 +95,20 @@ func (c *RuleComposer) Compose(r *CompiledRule) (*cel.Ast, *cel.Issues) {
return opt.Optimize(c.env, ast)
}

func recoverOriginalSource(r *CompiledRule) cel.Source {
match := r.Matches()[0]
var source cel.Source
if match.Output() != nil {
source = match.Output().Expr().Source()
} else {
source = match.Condition().Source()
}
if rel, ok := source.(*RelativeSource); ok {
source = rel.Source
}
return source
}

type varIndex struct {
index int
indexVar string
Expand Down
Loading