123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375 |
- /* ====================================================================
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==================================================================== */
-
- /*
- * Notes:
- * Duplicate x values don't work most of the time because of the way the
- * math library handles multiple regression.
- * The math library currently fails when the number of x variables is >=
- * the sample size (see https://github.com/Hipparchus-Math/hipparchus/issues/13).
- */
-
- package org.apache.poi.ss.formula.functions;
-
- import java.util.Arrays;
-
- import org.apache.commons.math3.linear.SingularMatrixException;
- import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
- import org.apache.poi.ss.formula.CacheAreaEval;
- import org.apache.poi.ss.formula.eval.AreaEval;
- import org.apache.poi.ss.formula.eval.BoolEval;
- import org.apache.poi.ss.formula.eval.ErrorEval;
- import org.apache.poi.ss.formula.eval.EvaluationException;
- import org.apache.poi.ss.formula.eval.MissingArgEval;
- import org.apache.poi.ss.formula.eval.NotImplementedException;
- import org.apache.poi.ss.formula.eval.NumberEval;
- import org.apache.poi.ss.formula.eval.NumericValueEval;
- import org.apache.poi.ss.formula.eval.RefEval;
- import org.apache.poi.ss.formula.eval.ValueEval;
-
-
- /**
- * Implementation for the Excel function TREND<p>
- *
- * Syntax:<br>
- * TREND(known_y's, known_x's, new_x's, constant)
- * <table border="0" cellpadding="1" cellspacing="0" summary="Parameter descriptions">
- * <tr><th>known_y's, known_x's, new_x's</th><td>typically area references, possibly cell references or scalar values</td></tr>
- * <tr><th>constant</th><td><b>TRUE</b> or <b>FALSE</b>:
- * determines whether the regression line should include an intercept term</td></tr>
- * </table><br>
- * If <b>known_x's</b> is not given, it is assumed to be the default array {1, 2, 3, ...}
- * of the same size as <b>known_y's</b>.<br>
- * If <b>new_x's</b> is not given, it is assumed to be the same as <b>known_x's</b><br>
- * If <b>constant</b> is omitted, it is assumed to be <b>TRUE</b>
- * </p>
- */
-
- public final class Trend implements Function {
- MatrixFunction.MutableValueCollector collector = new MatrixFunction.MutableValueCollector(false, false);
- private static final class TrendResults {
- public double[] vals;
- public int resultWidth;
- public int resultHeight;
-
- public TrendResults(double[] vals, int resultWidth, int resultHeight) {
- this.vals = vals;
- this.resultWidth = resultWidth;
- this.resultHeight = resultHeight;
- }
- }
-
- public ValueEval evaluate(ValueEval[] args, int srcRowIndex, int srcColumnIndex) {
- if (args.length < 1 || args.length > 4) {
- return ErrorEval.VALUE_INVALID;
- }
- try {
- TrendResults tr = getNewY(args);
- ValueEval[] vals = new ValueEval[tr.vals.length];
- for (int i = 0; i < tr.vals.length; i++) {
- vals[i] = new NumberEval(tr.vals[i]);
- }
- if (tr.vals.length == 1) {
- return vals[0];
- }
- return new CacheAreaEval(srcRowIndex, srcColumnIndex, srcRowIndex + tr.resultHeight - 1, srcColumnIndex + tr.resultWidth - 1, vals);
- } catch (EvaluationException e) {
- return e.getErrorEval();
- }
- }
-
- private static double[][] evalToArray(ValueEval arg) throws EvaluationException {
- double[][] ar;
- ValueEval eval;
- if (arg instanceof MissingArgEval) {
- return new double[0][0];
- }
- if (arg instanceof RefEval) {
- RefEval re = (RefEval) arg;
- if (re.getNumberOfSheets() > 1) {
- throw new EvaluationException(ErrorEval.VALUE_INVALID);
- }
- eval = re.getInnerValueEval(re.getFirstSheetIndex());
- } else {
- eval = arg;
- }
- if (eval == null) {
- throw new RuntimeException("Parameter may not be null.");
- }
-
- if (eval instanceof AreaEval) {
- AreaEval ae = (AreaEval) eval;
- int w = ae.getWidth();
- int h = ae.getHeight();
- ar = new double[h][w];
- for (int i = 0; i < h; i++) {
- for (int j = 0; j < w; j++) {
- ValueEval ve = ae.getRelativeValue(i, j);
- if (!(ve instanceof NumericValueEval)) {
- throw new EvaluationException(ErrorEval.VALUE_INVALID);
- }
- ar[i][j] = ((NumericValueEval)ve).getNumberValue();
- }
- }
- } else if (eval instanceof NumericValueEval) {
- ar = new double[1][1];
- ar[0][0] = ((NumericValueEval)eval).getNumberValue();
- } else {
- throw new EvaluationException(ErrorEval.VALUE_INVALID);
- }
-
- return ar;
- }
-
- private static double[][] getDefaultArrayOneD(int w) {
- double[][] array = new double[w][1];
- for (int i = 0; i < w; i++) {
- array[i][0] = i + 1.;
- }
- return array;
- }
-
- private static double[] flattenArray(double[][] twoD) {
- if (twoD.length < 1) {
- return new double[0];
- }
- double[] oneD = new double[twoD.length * twoD[0].length];
- for (int i = 0; i < twoD.length; i++) {
- System.arraycopy(twoD[i], 0, oneD, i * twoD[0].length + 0, twoD[0].length);
- }
- return oneD;
- }
-
- private static double[][] flattenArrayToRow(double[][] twoD) {
- if (twoD.length < 1) {
- return new double[0][0];
- }
- double[][] oneD = new double[twoD.length * twoD[0].length][1];
- for (int i = 0; i < twoD.length; i++) {
- for (int j = 0; j < twoD[0].length; j++) {
- oneD[i * twoD[0].length + j][0] = twoD[i][j];
- }
- }
- return oneD;
- }
-
- private static double[][] switchRowsColumns(double[][] array) {
- double[][] newArray = new double[array[0].length][array.length];
- for (int i = 0; i < array.length; i++) {
- for (int j = 0; j < array[0].length; j++) {
- newArray[j][i] = array[i][j];
- }
- }
- return newArray;
- }
-
- /**
- * Check if all columns in a matrix contain the same values.
- * Return true if the number of distinct values in each column is 1.
- *
- * @param matrix column-oriented matrix. A Row matrix should be transposed to column .
- * @return true if all columns contain the same value
- */
- private static boolean isAllColumnsSame(double[][] matrix){
- if(matrix.length == 0) return false;
-
- boolean[] cols = new boolean[matrix[0].length];
- for (int j = 0; j < matrix[0].length; j++) {
- double prev = Double.NaN;
- for (int i = 0; i < matrix.length; i++) {
- double v = matrix[i][j];
- if(i > 0 && v != prev) {
- cols[j] = true;
- break;
- }
- prev = v;
- }
- }
- boolean allEquals = true;
- for (boolean x : cols) {
- if(x) {
- allEquals = false;
- break;
- }
- }
- return allEquals;
-
- }
-
- private static TrendResults getNewY(ValueEval[] args) throws EvaluationException {
- double[][] xOrig;
- double[][] x;
- double[][] yOrig;
- double[] y;
- double[][] newXOrig;
- double[][] newX;
- double[][] resultSize;
- boolean passThroughOrigin = false;
- switch (args.length) {
- case 1:
- yOrig = evalToArray(args[0]);
- xOrig = new double[0][0];
- newXOrig = new double[0][0];
- break;
- case 2:
- yOrig = evalToArray(args[0]);
- xOrig = evalToArray(args[1]);
- newXOrig = new double[0][0];
- break;
- case 3:
- yOrig = evalToArray(args[0]);
- xOrig = evalToArray(args[1]);
- newXOrig = evalToArray(args[2]);
- break;
- case 4:
- yOrig = evalToArray(args[0]);
- xOrig = evalToArray(args[1]);
- newXOrig = evalToArray(args[2]);
- if (!(args[3] instanceof BoolEval)) {
- throw new EvaluationException(ErrorEval.VALUE_INVALID);
- }
- // The argument in Excel is false when it *should* pass through the origin.
- passThroughOrigin = !((BoolEval)args[3]).getBooleanValue();
- break;
- default:
- throw new EvaluationException(ErrorEval.VALUE_INVALID);
- }
-
- if (yOrig.length < 1) {
- throw new EvaluationException(ErrorEval.VALUE_INVALID);
- }
- y = flattenArray(yOrig);
- newX = newXOrig;
-
- if (newXOrig.length > 0) {
- resultSize = newXOrig;
- } else {
- resultSize = new double[1][1];
- }
-
- if (y.length == 1) {
- /* See comment at top of file
- if (xOrig.length > 0 && !(xOrig.length == 1 || xOrig[0].length == 1)) {
- throw new EvaluationException(ErrorEval.REF_INVALID);
- } else if (xOrig.length < 1) {
- x = new double[1][1];
- x[0][0] = 1;
- } else {
- x = new double[1][];
- x[0] = flattenArray(xOrig);
- if (newXOrig.length < 1) {
- resultSize = xOrig;
- }
- }*/
- throw new NotImplementedException("Sample size too small");
- } else if (yOrig.length == 1 || yOrig[0].length == 1) {
- if (xOrig.length < 1) {
- x = getDefaultArrayOneD(y.length);
- if (newXOrig.length < 1) {
- resultSize = yOrig;
- }
- } else {
- x = xOrig;
- if (xOrig[0].length > 1 && yOrig.length == 1) {
- x = switchRowsColumns(x);
- }
- if (newXOrig.length < 1) {
- resultSize = xOrig;
- }
- }
- if (newXOrig.length > 0 && (x.length == 1 || x[0].length == 1)) {
- newX = flattenArrayToRow(newXOrig);
- }
- } else {
- if (xOrig.length < 1) {
- x = getDefaultArrayOneD(y.length);
- if (newXOrig.length < 1) {
- resultSize = yOrig;
- }
- } else {
- x = flattenArrayToRow(xOrig);
- if (newXOrig.length < 1) {
- resultSize = xOrig;
- }
- }
- if (newXOrig.length > 0) {
- newX = flattenArrayToRow(newXOrig);
- }
- if (y.length != x.length || yOrig.length != xOrig.length) {
- throw new EvaluationException(ErrorEval.REF_INVALID);
- }
- }
-
- if (newXOrig.length < 1) {
- newX = x;
- } else if (newXOrig.length == 1 && newXOrig[0].length > 1 && xOrig.length > 1 && xOrig[0].length == 1) {
- newX = switchRowsColumns(newXOrig);
- }
-
- if (newX[0].length != x[0].length) {
- throw new EvaluationException(ErrorEval.REF_INVALID);
- }
-
- if (x[0].length >= x.length) {
- /* See comment at top of file */
- throw new NotImplementedException("Sample size too small");
- }
-
- int resultHeight = resultSize.length;
- int resultWidth = resultSize[0].length;
-
- if(isAllColumnsSame(x)){
- double[] result = new double[newX.length];
- double avg = Arrays.stream(y).average().orElse(0);
- for(int i = 0; i < result.length; i++) result[i] = avg;
- return new TrendResults(result, resultWidth, resultHeight);
- }
-
- OLSMultipleLinearRegression reg = new OLSMultipleLinearRegression();
- if (passThroughOrigin) {
- reg.setNoIntercept(true);
- }
-
- try {
- reg.newSampleData(y, x);
- } catch (IllegalArgumentException e) {
- throw new EvaluationException(ErrorEval.REF_INVALID);
- }
- double[] par;
- try {
- par = reg.estimateRegressionParameters();
- } catch (SingularMatrixException e) {
- throw new NotImplementedException("Singular matrix in input");
- }
-
- double[] result = new double[newX.length];
- for (int i = 0; i < newX.length; i++) {
- result[i] = 0;
- if (passThroughOrigin) {
- for (int j = 0; j < par.length; j++) {
- result[i] += par[j] * newX[i][j];
- }
- } else {
- result[i] = par[0];
- for (int j = 1; j < par.length; j++) {
- result[i] += par[j] * newX[i][j - 1];
- }
- }
- }
- return new TrendResults(result, resultWidth, resultHeight);
- }
- }
|