diff --git a/ProgrammingLanguageClassification.ipynb b/ProgrammingLanguageClassification.ipynb new file mode 100644 index 0000000..f3939cd --- /dev/null +++ b/ProgrammingLanguageClassification.ipynb @@ -0,0 +1,300 @@ +{ + "metadata": { + "name": "", + "signature": "sha256:61d5af4ccb86ac8537d317c915e8379d0ed4a12643ccdc5816b46ddc9097c3da" + }, + "nbformat": 3, + "nbformat_minor": 0, + "worksheets": [ + { + "cells": [ + { + "cell_type": "code", + "collapsed": false, + "input": [ + "import re\n", + "import numpy as np\n", + "from sklearn.metrics import (classification_report, f1_score, accuracy_score,\n", + " confusion_matrix)\n", + "import parser\n", + "import trainer\n", + "import predictor\n", + "from sklearn.ensemble import AdaBoostClassifier" + ], + "language": "python", + "metadata": {}, + "outputs": [], + "prompt_number": 1 + }, + { + "cell_type": "heading", + "level": 2, + "metadata": {}, + "source": [ + "Programming Language Identification" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "First, we need to create and train our language classifier. This will also involve testing our classifier to see its accuracy. For this script we are using a Random Tree Classifier provided by the sklearn toolkit." + ] + }, + { + "cell_type": "code", + "collapsed": false, + "input": [ + "data, results = trainer.create_training_data()" + ], + "language": "python", + "metadata": {}, + "outputs": [], + "prompt_number": 2 + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The create_training_data function reads in the training_data folder and parses and scores each of the source files for use with our classifier. It also creates a list containing the correct answers for each of the elements in the data array.\n", + "\n", + "Next we need to split our data into training and testing blocks." + ] + }, + { + "cell_type": "code", + "collapsed": false, + "input": [ + "train_data, test_data, train_results, test_results = trainer.split_data(data, results, 0.2)" + ], + "language": "python", + "metadata": {}, + "outputs": [], + "prompt_number": 3 + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Now that our data has been appropriately split we need to use our training data to train our classifier." + ] + }, + { + "cell_type": "code", + "collapsed": false, + "input": [ + "trained_forest = trainer.train_learner(train_data, train_results)" + ], + "language": "python", + "metadata": {}, + "outputs": [], + "prompt_number": 4 + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Now that our random forest is trained, we need to run it against our test data to see how well it performs." + ] + }, + { + "cell_type": "code", + "collapsed": false, + "input": [ + "trainer.test_learner(trained_forest, test_data, test_results)" + ], + "language": "python", + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "stream": "stdout", + "text": [ + " precision recall f1-score support\n", + "\n", + " Clojure 1.00 1.00 1.00 6\n", + " Haskell 1.00 0.50 0.67 2\n", + " Java 1.00 1.00 1.00 39\n", + " JavaScript 1.00 1.00 1.00 9\n", + " OCaml 1.00 1.00 1.00 4\n", + " PHP 0.96 1.00 0.98 72\n", + " Perl 1.00 0.91 0.95 23\n", + " Python 0.94 1.00 0.97 15\n", + " Ruby 1.00 0.93 0.96 29\n", + " Scala 1.00 1.00 1.00 10\n", + " Scheme 1.00 1.00 1.00 2\n", + " TCL 0.83 1.00 0.91 5\n", + "\n", + "avg / total 0.98 0.98 0.98 216\n", + "\n", + "[[ 6 0 0 0 0 0 0 0 0 0 0 0]\n", + " [ 0 1 0 0 0 1 0 0 0 0 0 0]\n", + " [ 0 0 39 0 0 0 0 0 0 0 0 0]\n", + " [ 0 0 0 9 0 0 0 0 0 0 0 0]\n", + " [ 0 0 0 0 4 0 0 0 0 0 0 0]\n", + " [ 0 0 0 0 0 72 0 0 0 0 0 0]\n", + " [ 0 0 0 0 0 1 21 1 0 0 0 0]\n", + " [ 0 0 0 0 0 0 0 15 0 0 0 0]\n", + " [ 0 0 0 0 0 1 0 0 27 0 0 1]\n", + " [ 0 0 0 0 0 0 0 0 0 10 0 0]\n", + " [ 0 0 0 0 0 0 0 0 0 0 2 0]\n", + " [ 0 0 0 0 0 0 0 0 0 0 0 5]]\n", + "0.9761312978\n" + ] + } + ], + "prompt_number": 5 + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "We are getting near 97% accuracy against our test data. TCL seems to have the lowest success in identification, but the training set for that language was very small.\n", + "\n", + "After training and testing the classifier was retrained using the entire data set and saved to disk for later use. Next we will use our classifier to try to identify some other test data." + ] + }, + { + "cell_type": "code", + "collapsed": false, + "input": [ + "result_list = []\n", + "with open(\"test.csv\") as result:\n", + " results = result.readlines()\n", + " for item in results:\n", + " result_list.append(re.findall(\"\\d+,(\\w+)\", item)[0])" + ], + "language": "python", + "metadata": {}, + "outputs": [], + "prompt_number": 6 + }, + { + "cell_type": "code", + "collapsed": false, + "input": [ + "result_list" + ], + "language": "python", + "metadata": {}, + "outputs": [ + { + "metadata": {}, + "output_type": "pyout", + "prompt_number": 7, + "text": [ + "['Clojure',\n", + " 'Clojure',\n", + " 'Clojure',\n", + " 'Clojure',\n", + " 'Python',\n", + " 'Python',\n", + " 'Python',\n", + " 'Python',\n", + " 'JavaScript',\n", + " 'JavaScript',\n", + " 'JavaScript',\n", + " 'JavaScript',\n", + " 'Ruby',\n", + " 'Ruby',\n", + " 'Ruby',\n", + " 'Haskell',\n", + " 'Haskell',\n", + " 'Haskell',\n", + " 'Scheme',\n", + " 'Scheme',\n", + " 'Scheme',\n", + " 'Java',\n", + " 'Java',\n", + " 'Scala',\n", + " 'Scala',\n", + " 'TCL',\n", + " 'TCL',\n", + " 'PHP',\n", + " 'PHP',\n", + " 'PHP',\n", + " 'OCaml',\n", + " 'OCaml']" + ] + } + ], + "prompt_number": 7 + }, + { + "cell_type": "code", + "collapsed": false, + "input": [ + "predictions = []\n", + "classifier = predictor.load_classifier()\n", + "\n", + "for num in range(1, 33):\n", + " data = predictor.prepare_file(\"test/{}\".format(num))\n", + " predictions.append(predictor.test_file(classifier, data))\n", + "predictions = np.array(predictions)" + ], + "language": "python", + "metadata": {}, + "outputs": [], + "prompt_number": 8 + }, + { + "cell_type": "code", + "collapsed": false, + "input": [ + "print(classification_report(result_list, predictions))\n", + "print(confusion_matrix(result_list, predictions))\n", + "print(f1_score(result_list, predictions))" + ], + "language": "python", + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "stream": "stdout", + "text": [ + " precision recall f1-score support\n", + "\n", + " Clojure 0.80 1.00 0.89 4\n", + " Haskell 1.00 1.00 1.00 3\n", + " Java 1.00 1.00 1.00 2\n", + " JavaScript 1.00 0.75 0.86 4\n", + " OCaml 1.00 0.50 0.67 2\n", + " PHP 0.75 1.00 0.86 3\n", + " Python 1.00 1.00 1.00 4\n", + " Ruby 0.75 1.00 0.86 3\n", + " Scala 1.00 1.00 1.00 2\n", + " Scheme 1.00 1.00 1.00 3\n", + " TCL 1.00 0.50 0.67 2\n", + "\n", + "avg / total 0.93 0.91 0.90 32\n", + "\n", + "[[4 0 0 0 0 0 0 0 0 0 0]\n", + " [0 3 0 0 0 0 0 0 0 0 0]\n", + " [0 0 2 0 0 0 0 0 0 0 0]\n", + " [1 0 0 3 0 0 0 0 0 0 0]\n", + " [0 0 0 0 1 0 0 1 0 0 0]\n", + " [0 0 0 0 0 3 0 0 0 0 0]\n", + " [0 0 0 0 0 0 4 0 0 0 0]\n", + " [0 0 0 0 0 0 0 3 0 0 0]\n", + " [0 0 0 0 0 0 0 0 2 0 0]\n", + " [0 0 0 0 0 0 0 0 0 3 0]\n", + " [0 0 0 0 0 1 0 0 0 0 1]]\n", + "0.899801587302\n" + ] + } + ], + "prompt_number": 9 + }, + { + "cell_type": "code", + "collapsed": false, + "input": [], + "language": "python", + "metadata": {}, + "outputs": [] + } + ], + "metadata": {} + } + ] +} \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/parser.py b/parser.py new file mode 100644 index 0000000..f54ce4c --- /dev/null +++ b/parser.py @@ -0,0 +1,421 @@ +import re +import numpy as np + + +"""The Parser module will provide methods for parsing information from a +string representing a source code document, as well as a method for reading +the file""" + + +def read_file(file_path): + with open(file_path) as source_file: + source_text = source_file.read() + return source_text + + +def split_into_lines(source_text): + return source_text.splitlines() + + +def identify_comment_type(source_lines): + """Identifies the symbols used to denote comments in the supplied + source file""" + comment_types = [r"\w*/\*+", r"\w*\(\*+", r"{-", r"\w*;;+", r"\w*//+", + r"\w*--+", r"\w*\#+"] + + for line in source_lines: + for index, c_type in enumerate(comment_types): + if re.search(r"{} ?(\w+ )+".format(c_type), line): + return index + return -1 + + +def identify_comment_blocks(c_type, source_lines): + """Identifies block comments and creates a list of them for removal""" + block_start = [r"\w*/\*+", r"\w*\(\*+", r"\w*{-"] + block_stop = [r"\*+/", r"\*+\)", r"-}"] + in_block = False + to_remove = [] + + for line in source_lines: + if in_block and re.search(block_stop[c_type], line): + to_remove.append(line) + in_block = False + elif in_block: + to_remove.append(line) + elif re.match(block_start[c_type], line): + to_remove.append(line) + if not re.search(block_stop[c_type], line): + in_block = True + return to_remove + + +def strip_comments(c_type, source_lines): + """Identifies and strips comments from the supplied source file""" + comment_types = [r"/\*+", r"\(\*+", r"{-", r";;+", r"//+", r"--+", r"\#+"] + items_to_remove = [] + + if c_type == 0 or c_type == 1: + items_to_remove.extend(identify_comment_blocks(c_type, source_lines)) + if c_type == 0: + c_type = 4 + if c_type >= 3: + for line in source_lines: + if re.match(comment_types[c_type], line): + items_to_remove.append(line) + + stripped_source_lines = remove_items(items_to_remove, source_lines) + if c_type >= 3: + stripped_source_lines = strip_inline_comments(c_type, + stripped_source_lines) + return stripped_source_lines + + +def remove_items(items, source_lines): + """Removes the lines identified by the comment identifier""" + for item in items: + try: + source_lines.remove(item) + except ValueError: + continue + return source_lines + + +def strip_inline_comments(c_type, source_lines): + """Attempts to find inline comments in the source file and removes them""" + comment_types = [r"/\*+", r"\(\*+", r"{-", r";;+", r"//+", r"--+", r"\#+"] + stripped_lines = [] + for line in source_lines: + if re.search(comment_types[c_type], line): + index = re.search(comment_types[c_type], line).span()[0] + stripped_lines.append(line[:index]) + else: + stripped_lines.append(line) + return stripped_lines + +"""Functions below are feature examinations of the source file""" + + +def count_characters(source_lines): + source_text = "\n".join(source_lines) + return len(source_text) + + +def identify_print_style(source_lines): + for line in source_lines: + if line.find("puts") > -1: + return 0 + elif line.find("println") > -1: + return 1 + elif line.find("printf") > -1: + return 2 + elif line.find("print") > -1: + return 3 + return -1 + + +def count_parentheses(source_list, length_of_source): + source_text = "\n".join(source_list) + total_count = source_text.count("(") + source_text.count(")") + return total_count / length_of_source + + +def count_braces(source_list, length_of_source): + source_text = "\n".join(source_list) + total_count = source_text.count("{") + source_text.count("}") + return total_count / length_of_source + + +def count_brackets(source_list, length_of_source): + source_text = "\n".join(source_list) + total_count = source_text.count("[") + source_text.count("]") + return total_count / length_of_source + + +def count_double_colons(source_list, length_of_source): + source_text = "\n".join(source_list) + total_count = source_text.count("::") + return total_count / length_of_source + + +def count_semi_colons(source_list, length_of_source): + source_text = "\n".join(source_list) + total_count = source_text.count(";\n") + return total_count / length_of_source + + +def count_dollar_signs(source_list, length_of_source): + source_text = "\n".join(source_list) + total_count = len(re.findall(r"\$\w+", source_text)) + return total_count / length_of_source + + +def count_question_marks(source_list, length_of_source): + source_text = "\n".join(source_list) + total_count = source_text.count("?") + return total_count / length_of_source + + +def count_pipes(source_list, length_of_source): + source_text = "\n".join(source_list) + total_count = source_text.count("|") + return total_count / length_of_source + + +def count_percent_signs(source_list, length_of_source): + source_text = "\n".join(source_list) + total_count = source_text.count("%") + return total_count / length_of_source + + +def count_at_symbols(source_list, length_of_source): + source_text = "\n".join(source_list) + total_count = source_text.count("@") + return total_count / length_of_source + + +def count_double_period(source_list, length_of_source): + source_text = "\n".join(source_list) + total_count = source_text.count("..") + return total_count / length_of_source + + +def check_double_plus_minus(source_list): + source_text = "\n".join(source_list) + if source_text.find("++") > -1 or source_text.find("--") > -1: + return 1 + else: + return 0 + + +def check_colon_equals(source_list): + source_text = "\n".join(source_list) + if source_text.find(":=") > -1: + return 1 + else: + return 0 + + +def check_dash_arrow(source_list): + source_text = "\n".join(source_list) + if source_text.find("->") > -1: + return 1 + else: + return 0 + + +def check_reverse_dash_arrow(source_list): + source_text = "\n".join(source_list) + if source_text.find("<-") > -1: + return 1 + else: + return 0 + + +def check_equals_arrow(source_list): + source_text = "\n".join(source_list) + if source_text.find("=>") > -1: + return 1 + else: + return 0 + + +def check_for_function(source_list): + source_text = "\n".join(source_list) + if source_text.find("function") > -1: + return 1 + else: + return 0 + + +def check_for_public(source_list): + source_text = "\n".join(source_list) + if source_text.find("public") > -1: + return 1 + else: + return 0 + + +def check_def_method(source_list): + source_text = "\n".join(source_list) + if source_text.find("defn") > -1: + return 1 + elif source_text.find("function") > -1: + return 2 + elif source_text.find("define") > -1: + return 3 + elif source_text.find("def") > -1: + return 4 + elif source_text.find("let") > -1: + return 5 + elif source_text.find("proc") > -1: + return 6 + else: + return -1 + + +def check_for_end(source_list): + source_text = "\n".join(source_list) + if source_text.find("end") > -1: + return 1 + else: + return 0 + + +def check_for_static(source_list): + source_text = "\n".join(source_list) + if source_text.find("static") > -1: + return 1 + else: + return 0 + + +def check_for_colon_word(source_list): + source_text = "\n".join(source_list) + if re.search(r":\w+", source_text): + return 1 + else: + return 0 + + +def check_for_type(source_list): + source_text = "\n".join(source_list) + if source_text.find("type") > -1: + return 1 + else: + return 0 + + +def count_bol_parentheses(source_list, length_of_source): + total_count = 0 + for item in source_list: + if re.match(r"\(", item): + total_count += 1 + return total_count / length_of_source + + +def check_for_val(source_list): + source_text = "\n".join(source_list) + if source_text.find("val") > -1: + return 1 + else: + return 0 + + +def check_for_where(source_list): + source_text = "\n".join(source_list) + if source_text.find("where") > -1: + return 1 + else: + return 0 + + +def check_for_module(source_list): + source_text = "\n".join(source_list) + if source_text.find("module") > -1: + return 1 + else: + return 0 + + +def count_arrows(source_list, length_of_source): + source_text = "\n".join(source_list) + total_count = source_text.count("<") + source_text.count(">") + return total_count / length_of_source + + +def check_assignment(source_list): + source_text = "\n".join(source_list) + if len(re.findall(r"\$\w+->\w+", source_text)): + return 1 + else: + return 0 + + +def count_carets(source_list, length_of_source): + source_text = "\n".join(source_list) + total_count = source_text.count("^") + return total_count / length_of_source + + +def parse_and_score(file_location): + scores = [] + source_text = read_file(file_location) + source_length = count_characters(source_text) + listed_source_text = split_into_lines(source_text) + + comment_style = identify_comment_type(listed_source_text) + listed_source_text = strip_comments(comment_style, listed_source_text) + print_style = identify_print_style(listed_source_text) + parentheses_proportion = count_parentheses(listed_source_text, + source_length) + curly_brace_proportion = count_braces(listed_source_text, source_length) + bracket_proportion = count_brackets(listed_source_text, source_length) + double_colon_proportion = count_double_colons(listed_source_text, + source_length) + semi_colon_proportion = count_semi_colons(listed_source_text, + source_length) + dollar_sign_proportion = count_dollar_signs(listed_source_text, + source_length) + question_mark_proportion = count_question_marks(listed_source_text, + source_length) + pipes_proportion = count_pipes(listed_source_text, source_length) + percent_sign_proportion = count_percent_signs(listed_source_text, + source_length) + at_symbol_proportion = count_at_symbols(listed_source_text, source_length) + double_period_proportion = count_double_period(listed_source_text, + source_length) + contains_double_plus_minus = check_double_plus_minus(listed_source_text) + contains_colon_equals = check_colon_equals(listed_source_text) + contains_dash_arrow = check_equals_arrow(listed_source_text) + contains_reverse_dash_arrow = check_reverse_dash_arrow(listed_source_text) + contains_equals_arrow = check_equals_arrow(listed_source_text) + contains_word_function = check_for_function(listed_source_text) + contains_word_public = check_for_public(listed_source_text) + def_method = check_def_method(listed_source_text) + contains_end = check_for_end(listed_source_text) + contains_static = check_for_static(listed_source_text) + colon_word = check_for_colon_word(listed_source_text) + contains_type = check_for_type(listed_source_text) + bol_parentheses = count_bol_parentheses(listed_source_text, source_length) + contains_val = check_for_val(listed_source_text) + contains_where = check_for_where(listed_source_text) + contains_module = check_for_module(listed_source_text) + arrows_proportion = count_arrows(listed_source_text, source_length) + php_assignment = check_assignment(listed_source_text) + caret_count = count_carets(listed_source_text, source_length) + + scores.append(comment_style) + scores.append(print_style) + scores.append(parentheses_proportion) + scores.append(curly_brace_proportion) + scores.append(bracket_proportion) + scores.append(double_colon_proportion) + scores.append(semi_colon_proportion) + scores.append(dollar_sign_proportion) + scores.append(question_mark_proportion) + scores.append(pipes_proportion) + scores.append(percent_sign_proportion) + scores.append(at_symbol_proportion) + scores.append(contains_double_plus_minus) + scores.append(double_period_proportion) + scores.append(contains_colon_equals) + scores.append(contains_dash_arrow) + scores.append(contains_reverse_dash_arrow) + scores.append(contains_equals_arrow) + scores.append(contains_word_function) + scores.append(contains_word_public) + scores.append(def_method) + scores.append(contains_end) + scores.append(contains_static) + scores.append(colon_word) + scores.append(contains_type) + scores.append(bol_parentheses) + scores.append(contains_val) + scores.append(contains_where) + scores.append(contains_module) + scores.append(arrows_proportion) + scores.append(php_assignment) + scores.append(caret_count) + + return np.array(scores) diff --git a/predictor.py b/predictor.py new file mode 100644 index 0000000..41ef813 --- /dev/null +++ b/predictor.py @@ -0,0 +1,41 @@ +import os +import re +import sys +import pickle +import parser + + +def load_classifier(test_file="random_forest.dat"): + """Load the classifier from disk""" + with open(test_file, "rb") as saved_classifier: + classifier = pickle.load(saved_classifier) + return classifier + + +def prepare_file(file_location): + """Parse and score the file to be examined""" + return parser.parse_and_score(file_location) + + +def test_file(classifier, test_data): + """Use the classifier to predict the programming language of the + supplied file""" + return classifier.predict(test_data) + + +def get_probabilities(classifier, test_data): + """Returns the list of probabilities for each of the known programming + languages for the file being examined.""" + return classifier.predict_proba(test_data).tolist()[0] + + +if __name__ == '__main__': + classifier = load_classifier() + classes = classifier.classes_.tolist() + data = prepare_file(sys.argv[1]) + results = test_file(classifier, data) + probabilities = get_probabilities(classifier, data) + print("Programming Language Identification Results:\n") + for index, item in enumerate(classes): + print("{}: {}".format(item, probabilities[index])) + print("\nBest Guess: {}".format(results[0])) diff --git a/random_forest.dat b/random_forest.dat new file mode 100644 index 0000000..74f0af9 Binary files /dev/null and b/random_forest.dat differ diff --git a/test.csv b/test.csv index adbf5dd..a633962 100644 --- a/test.csv +++ b/test.csv @@ -1,33 +1,32 @@ -Filename,Language -1,clojure -2,clojure -3,clojure -4,clojure -5,python -6,python -7,python -8,python -9,javascript -10,javascript -11,javascript -12,javascript -13,ruby -14,ruby -15,ruby -16,haskell -17,haskell -18,haskell -19,scheme -20,scheme -21,scheme -22,java -23,java -24,scala -25,scala -26,tcl -27,tcl -28,php -29,php -30,php -31,ocaml -32,ocaml +1,Clojure +2,Clojure +3,Clojure +4,Clojure +5,Python +6,Python +7,Python +8,Python +9,JavaScript +10,JavaScript +11,JavaScript +12,JavaScript +13,Ruby +14,Ruby +15,Ruby +16,Haskell +17,Haskell +18,Haskell +19,Scheme +20,Scheme +21,Scheme +22,Java +23,Java +24,Scala +25,Scala +26,TCL +27,TCL +28,PHP +29,PHP +30,PHP +31,OCaml +32,OCaml diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/parser_test.py b/test/parser_test.py new file mode 100644 index 0000000..3c8d4cc --- /dev/null +++ b/test/parser_test.py @@ -0,0 +1,83 @@ +import parser + + +def test_read_file(): + text = parser.read_file("test/1") + assert text + + +def test_split_lines(): + text = "a\nb\nc" + text_list = parser.split_into_lines(text) + assert text_list == ["a", "b", "c"] + + +def test_identify_comment_type(): + text = ["/* Test comment"] + comment_type = parser.identify_comment_type(text) + assert comment_type == 0 + + text = ["(* Test comment *)"] + comment_type = parser.identify_comment_type(text) + assert comment_type == 1 + + text = [";; Test comment "] + comment_type = parser.identify_comment_type(text) + assert comment_type == 2 + + text = ["// Test comment "] + comment_type = parser.identify_comment_type(text) + assert comment_type == 3 + + text = ["-- Test comment "] + comment_type = parser.identify_comment_type(text) + assert comment_type == 4 + + text = ["# Test comment "] + comment_type = parser.identify_comment_type(text) + assert comment_type == 5 + + +def test_identify_comment_block(): + text = ["codecode", "/* Comment", "Comment", "End */", "Code"] + to_remove = parser.identify_comment_blocks(0, text) + assert to_remove == ["/* Comment", "Comment", "End */"] + + text = ["codecode", "(* Comment", "Comment", "End *)", "Code"] + to_remove = parser.identify_comment_blocks(1, text) + assert to_remove == ["(* Comment", "Comment", "End *)"] + + +def test_strip_comments(): + text = ["code", "code", "/* Comment", "comment", "comment */", "code", + "code", "// comment", "code", "// comment"] + new_text = parser.strip_comments(0, text) + assert new_text == ["code", "code", "code", "code", "code"] + + text = ["code", "# comment", "#comment", "code"] + new_text = parser.strip_comments(5, text) + assert new_text == ["code", "code"] + + +def test_strip_inline_comments(): + text = ["code // comment"] + new_text = parser.strip_inline_comments(3, text) + assert new_text == ["code "] + + +def test_identify_print_style(): + text = ["puts 'ths'"] + print_type = parser.identify_print_style(text) + assert print_type == 0 + + text = ["println 'ths'"] + print_type = parser.identify_print_style(text) + assert print_type == 1 + + text = ["printf 'ths'"] + print_type = parser.identify_print_style(text) + assert print_type == 2 + + text = ["print 'ths'"] + print_type = parser.identify_print_style(text) + assert print_type == 3 diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000..e130170 --- /dev/null +++ b/trainer.py @@ -0,0 +1,87 @@ +import numpy as np +import os +import re +from sklearn.cross_validation import cross_val_score, train_test_split +from sklearn.metrics import (classification_report, f1_score, accuracy_score, + confusion_matrix) +from sklearn.ensemble import RandomForestClassifier +import pickle +import parser + + +def create_training_data(): + """Creates training data list and results list corresponding to the + data list for training the classifier.""" + data_directory = "train_files/" + + filetype_dict = create_filetype_dict() + filetype_list = list(set(value for key, value in filetype_dict.items())) + training_data = [] + training_results = [] + + for file in os.listdir(data_directory): + fileext = re.findall(r"\w+\.?\w+?\.(\w+)", file) + if fileext: + if fileext[0] == "txt": + continue + filetype = filetype_dict[fileext[0]] + training_results.append(filetype) + training_data.append(parser.parse_and_score(data_directory + file)) + + return training_data, training_results + + +def create_filetype_dict(): + """Read in file containing all known file extensions and their language + and create a dictionary for lookup""" + with open("train_files/extension_dict.txt") as filetype: + filetype_data = filetype.readlines() + filetype_list = [] + for line in filetype_data: + filetype_list.extend(re.findall(r"(\w+), (\w+)", line)) + filetype_dict = {} + for filetypes in filetype_list: + key, value = filetypes + filetype_dict[key] = value + return filetype_dict + + +def split_data(data, results, test_size): + """Splits the data into training and test data sets.""" + train_data, test_data, train_results, test_results = train_test_split( + data, results, test_size=test_size, random_state=0) + + return train_data, test_data, train_results, test_results + + +def train_learner(train_data, train_results): + """Fit the classifier to the training data.""" + learner = RandomForestClassifier(n_estimators=100, random_state=0) + learner.fit(train_data, train_results) + return learner + + +def test_learner(learner, test_data, test_results): + """Test the classifier against the test data""" + prediction = learner.predict(test_data) + print(classification_report(test_results, prediction)) + print(confusion_matrix(test_results, prediction)) + print(f1_score(test_results, prediction)) + + +def export_forest(forest): + """Save the classifer as a pickle file for use in the predictor.""" + with open("random_forest.dat", "wb") as file: + pickle.dump(max_train, file) + + +if __name__ == '__main__': + data, results = create_training_data() + train_data, test_data, train_results, test_results = split_data(data, + results, + 0.2) + trained_forest = train_learner(train_data, train_results) + test_learner(trained_forest, test_data, test_results) + + max_train = train_learner(data, results) + export_forest(max_train)