rsa: add sha256-rsa2048 algorithm

based on patch from andreas@oetken.name:

http://patchwork.ozlabs.org/patch/294318/
commit message:
I currently need support for rsa-sha256 signatures in u-boot and found out that
the code for signatures is not very generic. Thus adding of different
hash-algorithms for rsa-signatures is not easy to do without copy-pasting the
rsa-code. I attached a patch for how I think it could be better and included
support for rsa-sha256. This is a fast first shot.

aditionally work:
- removed checkpatch warnings
- removed compiler warnings
- rebased against current head

Signed-off-by: Heiko Schocher <hs@denx.de>
Cc: andreas@oetken.name
Cc: Simon Glass <sjg@chromium.org>
diff --git a/lib/rsa/Makefile b/lib/rsa/Makefile
index 164ab39..a5a96cb6 100644
--- a/lib/rsa/Makefile
+++ b/lib/rsa/Makefile
@@ -7,4 +7,4 @@
 # SPDX-License-Identifier:	GPL-2.0+
 #
 
-obj-$(CONFIG_FIT_SIGNATURE) += rsa-verify.o
+obj-$(CONFIG_FIT_SIGNATURE) += rsa-verify.o rsa-checksum.o
diff --git a/lib/rsa/rsa-checksum.c b/lib/rsa/rsa-checksum.c
new file mode 100644
index 0000000..e520e1c
--- /dev/null
+++ b/lib/rsa/rsa-checksum.c
@@ -0,0 +1,98 @@
+/*
+ * Copyright (c) 2013, Andreas Oetken.
+ *
+ * SPDX-License-Identifier:    GPL-2.0+
+ */
+
+#include <common.h>
+#include <fdtdec.h>
+#include <rsa.h>
+#include <sha1.h>
+#include <sha256.h>
+#include <asm/byteorder.h>
+#include <asm/errno.h>
+#include <asm/unaligned.h>
+
+#define RSA2048_BYTES 256
+
+/* PKCS 1.5 paddings as described in the RSA PKCS#1 v2.1 standard. */
+
+const uint8_t padding_sha256_rsa2048[RSA2048_BYTES - SHA256_SUM_LEN] = {
+0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x30, 0x31, 0x30,
+0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05,
+0x00, 0x04, 0x20
+};
+
+const uint8_t padding_sha1_rsa2048[RSA2048_BYTES - SHA1_SUM_LEN] = {
+	0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+	0xff, 0xff, 0xff, 0xff, 0x00, 0x30, 0x21, 0x30,
+	0x09, 0x06, 0x05, 0x2b, 0x0e, 0x03, 0x02, 0x1a,
+	0x05, 0x00, 0x04, 0x14
+};
+
+void sha1_calculate(const struct image_region region[], int region_count,
+		    uint8_t *checksum)
+{
+	sha1_context ctx;
+	uint32_t i;
+	i = 0;
+
+	sha1_starts(&ctx);
+	for (i = 0; i < region_count; i++)
+		sha1_update(&ctx, region[i].data, region[i].size);
+	sha1_finish(&ctx, checksum);
+}
+
+void sha256_calculate(const struct image_region region[], int region_count,
+		      uint8_t *checksum)
+{
+	sha256_context ctx;
+	uint32_t i;
+	i = 0;
+
+	sha256_starts(&ctx);
+	for (i = 0; i < region_count; i++)
+		sha256_update(&ctx, region[i].data, region[i].size);
+	sha256_finish(&ctx, checksum);
+}
diff --git a/lib/rsa/rsa-sign.c b/lib/rsa/rsa-sign.c
index 549130e..0fe6e9f 100644
--- a/lib/rsa/rsa-sign.c
+++ b/lib/rsa/rsa-sign.c
@@ -159,8 +159,9 @@
 	EVP_cleanup();
 }
 
-static int rsa_sign_with_key(RSA *rsa, const struct image_region region[],
-		int region_count, uint8_t **sigp, uint *sig_size)
+static int rsa_sign_with_key(RSA *rsa, struct checksum_algo *checksum_algo,
+		const struct image_region region[], int region_count,
+		uint8_t **sigp, uint *sig_size)
 {
 	EVP_PKEY *key;
 	EVP_MD_CTX *context;
@@ -192,7 +193,7 @@
 		goto err_create;
 	}
 	EVP_MD_CTX_init(context);
-	if (!EVP_SignInit(context, EVP_sha1())) {
+	if (!EVP_SignInit(context, checksum_algo->calculate())) {
 		ret = rsa_err("Signer setup failed");
 		goto err_sign;
 	}
@@ -242,7 +243,8 @@
 	ret = rsa_get_priv_key(info->keydir, info->keyname, &rsa);
 	if (ret)
 		goto err_priv;
-	ret = rsa_sign_with_key(rsa, region, region_count, sigp, sig_len);
+	ret = rsa_sign_with_key(rsa, info->algo->checksum, region,
+				region_count, sigp, sig_len);
 	if (ret)
 		goto err_sign;
 
diff --git a/lib/rsa/rsa-verify.c b/lib/rsa/rsa-verify.c
index 02cc4e3..b3573a8 100644
--- a/lib/rsa/rsa-verify.c
+++ b/lib/rsa/rsa-verify.c
@@ -8,23 +8,11 @@
 #include <fdtdec.h>
 #include <rsa.h>
 #include <sha1.h>
+#include <sha256.h>
 #include <asm/byteorder.h>
 #include <asm/errno.h>
 #include <asm/unaligned.h>
 
-/**
- * struct rsa_public_key - holder for a public key
- *
- * An RSA public key consists of a modulus (typically called N), the inverse
- * and R^2, where R is 2^(# key bits).
- */
-struct rsa_public_key {
-	uint len;		/* Length of modulus[] in number of uint32_t */
-	uint32_t n0inv;		/* -1 / modulus[0] mod 2^32 */
-	uint32_t *modulus;	/* modulus as little endian array */
-	uint32_t *rr;		/* R^2 as little endian array */
-};
-
 #define UINT64_MULT32(v, multby)  (((uint64_t)(v)) * ((uint32_t)(multby)))
 
 #define RSA2048_BYTES	(2048 / 8)
@@ -36,39 +24,6 @@
 /* This is the maximum signature length that we support, in bits */
 #define RSA_MAX_SIG_BITS	2048
 
-static const uint8_t padding_sha1_rsa2048[RSA2048_BYTES - SHA1_SUM_LEN] = {
-	0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-	0xff, 0xff, 0xff, 0xff, 0x00, 0x30, 0x21, 0x30,
-	0x09, 0x06, 0x05, 0x2b, 0x0e, 0x03, 0x02, 0x1a,
-	0x05, 0x00, 0x04, 0x14
-};
-
 /**
  * subtract_modulus() - subtract modulus from the given value
  *
@@ -209,13 +164,14 @@
 }
 
 static int rsa_verify_key(const struct rsa_public_key *key, const uint8_t *sig,
-		const uint32_t sig_len, const uint8_t *hash)
+			  const uint32_t sig_len, const uint8_t *hash,
+			  struct checksum_algo *algo)
 {
 	const uint8_t *padding;
 	int pad_len;
 	int ret;
 
-	if (!key || !sig || !hash)
+	if (!key || !sig || !hash || !algo)
 		return -EIO;
 
 	if (sig_len != (key->len * sizeof(uint32_t))) {
@@ -223,6 +179,8 @@
 		return -EINVAL;
 	}
 
+	debug("Checksum algorithm: %s", algo->name);
+
 	/* Sanity check for stack size */
 	if (sig_len > RSA_MAX_SIG_BITS / 8) {
 		debug("Signature length %u exceeds maximum %d\n", sig_len,
@@ -238,9 +196,8 @@
 	if (ret)
 		return ret;
 
-	/* Determine padding to use depending on the signature type. */
-	padding = padding_sha1_rsa2048;
-	pad_len = RSA2048_BYTES - SHA1_SUM_LEN;
+	padding = algo->rsa_padding;
+	pad_len = RSA2048_BYTES - algo->checksum_len;
 
 	/* Check pkcs1.5 padding bytes. */
 	if (memcmp(buf, padding, pad_len)) {
@@ -309,7 +266,7 @@
 	}
 
 	debug("key length %d\n", key.len);
-	ret = rsa_verify_key(&key, sig, sig_len, hash);
+	ret = rsa_verify_key(&key, sig, sig_len, hash, info->algo->checksum);
 	if (ret) {
 		printf("%s: RSA failed to verify: %d\n", __func__, ret);
 		return ret;
@@ -323,12 +280,22 @@
 	       uint8_t *sig, uint sig_len)
 {
 	const void *blob = info->fdt_blob;
-	uint8_t hash[SHA1_SUM_LEN];
+	/* Reserve memory for maximum checksum-length */
+	uint8_t hash[RSA2048_BYTES];
 	int ndepth, noffset;
 	int sig_node, node;
 	char name[100];
-	sha1_context ctx;
-	int ret, i;
+	int ret;
+
+	/*
+	 * Verify that the checksum-length does not exceed the
+	 * rsa-signature-length
+	 */
+	if (info->algo->checksum->checksum_len > RSA2048_BYTES) {
+		debug("%s: invlaid checksum-algorithm %s for RSA2048\n",
+		      __func__, info->algo->checksum->name);
+		return -EINVAL;
+	}
 
 	sig_node = fdt_subnode_offset(blob, 0, FIT_SIG_NODENAME);
 	if (sig_node < 0) {
@@ -336,10 +303,8 @@
 		return -ENOENT;
 	}
 
-	sha1_starts(&ctx);
-	for (i = 0; i < region_count; i++)
-		sha1_update(&ctx, region[i].data, region[i].size);
-	sha1_finish(&ctx, hash);
+	/* Calculate checksum with checksum-algorithm */
+	info->algo->checksum->calculate(region, region_count, hash);
 
 	/* See if we must use a particular key */
 	if (info->required_keynode != -1) {