#!/usr/bin/env python3

# tello_rc.py
# 2022-04-30
# by Gernot Walzl

# This script interacts with the DJI Ryze Tello drone.
# The drone can be controlled via keyboard and/or gamepad.
# The stream of the drone's camera is shown in a window.
#
# The firewall needs to allow incoming UDP traffic on the following ports:
# -A INPUT -p udp --dport 8889 -j ACCEPT
# -A INPUT -p udp --dport 8890 -j ACCEPT
# -A INPUT -p udp --dport 11111 -j ACCEPT
#
# The function TelloPygameWindow._handle_keydown shows the assignment of keys.
#
# This script is based on the Tello SDK:
# https://dl-cdn.ryzerobotics.com/downloads/Tello/Tello%20SDK%202.0%20User%20Guide.pdf
# https://dl-cdn.ryzerobotics.com/downloads/tello/20180222/Tello3.py
# https://github.com/damiafuentes/DJITelloPy
#
# Probably, it would be better to send the same commands
# as the official Tello App:
# https://bitbucket.org/PingguSoft/pytello/src/master/

from datetime import datetime
import logging
import socket
import threading
import av
import cv2
import numpy as np
import pygame


class StateReceiverThread(threading.Thread):

    def __init__(self):
        threading.Thread.__init__(self)
        self._running = False
        self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self._socket.bind(('', 8890))
        self._state = dict()

    def _parse_state(self, state):
        for field in state.split(';'):
            if ':' in field:
                key, value = field.split(':')
                self._state[key] = value

    def run(self):
        self._running = True
        while self._running:
            data, _ = self._socket.recvfrom(1024)
            try:
                str_state = data.decode(encoding='utf-8').strip()
                self._parse_state(str_state)
            except UnicodeDecodeError:
                logging.warning("Unable to decode state.")

    def stop(self):
        self._running = False

    def get_state(self):
        return self._state


class StreamReceiverThread(threading.Thread):

    def __init__(self):
        threading.Thread.__init__(self)
        self._running = False
        self._frame_curr = None

    def run(self):
        self._running = True
        with av.open('udp://0.0.0.0:11111', format='h264') as container:
            for frame in container.decode(video=0):
                self._frame_curr = frame.to_ndarray(format='rgb24')
                if not self._running:
                    break

    def stop(self):
        self._running = False

    def get_frame(self):
        return self._frame_curr


class Tello:

    def __init__(self):
        self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self._socket.bind(('', 8889))
        self._address_tello = ('192.168.10.1', 8889)
        self._thread_state = StateReceiverThread()
        self._thread_stream = StreamReceiverThread()
        self._connected = False

    def __del__(self):
        if self._connected:
            self.disconnect()

    def _send_command(self, command):
        data_send = command.encode(encoding='utf-8')
        self._socket.sendto(data_send, self._address_tello)

    def _read_response(self, command, timeout=0.2):
        response = None
        self._socket.settimeout(timeout)
        data_send = command.encode(encoding='utf-8')
        self._socket.sendto(data_send, self._address_tello)
        try:
            data_recv, _ = self._socket.recvfrom(1024)
            response = data_recv.decode(encoding='utf-8').strip()
        except socket.timeout:
            logging.warning("No response received.")
        except UnicodeDecodeError:
            logging.warning("Unable to decode response.")
        self._socket.settimeout(None)
        return response

    def connect(self):
        if self._connected:
            return False
        if self._read_response("command") != "ok":
            return False
        self._thread_state.start()
        self._thread_stream.start()
        if self._read_response("streamon") != "ok":
            return False
        self._connected = True
        return self._connected

    def disconnect(self):
        self.send_rc(0, 0, 0, 0)
        self._thread_stream.stop()
        self._thread_state.stop()
        self._thread_state.join(1.0)
        self._thread_stream.join(1.0)
        self._send_command("streamoff")
        self._connected = False

    def takeoff(self):
        return self._read_response("takeoff", 20.0) == "ok"

    def land(self):
        return self._read_response("land", 20.0) == "ok"

    def send_rc(self, velo_right, velo_forward, velo_up, velo_clockwise):
        cmd = "rc {} {} {} {}".format(
            velo_right,
            velo_forward,
            velo_up,
            velo_clockwise
        )
        self._send_command(cmd)

    def read_wifi_snr(self):
        return self._read_response("wifi?")

    def get_state(self):
        return self._thread_state.get_state()

    def get_frame(self):
        return self._thread_stream.get_frame()


class TelloPygameWindow:

    def __init__(self, tello):
        pygame.init()
        pygame.display.set_caption("Tello")
        self._size = (960, 720)
        self._screen = pygame.display.set_mode(self._size)
        self._running = False
        self._tello = tello
        self._thread_tello = threading.Thread()
        self._speed_keys = 50
        self._rc_velo_right = 0
        self._rc_velo_forward = 0
        self._rc_velo_up = 0
        self._rc_velo_clockwise = 0
        self._font = pygame.font.Font(pygame.font.get_default_font(), 16)
        self._text_busy = ''
        self._init_joysticks()

    def __del__(self):
        pygame.quit()

    def _init_joysticks(self):
        pygame.joystick.init()
        for cnt in range(pygame.joystick.get_count()):
            joystick = pygame.joystick.Joystick(cnt)
            logging.info("Joystick detected: {}".format(joystick.get_name()))
            joystick.init()

    def _toggle_fullscreen(self):
        if self._screen.get_flags() & pygame.FULLSCREEN:
            self._screen = pygame.display.set_mode(self._size)
        else:
            self._screen = pygame.display.set_mode(
                self._size,
                pygame.FULLSCREEN
            )

    def _record_image(self):
        filename = "tello_{}.png".format(
            datetime.now().strftime("%Y-%m-%d_%H%M%S")
        )
        logging.info("Recording image to {}".format(filename))
        image = cv2.cvtColor(self._tello.get_frame(), cv2.COLOR_RGB2BGR)
        cv2.imwrite(filename, image)

    def _start_takeoff(self):
        if not self._thread_tello.is_alive():
            logging.info("TAKING OFF")
            self._text_busy = "TAKING OFF"
            self._thread_tello = threading.Thread(target=self._tello.takeoff)
            self._thread_tello.start()

    def _start_land(self):
        if not self._thread_tello.is_alive():
            logging.info("LANDING")
            self._text_busy = "LANDING"
            self._thread_tello = threading.Thread(target=self._tello.land)
            self._thread_tello.start()

    def _handle_keydown(self, key):
        if key == pygame.K_ESCAPE:
            self._running = False
        elif key == pygame.K_UP:
            self._rc_velo_forward = self._speed_keys
        elif key == pygame.K_DOWN:
            self._rc_velo_forward = -self._speed_keys
        elif key == pygame.K_LEFT:
            self._rc_velo_right = -self._speed_keys
        elif key == pygame.K_RIGHT:
            self._rc_velo_right = self._speed_keys
        elif key == pygame.K_w:
            self._rc_velo_up = self._speed_keys
        elif key == pygame.K_s:
            self._rc_velo_up = -self._speed_keys
        elif key == pygame.K_a:
            self._rc_velo_clockwise = -self._speed_keys
        elif key == pygame.K_d:
            self._rc_velo_clockwise = self._speed_keys
        elif key == pygame.K_t:
            self._start_takeoff()
        elif key == pygame.K_l:
            self._start_land()
        elif key == pygame.K_r:
            self._record_image()
        elif key == pygame.K_f:
            self._toggle_fullscreen()

    def _handle_keyup(self, key):
        if key == pygame.K_UP:
            self._rc_velo_forward = 0
        elif key == pygame.K_DOWN:
            self._rc_velo_forward = 0
        elif key == pygame.K_LEFT:
            self._rc_velo_right = 0
        elif key == pygame.K_RIGHT:
            self._rc_velo_right = 0
        elif key == pygame.K_w:
            self._rc_velo_up = 0
        elif key == pygame.K_s:
            self._rc_velo_up = 0
        elif key == pygame.K_a:
            self._rc_velo_clockwise = 0
        elif key == pygame.K_d:
            self._rc_velo_clockwise = 0

    def _handle_joybuttondown(self, button):
        if button == 1:    # PS4 red circle
            self._start_land()
        elif button == 2:  # PS4 green triangle
            self._start_takeoff()
        elif button == 5:  # PS4 R1
            self._record_image()

    def _handle_joyaxismotion(self, axis, value):
        if axis == 0:      # PS4 left left/right
            self._rc_velo_clockwise = int(100.0 * value)
        elif axis == 1:    # PS4 left up/down
            self._rc_velo_up = int(-100.0 * value)
        elif axis == 3:    # PS4 right left/right
            self._rc_velo_right = int(100.0 * value)
        elif axis == 4:    # PS4 right up/down
            self._rc_velo_forward = int(-100.0 * value)

    def _determine_color_perc(self, value):
        color = (255, 255, 255)
        try:
            i_val = int(value)
            if i_val > 66:
                color = (0, 255, 0)
            elif i_val > 33:
                color = (255, 255, 0)
            else:
                color = (255, 0, 0)
        except (TypeError, ValueError):
            pass
        return color

    def _determine_color_temp(self, value):
        color = (255, 255, 255)
        try:
            if int(value) > 80:
                color = (255, 0, 0)
        except (TypeError, ValueError):
            pass
        return color

    def _render_state(self, state):
        pos_y = 8
        for key, val in state.items():
            text = "{}: {}".format(key, val)
            color = (255, 255, 255)
            if key == 'bat':
                color = self._determine_color_perc(val)
            elif key in ('templ', 'temph'):
                color = self._determine_color_temp(val)
            surf_text = self._font.render(text, True, color)
            self._screen.blit(surf_text, (8, pos_y))
            pos_y += 20

    def _render_wifi(self, wifi_snr):
        text = "WIFI {}".format(wifi_snr)
        color = self._determine_color_perc(wifi_snr)
        surf_text = self._font.render(text, True, color)
        rect_text = surf_text.get_rect(topright=(952, 8))
        self._screen.blit(surf_text, rect_text)

    def _render_control(self, text):
        surf_text = self._font.render(text, True, (255, 255, 255))
        rect_text = surf_text.get_rect(center=(480, 704))
        self._screen.blit(surf_text, rect_text)

    def mainloop(self):
        clock = pygame.time.Clock()
        pygame.time.set_timer(pygame.USEREVENT+1, 1000)
        wifi_snr = self._tello.read_wifi_snr()
        self._running = True
        while self._running:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    self._running = False
                elif event.type == pygame.KEYDOWN:
                    self._handle_keydown(event.key)
                elif event.type == pygame.KEYUP:
                    self._handle_keyup(event.key)
                elif event.type == pygame.JOYBUTTONDOWN:
                    self._handle_joybuttondown(event.button)
                elif event.type == pygame.JOYAXISMOTION:
                    self._handle_joyaxismotion(event.axis, event.value)
                elif event.type == pygame.USEREVENT+1:
                    if not self._thread_tello.is_alive():
                        wifi_snr = self._tello.read_wifi_snr()
            frame = self._tello.get_frame()
            if frame is None:
                self._screen.fill((0, 0, 0))
            else:
                frame = np.swapaxes(frame, 0, 1)
                surface = pygame.surfarray.make_surface(frame)
                self._screen.blit(surface, (0, 0))
            state = self._tello.get_state()
            self._render_state(state)
            self._render_wifi(wifi_snr)
            if self._thread_tello.is_alive():
                self._render_control(self._text_busy)
            else:
                self._render_control("rc {} {} {} {}".format(
                    self._rc_velo_right,
                    self._rc_velo_forward,
                    self._rc_velo_up,
                    self._rc_velo_clockwise
                ))
                self._tello.send_rc(
                    self._rc_velo_right,
                    self._rc_velo_forward,
                    self._rc_velo_up,
                    self._rc_velo_clockwise
                )
            pygame.display.update()
            clock.tick(30)


def main():
    logging.basicConfig(
        format="%(asctime)s %(levelname)s: %(message)s",
        level=logging.INFO
    )
    tello = Tello()
    if tello.connect():
        logging.info("Connection established.")
        window = TelloPygameWindow(tello)
        window.mainloop()
        tello.disconnect()
    else:
        logging.error("Unable to connect.")


if __name__ == '__main__':
    main()