/*
 * netsniff-ng - the packet sniffing beast
 * Copyright 2009 - 2012 Daniel Borkmann.
 * Copyright 2009, 2010 Emmanuel Roullit.
 * Copyright 1990-1996 The Regents of the University of
 * California. All rights reserved. (3-clause BSD license)
 * Subject to the GPL, version 2.
 */

#include <stdint.h>
#include <stdio.h>
#include <arpa/inet.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>

#include "bpf.h"
#include "xmalloc.h"
#include "die.h"
#include "str.h"
#include "sysctl.h"

#define EXTRACT_SHORT(packet)						\
		((unsigned short) ntohs(*(unsigned short *) packet))
#define EXTRACT_LONG(packet)						\
		(ntohl(*(unsigned long *) packet))

#ifndef BPF_MEMWORDS
# define BPF_MEMWORDS	16
#endif

#define BPF_LD_B	(BPF_LD   |    BPF_B)
#define BPF_LD_H	(BPF_LD   |    BPF_H)
#define BPF_LD_W	(BPF_LD   |    BPF_W)
#define BPF_LDX_B	(BPF_LDX  |    BPF_B)
#define BPF_LDX_W	(BPF_LDX  |    BPF_W)
#define BPF_JMP_JA	(BPF_JMP  |   BPF_JA)
#define BPF_JMP_JEQ	(BPF_JMP  |  BPF_JEQ)
#define BPF_JMP_JGT	(BPF_JMP  |  BPF_JGT)
#define BPF_JMP_JGE	(BPF_JMP  |  BPF_JGE)
#define BPF_JMP_JSET	(BPF_JMP  | BPF_JSET)
#define BPF_ALU_ADD	(BPF_ALU  |  BPF_ADD)
#define BPF_ALU_SUB	(BPF_ALU  |  BPF_SUB)
#define BPF_ALU_MUL	(BPF_ALU  |  BPF_MUL)
#define BPF_ALU_DIV	(BPF_ALU  |  BPF_DIV)
#define BPF_ALU_MOD	(BPF_ALU  |  BPF_MOD)
#define BPF_ALU_NEG	(BPF_ALU  |  BPF_NEG)
#define BPF_ALU_AND	(BPF_ALU  |  BPF_AND)
#define BPF_ALU_OR	(BPF_ALU  |   BPF_OR)
#define BPF_ALU_XOR	(BPF_ALU  |  BPF_XOR)
#define BPF_ALU_LSH	(BPF_ALU  |  BPF_LSH)
#define BPF_ALU_RSH	(BPF_ALU  |  BPF_RSH)
#define BPF_MISC_TAX	(BPF_MISC |  BPF_TAX)
#define BPF_MISC_TXA	(BPF_MISC |  BPF_TXA)

static const char *op_table[] = {
	[BPF_LD_B]	=	"ldb",
	[BPF_LD_H]	=	"ldh",
	[BPF_LD_W]	=	"ld",
	[BPF_LDX]	=	"ldx",
	[BPF_LDX_B]	=	"ldxb",
	[BPF_ST]	=	"st",
	[BPF_STX]	=	"stx",
	[BPF_JMP_JA]	=	"ja",
	[BPF_JMP_JEQ]	=	"jeq",
	[BPF_JMP_JGT]	=	"jgt",
	[BPF_JMP_JGE]	=	"jge",
	[BPF_JMP_JSET]	=	"jset",
	[BPF_ALU_ADD]	=	"add",
	[BPF_ALU_SUB]	=	"sub",
	[BPF_ALU_MUL]	=	"mul",
	[BPF_ALU_DIV]	=	"div",
	[BPF_ALU_MOD]	=	"mod",
	[BPF_ALU_NEG]	=	"neg",
	[BPF_ALU_AND]	=	"and",
	[BPF_ALU_OR]	=	"or",
	[BPF_ALU_XOR]	=	"xor",
	[BPF_ALU_LSH]	=	"lsh",
	[BPF_ALU_RSH]	=	"rsh",
	[BPF_RET]	=	"ret",
	[BPF_MISC_TAX]	=	"tax",
	[BPF_MISC_TXA]	=	"txa",
};

void bpf_dump_op_table(void)
{
	size_t i;
	for (i = 0; i < array_size(op_table); ++i) {
		if (op_table[i])
			printf("%s\n", op_table[i]);
	}
}

static const char *bpf_dump_linux_k(uint32_t k)
{
	switch (k) {
	default:
		return "[%d]";
	case SKF_AD_OFF + SKF_AD_PROTOCOL:
		return "proto";
	case SKF_AD_OFF + SKF_AD_PKTTYPE:
		return "type";
	case SKF_AD_OFF + SKF_AD_IFINDEX:
		return "ifidx";
	case SKF_AD_OFF + SKF_AD_NLATTR:
		return "nla";
	case SKF_AD_OFF + SKF_AD_NLATTR_NEST:
		return "nlan";
	case SKF_AD_OFF + SKF_AD_MARK:
		return "mark";
	case SKF_AD_OFF + SKF_AD_QUEUE:
		return "queue";
	case SKF_AD_OFF + SKF_AD_HATYPE:
		return "hatype";
	case SKF_AD_OFF + SKF_AD_RXHASH:
		return "rxhash";
	case SKF_AD_OFF + SKF_AD_CPU:
		return "cpu";
	case SKF_AD_OFF + SKF_AD_VLAN_TAG:
		return "vlant";
	case SKF_AD_OFF + SKF_AD_VLAN_TAG_PRESENT:
		return "vlanp";
	case SKF_AD_OFF + SKF_AD_PAY_OFFSET:
		return "poff";
	}
}

static char *__bpf_dump(const struct sock_filter bpf, int n)
{
	int v;
	const char *fmt, *op;
	static char image[256];
	char operand[64];

	v = bpf.k;
	switch (bpf.code) {
	default:
		op = "unimp";
		fmt = "0x%x";
		v = bpf.code;
		break;
	case BPF_RET | BPF_K:
		op = op_table[BPF_RET];
		fmt = "#0x%x";
		break;
	case BPF_RET | BPF_A:
		op = op_table[BPF_RET];
		fmt = "a";
		break;
	case BPF_RET | BPF_X:
		op = op_table[BPF_RET];
		fmt = "x";
		break;
	case BPF_LD_W | BPF_ABS:
		op = op_table[BPF_LD_W];
		fmt = bpf_dump_linux_k(bpf.k);
		break;
	case BPF_LD_H | BPF_ABS:
		op = op_table[BPF_LD_H];
		fmt = bpf_dump_linux_k(bpf.k);
		break;
	case BPF_LD_B | BPF_ABS:
		op = op_table[BPF_LD_B];
		fmt = bpf_dump_linux_k(bpf.k);
		break;
	case BPF_LD_W | BPF_LEN:
		op = op_table[BPF_LD_W];
		fmt = "#len";
		break;
	case BPF_LD_W | BPF_IND:
		op = op_table[BPF_LD_W];
		fmt = "[x + %d]";
		break;
	case BPF_LD_H | BPF_IND:
		op = op_table[BPF_LD_H];
		fmt = "[x + %d]";
		break;
	case BPF_LD_B | BPF_IND:
		op = op_table[BPF_LD_B];
		fmt = "[x + %d]";
		break;
	case BPF_LD | BPF_IMM:
		op = op_table[BPF_LD_W];
		fmt = "#0x%x";
		break;
	case BPF_LDX | BPF_IMM:
		op = op_table[BPF_LDX];
		fmt = "#0x%x";
		break;
	case BPF_LDX_B | BPF_MSH:
		op = op_table[BPF_LDX_B];
		fmt = "4*([%d]&0xf)";
		break;
	case BPF_LD | BPF_MEM:
		op = op_table[BPF_LD_W];
		fmt = "M[%d]";
		break;
	case BPF_LDX | BPF_MEM:
		op = op_table[BPF_LDX];
		fmt = "M[%d]";
		break;
	case BPF_ST:
		op = op_table[BPF_ST];
		fmt = "M[%d]";
		break;
	case BPF_STX:
		op = op_table[BPF_STX];
		fmt = "M[%d]";
		break;
	case BPF_JMP_JA:
		op = op_table[BPF_JMP_JA];
		fmt = "%d";
		v = n + 1 + bpf.k;
		break;
	case BPF_JMP_JGT | BPF_K:
		op = op_table[BPF_JMP_JGT];
		fmt = "#0x%x";
		break;
	case BPF_JMP_JGE | BPF_K:
		op = op_table[BPF_JMP_JGE];
		fmt = "#0x%x";
		break;
	case BPF_JMP_JEQ | BPF_K:
		op = op_table[BPF_JMP_JEQ];
		fmt = "#0x%x";
		break;
	case BPF_JMP_JSET | BPF_K:
		op = op_table[BPF_JMP_JSET];
		fmt = "#0x%x";
		break;
	case BPF_JMP_JGT | BPF_X:
		op = op_table[BPF_JMP_JGT];
		fmt = "x";
		break;
	case BPF_JMP_JGE | BPF_X:
		op = op_table[BPF_JMP_JGE];
		fmt = "x";
		break;
	case BPF_JMP_JEQ | BPF_X:
		op = op_table[BPF_JMP_JEQ];
		fmt = "x";
		break;
	case BPF_JMP_JSET | BPF_X:
		op = op_table[BPF_JMP_JSET];
		fmt = "x";
		break;
	case BPF_ALU_ADD | BPF_X:
		op = op_table[BPF_ALU_ADD];
		fmt = "x";
		break;
	case BPF_ALU_SUB | BPF_X:
		op = op_table[BPF_ALU_SUB];
		fmt = "x";
		break;
	case BPF_ALU_MUL | BPF_X:
		op = op_table[BPF_ALU_MUL];
		fmt = "x";
		break;
	case BPF_ALU_DIV | BPF_X:
		op = op_table[BPF_ALU_DIV];
		fmt = "x";
		break;
	case BPF_ALU_MOD | BPF_X:
		op = op_table[BPF_ALU_MOD];
		fmt = "x";
		break;
	case BPF_ALU_AND | BPF_X:
		op = op_table[BPF_ALU_AND];
		fmt = "x";
		break;
	case BPF_ALU_OR | BPF_X:
		op = op_table[BPF_ALU_OR];
		fmt = "x";
		break;
	case BPF_ALU_XOR | BPF_X:
		op = op_table[BPF_ALU_XOR];
		fmt = "x";
		break;
	case BPF_ALU_LSH | BPF_X:
		op = op_table[BPF_ALU_LSH];
		fmt = "x";
		break;
	case BPF_ALU_RSH | BPF_X:
		op = op_table[BPF_ALU_RSH];
		fmt = "x";
		break;
	case BPF_ALU_ADD | BPF_K:
		op = op_table[BPF_ALU_ADD];
		fmt = "#%d";
		break;
	case BPF_ALU_SUB | BPF_K:
		op = op_table[BPF_ALU_SUB];
		fmt = "#%d";
		break;
	case BPF_ALU_MUL | BPF_K:
		op = op_table[BPF_ALU_MUL];
		fmt = "#%d";
		break;
	case BPF_ALU_DIV | BPF_K:
		op = op_table[BPF_ALU_DIV];
		fmt = "#%d";
		break;
	case BPF_ALU_MOD | BPF_K:
		op = op_table[BPF_ALU_MOD];
		fmt = "#%d";
		break;
	case BPF_ALU_AND | BPF_K:
		op = op_table[BPF_ALU_AND];
		fmt = "#0x%x";
		break;
	case BPF_ALU_OR | BPF_K:
		op = op_table[BPF_ALU_OR];
		fmt = "#0x%x";
		break;
	case BPF_ALU_XOR | BPF_K:
		op = op_table[BPF_ALU_XOR];
		fmt = "#0x%x";
		break;
	case BPF_ALU_LSH | BPF_K:
		op = op_table[BPF_ALU_LSH];
		fmt = "#%d";
		break;
	case BPF_ALU_RSH | BPF_K:
		op = op_table[BPF_ALU_RSH];
		fmt = "#%d";
		break;
	case BPF_ALU_NEG:
		op = op_table[BPF_ALU_NEG];
		fmt = "";
		break;
	case BPF_MISC_TAX:
		op = op_table[BPF_MISC_TAX];
		fmt = "";
		break;
	case BPF_MISC_TXA:
		op = op_table[BPF_MISC_TXA];
		fmt = "";
		break;
	}

	slprintf_nocheck(operand, sizeof(operand), fmt, v);
	slprintf_nocheck(image, sizeof(image),
			 (BPF_CLASS(bpf.code) == BPF_JMP &&
			  BPF_OP(bpf.code) != BPF_JA) ?
			 " L%d: %s %s, L%d, L%d" : " L%d: %s %s",
			 n, op, operand, n + 1 + bpf.jt, n + 1 + bpf.jf);
	return image;
}

void bpf_dump_all(struct sock_fprog *bpf)
{
	int i;

	for (i = 0; i < bpf->len; ++i)
		printf("%s\n", __bpf_dump(bpf->filter[i], i));
}

void bpf_attach_to_sock(int sock, struct sock_fprog *bpf)
{
	int ret;

	if (bpf->filter[0].code == BPF_RET &&
	    bpf->filter[0].k == 0xFFFFFFFF)
		return;

	ret = setsockopt(sock, SOL_SOCKET, SO_ATTACH_FILTER,
			 bpf, sizeof(*bpf));
	if (unlikely(ret < 0))
		panic("Cannot attach filter to socket!\n");
}

void bpf_detach_from_sock(int sock)
{
	int ret, empty = 0;

	ret = setsockopt(sock, SOL_SOCKET, SO_DETACH_FILTER,
			 &empty, sizeof(empty));
	if (unlikely(ret < 0))
		panic("Cannot detach filter from socket!\n");
}

int enable_kernel_bpf_jit_compiler(void)
{
	return sysctl_set_int("net/core/bpf_jit_enable", 1);
}

int __bpf_validate(const struct sock_fprog *bpf)
{
	uint32_t i, from;
	const struct sock_filter *p;

	if (!bpf)
		return 0;
	if (bpf->len < 1)
		return 0;

	for (i = 0; i < bpf->len; ++i) {
		p = &bpf->filter[i];
		switch (BPF_CLASS(p->code)) {
			/* Check that memory operations use valid addresses. */
		case BPF_LD:
		case BPF_LDX:
			switch (BPF_MODE(p->code)) {
			case BPF_IMM:
				break;
			case BPF_ABS:
			case BPF_IND:
			case BPF_MSH:
				/* There's no maximum packet data size
				 * in userland.  The runtime packet length
				 * check suffices.
				 */
				break;
			case BPF_MEM:
				if (p->k >= BPF_MEMWORDS)
					return 0;
				break;
			case BPF_LEN:
				break;
			default:
				return 0;
			}
			break;
		case BPF_ST:
		case BPF_STX:
			if (p->k >= BPF_MEMWORDS)
				return 0;
			break;
		case BPF_ALU:
			switch (BPF_OP(p->code)) {
			case BPF_ADD:
			case BPF_SUB:
			case BPF_MUL:
			case BPF_OR:
			case BPF_XOR:
			case BPF_AND:
			case BPF_LSH:
			case BPF_RSH:
			case BPF_NEG:
				break;
			case BPF_DIV:
			case BPF_MOD:
				/* Check for constant division by 0 (undefined
				 * for div and mod).
				 */
				if (BPF_RVAL(p->code) == BPF_K && p->k == 0)
					return 0;
				break;
			default:
				return 0;
			}
			break;
		case BPF_JMP:
			/* Check that jumps are within the code block,
			 * and that unconditional branches don't go
			 * backwards as a result of an overflow.
			 * Unconditional branches have a 32-bit offset,
			 * so they could overflow; we check to make
			 * sure they don't.  Conditional branches have
			 * an 8-bit offset, and the from address is <=
			 * BPF_MAXINSNS, and we assume that BPF_MAXINSNS
			 * is sufficiently small that adding 255 to it
			 * won't overflow.
			 *
			 * We know that len is <= BPF_MAXINSNS, and we
			 * assume that BPF_MAXINSNS is < the maximum size
			 * of a u_int, so that i + 1 doesn't overflow.
			 *
			 * For userland, we don't know that the from
			 * or len are <= BPF_MAXINSNS, but we know that
			 * from <= len, and, except on a 64-bit system,
			 * it's unlikely that len, if it truly reflects
			 * the size of the program we've been handed,
			 * will be anywhere near the maximum size of
			 * a u_int.  We also don't check for backward
			 * branches, as we currently support them in
			 * userland for the protochain operation.
			 */
			from = i + 1;
			switch (BPF_OP(p->code)) {
			case BPF_JA:
				if (from + p->k >= bpf->len)
					return 0;
				break;
			case BPF_JEQ:
			case BPF_JGT:
			case BPF_JGE:
			case BPF_JSET:
				if (from + p->jt >= bpf->len ||
				    from + p->jf >= bpf->len)
					return 0;
				break;
			default:
				return 0;
			}
			break;
		case BPF_RET:
			break;
		case BPF_MISC:
			break;
		}
	}

	return BPF_CLASS(bpf->filter[bpf->len - 1].code) == BPF_RET;
}

uint32_t bpf_run_filter(const struct sock_fprog * fcode, uint8_t * packet,
		        size_t plen)
{
	/* XXX: caplen == len */
	uint32_t A, X;
	uint32_t k;
	struct sock_filter *bpf;
	int32_t mem[BPF_MEMWORDS] = { 0, };

	if (fcode == NULL || fcode->filter == NULL || fcode->len == 0)
		return 0xFFFFFFFF;

	A = 0;
	X = 0;

	bpf = fcode->filter;
	--bpf;
	while (1) {
		++bpf;
		switch (bpf->code) {
		default:
			return 0;
		case BPF_RET | BPF_K:
			return (uint32_t) bpf->k;
		case BPF_RET | BPF_A:
			return (uint32_t) A;
		case BPF_LD_W | BPF_ABS:
			/* No Linux extensions supported here! */
			k = bpf->k;
			if (k + sizeof(int32_t) > plen)
				return 0;
			A = EXTRACT_LONG(&packet[k]);
			continue;
		case BPF_LD_H | BPF_ABS:
			/* No Linux extensions supported here! */
			k = bpf->k;
			if (k + sizeof(short) > plen)
				return 0;
			A = EXTRACT_SHORT(&packet[k]);
			continue;
		case BPF_LD_B | BPF_ABS:
			/* No Linux extensions supported here! */
			k = bpf->k;
			if (k >= plen)
				return 0;
			A = packet[k];
			continue;
		case BPF_LD_W | BPF_LEN:
			A = plen;
			continue;
		case BPF_LDX_W | BPF_LEN:
			X = plen;
			continue;
		case BPF_LD_W | BPF_IND:
			k = X + bpf->k;
			if (k + sizeof(int32_t) > plen)
				return 0;
			A = EXTRACT_LONG(&packet[k]);
			continue;
		case BPF_LD_H | BPF_IND:
			k = X + bpf->k;
			if (k + sizeof(short) > plen)
				return 0;
			A = EXTRACT_SHORT(&packet[k]);
			continue;
		case BPF_LD_B | BPF_IND:
			k = X + bpf->k;
			if (k >= plen)
				return 0;
			A = packet[k];
			continue;
		case BPF_LDX_B | BPF_MSH:
			k = bpf->k;
			if (k >= plen)
				return 0;
			X = (packet[bpf->k] & 0xf) << 2;
			continue;
		case BPF_LD | BPF_IMM:
			A = bpf->k;
			continue;
		case BPF_LDX | BPF_IMM:
			X = bpf->k;
			continue;
		case BPF_LD | BPF_MEM:
			A = mem[bpf->k];
			continue;
		case BPF_LDX | BPF_MEM:
			X = mem[bpf->k];
			continue;
		case BPF_ST:
			mem[bpf->k] = A;
			continue;
		case BPF_STX:
			mem[bpf->k] = X;
			continue;
		case BPF_JMP_JA:
			bpf += bpf->k;
			continue;
		case BPF_JMP_JGT | BPF_K:
			bpf += (A > bpf->k) ? bpf->jt : bpf->jf;
			continue;
		case BPF_JMP_JGE | BPF_K:
			bpf += (A >= bpf->k) ? bpf->jt : bpf->jf;
			continue;
		case BPF_JMP_JEQ | BPF_K:
			bpf += (A == bpf->k) ? bpf->jt : bpf->jf;
			continue;
		case BPF_JMP_JSET | BPF_K:
			bpf += (A & bpf->k) ? bpf->jt : bpf->jf;
			continue;
		case BPF_JMP_JGT | BPF_X:
			bpf += (A > X) ? bpf->jt : bpf->jf;
			continue;
		case BPF_JMP_JGE | BPF_X:
			bpf += (A >= X) ? bpf->jt : bpf->jf;
			continue;
		case BPF_JMP_JEQ | BPF_X:
			bpf += (A == X) ? bpf->jt : bpf->jf;
			continue;
		case BPF_JMP_JSET | BPF_X:
			bpf += (A & X) ? bpf->jt : bpf->jf;
			continue;
		case BPF_ALU_ADD | BPF_X:
			A += X;
			continue;
		case BPF_ALU_SUB | BPF_X:
			A -= X;
			continue;
		case BPF_ALU_MUL | BPF_X:
			A *= X;
			continue;
		case BPF_ALU_DIV | BPF_X:
			if (X == 0)
				return 0;
			A /= X;
			continue;
		case BPF_ALU_MOD | BPF_X:
			if (X == 0)
				return 0;
			A %= X;
			continue;
		case BPF_ALU_AND | BPF_X:
			A &= X;
			continue;
		case BPF_ALU_OR | BPF_X:
			A |= X;
			continue;
		case BPF_ALU_XOR | BPF_X:
			A ^= X;
			continue;
		case BPF_ALU_LSH | BPF_X:
			A <<= X;
			continue;
		case BPF_ALU_RSH | BPF_X:
			A >>= X;
			continue;
		case BPF_ALU_ADD | BPF_K:
			A += bpf->k;
			continue;
		case BPF_ALU_SUB | BPF_K:
			A -= bpf->k;
			continue;
		case BPF_ALU_MUL | BPF_K:
			A *= bpf->k;
			continue;
		case BPF_ALU_DIV | BPF_K:
			A /= bpf->k;
			continue;
		case BPF_ALU_MOD | BPF_K:
			A %= bpf->k;
			continue;
		case BPF_ALU_AND | BPF_K:
			A &= bpf->k;
			continue;
		case BPF_ALU_OR | BPF_K:
			A |= bpf->k;
			continue;
		case BPF_ALU_XOR | BPF_K:
			A ^= bpf->k;
			continue;
		case BPF_ALU_LSH | BPF_K:
			A <<= bpf->k;
			continue;
		case BPF_ALU_RSH | BPF_K:
			A >>= bpf->k;
			continue;
		case BPF_ALU_NEG:
			A = -A;
			continue;
		case BPF_MISC_TAX:
			X = A;
			continue;
		case BPF_MISC_TXA:
			A = X;
			continue;
		}
	}
}

void bpf_parse_rules(char *rulefile, struct sock_fprog *bpf, uint32_t link_type)
{
	int ret;
	char buff[256];
	struct sock_filter sf_single = { 0x06, 0, 0, 0xFFFFFFFF };
	FILE *fp;

	fmemset(bpf, 0, sizeof(*bpf));

	if (rulefile == NULL) {
		bpf->len = 1;
		bpf->filter = xmalloc(sizeof(sf_single));

		fmemcpy(&bpf->filter[0], &sf_single, sizeof(sf_single));
		return;
	}

	if (!strcmp(rulefile, "-"))
		fp = stdin;
	else
		fp = fopen(rulefile, "r");

	if (!fp) {
		bpf_try_compile(rulefile, bpf, link_type);
		return;
	}

	fmemset(buff, 0, sizeof(buff));
	while (fgets(buff, sizeof(buff), fp) != NULL) {
		buff[sizeof(buff) - 1] = 0;

		if (buff[0] != '{') {
			fmemset(buff, 0, sizeof(buff));
			continue;
		}

		fmemset(&sf_single, 0, sizeof(sf_single));
		ret = sscanf(buff, "{ 0x%x, %u, %u, 0x%08x },",
			     (unsigned int *) &sf_single.code,
			     (unsigned int *) &sf_single.jt,
			     (unsigned int *) &sf_single.jf,
			     (unsigned int *) &sf_single.k);
		if (unlikely(ret != 4))
			panic("BPF syntax error!\n");

		bpf->len++;
		bpf->filter = xrealloc(bpf->filter,
				       bpf->len * sizeof(sf_single));

		fmemcpy(&bpf->filter[bpf->len - 1], &sf_single,
			sizeof(sf_single));
		fmemset(buff, 0, sizeof(buff));
	}

	if (fp != stdin)
		fclose(fp);

	if (unlikely(__bpf_validate(bpf) == 0))
		panic("This is not a valid BPF program!\n");
}