A
∗
A*
A∗算法的思路可看:
路径规划之
A
∗
A*
A∗ 算法
Introduction to the
A
∗
A*
A∗ Algorithm
提炼一下:
定义起点
s
s
s,终点
t
t
t,从起点开始的距离函数
g
(
x
)
g(x)
g(x) ,到终点的距离函数
h
1
(
x
)
h_{1}(x)
h1(x) ,
h
2
(
x
)
h_{2}(x)
h2(x),以及每个点的估价函数
f
(
x
)
=
g
(
x
)
+
h
1
(
x
)
f(x)=g(x)+h_{1}(x)
f(x)=g(x)+h1(x),其中
h
1
(
x
)
h_{1}(x)
h1(x)是我们定义的点
x
x
x到终点的预估代价函数,
h
2
(
x
)
h_{2}(x)
h2(x)是点
x
x
x到终点的实际代价函数
启发函数会影响
A
∗
A*
A∗算法的行为:
- 在极端情况下,当启发函数
h
1
(
x
)
h_{1}(x)
h1(x)始终为0,则将由
g
(
x
)
g(x)
g(x)决定节点的优先级,此时算法就退化成了Dijkstra算法
- 如果
h
1
(
x
)
h_{1}(x)
h1(x)始终小于等于节点
x
x
x到终点的代价
h
2
(
x
)
h_{2}(x)
h2(x),则
A
∗
A*
A∗算法保证一定能够找到最短路径。但是当
h
1
(
x
)
h_{1}(x)
h1(x)的值越小,算法将遍历越多的节点,也就导致算法越慢
- 如果
h
1
(
x
)
h_{1}(x)
h1(x)完全等于节点
x
x
x到终点的代价
h
2
(
x
)
h_{2}(x)
h2(x),则
A
∗
A*
A∗算法将找到最佳路径,并且速度很快。可惜的是,并非所有场景下都能做到这一点,因为在没有达到终点之前,我们很难确切算出距离终点还有多远
- 如果
h
1
(
x
)
h_{1}(x)
h1(x)的值比节点
x
x
x到终点的代价
h
2
(
x
)
h_{2}(x)
h2(x)要大,则
A
∗
A*
A∗算法不能保证找到最短路径,不过此时会很快
- 在另外一个极端情况下,如果
h
1
(
x
)
h_{1}(x)
h1(x)相较于
g
(
x
)
g(x)
g(x)大很多,则此时只有
h
1
(
x
)
h_{1}(x)
h1(x)产生效果,这也就变成了最佳优先搜索
所以通过调节启发函数,我们可以控制算法的速度和精确度,在一些情况,我们可能未必需要最短路径,而是希望能够尽快找到一个路径即可,这也是
A
∗
A*
A∗算法比较灵活的地方
例一:
八数码
思路:
我们知道,在平面上,坐标
(
x
1
,
y
1
)
(x_{1},y_{1})
(x1,y1)的
i
i
i点与坐标
(
x
2
,
y
2
)
(x_{2},y_{2})
(x2,y2)的
j
j
j点的曼哈顿距离为:
d
(
i
,
j
)
=
∣
x
1
−
x
2
∣
+
∣
y
2
−
y
1
∣
d(i,j)=|x_{1}-x_{2}|+|y_{2}-y_{1}|
d(i,j)=∣x1−x2∣+∣y2−y1∣,所以,用每个数和其最终位置的曼哈顿距离作为
h
1
(
x
)
h_{1}(x)
h1(x),因为其小于
h
2
(
x
)
h_{2}(x)
h2(x)且差别又不大,所以可以很好的优化
代码:
#include <iostream>
#include <cstring>
#include <queue>
#include <unordered_map>
#include <algorithm>
using namespace std;
typedef pair<int , string> PIS;
unordered_map<string , int> dist;
unordered_map<string , pair<string , char>> pre;
priority_queue<PIS , vector<PIS> , greater<PIS>> heap;
string ed = "12345678x";
int dx[4] = {-1 , 0 , 1 , 0} , dy[4] = {0 , 1 , 0 , -1};
char op[] = "urdl";
int f(string state){
int res = 0;
for(int i = 0 ; i < 9 ; i++){
if(state[i] != 'x'){
int t = state[i] - '1';
res += abs(t / 3 - i / 3) + abs(t % 3 - i % 3);
}
}
return res;
}
string bfs(string start){
heap.push({f(start) , start});
dist[start] = 0;
while(heap.size()){
auto t = heap.top();heap.pop();
string state = t.second;
int step = dist[state];
if(state == ed) break;
int k = state.find('x');
int x = k / 3 , y = k % 3;
string source = state;
for (int i = 0; i < 4; i ++ ){
int a = x + dx[i], b = y + dy[i];
if (a >= 0 && a < 3 && b >= 0 && b < 3){
swap(state[x * 3 + y], state[a * 3 + b]);
if (!dist.count(state) || dist[state] > step + 1){
dist[state] = step + 1;
pre[state] = {source, op[i]};
heap.push({dist[state] + f(state), state});
}
swap(state[x * 3 + y], state[a * 3 + b]);
}
}
}
string res;
while(ed != start){
res += pre[ed].second;
ed = pre[ed].first;
}
reverse(res.begin() , res.end());
return res;
}
int main(){
string start , seq;
for(int i = 0 ; i < 9 ; i++){
char c;
cin >> c;
start += c;
if(c != 'x') seq += c;
}
int cnt = 0;
for(int i = 0 ; i < 8 ; i ++)
for(int j = i + 1 ; j < 8 ; j++)
if(seq[i] > seq[j])
cnt++;
if(cnt % 2) puts("unsolvable");
else cout << bfs(start) << endl;
return 0;
}
发现可以优化,一个点其实只需进队列一次就可以了:
#include <iostream>
#include <cstring>
#include <queue>
#include <unordered_map>
#include <algorithm>
using namespace std;
typedef pair<int , string> PIS;
unordered_map<string , int> dist;
unordered_map<string , bool> st;
unordered_map<string , pair<string , char>> pre;
priority_queue<PIS , vector<PIS> , greater<PIS>> heap;
string ed = "12345678x";
int dx[4] = {-1 , 0 , 1 , 0} , dy[4] = {0 , 1 , 0 , -1};
char op[] = "urdl";
int f(string state){
int res = 0;
for(int i = 0 ; i < 9 ; i++){
if(state[i] != 'x'){
int t = state[i] - '1';
res += abs(t / 3 - i / 3) + abs(t % 3 - i % 3);
}
}
return res;
}
string bfs(string start){
heap.push({f(start) , start});
dist[start] = 0;
while(heap.size()){
auto t = heap.top();heap.pop();
string state = t.second;
if(st[state])continue;
st[state]=true;
int step = dist[state];
if(state == ed) break;
int k = state.find('x');
int x = k / 3 , y = k % 3;
string source = state;
for (int i = 0; i < 4; i ++ ){
int a = x + dx[i], b = y + dy[i];
if (a >= 0 && a < 3 && b >= 0 && b < 3){
swap(state[x * 3 + y], state[a * 3 + b]);
if (!dist.count(state) || dist[state] > step + 1){
dist[state] = step + 1;
pre[state] = {source, op[i]};
heap.push({dist[state] + f(state), state});
}
swap(state[x * 3 + y], state[a * 3 + b]);
}
}
}
string res;
while(ed != start){
res += pre[ed].second;
ed = pre[ed].first;
}
reverse(res.begin() , res.end());
return res;
}
int main(){
string start , seq;
for(int i = 0 ; i < 9 ; i++){
char c;
cin >> c;
start += c;
if(c != 'x') seq += c;
}
int cnt = 0;
for(int i = 0 ; i < 8 ; i ++)
for(int j = i + 1 ; j < 8 ; j++)
if(seq[i] > seq[j])
cnt++;
if(cnt % 2) puts("unsolvable");
else cout << bfs(start) << endl;
return 0;
}
例二:
k短路
思路:
首先,小根堆每次出堆且是终点第几次出堆就是第几短路(第
k
k
k次到达终点时的路径长度即为第
k
k
k短路的长度),然后设计代价函数
h
1
(
x
)
h_{1}(x)
h1(x):从
x
x
x点到终点的最短距离,其求法:建反向边跑dijkstra即可
代码:
#include <cstring>
#include <iostream>
#include <algorithm>
#include <queue>
#define x first
#define y second
using namespace std;
typedef pair<int, int> PII;
typedef pair<int, PII> PIII;
const int N = 1010, M = 200010;
int n, m, S, T, K;
int h[N], rh[N], e[M], w[M], ne[M], idx;
int dist[N], cnt[N];
bool st[N];
void add(int h[], int a, int b, int c){
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}
void dijkstra(){
priority_queue<PII, vector<PII>, greater<PII>> heap;
heap.push({0, T});
memset(dist, 0x3f, sizeof dist);
dist[T] = 0;
while (heap.size()){
auto t = heap.top(); heap.pop();
int ver = t.y;
if (st[ver]) continue;
st[ver] = true;
for (int i = rh[ver]; ~i; i = ne[i]){
int j = e[i];
if (dist[j] > dist[ver] + w[i]){
dist[j] = dist[ver] + w[i];
heap.push({dist[j], j});
}
}
}
}
int astar(){
priority_queue<PIII, vector<PIII>, greater<PIII>> heap;
heap.push({dist[S], {0, S}});
while (heap.size()){
auto t = heap.top(); heap.pop();
int ver = t.y.y, distance = t.y.x;
cnt[ver] ++ ;
if (cnt[T] == K) return distance;
for (int i = h[ver]; ~i; i = ne[i]){
int j = e[i];
if (cnt[j] < K)
heap.push({distance + w[i] + dist[j], {distance + w[i], j}});
}
}
return -1;
}
int main(){
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h);
memset(rh, -1, sizeof rh);
for (int i = 0; i < m; i ++ ){
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(h, a, b, c);
add(rh, b, a, c);
}
scanf("%d%d%d", &S, &T, &K);
if (S == T) K ++ ;
dijkstra();
printf("%d\n", astar());
return 0;
}
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)