aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-05-01 12:49:02 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-01 12:49:02 -0700
commit37537760d19eab878a5e1a48641cc49e6cb4b989 (patch)
treea37c553f7c27835399dd46a26fafd8dcc1613437 /mllib
parent16860327286bc08b4e2283d51b4c8fe024ba5006 (diff)
downloadspark-37537760d19eab878a5e1a48641cc49e6cb4b989.tar.gz
spark-37537760d19eab878a5e1a48641cc49e6cb4b989.tar.bz2
spark-37537760d19eab878a5e1a48641cc49e6cb4b989.zip
[SPARK-7274] [SQL] Create Column expression for array/struct creation.
Author: Reynold Xin <rxin@databricks.com> Closes #5802 from rxin/SPARK-7274 and squashes the following commits: 19aecaa [Reynold Xin] Fixed unicode tests. bfc1538 [Reynold Xin] Export all Python functions. 2517b8c [Reynold Xin] Code review. 23da335 [Reynold Xin] Fixed Python bug. 132002e [Reynold Xin] Fixed tests. 56fce26 [Reynold Xin] Added Python support. b0d591a [Reynold Xin] Fixed debug error. 86926a6 [Reynold Xin] Added test suite. 7dbb9ab [Reynold Xin] Ok one more. 470e2f5 [Reynold Xin] One more MLlib ... e2d14f0 [Reynold Xin] [SPARK-7274][SQL] Create Column expression for array/struct creation.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala13
1 files changed, 5 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 7b2a451ca5..5e781a326d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -25,9 +25,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
-import org.apache.spark.sql.{Column, DataFrame, Row}
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -53,13 +51,12 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
val inputColNames = map(inputCols)
val args = inputColNames.map { c =>
schema(c).dataType match {
- case DoubleType => UnresolvedAttribute(c)
- case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c)
- case _: NumericType | BooleanType =>
- Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")()
+ case DoubleType => dataset(c)
+ case _: VectorUDT => dataset(c)
+ case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
}
}
- dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol)))
+ dataset.select(col("*"), assembleFunc(struct(args : _*)).as(map(outputCol)))
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {