1 package main 2 3 import ( 4 "bufio" 5 "crypto/sha256" 6 "encoding/json" 7 "errors" 8 "flag" 9 "fmt" 10 "net" 11 "net/http" 12 "os" 13 "os/signal" 14 "strconv" 15 "sync" 16 17 "golang.org/x/net/websocket" 18 ) 19 20 type hash [sha256.Size]byte 21 22 func (h *hash) MarshalJSON() ([]byte, error) { 23 r := make([]byte, (sha256.Size<<1)+2) 24 r[0] = '"' 25 r[(sha256.Size<<1)+1] = '"' 26 for n, b := range *h { 27 if t := b >> 4; t > 9 { 28 r[(n<<1)+1] = 'A' - 10 + t 29 } else { 30 r[(n<<1)+1] = '0' + t 31 } 32 if t := b & 15; t > 9 { 33 r[(n<<1)+2] = 'A' - 10 + t 34 } else { 35 r[(n<<1)+2] = '0' + t 36 } 37 } 38 return r, nil 39 } 40 41 var ErrInvalidPasswordHash = errors.New("invalid password hash") 42 43 func (h *hash) UnmarshalJSON(data []byte) error { 44 if len(data) != sha256.Size<<1+2 || data[0] != '"' || data[sha256.Size<<1+1] != '"' { 45 return ErrInvalidPasswordHash 46 } 47 for n, b := range data[1 : sha256.Size<<1+1] { 48 var v byte 49 if b >= '0' && b <= '9' { 50 v = b - '0' 51 } else if b >= 'A' && b <= 'F' { 52 v = b - 'A' + 10 53 } else if b >= 'a' && b <= 'f' { 54 v = b - 'a' + 10 55 } else { 56 return ErrInvalidPasswordHash 57 } 58 if n&1 == 0 { 59 (*h)[n>>1] = v << 4 60 } else { 61 (*h)[n>>1] |= v 62 } 63 } 64 return nil 65 } 66 67 var ( 68 configFile string 69 config Config 70 ) 71 72 type Config struct { 73 Port uint16 74 Username string 75 Password hash 76 77 mu sync.RWMutex 78 Servers servers 79 } 80 81 func saveConfig() error { 82 f, err := os.Create(configFile) 83 if err != nil { 84 return fmt.Errorf("error creating new config file: %w", err) 85 } 86 if err = json.NewEncoder(f).Encode(&config); err != nil { 87 return fmt.Errorf("error writing config file: %w", err) 88 } 89 if err = f.Close(); err != nil { 90 return fmt.Errorf("error closing config file: %w", err) 91 } 92 return nil 93 } 94 95 var unauthorised = []byte(`<html> 96 <head> 97 <title>Unauthorised</title> 98 </head> 99 <body> 100 <h1>Not Authorised</h1> 101 </body> 102 </html> 103 `) 104 105 func (c *Config) ServeHTTP(w http.ResponseWriter, r *http.Request) { 106 if u, p, ok := r.BasicAuth(); ok && u == c.Username && sha256.Sum256([]byte(p)) == c.Password { 107 switch r.URL.Path { 108 case "/": 109 index.ServeHTTP(w, r) 110 case "/socket": 111 websocket.Handler(NewConn).ServeHTTP(w, r) 112 default: 113 http.NotFound(w, r) 114 } 115 return 116 } 117 w.Header().Set("WWW-Authenticate", "Basic realm=\"Enter Credentials\"") 118 w.WriteHeader(http.StatusUnauthorized) 119 w.Write(unauthorised) 120 } 121 122 func main() { 123 if err := run(); err != nil { 124 fmt.Fprintln(os.Stderr, err) 125 os.Exit(1) 126 } 127 } 128 129 func run() error { 130 var define bool 131 flag.StringVar(&configFile, "c", "", "config file") 132 flag.BoolVar(&define, "d", false, "define settings for config file") 133 flag.Parse() 134 if configFile == "" { 135 return errors.New("no config file specified") 136 } 137 if define { 138 return defineConfig() 139 } 140 f, err := os.Open(configFile) 141 if err != nil { 142 return fmt.Errorf("error while opening config file: %w", err) 143 } 144 if err := json.NewDecoder(f).Decode(&config); err != nil { 145 return fmt.Errorf("error while decoding config file: %w", err) 146 } 147 f.Close() 148 l, err := net.ListenTCP("tcp", &net.TCPAddr{Port: int(config.Port)}) 149 if err != nil { 150 return fmt.Errorf("error opening management interface port: %w", err) 151 } 152 if config.Servers == nil { 153 config.Servers = make(servers) 154 } 155 config.Servers.Init() 156 s := http.Server{ 157 Handler: &config, 158 } 159 go func() { 160 if err := s.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) { 161 fmt.Fprintln(os.Stderr, err) 162 } 163 }() 164 sc := make(chan os.Signal, 1) 165 signal.Notify(sc, os.Interrupt) 166 <-sc 167 signal.Stop(sc) 168 close(sc) 169 s.Close() 170 ShutdownRPC() 171 config.Servers.Shutdown() 172 return nil 173 } 174 175 func defineConfig() error { 176 f, err := os.Open(configFile) 177 if err == nil { 178 json.NewDecoder(f).Decode(&config) 179 f.Close() 180 } else if !os.IsNotExist(err) { 181 return fmt.Errorf("error opening config file: %w", err) 182 } 183 r := bufio.NewReader(os.Stdin) 184 var skipPort, skipCredentials bool 185 if config.Port != 0 { 186 if err := getInput(r, fmt.Sprintf("Do you want to set a new management port (%d)? Y/N: ", config.Port), func(ans string) bool { 187 switch ans { 188 case "Y", "y": 189 case "N", "n": 190 skipPort = true 191 default: 192 return false 193 } 194 return true 195 }); err != nil { 196 return err 197 } 198 } 199 if !skipPort { 200 if err := getInput(r, "Please enter a port number for the management console (1-65535): ", func(ans string) bool { 201 p, err := strconv.ParseUint(ans, 10, 16) 202 if err != nil { 203 return false 204 } 205 config.Port = uint16(p) 206 return true 207 }); err != nil { 208 return err 209 } 210 } 211 if config.Username != "" { 212 if err := getInput(r, "Do you want to set new management credentials? Y/N: ", func(ans string) bool { 213 switch ans { 214 case "Y", "y": 215 case "N", "n": 216 skipCredentials = true 217 default: 218 return false 219 } 220 return true 221 }); err != nil { 222 return err 223 } 224 } 225 if !skipCredentials { 226 if err := getInput(r, "Username: ", func(ans string) bool { 227 if ans == "" { 228 return false 229 } 230 config.Username = ans 231 return true 232 }); err != nil { 233 return err 234 } 235 if err := getInput(r, "Password: ", func(ans string) bool { 236 if ans == "" { 237 return false 238 } 239 config.Password = sha256.Sum256([]byte(ans)) 240 return true 241 }); err != nil { 242 return err 243 } 244 } 245 if !skipPort || !skipCredentials { 246 return saveConfig() 247 } 248 return nil 249 } 250 251 func getInput(r *bufio.Reader, question string, checkFn func(string) bool) error { 252 for { 253 fmt.Print(question) 254 ans, err := r.ReadString('\n') 255 if err != nil { 256 return fmt.Errorf("error returned when reading stdin: %w", err) 257 } 258 if checkFn(ans[:len(ans)-1]) { 259 return nil 260 } 261 fmt.Println("\nDid not understand response") 262 } 263 } 264