aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-03-16 14:18:35 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-16 14:18:35 -0700
commit6fc2b6541fd5ab73b289af5f7296fc602b5b4dce (patch)
treeec8da69765b849a72e0faf5914f25a6dbd4d21f6 /mllib/src/test/java/org/apache
parent3f06eb72ca0c3e5779a702c7c677229e0c480751 (diff)
downloadspark-6fc2b6541fd5ab73b289af5f7296fc602b5b4dce.tar.gz
spark-6fc2b6541fd5ab73b289af5f7296fc602b5b4dce.tar.bz2
spark-6fc2b6541fd5ab73b289af5f7296fc602b5b4dce.zip
[SPARK-11888][ML] Decision tree persistence in spark.ml
### What changes were proposed in this pull request? Made these MLReadable and MLWritable: DecisionTreeClassifier, DecisionTreeClassificationModel, DecisionTreeRegressor, DecisionTreeRegressionModel * The shared implementation is in treeModels.scala * I use case classes to create a DataFrame to save, and I use the Dataset API to parse loaded files. Other changes: * Made CategoricalSplit.numCategories public (to use in persistence) * Fixed a bug in DefaultReadWriteTest.testEstimatorAndModelReadWrite, where it did not call the checkModelData function passed as an argument. This caused an error in LDASuite, which I fixed. ### How was this patch tested? Persistence is tested via unit tests. For each algorithm, there are 2 non-trivial trees (depth 2). One is built with continuous features, and one with categorical; this ensures that both types of splits are tested. Author: Joseph K. Bradley <joseph@databricks.com> Closes #11581 from jkbradley/dt-io.
Diffstat (limited to 'mllib/src/test/java/org/apache')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java2
6 files changed, 6 insertions, 6 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
index 0d923dfeff..1f23682621 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -29,7 +29,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
index f470f4ada6..74841058a2 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
@@ -27,7 +27,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 9a63cef2a8..75061464e5 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -27,7 +27,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
index a1575300a8..fa3b28ed4f 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -27,7 +27,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
index 9477e8d2bf..8413ea0e0a 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
@@ -27,7 +27,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index a90535d11a..b6f793f6de 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -28,7 +28,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
-import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;