aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala13
-rw-r--r--python/pyspark/sql/functions.py80
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala41
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala84
5 files changed, 199 insertions, 29 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 7b2a451ca5..5e781a326d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -25,9 +25,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
-import org.apache.spark.sql.{Column, DataFrame, Row}
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -53,13 +51,12 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
val inputColNames = map(inputCols)
val args = inputColNames.map { c =>
schema(c).dataType match {
- case DoubleType => UnresolvedAttribute(c)
- case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c)
- case _: NumericType | BooleanType =>
- Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")()
+ case DoubleType => dataset(c)
+ case _: VectorUDT => dataset(c)
+ case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
}
}
- dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol)))
+ dataset.select(col("*"), assembleFunc(struct(args : _*)).as(map(outputCol)))
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 241f821757..641220a264 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -24,13 +24,20 @@ if sys.version < "3":
from itertools import imap as map
from pyspark import SparkContext
-from pyspark.rdd import _prepare_for_python_RDD
+from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.types import StringType
from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
-__all__ = ['countDistinct', 'approxCountDistinct', 'udf']
+__all__ = [
+ 'approxCountDistinct',
+ 'countDistinct',
+ 'monotonicallyIncreasingId',
+ 'rand',
+ 'randn',
+ 'sparkPartitionId',
+ 'udf']
def _create_function(name, doc=""):
@@ -74,27 +81,21 @@ __all__ += _functions.keys()
__all__.sort()
-def rand(seed=None):
- """
- Generate a random column with i.i.d. samples from U[0.0, 1.0].
- """
- sc = SparkContext._active_spark_context
- if seed:
- jc = sc._jvm.functions.rand(seed)
- else:
- jc = sc._jvm.functions.rand()
- return Column(jc)
+def array(*cols):
+ """Creates a new array column.
+ :param cols: list of column names (string) or list of :class:`Column` expressions that have
+ the same data type.
-def randn(seed=None):
- """
- Generate a column with i.i.d. samples from the standard normal distribution.
+ >>> df.select(array('age', 'age').alias("arr")).collect()
+ [Row(arr=[2, 2]), Row(arr=[5, 5])]
+ >>> df.select(array([df.age, df.age]).alias("arr")).collect()
+ [Row(arr=[2, 2]), Row(arr=[5, 5])]
"""
sc = SparkContext._active_spark_context
- if seed:
- jc = sc._jvm.functions.randn(seed)
- else:
- jc = sc._jvm.functions.randn()
+ if len(cols) == 1 and isinstance(cols[0], (list, set)):
+ cols = cols[0]
+ jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column))
return Column(jc)
@@ -146,6 +147,28 @@ def monotonicallyIncreasingId():
return Column(sc._jvm.functions.monotonicallyIncreasingId())
+def rand(seed=None):
+ """Generates a random column with i.i.d. samples from U[0.0, 1.0].
+ """
+ sc = SparkContext._active_spark_context
+ if seed:
+ jc = sc._jvm.functions.rand(seed)
+ else:
+ jc = sc._jvm.functions.rand()
+ return Column(jc)
+
+
+def randn(seed=None):
+ """Generates a column with i.i.d. samples from the standard normal distribution.
+ """
+ sc = SparkContext._active_spark_context
+ if seed:
+ jc = sc._jvm.functions.randn(seed)
+ else:
+ jc = sc._jvm.functions.randn()
+ return Column(jc)
+
+
def sparkPartitionId():
"""A column for partition ID of the Spark task.
@@ -158,6 +181,25 @@ def sparkPartitionId():
return Column(sc._jvm.functions.sparkPartitionId())
+@ignore_unicode_prefix
+def struct(*cols):
+ """Creates a new struct column.
+
+ :param cols: list of column names (string) or list of :class:`Column` expressions
+ that are named or aliased.
+
+ >>> df.select(struct('age', 'name').alias("struct")).collect()
+ [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))]
+ >>> df.select(struct([df.age, df.name]).alias("struct")).collect()
+ [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))]
+ """
+ sc = SparkContext._active_spark_context
+ if len(cols) == 1 and isinstance(cols[0], (list, set)):
+ cols = cols[0]
+ jc = sc._jvm.functions.struct(_to_seq(sc, cols, _to_java_column))
+ return Column(jc)
+
+
class UserDefinedFunction(object):
"""
User defined function in Python
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 2225621dba..c6217f07c4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -28,13 +28,21 @@ import org.apache.spark.sql.catalyst.trees
* the layout of intermediate tuples, BindReferences should be run after all such transformations.
*/
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
- extends Expression with trees.LeafNode[Expression] {
+ extends NamedExpression with trees.LeafNode[Expression] {
type EvaluatedType = Any
override def toString: String = s"input[$ordinal]"
override def eval(input: Row): Any = input(ordinal)
+
+ override def name: String = s"i[$ordinal]"
+
+ override def toAttribute: Attribute = throw new UnsupportedOperationException
+
+ override def qualifiers: Seq[String] = throw new UnsupportedOperationException
+
+ override def exprId: ExprId = throw new UnsupportedOperationException
}
object BindReferences extends Logging {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 242e64d3ff..7e283393d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag}
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -284,6 +284,23 @@ object functions {
def abs(e: Column): Column = Abs(e.expr)
/**
+ * Creates a new array column. The input columns must all have the same data type.
+ *
+ * @group normal_funcs
+ */
+ @scala.annotation.varargs
+ def array(cols: Column*): Column = CreateArray(cols.map(_.expr))
+
+ /**
+ * Creates a new array column. The input columns must all have the same data type.
+ *
+ * @group normal_funcs
+ */
+ def array(colName: String, colNames: String*): Column = {
+ array((colName +: colNames).map(col) : _*)
+ }
+
+ /**
* Returns the first column that is not null.
* {{{
* df.select(coalesce(df("a"), df("b")))
@@ -391,6 +408,28 @@ object functions {
def sqrt(e: Column): Column = Sqrt(e.expr)
/**
+ * Creates a new struct column. The input column must be a column in a [[DataFrame]], or
+ * a derived column expression that is named (i.e. aliased).
+ *
+ * @group normal_funcs
+ */
+ @scala.annotation.varargs
+ def struct(cols: Column*): Column = {
+ require(cols.forall(_.expr.isInstanceOf[NamedExpression]),
+ s"struct input columns must all be named or aliased ($cols)")
+ CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression]))
+ }
+
+ /**
+ * Creates a new struct column that composes multiple input columns.
+ *
+ * @group normal_funcs
+ */
+ def struct(colName: String, colNames: String*): Column = {
+ struct((colName +: colNames).map(col) : _*)
+ }
+
+ /**
* Converts a string expression to upper case.
*
* @group normal_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
new file mode 100644
index 0000000000..ca03713ef4
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
+import org.apache.spark.sql.types._
+
+/**
+ * Test suite for functions in [[org.apache.spark.sql.functions]].
+ */
+class DataFrameFunctionsSuite extends QueryTest {
+
+ test("array with column name") {
+ val df = Seq((0, 1)).toDF("a", "b")
+ val row = df.select(array("a", "b")).first()
+
+ val expectedType = ArrayType(IntegerType, containsNull = false)
+ assert(row.schema(0).dataType === expectedType)
+ assert(row.getAs[Seq[Int]](0) === Seq(0, 1))
+ }
+
+ test("array with column expression") {
+ val df = Seq((0, 1)).toDF("a", "b")
+ val row = df.select(array(col("a"), col("b") + col("b"))).first()
+
+ val expectedType = ArrayType(IntegerType, containsNull = false)
+ assert(row.schema(0).dataType === expectedType)
+ assert(row.getAs[Seq[Int]](0) === Seq(0, 2))
+ }
+
+ // Turn this on once we add a rule to the analyzer to throw a friendly exception
+ ignore("array: throw exception if putting columns of different types into an array") {
+ val df = Seq((0, "str")).toDF("a", "b")
+ intercept[AnalysisException] {
+ df.select(array("a", "b"))
+ }
+ }
+
+ test("struct with column name") {
+ val df = Seq((1, "str")).toDF("a", "b")
+ val row = df.select(struct("a", "b")).first()
+
+ val expectedType = StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", StringType)
+ ))
+ assert(row.schema(0).dataType === expectedType)
+ assert(row.getAs[Row](0) === Row(1, "str"))
+ }
+
+ test("struct with column expression") {
+ val df = Seq((1, "str")).toDF("a", "b")
+ val row = df.select(struct((col("a") * 2).as("c"), col("b"))).first()
+
+ val expectedType = StructType(Seq(
+ StructField("c", IntegerType, nullable = false),
+ StructField("b", StringType)
+ ))
+ assert(row.schema(0).dataType === expectedType)
+ assert(row.getAs[Row](0) === Row(2, "str"))
+ }
+
+ test("struct: must use named column expression") {
+ intercept[IllegalArgumentException] {
+ struct(col("a") * 2)
+ }
+ }
+}