package main

import (
	"bufio"
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"log"
	"net"
	"sync"
)

const socks5Version = 0x05
const subnegotiation = 0x01
const connectCommand = 0x01
const bindCommand = 0x02
const udpAssociateCommand = 0x03
const ipv4AddressType = 0x01
const ipv6AddressType = 0x04
const domainAddressType = 0x03

func main() {
	server, err := net.Listen("tcp", ":8080")
	if err != nil {
		panic(err)
	}
	for {
		client, err := server.Accept()
		if err != nil {
			log.Printf("Error accepting client connection: %v", err)
			continue
		}
		go process(client)
	}
}

func process(conn net.Conn) {
	defer conn.Close()
	/* 开始协商 */
	reader := bufio.NewReader(conn)
	err := auth(reader, conn)
	if err != nil {
		log.Printf("client %v auth failed:%v", conn.RemoteAddr(), err)
		return
	}
	//auReader := bufio.NewReader(conn)
	//err = authenticate(auReader, conn)
	//if err != nil {
	//	log.Printf("client %v auth failed:%v", conn.RemoteAddr(), err)
	//	return
	//}
	/* 协商成功, 进入请求阶段 */
	err = connect(reader, conn)
	if err != nil {
		log.Printf("client %v auth failed:%v", conn.RemoteAddr(), err)
		return
	}
}

func auth(reader *bufio.Reader, conn net.Conn) error {
	ver, err := reader.ReadByte()
	if err != nil {
		return fmt.Errorf("Error reading ver: %v", err)
	}
	if ver != socks5Version {
		log.Println(ver)
		return fmt.Errorf("Unsupported socks5 version: %v", ver)
	}
	methodSize, err := reader.ReadByte()
	if err != nil {
		return fmt.Errorf("Error reading method size: %v", err)
	}
	method := make([]byte, methodSize)
	_, err = reader.Read(method)
	if err != nil {
		return fmt.Errorf("Error reading method: %v", err)
	}

	_, err = conn.Write([]byte{socks5Version, 0x00})

	if err != nil {
		return fmt.Errorf("Error writing method response: %v", err)
	}

	return nil
}

func authenticate(reader *bufio.Reader, conn net.Conn) error {
	/**
	*  +----+------+----------+------+----------+
	*  |VER | ULEN |  UNAME   | PLEN |  PASSWD  |
	*  +----+------+----------+------+----------+
	*  | 1  |  1   | 1 to 255 |  1   | 1 to 255 |
	*  +----+------+----------+------+----------+
	 */
	ver, err := reader.ReadByte()
	if err != nil {
		return fmt.Errorf("read ver failed:%w", err)
	}
	if ver != subnegotiation {
		log.Printf("subnegotiation:%v ", subnegotiation)
		return fmt.Errorf("not supported ver:%v", ver)
	}
	ulen, err := reader.ReadByte()
	if err != nil {
		return fmt.Errorf("read ulen failed:%w", err)
	}
	uname := make([]byte, ulen)
	_, err = io.ReadFull(reader, uname)
	if err != nil {
		return fmt.Errorf("read uname failed:%w", err)
	}
	if string(uname) != "admin" {
		_, err = conn.Write([]byte{0x01, 0x01})
		if err != nil {
			return fmt.Errorf("write failed:%w", err)
		}
		return fmt.Errorf("auth uname failed:%w", err)
	}

	plen, err := reader.ReadByte()
	if err != nil {
		return fmt.Errorf("read plen failed:%w", err)
	}

	passwd := make([]byte, plen)
	_, err = io.ReadFull(reader, passwd)
	if err != nil {
		return fmt.Errorf("read passwd failed:%w", err)
	}
	if string(passwd) != "123456" {
		_, err = conn.Write([]byte{0x01, 0x01})
		if err != nil {
			return fmt.Errorf("write failed:%w", err)
		}
		return fmt.Errorf("auth uname failed:%w", err)
	}

	/*  +----+--------+
	*  |VER | STATUS |
	*  +----+--------+
	*  | 1  |   1    |
	*  +----+--------+
	 */
	_, err = conn.Write([]byte{0x01, 0x00})
	if err != nil {
		return fmt.Errorf("write failed:%w", err)
	}

	return nil
}

func connect(reader *bufio.Reader, conn net.Conn) error {
	// +----+-----+-------+------+----------+----------+
	// |VER | CMD |  RSV  | ATYP | DST.ADDR | DST.PORT |
	// +----+-----+-------+------+----------+----------+
	// | 1  |  1  | X'00' |  1   | Variable |    2     |
	// +----+-----+-------+------+----------+----------+
	// VER 版本号,socks5的值为0x05
	// CMD 0x01表示CONNECT请求
	// RSV 保留字段,值为0x00
	// ATYP 目标地址类型,DST.ADDR的数据对应这个字段的类型。
	//   0x01表示IPv4地址,DST.ADDR为4个字节
	//   0x03表示域名,DST.ADDR是一个可变长度的域名
	// DST.ADDR 一个可变长度的值
	// DST.PORT 目标端口,固定2个字节

	buf := make([]byte, 4)
	ipv4buf := make([]byte, 4)
	ipv6buf := make([]byte, 16)
	portbuf := make([]byte, 2)
	_, err := reader.Read(buf)
	if err != nil {
		return fmt.Errorf("read header failed:%w", err)
	}
	ver, cmd, atyp := buf[0], buf[1], buf[3]
	if ver != socks5Version {
		return fmt.Errorf("Unsupported socks5 ver: %v", ver)
	}
	if cmd != connectCommand {
		return fmt.Errorf("not supported cmd:%v", cmd)
	}
	addr := ""
	switch atyp {
	case ipv4AddressType:
		_, err = reader.Read(ipv4buf)
		if err != nil {
			return fmt.Errorf("read ipv4 address failed:%w", err)
		}
		addr = net.IP(ipv4buf).String()
	case ipv6AddressType:
		_, err = reader.Read(ipv6buf)
		if err != nil {
			return fmt.Errorf("read ipv6 address failed:%w", err)
		}
		addr = net.IP(ipv6buf).String()
	case domainAddressType:
		hostSize, err := reader.ReadByte()
		if err != nil {
			return fmt.Errorf("read hostSize failed:%w", err)
		}
		host := make([]byte, hostSize)
		_, err = reader.Read(host)
		addr = string(host)
	default:
		return errors.New("invalid atyp")
	}

	_, err = reader.Read(portbuf)
	port := binary.BigEndian.Uint16(portbuf)
	if err != nil {
		return fmt.Errorf("read port failed:%w", err)
	}

	dest, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port))
	if err != nil {
		return fmt.Errorf("connect to %s:%d failed:%w", addr, port, err)
	}
	defer dest.Close()
	log.Println("dial", addr, port)
	// +----+-----+-------+------+----------+----------+
	// |VER | REP |  RSV  | ATYP | BND.ADDR | BND.PORT |
	// +----+-----+-------+------+----------+----------+
	// | 1  |  1  | X'00' |  1   | Variable |    2     |
	// +----+-----+-------+------+----------+----------+
	// VER socks版本,这里为0x05
	// REP Relay field,内容取值如下 X’00’ succeeded
	// RSV 保留字段
	// ATYPE 地址类型
	// BND.ADDR 服务绑定的地址
	// BND.PORT 服务绑定的端口DST.PORT
	_, err = conn.Write([]byte{0x05, 0x00, 0x00, ipv4AddressType, 0, 0, 0, 0, 0, 0})

	if err != nil {
		return fmt.Errorf("write failed: %w", err)
	}

	var wg sync.WaitGroup
	wg.Add(2)

	go func() {
		defer wg.Done()
		io.Copy(dest, conn)
	}()
	go func() {
		defer wg.Done()
		io.Copy(conn, dest)
	}()

	wg.Wait()
	return nil
}