package self_test

import (
	"bytes"
	"context"
	"crypto/tls"
	"fmt"
	"io"
	mrand "math/rand/v2"
	"net"
	"runtime"
	"slices"
	"strings"
	"sync"
	"sync/atomic"
	"testing"
	"time"

	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/internal/protocol"
	"github.com/quic-go/quic-go/internal/synctest"
	"github.com/quic-go/quic-go/qlog"
	"github.com/quic-go/quic-go/testutils/events"
	"github.com/quic-go/quic-go/testutils/simnet"

	"github.com/stretchr/testify/require"
)

func dropTestProtocolClientSpeaksFirst(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn {
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
	defer cancel()
	conn, err := quic.Dial(
		ctx,
		clientConn,
		ln.Addr(),
		clientConf,
		getQuicConfig(&quic.Config{
			MaxIdleTimeout:          timeout,
			HandshakeIdleTimeout:    timeout,
			DisablePathMTUDiscovery: true,
		}),
	)
	require.NoError(t, err)
	defer conn.CloseWithError(0, "")

	str, err := conn.OpenUniStream()
	require.NoError(t, err)
	errChan := make(chan error, 1)
	go func() {
		defer str.Close()
		_, err := str.Write(data)
		errChan <- err
	}()

	serverConn, err := ln.Accept(ctx)
	require.NoError(t, err)
	serverStr, err := serverConn.AcceptUniStream(ctx)
	require.NoError(t, err)
	b, err := io.ReadAll(&readerWithTimeout{Reader: serverStr, Timeout: timeout})
	require.NoError(t, err)
	require.Equal(t, b, data)
	serverConn.CloseWithError(0, "")

	return conn
}

func dropTestProtocolServerSpeaksFirst(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn {
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
	defer cancel()
	conn, err := quic.Dial(
		ctx,
		clientConn,
		ln.Addr(),
		clientConf,
		getQuicConfig(&quic.Config{
			MaxIdleTimeout:          timeout,
			HandshakeIdleTimeout:    timeout,
			DisablePathMTUDiscovery: true,
		}),
	)
	require.NoError(t, err)

	errChan := make(chan error, 1)
	go func() {
		defer close(errChan)
		defer conn.CloseWithError(0, "")
		str, err := conn.AcceptUniStream(ctx)
		if err != nil {
			errChan <- err
			return
		}
		b, err := io.ReadAll(&readerWithTimeout{Reader: str, Timeout: timeout})
		if err != nil {
			errChan <- err
			return
		}
		if !bytes.Equal(b, data) {
			errChan <- fmt.Errorf("data mismatch: %x != %x", b, data)
			return
		}
	}()

	serverConn, err := ln.Accept(ctx)
	require.NoError(t, err)
	serverStr, err := serverConn.OpenUniStream()
	require.NoError(t, err)
	_, err = serverStr.Write(data)
	require.NoError(t, err)
	require.NoError(t, serverStr.Close())

	select {
	case err := <-errChan:
		require.NoError(t, err)
	case <-time.After(timeout):
		t.Fatal("server connection not closed")
	}

	select {
	case <-conn.Context().Done():
	case <-time.After(timeout):
		t.Fatal("server connection not closed")
	}

	return conn
}

func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, _ []byte) *quic.Conn {
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
	defer cancel()
	conn, err := quic.Dial(
		ctx,
		clientConn,
		ln.Addr(),
		clientConf,
		getQuicConfig(&quic.Config{
			MaxIdleTimeout:          timeout,
			HandshakeIdleTimeout:    timeout,
			DisablePathMTUDiscovery: true,
		}),
	)
	require.NoError(t, err)
	defer conn.CloseWithError(0, "")

	serverConn, err := ln.Accept(ctx)
	require.NoError(t, err)
	serverConn.CloseWithError(0, "")

	return conn
}

func dropCallbackDropNthPacket(dir direction, ns ...int) func(direction, simnet.Packet) bool {
	var toClient, toServer atomic.Int32
	return func(d direction, p simnet.Packet) bool {
		switch d {
		case directionToClient:
			c := toClient.Add(1)
			if d == dir || dir == directionBoth {
				return slices.Contains(ns, int(c))
			}
		case directionToServer:
			c := toServer.Add(1)
			if dir == d || dir == directionBoth {
				return slices.Contains(ns, int(c))
			}
		}
		return false
	}
}

func dropCallbackDropOneThird(_ direction) func(direction, simnet.Packet) bool {
	const maxSequentiallyDropped = 10
	var mx sync.Mutex
	var toClient, toServer int
	return func(d direction, p simnet.Packet) bool {
		drop := mrand.IntN(3) == 0

		mx.Lock()
		defer mx.Unlock()
		// never drop more than 10 consecutive packets
		if d == directionToClient || d == directionBoth {
			if drop {
				toClient++
				if toClient > maxSequentiallyDropped {
					drop = false
				}
			}
			if !drop {
				toClient = 0
			}
		}
		if d == directionToServer || d == directionBoth {
			if drop {
				toServer++
				if toServer > maxSequentiallyDropped {
					drop = false
				}
			}
			if !drop {
				toServer = 0
			}
		}
		return drop
	}
}

func TestHandshakeWithPacketLoss(t *testing.T) {
	data := GeneratePRData(5000)
	const timeout = 2 * time.Minute
	const rtt = 20 * time.Millisecond

	type dropPattern string

	const (
		dropPatternDrop1stPacket         dropPattern = "drop 1st packet"
		dropPatternDropFirst3Packets     dropPattern = "drop first 3 packets"
		dropPatternDropOneThirdOfPackets dropPattern = "drop 1/3 of packets"
	)

	type testConfig struct {
		postQuantum   bool
		longCertChain bool
		doRetry       bool
	}

	for _, dir := range []direction{directionToClient, directionToServer, directionBoth} {
		for _, pattern := range []dropPattern{
			dropPatternDrop1stPacket,
			dropPatternDropFirst3Packets,
			dropPatternDropOneThirdOfPackets,
		} {
			t.Run(fmt.Sprintf("%s in direction %s", pattern, dir), func(t *testing.T) {
				for _, conf := range []testConfig{
					{postQuantum: false, longCertChain: false, doRetry: true},
					{postQuantum: false, longCertChain: false, doRetry: false},
					{postQuantum: false, longCertChain: true, doRetry: false},
					{postQuantum: true, longCertChain: false, doRetry: false},
					{postQuantum: true, longCertChain: true, doRetry: false},
				} {
					for _, test := range []struct {
						name string
						fn   func(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn
					}{
						{"client speaks first", dropTestProtocolClientSpeaksFirst},
						{"server speaks first", dropTestProtocolServerSpeaksFirst},
						{"nobody speaks", dropTestProtocolNobodySpeaks},
					} {
						t.Run(fmt.Sprintf("retry: %t/%s", conf.doRetry, test.name), func(t *testing.T) {
							synctest.Test(t, func(t *testing.T) {
								clientAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001}
								serverAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002}
								var fn func(direction, simnet.Packet) bool
								switch pattern {
								case dropPatternDrop1stPacket:
									fn = dropCallbackDropNthPacket(dir, 1)
								case dropPatternDropFirst3Packets:
									fn = dropCallbackDropNthPacket(dir, 1, 2, 3)
								case dropPatternDropOneThirdOfPackets:
									fn = dropCallbackDropOneThird(dir)
								}
								var numDropped atomic.Int32
								n := &simnet.Simnet{
									Router: &directionAwareDroppingRouter{
										ClientAddr: clientAddr,
										ServerAddr: serverAddr,
										Drop: func(d direction, p simnet.Packet) bool {
											drop := fn(d, p)
											if drop {
												numDropped.Add(1)
											}
											return drop
										},
									},
								}
								settings := simnet.NodeBiDiLinkSettings{Latency: rtt / 2}
								clientConn := n.NewEndpoint(clientAddr, settings)
								defer clientConn.Close()
								serverConn := n.NewEndpoint(serverAddr, settings)
								defer serverConn.Close()
								require.NoError(t, n.Start())
								defer n.Close()

								var tlsConf *tls.Config
								if conf.longCertChain {
									tlsConf = getTLSConfigWithLongCertChain()
								} else {
									tlsConf = getTLSConfig()
								}
								clientConf := getTLSClientConfig()
								if !conf.postQuantum {
									clientConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
								}

								tr := &quic.Transport{
									Conn:                serverConn,
									VerifySourceAddress: func(net.Addr) bool { return conf.doRetry },
								}
								defer tr.Close()

								ln, err := tr.Listen(
									tlsConf,
									getQuicConfig(&quic.Config{
										MaxIdleTimeout:          timeout,
										HandshakeIdleTimeout:    timeout,
										DisablePathMTUDiscovery: true,
									}),
								)
								require.NoError(t, err)
								defer ln.Close()

								conn := test.fn(t, ln, clientConn, clientConf, timeout, data)
								if !strings.HasPrefix(runtime.Version(), "go1.24") {
									curveID := getCurveID(conn.ConnectionState().TLS)
									if conf.postQuantum {
										require.Equal(t, tls.X25519MLKEM768, curveID)
									} else {
										require.Equal(t, tls.CurveP384, curveID)
									}
								}

								if pattern != dropPatternDropOneThirdOfPackets {
									require.NotZero(t, numDropped.Load())
								}
								t.Logf("dropped %d packets", numDropped.Load())
							})
						})
					}
				}
			})
		}
	}
}

func TestHandshakePacketBuffering(t *testing.T) {
	synctest.Test(t, func(t *testing.T) {
		const rtt = 20 * time.Millisecond

		clientAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001}
		serverAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002}
		var droppedFirst atomic.Bool
		n := &simnet.Simnet{
			Router: &directionAwareDroppingRouter{
				ClientAddr: clientAddr,
				ServerAddr: serverAddr,
				Drop: func(d direction, p simnet.Packet) bool {
					if droppedFirst.Load() {
						return false
					}
					if d == directionToClient && containsPacketType(p.Data, protocol.PacketTypeInitial) {
						droppedFirst.Store(true)
						return true
					}
					return false
				},
			},
		}
		settings := simnet.NodeBiDiLinkSettings{Latency: rtt / 2}
		clientConn := n.NewEndpoint(clientAddr, settings)
		defer clientConn.Close()
		serverConn := n.NewEndpoint(serverAddr, settings)
		defer serverConn.Close()
		require.NoError(t, n.Start())
		defer n.Close()

		var serverEventRecorder events.Recorder
		ln, err := quic.Listen(
			serverConn,
			getTLSConfig(),
			getQuicConfig(&quic.Config{Tracer: newTracer(&serverEventRecorder)}),
		)
		require.NoError(t, err)
		defer ln.Close()

		var clientEventRecorder events.Recorder
		conn, err := quic.Dial(
			context.Background(),
			clientConn,
			ln.Addr(),
			getTLSClientConfig(),
			getQuicConfig(&quic.Config{Tracer: newTracer(&clientEventRecorder)}),
		)
		require.NoError(t, err)
		defer conn.CloseWithError(0, "")
		str, err := conn.OpenUniStream()
		require.NoError(t, err)
		data := []byte("foobar")
		_, err = str.Write(data)
		require.NoError(t, err)
		require.NoError(t, str.Close())

		require.Empty(t, serverEventRecorder.Events(qlog.PacketBuffered{}))
		buffered := clientEventRecorder.Events(qlog.PacketBuffered{})
		t.Logf("buffered packets: %d", len(buffered))
		require.NotEmpty(t, buffered)
		receivedPackets := make(map[qlog.DatagramID][]qlog.PacketType)
		for _, ev := range clientEventRecorder.Events(qlog.PacketReceived{}) {
			id := ev.(qlog.PacketReceived).DatagramID
			receivedPackets[id] = append(receivedPackets[id], ev.(qlog.PacketReceived).Header.PacketType)
		}
		for _, ev := range buffered {
			id := ev.(qlog.PacketBuffered).DatagramID
			require.Contains(t, receivedPackets, id)
			require.Contains(t, receivedPackets[id], qlog.PacketTypeHandshake)
		}

		sconn, err := ln.Accept(context.Background())
		require.NoError(t, err)
		defer sconn.CloseWithError(0, "")
		sstr, err := sconn.AcceptUniStream(context.Background())
		require.NoError(t, err)
		b, err := io.ReadAll(sstr)
		require.NoError(t, err)
		require.Equal(t, data, b)
		require.Equal(t, rtt, sconn.ConnectionStats().SmoothedRTT)
	})
}
