diff options
Diffstat (limited to 'mllib')
5 files changed, 20 insertions, 36 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 29cd981078..6eb7ea639c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -23,7 +23,6 @@ import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ import org.apache.spark.sql.DataFrame -import org.apache.spark.sql._ import org.apache.spark.sql.api.scala.dsl._ import org.apache.spark.sql.types._ @@ -99,6 +98,6 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap dataset.select($"*", callUDF( - this.createTransformFunc(map), outputDataType, Column(map(inputCol))).as(map(outputCol))) + this.createTransformFunc(map), outputDataType, dataset(map(inputCol))).as(map(outputCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 101f6c8114..d82360dcce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -25,7 +25,6 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql._ import org.apache.spark.sql.api.scala.dsl._ -import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.storage.StorageLevel @@ -133,15 +132,14 @@ class LogisticRegressionModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap - val score: Vector => Double = (v) => { + val scoreFunction: Vector => Double = (v) => { val margin = BLAS.dot(v, weights) 1.0 / (1.0 + math.exp(-margin)) } val t = map(threshold) - val predict: Double => Double = (score) => { - if (score > t) 1.0 else 0.0 - } - dataset.select($"*", callUDF(score, Column(map(featuresCol))).as(map(scoreCol))) - .select($"*", callUDF(predict, Column(map(scoreCol))).as(map(predictionCol))) + val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 } + dataset + .select($"*", callUDF(scoreFunction, col(map(featuresCol))).as(map(scoreCol))) + .select($"*", callUDF(predictFunction, col(map(scoreCol))).as(map(predictionCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index c456beb65d..78a48561dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -24,7 +24,6 @@ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ import org.apache.spark.sql.api.scala.dsl._ -import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.types.{StructField, StructType} /** @@ -85,7 +84,7 @@ class StandardScalerModel private[ml] ( val scale: (Vector) => Vector = (v) => { scaler.transform(v) } - dataset.select($"*", callUDF(scale, Column(map(inputCol))).as(map(outputCol))) + dataset.select($"*", callUDF(scale, col(map(inputCol))).as(map(outputCol))) } private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 738b1844b5..474d4731ec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -111,20 +111,10 @@ class ALSModel private[ml] ( def setPredictionCol(value: String): this.type = set(predictionCol, value) override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - import dataset.sqlContext._ - import org.apache.spark.ml.recommendation.ALSModel.Factor + import dataset.sqlContext.createDataFrame val map = this.paramMap ++ paramMap - // TODO: Add DSL to simplify the code here. - val instanceTable = s"instance_$uid" - val userTable = s"user_$uid" - val itemTable = s"item_$uid" - val instances = dataset.as(instanceTable) - val users = userFactors.map { case (id, features) => - Factor(id, features) - }.as(userTable) - val items = itemFactors.map { case (id, features) => - Factor(id, features) - }.as(itemTable) + val users = userFactors.toDataFrame("id", "features") + val items = itemFactors.toDataFrame("id", "features") val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => { if (userFeatures != null && itemFeatures != null) { blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) @@ -133,13 +123,14 @@ class ALSModel private[ml] ( } } val inputColumns = dataset.schema.fieldNames - val prediction = callUDF(predict, $"$userTable.features", $"$itemTable.features") - .as(map(predictionCol)) - val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ prediction - instances - .join(users, Column(map(userCol)) === $"$userTable.id", "left") - .join(items, Column(map(itemCol)) === $"$itemTable.id", "left") + val prediction = callUDF(predict, users("features"), items("features")).as(map(predictionCol)) + val outputColumns = inputColumns.map(f => dataset(f)) :+ prediction + dataset + .join(users, dataset(map(userCol)) === users("id"), "left") + .join(items, dataset(map(itemCol)) === items("id"), "left") .select(outputColumns: _*) + // TODO: Just use a dataset("*") + // .select(dataset("*"), prediction) } override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { @@ -147,10 +138,6 @@ class ALSModel private[ml] ( } } -private object ALSModel { - /** Case class to convert factors to [[DataFrame]]s */ - private case class Factor(id: Int, features: Seq[Float]) -} /** * Alternating Least Squares (ALS) matrix factorization. @@ -210,7 +197,7 @@ class ALS extends Estimator[ALSModel] with ALSParams { override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { val map = this.paramMap ++ paramMap val ratings = dataset - .select(Column(map(userCol)), Column(map(itemCol)), Column(map(ratingCol)).cast(FloatType)) + .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType)) .map { row => new Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 31c33f1bf6..567a8a6c03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -27,7 +27,8 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types._ /** |