Commit 12c4e43c authored by Manuel  Segimon's avatar Manuel Segimon
Browse files

Refactor perplexity calculation in TrieNode.java

parent 25931dac
Loading
Loading
Loading
Loading
+32 −3
Original line number Diff line number Diff line
@@ -37,6 +37,13 @@ public class TrieNode implements Serializable, Cloneable {
        return (float) current.count / past.childCounts;
    }

    private float getAverageChildCount() {
        if (this.children.size() == 0) {
            return 1;
        }
        return (float) this.childCounts / this.children.size();
    }

    public float perplexity(String phrase) {
        TrieNode current = this;
        TrieNode past = this;
@@ -49,10 +56,32 @@ public class TrieNode implements Serializable, Cloneable {
            past = current;
            current = current.children.get(word);
            if (current == null) {
                float alpha = 2;
                return alpha * perplexity(phrase.replaceFirst(words[0] + " ", ""));
                float alpha = (float) 100 / words.length;
                return alpha + perplexity(phrase.replaceFirst(words[0] + " ", ""), words.length - 1);
            }
            logProb += Math.log((float) current.count / past.getAverageChildCount());
        }
        float perplexity = (float) Math.pow(2, -logProb);
        //System.out.println("Perplexity of phrase (" + phrase + ") : " + perplexity);
        return perplexity;
    }

    private float perplexity(String phrase, int wordCount) {
        TrieNode current = this;
        TrieNode past = this;
        float logProb = 0;
        String[] words = phrase.split(" ");
        if (words.length == 1) {
            return 0;
        }
        for (String word : words) {
            past = current;
            current = current.children.get(word);
            if (current == null) {
                float alpha = (float) 100 / wordCount;
                return alpha + perplexity(phrase.replaceFirst(words[0] + " ", ""), wordCount - 1);
            }
            logProb += Math.log((float) current.count / past.childCounts);
            logProb += Math.log((float) current.count / past.getAverageChildCount());
        }
        float perplexity = (float) Math.pow(2, -logProb);
        // System.out.println("Perplexity of phrase (" + phrase + ") : " + perplexity);