aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-08-14 12:00:01 -0700
committerReynold Xin <rxin@databricks.com>2015-08-14 12:00:01 -0700
commit34d610be854d2a975d9c1e232d87433b85add6fd (patch)
tree5bbf9882166496d4fe16cf8f6af4087c9e150f4b /mllib/src
parenta7317ccdc20d001e5b7f5277b0535923468bfbc6 (diff)
downloadspark-34d610be854d2a975d9c1e232d87433b85add6fd.tar.gz
spark-34d610be854d2a975d9c1e232d87433b85add6fd.tar.bz2
spark-34d610be854d2a975d9c1e232d87433b85add6fd.zip
[SPARK-9929] [SQL] support metadata in withColumn
in MLlib sometimes we need to set metadata for the new column, thus we will alias the new column with metadata before call `withColumn` and in `withColumn` we alias this clolumn again. Here I overloaded `withColumn` to allow user set metadata, just like what we did for `Column.as`. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #8159 from cloud-fan/withColumn.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala3
4 files changed, 6 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 1132d8046d..c62e132f5d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -131,7 +131,7 @@ final class OneVsRestModel private[ml] (
// output label and label metadata as prediction
aggregatedDataset
- .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata))
+ .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
.drop(accColName)
}
@@ -203,8 +203,8 @@ final class OneVsRest(override val uid: String)
// TODO: use when ... otherwise after SPARK-7321 is merged
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
val labelColName = "mc2b$" + index
- val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
- val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
+ val trainingDataset =
+ multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta)
val classifier = getClassifier
val paramMap = new ParamMap()
paramMap.put(classifier.labelCol -> labelColName)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index cfca494dcf..6fdf25b015 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String)
}
val newCol = bucketizer(dataset($(inputCol)))
val newField = prepOutputField(dataset.schema)
- dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
+ dataset.withColumn($(outputCol), newCol, newField.metadata)
}
private def prepOutputField(schema: StructType): StructField = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 6875aefe06..61b925c0fd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -341,7 +341,7 @@ class VectorIndexerModel private[ml] (
val newField = prepOutputField(dataset.schema)
val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
val newCol = transformUDF(dataset($(inputCol)))
- dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
+ dataset.withColumn($(outputCol), newCol, newField.metadata)
}
override def transformSchema(schema: StructType): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index 772bebeff2..c5c2272270 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -119,8 +119,7 @@ final class VectorSlicer(override val uid: String)
case features: SparseVector => features.slice(inds)
}
}
- dataset.withColumn($(outputCol),
- slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata()))
+ dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), outputAttr.toMetadata())
}
/** Get the feature indices in order: indices, names */