Loading src/main/java/edu/bu/LanguageCorrection/Corrector.java +41 −26 Original line number Diff line number Diff line Loading @@ -8,13 +8,17 @@ import java.nio.file.Paths; import java.util.zip.Inflater; import java.io.ByteArrayOutputStream; import java.io.FileInputStream; import java.util.List; import java.util.Map; import java.util.HashMap; import java.util.ArrayList; public class Corrector { private TrieNode trieRoot; private TrieNode detector; public Corrector() { trieRoot = loadFile("metadata.ser"); detector = loadFile("metadata.ser"); } private TrieNode loadFile(String filePath) { Loading Loading @@ -51,36 +55,47 @@ public class Corrector { } } public String correct(String input) { StringBuilder correctedSentence = new StringBuilder(); String[] words = input.split("\\s+"); public String correct(String inputSentence) { // Divide sentence into words String[] words = inputSentence.split(" "); for (int i = 0; i < words.length - 2; i++) { String trigram = words[i] + " " + words[i + 1] + " " + words[i + 2]; if (trieRoot.probability(trigram) == 0) { correctedSentence.append(suggestCorrection(words[i], words[i + 1], words[i + 2])).append(" "); } else { correctedSentence.append(words[i]).append(" "); correctedSentence.append(words[i + 1]).append(" "); correctedSentence.append(words[i + 2]).append(" "); List<String> sentences = generateSentences(words); float bestScore = Float.MAX_VALUE; String result = ""; for (String sentence : sentences) { float score = detector.perplexity(sentence); // System.out.println(sentence + " | Score: " + score); if (score < bestScore) { bestScore = score; result = sentence; } } return result; } public static List<String> generateSentences(String[] words) { List<String> results = new ArrayList<>(); boolean[] used = new boolean[words.length]; backtrack(results, words, new ArrayList<>(), used); return results; } return correctedSentence.toString().trim(); private static void backtrack(List<String> results, String[] words, List<String> current, boolean[] used) { if (current.size() >= Math.ceil(words.length * 3.0 / 4.0) && current.size() <= words.length) { results.add(String.join(" ", current)); } private String suggestCorrection(String word1, String word2, String word3) { // Calculate perplexity of trigram, bigram, and unigram using trie float trigramPerplexity = trieRoot.perplexity(word1 + " " + word2 + " " + word3); float bigramPerplexity = trieRoot.perplexity(word1 + " " + word2); float unigramPerplexity = trieRoot.perplexity(word1); for (int i = 0; i < words.length; i++) { if (used[i]) continue; // Skip used words if (trigramPerplexity <= bigramPerplexity && trigramPerplexity <= unigramPerplexity) { return word1; } else if (bigramPerplexity <= unigramPerplexity) { return word2; } else { return word3; used[i] = true; current.add(words[i]); backtrack(results, words, current, used); current.remove(current.size() - 1); used[i] = false; } } Loading src/main/java/edu/bu/LanguageCorrection/TrieNode.java +1 −1 Original line number Diff line number Diff line Loading @@ -32,7 +32,7 @@ public class TrieNode implements Serializable, Cloneable { } TrieNode current = this; TrieNode past = this; System.out.println("Phrase: " + phrase); // System.out.println("Phrase: " + phrase); for (String word : words) { past = current; current = current.children.get(word); Loading Loading
src/main/java/edu/bu/LanguageCorrection/Corrector.java +41 −26 Original line number Diff line number Diff line Loading @@ -8,13 +8,17 @@ import java.nio.file.Paths; import java.util.zip.Inflater; import java.io.ByteArrayOutputStream; import java.io.FileInputStream; import java.util.List; import java.util.Map; import java.util.HashMap; import java.util.ArrayList; public class Corrector { private TrieNode trieRoot; private TrieNode detector; public Corrector() { trieRoot = loadFile("metadata.ser"); detector = loadFile("metadata.ser"); } private TrieNode loadFile(String filePath) { Loading Loading @@ -51,36 +55,47 @@ public class Corrector { } } public String correct(String input) { StringBuilder correctedSentence = new StringBuilder(); String[] words = input.split("\\s+"); public String correct(String inputSentence) { // Divide sentence into words String[] words = inputSentence.split(" "); for (int i = 0; i < words.length - 2; i++) { String trigram = words[i] + " " + words[i + 1] + " " + words[i + 2]; if (trieRoot.probability(trigram) == 0) { correctedSentence.append(suggestCorrection(words[i], words[i + 1], words[i + 2])).append(" "); } else { correctedSentence.append(words[i]).append(" "); correctedSentence.append(words[i + 1]).append(" "); correctedSentence.append(words[i + 2]).append(" "); List<String> sentences = generateSentences(words); float bestScore = Float.MAX_VALUE; String result = ""; for (String sentence : sentences) { float score = detector.perplexity(sentence); // System.out.println(sentence + " | Score: " + score); if (score < bestScore) { bestScore = score; result = sentence; } } return result; } public static List<String> generateSentences(String[] words) { List<String> results = new ArrayList<>(); boolean[] used = new boolean[words.length]; backtrack(results, words, new ArrayList<>(), used); return results; } return correctedSentence.toString().trim(); private static void backtrack(List<String> results, String[] words, List<String> current, boolean[] used) { if (current.size() >= Math.ceil(words.length * 3.0 / 4.0) && current.size() <= words.length) { results.add(String.join(" ", current)); } private String suggestCorrection(String word1, String word2, String word3) { // Calculate perplexity of trigram, bigram, and unigram using trie float trigramPerplexity = trieRoot.perplexity(word1 + " " + word2 + " " + word3); float bigramPerplexity = trieRoot.perplexity(word1 + " " + word2); float unigramPerplexity = trieRoot.perplexity(word1); for (int i = 0; i < words.length; i++) { if (used[i]) continue; // Skip used words if (trigramPerplexity <= bigramPerplexity && trigramPerplexity <= unigramPerplexity) { return word1; } else if (bigramPerplexity <= unigramPerplexity) { return word2; } else { return word3; used[i] = true; current.add(words[i]); backtrack(results, words, current, used); current.remove(current.size() - 1); used[i] = false; } } Loading
src/main/java/edu/bu/LanguageCorrection/TrieNode.java +1 −1 Original line number Diff line number Diff line Loading @@ -32,7 +32,7 @@ public class TrieNode implements Serializable, Cloneable { } TrieNode current = this; TrieNode past = this; System.out.println("Phrase: " + phrase); // System.out.println("Phrase: " + phrase); for (String word : words) { past = current; current = current.children.get(word); Loading