#include #include #include #include #include "types.hpp" #include "aes.hpp" const int no_traces = 10'000; using namespace std; char read_char(istream &stream) { char buf; stream.get(buf); return buf; } traces_vector_t load_traces() { traces_vector_t traces(no_traces, trace()); //ifstream plaintexts("plaintexts.dat", ios::binary); ifstream ciphertexts("ciphertexts.dat", ios::binary); ifstream faultytexts("faultytexts.dat", ios::binary); for (int trace_no = 0;trace_no < no_traces;trace_no++) { trace ¤t_trace = traces[trace_no]; for (int i = 0;i < 16;i++) { //current_trace.plaintext[i] = read_char(plaintexts); current_trace.ciphertext[i] = read_char(ciphertexts); current_trace.faultytext[i] = read_char(faultytexts); } } if (!ciphertexts.good() || !ciphertexts.good()) { cout << "IO ERROR"; exit(1); } return traces; } int main(int argc, char **argv) { traces_vector_t traces = load_traces(); aes_col key; set candidates; for (int i = 0;i < no_traces;i++) { trace current_master = traces[i]; AES::r_shift_rows(current_master.ciphertext); AES::r_shift_rows(current_master.faultytext); cout << i; cout.flush(); set round_candidates; for (int k0 = 0;k0 < 256;k0++) { key.column[0] = k0; for (int k1 = 0;k1 < 256;k1++) { key.column[1] = k1; for (int k2 = 0;k2 < 256;k2++) { key.column[2] = k2; for (int k3 = 0;k3 < 256;k3++) { key.column[3] = k3; if (i > 0 && round_candidates.find(key) == round_candidates.end()) continue; trace current = current_master; aes_col diff; for (int j = 0;j < 4;j++) { current.ciphertext[j] = AES::r_sbox[current.ciphertext[j] ^ key.column[j]]; current.faultytext[j] = AES::r_sbox[current.faultytext[j] ^ key.column[j]]; diff.column[j] = current.ciphertext[j] ^ current.faultytext[j]; } AES::r_mix_column(diff.column); int zero_bytes = 0; for (int j = 0;j < 4;j++) { if (diff.column[i] == 0) zero_bytes++; } if (zero_bytes >= 3) { round_candidates.insert(key); } } } } } if (i == 0) { candidates = round_candidates; } else { set intersection; set_intersection(candidates.begin(), candidates.end(), round_candidates.begin(), round_candidates.end(), std::inserter(intersection, intersection.begin())); candidates = intersection; } if (candidates.size() == 1) { cout << endl; break; } else if (candidates.size() == 0) { cout << " -> " << candidates.size() << endl; cout << "Error!" << endl; break; } else { cout << " -> " << candidates.size() << endl; } } cout << "Key: "; key = *candidates.begin(); for (int i = 0; i < 4;i++) { cout << key.column[i] << " "; } cout << endl; return 0; }