diff options
author | Reynold Xin <rxin@databricks.com> | 2015-05-01 12:49:02 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-01 12:49:02 -0700 |
commit | 37537760d19eab878a5e1a48641cc49e6cb4b989 (patch) | |
tree | a37c553f7c27835399dd46a26fafd8dcc1613437 /mllib | |
parent | 16860327286bc08b4e2283d51b4c8fe024ba5006 (diff) | |
download | spark-37537760d19eab878a5e1a48641cc49e6cb4b989.tar.gz spark-37537760d19eab878a5e1a48641cc49e6cb4b989.tar.bz2 spark-37537760d19eab878a5e1a48641cc49e6cb4b989.zip |
[SPARK-7274] [SQL] Create Column expression for array/struct creation.
Author: Reynold Xin <rxin@databricks.com>
Closes #5802 from rxin/SPARK-7274 and squashes the following commits:
19aecaa [Reynold Xin] Fixed unicode tests.
bfc1538 [Reynold Xin] Export all Python functions.
2517b8c [Reynold Xin] Code review.
23da335 [Reynold Xin] Fixed Python bug.
132002e [Reynold Xin] Fixed tests.
56fce26 [Reynold Xin] Added Python support.
b0d591a [Reynold Xin] Fixed debug error.
86926a6 [Reynold Xin] Added test suite.
7dbb9ab [Reynold Xin] Ok one more.
470e2f5 [Reynold Xin] One more MLlib ...
e2d14f0 [Reynold Xin] [SPARK-7274][SQL] Create Column expression for array/struct creation.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala | 13 |
1 files changed, 5 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 7b2a451ca5..5e781a326d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -25,9 +25,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.sql.{Column, DataFrame, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -53,13 +51,12 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { val inputColNames = map(inputCols) val args = inputColNames.map { c => schema(c).dataType match { - case DoubleType => UnresolvedAttribute(c) - case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c) - case _: NumericType | BooleanType => - Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() + case DoubleType => dataset(c) + case _: VectorUDT => dataset(c) + case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") } } - dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol))) + dataset.select(col("*"), assembleFunc(struct(args : _*)).as(map(outputCol))) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { |