Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -922,35 +922,6 @@ def astype_op_impl(x: ibis_types.Value, op: ops.AsTypeOp):
elif to_type == ibis_dtypes.time:
return x_converted.time()

if to_type == ibis_dtypes.json:
if x.type() == ibis_dtypes.string:
return parse_json_in_safe(x) if op.safe else parse_json(x)
if x.type() == ibis_dtypes.bool:
x_bool = typing.cast(
ibis_types.StringValue,
bigframes.core.compile.ibis_types.cast_ibis_value(
x, ibis_dtypes.string, safe=op.safe
),
).lower()
return parse_json_in_safe(x_bool) if op.safe else parse_json(x_bool)
if x.type() in (ibis_dtypes.int64, ibis_dtypes.float64):
x_str = bigframes.core.compile.ibis_types.cast_ibis_value(
x, ibis_dtypes.string, safe=op.safe
)
return parse_json_in_safe(x_str) if op.safe else parse_json(x_str)

if x.type() == ibis_dtypes.json:
if to_type == ibis_dtypes.int64:
return cast_json_to_int64_in_safe(x) if op.safe else cast_json_to_int64(x)
if to_type == ibis_dtypes.float64:
return (
cast_json_to_float64_in_safe(x) if op.safe else cast_json_to_float64(x)
)
if to_type == ibis_dtypes.bool:
return cast_json_to_bool_in_safe(x) if op.safe else cast_json_to_bool(x)
if to_type == ibis_dtypes.string:
return cast_json_to_string_in_safe(x) if op.safe else cast_json_to_string(x)

# TODO: either inline this function, or push rest of this op into the function
return bigframes.core.compile.ibis_types.cast_ibis_value(x, to_type, safe=op.safe)

Expand Down Expand Up @@ -1193,9 +1164,27 @@ def parse_json_op_impl(x: ibis_types.Value, op: ops.ParseJSON):
return parse_json(json_str=x)


@scalar_op_compiler.register_unary_op(ops.ToJSON)
def to_json_op_impl(json_obj: ibis_types.Value):
return to_json(json_obj=json_obj)
@scalar_op_compiler.register_unary_op(ops.ToJSON, pass_op=True)
def to_json_op_impl(x: ibis_types.Value, op: ops.ToJSON):
if x.type() == ibis_dtypes.string:
return parse_json_in_safe(x) if op.safe else parse_json(x)
return x.isnull().ifelse(ibis.null().cast(ibis_dtypes.json), to_json(x))


@scalar_op_compiler.register_unary_op(ops.JSONDecode, pass_op=True)
def json_decode_op_impl(x: ibis_types.Value, op: ops.JSONDecode):
to_type = bigframes.core.compile.ibis_types.bigframes_dtype_to_ibis_dtype(
op.to_type
)
if to_type == ibis_dtypes.int64:
return cast_json_to_int64_in_safe(x) if op.safe else cast_json_to_int64(x)
if to_type == ibis_dtypes.float64:
return cast_json_to_float64_in_safe(x) if op.safe else cast_json_to_float64(x)
if to_type == ibis_dtypes.bool:
return cast_json_to_bool_in_safe(x) if op.safe else cast_json_to_bool(x)
if to_type == ibis_dtypes.string:
return cast_json_to_string_in_safe(x) if op.safe else cast_json_to_string(x)
raise TypeError(f"Cannot cast from JSON to type {to_type}")


@scalar_op_compiler.register_unary_op(ops.ToJSONString)
Expand Down
32 changes: 26 additions & 6 deletions packages/bigframes/bigframes/core/compile/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,20 @@ class PolarsExpressionCompiler:
Should be extended to dispatch based on bigframes schema types.
"""

@functools.singledispatchmethod
_expr_types: dict[int, bigframes.dtypes.ExpressionType] = dataclasses.field(
default_factory=dict, init=False, compare=False
)

def compile_expression(self, expression: ex.Expression) -> pl.Expr:
res = self._compile_expression(expression)
self._expr_types[id(res)] = expression.output_type
return res

@functools.singledispatchmethod
def _compile_expression(self, expression: ex.Expression) -> pl.Expr:
raise NotImplementedError(f"Cannot compile expression: {expression}")

@compile_expression.register
@_compile_expression.register
def _(
self,
expression: ex.ScalarConstantExpression,
Expand All @@ -159,21 +168,21 @@ def _(

return pl.lit(value, _bigframes_dtype_to_polars_dtype(expression.dtype))

@compile_expression.register
@_compile_expression.register
def _(
self,
expression: ex.DerefOp,
) -> pl.Expr:
return pl.col(expression.id.sql)

@compile_expression.register
@_compile_expression.register
def _(
self,
expression: ex.ResolvedDerefOp,
) -> pl.Expr:
return pl.col(expression.id.sql)

@compile_expression.register
@_compile_expression.register
def _(
self,
expression: ex.OpExpression,
Expand Down Expand Up @@ -478,10 +487,21 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
)

@compile_op.register(json_ops.JSONDecode)
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
def _(self, op: json_ops.JSONDecode, input: pl.Expr) -> pl.Expr:
assert isinstance(op, json_ops.JSONDecode)
return input.str.json_decode(_DTYPE_MAPPING[op.to_type])

@compile_op.register(json_ops.ToJSON)
def _(self, op: json_ops.ToJSON, input: pl.Expr) -> pl.Expr:
from_type = self._expr_types.get(id(input))
if from_type in (
bigframes.dtypes.STRING_DTYPE,
bigframes.dtypes.JSON_DTYPE,
):
return input
else:
return input.cast(pl.String())

@compile_op.register(arr_ops.ToArrayOp)
def _(self, op: ops.ToArrayOp, *inputs: pl.Expr) -> pl.Expr:
return pl.concat_list(*inputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
comparison_ops,
datetime_ops,
generic_ops,
json_ops,
numeric_ops,
string_ops,
)
Expand Down Expand Up @@ -412,9 +411,6 @@ def _coerce_comparables(
def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
if arg.output_type == cast_op.to_type:
return arg

if arg.output_type == dtypes.JSON_DTYPE:
return json_ops.JSONDecode(cast_op.to_type).as_expr(arg)
if (
arg.output_type == dtypes.STRING_DTYPE
and cast_op.to_type == dtypes.DATETIME_DTYPE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
sg_to_type = sqlglot_types.from_bigframes_dtype(to_type)
sg_expr = expr.expr

if to_type == dtypes.JSON_DTYPE:
return _cast_to_json(expr, op)

if from_type == dtypes.JSON_DTYPE:
return _cast_from_json(expr, op)

if to_type == dtypes.INT_DTYPE:
result = _cast_to_int(expr, op)
if result is not None:
Expand Down Expand Up @@ -251,35 +245,6 @@ def _(*values: TypedExpr) -> sge.Expression:


# Helper functions
def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
from_type = expr.dtype
sg_expr = expr.expr

if from_type == dtypes.STRING_DTYPE:
func_name = "SAFE.PARSE_JSON" if op.safe else "PARSE_JSON"
return sge.func(func_name, sg_expr)
if from_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE, dtypes.FLOAT_DTYPE):
sg_expr = sge.Cast(this=sg_expr, to="STRING")
return sge.func("PARSE_JSON", sg_expr)
raise TypeError(f"Cannot cast from {from_type} to {dtypes.JSON_DTYPE}")


def _cast_from_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
to_type = op.to_type
sg_expr = expr.expr
func_name = ""
if to_type == dtypes.INT_DTYPE:
func_name = "INT64"
elif to_type == dtypes.FLOAT_DTYPE:
func_name = "FLOAT64"
elif to_type == dtypes.BOOL_DTYPE:
func_name = "BOOL"
elif to_type == dtypes.STRING_DTYPE:
func_name = "STRING"
if func_name:
func_name = "SAFE." + func_name if op.safe else func_name
return sge.func(func_name, sg_expr)
raise TypeError(f"Cannot cast from {dtypes.JSON_DTYPE} to {to_type}")


def _cast_to_int(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import bigframes_vendored.sqlglot.expressions as sge

import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler
from bigframes import dtypes
from bigframes import operations as ops
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr

Expand Down Expand Up @@ -69,9 +70,39 @@ def _(expr: TypedExpr) -> sge.Expression:
return sge.func("PARSE_JSON", expr.expr)


@register_unary_op(ops.ToJSON)
def _(expr: TypedExpr) -> sge.Expression:
return sge.func("TO_JSON", expr.expr)
@register_unary_op(ops.ToJSON, pass_op=True)
def _(expr: TypedExpr, op: ops.ToJSON) -> sge.Expression:
from_type = expr.dtype
sg_expr = expr.expr

# Parsing really should be a distinct operation from serialization, but
# this was the way things were intially launched.
if from_type == dtypes.STRING_DTYPE:
func_name = "SAFE.PARSE_JSON" if op.safe else "PARSE_JSON"
return sge.func(func_name, sg_expr)
else:
return sge.func(
"IF", sg_expr.is_(sge.Null()), sge.Null(), sge.func("TO_JSON", sg_expr)
)


@register_unary_op(ops.JSONDecode, pass_op=True)
def _(expr: TypedExpr, op: ops.JSONDecode) -> sge.Expression:
to_type = op.to_type
sg_expr = expr.expr
func_name = ""
if to_type == dtypes.INT_DTYPE:
func_name = "INT64"
elif to_type == dtypes.FLOAT_DTYPE:
func_name = "FLOAT64"
elif to_type == dtypes.BOOL_DTYPE:
func_name = "BOOL"
elif to_type == dtypes.STRING_DTYPE:
func_name = "STRING"
if func_name:
func_name = "SAFE." + func_name if op.safe else func_name
return sge.func(func_name, sg_expr)
raise TypeError(f"Cannot cast from {dtypes.JSON_DTYPE} to {to_type}")


@register_unary_op(ops.ToJSONString)
Expand Down
38 changes: 31 additions & 7 deletions packages/bigframes/bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,17 +442,41 @@ def astype(
if errors not in ["raise", "null"]:
raise ValueError("Arg 'error' must be one of 'raise' or 'null'")

if isinstance(dtype, dict):
for col in dtype:
if col not in self.columns:
raise KeyError(
f"Only Column Names are allowed in dtypes dict. '{col}' is not in the columns."
)

safe_cast = errors == "null"

if isinstance(dtype, dict):
result = self.copy()
for col, to_type in dtype.items():
result[col] = result[col].astype(to_type)
return result
exprs: list[ex.Expression] = []
for col_id, col_label in zip(
self._block.value_columns, self._block.column_labels
):
from_type = self._block._column_type(col_id)

dtype = bigframes.dtypes.bigframes_type(dtype)
if isinstance(dtype, dict):
if col_label not in dtype:
exprs.append(ex.deref(col_id))
continue
to_type = bigframes.dtypes.bigframes_type(dtype[col_label])
else:
to_type = bigframes.dtypes.bigframes_type(dtype)

op: ops.UnaryOp
if to_type == bigframes.dtypes.JSON_DTYPE:
op = ops.ToJSON(safe=safe_cast)
elif from_type == bigframes.dtypes.JSON_DTYPE:
op = ops.JSONDecode(to_type=to_type, safe=safe_cast)
else:
op = ops.AsTypeOp(to_type=to_type, safe=safe_cast)

return self._apply_unary_op(ops.AsTypeOp(dtype, safe_cast))
exprs.append(op.as_expr(ex.deref(col_id)))

block = self._block.project_exprs(exprs, labels=self.columns, drop=True)
return DataFrame(block)

def _should_sql_have_index(self) -> bool:
"""Should the SQL we pass to BQML and other I/O include the index?"""
Expand Down
24 changes: 22 additions & 2 deletions packages/bigframes/bigframes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,30 @@ def is_json_like(type_: ExpressionType) -> bool:
return type_ == JSON_DTYPE or type_ == STRING_DTYPE # Including JSON string


def is_json_encoding_type(type_: ExpressionType) -> bool:
def is_json_encoding_type(type_: ExpressionType, strict: bool = False) -> bool:
# Types can be converted into JSON.
# https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_encodings
return type_ != GEO_DTYPE
if is_array_like(type_):
return is_json_encoding_type(get_array_inner_type(type_), strict=strict)
if is_struct_like(type_):
return all(
is_json_encoding_type(field_type, strict=strict)
for field_type in get_struct_fields(type_).values()
)

if strict:
# Strict are the types (mostly) defined by json spec, with no/minimal
# encoding/decoding involved. So no temporal types.
return type_ in (
INT_DTYPE,
FLOAT_DTYPE,
BOOL_DTYPE,
STRING_DTYPE,
JSON_DTYPE,
)
else:
# GoogleSQL implementation handles anything but GEO
return type_ != GEO_DTYPE


def is_numeric(type_: ExpressionType, include_bool: bool = True) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions packages/bigframes/bigframes/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
)
from bigframes.operations.googlesql import GoogleSqlScalarOp
from bigframes.operations.json_ops import (
JSONDecode,
JSONExtract,
JSONExtractArray,
JSONExtractStringArray,
Expand Down Expand Up @@ -382,6 +383,7 @@
"FloorDtOp",
"IntegerLabelToDatetimeOp",
# JSON ops
"JSONDecode",
"JSONExtract",
"JSONExtractArray",
"JSONExtractStringArray",
Expand Down
Loading
Loading