diff options
Diffstat (limited to 'mllib/src')
6 files changed, 50 insertions, 20 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 3ac6c77669..5219680be2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -214,7 +214,7 @@ class RFormulaModel private[feature]( override def transformSchema(schema: StructType): StructType = { checkCanTransform(schema) val withFeatures = pipelineModel.transformSchema(schema) - if (hasLabelCol(withFeatures)) { + if (resolvedFormula.label.isEmpty || hasLabelCol(withFeatures)) { withFeatures } else if (schema.exists(_.name == resolvedFormula.label)) { val nullable = schema(resolvedFormula.label).dataType match { @@ -236,7 +236,7 @@ class RFormulaModel private[feature]( private def transformLabel(dataset: Dataset[_]): DataFrame = { val labelName = resolvedFormula.label - if (hasLabelCol(dataset.schema)) { + if (labelName.isEmpty || hasLabelCol(dataset.schema)) { dataset.toDF } else if (dataset.schema.exists(_.name == labelName)) { dataset.schema(labelName).dataType match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index 4079b387e1..cf52710ab8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -63,6 +63,9 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { ResolvedRFormula(label.value, includedTerms.distinct, hasIntercept) } + /** Whether this formula specifies fitting with response variable. */ + def hasLabel: Boolean = label.value.nonEmpty + /** Whether this formula specifies fitting with an intercept term. */ def hasIntercept: Boolean = { var intercept = true @@ -159,6 +162,10 @@ private[ml] object RFormulaParser extends RegexParsers { private val columnRef: Parser[ColumnRef] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) } + private val empty: Parser[ColumnRef] = "" ^^ { case a => ColumnRef("") } + + private val label: Parser[ColumnRef] = columnRef | empty + private val dot: Parser[InteractableTerm] = "\\.".r ^^ { case _ => Dot } private val interaction: Parser[List[InteractableTerm]] = rep1sep(columnRef | dot, ":") @@ -174,7 +181,7 @@ private[ml] object RFormulaParser extends RegexParsers { } private val formula: Parser[ParsedRFormula] = - (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } + (label ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } def parse(value: String): ParsedRFormula = parseAll(formula, value) match { case Success(result, _) => result diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala index f67760d3ca..4d4c303fc8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -25,7 +25,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.clustering.{KMeans, KMeansModel} -import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -65,28 +65,32 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] { def fit( data: DataFrame, - k: Double, - maxIter: Double, - initMode: String, - columns: Array[String]): KMeansWrapper = { + formula: String, + k: Int, + maxIter: Int, + initMode: String): KMeansWrapper = { + + val rFormulaModel = new RFormula() + .setFormula(formula) + .setFeaturesCol("features") + .fit(data) - val assembler = new VectorAssembler() - .setInputCols(columns) - .setOutputCol("features") + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) val kMeans = new KMeans() - .setK(k.toInt) - .setMaxIter(maxIter.toInt) + .setK(k) + .setMaxIter(maxIter) .setInitMode(initMode) val pipeline = new Pipeline() - .setStages(Array(assembler, kMeans)) + .setStages(Array(rFormulaModel, kMeans)) .fit(data) val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel] - val attrs = AttributeGroup.fromStructField( - kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol)) - val features: Array[String] = attrs.attributes.get.map(_.name.get) val size: Array[Long] = kMeansModel.summary.clusterSizes new KMeansWrapper(pipeline, features, size) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 9c0757941e..568c160ee5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkException import org.apache.spark.ml.util.MLReader /** - * This is the Scala stub of SparkR ml.load. It will dispatch the call to corresponding + * This is the Scala stub of SparkR read.ml. It will dispatch the call to corresponding * model wrapper loading function according the class name extracted from rMetadata of the path. */ private[r] object RWrappers extends MLReader[Object] { @@ -45,7 +45,7 @@ private[r] object RWrappers extends MLReader[Object] { case "org.apache.spark.ml.r.KMeansWrapper" => KMeansWrapper.load(path) case _ => - throw new SparkException(s"SparkR ml.load does not support load $className") + throw new SparkException(s"SparkR read.ml does not support load $className") } } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index 66b2ceacb0..5f1d5987e8 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -72,7 +72,7 @@ public class JavaStatisticsSuite implements Serializable { Double corr1 = Statistics.corr(x, y); Double corr2 = Statistics.corr(x, y, "pearson"); // Check default method - assertEquals(corr1, corr2); + assertEquals(corr1, corr2, 1e-5); } @Test diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index e1b269b5b6..f8476953d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { @@ -89,6 +90,24 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(resultSchema.toString == model.transform(original).schema.toString) } + test("allow empty label") { + val original = sqlContext.createDataFrame( + Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0)) + ).toDF("id", "a", "b") + val formula = new RFormula().setFormula("~ a + b") + val model = formula.fit(original) + val result = model.transform(original) + val resultSchema = model.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq( + (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)), + (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)), + (7, 8.0, 9.0, Vectors.dense(8.0, 9.0))) + ).toDF("id", "a", "b", "features") + assert(result.schema.toString == resultSchema.toString) + assert(result.collect() === expected.collect()) + } + test("encodes string terms") { val formula = new RFormula().setFormula("id ~ a + b") val original = sqlContext.createDataFrame( |