使用GO为启明防火墙添加黑名单

发布时间 2023-08-18 14:16:14作者: yangras
package main

import (
	 "bytes"
	 "encoding/json"
	"fmt"
	 "net/http"
	 "crypto/tls"
	"net/url"
	"strings"
	"regexp"
	"bufio"
	"os"
)

var (
	//client   *http.Client
	cookie   *http.Cookie
)

func sendPostRequest(url string, data map[string]string) (map[string]interface{}, error) {
	jsonData, err := json.Marshal(data)
	if err != nil {
		return nil, fmt.Errorf("JSON marshaling error: %v", err)
	}

	tr := &http.Transport{
		TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
	}

	client := &http.Client{Transport: tr}

	req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
	if err != nil {
		return nil, fmt.Errorf("Request creation error: %v", err)
	}

	req.Header.Set("Content-Type", "application/json")

	resp, err := client.Do(req)
	if err != nil {
		return nil, fmt.Errorf("Request error: %v", err)
	}
	defer resp.Body.Close()
	//保存cookie
	cookie = resp.Cookies()[0]

	var response map[string]interface{}
	err = json.NewDecoder(resp.Body).Decode(&response)
	if err != nil {
		return nil, fmt.Errorf("JSON decoding error: %v", err)
	}
	return response, nil
}

func Login2Fw(url, user, pwd string) (string, error) {
	data := map[string]string{
		"user": user,
		"pwd":  pwd,
	}

	response, err := sendPostRequest(url, data)
	if err != nil {
		return "", err
	}

	result, resultExists := response["result"].(bool)
	
	if !result {
		return "", fmt.Errorf("Result field not found in response")
	}
	authorization,resultExists:= response["authorization"].(string)
	if !resultExists {
		return "", fmt.Errorf(" authorization field not found in response")
	}
	return authorization, nil
}


func EncodeString(input string) string {
	encoded := url.QueryEscape(input)
	return encoded
}


func sendBanRequest(url , authorization string, requestBody map[string]interface{}) (bool, error) {
	jsonData, err := json.Marshal(requestBody)
	if err != nil {
		return false, fmt.Errorf("JSON marshaling error: %v", err)
	}

	tr := &http.Transport{
		TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
	}

	client := &http.Client{Transport: tr}

	req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
	if err != nil {
		return false, fmt.Errorf("Request creation error: %v", err)
	}

	req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=UTF-8")
	req.Header.Set("Authorization", authorization)
	req.Header.Set("Accept", "application/json, text/javascript, */*; q=0.01")
	req.Header.Set("Accept-Language", "zh-CN,zh;q=0.8,zh-TW;q=0.7,zh-HK;q=0.5,en-US;q=0.3,en;q=0.2")
	req.Header.Set("Accept-Encoding", "gzip, deflate")
	req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/116.0")

	req.Header.Set("Sec-Fetch-Dest", "empty")
	req.Header.Set("Sec-Fetch-Mode", "cors")
	req.Header.Set("Sec-Fetch-Site", "same-origin")
	req.Header.Set("Te", "trailers")
	req.AddCookie(cookie)

	resp, err := client.Do(req)
	if err != nil {
		return false, fmt.Errorf("Request error: %v", err)
	}
	defer resp.Body.Close()

	
	var response map[string]interface{}
	err = json.NewDecoder(resp.Body).Decode(&response)
	if err != nil {
		return false, fmt.Errorf("JSON decoding error: %v", err)
	}

	result, resultExists := response["result"].(bool)
	if !resultExists {
		return false, fmt.Errorf("Result field not found in response")
	}

	return result, nil
}


func main() {
	url := "https://x.x.x.x/API/login"
	user := "username"
	pwd := "password"
	banurl:= "https://x.x.x.x/API/black_list"

	authorization, err := Login2Fw(url, user, pwd)
	if err != nil {
		fmt.Println("Error:", err,authorization)
		return
	}
    

    IpList := generateIpList()
	for _,ip :=range IpList{
		fmt.Println("准备封禁IP: ",ip)
		requestBody := map[string]interface{}{
			"ip_type": "2",
			"bl_ip_name": ip,
			"timeout": "0",
			"bl_grp_profile":"xxx", 
			"time_cfg_type":"0",
			"valid_time_type":"0"}
		
		success, err := sendBanRequest(banurl,authorization,requestBody)
		if err != nil {
			fmt.Println("Error:", err)
			return
		}
	
		if success {
			fmt.Println("封禁IP: ",ip,"成功")
		} else {
			fmt.Println("封禁IP: ",ip,"失败")
		}
	
	
	}
	

}





func generateIpList() []string {
	// 读取 InputList.txt
	inputFilePath := "InputList.txt"
	inputLines, err := readLines(inputFilePath)
	if err != nil {
		fmt.Println("Error reading InputList.txt:", err)
		return nil
	}

	// 过滤并处理行
	filteredLines := filterAndProcessLines(inputLines)

	// 读取 WhiteList.txt
	whiteListFilePath := "WhiteList.txt"
	whiteList, err := readLines(whiteListFilePath)
	if err != nil {
		fmt.Println("Error reading WhiteList.txt:", err)
		return nil
	}

	// 过滤 WhiteList
	filteredOutput := filterWhiteList(filteredLines, whiteList)
	fmt.Println("获取到有效IP,准备封禁IP")
	return filteredOutput
}

func readLines(filePath string) ([]string, error) {
	file, err := os.Open(filePath)
	if err != nil {
		return nil, err
	}
	defer file.Close()

	var lines []string
	scanner := bufio.NewScanner(file)
	for scanner.Scan() {
		lines = append(lines, scanner.Text())
	}

	if err := scanner.Err(); err != nil {
		return nil, err
	}

	return lines, nil
}

func filterAndProcessLines(lines []string) []string {
	var filteredLines []string

	for _, line := range lines {
		// 去除前后空格和 tab
		parts := extractIPv4Parts(line)
		for _, part := range parts {
		// 过滤不符合 IPv4 格式的行
		fmt.Println("已处理输入文本,获取到原始IP: ",part)
			if isIPv4(part) {
				if !strings.HasSuffix(part, ".0") && !strings.HasSuffix(part, ".255") {
					filteredLines = append(filteredLines, part)
				}
				
			} 
		}
	}
	return filteredLines
}

func isIPv4(input string) bool {
	re := regexp.MustCompile(`^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$`)
	return re.MatchString(input)
}

func filterWhiteList(lines []string, whiteList []string) []string {
	var filteredOutput []string

	for _, line := range lines {
		if !contains(whiteList, line) {
			filteredOutput = append(filteredOutput, line)
		}
	}

	return filteredOutput
}

func contains(slice []string, element string) bool {
	for _, e := range slice {
		if e == element {
			return true
		}
	}
	return false
}



func extractIPv4Parts(input string) []string {
	re := regexp.MustCompile(`\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}`)
	return re.FindAllString(input, -1)
}