diff options
Diffstat (limited to 'mllib/src/main')
4 files changed, 30 insertions, 19 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 3ac6c77669..5219680be2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -214,7 +214,7 @@ class RFormulaModel private[feature]( override def transformSchema(schema: StructType): StructType = { checkCanTransform(schema) val withFeatures = pipelineModel.transformSchema(schema) - if (hasLabelCol(withFeatures)) { + if (resolvedFormula.label.isEmpty || hasLabelCol(withFeatures)) { withFeatures } else if (schema.exists(_.name == resolvedFormula.label)) { val nullable = schema(resolvedFormula.label).dataType match { @@ -236,7 +236,7 @@ class RFormulaModel private[feature]( private def transformLabel(dataset: Dataset[_]): DataFrame = { val labelName = resolvedFormula.label - if (hasLabelCol(dataset.schema)) { + if (labelName.isEmpty || hasLabelCol(dataset.schema)) { dataset.toDF } else if (dataset.schema.exists(_.name == labelName)) { dataset.schema(labelName).dataType match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index 4079b387e1..cf52710ab8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -63,6 +63,9 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { ResolvedRFormula(label.value, includedTerms.distinct, hasIntercept) } + /** Whether this formula specifies fitting with response variable. */ + def hasLabel: Boolean = label.value.nonEmpty + /** Whether this formula specifies fitting with an intercept term. */ def hasIntercept: Boolean = { var intercept = true @@ -159,6 +162,10 @@ private[ml] object RFormulaParser extends RegexParsers { private val columnRef: Parser[ColumnRef] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) } + private val empty: Parser[ColumnRef] = "" ^^ { case a => ColumnRef("") } + + private val label: Parser[ColumnRef] = columnRef | empty + private val dot: Parser[InteractableTerm] = "\\.".r ^^ { case _ => Dot } private val interaction: Parser[List[InteractableTerm]] = rep1sep(columnRef | dot, ":") @@ -174,7 +181,7 @@ private[ml] object RFormulaParser extends RegexParsers { } private val formula: Parser[ParsedRFormula] = - (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } + (label ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } def parse(value: String): ParsedRFormula = parseAll(formula, value) match { case Success(result, _) => result diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala index f67760d3ca..4d4c303fc8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -25,7 +25,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.clustering.{KMeans, KMeansModel} -import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -65,28 +65,32 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] { def fit( data: DataFrame, - k: Double, - maxIter: Double, - initMode: String, - columns: Array[String]): KMeansWrapper = { + formula: String, + k: Int, + maxIter: Int, + initMode: String): KMeansWrapper = { + + val rFormulaModel = new RFormula() + .setFormula(formula) + .setFeaturesCol("features") + .fit(data) - val assembler = new VectorAssembler() - .setInputCols(columns) - .setOutputCol("features") + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) val kMeans = new KMeans() - .setK(k.toInt) - .setMaxIter(maxIter.toInt) + .setK(k) + .setMaxIter(maxIter) .setInitMode(initMode) val pipeline = new Pipeline() - .setStages(Array(assembler, kMeans)) + .setStages(Array(rFormulaModel, kMeans)) .fit(data) val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel] - val attrs = AttributeGroup.fromStructField( - kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol)) - val features: Array[String] = attrs.attributes.get.map(_.name.get) val size: Array[Long] = kMeansModel.summary.clusterSizes new KMeansWrapper(pipeline, features, size) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 9c0757941e..568c160ee5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkException import org.apache.spark.ml.util.MLReader /** - * This is the Scala stub of SparkR ml.load. It will dispatch the call to corresponding + * This is the Scala stub of SparkR read.ml. It will dispatch the call to corresponding * model wrapper loading function according the class name extracted from rMetadata of the path. */ private[r] object RWrappers extends MLReader[Object] { @@ -45,7 +45,7 @@ private[r] object RWrappers extends MLReader[Object] { case "org.apache.spark.ml.r.KMeansWrapper" => KMeansWrapper.load(path) case _ => - throw new SparkException(s"SparkR ml.load does not support load $className") + throw new SparkException(s"SparkR read.ml does not support load $className") } } } |