Go 与 GraphQL:构建灵活的 API
你用过 REST API 吗?肯定用过。但你有没有遇到过这样的场景:前端需要用户信息、用户的订单列表、以及每个订单的商品详情。用 REST 的话,你可能需要调用 /users/123、/users/123/orders、然后对每个订单再调用 /orders/456/products。这就是经典的 N+1 请求问题。
GraphQL 就是为了解决这类问题而生的。它允许客户端精确指定需要的数据,一次请求搞定所有事情。今天,我们就来学习如何用 Go 和 gqlgen 构建一个功能完善的 GraphQL API。
GraphQL 基础概念
在开始写代码之前,让我们快速了解一下 GraphQL 的核心概念:
- Schema:定义数据的结构和可执行的操作
- Query:查询数据(类似于 REST 的 GET)
- Mutation:修改数据(类似于 REST 的 POST/PUT/DELETE)
- Subscription:实时数据推送(WebSocket)
- Resolver:实际执行查询和修改的逻辑函数
- Type:数据类型定义(对象、标量、枚举等)
项目初始化与 gqlgen 配置
首先,创建项目并初始化 Go module:
mkdir go-graphql-demo
cd go-graphql-demo
go mod init github.com/yourusername/go-graphql-demo
安装 gqlgen,这是 Go 生态中最流行的 GraphQL 代码生成工具:
go get github.com/99designs/gqlgen
go install github.com/99designs/gqlgen@latest
创建 gqlgen 配置文件 gqlgen.yml:
# gqlgen.yml
schema:
- graph/*.graphqls
exec:
filename: graph/generated.go
package: graph
model:
filename: graph/model/models_gen.go
package: model
resolver:
layout: follow-schema
dir: graph
package: graph
filename_template: "{name}.resolvers.go"
autobind:
- github.com/yourusername/go-graphql-demo/graph/model
models:
ID:
model:
- github.com/99designs/gqlgen/graphql.ID
- github.com/99designs/gqlgen/graphql.Int
- github.com/99designs/gqlgen/graphql.Int64
- github.com/99designs/gqlgen/graphql.Int32
Int:
model:
- github.com/99designs/gqlgen/graphql.Int
- github.com/99designs/gqlgen/graphql.Int64
- github.com/99designs/gqlgen/graphql.Int32
定义 GraphQL Schema
在 graph 目录下创建 schema.graphqls 文件:
# graph/schema.graphqls
scalar Time
scalar Upload
type User {
id: ID!
username: String!
email: String!
createdAt: Time!
posts: [Post!]!
followers: [User!]!
following: [User!]!
}
type Post {
id: ID!
title: String!
content: String!
author: User!
tags: [String!]!
likes: Int!
comments: [Comment!]!
createdAt: Time!
updatedAt: Time!
}
type Comment {
id: ID!
content: String!
author: User!
post: Post!
createdAt: Time!
}
type AuthPayload {
token: String!
user: User!
}
input CreateUserInput {
username: String!
email: String!
password: String!
}
input CreatePostInput {
title: String!
content: String!
tags: [String!]
}
input UpdatePostInput {
title: String
content: String
tags: [String!]
}
input CreateCommentInput {
postId: ID!
content: String!
}
type PostConnection {
edges: [PostEdge!]!
pageInfo: PageInfo!
totalCount: Int!
}
type PostEdge {
node: Post!
cursor: String!
}
type PageInfo {
hasNextPage: Boolean!
hasPreviousPage: Boolean!
startCursor: String
endCursor: String
}
type Query {
user(id: ID!): User
users: [User!]!
post(id: ID!): Post
posts(first: Int, after: String, last: Int, before: String): PostConnection!
me: User
}
type Mutation {
register(input: CreateUserInput!): AuthPayload!
login(email: String!, password: String!): AuthPayload!
createPost(input: CreatePostInput!): Post!
updatePost(id: ID!, input: UpdatePostInput!): Post!
deletePost(id: ID!): Boolean!
likePost(id: ID!): Post!
createComment(input: CreateCommentInput!): Comment!
followUser(userId: ID!): User!
unfollowUser(userId: ID!): User!
}
type Subscription {
postCreated: Post!
commentAdded(postId: ID!): Comment!
postLiked(postId: ID!): Post!
}
运行代码生成命令:
go run github.com/99designs/gqlgen generate
这会自动生成 generated.go 和 *.resolvers.go 文件。
实现数据模型和存储
在实现 resolver 之前,我们先定义数据模型和简单的内存存储:
// graph/model/models.go
package model
import "time"
type User struct {
ID string `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"-"` // 不暴露给客户端
CreatedAt time.Time `json:"createdAt"`
}
type Post struct {
ID string `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
AuthorID string `json:"authorId"`
Tags []string `json:"tags"`
Likes int `json:"likes"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type Comment struct {
ID string `json:"id"`
Content string `json:"content"`
AuthorID string `json:"authorId"`
PostID string `json:"postId"`
CreatedAt time.Time `json:"createdAt"`
}
创建存储层:
// internal/store/store.go
package store
import (
"context"
"errors"
"sync"
"github.com/yourusername/go-graphql-demo/graph/model"
)
var (
ErrNotFound = errors.New("not found")
)
type Store struct {
mu sync.RWMutex
users map[string]*model.User
posts map[string]*model.Post
comments map[string]*model.Comment
// 关系映射
followers map[string][]string // userID -> []followerID
following map[string][]string // userID -> []followingID
postLikes map[string][]string // postID -> []userID
}
func New() *Store {
return &Store{
users: make(map[string]*model.User),
posts: make(map[string]*model.Post),
comments: make(map[string]*model.Comment),
followers: make(map[string][]string),
following: make(map[string][]string),
postLikes: make(map[string][]string),
}
}
func (s *Store) CreateUser(ctx context.Context, user *model.User) error {
s.mu.Lock()
defer s.mu.Unlock()
// 检查邮箱是否已存在
for _, u := range s.users {
if u.Email == user.Email {
return errors.New("email already exists")
}
}
s.users[user.ID] = user
return nil
}
func (s *Store) GetUserByID(ctx context.Context, id string) (*model.User, error) {
s.mu.RLock()
defer s.mu.RUnlock()
user, ok := s.users[id]
if !ok {
return nil, ErrNotFound
}
return user, nil
}
func (s *Store) GetUserByEmail(ctx context.Context, email string) (*model.User, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, user := range s.users {
if user.Email == email {
return user, nil
}
}
return nil, ErrNotFound
}
func (s *Store) CreatePost(ctx context.Context, post *model.Post) error {
s.mu.Lock()
defer s.mu.Unlock()
s.posts[post.ID] = post
return nil
}
func (s *Store) GetPostByID(ctx context.Context, id string) (*model.Post, error) {
s.mu.RLock()
defer s.mu.RUnlock()
post, ok := s.posts[id]
if !ok {
return nil, ErrNotFound
}
return post, nil
}
func (s *Store) GetPostsByAuthor(ctx context.Context, authorID string) ([]*model.Post, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var posts []*model.Post
for _, post := range s.posts {
if post.AuthorID == authorID {
posts = append(posts, post)
}
}
return posts, nil
}
func (s *Store) ListPosts(ctx context.Context, limit int, cursor *string) ([]*model.Post, error) {
s.mu.RLock()
defer s.mu.RUnlock()
// 简单实现:返回所有文章,实际项目中应该使用数据库分页
var posts []*model.Post
for _, post := range s.posts {
posts = append(posts, post)
}
// 按创建时间倒序排序
sortPostsByCreatedAt(posts)
if limit > 0 && len(posts) > limit {
posts = posts[:limit]
}
return posts, nil
}
func (s *Store) UpdatePost(ctx context.Context, id string, input *model.UpdatePostInput) (*model.Post, error) {
s.mu.Lock()
defer s.mu.Unlock()
post, ok := s.posts[id]
if !ok {
return nil, ErrNotFound
}
if input.Title != nil {
post.Title = *input.Title
}
if input.Content != nil {
post.Content = *input.Content
}
if input.Tags != nil {
post.Tags = input.Tags
}
post.UpdatedAt = time.Now()
return post, nil
}
func (s *Store) DeletePost(ctx context.Context, id string) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.posts[id]; !ok {
return ErrNotFound
}
delete(s.posts, id)
return nil
}
func (s *Store) CreateComment(ctx context.Context, comment *model.Comment) error {
s.mu.Lock()
defer s.mu.Unlock()
s.comments[comment.ID] = comment
return nil
}
func (s *Store) GetCommentsByPost(ctx context.Context, postID string) ([]*model.Comment, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var comments []*model.Comment
for _, comment := range s.comments {
if comment.PostID == postID {
comments = append(comments, comment)
}
}
return comments, nil
}
func (s *Store) LikePost(ctx context.Context, postID, userID string) error {
s.mu.Lock()
defer s.mu.Unlock()
post, ok := s.posts[postID]
if !ok {
return ErrNotFound
}
// 检查是否已经点赞
for _, uid := range s.postLikes[postID] {
if uid == userID {
return errors.New("already liked")
}
}
s.postLikes[postID] = append(s.postLikes[postID], userID)
post.Likes++
return nil
}
func (s *Store) FollowUser(ctx context.Context, followerID, followingID string) error {
s.mu.Lock()
defer s.mu.Unlock()
// 检查用户是否存在
if _, ok := s.users[followingID]; !ok {
return ErrNotFound
}
// 检查是否已经关注
for _, uid := range s.following[followerID] {
if uid == followingID {
return errors.New("already following")
}
}
s.following[followerID] = append(s.following[followerID], followingID)
s.followers[followingID] = append(s.followers[followingID], followerID)
return nil
}
func (s *Store) UnfollowUser(ctx context.Context, followerID, followingID string) error {
s.mu.Lock()
defer s.mu.Unlock()
// 移除关注关系
following := s.following[followerID]
for i, uid := range following {
if uid == followingID {
s.following[followerID] = append(following[:i], following[i+1:]...)
break
}
}
// 移除粉丝关系
followers := s.followers[followingID]
for i, uid := range followers {
if uid == followerID {
s.followers[followingID] = append(followers[:i], followers[i+1:]...)
break
}
}
return nil
}
func (s *Store) GetFollowers(ctx context.Context, userID string) ([]*model.User, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var users []*model.User
for _, followerID := range s.followers[userID] {
if user, ok := s.users[followerID]; ok {
users = append(users, user)
}
}
return users, nil
}
func (s *Store) GetFollowing(ctx context.Context, userID string) ([]*model.User, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var users []*model.User
for _, followingID := range s.following[userID] {
if user, ok := s.users[followingID]; ok {
users = append(users, user)
}
}
return users, nil
}
func sortPostsByCreatedAt(posts []*model.Post) {
// 简单的冒泡排序,实际项目中使用 sort.Slice
for i := 0; i < len(posts); i++ {
for j := i + 1; j < len(posts); j++ {
if posts[i].CreatedAt.Before(posts[j].CreatedAt) {
posts[i], posts[j] = posts[j], posts[i]
}
}
}
}
实现 Resolver
现在实现核心的 resolver 逻辑。首先是认证相关的 mutation:
// graph/auth.resolvers.go
package graph
import (
"context"
"errors"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"github.com/yourusername/go-graphql-demo/graph/model"
)
const jwtSecret = "your-secret-key-change-in-production"
type Claims struct {
UserID string `json:"user_id"`
jwt.RegisteredClaims
}
func generateToken(userID string) (string, error) {
claims := &Claims{
UserID: userID,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(72 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(jwtSecret))
}
func parseToken(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(jwtSecret), nil
})
if err != nil {
return nil, err
}
claims, ok := token.Claims.(*Claims)
if !ok || !token.Valid {
return nil, errors.New("invalid token")
}
return claims, nil
}
func hashPassword(password string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return string(bytes), err
}
func checkPassword(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}
// Register is the resolver for the register mutation.
func (r *mutationResolver) Register(ctx context.Context, input model.CreateUserInput) (*model.AuthPayload, error) {
hashedPassword, err := hashPassword(input.Password)
if err != nil {
return nil, err
}
user := &model.User{
ID: uuid.New().String(),
Username: input.Username,
Email: input.Email,
Password: hashedPassword,
CreatedAt: time.Now(),
}
if err := r.Store.CreateUser(ctx, user); err != nil {
return nil, err
}
token, err := generateToken(user.ID)
if err != nil {
return nil, err
}
return &model.AuthPayload{
Token: token,
User: user,
}, nil
}
// Login is the resolver for the login mutation.
func (r *mutationResolver) Login(ctx context.Context, email string, password string) (*model.AuthPayload, error) {
user, err := r.Store.GetUserByEmail(ctx, email)
if err != nil {
return nil, errors.New("invalid credentials")
}
if !checkPassword(password, user.Password) {
return nil, errors.New("invalid credentials")
}
token, err := generateToken(user.ID)
if err != nil {
return nil, err
}
return &model.AuthPayload{
Token: token,
User: user,
}, nil
}
接下来实现文章的 CRUD resolver:
// graph/post.resolvers.go
package graph
import (
"context"
"errors"
"time"
"github.com/google/uuid"
"github.com/yourusername/go-graphql-demo/graph/model"
)
// CreatePost is the resolver for the createPost field.
func (r *mutationResolver) CreatePost(ctx context.Context, input model.CreatePostInput) (*model.Post, error) {
userID, err := getUserIDFromContext(ctx)
if err != nil {
return nil, errors.New("unauthorized")
}
post := &model.Post{
ID: uuid.New().String(),
Title: input.Title,
Content: input.Content,
AuthorID: userID,
Tags: input.Tags,
Likes: 0,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := r.Store.CreatePost(ctx, post); err != nil {
return nil, err
}
// 发布订阅事件
r.Subscription.PostCreated(ctx, post)
return post, nil
}
// UpdatePost is the resolver for the updatePost field.
func (r *mutationResolver) UpdatePost(ctx context.Context, id string, input model.UpdatePostInput) (*model.Post, error) {
userID, err := getUserIDFromContext(ctx)
if err != nil {
return nil, errors.New("unauthorized")
}
post, err := r.Store.GetPostByID(ctx, id)
if err != nil {
return nil, err
}
// 检查权限:只有作者可以修改
if post.AuthorID != userID {
return nil, errors.New("forbidden")
}
return r.Store.UpdatePost(ctx, id, &input)
}
// DeletePost is the resolver for the deletePost field.
func (r *mutationResolver) DeletePost(ctx context.Context, id string) (bool, error) {
userID, err := getUserIDFromContext(ctx)
if err != nil {
return false, errors.New("unauthorized")
}
post, err := r.Store.GetPostByID(ctx, id)
if err != nil {
return false, err
}
if post.AuthorID != userID {
return false, errors.New("forbidden")
}
if err := r.Store.DeletePost(ctx, id); err != nil {
return false, err
}
return true, nil
}
// LikePost is the resolver for the likePost field.
func (r *mutationResolver) LikePost(ctx context.Context, id string) (*model.Post, error) {
userID, err := getUserIDFromContext(ctx)
if err != nil {
return nil, errors.New("unauthorized")
}
if err := r.Store.LikePost(ctx, id, userID); err != nil {
return nil, err
}
post, err := r.Store.GetPostByID(ctx, id)
if err != nil {
return nil, err
}
// 发布订阅事件
r.Subscription.PostLiked(ctx, id, post)
return post, nil
}
// Posts is the resolver for the posts field.
func (r *queryResolver) Posts(ctx context.Context, first *int, after *string, last *int, before *string) (*model.PostConnection, error) {
limit := 10
if first != nil {
limit = *first
}
posts, err := r.Store.ListPosts(ctx, limit+1, after)
if err != nil {
return nil, err
}
hasNextPage := len(posts) > limit
if hasNextPage {
posts = posts[:limit]
}
var edges []*model.PostEdge
for _, post := range posts {
edges = append(edges, &model.PostEdge{
Node: post,
Cursor: post.ID,
})
}
var startCursor, endCursor *string
if len(edges) > 0 {
startCursor = &edges[0].Cursor
endCursor = &edges[len(edges)-1].Cursor
}
return &model.PostConnection{
Edges: edges,
PageInfo: &model.PageInfo{
HasNextPage: hasNextPage,
HasPreviousPage: after != nil,
StartCursor: startCursor,
EndCursor: endCursor,
},
TotalCount: len(posts),
}, nil
}
// Post is the resolver for the post field.
func (r *queryResolver) Post(ctx context.Context, id string) (*model.Post, error) {
return r.Store.GetPostByID(ctx, id)
}
实现用户和评论的 resolver:
// graph/user.resolvers.go
package graph
import (
"context"
"errors"
"time"
"github.com/google/uuid"
"github.com/yourusername/go-graphql-demo/graph/model"
)
// Me is the resolver for the me field.
func (r *queryResolver) Me(ctx context.Context) (*model.User, error) {
userID, err := getUserIDFromContext(ctx)
if err != nil {
return nil, errors.New("unauthorized")
}
return r.Store.GetUserByID(ctx, userID)
}
// User is the resolver for the user field.
func (r *queryResolver) User(ctx context.Context, id string) (*model.User, error) {
return r.Store.GetUserByID(ctx, id)
}
// Users is the resolver for the users field.
func (r *queryResolver) Users(ctx context.Context) ([]*model.User, error) {
// 实际项目中应该实现分页
return nil, errors.New("not implemented")
}
// FollowUser is the resolver for the followUser field.
func (r *mutationResolver) FollowUser(ctx context.Context, userID string) (*model.User, error) {
followerID, err := getUserIDFromContext(ctx)
if err != nil {
return nil, errors.New("unauthorized")
}
if err := r.Store.FollowUser(ctx, followerID, userID); err != nil {
return nil, err
}
return r.Store.GetUserByID(ctx, userID)
}
// UnfollowUser is the resolver for the unfollowUser field.
func (r *mutationResolver) UnfollowUser(ctx context.Context, userID string) (*model.User, error) {
followerID, err := getUserIDFromContext(ctx)
if err != nil {
return nil, errors.New("unauthorized")
}
if err := r.Store.UnfollowUser(ctx, followerID, userID); err != nil {
return nil, err
}
return r.Store.GetUserByID(ctx, userID)
}
// CreateComment is the resolver for the createComment field.
func (r *mutationResolver) CreateComment(ctx context.Context, input model.CreateCommentInput) (*model.Comment, error) {
userID, err := getUserIDFromContext(ctx)
if err != nil {
return nil, errors.New("unauthorized")
}
// 检查文章是否存在
_, err = r.Store.GetPostByID(ctx, input.PostID)
if err != nil {
return nil, errors.New("post not found")
}
comment := &model.Comment{
ID: uuid.New().String(),
Content: input.Content,
AuthorID: userID,
PostID: input.PostID,
CreatedAt: time.Now(),
}
if err := r.Store.CreateComment(ctx, comment); err != nil {
return nil, err
}
// 发布订阅事件
r.Subscription.CommentAdded(ctx, input.PostID, comment)
return comment, nil
}
字段级 Resolver 与 DataLoader
现在来处理嵌套字段的 resolver。这里有一个经典问题:N+1 查询。假设你查询 10 篇文章及其作者信息, naive 的实现会触发 10 次数据库查询。
我们使用 DataLoader 来批量加载数据:
// internal/dataloader/dataloader.go
package dataloader
import (
"context"
"net/http"
"time"
"github.com/graph-gophers/dataloader"
"github.com/yourusername/go-graphql-demo/graph/model"
"github.com/yourusername/go-graphql-demo/internal/store"
)
type ctxKey string
const (
userLoaderKey = ctxKey("userLoader")
postLoaderKey = ctxKey("postLoader")
postsLoaderKey = ctxKey("postsLoader")
)
// Loaders 数据加载器集合
type Loaders struct {
UserLoader *dataloader.Loader
PostLoader *dataloader.Loader
PostsLoader *dataloader.Loader
}
// NewLoaders 创建数据加载器
func NewLoaders(store *store.Store) *Loaders {
return &Loaders{
UserLoader: newUserLoader(store),
PostLoader: newPostLoader(store),
PostsLoader: newPostsLoader(store),
}
}
func newUserLoader(store *store.Store) *dataloader.Loader {
return dataloader.NewBatchedLoader(func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result {
var results []*dataloader.Result
for _, key := range keys {
user, err := store.GetUserByID(ctx, key.String())
results = append(results, &dataloader.Result{Data: user, Error: err})
}
return results
}, dataloader.WithBatchCapacity(100), dataloader.WithBatchSchedule(dataloader.DurationScheduler(5*time.Millisecond)))
}
func newPostLoader(store *store.Store) *dataloader.Loader {
return dataloader.NewBatchedLoader(func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result {
var results []*dataloader.Result
for _, key := range keys {
post, err := store.GetPostByID(ctx, key.String())
results = append(results, &dataloader.Result{Data: post, Error: err})
}
return results
}, dataloader.WithBatchCapacity(100), dataloader.WithBatchSchedule(dataloader.DurationScheduler(5*time.Millisecond)))
}
func newPostsLoader(store *store.Store) *dataloader.Loader {
return dataloader.NewBatchedLoader(func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result {
var results []*dataloader.Result
for _, key := range keys {
posts, err := store.GetPostsByAuthor(ctx, key.String())
results = append(results, &dataloader.Result{Data: posts, Error: err})
}
return results
}, dataloader.WithBatchCapacity(100), dataloader.WithBatchSchedule(dataloader.DurationScheduler(5*time.Millisecond)))
}
// Middleware 将 loaders 注入到 context 中
func Middleware(loaders *Loaders, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), userLoaderKey, loaders.UserLoader)
ctx = context.WithValue(ctx, postLoaderKey, loaders.PostLoader)
ctx = context.WithValue(ctx, postsLoaderKey, loaders.PostsLoader)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// GetUserLoader 从 context 获取 UserLoader
func GetUserLoader(ctx context.Context) *dataloader.Loader {
return ctx.Value(userLoaderKey).(*dataloader.Loader)
}
// GetPostLoader 从 context 获取 PostLoader
func GetPostLoader(ctx context.Context) *dataloader.Loader {
return ctx.Value(postLoaderKey).(*dataloader.Loader)
}
// GetPostsLoader 从 context 获取 PostsLoader
func GetPostsLoader(ctx context.Context) *dataloader.Loader {
return ctx.Value(postsLoaderKey).(*dataloader.Loader)
}
现在实现字段级 resolver,使用 DataLoader:
// graph/schema.resolvers.go
package graph
import (
"context"
"github.com/yourusername/go-graphql-demo/graph/model"
"github.com/yourusername/go-graphql-demo/internal/dataloader"
)
// Posts is the resolver for the posts field of User.
func (r *userResolver) Posts(ctx context.Context, obj *model.User) ([]*model.Post, error) {
loader := dataloader.GetPostsLoader(ctx)
thunk := loader.Load(ctx, dataloader.StringKey(obj.ID))
result, err := thunk()
if err != nil {
return nil, err
}
return result.([]*model.Post), nil
}
// Followers is the resolver for the followers field.
func (r *userResolver) Followers(ctx context.Context, obj *model.User) ([]*model.User, error) {
return r.Store.GetFollowers(ctx, obj.ID)
}
// Following is the resolver for the following field.
func (r *userResolver) Following(ctx context.Context, obj *model.User) ([]*model.User, error) {
return r.Store.GetFollowing(ctx, obj.ID)
}
// Author is the resolver for the author field of Post.
func (r *postResolver) Author(ctx context.Context, obj *model.Post) (*model.User, error) {
loader := dataloader.GetUserLoader(ctx)
thunk := loader.Load(ctx, dataloader.StringKey(obj.AuthorID))
result, err := thunk()
if err != nil {
return nil, err
}
return result.(*model.User), nil
}
// Comments is the resolver for the comments field.
func (r *postResolver) Comments(ctx context.Context, obj *model.Post) ([]*model.Comment, error) {
return r.Store.GetCommentsByPost(ctx, obj.ID)
}
// Author is the resolver for the author field of Comment.
func (r *commentResolver) Author(ctx context.Context, obj *model.Comment) (*model.User, error) {
loader := dataloader.GetUserLoader(ctx)
thunk := loader.Load(ctx, dataloader.StringKey(obj.AuthorID))
result, err := thunk()
if err != nil {
return nil, err
}
return result.(*model.User), nil
}
// Post is the resolver for the post field of Comment.
func (r *commentResolver) Post(ctx context.Context, obj *model.Comment) (*model.Post, error) {
loader := dataloader.GetPostLoader(ctx)
thunk := loader.Load(ctx, dataloader.StringKey(obj.PostID))
result, err := thunk()
if err != nil {
return nil, err
}
return result.(*model.Post), nil
}
实现 Subscription
GraphQL Subscription 允许客户端订阅实时事件。我们使用 WebSocket 实现:
// graph/subscription.resolvers.go
package graph
import (
"context"
"sync"
"github.com/yourusername/go-graphql-demo/graph/model"
)
type Subscription struct {
mu sync.RWMutex
postCreatedSubs []chan *model.Post
commentAddedSubs map[string][]chan *model.Comment
postLikedSubs map[string][]chan *model.Post
}
func NewSubscription() *Subscription {
return &Subscription{
commentAddedSubs: make(map[string][]chan *model.Comment),
postLikedSubs: make(map[string][]chan *model.Post),
}
}
// PostCreated 发布新文章创建事件
func (s *Subscription) PostCreated(ctx context.Context, post *model.Post) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, ch := range s.postCreatedSubs {
select {
case ch <- post:
default:
// 如果 channel 满了,跳过
}
}
}
// CommentAdded 发布新评论事件
func (s *Subscription) CommentAdded(ctx context.Context, postID string, comment *model.Comment) {
s.mu.RLock()
defer s.mu.RUnlock()
if subs, ok := s.commentAddedSubs[postID]; ok {
for _, ch := range subs {
select {
case ch <- comment:
default:
}
}
}
}
// PostLiked 发布文章点赞事件
func (s *Subscription) PostLiked(ctx context.Context, postID string, post *model.Post) {
s.mu.RLock()
defer s.mu.RUnlock()
if subs, ok := s.postLikedSubs[postID]; ok {
for _, ch := range subs {
select {
case ch <- post:
default:
}
}
}
}
// PostCreated is the resolver for the postCreated field.
func (r *subscriptionResolver) PostCreated(ctx context.Context) (<-chan *model.Post, error) {
ch := make(chan *model.Post, 1)
r.Subscription.mu.Lock()
r.Subscription.postCreatedSubs = append(r.Subscription.postCreatedSubs, ch)
r.Subscription.mu.Unlock()
go func() {
<-ctx.Done()
// 清理:从订阅列表中移除
r.Subscription.mu.Lock()
defer r.Subscription.mu.Unlock()
for i, sub := range r.Subscription.postCreatedSubs {
if sub == ch {
r.Subscription.postCreatedSubs = append(
r.Subscription.postCreatedSubs[:i],
r.Subscription.postCreatedSubs[i+1:]...,
)
close(ch)
break
}
}
}()
return ch, nil
}
// CommentAdded is the resolver for the commentAdded field.
func (r *subscriptionResolver) CommentAdded(ctx context.Context, postID string) (<-chan *model.Comment, error) {
ch := make(chan *model.Comment, 1)
r.Subscription.mu.Lock()
r.Subscription.commentAddedSubs[postID] = append(r.Subscription.commentAddedSubs[postID], ch)
r.Subscription.mu.Unlock()
go func() {
<-ctx.Done()
r.Subscription.mu.Lock()
defer r.Subscription.mu.Unlock()
subs := r.Subscription.commentAddedSubs[postID]
for i, sub := range subs {
if sub == ch {
r.Subscription.commentAddedSubs[postID] = append(subs[:i], subs[i+1:]...)
close(ch)
break
}
}
}()
return ch, nil
}
// PostLiked is the resolver for the postLiked field.
func (r *subscriptionResolver) PostLiked(ctx context.Context, postID string) (<-chan *model.Post, error) {
ch := make(chan *model.Post, 1)
r.Subscription.mu.Lock()
r.Subscription.postLikedSubs[postID] = append(r.Subscription.postLikedSubs[postID], ch)
r.Subscription.mu.Unlock()
go func() {
<-ctx.Done()
r.Subscription.mu.Lock()
defer r.Subscription.mu.Unlock()
subs := r.Subscription.postLikedSubs[postID]
for i, sub := range subs {
if sub == ch {
r.Subscription.postLikedSubs[postID] = append(subs[:i], subs[i+1:]...)
close(ch)
break
}
}
}()
return ch, nil
}
中间件与认证
创建一个中间件来提取 JWT token 并将用户信息注入 context:
// internal/middleware/auth.go
package middleware
import (
"context"
"net/http"
"strings"
"github.com/yourusername/go-graphql-demo/graph"
)
type contextKey string
const userContextKey = contextKey("user")
// AuthMiddleware 认证中间件
func AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 从 Authorization header 提取 token
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
next.ServeHTTP(w, r)
return
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
next.ServeHTTP(w, r)
return
}
tokenString := parts[1]
claims, err := graph.ParseToken(tokenString)
if err != nil {
next.ServeHTTP(w, r)
return
}
// 将用户 ID 注入 context
ctx := context.WithValue(r.Context(), userContextKey, claims.UserID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// GetUserID 从 context 获取用户 ID
func GetUserID(ctx context.Context) (string, bool) {
userID, ok := ctx.Value(userContextKey).(string)
return userID, ok
}
在 resolver 中获取用户 ID 的辅助函数:
// graph/helpers.go
package graph
import (
"context"
"errors"
"github.com/yourusername/go-graphql-demo/internal/middleware"
)
func getUserIDFromContext(ctx context.Context) (string, error) {
userID, ok := middleware.GetUserID(ctx)
if !ok {
return "", errors.New("unauthorized")
}
return userID, nil
}
错误处理
定义统一的错误处理机制:
// graph/errors.go
package graph
import (
"errors"
"github.com/vektah/gqlparser/v2/gqlerror"
)
// ErrorCode 错误代码
type ErrorCode string
const (
ErrCodeNotFound ErrorCode = "NOT_FOUND"
ErrCodeUnauthorized ErrorCode = "UNAUTHORIZED"
ErrCodeForbidden ErrorCode = "FORBIDDEN"
ErrCodeValidation ErrorCode = "VALIDATION_ERROR"
ErrCodeInternal ErrorCode = "INTERNAL_ERROR"
)
// NewError 创建 GraphQL 错误
func NewError(code ErrorCode, message string) *gqlerror.Error {
return &gqlerror.Error{
Message: message,
Extensions: map[string]interface{}{
"code": code,
},
}
}
// ErrorPresenter 自定义错误展示
func ErrorPresenter(ctx context.Context, err error) *gqlerror.Error {
// 如果是 gqlerror,直接返回
var gqlErr *gqlerror.Error
if errors.As(err, &gqlErr) {
return gqlErr
}
// 其他错误转换为 INTERNAL_ERROR
return &gqlerror.Error{
Message: "An internal error occurred",
Extensions: map[string]interface{}{
"code": ErrCodeInternal,
},
}
}
主程序入口
最后,将所有组件组装起来:
// main.go
package main
import (
"log"
"net/http"
"os"
"github.com/99designs/gqlgen/graphql/handler"
"github.com/99designs/gqlgen/graphql/handler/extension"
"github.com/99designs/gqlgen/graphql/handler/lru"
"github.com/99designs/gqlgen/graphql/handler/transport"
"github.com/99designs/gqlgen/graphql/playground"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/gorilla/websocket"
"github.com/vektah/gqlparser/v2/ast"
"github.com/yourusername/go-graphql-demo/graph"
"github.com/yourusername/go-graphql-demo/internal/dataloader"
storepkg "github.com/yourusername/go-graphql-demo/internal/store"
authmw "github.com/yourusername/go-graphql-demo/internal/middleware"
)
const defaultPort = "8080"
func main() {
port := os.Getenv("PORT")
if port == "" {
port = defaultPort
}
// 初始化存储
store := storepkg.New()
loaders := dataloader.NewLoaders(store)
// 创建 GraphQL server
srv := handler.New(graph.NewExecutableSchema(graph.Config{
Resolvers: &graph.Resolver{
Store: store,
Subscription: graph.NewSubscription(),
},
}))
// 配置传输层
srv.AddTransport(transport.Options{})
srv.AddTransport(transport.GET{})
srv.AddTransport(transport.POST{})
// WebSocket 支持(用于 Subscription)
srv.AddTransport(&transport.Websocket{
KeepAlivePingInterval: 10 * time.Second,
Upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
},
})
// 查询缓存
srv.SetQueryCache(lru.New[*ast.QueryDocument](1000))
srv.Use(extension.Introspection{})
srv.Use(extension.AutomaticPersistedQuery{
Cache: lru.New[string](100),
})
// 自定义错误处理
srv.SetErrorPresenter(graph.ErrorPresenter)
// 路由配置
router := chi.NewRouter()
router.Use(middleware.Logger)
router.Use(middleware.Recoverer)
router.Use(middleware.RequestID)
router.Use(middleware.RealIP)
router.Use(middleware.CORS(&middleware.CORSOptions{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: true,
MaxAge: 300,
}))
// 认证中间件
router.Use(authmw.AuthMiddleware)
// DataLoader 中间件
router.Use(dataloader.Middleware(loaders, nil))
// GraphQL Playground
router.Handle("/", playground.Handler("GraphQL playground", "/query"))
// GraphQL endpoint
router.Handle("/query", srv)
log.Printf("connect to http://localhost:%s/ for GraphQL playground", port)
log.Fatal(http.ListenAndServe(":"+port, router))
}
测试 GraphQL API
编写单元测试来验证 resolver 逻辑:
// graph/resolvers_test.go
package graph_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/yourusername/go-graphql-demo/graph"
"github.com/yourusername/go-graphql-demo/graph/model"
"github.com/yourusername/go-graphql-demo/internal/store"
)
func TestRegister(t *testing.T) {
store := store.New()
resolver := &graph.Resolver{Store: store}
input := model.CreateUserInput{
Username: "testuser",
Email: "test@example.com",
Password: "password123",
}
payload, err := resolver.Mutation().Register(context.Background(), input)
require.NoError(t, err)
require.NotNil(t, payload)
assert.NotEmpty(t, payload.Token)
assert.Equal(t, "testuser", payload.User.Username)
assert.Equal(t, "test@example.com", payload.User.Email)
}
func TestCreatePost(t *testing.T) {
store := store.New()
resolver := &graph.Resolver{Store: store}
// 先注册用户
userPayload, err := resolver.Mutation().Register(context.Background(), model.CreateUserInput{
Username: "author",
Email: "author@example.com",
Password: "password123",
})
require.NoError(t, err)
// 模拟认证后的 context
ctx := context.WithValue(context.Background(), middleware.UserContextKey, userPayload.User.ID)
// 创建文章
input := model.CreatePostInput{
Title: "My First Post",
Content: "Hello, GraphQL!",
Tags: []string{"go", "graphql"},
}
post, err := resolver.Mutation().CreatePost(ctx, input)
require.NoError(t, err)
require.NotNil(t, post)
assert.Equal(t, "My First Post", post.Title)
assert.Equal(t, "Hello, GraphQL!", post.Content)
assert.Equal(t, userPayload.User.ID, post.AuthorID)
}
func TestLikePost(t *testing.T) {
store := store.New()
resolver := &graph.Resolver{Store: store}
// 注册用户并创建文章
userPayload, _ := resolver.Mutation().Register(context.Background(), model.CreateUserInput{
Username: "author",
Email: "author@example.com",
Password: "password123",
})
ctx := context.WithValue(context.Background(), middleware.UserContextKey, userPayload.User.ID)
post, _ := resolver.Mutation().CreatePost(ctx, model.CreatePostInput{
Title: "Test Post",
Content: "Content",
})
// 点赞
likedPost, err := resolver.Mutation().LikePost(ctx, post.ID)
require.NoError(t, err)
assert.Equal(t, 1, likedPost.Likes)
// 重复点赞应该失败
_, err = resolver.Mutation().LikePost(ctx, post.ID)
assert.Error(t, err)
}
使用 testclient 进行集成测试:
// graph/integration_test.go
package graph_test
import (
"testing"
"github.com/99designs/gqlgen/client"
"github.com/99designs/gqlgen/graphql/handler"
"github.com/stretchr/testify/require"
"github.com/yourusername/go-graphql-demo/graph"
"github.com/yourusername/go-graphql-demo/internal/store"
)
func TestGraphQLIntegration(t *testing.T) {
store := store.New()
srv := handler.NewDefaultServer(graph.NewExecutableSchema(graph.Config{
Resolvers: &graph.Resolver{Store: store},
}))
c := client.New(srv)
// 测试注册
var registerResp struct {
Register struct {
Token string
User struct {
ID string
Username string
}
}
}
c.MustPost(`
mutation {
register(input: {
username: "testuser"
email: "test@example.com"
password: "password123"
}) {
token
user {
id
username
}
}
}
`, ®isterResp)
require.NotEmpty(t, registerResp.Register.Token)
require.Equal(t, "testuser", registerResp.Register.User.Username)
// 测试创建文章(需要带 token)
var createPostResp struct {
CreatePost struct {
ID string
Title string
}
}
c.MustPost(`
mutation {
createPost(input: {
title: "Test Post"
content: "This is a test"
}) {
id
title
}
}
`, &createPostResp, client.AddHeader("Authorization", "Bearer "+registerResp.Register.Token))
require.NotEmpty(t, createPostResp.CreatePost.ID)
require.Equal(t, "Test Post", createPostResp.CreatePost.Title)
}
总结
恭喜你!我们已经构建了一个功能完善的 GraphQL API,包括:
- 完整的 CRUD 操作:用户注册、登录、文章增删改查
- 关系查询:用户关注、粉丝、文章作者
- 实时订阅:新文章、新评论、点赞通知
- 认证系统:基于 JWT 的身份验证
- 性能优化:使用 DataLoader 解决 N+1 问题
- 分页支持:基于 cursor 的分页
- 错误处理:统一的错误格式
- 测试覆盖:单元测试和集成测试
GraphQL 相比 REST 的优势在于:
- 客户端可以精确指定需要的字段
- 一次请求获取多个资源
- 强类型的 schema 提供更好的文档和工具支持
- 更容易演进,不需要版本号
但也有一些挑战:
- 学习曲线较陡
- 缓存比 REST 复杂
- 需要仔细处理 N+1 问题
- 文件上传支持不如 REST 直接
在生产环境中,你还需要考虑:
- 查询复杂度限制(防止恶意查询)
- 速率限制
- 日志和监控
- 数据库集成(替换内存存储)
- 更完善的认证授权
希望这篇文章能帮助你掌握 Go + GraphQL 的开发。下次我们将深入探讨 Go 的垃圾回收调优,敬请期待!
继续阅读
探索更多技术文章
浏览归档,发现更多关于系统设计、工具链和工程实践的内容。