Commit 580a5f7b authored by Manuel  Segimon's avatar Manuel Segimon
Browse files

Update corrector to use trie

parent 9f304279
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -5,7 +5,6 @@ import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import java.util.ArrayList;
import java.util.zip.Inflater;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
@@ -50,6 +49,7 @@ public class Checker {
        System.out.println("\"phrases\": " + mapToJson(phraseScores));
        System.out.println("}");
    }

    private static String mapToJson(Map<String, Float> map) {
        StringBuilder jsonBuilder = new StringBuilder("{");
        for (Map.Entry<String, Float> entry : map.entrySet()) {
@@ -60,6 +60,7 @@ public class Checker {

        return jsonBuilder.toString();
    }

    private static byte[] decompress(byte[] compressedData) {
        Inflater decompressor = new Inflater();
        decompressor.setInput(compressedData);
@@ -93,6 +94,7 @@ public class Checker {
            return new TrieNode();
        }
    }

    public static void main(String[] args) {
        if (args.length > 1 && "--file".equals(args[0])) { // check syntax
            String path = args[1];
+42 −81
Original line number Diff line number Diff line
package edu.bu.LanguageCorrection;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
import java.util.zip.Inflater;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;

public class Corrector {
    private Map<String, Double> trigramProbabilities;
    private Map<String, Double> bigramProbabilities;
    private Map<String, Double> unigramProbabilities;

    private static final double BACKOFF_PENALTY = 0.1;
public class Corrector {
    private TrieNode trieRoot;

    public Corrector() {
        trigramProbabilities = new HashMap<>();
        bigramProbabilities = new HashMap<>();
        unigramProbabilities = new HashMap<>();
        loadBrown();
    }

    private void loadBrown() {
        try (BufferedReader br = new BufferedReader(new FileReader("Checker/brown.txt"))) {
            String line;
            Map<String, Integer> bigramCounts = new HashMap<>();
            Map<String, Integer> trigramCounts = new HashMap<>();
            Map<String, Integer> unigramCounts = new HashMap<>();

            while ((line = br.readLine()) != null) {
                String[] words = line.split("\\s+");
                for (String word : words) {
                    String lowerCaseWord = word.toLowerCase();
                    unigramCounts.put(lowerCaseWord, unigramCounts.getOrDefault(lowerCaseWord, 0) + 1);
        trieRoot = loadFile("metadata.ser");
    }

                if (words.length < 3) continue; // Skip lines with less than 3 words

                for (int i = 0; i < words.length - 1; i++) {
                    String bigram = words[i].toLowerCase() + " " + words[i + 1].toLowerCase();
                    bigramCounts.put(bigram, bigramCounts.getOrDefault(bigram, 0) + 1);
                }

                for (int i = 0; i < words.length - 2; i++) {
                    String trigram = words[i].toLowerCase() + " " +
                            words[i + 1].toLowerCase() + " " +
                            words[i + 2].toLowerCase();
                    trigramCounts.put(trigram, trigramCounts.getOrDefault(trigram, 0) + 1);
    private TrieNode loadFile(String filePath) {
        TrieNode trie = new TrieNode();
        try (FileInputStream fis = new FileInputStream(filePath)) {
            byte[] compressedData = fis.readAllBytes();
            byte[] decompressedData = decompress(compressedData);
            trie.deserialize(decompressedData);
            System.out.println("Metadata loaded successfully.");
            return trie;
        } catch (IOException e) {
            System.err.println("Error reading metadata from file: " + e.getMessage());
            return new TrieNode();
        }
    }

            int totalUnigrams = unigramCounts.values().stream().mapToInt(Integer::intValue).sum();
            int totalBigrams = bigramCounts.values().stream().mapToInt(Integer::intValue).sum();
            int totalTrigrams = trigramCounts.values().stream().mapToInt(Integer::intValue).sum();
    private static byte[] decompress(byte[] compressedData) {
        Inflater decompressor = new Inflater();
        decompressor.setInput(compressedData);

            for (Map.Entry<String, Integer> entry : unigramCounts.entrySet()) {
                unigramProbabilities.put(entry.getKey(), (double) entry.getValue() / totalUnigrams);
            }
        ByteArrayOutputStream bos = new ByteArrayOutputStream(compressedData.length);

            for (Map.Entry<String, Integer> entry : bigramCounts.entrySet()) {
                bigramProbabilities.put(entry.getKey(), (double) entry.getValue() / totalBigrams);
            }

            for (Map.Entry<String, Integer> entry : trigramCounts.entrySet()) {
                trigramProbabilities.put(entry.getKey(), (double) entry.getValue() / totalTrigrams);
        byte[] buf = new byte[1024];
        try {
            while (!decompressor.finished()) {
                int count = decompressor.inflate(buf);
                bos.write(buf, 0, count);
            }

        } catch (IOException e) {
            e.printStackTrace();
            decompressor.end();
            return bos.toByteArray();
        } catch (Exception e) {
            System.err.println("Error decompressing data: " + e.getMessage());
            return new byte[0];
        }
    }

@@ -76,11 +56,8 @@ public class Corrector {
        String[] words = input.split("\\s+");

        for (int i = 0; i < words.length - 2; i++) {
            String trigram = words[i].toLowerCase() + " " +
                    words[i + 1].toLowerCase() + " " +
                    words[i + 2].toLowerCase();

            if (!trigramProbabilities.containsKey(trigram)) {
            String trigram = words[i] + " " + words[i + 1] + " " + words[i + 2];
            if (trieRoot.probability(trigram) == 0) {
                correctedSentence.append(suggestCorrection(words[i], words[i + 1], words[i + 2])).append(" ");
            } else {
                correctedSentence.append(words[i]).append(" ");
@@ -93,10 +70,10 @@ public class Corrector {
    }

    private String suggestCorrection(String word1, String word2, String word3) {
        // Trigram, Bigram, and Unigram perplexities
        double trigramPerplexity = calculatePerplexity(trigramProbabilities, word1, word2, word3);
        double bigramPerplexity = calculatePerplexity(bigramProbabilities, word1, word2, "");
        double unigramPerplexity = calculatePerplexity(unigramProbabilities, word1, "", "");
        // Calculate perplexity of trigram, bigram, and unigram using trie
        float trigramPerplexity = trieRoot.perplexity(word1 + " " + word2 + " " + word3);
        float bigramPerplexity = trieRoot.perplexity(word1 + " " + word2);
        float unigramPerplexity = trieRoot.perplexity(word1);

        if (trigramPerplexity <= bigramPerplexity && trigramPerplexity <= unigramPerplexity) {
            return word1;
@@ -107,24 +84,8 @@ public class Corrector {
        }
    }

    private double calculatePerplexity(Map<String, Double> probabilities, String word1, String word2, String word3) {
        String trigram = word1.toLowerCase() + " " + word2.toLowerCase() + " " + word3;
        double probability = probabilities.getOrDefault(trigram, 0.0);

        // If probability is zero, BACKOFF
        if (probability == 0.0) {
            String bigram = word1.toLowerCase() + " " + word2.toLowerCase();
            probability = probabilities.getOrDefault(bigram, 0.0) * BACKOFF_PENALTY;
            if (probability == 0.0) {
                probability = unigramProbabilities.getOrDefault(word1.toLowerCase(), 0.0) * BACKOFF_PENALTY * BACKOFF_PENALTY;
            }
        }
        // Perplexity
        return 1.0 / probability;
    }

    public static void main(String[] args) {
        if (args.length > 1 && "--file".equals(args[0])) { // check syntax
        if (args.length > 1 && "--file".equals(args[0])) {
            String path = args[1];
            try {
                String content = new String(Files.readAllBytes(Paths.get(path)));