Golang使用grpc实现token拦截

发布时间 2023-11-06 16:33:21作者: 朝阳1

上一篇简单使用了grpcGolang简单使用grpc

server

package main

import (
	"fmt"
	"golang.org/x/net/context"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/grpclog"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/status"
	"net"
	"rpc-demo/pb/pb" // 引入编译生成的包
)

const (
	// Address gRPC服务地址
	Address = "127.0.0.1:50052"
)

// 定义HelloService并实现约定的接口
type HelloService struct {
	pb.HelloServer
}

// SayHello 实现Hello服务接口
func (h *HelloService) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloResponse, error) {
	resp := new(pb.HelloResponse)
	resp.Message = fmt.Sprintf("Hello %s.", in.Name)

	return resp, nil
}

func Auth(ctx context.Context) error {
	md, ok := metadata.FromIncomingContext(ctx)
	if !ok {
		return fmt.Errorf("missing credentials")
	}
	var user string
	var password string

	if val, ok := md["user"]; ok {
		user = val[0]
	}
	if val, ok := md["password"]; ok {
		password = val[0]
	}

	if user != "admin" || password != "admin" {
		return status.Errorf(codes.Unauthenticated, "客户端请求的token不合法")
	}
	return nil
}

func main() {
	listen, err := net.Listen("tcp", Address)
	if err != nil {
		grpclog.Fatalf("Failed to listen: %v", err)
	}
	var authInterceptor grpc.UnaryServerInterceptor

	//匿名方法
	authInterceptor = func(
		ctx context.Context,
		req interface{},
		info *grpc.UnaryServerInfo,
		handler grpc.UnaryHandler,
	) (resp interface{}, err error) {
		//拦截普通方法请求,验证 Token
		err = Auth(ctx)
		if err != nil {
			return
		}
		// 继续处理请求
		return handler(ctx, req)
	}
	s := grpc.NewServer(grpc.UnaryInterceptor(authInterceptor))
	// 注册HelloService
	pb.RegisterHelloServer(s, &HelloService{})
	fmt.Println("Listen on " + Address)
	err = s.Serve(listen)
	if err != nil {
		fmt.Println(err)
	}
}

client

package main

import (
	"fmt"
	"golang.org/x/net/context"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/grpclog"
	pb "rpc-demo/pb/pb" // 引入proto包
)

const (
	// Address gRPC服务地址
	Address = "127.0.0.1:50052"
)

type PerRPCCredentials interface {

	//GetRequestMetadata 方法返回认证需要的必要信息
	GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error)

	//RequireTransportSecurity 方法表示是否启用安全链接,在生产环境中,一般都是启用的,但为了测试方便,暂时这里不启用
	RequireTransportSecurity() bool
}
type Authentication struct {
	User     string
	Password string
}

func (a *Authentication) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
	return map[string]string{"user": a.User, "password": a.Password}, nil
}

func (a *Authentication) RequireTransportSecurity() bool {
	return false
}

func main() {
	// 连接
	user := &Authentication{
		User:     "admin",
		Password: "admin",
	}
	//1、 建立连接
	conn, err := grpc.Dial(Address, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithPerRPCCredentials(user))
	//conn, err := grpc.Dial(Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
	if err != nil {
		grpclog.Fatalln(err)
	}
	defer conn.Close()

	// 初始化客户端
	c := pb.NewHelloClient(conn)

	// 调用方法
	req := &pb.HelloRequest{Name: "gRPC"}
	res, err := c.SayHello(context.Background(), req)

	if err != nil {
		grpclog.Fatalln(err)
	}

	fmt.Println(res.Message)
}

如果客户端的账户密码不是admin,admin服务器就会拦截