Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

database/sql/driver - 数据库驱动接口

概述

database/sql/driver 包定义了数据库驱动需要实现的接口。

重要说明

  • ⚠️ 驱动开发者使用:普通应用开发者不需要直接使用
  • ⚠️ 底层接口:配合 database/sql 包使用
  • 统一标准:所有数据库驱动实现相同的接口

主要用途

  • 🛠️ 编写数据库驱动:实现特定数据库的驱动
  • 🔍 理解驱动原理:了解 database/sql 的工作机制
  • 🔧 自定义驱动:为特殊数据库或协议实现驱动

与 database/sql 的关系

  • database/sql:面向应用开发者的通用接口
  • database/sql/driver:面向驱动开发者的底层接口
  • 驱动实现 driver 接口,sql 包调用这些接口

核心接口

1. Driver - 驱动接口

type Driver interface {
    Open(name string) (Conn, error)
}

功能:数据库驱动的根接口。

方法说明

  • Open(name string):打开数据库连接
    • 参数 name:数据源名称(DSN)
    • 返回 Conn:数据库连接
    • 返回 error:错误信息

实现示例

type mysqlDriver struct{}

func (d *mysqlDriver) Open(name string) (driver.Conn, error) {
    // 解析 DSN,创建连接
    return &mysqlConn{dsn: name}, nil
}

2. Conn - 连接接口

type Conn interface {
    Prepare(query string) (Stmt, error)
    Close() error
    Begin() (Tx, error)
}

功能:表示数据库连接。

方法说明

  • Prepare(query string):准备预处理语句
  • Close() error:关闭连接
  • Begin() error:开始事务

扩展接口(Go 1.8+):

// 支持 Context
type ConnBeginTx interface {
    BeginTx(ctx context.Context, opts TxOptions) (Tx, error)
}

// 支持预处理语句缓存
type ConnPrepareContext interface {
    PrepareContext(ctx context.Context, query string) (Stmt, error)
}

// 支持 Ping
type Pinger interface {
    Ping(ctx context.Context) error
}

3. Stmt - 预处理语句接口

type Stmt interface {
    Close() error
    NumInput() int
    Exec(args []Value) (Result, error)
    Query(args []Value) (Rows, error)
}

功能:表示预处理的 SQL 语句。

方法说明

  • Close() error:关闭语句
  • NumInput() int:返回参数数量(-1 表示未知)
  • Exec(args []Value):执行语句(INSERT/UPDATE/DELETE)
  • Query(args []Value):执行查询(SELECT)

扩展接口(Go 1.8+):

// 支持 Context
type StmtExecContext interface {
    ExecContext(ctx context.Context, args []NamedValue) (Result, error)
}

type StmtQueryContext interface {
    QueryContext(ctx context.Context, args []NamedValue) (Rows, error)
}

// 支持列信息
type RowsColumnTypeDatabaseTypeName interface {
    ColumnTypeDatabaseTypeName(index int) string
}

type RowsColumnTypeLength interface {
    ColumnTypeLength(index int) (length int64, ok bool)
}

type RowsColumnTypeNullable interface {
    ColumnTypeNullable(index int) (nullable, ok bool)
}

type RowsColumnTypePrecisionScale interface {
    ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool)
}

type RowsColumnTypeScanType interface {
    ColumnTypeScanType(index int) reflect.Type
}

4. Rows - 结果集接口

type Rows interface {
    Columns() []string
    Close() error
    Next(dest []Value) error
}

功能:表示查询结果集。

方法说明

  • Columns() []string:返回列名
  • Close() error:关闭结果集
  • Next(dest []Value) error:移动到下一行并填充数据

扩展接口(Go 1.8+):

// 支持可扫描的结果集
type RowsNextResultSet interface {
    HasNextResultSet() bool
    NextResultSet() error
}

// 支持最后插入 ID
type RowsLastInsertId interface {
    LastInsertId() (int64, error)
}

// 支持受影响行数
type RowsAffected interface {
    RowsAffected() (int64, error)
}

5. Tx - 事务接口

type Tx interface {
    Commit() error
    Rollback() error
}

功能:表示数据库事务。

方法说明

  • Commit() error:提交事务
  • Rollback() error:回滚事务

6. Result - 结果接口

type Result interface {
    LastInsertId() (int64, error)
    RowsAffected() (int64, error)
}

功能:表示执行结果。

方法说明

  • LastInsertId(): 返回最后插入的 ID
  • RowsAffected(): 返回受影响的行数

7. Value - 值类型

type Value interface{}

功能:表示数据库值。

允许的类型

  • []byte(用于二进制数据)
  • bool
  • float64
  • int64
  • string
  • time.Time(用于日期时间)
  • driver.Value(用于自定义类型)

8. NamedValue - 命名参数

type NamedValue struct {
    Name    string // 参数名称
    Ordinal int    // 参数位置(从 1 开始)
    Value   Value  // 参数值
}

功能:表示命名参数。

使用场景

-- 位置参数
SELECT * FROM users WHERE id = ?

-- 命名参数
SELECT * FROM users WHERE id = @id

9. TxOptions - 事务选项

type TxOptions struct {
    Isolation IsolationLevel
    ReadOnly  bool
}

字段说明

  • Isolation:事务隔离级别
  • ReadOnly:是否只读事务

10. IsolationLevel - 隔离级别

type IsolationLevel int

const (
    LevelDefault IsolationLevel = iota
    LevelReadUncommitted
    LevelReadCommitted
    LevelWriteCommitted
    LevelRepeatableRead
    LevelSnapshot
    LevelSerializable
    LevelLinearizable
)

隔离级别说明

  • LevelDefault:默认级别(由数据库决定)
  • LevelReadUncommitted:读未提交(最低)
  • LevelReadCommitted:读已提交
  • LevelWriteCommitted:写已提交
  • LevelRepeatableRead:可重复读
  • LevelSnapshot:快照隔离
  • LevelSerializable:可串行化(最高)
  • LevelLinearizable:线性化

实现数据库驱动

示例 1:最小化驱动实现

package main

import (
    "database/sql/driver"
    "fmt"
    "log"
)

// 1. 实现 Driver 接口
type simpleDriver struct{}

func (d *simpleDriver) Open(name string) (driver.Conn, error) {
    fmt.Printf("打开连接:%s\n", name)
    return &simpleConn{name: name}, nil
}

// 2. 实现 Conn 接口
type simpleConn struct {
    name   string
    closed bool
}

func (c *simpleConn) Prepare(query string) (driver.Stmt, error) {
    fmt.Printf("准备语句:%s\n", query)
    return &simpleStmt{query: query}, nil
}

func (c *simpleConn) Close() error {
    c.closed = true
    fmt.Println("关闭连接")
    return nil
}

func (c *simpleConn) Begin() (driver.Tx, error) {
    fmt.Println("开始事务")
    return &simpleTx{}, nil
}

// 3. 实现 Stmt 接口
type simpleStmt struct {
    query string
}

func (s *simpleStmt) Close() error {
    fmt.Println("关闭语句")
    return nil
}

func (s *simpleStmt) NumInput() int {
    // 返回 -1 表示不检查参数数量
    return -1
}

func (s *simpleStmt) Exec(args []driver.Value) (driver.Result, error) {
    fmt.Printf("执行:%s, 参数:%v\n", s.query, args)
    return &simpleResult{}, nil
}

func (s *simpleStmt) Query(args []driver.Value) (driver.Rows, error) {
    fmt.Printf("查询:%s, 参数:%v\n", s.query, args)
    return &simpleRows{
        columns: []string{"id", "name"},
        data: [][]driver.Value{
            {int64(1), "Alice"},
            {int64(2), "Bob"},
        },
    }, nil
}

// 4. 实现 Rows 接口
type simpleRows struct {
    columns []string
    data    [][]driver.Value
    pos     int
}

func (r *simpleRows) Columns() []string {
    return r.columns
}

func (r *simpleRows) Close() error {
    fmt.Println("关闭结果集")
    return nil
}

func (r *simpleRows) Next(dest []driver.Value) error {
    if r.pos >= len(r.data) {
        return driver.ErrNoRows
    }
    
    row := r.data[r.pos]
    r.pos++
    
    // 复制数据到目标
    for i, v := range row {
        dest[i] = v
    }
    
    return nil
}

// 5. 实现 Tx 接口
type simpleTx struct{}

func (t *simpleTx) Commit() error {
    fmt.Println("提交事务")
    return nil
}

func (t *simpleTx) Rollback() error {
    fmt.Println("回滚事务")
    return nil
}

// 6. 实现 Result 接口
type simpleResult struct{}

func (r *simpleResult) LastInsertId() (int64, error) {
    return 1, nil
}

func (r *simpleResult) RowsAffected() (int64, error) {
    return 1, nil
}

// 7. 注册驱动
func init() {
    sql.Register("simple", &simpleDriver{})
}

func main() {
    // 使用自定义驱动
    db, err := sql.Open("simple", "test-db")
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 查询
    rows, err := db.Query("SELECT * FROM users")
    if err != nil {
        log.Fatal(err)
    }
    defer rows.Close()
    
    for rows.Next() {
        var id int
        var name string
        rows.Scan(&id, &name)
        fmt.Printf("用户:%d, %s\n", id, name)
    }
}

示例 2:支持 Context 的驱动

package main

import (
    "context"
    "database/sql/driver"
    "fmt"
    "time"
)

// 实现支持 Context 的连接
type contextConn struct {
    name   string
    closed bool
}

// 实现 ConnBeginTx 接口
func (c *contextConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
    fmt.Printf("开始事务(隔离级别:%v, 只读:%v)\n", opts.Isolation, opts.ReadOnly)
    
    // 检查 context 是否已取消
    select {
    case <-ctx.Done():
        return nil, ctx.Err()
    default:
    }
    
    return &contextTx{ctx: ctx}, nil
}

// 实现 ConnPrepareContext 接口
func (c *contextConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
    fmt.Printf("准备语句:%s\n", query)
    
    select {
    case <-ctx.Done():
        return nil, ctx.Err()
    default:
        return &contextStmt{ctx: ctx, query: query}, nil
    }
}

// 实现 Pinger 接口
func (c *contextConn) Ping(ctx context.Context) error {
    fmt.Println("Ping 数据库")
    
    // 模拟网络延迟
    select {
    case <-ctx.Done():
        return ctx.Err()
    case <-time.After(100 * time.Millisecond):
        return nil
    }
}

// 实现 Tx 接口
type contextTx struct {
    ctx context.Context
}

func (t *contextTx) Commit() error {
    select {
    case <-t.ctx.Done():
        return t.ctx.Err()
    default:
        fmt.Println("提交事务")
        return nil
    }
}

func (t *contextTx) Rollback() error {
    fmt.Println("回滚事务")
    return nil
}

// 实现 Stmt 接口
type contextStmt struct {
    ctx   context.Context
    query string
}

func (s *contextStmt) Close() error {
    return nil
}

func (s *contextStmt) NumInput() int {
    return -1
}

// 实现 StmtExecContext 接口
func (s *contextStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
    select {
    case <-ctx.Done():
        return nil, ctx.Err()
    case <-s.ctx.Done():
        return nil, s.ctx.Err()
    default:
        fmt.Printf("执行:%s\n", s.query)
        return &simpleResult{}, nil
    }
}

// 实现 StmtQueryContext 接口
func (s *contextStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
    select {
    case <-ctx.Done():
        return nil, ctx.Err()
    case <-s.ctx.Done():
        return nil, s.ctx.Err()
    default:
        fmt.Printf("查询:%s\n", s.query)
        return &contextRows{ctx: ctx}, nil
    }
}

// 实现 Rows 接口
type contextRows struct {
    ctx context.Context
    pos int
}

func (r *contextRows) Columns() []string {
    return []string{"id", "name"}
}

func (r *contextRows) Close() error {
    return nil
}

func (r *contextRows) Next(dest []driver.Value) error {
    select {
    case <-r.ctx.Done():
        return r.ctx.Err()
    default:
        if r.pos >= 2 {
            return driver.ErrNoRows
        }
        
        dest[0] = int64(r.pos + 1)
        dest[1] = fmt.Sprintf("User%d", r.pos+1)
        r.pos++
        return nil
    }
}

示例 3:支持命名参数

package main

import (
    "database/sql/driver"
    "fmt"
    "regexp"
    "strings"
)

// 实现支持命名参数的语句
type namedStmt struct {
    query string
}

func (s *namedStmt) Close() error {
    return nil
}

func (s *namedStmt) NumInput() int {
    return -1
}

func (s *namedStmt) Exec(args []driver.Value) (driver.Result, error) {
    return s.ExecContext(context.Background(), s.valuesToNamedValues(args))
}

func (s *namedStmt) Query(args []driver.Value) (driver.Rows, error) {
    return s.QueryContext(context.Background(), s.valuesToNamedValues(args))
}

// 实现 StmtExecContext 接口
func (s *namedStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
    // 将命名参数转换为位置参数
    query, convertedArgs := s.convertNamedToPositional(args)
    fmt.Printf("执行:%s, 参数:%v\n", query, convertedArgs)
    return &simpleResult{}, nil
}

// 实现 StmtQueryContext 接口
func (s *namedStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
    // 将命名参数转换为位置参数
    query, convertedArgs := s.convertNamedToPositional(args)
    fmt.Printf("查询:%s, 参数:%v\n", query, convertedArgs)
    return &simpleRows{}, nil
}

// 转换命名参数为位置参数
func (s *namedStmt) convertNamedToPositional(args []driver.NamedValue) (string, []driver.Value) {
    query := s.query
    convertedArgs := make([]driver.Value, 0, len(args))
    
    // 创建参数映射
    paramMap := make(map[string]int)
    for i, arg := range args {
        if arg.Name != "" {
            paramMap[arg.Name] = i
        }
    }
    
    // 替换命名参数
    re := regexp.MustCompile(`@(\w+)`)
    query = re.ReplaceAllStringFunc(query, func(match string) string {
        paramName := strings.TrimPrefix(match, "@")
        if idx, ok := paramMap[paramName]; ok {
            convertedArgs = append(convertedArgs, args[idx].Value)
            return "?"
        }
        return match
    })
    
    return query, convertedArgs
}

func (s *namedStmt) valuesToNamedValues(args []driver.Value) []driver.NamedValue {
    namedArgs := make([]driver.NamedValue, len(args))
    for i, arg := range args {
        namedArgs[i] = driver.NamedValue{
            Ordinal: i + 1,
            Value:   arg,
        }
    }
    return namedArgs
}

示例 4:实现完整的 MySQL 风格驱动

package main

import (
    "context"
    "database/sql"
    "database/sql/driver"
    "encoding/binary"
    "fmt"
    "net"
    "strconv"
    "strings"
    "time"
)

// MySQL 驱动
type mysqlDriver struct{}

func (d *mysqlDriver) Open(name string) (driver.Conn, error) {
    // 解析 DSN
    cfg, err := parseDSN(name)
    if err != nil {
        return nil, err
    }
    
    // 建立网络连接
    conn, err := net.Dial("tcp", cfg.Addr)
    if err != nil {
        return nil, err
    }
    
    return &mysqlConn{
        conn:   conn,
        cfg:    cfg,
        closed: false,
    }, nil
}

// 解析 DSN
func parseDSN(dsn string) (*mysqlConfig, error) {
    // 格式:user:pass@tcp(host:port)/db?params
    cfg := &mysqlConfig{
        User: "root",
        Addr: "localhost:3306",
    }
    
    // 简单解析(实际实现需要更复杂)
    parts := strings.Split(dsn, "@")
    if len(parts) >= 2 {
        authParts := strings.Split(parts[0], ":")
        if len(authParts) == 2 {
            cfg.User = authParts[0]
            cfg.Passwd = authParts[1]
        }
        
        addrParts := strings.Split(parts[1], "/")
        if len(addrParts) >= 2 {
            cfg.Addr = strings.Trim(addrParts[0], "()")
            cfg.DBName = addrParts[1]
        }
    }
    
    return cfg, nil
}

type mysqlConfig struct {
    User   string
    Passwd string
    Addr   string
    DBName string
}

// MySQL 连接
type mysqlConn struct {
    conn   net.Conn
    cfg    *mysqlConfig
    closed bool
}

func (c *mysqlConn) Prepare(query string) (driver.Stmt, error) {
    if c.closed {
        return nil, driver.ErrBadConn
    }
    
    return &mysqlStmt{conn: c, query: query}, nil
}

func (c *mysqlConn) Close() error {
    if c.closed {
        return nil
    }
    
    c.closed = true
    return c.conn.Close()
}

func (c *mysqlConn) Begin() (driver.Tx, error) {
    if c.closed {
        return nil, driver.ErrBadConn
    }
    
    // 执行 BEGIN 命令
    _, err := c.conn.Write([]byte("BEGIN\n"))
    if err != nil {
        return nil, err
    }
    
    return &mysqlTx{conn: c}, nil
}

func (c *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
    if c.closed {
        return nil, driver.ErrBadConn
    }
    
    // 检查 context
    select {
    case <-ctx.Done():
        return nil, ctx.Err()
    default:
    }
    
    // 设置隔离级别
    if opts.Isolation != driver.LevelDefault {
        isolationSQL := fmt.Sprintf("SET TRANSACTION ISOLATION LEVEL %s\n", 
            isolationLevelToString(opts.Isolation))
        c.conn.Write([]byte(isolationSQL))
    }
    
    return c.Begin()
}

func (c *mysqlConn) Ping(ctx context.Context) error {
    if c.closed {
        return driver.ErrBadConn
    }
    
    select {
    case <-ctx.Done():
        return ctx.Err()
    default:
        // 发送 ping 命令
        c.conn.Write([]byte("SELECT 1\n"))
        return nil
    }
}

// MySQL 语句
type mysqlStmt struct {
    conn  *mysqlConn
    query string
}

func (s *mysqlStmt) Close() error {
    return nil
}

func (s *mysqlStmt) NumInput() int {
    return -1
}

func (s *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
    return s.ExecContext(context.Background(), valuesToNamedValues(args))
}

func (s *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
    if s.conn.closed {
        return nil, driver.ErrBadConn
    }
    
    // 构建 SQL
    query := s.buildQuery(args)
    
    // 发送查询
    _, err := s.conn.conn.Write([]byte(query + "\n"))
    if err != nil {
        return nil, err
    }
    
    // 读取响应(简化)
    return &mysqlResult{}, nil
}

func (s *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
    return s.QueryContext(context.Background(), valuesToNamedValues(args))
}

func (s *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
    if s.conn.closed {
        return nil, driver.ErrBadConn
    }
    
    // 构建 SQL
    query := s.buildQuery(args)
    
    // 发送查询
    _, err := s.conn.conn.Write([]byte(query + "\n"))
    if err != nil {
        return nil, err
    }
    
    // 读取结果(简化)
    return &mysqlRows{conn: s.conn}, nil
}

func (s *mysqlStmt) buildQuery(args []driver.NamedValue) string {
    // 替换参数
    query := s.query
    for _, arg := range args {
        query = strings.Replace(query, "?", formatValue(arg.Value), 1)
    }
    return query
}

func formatValue(v driver.Value) string {
    switch v := v.(type) {
    case string:
        return "'" + strings.Replace(v, "'", "''", -1) + "'"
    case int64:
        return strconv.FormatInt(v, 10)
    case float64:
        return strconv.FormatFloat(v, 'g', -1, 64)
    case []byte:
        return fmt.Sprintf("0x%x", v)
    case time.Time:
        return fmt.Sprintf("'%s'", v.Format("2006-01-02 15:04:05"))
    default:
        return "NULL"
    }
}

// MySQL 事务
type mysqlTx struct {
    conn *mysqlConn
}

func (t *mysqlTx) Commit() error {
    if t.conn.closed {
        return driver.ErrBadConn
    }
    
    _, err := t.conn.conn.Write([]byte("COMMIT\n"))
    return err
}

func (t *mysqlTx) Rollback() error {
    if t.conn.closed {
        return driver.ErrBadConn
    }
    
    _, err := t.conn.conn.Write([]byte("ROLLBACK\n"))
    return err
}

// MySQL 结果
type mysqlResult struct{}

func (r *mysqlResult) LastInsertId() (int64, error) {
    return 0, nil
}

func (r *mysqlResult) RowsAffected() (int64, error) {
    return 1, nil
}

// MySQL 结果集
type mysqlRows struct {
    conn    *mysqlConn
    columns []string
    pos     int
}

func (r *mysqlRows) Columns() []string {
    return r.columns
}

func (r *mysqlRows) Close() error {
    return nil
}

func (r *mysqlRows) Next(dest []driver.Value) error {
    // 简化实现
    return driver.ErrNoRows
}

// 辅助函数
func valuesToNamedValues(args []driver.Value) []driver.NamedValue {
    namedArgs := make([]driver.NamedValue, len(args))
    for i, arg := range args {
        namedArgs[i] = driver.NamedValue{
            Ordinal: i + 1,
            Value:   arg,
        }
    }
    return namedArgs
}

func isolationLevelToString(level driver.IsolationLevel) string {
    switch level {
    case driver.LevelReadUncommitted:
        return "READ UNCOMMITTED"
    case driver.LevelReadCommitted:
        return "READ COMMITTED"
    case driver.LevelRepeatableRead:
        return "REPEATABLE READ"
    case driver.LevelSerializable:
        return "SERIALIZABLE"
    default:
        return "REPEATABLE READ"
    }
}

// 注册驱动
func init() {
    sql.Register("mysql-custom", &mysqlDriver{})
}

驱动注册和使用

示例 5:注册和初始化驱动

package main

import (
    "database/sql"
    "database/sql/driver"
    "fmt"
    "log"
    "sync"
)

// 自定义驱动
type customDriver struct {
    mu sync.Mutex
}

func (d *customDriver) Open(name string) (driver.Conn, error) {
    d.mu.Lock()
    defer d.mu.Unlock()
    return &customConn{name: name}, nil
}

type customConn struct {
    name string
}

func (c *customConn) Prepare(query string) (driver.Stmt, error) {
    return &customStmt{query: query}, nil
}

func (c *customConn) Close() error {
    return nil
}

func (c *customConn) Begin() (driver.Tx, error) {
    return &customTx{}, nil
}

type customStmt struct {
    query string
}

func (s *customStmt) Close() error {
    return nil
}

func (s *customStmt) NumInput() int {
    return -1
}

func (s *customStmt) Exec(args []driver.Value) (driver.Result, error) {
    fmt.Printf("执行:%s\n", s.query)
    return &customResult{}, nil
}

func (s *customStmt) Query(args []driver.Value) (driver.Rows, error) {
    fmt.Printf("查询:%s\n", s.query)
    return &customRows{}, nil
}

type customTx struct{}

func (t *customTx) Commit() error {
    return nil
}

func (t *customTx) Rollback() error {
    return nil
}

type customResult struct{}

func (r *customResult) LastInsertId() (int64, error) {
    return 1, nil
}

func (r *customResult) RowsAffected() (int64, error) {
    return 1, nil
}

type customRows struct{}

func (r *customRows) Columns() []string {
    return []string{"id", "name"}
}

func (r *customRows) Close() error {
    return nil
}

func (r *customRows) Next(dest []driver.Value) error {
    return driver.ErrNoRows
}

func main() {
    // 1. 注册驱动
    sql.Register("custom", &customDriver{})
    
    // 2. 打开数据库
    db, err := sql.Open("custom", "test-db")
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 3. 使用数据库
    rows, err := db.Query("SELECT * FROM users")
    if err != nil {
        log.Fatal(err)
    }
    defer rows.Close()
    
    fmt.Println("驱动使用成功")
}

最佳实践

✅ 推荐做法

  1. 始终实现 Context 接口

    // ✅ 实现这些接口以支持 Context
    type Conn interface {
        driver.Conn
        driver.ConnBeginTx
        driver.ConnPrepareContext
        driver.Pinger
    }
    
  2. 检查连接状态

    func (c *conn) Prepare(query string) (driver.Stmt, error) {
        if c.closed {
            return nil, driver.ErrBadConn
        }
        // ...
    }
    
  3. 正确处理 Context 取消

    func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
        select {
        case <-ctx.Done():
            return nil, ctx.Err()
        default:
            // 继续执行
        }
    }
    
  4. 实现所有可选接口

    // ✅ 实现所有相关接口以获得完整功能
    type Rows interface {
        driver.Rows
        driver.RowsNextResultSet
        driver.RowsColumnTypeDatabaseTypeName
        driver.RowsColumnTypeLength
        driver.RowsColumnTypeNullable
        driver.RowsColumnTypePrecisionScale
        driver.RowsColumnTypeScanType
    }
    

❌ 不安全做法

  1. 不要忽略 Context

    // ❌ 不支持 Context
    func (s *stmt) Query(args []driver.Value) (driver.Rows, error)
    
    // ✅ 支持 Context
    func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error)
    
  2. 不要返回无效的连接

    // ❌ 返回已关闭的连接
    if c.closed {
        return c, nil // 错误!
    }
    
    // ✅ 返回错误
    if c.closed {
        return nil, driver.ErrBadConn
    }
    

总结

核心接口

// 必须实现的核心接口
Driver                          // 驱动入口
Conn                            // 数据库连接
Stmt                            // 预处理语句
Rows                            // 结果集
Tx                              // 事务
Result                          // 执行结果

// 可选实现的扩展接口(Go 1.8+)
ConnBeginTx                     // 支持事务选项
ConnPrepareContext              // 支持 Context 预处理
Pinger                          // 支持 Ping
StmtExecContext                 // 支持 Context 执行
StmtQueryContext                // 支持 Context 查询
RowsNextResultSet               // 支持多结果集
RowsColumnType*                 // 支持列类型信息

接口关系

Driver
  └─> Open() → Conn
                ├─> Prepare() → Stmt
                │                ├─> Query() → Rows
                │                └─> Exec() → Result
                ├─> Begin() → Tx
                │                ├─> Commit()
                │                └─> Rollback()
                └─> Close()

数据类型映射

Go 类型driver.ValueSQL 类型
int64int64INT, BIGINT
float64float64FLOAT, DOUBLE
stringstringVARCHAR, TEXT
boolboolBOOLEAN
time.Timetime.TimeDATETIME, TIMESTAMP
[]byte[]byteBLOB, BINARY

实现检查清单

  • 实现 Driver 接口
  • 实现 Conn 接口(包括扩展接口)
  • 实现 Stmt 接口(包括扩展接口)
  • 实现 Rows 接口(包括扩展接口)
  • 实现 Tx 接口
  • 实现 Result 接口
  • 注册驱动(sql.Register
  • 实现 Context 支持
  • 处理连接状态检查
  • 处理错误和超时

参考资料


最后更新:2026-04-03
Go 版本:Go 1.23+