diff options
Diffstat (limited to 'sql/catalyst/src')
8 files changed, 91 insertions, 26 deletions
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 9a643465a9..b475abdce2 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -324,7 +324,7 @@ queryPrimary ; sortItem - : expression ordering=(ASC | DESC)? + : expression ordering=(ASC | DESC)? (NULLS nullOrder=(LAST | FIRST))? ; querySpecification @@ -641,7 +641,8 @@ number nonReserved : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS | DATABASES | ADD - | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT + | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST + | MAP | ARRAY | STRUCT | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS @@ -729,6 +730,8 @@ UNBOUNDED: 'UNBOUNDED'; PRECEDING: 'PRECEDING'; FOLLOWING: 'FOLLOWING'; CURRENT: 'CURRENT'; +FIRST: 'FIRST'; +LAST: 'LAST'; ROW: 'ROW'; WITH: 'WITH'; VALUES: 'VALUES'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 18f814d6cd..92bf8e0536 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -714,9 +714,9 @@ class Analyzer( case s @ Sort(orders, global, child) if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { - case s @ SortOrder(UnresolvedOrdinal(index), direction) => + case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering) => if (index > 0 && index <= child.output.size) { - SortOrder(child.output(index - 1), direction) + SortOrder(child.output(index - 1), direction, nullOrdering) } else { s.failAnalysis( s"ORDER BY position $index is not in select list " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index 6d8dc86282..af0a565f73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -36,7 +36,7 @@ class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] def apply(plan: LogicalPlan): LogicalPlan = plan transform { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { - case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _) => + case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) => val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) withOrigin(order.origin)(order.copy(child = newOrdinal)) case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 8549187a66..66e52ca68a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -109,8 +109,9 @@ package object dsl { def cast(to: DataType): Expression = Cast(expr, to) def asc: SortOrder = SortOrder(expr, Ascending) + def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast) def desc: SortOrder = SortOrder(expr, Descending) - + def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst) def as(alias: String): NamedExpression = Alias(expr, alias)() def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index de779ed370..d015125bac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -21,26 +21,43 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator -import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator +import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ abstract sealed class SortDirection { def sql: String + def defaultNullOrdering: NullOrdering +} + +abstract sealed class NullOrdering { + def sql: String } case object Ascending extends SortDirection { override def sql: String = "ASC" + override def defaultNullOrdering: NullOrdering = NullsFirst } case object Descending extends SortDirection { override def sql: String = "DESC" + override def defaultNullOrdering: NullOrdering = NullsLast +} + +case object NullsFirst extends NullOrdering{ + override def sql: String = "NULLS FIRST" +} + +case object NullsLast extends NullOrdering{ + override def sql: String = "NULLS LAST" } /** * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. */ -case class SortOrder(child: Expression, direction: SortDirection) +case class SortOrder( + child: Expression, + direction: SortDirection, + nullOrdering: NullOrdering) extends UnaryExpression with Unevaluable { /** Sort order is not foldable because we don't have an eval for it. */ @@ -57,12 +74,18 @@ case class SortOrder(child: Expression, direction: SortDirection) override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable - override def toString: String = s"$child ${direction.sql}" - override def sql: String = child.sql + " " + direction.sql + override def toString: String = s"$child ${direction.sql} ${nullOrdering.sql}" + override def sql: String = child.sql + " " + direction.sql + " " + nullOrdering.sql def isAscending: Boolean = direction == Ascending } +object SortOrder { + def apply(child: Expression, direction: SortDirection): SortOrder = { + new SortOrder(child, direction, direction.defaultNullOrdering) + } +} + /** * An expression to generate a 64-bit long prefix used in sorting. If the sort must operate over * null keys as well, this.nullValue can be used in place of emitted null prefixes in the sort. @@ -71,14 +94,35 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { val nullValue = child.child.dataType match { case BooleanType | DateType | TimestampType | _: IntegralType => - Long.MinValue + if (nullAsSmallest) { + Long.MinValue + } else { + Long.MaxValue + } case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => - Long.MinValue + if (nullAsSmallest) { + Long.MinValue + } else { + Long.MaxValue + } case _: DecimalType => - DoublePrefixComparator.computePrefix(Double.NegativeInfinity) - case _ => 0L + if (nullAsSmallest) { + DoublePrefixComparator.computePrefix(Double.NegativeInfinity) + } else { + DoublePrefixComparator.computePrefix(Double.NaN) + } + case _ => + if (nullAsSmallest) { + 0L + } else { + -1L + } } + private def nullAsSmallest: Boolean = (child.isAscending && child.nullOrdering == NullsFirst) || + (!child.isAscending && child.nullOrdering == NullsLast) + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -86,6 +130,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { val input = childCode.value val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName val DoublePrefixCmp = classOf[DoublePrefixComparator].getName + val StringPrefixCmp = classOf[StringPrefixComparator].getName val prefixCode = child.child.dataType match { case BooleanType => s"$input ? 1L : 0L" @@ -95,7 +140,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { s"(long) $input" case FloatType | DoubleType => s"$DoublePrefixCmp.computePrefix((double)$input)" - case StringType => s"$input.getPrefix()" + case StringType => s"$StringPrefixCmp.computePrefix($input)" case BinaryType => s"$BinaryPrefixCmp.computePrefix($input)" case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => if (dt.precision <= Decimal.MAX_LONG_DIGITS) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index f4d35d232e..e7df95e114 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -63,7 +63,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR */ def genComparisons(ctx: CodegenContext, schema: StructType): String = { val ordering = schema.fields.map(_.dataType).zipWithIndex.map { - case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + case(dt, index) => SortOrder(BoundReference(index, dt, nullable = true), Ascending) } genComparisons(ctx, ordering) } @@ -74,7 +74,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = { val comparisons = ordering.map { order => val eval = order.child.genCode(ctx) - val asc = order.direction == Ascending + val asc = order.isAscending val isNullA = ctx.freshName("isNullA") val primitiveA = ctx.freshName("primitiveA") val isNullB = ctx.freshName("isNullB") @@ -99,9 +99,17 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR if ($isNullA && $isNullB) { // Nothing } else if ($isNullA) { - return ${if (order.direction == Ascending) "-1" else "1"}; + return ${ + order.nullOrdering match { + case NullsFirst => "-1" + case NullsLast => "1" + }}; } else if ($isNullB) { - return ${if (order.direction == Ascending) "1" else "-1"}; + return ${ + order.nullOrdering match { + case NullsFirst => "1" + case NullsLast => "-1" + }}; } else { int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)}; if (comp != 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index 6112259fed..79d2052c38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -39,9 +39,9 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow if (left == null && right == null) { // Both null, continue looking. } else if (left == null) { - return if (order.direction == Ascending) -1 else 1 + return if (order.nullOrdering == NullsFirst) -1 else 1 } else if (right == null) { - return if (order.direction == Ascending) 1 else -1 + return if (order.nullOrdering == NullsFirst) 1 else -1 } else { val comparison = order.dataType match { case dt: AtomicType if order.direction == Ascending => @@ -76,7 +76,7 @@ object InterpretedOrdering { */ def forSchema(dataTypes: Seq[DataType]): InterpretedOrdering = { new InterpretedOrdering(dataTypes.zipWithIndex.map { - case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + case (dt, index) => SortOrder(BoundReference(index, dt, nullable = true), Ascending) }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bbbb14df88..69d68fa6f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1206,11 +1206,19 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create a [[SortOrder]] expression. */ override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) { - if (ctx.DESC != null) { - SortOrder(expression(ctx.expression), Descending) + val direction = if (ctx.DESC != null) { + Descending } else { - SortOrder(expression(ctx.expression), Ascending) + Ascending } + val nullOrdering = if (ctx.FIRST != null) { + NullsFirst + } else if (ctx.LAST != null) { + NullsLast + } else { + direction.defaultNullOrdering + } + SortOrder(expression(ctx.expression), direction, nullOrdering) } /** |