aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala13
2 files changed, 21 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 1b62e17ff4..b6ec7d3417 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.types.{UTF8String, StructType, NativeType}
-
+import org.apache.spark.sql.types.{UTF8String, DataType, StructType, NativeType}
/**
* An extended interface to [[Row]] that allows the values for each column to be updated. Setting
@@ -239,3 +238,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
return 0
}
}
+
+object RowOrdering {
+ def forSchema(dataTypes: Seq[DataType]): RowOrdering =
+ new RowOrdering(dataTypes.zipWithIndex.map {
+ case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
+ })
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 288c11f69f..fb4217a448 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -94,6 +94,9 @@ sealed trait Partitioning {
* only compatible if the `numPartitions` of them is the same.
*/
def compatibleWith(other: Partitioning): Boolean
+
+ /** Returns the expressions that are used to key the partitioning. */
+ def keyExpressions: Seq[Expression]
}
case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
@@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
case UnknownPartitioning(_) => true
case _ => false
}
+
+ override def keyExpressions: Seq[Expression] = Nil
}
case object SinglePartition extends Partitioning {
@@ -117,6 +122,8 @@ case object SinglePartition extends Partitioning {
case SinglePartition => true
case _ => false
}
+
+ override def keyExpressions: Seq[Expression] = Nil
}
case object BroadcastPartitioning extends Partitioning {
@@ -128,6 +135,8 @@ case object BroadcastPartitioning extends Partitioning {
case SinglePartition => true
case _ => false
}
+
+ override def keyExpressions: Seq[Expression] = Nil
}
/**
@@ -158,6 +167,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}
+ override def keyExpressions: Seq[Expression] = expressions
+
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
@@ -200,6 +211,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case _ => false
}
+ override def keyExpressions: Seq[Expression] = ordering.map(_.child)
+
override def eval(input: Row): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}