题目描述
Equations are given in the format A / B = k, where A and B are variables represented as strings, and k is a real number (floating point number). Given some queries, return the answers. If the answer does not exist, return -1.0.
样例
Example:
Given a / b = 2.0, b / c = 3.0.
queries are: a / c = ?, b / a = ?, a / e = ?, a / a = ?, x / x = ? .
return [6.0, 0.5, -1.0, 1.0, -1.0 ].
The input is: vector<pair<string, string>> equations, vector<double>& values, vector<pair<string, string>> queries , where equations.size() == values.size(), and the values are positive. This represents the equations. Return vector<double>.
According to the example above:
equations = [ ["a", "b"], ["b", "c"] ],
values = [2.0, 3.0],
queries = [ ["a", "c"], ["b", "a"], ["a", "e"], ["a", "a"], ["x", "x"] ].
算法1
dfs
题目很好,想到要计算queries中每个查询,就要dfs去查询a->b->c->d,求出a/d.
根据等式先得到每个等式到一个值的map,优化就是a+b代表等式a/b;b+a代表b/a.
然后得知道每个节点到下面分支能有多少节点,节点存储在start2end[节点]里,
根据map和start2end构成的查询树dfs求得每个查询结果
C++ 代码
class Solution {
public:
vector<double> calcEquation(vector<vector<string>>& equations, vector<double>& values, vector<vector<string>>& queries) {
unordered_map<string,double> map;
unordered_map<string,vector<string>> start2end;
for(int i=0;i<values.size();i++){
string start=equations[i][0];
string end=equations[i][1];
map[start+end]=values[i];
map[end+start]=1.0/values[i];
start2end[start].push_back(end);
start2end[end].push_back(start);
}
vector<double> res;
for(int i=0;i<queries.size();i++){
unordered_set<string> visited;
string start=queries[i][0];
string end=queries[i][1];
double value=1.0;
value=dfs(start,end,map,start2end,value,visited);
res.push_back(value);
}
return res;
}
double dfs(string start,string end,unordered_map<string,double>&map,unordered_map<string,vector<string>>&start2end,double value,unordered_set<string>&visited){
if(!start2end.count(start))
return -1.0;
else if(start==end){
return value;
}
vector<string> nexts=start2end[start];
for(int i=0;i<nexts.size();i++){
string next=nexts[i];
if(visited.count(next))
continue;
visited.insert(start);
double tmp=dfs(next,end,map,start2end,value*map[start+next],visited);
if(tmp>-1.0)
return tmp;
visited.erase(start);
}
return -1.0;
}
};