package infinitime

import (
	"archive/zip"
	"bytes"
	"encoding/binary"
	"encoding/json"
	"errors"
	"io/fs"
	"io/ioutil"
	"os"
	"time"

	"github.com/muka/go-bluetooth/bluez"
	"github.com/muka/go-bluetooth/bluez/profile/gatt"
)

const (
	DFUCtrlPointChar = "00001531-1212-efde-1523-785feabcd123" // UUID of Control Point characteristic
	DFUPacketChar    = "00001532-1212-efde-1523-785feabcd123" // UUID of Packet characteristic
)

const (
	DFUSegmentSize     = 20 // Size of each firmware packet
	DFUPktRecvInterval = 10 // Amount of packets to send before checking for receipt
)

var (
	DFUCmdStart              = []byte{0x01, 0x04}
	DFUCmdRecvInitPkt        = []byte{0x02, 0x00}
	DFUCmdInitPktComplete    = []byte{0x02, 0x01}
	DFUCmdPktReceiptInterval = []byte{0x08, 0x0A}
	DFUCmdRecvFirmware       = []byte{0x03}
	DFUCmdValidate           = []byte{0x04}
	DFUCmdActivateReset      = []byte{0x05}
)

var (
	DFUResponseStart            = []byte{0x10, 0x01, 0x01}
	DFUResponseInitParams       = []byte{0x10, 0x02, 0x01}
	DFUResponseRecvFwImgSuccess = []byte{0x10, 0x03, 0x01}
	DFUResponseValidate         = []byte{0x10, 0x04, 0x01}
)

var DFUNotifPktRecvd = []byte{0x11}

var (
	ErrDFUInvalidInput    = errors.New("input file invalid, must be a .bin file")
	ErrDFUTimeout         = errors.New("timed out waiting for response")
	ErrDFUNoFilesLoaded   = errors.New("no files are loaded")
	ErrDFUInvalidResponse = errors.New("invalid response returned")
	ErrDFUSizeMismatch    = errors.New("amount of bytes sent does not match amount received")
)

var btOptsCmd = map[string]interface{}{"type": "command"}

type DFUProgress struct {
	Sent     int   `json:"sent"`
	Received int   `json:"recvd"`
	Total    int64 `json:"total"`
}

// DFU stores everything required for doing firmware upgrades
type DFU struct {
	initPacket    fs.File
	fwImage       fs.File
	ctrlRespCh    <-chan *bluez.PropertyChanged
	fwSize        int64
	bytesSent     int
	bytesRecvd    int
	fwSendDone    bool
	progress      chan DFUProgress
	ctrlPointChar *gatt.GattCharacteristic1
	packetChar    *gatt.GattCharacteristic1
}

// LoadFiles loads an init packet (.dat) and firmware image (.bin)
func (dfu *DFU) LoadFiles(initPath, fwPath string) error {
	// Open init packet file
	initPktFl, err := os.Open(initPath)
	if err != nil {
		return err
	}
	dfu.initPacket = initPktFl

	// Open firmware image file
	fwImgFl, err := os.Open(fwPath)
	if err != nil {
		return err
	}
	dfu.fwImage = fwImgFl

	// Get firmware file size
	dfu.fwSize, err = getFlSize(dfu.fwImage)
	if err != nil {
		return err
	}

	return nil
}

type archiveManifest struct {
	Manifest struct {
		Application struct {
			BinFile string `json:"bin_file"`
			DatFile string `json:"dat_file"`
		} `json:"application"`
	} `json:"manifest"`
}

// LoadArchive loads an init packet and firmware image from a zip archive
// using a maifest.json also stored in the archive.
func (dfu *DFU) LoadArchive(archivePath string) error {
	// Open archive file
	archiveFl, err := os.Open(archivePath)
	if err != nil {
		return err
	}

	// Get archive size
	archiveSize, err := getFlSize(archiveFl)
	if err != nil {
		return err
	}

	// Create zip reader from archive file
	zipReader, err := zip.NewReader(archiveFl, archiveSize)
	if err != nil {
		return err
	}

	// Open manifest.json from zip archive
	manifestFl, err := zipReader.Open("manifest.json")
	if err != nil {
		return err
	}

	var manifest archiveManifest
	// Decode manifest file as JSON
	err = json.NewDecoder(manifestFl).Decode(&manifest)
	if err != nil {
		return err
	}

	// Open init packet from zip archive
	initPktFl, err := zipReader.Open(manifest.Manifest.Application.DatFile)
	if err != nil {
		return err
	}
	dfu.initPacket = initPktFl

	// Open firmware image from zip archive
	fwImgFl, err := zipReader.Open(manifest.Manifest.Application.BinFile)
	if err != nil {
		return err
	}
	dfu.fwImage = fwImgFl

	// Get file size of firmware image
	dfu.fwSize, err = getFlSize(dfu.fwImage)
	if err != nil {
		return err
	}

	return nil
}

// getFlSize uses Stat to get the size of a file
func getFlSize(fl fs.File) (int64, error) {
	// Get file information
	flInfo, err := fl.Stat()
	if err != nil {
		return 0, err
	}
	return flInfo.Size(), nil
}

func (dfu *DFU) Progress() <-chan DFUProgress {
	if dfu.progress == nil {
		dfu.progress = make(chan DFUProgress, 5)
	}
	return dfu.progress
}

func (dfu *DFU) sendProgress() {
	dfu.progress <- DFUProgress{
		Sent:     dfu.bytesSent,
		Received: dfu.bytesRecvd,
		Total:    dfu.fwSize,
	}
}

// bufferPropsCh resends all messages on in to a new channel that is buffered with 5 elements
func bufferPropsCh(in chan *bluez.PropertyChanged) chan *bluez.PropertyChanged {
	// Create new buffered channel
	out := make(chan *bluez.PropertyChanged, 5)
	go func() {
		// Close channel when underlying channel closed
		defer close(out)
		// For every property change on in
		for prop := range in {
			// Write to out
			out <- prop
		}
	}()
	// Return new buffered channel
	return out
}

// Start DFU process
func (dfu *DFU) Start() error {
	if dfu.progress == nil {
		dfu.progress = make(chan DFUProgress, 5)
	}
	defer close(dfu.progress)

	if dfu.fwImage == nil || dfu.initPacket == nil {
		return ErrDFUNoFilesLoaded
	}

	// Start notifications on control point
	err := dfu.ctrlPointChar.StartNotify()
	if err != nil {
		return err
	}

	// Watch for property changes on control point
	unbufferedCh, err := dfu.ctrlPointChar.WatchProperties()
	if err != nil {
		return err
	}
	// Buffer properties channel so no changes are missed
	dfu.ctrlRespCh = bufferPropsCh(unbufferedCh)

	// Run step one
	err = dfu.stepOne()
	if err != nil {
		return err
	}

	// Run step two
	err = dfu.stepTwo()
	if err != nil {
		return err
	}

	// When 0x100101 received, run step three
	err = dfu.on(DFUResponseStart, func(_ []byte) error {
		return dfu.stepThree()
	})
	if err != nil {
		return err
	}

	// Run step three
	err = dfu.stepFour()
	if err != nil {
		return err
	}

	// When 0x100201 received. run step five
	err = dfu.on(DFUResponseInitParams, func(_ []byte) error {
		return dfu.stepFive()
	})
	if err != nil {
		return err
	}

	// Run step six
	err = dfu.stepSix()
	if err != nil {
		return err
	}

	// Run step seven
	err = dfu.stepSeven()
	if err != nil {
		return err
	}

	// When 0x100301 received, run step eight
	err = dfu.on(DFUResponseRecvFwImgSuccess, func(_ []byte) error {
		return dfu.stepEight()
	})
	if err != nil {
		return err
	}

	// When 0x100401 received, run step nine
	err = dfu.on(DFUResponseValidate, func(_ []byte) error {
		return dfu.stepNine()
	})
	if err != nil {
		return err
	}

	return nil
}

// Reset reverts all values back to default to prepare for
// the next DFU.
func (dfu *DFU) Reset() {
	dfu.bytesRecvd = 0
	dfu.bytesSent = 0
	dfu.fwSize = 0
	dfu.fwSendDone = false
	dfu.fwImage = nil
	dfu.initPacket = nil
	dfu.ctrlRespCh = nil
	dfu.progress = nil
}

// on waits for the given command to be received on
// the control point characteristic, then runs the callback.
func (dfu *DFU) on(cmd []byte, onCmdCb func(data []byte) error) error {
	// Use for loop in case of invalid property
	for {
		select {
		case propChanged := <-dfu.ctrlRespCh:
			// If property was invalid
			if propChanged.Name != "Value" {
				// Keep waiting
				continue
			}
			// Assert propery value as byte slice
			data := propChanged.Value.([]byte)
			// If command has prefix of given command
			if bytes.HasPrefix(data, cmd) {
				// Return callback with data after command
				return onCmdCb(data[len(cmd):])
			}
			return ErrDFUInvalidResponse
		case <-time.After(50 * time.Second):
			return ErrDFUTimeout
		}
	}
}

func (dfu *DFU) stepOne() error {
	return dfu.ctrlPointChar.WriteValue(DFUCmdStart, nil)
}

func (dfu *DFU) stepTwo() error {
	// Create byte slice with 4 bytes allocated
	data := make([]byte, 4)
	// Write little endian uint32 to data slice
	binary.LittleEndian.PutUint32(data, uint32(dfu.fwSize))
	// Pad data with 8 bytes
	data = append(make([]byte, 8), data...)
	// Write data to packet characteristic
	return dfu.packetChar.WriteValue(data, nil)
}

func (dfu *DFU) stepThree() error {
	return dfu.ctrlPointChar.WriteValue(DFUCmdRecvInitPkt, nil)
}

func (dfu *DFU) stepFour() error {
	// Read init packet
	data, err := ioutil.ReadAll(dfu.initPacket)
	if err != nil {
		return err
	}
	// Write init packet to packet characteristic
	err = dfu.packetChar.WriteValue(data, nil)
	if err != nil {
		return err
	}
	// Write init packet complete command to control point
	return dfu.ctrlPointChar.WriteValue(DFUCmdInitPktComplete, nil)
}

func (dfu *DFU) stepFive() error {
	return dfu.ctrlPointChar.WriteValue(DFUCmdPktReceiptInterval, nil)
}

func (dfu *DFU) stepSix() error {
	return dfu.ctrlPointChar.WriteValue(DFUCmdRecvFirmware, nil)
}

func (dfu *DFU) stepSeven() error {
	// While send is not done
	for !dfu.fwSendDone {
		for i := 0; i < DFUPktRecvInterval; i++ {
			amtLeft := dfu.fwSize - int64(dfu.bytesSent)
			// If no bytes left to send, end transfer
			if amtLeft == 0 {
				dfu.sendProgress()
				dfu.fwSendDone = true
				return nil
			}
			var segment []byte
			// If amount left is less than segment size
			if amtLeft < DFUSegmentSize {
				// Create byte slice with amount left
				segment = make([]byte, amtLeft)
			} else {
				// Create byte slice with segment size
				segment = make([]byte, DFUSegmentSize)
			}
			// Write firmware image into slice
			_, err := dfu.fwImage.Read(segment)
			if err != nil {
				return err
			}
			// Write segment to packet characteristic
			err = dfu.packetChar.WriteValue(segment, nil)
			if err != nil {
				return err
			}
			// Increment bytes sent by amount read
			dfu.bytesSent += len(segment)
		}
		// On 0x11, verify packet receipt size
		err := dfu.on(DFUNotifPktRecvd, func(data []byte) error {
			// Set bytes received to data returned by InfiniTime
			dfu.bytesRecvd = int(binary.LittleEndian.Uint32(data))
			if dfu.bytesRecvd != dfu.bytesSent {
				return ErrDFUSizeMismatch
			}
			dfu.sendProgress()
			return nil
		})
		if err != nil {
			return err
		}
	}
	return nil
}

func (dfu *DFU) stepEight() error {
	return dfu.ctrlPointChar.WriteValue(DFUCmdValidate, nil)
}

func (dfu *DFU) stepNine() error {
	return dfu.ctrlPointChar.WriteValue(DFUCmdActivateReset, btOptsCmd)
}