local bench = script and require(script.Parent.bench_support) or pcall(require, "bench_support") or require("../bench_support") local RANKS = "12345678" local FILES = "abcdefgh" local PieceSymbols = "PpRrNnBbQqKk" local UnicodePieces = {"♙", "♟", "♖", "♜", "♘", "♞", "♗", "♝", "♕", "♛", "♔", "♚"} local StartingFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" -- -- Lua 5.2 Compat -- if not table.create then function table.create(n, v) local result = {} for i=1,n do result[i] = v end return result end end if not table.move then function table.move(a, from, to, start, target) local dx = start - from for i=from,to do target[i+dx] = a[i] end end end -- -- Utils -- local function square(s) return RANKS:find(s:sub(2,2)) * 8 + FILES:find(s:sub(1,1)) - 9 end local function squareName(n) local file = n % 8 local rank = (n-file)/8 return FILES:sub(file+1,file+1) .. RANKS:sub(rank+1,rank+1) end local function moveName(v ) local from = bit32.extract(v, 6, 6) local to = bit32.extract(v, 0, 6) local piece = bit32.extract(v, 20, 4) local captured = bit32.extract(v, 25, 4) local move = PieceSymbols:sub(piece,piece) .. ' ' .. squareName(from) .. (captured ~= 0 and 'x' or '-') .. squareName(to) if bit32.extract(v,14) == 1 then if to > from then return "O-O" else return "O-O-O" end end local promote = bit32.extract(v,15,4) if promote ~= 0 then move = move .. "=" .. PieceSymbols:sub(promote,promote) end return move end local function ucimove(m) local mm = squareName(bit32.extract(m, 6, 6)) .. squareName(bit32.extract(m, 0, 6)) local promote = bit32.extract(m,15,4) if promote > 0 then mm = mm .. PieceSymbols:sub(promote,promote):lower() end return mm end local _utils = {squareName, moveName} -- -- Bitboards -- local Bitboard = {} function Bitboard:toString() local out = {} local src = self.h for x=7,0,-1 do table.insert(out, RANKS:sub(x+1,x+1)) table.insert(out, " ") local bit = bit32.lshift(1,(x%4) * 8) for x=0,7 do if bit32.band(src, bit) ~= 0 then table.insert(out, "x ") else table.insert(out, "- ") end bit = bit32.lshift(bit, 1) end if x == 4 then src = self.l end table.insert(out, "\n") end table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n') table.insert(out, '#: ' .. self:popcnt() .. "\tl:" .. self.l .. "\th:" .. self.h) return table.concat(out) end function Bitboard.from(l ,h ) return setmetatable({l=l, h=h}, Bitboard) end Bitboard.zero = Bitboard.from(0,0) Bitboard.full = Bitboard.from(0xFFFFFFFF, 0xFFFFFFFF) local Rank1 = Bitboard.from(0x000000FF, 0) local Rank3 = Bitboard.from(0x00FF0000, 0) local Rank6 = Bitboard.from(0, 0x0000FF00) local Rank8 = Bitboard.from(0, 0xFF000000) local FileA = Bitboard.from(0x01010101, 0x01010101) local FileB = Bitboard.from(0x02020202, 0x02020202) local FileC = Bitboard.from(0x04040404, 0x04040404) local FileD = Bitboard.from(0x08080808, 0x08080808) local FileE = Bitboard.from(0x10101010, 0x10101010) local FileF = Bitboard.from(0x20202020, 0x20202020) local FileG = Bitboard.from(0x40404040, 0x40404040) local FileH = Bitboard.from(0x80808080, 0x80808080) local _Files = {FileA, FileB, FileC, FileD, FileE, FileF, FileG, FileH} -- These masks are filled out below for all files local RightMasks = {FileH} local LeftMasks = {FileA} local function popcnt32(i) i = i - bit32.band(bit32.rshift(i,1), 0x55555555) i = bit32.band(i, 0x33333333) + bit32.band(bit32.rshift(i,2), 0x33333333) return bit32.rshift(bit32.band(i + bit32.rshift(i,4), 0x0F0F0F0F) * 0x01010101, 24) end function Bitboard:up() return self:lshift(8) end function Bitboard:down() return self:rshift(8) end function Bitboard:right() return self:band(FileH:inverse()):lshift(1) end function Bitboard:left() return self:band(FileA:inverse()):rshift(1) end function Bitboard:move(x,y) local out = self if x < 0 then out = out:bandnot(RightMasks[-x]):lshift(-x) end if x > 0 then out = out:bandnot(LeftMasks[x]):rshift(x) end if y < 0 then out = out:rshift(-8 * y) end if y > 0 then out = out:lshift(8 * y) end return out end function Bitboard:popcnt() return popcnt32(self.l) + popcnt32(self.h) end function Bitboard:band(other ) return Bitboard.from(bit32.band(self.l,other.l), bit32.band(self.h, other.h)) end function Bitboard:bandnot(other ) return Bitboard.from(bit32.band(self.l,bit32.bnot(other.l)), bit32.band(self.h, bit32.bnot(other.h))) end function Bitboard:bandempty(other ) return bit32.band(self.l,other.l) == 0 and bit32.band(self.h, other.h) == 0 end function Bitboard:bor(other ) return Bitboard.from(bit32.bor(self.l,other.l), bit32.bor(self.h, other.h)) end function Bitboard:bxor(other ) return Bitboard.from(bit32.bxor(self.l,other.l), bit32.bxor(self.h, other.h)) end function Bitboard:inverse() return Bitboard.from(bit32.bxor(self.l,0xFFFFFFFF), bit32.bxor(self.h, 0xFFFFFFFF)) end function Bitboard:empty() return self.h == 0 and self.l == 0 end if not bit32.countrz then local function ctz(v) if v == 0 then return 32 end local offset = 0 while bit32.extract(v, offset) == 0 do offset = offset + 1 end return offset end function Bitboard:ctz() local result = ctz(self.l) if result == 32 then return ctz(self.h) + 32 else return result end end function Bitboard:ctzafter(start) start = start + 1 if start < 32 then for i=start,31 do if bit32.extract(self.l, i) == 1 then return i end end end for i=math.max(32,start),63 do if bit32.extract(self.h, i-32) == 1 then return i end end return 64 end else function Bitboard:ctz() local result = bit32.countrz(self.l) if result == 32 then return bit32.countrz(self.h) + 32 else return result end end function Bitboard:ctzafter(start) local masked = self:band(Bitboard.full:lshift(start+1)) return masked:ctz() end end function Bitboard:lshift(amt) assert(amt >= 0) if amt == 0 then return self end if amt > 31 then return Bitboard.from(0, bit32.lshift(self.l, amt-32)) end local l = bit32.lshift(self.l, amt) local h = bit32.bor( bit32.lshift(self.h, amt), bit32.extract(self.l, 32-amt, amt) ) return Bitboard.from(l, h) end function Bitboard:rshift(amt) assert(amt >= 0) if amt == 0 then return self end local h = bit32.rshift(self.h, amt) local l = bit32.bor( bit32.rshift(self.l, amt), bit32.lshift(bit32.extract(self.h, 0, amt), 32-amt) ) return Bitboard.from(l, h) end function Bitboard:index(i) if i > 31 then return bit32.extract(self.h, i - 32) else return bit32.extract(self.l, i) end end function Bitboard:set(i , v) if i > 31 then return Bitboard.from(self.l, bit32.replace(self.h, v, i - 32)) else return Bitboard.from(bit32.replace(self.l, v, i), self.h) end end function Bitboard:isolate(i) return self:band(Bitboard.some(i)) end function Bitboard.some(idx ) return Bitboard.zero:set(idx, 1) end Bitboard.__index = Bitboard Bitboard.__tostring = Bitboard.toString for i=2,8 do RightMasks[i] = RightMasks[i-1]:rshift(1):bor(FileH) LeftMasks[i] = LeftMasks[i-1]:lshift(1):bor(FileA) end -- -- Board -- local Board = {} function Board.new() local boards = table.create(12, Bitboard.zero) boards.ocupied = Bitboard.zero boards.white = Bitboard.zero boards.black = Bitboard.zero boards.unocupied = Bitboard.full boards.ep = Bitboard.zero boards.castle = Bitboard.zero boards.toMove = 1 boards.hm = 0 boards.moves = 0 boards.material = 0 return setmetatable(boards, Board) end function Board.fromFen(fen ) local b = Board.new() local i = 0 local rank = 7 local file = 0 while true do i = i + 1 local p = fen:sub(i,i) if p == '/' then rank = rank - 1 file = 0 elseif tonumber(p) ~= nil then file = file + tonumber(p) else local pidx = PieceSymbols:find(p) if pidx == nil then break end b[pidx] = b[pidx]:set(rank*8+file, 1) file = file + 1 end end local move, castle, ep, hm, m = string.match(fen, "^ ([bw]) ([KQkq-]*) ([a-h-][0-9]?) (%d*) (%d*)", i) if move == nil then print(fen:sub(i)) end b.toMove = move == 'w' and 1 or 2 if ep ~= "-" then b.ep = Bitboard.some(square(ep)) end if castle ~= "-" then local oo = Bitboard.zero if castle:find("K") then oo = oo:set(7, 1) end if castle:find("Q") then oo = oo:set(0, 1) end if castle:find("k") then oo = oo:set(63, 1) end if castle:find("q") then oo = oo:set(56, 1) end b.castle = oo end b.hm = hm b.moves = m b:updateCache() return b end function Board:index(idx ) if self.white:index(idx) == 1 then for p=1,12,2 do if self[p]:index(idx) == 1 then return p end end else for p=2,12,2 do if self[p]:index(idx) == 1 then return p end end end return 0 end function Board:updateCache() for i=1,11,2 do self.white = self.white:bor(self[i]) self.black = self.black:bor(self[i+1]) end self.ocupied = self.black:bor(self.white) self.unocupied = self.ocupied:inverse() self.material = 100*self[1]:popcnt() - 100*self[2]:popcnt() + 500*self[3]:popcnt() - 500*self[4]:popcnt() + 300*self[5]:popcnt() - 300*self[6]:popcnt() + 300*self[7]:popcnt() - 300*self[8]:popcnt() + 900*self[9]:popcnt() - 900*self[10]:popcnt() end function Board:fen() local out = {} local s = 0 local idx = 56 for i=0,63 do if i % 8 == 0 and i > 0 then idx = idx - 16 if s > 0 then table.insert(out, '' .. s) s = 0 end table.insert(out, '/') end local p = self:index(idx) if p == 0 then s = s + 1 else if s > 0 then table.insert(out, '' .. s) s = 0 end table.insert(out, PieceSymbols:sub(p,p)) end idx = idx + 1 end if s > 0 then table.insert(out, '' .. s) end table.insert(out, self.toMove == 1 and ' w ' or ' b ') if self.castle:empty() then table.insert(out, '-') else if self.castle:index(7) == 1 then table.insert(out, 'K') end if self.castle:index(0) == 1 then table.insert(out, 'Q') end if self.castle:index(63) == 1 then table.insert(out, 'k') end if self.castle:index(56) == 1 then table.insert(out, 'q') end end table.insert(out, ' ') if self.ep:empty() then table.insert(out, '-') else table.insert(out, squareName(self.ep:ctz())) end table.insert(out, ' ' .. self.hm) table.insert(out, ' ' .. self.moves) return table.concat(out) end function Board:pmoves(idx) return self:generate(idx) end function Board:pcaptures(idx) return self:generate(idx):band(self.ocupied) end local ROOK_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}} local BISHOP_SLIDES = {{1,1}, {-1,1}, {1,-1}, {-1,-1}} local QUEEN_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}, {1,1}, {-1,1}, {1,-1}, {-1,-1}} local KNIGHT_MOVES = {{2,1}, {2,-1}, {-2,1}, {-2,-1}, {1,2}, {1,-2}, {-1,2}, {-1,-2}} function Board:generate(idx) local piece = self:index(idx) local r = Bitboard.some(idx) local out = Bitboard.zero local type = bit32.rshift(piece - 1, 1) local cancapture = piece % 2 == 1 and self.black or self.white if piece == 0 then return Bitboard.zero end if type == 0 then -- Pawn local d = -(piece*2 - 3) local movetwo = piece == 1 and Rank3 or Rank6 out = out:bor(r:move(0,d):band(self.unocupied)) out = out:bor(out:band(movetwo):move(0,d):band(self.unocupied)) local captures = r:move(0,d) captures = captures:right():bor(captures:left()) if not captures:bandempty(self.ep) then out = out:bor(self.ep) end captures = captures:band(cancapture) out = out:bor(captures) return out elseif type == 5 then -- King for x=-1,1,1 do for y = -1,1,1 do local w = r:move(x,y) if self.ocupied:bandempty(w) then out = out:bor(w) else if not cancapture:bandempty(w) then out = out:bor(w) end end end end elseif type == 2 then -- Knight for _,j in ipairs(KNIGHT_MOVES) do local w = r:move(j[1],j[2]) if self.ocupied:bandempty(w) then out = out:bor(w) else if not cancapture:bandempty(w) then out = out:bor(w) end end end else -- Sliders (Rook, Bishop, Queen) local slides if type == 1 then slides = ROOK_SLIDES elseif type == 3 then slides = BISHOP_SLIDES else slides = QUEEN_SLIDES end for _, op in ipairs(slides) do local w = r for i=1,7 do w = w:move(op[1], op[2]) if w:empty() then break end if self.ocupied:bandempty(w) then out = out:bor(w) else if not cancapture:bandempty(w) then out = out:bor(w) end break end end end end return out end -- 0-5 - From Square -- 6-11 - To Square -- 12 - is Check -- 13 - Is EnPassent -- 14 - Is Castle -- 15-19 - Promotion Piece -- 20-24 - Moved Pice -- 25-29 - Captured Piece function Board:toString(mark ) local out = {} for x=8,1,-1 do table.insert(out, RANKS:sub(x,x) .. " ") for y=1,8 do local n = 8*x+y-9 local i = self:index(n) if i == 0 then table.insert(out, '-') else -- out = out .. PieceSymbols:sub(i,i) table.insert(out, UnicodePieces[i]) end if mark ~= nil and mark:index(n) ~= 0 then table.insert(out, ')') elseif mark ~= nil and n < 63 and y < 8 and mark:index(n+1) ~= 0 then table.insert(out, '(') else table.insert(out, ' ') end end table.insert(out, "\n") end table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n') table.insert(out, (self.toMove == 1 and "White" or "Black") .. ' e:' .. (self.material/100) .. "\n") return table.concat(out) end function Board:moveList() local tm = self.toMove == 1 and self.white or self.black local castle_rank = self.toMove == 1 and Rank1 or Rank8 local out = {} local function emit(id) if not self:applyMove(id):illegalyChecked() then table.insert(out, id) end end local cr = tm:band(self.castle):band(castle_rank) if not cr:empty() then local p = self.toMove == 1 and 11 or 12 local tcolor = self.toMove == 1 and self.black or self.white local kidx = self[p]:ctz() local castle = bit32.replace(0, p, 20, 4) castle = bit32.replace(castle, kidx, 6, 6) castle = bit32.replace(castle, 1, 14) local mustbeemptyl = LeftMasks[4]:bxor(FileA):band(castle_rank) local cantbethreatened = FileD:bor(FileC):band(castle_rank):bor(self[p]) if not cr:bandempty(FileA) and mustbeemptyl:bandempty(self.ocupied) and not self:isSquareThreatened(cantbethreatened, tcolor) then emit(bit32.replace(castle, kidx - 2, 0, 6)) end local mustbeemptyr = RightMasks[3]:bxor(FileH):band(castle_rank) if not cr:bandempty(FileH) and mustbeemptyr:bandempty(self.ocupied) and not self:isSquareThreatened(mustbeemptyr:bor(self[p]), tcolor) then emit(bit32.replace(castle, kidx + 2, 0, 6)) end end local sq = tm:ctz() repeat local p = self:index(sq) local moves = self:pmoves(sq) while not moves:empty() do local m = moves:ctz() moves = moves:set(m, 0) local id = bit32.replace(m, sq, 6, 6) id = bit32.replace(id, p, 20, 4) local mbb = Bitboard.some(m) if not self.ocupied:bandempty(mbb) then id = bit32.replace(id, self:index(m), 25, 4) end -- Check if pawn needs to be promoted if p == 1 and m >= 8*7 then for i=3,9,2 do emit(bit32.replace(id, i, 15, 4)) end elseif p == 2 and m < 8 then for i=4,10,2 do emit(bit32.replace(id, i, 15, 4)) end else emit(id) end end sq = tm:ctzafter(sq) until sq == 64 return out end function Board:illegalyChecked() local target = self.toMove == 1 and self[PieceSymbols:find("k")] or self[PieceSymbols:find("K")] return self:isSquareThreatened(target, self.toMove == 1 and self.white or self.black) end function Board:isSquareThreatened(target , color ) local tm = color local sq = tm:ctz() repeat local moves = self:pmoves(sq) if not moves:bandempty(target) then return true end sq = color:ctzafter(sq) until sq == 64 return false end function Board:perft(depth ) if depth == 0 then return 1 end if depth == 1 then return #self:moveList() end local result = 0 for k,m in ipairs(self:moveList()) do local c = self:applyMove(m):perft(depth - 1) if c == 0 then -- Perft only counts leaf nodes at target depth -- result = result + 1 else result = result + c end end return result end function Board:applyMove(move ) local out = Board.new() table.move(self, 1, 12, 1, out) local from = bit32.extract(move, 6, 6) local to = bit32.extract(move, 0, 6) local promote = bit32.extract(move, 15, 4) local piece = self:index(from) local captured = self:index(to) local tom = Bitboard.some(to) local isCastle = bit32.extract(move, 14) if piece % 2 == 0 then out.moves = self.moves + 1 end if captured == 1 or piece < 3 then out.hm = 0 else out.hm = self.hm + 1 end out.castle = self.castle out.toMove = self.toMove == 1 and 2 or 1 if isCastle == 1 then local rank = piece == 11 and Rank1 or Rank8 local colorOffset = piece - 11 out[3 + colorOffset] = out[3 + colorOffset]:bandnot(from < to and FileH or FileA) out[3 + colorOffset] = out[3 + colorOffset]:bor((from < to and FileF or FileD):band(rank)) out[piece] = (from < to and FileG or FileC):band(rank) out.castle = out.castle:bandnot(rank) out:updateCache() return out end if piece < 3 then local dist = math.abs(to - from) -- Pawn moved two squares, set ep square if dist == 16 then out.ep = Bitboard.some((from + to) / 2) end -- Remove enpasent capture if not tom:bandempty(self.ep) then if piece == 1 then out[2] = out[2]:bandnot(self.ep:down()) end if piece == 2 then out[1] = out[1]:bandnot(self.ep:up()) end end end if piece == 3 or piece == 4 then out.castle = out.castle:set(from, 0) end if piece > 10 then local rank = piece == 11 and Rank1 or Rank8 out.castle = out.castle:bandnot(rank) end out[piece] = out[piece]:set(from, 0) if promote == 0 then out[piece] = out[piece]:set(to, 1) else out[promote] = out[promote]:set(to, 1) end if captured ~= 0 then out[captured] = out[captured]:set(to, 0) end out:updateCache() return out end Board.__index = Board Board.__tostring = Board.toString -- -- Main -- local failures = 0 local function test(fen, ply, target) local b = Board.fromFen(fen) if b:fen() ~= fen then print("FEN MISMATCH", fen, b:fen()) failures = failures + 1 return end local found = b:perft(ply) if found ~= target then print(fen, "Found", found, "target", target) failures = failures + 1 for k,v in pairs(b:moveList()) do print(ucimove(v) .. ': ' .. (ply > 1 and b:applyMove(v):perft(ply-1) or '1')) end --error("Test Failure") else print("OK", found, fen) end end -- From https://www.chessprogramming.org/Perft_Results -- If interpreter, computers, or algorithm gets too fast -- feel free to go deeper local testCases = {} local function addTest(...) table.insert(testCases, {...}) end addTest(StartingFen, 2, 400) addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 1, 48) addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 2, 191) addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 2, 264) addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 1, 44) addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 1, 46) local function chess() for k,v in ipairs(testCases) do test(v[1],v[2],v[3]) end end bench.runCode(chess, "chess")