/*
    linkmon.c - see usage() for details
    Copyright (C) 2009  Thomas Andrews

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
#include <stdio.h>
#include <stdarg.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <linux/ip.h>
#include <linux/icmp.h>
#include <string.h>
#include <unistd.h>
#include <syslog.h>
#include <getopt.h>
#include <signal.h>

#define PROGNAME "linkmon"
#define DEFAULT_TIMEOUT  5  /* 5 Seconds */
#define DEFAULT_FAILURES 5  /* Number of consecutive failed attempts before command is executed */
#define DEFAULT_INTERVAL 5  /* 5 seconds between ping attempts */
#define FOREVER 1           /* Set to 0 to just ping once and exit for ease of debugging */

typedef enum debuglevel {
    no_output = 0,  // You don't use this one in the d_printf call
    panic     = 1,
    error     = 2,
    warning   = 3,
    info      = 4,
    verbose   = 5,
} dbglevel;

unsigned short in_cksum(unsigned short *, int);
void usage();
char *toip(char *);
void d_printf(dbglevel importance, const char *fmt, ...);
void bye(void);

dbglevel cur_dbg_lvl = error;
char *dst_addr;
char *packet;
char *buffer;

int main(int argc, char *argv[])
{
    struct iphdr *ip;
    struct iphdr *ip_reply;
    struct icmphdr *icmp;
    struct sockaddr_in connection;
    int sockfd;
    int optval = 1;
    int addrlen;
    int size;
    char *ptr;
    int res; 
    fd_set ready; 
    struct timeval wait; 
    unsigned short id;
    static int timeout = DEFAULT_TIMEOUT;
    static int interval = DEFAULT_INTERVAL;
    static int failures = DEFAULT_FAILURES;
    int errors = 0;
    int syslog_facility = LOG_DAEMON;
    int c;
    char *command = NULL;
    char *host = NULL;
    int waiting;

    atexit(bye);
    signal(SIGQUIT, exit);
    signal(SIGHUP, exit);
    signal(SIGINT, exit);

    // Check that I'm root
    if (getuid() != 0) {
        fprintf(stderr, "Error: root privelidges needed\n");
        usage();
        exit(EXIT_FAILURE);
    }

    while (1) {
        c = getopt(argc, argv, "c:h:i:n:t:d:l:");
        if (c == -1)
            break;
        switch (c) {
        case 'c':
            command = optarg;
            break;
        case 'h':
            host = optarg;
            break;
        case 'i':
            interval = strtoul(optarg, NULL, 10);
            break;
        case 'n':
            failures = strtoul(optarg, NULL, 10);
            break;
        case 't':
            timeout = strtoul(optarg, NULL, 10);
            break;
        case 'd':
            cur_dbg_lvl = strtoul(optarg, NULL, 10);
            break;
        case 'l':
            syslog_facility = strtoul(optarg, NULL, 10);
            break;
        default:
            usage();
            exit(EXIT_FAILURE);
            break;
        }
    }
    openlog(PROGNAME, LOG_PID|LOG_CONS, syslog_facility);
    syslog(LOG_INFO, "Starting.");

    if (!host) {
        d_printf(panic, "Error: host argument missing\n");
        usage();
        exit(EXIT_FAILURE);
    }
    if (!command) {
        d_printf(panic, "Error: command argument missing\n");
        usage();
        exit(EXIT_FAILURE);
    }

    // Get the IP of the destination host
    ptr = toip(host);
    dst_addr = malloc(1+strlen(ptr));
    strcpy(dst_addr,ptr);
    if (!dst_addr) {
        d_printf(panic, "Error: couldn't find address for %s\n", *(argv + 1));
        exit(EXIT_FAILURE);
    }

    d_printf(info, "Destination address: %s\n", dst_addr);

    int len = (sizeof(struct iphdr) + sizeof(struct icmphdr));
    packet = malloc(len);
    buffer = malloc(len);

    do {
        ip = (struct iphdr *) packet;
        icmp = (struct icmphdr *) (packet + sizeof(struct iphdr));
        memset(packet, 0, len);

        //  Set up the IP packet
        ip->ihl = 5;
        ip->version = 4;
        ip->tos = 0;
        ip->tot_len = len;
        ip->id = htons(0);
        ip->frag_off = 0;
        ip->ttl = 64;
        ip->protocol = IPPROTO_ICMP;
        ip->saddr = 0; // Let the kernel set it
        ip->daddr = inet_addr(dst_addr);
        ip->check = 0; // Let the kernel update it

        if ((sockfd = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP)) == -1) {
            perror("socket");
            exit(EXIT_FAILURE);
        }

        // Stop the kernel from automatically adding a default IP header to the packet
        setsockopt(sockfd, IPPROTO_IP, IP_HDRINCL, &optval, sizeof(int));

        // Create the ICMP part of the packet
        icmp->type = ICMP_ECHO;
        icmp->code = 0;
        id = random();
        icmp->un.echo.id = id;
        icmp->un.echo.sequence = 0;
        icmp->checksum = in_cksum((unsigned short *) icmp, sizeof(struct icmphdr));

        connection.sin_family = AF_INET;
        connection.sin_addr.s_addr = inet_addr(dst_addr);

        // Send the packet
        sendto(sockfd, packet, ip->tot_len, 0, (struct sockaddr *) &connection, sizeof(struct sockaddr));
        d_printf(info, "Sent %d byte packet to %s\n", ip->tot_len, dst_addr);

        FD_ZERO(&ready); 
        FD_SET((unsigned int)sockfd, &ready); 
        memset((char *)&wait, 0, sizeof(wait)); 
        wait.tv_sec = timeout;
        waiting = 1;

        while (waiting) {
            // Wait for a response, for the prescribed timeout period
            res = select(sockfd+1, &ready, NULL, NULL, &wait); 
            if (res == -1) {
                perror("select()");
                exit(EXIT_FAILURE);
            }
            else if (res) {
                // Get the response (which could be "Destination Host Unreachable")
                addrlen = sizeof(connection);
                if ((size = recvfrom(sockfd, buffer, len, 0, (struct sockaddr *) &connection, (socklen_t *) & addrlen)) == -1) {
                    perror("recv");
                } else {
                    struct in_addr ad;
                    ip_reply = (struct iphdr *) buffer;
                    ad.s_addr = ip_reply->saddr;
                    d_printf(info, "Received %d byte reply from %s:\n", size, inet_ntoa(ad));
                    d_printf(verbose, "ID: %d\n", ntohs(ip_reply->id));
                    icmp = (struct icmphdr *) (buffer + sizeof(struct iphdr));
                    d_printf(verbose, "ICMP.type: %d\n", icmp->type);
                    d_printf(verbose, "ICMP.code: %d\n", icmp->code);
                    d_printf(verbose, "ICMP.id: %d (%#x)\n", icmp->un.echo.id, icmp->un.echo.id);
                    if(icmp->un.echo.id != id) {
                        d_printf(verbose, "ICMP ID mismatch (expected %d (%#x))\n", id,id); // Other pings make this happen
                    }
                    if (icmp->type == ICMP_DEST_UNREACH) {
                        d_printf(verbose, "Destination unreachable\n");
                    }
                    if (icmp->type != ICMP_ECHOREPLY) {
                        d_printf(verbose, "ICMP type mismatch\n");  // Other pings make this happen
                    }

                    if (icmp->un.echo.id == id && icmp->type == ICMP_ECHOREPLY) {
                        d_printf(verbose, "ok\n");
                        waiting = 0;
                        errors = 0;
                    }
                }
            } else {
                // No response
                d_printf(warning, "No response\n");
                waiting = 0;
                errors++;
            }
        }
        close(sockfd);
        if (errors >= failures) {
            errors = 0;
            // Execute the command specified..
            d_printf(info, "Executing command \'%s\'\n", command);
            syslog(LOG_INFO, "Executing command \'%s\'", command);
            system(command);
        }
        sleep(interval);
    } while (FOREVER);

    exit(EXIT_SUCCESS);
}

void bye(void)
{
    syslog(LOG_INFO, "Exiting.");
    closelog();
    free(dst_addr);
    free(packet);
    free(buffer);
}

void usage()
{
    fprintf(stdout, "\nThis program sends ICMP ECHO packets to the address of ");
    fprintf(stdout, "\nyour choice and then listens for ICMP REPLY packets.");
    fprintf(stdout, "\nIf there's no response after the specified period it will retry repeatedly");
    fprintf(stdout, "\nuntil the specified count is reached, after which it will execute");
    fprintf(stdout, "\nthe command specified. You can specify the period between pings");
    fprintf(stdout, "\nand also the number of failed attempts to make before running the command.");
    fprintf(stdout, "\nThe syslog 'facility' can be specified too.\n");
    fprintf(stdout, "\nIt doesn't daemonise - it is meant to be run from inittab.");
    fprintf(stdout, "\nIt has to be run with UID 0 (ie root privileges)\n");
    fprintf(stdout, "\nUsage: " PROGNAME " -h <HOST> -c <COMMAND> -i <INTERVAL> -n <FAILURES> -t <TIMEOUT> -d <DEBUGLEVEL> -l <SYSLOG FACILITY)\n\n");
    fprintf(stdout, "eg: " PROGNAME " -h 169.254.1.1 -c 'killall pppd'\n\n");
    fprintf(stdout, "    HOST and COMMAND must be provided\n");
    fprintf(stdout, "    COMMAND must be put into quotes if it has more than one argument\n");
    fprintf(stdout, "    INTERVAL, FAILURES, TIMEOUT, DEBUGLEVEL, and SYSLOG FACILITY are optional\n\n");
    fprintf(stdout, "    INTERVAL defaults to %d seconds between pings\n", DEFAULT_INTERVAL);
    fprintf(stdout, "    TIMEOUT defaults to %d seconds (awaiting response)\n", DEFAULT_TIMEOUT);
    fprintf(stdout, "    FAILURES defaults to %d attempts\n", DEFAULT_FAILURES);
    fprintf(stdout, "    DEBUGLEVEL defaults to %d (errors) 0=off 1=panic 2=error 3=warning 4=info 5=verbose \n", error);
    fprintf(stdout, "    SYSLOG FACILITY defaults to %d\n\n", LOG_DAEMON);
}

// Return the ip address if host provided by DNS name
char *toip(char *address)
{
    struct hostent *h;
    h = gethostbyname(address);
    if (h)
        return inet_ntoa(*(struct in_addr *) h->h_addr);
    else
        return NULL;
}

// Checksum routine for Internet Protocol family headers (C Version)
unsigned short in_cksum(unsigned short *addr, int len)
{
    register int sum = 0;
    u_short answer = 0;
    register u_short *w = addr;
    register int nleft = len;

    // Using a 32 bit accumulator, add sequential 16 bit words. At the end,
    // add back all the carry bits from the top 16 bits.
    while (nleft > 1) {
        sum += *w++;
        nleft -= 2;
    }
    // Mop up an odd byte if necessary
    if (nleft == 1) {
        *(u_char *) (&answer) = *(u_char *) w;
        sum += answer;
    }
    
    sum = (sum >> 16) + (sum & 0xffff); // Add back all the carry bits
    sum += (sum >> 16);                 // And again
    answer = ~sum;                      // Complement
    return (answer);
}

void d_printf(dbglevel importance, const char *fmt, ...) {
    va_list ap;
    unsigned char buf[500];

    if(importance <= cur_dbg_lvl) {
        // Log to the syslog
        va_start(ap, fmt);
        vsnprintf(buf, sizeof(buf), fmt, ap);
        va_end(ap);
        syslog(LOG_INFO, buf);
        // Log to the console
        va_start(ap, fmt);
        vprintf(fmt, ap);
        va_end(ap);
    }
}
