Skip to content

Commit

Permalink
reformat datafusion file properly
Browse files Browse the repository at this point in the history
  • Loading branch information
Venkat Allam committed Dec 26, 2024
1 parent dcb63ac commit 759a4c4
Showing 1 changed file with 18 additions and 36 deletions.
54 changes: 18 additions & 36 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
if dtype.is_decimal():
return self.cast(
sg.exp.convert(str(value)),
dt.Decimal(precision=dtype.precision or 38,
scale=dtype.scale or 9),
dt.Decimal(precision=dtype.precision or 38, scale=dtype.scale or 9),
)
elif dtype.is_numeric():
if isinstance(value, float):
Expand Down Expand Up @@ -132,8 +131,7 @@ def visit_Cast(self, op, *, arg, to):
if to.is_interval():
unit = to.unit.name.lower()
return sg.cast(
self.f.concat(self.cast(arg, dt.string),
f" {unit}"), "interval"
self.f.concat(self.cast(arg, dt.string), f" {unit}"), "interval"
)
if to.is_timestamp():
return self._to_timestamp(arg, to)
Expand Down Expand Up @@ -216,16 +214,14 @@ def visit_LPad(self, op, *, arg, length, pad):
return self.if_(
length <= self.f.length(arg),
arg,
self.f.concat(self.f.repeat(
pad, length - self.f.length(arg)), arg),
self.f.concat(self.f.repeat(pad, length - self.f.length(arg)), arg),
)

def visit_RPad(self, op, *, arg, length, pad):
return self.if_(
length <= self.f.length(arg),
arg,
self.f.concat(arg, self.f.repeat(
pad, length - self.f.length(arg))),
self.f.concat(arg, self.f.repeat(pad, length - self.f.length(arg))),
)

def visit_ExtractFragment(self, op, *, arg):
Expand Down Expand Up @@ -355,8 +351,7 @@ def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct
)
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(
this=cond, expression=where)
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_Covariance(self, op, *, left, right, how, where):
Expand Down Expand Up @@ -412,14 +407,11 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit):
def visit_DateFromYMD(self, op, *, year, month, day):
return self.cast(
self.f.concat(
self.f.lpad(
self.cast(self.cast(year, dt.int64), dt.string), 4, "0"),
self.f.lpad(self.cast(self.cast(year, dt.int64), dt.string), 4, "0"),
"-",
self.f.lpad(
self.cast(self.cast(month, dt.int64), dt.string), 2, "0"),
self.f.lpad(self.cast(self.cast(month, dt.int64), dt.string), 2, "0"),
"-",
self.f.lpad(
self.cast(self.cast(day, dt.int64), dt.string), 2, "0"),
self.f.lpad(self.cast(self.cast(day, dt.int64), dt.string), 2, "0"),
),
dt.date,
)
Expand All @@ -429,23 +421,17 @@ def visit_TimestampFromYMDHMS(
):
return self.f.to_timestamp_micros(
self.f.concat(
self.f.lpad(
self.cast(self.cast(year, dt.int64), dt.string), 4, "0"),
self.f.lpad(self.cast(self.cast(year, dt.int64), dt.string), 4, "0"),
"-",
self.f.lpad(
self.cast(self.cast(month, dt.int64), dt.string), 2, "0"),
self.f.lpad(self.cast(self.cast(month, dt.int64), dt.string), 2, "0"),
"-",
self.f.lpad(
self.cast(self.cast(day, dt.int64), dt.string), 2, "0"),
self.f.lpad(self.cast(self.cast(day, dt.int64), dt.string), 2, "0"),
"T",
self.f.lpad(
self.cast(self.cast(hours, dt.int64), dt.string), 2, "0"),
self.f.lpad(self.cast(self.cast(hours, dt.int64), dt.string), 2, "0"),
":",
self.f.lpad(
self.cast(self.cast(minutes, dt.int64), dt.string), 2, "0"),
self.f.lpad(self.cast(self.cast(minutes, dt.int64), dt.string), 2, "0"),
":",
self.f.lpad(
self.cast(self.cast(seconds, dt.int64), dt.string), 2, "0"),
self.f.lpad(self.cast(self.cast(seconds, dt.int64), dt.string), 2, "0"),
".000000Z",
)
)
Expand All @@ -459,22 +445,19 @@ def visit_ArrayIndex(self, op, *, arg, index):
def visit_StringConcat(self, op, *, arg):
any_args_null = (a.is_(NULL) for a in arg)
return self.if_(
sg.or_(*any_args_null), self.cast(NULL,
dt.string), self.f.concat(*arg)
sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg)
)

def visit_First(self, op, *, arg, where, order_by, include_null):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(
this=cond, expression=where)
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.first_value(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by, include_null):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(
this=cond, expression=where)
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_ArgMin(self, op, *, arg, key, where):
Expand Down Expand Up @@ -517,8 +500,7 @@ def visit_Aggregate(self, op, *, parent, groups, metrics):
# datafusion lower cases all column names internally unless quoted so
# quoted=True is required here for correctness
by_names_quoted = tuple(
sg.column(key, table=getattr(
value, "table", None), quoted=quoted)
sg.column(key, table=getattr(value, "table", None), quoted=quoted)
for key, value in groups.items()
)
selections = by_names_quoted + metrics
Expand Down

0 comments on commit 759a4c4

Please sign in to comment.