Go 1.18 泛型:期待十年的类型参数

深入理解 Go 1.18 引入的泛型特性,学习类型参数、类型约束和类型集合的使用

Go 1.18 泛型:期待十年的类型参数

2022 年 3 月,Go 1.18 正式发布,这是 Go 语言诞生以来最重要的一次更新——泛型(Generics)终于来了!

从 2010 年社区开始讨论泛型,到 2022 年正式落地,整整 12 年的时间。这期间,Go 团队尝试了无数种方案,最终选择了"类型参数"(Type Parameters)这个设计。

本文将带你深入理解 Go 泛型的设计哲学、语法细节和最佳实践。

为什么需要泛型?

在没有泛型的时代,我们写过太多重复的代码:

// 为每种类型都要写一遍
func MinInt(a, b int) int {
    if a < b {
        return a
    }
    return b
}

func MinFloat64(a, b float64) float64 {
    if a < b {
        return a
    }
    return b
}

func MinString(a, b string) string {
    if a < b {
        return a
    }
    return b
}

或者使用 interface{} 牺牲类型安全:

// 失去类型安全
func Min(a, b interface{}) interface{} {
    // 需要类型断言,运行时才能发现错误
    switch a := a.(type) {
    case int:
        if a < b.(int) {
            return a
        }
        return b
    // ... 其他类型
    }
}

泛型让我们能够写出既通用又类型安全的代码。

泛型基础语法

类型参数

package main

import "fmt"

// T 是类型参数,~int | ~float64 | ~string 是类型约束
func Min[T ~int | ~float64 | ~string](a, b T) T {
    if a < b {
        return a
    }
    return b
}

func main() {
    // 类型推断,不需要显式指定类型参数
    fmt.Println(Min(1, 2))           // 1
    fmt.Println(Min(3.14, 2.71))     // 2.71
    fmt.Println(Min("apple", "banana")) // apple
    
    // 显式指定类型参数
    fmt.Println(Min[int](1, 2))
}

类型约束

类型约束定义了类型参数可以接受的类型集合:

// 精确类型约束
type ExactInt interface {
    int
}

// 近似类型约束(使用 ~)
type ApproxInt interface {
    ~int  // 包括 int 和所有底层类型为 int 的类型
}

// 组合约束
type Number interface {
    ~int | ~int8 | ~int16 | ~int32 | ~int64 |
    ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
    ~float32 | ~float64
}

// 使用约束
func Sum[T Number](numbers []T) T {
    var sum T
    for _, n := range numbers {
        sum += n
    }
    return sum
}

// 自定义整数类型
type MyInt int

func main() {
    // MyInt 的底层类型是 int,满足 ~int 约束
    nums := []MyInt{1, 2, 3, 4, 5}
    fmt.Println(Sum(nums)) // 15
}

标准库中的约束

Go 在 golang.org/x/exp/constraints 包中提供了常用约束:

package main

import (
    "fmt"
    "golang.org/x/exp/constraints"
)

// 使用标准约束
func Min[T constraints.Ordered](a, b T) T {
    if a < b {
        return a
    }
    return b
}

// constraints 包提供的约束:
// - Signed:有符号整数
// - Unsigned:无符号整数
// - Integer:所有整数
// - Float:浮点数
// - Complex:复数
// - Ordered:可排序类型(整数、浮点数、字符串)

func Abs[T constraints.Signed | constraints.Float](n T) T {
    if n < 0 {
        return -n
    }
    return n
}

func main() {
    fmt.Println(Min(1, 2))
    fmt.Println(Min(3.14, 2.71))
    fmt.Println(Min("a", "b"))
    
    fmt.Println(Abs(-42))
    fmt.Println(Abs(-3.14))
}

泛型函数实战

1. Map、Filter、Reduce

package main

import "fmt"

// Map 对切片中的每个元素应用函数
func Map[T any, U any](s []T, f func(T) U) []U {
    result := make([]U, len(s))
    for i, v := range s {
        result[i] = f(v)
    }
    return result
}

// Filter 过滤切片中的元素
func Filter[T any](s []T, f func(T) bool) []T {
    var result []T
    for _, v := range s {
        if f(v) {
            result = append(result, v)
        }
    }
    return result
}

// Reduce 将切片归约为单个值
func Reduce[T any, U any](s []T, init U, f func(U, T) U) U {
    result := init
    for _, v := range s {
        result = f(result, v)
    }
    return result
}

func main() {
    numbers := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
    
    // Map:每个数平方
    squares := Map(numbers, func(n int) int {
        return n * n
    })
    fmt.Println("Squares:", squares)
    
    // Filter:过滤偶数
    evens := Filter(numbers, func(n int) bool {
        return n%2 == 0
    })
    fmt.Println("Evens:", evens)
    
    // Reduce:求和
    sum := Reduce(numbers, 0, func(acc, n int) int {
        return acc + n
    })
    fmt.Println("Sum:", sum)
    
    // 链式调用
    result := Reduce(
        Map(
            Filter(numbers, func(n int) bool { return n%2 == 0 }),
            func(n int) int { return n * n },
        ),
        0,
        func(acc, n int) int { return acc + n },
    )
    fmt.Println("Sum of even squares:", result) // 220
}

2. 泛型容器

package main

import "fmt"

// Stack 泛型栈
type Stack[T any] struct {
    items []T
}

func NewStack[T any]() *Stack[T] {
    return &Stack[T]{
        items: make([]T, 0),
    }
}

func (s *Stack[T]) Push(item T) {
    s.items = append(s.items, item)
}

func (s *Stack[T]) Pop() (T, bool) {
    if len(s.items) == 0 {
        var zero T
        return zero, false
    }
    item := s.items[len(s.items)-1]
    s.items = s.items[:len(s.items)-1]
    return item, true
}

func (s *Stack[T]) Peek() (T, bool) {
    if len(s.items) == 0 {
        var zero T
        return zero, false
    }
    return s.items[len(s.items)-1], true
}

func (s *Stack[T]) Size() int {
    return len(s.items)
}

// Set 泛型集合
type Set[T comparable] struct {
    items map[T]struct{}
}

func NewSet[T comparable]() *Set[T] {
    return &Set[T]{
        items: make(map[T]struct{}),
    }
}

func (s *Set[T]) Add(item T) {
    s.items[item] = struct{}{}
}

func (s *Set[T]) Remove(item T) {
    delete(s.items, item)
}

func (s *Set[T]) Contains(item T) bool {
    _, ok := s.items[item]
    return ok
}

func (s *Set[T]) Size() int {
    return len(s.items)
}

// Union 并集
func (s *Set[T]) Union(other *Set[T]) *Set[T] {
    result := NewSet[T]()
    for item := range s.items {
        result.Add(item)
    }
    for item := range other.items {
        result.Add(item)
    }
    return result
}

// Intersect 交集
func (s *Set[T]) Intersect(other *Set[T]) *Set[T] {
    result := NewSet[T]()
    for item := range s.items {
        if other.Contains(item) {
            result.Add(item)
        }
    }
    return result
}

func main() {
    // 使用泛型栈
    intStack := NewStack[int]()
    intStack.Push(1)
    intStack.Push(2)
    intStack.Push(3)
    
    for intStack.Size() > 0 {
        item, _ := intStack.Pop()
        fmt.Printf("Popped: %d\n", item)
    }
    
    // 使用泛型集合
    set1 := NewSet[string]()
    set1.Add("apple")
    set1.Add("banana")
    set1.Add("cherry")
    
    set2 := NewSet[string]()
    set2.Add("banana")
    set2.Add("date")
    set2.Add("elderberry")
    
    union := set1.Union(set2)
    intersect := set1.Intersect(set2)
    
    fmt.Printf("Union size: %d\n", union.Size())
    fmt.Printf("Intersect contains banana: %v\n", intersect.Contains("banana"))
}

3. 泛型 Channel 工具

package main

import (
    "context"
    "fmt"
    "time"
)

// MapChannel 对 channel 中的每个元素应用函数
func MapChannel[T any, U any](ctx context.Context, in <-chan T, f func(T) U) <-chan U {
    out := make(chan U)
    go func() {
        defer close(out)
        for {
            select {
            case <-ctx.Done():
                return
            case v, ok := <-in:
                if !ok {
                    return
                }
                out <- f(v)
            }
        }
    }()
    return out
}

// FilterChannel 过滤 channel 中的元素
func FilterChannel[T any](ctx context.Context, in <-chan T, f func(T) bool) <-chan T {
    out := make(chan T)
    go func() {
        defer close(out)
        for {
            select {
            case <-ctx.Done():
                return
            case v, ok := <-in:
                if !ok {
                    return
                }
                if f(v) {
                    out <- v
                }
            }
        }
    }()
    return out
}

// MergeChannels 合并多个 channel
func MergeChannels[T any](ctx context.Context, channels ...<-chan T) <-chan T {
    out := make(chan T)
    
    for _, ch := range channels {
        go func(c <-chan T) {
            for {
                select {
                case <-ctx.Done():
                    return
                case v, ok := <-c:
                    if !ok {
                        return
                    }
                    out <- v
                }
            }
        }(ch)
    }
    
    return out
}

func main() {
    ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
    defer cancel()
    
    // 创建输入 channel
    in := make(chan int)
    go func() {
        defer close(in)
        for i := 1; i <= 10; i++ {
            in <- i
        }
    }()
    
    // 过滤偶数并平方
    filtered := FilterChannel(ctx, in, func(n int) bool {
        return n%2 == 0
    })
    
    squared := MapChannel(ctx, filtered, func(n int) int {
        return n * n
    })
    
    // 消费结果
    for result := range squared {
        fmt.Printf("Result: %d\n", result)
    }
}

泛型的限制

1. 不能用于方法类型参数

type Container[T any] struct {
    value T
}

// ❌ 错误:方法不能有额外的类型参数
func (c *Container[T]) Transform[U any](f func(T) U) *Container[U] {
    return &Container[U]{value: f(c.value)}
}

// ✅ 正确:使用泛型函数
func Transform[T any, U any](c *Container[T], f func(T) U) *Container[U] {
    return &Container[U]{value: f(c.value)}
}

2. 不能嵌入类型参数

// ❌ 错误
type Wrapper[T any] struct {
    T  // 不能嵌入类型参数
}

3. 类型断言不支持类型参数

func Process[T any](v interface{}) {
    // ❌ 错误
    // if t, ok := v.(T); ok { ... }
    
    // ✅ 正确:使用 reflect
}

何时使用泛型?

适合使用泛型的场景

  1. 数据结构:栈、队列、集合、树等
  2. 算法:排序、搜索、图算法等
  3. 函数式工具:Map、Filter、Reduce
  4. 处理切片的通用函数

不适合使用泛型的场景

  1. 业务逻辑:通常类型是确定的
  2. 可以用接口解决的问题:如果只需要行为,用接口
  3. 只有一个实现的代码:不要为了泛型而泛型
// ❌ 不好:只有一个实现,却用了泛型
type UserService[T any] struct {
    db *sql.DB
}

// ✅ 好:直接使用具体类型
type UserService struct {
    db *sql.DB
}

性能考虑

泛型在编译时实例化,性能与手写具体类型版本相同:

// 编译器会生成两个版本:
// Min[int]
// Min[float64]
func Min[T constraints.Ordered](a, b T) T {
    if a < b {
        return a
    }
    return b
}

基准测试显示,泛型版本与手写版本性能相当。

总结

Go 泛型的设计哲学是简单实用

  1. 类型参数:使用 [T constraint] 语法
  2. 类型约束:使用接口定义,支持 ~ 近似类型
  3. 类型推断:大多数情况下不需要显式指定类型参数
  4. 编译时实例化:性能与手写版本相同

记住几个最佳实践:

  • 优先使用标准库提供的约束
  • 不要过度使用泛型
  • 接口和泛型是互补的,不是替代关系
  • 泛型让代码更通用,但也更复杂

Go 泛型的到来,让这门语言更加完善。但记住 Rob Pike 的话:"Less is more",不要为了炫技而使用泛型。

继续阅读

探索更多技术文章

浏览归档,发现更多关于系统设计、工具链和工程实践的内容。

全部文章 返回首页