#include <linux/etherdevice.h>
#include <linux/export.h>
#include "iwl-drv.h"
#include "iwl-wapi-sms4.h"

#define SMS4_INIT_UNICAST_PN_FRAG 0x5c36
#define SMS4_HEADER_LEN (SMS4_BLOCK_SIZE + 2)
#define SMS4_MIC_LEN 16

#ifdef CPTCFG_IWLWIFI_SW_WAPI_SMS4_DBG
#define WAPI_DBG_LEVEL KERN_ERR
#define WAPI_DBG_PREFIX "WAPI-DBG "

static void iwl_wapi_sms4_printk_buf(struct iwl_wapi_sms4_ctx *ctx,
				     const char *caller, const char *msg,
				     const u8 *buf, size_t buf_len)
{
	char tmp[3] = {0};
	unsigned int i;

	ctx->str[0] = '\0';
	for (i = 0; i < buf_len; i++) {
		tmp[0] = hex_asc_hi(buf[i]);
		tmp[1] = hex_asc_lo(buf[i]);
		strcat(ctx->str, tmp);
	}
	printk(WAPI_DBG_LEVEL WAPI_DBG_PREFIX "%s - %s (size %d):\n%s\n",
	       caller, msg, buf_len, ctx->str);
}

static void iwl_wapi_sms4_printk_str(const char *caller, const char *msg)
{
	printk(WAPI_DBG_LEVEL WAPI_DBG_PREFIX "%s - %s\n", caller, msg);
}

static void iwl_wapi_sms4_printk_ctx(struct iwl_wapi_sms4_ctx *ctx,
				     const char *msg)
{
	printk(WAPI_DBG_LEVEL WAPI_DBG_PREFIX "%s", msg);
	printk(WAPI_DBG_LEVEL WAPI_DBG_PREFIX
	       "uc_key[0].is_valid=%d mc_key[0].is_valid=%d",
	       ctx->uc_key[0].is_valid, ctx->mc_key[0].is_valid);
	printk(WAPI_DBG_LEVEL WAPI_DBG_PREFIX
	       "uc_key[1].is_valid=%d mc_key[1].is_valid=%d",
	       ctx->uc_key[1].is_valid, ctx->mc_key[1].is_valid);
	printk(WAPI_DBG_LEVEL WAPI_DBG_PREFIX
	       "curr_uc_key_idx=%d curr_mc_key_idx=%d",
	       ctx->curr_uc_key_idx, ctx->curr_mc_key_idx);
}

#else /* defined CPTCFG_IWLWIFI_SW_WAPI_SMS4_DBG */
#define iwl_wapi_sms4_printk_buf(ctx, caller, msg, buf, buf_len)
#define iwl_wapi_sms4_printk_str(caller, msg)
#define iwl_wapi_sms4_printk_ctx(ctx, msg)
#endif /* CPTCFG_IWLWIFI_SW_WAPI_SMS4_DBG */

/* Unicast PN Policy:
 * 1. Always keep even values in 'next to tx'
 * 2. Upon Tx - increment unicast 'next to tx' (by 2) after sending
 * 3. Upon Rx - verify receiving odd number larger than proper 'last rx-ed'
 */

static void iwl_wapi_sms4_init_uc_pn(u8 *pn)
{
	int i;

	for (i = 0; i < SMS4_BLOCK_SIZE; i += 2) {
		pn[i] = SMS4_INIT_UNICAST_PN_FRAG & 0xFF;
		pn[i+1] = (SMS4_INIT_UNICAST_PN_FRAG >> 8) & 0xFF;
	}
}

static void iwl_wapi_sms4_inc_pn(u8 *pn)
{
	int i = -1;

	do {
		i++;
		pn[i] += (i == 0) ? 2 : 1;
	} while (pn[i] == 0 && i < SMS4_BLOCK_SIZE - 1);
}

static bool iwl_wapi_sms4_verify_pn_higher(u8 *presumed_high, u8 *presumed_low)
{
	int i;

	for (i = SMS4_BLOCK_SIZE - 1; i >= 0; i--) {
		if (presumed_high[i] > presumed_low[i])
			return true;
		if (presumed_low[i] > presumed_high[i])
			return false;
	}

	return false;
}

void iwl_wapi_sms4_init_ctx(struct iwl_wapi_sms4_ctx *ctx)
{
	iwl_wapi_sms4_printk_str("init", "in");
	spin_lock_init(&ctx->ctx_lock);
	ctx->curr_mc_key_idx = 0;
	ctx->curr_uc_key_idx = 0;
	memset(ctx->last_mc_pn, 0, SMS4_BLOCK_SIZE);
	memset(ctx->last_uc_pn, 0, SMS4_BLOCK_SIZE);
	memset(ctx->next_pn, 0, SMS4_BLOCK_SIZE);
	ctx->mc_key[0].is_valid = false;
	ctx->mc_key[1].is_valid = false;
	ctx->uc_key[0].is_valid = false;
	ctx->uc_key[1].is_valid = false;
	iwl_wapi_sms4_printk_str("init", "out");
}
IWL_EXPORT_SYMBOL(iwl_wapi_sms4_init_ctx);

void iwl_wapi_sms4_report_supported_cipher_suites(struct ieee80211_hw *hw)
{
	static const u32 wapi_sms4_cipher_suites[] = {
		/* keep WEP first, it may be removed */
		WLAN_CIPHER_SUITE_WEP40,
		WLAN_CIPHER_SUITE_WEP104,
		WLAN_CIPHER_SUITE_TKIP,
		WLAN_CIPHER_SUITE_CCMP,
		WLAN_CIPHER_SUITE_SMS4,

		/* keep last -- depends on hw flags! */
		WLAN_CIPHER_SUITE_AES_CMAC
	};

	iwl_wapi_sms4_printk_str("report", "in");
	hw->wiphy->cipher_suites = wapi_sms4_cipher_suites;
	hw->wiphy->n_cipher_suites = ARRAY_SIZE(wapi_sms4_cipher_suites);
	if (!(hw->flags & IEEE80211_HW_MFP_CAPABLE))
		hw->wiphy->n_cipher_suites--;
	iwl_wapi_sms4_printk_str("report", "out");
}
IWL_EXPORT_SYMBOL(iwl_wapi_sms4_report_supported_cipher_suites);

static bool iwl_wapi_sms4_is_active(struct iwl_wapi_sms4_ctx *ctx)
{
	bool result;
	unsigned long flags;

	spin_lock_irqsave(&ctx->ctx_lock, flags);

	result = (ctx->uc_key[0].is_valid || ctx->uc_key[1].is_valid) &&
		 (ctx->mc_key[0].is_valid || ctx->mc_key[1].is_valid);

	spin_unlock_irqrestore(&ctx->ctx_lock, flags);

	return result;
}

int iwl_wapi_sms4_set_key(struct ieee80211_key_conf *keyconf,
			  struct iwl_wapi_sms4_ctx *ctx)
{
	struct iwl_wapi_sms4_keyconf *sms4_keyconf;
	unsigned long flags;

	iwl_wapi_sms4_printk_str("set_key", "in");

	if (keyconf->keylen != 2 * IWL_WAPI_SMS4_KEY_LEN ||
	    ((u8)(keyconf->keyidx)) > 1) {
		printk(KERN_ERR "WAPI - Bad key material. len = %d  idx = %d\n",
		       keyconf->keylen, keyconf->keyidx);
		return -EINVAL;
	}

	spin_lock_irqsave(&ctx->ctx_lock, flags);

	if (keyconf->flags & IEEE80211_KEY_FLAG_PAIRWISE) {
		sms4_keyconf = &ctx->uc_key[keyconf->keyidx];
		/* set curr key idx only in first key installation */
		if (!ctx->uc_key[0].is_valid &&
		    !ctx->uc_key[1].is_valid) {
			memset(ctx->last_uc_pn, 0, SMS4_BLOCK_SIZE);
			iwl_wapi_sms4_init_uc_pn(ctx->next_pn);
			iwl_wapi_sms4_inc_pn(ctx->next_pn);
			ctx->curr_uc_key_idx = keyconf->keyidx;
		}
	} else {
		sms4_keyconf = &ctx->mc_key[keyconf->keyidx];
		/* set curr key idx only in first key installation */
		if (!ctx->mc_key[0].is_valid &&
		    !ctx->mc_key[1].is_valid) {
			memset(ctx->last_mc_pn, 0, SMS4_BLOCK_SIZE);
			ctx->curr_mc_key_idx = keyconf->keyidx;
		}
	}

	memcpy(sms4_keyconf->enc_key, keyconf->key, IWL_WAPI_SMS4_KEY_LEN);
	iwl_wapi_sms4_printk_buf(ctx, "set_key", "encryption key",
				 sms4_keyconf->enc_key,
				 IWL_WAPI_SMS4_KEY_LEN);
	sms4_set_key(&sms4_keyconf->enc_ctx, sms4_keyconf->enc_key);
	memcpy(sms4_keyconf->cons_key, keyconf->key + IWL_WAPI_SMS4_KEY_LEN,
	       IWL_WAPI_SMS4_KEY_LEN);
	iwl_wapi_sms4_printk_buf(ctx, "set_key", "consistency key",
				 sms4_keyconf->cons_key,
				 IWL_WAPI_SMS4_KEY_LEN);
	sms4_set_key(&sms4_keyconf->cons_ctx, sms4_keyconf->cons_key);
	sms4_keyconf->is_valid = true;

	iwl_wapi_sms4_printk_ctx(ctx, __func__);

	spin_unlock_irqrestore(&ctx->ctx_lock, flags);

	iwl_wapi_sms4_printk_str("set_key", "out");

	return 0;
}
IWL_EXPORT_SYMBOL(iwl_wapi_sms4_set_key);

int iwl_wapi_sms4_clear_key(struct ieee80211_key_conf *keyconf,
			    struct iwl_wapi_sms4_ctx *ctx)
{
	struct iwl_wapi_sms4_keyconf *sms4_keyconf;
	unsigned long flags;

	iwl_wapi_sms4_printk_str("clear_key", "in");

	if (keyconf->cipher != WLAN_CIPHER_SUITE_SMS4)
		return 1;

	if (((u8)(keyconf->keyidx)) > 1) {
		printk(KERN_ERR "WAPI - Bad key material. idx = %d\n",
		       keyconf->keyidx);
		return -EINVAL;
	}

	spin_lock_irqsave(&ctx->ctx_lock, flags);

	if (keyconf->flags & IEEE80211_KEY_FLAG_PAIRWISE)
		sms4_keyconf = &ctx->uc_key[keyconf->keyidx];
	else
		sms4_keyconf = &ctx->mc_key[keyconf->keyidx];

	sms4_keyconf->is_valid = false;

	spin_unlock_irqrestore(&ctx->ctx_lock, flags);

	iwl_wapi_sms4_printk_ctx(ctx, __func__);

	iwl_wapi_sms4_printk_str("clear_key", "out");

	return 0;
}
IWL_EXPORT_SYMBOL(iwl_wapi_sms4_clear_key);

static void iwl_wapi_sms4_assemble_icd(const struct sk_buff *skb, s8 keyidx,
				       u8 *icd, size_t *p_icd_len)
{
	struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data;
	__le16 fc = hdr->frame_control;
	u16 data_len = skb->len - ieee80211_hdrlen(fc);
	__be16 data_len_be = cpu_to_be16(data_len);

	*p_icd_len = 0;
	/* PART 1: */
	memcpy(icd + *p_icd_len, &fc, 2);
	(icd + *p_icd_len)[0] &= 0x8F; /* as required by WAPI spec */
	(icd + *p_icd_len)[1] &= 0xC7; /* as required by WAPI spec */
	(icd + *p_icd_len)[1] |= 0x40; /* as required by WAPI spec */
	*p_icd_len += 2;
	memcpy(icd + *p_icd_len, hdr->addr1, ETH_ALEN);
	*p_icd_len += ETH_ALEN;
	memcpy(icd + *p_icd_len, hdr->addr2, ETH_ALEN);
	*p_icd_len += ETH_ALEN;
	memcpy(icd + *p_icd_len, &hdr->seq_ctrl, 2);
	(icd + *p_icd_len)[0] &= 0x0F;  /* as required by WAPI spec */
	(icd + *p_icd_len)[1] = 0;      /* as required by WAPI spec */
	*p_icd_len += 2;
	memcpy(icd + *p_icd_len, hdr->addr3, ETH_ALEN);
	*p_icd_len += ETH_ALEN;
	if (ieee80211_has_a4(fc))
		memcpy(icd + *p_icd_len, hdr->addr4, ETH_ALEN);
	else
		memset(icd + *p_icd_len, 0, ETH_ALEN);
	*p_icd_len += ETH_ALEN;
	if (ieee80211_is_data_qos(fc)) {
		memcpy(icd + *p_icd_len, ieee80211_get_qos_ctl(hdr), 2);
		*p_icd_len += 2;
	}
	(icd + *p_icd_len)[0] = keyidx;
	(icd + *p_icd_len)[1] = 0; /* 'reserved' field */
	*p_icd_len += 2;
	memcpy(icd + *p_icd_len, &data_len_be, 2);
	*p_icd_len += 2;
	if (ieee80211_is_data_qos(fc)) {
		/* zero-pad part1 with extra 14 bytes */
		memset((icd + *p_icd_len), 0, 14);
		*p_icd_len += 14;
	}
	/* PART 2: */
	memcpy(icd + *p_icd_len, ((u8 *)hdr) + ieee80211_hdrlen(fc), data_len);
	*p_icd_len += data_len;
}

static int _iwl_wapi_sms4_tx(struct sk_buff *skb, struct iwl_wapi_sms4_ctx *ctx)
{
	struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data;
	s8 curr_uc_keyidx = ctx->curr_uc_key_idx;
	u8 *pos;
	size_t hdr_len = ieee80211_hdrlen(hdr->frame_control);
	u8 mic[SMS4_MIC_LEN];
	int ret;
	unsigned long flags;

	spin_lock_irqsave(&ctx->ctx_lock, flags);

	if (!ctx->uc_key[curr_uc_keyidx].is_valid) {
		ret = -EINVAL;
		goto unlock_ctx;
	}

	iwl_wapi_sms4_printk_buf(ctx, "tx", "plain packet",
				 skb->data, skb->len);

	hdr->frame_control |= cpu_to_le16(IEEE80211_FCTL_PROTECTED);
	iwl_wapi_sms4_assemble_icd(skb, curr_uc_keyidx,
				   ctx->icd, &ctx->icd_len);
	iwl_wapi_sms4_printk_buf(ctx, "tx", "icd", ctx->icd, ctx->icd_len);
	sms4_calculate_mic_by_cbc_mac(&ctx->uc_key[curr_uc_keyidx].cons_ctx,
				      ctx->next_pn, ctx->icd, ctx->icd_len,
				      mic);

	iwl_wapi_sms4_printk_buf(ctx, "tx", "mic is", mic, SMS4_MIC_LEN);

	if (skb_tailroom(skb) < SMS4_MIC_LEN ||
	    skb_headroom(skb) < SMS4_HEADER_LEN) {
		ret = pskb_expand_head(
			skb, SMS4_HEADER_LEN, SMS4_MIC_LEN, GFP_ATOMIC);
		if (ret)
			goto unlock_ctx;
	}

	pos = skb_put(skb, SMS4_MIC_LEN);
	memcpy(pos, mic, SMS4_MIC_LEN);
	sms4_crypt_buf_by_ofb(&ctx->uc_key[curr_uc_keyidx].enc_ctx,
			      ctx->next_pn, skb->data + hdr_len,
			      skb->len - hdr_len);

	pos = skb_push(skb, SMS4_HEADER_LEN);
	memmove(pos, pos + SMS4_HEADER_LEN, hdr_len);
	pos += hdr_len;
	pos[0] = curr_uc_keyidx;
	pos[1] = 0; /* reserved field */
	memcpy(pos + 2, ctx->next_pn, SMS4_BLOCK_SIZE);

	iwl_wapi_sms4_printk_buf(ctx, "tx", "crypted packet",
				 skb->data, skb->len);

	/* update pn for next tx */
	iwl_wapi_sms4_inc_pn(ctx->next_pn);

	ret = 0;

unlock_ctx:
	spin_unlock_irqrestore(&ctx->ctx_lock, flags);

	return ret;
}

int iwl_wapi_sms4_tx(struct sk_buff *skb, struct iwl_wapi_sms4_ctx *ctx)
{
	struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data;
	__le16 fc = hdr->frame_control;
	int err = 0;

	if (ieee80211_is_data(fc) && skb->protocol != htons(ETH_P_WAPI) &&
	    (skb->len - ieee80211_hdrlen(fc) > 0) &&
	    iwl_wapi_sms4_is_active(ctx)) {
		err = _iwl_wapi_sms4_tx(skb, ctx);
		if (err == -EINVAL)
			printk(KERN_ERR "WAPI - Dropping tx - invalid key\n");
		else if (err)
			printk(KERN_ERR "WAPI - Dropping tx - err code: %d\n",
			       err);
	}

	iwl_wapi_sms4_printk_ctx(ctx, __func__);

	return err;
}
IWL_EXPORT_SYMBOL(iwl_wapi_sms4_tx);

static int _iwl_wapi_sms4_rx(struct sk_buff *skb, struct iwl_wapi_sms4_ctx *ctx)
{
	struct ieee80211_hdr *hdr;
	size_t hdr_len;
	u8 pn[SMS4_BLOCK_SIZE];
	u8 keyidx;
	u8 expected_mic[SMS4_MIC_LEN], calculated_mic[SMS4_MIC_LEN];
	bool is_multicast;
	struct iwl_wapi_sms4_keyconf *sms4_keyconf;
	struct ieee80211_rx_status *status;
	int ret;
	unsigned long flags;

	if (skb_linearize(skb))
		return -ENOMEM;

	spin_lock_irqsave(&ctx->ctx_lock, flags);

	iwl_wapi_sms4_printk_buf(ctx, "rx", "crypted packet",
				 skb->data, skb->len);

	status = IEEE80211_SKB_RXCB(skb);
	hdr = (struct ieee80211_hdr *)skb->data;
	hdr_len = ieee80211_hdrlen(hdr->frame_control);
	is_multicast = is_multicast_ether_addr(ieee80211_get_DA(hdr));

	/* check key index */
	keyidx = skb->data[hdr_len];
	if (keyidx > 1) {
		ret = -EINVAL;
		goto unlock_ctx;
	}
	sms4_keyconf = is_multicast
		? &ctx->mc_key[keyidx] : &ctx->uc_key[keyidx];
	if (!sms4_keyconf->is_valid) {
		ret = -EINVAL;
		goto unlock_ctx;
	}

	/* set PN */
	memcpy(pn, skb->data + hdr_len + 2, SMS4_BLOCK_SIZE);

	/* decrypt and set MIC*/
	sms4_crypt_buf_by_ofb(&sms4_keyconf->enc_ctx, pn,
			      skb->data + hdr_len + SMS4_HEADER_LEN,
			      skb->len - hdr_len - SMS4_HEADER_LEN);
	status->flag |= RX_FLAG_DECRYPTED;
	memcpy(expected_mic, skb->data + skb->len - SMS4_MIC_LEN, SMS4_MIC_LEN);

	/* strip header and footer */
	memmove(skb->data + SMS4_HEADER_LEN, skb->data, hdr_len);
	skb_pull(skb, SMS4_HEADER_LEN);
	status->flag |= RX_FLAG_IV_STRIPPED;
	skb_trim(skb, skb->len - SMS4_MIC_LEN);
	status->flag |= RX_FLAG_MMIC_STRIPPED;

	/* calculate MIC */
	iwl_wapi_sms4_assemble_icd(skb, keyidx, ctx->icd, &ctx->icd_len);
	iwl_wapi_sms4_printk_buf(ctx, "rx", "icd", ctx->icd, ctx->icd_len);
	sms4_calculate_mic_by_cbc_mac(&sms4_keyconf->cons_ctx, pn,
				      ctx->icd, ctx->icd_len, calculated_mic);
	if (memcmp(expected_mic, calculated_mic, SMS4_MIC_LEN) != 0) {
		iwl_wapi_sms4_printk_buf(ctx, "rx", "expected mic is",
					 expected_mic, SMS4_MIC_LEN);
		iwl_wapi_sms4_printk_buf(ctx, "rx", "calculated mic is",
					 calculated_mic, SMS4_MIC_LEN);
		ret = -EIO;
		goto unlock_ctx;
	}
	iwl_wapi_sms4_printk_buf(ctx, "rx", "verified mic is",
				 calculated_mic, SMS4_MIC_LEN);

	/* update keyidx, pn and frame control */
	if (is_multicast) {
		if (keyidx == ctx->curr_mc_key_idx) {
			if (!iwl_wapi_sms4_verify_pn_higher(pn,
							    ctx->last_mc_pn)) {
				ret = -EPERM;
				goto unlock_ctx;
			}
		} else { /* rekeying */
			ctx->mc_key[ctx->curr_mc_key_idx].is_valid = false;
			ctx->curr_mc_key_idx = keyidx;
		}
		memcpy(ctx->last_mc_pn, pn, SMS4_BLOCK_SIZE);
	} else {
		if (keyidx == ctx->curr_uc_key_idx) {
			if (!iwl_wapi_sms4_verify_pn_higher(pn,
							    ctx->last_uc_pn)) {
				ret = -EPERM;
				goto unlock_ctx;
			}
		} else { /* rekeying */
			ctx->uc_key[ctx->curr_uc_key_idx].is_valid = false;
			ctx->curr_uc_key_idx = keyidx;
		}
		memcpy(ctx->last_uc_pn, pn, SMS4_BLOCK_SIZE);
	}

	hdr = (struct ieee80211_hdr *)skb->data;
	hdr->frame_control &= ~cpu_to_le16(IEEE80211_FCTL_PROTECTED);
	iwl_wapi_sms4_printk_buf(ctx, "rx", "plain packet",
				 skb->data, skb->len);
	ret = 0;

unlock_ctx:
	spin_unlock_irqrestore(&ctx->ctx_lock, flags);

	return ret;
}

int iwl_wapi_sms4_rx(struct sk_buff *skb, struct iwl_wapi_sms4_ctx *ctx)
{
	struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data;
	__le16 fc = hdr->frame_control;
	int err = 0;

	if (ieee80211_is_data(fc) && ieee80211_has_protected(fc) &&
	    iwl_wapi_sms4_is_active(ctx)) {
		err = _iwl_wapi_sms4_rx(skb, ctx);
		switch (err) {
		case 0:
			break;
		case -ENOMEM:
			printk(KERN_ERR "WAPI - Dropping rx - unable to linearize\n");
			break;
		case -EINVAL:
			printk(KERN_ERR "WAPI - Dropping rx - invalid key\n");
			break;
		case -EPERM:
			printk(KERN_ERR "WAPI - Dropping rx - bad PN\n");
			break;
		case -EIO:
			printk(KERN_ERR "WAPI - Dropping rx - bad MIC\n");
			break;
		default:
			printk(KERN_ERR "WAPI - Dropping rx unexpectedly (%d)\n",
			       err);
			break;
		}
	}

	return err;
}
IWL_EXPORT_SYMBOL(iwl_wapi_sms4_rx);

