/*
 * Copyright (c) 2019 Daniel Wilkins <tekk@linuxmail.org>
 *
 * Permission to use, copy, modify, and distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

#define _DEFAULT_SOURCE
#include <ctype.h>
#include <sys/types.h>
#include <fcntl.h>
#include <sys/socket.h>
#include <sys/stat.h>


#include <dirent.h>
#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <pwd.h>
#include <unistd.h>

#include <arpa/inet.h>

#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>


#define LENGTH(x) ((int)(sizeof(x) / sizeof(x[0])))

/*
 * A trivial, single file gopher server.
 * Notable limitations: it only supports directories and text files. This is because POSIX filesystems don't have filetypes,
 * meaning that there's no accurate way to know the type of a file. Heuristics may be added later on for, e.g. image files based
 * on file names.
 */


typedef enum {
    ERR,
    WARN,
    INFO,
    DEBUG
} loglevel;

struct connection {
    int fd;
    struct sockaddr addr;
    socklen_t addrlen;
};

struct mapping {
    char filetype;
    char extension[16];
};

#include "config.h"

char extension_matches(const char *path);

// Spit out a log entry. Maybe we'll add logfile support later but honestly you could just do tgif 2>...
void emit(loglevel l, const char *message) {
    if (l <= LOGLEVEL) {
        fprintf(stderr, "%s\n", message);
        if (l == ERR) {
            exit(1);
        }
    }
}

// Emit a log entry relating to a syscall.
void emit_syscall(loglevel l, const char *message) {
    if (l <= LOGLEVEL) {
        fprintf(stderr, "%s: %s\n", message, strerror(errno));
        if (l == ERR) {
            exit(errno);
        }
    }
}

/* 
 * Drop all of our privileges.
 * Do this in its own function to try and keep gopher_user out of memory
 * in case someone's stupid enough to run this as a real user.
 */
void drop_root(struct passwd *gopher_user) {
    if (!gopher_user) {
        emit_syscall(ERR, "Unable to look up user " GOPHER_USER);
    }
    
    if (setgid(gopher_user->pw_gid) < 0) {
        emit_syscall(ERR, "Unable to change group to " GOPHER_GROUP);
    }
    
    if (setuid(gopher_user->pw_uid) < 0) {
        emit_syscall(ERR, "Unable to change user to " GOPHER_USER);
    }
}

/*
 * The standard says that we should try and fix "problematic" characters, specifically tabs and periods on their own line.
 * This is why the function has the logic it does: the worst possible substitution we could perform is a tab,
 * because tabs are traditionally 8 spaces. This means that, worst case, with a buffer of tabs, we need a buffer 8 times the size
 * to store what we send to the clients. The other notable substitution is a single period by itself on a line, which must be
 * replaced by 2 periods.
 */
void list_text_file(int out_fd, const char *path) {
    #define BUFFER_SIZE 512
    char buffer[BUFFER_SIZE] = {0};
    char send_buffer[8*BUFFER_SIZE] = {0};
    
    int fd = open(path, O_RDONLY);

    char first_byte;
    int read_return = 0;
    
    // -1 means error, 0 means eof (which means just quit printing the file), 1 means successful read.
    switch (read(fd, &first_byte, 1)) {
    case -1:
        emit_syscall(INFO, "Failed to read first byte of file");
        break;
    case 0:
        return;

    case 1:
        if (first_byte == '.') {
            write(out_fd, "..", 2);
        } else {
            write(out_fd, &first_byte, 1);
        }
        break;
    }
    


    /*
     * This is a C quirk: the value of assignment is what was assigned. So we're filling the buffer and checking the value of written
     * all at once. The reason we're doing this is because there are 2 possible cases for written: either it's finally 0, in which
     * case we successfully hit the end of the file, *or* it returned -1 and we hit an error.
     */
    while ((read_return = read(fd, buffer, BUFFER_SIZE)) > 0) {
        // How far into the send buffer we are, which is always *at least* how far we are in buffer.
        int send_position = 0;

        for (int i = 0; i < read_return; i++) {

            if (buffer[i] == '\t') {
                for (int ii = 0; ii < 8; ii++) {
                    send_buffer[send_position++] = ' ';
                }
            } else if (buffer[i] == '.' && buffer[i-1] == '\n') {
                // Replace a single period with 2 periods, but only if it occupies its own line.
                send_buffer[send_position++] = '.';
                send_buffer[send_position++] = '.';
            } else {
                send_buffer[send_position++] = buffer[i];
            }
        }

        // Send buffer's full, send it out. This is safe to do before error checking because if we errored on read, send_position is
        // 0 and it won't send anything.
        printf("%d\n", send_position);
        write(out_fd, send_buffer, send_position);
    }
    
    if (read_return == -1) {
        emit_syscall(INFO, "Failed to read from file");
    }

    #undef BUFFER_SIZE
}

// Identical to list_text_file, except we don't send a . at the end, but rather just close the connection, and we don't replace anything.

void list_binary_file(int out_fd, const char *path) {
    #define BUFFER_SIZE 512
    char buffer[BUFFER_SIZE] = {0};
    
    int fd = open(path, O_RDONLY);

    int read_return = 0;
    
    /*
     * This is a C quirk: the value of assignment is what was assigned. So we're filling the buffer and checking the value of written
     * all at once. The reason we're doing this is because there are 2 possible cases for written: either it's finally 0, in which
     * case we successfully hit the end of the file, *or* it returned -1 and we hit an error.
     */
    while ((read_return = read(fd, buffer, BUFFER_SIZE)) > 0) {
            write(out_fd, buffer, read_return);
    }
    
    if (read_return == -1) {
        emit_syscall(INFO, "Failed to read from file");
    }
    #undef BUFFER_SIZE
}

void list_directory(int out_fd, const char *path) {
    char mappath[1024] = {0};
    snprintf(mappath, 1023, "%s/here.map", path);
    int mapfd = open(mappath, O_RDONLY);
    if (mapfd != -1) {
        close(mapfd);
        list_binary_file(out_fd, mappath);
        write(out_fd, "\r\n", 2);
    } else {
        DIR *d = opendir(path);
        // Impose a hard path size limit of 1023(+ null).
        char entpath[1024] = {};
        struct dirent *ent;
        char type;

        emit(INFO, "Received request for directory at");
        emit(INFO, path);
    
        if (!d) {
            emit_syscall(INFO, "Unable to open directory");
            return;
        }

        strncpy(entpath, path, 1022);
        entpath[strlen(entpath)] = '/';
        entpath[strlen(entpath)] = '\0';
        // The way that this call works is that you *must* check errno after it returns NULL, so you have to set it to 0 first.
        errno = 0;
    
        while ((ent = readdir(d))) {
            // Don't show hidden files.
            if (ent->d_name[0] == '.') {
                continue;
            }
        
            // Give the client the gopher type of the entry.
            switch (ent->d_type) {
            case DT_DIR:
                write(out_fd, "1", 1);
                break;
            case DT_REG:
                type = extension_matches(ent->d_name);
                write(out_fd, &type, 1);
                break;
            default:
                continue;
            }

            write(out_fd, ent->d_name, strlen(ent->d_name));
            write(out_fd, "\t", 1);
            write(out_fd, entpath+(entpath[0] == '/' ? 1 : 0), strlen(entpath) - (entpath[0] == '/' ? 1 : 0));
            write(out_fd, ent->d_name, strlen(ent->d_name));
            write(out_fd, "\t" GOPHER_HOST "\t" GOPHER_PORT "\r\n",  strlen("\t" GOPHER_HOST "\t" GOPHER_PORT "\r\n"));
        }

        if (errno) {
            emit_syscall(WARN, "readdir call failed");
        }
    }
}

char extension_matches(const char *path) {
    int path_len = strlen(path);
    for (int i = 0; i < LENGTH(mappings); i++) {
        int ext_len = strlen(mappings[i].extension);
        // Can't match the extension if it's too short to have it...
        if (path_len > ext_len) {
            printf("%s %s\n", mappings[i].extension, (path + path_len - ext_len));

            if (!strcmp(path + path_len - ext_len, mappings[i].extension)) {
                return mappings[i].filetype;
            }
        }
    }

    return '0';
}

void list(int fd, char *path) {
    struct stat s;
    
    emit(DEBUG, path);
    if (stat(path, &s) < 0) {
        char debugstring[128] = {0};
        snprintf(debugstring, 127, "Unable to stat file %s", path);
        emit_syscall(INFO, debugstring);
    }

    if (s.st_mode & S_IFDIR) {
        list_directory(fd, path);
        write(fd, ".\r\n", 3);

    } else {

        switch (extension_matches(path)) {
        case '9':
            list_binary_file(fd, path);
	    close(fd);
            break;
        case '1':
            list_binary_file(fd, path);
            write(fd, "\r\n.\r\n", 5);
            break;
        default:
           list_text_file(fd, path);
           write(fd, "\r\n.\r\n", 5);
        }
    }
}

void handle_connection(struct connection *conn) {
    char selector[1024] = {0};

    emit(DEBUG, "Connection opened");
    if (read(conn->fd, selector, 1022)) {
        int len = strlen(selector);
        printf("%d\n", len);

        for (int i = 0; i < len; i++) {
            printf("%x", (unsigned char)selector[i]);
        }
        printf("\n");
        

        emit(DEBUG, selector);
        
        if (len > 2 && selector[len-2] == '\r' && selector[len-1] == '\n') {
            selector[len-2] = '\0';
            list(conn->fd, selector);
        } else if (len == 2 && selector[len-2] == '\r' && selector[len-1] == '\n') {
            list(conn->fd, "/");
            
        }

    }
    close(conn->fd);
}

int main(void) {
    struct addrinfo *result;

    //                                      ipv4 and v6               tcp                      listen
    struct addrinfo addhint = {.ai_family = AF_UNSPEC, .ai_socktype = SOCK_STREAM, .ai_flags = AI_PASSIVE};

    // We need to have access to the information we need from /etc/passwd before we chroot
    struct passwd *gopher_user = getpwnam(GOPHER_USER);

    if (chroot(GOPHER_DIR) < 0) {
        emit_syscall(ERR, "Unable to chroot to " GOPHER_DIR);
    }

    //              We're not connecting to an address, we're listening on port 70
    if (getaddrinfo(NULL, GOPHER_PORT, &addhint, &result) < 0) {
        emit_syscall(ERR, "Getaddrinfo failed");
    }

    int listener = socket(result->ai_family, result->ai_socktype, result->ai_protocol);

    int opt = 1;
    setsockopt(listener, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));

    if (bind(listener, result->ai_addr, result->ai_addrlen) < 0) {
        emit_syscall(ERR, "Failed to bind");
    }


    // Okay, we only needed root for chroot and binding to a low port. Drop it.
    drop_root(gopher_user);

    // Make sure we're in the root of our gopher directory
    chdir("/");

    if (listen(listener, 128)) {
        emit_syscall(ERR, "Failed to listen");
    }

    while (true) {
        struct connection conn;

        conn.fd = accept(listener, &conn.addr, &conn.addrlen);

        if (conn.fd < 0) {
            emit_syscall(WARN, "Accept failed");
        }

        handle_connection(&conn);
    }

    // never reached
    return 0;
}
