yats.git

ref: a118ec93f002f5634886df58c2ff048d83953684

server/tcp/tcpserver.go


  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
/**
 * Yats - yats
 *
 * This file is licensed under the Affero General Public License version 3 or
 * later. See the COPYING file.
 *
 * @author Paolo Lulli <kevwe.com>
 * @copyright Paolo Lulli 2024
 */

package tcp

import (
	"crypto/tls"
	"fmt"
	"log"
	"net"
	"os"
	"os/signal"
	"sync"
	"syscall"
	"time"
	"yats-server/config"
	"yats-server/tlv"
)

type tcpserver struct {
	wg         sync.WaitGroup
	listener   net.Listener
	shutdown   chan struct{}
	connection chan net.Conn
}

func createServer(cfg config.Configuration) (*tcpserver, error) {

	var listener net.Listener
	var err error

	if cfg.TlsActive == "true" {
		cert, err := tls.LoadX509KeyPair(cfg.TlsCertificate, cfg.TlsKeyFile)
		if err != nil {
			log.Fatalf("server: loadkeys: %s", err)
		}
		tlsconfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: false}
		//listener, err := tls.Dial("tcp", cfg.TcpAddress, &tlsconfig)

		listener, err = tls.Listen("tcp", cfg.TcpAddress, &tlsconfig)
	} else {
		listener, err = net.Listen("tcp", cfg.TcpAddress)
	}
	if err != nil {
		return nil, fmt.Errorf("failed to listen on address %s: %w", cfg.TcpAddress, err)
	}

	return &tcpserver{
		listener:   listener,
		shutdown:   make(chan struct{}),
		connection: make(chan net.Conn),
	}, nil
}

func (s *tcpserver) acceptConnections() {
	defer s.wg.Done()

	for {
		select {
		case <-s.shutdown:
			return
		default:
			conn, err := s.listener.Accept()
			if err != nil {
				continue
			}
			s.connection <- conn
		}
	}
}

func (s *tcpserver) handleConnections() {
	defer s.wg.Done()

	for {
		select {
		case <-s.shutdown:
			return
		case conn := <-s.connection:
			go s.handleConnection(conn)
		}
	}
}

func (s *tcpserver) handleConnection(conn net.Conn) {
	defer conn.Close()

	tlscon, ok := conn.(*tls.Conn)
	if ok {
		state := tlscon.ConnectionState()
		log.Println("Server: client public key is:")
		clientCN := state.PeerCertificates[0].Subject.CommonName
		fmt.Printf("clientCN: %s\n", clientCN)
	}

	// Read incoming data
	buf := make([]byte, 1024)
	_, err := conn.Read(buf)
	if err != nil {
		fmt.Println(err)
		return
	}

	fmt.Printf("Received: %s", buf)
	decodeTlv, bytes, err := tlv.DecodeTlv(buf)
	if err == nil {
		fmt.Printf("Decoded: tlv: %s with value: %x", decodeTlv, bytes)
	}
}

func (s *tcpserver) Start() {
	s.wg.Add(2)
	go s.acceptConnections()
	go s.handleConnections()
}

func (s *tcpserver) Stop() {
	close(s.shutdown)
	s.listener.Close()

	done := make(chan struct{})
	go func() {
		s.wg.Wait()
		close(done)
	}()

	select {
	case <-done:
		return
	case <-time.After(time.Second):
		fmt.Println("Timed out waiting for connections to finish.")
		return
	}
}

func RunTcpServer(cfg config.Configuration) {
	s, err := createServer(cfg)
	if err != nil {
		fmt.Println(err)
		os.Exit(1)
	}

	s.Start()

	// Wait for a SIGINT or SIGTERM signal to gracefully shut down the tcpserver
	sigChan := make(chan os.Signal, 1)
	signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
	<-sigChan

	fmt.Println("Shutting down tcpserver...")
	s.Stop()
	fmt.Println("tcpserver stopped.")
}