aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala17
5 files changed, 23 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 1132d8046d..c62e132f5d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -131,7 +131,7 @@ final class OneVsRestModel private[ml] (
// output label and label metadata as prediction
aggregatedDataset
- .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata))
+ .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
.drop(accColName)
}
@@ -203,8 +203,8 @@ final class OneVsRest(override val uid: String)
// TODO: use when ... otherwise after SPARK-7321 is merged
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
val labelColName = "mc2b$" + index
- val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
- val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
+ val trainingDataset =
+ multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta)
val classifier = getClassifier
val paramMap = new ParamMap()
paramMap.put(classifier.labelCol -> labelColName)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index cfca494dcf..6fdf25b015 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String)
}
val newCol = bucketizer(dataset($(inputCol)))
val newField = prepOutputField(dataset.schema)
- dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
+ dataset.withColumn($(outputCol), newCol, newField.metadata)
}
private def prepOutputField(schema: StructType): StructField = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 6875aefe06..61b925c0fd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -341,7 +341,7 @@ class VectorIndexerModel private[ml] (
val newField = prepOutputField(dataset.schema)
val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
val newCol = transformUDF(dataset($(inputCol)))
- dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
+ dataset.withColumn($(outputCol), newCol, newField.metadata)
}
override def transformSchema(schema: StructType): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index 772bebeff2..c5c2272270 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -119,8 +119,7 @@ final class VectorSlicer(override val uid: String)
case features: SparseVector => features.slice(inds)
}
}
- dataset.withColumn($(outputCol),
- slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata()))
+ dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), outputAttr.toMetadata())
}
/** Get the feature indices in order: indices, names */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index c466d9e6cb..cf75e64e88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1150,6 +1150,23 @@ class DataFrame private[sql](
}
/**
+ * Returns a new [[DataFrame]] by adding a column with metadata.
+ */
+ private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = {
+ val resolver = sqlContext.analyzer.resolver
+ val replaced = schema.exists(f => resolver(f.name, colName))
+ if (replaced) {
+ val colNames = schema.map { field =>
+ val name = field.name
+ if (resolver(name, colName)) col.as(colName, metadata) else Column(name)
+ }
+ select(colNames : _*)
+ } else {
+ select(Column("*"), col.as(colName, metadata))
+ }
+ }
+
+ /**
* Returns a new [[DataFrame]] with a column renamed.
* This is a no-op if schema doesn't contain existingName.
* @group dfops