Files
phyatk10/main.cpp

132 lines
2.8 KiB
C++

#include <iostream>
#include <fstream>
#include <set>
#include <algorithm>
#include "types.hpp"
#include "aes.hpp"
const int no_traces = 10'000;
using namespace std;
char read_char(istream &stream)
{
char buf;
stream.get(buf);
return buf;
}
traces_vector_t load_traces()
{
traces_vector_t traces(no_traces, trace());
//ifstream plaintexts("plaintexts.dat", ios::binary);
ifstream ciphertexts("ciphertexts.dat", ios::binary);
ifstream faultytexts("faultytexts.dat", ios::binary);
for (int trace_no = 0;trace_no < no_traces;trace_no++)
{
trace &current_trace = traces[trace_no];
for (int i = 0;i < 16;i++)
{
//current_trace.plaintext[i] = read_char(plaintexts);
current_trace.ciphertext[i] = read_char(ciphertexts);
current_trace.faultytext[i] = read_char(faultytexts);
}
}
if (!ciphertexts.good() || !ciphertexts.good())
{
cout << "IO ERROR";
exit(1);
}
return traces;
}
int main(int argc, char **argv)
{
traces_vector_t traces = load_traces();
aes_col key;
set<aes_col> candidates;
for (int i = 0;i < no_traces;i++)
{
trace current_master = traces[i];
AES::r_shift_rows(current_master.ciphertext);
AES::r_shift_rows(current_master.faultytext);
cout << i;
cout.flush();
set<aes_col> round_candidates;
for (int k0 = 0;k0 < 256;k0++)
{
key.column[0] = k0;
for (int k1 = 0;k1 < 256;k1++)
{
key.column[1] = k1;
for (int k2 = 0;k2 < 256;k2++)
{
key.column[2] = k2;
for (int k3 = 0;k3 < 256;k3++)
{
key.column[3] = k3;
if (i > 0 && round_candidates.find(key) == round_candidates.end())
continue;
trace current = current_master;
aes_col diff;
for (int j = 0;j < 4;j++)
{
current.ciphertext[j] = AES::r_sbox[current.ciphertext[j] ^ key.column[j]];
current.faultytext[j] = AES::r_sbox[current.faultytext[j] ^ key.column[j]];
diff.column[j] = current.ciphertext[j] ^ current.faultytext[j];
}
AES::r_mix_column(diff.column);
int zero_bytes = 0;
for (int j = 0;j < 4;j++)
{
if (diff.column[i] == 0)
zero_bytes++;
}
if (zero_bytes >= 3)
{
round_candidates.insert(key);
}
}
}
}
}
if (i == 0)
{
candidates = round_candidates;
}
else
{
set<aes_col> intersection;
set_intersection(candidates.begin(), candidates.end(), round_candidates.begin(), round_candidates.end(), std::inserter(intersection, intersection.begin()));
candidates = intersection;
}
if (candidates.size() == 1)
{
cout << endl;
break;
}
else if (candidates.size() == 0)
{
cout << " -> " << candidates.size() << endl;
cout << "Error!" << endl;
break;
}
else
{
cout << " -> " << candidates.size() << endl;
}
}
cout << "Key: ";
key = *candidates.begin();
for (int i = 0; i < 4;i++)
{
cout << key.column[i] << " ";
}
cout << endl;
return 0;
}