
/*
Change log
----------
2015-07-17 : Testing release
2015-07-20 : Scoring function update
2015-07-26 : BUG FIX: Remove offset when loading other quakes.
2015-07-26 : Add (time in seconds from start) to algorithm for other quakes.
2015-07-28 : Make sure UTC timezone is used for SimpleDateFormat.
2015-07-31 : Fix hour issue: new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
2015-07-31 : Prevent tester from crashing when data not available
*/

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.InputStream;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.Locale;
import java.util.TimeZone;

import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;

import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;

public class QuakeTester {
    public static boolean debug = false;
    public static boolean training = false;
    public static String dataFolder;
    private long time = 0;
    private QuakeTrainer trainer;

    public void printMessage(String s) {
        if (debug) {
            System.out.println(s);
        }
    }

    public class Coordinate {
        double latitude;
        double longitude;

        public Coordinate(double lat, double lon) {
            latitude = lat;
            longitude = lon;
        }
    }

    public class Quake {
        int timeSecs;
        double depth;
        double magnitude;
        Coordinate coord;

        public Quake(int tim, Coordinate loc, double dep, double mag) {
            coord = loc;
            depth = dep;
            magnitude = mag;
            timeSecs = tim;
        }
    }

    public class DataSet {
        Coordinate[] sites;
        Quake[] quakes;
        int sampleRate;
        int numOfSites;
        int numOfEMA;
        int numOfQuakes;

        int[] rawData = null;
        byte[] result = null;
        double[] EMA = null;

        String gtfStartTime, gtfEQtime;
        double gtfMagnitude, gtfLatitude, gtfLongitude, gtfDistToEQ, gtfEQSec;
        int gtfEQHour, gtfSite;

        DocumentBuilderFactory docBuilderFactory = null;
        DocumentBuilder docBuilder = null;

        public void loadSiteInfo(String sXmlFile) throws Exception {
            // load site information
            Document doc = docBuilder.parse(new File(sXmlFile));
            doc.getDocumentElement().normalize();

            NodeList listOfSites = doc.getElementsByTagName("Site");
            numOfSites = listOfSites.getLength();
            printMessage("Total number of sites: " + numOfSites);

            sites = new Coordinate[numOfSites];
            for (int s = 0; s < numOfSites; s++) {
                Element siteElement = (Element) listOfSites.item(s);
                if (s == 0) {
                    sampleRate = Integer.parseInt(siteElement.getAttribute("sample_rate"));
                    printMessage("Sample Rate = " + sampleRate);
                }
                double lat = Double.parseDouble(siteElement.getAttribute("latitude"));
                double lon = Double.parseDouble(siteElement.getAttribute("longitude"));
                sites[s] = new Coordinate(lat, lon);
                printMessage("site " + s + ": " + lat + "," + lon);
            }
            // allocate memory for hourly data
            rawData = new int[numOfSites * 3 * 3600 * sampleRate];
            result = new byte[3600 * sampleRate * 4];
        }

        public void loadEarthMagneticActivity(String sXmlFile) throws Exception {
            // load earth magnetic activity
            Document doc = docBuilder.parse(new File(sXmlFile));
            doc.getDocumentElement().normalize();

            NodeList listOfEMA = doc.getElementsByTagName("kp_hr");
            numOfEMA = listOfEMA.getLength();
            printMessage("Total number of EM activities: " + numOfEMA);

            EMA = new double[numOfEMA];
            for (int i = 0; i < numOfEMA; i++) {
                EMA[i] = Double.parseDouble(listOfEMA.item(i).getFirstChild().getNodeValue());
            }
        }

        public void loadOtherQuakes(String sXmlFile) throws Exception {
            // load earth magnetic activity
            Document doc = docBuilder.parse(new File(sXmlFile));
            doc.getDocumentElement().normalize();

            NodeList listOfQuakes = doc.getElementsByTagName("Quake");
            numOfQuakes = listOfQuakes.getLength();
            printMessage("Total number of other quakes: " + numOfQuakes);

            quakes = new Quake[numOfQuakes];
            for (int i = 0; i < numOfQuakes; i++) {
                Element quakeElement = (Element) listOfQuakes.item(i);
                int secs = Integer.parseInt(quakeElement.getAttribute("secs"));
                double lat = Double.parseDouble(quakeElement.getAttribute("latitude"));
                double lon = Double.parseDouble(quakeElement.getAttribute("longitude"));
                double depth = Double.parseDouble(quakeElement.getAttribute("depth"));
                double mag = Double.parseDouble(quakeElement.getAttribute("magnitude"));
                quakes[i] = new Quake(secs, new Coordinate(lat, lon), depth, mag);
            }
        }

        public double[] getOtherQuakes(int hour) {
            int hStart = hour * 3600;
            int hEnd = (hour + 1) * 3600;
            int numInHour = 0;
            for (int i = 0; i < numOfQuakes; i++) {
                if (quakes[i].timeSecs >= hStart && quakes[i].timeSecs < hEnd) numInHour++;
            }
            double[] oQuake = new double[numInHour * 5];
            int q = 0;
            for (int i = 0; i < numOfQuakes; i++) {
                if (quakes[i].timeSecs >= hStart && quakes[i].timeSecs < hEnd) {
                    oQuake[q] = quakes[i].coord.latitude;
                    oQuake[q + 1] = quakes[i].coord.longitude;
                    oQuake[q + 2] = quakes[i].depth;
                    oQuake[q + 3] = quakes[i].magnitude;
                    oQuake[q + 4] = quakes[i].timeSecs;
                    q += 5;
                }
            }
            return oQuake;
        }

        public DataSet(String sFolder) throws Exception {
            docBuilderFactory = DocumentBuilderFactory.newInstance();
            docBuilder = docBuilderFactory.newDocumentBuilder();
            loadSiteInfo(sFolder + "SiteInfo.xml");
            loadEarthMagneticActivity(sFolder + "Kp.xml");
            loadOtherQuakes(sFolder + "Quakes.xml");
        }

        public void readGTF(long seed) throws Exception {
            BufferedReader br = new BufferedReader(new FileReader("gtf.csv"));
            int numOfCases = Integer.parseInt(br.readLine());
            for (int i = 0; i < numOfCases; i++) {
                String s = br.readLine();
                String[] token = s.split(",");
                int setID = Integer.parseInt(token[0]);
                if (setID == seed) {
                    gtfStartTime = token[1];
                    gtfEQtime = token[2];
                    gtfMagnitude = Double.parseDouble(token[3]);
                    gtfLatitude = Double.parseDouble(token[4]);
                    gtfLongitude = Double.parseDouble(token[5]);
                    gtfSite = Integer.parseInt(token[6]);
                    gtfDistToEQ = Double.parseDouble(token[7]);
                    // Calculate number of hours till EQ
                    SimpleDateFormat ft = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
                    ft.setTimeZone(TimeZone.getTimeZone("UTC"));
                    Date date1 = ft.parse(gtfStartTime);
                    long startMSec = date1.getTime();
                    Date date2 = ft.parse(gtfEQtime);
                    long eqMSec = date2.getTime();
                    gtfEQSec = (eqMSec - startMSec) / 1000;
                    gtfEQHour = (int) (gtfEQSec / (60 * 60));
                    printMessage("Quake happened at hour = " + gtfEQHour + " and second = " + gtfEQSec + " at row " + (gtfEQSec - 3600.0 * gtfEQHour) * sampleRate);
                    break;
                }
            }
            br.close();
        }

        public int[] loadHourOld(String sFolder, int h, long seed) throws Exception {
            String fname = sFolder + "test" + seed + "_" + h + ".bin";
            printMessage("loading " + fname);
            int[] dt = new int[numOfSites * 3 * 3600 * sampleRate];
            File file = new File(fname);
            InputStream input = new BufferedInputStream(new FileInputStream(file));
            DataInputStream din = new DataInputStream(input);
            int prev = -1;
            int diff;
            for (int i = 0; i < dt.length; i++) {
                int b1 = (int) din.readByte();
                if (b1 < 0) b1 += 256;
                if ((b1 & 3) == 1) {
                    diff = (b1 >> 2) - (1 << 5);
                    dt[i] = prev + diff;
                } else if ((b1 & 3) == 2) {
                    int b2 = (int) din.readByte();
                    if (b2 < 0) b2 += 256;
                    diff = ((b1 + (b2 << 8)) >> 2) - (1 << 13);
                    dt[i] = prev + diff;
                } else if ((b1 & 3) == 3) {
                    int b2 = (int) din.readByte();
                    if (b2 < 0) b2 += 256;
                    int b3 = (int) din.readByte();
                    if (b3 < 0) b3 += 256;
                    diff = ((b1 + (b2 << 8) + (b3 << 16)) >> 2) - (1 << 21);
                    dt[i] = prev + diff;
                } else {
                    int b2 = (int) din.readByte();
                    if (b2 < 0) b2 += 256;
                    int b3 = (int) din.readByte();
                    if (b3 < 0) b3 += 256;
                    int b4 = (int) din.readByte();
                    if (b4 < 0) b4 += 256;
                    dt[i] = ((b1 + (b2 << 8) + (b3 << 16) + (b4 << 24)) >> 2) - 1;
                }
                prev = dt[i];
            }
            input.close();
            return dt;
        }

        private static final int a5 = 1 << 5;
        private static final int a13 = 1 << 13;
        private static final int a21 = 1 << 21;

        public int[] loadHour(String sFolder, int h, long seed) throws Exception {
            String fname = sFolder + "test" + seed + "_" + h + ".bin";
            printMessage("loading " + fname);
            int[] dt = new int[numOfSites * 3 * 3600 * sampleRate];
            //if (h>=0) return dt;

            FileInputStream fis = new FileInputStream(fname);
            final FileChannel channel = fis.getChannel();
            MappedByteBuffer buffer = channel.map(FileChannel.MapMode.READ_ONLY, 0, channel.size());

            //File file = new File(fname);
            //InputStream in = new BufferedInputStream(new FileInputStream(file));
            //byte[] bytes = new byte[(int) file.length()];
            //in.read(bytes);
            //in.close();
            int prev = -1;
            //int pos = 0;

            for (int i = 0; i < dt.length; i++) {
                int b1 = buffer.get() & 255;
                int b13 = b1 & 3;
                if (b13 == 1) prev = dt[i] = prev + (b1 >> 2) - a5;
                else if (b13 == 2) prev = dt[i] = prev + ((b1 + ((buffer.get() & 255) << 8)) >>> 2) - a13;
                else if (b13 == 3) prev = dt[i] = prev + ((b1 + ((buffer.get() & 255) << 8) + ((buffer.get() & 255) << 16)) >> 2) - a21;
                else prev = dt[i] = ((b1 + ((buffer.get() & 255) << 8) + ((buffer.get() & 255) << 16) + ((buffer.get() & 255) << 24)) >>> 2) - 1;
            }
            channel.close();
            fis.close();
            return dt;
        }
    }

    DataSet dset;
    QuakePredictor qp;

    public double doExec(long seed) throws Exception {

        try {
            // load ground truth data
            dset = new DataSet(dataFolder + seed + "/");
            dset.readGTF(seed);

            System.out.println(dset.gtfSite + " >> " + dset.gtfEQHour);

            double[] sitesData = new double[dset.numOfSites * 2];
            for (int i = 0; i < dset.numOfSites; i++) {
                sitesData[i * 2] = dset.sites[i].latitude;
                sitesData[i * 2 + 1] = dset.sites[i].longitude;
            }

            if (qp == null || !training) qp = new QuakePredictor();

            // pass site info to algo
            time = 0;
            long t = System.nanoTime();
            qp.init(dset.sampleRate, dset.numOfSites, sitesData);
            time += System.nanoTime() - t;

            // pass hourly data to algorithm
            double score = 0.0;
            //int hcnt = 0;
            for (int h = 0; h < dset.gtfEQHour; h++) {
                int[] hourlyData = null;
                try {
                    hourlyData = dset.loadHour(dataFolder + seed + "/", h, seed);
                } catch (Exception e) {
                    //System.err.println("WARNING: Missing data for hour " + h);
                    continue;
                }

                double[] otherQuakes = dset.getOtherQuakes(h);

                if (training) {
                    if (trainer == null) trainer = new QuakeTrainer();
                    trainer.train(dset.sampleRate, dset.numOfSites, dset.gtfSite, dset.gtfEQHour, h, hourlyData, dset.EMA[h], otherQuakes);
                    //hcnt++;
                    continue;
                }

                // read forecast matrix
                t = System.nanoTime();
                double[] ret = qp.forecast(h, hourlyData, dset.EMA[h], otherQuakes).clone();
                time += System.nanoTime() - t;

                if (ret.length != dset.numOfSites * 2160) {
                    System.err.println("ERROR: The number of elements in your return is incorrect. You returned " + ret.length + " it should be " + (dset.numOfSites * (2160)));
                    return -1.0;
                }

                // score return
                if (h >= 768) {
                    int idx = h * dset.numOfSites;
                    // normalize to sum of 1
                    double sum = 0;
                    for (int i = idx; i < ret.length; i++) {
                        sum += ret[i];
                    }
                    if (Math.abs(sum) > 1e-9) {
                        for (int i = idx; i < ret.length; i++) {
                            ret[i] /= sum;
                        }
                    }
                    // score only from hour 768
                    //F = sizeof(NN) * (2 * G  Sum of squared values in NN) - 1
                    int sizeofNN = dset.numOfSites * (2160 - h);
                    int gtfIndex = (dset.gtfEQHour) * dset.numOfSites + dset.gtfSite;
                    double ssv = 0; // sum of squared values in NN
                    for (int i = idx; i < ret.length; i++) {
                        ssv += ret[i] * ret[i];
                    }
                    //ret[gtfIndex] = 1;
                    //ssv=1;
                    double hourScore = (double) sizeofNN * (2.0 * ret[gtfIndex] - ssv) - 1.0;
                    score += hourScore;
                    //hcnt++;
                    if (!training && h % 100 == 0) System.out.println("\t" + h + "\t" + hourScore + "\t" + score);
                }
            }
            //score /= hcnt;
            return score;
        } catch (Exception e) {
            System.out.println("FAILURE: " + e.getMessage());
            e.printStackTrace();
        }
        return -1.0;
    }

    private static void train() throws Exception {
        training = true;
        QuakeTester tester = new QuakeTester();
        for (int seed : new int[] {6,9,11,23,67,78,105,113,115,116,129,142,143,149,151,156,166,171,180,183,199,202,205,215,233,237,286,299}) {
            System.out.println(seed + "\t" + new Date());
            tester.doExec(seed);
        }
        tester.trainer.done();
        training = false;
    }

    private static void test() throws Exception {
        debug = false;
        long t = System.currentTimeMillis();
        for (int seed : new int[] {2,4,7,8,13,15,17,20,21,36,63,69,90,138,147,148,152,155,163,169,172,176,193,194,211,212,213,232,234,277,281,291}) {
            QuakeTester tester = new QuakeTester();
            double score = tester.doExec(seed);
            System.out.println(seed + "\t" + tester.dset.numOfSites + "\t" + tester.dset.gtfEQHour + "\t" + tester.time / 1000000 + "\t" + score);
        }
        t = System.currentTimeMillis() - t;
        System.out.println("T = " + t);
    }

    public static void main(String[] args) throws Exception {
        if (args.length != 2) {
            System.err.println("Usage: java QuakeTester [train|convert|test] data-folder");
        } else {
            Locale.setDefault(Locale.US);
            dataFolder = args[1];

            if (args[0].equals("train")) train();
            else if (args[0].equals("convert")) Converter.convert();
            else if (args[0].equals("test")) test();
            else System.err.println("Usage: java QuakeTester [train|convert|test] data-folder");
        }
    }
}