Commit 41f45a28 authored by Manuel  Segimon's avatar Manuel Segimon
Browse files

Refactor Corrector class to improve sentence generation and scoring

parent 91d547c9
Loading
Loading
Loading
Loading
+41 −26
Original line number Diff line number Diff line
@@ -8,13 +8,17 @@ import java.nio.file.Paths;
import java.util.zip.Inflater;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import java.util.ArrayList;


public class Corrector {
    private TrieNode trieRoot;
    private TrieNode detector;

    public Corrector() {
        trieRoot = loadFile("metadata.ser");
        detector = loadFile("metadata.ser");
    }

    private TrieNode loadFile(String filePath) {
@@ -51,36 +55,47 @@ public class Corrector {
        }
    }

    public String correct(String input) {
        StringBuilder correctedSentence = new StringBuilder();
        String[] words = input.split("\\s+");
    public String correct(String inputSentence) {
        // Divide sentence into words
        String[] words = inputSentence.split(" ");

        for (int i = 0; i < words.length - 2; i++) {
            String trigram = words[i] + " " + words[i + 1] + " " + words[i + 2];
            if (trieRoot.probability(trigram) == 0) {
                correctedSentence.append(suggestCorrection(words[i], words[i + 1], words[i + 2])).append(" ");
            } else {
                correctedSentence.append(words[i]).append(" ");
                correctedSentence.append(words[i + 1]).append(" ");
                correctedSentence.append(words[i + 2]).append(" ");
        List<String> sentences = generateSentences(words);
        float bestScore = Float.MAX_VALUE;
        String result = "";
        for (String sentence : sentences) {
            float score = detector.perplexity(sentence);
            // System.out.println(sentence + " | Score: " + score);

            if (score < bestScore) {
                bestScore = score;
                result = sentence;
            }
        }

        return result;
    }

    public static List<String> generateSentences(String[] words) {
        List<String> results = new ArrayList<>();
        boolean[] used = new boolean[words.length];
        backtrack(results, words, new ArrayList<>(), used);
        return results;
    }

        return correctedSentence.toString().trim();
    private static void backtrack(List<String> results, String[] words, List<String> current, boolean[] used) {
        if (current.size() >= Math.ceil(words.length * 3.0 / 4.0) && current.size() <= words.length) {
            results.add(String.join(" ", current));
        }

    private String suggestCorrection(String word1, String word2, String word3) {
        // Calculate perplexity of trigram, bigram, and unigram using trie
        float trigramPerplexity = trieRoot.perplexity(word1 + " " + word2 + " " + word3);
        float bigramPerplexity = trieRoot.perplexity(word1 + " " + word2);
        float unigramPerplexity = trieRoot.perplexity(word1);
        for (int i = 0; i < words.length; i++) {
            if (used[i])
                continue; // Skip used words

        if (trigramPerplexity <= bigramPerplexity && trigramPerplexity <= unigramPerplexity) {
            return word1;
        } else if (bigramPerplexity <= unigramPerplexity) {
            return word2;
        } else {
            return word3;
            used[i] = true;
            current.add(words[i]);
            backtrack(results, words, current, used);
            current.remove(current.size() - 1);
            used[i] = false;
        }
    }

+1 −1
Original line number Diff line number Diff line
@@ -32,7 +32,7 @@ public class TrieNode implements Serializable, Cloneable {
        }
        TrieNode current = this;
        TrieNode past = this;
        System.out.println("Phrase: " + phrase);
        // System.out.println("Phrase: " + phrase);
        for (String word : words) {
            past = current;
            current = current.children.get(word);