Python 使用Matplotlib绘制可拖动的折线

发布时间 2023-04-15 17:03:48作者: 酱_油

Python 使用Matplotlib绘制可拖动的折线

效果图: 

可以拖曲线上的点调整, 也可以拖旁边的sliderbar调整.

 

 

代码如下:

import matplotlib.animation as animation
from matplotlib.widgets import Slider, Button
import pandas as pd
import matplotlib as mpl
from matplotlib import pyplot as plt
import scipy.interpolate as inter
import numpy as np

func = lambda x: np.zeros_like(x)

def load_cache_weight(cache_file):
    import yaml
    global yvals
    with open(cache_file,'r') as f:
        line = f.readline()
        d = dict(yaml.safe_load(line))
    keys = d.keys()
    for i, key in enumerate(keys):
        yvals[i] = d[key]

# user input config
N = 30
st = pd.to_datetime('20230414')
cache_file = None
cache_file = './tmp/saved_weight_2.json'
save_file  = './tmp/saved_weight_2.json'

#get a list of points to fit a spline to as well
xmin = 1 
xmax = N+1
x = np.linspace(xmin,xmax,N)

#spline fit
yvals = func(x)
if cache_file is not None:
    load_cache_weight(cache_file)
spline = inter.InterpolatedUnivariateSpline (x, yvals)

#figure.subplot.right
mpl.rcParams['figure.subplot.left'] = 0.1
mpl.rcParams['figure.subplot.right'] = 0.8

#set up a plot
fig,axes = plt.subplots(1,1,figsize=(16,5),sharex=True)
ax1 = axes

pind = None #active point
epsilon = 5 #max pixel distance

def update(val):
    global yvals
    global spline
    # update curve
    for i in np.arange(N):
      yvals[i] = sliders[i].val 
    l.set_ydata(yvals)
    spline = inter.InterpolatedUnivariateSpline(x, yvals)
    m.set_ydata(spline(X))
    # redraw canvas while idle
    fig.canvas.draw_idle()

def reset(event):
    global yvals
    global spline
    #reset the values
    yvals = func(x)
    if cache_file is not None:
        load_cache_weight(cache_file)
    for i in np.arange(N):
      sliders[i].reset()
    spline = inter.InterpolatedUnivariateSpline(x, yvals)
    l.set_ydata(yvals)
    m.set_ydata(spline(X))
    # redraw canvas while idle
    fig.canvas.draw_idle()

def save_p(event):
    global yvals
    global datelst
    global save_file
    r = dict(zip(map(lambda x: x.strftime('%Y%m%d'),datelst),yvals))
    print(r)
    if save_file is not None:
        with open(save_file,'w') as f:
            import json
            json.dump(r,f)

def button_press_callback(event):
    'whenever a mouse button is pressed'
    global pind
    if event.inaxes is None:
        return
    if event.button != 1:
        return
    #print(pind)
    pind = get_ind_under_point(event)    

def button_release_callback(event):
    'whenever a mouse button is released'
    global pind
    if event.button != 1:
        return
    pind = None

def get_ind_under_point(event):
    'get the index of the vertex under point if within epsilon tolerance'

    # display coords
    #print('display x is: {0}; display y is: {1}'.format(event.x,event.y))
    t = ax1.transData.inverted()
    tinv = ax1.transData 
    xy = t.transform([event.x,event.y])
    #print('data x is: {0}; data y is: {1}'.format(xy[0],xy[1]))
    xr = np.reshape(x,(np.shape(x)[0],1))
    yr = np.reshape(yvals,(np.shape(yvals)[0],1))
    xy_vals = np.append(xr,yr,1)
    xyt = tinv.transform(xy_vals)
    xt, yt = xyt[:, 0], xyt[:, 1]
    d = np.hypot(xt - event.x, yt - event.y)
    indseq, = np.nonzero(d == d.min())
    ind = indseq[0]

    #print(d[ind])
    if d[ind] >= epsilon:
        ind = None
    
    #print(ind)
    return ind

def motion_notify_callback(event):
    'on mouse movement'
    global yvals
    if pind is None:
        return
    if event.inaxes is None:
        return
    if event.button != 1:
        return
    
    #update yvals
    #print('motion x: {0}; y: {1}'.format(event.xdata,event.ydata))
    yvals[pind] = np.clip(event.ydata,-1,1)

    # update curve via sliders and draw
    sliders[pind].set_val(yvals[pind])
    fig.canvas.draw_idle()

############################

ed = st+pd.Timedelta(days=N-1)
datelst = pd.date_range(st,ed)

# ax1.plot ()

###########################

X = np.arange(0,xmax+1,0.1)
ax1.plot (X, func(X), 'k--', label='original')
l, = ax1.plot (x,yvals,color='k',linestyle='none',marker='o',markersize=8)
m, = ax1.plot (X, spline(X), 'r-', label='spline')

ax1.set_yscale('linear')
ax1.set_xlim(0, 32)
ax1.set_ylim(-1.05,1.05)
ax1.set_xlabel('dt')
ax1.set_ylabel('p')
ax1.grid(True)
ax1.yaxis.grid(True,which='minor',linestyle='--')
ax1.legend(loc=2,prop={'size':8})

sliders = []
for i in np.arange(N):

    axamp = plt.axes([0.84, 0.95-(i*0.03), 0.12, 0.02])
    # Slider
    date_i = datelst[i]
    mth = date_i.month
    day = date_i.day
    s = Slider(axamp, '{}/{}'.format(mth,day), -1, 1, valinit=yvals[i])
    sliders.append(s)

    
for i in np.arange(N):
    #samp.on_changed(update_slider)
    sliders[i].on_changed(update)

axres = plt.axes([0.84, 0.95-((N)*0.03), 0.06, 0.02])
bres = Button(axres, 'Reset')
bres.on_clicked(reset)

axres = plt.axes([0.84+0.08, 0.95-((N)*0.03), 0.06, 0.02])
bres2 = Button(axres, 'Save')
bres2.on_clicked(save_p)

fig.canvas.mpl_connect('button_press_event', button_press_callback)
fig.canvas.mpl_connect('button_release_event', button_release_callback)
fig.canvas.mpl_connect('motion_notify_event', motion_notify_callback)

plt.show()