/*
 *  snc - Simple Netcat v0.0
 *  
 *  Another clone of the popular netcat utility.
 *  Supports ipv4 and ipv6.
 *  Written in about half hour based on my own needs.
 *  Use at your own risk.
 *
 *  Cicuttin Matteo (C) 2007 - matteo.cicuttin@gmail.com
 *  Released under BSD license.
 *
 *  To compile:
 *  gcc -O2 -Wall -pedantic -o snc snc.c
 *
 *  History:
 *  v0.0 (26/06/2007) -		Created snc
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <signal.h>
#include <errno.h>
#include <netdb.h>

#define OK_V4(x)		(x & 0x01)
#define	OK_V6(x)		(x & 0x02)
#define CONN_V4			0x01
#define CONN_V6			0x02
#define LISTEN_MODE		0x04
#define PORT_SPECIFIED	0x08

#define BUFLEN		16384
#define BACKLOG		10

int		get_listening_socket_v4(unsigned short);
int		get_listening_socket_v6(unsigned short);
void	die(char *);

void
usage(char *progname)
{
	fprintf(stderr, "\n%s [-l46] -p <port> [hostname]\n\n", progname);
	fprintf(stderr, "-p [port]:      Specifies port\n");
	fprintf(stderr, "-l:             Listen mode\n");
	fprintf(stderr, "-4:             Use IPv4 (default)\n");
	fprintf(stderr, "-6:             Use IPv6\n");
	fprintf(stderr, "hostname:       Host to connect while in server mode\n\n");
	fprintf(stderr, "In listen mode -4 and -6 are mutually exclusive\n");
	exit(-1);
}

void
die(char *reason)
{
    perror(reason);
    exit(-1);
}

int
get_listening_socket_v4(unsigned short port)
{
	int s, on = 1;
	struct sockaddr_in sa;
	
	s = socket(AF_INET, SOCK_STREAM, 0);
	if (s == -1)
		die("socket");
		
	if ( setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) == -1)
		die("setsockopt");

	sa.sin_family = AF_INET;
	sa.sin_port = htons(port);
	sa.sin_addr.s_addr = htons(INADDR_ANY);

	if ( bind(s, (struct sockaddr *) &sa, sizeof(sa)) == -1)
		die("bind");

	if ( listen(s, BACKLOG) == -1 )
		die("listen");

	return s;
}

int
get_listening_socket_v6(unsigned short port)
{
	int s, on = 1;
	struct sockaddr_in6 sa;

	s = socket(AF_INET6, SOCK_STREAM, 0);
	if (s == -1)
		die("socket");

	if ( setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) == -1)
		die("setsockopt");

	if ( setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) == -1)
		die("setsockopt");

	sa.sin6_family = AF_INET6;
	sa.sin6_port = htons(port);
	sa.sin6_addr = in6addr_any;

	if ( bind(s, (struct sockaddr *) &sa, sizeof(sa)) == -1)
		die("bind");

	if ( listen(s, BACKLOG) == -1 )
		die("listen");

	return s;
}

void
itoa(int n, char *str)
{
	int i, j, sign;
	char temp;
  
	if ((sign = n) < 0)
		n = -n;
  
	i = 0;
	do
	{
		str[i++] = n % 10 + '0';
	} while ((n /= 10) > 0);
  
	if (sign < 0)
		str[i++] = '-';
	str[i] = '\0';
  
	for (i = 0, j = strlen(str)-1; i < (strlen(str)/2); i++, j--)
	{
		temp = str[i];
		str[i] = str[j];
		str[j] = temp;
	}
}

int connect_host(char *host, unsigned short port, char okprotos)
{
	struct addrinfo hints, *res, *res0;
	int ai_ret;
	int s;
	char *cause = NULL;
	char *proto = NULL;
	char c_port[5]; /* port max: 65536 */
    
	memset(&hints, 0, sizeof(hints));
  
	hints.ai_socktype = SOCK_STREAM;
	hints.ai_family = AF_UNSPEC;
  
	itoa(port, c_port);
  
	ai_ret = getaddrinfo(host, c_port, &hints, &res0);
  
	if (ai_ret)
	{
		fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(errno));
		goto cleanup;
	}
  
	s = -1;
	
	/* find a entry that matches our expectations */
	for ( res = res0; res; res = res->ai_next )
	{
		switch (res->ai_family)
		{
			case AF_INET:
				if (!OK_V4(okprotos)) continue;
				proto = "IPv4";
			break;
	  
			case AF_INET6:
				if (!OK_V6(okprotos)) continue;
				proto = "IPv6";
			break;
	  
			default:
				proto = "?";
				continue;
		}
		
		if ( (res->ai_family != AF_INET) && (res->ai_family != AF_INET6) && !res->ai_next )
		{
			fprintf(stderr, "Cannot find a valid address\n");
			goto cleanup;
		}
		
		s = socket (res->ai_family, res->ai_socktype, res->ai_protocol);
		if (s < 0)
		{
			cause = "socket";	  
			continue;
		}
			  
		if ( connect(s, res->ai_addr, res->ai_addrlen) < 0 )
		{
			cause = "connect";
			close(s);
			s = -1;
			continue;
		}
	
		break;
	} /* for loop */
  
	if (s < 0)
	{
		fprintf(stderr, "%s: %s\n", cause, strerror(errno));
		goto cleanup;
	}
  
	fprintf(stderr,"Connected to %s via %s\n", host, proto); 
  
	freeaddrinfo(res0);
	return s;
	
cleanup:
	freeaddrinfo(res0);
	return -1;
}

void
xfer_data(int srcfd, int dstfd)
{
	char buf[BUFLEN];
	int read_bytes;
	unsigned long total_bytes = 0;
	int finished = 0;
	
	while(!finished)
	{
		read_bytes = read(srcfd, &buf, BUFLEN);
		
		switch (read_bytes)
		{
			case -1:
				die("read");
				break;
			
			case 0:
				finished = 1;
				break;
				
			default:
				total_bytes += read_bytes;
				if ( write(dstfd, &buf, read_bytes) != read_bytes )
					die("write");
				break;
		}/* switch */
	}/* while */
	
	fprintf(stderr, "%ld bytes transferred\n", total_bytes);
}

void listen_mode(unsigned short port, int flags)
{
	int s, srcfd, dstfd = 1; /* stdout */
	struct sockaddr_in remote;
	struct sockaddr_in6 remote6;
	socklen_t addrlen;
		
	if (OK_V4(flags))
	{
		s = get_listening_socket_v4(port);
		if (s == -1)
			exit(-1);
		addrlen = sizeof(remote);
		srcfd = accept(s, (struct sockaddr *) &remote, &addrlen);
		if (srcfd == -1)
			die("accept");
		xfer_data(srcfd, dstfd);
	}
		
	else if (OK_V6(flags))
	{
		s = get_listening_socket_v6(port);
		if (s == -1)
			exit(-1);
		addrlen = sizeof(remote6);
		srcfd = accept(s, (struct sockaddr *) &remote6, &addrlen);
		if (srcfd == -1)
			die("accept");
		xfer_data(srcfd, dstfd);
	}
		
	else	/* default is IPv4 */
	{
		s = get_listening_socket_v4(port);
		if (s == -1)
			exit(-1);
		addrlen = sizeof(remote);
		srcfd = accept(s, (struct sockaddr *) &remote, &addrlen);
		if (srcfd == -1)
			die("accept");
		xfer_data(srcfd, dstfd);
	}
}

int
main(int argc, char **argv)
{
	int ch, flags = 0, port = 0;
	
	/* parse arguments */
	while ((ch = getopt(argc, argv, "lp:46")) != -1)
	{
		switch(ch)
		{
			case '4':
				flags |= CONN_V4;
				break;
			
			case '6':
				flags |= CONN_V6;
				break;
			
			case 'l':
				flags |= LISTEN_MODE;
				break;
			
			case 'p':
				flags |= PORT_SPECIFIED;
				port = atoi(optarg);
				
				if ( port < 0 || port > 65535 )
				{
					printf("Invalid port\n");
					exit(-1);
				}
				
				break;
				
			default:
				usage(argv[0]);
		}
	}
	
	/* verify arguments */
	if ( !(flags & LISTEN_MODE) && !argv[optind] )
	{
		fprintf(stderr, "Remote host not specified\n");
		usage(argv[0]);
	}
	
	if ( ! (flags & PORT_SPECIFIED) )
	{
		fprintf(stderr, "Port not specified\n");
		usage(argv[0]);
	}
	
	if ( (flags & LISTEN_MODE) && (flags & CONN_V4) && (flags & CONN_V6) )
	{
		fprintf(stderr, "In listen mode only one protocol is allowed\n");
		usage(argv[0]);
	}
	
	/* start the thing */
	if ( flags & LISTEN_MODE )
	{
		listen_mode(port, flags);
	}
	
	else if ( !(flags & LISTEN_MODE) )
	{
		int srcfd = 0 /*stdin*/, dstfd;
		
		if ( !(OK_V4(flags)) && !(OK_V6(flags)) )
			flags |= (CONN_V4 | CONN_V6);
			
		dstfd = connect_host(argv[optind], port, flags & 0x03);
		if (dstfd == -1)
			exit(-1);
		
		xfer_data(srcfd, dstfd);
	}
	
	return 0;
}



