Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/ast/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ java_library(
exports = ["//common/src/main/java/dev/cel/common/ast"],
)

java_library(
name = "cel_block",
visibility = ["//:internal"],
exports = ["//common/src/main/java/dev/cel/common/ast:cel_block"],
)

cel_android_library(
name = "ast_android",
exports = ["//common/src/main/java/dev/cel/common/ast:ast_android"],
Expand Down
14 changes: 14 additions & 0 deletions common/src/main/java/dev/cel/common/ast/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ java_library(
],
)

java_library(
name = "cel_block",
srcs = ["CelBlock.java"],
tags = [
],
deps = [
":ast",
"//common:cel_ast",
"//common/annotations",
"//common/navigation",
"@maven//:com_google_guava_guava",
],
)

java_library(
name = "expr_converter",
srcs = EXPR_CONVERTER_SOURCES,
Expand Down
144 changes: 144 additions & 0 deletions common/src/main/java/dev/cel/common/ast/CelBlock.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package dev.cel.common.ast;

import static com.google.common.collect.ImmutableList.toImmutableList;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.annotations.Internal;
import dev.cel.common.navigation.CelNavigableExpr;
import java.util.Optional;

/**
* Represents a {@code cel.@block} expression.
*
* <p>CEL Block is used by the CSE (Common Subexpression Elimination) optimizer to hoist common
* subexpressions into an evaluated block.
*/
@Internal
public final class CelBlock {
public static final String FUNCTION_NAME = "cel.@block";
public static final String INDEX_PREFIX = "@index";

private final CelExpr blockExpr;

private CelBlock(CelExpr blockExpr) {
this.blockExpr = blockExpr;
}

public ImmutableList<CelExpr> indices() {
return blockExpr.call().args().get(0).list().elements();
}

public CelExpr result() {
return blockExpr.call().args().get(1);
}

public CelExpr expr() {
return blockExpr;
}

/**
* Extracts a {@link CelBlock} from the given AST.
*
* <p>Enforces the contract that {@code cel.@block} must only appear exactly once and at the root
* of the AST.
*
* @throws IllegalArgumentException if the block is malformed or its indices are invalid.
*/
public static Optional<CelBlock> extract(CelAbstractSyntaxTree ast) {
CelNavigableExpr celNavigableExpr = CelNavigableExpr.fromExpr(ast.getExpr());

ImmutableList<CelExpr> allCelBlocks =
celNavigableExpr
.allNodes()
.map(CelNavigableExpr::expr)
.filter(expr -> expr.callOrDefault().function().equals(FUNCTION_NAME))
.collect(toImmutableList());
if (allCelBlocks.isEmpty()) {
return Optional.empty();
}

Preconditions.checkArgument(
allCelBlocks.size() == 1,
"Expected 1 cel.block function to be present but found %s",
allCelBlocks.size());
Preconditions.checkArgument(
celNavigableExpr.expr().equals(allCelBlocks.get(0)),
"Expected cel.block to be present at root");

return Optional.of(fromExpr(allCelBlocks.get(0)));
}

/**
* Constructs a {@link CelBlock} from a {@link CelExpr}.
*
* @throws IllegalArgumentException if the expression is not a valid block.
*/
private static CelBlock fromExpr(CelExpr expr) {
Preconditions.checkArgument(
expr.exprKind().getKind() == CelExpr.ExprKind.Kind.CALL,
"Expected cel.@block to be a call expression");
Preconditions.checkArgument(
expr.call().function().equals(FUNCTION_NAME), "Expected function to be cel.@block");
Preconditions.checkArgument(
expr.call().args().size() == 2, "Expected exactly 2 arguments for cel.@block");
Preconditions.checkArgument(
expr.call().args().get(0).exprKind().getKind() == CelExpr.ExprKind.Kind.LIST,
"Expected first argument of cel.@block to be a list");

CelBlock block = new CelBlock(expr);

// Assert correctness on block indices used in subexpressions
ImmutableList<CelExpr> subexprs = block.indices();
for (int i = 0; i < subexprs.size(); i++) {
verifyBlockIndex(subexprs.get(i), i, expr);
}

// Assert correctness on block indices used in block result
CelExpr blockResult = block.result();
verifyBlockIndex(blockResult, subexprs.size(), expr);
boolean resultHasAtLeastOneBlockIndex =
CelNavigableExpr.fromExpr(blockResult)
.allNodes()
.map(CelNavigableExpr::expr)
.anyMatch(e -> e.identOrDefault().name().startsWith(INDEX_PREFIX));
Preconditions.checkArgument(
resultHasAtLeastOneBlockIndex,
"Expected at least one reference of index in cel.block result");

return block;
}

private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue, CelExpr rootBlock) {
boolean areAllIndicesValid =
CelNavigableExpr.fromExpr(celExpr)
.allNodes()
.map(CelNavigableExpr::expr)
.filter(expr -> expr.identOrDefault().name().startsWith(INDEX_PREFIX))
.map(CelExpr::ident)
.allMatch(
blockIdent ->
Integer.parseInt(blockIdent.name().substring(INDEX_PREFIX.length()))
< maxIndexValue);
Preconditions.checkArgument(
areAllIndicesValid,
"Illegal block index found. The index value must be less than %s. Expr: %s",
maxIndexValue,
rootBlock);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ java_library(
"//common:mutable_ast",
"//common:mutable_source",
"//common/ast",
"//common/ast:cel_block",
"//common/ast:mutable_expr",
"//common/navigation",
"//common/navigation:common",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import dev.cel.common.CelSource.Extension.Version;
import dev.cel.common.CelValidationException;
import dev.cel.common.CelVarDecl;
import dev.cel.common.ast.CelBlock;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelCall;
import dev.cel.common.ast.CelExpr.CelComprehension;
Expand Down Expand Up @@ -238,64 +239,12 @@ private OptimizationResult optimizeUsingCelBlock(CelAbstractSyntaxTree ast, Cel
*/
@VisibleForTesting
static void verifyOptimizedAstCorrectness(CelAbstractSyntaxTree ast) {
CelNavigableExpr celNavigableExpr = CelNavigableExpr.fromExpr(ast.getExpr());

ImmutableList<CelExpr> allCelBlocks =
celNavigableExpr
.allNodes()
.map(CelNavigableExpr::expr)
.filter(expr -> expr.callOrDefault().function().equals(CEL_BLOCK_FUNCTION))
.collect(toImmutableList());
if (allCelBlocks.isEmpty()) {
CelBlock celBlock = CelBlock.extract(ast).orElse(null);
if (celBlock == null) {
return;
}

CelExpr celBlockExpr = allCelBlocks.get(0);
Verify.verify(
allCelBlocks.size() == 1,
"Expected 1 cel.block function to be present but found %s",
allCelBlocks.size());
Verify.verify(
celNavigableExpr.expr().equals(celBlockExpr), "Expected cel.block to be present at root");

// Assert correctness on block indices used in subexpressions
CelCall celBlockCall = celBlockExpr.call();
ImmutableList<CelExpr> subexprs = celBlockCall.args().get(0).list().elements();
for (int i = 0; i < subexprs.size(); i++) {
verifyBlockIndex(subexprs.get(i), i);
}

// Assert correctness on block indices used in block result
CelExpr blockResult = celBlockCall.args().get(1);
verifyBlockIndex(blockResult, subexprs.size());
boolean resultHasAtLeastOneBlockIndex =
CelNavigableExpr.fromExpr(blockResult)
.allNodes()
.map(CelNavigableExpr::expr)
.anyMatch(expr -> expr.identOrDefault().name().startsWith(BLOCK_INDEX_PREFIX));
Verify.verify(
resultHasAtLeastOneBlockIndex,
"Expected at least one reference of index in cel.block result");

verifyNoInvalidScopedMangledVariables(celBlockExpr);
}

private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
boolean areAllIndicesValid =
CelNavigableExpr.fromExpr(celExpr)
.allNodes()
.map(CelNavigableExpr::expr)
.filter(expr -> expr.identOrDefault().name().startsWith(BLOCK_INDEX_PREFIX))
.map(CelExpr::ident)
.allMatch(
blockIdent ->
Integer.parseInt(blockIdent.name().substring(BLOCK_INDEX_PREFIX.length()))
< maxIndexValue);
Verify.verify(
areAllIndicesValid,
"Illegal block index found. The index value must be less than %s. Expr: %s",
maxIndexValue,
celExpr);
verifyNoInvalidScopedMangledVariables(celBlock.expr());
}

private static void verifyNoInvalidScopedMangledVariables(CelExpr celExpr) {
Expand Down
Loading
Loading