#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <err.h>
#include <errno.h>

#include <fcntl.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/poll.h>
#include <sys/un.h>
#include <signal.h>

#define SOCKET "/tmp/soc"
#define BUFLEN 1

int
sockbind(const char *path)
{
	size_t len;
	int error;
	int fd;

	struct sockaddr_un s;

	fd = socket(PF_LOCAL, SOCK_STREAM, 0);
	if (fd == -1)
		err(1, "socket");

	bzero(&s, sizeof(s));
	s.sun_family = AF_LOCAL;
	strncpy(s.sun_path, path, sizeof(s.sun_path));
	len = SUN_LEN(&s);
	unlink(s.sun_path);

 	error = bind(fd, (struct sockaddr *)&s, len);
	if (error == -1)
		err(1, "bind");

	error = listen(fd, 0);
	if (error == -1)
		err(1, "listen");

	return (fd);
}

void
closeconn(struct pollfd *client) {
	printf("Closing connection %d\n", client->fd);
	close(client->fd);
	client->fd = -1;
}

int
newconn(struct pollfd *client, struct pollfd serv)
{
	struct sockaddr_un cli_sock;
	socklen_t cli_len;

	cli_len = sizeof(cli_sock);
	client->fd = accept(serv.fd, (struct sockaddr *)&cli_sock, &cli_len);
	printf("Accepted %d \n", client->fd);
	client->events = POLLIN | POLLHUP;
	close(0);	/* Close stdin */
	fflush(stdin);
	if ((dup2(client->fd, 0)) < 0)
		errx(1, "Couldn't dup stdin");
	else {
		printf("Redirecting stdin through the socket\n");
	}

	return 0;
}

int
main(int argc, char *argv[])
{
	struct pollfd conn[1];
	char buf[BUFLEN];
	ssize_t count;
	int hasconn, conn_rdy;

	hasconn = 0;

	conn[0].fd = sockbind(SOCKET);	/* Server socket */
	conn[0].events = POLLIN | POLLPRI;

	printf("Starting up server - fd=%d\n", conn[0].fd);

	for (;;) {
		conn_rdy = poll(conn, 2, 1000);

		if (conn[0].revents & POLLIN) {
			if (!hasconn)
				newconn(&conn[1], conn[0]);
		}

		if (conn[1].revents &  POLLIN) {
			count = read(conn[1].fd, buf, BUFLEN);
			if (count < 0) {
				if (errno == ECONNRESET) {
					close(conn[1].fd);
					hasconn = 0;
				} else {
					warnx("Failed to read from client");
				}
			}

			if (count == 0) {
				closeconn(&conn[1]);
				hasconn = 0;
			} else {
				printf("%s", buf);
			}
		}
	}
		
	return 0;
}
