diff --git a/capectracer/capec_specs_dict.py b/capectracer/capec_specs_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..12aedcd32cc9e237ee3d814b48a13ee82007b41a --- /dev/null +++ b/capectracer/capec_specs_dict.py @@ -0,0 +1,221 @@ +import xmltodict +import json +import re + +from tokens_dict import TokensDict + +from gensim.models.phrases import Phrases, ENGLISH_CONNECTOR_WORDS, Phraser + +class CapecSpecsDict(TokensDict): + def __init__(self, capec_file, training_corpus_file): + super().__init__() + + self.__capec_names = [] + self.__capec_descriptions = [] + self.__capec_execution_flows = [] + self.__capec_mitigations = [] + self.__capec_tokens = [] + self.__capec_tokens_lemm = [] + self.__system_spec_tokens = [] + self.__system_spec_tokens_lemm = [] + self.__training_tokens = [] + # self.__bigram_model = None + # self.__trigram_model = None + + self.parse_capecs(capec_file) + self.parse_training_corpus(training_corpus_file) + + def get_capec_names(self): + return self.__capec_names + + def get_capec_descs(self): + return self.__capec_descriptions + + def get_capec_execution_flows(self): + return self.__capec_execution_flows + + def get_capec_mitigations(self): + return self.__capec_mitigations + + def get_capec_tokens(self): + return self.__capec_tokens + + def get_capec_tokens_lemm(self): + return self.__capec_tokens_lemm + + def get_system_spec_tokens(self): + return self.__system_spec_tokens + + def get_system_spec_tokens_lemm(self): + return self.__system_spec_tokens_lemm + + def get_training_tokens(self): + return self.__training_tokens + + def get_training_word_doc_count(self): + word_document_count = {} # Dictionary to store word-document counts + + # Iterate over each document in the corpus + for document in self.__training_tokens: + unique_words_in_doc = set(document) + + # Update the word-document count for each unique word in the document + for word in unique_words_in_doc: + # Increment the count for the word if it exists, otherwise initialize it to 1 + word_document_count[word] = word_document_count.get(word, 0) + 1 + + sorted_counts = sorted(word_document_count.items(), key=lambda x: (-x[1], x[0])) + + return sorted_counts + + def parse_capecs(self, capec_file): + capec_descriptions = [] + capec_names = [] + capec_execution_flows = [] + capec_mitigations = [] + + with open(capec_file, 'r') as xml_file: + xml_data = xml_file.read() + xml_data = re.sub(r"<xhtml:(.*?)>", "", xml_data) + xml_data = re.sub(r"</xhtml:(.*?)>", "", xml_data) + capec_dict = xmltodict.parse(xml_data) + + attack_patterns = capec_dict['Attack_Pattern_Catalog']['Attack_Patterns']['Attack_Pattern'] + + for attack_pattern in attack_patterns: + if (attack_pattern['@Abstraction'] == 'Standard' or \ + attack_pattern['@Abstraction'] == 'Detailed') and \ + attack_pattern['@Status'] != 'Obsolete' and \ + attack_pattern['@Status'] !='Deprecated' and \ + "Execution_Flow" in attack_pattern: + + capec_names.append(attack_pattern['@Name']) + + description = "" + description = attack_pattern["Description"] + + if "Extended_Description" in attack_pattern: + description = description + " " + attack_pattern["Extended_Description"] + + capec_descriptions.append(description) + + attack_steps_list = [] + attack_steps = attack_pattern["Execution_Flow"]["Attack_Step"] + + if isinstance(attack_steps, list): + for idx, attack_step in enumerate(attack_steps): + cleaned_attack_step = "Step " + str(idx + 1) + ". " + \ + super().clean_text(attack_step["Description"]) + attack_steps_list.append(cleaned_attack_step) + else: + cleaned_attack_step = "Step 1. " + super().clean_text(attack_steps["Description"]) + attack_steps_list.append(cleaned_attack_step) + + capec_execution_flows.append(attack_steps_list) + + mitigations_list = [] + + if "Mitigations" in attack_pattern: + mitigations = attack_pattern["Mitigations"]["Mitigation"] + + if isinstance(mitigations, list): + for mitigation in mitigations: + cleaned_mitigation = super().clean_text(mitigation) + mitigations_list.append(cleaned_mitigation) + else: + cleaned_mitigation = super().clean_text(mitigations) + mitigations_list.append(cleaned_mitigation) + + capec_mitigations.append(mitigations_list) + + capec_descriptions = super().clean_texts(capec_descriptions) + + self.__capec_names = capec_names + self.__capec_descriptions = capec_descriptions + self.__capec_execution_flows = capec_execution_flows + self.__capec_mitigations = capec_mitigations + + capec_tokens = super().tokenize_text(capec_descriptions) + capec_tokens = super().lowercase_tokens(capec_tokens) + capec_tokens = super().remove_numbers(capec_tokens) + + capec_tokens_lemm = super().lemmatize_tokens(capec_tokens) + + self.__capec_tokens = capec_tokens + self.__capec_tokens_lemm = capec_tokens_lemm + + # bigram = Phrases(training_tokens, min_count=7, threshold=18, + # connector_words=ENGLISH_CONNECTOR_WORDS) + # self.__bigram_model = Phraser(bigram) + + # trigram = Phrases(bigram[training_tokens], min_count=16, threshold=26, + # connector_words=ENGLISH_CONNECTOR_WORDS) + # self.__trigram_model = Phraser(trigram) + + # # Apply the bigram model to each document in the training tokens corpus + # training_tokens = [self.__bigram_model[doc] for doc in training_tokens] + # # Apply the trigram model to the resulting bigram training tokens corpus + # training_tokens = [self.__trigram_model[doc] for doc in training_tokens] + + def parse_training_corpus(self, training_corpus_file): + unduped_training_tokens = [] + + # Open the JSON file + with open(training_corpus_file, 'r') as file: + # Load the JSON data + data = json.load(file) + + for source in ["MS-Bulletin", "Metasploit", "NVD"]: + for _, value in data[source].items(): + corpora_tokens = [] + + for word_label in value: + word = word_label[0] + + if word.isalnum(): + corpora_tokens.append(word) + + unduped_training_tokens.append(corpora_tokens) + + unduped_training_tokens = super().lowercase_tokens(unduped_training_tokens) + unduped_training_tokens = super().remove_numbers(unduped_training_tokens) + + unduped_training_tokens += self.__capec_tokens + + training_tokens_dict = {} + + # Iterate through each array in the outer array + for index, token_set in enumerate(unduped_training_tokens): + # Sort the elements of the array to ignore the order + sorted_token_set = sorted(token_set) + + # Convert the sorted array to a tuple to make it hashable + token_set_tuple = tuple(sorted_token_set) + + # Compute the hash value of the sorted array + token_set_hash = hash(token_set_tuple) + + # Check if the hash value already exists in the dictionary + if token_set_hash not in training_tokens_dict: + # If the hash value doesn't exist, add the sorted array to the dictionary + training_tokens_dict[token_set_hash] = index + + training_tokens = [unduped_training_tokens[index] for index in list(training_tokens_dict.values())] + + self.__training_tokens = training_tokens + + def parse_system_specs(self, spec_file): + with open(spec_file, 'r') as file: + # Read the contents of the file + file_contents = file.read() + + system_spec_texts = super().clean_texts([file_contents]) + + system_spec_tokens = super().tokenize_text(system_spec_texts) + system_spec_tokens = super().lowercase_tokens(system_spec_tokens) + system_spec_tokens = super().remove_numbers(system_spec_tokens) + + system_spec_tokens_lemm = super().lemmatize_tokens(system_spec_tokens) + + self.__system_spec_tokens = system_spec_tokens[0] + self.__system_spec_tokens_lemm = system_spec_tokens_lemm[0] \ No newline at end of file diff --git a/capectracer/capec_tracer.py b/capectracer/capec_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f862936a3d64e86ab6d2d5cd7b85000b2821a9 --- /dev/null +++ b/capectracer/capec_tracer.py @@ -0,0 +1,140 @@ +import requests +import os + +from capec_specs_dict import CapecSpecsDict + +from gensim.corpora import Dictionary +from gensim.models import LdaModel +from gensim.matutils import hellinger +from gensim import similarities +from gensim.models import TfidfModel +from sklearn.utils import shuffle + +def get_capec_file(abs_path): + # Send a GET request to the URL + response = requests.get("https://capec.mitre.org/data/xml/capec_latest.xml") + + # Check if the request was successful (status code 200) + if response.status_code == 200: + # Open the file in binary write mode and write the content + with open(abs_path + "capec_latest.xml", 'wb') as file: + file.write(response.content) + + return response + +def trace_capecs(abs_path): + try: + tokens_dicti = CapecSpecsDict(abs_path + "capec_latest.xml", abs_path + 'full_corpus.json') + + capec_names = tokens_dicti.get_capec_names() + capec_descs = tokens_dicti.get_capec_descs() + capec_attack_steps = tokens_dicti.get_capec_execution_flows() + capec_mitigations = tokens_dicti.get_capec_mitigations() + capec_tokens = tokens_dicti.get_capec_tokens() + capec_tokens_lemm = tokens_dicti.get_capec_tokens_lemm() + + tokens_dicti.parse_system_specs(abs_path + "system_specs.txt") + spec_tokens = tokens_dicti.get_system_spec_tokens() + spec_tokens_lemm = tokens_dicti.get_system_spec_tokens_lemm() + + # Create a dictionary from the corpus documents + capec_lemm_dictionary = Dictionary(capec_tokens_lemm) + + # Create a bag-of-words corpus from the corpus documents + capec_lemm_corpus = [capec_lemm_dictionary.doc2bow(doc) for doc in capec_tokens_lemm] + + # Create a TF-IDF model + tfidf_model = TfidfModel(capec_lemm_corpus, id2word=capec_lemm_dictionary) + + # Convert the checked document to TF-IDF representation + spec_lemm_bow = capec_lemm_dictionary.doc2bow(spec_tokens_lemm) + spec_tfidf = tfidf_model[spec_lemm_bow] + + # Create a similarity index + index = similarities.MatrixSimilarity(tfidf_model[capec_lemm_corpus], num_features=len(capec_lemm_dictionary)) + + # Calculate similarity scores + similarity_scores = index[spec_tfidf] + + # Combine similarity scores with document indices + doc_similarity_pairs = list(enumerate(similarity_scores)) + + # Sort documents by similarity score + sorted_doc_similarity_pairs = sorted(doc_similarity_pairs, key=lambda x: x[1], reverse=True) + + # Print the top N most similar documents + # for idx, score in sorted_doc_similarity_pairs: + # print(score) + # print(capec_descs[idx]) + + if sorted_doc_similarity_pairs[1][1] > 0: + training_tokens = tokens_dicti.get_training_tokens() + training_tokens = shuffle(training_tokens) + + training_dictionary = Dictionary(training_tokens) + training_dictionary.filter_extremes(no_below=2, no_above=0.3) + + training_corpus = [training_dictionary.doc2bow(text) for text in training_tokens] + + lda = LdaModel(corpus=training_corpus, + num_topics=18, + id2word=training_dictionary, + passes=12, + alpha="auto", + eta="auto", + iterations=6, + update_every=1, + decay=0.745786579640169, + offset=3.20092862358715 + ) + + system_specs_bow = training_dictionary.doc2bow(spec_tokens) + system_spec_topic_dist = lda[system_specs_bow] + + difference_scores = [] + + for index, capec_token_set in enumerate(capec_tokens): + capec_desc_bow = training_dictionary.doc2bow(capec_token_set) + capec_topic_dist = lda[capec_desc_bow] + + score = [index, hellinger(capec_topic_dist, system_spec_topic_dist)] + + difference_scores.append(score) + + # Sort the arrays based on the value of the second index in each array + difference_scores = sorted(difference_scores, key=lambda x: x[1], reverse=False) + + with open(abs_path + 'traced_capecs.txt', 'w') as output: + for score in difference_scores: + output.write(f'Confidence score: {100 - int(100 * score[1])}%\n') + output.write(f'Name: {capec_names[score[0]]}\n') + output.write(f'Description:\n') + output.write(f'{capec_descs[score[0]]}\n') + output.write(f'Attack Steps:\n') + for attack_step in capec_attack_steps[score[0]]: + output.write(f'{attack_step}\n') + if capec_mitigations[score[0]]: + output.write(f"Mitigations:\n") + for mitigation in capec_mitigations[score[0]]: + output.write(f'{mitigation}\n') + output.write('\n') + else: + with open(abs_path + 'traced_capecs.txt', 'w') as output: + output.write("No attack patterns were able to be identified with the provided system specifications. " + + "Please provide the system specifications with additional details.\n") + except Exception as e: + with open(abs_path + 'traced_capecs.txt', 'w') as output: + output.write(str(e)) + +if __name__ == "__main__": + abs_path = os.path.abspath(__file__) + last_slash_index = abs_path.rfind("/") + abs_path = abs_path[:last_slash_index + 1] + + response = get_capec_file(abs_path) + + if response.status_code == 200: + trace_capecs(abs_path) + else: + with open(abs_path + 'traced_capecs.txt', 'w') as output: + output.write(f"Failed to download the list of CAPECs from MITRE. Status code: {response.status_code}.") diff --git a/capectracer/full_corpus.json b/capectracer/full_corpus.json new file mode 100755 index 0000000000000000000000000000000000000000..b7c038819c6f10f173b44b6a01e3da0c5c8b1086 Binary files /dev/null and b/capectracer/full_corpus.json differ diff --git a/capectracer/requirements_capec_tracer.txt b/capectracer/requirements_capec_tracer.txt new file mode 100644 index 0000000000000000000000000000000000000000..205667098ded7fb1ef0699bc71262be4b0416c0f --- /dev/null +++ b/capectracer/requirements_capec_tracer.txt @@ -0,0 +1,5 @@ +scikit-learn==1.0.2 +gensim==4.1.2 +spacy==3.7.2 +xmltodict==0.13.0 +nltk==3.7 \ No newline at end of file diff --git a/capectracer/requirements_model_training.txt b/capectracer/requirements_model_training.txt new file mode 100644 index 0000000000000000000000000000000000000000..46d9d49fecc66a8f9a5fa22fc0e5bdf77c171f82 --- /dev/null +++ b/capectracer/requirements_model_training.txt @@ -0,0 +1,4 @@ +tqdm==4.64.0 +numpy==1.21.5 +bayesian-optimization==1.4.3 +pandas==1.4.2 \ No newline at end of file diff --git a/capectracer/tokens_dict.py b/capectracer/tokens_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..a2aa8ebeda84f2b7eca0d2260cea8d0ac86dd2aa --- /dev/null +++ b/capectracer/tokens_dict.py @@ -0,0 +1,83 @@ +import re +import nltk + +from nltk.stem import WordNetLemmatizer +from nltk.tokenize import word_tokenize +import spacy + +class TokensDict: + def __init__(self): + # Initialize NLTK + nltk.download('punkt') + nltk.download('wordnet') + nltk.download('averaged_perceptron_tagger') + nltk.download('omw-1.4') + + def tokenize_text(self, texts): + tokens_list = [] + + for text in texts: + # Remove punctuation using regular expression + prepro_text = re.sub(r'[^a-zA-Z0-9\s/-]', '', text) + prepro_text = re.sub(r'[-/]', ' ', prepro_text) + # Tokenization using NLTK + tokens = word_tokenize(prepro_text) + + tokens_list.append(tokens) + + return tokens_list + + def remove_numbers(self, tokens_list): + filtered_tokens_list = [[token for token in token_set if not token.isnumeric()] \ + for token_set in tokens_list] + + return filtered_tokens_list + + def lemmatize_tokens(self, tokens_list): + en = spacy.load('en_core_web_sm') + stop_words = en.Defaults.stop_words + + # Initialize WordNet Lemmatizer + lemmatizer = WordNetLemmatizer() + lemmatized_token_list = [] + + for token_set in tokens_list: + lemmatized_tokens = [] + tagged_tokens = nltk.pos_tag(token_set) + + for token, pos in tagged_tokens: + if token not in stop_words: + if pos.startswith("N"): + lemmatized_tokens.append(lemmatizer.lemmatize(token, "n")) + elif pos.startswith("V"): + lemmatized_tokens.append(lemmatizer.lemmatize(token, "v")) + elif pos.startswith('J'): + lemmatized_tokens.append(lemmatizer.lemmatize(token, "a")) + elif pos.startswith('R'): + lemmatized_tokens.append(lemmatizer.lemmatize(token, "r")) + else: + lemmatized_tokens.append(token) + + lemmatized_token_list.append(lemmatized_tokens) + + return lemmatized_token_list + + def lowercase_tokens(self, tokens_list): + return [[token.lower() for token in token_set] for token_set in tokens_list] + + def clean_text(self, text): + cleaned_text = re.sub(r'\n', '', text) + cleaned_text = re.sub(r'\s+', ' ', cleaned_text) + cleaned_text = re.sub(r'^\[[^\]]*\]\s*', '', cleaned_text) + + return cleaned_text + + def clean_texts(self, texts): + cleaned_texts = [] + + for text in texts: + cleaned_text = self.clean_text(text) + + cleaned_texts.append(cleaned_text) + + return cleaned_texts \ No newline at end of file diff --git a/capectracer/training_lda_model.py b/capectracer/training_lda_model.py new file mode 100644 index 0000000000000000000000000000000000000000..29e4d95b0e011463948bd04e5885d7b417d53e13 --- /dev/null +++ b/capectracer/training_lda_model.py @@ -0,0 +1,206 @@ +import tqdm +import logging +import numpy as np +import os.path +import requests + +from capec_specs_dict import CapecSpecsDict + +from gensim.corpora import Dictionary +from gensim.models import LdaModel +from gensim.models.callbacks import CoherenceMetric, PerplexityMetric, ConvergenceMetric +from bayes_opt import BayesianOptimization +from pandas import DataFrame +from bayes_opt.logger import JSONLogger +from bayes_opt.event import Events +from sklearn.model_selection import KFold +from functools import partial + +# Train model +def train_lda(training_tokens, dictionary, num_topics, passes, iterations, update_every, + training_logs, training_results_csv, eval_every, decay, offset): + model_dict = {'Topics': [], + 'Passes': [], + 'Iterations': [], + 'Update_Every': [], + 'Decays': [], + 'Offsets': [], + 'Avg_Coherences': [], + 'Avg_Perplexities': [], + 'Avg_Convergences': [], + 'Dev_Coherences': [], + 'Dev_Perplexities': [], + 'Dev_Convergences': [], + 'Epochs': []} + + # Initialize k-fold cross-validation + k = 4 + kf = KFold(n_splits=k, shuffle=True, random_state=1) + + # Initialize a list to store evaluation metrics for each fold + coherence_scores = [] + perplexity_scores = [] + convergence_scores = [] + + k_index = 1 + + print("Training with", num_topics, "topics,", passes, + "max passes,", iterations, "max iterations,", + update_every, "update every,", decay, "decay,", offset, "offset:", sep=" ") + + pbar = tqdm.trange(k) + + # Iterate over each fold + for train_index, test_index in kf.split(training_tokens): + train_tokens = [training_tokens[i] for i in train_index] + val_tokens = [training_tokens[i] for i in test_index] + + train_corpus = [dictionary.doc2bow(doc) for doc in train_tokens] + val_corpus = [dictionary.doc2bow(doc) for doc in val_tokens] + + # Configure logging + logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG, + filename=training_logs) + + logging.info( + "Training with %s topics, %s max passes, %s max iterations, %s update every, %s decay, %s offset at %s fold:", + num_topics, passes, iterations, update_every, decay, offset, k_index) + + coherence_metric = CoherenceMetric( + texts=train_tokens, coherence='c_npmi', window_size=10) + perplexity_metric = PerplexityMetric(corpus=val_corpus) + convergence_metric = ConvergenceMetric(distance="hellinger") + + lda_model = LdaModel(corpus=train_corpus, + num_topics=num_topics, + id2word=dictionary, + passes=passes, + alpha="auto", + eta="auto", + random_state=1, + iterations=iterations, + update_every=update_every, + decay=decay, + offset=offset, + eval_every=eval_every, + callbacks=[perplexity_metric, coherence_metric, convergence_metric], + chunksize=2000) + + logging.shutdown() + + coherence_scores.append(lda_model.metrics['Coherence']) + perplexity_scores.append(lda_model.metrics['Perplexity']) + convergence_scores.append(lda_model.metrics['Convergence']) + + k_index += 1 + + pbar.update(1) + + pbar.close() + + for epoch, _ in enumerate(coherence_scores[0]): + model_dict['Topics'].append(num_topics) + model_dict['Passes'].append(passes) + model_dict['Iterations'].append(iterations) + model_dict['Update_Every'].append(update_every) + model_dict['Decays'].append(decay) + model_dict['Offsets'].append(offset) + model_dict['Epochs'].append(epoch + 1) + + coherence_score_averages = np.mean(coherence_scores, axis=0) + perplexity_score_averages = np.mean(perplexity_scores, axis=0) + convergence_score_averages = np.mean(convergence_scores, axis=0) + + # Compute the absolute deviation of each element from its corresponding column mean + coherence_absolute_deviations = np.abs(coherence_scores - coherence_score_averages) + perplexity_absolute_deviations = np.abs(perplexity_scores - perplexity_score_averages) + convergence_absolute_deviations = np.abs(convergence_scores - convergence_score_averages) + + # Compute the average of these absolute deviations for each column + coherece_absolute_average_deviations = np.mean(coherence_absolute_deviations, axis=0) + perplexity_absolute_average_deviations = np.mean(perplexity_absolute_deviations, axis=0) + convergence_absolute_average_deviations = np.mean(convergence_absolute_deviations, axis=0) + + model_dict['Avg_Coherences'] = coherence_score_averages + model_dict['Avg_Perplexities'] = perplexity_score_averages + model_dict['Avg_Convergences'] = convergence_score_averages + model_dict['Dev_Coherences'] = coherece_absolute_average_deviations + model_dict['Dev_Perplexities'] = perplexity_absolute_average_deviations + model_dict['Dev_Convergences'] = convergence_absolute_average_deviations + + model_df = DataFrame.from_dict(model_dict) + + if os.path.isfile(training_results_csv): + model_df.to_csv(training_results_csv, mode='a', header=False) + else: + model_df.to_csv(training_results_csv, mode='w', header=True) + + return coherence_score_averages[len(coherence_score_averages) - 1] + # return -perplexity_score_averages[len(perplexity_score_averages) - 1] + +def train_lda_wrapper(training_tokens, dictionary, num_topics, passes, iterations, update_every, + training_logs, training_results_csv, eval_every, decay=0.5, offset=1): + num_topics = int(num_topics) + passes = int(passes) + iterations = int(iterations) + update_every = int(update_every) + + return train_lda(training_tokens, dictionary, num_topics, passes, iterations, update_every, + training_logs, training_results_csv, eval_every, decay, offset) + +def get_capec_file(): + # Send a GET request to the URL + response = requests.get("https://capec.mitre.org/data/xml/capec_latest.xml") + + # Check if the request was successful (status code 200) + if response.status_code == 200: + # Open the file in binary write mode and write the content + with open("capec_latest.xml", 'wb') as file: + file.write(response.content) + + return response + +if __name__ == "__main__": + response = get_capec_file() + + if response.status_code == 200: + tokens_dicti = CapecSpecsDict("capec_latest.xml", 'full_corpus.json') + + training_tokens = tokens_dicti.get_training_tokens() + + # Create dictionary and train/test corpus + dictionary = Dictionary(training_tokens) + dictionary.filter_extremes(no_below=2, no_above=0.3) + + train_lda_partial = partial(train_lda_wrapper, + training_tokens=training_tokens, + dictionary=dictionary, + training_logs="lda_training4.log", + training_results_csv="lda_training_results4.csv", + eval_every=10 + ) + + # train_lda_partial(num_topics=181, + # passes=20, iterations=len(tokens_dicti.get_capec_names()), update_every=1, decay=0.7, offset=10) + + optimizer = BayesianOptimization( + f=train_lda_partial, + pbounds={'num_topics': (2, 30), + 'passes': (10, 20), + 'iterations': (1, 100), + 'update_every': (1, 1), + 'decay': (0.5, 1), + 'offset': (1, 16)}, + verbose=2, + random_state=1, + ) + + optimizer.set_gp_params(alpha=1e-3) + + logger = JSONLogger(path="./lda_optimization4_log") + optimizer.subscribe(Events.OPTIMIZATION_STEP, logger) + + optimizer.maximize(init_points=10, n_iter=50) + print(optimizer.max) + else: + print(f"Failed to download the list of CAPECs from MITRE. Status code: {response.status_code}.") \ No newline at end of file diff --git a/src/main/java/ai/AIAttackPatternTree1.java b/src/main/java/ai/AIAttackPatternTree1.java new file mode 100644 index 0000000000000000000000000000000000000000..ec8751ebaa4aff86a1981af4092f21cf8950919b --- /dev/null +++ b/src/main/java/ai/AIAttackPatternTree1.java @@ -0,0 +1,849 @@ +package ai; + +import attacktrees.*; +import myutil.TraceManager; + +import org.json.JSONArray; +import org.json.JSONObject; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * Class AIUseCaseDiagram + * <p> + * Creation: 19/03/2024 + * + * @author Alan Birchler De Allende + */ +public class AIAttackPatternTree1 extends AIInteract { + private static final String KNOWLEDGE_ON_JSON_FOR_ROOT = "When you are asked to identify the root attack of " + + "an attack pattern and how it can be used to exploit a provided system specification, " + + "return the root attack formatted as JSON like so: " + + "{\"rootattack\": {\"name\": \"NameOfRootAttack\", \"description\": \"" + + "The description of the root attack and how it can be used to exploit the " + + "system specifications.\"}} " + + "# Respect: All words in the \"name\" of the root attack must be conjoined together. " + + "# Respect: There must be no more than forty characters in the \"name\" of the root attack. " + + "# Respect: For each word in the \"name\" of the root attack, its first letter must be capitalized. " + + "# Respect: Include what the root attack is and how it can be used to exploit the system " + + "specification in the same \"description\" key." + + "# Respect: All words in the \"description\" key must be separated with spaces."; + + private static final String KNOWLEDGE_ON_JSON_FOR_ATTACKS = "When you are asked to identify the attack nodes " + + "that an attacker needs to complete to successfully achieve the root attack, " + + "return them as a JSON specification formatted as follows: " + + "{\"attacknodes\": [{\"name\": \"NameOfAttackNode\", \"description\": \"" + + "The description of the attack node and how it brings an attacker closer to the root attack.\"} ...]} " + + "# Respect: All words in the \"name\" of each attack node must be conjoined together. " + + "# Respect: There must be no more than forty characters in the \"name\" of each attack node. " + + "# Respect: For each word in each attack node's \"name\", its first letter must be capitalized. " + + "# Respect: Include what the attack node is and how it is used by an attacker for getting closer " + + "to the root attack in the same \"description\" key. " + + "# Respect: All words in the \"description\" key must be separated with spaces."; + + private static final String KNOWLEDGE_ON_JSON_FOR_ATT_CONNS = "When you are asked to identify connections between the root " + + "attack and the attack nodes, return them as a JSON specification formatted as follows: " + + "{\"attackconnections\": [{\"parentattack\": \"NameOfRootAttack or NameOfAttackNode\", " + + "\"connectiontype\": \"the connection type\", " + + "\"childrenattacks\": [\"NameOfAttackNode\" ...]} ...]} " + + "# Respect: There must be at least two attack nodes in the \"childrenattacks\" array. " + + "# Respect: The \"childrenattacks\" array must not contain the root attack. " + + "# Note: There are four types of connections: \"OR\", \"XOR\", \"AND\", and \"SEQUENCE\". " + + "# Note: An \"OR\" connection represents the scenario that among all of the children attacks, " + + "an attacker only needs one of the children to proceed to the parent attack. " + + "# Note: A \"XOR\" connection represents the scenario that among all of the children attacks, " + + "an attacker only needs one and only one of the children to proceed to the parent attack. " + + "# Note: An \"AND\" connection represents the scenario that among all of the children objects, " + + "an attacker needs all children simultaneously to proceed to the parent attack. " + + "# Note: A \"SEQUENCE\" connection represents the scenario that an attacker needs each of the children " + + "objects sequentially to proceed to the parent attack. The first indexed child of a \"SEQUENCE\" " + + "connection is the child that an attacker needs first while the last indexed child is the child that " + + "an attacker needs last. " + + "# Respect: The \"connectiontype\" must only be \"OR\", \"XOR\", \"AND\", or \"SEQUENCE\"."; + + private static final String KNOWLEDGE_ON_JSON_FOR_MITIGATIONS = "When you are asked to identify the mitigations " + + "that prevent an attacker from performing an attack node, " + + "return them as a JSON specification formatted as follows: " + + "{\"mitigations\": [{\"name\": \"NameOfMitigation\", \"description\": \"" + + "The description of the mitigation and how it prevents an attacker from completing " + + "its associated attack node.\"} ...]} " + + "# Respect: All words in the \"name\" of each mitigation must be conjoined together. " + + "# Respect: There must be no more than forty characters in the \"name\" of each mitigation. " + + "# Respect: For each word in each mitigation's \"name\", its first letter must be capitalized. " + + "# Respect: Include what the mitigation is and how it is used to prevent an attacker from " + + "completing its associated attack node in the same \"description\" key. " + + "# Respect: If there are no mitigations that can be applied to the provided attack nodes, return the " + + "following JSON: {\"mitigations\": \"No mitigations were able to be identified.\"} " + + "# Respect: All words in the \"description\" key must be separated with spaces. "; + + private static final String KNOWLEDGE_ON_JSON_FOR_MITI_PAIRS = "When you are asked to identify the " + + "mitigation that can be applied to an attack node to prevent " + + "an attacker from performing the attack node, " + + "return them as a JSON specification formatted as follows: " + + "{\"mitigationpairings\": [{\"mitigation\": \"NameOfMitigation\", \"attacknode\": \"NameOfAttackNode\"} ...]} " + + "# Respect: The value of \"mitigation\" should be only one of the names of the given mitigations. " + + "# Respect: The value of \"attacknode\" should be only one of the names of the given attack nodes. " + + "# Respect: The value of \"attacknode\" should not be the name of the root attack."; + + private static final String[] KNOWLEDGE_STAGES = { + KNOWLEDGE_ON_JSON_FOR_ROOT, + KNOWLEDGE_ON_JSON_FOR_ATTACKS, + KNOWLEDGE_ON_JSON_FOR_ATT_CONNS +// KNOWLEDGE_ON_JSON_FOR_MITIGATIONS, +// KNOWLEDGE_ON_JSON_FOR_MITI_PAIRS + }; + + private final String[] QUESTION_IDENTIFY_ATD = {"From the provided system specification and using the specified " + + "JSON format, identify a possible root attack that an attacker could potentially exploit on the provided " + + "system specification. Do respect the JSON format, and " + + "provide only JSON (no explanation before or after).\n", + + "From the provided system specification and root attack and " + + "using the specified JSON format, identify the attack nodes that an attacker needs to " + + "complete for achieving the root attack. Do respect the JSON format, and " + + "provide only JSON (no explanation before or after).\n", + + "From the provided system specification, attack pattern, root attack, and attack nodes and " + + "using the specified JSON format, identify the connections that illustrate in what order an " + + "attacker needs to complete the attack nodes to achieve the root attack. " + + "Do respect the JSON format, and provide only JSON (no explanation before or after).\n", + + "From the provided system specification, attack pattern, and attack nodes " + + "and using the specified JSON format, " + + "identify possible mitigations, if there are any, that could prevent an attacker from completing " + + "an attack node. Do respect the JSON format, and provide only JSON (no explanation before or after).\n", + + "From the provided system specification, attack pattern, attack nodes, and mitigations " + + "and using the specified JSON format, " + + "identify what mitigation can be applied to which attack node such that the mitigation prevents " + + "an attacker from performing the attack node. " + + "Do respect the JSON format, and provide only JSON (no explanation before or after).\n" + }; + + private String rootAttackData; + private String attackNodeData; + private String attConnections; + private String mitigations; + private String mitiPairs; + private AttackTree at; + + public AIAttackPatternTree1(AIChatData _chatData) { + super(_chatData); + + at = new AttackTree("", null); + } + + public AttackTree getATDiagram() { + return at; + } + + public void internalRequest() { + at = new AttackTree("", null); + int stage = 0; + String lastQuestion = chatData.lastQuestion.trim(); + String[] data = lastQuestion.split("\n\n"); + String systemSpec = data[0]; + String attackPattern = data[1]; + boolean apContainsMitigations = attackPattern.contains("Mitigations:"); + + String json = ""; + + String questionT = QUESTION_IDENTIFY_ATD[stage]; + + initKnowledge(); + makeKnowledge(stage, systemSpec, attackPattern); + + boolean done = false; + int cpt = 0; + + // actors, use cases and connections + while (!done && cpt < 40) { + cpt++; + boolean ok = makeQuestion(questionT); + + if (!ok) { + done = true; + TraceManager.addDev("Make question failed"); + } + + ArrayList<String> errors = null; + + try { + TraceManager.addDev("\n\nMaking specification from " + chatData.lastAnswer + "\n\n"); + json = extractJSON(); + + if (stage == 0) { + rootAttackData = ""; + + errors = new ArrayList<>(); + rootAttackData = checkRootAttack(json, errors); + TraceManager.addDev("Identified root attack - " + rootAttackData); + + if (rootAttackData.isEmpty()) { + errors.add("You must provide the root attack of " + + "the given attack pattern and how it can be used to " + + "exploit the provided system specification. Do respect the JSON format, and " + + "provide only JSON (no explanation before or after)."); + } + } + else if (stage == 1) { + attackNodeData = ""; + + errors = new ArrayList<>(); + attackNodeData = checkAttackNodes(json, errors); + TraceManager.addDev("Identified attack nodes: " + attackNodeData); + + if (attackNodeData.isEmpty()) { + errors.add("You must provide the attack nodes showing how an " + + "attacker uses these nodes to achieve the root attack. " + + "Do respect the JSON format, and " + + "provide only JSON (no explanation before or after)."); + } + } + else if (stage == 2) { + attConnections = ""; + + errors = new ArrayList<>(); + attConnections = checkAttackConns(json, errors); + TraceManager.addDev("Identified attack connections: " + attConnections); + + if (attConnections.isEmpty()) { + errors.add("You must provide the connections showing in what order" + + "an attacker needs to complete the attack nodes to achieve the " + + "root attack. Do respect the JSON format, and " + + "provide only JSON (no explanation before or after)."); + } + } + else if (stage == 3) { + mitigations = ""; + + errors = new ArrayList<>(); + mitigations = checkMitigations(json, errors, apContainsMitigations); + + TraceManager.addDev("Identified mitigations: " + mitigations); + + if (mitigations.isEmpty()) { + errors.add("You must provide mitigations showing how they prevent " + + "an attacker from performing an attack node. " + + "Do respect the JSON format, and " + + "provide only JSON (no explanation before or after)."); + } + } + else if (stage == 4) { + mitiPairs = ""; + + errors = new ArrayList<>(); + mitiPairs = checkMitiPairs(json, errors); + + TraceManager.addDev("Identified mitigation pairings: " + mitiPairs); + + if (mitiPairs.isEmpty()) { + errors.add("You must associate the provided mitigations with a provided attack node " + + "such that the mitigation prevents an attacker from performing the attack node. " + + "Do respect the JSON format, and " + + "provide only JSON (no explanation before or after)."); + } + } + } catch (org.json.JSONException e) { + TraceManager.addDev("Invalid JSON spec: " + extractJSON() + " because " + e.getMessage() + ": INJECTING ERROR"); + errors = new ArrayList<>(); + errors.add("There is an error in your JSON: " + e.getMessage() + ". Probably the JSON spec was incomplete. Do correct it. I need " + + "the full specification at once."); + } + + if ((errors != null) && (!errors.isEmpty())) { + questionT = "Your answer was as follows: " + json + "\n\nYet, it was not correct because of the following errors:"; + // Updating knowledge + for (String s : errors) { + questionT += "\n- " + s; + } + + initKnowledge(); + makeKnowledge(stage, systemSpec, attackPattern); + } else { + stage++; + + if (stage == KNOWLEDGE_STAGES.length) { + done = true; + } else { + initKnowledge(); + makeKnowledge(stage, systemSpec, attackPattern); + questionT = QUESTION_IDENTIFY_ATD[stage]; + + if (stage == 1) { + questionT += "\nThe root attack data is in the following JSON:\n" + rootAttackData.trim() + "\n"; + } + else if (stage == 2) { + questionT += "\nThe root attack data is in the following JSON:\n" + rootAttackData.trim() + "\n"; + questionT += "\nThe data of the attack nodes are in the following JSON:\n" + + attackNodeData.trim() + "\n"; + } + else if (stage == 3) { + questionT += "\nThe data of the attack nodes are in the following JSON:\n" + attackNodeData.trim(); + } + else if (stage == 4) { + if (mitigations.toLowerCase().contains("no mitigations were able to be identified")) { + done = true; + } + else { + questionT += "\nThe data of the attack nodes are in the following JSON:\n" + attackNodeData.trim(); + questionT += "\nThe data of the mitigations are in the following JSON:\n" + mitigations.trim(); + } + } + } + } + + waitIfConditionTrue(!done && cpt < 20); + + cpt++; + } + + TraceManager.addDev("Reached end of AIAttackPatternTree internal request cpt=" + cpt); + } + + public Object applyAnswer(Object input) { return at; } + + private void initKnowledge() { + chatData.aiinterface.clearKnowledge(); + } + + private void makeKnowledge(int stage, String _spec, String _attackPattern) { + String [] know = KNOWLEDGE_STAGES[stage].split("#"); + for(String s: know) { + TraceManager.addDev("\nKnowledge added: " + s); + chatData.aiinterface.addKnowledge(s, "ok"); + } + + if (_spec != null) { + TraceManager.addDev("\nKnowledge added: " + _spec); + chatData.aiinterface.addKnowledge("The system specification is: " + _spec, "ok"); + } + + if (_attackPattern != null) { + TraceManager.addDev("\nKnowledge added: " + _attackPattern); + chatData.aiinterface.addKnowledge("The attack pattern is: " + _attackPattern, "ok"); + } + } + + private String checkRootAttack(String _spec, Collection<String> _errors) throws org.json.JSONException { + if (_spec == null) { + _errors.add("No \"rootattack\" object in json"); + + return ""; + } + + int indexStart = _spec.indexOf('{'); + int indexStop = _spec.lastIndexOf('}'); + + if ((indexStart == -1) || (indexStop == -1) || (indexStart > indexStop)) { + _errors.add("Invalid JSON object (start or stop)"); + + return ""; + } + + String json = _spec.substring(indexStart, indexStop + 1); + + JSONObject mainObject = new JSONObject(json); + JSONObject rootAttackJSON = mainObject.getJSONObject("rootattack"); + + if (rootAttackJSON == null) { + TraceManager.addDev("No \"rootattack\" array in json"); + _errors.add("No \"rootattack\" array in json"); + + return ""; + } + else { + _errors.addAll(checkAttackSyntax(rootAttackJSON)); + } + + if (_errors.isEmpty()) { + Attack rootAttack = new Attack("", null); + rootAttack.setRoot(true); + rootAttack.setName(rootAttackJSON.getString("name")); + rootAttack.setDescription(rootAttackJSON.getString("description")); + at.addAttack(rootAttack); + + return _spec; + } + else { + return ""; + } + } + + private String checkAttackNodes(String _spec, Collection<String> _errors) throws org.json.JSONException { + if (_spec == null) { + _errors.add("No \"attacknodes\" array in json"); + + return ""; + } + + int indexStart = _spec.indexOf('{'); + int indexStop = _spec.lastIndexOf('}'); + + if ((indexStart == -1) || (indexStop == -1) || (indexStart > indexStop)) { + _errors.add("Invalid JSON object (start or stop)"); + + return ""; + } + + String json = _spec.substring(indexStart, indexStop + 1); + + JSONObject mainObject = new JSONObject(json); + JSONArray attackNodesJSON = mainObject.getJSONArray("attacknodes"); + + if (attackNodesJSON == null) { + TraceManager.addDev("No \"attacknodes\" array in json"); + _errors.add("No \"attacknodes\" array in json"); + + return ""; + } + else { + int i = 0; + + while (i < attackNodesJSON.length() && _errors.isEmpty()) + { + _errors.addAll(checkAttackSyntax(attackNodesJSON.getJSONObject(i))); + i++; + } + } + + if (_errors.isEmpty()) { + for (int i = 0; i < attackNodesJSON.length(); i++) { + JSONObject attackNode = attackNodesJSON.getJSONObject(i); + + Attack attack = new Attack("", null); + attack.setName(attackNode.getString("name")); + attack.setDescription(attackNode.getString("description")); + at.addAttack(attack); + } + + return _spec; + } + else { + return ""; + } + } + + private ArrayList<String> checkAttackSyntax(JSONObject attack) { + ArrayList<String> errors = new ArrayList<>(); + + if (attack.length() != 2) { + errors.add("The attack object should only contain \"name\" and \"description\""); + + return errors; + } + + String name = attack.getString("name"); + String description = attack.getString("description"); + + if (name == null) { + errors.add("Attack has no name"); + } + else if (!name.matches("[a-zA-Z0-9]+")) { + errors.add(name + " must only contain alphanumeric characters"); + } + else if (name.length() > 40) { + errors.add(name + " must only contain forty characters max"); + } + else if (description == null) { + errors.add("Attack has no description"); + } + else if (!description.contains(" ")) { + errors.add("The words in \"description\" must be separated with spaces."); + } + + return errors; + } + + private String checkAttackConns(String _spec, Collection<String> _errors) throws org.json.JSONException { + if (_spec == null) { + _errors.add("No \"attackconnections\" array in json"); + + return ""; + } + + int indexStart = _spec.indexOf('{'); + int indexStop = _spec.lastIndexOf('}'); + + if ((indexStart == -1) || (indexStop == -1) || (indexStart > indexStop)) { + _errors.add("Invalid JSON object (start or stop)"); + + return ""; + } + + String json = _spec.substring(indexStart, indexStop + 1); + + JSONObject mainObject = new JSONObject(json); + JSONArray attackConnsJSON = mainObject.getJSONArray("attackconnections"); + + if (attackConnsJSON == null) { + TraceManager.addDev("No \"attackconnections\" array in json"); + _errors.add("No \"attackconnections\" array in json"); + + return ""; + } + else { + int i = 0; + + while (i < attackConnsJSON.length() && _errors.isEmpty()) + { + _errors.addAll(checkAttConnSyntax(attackConnsJSON.getJSONObject(i))); + i++; + } + } + + if (_errors.isEmpty()) { + AttackNode connection = null; + + for (int i = 0; i < attackConnsJSON.length(); i++) { + JSONObject attackConn = attackConnsJSON.getJSONObject(i); + String parentAttackName = attackConn.getString("parentattack"); + String connectionType = attackConn.getString("connectiontype"); + JSONArray childrenAttackNames = attackConn.getJSONArray("childrenattacks"); + List<Attack> attackList = at.getAttacks(); + + switch(connectionType) { + case "OR": + connection = new ORNode("", null); + break; + case "XOR": + connection = new XORNode("", null); + break; + case "AND": + connection = new ANDNode("", null); + break; + default: + connection = new SequenceNode("", null); + break; + } + + Attack parentAttack = findAttack(parentAttackName, attackList, true); + connection.setResultingAttack(parentAttack); + int attackValue = 1; + + for (Object childAttackName : childrenAttackNames) { + String childAttackNameString = (String) childAttackName; + Attack childAttack = findAttack(childAttackNameString, attackList, false); + + connection.addInputAttack(childAttack, attackValue); + + attackValue++; + } + + at.addNode(connection); + } + + return _spec; + } + else { + return ""; + } + } + + private ArrayList<String> checkAttConnSyntax(JSONObject attackConn) { + ArrayList<String> errors = new ArrayList<>(); + + if (attackConn.length() != 3) { + errors.add("The attack connection object should only contain " + + "\"parentattack\", \"connectiontype\", and \"childrenattacks\"."); + + return errors; + } + + String parentAttack = attackConn.getString("parentattack"); + String connectionType = attackConn.getString("connectiontype"); + JSONArray childrenAttacks = attackConn.getJSONArray("childrenattacks"); + + if (parentAttack == null) { + errors.add("Connection has no parentattack"); + return errors; + } + + Attack foundParentAttack = findAttack(parentAttack, at.getAttacks(), true); + + if (foundParentAttack == null) { + errors.add(parentAttack + " is not the name of any provided root attack " + + "or attack node. Ensure that \"parentattack\" is the name of either the provided root" + + "attack or one of the attack nodes."); + } + + if (connectionType == null) { + errors.add("Connection has no connectiontype"); + } + else if (!connectionType.equals("OR") && + !connectionType.equals("XOR") && + !connectionType.equals("AND") && + !connectionType.equals("SEQUENCE")) { + errors.add("Ensure that connectiontype is one of the following types: \"OR\", \"XOR\", \"AND\", " + + "or \"SEQUENCE\""); + } + + if (childrenAttacks == null) { + errors.add("Connection has no childrenattacks"); + return errors; + } + else if (childrenAttacks.length() <= 1) { + errors.add("Connection has only one child attack. Please ensure that all connections have at least two " + + "children attacks."); + return errors; + } + + for (Object childAttack : childrenAttacks) { + String childAttackString = (String) childAttack; + Attack foundChildAttack = findAttack(childAttackString, at.getAttacks(), false); + + if (foundChildAttack == null) { + errors.add(childAttackString + " is not the name of any provided attack node. " + + "Ensure that all values in the \"childrenattacks\" array are the name of an attack node."); + } + } + + return errors; + } + + private Attack findAttack(String attackNodeName, List<Attack> attackNodeList, boolean checkRoot) { + Attack attack = null; + int i; + + if (!checkRoot) { + i = 1; + } + else { + i = 0; + } + + while (i < attackNodeList.size()) { + attack = attackNodeList.get(i); + + if (attackNodeName.equals(attack.getName())) { + return attack; + } + + i++; + } + + return null; + } + + private String checkMitigations(String _spec, Collection<String> _errors, boolean containsMiti) + throws org.json.JSONException { + String mitigationList = ""; + + if (_spec == null) { + _errors.add("No \"mitigations\" array in json"); + + return ""; + } + + int indexStart = _spec.indexOf('{'); + int indexStop = _spec.lastIndexOf('}'); + + if ((indexStart == -1) || (indexStop == -1) || (indexStart > indexStop)) { + _errors.add("Invalid JSON object (start or stop)"); + + return ""; + } + + String json = _spec.substring(indexStart, indexStop + 1); + + JSONObject mainObject = new JSONObject(json); + Object mitigations = mainObject.get("mitigations"); + + if (mitigations instanceof String && + mitigations.equals("No mitigations were able to be identified.")) { + + if (!containsMiti) { + return (String) mitigations; + } + else { + _errors.add("You specified that there were no mitigations able to be identified. Yet, the " + + "provided attack pattern contains a \"Mitigations:\" section."); + + return ""; + } + } + + JSONArray mitigationsJSON = (JSONArray) mitigations; + + if (mitigationsJSON == null) { + TraceManager.addDev("No \"mitigations\" array in json"); + _errors.add("No \"mitigations\" array in json"); + + return ""; + } + else { + int i = 0; + + while (i < mitigationsJSON.length() && _errors.isEmpty()) + { + _errors.addAll(checkMitigationSyntax(mitigationsJSON.getJSONObject(i))); + i++; + } + } + + if (_errors.isEmpty()) { +// for (int i = 0; i < mitigationsJSON.length(); i++) { +// JSONObject attackStep = mitigationsJSON.getJSONObject(i); +// +// Attack attack = new Attack("", null); +// attack.setName(attackStep.getString("name")); +// attack.setDescription(attackStep.getString("description")); +// at.addAttack(attack); +// +// mitigations += attack.toString(); +// } + + return _spec; + } + else { + return ""; + } + } + + private ArrayList<String> checkMitigationSyntax(JSONObject mitigation) { + ArrayList<String> errors = new ArrayList<>(); + + if (mitigation.length() != 2) { + errors.add("The mitigation object should only contain \"name\" and \"description\""); + + return errors; + } + + String name = mitigation.getString("name"); + String description = mitigation.getString("description"); +// String attackNodeName = mitigation.getString("attacknode"); + + if (name == null) { + errors.add("Mitigation has no name"); + } + else if (!name.matches("[a-zA-Z0-9]+")) { + errors.add(name + " must only contain alphanumeric characters"); + } + else if (name.length() > 40) { + errors.add(name + " must only contain forty characters max"); + } + else if (description == null) { + errors.add("Mitigation has no description"); + } + else if (!description.contains(" ")) { + errors.add("The words in \"description\" must be separated with spaces."); + } +// else if (attackNodeName == null) { +// errors.add("Mitigation has no associated attack node"); +// } +// else if (attackNodeName.equals(at.getAttacks().get(0).getName())) { +// errors.add(attackNodeName + " is the name of the root attack. " + +// "Do not associate mitigations with the root attack"); +// } + +// ArrayList<Attack> attackNodeList = at.getAttacks(); +// ArrayList<String> differentAttackNames = getAllDiffAttackNodeNames(attackNodeName, attackNodeList); +// if (differentAttackNames.size() == at.getAttacks().size() - 1) { +// StringBuilder builder = new StringBuilder( +// attackNodeName + " in \"attacknode\" is not the name of any of the provided attack nodes. " + +// "The value of \"attacknode\" can only be one of the following names: "); +// +// for (int i = 0; i < differentAttackNames.size(); i++) { +// if (i != differentAttackNames.size() - 1) { +// builder.append("\"").append(differentAttackNames.get(i)).append("\"").append(", "); +// } +// else { +// builder.append("or ").append("\"").append(differentAttackNames.get(i)).append("\""); +// } +// } + +// errors.add(builder.toString()); +// } + + return errors; + } + + private String checkMitiPairs(String _spec, Collection<String> _errors) + throws org.json.JSONException { + String mitigationList = ""; + + if (_spec == null) { + _errors.add("No \"mitigationpairings\" array in json"); + + return ""; + } + + int indexStart = _spec.indexOf('{'); + int indexStop = _spec.lastIndexOf('}'); + + if ((indexStart == -1) || (indexStop == -1) || (indexStart > indexStop)) { + _errors.add("Invalid JSON object (start or stop)"); + + return ""; + } + + String json = _spec.substring(indexStart, indexStop + 1); + + JSONObject mainObject = new JSONObject(json); + JSONArray mitigationPairingsJSON = mainObject.getJSONArray("mitigationpairings"); + + if (mitigationPairingsJSON == null) { + TraceManager.addDev("No \"mitigationpairings\" array in json"); + _errors.add("No \"mitigationpairings\" array in json"); + + return ""; + } + else { + int i = 0; + + while (i < mitigationPairingsJSON.length() && _errors.isEmpty()) + { + _errors.addAll(checkMitiPairingSyntax(mitigationPairingsJSON.getJSONObject(i))); + i++; + } + } + + if (_errors.isEmpty()) { +// for (int i = 0; i < mitigationsJSON.length(); i++) { +// JSONObject attackStep = mitigationsJSON.getJSONObject(i); +// +// Attack attack = new Attack("", null); +// attack.setName(attackStep.getString("name")); +// attack.setDescription(attackStep.getString("description")); +// at.addAttack(attack); +// +// mitigations += attack.toString(); +// } + + return _spec; + } + else { + return ""; + } + } + + private ArrayList<String> checkMitiPairingSyntax(JSONObject mitigation) { + ArrayList<String> errors = new ArrayList<>(); + + if (mitigation.length() != 2) { + errors.add("The mitigation pairing object should only contain \"mitigation\" and \"attacknode\""); + + return errors; + } + + String mitigationName = mitigation.getString("mitigation"); + String attackNodeName = mitigation.getString("attacknode"); + + if (mitigationName == null) { + errors.add("Pairing has no mitigation"); + } + else if (attackNodeName == null) { + errors.add("Pairing has no attack node"); + } + + Attack foundAttackNode = findAttack(attackNodeName, at.getAttacks(), false); + + if (foundAttackNode == null) { + errors.add(attackNodeName + " is not the name of any provided attack node. " + + "Ensure that the value of \"attacknode\" is the name of an attack node."); + } + + return errors; + } +} diff --git a/src/main/java/ai/AIAttackPatternTree2.java b/src/main/java/ai/AIAttackPatternTree2.java new file mode 100644 index 0000000000000000000000000000000000000000..ec841602f241c1dc19bdb3102ad8f98c469e9cb1 --- /dev/null +++ b/src/main/java/ai/AIAttackPatternTree2.java @@ -0,0 +1,850 @@ +package ai; + +import attacktrees.*; +import myutil.TraceManager; +import org.json.JSONArray; +import org.json.JSONObject; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * Class AIUseCaseDiagram + * <p> + * Creation: 19/03/2024 + * + * @author Alan Birchler De Allende + */ +public class AIAttackPatternTree2 extends AIInteract { + private static final String KNOWLEDGE_ON_JSON_FOR_ROOT = "When you are asked to identify the root attack of " + + "an attack pattern and how it can be used to exploit a provided system specification, " + + "return the root attack formatted as JSON like so: " + + "{\"rootattack\": {\"name\": \"NameOfRootAttack\", \"description\": \"" + + "The description of the root attack and how it can be used to exploit the " + + "system specifications.\"}} " + + "# Respect: All words in the \"name\" of the root attack must be conjoined together. " + + "# Respect: There must be no more than forty characters in the \"name\" of the root attack. " + + "# Respect: For each word in the \"name\" of the root attack, its first letter must be capitalized. " + + "# Respect: Include what the root attack is and how it can be used to exploit the system " + + "specification in the same \"description\" key." + + "# Respect: All words in the \"description\" key must be separated with spaces."; + + private static final String KNOWLEDGE_ON_JSON_FOR_ATTACKS = "When you are asked to identify the attack nodes " + + "that an attacker needs to complete to successfully achieve the root attack, " + + "return them as a JSON specification formatted as follows: " + + "{\"attacknodes\": [{\"name\": \"NameOfAttackNode\", \"description\": \"" + + "The description of the attack node and how it brings an attacker closer to the root attack.\"} ...]} " + + "# Respect: All words in the \"name\" of each attack node must be conjoined together. " + + "# Respect: There must be no more than forty characters in the \"name\" of each attack node. " + + "# Respect: For each word in each attack node's \"name\", its first letter must be capitalized. " + + "# Respect: Include what the attack node is and how it is used by an attacker for getting closer " + + "to the root attack in the same \"description\" key. " + + "# Respect: All words in the \"description\" key must be separated with spaces."; + + private static final String KNOWLEDGE_ON_JSON_FOR_ATT_CONNS = "When you are asked to identify connections between the root " + + "attack and the attack nodes, return them as a JSON specification formatted as follows: " + + "{\"attackconnections\": [{\"parentattack\": \"NameOfRootAttack or NameOfAttackNode\", " + + "\"connectiontype\": \"the connection type\", " + + "\"childrenattacks\": [\"NameOfAttackNode\" ...]} ...]} " + + "# Respect: There must be at least two attack nodes in the \"childrenattacks\" array. " + + "# Respect: The \"childrenattacks\" array must not contain the root attack. " + + "# Note: There are four types of connections: \"OR\", \"XOR\", \"AND\", and \"SEQUENCE\". " + + "# Note: An \"OR\" connection represents the scenario that among all of the children attacks, " + + "an attacker only needs one of the children to proceed to the parent attack. " + + "# Note: A \"XOR\" connection represents the scenario that among all of the children attacks, " + + "an attacker only needs one and only one of the children to proceed to the parent attack. " + + "# Note: An \"AND\" connection represents the scenario that among all of the children objects, " + + "an attacker needs all children simultaneously to proceed to the parent attack. " + + "# Note: A \"SEQUENCE\" connection represents the scenario that an attacker needs each of the children " + + "objects sequentially to proceed to the parent attack. The first indexed child of a \"SEQUENCE\" " + + "connection is the child that an attacker needs first while the last indexed child is the child that " + + "an attacker needs last. " + + "# Respect: The \"connectiontype\" must only be \"OR\", \"XOR\", \"AND\", or \"SEQUENCE\"."; + + private static final String KNOWLEDGE_ON_JSON_FOR_MITIGATIONS = "When you are asked to identify the mitigations " + + "that prevent an attacker from performing an attack node, " + + "return them as a JSON specification formatted as follows: " + + "{\"mitigations\": [{\"name\": \"NameOfMitigation\", \"description\": \"" + + "The description of the mitigation and how it prevents an attacker from completing " + + "its associated attack node.\"} ...]} " + + "# Respect: All words in the \"name\" of each mitigation must be conjoined together. " + + "# Respect: There must be no more than forty characters in the \"name\" of each mitigation. " + + "# Respect: For each word in each mitigation's \"name\", its first letter must be capitalized. " + + "# Respect: Include what the mitigation is and how it is used to prevent an attacker from " + + "completing its associated attack node in the same \"description\" key. " + + "# Respect: If there are no mitigations that can be applied to the provided attack nodes, return the " + + "following JSON: {\"mitigations\": \"No mitigations were able to be identified.\"} " + + "# Respect: All words in the \"description\" key must be separated with spaces. "; + + private static final String KNOWLEDGE_ON_JSON_FOR_MITI_PAIRS = "When you are asked to identify the " + + "mitigation that can be applied to an attack node to prevent " + + "an attacker from performing the attack node, " + + "return them as a JSON specification formatted as follows: " + + "{\"mitigationpairings\": [{\"mitigation\": \"NameOfMitigation\", \"attacknode\": \"NameOfAttackNode\"} ...]} " + + "# Respect: The value of \"mitigation\" should be only one of the names of the given mitigations. " + + "# Respect: The value of \"attacknode\" should be only one of the names of the given attack nodes. " + + "# Respect: The value of \"attacknode\" should not be the name of the root attack."; + + private static final String[] KNOWLEDGE_STAGES = { + KNOWLEDGE_ON_JSON_FOR_ROOT, + KNOWLEDGE_ON_JSON_FOR_ATTACKS, + KNOWLEDGE_ON_JSON_FOR_ATT_CONNS +// KNOWLEDGE_ON_JSON_FOR_MITIGATIONS, +// KNOWLEDGE_ON_JSON_FOR_MITI_PAIRS + }; + + private final String[] QUESTION_IDENTIFY_ATD = {"From the provided system specification and attack pattern " + + "and using the specified JSON format, identify what " + + "the root attack is of the attack pattern and how it can be used to exploit the system specification. " + + "Do respect the JSON format, and " + + "provide only JSON (no explanation before or after).\n", + + "From the provided system specification, attack pattern, and root attack and " + + "using the specified JSON format, identify the attack nodes that an attacker needs to " + + "complete for achieving the root attack. Do respect the JSON format, and " + + "provide only JSON (no explanation before or after).\n", + + "From the provided system specification, attack pattern, root attack, and attack nodes and " + + "using the specified JSON format, identify the connections that illustrate in what order an " + + "attacker needs to complete the attack nodes to achieve the root attack. " + + "Do respect the JSON format, and provide only JSON (no explanation before or after).\n", + + "From the provided system specification, attack pattern, and attack nodes " + + "and using the specified JSON format, " + + "identify possible mitigations, if there are any, that could prevent an attacker from completing " + + "an attack node. Do respect the JSON format, and provide only JSON (no explanation before or after).\n", + + "From the provided system specification, attack pattern, attack nodes, and mitigations " + + "and using the specified JSON format, " + + "identify what mitigation can be applied to which attack node such that the mitigation prevents " + + "an attacker from performing the attack node. " + + "Do respect the JSON format, and provide only JSON (no explanation before or after).\n" + }; + + private String rootAttackData; + private String attackNodeData; + private String attConnections; + private String mitigations; + private String mitiPairs; + private AttackTree at; + + public AIAttackPatternTree2(AIChatData _chatData) { + super(_chatData); + + at = new AttackTree("", null); + } + + public AttackTree getATDiagram() { + return at; + } + + public void internalRequest() { + at = new AttackTree("", null); + int stage = 0; + String lastQuestion = chatData.lastQuestion.trim(); + String[] data = lastQuestion.split("\n\n"); + String systemSpec = data[0]; + String attackPattern = data[1]; + boolean apContainsMitigations = attackPattern.contains("Mitigations:"); + + String json = ""; + + String questionT = QUESTION_IDENTIFY_ATD[stage]; + + initKnowledge(); + makeKnowledge(stage, systemSpec, attackPattern); + + boolean done = false; + int cpt = 0; + + // actors, use cases and connections + while (!done && cpt < 40) { + cpt++; + boolean ok = makeQuestion(questionT); + + if (!ok) { + done = true; + TraceManager.addDev("Make question failed"); + } + + ArrayList<String> errors = null; + + try { + TraceManager.addDev("\n\nMaking specification from " + chatData.lastAnswer + "\n\n"); + json = extractJSON(); + + if (stage == 0) { + rootAttackData = ""; + + errors = new ArrayList<>(); + rootAttackData = checkRootAttack(json, errors); + TraceManager.addDev("Identified root attack - " + rootAttackData); + + if (rootAttackData.isEmpty()) { + errors.add("You must provide the root attack of " + + "the given attack pattern and how it can be used to " + + "exploit the provided system specification. Do respect the JSON format, and " + + "provide only JSON (no explanation before or after)."); + } + } + else if (stage == 1) { + attackNodeData = ""; + + errors = new ArrayList<>(); + attackNodeData = checkAttackNodes(json, errors); + TraceManager.addDev("Identified attack nodes: " + attackNodeData); + + if (attackNodeData.isEmpty()) { + errors.add("You must provide the attack nodes showing how an " + + "attacker uses these nodes to achieve the root attack. " + + "Do respect the JSON format, and " + + "provide only JSON (no explanation before or after)."); + } + } + else if (stage == 2) { + attConnections = ""; + + errors = new ArrayList<>(); + attConnections = checkAttackConns(json, errors); + TraceManager.addDev("Identified attack connections: " + attConnections); + + if (attConnections.isEmpty()) { + errors.add("You must provide the connections showing in what order" + + "an attacker needs to complete the attack nodes to achieve the " + + "root attack. Do respect the JSON format, and " + + "provide only JSON (no explanation before or after)."); + } + } + else if (stage == 3) { + mitigations = ""; + + errors = new ArrayList<>(); + mitigations = checkMitigations(json, errors, apContainsMitigations); + + TraceManager.addDev("Identified mitigations: " + mitigations); + + if (mitigations.isEmpty()) { + errors.add("You must provide mitigations showing how they prevent " + + "an attacker from performing an attack node. " + + "Do respect the JSON format, and " + + "provide only JSON (no explanation before or after)."); + } + } + else if (stage == 4) { + mitiPairs = ""; + + errors = new ArrayList<>(); + mitiPairs = checkMitiPairs(json, errors); + + TraceManager.addDev("Identified mitigation pairings: " + mitiPairs); + + if (mitiPairs.isEmpty()) { + errors.add("You must associate the provided mitigations with a provided attack node " + + "such that the mitigation prevents an attacker from performing the attack node. " + + "Do respect the JSON format, and " + + "provide only JSON (no explanation before or after)."); + } + } + } catch (org.json.JSONException e) { + TraceManager.addDev("Invalid JSON spec: " + extractJSON() + " because " + e.getMessage() + ": INJECTING ERROR"); + errors = new ArrayList<>(); + errors.add("There is an error in your JSON: " + e.getMessage() + ". Probably the JSON spec was incomplete. Do correct it. I need " + + "the full specification at once."); + } + + if ((errors != null) && (!errors.isEmpty())) { + questionT = "Your answer was as follows: " + json + "\n\nYet, it was not correct because of the following errors:"; + // Updating knowledge + for (String s : errors) { + questionT += "\n- " + s; + } + + initKnowledge(); + makeKnowledge(stage, systemSpec, attackPattern); + } else { + stage++; + + if (stage == KNOWLEDGE_STAGES.length) { + done = true; + } else { + initKnowledge(); + makeKnowledge(stage, systemSpec, attackPattern); + questionT = QUESTION_IDENTIFY_ATD[stage]; + + if (stage == 1) { + questionT += "\nThe root attack data is in the following JSON:\n" + rootAttackData.trim() + "\n"; + } + else if (stage == 2) { + questionT += "\nThe root attack data is in the following JSON:\n" + rootAttackData.trim() + "\n"; + questionT += "\nThe data of the attack nodes are in the following JSON:\n" + + attackNodeData.trim() + "\n"; + } + else if (stage == 3) { + questionT += "\nThe data of the attack nodes are in the following JSON:\n" + attackNodeData.trim(); + } + else if (stage == 4) { + if (mitigations.toLowerCase().contains("no mitigations were able to be identified")) { + done = true; + } + else { + questionT += "\nThe data of the attack nodes are in the following JSON:\n" + attackNodeData.trim(); + questionT += "\nThe data of the mitigations are in the following JSON:\n" + mitigations.trim(); + } + } + } + } + + waitIfConditionTrue(!done && cpt < 20); + + cpt++; + } + + TraceManager.addDev("Reached end of AIAttackPatternTree internal request cpt=" + cpt); + } + + public Object applyAnswer(Object input) { return at; } + + private void initKnowledge() { + chatData.aiinterface.clearKnowledge(); + } + + private void makeKnowledge(int stage, String _spec, String _attackPattern) { + String [] know = KNOWLEDGE_STAGES[stage].split("#"); + for(String s: know) { + TraceManager.addDev("\nKnowledge added: " + s); + chatData.aiinterface.addKnowledge(s, "ok"); + } + + if (_spec != null) { + TraceManager.addDev("\nKnowledge added: " + _spec); + chatData.aiinterface.addKnowledge("The system specification is: " + _spec, "ok"); + } + + if (_attackPattern != null) { + TraceManager.addDev("\nKnowledge added: " + _attackPattern); + chatData.aiinterface.addKnowledge("The attack pattern is: " + _attackPattern, "ok"); + } + } + + private String checkRootAttack(String _spec, Collection<String> _errors) throws org.json.JSONException { + if (_spec == null) { + _errors.add("No \"rootattack\" object in json"); + + return ""; + } + + int indexStart = _spec.indexOf('{'); + int indexStop = _spec.lastIndexOf('}'); + + if ((indexStart == -1) || (indexStop == -1) || (indexStart > indexStop)) { + _errors.add("Invalid JSON object (start or stop)"); + + return ""; + } + + String json = _spec.substring(indexStart, indexStop + 1); + + JSONObject mainObject = new JSONObject(json); + JSONObject rootAttackJSON = mainObject.getJSONObject("rootattack"); + + if (rootAttackJSON == null) { + TraceManager.addDev("No \"rootattack\" array in json"); + _errors.add("No \"rootattack\" array in json"); + + return ""; + } + else { + _errors.addAll(checkAttackSyntax(rootAttackJSON)); + } + + if (_errors.isEmpty()) { + Attack rootAttack = new Attack("", null); + rootAttack.setRoot(true); + rootAttack.setName(rootAttackJSON.getString("name")); + rootAttack.setDescription(rootAttackJSON.getString("description")); + at.addAttack(rootAttack); + + return _spec; + } + else { + return ""; + } + } + + private String checkAttackNodes(String _spec, Collection<String> _errors) throws org.json.JSONException { + if (_spec == null) { + _errors.add("No \"attacknodes\" array in json"); + + return ""; + } + + int indexStart = _spec.indexOf('{'); + int indexStop = _spec.lastIndexOf('}'); + + if ((indexStart == -1) || (indexStop == -1) || (indexStart > indexStop)) { + _errors.add("Invalid JSON object (start or stop)"); + + return ""; + } + + String json = _spec.substring(indexStart, indexStop + 1); + + JSONObject mainObject = new JSONObject(json); + JSONArray attackNodesJSON = mainObject.getJSONArray("attacknodes"); + + if (attackNodesJSON == null) { + TraceManager.addDev("No \"attacknodes\" array in json"); + _errors.add("No \"attacknodes\" array in json"); + + return ""; + } + else { + int i = 0; + + while (i < attackNodesJSON.length() && _errors.isEmpty()) + { + _errors.addAll(checkAttackSyntax(attackNodesJSON.getJSONObject(i))); + i++; + } + } + + if (_errors.isEmpty()) { + for (int i = 0; i < attackNodesJSON.length(); i++) { + JSONObject attackNode = attackNodesJSON.getJSONObject(i); + + Attack attack = new Attack("", null); + attack.setName(attackNode.getString("name")); + attack.setDescription(attackNode.getString("description")); + at.addAttack(attack); + } + + return _spec; + } + else { + return ""; + } + } + + private ArrayList<String> checkAttackSyntax(JSONObject attack) { + ArrayList<String> errors = new ArrayList<>(); + + if (attack.length() != 2) { + errors.add("The attack object should only contain \"name\" and \"description\""); + + return errors; + } + + String name = attack.getString("name"); + String description = attack.getString("description"); + + if (name == null) { + errors.add("Attack has no name"); + } + else if (!name.matches("[a-zA-Z0-9]+")) { + errors.add(name + " must only contain alphanumeric characters"); + } + else if (name.length() > 40) { + errors.add(name + " must only contain forty characters max"); + } + else if (description == null) { + errors.add("Attack has no description"); + } + else if (!description.contains(" ")) { + errors.add("The words in \"description\" must be separated with spaces."); + } + + return errors; + } + + private String checkAttackConns(String _spec, Collection<String> _errors) throws org.json.JSONException { + if (_spec == null) { + _errors.add("No \"attackconnections\" array in json"); + + return ""; + } + + int indexStart = _spec.indexOf('{'); + int indexStop = _spec.lastIndexOf('}'); + + if ((indexStart == -1) || (indexStop == -1) || (indexStart > indexStop)) { + _errors.add("Invalid JSON object (start or stop)"); + + return ""; + } + + String json = _spec.substring(indexStart, indexStop + 1); + + JSONObject mainObject = new JSONObject(json); + JSONArray attackConnsJSON = mainObject.getJSONArray("attackconnections"); + + if (attackConnsJSON == null) { + TraceManager.addDev("No \"attackconnections\" array in json"); + _errors.add("No \"attackconnections\" array in json"); + + return ""; + } + else { + int i = 0; + + while (i < attackConnsJSON.length() && _errors.isEmpty()) + { + _errors.addAll(checkAttConnSyntax(attackConnsJSON.getJSONObject(i))); + i++; + } + } + + if (_errors.isEmpty()) { + AttackNode connection = null; + + for (int i = 0; i < attackConnsJSON.length(); i++) { + JSONObject attackConn = attackConnsJSON.getJSONObject(i); + String parentAttackName = attackConn.getString("parentattack"); + String connectionType = attackConn.getString("connectiontype"); + JSONArray childrenAttackNames = attackConn.getJSONArray("childrenattacks"); + List<Attack> attackList = at.getAttacks(); + + switch(connectionType) { + case "OR": + connection = new ORNode("", null); + break; + case "XOR": + connection = new XORNode("", null); + break; + case "AND": + connection = new ANDNode("", null); + break; + default: + connection = new SequenceNode("", null); + break; + } + + Attack parentAttack = findAttack(parentAttackName, attackList, true); + connection.setResultingAttack(parentAttack); + int attackValue = 1; + + for (Object childAttackName : childrenAttackNames) { + String childAttackNameString = (String) childAttackName; + Attack childAttack = findAttack(childAttackNameString, attackList, false); + + connection.addInputAttack(childAttack, attackValue); + + attackValue++; + } + + at.addNode(connection); + } + + return _spec; + } + else { + return ""; + } + } + + private ArrayList<String> checkAttConnSyntax(JSONObject attackConn) { + ArrayList<String> errors = new ArrayList<>(); + + if (attackConn.length() != 3) { + errors.add("The attack connection object should only contain " + + "\"parentattack\", \"connectiontype\", and \"childrenattacks\"."); + + return errors; + } + + String parentAttack = attackConn.getString("parentattack"); + String connectionType = attackConn.getString("connectiontype"); + JSONArray childrenAttacks = attackConn.getJSONArray("childrenattacks"); + + if (parentAttack == null) { + errors.add("Connection has no parentattack"); + return errors; + } + + Attack foundParentAttack = findAttack(parentAttack, at.getAttacks(), true); + + if (foundParentAttack == null) { + errors.add(parentAttack + " is not the name of any provided root attack " + + "or attack node. Ensure that \"parentattack\" is the name of either the provided root" + + "attack or one of the attack nodes."); + } + + if (connectionType == null) { + errors.add("Connection has no connectiontype"); + } + else if (!connectionType.equals("OR") && + !connectionType.equals("XOR") && + !connectionType.equals("AND") && + !connectionType.equals("SEQUENCE")) { + errors.add("Ensure that connectiontype is one of the following types: \"OR\", \"XOR\", \"AND\", " + + "or \"SEQUENCE\""); + } + + if (childrenAttacks == null) { + errors.add("Connection has no childrenattacks"); + return errors; + } + else if (childrenAttacks.length() <= 1) { + errors.add("Connection has only one child attack. Please ensure that all connections have at least two " + + "children attacks."); + return errors; + } + + for (Object childAttack : childrenAttacks) { + String childAttackString = (String) childAttack; + Attack foundChildAttack = findAttack(childAttackString, at.getAttacks(), false); + + if (foundChildAttack == null) { + errors.add(childAttackString + " is not the name of any provided attack node. " + + "Ensure that all values in the \"childrenattacks\" array are the name of an attack node."); + } + } + + return errors; + } + + private Attack findAttack(String attackNodeName, List<Attack> attackNodeList, boolean checkRoot) { + Attack attack = null; + int i; + + if (!checkRoot) { + i = 1; + } + else { + i = 0; + } + + while (i < attackNodeList.size()) { + attack = attackNodeList.get(i); + + if (attackNodeName.equals(attack.getName())) { + return attack; + } + + i++; + } + + return null; + } + + private String checkMitigations(String _spec, Collection<String> _errors, boolean containsMiti) + throws org.json.JSONException { + String mitigationList = ""; + + if (_spec == null) { + _errors.add("No \"mitigations\" array in json"); + + return ""; + } + + int indexStart = _spec.indexOf('{'); + int indexStop = _spec.lastIndexOf('}'); + + if ((indexStart == -1) || (indexStop == -1) || (indexStart > indexStop)) { + _errors.add("Invalid JSON object (start or stop)"); + + return ""; + } + + String json = _spec.substring(indexStart, indexStop + 1); + + JSONObject mainObject = new JSONObject(json); + Object mitigations = mainObject.get("mitigations"); + + if (mitigations instanceof String && + mitigations.equals("No mitigations were able to be identified.")) { + + if (!containsMiti) { + return (String) mitigations; + } + else { + _errors.add("You specified that there were no mitigations able to be identified. Yet, the " + + "provided attack pattern contains a \"Mitigations:\" section."); + + return ""; + } + } + + JSONArray mitigationsJSON = (JSONArray) mitigations; + + if (mitigationsJSON == null) { + TraceManager.addDev("No \"mitigations\" array in json"); + _errors.add("No \"mitigations\" array in json"); + + return ""; + } + else { + int i = 0; + + while (i < mitigationsJSON.length() && _errors.isEmpty()) + { + _errors.addAll(checkMitigationSyntax(mitigationsJSON.getJSONObject(i))); + i++; + } + } + + if (_errors.isEmpty()) { +// for (int i = 0; i < mitigationsJSON.length(); i++) { +// JSONObject attackStep = mitigationsJSON.getJSONObject(i); +// +// Attack attack = new Attack("", null); +// attack.setName(attackStep.getString("name")); +// attack.setDescription(attackStep.getString("description")); +// at.addAttack(attack); +// +// mitigations += attack.toString(); +// } + + return _spec; + } + else { + return ""; + } + } + + private ArrayList<String> checkMitigationSyntax(JSONObject mitigation) { + ArrayList<String> errors = new ArrayList<>(); + + if (mitigation.length() != 2) { + errors.add("The mitigation object should only contain \"name\" and \"description\""); + + return errors; + } + + String name = mitigation.getString("name"); + String description = mitigation.getString("description"); +// String attackNodeName = mitigation.getString("attacknode"); + + if (name == null) { + errors.add("Mitigation has no name"); + } + else if (!name.matches("[a-zA-Z0-9]+")) { + errors.add(name + " must only contain alphanumeric characters"); + } + else if (name.length() > 40) { + errors.add(name + " must only contain forty characters max"); + } + else if (description == null) { + errors.add("Mitigation has no description"); + } + else if (!description.contains(" ")) { + errors.add("The words in \"description\" must be separated with spaces."); + } +// else if (attackNodeName == null) { +// errors.add("Mitigation has no associated attack node"); +// } +// else if (attackNodeName.equals(at.getAttacks().get(0).getName())) { +// errors.add(attackNodeName + " is the name of the root attack. " + +// "Do not associate mitigations with the root attack"); +// } + +// ArrayList<Attack> attackNodeList = at.getAttacks(); +// ArrayList<String> differentAttackNames = getAllDiffAttackNodeNames(attackNodeName, attackNodeList); +// if (differentAttackNames.size() == at.getAttacks().size() - 1) { +// StringBuilder builder = new StringBuilder( +// attackNodeName + " in \"attacknode\" is not the name of any of the provided attack nodes. " + +// "The value of \"attacknode\" can only be one of the following names: "); +// +// for (int i = 0; i < differentAttackNames.size(); i++) { +// if (i != differentAttackNames.size() - 1) { +// builder.append("\"").append(differentAttackNames.get(i)).append("\"").append(", "); +// } +// else { +// builder.append("or ").append("\"").append(differentAttackNames.get(i)).append("\""); +// } +// } + +// errors.add(builder.toString()); +// } + + return errors; + } + + private String checkMitiPairs(String _spec, Collection<String> _errors) + throws org.json.JSONException { + String mitigationList = ""; + + if (_spec == null) { + _errors.add("No \"mitigationpairings\" array in json"); + + return ""; + } + + int indexStart = _spec.indexOf('{'); + int indexStop = _spec.lastIndexOf('}'); + + if ((indexStart == -1) || (indexStop == -1) || (indexStart > indexStop)) { + _errors.add("Invalid JSON object (start or stop)"); + + return ""; + } + + String json = _spec.substring(indexStart, indexStop + 1); + + JSONObject mainObject = new JSONObject(json); + JSONArray mitigationPairingsJSON = mainObject.getJSONArray("mitigationpairings"); + + if (mitigationPairingsJSON == null) { + TraceManager.addDev("No \"mitigationpairings\" array in json"); + _errors.add("No \"mitigationpairings\" array in json"); + + return ""; + } + else { + int i = 0; + + while (i < mitigationPairingsJSON.length() && _errors.isEmpty()) + { + _errors.addAll(checkMitiPairingSyntax(mitigationPairingsJSON.getJSONObject(i))); + i++; + } + } + + if (_errors.isEmpty()) { +// for (int i = 0; i < mitigationsJSON.length(); i++) { +// JSONObject attackStep = mitigationsJSON.getJSONObject(i); +// +// Attack attack = new Attack("", null); +// attack.setName(attackStep.getString("name")); +// attack.setDescription(attackStep.getString("description")); +// at.addAttack(attack); +// +// mitigations += attack.toString(); +// } + + return _spec; + } + else { + return ""; + } + } + + private ArrayList<String> checkMitiPairingSyntax(JSONObject mitigation) { + ArrayList<String> errors = new ArrayList<>(); + + if (mitigation.length() != 2) { + errors.add("The mitigation pairing object should only contain \"mitigation\" and \"attacknode\""); + + return errors; + } + + String mitigationName = mitigation.getString("mitigation"); + String attackNodeName = mitigation.getString("attacknode"); + + if (mitigationName == null) { + errors.add("Pairing has no mitigation"); + } + else if (attackNodeName == null) { + errors.add("Pairing has no attack node"); + } + + Attack foundAttackNode = findAttack(attackNodeName, at.getAttacks(), false); + + if (foundAttackNode == null) { + errors.add(attackNodeName + " is not the name of any provided attack node. " + + "Ensure that the value of \"attacknode\" is the name of an attack node."); + } + + return errors; + } +} diff --git a/src/main/java/ai/CAPECTracer.java b/src/main/java/ai/CAPECTracer.java index 28f547f7cda71408ff1291b7512369bc9172deff..95753adf877e9672d72d835ac0de662f18e0bb77 100644 --- a/src/main/java/ai/CAPECTracer.java +++ b/src/main/java/ai/CAPECTracer.java @@ -6,7 +6,6 @@ import java.io.*; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.Objects; /** * Class AIUseCaseDiagram @@ -26,15 +25,12 @@ public class CAPECTracer extends AIInteract { @Override public void internalRequest() { chatData.feedback.addToChat(QUESTIONTRACECAPECS, true); - String systemSpec = chatData.lastQuestion.trim(); - Path capecTracerFolder = Paths.get(Objects.requireNonNull( - this.getClass().getResource("../capec_tracer")).getPath()); - - writeToFile(capecTracerFolder + "/system_specs.txt", systemSpec); + Path projectPath = Paths.get("../capectracer").normalize().toAbsolutePath(); + writeToFile(projectPath + "/system_specs.txt", systemSpec); - String results = runCapecTracer(capecTracerFolder.toString()); + String results = runCapecTracer(projectPath.toString()); chatData.feedback.addToChat(results, false); } diff --git a/src/main/java/attacktrees/Attack.java b/src/main/java/attacktrees/Attack.java index cd319717423203b4e7217d5c6518a434298b161b..5898e0b8082ad33ff922cb6cfd090d7e82331c0c 100644 --- a/src/main/java/attacktrees/Attack.java +++ b/src/main/java/attacktrees/Attack.java @@ -39,7 +39,6 @@ package attacktrees; -import java.awt.*; import java.util.ArrayList; @@ -67,6 +66,7 @@ public class Attack extends AttackElement { private int attackCost; private int attackExperience; + private String description; public Attack(String _name, Object _referenceObject) { super(_name, _referenceObject); @@ -113,6 +113,9 @@ public class Attack extends AttackElement { return attackExperience; } + public String getDescription() { return description; } + + public void setDescription(String description) { this.description = description; } public boolean isLeaf() { return (originNode == null); @@ -153,4 +156,14 @@ public class Attack extends AttackElement { } + public String toString(boolean showIfRoot) { + String data = "Attack name: " + super.getName() + " Description: " + this.description; + + if (showIfRoot) { + data += " Is root: " + isRoot; + } + + data += "\n"; + return data; + } } diff --git a/src/main/java/ui/window/JFrameAI.java b/src/main/java/ui/window/JFrameAI.java index 1955d1b06503796a1291ae697edec251b83d8875..4570dcf21d4e211ecf7e82eb5efa2ce687ef26a0 100644 --- a/src/main/java/ui/window/JFrameAI.java +++ b/src/main/java/ui/window/JFrameAI.java @@ -89,12 +89,18 @@ public class JFrameAI extends JFrame implements ActionListener { "DEPRECATED - Identify state machines - Select a block diagram. Additionally, you can provide a system specification", "Identify state machines and attributes - Select a block diagram. Additionally, you can provide a system specification", "A(I)MULET - Select a block diagram first", - "Capec tracer - Identify the possible attack patterns that an attacker could use to exploit your system specifications."}; + "Capec tracer - Identify the possible attack patterns that an attacker could use to exploit your system specifications.", + "Attack Tree Generator, Pipeline 1 - Creates an Attack Tree diagram that models a given attack pattern using " + + "a provided system specification.", + "Attack Tree Generator, Pipeline 2 - Creates an Attack Tree diagram that models a given attack pattern using " + + "a provided system specification and an attack pattern." + }; private static String[] AIInteractClass = {"AIChat", "AIReqIdent", "AIReqClassification", "AIUseCaseDiagram", "AIDesignPropertyIdentification", "AIBlock", "AIBlockConnAttrib", "AIBlockConnAttribWithSlicing", "AISoftwareBlock", "AIStateMachine", "AIStateMachinesAndAttributes", "AIAmulet", - "CAPECTracer"}; + "CAPECTracer", "AIAttackPatternTree1", "AIAttackPatternTree2" + }; private static String[] INFOS = {"Chat on any topic you like", "Identify requirements from the specification of a system", "Classify " + "requirements from a requirement diagram", "Identify use cases and actors from a system specification", @@ -114,7 +120,12 @@ public class JFrameAI extends JFrame implements ActionListener { "Formalize mutations to be performed on a block diagram", "Identify the possible attack patterns that an attacker could use to exploit your system specifications. " + "Each identified attack pattern will have a confidence score of TTool's estimation on how related an attack pattern is " + - "to the provided system specifications."}; + "to the provided system specifications.", + "Using a provided system specification, create an Attack Tree diagram that models the steps that an attacker " + + "would need to take to exploit an identified root attack on the given system specification.", + "Using a provided system specification and an attack pattern, create an Attack Tree diagram that models the steps " + + "that an attacker would need to take to exploit the given attack pattern on the given system specifications." + }; protected JComboBox<String> listOfPossibleActions; protected JComboBox<String> listOfPossibleModels;