import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.zip.Deflater;

public class QuakeTrainer {
	private float[][] training = null;
	private byte[] classif = null;
	private static final int numClasses = QuakePredictor.numClasses;
	private static final int maxInput = 2160 * 9 * 75;
	private AtomicInteger idxTraining = new AtomicInteger();
	private int[][][] vals = new int[9][5][2160];

	public void train(int sampleRate, int numOfSites, int eqSite, int eqHour, int hour, int[] data, double K, double[] globalQuakes) {
		if (sampleRate == 32) {
			data = QuakePredictor.conv(data);
			sampleRate = 48;
		}
		if (hour % 100 == 0) System.out.println("\tTraining: " + hour);
		int k = (int) Math.round(K * 100);
		for (int site = 0; site < numOfSites; site++) {
			List<Float> features = QuakePredictor.extractFeatures(vals[site], hour, site, sampleRate, data, k);

			if (training == null) {
				training = new float[features.size()][maxInput];
				classif = new byte[maxInput];
			}
			int idx = idxTraining.getAndIncrement();
			int rem = (eqHour - hour) / 24 / QuakePredictor.group;
			classif[idx] = (byte) (site != eqSite || rem >= numClasses - 1 ? 0 : (rem + 1));
			for (int l = 0; l < features.size(); l++) {
				training[l][idx] = features.get(l);
			}
		}
	}

	public void done() {
		try {
			RandomForestBuilder builder = new RandomForestBuilder();
			RandomForestPredictor predictor = builder.train(training, classif, numClasses, 64, idxTraining.get());
			save(predictor, new File("qfq.dat"));
		} catch (Exception e) {
			e.printStackTrace();
		}
	}

	void save(RandomForestPredictor predictor, File file) throws Exception {
		DataOutputStream out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)));
		out.writeInt(predictor.trees);
		for (int i = 0; i < predictor.trees; i++) {
			out.writeInt(predictor.roots[i]);
		}
		out.writeInt(predictor.free);
		for (int i = 0; i < predictor.free; i++) {
			int splitFeature = predictor.splitFeature[i];
			out.writeShort(splitFeature);
			if (splitFeature >= 0) {
				out.writeInt(predictor.nodeLeft[i]);
				out.writeFloat(predictor.splitValue[i]);
			} else {
				for (int j = 1; j < numClasses; j++) {
					out.writeFloat(predictor.classif[i * numClasses + j]);
				}
			}
		}
		out.close();
	}
}

class FilePredictor {
	public static RandomForestPredictor load(File file, int maxTrees, int numClasses) throws Exception {
		DataInputStream in = new DataInputStream(new BufferedInputStream(new FileInputStream(file)));
		int trees = in.readInt();
		int[] t = new int[trees];
		for (int i = 0; i < trees; i++) {
			t[i] = in.readInt();
		}
		int nodes = in.readInt();
		if (maxTrees > 0 && trees > maxTrees) {
			trees = maxTrees;
			nodes = t[trees + 1];
		}
		RandomForestPredictor predictor = new RandomForestPredictor(trees, nodes, numClasses);
		System.arraycopy(t, 0, predictor.roots, 0, trees);
		byte[] bytes = new byte[(int) file.length()];
		in.read(bytes);
		in.close();

		int pos = 0;
		for (int i = 0; i < nodes; i++) {
			int splitFeature = predictor.splitFeature[i] = (short) (((bytes[pos++] & 0xFF) << 8) + ((bytes[pos++] & 0xFF) << 0));
			if (splitFeature >= 0) {
				predictor.nodeLeft[i] = (((bytes[pos++] & 0xFF) << 24) + ((bytes[pos++] & 0xFF) << 16) + ((bytes[pos++] & 0xFF) << 8) + ((bytes[pos++] & 0xFF) << 0));
				predictor.splitValue[i] = Float.intBitsToFloat((((bytes[pos++] & 0xFF) << 24) + ((bytes[pos++] & 0xFF) << 16) + ((bytes[pos++] & 0xFF) << 8) + ((bytes[pos++] & 0xFF) << 0)));
			} else {
				for (int j = 1; j < numClasses; j++) {
					predictor.classif[i * numClasses + j] = Float.intBitsToFloat((((bytes[pos++] & 0xFF) << 24) + ((bytes[pos++] & 0xFF) << 16) + ((bytes[pos++] & 0xFF) << 8) + ((bytes[pos++] & 0xFF) << 0)));
				}
			}
		}
		return predictor;
	}

}

class RandomForestBuilder {
	public RandomForestPredictor train(float[][] features, byte[] classif, int numClasses, int maxTrees, int rows) {
		RandomForestPredictor rf = new RandomForestPredictor(maxTrees, numClasses);

		int numThreads = Math.max(1, Runtime.getRuntime().availableProcessors() - 1);
		Thread[] threads = new Thread[numThreads];
		for (int i = 0; i < numThreads; i++) {
			final int idx = i;
			threads[i] = new Thread() {
				public void run() {
					for (int i = idx; i < maxTrees; i += numThreads) {
						ClassificationNode root = new ClassificationTree(features, classif, numClasses, i, rows).getRoot();
						rf.add(root);
						if (idx == 0) System.out.println("    " + i);
					}
				}
			};
			threads[i].start();
		}
		try {
			for (int i = 0; i < numThreads; i++) {
				threads[i].join();
			}
		} catch (InterruptedException e) {
		}
		return rf;
	}
}

class Converter {
	private static final String dict = QuakePredictor.dict;
	private static final int numClasses = QuakePredictor.numClasses;

	static void convert() {
		try {
			RandomForestPredictor predictor = FilePredictor.load(new File("qfq.dat"), 14, QuakePredictor.numClasses);
			byte[] bytes = new byte[3 + 3 * predictor.trees + 3 + 4 * numClasses * predictor.free];
			int pos = 0;

			writeInt(bytes, predictor.trees, pos);
			pos += 3;
			for (int i = 0; i < predictor.trees; i++) {
				writeInt(bytes, predictor.roots[i], pos);
				pos += 3;
			}
			writeInt(bytes, predictor.free, pos);
			pos += 3;

			for (int i = 0; i < predictor.free; i++) {
				int splitFeature = predictor.splitFeature[i];
				bytes[pos++] = (byte) (splitFeature + 1);
				if (splitFeature >= 0) {
					writeInt(bytes, predictor.nodeLeft[i] + 1, pos);
					pos += 3;
					writeFloat(bytes, predictor.splitValue[i], pos);
					pos += 4;
				} else {
					for (int j = 1; j < numClasses; j++) {
						writeFloat(bytes, predictor.classif[i * numClasses + j], pos);
						pos += 4;
					}
				}
			}
			byte[] zip = zipBytes(bytes, pos);
			StringBuilder sb = new StringBuilder();
			sb.append("private static int blen = " + pos + ";\n");
			sb.append("private static int zlen = " + zip.length + ";\n");
			sb.append("private static String[] data = new String[]{\n\"");
			int div = dict.length();
			for (int i = 0; i < zip.length; i += 4) {
				long v = ((zip[i] & 255L) << 24) | ((zip[i + 1] & 255L) << 16) | ((zip[i + 2] & 255L) << 8) | (zip[i + 3] & 255L);
				for (int j = 0; j < 5; j++) {
					sb.append(dict.charAt((int) (v % div)));
					v /= div;
				}
				if (i > 0 && i % 2000 == 0) sb.append("\",\n\"");
			}
			sb.append("\"};");
			System.out.println(sb);
			System.out.println(pos + "/" + bytes.length + "/" + zip.length + "/" + sb.length());
		} catch (Exception e) {
			e.printStackTrace();
		}
	}

	public static byte[] zipBytes(byte[] input, int len) throws IOException {
		Deflater deflater = new Deflater(9);
		deflater.setInput(input, 0, len);
		deflater.finish();
		byte[] aux = new byte[input.length];
		int size = deflater.deflate(aux);
		byte[] ret = Arrays.copyOf(aux, (size + 3) / 4 * 4);
		return ret;
	}

	static void writeInt(byte[] bytes, int val, int pos) {
		bytes[pos] = (byte) (val & 255);
		bytes[pos + 1] = (byte) ((val >>> 8) & 255);
		bytes[pos + 2] = (byte) ((val >>> 16) & 255);
	}

	static void writeFloat(byte[] bytes, float val, int pos) {
		int b = Float.floatToIntBits(val);
		bytes[pos] = (byte) (b & 255);
		bytes[pos + 1] = (byte) ((b >>> 8) & 255);
		bytes[pos + 2] = (byte) ((b >>> 16) & 255);
		bytes[pos + 3] = (byte) ((b >>> 24) & 255);
	}
}

class ClassificationTree {
	private static int SPLIT_STEPS = 8;
	private static int MIN_ROWS_PER_NODE = 8;
	private static int MAX_LEVEL = 16;
	private static final int MAX_NODES = 65536;
	private final ClassificationNode[] nodes = new ClassificationNode[MAX_NODES + 2];
	private ClassificationNode root;
	private final Random rnd;

	ClassificationTree(float[][] features, byte[] classif, int numClasses, int idx, int totRows) {
		rnd = new Random(197209091220L + idx);
		int numFeatures = features.length;
		int featuresSteps = (int) (2 * Math.sqrt(numFeatures));
		int[] weight = new int[totRows];
		for (int i = 0; i < totRows; i++) {
			weight[rnd.nextInt(totRows)]++;
		}
		int numSel = 0;
		for (int i = 0; i < totRows; i++) {
			if (weight[i] > 0) numSel++;
		}
		int[] selRows = new int[numSel];
		numSel = 0;
		for (int i = 0; i < totRows; i++) {
			if (weight[i] > 0) selRows[numSel++] = i;
		}
		int[] classifCount = new int[numClasses];
		for (int row : selRows) {
			classifCount[classif[row]] += weight[row];
		}

		root = new ClassificationNode(classifCount, totRows, impurity(classifCount), 1, 0, numSel - 1);
		int nodeCnt = 0;
		nodes[nodeCnt++] = root;
		float[] prevSplitVal = new float[SPLIT_STEPS];
		int[] leftCount = new int[numClasses];
		int[] rightCount = new int[numClasses];
		for (int i = 0; i < nodeCnt && nodeCnt < MAX_NODES; i++) {
			ClassificationNode node = nodes[i];
			if (node.isPure() || node.level >= MAX_LEVEL || node.total <= MIN_ROWS_PER_NODE) continue;
			double maxSplitGain = 0;
			float bestSplitVal = 0;
			int bestSplitFeature = -1;
			for (int j = 0; j < featuresSteps; j++) {
				int splitFeature = rnd.nextInt(numFeatures);
				float[] featuresSplitFeature = features[splitFeature];
				NEXT: for (int k = 0; k < SPLIT_STEPS; k++) {
					float splitVal = prevSplitVal[k] = featuresSplitFeature[randomNodeRow(selRows, node, rnd)];
					for (int l = 0; l < k; l++) {
						if (splitVal == prevSplitVal[l]) continue NEXT;
					}
					Arrays.fill(leftCount, 0);
					Arrays.fill(rightCount, 0);
					for (int r = node.startRow; r <= node.endRow; r++) {
						int row = selRows[r];
						if (featuresSplitFeature[row] < splitVal) leftCount[classif[row]] += weight[row];
						else rightCount[classif[row]] += weight[row];
					}
					if (sum(leftCount) < MIN_ROWS_PER_NODE || sum(rightCount) < MIN_ROWS_PER_NODE) continue;
					double splitGain = node.impurity - impurity(leftCount, rightCount);
					if (splitGain > maxSplitGain) {
						maxSplitGain = splitGain;
						bestSplitFeature = splitFeature;
						bestSplitVal = splitVal;
					}
				}
			}
			if (bestSplitFeature >= 0) {
				int[] leftNode = new int[numClasses];
				int[] rightNode = new int[numClasses];

				int endLeft = node.endRow;
				float[] featuresSplitFeature = features[bestSplitFeature];
				for (int r = node.startRow; r <= endLeft; r++) {
					int row = selRows[r];
					int w = weight[row];
					if (featuresSplitFeature[row] < bestSplitVal) {
						leftNode[classif[row]] += w;
					} else {
						rightNode[classif[row]] += w;
						selRows[r--] = selRows[endLeft];
						selRows[endLeft--] = row;
					}
				}
				node.left = new ClassificationNode(leftNode, sum(leftNode), impurity(leftNode), node.level + 1, node.startRow, endLeft);
				node.right = new ClassificationNode(rightNode, sum(rightNode), impurity(rightNode), node.level + 1, endLeft + 1, node.endRow);
				nodes[nodeCnt++] = node.left;
				nodes[nodeCnt++] = node.right;
				node.splitVal = bestSplitVal;
				node.splitFeature = bestSplitFeature;
			}
		}
	}

	public ClassificationNode getRoot() {
		return root;
	}

	private double impurity(int[] cnt) {
		int tot = sum(cnt);
		if (tot <= 1) return 0;
		double val = 0;
		int v = 0;
		double lt = Math.log(tot);
		for (int i = 0; i < cnt.length; i++) {
			if ((v = cnt[i]) > 0) val += v * (lt - Math.log(v));
		}
		return val / tot;
	}

	private double impurity(int[] cnt1, int[] cnt2) {
		int tot1 = sum(cnt1);
		int tot2 = sum(cnt2);
		return (impurity(cnt1) * tot1 + impurity(cnt2) * tot2) / (tot1 + tot2);
	}

	private static int sum(int[] cnt) {
		int tot = 0;
		for (int v : cnt) {
			tot += v;
		}
		return tot;
	}

	private final int randomNodeRow(int[] rows, ClassificationNode node, Random rnd) {
		return rows[rnd.nextInt(node.endRow - node.startRow + 1) + node.startRow];
	}
}

class Random {
	private static final long mask0 = 0x80000000L;
	private static final long mask1 = 0x7fffffffL;
	private static final long[] mult = new long[] {0,0x9908b0dfL};
	private final long[] mt = new long[624];
	private int idx = 0;

	Random(long seed) {
		init(seed);
	}

	private void init(long seed) {
		mt[0] = seed & 0xffffffffl;
		for (int i = 1; i < 624; i++) {
			mt[i] = 1812433253l * (mt[i - 1] ^ (mt[i - 1] >>> 30)) + i;
			mt[i] &= 0xffffffffl;
		}
	}

	private void generate() {
		for (int i = 0; i < 227; i++) {
			long y = (mt[i] & mask0) | (mt[i + 1] & mask1);
			mt[i] = mt[i + 397] ^ (y >> 1) ^ mult[(int) (y & 1)];
		}
		for (int i = 227; i < 623; i++) {
			long y = (mt[i] & mask0) | (mt[i + 1] & mask1);
			mt[i] = mt[i - 227] ^ (y >> 1) ^ mult[(int) (y & 1)];
		}
		long y = (mt[623] & mask0) | (mt[0] & mask1);
		mt[623] = mt[396] ^ (y >> 1) ^ mult[(int) (y & 1)];
	}

	private long rand() {
		if (idx == 0) generate();
		long y = mt[idx];
		idx = (idx + 1) % 624;
		y ^= (y >> 11);
		y ^= (y << 7) & 0x9d2c5680l;
		y ^= (y << 15) & 0xefc60000l;
		return y ^ (y >> 18);
	}

	int nextInt(int n) {
		return (int) (rand() % n);
	}
}
