#include "dtls.h"
#include "static_task.h"
#include "wifi.h"
#include "freertos/FreeRTOS.h"
#include "freertos/FreeRTOSConfig.h"
#include "freertos/event_groups.h"
#include "freertos/timers.h"
#include "freertos/portmacro.h"
#include "mbedtls/net_sockets.h"
#include "mbedtls/ssl.h"
#include "mbedtls/entropy.h"
#include "mbedtls/ctr_drbg.h"
#include "mbedtls/debug.h"
#include "mbedtls/error.h"
#include "mbedtls/certs.h"
#include "mbedtls/timing.h"
#include "mbedtls/platform.h"
#include "global_defs.h"
#include "sdkconfig.h"
#include <errno.h>
#include <stdio.h>
#include <string.h>

enum {TIMER_PERIOD_MS = 10 / portTICK_PERIOD_MS};

struct dtls_static_priv
{
    mbedtls_net_context server_fd;
    mbedtls_entropy_context entropy;
    mbedtls_ctr_drbg_context ctr_drbg;
    mbedtls_ssl_context ssl;
    mbedtls_ssl_config conf;
    mbedtls_x509_crt cacert;

    struct dtls_time
    {
        TimerHandle_t handle;
        StaticTimer_t static_data;
        struct delay_priv
        {
            uint32_t ms;
            uint8_t count;
            bool finished;
        } interm, fin;
    } timer;
};

_Static_assert (sizeof (dtls_static) == sizeof (struct dtls_static_priv),
                "dtls_static public size mismatch");

#undef mbedtls_printf
#define mbedtls_printf(...) printf(__VA_ARGS__)

static void update_delay_priv(struct delay_priv *const dp)
{
    if (dp->count >= dp->ms)
        dp->finished = true;
    else
        dp->count++;
}

static void dtls_timer_expired(const TimerHandle_t handle)
{
    if (handle)
    {
        struct dtls_time *const t = (struct dtls_time *)pvTimerGetTimerID(handle);

        if (t)
        {
            update_delay_priv(&t->interm);
            update_delay_priv(&t->fin);
        }
        else
            mbedtls_printf("Timer handle %p not associated to any dtls_time instance\n", handle);
    }
    else
        mbedtls_printf("dtls_timer_expired: invalid timer handle\n");
}

static void timing_set_delay(void *const data, const uint32_t int_ms, const uint32_t fin_ms)
{
    struct dtls_time *const dtlst = data;

    dtlst->interm.ms = int_ms / TIMER_PERIOD_MS;
    dtlst->fin.ms = fin_ms / TIMER_PERIOD_MS;

    if (fin_ms)
    {
        dtlst->interm.count = 0;
        dtlst->interm.finished = false;

        dtlst->fin.count = 0;
        dtlst->fin.finished = false;

        xTimerStart(dtlst->handle, 0);
    }
}

static int timing_get_delay(void *const data)
{
    struct dtls_time *const dtlst = (struct dtls_time *)data;

    if (!dtlst->fin.ms)
        return -1;
    else if (dtlst->fin.finished)
        return 2;
    else if (dtlst->interm.finished)
        return 1;

    return 0;
}

static void my_debug( void *ctx, int level, const char *file, int line, const char *str )
{
    ((void) level);

    fprintf((FILE *)ctx, "%s:%04d: %s", file, line, str);
    fflush((FILE *)ctx);
}

static void dtls_close_priv(struct dtls_static_priv *const dtls)
{
    mbedtls_net_free(&dtls->server_fd);
    mbedtls_x509_crt_free(&dtls->cacert);
    mbedtls_ssl_free(&dtls->ssl);
    mbedtls_ssl_config_free(&dtls->conf);
    mbedtls_ctr_drbg_free(&dtls->ctr_drbg);
    mbedtls_entropy_free(&dtls->entropy);
}

static int dtls_retry(struct dtls_static_priv *const dtls)
{
    enum {MAX_RETRIES = 5};
    int ret = 0;
    size_t retries = 0;

    mbedtls_printf("Operation failed. Retrying...\n");

    while (retries++ < MAX_RETRIES)
    {
        if ((ret = mbedtls_ssl_session_reset(&dtls->ssl)))
        {
            mbedtls_printf("\tRetry %d failed, code: %d", retries, ret);
        }
        else
            break;
    }

    if (retries > MAX_RETRIES)
    {
        mbedtls_printf("Connection lost after %d retries\n", MAX_RETRIES);
        dtls_close_priv(dtls);
        ret = 1;
    }

    return ret;
}

int dtls_connect(const char *hostname, const char *port, dtls_static *const dtls_public)
{
    struct dtls_static_priv *const dtls = (struct dtls_static_priv *)dtls_public;
    int ret;

    while (!wifi_is_connected());

    mbedtls_net_init(&dtls->server_fd);
    mbedtls_ssl_init(&dtls->ssl);
    mbedtls_ssl_config_init(&dtls->conf);
    mbedtls_x509_crt_init(&dtls->cacert);
    mbedtls_ctr_drbg_init(&dtls->ctr_drbg);
    mbedtls_entropy_init(&dtls->entropy);

    {
        uint8_t mac_addr[6];
        ESP_ERROR_CHECK(esp_efuse_mac_get_default(mac_addr));

        if ((ret = mbedtls_ctr_drbg_seed(&dtls->ctr_drbg, mbedtls_entropy_func,
                                         &dtls->entropy, mac_addr, sizeof mac_addr)))
        {
            mbedtls_printf( " failed\n  ! mbedtls_ctr_drbg_seed returned %d\n", ret );
            return ret;
        }
    }

    if ((ret = mbedtls_x509_crt_parse(&dtls->cacert,
                                      (const unsigned char *)mbedtls_test_cas_pem,
                                      mbedtls_test_cas_pem_len)) < 0)
    {
        mbedtls_printf( " failed\n  !  mbedtls_x509_crt_parse returned -0x%x\n\n", -ret );
        return ret;
    }

    {
        enum
        {
            MIN_TIMEOUT_MS = 1000,
            MAX_TIMEOUT_MS = 60000
        };

        mbedtls_ssl_conf_handshake_timeout(&dtls->conf, MIN_TIMEOUT_MS, MAX_TIMEOUT_MS);
    }

    ret = mbedtls_net_connect(&dtls->server_fd, hostname, port, MBEDTLS_NET_PROTO_UDP);

    switch (ret)
    {
        case 0:
            /* Connection successful. */
            break;

        case MBEDTLS_ERR_NET_UNKNOWN_HOST:
            mbedtls_printf("Unknown host %s\n", hostname);
            return ret;

        case MBEDTLS_ERR_NET_CONNECT_FAILED:
            mbedtls_printf("Connection to %s:%s failed\n", hostname, port);
            return ret;

        case MBEDTLS_ERR_NET_SOCKET_FAILED:
            /* Fall through. */
        default:
            mbedtls_printf("mbedtls_net_connect failed with error code -0x%x\n", -ret);
            return ret;
    }

    if ((ret = mbedtls_ssl_config_defaults(&dtls->conf,
                                            MBEDTLS_SSL_IS_CLIENT,
                                            MBEDTLS_SSL_TRANSPORT_DATAGRAM,
                                            MBEDTLS_SSL_PRESET_DEFAULT)))
    {
        mbedtls_printf( " failed\n ! mbedtls_ssl_config_defaults returned %d\n\n", ret );
        return ret;
    }

    /* Taken from mbedtls/programs/ssl/dtls_client.c:
     * OPTIONAL is usually a bad choice for security, but makes interop easier
     * in this simplified example, in which the ca chain is hardcoded.
     * Production code should set a proper ca chain and use REQUIRED. */
    mbedtls_ssl_conf_authmode(&dtls->conf, MBEDTLS_SSL_VERIFY_NONE);
    mbedtls_ssl_conf_ca_chain( &dtls->conf, &dtls->cacert, NULL );
    mbedtls_ssl_conf_rng(&dtls->conf, mbedtls_ctr_drbg_random, &dtls->ctr_drbg);
    mbedtls_ssl_conf_dbg(&dtls->conf, my_debug, stdout);

    if ((ret = mbedtls_ssl_setup(&dtls->ssl, &dtls->conf)))
    {
        mbedtls_printf("mbedtls_ssl_setup() failed with code %d\n\n", ret);
        return ret;
    }

    if ((ret = mbedtls_ssl_set_hostname(&dtls->ssl, hostname)))
    {
        mbedtls_printf( " failed\n ! mbedtls_ssl_set_hostname returned %d\n\n", ret);
        return ret;
    }

    mbedtls_ssl_set_bio(&dtls->ssl, &dtls->server_fd, mbedtls_net_send, NULL, mbedtls_net_recv_timeout);

    dtls->timer.handle = xTimerCreateStatic("DTLS", TIMER_PERIOD_MS, pdTRUE, &dtls->timer,
                                            dtls_timer_expired, &dtls->timer.static_data);

    if (!dtls->timer.handle)
    {
        mbedtls_printf("%s, %d: xTimerCreateStatic() failed", __FILE__, __LINE__);
        return -1;
    }

    mbedtls_ssl_set_timer_cb(&dtls->ssl, &dtls->timer, timing_set_delay, timing_get_delay);

    do {
        mbedtls_printf("ssl->state = %d\n", dtls->ssl.state);
        ret = mbedtls_ssl_handshake(&dtls->ssl);
    }
    while( ret == MBEDTLS_ERR_SSL_WANT_READ ||
           ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
           ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS);

    if (ret)
    {
        mbedtls_printf( " failed\n  ! mbedtls_ssl_handshake returned -0x%x\n\n", -ret );
        return ret;
    }

    if ((ret = mbedtls_ssl_get_verify_result(&dtls->ssl)))
    {
        mbedtls_printf("mbedtls_ssl_get_verify_result returned %d\n\n", ret);
        return ret;
    }

    return 0;
}

int dtls_write(dtls_static *const dtls_public, const void *const buf, const size_t len)
{
    struct dtls_static_priv *const dtls = (struct dtls_static_priv *)dtls_public;
    size_t bytes_written = 0;
    int ret;

    do
    {
        if ((ret = mbedtls_ssl_write(&dtls->ssl, (const unsigned char *)buf, len)) >= 0)
            bytes_written += ret;
        else
        {
            mbedtls_printf("mbedtls_ssl_write returned error code %d", ret);
            if ((ret = dtls_retry(dtls)))
                return ret;
        }
    } while (bytes_written < len);

    return 0;
}

int dtls_read(dtls_static *const dtls_public, void *const buf, const size_t len)
{
    int ret;
    size_t read_bytes = 0;
    struct dtls_static_priv *const dtls = (struct dtls_static_priv *)dtls_public;

    do
    {
        ret = mbedtls_ssl_read(&dtls->ssl, (unsigned char*)buf, len);

        if (ret > 0)
        {
            if (read_bytes += ret >= len)
                return 0;
        }
        else
        {
            mbedtls_printf("mbedtls_ssl_read returned error code %d", ret);
            if ((ret = dtls_retry(dtls)))
                return ret;
        }
    } while (1);

    return -1;
}

void dtls_close(dtls_static *const dtls_public)
{
    struct dtls_static_priv *const dtls = (struct dtls_static_priv *)dtls_public;

    dtls_close_priv(dtls);
}
