From 680a48ffbb7d63f8c2ff171c87cd05cb7733ccd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C4=83t=C4=83lina=20=C5=A0egina?= Date: Mon, 8 Dec 2025 15:08:15 +0100 Subject: [PATCH] Refactor match compiling to accept user-defined logic. --- policy/compiler.go | 131 ++++++++++++++++++++++++++++----------------- 1 file changed, 81 insertions(+), 50 deletions(-) diff --git a/policy/compiler.go b/policy/compiler.go index bdf495a98..ad428142c 100644 --- a/policy/compiler.go +++ b/policy/compiler.go @@ -119,6 +119,16 @@ type CompiledMatch struct { nestedRule *CompiledRule } +// NewCompiledMatch creates a CompiledMatch. +func NewCompiledMatch(exprID int64, cond *cel.Ast, output *OutputValue, nestedRule *CompiledRule) *CompiledMatch { + return &CompiledMatch{ + exprID: exprID, + cond: cond, + output: output, + nestedRule: nestedRule, + } +} + // SourceID returns the source identifier associated with the compiled match. func (m *CompiledMatch) SourceID() int64 { return m.exprID @@ -163,6 +173,14 @@ type OutputValue struct { expr *cel.Ast } +// NewOutputValue creates an OutputValue. +func NewOutputValue(exprID int64, expr *cel.Ast) *OutputValue { + return &OutputValue{ + exprID: exprID, + expr: expr, + } +} + // SourceID returns the expression id associated with the output expression. func (o *OutputValue) SourceID() int64 { return o.exprID @@ -174,17 +192,28 @@ func (o *OutputValue) Expr() *cel.Ast { } // CompilerOption specifies a functional option to be applied to new RuleComposer instances. -type CompilerOption func(*compiler) error +type CompilerOption func(*Compiler) error // MaxNestedExpressions limits the number of variable and nested rule expressions during compilation. // // Defaults to 100 if not set. func MaxNestedExpressions(limit int) CompilerOption { - return func(c *compiler) error { + return func(c *Compiler) error { if limit <= 0 { return fmt.Errorf("nested expression limit must be non-negative, non-zero value: %d", limit) } - c.maxNestedExpressions = limit + c.MaxNestedExpressions = limit + return nil + } +} + +// CompileMatches is a function that compiles the match blocks within a rule. +type CompileMatches func(c *Compiler, r *Rule, p *Policy, ruleEnv *cel.Env, iss *cel.Issues) ([]*CompiledMatch, *cel.Issues) + +// WithMatchCompiler sets a custom CompileMatches function on the compiler. +func WithMatchCompiler(f CompileMatches) CompilerOption { + return func(c *Compiler) error { + c.compileMatches = f return nil } } @@ -207,11 +236,12 @@ func Compile(env *cel.Env, p *Policy, opts ...CompilerOption) (*cel.Ast, *cel.Is // match statements. The compiled rule defines an expression graph, which can be composed into a single // expression via the RuleComposer.Compose method. func CompileRule(env *cel.Env, p *Policy, opts ...CompilerOption) (*CompiledRule, *cel.Issues) { - c := &compiler{ + c := &Compiler{ env: env, info: p.SourceInfo(), src: p.Source(), - maxNestedExpressions: defaultMaxNestedExpressions, + compileMatches: defaultCompileMatches, + MaxNestedExpressions: defaultMaxNestedExpressions, } var err error errs := common.NewErrors(c.src) @@ -248,22 +278,28 @@ func CompileRule(env *cel.Env, p *Policy, opts ...CompilerOption) (*CompiledRule c.env = env } } - return c.compileRule(p.Rule(), c.env, iss) + return c.CompileRuleImpl(p.Rule(), p, c.env, iss) } -type compiler struct { +// Compiler holds the state of the policy compiler. +type Compiler struct { env *cel.Env info *ast.SourceInfo src *Source - maxNestedExpressions int - nestedCount int + compileMatches CompileMatches + MaxNestedExpressions int + NestedCount int } -func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*CompiledRule, *cel.Issues) { +// CompileRuleImpl compiles a Rule into a CompiledRule. +// +// This is the main entry point for the compilation process. It compiles the variables and matches +// within the rule. +func (c *Compiler) CompileRuleImpl(r *Rule, p *Policy, ruleEnv *cel.Env, iss *cel.Issues) (*CompiledRule, *cel.Issues) { compiledVars := make([]*CompiledVariable, len(r.Variables())) for i, v := range r.Variables() { - exprSrc := c.relSource(v.Expression()) + exprSrc := c.RelSource(v.Expression()) varAST, exprIss := ruleEnv.CompileSource(exprSrc) varName := v.Name().Value @@ -295,16 +331,36 @@ func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*Com compiledVars[i] = compiledVar // Increment the nesting count post-compile. - c.nestedCount++ - if c.nestedCount == c.maxNestedExpressions+1 { + c.NestedCount++ + if c.NestedCount == c.MaxNestedExpressions+1 { iss.ReportErrorAtID(compiledVar.SourceID(), "variable exceeds nested expression limit") } } // Compile the set of match conditions under the rule. + compiledMatches, iss := c.compileMatches(c, r, p, ruleEnv, iss) + + // Validate that all branches in the rule are reachable + rule := &CompiledRule{ + exprID: r.exprID, + id: r.id, + variables: compiledVars, + matches: compiledMatches, + } + + // Note: Consider supporting configurable policy validators that take the policy, rule, and issues + // Validate type agreement between the different match outputs + c.checkMatchOutputTypesAgree(rule, iss) + // Validate that all branches in the policy are reachable + c.checkUnreachableCode(rule, iss) + + return rule, iss +} + +func defaultCompileMatches(c *Compiler, r *Rule, p *Policy, ruleEnv *cel.Env, iss *cel.Issues) ([]*CompiledMatch, *cel.Issues) { compiledMatches := []*CompiledMatch{} for _, m := range r.Matches() { - condSrc := c.relSource(m.Condition()) + condSrc := c.RelSource(m.Condition()) condAST, condIss := ruleEnv.CompileSource(condSrc) iss = iss.Append(condIss) // This case cannot happen when the Policy object is parsed from yaml, but could happen @@ -315,54 +371,28 @@ func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*Com continue } if m.HasOutput() { - outSrc := c.relSource(m.Output()) + outSrc := c.RelSource(m.Output()) outAST, outIss := ruleEnv.CompileSource(outSrc) iss = iss.Append(outIss) - compiledMatches = append(compiledMatches, &CompiledMatch{ - exprID: m.exprID, - cond: condAST, - output: &OutputValue{ - exprID: m.Output().ID, - expr: outAST, - }, - }) + compiledMatches = append(compiledMatches, NewCompiledMatch(m.SourceID(), condAST, NewOutputValue(m.Output().ID, outAST), nil)) continue } if m.HasRule() { - nestedRule, ruleIss := c.compileRule(m.Rule(), ruleEnv, iss) + nestedRule, ruleIss := c.CompileRuleImpl(m.Rule(), p, ruleEnv, iss) iss = iss.Append(ruleIss) - compiledMatches = append(compiledMatches, &CompiledMatch{ - exprID: m.exprID, - cond: condAST, - nestedRule: nestedRule, - }) + compiledMatches = append(compiledMatches, NewCompiledMatch(m.SourceID(), condAST, nil, nestedRule)) // Increment the nesting count post-compile. - c.nestedCount++ - if c.nestedCount == c.maxNestedExpressions+1 { + c.NestedCount++ + if c.NestedCount == c.MaxNestedExpressions+1 { iss.ReportErrorAtID(nestedRule.SourceID(), "rule exceeds nested expression limit") } } } - - // Validate that all branches in the rule are reachable - rule := &CompiledRule{ - exprID: r.exprID, - id: r.id, - variables: compiledVars, - matches: compiledMatches, - } - - // Note: Consider supporting configurable policy validators that take the policy, rule, and issues - // Validate type agreement between the different match outputs - c.checkMatchOutputTypesAgree(rule, iss) - // Validate that all branches in the policy are reachable - c.checkUnreachableCode(rule, iss) - - return rule, iss + return compiledMatches, iss } -func (c *compiler) checkMatchOutputTypesAgree(rule *CompiledRule, iss *cel.Issues) { +func (c *Compiler) checkMatchOutputTypesAgree(rule *CompiledRule, iss *cel.Issues) { var outputType *cel.Type for _, m := range rule.Matches() { if outputType == nil { @@ -391,7 +421,7 @@ func (c *compiler) checkMatchOutputTypesAgree(rule *CompiledRule, iss *cel.Issue } } -func (c *compiler) checkUnreachableCode(rule *CompiledRule, iss *cel.Issues) { +func (c *Compiler) checkUnreachableCode(rule *CompiledRule, iss *cel.Issues) { ruleHasOptional := rule.HasOptionalOutput() compiledMatches := rule.Matches() matchCount := len(compiledMatches) @@ -411,7 +441,8 @@ func (c *compiler) checkUnreachableCode(rule *CompiledRule, iss *cel.Issues) { } } -func (c *compiler) relSource(pstr ValueString) *RelativeSource { +// RelSource returns a RelativeSource for a given expression id. +func (c *Compiler) RelSource(pstr ValueString) *RelativeSource { line := 0 col := 1 if offset, found := c.info.GetOffsetRange(pstr.ID); found {