aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Predictor.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala27
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala5
6 files changed, 46 insertions, 35 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index edaa2afb79..333b42711e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -122,9 +122,7 @@ abstract class Predictor[
*/
protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
dataset.select($(labelCol), $(featuresCol))
- .map { case Row(label: Double, features: Vector) =>
- LabeledPoint(label, features)
- }
+ .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) }
}
}
@@ -171,7 +169,10 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
if ($(predictionCol).nonEmpty) {
- dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol))))
+ val predictUDF = udf { (features: Any) =>
+ predict(features.asInstanceOf[FeaturesType])
+ }
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
} else {
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
" since no output columns were set.")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 14c285dbfc..85c097bc64 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -102,15 +102,20 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
var outputData = dataset
var numColsOutput = 0
if (getRawPredictionCol != "") {
- outputData = outputData.withColumn(getRawPredictionCol,
- callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol)))
+ val predictRawUDF = udf { (features: Any) =>
+ predictRaw(features.asInstanceOf[FeaturesType])
+ }
+ outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
numColsOutput += 1
}
if (getPredictionCol != "") {
val predUDF = if (getRawPredictionCol != "") {
- callUDF(raw2prediction _, DoubleType, col(getRawPredictionCol))
+ udf(raw2prediction _).apply(col(getRawPredictionCol))
} else {
- callUDF(predict _, DoubleType, col(getFeaturesCol))
+ val predictUDF = udf { (features: Any) =>
+ predict(features.asInstanceOf[FeaturesType])
+ }
+ predictUDF(col(getFeaturesCol))
}
outputData = outputData.withColumn(getPredictionCol, predUDF)
numColsOutput += 1
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index b657882f8a..ea757c5e40 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -88,9 +88,9 @@ final class OneVsRestModel private[ml] (
// add an accumulator column to store predictions of all the models
val accColName = "mbc$acc" + UUID.randomUUID().toString
- val init: () => Map[Int, Double] = () => {Map()}
+ val initUDF = udf { () => Map[Int, Double]() }
val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false)
- val newDataset = dataset.withColumn(accColName, callUDF(init, mapType))
+ val newDataset = dataset.withColumn(accColName, initUDF())
// persist if underlying dataset is not persistent.
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -106,13 +106,12 @@ final class OneVsRestModel private[ml] (
// add temporary column to store intermediate scores and update
val tmpColName = "mbc$tmp" + UUID.randomUUID().toString
- val update: (Map[Int, Double], Vector) => Map[Int, Double] =
- (predictions: Map[Int, Double], prediction: Vector) => {
- predictions + ((index, prediction(1)))
- }
- val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
+ val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
+ predictions + ((index, prediction(1)))
+ }
val transformedDataset = model.transform(df).select(columns : _*)
- val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf)
+ val updatedDataset = transformedDataset
+ .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol)))
val newColumns = origCols ++ List(col(tmpColName))
// switch out the intermediate column with the accumulator column
@@ -124,13 +123,13 @@ final class OneVsRestModel private[ml] (
}
// output the index of the classifier with highest confidence as prediction
- val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => {
+ val labelUDF = udf { (predictions: Map[Int, Double]) =>
predictions.maxBy(_._2)._1.toDouble
}
// output label and label metadata as prediction
- val labelUdf = callUDF(label, DoubleType, col(accColName))
- aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
+ aggregatedDataset
+ .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata))
.drop(accColName)
}
@@ -185,17 +184,15 @@ final class OneVsRest(override val uid: String)
// create k columns, one for each binary classifier.
val models = Range(0, numClasses).par.map { index =>
-
- val label: Double => Double = (label: Double) => {
+ val labelUDF = udf { (label: Double) =>
if (label.toInt == index) 1.0 else 0.0
}
// generate new label metadata for the binary problem.
// TODO: use when ... otherwise after SPARK-7321 is merged
- val labelUDF = callUDF(label, DoubleType, col($(labelCol)))
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
val labelColName = "mc2b$" + index
- val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta)
+ val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
val classifier = getClassifier
classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 330ae2938f..38e8323726 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -98,26 +98,34 @@ private[spark] abstract class ProbabilisticClassificationModel[
var outputData = dataset
var numColsOutput = 0
if ($(rawPredictionCol).nonEmpty) {
- outputData = outputData.withColumn(getRawPredictionCol,
- callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol)))
+ val predictRawUDF = udf { (features: Any) =>
+ predictRaw(features.asInstanceOf[FeaturesType])
+ }
+ outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
numColsOutput += 1
}
if ($(probabilityCol).nonEmpty) {
val probUDF = if ($(rawPredictionCol).nonEmpty) {
- callUDF(raw2probability _, new VectorUDT, col($(rawPredictionCol)))
+ udf(raw2probability _).apply(col($(rawPredictionCol)))
} else {
- callUDF(predictProbability _, new VectorUDT, col($(featuresCol)))
+ val probabilityUDF = udf { (features: Any) =>
+ predictProbability(features.asInstanceOf[FeaturesType])
+ }
+ probabilityUDF(col($(featuresCol)))
}
outputData = outputData.withColumn($(probabilityCol), probUDF)
numColsOutput += 1
}
if ($(predictionCol).nonEmpty) {
val predUDF = if ($(rawPredictionCol).nonEmpty) {
- callUDF(raw2prediction _, DoubleType, col($(rawPredictionCol)))
+ udf(raw2prediction _).apply(col($(rawPredictionCol)))
} else if ($(probabilityCol).nonEmpty) {
- callUDF(probability2prediction _, DoubleType, col($(probabilityCol)))
+ udf(probability2prediction _).apply(col($(probabilityCol)))
} else {
- callUDF(predict _, DoubleType, col($(featuresCol)))
+ val predictUDF = udf { (features: Any) =>
+ predict(features.asInstanceOf[FeaturesType])
+ }
+ predictUDF(col($(featuresCol)))
}
outputData = outputData.withColumn($(predictionCol), predUDF)
numColsOutput += 1
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index f4854a5e4b..c73bdccdef 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -30,7 +30,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.functions.callUDF
+import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.util.collection.OpenHashSet
@@ -339,7 +339,8 @@ class VectorIndexerModel private[ml] (
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
val newField = prepOutputField(dataset.schema)
- val newCol = callUDF(transformFunc, new VectorUDT, dataset($(inputCol)))
+ val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
+ val newCol = transformUDF(dataset($(inputCol)))
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 73bc6c9991..22c54e43c1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -137,13 +137,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
test("SPARK-7158 collect and take return different results") {
import java.util.UUID
- import org.apache.spark.sql.types._
val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index")
// we except the id is materialized once
- def id: () => String = () => { UUID.randomUUID().toString() }
+ val idUdf = udf(() => UUID.randomUUID().toString)
- val dfWithId = df.withColumn("id", callUDF(id, StringType))
+ val dfWithId = df.withColumn("id", idUdf())
// Make a new DataFrame (actually the same reference to the old one)
val cached = dfWithId.cache()
// Trigger the cache