aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2015-08-19 15:43:08 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-19 15:43:15 -0700
commit56a37b01fd07f4f1a8cb4e07b55e1a02cf23a5f7 (patch)
treed9c9bf104ceedee879727274aa78e9f6f6eb9f52
parent5c749c82cb3caa5a41fd3fd49c32ab23c6f738da (diff)
downloadspark-56a37b01fd07f4f1a8cb4e07b55e1a02cf23a5f7.tar.gz
spark-56a37b01fd07f4f1a8cb4e07b55e1a02cf23a5f7.tar.bz2
spark-56a37b01fd07f4f1a8cb4e07b55e1a02cf23a5f7.zip
[SPARK-9895] User Guide for RFormula Feature Transformer
mengxr Author: Eric Liang <ekl@databricks.com> Closes #8293 from ericl/docs-2. (cherry picked from commit 8e0a072f78b4902d5f7ccc6b15232ed202a117f9) Signed-off-by: Xiangrui Meng <meng@databricks.com>
-rw-r--r--docs/ml-features.md108
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala4
2 files changed, 110 insertions, 2 deletions
diff --git a/docs/ml-features.md b/docs/ml-features.md
index d0e8eeb7a7..6309db97be 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -1477,3 +1477,111 @@ print(output.select("features", "clicked").first())
</div>
</div>
+## RFormula
+
+`RFormula` selects columns specified by an [R model formula](https://stat.ethz.ch/R-manual/R-devel/library/stats/html/formula.html). It produces a vector column of features and a double column of labels. Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. If not already present in the DataFrame, the output label column will be created from the specified response variable in the formula.
+
+**Examples**
+
+Assume that we have a DataFrame with the columns `id`, `country`, `hour`, and `clicked`:
+
+~~~
+id | country | hour | clicked
+---|---------|------|---------
+ 7 | "US" | 18 | 1.0
+ 8 | "CA" | 12 | 0.0
+ 9 | "NZ" | 15 | 0.0
+~~~
+
+If we use `RFormula` with a formula string of `clicked ~ country + hour`, which indicates that we want to
+predict `clicked` based on `country` and `hour`, after transformation we should get the following DataFrame:
+
+~~~
+id | country | hour | clicked | features | label
+---|---------|------|---------|------------------|-------
+ 7 | "US" | 18 | 1.0 | [0.0, 0.0, 18.0] | 1.0
+ 8 | "CA" | 12 | 0.0 | [0.0, 1.0, 12.0] | 0.0
+ 9 | "NZ" | 15 | 0.0 | [1.0, 0.0, 15.0] | 0.0
+~~~
+
+<div class="codetabs">
+<div data-lang="scala" markdown="1">
+
+[`RFormula`](api/scala/index.html#org.apache.spark.ml.feature.RFormula) takes an R formula string, and optional parameters for the names of its output columns.
+
+{% highlight scala %}
+import org.apache.spark.ml.feature.RFormula
+
+val dataset = sqlContext.createDataFrame(Seq(
+ (7, "US", 18, 1.0),
+ (8, "CA", 12, 0.0),
+ (9, "NZ", 15, 0.0)
+)).toDF("id", "country", "hour", "clicked")
+val formula = new RFormula()
+ .setFormula("clicked ~ country + hour")
+ .setFeaturesCol("features")
+ .setLabelCol("label")
+val output = formula.fit(dataset).transform(dataset)
+output.select("features", "label").show()
+{% endhighlight %}
+</div>
+
+<div data-lang="java" markdown="1">
+
+[`RFormula`](api/java/org/apache/spark/ml/feature/RFormula.html) takes an R formula string, and optional parameters for the names of its output columns.
+
+{% highlight java %}
+import java.util.Arrays;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.ml.feature.RFormula;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.types.*;
+import static org.apache.spark.sql.types.DataTypes.*;
+
+StructType schema = createStructType(new StructField[] {
+ createStructField("id", IntegerType, false),
+ createStructField("country", StringType, false),
+ createStructField("hour", IntegerType, false),
+ createStructField("clicked", DoubleType, false)
+});
+JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList(
+ RowFactory.create(7, "US", 18, 1.0),
+ RowFactory.create(8, "CA", 12, 0.0),
+ RowFactory.create(9, "NZ", 15, 0.0)
+));
+DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
+
+RFormula formula = new RFormula()
+ .setFormula("clicked ~ country + hour")
+ .setFeaturesCol("features")
+ .setLabelCol("label");
+
+DataFrame output = formula.fit(dataset).transform(dataset);
+output.select("features", "label").show();
+{% endhighlight %}
+</div>
+
+<div data-lang="python" markdown="1">
+
+[`RFormula`](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) takes an R formula string, and optional parameters for the names of its output columns.
+
+{% highlight python %}
+from pyspark.ml.feature import RFormula
+
+dataset = sqlContext.createDataFrame(
+ [(7, "US", 18, 1.0),
+ (8, "CA", 12, 0.0),
+ (9, "NZ", 15, 0.0)],
+ ["id", "country", "hour", "clicked"])
+formula = RFormula(
+ formula="clicked ~ country + hour",
+ featuresCol="features",
+ labelCol="label")
+output = formula.fit(dataset).transform(dataset)
+output.select("features", "label").show()
+{% endhighlight %}
+</div>
+</div>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index a752dacd72..a7fa504442 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -42,8 +42,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
/**
* :: Experimental ::
* Implements the transforms required for fitting a dataset against an R model formula. Currently
- * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula
- * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
+ * we support a limited subset of the R operators, including '.', '~', '+', and '-'. Also see the
+ * R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
*/
@Experimental
class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase {