Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 81 additions & 50 deletions policy/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ type CompiledMatch struct {
nestedRule *CompiledRule
}

// NewCompiledMatch creates a CompiledMatch.
func NewCompiledMatch(exprID int64, cond *cel.Ast, output *OutputValue, nestedRule *CompiledRule) *CompiledMatch {
return &CompiledMatch{
exprID: exprID,
cond: cond,
output: output,
nestedRule: nestedRule,
}
}

// SourceID returns the source identifier associated with the compiled match.
func (m *CompiledMatch) SourceID() int64 {
return m.exprID
Expand Down Expand Up @@ -163,6 +173,14 @@ type OutputValue struct {
expr *cel.Ast
}

// NewOutputValue creates an OutputValue.
func NewOutputValue(exprID int64, expr *cel.Ast) *OutputValue {
return &OutputValue{
exprID: exprID,
expr: expr,
}
}

// SourceID returns the expression id associated with the output expression.
func (o *OutputValue) SourceID() int64 {
return o.exprID
Expand All @@ -174,17 +192,28 @@ func (o *OutputValue) Expr() *cel.Ast {
}

// CompilerOption specifies a functional option to be applied to new RuleComposer instances.
type CompilerOption func(*compiler) error
type CompilerOption func(*Compiler) error

// MaxNestedExpressions limits the number of variable and nested rule expressions during compilation.
//
// Defaults to 100 if not set.
func MaxNestedExpressions(limit int) CompilerOption {
return func(c *compiler) error {
return func(c *Compiler) error {
if limit <= 0 {
return fmt.Errorf("nested expression limit must be non-negative, non-zero value: %d", limit)
}
c.maxNestedExpressions = limit
c.MaxNestedExpressions = limit
return nil
}
}

// CompileMatches is a function that compiles the match blocks within a rule.
type CompileMatches func(c *Compiler, r *Rule, p *Policy, ruleEnv *cel.Env, iss *cel.Issues) ([]*CompiledMatch, *cel.Issues)

// WithMatchCompiler sets a custom CompileMatches function on the compiler.
func WithMatchCompiler(f CompileMatches) CompilerOption {
return func(c *Compiler) error {
c.compileMatches = f
return nil
}
}
Expand All @@ -207,11 +236,12 @@ func Compile(env *cel.Env, p *Policy, opts ...CompilerOption) (*cel.Ast, *cel.Is
// match statements. The compiled rule defines an expression graph, which can be composed into a single
// expression via the RuleComposer.Compose method.
func CompileRule(env *cel.Env, p *Policy, opts ...CompilerOption) (*CompiledRule, *cel.Issues) {
c := &compiler{
c := &Compiler{
env: env,
info: p.SourceInfo(),
src: p.Source(),
maxNestedExpressions: defaultMaxNestedExpressions,
compileMatches: defaultCompileMatches,
MaxNestedExpressions: defaultMaxNestedExpressions,
}
var err error
errs := common.NewErrors(c.src)
Expand Down Expand Up @@ -248,22 +278,28 @@ func CompileRule(env *cel.Env, p *Policy, opts ...CompilerOption) (*CompiledRule
c.env = env
}
}
return c.compileRule(p.Rule(), c.env, iss)
return c.CompileRuleImpl(p.Rule(), p, c.env, iss)
}

type compiler struct {
// Compiler holds the state of the policy compiler.
type Compiler struct {
env *cel.Env
info *ast.SourceInfo
src *Source

maxNestedExpressions int
nestedCount int
compileMatches CompileMatches
MaxNestedExpressions int
NestedCount int
}

func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*CompiledRule, *cel.Issues) {
// CompileRuleImpl compiles a Rule into a CompiledRule.
//
// This is the main entry point for the compilation process. It compiles the variables and matches
// within the rule.
func (c *Compiler) CompileRuleImpl(r *Rule, p *Policy, ruleEnv *cel.Env, iss *cel.Issues) (*CompiledRule, *cel.Issues) {
compiledVars := make([]*CompiledVariable, len(r.Variables()))
for i, v := range r.Variables() {
exprSrc := c.relSource(v.Expression())
exprSrc := c.RelSource(v.Expression())
varAST, exprIss := ruleEnv.CompileSource(exprSrc)
varName := v.Name().Value

Expand Down Expand Up @@ -295,16 +331,36 @@ func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*Com
compiledVars[i] = compiledVar

// Increment the nesting count post-compile.
c.nestedCount++
if c.nestedCount == c.maxNestedExpressions+1 {
c.NestedCount++
if c.NestedCount == c.MaxNestedExpressions+1 {
iss.ReportErrorAtID(compiledVar.SourceID(), "variable exceeds nested expression limit")
}
}

// Compile the set of match conditions under the rule.
compiledMatches, iss := c.compileMatches(c, r, p, ruleEnv, iss)

// Validate that all branches in the rule are reachable
rule := &CompiledRule{
exprID: r.exprID,
id: r.id,
variables: compiledVars,
matches: compiledMatches,
}

// Note: Consider supporting configurable policy validators that take the policy, rule, and issues
// Validate type agreement between the different match outputs
c.checkMatchOutputTypesAgree(rule, iss)
// Validate that all branches in the policy are reachable
c.checkUnreachableCode(rule, iss)

return rule, iss
}

func defaultCompileMatches(c *Compiler, r *Rule, p *Policy, ruleEnv *cel.Env, iss *cel.Issues) ([]*CompiledMatch, *cel.Issues) {
compiledMatches := []*CompiledMatch{}
for _, m := range r.Matches() {
condSrc := c.relSource(m.Condition())
condSrc := c.RelSource(m.Condition())
condAST, condIss := ruleEnv.CompileSource(condSrc)
iss = iss.Append(condIss)
// This case cannot happen when the Policy object is parsed from yaml, but could happen
Expand All @@ -315,54 +371,28 @@ func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*Com
continue
}
if m.HasOutput() {
outSrc := c.relSource(m.Output())
outSrc := c.RelSource(m.Output())
outAST, outIss := ruleEnv.CompileSource(outSrc)
iss = iss.Append(outIss)
compiledMatches = append(compiledMatches, &CompiledMatch{
exprID: m.exprID,
cond: condAST,
output: &OutputValue{
exprID: m.Output().ID,
expr: outAST,
},
})
compiledMatches = append(compiledMatches, NewCompiledMatch(m.SourceID(), condAST, NewOutputValue(m.Output().ID, outAST), nil))
continue
}
if m.HasRule() {
nestedRule, ruleIss := c.compileRule(m.Rule(), ruleEnv, iss)
nestedRule, ruleIss := c.CompileRuleImpl(m.Rule(), p, ruleEnv, iss)
iss = iss.Append(ruleIss)
compiledMatches = append(compiledMatches, &CompiledMatch{
exprID: m.exprID,
cond: condAST,
nestedRule: nestedRule,
})
compiledMatches = append(compiledMatches, NewCompiledMatch(m.SourceID(), condAST, nil, nestedRule))

// Increment the nesting count post-compile.
c.nestedCount++
if c.nestedCount == c.maxNestedExpressions+1 {
c.NestedCount++
if c.NestedCount == c.MaxNestedExpressions+1 {
iss.ReportErrorAtID(nestedRule.SourceID(), "rule exceeds nested expression limit")
}
}
}

// Validate that all branches in the rule are reachable
rule := &CompiledRule{
exprID: r.exprID,
id: r.id,
variables: compiledVars,
matches: compiledMatches,
}

// Note: Consider supporting configurable policy validators that take the policy, rule, and issues
// Validate type agreement between the different match outputs
c.checkMatchOutputTypesAgree(rule, iss)
// Validate that all branches in the policy are reachable
c.checkUnreachableCode(rule, iss)

return rule, iss
return compiledMatches, iss
}

func (c *compiler) checkMatchOutputTypesAgree(rule *CompiledRule, iss *cel.Issues) {
func (c *Compiler) checkMatchOutputTypesAgree(rule *CompiledRule, iss *cel.Issues) {
var outputType *cel.Type
for _, m := range rule.Matches() {
if outputType == nil {
Expand Down Expand Up @@ -391,7 +421,7 @@ func (c *compiler) checkMatchOutputTypesAgree(rule *CompiledRule, iss *cel.Issue
}
}

func (c *compiler) checkUnreachableCode(rule *CompiledRule, iss *cel.Issues) {
func (c *Compiler) checkUnreachableCode(rule *CompiledRule, iss *cel.Issues) {
ruleHasOptional := rule.HasOptionalOutput()
compiledMatches := rule.Matches()
matchCount := len(compiledMatches)
Expand All @@ -411,7 +441,8 @@ func (c *compiler) checkUnreachableCode(rule *CompiledRule, iss *cel.Issues) {
}
}

func (c *compiler) relSource(pstr ValueString) *RelativeSource {
// RelSource returns a RelativeSource for a given expression id.
func (c *Compiler) RelSource(pstr ValueString) *RelativeSource {
line := 0
col := 1
if offset, found := c.info.GetOffsetRange(pstr.ID); found {
Expand Down