TCO19 Round 3A Editorial

FamilySeatingArrangements

Let’s split the seating into two phases. First, we seat the parents, then we seat the children. The parents can be seated without any restrictions. Let’s fix the number of empty tables. If there are x empty tables, each child has a choice of (x+1) tables to sit at (regardless of which family they are from).

Thus, if f(n, x) is the number of ways to seat n parents with x empty tables, the answer is the sum of f(n,x) * (x+1)^(sum a) over all x from 1 to n.

f(n,x) can be computed in quadratic time using a dp solution, or it can be computed using a combination of binomial coefficients and stirling numbers.

public class FamilySeatingArrangement {
    public int mod = 1000000007;
    public int exp(long b, long e) {
        long r = 1;
        while (e > 0) {
            if (e % 2 == 1) r = r * b %  mod;
            b = b * b % mod;
            e >>= 1;
        }
        return (int)r;
    }
    public int countWays(int[] a, int k) {
        long nways = 0;
        int s = 0;
        for (int x : a) s += x;
        int n = a.length;
        long[] fact = new long[n+10];
        fact[0] = 1;
        for (int i = 1; i < fact.length; i++) fact[i] = fact[i-1] * i % mod;
        long[][] stir = new long[n+10][n+10];
        for (int i = 1; i < stir.length; i++) {
            stir[i][1] = 1;
            for (int j = 2; j < i; j++) {
                stir[i][j] = (stir[i-1][j-1] + j * stir[i-1][j]) % mod;
            }
            stir[i][i] = 1;
        }
        long[][] comb = new long[k+10][k+10];
        comb[0][0] = 1;
        for (int i = 1; i < comb.length; i++) {
            comb[i][0] = 1;
            for (int j = 1; j < i; j++) {
                comb[i][j] = (comb[i-1][j-1] + comb[i-1][j]) % mod;
            }
            comb[i][i] = 1;
        }

        for (int takentables = 1; takentables <= a.length &amp;&amp; takentables <= k; takentables++) {
            int empty = k - takentables;
            long cur = exp(empty+1, s);
            nways = (nways + cur * stir[n][takentables] % mod * comb[k][takentables] % mod * fact[takentables]) % mod;
        }
        return (int)nways;
    }
}

WaitingForBusAgain

First, find the bus with smallest interval. Let mn be the index of this bus. We know that there will be a bus that arrives within f[mn] minutes. All buses have a probability of f[mn]/f[i] of arriving before f[mn] minutes. If k buses appear before f[mn] minutes, all of these buses have equal probability of appearing first. Thus, this can be reduced to the following problem: you have some items with value i and probability p_i of appearing. What is the expected value of their average?

This can be solve with dp. Let dp[i][j] be the expected value of the average given there are exactly j items that appear from the first i items. You can see the code for more details on how to maintain this dp.

public class WaitingForBusAgain {   
    public double expectedBus(int[] f) {
        int n = f.length;
        int mn = 0;
        for (int i = 1; i < n; i++) {
            if (f[i] < f[mn]) {
                mn = i;
            }
        }

        double[] dp = new double[n+1]; // dp[i] = expected value of average given i items
        dp[1] = mn;
        double[] prob = new double[n+1]; // prob[i] = probability there are i items
        prob[1] = 1;

        for (int i = 0; i < n; i++) {
            if (i == mn) continue;
            double phit = f[mn] * 1.0 / f[i];

            double[] ndp = new double[n+1];
            double[] nprob = new double[n+1];
            for (int j = 1; j <= n; j++) {
                ndp[j] += (1 - phit) * dp[j];
                nprob[j] += (1 - phit) * prob[j];
                if(j+1 <= n) {
                    ndp[j+1] += phit * (dp[j] * j + prob[j] * i) / (j+1);
                    nprob[j+1] += phit * prob[j];
                }
            }
            dp = ndp;
            prob = nprob;
        }

        double res = 0;
        for (int i = 0; i <= n; i++) res += dp[i];
        return res;
    }
}

TwoLineRegions

First, fix two lines. This splits the plane into four regions. Then, the other lines can be split into four different classes, based on which region it doesn’t touch. We can count the number of times each region appears in some solution, which is 2^(number of lines in this class).

This gives an n^3 solution. To do it faster, you can use bitsets. Or, alternatively, there is a n^2 log n solution as follows. First sort the lines by angle. Then fix one line. Sweep over the other lines in order of their intersection point (by x-coordinate) with the fixed line. When processing the i-th of this line, you can compute how many lines fall into each class with a binary indexed tree depending on how many lines have higher or lower slope.

import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;

public class TwoLineRegions {
    int mod = 1000000007;

    long gcd(long a, long b) {
        return b == 0 ? a : gcd(b, a % b);
    }

    class Fraction implements Comparable<Fraction>{
        public long x, y;
        public Fraction (long _x, long _y) {
            if (_x < 0) {
                _x = -_x;
                _y = -_y;
            }
            long g = gcd(_x, Math.abs(_y));
            if (g != 0) {
                _x /= g;
                _y /= g;
            }
            this.x = _x;
            this.y = _y;
        }

        @Override
        public int hashCode() {
            return new Long(x*34029341L + y).hashCode();
        }

        @Override
        public boolean equals(Object other) {
            if (!(other instanceof Fraction)) return false;
            return ((Fraction)other).x == x &amp;&amp; ((Fraction)other).y == y;
        }

        @Override
        public int compareTo(Fraction other) {
            return Long.compare(this.x*other.y,this.y*other.x);
        }

    }

    class Line {
        public int a,b,c;

        public Line(int a, int b, int c) {
            this.a = a;
            this.b = b;
            this.c = c;
        }
    }

    class Intersection {
        public int id;
        public Fraction y;

        public Intersection(int id, Fraction y) {
            this.id = id;
            this.y = y;
        }
    }

    class BIT {
        public BIT(int N) {
            this.N = N;
            this.tree = new int[N+3];
        }
        private int[] tree;
        private int N;

        public int query(int K) {
            K += 2;
            int sum = 0;
            for (int i = K; i > 0; i -= (i &amp; -i))
                sum += tree[i];
            return sum;
        }

        public void update(int K, int val) {
            K += 2;
            for (int i = K; i < tree.length; i += (i &amp; -i))
                tree[i] += val;
        }
    }

    public int count(int[] a, int[] b, int[] c) {
        int n = a.length;
        Line[] ls = new Line[n];
        for (int i = 0; i < n; i++) {
            ls[i] = new Line(a[i], b[i], c[i]);
            if (ls[i].b < 0) {
                ls[i].a = -ls[i].a;
                ls[i].b = -ls[i].b;
                ls[i].c = -ls[i].c;
            }
        }
        Arrays.sort(ls, (x,y) -> Long.compare(x.a*y.b,y.a*x.b));
        int[] pow2 = new int[n+10];
        pow2[0] = 1;
        for (int i = 1; i < pow2.length; i++) {
            pow2[i] = pow2[i-1]*2 % mod;
        }
        long ans = 0;
        for (int i = 0; i < n; i++) {
            Intersection[] is = new Intersection[n-1];
            for (int k = 1; k < n; k++) {
                int j = (k+i)%n;
                long x1 = ls[i].a * ls[j].b, y1 = ls[i].c * ls[j].b;
                long x2 = ls[j].a * ls[i].b, y2 = ls[j].c * ls[i].b;
                Fraction g = new Fraction(x2-x1, y2-y1);
                is[k-1] = new Intersection(k-1, g);
            }
            Arrays.sort(is, Comparator.comparing(x -> x.y));
            BIT bit = new BIT(n);
            for (int k = 0; k < n-1; k++) {
                int a1 = bit.query(is[k].id);
                int a2 = k-a1;
                int a3 = is[k].id-a1;
                int a4 = n-2-a1-a2-a3;
                ans = (ans + pow2[a1] + pow2[a2] + pow2[a3] + pow2[a4]) % mod;
                bit.update(is[k].id, +1);
            }
        }
        return (int)(ans*((mod+1)/2) % mod);
    }
}