local socket = require("socket")
local mime = require("mime")
require("sha1")

local sha1_binary = sha1.binary
local mime_b64 = mime.b64

local function xor_cipher32(data, key)
    local out = {}
    for i=1,string.len(data) do
        table.insert(out, string.char(data:byte(i) ~ key:byte(((i-1) % 4) + 1)))
    end
    return table.concat(out)
end

local WebSocketStream = class()
function WebSocketStream:init(socket)
	assert(socket ~= nil)
	self.socket = socket
	self.select_recvt = { socket }
end

function WebSocketStream:recv(returnOnPrint)
	::recv::
	local data, err = self.socket:receive(2)
    if err then
        if err ~= "timeout" then
            objc.warning(err)
        end
        return nil
    end
    
    local b0t15 = string.unpack(">I2", data)
    assert((b0t15 & 0x80) > 0) -- Check mask flag is set
    
    local fin = b0t15 & 0x8000
    local plLen = b0t15 & 0x7F
    
    if plLen == 126 then
        plLen = string.unpack(">I2", self.socket:receive(2))
    elseif plLen == 127 then
        plLen = string.unpack(">I8", self.socket:receive(8))
    end

    local MASK = table.pack(string.unpack(">BBBB", self.socket:receive(4)))
        
    local payload = self.socket:receive(plLen)
    local i = 0
    payload = payload:gsub(".", function(c)
        local nc = c:byte() ~ MASK[(i % 4) + 1]
        i = i + 1
        return string.char(nc)
	end)
	
	if payload:match("^print:") then
		print(payload:sub(7))
		if returnOnPrint then return nil end
		goto recv
	elseif payload:match("^warn:") then
		objc.warning(payload:sub(6))
		if returnOnPrint then return nil end
		goto recv
	end
	
	return payload
end

function WebSocketStream:canread()
	local rready = socket.select(self.select_recvt, nil, 0)
	return #rready > 0
end

function WebSocketStream:send(msg)
	local plLen = string.len(msg)
	
	-- Always a text message.
	self.socket:send(string.pack(">B", 0x81))
	
	if plLen < 126 then
		self.socket:send(string.pack(">B", plLen))
	elseif plLen < 65536 then
		self.socket:send(string.pack(">B", 126))
		self.socket:send(string.pack(">I2", plLen))
	else
		self.socket:send(string.pack(">B", 127))
		self.socket:send(string.pack(">I8", plLen))
	end
	
	self.socket:send(msg)
end

local WebSocketServer = class()

function WebSocketServer:init(port)
	self.server = socket.bind("127.0.0.1", port)
	self.server:settimeout(3) -- 3 Second timeout
end

function WebSocketServer:getClient()
	
	local client = self.server:accept()
	assert(client ~= nil, "Failed to connect to Javascript Process")
	
	-- Perform the connection handshake
		
	-- Read WebSocket headers
	local headers = {}
	client:receive('*l') -- Ignore first line 'GET / HTTP/1.1'
	while true do
		local l = client:receive('*l')
		if l == "" then
			break
		end
		local k,v = l:match("(.-):%s*(.*)")
		headers[k] = v
	end
	
	local UPGRADE_RESPONSE = table.concat({
		"HTTP/1.1 101 Switching Protocols",
		"Upgrade: websocket",
		"Connection: Upgrade",
		"Sec-WebSocket-Accept: {Key}",
		"\r\n"
	}, "\r\n")
	
	local key = headers["Sec-WebSocket-Key"] .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
	key = sha1_binary(key)
	key = mime_b64(key)
	local response = UPGRADE_RESPONSE:gsub("{Key}", key);
	client:send(response)
	--client:settimeout(1)
	
	-- Create the client stream
	return WebSocketStream(client)
end

return WebSocketServer