######################################################################################
#
#   HW Debugger for BW board bringup
#
#   author: Ananda (ANN)
#
######################################################################################

# imports
import serial
import msvcrt
import time
import argparse
import serial.tools.list_ports as list_ports
import os
import socket
import datetime

try:
    import tftpy
    firmwareUpdate = 1
    from tftpy.TftpShared import SOCK_TIMEOUT
except:
    print("tftpy not install. Firmware update not available")
    firmwareUpdate = 0

from bw_pack_request import pack_request

tftp_file_size = 0
tftp_total_transfer = 0
tftp_prog = 0

# Prints progress in steps of 10 %
def tftpy_hook( arg ):
    global tftpy_file_size
    global tftp_total_transfer
    global tftp_prog
    if type(arg) is tftpy.TftpPacketTypes.TftpPacketDAT:
        tftp_total_transfer = tftp_total_transfer + len(arg.decode().data)
        perc = int( 100 * tftp_total_transfer / tftp_file_size )
        if( perc % 10 == 0 and tftp_prog != perc):
            print("Transfer: " + str(perc) + "%" )
            tftp_prog = perc


def flash_mcu(flashtool, name):
    # Flash program binary
    if( name is None ):
        name = 'BW_CDA_Firmware.hex'

    if( name[-4:] != '.hex'):
        print("Must be a .hex file")
        return

    if( flashtool is not None):
        if( flashtool == "JLINK" ):
            print("------------------------------------")
            cmdFile = open("commandFile.jlink", "w")
            cmdFile.write('r\n')
            cmdFile.write('loadfile ' + name + '\n')
            cmdFile.write('r\n')
            cmdFile.write('exit\n')
            cmdFile.close()
            os.system("JLink -device STM32F767ZI -if SWD -speed 4000 commandFile.jlink")
            print("------------------------------------")
            time.sleep(0.1)
        elif( flashtool == "STLINK" ):
            print("------------------------------------")
            os.system("ST-LINK_CLI.exe -ME -P " + name + " -V \"while_programming\" -Rst")
            print("------------------------------------")
            time.sleep(0.1)
        else:
            print("FLASHTOOL must be STLINK or JLINK. Firmware file should be called BW_CDA_Firmware.hex")

# Argument parser
parser = argparse.ArgumentParser(description='HW debug interface program')
parser.add_argument('--port', default='COM99')
parser.add_argument('--baud', type=int, default=921600)
parser.add_argument('--ip')
parser.add_argument('--flashtool')
args = parser.parse_args()

if( args.flashtool is not None ):
    flash_mcu(args.flashtool, None)


# Time related variables
start = 0
toggle_time = 1000
sleep_time = 0.001
cur_time = time.time()
use_ip = 0

# Open Socket and use that for communication
if( args.ip is not None ):
    # Open socket
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    try:
        s.settimeout(2)
        s.connect( (args.ip, 5000) )
    except:
        print("Opening socket failed. check ip: {}".format(args.ip))
        s.close()
        exit()
    
    s.settimeout(0.01)
    use_ip = 1

    print("Connection open {}".format(s))

    print("Type 'exit' to quit")
    print("------------------------------------")

# Open serial and use that if IP is not selected
else:
    # Open COM port for UART Connection
    COMPort = args.port
    try:
        serialPort = serial.Serial(port=COMPort, baudrate=args.baud)
    except:
        print("Error opening {} port".format(COMPort))
        print("Available ports:")
        print("------------------------------------")
        i = 0
        for port in list_ports.comports():
            if port[2] != 'n/a':
                print( "[{}]  ".format(i) + str(port) )
            i = i+1
        print("[{}]  Exit".format(i))
        print("------------------------------------")
        user_select = input("Select port (0-{}): ".format(i))

        if( user_select == str(i) ):
            exit()
        elif( int(user_select) < i ):
            port = list_ports.comports()[int(user_select)]
            COMPort = str(port[0])
            serialPort = serial.Serial(port=COMPort, baudrate=args.baud)
        else:
            print("Invalid selection")
            exit()

    print("Connection open on port {}".format(COMPort))

    print("Type 'exit' to quit")
    print("------------------------------------")



while( 1 ):
    if use_ip:
        try:
            msg = s.recv(200)
            try:
                print( "[" + datetime.datetime.now().strftime("%H:%M:%S.%f") + "] - " + msg.decode("Ascii"), end='')
            except:
                print("Unable to print. None-Ascii recieved")
        except:
            None
    else:
        # Wait until there is data waiting in the serial buffer
        if serialPort.in_waiting > 0:

            # Read data out of the buffer until a carraige return / new line is found
            serialString = serialPort.readline()

            # Print the contents of the serial data
            try:
                print( "[" + datetime.datetime.now().strftime("%H:%M:%S.%f") + "] - " + serialString.decode("Ascii"), end='')
            except:
                print("Unable to print. None-Ascii recieved")

    if msvcrt.kbhit():

        request = input("Command: ")
        request_items = request.rsplit()

        # Empty enter
        if( len(request_items) < 1 ):
            continue

        # Exit program
        if (request_items[0] == 'Exit' or request_items[0] == 'exit'):
            break

        # Check for start / stop toggle
        if (request_items[0] == 'Start'):
            toggle_time = int(request_items[1])
            if toggle_time < 100:
                toggle_time = 100

            request = request[ (len(request_items[0])+len(request_items[1])+1) : ]
            start_command = pack_request(request)

            start = 1
            start_cnt = 0
            cur_time = time.time()

        elif (request_items[0] == 'FirmwareUpgrade' and len(request_items) == 2):
            if( firmwareUpdate == 0 ):
                print("tftpy not installed. Cannot update over tftp")
                continue
            print("Starting firmware update over TFTP for " + request_items[1])
            client = tftpy.TftpClient(request_items[1], 69)
            tftp_file_size = os.path.getsize('BW_CDA_Firmware_crc.bin')
            client.upload('update.bin','BW_CDA_Firmware_crc.bin', timeout=20, packethook=tftpy_hook)
            tftp_total_transfer = 0
            tftp_prog = 0
            print("Sending complete")
            time.sleep(1)
            request = 'Upgrade'
            command = pack_request( request )

            if command is not None:
                print( "Sending: " + command )
                if use_ip:
                    s.send(command.encode())
                else:
                    serialPort.write(command.encode())

        # Stop periodic
        elif (request_items[0] == 'Stop'):
            start = 0
            continue

        elif( request_items[0] == 'Flash' and len(request_items)>=2 ):
            if( len(request_items)==3 ):
                name = request_items[2]
            else:
                name = None
            flash_mcu(request_items[1], name)
            continue

        # Handle single commands
        else:
            command = pack_request( request )

            if command is not None:
                print( "Sending: " + command )
                if use_ip:
                    s.send(command.encode())
                else:
                    serialPort.write(command.encode())
    
    # Set periodic function
    if ((time.time() - cur_time)*1000 > toggle_time) and start:
        print( "Sending: " + start_command )
        if use_ip:
            s.send(start_command.encode())
        else:
            serialPort.write(start_command.encode())
        cur_time = time.time()

    time.sleep(sleep_time)

print("Closing connection")
if use_ip:
    s.close()
else:
    serialPort.close()