Understanding the Secure Hash Algorithm 2: Part 1
A handy reference for understanding one of the most used encryption algorithms: SHA-2.

This is the first in a three part series of articles that break down the SHA-2 algorithm. In part one, we'll go over the initial definitions that are integral to understanding the algorithm and write sample code in C. In part two, we'll compute the hash given the initial values we generated in part one. And then in part three, we'll take another look at the program, comparing it against the defacto implementations and see how we can optimise it to run faster.

SHA-2

SHA-2 (Secure Hash Algorithm 2) is used trillions of times a day to compute cryptographic hashes and is essential for the stability of everything on the web. That being said I'm not exactly sure how it works and thought it would be a good exercise to sit down and break it into digestible chunks.

My reference is the PDF published on the National Institute of Standards and Technology site here.

Here are a few concepts that need to be internalised before moving forward.

Definitions

One way function
A function that is easy to compute but hard to reverse.
Ability to process arbitrary length inputs
A hash function can split up messages into smaller chunks and then operate on them sequentially.
Merkle-Damgard construction
A way to build a cryptographic hashing function that retains the collision resistant properties of the hashing function.
Time to brute-force a hash function
For a hash function where L is the bits per digest, then finding a matching message will take 2 raised to the L evaluations. Known as a pre-image attack.
Avalanche effect
A simple change in the output modifies the output drastically.
Modular arithmetic
Wrap around addition -- used when telling time (modulo 12).

Functions used

The following functions operate on 32-bit words.

  • RIGHT (n) ROTATE = (x right shift n) OR (x left shift (w - n)), x = w bit word, n is integer between 0 (inclusive) and w
  • RIGHT (n) SHIFT = x right shift n, x = w bit word, n integer between 0 (inclusive) and w
  • Addition is modulo 2^32.
  • (1): Ch(x, y, z) = (x AND y) XOR (NOT x AND z)
  • (2): Maj(x, y, z) = (x AND y) XOR (x AND z) XOR (y AND z)
  • (3): Sum(x, 0, 256) = RIGHT (2) ROTATE (x) XOR RIGHT (13) ROTATE (x) XOR RIGHT (22) ROTATE (x)
  • (4): Sum(x, 1, 256) = RIGHT (6) ROTATE (x) XOR RIGHT (11) ROTATE (x) XOR RIGHT (25) ROTATE (x)
  • (5): σ(x, 0, 256) = RIGHT (7) ROTATE (x) XOR RIGHT (18) ROTATE (x) RIGHT (3) SHIFT (x) ROTATE (x)
  • (6): σ(x, 1, 256) = RIGHT (2) ROTATE (x) XOR RIGHT (13) ROTATE (x) XOR RIGHT (10) SHIFT (x)

Initialisation

First, we begin by calculating all the initial values. There are two sets: (1) the fractional portions of the square roots of the first 8 primes; and (2) the fractional portions of the cube roots of the first 64 primes. I have generated them here.

Preparing the message

Next, we prepare the message for processing. For this tutorial, I'll be using "SNSD" as the initial message. This message has a length of 20 bits. We append a single '1' bit to the end (step 1) -- this specifies the end of the message in padded message. Then we add (step 2) k zero bits where k is the smallest number that makes the following equation true: l + 1 + k = 448 mod 512. Solving for k in this message gives us k = 448 - (20 + 1) = 429 zero bits. After that, we add the message length in the last 64 bit block (step 3).

S N S D 1 bit 429 zero bits Length bits
0x53 0x4E 0x53 0x44 0x1 429 * 0x0 0x00000000000014

Reading in the message

Next, we split the message into blocks so we can process it. This step is simple for this short example. The code at the end of the post handles more difficult lengths of messages. We'll reference each N block of 512 as MN and each 32-bit word (W) within the block as MWN. Therefore, for this message we have:

  • Block M11 = 0x534E5344
  • Block M12 = 0x80000000
  • Block M13 = 0x00000000
  • [....blocks of zeroes.....]
  • Block M116 = 0x00000020

Putting it all together

The following (hastily put together, if I might add) C program will calculate the initial values that I showed earlier and the padded message. Look for part 2 where I will cover the hashing function.

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <inttypes.h>

#define PRIME_NUMBERS 2, 3, 5, 7, 11, 13, 17, 19, 23, 29,  \
31, 37, 41, 43, 47, 53, 59, 61, 67, 71, \
73, 79, 83, 89, 97, 101, 103, 107, 109, 113, \
127, 131, 137, 139, 149, 151, 157, 163, 167, 173, \
179, 181, 191, 193, 197, 199, 211, 223, 227, 229, \
233, 239, 241, 251, 257, 263, 269, 271, 277, 281, \
283, 293, 307, 311

struct m_prime {
  unsigned int pos;
  uint64_t sqrt_frac, cbrt_frac;
  unsigned long prime;
  double sqrt, cbrt;
};

struct message_block {
  uint32_t block[16];
  struct message_block * next;
  uint8_t position;
};

struct message_block * new_message_block(void) {
  struct message_block * m = calloc(1, sizeof *m);
  m->next = NULL;
  return m;
}

void free_message_block(struct message_block * m) {
  if (m == NULL) return;
  free_message_block(m->next);
  free(m);
}

struct message_block * pad_message(char * input) {
  size_t len = 0;
  uint32_t * blk;
  uint16_t remainder;
  char ch;
  struct message_block *head, *p;
  head = new_message_block();

  p = head;

  // Determine how many message blocks we need.
  // We need [data][1 bit][zero padding][length]
  // [data] maximum length 447 bits
  // [1 bit] length = 1 bit
  // [zero padding] maximum length 447 bits
  // [length] 64 bits

  
  while (*(input + len) != '\0') {
    len++; // we have one character
    blk = &p->block[p->position];
    ch = *(input + len - 1);

    switch (len % 4) {
      case 1:  *blk |= ch << 24;
              break;
      case 2:  *blk |= ch << 16;
              break;
      case 3:  *blk |= ch << 8;
              break;
      default: *blk |= ch;
              break;
    }

    if (p->position == 15 && (len % 4) == 0) { // At the end of the block
      p->next = new_message_block();
      p = p->next;
    } else { 
      if (len % 4 == 0) p->position++;
    }
  }

  // Append the one after this block.
  p->block[p->position] += (1u << (32 - 8 * (len % 4) - 1));

  // Calculate how many zeroes is needed.
  remainder = (len * 8) % 512;
  if (remainder > 447) { // Get a new block
    p->next = new_message_block();
    p = p->next;
  }

  p->block[14] = (uint32_t) len >> 16;
  p->block[15] = (uint32_t) len;

  return head;
}

uint64_t calculate_fractional(double floating) {
  /* IEEE-754 double-precision
   * S: 1 bit, sign
   * E: 11 bits, exponent
   * M: 53 bits, mantissa
   */

  uint64_t f, exponent, mantissa, fractional;

  f = *((uint64_t *) &(floating));
  exponent = ((f >> 13 * 4) & 0x7ff) - 0x3ff;
  mantissa = f & 0x000fffffffffffff;
  fractional = (((mantissa << exponent) & 0x000ffffffff00000)) >> 5 * 4;  /* Only last 4 bytes). */

  return fractional;
}

struct m_prime * init_prime_list(unsigned int num) {
  /* Initialise variables. */
  unsigned int x;
  const unsigned int primes[]= {PRIME_NUMBERS};
  struct m_prime *p, *mp = malloc(num * sizeof *mp);

  p = mp;

  for (x = 0; x < num; x++) {
    p->prime = primes[x]; 
    p->sqrt = sqrt(primes[x]);
    p->cbrt = cbrt(primes[x]);
    p->sqrt_frac = calculate_fractional(p->sqrt);
    p->cbrt_frac = calculate_fractional(p->cbrt);

    p++;
  }

  return mp;
}

void print_prime_list(struct m_prime * p, unsigned int num) {
  unsigned int x;

  printf("| Numb | Prme |   Sqrt |   Cbrt | Frac (b10) |   Frac (b2) |\n");
  printf("|------|------|--------|--------|------------|-------------|\n");

  for (x = 0; x < num; x++) {
    printf("| %4u | %4lu | %2.4f | %2.4f | %#.8lx | %#.8lx  |\n", 
        x + 1,
        p->prime,
        p->sqrt,
        p->cbrt,
        p->sqrt_frac,
        p->cbrt_frac
    ); 

    p++;
  }
}

int main(int argc, char *argv[]) {
  if (argc > 1) {
    struct m_prime * p;
    struct message_block *b, *h;
    unsigned int i, size = 64,  block_num = 0;

    p = init_prime_list(size); 
    print_prime_list(p, size);

    h = pad_message(argv[1]);
    b = h;

    do {
      for(i = 0; i < 16; i++) {
        printf("Block %2u.%2u: %8x\n", block_num, i + 1, b->block[i]);
      }
      b = b->next;
      block_num++;
    } while (b != NULL);

    free_message_block(h);
    free(p);
  } else {
    printf("Please enter a string.\n");
    return -1;
  }
}