golang蒙特卡洛树算法实现五子棋AI

发布时间 2023-04-01 20:52:23作者: janbar

已经实现蒙特卡洛树算法的通用逻辑,只需要对应结构体实现相关接口就可以直接使用该算法。

优化算法主要优化GetActions生成下一步动作,要尽可能少,去掉无意义的动作。

以及优化ActionPolicy从众多动作挑选比较优秀的动作。对应五子棋就是执行该动作后当前局面评分最高。

package main

import (
	"fmt"
	"math"
	"math/rand"
	"strings"
	"time"
)

func main() {
	var (
		board = NewQuZiQi(15)
		x, y  int
	)

	board.Print()
	for board.IsTerminal() == 0 {
		board = Search(time.Second*10, board).(*WuZiQi)

		board.Print()
		if board.IsTerminal() == 1 {
			fmt.Println("电脑赢了")
			return
		}

		for {
			fmt.Print("轮到您执棋,请输入坐标: ")
			_, _ = fmt.Scanln(&x, &y)
			x--
			y--
			if x < 0 || y < 0 || x >= board.size || y >= board.size {
				fmt.Println("您输入的数据超出棋盘范围")
			} else if board.board[x][y] > 0 {
				fmt.Println("该位置已有棋子")
			} else {
				board.board[x][y] = 2
				board.player = 1 // 下一步该电脑下
				break
			}
		}

		board.Print()
		if board.IsTerminal() == 2 {
			fmt.Println("你赢了")
			return
		}
	}
}

// WuZiQi 五子棋游戏
type WuZiQi struct {
	size   int     // 棋盘大小
	board  [][]int // 棋盘状态
	player int     // 1: 电脑落子,2: 玩家落子
}

func NewQuZiQi(size int) *WuZiQi {
	w := &WuZiQi{
		size:   size,
		board:  make([][]int, size),
		player: 1,
	}
	for i := 0; i < size; i++ {
		w.board[i] = make([]int, size)
	}
	size /= 2
	// 默认中间落一个棋子
	// 0: 表示没有落子,1: 表示电脑,2: 表示玩家
	w.board[size][size] = 2
	return w
}

func (w *WuZiQi) Print() {
	var (
		str strings.Builder
		num = func(n int) {
			a, b := n/10, n%10
			if a > 0 {
				str.WriteByte(byte(a + '0'))
			} else {
				str.WriteByte(' ') // 1位数前面加空格
			}
			str.WriteByte(byte(b + '0'))
		}
	)
	str.WriteString("   ")
	for i := 1; i <= w.size; i++ {
		str.WriteByte(' ')
		num(i)
	}
	str.WriteByte('\n')
	for i := 0; i < w.size; i++ {
		str.WriteString("   ")
		for j := 0; j < w.size; j++ {
			str.WriteString(" __")
		}

		str.WriteByte('\n')
		num(i + 1)
		str.WriteByte(' ')

		for j := 0; j < w.size; j++ {
			str.WriteByte('|')
			switch w.board[i][j] {
			case 0:
				str.WriteByte(' ')
			case 1:
				str.WriteByte('O')
			case 2:
				str.WriteByte('X')
			}
			str.WriteByte(' ')
		}
		str.WriteString("|\n")
	}
	str.WriteString("   ")
	for i := 0; i < w.size; i++ {
		str.WriteString(" __")
	}
	fmt.Println(str.String())
}

func (w *WuZiQi) IsTerminal() int {
	full := -1 // 没有空位且都没赢
	for i := 0; i < w.size; i++ {
		for j := 0; j < w.size; j++ {
			if wc := w.board[i][j]; wc == 0 {
				full = 0 // 还有空位,没结束
			} else {
				// 向右
				cnt, x, y := 1, 0, j+1
				for ; y < w.size && w.board[i][y] == wc; y++ {
					cnt++
				}
				if cnt >= 5 {
					return wc
				}
				// 向下
				cnt, x = 1, i+1
				for ; x < w.size && w.board[x][j] == wc; x++ {
					cnt++
				}
				if cnt >= 5 {
					return wc
				}
				// 向右下
				cnt, x, y = 1, i+1, j+1
				for ; x < w.size && y < w.size && w.board[x][y] == wc; x, y = x+1, y+1 {
					cnt++
				}
				if cnt >= 5 {
					return wc
				}
				// 向左下
				cnt, x, y = 1, i+1, j-1
				for ; x < w.size && y >= 0 && w.board[x][y] == wc; x, y = x+1, y-1 {
					cnt++
				}
				if cnt >= 5 {
					return wc
				}
			}
		}
	}
	return full
}

func (w *WuZiQi) Result(state int) float64 {
	switch state {
	case -1:
		return 0 // 都没赢且没空位
	case 1:
		return -1 // 电脑赢了
	case 2:
		return +1 // 玩家赢了
	default:
		return 0 // 都没赢且有空位
	}
}

func (w *WuZiQi) GetActions() (res []any) {
	// todo 敌方上一步落子附近才是最优搜索范围
	//  某个落子必胜,则直接落子,如果某个落子让对手所有落子都必败则直接落子
	//    因此后续动作进一步缩小范围
	//  可以使用hash判断棋盘状态

	m := map[[2]int]struct{}{} // 用于去重
	for i := 0; i < w.size; i++ {
		for j := 0; j < w.size; j++ {
			if w.board[i][j] == 0 || w.board[i][j] == w.player {
				continue // 跳过空位和己方棋子
			}

			x0, x1, y0, y1 := i-2, i+2, j-2, j+2
			for ii := x0; ii < x1; ii++ {
				for jj := y0; jj < y1; jj++ {
					if ii >= 0 && jj >= 0 && ii < w.size && jj < w.size &&
						w.board[ii][jj] == 0 {

						p := [2]int{ii, jj}
						_, ok := m[p]
						if !ok {
							// 在棋子周围2格范围的空位加到结果中
							// 超过2格的空位落子的意义不大
							res = append(res, p)
							m[p] = struct{}{}
						}
					}
				}
			}
		}
	}
	return
}

func (w *WuZiQi) ActionPolicy(action []any) any {
	// 目前随机选一个动作,应该是好方案先选出来
	return action[rand.Intn(len(action))]
}

func (w *WuZiQi) Action(action any) TreeState {
	wn := &WuZiQi{
		size:   w.size,
		board:  make([][]int, w.size),
		player: 3 - w.player, // 切换电脑和玩家
	}
	for i := 0; i < w.size; i++ {
		wn.board[i] = make([]int, w.size)
		for j := 0; j < w.size; j++ {
			wn.board[i][j] = w.board[i][j]
		}
	}

	ac := action.([2]int) // 在该位置落子
	wn.board[ac[0]][ac[1]] = w.player
	return wn
}

// MonteCarloTree 下面是算法部分
// 你的对象只需要提供TreeState所有接口,就可以直接使用
// https://github.com/int8/monte-carlo-tree-search
// https://blog.csdn.net/masterhero666/article/details/126325506
type (
	TreeState interface {
		IsTerminal() int        // 0: 未结束,其他为自定义状态
		Result(int) float64     // 计算分数,传入IsTerminal结果
		GetActions() []any      // 获取所有合法动作, todo 考虑获取不到动作时如何处理
		ActionPolicy([]any) any // 按策略挑选一个动作
		Action(any) TreeState   // 执行动作生成子节点
	}

	McTreeNode struct {
		parent         *McTreeNode
		children       []*McTreeNode
		score          float64
		visitCount     float64
		untriedActions []any
		nodeState      TreeState
	}
)

func Search(simulate any, state TreeState, discount ...float64) TreeState {
	var (
		root = &McTreeNode{nodeState: state}
		leaf *McTreeNode
		dp   = 1.4 // 折扣参数默认值
	)
	if len(discount) > 0 {
		dp = discount[0]
	}

	var loop func() bool
	switch s := simulate.(type) {
	case int:
		loop = func() bool {
			s-- // 模拟指定次数后退出
			return s >= 0
		}
	case time.Duration:
		ts := time.Now().Add(s) // 超过指定时间后退出
		loop = func() bool { return time.Now().Before(ts) }
	case func() bool:
		loop = s // 或者由外部指定模拟结束方案
	default:
		panic(simulate)
	}

	for loop() {
		leaf = root.treePolicy(dp)

		result, curState := 0, leaf.nodeState
		for {
			if result = curState.IsTerminal(); result != 0 {
				break // 结束状态
			}

			// 根据该节点状态生成所有合法动作
			all := curState.GetActions()
			// 按照某种策略选出1个动作,不同于expand的顺序取出
			one := curState.ActionPolicy(all)
			// 执行该动作,重复该过程,直到结束
			curState = curState.Action(one)
		}

		// 根据结束状态计算结果,将该结果反向传播
		leaf.backPropagate(curState.Result(result))
	}
	return root.chooseBestChild(dp).nodeState // 选择最优子节点
}

func (cur *McTreeNode) chooseBestChild(c float64) *McTreeNode {
	var (
		idx        = 0
		maxValue   = -math.MaxFloat64
		childValue float64
	)
	for i, child := range cur.children {
		childValue = (child.score / child.visitCount) +
			c*math.Sqrt(math.Log(cur.visitCount)/child.visitCount)
		if childValue > maxValue {
			maxValue = childValue
			idx = i // 选择分值最高的子节点
		}
	}
	return cur.children[idx]
}

func (cur *McTreeNode) backPropagate(result float64) {
	nodeCursor := cur
	for nodeCursor.parent != nil {
		nodeCursor.score += result
		nodeCursor.visitCount++ // 反向传播,增加访问次数,更新分数
		nodeCursor = nodeCursor.parent
	}
	nodeCursor.visitCount++
}

func (cur *McTreeNode) expand() *McTreeNode {
	res := cur.untriedActions[0] // 返回1个未经尝试动作
	cur.untriedActions = cur.untriedActions[1:]

	child := &McTreeNode{
		parent:    cur, // 当前节点按顺序弹出1个动作,执行动作生成子节点
		nodeState: cur.nodeState.Action(res),
	}
	cur.children = append(cur.children, child)
	return child
}

func (cur *McTreeNode) treePolicy(discountParamC float64) *McTreeNode {
	nodeCursor := cur // 一直循环直到结束
	for nodeCursor.nodeState.IsTerminal() == 0 {
		if nodeCursor.untriedActions == nil {
			// 只会初始化1次,找出该节点所有动作
			nodeCursor.untriedActions = nodeCursor.nodeState.GetActions()
		}
		if len(nodeCursor.untriedActions) > 0 {
			return nodeCursor.expand() // 存在未处理动作则添加子节点
		}
		// 处理完动作,选择最好子节点继续往下处理
		nodeCursor = nodeCursor.chooseBestChild(discountParamC)
	}
	return nodeCursor
}