Go 1.18 泛型:期待十年的类型参数
2022 年 3 月,Go 1.18 正式发布,这是 Go 语言诞生以来最重要的一次更新——泛型(Generics)终于来了!
从 2010 年社区开始讨论泛型,到 2022 年正式落地,整整 12 年的时间。这期间,Go 团队尝试了无数种方案,最终选择了"类型参数"(Type Parameters)这个设计。
本文将带你深入理解 Go 泛型的设计哲学、语法细节和最佳实践。
为什么需要泛型?
在没有泛型的时代,我们写过太多重复的代码:
// 为每种类型都要写一遍
func MinInt(a, b int) int {
if a < b {
return a
}
return b
}
func MinFloat64(a, b float64) float64 {
if a < b {
return a
}
return b
}
func MinString(a, b string) string {
if a < b {
return a
}
return b
}
或者使用 interface{} 牺牲类型安全:
// 失去类型安全
func Min(a, b interface{}) interface{} {
// 需要类型断言,运行时才能发现错误
switch a := a.(type) {
case int:
if a < b.(int) {
return a
}
return b
// ... 其他类型
}
}
泛型让我们能够写出既通用又类型安全的代码。
泛型基础语法
类型参数
package main
import "fmt"
// T 是类型参数,~int | ~float64 | ~string 是类型约束
func Min[T ~int | ~float64 | ~string](a, b T) T {
if a < b {
return a
}
return b
}
func main() {
// 类型推断,不需要显式指定类型参数
fmt.Println(Min(1, 2)) // 1
fmt.Println(Min(3.14, 2.71)) // 2.71
fmt.Println(Min("apple", "banana")) // apple
// 显式指定类型参数
fmt.Println(Min[int](1, 2))
}
类型约束
类型约束定义了类型参数可以接受的类型集合:
// 精确类型约束
type ExactInt interface {
int
}
// 近似类型约束(使用 ~)
type ApproxInt interface {
~int // 包括 int 和所有底层类型为 int 的类型
}
// 组合约束
type Number interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 |
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
~float32 | ~float64
}
// 使用约束
func Sum[T Number](numbers []T) T {
var sum T
for _, n := range numbers {
sum += n
}
return sum
}
// 自定义整数类型
type MyInt int
func main() {
// MyInt 的底层类型是 int,满足 ~int 约束
nums := []MyInt{1, 2, 3, 4, 5}
fmt.Println(Sum(nums)) // 15
}
标准库中的约束
Go 在 golang.org/x/exp/constraints 包中提供了常用约束:
package main
import (
"fmt"
"golang.org/x/exp/constraints"
)
// 使用标准约束
func Min[T constraints.Ordered](a, b T) T {
if a < b {
return a
}
return b
}
// constraints 包提供的约束:
// - Signed:有符号整数
// - Unsigned:无符号整数
// - Integer:所有整数
// - Float:浮点数
// - Complex:复数
// - Ordered:可排序类型(整数、浮点数、字符串)
func Abs[T constraints.Signed | constraints.Float](n T) T {
if n < 0 {
return -n
}
return n
}
func main() {
fmt.Println(Min(1, 2))
fmt.Println(Min(3.14, 2.71))
fmt.Println(Min("a", "b"))
fmt.Println(Abs(-42))
fmt.Println(Abs(-3.14))
}
泛型函数实战
1. Map、Filter、Reduce
package main
import "fmt"
// Map 对切片中的每个元素应用函数
func Map[T any, U any](s []T, f func(T) U) []U {
result := make([]U, len(s))
for i, v := range s {
result[i] = f(v)
}
return result
}
// Filter 过滤切片中的元素
func Filter[T any](s []T, f func(T) bool) []T {
var result []T
for _, v := range s {
if f(v) {
result = append(result, v)
}
}
return result
}
// Reduce 将切片归约为单个值
func Reduce[T any, U any](s []T, init U, f func(U, T) U) U {
result := init
for _, v := range s {
result = f(result, v)
}
return result
}
func main() {
numbers := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
// Map:每个数平方
squares := Map(numbers, func(n int) int {
return n * n
})
fmt.Println("Squares:", squares)
// Filter:过滤偶数
evens := Filter(numbers, func(n int) bool {
return n%2 == 0
})
fmt.Println("Evens:", evens)
// Reduce:求和
sum := Reduce(numbers, 0, func(acc, n int) int {
return acc + n
})
fmt.Println("Sum:", sum)
// 链式调用
result := Reduce(
Map(
Filter(numbers, func(n int) bool { return n%2 == 0 }),
func(n int) int { return n * n },
),
0,
func(acc, n int) int { return acc + n },
)
fmt.Println("Sum of even squares:", result) // 220
}
2. 泛型容器
package main
import "fmt"
// Stack 泛型栈
type Stack[T any] struct {
items []T
}
func NewStack[T any]() *Stack[T] {
return &Stack[T]{
items: make([]T, 0),
}
}
func (s *Stack[T]) Push(item T) {
s.items = append(s.items, item)
}
func (s *Stack[T]) Pop() (T, bool) {
if len(s.items) == 0 {
var zero T
return zero, false
}
item := s.items[len(s.items)-1]
s.items = s.items[:len(s.items)-1]
return item, true
}
func (s *Stack[T]) Peek() (T, bool) {
if len(s.items) == 0 {
var zero T
return zero, false
}
return s.items[len(s.items)-1], true
}
func (s *Stack[T]) Size() int {
return len(s.items)
}
// Set 泛型集合
type Set[T comparable] struct {
items map[T]struct{}
}
func NewSet[T comparable]() *Set[T] {
return &Set[T]{
items: make(map[T]struct{}),
}
}
func (s *Set[T]) Add(item T) {
s.items[item] = struct{}{}
}
func (s *Set[T]) Remove(item T) {
delete(s.items, item)
}
func (s *Set[T]) Contains(item T) bool {
_, ok := s.items[item]
return ok
}
func (s *Set[T]) Size() int {
return len(s.items)
}
// Union 并集
func (s *Set[T]) Union(other *Set[T]) *Set[T] {
result := NewSet[T]()
for item := range s.items {
result.Add(item)
}
for item := range other.items {
result.Add(item)
}
return result
}
// Intersect 交集
func (s *Set[T]) Intersect(other *Set[T]) *Set[T] {
result := NewSet[T]()
for item := range s.items {
if other.Contains(item) {
result.Add(item)
}
}
return result
}
func main() {
// 使用泛型栈
intStack := NewStack[int]()
intStack.Push(1)
intStack.Push(2)
intStack.Push(3)
for intStack.Size() > 0 {
item, _ := intStack.Pop()
fmt.Printf("Popped: %d\n", item)
}
// 使用泛型集合
set1 := NewSet[string]()
set1.Add("apple")
set1.Add("banana")
set1.Add("cherry")
set2 := NewSet[string]()
set2.Add("banana")
set2.Add("date")
set2.Add("elderberry")
union := set1.Union(set2)
intersect := set1.Intersect(set2)
fmt.Printf("Union size: %d\n", union.Size())
fmt.Printf("Intersect contains banana: %v\n", intersect.Contains("banana"))
}
3. 泛型 Channel 工具
package main
import (
"context"
"fmt"
"time"
)
// MapChannel 对 channel 中的每个元素应用函数
func MapChannel[T any, U any](ctx context.Context, in <-chan T, f func(T) U) <-chan U {
out := make(chan U)
go func() {
defer close(out)
for {
select {
case <-ctx.Done():
return
case v, ok := <-in:
if !ok {
return
}
out <- f(v)
}
}
}()
return out
}
// FilterChannel 过滤 channel 中的元素
func FilterChannel[T any](ctx context.Context, in <-chan T, f func(T) bool) <-chan T {
out := make(chan T)
go func() {
defer close(out)
for {
select {
case <-ctx.Done():
return
case v, ok := <-in:
if !ok {
return
}
if f(v) {
out <- v
}
}
}
}()
return out
}
// MergeChannels 合并多个 channel
func MergeChannels[T any](ctx context.Context, channels ...<-chan T) <-chan T {
out := make(chan T)
for _, ch := range channels {
go func(c <-chan T) {
for {
select {
case <-ctx.Done():
return
case v, ok := <-c:
if !ok {
return
}
out <- v
}
}
}(ch)
}
return out
}
func main() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
// 创建输入 channel
in := make(chan int)
go func() {
defer close(in)
for i := 1; i <= 10; i++ {
in <- i
}
}()
// 过滤偶数并平方
filtered := FilterChannel(ctx, in, func(n int) bool {
return n%2 == 0
})
squared := MapChannel(ctx, filtered, func(n int) int {
return n * n
})
// 消费结果
for result := range squared {
fmt.Printf("Result: %d\n", result)
}
}
泛型的限制
1. 不能用于方法类型参数
type Container[T any] struct {
value T
}
// ❌ 错误:方法不能有额外的类型参数
func (c *Container[T]) Transform[U any](f func(T) U) *Container[U] {
return &Container[U]{value: f(c.value)}
}
// ✅ 正确:使用泛型函数
func Transform[T any, U any](c *Container[T], f func(T) U) *Container[U] {
return &Container[U]{value: f(c.value)}
}
2. 不能嵌入类型参数
// ❌ 错误
type Wrapper[T any] struct {
T // 不能嵌入类型参数
}
3. 类型断言不支持类型参数
func Process[T any](v interface{}) {
// ❌ 错误
// if t, ok := v.(T); ok { ... }
// ✅ 正确:使用 reflect
}
何时使用泛型?
适合使用泛型的场景
- 数据结构:栈、队列、集合、树等
- 算法:排序、搜索、图算法等
- 函数式工具:Map、Filter、Reduce
- 处理切片的通用函数
不适合使用泛型的场景
- 业务逻辑:通常类型是确定的
- 可以用接口解决的问题:如果只需要行为,用接口
- 只有一个实现的代码:不要为了泛型而泛型
// ❌ 不好:只有一个实现,却用了泛型
type UserService[T any] struct {
db *sql.DB
}
// ✅ 好:直接使用具体类型
type UserService struct {
db *sql.DB
}
性能考虑
泛型在编译时实例化,性能与手写具体类型版本相同:
// 编译器会生成两个版本:
// Min[int]
// Min[float64]
func Min[T constraints.Ordered](a, b T) T {
if a < b {
return a
}
return b
}
基准测试显示,泛型版本与手写版本性能相当。
总结
Go 泛型的设计哲学是简单实用:
- 类型参数:使用
[T constraint]语法 - 类型约束:使用接口定义,支持
~近似类型 - 类型推断:大多数情况下不需要显式指定类型参数
- 编译时实例化:性能与手写版本相同
记住几个最佳实践:
- 优先使用标准库提供的约束
- 不要过度使用泛型
- 接口和泛型是互补的,不是替代关系
- 泛型让代码更通用,但也更复杂
Go 泛型的到来,让这门语言更加完善。但记住 Rob Pike 的话:"Less is more",不要为了炫技而使用泛型。
继续阅读
探索更多技术文章
浏览归档,发现更多关于系统设计、工具链和工程实践的内容。