Lọc thư rác bằng Java, sử dụng phân loại bayes

Lọc thư rác bằng Java, sử dụng phân loại bayes.

Ở phần trước mình đã giới thiệu các bạn về phương pháp phân loại đơn giản bằng Bayes (đọc lại tại đây).

Ở bài này chúng ta sẽ tiến hành cài đặt chương trình lọc thư rác bằng ngôn ngữ Java dựa trên phương pháp phân loại đơn giản bằng Bayes.

1. Dữ liệu ban đầu

Ban đầu mình sẽ chuẩn bị 1 tập huấn luyện gồm 10 file text đánh dấu là spam và 10 file text đánh dấu là không spam (nếu bạn có càng nhiều file càng tốt). Tỉ lệ ở đây là 50:50 tức là nhận 1 email tới thì khả năng nó là spam là 50%.

Sau mỗi lần kiểm tra được đó có phải là email spam không ta lại thêm nó vào tập huấn luyện và tỉ lệ sẽ khác dần đi. Ví dụ nhận được 10 email tiếp theo đều không phải là spam thì tỉ lệ spam/non-spam sẽ là khoảng 30:70 như thế 1 email mới sẽ có khả năng spam là 30%

Lọc thư rác bằng Java, sử dụng phân loại bayes

2. Chạy tập huấn luyện.

Mình sẽ duyệt từng file spam và tách thành các túi từ (mỗi túi từ 1 Set các từ phân biệt nhau). Sau khi duyệt xong 10 file ta sẽ được 1 List gồm 10 túi từ spam. Tương tự duyệt và tác các file non-spam ta cũng được 1 list gồm 10 túi từ non-spam.

Sau khi tách được 2 List túi từ ta ghi nó vào file result_training.dat để lúc thực hiện kiểm tra mail spam thì chỉ việc đọc từ file này ra chứ không cần chạy huấn luyên nữa.

public class Demo {

  // mảng chứa các túi từ của thư thường (non-spam)
  static ArrayList<Set<String>> listBagOfNonSpam = new ArrayList<>();
  // mảng chứa các túi từ của thư rác (spam)
  static ArrayList<Set<String>> listBagOfSpam = new ArrayList<>();

  // tinh xac xuat P(xi=x|nhan= nonspam)
  public static double pNonSpam(String x) { 
    double k = 0;
    for (int i = 0; i < listBagOfNonSpam.size(); i++) {
      // moi lan x xuat hien trong 1 thu thuong thi k++
      if (listBagOfNonSpam.get(i).contains(x))
        k++;
    }
    return (k + 1) / (listBagOfNonSpam.size() + 1);
    // P(xi|nhan= nonspam)= (k+1)/(sothuthuong+1);
    // trong do: k la so cac mail nonspam xuat hien xi
    // sothuthuong la so mail nonspam

  }

  // tinh xac xuat P(xi=x|nhan= spam)
  public static double pSpam(String x) {
    double k = 0;
    for (int i = 0; i < listBagOfSpam.size(); i++) {
      if (listBagOfSpam.get(i).contains(x))
        // moi lan x xuat hien trong 1 thu rac thi k++
        k++;
    }
    return (k + 1) / (listBagOfSpam.size() + 1);
    // P(xi|nhan= spam)= (k+1)/(sothurac+1);
    // trong do: k la so cac mail spam xuat hien xi
    // sothurac la so mail spam
  }

  @SuppressWarnings("unchecked")
  public static void main(String[] args) throws FileNotFoundException, IOException, ClassNotFoundException {

    // đọc dữ liệu huấn luyện từ trước ở trong file result_training.dat ra
    System.out.println("Bắt đầu load dữ liệu huấn luyện");
    ObjectInputStream inp = new ObjectInputStream(
        new FileInputStream(new File("data/_result_training/result_training.dat")));
    listBagOfSpam = (ArrayList<Set<String>>) inp.readObject();
    listBagOfNonSpam = (ArrayList<Set<String>>) inp.readObject();
    inp.close();
    System.out.println("Hoàn load dữ liệu huấn luyện");

    // đọc dữ liệu từ mail cần kiểm tra
    System.out.println("Đọc dữ liệu mail cần kiểm tra");
    File mailTesting = new File("data/test/test (1).txt");
    // Tiền xử lý mail cần kiểm tra

    String mailData = FileUtils.readFileToString(mailTesting, "UTF-16");
    Set<String> bagOfTest = RunTrainingData.toBagOfWord(mailData);

    System.out.println("Bắt đầu kiểm tra:");

    // xác xuất là thư thường. P(xi|non-spam)
    double C_NB1 = listBagOfNonSpam.size() / ((double) listBagOfNonSpam.size() + listBagOfSpam.size());
    // xác xuất là thư rác. P(xi|spam)
    double C_NB2 = listBagOfSpam.size() / ((double) listBagOfNonSpam.size() + listBagOfSpam.size());

    ArrayList<String> listStringTest = new ArrayList<>(bagOfTest);

    for (String strTest : listStringTest) {
      if (pNonSpam(strTest) != ((double) 1 / (listBagOfNonSpam.size() + 1))
          || pSpam(strTest) != ((double) 1 / (listBagOfSpam.size() + 1))) {

        System.out.println("P(x_i=" + strTest + "|nonspam)=  " + pNonSpam(strTest) + "        " + "P(x_i="
            + strTest + "|spam)=  " + pSpam(strTest));
        C_NB1 *= pNonSpam(strTest);
        C_NB2 *= pSpam(strTest);
      }
    }
    if (C_NB1 < C_NB2) {
      // Bổ sung thư vừa kiểm tra vào tập huấn luyện.
      listBagOfSpam.add(bagOfTest);
      System.out.println("Là thư rác");
    } else {
      listBagOfNonSpam.add(bagOfTest);
      System.out.println("Là thư thường");
    }

    // Lưu lại tập huấn luyện mới.
    ObjectOutputStream out = new ObjectOutputStream(
        new FileOutputStream(new File("data/_result_training/result_training.dat")));
    out.writeObject(listBagOfSpam);
    out.writeObject(listBagOfNonSpam);
    out.close();
    System.out.println("Kết thúc");
  }

}

(Lưu ý, mình dùng tư viện Apache Common IO để đọc file cho nhanh gọn, các bạn có thể xem lại về Common IO tại đây)

3. Tiến hành kiểm tra mail spam

Mình cũng sẽ tách mail cần kiểm tra thành 1 túi từ.

Áp dụng phương pháp phân loại bayes đơn giản, ta sẽ tính tỉ lệ của từng từ trong túi từ này có trong List túi từ spam và non-spam là bao nhiêu sau đó lấy tích của chúng và nhân với tỉ lệ spam:non-spam và so sánh 2 kết quả.

Ví dụ: mail cần kiểm tra có 100 từ, tỉ lệ của từng từ trong List túi từ spam nhân với nhau là A; tỉ lệ của từng từ trong túi từ non-spam là B. Tỉ lệ spam:non-spam là X:Y (ban đầu là 50:50 nhưng con số này thay đổi sau mỗi lần thêm mail kiểm tra vào tập huấn luyện)

Để biết mail mới có phải là spam hay không ta so sánh A.X với B.Y. Nếu A.X > B.Y thì mail mới là spam và ta thêm túi từ mới vào List túi từ spam ngược lại thì mail mới là non-spam, ta thêm túi từ mới vào List túi từ non-spam và lưu lại.

Ví dụ:

public class Demo {

  // mảng chứa các túi từ của thư thường (non-spam)
  static ArrayList<Set<String>> listBagOfNonSpam = new ArrayList<>();
  // mảng chứa các túi từ của thư rác (spam)
  static ArrayList<Set<String>> listBagOfSpam = new ArrayList<>();

  // tinh xac xuat P(xi=x|nhan= nonspam)
  public static double pNonSpam(String x) { 
    double k = 0;
    for (int i = 0; i < listBagOfNonSpam.size(); i++) {
      // moi lan x xuat hien trong 1 thu thuong thi k++
      if (listBagOfNonSpam.get(i).contains(x))
        k++;
    }
    return (k + 1) / (listBagOfNonSpam.size() + 1);
    // P(xi|nhan= nonspam)= (k+1)/(sothuthuong+1);
    // trong do: k la so cac mail nonspam xuat hien xi
    // sothuthuong la so mail nonspam

  }

  // tinh xac xuat P(xi=x|nhan= spam)
  public static double pSpam(String x) {
    double k = 0;
    for (int i = 0; i < listBagOfSpam.size(); i++) {
      if (listBagOfSpam.get(i).contains(x))
        // moi lan x xuat hien trong 1 thu rac thi k++
        k++;
    }
    return (k + 1) / (listBagOfSpam.size() + 1);
    // P(xi|nhan= spam)= (k+1)/(sothurac+1);
    // trong do: k la so cac mail spam xuat hien xi
    // sothurac la so mail spam
  }

  @SuppressWarnings("unchecked")
  public static void main(String[] args) throws FileNotFoundException, IOException, ClassNotFoundException {

    // đọc dữ liệu huấn luyện từ trước ở trong file result_training.dat ra
    System.out.println("Bắt đầu load dữ liệu huấn luyện");
    ObjectInputStream inp = new ObjectInputStream(
        new FileInputStream(new File("data/_result_training/result_training.dat")));
    listBagOfSpam = (ArrayList<Set<String>>) inp.readObject();
    listBagOfNonSpam = (ArrayList<Set<String>>) inp.readObject();
    inp.close();
    System.out.println("Hoàn load dữ liệu huấn luyện");

    // đọc dữ liệu từ mail cần kiểm tra
    System.out.println("Đọc dữ liệu mail cần kiểm tra");
    File mailTesting = new File("data/test/test (1).txt");
    // Tiền xử lý mail cần kiểm tra

    String mailData = FileUtils.readFileToString(mailTesting, "UTF-16");
    Set<String> bagOfTest = RunTrainingData.toBagOfWord(mailData);

    System.out.println("Bắt đầu kiểm tra:");

    // xác xuất là thư thường. P(xi|non-spam)
    double C_NB1 = listBagOfNonSpam.size() / ((double) listBagOfNonSpam.size() + listBagOfSpam.size());
    // xác xuất là thư rác. P(xi|spam)
    double C_NB2 = listBagOfSpam.size() / ((double) listBagOfNonSpam.size() + listBagOfSpam.size());

    ArrayList<String> listStringTest = new ArrayList<>(bagOfTest);

    for (String strTest : listStringTest) {
      if (pNonSpam(strTest) != ((double) 1 / (listBagOfNonSpam.size() + 1))
          || pSpam(strTest) != ((double) 1 / (listBagOfSpam.size() + 1))) {

        System.out.println("P(x_i=" + strTest + "|nonspam)=  " + pNonSpam(strTest) + "        " + "P(x_i="
            + strTest + "|spam)=  " + pSpam(strTest));
        C_NB1 *= pNonSpam(strTest);
        C_NB2 *= pSpam(strTest);
      }
    }
    if (C_NB1 < C_NB2) {
      // Bổ sung thư vừa kiểm tra vào tập huấn luyện.
      listBagOfSpam.add(bagOfTest);
      System.out.println("Là thư rác");
    } else {
      listBagOfNonSpam.add(bagOfTest);
      System.out.println("Là thư thường");
    }

    // Lưu lại tập huấn luyện mới.
    ObjectOutputStream out = new ObjectOutputStream(
        new FileOutputStream(new File("data/_result_training/result_training.dat")));
    out.writeObject(listBagOfSpam);
    out.writeObject(listBagOfNonSpam);
    out.close();
    System.out.println("Kết thúc");
  }

}

Kết quả:

Bắt đầu load dữ liệu huấn luyện
Hoàn load dữ liệu huấn luyện
Đọc dữ liệu mail cần kiểm tra
Bắt đầu kiểm tra:
P(x_i=thu|nonspam)=  0.2727272727272727        P(x_i=thu|spam)=  0.09090909090909091
P(x_i=hàng|nonspam)=  0.2727272727272727        P(x_i=hàng|spam)=  1.0
P(x_i=điểm|nonspam)=  0.18181818181818182        P(x_i=điểm|spam)=  0.36363636363636365
...
P(x_i=của|nonspam)=  1.0        P(x_i=của|spam)=  0.5454545454545454
P(x_i=tự|nonspam)=  0.45454545454545453        P(x_i=tự|spam)=  0.18181818181818182
Là thư rác
Kết thúc

 

Như vậy mình đã hướng dẫn các bạn hoàn thành chương trình lọc thư rác bằng Java, sử dụng phân loại bayes. Các bạn có thể thêm giao diện để nó chuyên nghiệp hơn  😎

Thực tế thì mail hiện tại không chỉ có text mà còn có hình ảnh, link, html nên phân loại bayes sẽ không còn chính xác nhiều, tuy nhiên từ bài toán này bạn có thể áp dụng vào phân loại comment, tin nhắn, phân loại văn bản…

Thanks các bạn đã theo dõi bài viết!
Các bạn có thể download source code đầy đủ Tại đây (link dự phòng).

stackjava.com