都知道threadlocal可以用于线程之间的变量隔离,在登录时中它可以放入当前用户,之后再用于获取当前登录用户,下面是一个使用实例。
用户实体类:(jpa框架)
@Data
@EqualsAndHashCode(callSuper = false)
@TableName("sys_user")
public class SysUser extends SuperEntity {
private static final long serialVersionUID = -5886012896705137070L;
private String username;
private String password;
private String nickname;
private String headImgUrl;
private String mobile;
private Integer sex;
private Boolean enabled;
private String type;
private String openId;
@TableLogic
private boolean isDel;
}
threadlocal类:
public class LoginUserContextHolder {
private static final ThreadLocal<SysUser> CONTEXT = new TransmittableThreadLocal<>();
public static void setUser(SysUser user) {
CONTEXT.set(user);
}
public static SysUser getUser() {
return CONTEXT.get();
}
public static void clear() {
CONTEXT.remove();
}
}
防止用户到threadlocal中:
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.common.exceptions.UnapprovedClientAuthenticationException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.TokenStore;
import javax.servlet.http.HttpServletRequest;
import java.nio.charset.StandardCharsets;
import java.util.*;
public class AuthUtils {
/**
* 校验accessToken
*/
public static SysUser checkAccessToken(HttpServletRequest request) {
String accessToken = extractToken(request);
return checkAccessToken(accessToken);
}
public static SysUser checkAccessToken(String accessTokenValue) {
TokenStore tokenStore = SpringUtil.getBean(TokenStore.class);
OAuth2AccessToken accessToken = tokenStore.readAccessToken(accessTokenValue);
if (accessToken == null || accessToken.getValue() == null) {
throw new InvalidTokenException("Invalid access token: " + accessTokenValue);
} else if (accessToken.isExpired()) {
tokenStore.removeAccessToken(accessToken);
throw new InvalidTokenException("Access token expired: " + accessTokenValue);
}
OAuth2Authentication result = tokenStore.readAuthentication(accessToken);
if (result == null) {
throw new InvalidTokenException("Invalid access token: " + accessTokenValue);
}
return setContext(result);
}
/**
* 用户信息赋值 context 对象
*/
public static SysUser setContext(Authentication authentication) {
SecurityContextHolder.getContext().setAuthentication(authentication);
SysUser user = getUser(authentication);
LoginUserContextHolder.setUser(user);
return user;
}
/**
* *从header 请求中的clientId:clientSecret
*/
public static String[] extractClient(HttpServletRequest request) {
String header = request.getHeader("Authorization");
if (header == null || !header.startsWith(BASIC_)) {
throw new UnapprovedClientAuthenticationException("请求头中client信息为空");
}
return extractHeaderClient(header);
}
/**
* 从header 请求中的clientId:clientSecret
*
* @param header header中的参数
*/
public static String[] extractHeaderClient(String header) {
byte[] base64Client = header.substring(BASIC_.length()).getBytes(StandardCharsets.UTF_8);
byte[] decoded = Base64.getDecoder().decode(base64Client);
String clientStr = new String(decoded, StandardCharsets.UTF_8);
String[] clientArr = clientStr.split(":");
if (clientArr.length != 2) {
throw new RuntimeException("Invalid basic authentication token");
}
return clientArr;
}
获取当前登录人:
/**
* 通过 LoginUserContextHolder 获取当前登录人
*/
@GetMapping("/test/auth2")
public String auth() {
return "auth2:" + LoginUserContextHolder.getUser().getUsername();
}
websocket鉴权:
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.websocket.server.ServerEndpointConfig;
public class WcAuthConfigurator extends ServerEndpointConfig.Configurator {
//checkOrigin:校验token
@Override
public boolean checkOrigin(String originHeaderValue) {
ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
try {
//检查token有效性
AuthUtils.checkAccessToken(servletRequestAttributes.getRequest());
} catch (Exception e) {
log.error("WebSocket-auth-error", e);
return false;
}
return super.checkOrigin(originHeaderValue);
}
}
在AuthUtils.checkAccessToken方法内部最终执行了threadelocal的put方法
使用WcAuthConfigurator :
//@ServerEndpoint:
主要是将目前的类定义成一个websocket服务器端, 注解的值将被用于监听用户连接的终端访问URL地址,客户端可以通过这个URL来连接到WebSocket服务器端,在这里配置configurator属性为刚刚写的配置类
@Slf4j
@Component
@ServerEndpoint(value = "/websocket/test", configurator = WcAuthConfigurator.class)
public class TestWebSocketController {
@OnOpen
public void onOpen(Session session) throws IOException {
session.getBasicRemote().sendText("TestWebSocketController-ok");
}
}