aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorzero323 <zero323@users.noreply.github.com>2017-03-20 10:58:30 -0700
committerJoseph K. Bradley <joseph@databricks.com>2017-03-20 10:58:30 -0700
commitbec6b16c1900fe93def89cc5eb51cbef498196cb (patch)
tree8df7d1b0ccee69a81bdc461887ef163a339b82ba /mllib/src
parentfc7554599a4b6e5c22aa35e7296b424a653a420b (diff)
downloadspark-bec6b16c1900fe93def89cc5eb51cbef498196cb.tar.gz
spark-bec6b16c1900fe93def89cc5eb51cbef498196cb.tar.bz2
spark-bec6b16c1900fe93def89cc5eb51cbef498196cb.zip
[SPARK-19899][ML] Replace featuresCol with itemsCol in ml.fpm.FPGrowth
## What changes were proposed in this pull request? Replaces `featuresCol` `Param` with `itemsCol`. See [SPARK-19899](https://issues.apache.org/jira/browse/SPARK-19899). ## How was this patch tested? Manual tests. Existing unit tests. Author: zero323 <zero323@users.noreply.github.com> Closes #17321 from zero323/SPARK-19899.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala35
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala14
2 files changed, 31 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index fa39dd954a..e2bc270b38 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
+import org.apache.spark.ml.param.shared.HasPredictionCol
import org.apache.spark.ml.util._
import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules,
FPGrowth => MLlibFPGrowth}
@@ -37,7 +37,20 @@ import org.apache.spark.sql.types._
/**
* Common params for FPGrowth and FPGrowthModel
*/
-private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPredictionCol {
+private[fpm] trait FPGrowthParams extends Params with HasPredictionCol {
+
+ /**
+ * Items column name.
+ * Default: "items"
+ * @group param
+ */
+ @Since("2.2.0")
+ val itemsCol: Param[String] = new Param[String](this, "itemsCol", "items column name")
+ setDefault(itemsCol -> "items")
+
+ /** @group getParam */
+ @Since("2.2.0")
+ def getItemsCol: String = $(itemsCol)
/**
* Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears
@@ -91,10 +104,10 @@ private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPre
*/
@Since("2.2.0")
protected def validateAndTransformSchema(schema: StructType): StructType = {
- val inputType = schema($(featuresCol)).dataType
+ val inputType = schema($(itemsCol)).dataType
require(inputType.isInstanceOf[ArrayType],
s"The input column must be ArrayType, but got $inputType.")
- SchemaUtils.appendColumn(schema, $(predictionCol), schema($(featuresCol)).dataType)
+ SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType)
}
}
@@ -133,7 +146,7 @@ class FPGrowth @Since("2.2.0") (
/** @group setParam */
@Since("2.2.0")
- def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+ def setItemsCol(value: String): this.type = set(itemsCol, value)
/** @group setParam */
@Since("2.2.0")
@@ -146,8 +159,8 @@ class FPGrowth @Since("2.2.0") (
}
private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = {
- val data = dataset.select($(featuresCol))
- val items = data.where(col($(featuresCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray)
+ val data = dataset.select($(itemsCol))
+ val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray)
val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport))
if (isSet(numPartitions)) {
mllibFP.setNumPartitions($(numPartitions))
@@ -156,7 +169,7 @@ class FPGrowth @Since("2.2.0") (
val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq))
val schema = StructType(Seq(
- StructField("items", dataset.schema($(featuresCol)).dataType, nullable = false),
+ StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false),
StructField("freq", LongType, nullable = false)))
val frequentItems = dataset.sparkSession.createDataFrame(rows, schema)
copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this)
@@ -198,7 +211,7 @@ class FPGrowthModel private[ml] (
/** @group setParam */
@Since("2.2.0")
- def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+ def setItemsCol(value: String): this.type = set(itemsCol, value)
/** @group setParam */
@Since("2.2.0")
@@ -235,7 +248,7 @@ class FPGrowthModel private[ml] (
.collect().asInstanceOf[Array[(Seq[Any], Seq[Any])]]
val brRules = dataset.sparkSession.sparkContext.broadcast(rules)
- val dt = dataset.schema($(featuresCol)).dataType
+ val dt = dataset.schema($(itemsCol)).dataType
// For each rule, examine the input items and summarize the consequents
val predictUDF = udf((items: Seq[_]) => {
if (items != null) {
@@ -249,7 +262,7 @@ class FPGrowthModel private[ml] (
} else {
Seq.empty
}}, dt)
- dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ dataset.withColumn($(predictionCol), predictUDF(col($(itemsCol))))
}
@Since("2.2.0")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
index 910d4b07d1..4603a618d2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
@@ -34,7 +34,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("FPGrowth fit and transform with different data types") {
Array(IntegerType, StringType, ShortType, LongType, ByteType).foreach { dt =>
- val data = dataset.withColumn("features", col("features").cast(ArrayType(dt)))
+ val data = dataset.withColumn("items", col("items").cast(ArrayType(dt)))
val model = new FPGrowth().setMinSupport(0.5).fit(data)
val generatedRules = model.setMinConfidence(0.5).associationRules
val expectedRules = spark.createDataFrame(Seq(
@@ -52,8 +52,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
(0, Array("1", "2"), Array.emptyIntArray),
(0, Array("1", "2"), Array.emptyIntArray),
(0, Array("1", "3"), Array(2))
- )).toDF("id", "features", "prediction")
- .withColumn("features", col("features").cast(ArrayType(dt)))
+ )).toDF("id", "items", "prediction")
+ .withColumn("items", col("items").cast(ArrayType(dt)))
.withColumn("prediction", col("prediction").cast(ArrayType(dt)))
assert(expectedTransformed.collect().toSet.equals(
transformed.collect().toSet))
@@ -79,7 +79,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
(1, Array("1", "2", "3", "5")),
(2, Array("1", "2", "3", "4")),
(3, null.asInstanceOf[Array[String]])
- )).toDF("id", "features")
+ )).toDF("id", "items")
val model = new FPGrowth().setMinSupport(0.7).fit(dataset)
val prediction = model.transform(df)
assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
@@ -108,11 +108,11 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val dataset = spark.createDataFrame(Seq(
Array("1", "3"),
Array("2", "3")
- ).map(Tuple1(_))).toDF("features")
+ ).map(Tuple1(_))).toDF("items")
val model = new FPGrowth().fit(dataset)
val prediction = model.transform(
- spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features")
+ spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
).first().getAs[Seq[String]]("prediction")
assert(prediction === Seq("3"))
@@ -127,7 +127,7 @@ object FPGrowthSuite {
(0, Array("1", "2")),
(0, Array("1", "2")),
(0, Array("1", "3"))
- )).toDF("id", "features")
+ )).toDF("id", "items")
}
/**