Go中的依赖注入

本文翻译至这篇文章, 感谢原文作者Tin Rabzelj的精彩分享。

我已经编写了一个简单的处理依赖注入的go语言包(它在tinrab/kit这个项目中),这个包极其简单.

定义服务

首先得定义一个接口(interface),和至少一个实现它以及它的所有依赖的结构体(struct)。

数据库

SQLDatabase 接口将需要持有一个由Go官方包 database/sql 实现的数据库连接。真实的数据库将被一个私有的mySQLDatabase 结构体处理,它是通过mysql驱动连接到MySQL服务, 如下代码:

    package main
 
    import (
       "database/sql"

        _ "github.com/go-sql-driver/mysql"
        "github.com/tinrab/kit"
    )
 
    type SQLDatabase interface {
        kit.Dependency
        SQL() *sql.DB
    }

    type mySQLDatabase struct {
        address string
        conn    *sql.DB
    }

    func NewMySQLDatabase(address string) SQLDatabase {
        return &mySQLDatabase{
            address: address,
        }
    }

    func (db *mySQLDatabase) SQL() *sql.DB {
        return db.conn
    }

还需要实现kit.Dependency接口的Open和Close方法:

    func (db *mySQLDatabase) Open() error {
        conn, err := sql.Open("mysql", db.address)
        if err != nil {
            return err
        }

        db.conn = conn
        return nil
    }

    func (db *mySQLDatabase) Close() {
        db.conn.Close()
    }

用户仓库(UserRepository)

用户仓库用来管理这个应用的所有用户.

定义一个用户结构体:

    type User struct {
        ID uint64
        Name string
    }

定义一个用户仓库接口和Mysql实现的用户仓库结构体:

    package main

    import "github.com/tinrab/kit"

    type UserRepository interface {
        kit.Dependency
        GetUserByID(id uint64) (*User, error)
    }

    type mySQLUserRepository struct {
        Database SQLDatabase `inject:"database"`
    }

    func NewMySQLUserRepository() UserRepository {
        return &mySQLUserRepository{}
    }

    func (r *mySQLUserRepository) Open() error {
        return nil
    }

    func (r *mySQLUserRepository) Close() error {
    }

继续实现接口的其余方法. 注意 Database 属性上的 inject tag, database值的意义是 名为database的依赖将会被注入到这个属性上。

    func (r *mySQLUserRepository) GetUserByID(id uint64) (*User, error) {
        user := &User{}
        err := r.Database.SQL().QueryRow("SELECT * FROM users WHERE id = ?", id).
            Scan(&user.ID, &User.name)

        if err != nil {
            return nil, err
        }

        return user, nil
    }

文章仓库(PostRepository)

文章仓库和用户仓库非常相似

    type Post struct {
        ID uint64
        UserID uint64
        Title string
        Body string
    }

声明接口和结构体:

    package main 

    import "github.com/tinrab/kit"

    type PostRepository interface {
        kit.Dependency
        GetPostsByUser(userID uint64) []Post, error
    }

    type mySQLPostRepository struct {
        Database SQLDatabase `inject:"database"`
    }

    func NewMySQLPostRepository() PostRepository {
        return &mySQLPostRepository{}
    }

    func (r *mySQLPostRepository) Open() error {
        return nil
    }

    func (r *mySQLPostRepository) Close() error {
    }

方法GetPostsByUser通过用户的ID查询所属的文章:

    func (r *mySQLPostRepository) GetPostsByUser(userID uint64) ([]Post, error) {
        rows, err := r.Database.SQL().Query("SELECT * FROM posts WHERE user_id = ?", userID)
        if err != nil {
            return nil, err
        }
    
        var post Post
        var posts []Post
        for rows.Next() {
            err = rows.Scan(&post.ID, &post.UserID, &post.Title, &post.Body)
            if err != nil {
                return nil, err
            }
            posts = append(posts, post)
        }

        return posts, nil
    }

博客服务(Blog service)

博客服务使用上面实现的两个仓库来提供读取用户信息的接口:

    package main 

    import "github.com/tinrab/kit"

    type UserProfile struct {
        User User
        Posts []Post
    }

    type BlogService interface {
        kit.Dependency
        GetUserProfile(userID uint64) (*UserProfile, error)
    }

    type blogServiceImpl struct {
        UserRepository UserRepository `inject:"user.repository"`
        PostRepository PostRepository `inject:"post.repository"`
    }

    func NewBlogService() BlogService {
        return &blogServiceImpl{}
    }

    func (*blogServiceImpl) Open() error {
        return nil
    }

    func (*blogServiceImpl) Close() {
    }

如果正确的解析了依赖,所有的属性应该不包含空的值,下面实现 GetUserProfile 方法:

    func (s *blogServiceImpl) GetUserProfile(userID uint64) (*UserProfile, error) {
        user, err := s.UserRepository.GetUserByID(userID)
        if err != nil {
            return nil, err
        }

        posts, err := s.PostRepository.GetPostsByUser(userID)
        if err != nil {
            return nil, err
        }

        return &UserProfile{
            User:  *user,
            Posts: posts,
        }, nil
    }

解析依赖(Resolving dependencies)

为了注入所有的依赖,首先需要通过 Provide 方法建立起名称与依赖实例的映射关系,再调用 Resolve 方法:

    di := kit.NewDependencyInjection()

    di.Provide("database", NewMySQLDatabase("root:123456@tcp(127.0.0.1:3306)/blog"))
    di.Provide("user.repository", NewMySQLUserRepository())
    di.Provide("post.repository", NewMySQLPostRepository())
    di.Provide("blog.service", NewBlogService())

    if err := di.Resolve(); err != nil {
        log.Fatal(err)
    }

测试

依赖注入对于测试是非常友好的。

在这里,为了测试博客服务,用户和文章仓库会被mock。

写一个内存版的用户仓库实现

    package main
 
    import (
        "errors"
        "testing"
 
        "github.com/stretchr/testify/assert"
        "github.com/tinrab/kit"
    )

    type userRepositoryStub struct {
        users map[uint64]*User
    }

    func (r *userRepositoryStub) Open() error {
        r.users = map[uint64]*User{
           1:&User{ID: 1, Name: "User1"},
           2:&User{ID: 2, Name: "User2"},
           3:&User{ID: 3, Name: "User3"},
        }
        return nil
    }

    func (r *userRepositoryStub) Close() {
    }

    func (r *userRepositoryStub) GetUserByID(id uint64) (*User, error) {
        if user, ok := r.users[id]; ok {
            return user, nil
        }
        return nil, errors.New("User not found")
    }

同样的写一个内存版的文章仓库实现

    type postRepositoryStub struct {
       postsByUserID map[uint64][]Post
    }
 
    func (r *postRepositoryStub) Open() error {
        r.postsByUserID = map[uint64][]Post{
            1:[]Post{
                Post{ID: 1, UserID: 1, Title: "A", Body: "A"},
                Post{ID: 2, UserID: 1, Title: "B", Body: "B"},
            },
        }
        return nil
    }

    func (r *postRepositoryStub) Close() {
    }

    func (r *postRepositoryStub) GetPostsByUser(userID uint64) ([]Post, error) {
        if posts, ok := r.postsByUserID[userID]; ok {
            return posts, nil
        }
        return []Post{}, nil
    }

下面就是单元测试函数:

    package main
 
    import (
        "errors"
        "testing"
 
        "github.com/stretchr/testify/assert"
        "github.com/tinrab/kit"
    )

    func TestBlog(t *testing.T) {
        di := kit.NewDependencyInjection()
        
        di.Provide("database", NewMySQLDatabase("root:123456@tcp(127.0.0.1:3306)/blog"))
        di.Provide("user.repository", &userRepositoryStub{})
        di.Provide("post.repository", &postRepositoryStub{})
        di.Provide("blog.service", NewBlogService())

        if err := di.Resolve(); err != nil {
            t.Fatal(err)
        }

        blogService := di.Get("blog.service").(BlogService)
        profile, err := blogService.GetUserProfile(1)
        if err != nil {
            t.Fatal(err)
        }

        assert.Equal(t, "User1", profile.User.Name)
        assert.Equal(t, uint64(1), profile.Posts[0].UserID)
        assert.Equal(t, "A", profile.Posts[0].Title)
        assert.Equal(t, "A", profile.Posts[0].Body)
        assert.Equal(t, uint64(1), profile.Posts[1].UserID)
        assert.Equal(t, "B", profile.Posts[1].Title)
        assert.Equal(t, "B", profile.Posts[1].Body)
}

小结

  1. 通过对上面文章的翻译,使自己对go的接口的使用有了更深入的理解。接口可以隐藏实现,从而提高 系统的灵活性,能方便更新接口的实现方案,这要求我们在编写代码时尽量使用接口作为参数,或者结构体属性, 从而实现接口与实现的解耦。

  2. 对go语言的反射有了一定的认识,如果查看文章中使用的抵赖注入源码,我会看到它是利用reflect,结合结构体的tag机制来实现 依赖的自动注入的,如下:

    func (c *Container) inject(obj interface{}) {
        t := reflect.TypeOf(obj).Elem()
        for i := 0; i < t.NumField(); i++ {
            field := t.Field(i)
            inject := field.Tag.Get("inject")
            if inject == "" {
                continue
            }
            dependency := c.GetByName(inject)
            if dependency != nil {
                reflect.ValueOf(obj).Elem().Field(i).Set(reflect.ValueOf(dependency))
            }
        }
    }

上面的处理过程就是: 通过 reflect.TypeOf(obj).Elem() 获取到结构体类型信息,再迭代每一个结构体里面的属性,如果属性存在 inject tag, 那么获取到 reflect tag 的值,然后通过容器的 getByName 方法获取到实际的依赖,最后通过 reflect.ValueOf(obj).Elem().Field(i).Set(reflect.ValueOf(dependency)) 更新对应属性的值,从而完成依赖的注入。