梯度求解(csp)

发布时间 2023-11-26 20:34:25作者: 小菜碟子

202309-3的一道题目

1 struct item
2 {
3     long long k;//常系数
4     map<int, int>mp;//存储每一项
5     item(long long coe, map<int, int>mp) :k(coe), mp(mp) {}//结构体构造函数
6 };

 


这里新学一个结构体构造函数,就可以返回该结构体

1 struct fomula
2 {
3     vector<item> value;//每一项
4     fomula(vector<item>vec) :value(vec) {}//有构造函数才能返回
5 };
  1 #include <iostream>
  2 #include <math.h>
  3 #include <vector>
  4 #include <map>
  5 #include<sstream>
  6 #include <stack>
  7 
  8 using namespace std;
  9 
 10 
 11 long long mod = 1e9 + 7;
 12 vector<long long>val;//存储求导时的变量值
 13 
 14 struct item
 15 {
 16     long long k;//常系数
 17     map<int, int>mp;//存储每一项
 18     item(long long coe, map<int, int>mp) :k(coe), mp(mp) {}//结构体构造函数
 19 };
 20 
 21 struct fomula
 22 {
 23     vector<item> value;//每一项
 24     fomula(vector<item>vec) :value(vec) {}//有构造函数才能返回
 25 };
 26 
 27 stack<fomula>st;//栈S
 28 
 29 long long convert(string str)//字符串转换成数字 
 30 {
 31     long long num = 0;
 32     for (int i = (str[0] == '-') ? 1 : 0; i < str.length(); i++) 
 33     {
 34         num *= 10;
 35         num += str[i] - '0';
 36     }
 37     return (str[0] == '-') ? -1 * num : num;
 38 }
 39 
 40 item item_mul(item A, item B)//乘积项乘法函数 
 41 {
 42     long long k = A.k * B.k;
 43     map<int, int>mp_c;
 44     map<int, int>::iterator it;//创建迭代器
 45     //更新A项中每个变量的指数
 46     for (it = A.mp.begin(); it != A.mp.end(); it++) 
 47     {
 48         mp_c[it->first] = A.mp[it->first] + B.mp[it->first];
 49         B.mp.erase(it->first);//同时删除,便于后续添加B项中指数
 50     }
 51     //添加B中的项
 52     for (it = B.mp.begin(); it != B.mp.end(); it++) 
 53         mp_c[it->first] = B.mp[it->first];
 54 
 55     return item(k, mp_c);
 56 }
 57 
 58 fomula fomula_mul(fomula A, fomula B)//多项式乘法 
 59 {
 60     vector<item>vec;
 61     for (int i = 0; i < A.value.size(); i++) 
 62         for (int j = 0; j < B.value.size(); j++) 
 63             vec.push_back(item_mul(A.value[i], B.value[j]));//A每一项都会×B中每一项
 64         
 65     return fomula(vec);
 66 }
 67 
 68 fomula fomula_add(fomula A, fomula B)//多项式加法 
 69 {
 70     for (int i = 0; i < B.value.size(); i++) {
 71         A.value.push_back(B.value[i]);
 72     }
 73     return A;
 74 }
 75 
 76 fomula fomula_sub(fomula A, fomula B)//多项式减法 
 77 {
 78     //每一项带符号再相加
 79     for (int i = 0; i < A.value.size(); i++) {
 80         A.value[i].k *= -1;
 81     }
 82     return fomula_add(B, A);
 83 }
 84 
 85 long long function(fomula A, int goal)//求导函数 对最终的fomula求导 
 86 {
 87     long long sum = 0, mul;
 88     for (int i = 0; i < A.value.size(); i++) 
 89     {
 90         item now = A.value[i];
 91         mul = 1;
 92         if (now.mp.find(goal) != now.mp.end()) 
 93         {//此乘积项含有要求导的变量才拥有计算价值
 94             mul = (now.k * now.mp[goal]) % mod;
 95             now.mp[goal]--;//求导过程
 96 
 97             //计算该项求导结果
 98             for (map<int, int>::iterator it = now.mp.begin(); it != now.mp.end(); it++) 
 99                 for (int k = 0; k < it->second; k++) 
100                     mul = (mul * val[it->first]) % mod;
101                 
102             sum = (sum + mul) % mod;
103         }
104     }
105     return sum;
106 }
107 
108 int main()
109 {
110     int n, m;
111     cin >> n >> m;//求解函数中所含自变量的个数和要求解的偏导数的个数
112     getchar();//清空缓存区
113 
114     string s, temp;
115     getline(cin,s);
116 
117     stringstream sin(s);
118     while (sin >> temp)
119     {
120         if (temp == "+" || temp == "-" || temp == "*") 
121         {//运算符
122             fomula A = st.top(); st.pop();//从栈中依次弹出两个formula
123             fomula B = st.top(); st.pop();
124             if (temp == "*") 
125                 st.push(fomula_mul(B, A));
126             else if (temp == "+") 
127                 st.push(fomula_add(B, A));
128             else 
129                 st.push(fomula_sub(A, B));//A B的顺序很重要 
130         }
131         else
132         {
133             map<int, int>mp;//下标 指数 
134             vector<item>vec;
135             if (temp[0] == 'x') 
136             {//自变量 
137                 mp[convert(temp.substr(1, temp.length() - 1))] = 1;
138                 vec.push_back(item(1, mp));//把乘积项包装成多项式 
139             }
140             else
141                 vec.push_back(item(convert(temp), mp));//把乘积项包装成多项式 
142             st.push(fomula(vec));
143         }
144     }
145     for (int i = 0; i < m; i++)
146     {
147         long long value;
148         for (int j = 0; j < n + 1; j++) 
149         {
150             cin >> value;
151             val.push_back(value);
152         }
153         long long ans = function(st.top(), val[0]);
154         cout << ((ans < 0) ? ans + mod : ans) << endl;//当计算整数n对M的模时,若n为负数,需要注意将结果调整至区间[0,M)内
155         val.clear();
156     }
157 }

重点在于,它把所有的多项式都整理成*和+的形式,并没有必要合并同类项

重点是你如何把问题抽象成对应的数据结构,然后进行求解