Skip to content

Fasttext Classification with Keras in Python

I am going to perform fasttext classification of texts in the 20 Newsgroups dataset in this tutorial. I am going to use Keras in Python to build the model. I am going to visualize the dataset, train the model and evaluate the performance of the model.

Fasttext is developed by Facebook and exists as an open source project on GitHub. Fasttext is a neural network model that is used for text classification, it supports supervised learning and unsupervised learning. Text classification is a task that is supposed to classify texts in 2 or more categories.

Dataset and Libraries

I am using the 20 Newsgroups dataset (download it) in this tutorial. You should download 20news-bydate.tar.gz, this data set is sorted by date and divided into a training set and a test set. Unpack the file to a folder (20news_bydate), files is divided into folders where the name of the folder represent the name of a category. I have used the following libraries: os, re, string, numpy, nltk, pickle, contextlib, matplotlib, scikit-learn and keras.

Visualize Dataset

The code to visualize the dataset is included in the training module. We mainly want to see the balance of the training set, a balanced dataset is important in classification algorithms. The dataset is not perfectly balanced, the most frequent category (rec.sport.hockey) have 600 articles and the least frequent category (talk.religion.misc) have 377 articles. The probability of correctly predicting the most frequent category at random is 5.3 % (600 *100/11314), our model needs to have a higher probability than this to be useful.

  1. --- Information ---
  2. Number of articles: 11314
  3. Number of categories: 20
  4. --- Class distribution ---
  5. alt.atheism: 480
  6. comp.graphics: 584
  7. comp.os.ms-windows.misc: 591
  8. comp.sys.ibm.pc.hardware: 590
  9. comp.sys.mac.hardware: 578
  10. comp.windows.x: 593
  11. misc.forsale: 585
  12. rec.autos: 594
  13. rec.motorcycles: 598
  14. rec.sport.baseball: 597
  15. rec.sport.hockey: 600
  16. sci.crypt: 595
  17. sci.electronics: 591
  18. sci.med: 594
  19. sci.space: 593
  20. soc.religion.christian: 599
  21. talk.politics.guns: 546
  22. talk.politics.mideast: 564
  23. talk.politics.misc: 465
  24. talk.religion.misc: 377
20 Newsgroups, balance in data set

Common Module

I have created a common module (common.py) with configuration, functions to preprocess data and a function to build the fasttext model. The preprocessing method will remove headers, footers, quotes, punctations and digits for each article in the dataset. I am also using a stemmer to stem each word in each article, this process takes some time and you may want to comment this line to speed things up. You can use a lemmatizer instead of a stemmer if you want, you might need to download WordNetLemmatizer. This module also includes two methods to create n-grams.

  1. # Import libraries
  2. import re
  3. import string
  4. import keras
  5. import keras.preprocessing
  6. import contextlib
  7. import nltk.stem
  8. # Download WordNetLemmatizer
  9. # nltk.download()
  10. # Variables
  11. QUOTES = re.compile(r'(writes in|writes:|wrote:|says:|said:|^In article|^Quoted from|^\||^>)')
  12. # Configuration
  13. class Configuration:
  14. # Initializes the class
  15. def __init__(self):
  16. self.ngram_range = 2
  17. self.num_words = 20000 # Size of vocabulary, max number of words in a document
  18. self.max_length = 1000 # The maximum number of words in any document
  19. self.num_classes = 20
  20. self.batch_size = 32
  21. self.embedding_dims = 50
  22. self.epochs = 40 # 140 so far
  23. # Preprocess data
  24. def preprocess_data(data):
  25. # Create a stemmer/lemmatizer
  26. stemmer = nltk.stem.SnowballStemmer('english')
  27. #lemmer = nltk.stem.WordNetLemmatizer()
  28. for i in range(len(data)):
  29. # Remove header
  30. _, _, data[i] = data[i].partition('\n\n')
  31. # Remove footer
  32. lines = data[i].strip().split('\n')
  33. for line_num in range(len(lines) - 1, -1, -1):
  34. line = lines[line_num]
  35. if line.strip().strip('-') == '':
  36. break
  37. if line_num > 0:
  38. data[i] = '\n'.join(lines[:line_num])
  39. # Remove quotes
  40. data[i] = '\n'.join([line for line in data[i].split('\n') if not QUOTES.search(line)])
  41. # Remove punctation (!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~)
  42. data[i] = data[i].translate(str.maketrans('', '', string.punctuation))
  43. # Remove digits
  44. data[i] = re.sub('\d', '', data[i])
  45. # Stem words
  46. data[i] = ' '.join([stemmer.stem(word) for word in data[i].split()])
  47. #data[i] = ' '.join([lemmer.lemmatize(word) for word in data[i].split()])
  48. # Return data
  49. return data
  50. # Create n-gram set, extract a set of n-grams from a list of integers
  51. def create_ngram_set(input_list, ngram_value=2):
  52. return set(zip(*[input_list[i:] for i in range(ngram_value)]))
  53. # Add n-gram, augment the input list of list (sequences) by appending n-grams values
  54. def add_ngram(sequences, token_indice, ngram_range=2):
  55. new_sequences = []
  56. for input_list in sequences:
  57. new_list = input_list[:]
  58. for ngram_value in range(2, ngram_range + 1):
  59. for i in range(len(new_list) - ngram_value + 1):
  60. ngram = tuple(new_list[i:i + ngram_value])
  61. if ngram in token_indice:
  62. new_list.append(token_indice[ngram])
  63. new_sequences.append(new_list)
  64. return new_sequences
  65. # Get a fasttext model
  66. def fasttext(config:Configuration):
  67. # Create an input layer, dtype='int32'
  68. input = keras.layers.Input(shape=(config.max_length,), dtype='float32', name='input_layer')
  69. # Create output layers
  70. output = keras.layers.Embedding(config.num_words, config.embedding_dims, input_length=config.max_length, name='embedding_layer')(input) # Maps our vocabulary indices into embedding_dims dimensions
  71. output = keras.layers.GlobalAveragePooling1D(name='gapl')(output) # Will average the embeddings of all words in the document
  72. output = keras.layers.Dense(config.num_classes, activation='softmax', name='output_layer')(output) # Project to dense output layer with softmax
  73. # Create a model from input layer and output layers
  74. model = keras.models.Model(inputs=input, outputs=output, name='fasttext')
  75. # Compile the model
  76. model.compile(loss='categorical_crossentropy', optimizer=keras.optimizers.adam(lr=0.01, clipnorm=0.001), metrics=['accuracy'])
  77. # Save model summary to file
  78. with open('C:\\DATA\\Python-data\\20news_bydate\\fasttext\\model-summary.txt', 'w') as file:
  79. with contextlib.redirect_stdout(file):
  80. model.summary()
  81. # Return a model
  82. return model

Training

The training module is used to load the training dataset, visualize the dataset, train the model and evaluate the model on the training set. Classes, the tokenizer and the model is saved to disk after each training session (transfer learning). Output from a run is shown below the code.

  1. # Import libraries
  2. import os
  3. import pickle
  4. import keras
  5. import keras.preprocessing
  6. import sklearn.datasets
  7. import numpy as np
  8. import matplotlib.pyplot as plt
  9. import annytab.fasttext.common as common
  10. # Visualize dataset
  11. def visualize_dataset(ds:object, num_classes:int):
  12. # Print dataset
  13. print('\n--- Information ---')
  14. print('Number of articles: ' + str(len(ds.data)))
  15. print('Number of categories: ' + str(len(ds.target_names)))
  16. # Count number of articles in each category
  17. plot_X = np.arange(num_classes, dtype=np.int16)
  18. plot_Y = np.zeros(num_classes)
  19. for i in range(len(ds.data)):
  20. plot_Y[ds.target[i]] += 1
  21. print('\n--- Class distribution ---')
  22. for i in range(len(plot_X)):
  23. print('{0}: {1:.0f}'.format(ds.target_names[plot_X[i]], plot_Y[i]))
  24. # Plot the balance of the dataset
  25. figure = plt.figure(figsize = (16, 10))
  26. figure.suptitle('Balance of dataset', fontsize=16)
  27. plt.bar(plot_X, plot_Y, align='center', color='rgbkymc')
  28. plt.xticks(plot_X, ds.target_names, rotation=25, horizontalalignment='right')
  29. #plt.show()
  30. plt.savefig('C:\\DATA\\Python-data\\20news_bydate\\fasttext\\20-newsgroups-balance.png')
  31. # Train and evaluate a model
  32. def train_and_evaluate(train:object, config:common.Configuration):
  33. # Create a dictionary with classes (maps index to name)
  34. classes = {}
  35. for i in range(len(train.target_names)):
  36. classes[i] = train.target_names[i]
  37. # Save classes to file
  38. with open('C:\\DATA\\Python-data\\20news_bydate\\fasttext\\classes.pkl', 'wb') as file:
  39. pickle.dump(classes, file)
  40. print('Saved classes to disk!')
  41. # This class allows to vectorize a text corpus, by turning each text into a sequence of integers
  42. tokenizer = keras.preprocessing.text.Tokenizer(num_words=config.num_words)
  43. # Updates internal vocabulary based on a list of texts
  44. tokenizer.fit_on_texts(train.data)
  45. # Save tokenizer to disk
  46. with open('C:\\DATA\\Python-data\\20news_bydate\\fasttext\\tokenizer.pkl', 'wb') as file:
  47. pickle.dump(tokenizer, file)
  48. print('Saved tokenizer to disk!')
  49. # Transforms each text in texts to a sequence of integers
  50. train.data = tokenizer.texts_to_sequences(train.data)
  51. # Converts a class vector (integers) to binary class matrix: categorical_crossentropy expects targets
  52. # to be binary matrices (1s and 0s) of shape (samples, classes)
  53. train.target = keras.utils.to_categorical(train.target, num_classes=config.num_classes, dtype='int32')
  54. # Add n-gram features
  55. if config.ngram_range > 1:
  56. # Create set of unique n-gram from the training set
  57. ngram_set = set()
  58. for input_list in train.data:
  59. for i in range(2, config.ngram_range + 1):
  60. set_of_ngram = common.create_ngram_set(input_list, ngram_value=i)
  61. ngram_set.update(set_of_ngram)
  62. # Dictionary mapping n-gram token to a unique integer, integer values are greater than number of words in order to avoid collision with existing features
  63. start_index = config.num_words + 1
  64. token_indice = {v: k + start_index for k, v in enumerate(ngram_set)}
  65. indice_token = {token_indice[k]: k for k in token_indice}
  66. # Number of words is the highest integer that could be found in the dataset
  67. config.num_words = np.max(list(indice_token.keys())) + 1
  68. # Augmenting x_train and x_test with n-grams features
  69. train.data = common.add_ngram(train.data, token_indice, config.ngram_range)
  70. # Pads sequences to the same length
  71. train.data = keras.preprocessing.sequence.pad_sequences(train.data, maxlen=config.max_length)
  72. # Get a model
  73. if(os.path.isfile('C:\\DATA\\Python-data\\20news_bydate\\fasttext\\model.h5') == True):
  74. model = keras.models.load_model('C:\\DATA\\Python-data\\20news_bydate\\fasttext\\model.h5')
  75. else:
  76. model = common.fasttext_improved(config)
  77. # Start training
  78. history = model.fit(train.data, train.target, batch_size=config.batch_size, epochs=config.epochs, verbose=1)
  79. # Save model to disk
  80. model.save('C:\\DATA\\Python-data\\20news_bydate\\fasttext\\model.h5')
  81. print('Training completed, saved model to disk!')
  82. # Plot training loss
  83. plt.figure(figsize =(12,8))
  84. plt.plot(history.history['loss'], marker='.', label='train')
  85. plt.title('Loss')
  86. plt.grid(True)
  87. plt.xlabel('Epoch')
  88. plt.ylabel('Loss')
  89. plt.legend(loc='best')
  90. plt.savefig('C:\\DATA\\Python-data\\20news_bydate\\fasttext\\loss-plot.png')
  91. # Plot training accuracy
  92. plt.figure(figsize =(12,8))
  93. plt.plot(history.history['accuracy'], marker='.', label='train')
  94. plt.title('Accuracy')
  95. plt.grid(True)
  96. plt.xlabel('Epoch')
  97. plt.ylabel('Accuracy')
  98. plt.legend(loc='best')
  99. plt.savefig('C:\\DATA\\Python-data\\20news_bydate\\fasttext\\accuracy-plot.png')
  100. # The main entry point for this module
  101. def main():
  102. # Load text files with categories as subfolder names
  103. # Individual samples are assumed to be files stored in a two levels folder structure
  104. # The folder names are used as supervised signal label names. The individual file names are not important.
  105. train = sklearn.datasets.load_files('C:\\DATA\\Python-data\\20news_bydate\\20news-bydate-train', shuffle=False, load_content=True, encoding='latin1')
  106. # Get a configuration
  107. config = common.Configuration()
  108. # Visualize dataset
  109. #visualize_dataset(train, config.num_classes)
  110. # Preprocess data
  111. train.data = common.preprocess_data(train.data)
  112. # Print cleaned data
  113. #print(train.data[0])
  114. # Print empty row
  115. print()
  116. # Start training
  117. train_and_evaluate(train, config)
  118. # Tell python to run main method
  119. if __name__ == "__main__": main()
  1. 10784/11314 [===========================>..] - ETA: 9s - loss: 0.1440 - accuracy: 0.9651
  2. 10816/11314 [===========================>..] - ETA: 8s - loss: 0.1436 - accuracy: 0.9652
  3. 10848/11314 [===========================>..] - ETA: 7s - loss: 0.1435 - accuracy: 0.9652
  4. 10880/11314 [===========================>..] - ETA: 7s - loss: 0.1437 - accuracy: 0.9652
  5. 10912/11314 [===========================>..] - ETA: 6s - loss: 0.1438 - accuracy: 0.9652
  6. 10944/11314 [============================>.] - ETA: 6s - loss: 0.1439 - accuracy: 0.9652
  7. 10976/11314 [============================>.] - ETA: 5s - loss: 0.1443 - accuracy: 0.9650
  8. 11008/11314 [============================>.] - ETA: 5s - loss: 0.1446 - accuracy: 0.9649
  9. 11040/11314 [============================>.] - ETA: 4s - loss: 0.1442 - accuracy: 0.9650
  10. 11072/11314 [============================>.] - ETA: 4s - loss: 0.1449 - accuracy: 0.9649
  11. 11104/11314 [============================>.] - ETA: 3s - loss: 0.1445 - accuracy: 0.9650
  12. 11136/11314 [============================>.] - ETA: 3s - loss: 0.1455 - accuracy: 0.9647
  13. 11168/11314 [============================>.] - ETA: 2s - loss: 0.1451 - accuracy: 0.9648
  14. 11200/11314 [============================>.] - ETA: 1s - loss: 0.1447 - accuracy: 0.9649
  15. 11232/11314 [============================>.] - ETA: 1s - loss: 0.1444 - accuracy: 0.9650
  16. 11264/11314 [============================>.] - ETA: 0s - loss: 0.1441 - accuracy: 0.9651
  17. 11296/11314 [============================>.] - ETA: 0s - loss: 0.1447 - accuracy: 0.9650
  18. 11314/11314 [==============================] - 194s 17ms/step - loss: 0.1445 - accuracy: 0.9651
  19. Training completed, saved model to disk!

Evaluation

Model performance is evaluated on the test dataset, the model has been trained in about 140 epochs. The accuracy on the test dataset is much lower than the accuracy reported during training, this indicates that the model is underfitted (to simple). Output from an evaluation run is shown below the code.

  1. # Import libraries
  2. import keras
  3. import pickle
  4. import numpy as np
  5. import sklearn.datasets
  6. import sklearn.metrics
  7. import annytab.fasttext.common as common
  8. # Test and evaluate a model
  9. def test_and_evaluate(ds:object, config:common.Configuration):
  10. # Load models
  11. model = keras.models.load_model('C:\\DATA\\Python-data\\20news_bydate\\fasttext\\model.h5')
  12. with open('C:\\DATA\\Python-data\\20news_bydate\\fasttext\\classes.pkl', 'rb') as file:
  13. classes = pickle.load(file)
  14. with open('C:\\DATA\\Python-data\\20news_bydate\\fasttext\\tokenizer.pkl', 'rb') as file:
  15. tokenizer = pickle.load(file)
  16. # Transforms each text in texts to a sequence of integers
  17. ds.data = tokenizer.texts_to_sequences(ds.data)
  18. # Add n-gram features
  19. if config.ngram_range > 1:
  20. # Create set of unique n-gram from the dataset
  21. ngram_set = set()
  22. for input_list in ds.data:
  23. for i in range(2, config.ngram_range + 1):
  24. set_of_ngram = common.create_ngram_set(input_list, ngram_value=i)
  25. ngram_set.update(set_of_ngram)
  26. # Dictionary mapping n-gram token to a unique integer, integer values are greater than number of words in order to avoid collision with existing features
  27. start_index = config.num_words + 1
  28. token_indice = {v: k + start_index for k, v in enumerate(ngram_set)}
  29. indice_token = {token_indice[k]: k for k in token_indice}
  30. # Augmenting data with n-grams features
  31. ds.data = common.add_ngram(ds.data, token_indice, config.ngram_range)
  32. # Pads sequences to the same length
  33. ds.data = keras.preprocessing.sequence.pad_sequences(ds.data, maxlen=config.max_length)
  34. # Make predictions
  35. predictions = model.predict(ds.data)
  36. # Print results
  37. print('\n-- Results --')
  38. accuracy = sklearn.metrics.accuracy_score(ds.target, np.argmax(predictions, axis=1))
  39. print('Accuracy: {0:.2f} %'.format(accuracy * 100.0))
  40. print('Classification Report:')
  41. print(sklearn.metrics.classification_report(ds.target, np.argmax(predictions, axis=1), target_names=list(classes.values())))
  42. print('\n-- Samples --')
  43. for i in range(20):
  44. print('{0} --- Predicted: {1}, Actual: {2}'.format('CORRECT' if np.argmax(predictions[i]) == ds.target[i] else 'INCORRECT', classes.get(np.argmax(predictions[i])), classes.get(ds.target[i])))
  45. print()
  46. # The main entry point for this module
  47. def main():
  48. # Load test dataset (shuffle it to get different samples each time)
  49. test = sklearn.datasets.load_files('C:\\DATA\\Python-data\\20news_bydate\\20news-bydate-test', shuffle=True, load_content=True, encoding='latin1')
  50. # Preprocess data
  51. test.data = common.preprocess_data(test.data)
  52. # Get a configuration
  53. config = common.Configuration()
  54. # Test and evaluate
  55. test_and_evaluate(test, config)
  56. # Tell python to run main method
  57. if __name__ == "__main__": main()
  1. -- Results --
  2. Accuracy: 51.95 %
  3. Classification Report:
  4. precision recall f1-score support
  5. alt.atheism 0.41 0.42 0.42 319
  6. comp.graphics 0.51 0.57 0.54 389
  7. comp.os.ms-windows.misc 0.60 0.51 0.55 394
  8. comp.sys.ibm.pc.hardware 0.68 0.45 0.54 392
  9. comp.sys.mac.hardware 0.62 0.56 0.59 385
  10. comp.windows.x 0.81 0.49 0.61 395
  11. misc.forsale 0.79 0.65 0.71 390
  12. rec.autos 0.61 0.61 0.61 396
  13. rec.motorcycles 0.54 0.65 0.59 398
  14. rec.sport.baseball 0.19 0.93 0.32 397
  15. rec.sport.hockey 0.89 0.61 0.73 399
  16. sci.crypt 0.82 0.47 0.60 396
  17. sci.electronics 0.60 0.35 0.44 393
  18. sci.med 0.73 0.49 0.59 396
  19. sci.space 0.77 0.53 0.63 394
  20. soc.religion.christian 0.73 0.50 0.59 398
  21. talk.politics.guns 0.49 0.46 0.47 364
  22. talk.politics.mideast 0.89 0.43 0.58 376
  23. talk.politics.misc 0.36 0.37 0.36 310
  24. talk.religion.misc 0.37 0.15 0.22 251
  25. accuracy 0.52 7532
  26. macro avg 0.62 0.51 0.53 7532
  27. weighted avg 0.63 0.52 0.54 7532
  28. -- Samples --
  29. CORRECT --- Predicted: rec.sport.hockey, Actual: rec.sport.hockey
  30. CORRECT --- Predicted: talk.politics.guns, Actual: talk.politics.guns
  31. INCORRECT --- Predicted: rec.sport.baseball, Actual: sci.space
  32. CORRECT --- Predicted: talk.politics.misc, Actual: talk.politics.misc
  33. INCORRECT --- Predicted: rec.sport.baseball, Actual: comp.windows.x
  34. INCORRECT --- Predicted: talk.politics.misc, Actual: rec.autos
  35. CORRECT --- Predicted: comp.os.ms-windows.misc, Actual: comp.os.ms-windows.misc
  36. CORRECT --- Predicted: rec.autos, Actual: rec.autos
  37. INCORRECT --- Predicted: comp.sys.mac.hardware, Actual: comp.graphics
  38. CORRECT --- Predicted: comp.graphics, Actual: comp.graphics
  39. CORRECT --- Predicted: comp.graphics, Actual: comp.graphics
  40. CORRECT --- Predicted: comp.graphics, Actual: comp.graphics
  41. INCORRECT --- Predicted: rec.sport.baseball, Actual: sci.med
  42. CORRECT --- Predicted: soc.religion.christian, Actual: soc.religion.christian
  43. INCORRECT --- Predicted: comp.os.ms-windows.misc, Actual: comp.sys.ibm.pc.hardware
  44. CORRECT --- Predicted: sci.space, Actual: sci.space
  45. INCORRECT --- Predicted: rec.sport.baseball, Actual: comp.sys.mac.hardware
  46. CORRECT --- Predicted: rec.sport.hockey, Actual: rec.sport.hockey
  47. CORRECT --- Predicted: soc.religion.christian, Actual: soc.religion.christian
  48. CORRECT --- Predicted: comp.os.ms-windows.misc, Actual: comp.os.ms-windows.misc
Tags:

Leave a Reply

Your email address will not be published. Required fields are marked *