局部放大画图

发布时间 2023-11-18 21:37:39作者: SWHsz

学姐改好的,存一下以便以后用。主要用于训练到最后看不清末端的差异,放大局部的。

import random
import csv
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
import numpy as np
def zone_and_linked(ax,axins,zone_left,zone_right,x,y,linked='bottom',
                    x_ratio=0.05,y_ratio=0.05):
    """缩放内嵌图形,并且进行连线
    ax:         调用plt.subplots返回的画布。例如: fig,ax = plt.subplots(1,1)
    axins:      内嵌图的画布。 例如 axins = ax.inset_axes((0.4,0.1,0.4,0.3))
    zone_left:  要放大区域的横坐标左端点
    zone_right: 要放大区域的横坐标右端点
    x:          X轴标签
    y:          列表,所有y值
    linked:     进行连线的位置,{'bottom','top','left','right'}
    x_ratio:    X轴缩放比例
    y_ratio:    Y轴缩放比例
    """
    xlim_left = x[zone_left]-(x[zone_right]-x[zone_left])*x_ratio
    xlim_right = x[zone_right]+(x[zone_right]-x[zone_left])*x_ratio

    y_data = np.hstack([yi[zone_left:zone_right] for yi in y])
    ylim_bottom = np.min(y_data)-(np.max(y_data)-np.min(y_data))*y_ratio
    ylim_top = np.max(y_data)+(np.max(y_data)-np.min(y_data))*y_ratio

    axins.set_xlim(xlim_left, xlim_right)
    axins.set_ylim(ylim_bottom, ylim_top)

    ax.plot([xlim_left,xlim_right,xlim_right,xlim_left,xlim_left],
            [ylim_bottom,ylim_bottom,ylim_top,ylim_top,ylim_bottom],"black")

    if linked == 'bottom':
        xyA_1, xyB_1 = (xlim_left,ylim_top), (xlim_left,ylim_bottom)
        xyA_2, xyB_2 = (xlim_right,ylim_top), (xlim_right,ylim_bottom)
    elif  linked == 'top':
        xyA_1, xyB_1 = (xlim_left,ylim_bottom), (xlim_left,ylim_top)
        xyA_2, xyB_2 = (xlim_right,ylim_bottom), (xlim_right,ylim_top)
    elif  linked == 'left':
        xyA_1, xyB_1 = (xlim_right,ylim_top), (xlim_left,ylim_top)
        xyA_2, xyB_2 = (xlim_right,ylim_bottom), (xlim_left,ylim_bottom)
    elif  linked == 'right':
        xyA_1, xyB_1 = (xlim_left,ylim_top), (xlim_right,ylim_top)
        xyA_2, xyB_2 = (xlim_left,ylim_bottom), (xlim_right,ylim_bottom)
        
    con = ConnectionPatch(xyA=xyA_1,xyB=xyB_1,coordsA="data",
                          coordsB="data",axesA=axins,axesB=ax)
    axins.add_artist(con)
    con = ConnectionPatch(xyA=xyA_2,xyB=xyB_2,coordsA="data",
                          coordsB="data",axesA=axins,axesB=ax)
    axins.add_artist(con)
# 读取文件
data = []
for i in range(1,5):
    data.append([])
    path = './mnist/'+str(i-1)+'/acc_test.csv'
    reader = csv.reader(open(path))
    for item in reader:
        data[i-1].append(float(item[0]))

plt.rc('font',family='Times New Roman')
plt.figure(figsize=(10, 4))
width = 0.25  # 条形图的宽度
# x = [str(i) for i in range(0,201,25)]
x = [i for i in range(0,201,25)]

columns = [i for i in range(0,200)]#轮次
font1 = {'family' : 'Times New Roman',
'weight' : 'normal',
'size' : 14,
}
# 绘图
fig, ax = plt.subplots()

ax.plot(columns, data[0],color="#7DABCF",label='20',marker='o',markersize=3,markerfacecolor='none')
ax.plot(columns,  data[1],color="#AAB083",label='15',marker='x',markersize=3,markerfacecolor='none')
ax.plot(columns, data[2],color="#FBC1AD",label='25',marker='v',markersize=3,markerfacecolor='none')
ax.plot(columns, data[3],color="#ABC1AD",label='30',marker='s',markersize=3,markerfacecolor='none')
# plt.plot(columns, scaffold,color="C2",label='Scaffold',marker='v',markersize=5,markerfacecolor='none')

"""
缩放图
"""
# 绘制缩放图
axins1 = ax.inset_axes((0.5, 0.45, 0.3, 0.3))
# 在缩放图中也绘制主图所有内容,然后根据限制横纵坐标来达成局部显示的目的
axins1.plot(columns, data[0],color="#7DABCF",label='20',marker='o',markersize=3,markerfacecolor='none')
axins1.plot(columns, data[1],color="#AAB083",label='15',marker='x',markersize=3,markerfacecolor='none')
axins1.plot(columns, data[2],color="#FBC1AD",label='25',marker='v',markersize=3,markerfacecolor='none')
axins1.plot(columns, data[3],color="#ABC1AD",label='30',marker='s',markersize=3,markerfacecolor='none')

zone_left=180
zone_right=199
# 局部显示并且进行连线
zone_and_linked(ax, axins1, zone_left, zone_right, columns , [data[0],data[1],data[2],data[3]], 'bottom')


# 局部显示并且进行连线
x1=columns
y1=[data[0],data[1],data[2],data[3]]
x_ratio=0.02 # 0.02
y_ratio=0.02 # 0.02
xlim_left = x1[zone_left] - (x1[zone_right] - x1[zone_left]) * x_ratio
xlim_right = x1[zone_right] + (x1[zone_right] - x1[zone_left]) * x_ratio

y_data = np.hstack([yi[zone_left:zone_right] for yi in y1])
ylim_bottom = np.min(y_data)-y_ratio
ylim_top = np.max(y_data)+y_ratio

axins1.set_xlim(xlim_left, xlim_right)
axins1.set_ylim(ylim_bottom, ylim_top)

# ylim_top = np.max(y_data)
# ylim_bottom = np.min(y_data)

# ax.plot([xlim_left, xlim_right, xlim_right, xlim_left, xlim_left],
#             [ylim_bottom, ylim_bottom, ylim_top, ylim_top, ylim_bottom], "black")

ylim_top = np.max(y_data)
ylim_bottom = np.min(y_data)-y_ratio

xyA_1, xyB_1 = (xlim_left, ylim_bottom), (xlim_left, ylim_top)
xyA_2, xyB_2 = (xlim_right, ylim_bottom), (xlim_right, ylim_top)

con = ConnectionPatch(xyA=xyA_1, xyB=xyB_1, coordsA="data",
                          coordsB="data", axesA=axins1, axesB=ax)
axins1.add_artist(con)
con = ConnectionPatch(xyA=xyA_2, xyB=xyB_2, coordsA="data",
                          coordsB="data", axesA=axins1, axesB=ax)
axins1.add_artist(con)

"""
缩放图
"""
plt.gca().set_aspect(100)#改变xy轴长宽比例
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
# ax.set_ylim(0,1)
# ax.set_yticks([round(i,1) for i in np.linspace(0,1,11)])
ax.set_ylim(0,1)
ax.set_yticks([round(i,2) for i in np.linspace(0,1,11)])


ax.set_xticks(x)  # 设置刻度标签。
ax.set_ylabel(r"Accuracy",fontsize=14)
ax.set_xlabel(r"Communication rounds",fontsize=14)
ax.set_title('')
ax.legend(prop=font1)
plt.grid(axis='both', linestyle='--', linewidth=0.5)

plt.savefig("acc_cluster.pdf")
plt.show()