#!/usr/bin/env python3

# dji_fpv_srt_telemetry.py
# 2023-01-12
# by Gernot Walzl

# The DJI FPV drone records SRT files to track GPS coordinates, altitudes, etc.
# This script calculates speed, height, etc. of the drone and writes subtitles
# that are used as an overlay for the MP4 video file.
#
# To extract flight data at a specified start time, ffmpeg can be used:
# ffmpeg -ss 00:00:10 -i DJI_0100.SRT output.srt

import argparse
import datetime
import re
import pyproj


class DjiFpvSrtFileReader:

    def __init__(self):
        self._pattern_dt = re.compile(r'(\d+\-\d+\-\d+ \d+:\d+:\d+\.\d+)')
        self._pattern_var = re.compile(r'\[([\w+ *: *[a-zA-Z0-9.\-/]+)\]')
        self._frame_infos = []
        self._frame_time = 0.0

    def read(self, filename):
        self._frame_infos.clear()
        with open(filename, 'r') as file_srt:
            frame_datetime = None
            for line in file_srt:
                str_dts = self._pattern_dt.findall(line)
                if str_dts:
                    frame_datetime = datetime.datetime.strptime(
                        str_dts[0], '%Y-%m-%d %H:%M:%S.%f')
                    continue
                str_vars = self._pattern_var.findall(line)
                if str_vars:
                    infos = dict()
                    infos['datetime'] = frame_datetime
                    for str_var in str_vars:
                        (varname, value) = str_var.split(':')
                        varname = varname.strip()
                        str_value = value.strip()
                        if varname in ['iso', 'fnum', 'ev', 'ct', 'focal_len']:
                            infos[varname] = int(str_value)
                        elif varname in ['latitude', 'longitude', 'altitude']:
                            infos[varname] = float(str_value)
                        else:
                            infos[varname] = str_value
                    self._frame_infos.append(infos)
        self._frame_time = (
            self._frame_infos[-1]['datetime'] -
            self._frame_infos[0]['datetime']).total_seconds() / (
            len(self._frame_infos)-1)

    def get(self, varname, step=1):
        result = []
        for cnt in range(0, len(self._frame_infos), step):
            result.append(self._frame_infos[cnt][varname])
        return result

    def get_frame_time(self):
        return self._frame_time


class TelemetryCalculator:

    def __init__(self):
        self._geod = pyproj.Geod(ellps='WGS84')
        self._ground_speeds = []
        self._directions = []
        self._vert_speeds = []
        self._heights = []

    def calc_geo(self, datetimes, latitudes, longitudes):
        lats1 = latitudes[:-1]
        lons1 = longitudes[:-1]
        lats2 = latitudes[1:]
        lons2 = longitudes[1:]
        directions, _, dists = self._geod.inv(lons1, lats1, lons2, lats2)
        direction_last = 0.0
        self._directions = []
        for cnt in range(0, len(directions)):
            direction = directions[cnt]
            if dists[cnt] == 0.0:
                direction = direction_last
            else:
                if direction < 0.0:
                    direction += 360.0
                direction_last = direction
            self._directions.append(direction)
        self._ground_speeds = []
        for cnt in range(0, len(dists)):
            delta_t = (datetimes[cnt+1] - datetimes[cnt]).total_seconds()
            self._ground_speeds.append(dists[cnt] / delta_t)

    def calc_alt(self, datetimes, altitudes, alt_offset=None):
        if alt_offset is None:
            alt_offset = min(altitudes)
        self._heights = []
        for altitude in altitudes:
            height = (altitude - alt_offset) * 10.0
            self._heights.append(height)
        self._vert_speeds = []
        for cnt in range(0, len(self._heights)-1):
            delta_h = self._heights[cnt+1] - self._heights[cnt]
            delta_t = (datetimes[cnt+1] - datetimes[cnt]).total_seconds()
            self._vert_speeds.append(delta_h / delta_t)

    def get_ground_speeds(self):
        return self._ground_speeds

    def get_directions(self):
        return self._directions

    def get_vert_speeds(self):
        return self._vert_speeds

    def get_heights(self):
        return self._heights


class SrtFileWriter:

    def __init__(self, subtitle_time):
        self._subtitle_time = subtitle_time

    def _format_time(self, float_time):
        hours, remainder = divmod(float_time, 3600)
        minutes, remainder = divmod(remainder, 60)
        seconds, remainder = divmod(remainder, 1)
        milliseconds = remainder * 1000
        return '{:02d}:{:02d}:{:02d},{:03d}'.format(
            int(hours), int(minutes), int(seconds), int(milliseconds))

    def write(self, filename, subtitles):
        with open(filename, 'w') as file_srt:
            cnt = 0
            time_start = 0.0
            for subtitle in subtitles:
                cnt += 1
                file_srt.write("{}\n".format(cnt))
                time_end = cnt * self._subtitle_time
                str_time_s = self._format_time(time_start)
                str_time_e = self._format_time(time_end)
                file_srt.write("{} --> {}\n".format(str_time_s, str_time_e))
                time_start = time_end
                file_srt.write("{}\n".format(subtitle))
                file_srt.write("\n")


class VttFileWriter:

    def __init__(self, subtitle_time):
        self._subtitle_time = subtitle_time

    def _format_time(self, float_time):
        minutes, remainder = divmod(float_time, 60)
        seconds, remainder = divmod(remainder, 1)
        milliseconds = remainder * 1000
        return '{:02d}:{:02d}.{:03d}'.format(
            int(minutes), int(seconds), int(milliseconds))

    def write(self, filename, subtitles):
        with open(filename, 'w') as file_vtt:
            file_vtt.write("WEBVTT\n")
            cnt = 0
            time_start = 0.0
            for subtitle in subtitles:
                cnt += 1
                file_vtt.write("\n")
                time_end = cnt * self._subtitle_time
                str_time_s = self._format_time(time_start)
                str_time_e = self._format_time(time_end)
                file_vtt.write("{} --> {}\n".format(str_time_s, str_time_e))
                time_start = time_end
                file_vtt.write("{}\n".format(subtitle))


def main(inputfile, outputfile, alt_offset=None):
    srtfilereader = DjiFpvSrtFileReader()
    srtfilereader.read(inputfile)
    step = 6
    datetimes = srtfilereader.get("datetime", step)
    latitudes = srtfilereader.get("latitude", step)
    longitudes = srtfilereader.get("longitude", step)
    altitudes = srtfilereader.get("altitude", step)
    frame_time = srtfilereader.get_frame_time()

    telemetry = TelemetryCalculator()
    telemetry.calc_geo(datetimes, latitudes, longitudes)
    telemetry.calc_alt(datetimes, altitudes, alt_offset)

    ground_speeds = telemetry.get_ground_speeds()
    directions = telemetry.get_directions()
    vert_speeds = telemetry.get_vert_speeds()
    heights = telemetry.get_heights()

    subtitles = []
    for cnt in range(0, len(ground_speeds)):
        vert_speed = vert_speeds[cnt]
        vert_speed_ud = u'↑'
        if vert_speed < 0.0:
            vert_speed *= -1.0
            vert_speed_ud = u'↓'
        sub = u"h={:03.0f}m {}{:04.1f}m/s    v={:04.1f}m/s {:03.0f}°".format(
            heights[cnt], vert_speed_ud, vert_speed,
            ground_speeds[cnt], directions[cnt])
        subtitles.append(sub)

    filewriter = None
    if outputfile.endswith('.srt'):
        filewriter = SrtFileWriter(frame_time * step)
    elif outputfile.endswith('.vtt'):
        filewriter = VttFileWriter(frame_time * step)
    filewriter.write(outputfile, subtitles)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("inputfile", type=str)
    parser.add_argument("outputfile", type=str)
    parser.add_argument("--altoffset", type=float)
    args = parser.parse_args()
    main(args.inputfile, args.outputfile, args.altoffset)