Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

database/sql - 数据库操作

概述

database/sql 包提供了围绕 SQL(或类似 SQL)数据库的通用接口。

重要说明

  • ⚠️ 仅提供通用接口:不实现具体数据库驱动
  • ⚠️ 需要驱动:必须配合具体数据库驱动使用(如 github.com/go-sql-driver/mysql
  • 统一 API:所有数据库使用相同的接口

主要用途

  • 🔗 数据库连接管理:连接池、连接生命周期
  • 📝 执行 SQL 语句:查询、插入、更新、删除
  • 📊 处理结果集:行遍历、列扫描
  • 🔄 事务管理:事务开始、提交、回滚
  • 🛡️ 预处理语句:防止 SQL 注入

核心类型

1. DB - 数据库对象

type DB struct {
    // 包含过滤或未导出的字段
}

功能:表示数据库连接池,不是单个连接。

特点

  • 线程安全:多个 goroutine 可同时使用
  • 连接池:自动管理连接
  • 延迟连接:创建时不建立实际连接
  • ⚠️ 需要关闭:使用 defer db.Close()

创建方法

db, err := sql.Open("driver-name", "data-source-name")

重要方法

// 连接管理
func (db *DB) Close() error
func (db *DB) Ping() error
func (db *DB) SetMaxOpenConns(n int)
func (db *DB) SetMaxIdleConns(n int)
func (db *DB) SetConnMaxLifetime(d time.Duration)

// 查询
func (db *DB) Query(query string, args ...interface{}) (*Rows, error)
func (db *DB) QueryRow(query string, args ...interface{}) *Row
func (db *DB) Exec(query string, args ...interface{}) (Result, error)

// 预处理
func (db *DB) Prepare(query string) (*Stmt, error)

// 事务
func (db *DB) Begin() (*Tx, error)
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)

2. Rows - 结果集

type Rows struct {
    // 包含过滤或未导出的字段
}

功能:表示查询结果集。

特点

  • 延迟加载:数据按需读取
  • 需要关闭:使用 defer rows.Close()
  • ⚠️ 单向遍历:只能向前遍历

重要方法

// 遍历
func (rs *Rows) Next() bool
func (rs *Rows) Scan(dest ...interface{}) error
func (rs *Rows) Close() error

// 错误和统计
func (rs *Rows) Err() error
func (rs *Rows) Columns() ([]string, error)
func (rs *Rows) ColumnTypes() ([]*ColumnType, error)

// 游标
func (rs *Rows) NextResultSet() bool

3. Row - 单行结果

type Row struct {
    // 包含过滤或未导出的字段
}

功能:表示查询结果的单行。

重要方法

func (r *Row) Scan(dest ...interface{}) error
func (r *Row) Err() error

4. Stmt - 预处理语句

type Stmt struct {
    // 包含过滤或未导出的字段
}

功能:表示预处理的 SQL 语句。

特点

  • 防止 SQL 注入:参数化查询
  • 提高性能:重复执行相同语句
  • ⚠️ 需要关闭:使用 defer stmt.Close()

重要方法

func (s *Stmt) Exec(args ...interface{}) (Result, error)
func (s *Stmt) Query(args ...interface{}) (*Rows, error)
func (s *Stmt) QueryRow(args ...interface{}) *Row
func (s *Stmt) Close() error

5. Tx - 事务

type Tx struct {
    // 包含过滤或未导出的字段
}

功能:表示数据库事务。

特点

  • 原子性:所有操作成功或全部失败
  • ⚠️ 需要结束:必须 Commit 或 Rollback
  • ⚠️ 不能跨事务使用:Tx 上的 Stmt 仅在该事务中有效

重要方法

func (tx *Tx) Commit() error
func (tx *Tx) Rollback() error
func (tx *Tx) Exec(query string, args ...interface{}) (Result, error)
func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error)
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row
func (tx *Tx) Prepare(query string) (*Stmt, error)

6. Result - 执行结果

type Result interface {
    LastInsertId() (int64, error)
    RowsAffected() (int64, error)
}

功能:表示 SQL 执行结果。


7. ColumnType - 列类型信息

type ColumnType struct {
    // 包含过滤或未导出的字段
}

重要方法

func (ci *ColumnType) Name() string
func (ci *ColumnType) ScanType() reflect.Type
func (ci *ColumnType) DatabaseTypeName() string
func (ci *ColumnType) Length() (length int64, ok bool)
func (ci *ColumnType) Precision() (precision, scale int64, ok bool)
func (ci *ColumnType) Nullable() (nullable, ok bool)

数据库连接管理

示例 1:打开数据库连接

package main

import (
    "database/sql"
    "fmt"
    "log"
    _ "github.com/go-sql-driver/mysql" // MySQL 驱动
)

func main() {
    // 1. 打开数据库连接
    // 格式:用户名:密码@协议 (地址)/数据库名?参数
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4&parseTime=True&loc=Local"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 2. 验证连接
    err = db.Ping()
    if err != nil {
        log.Fatal("连接失败:", err)
    }
    
    fmt.Println("✓ 数据库连接成功")
}

示例 2:配置连接池

package main

import (
    "database/sql"
    "fmt"
    "log"
    "time"
    _ "github.com/go-sql-driver/mysql"
)

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4&parseTime=True"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 1. 配置连接池
    db.SetMaxOpenConns(25)     // 最大打开连接数
    db.SetMaxIdleConns(5)      // 最大空闲连接数
    db.SetConnMaxLifetime(5 * time.Minute) // 连接最大生命周期
    
    // 2. 验证连接
    err = db.Ping()
    if err != nil {
        log.Fatal(err)
    }
    
    fmt.Println("✓ 连接池配置成功")
    
    // 3. 查看连接池统计
    stats := db.Stats()
    fmt.Printf("最大打开连接数:%d\n", stats.MaxOpenConnections)
    fmt.Printf("当前打开连接数:%d\n", stats.OpenConnections)
    fmt.Printf("当前空闲连接数:%d\n", stats.Idle)
}

连接池参数说明

  • SetMaxOpenConns:最大打开连接数(默认无限制,推荐 25-100)
  • SetMaxIdleConns:最大空闲连接数(默认 2,推荐 5-10)
  • SetConnMaxLifetime:连接最大生命周期(防止连接老化,推荐 5-30 分钟)

示例 3:连接生命周期管理

package main

import (
    "context"
    "database/sql"
    "fmt"
    "log"
    "time"
    _ "github.com/go-sql-driver/mysql"
)

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 配置连接池
    db.SetMaxOpenConns(10)
    db.SetMaxIdleConns(5)
    db.SetConnMaxLifetime(30 * time.Minute)
    
    // 模拟并发请求
    for i := 0; i < 5; i++ {
        go func(id int) {
            // 使用 context 控制查询超时
            ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
            defer cancel()
            
            var result int
            err := db.QueryRowContext(ctx, "SELECT 1").Scan(&result)
            if err != nil {
                log.Printf("查询失败:%v", err)
                return
            }
            
            fmt.Printf("Goroutine %d: 查询成功\n", id)
            
            // 显示连接池状态
            stats := db.Stats()
            fmt.Printf("  打开连接:%d, 空闲连接:%d\n", stats.OpenConnections, stats.Idle)
        }(i)
    }
    
    // 等待一段时间
    time.Sleep(2 * time.Second)
    
    // 显示最终统计
    stats := db.Stats()
    fmt.Printf("\n最终统计:\n")
    fmt.Printf("  最大打开连接:%d\n", stats.MaxOpenConnections)
    fmt.Printf("  当前打开连接:%d\n", stats.OpenConnections)
    fmt.Printf("  当前空闲连接:%d\n", stats.Idle)
    fmt.Printf("  总请求数:%d\n", stats.Requests)
    
    time.Sleep(1 * time.Second)
}

基本查询操作

示例 4:查询单行数据(QueryRow)

package main

import (
    "database/sql"
    "fmt"
    "log"
    _ "github.com/go-sql-driver/mysql"
)

type User struct {
    ID       int
    Name     string
    Email    string
    Age      int
}

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 1. 查询单行
    var user User
    err = db.QueryRow("SELECT id, name, email, age FROM users WHERE id = ?", 1).
        Scan(&user.ID, &user.Name, &user.Email, &user.Age)
    
    if err == sql.ErrNoRows {
        fmt.Println("未找到记录")
        return
    }
    if err != nil {
        log.Fatal(err)
    }
    
    fmt.Printf("用户:%+v\n", user)
    
    // 2. 查询聚合值
    var count int
    err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
    if err != nil {
        log.Fatal(err)
    }
    fmt.Printf("用户总数:%d\n", count)
    
    // 3. 查询可空字段(使用 sql.NullString)
    var nullableEmail sql.NullString
    err = db.QueryRow("SELECT email FROM users WHERE id = ?", 1).Scan(&nullableEmail)
    if err != nil {
        log.Fatal(err)
    }
    
    if nullableEmail.Valid {
        fmt.Printf("邮箱:%s\n", nullableEmail.String)
    } else {
        fmt.Println("邮箱:NULL")
    }
}

示例 5:查询多行数据(Query)

package main

import (
    "database/sql"
    "fmt"
    "log"
    _ "github.com/go-sql-driver/mysql"
)

type User struct {
    ID       int
    Name     string
    Email    sql.NullString // 处理 NULL 值
    Age      int
}

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 1. 基本查询
    rows, err := db.Query("SELECT id, name, email, age FROM users ORDER BY id")
    if err != nil {
        log.Fatal(err)
    }
    defer rows.Close() // ⚠️ 必须关闭
    
    var users []User
    for rows.Next() {
        var user User
        err := rows.Scan(&user.ID, &user.Name, &user.Email, &user.Age)
        if err != nil {
            log.Fatal(err)
        }
        users = append(users, user)
    }
    
    // 检查遍历过程中的错误
    if err := rows.Err(); err != nil {
        log.Fatal(err)
    }
    
    fmt.Printf("找到 %d 个用户\n", len(users))
    for _, user := range users {
        fmt.Printf("  ID: %d, 姓名:%s", user.ID, user.Name)
        if user.Email.Valid {
            fmt.Printf(", 邮箱:%s", user.Email.String)
        }
        fmt.Printf(", 年龄:%d\n", user.Age)
    }
    
    // 2. 带条件查询
    rows, err = db.Query("SELECT id, name, age FROM users WHERE age > ? AND age < ?", 18, 30)
    if err != nil {
        log.Fatal(err)
    }
    defer rows.Close()
    
    fmt.Println("\n18-30 岁的用户:")
    for rows.Next() {
        var id, age int
        var name string
        rows.Scan(&id, &name, &age)
        fmt.Printf("  %s (%d 岁)\n", name, age)
    }
}

示例 6:处理 NULL 值

package main

import (
    "database/sql"
    "fmt"
    "log"
    _ "github.com/go-sql-driver/mysql"
)

type UserProfile struct {
    ID         int
    Name       string
    Email      sql.NullString
    Phone      sql.NullString
    Age        sql.NullInt64
    Score      sql.NullFloat64
    Active     sql.NullBool
}

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    rows, err := db.Query("SELECT id, name, email, phone, age, score, active FROM users")
    if err != nil {
        log.Fatal(err)
    }
    defer rows.Close()
    
    for rows.Next() {
        var profile UserProfile
        err := rows.Scan(
            &profile.ID,
            &profile.Name,
            &profile.Email,
            &profile.Phone,
            &profile.Age,
            &profile.Score,
            &profile.Active,
        )
        if err != nil {
            log.Fatal(err)
        }
        
        fmt.Printf("用户:%s\n", profile.Name)
        
        // 处理 NULL 值
        if profile.Email.Valid {
            fmt.Printf("  邮箱:%s\n", profile.Email.String)
        } else {
            fmt.Printf("  邮箱:未提供\n")
        }
        
        if profile.Phone.Valid {
            fmt.Printf("  电话:%s\n", profile.Phone.String)
        }
        
        if profile.Age.Valid {
            fmt.Printf("  年龄:%d\n", profile.Age.Int64)
        }
        
        if profile.Score.Valid {
            fmt.Printf("  分数:%.2f\n", profile.Score.Float64)
        }
        
        if profile.Active.Valid {
            fmt.Printf("  状态:%v\n", profile.Active.Bool)
        }
        
        fmt.Println()
    }
}

sql.Null 类型*:

  • sql.NullString:可空字符串
  • sql.NullInt64:可空整数
  • sql.NullFloat64:可空浮点数
  • sql.NullBool:可空布尔值
  • sql.NullTime:可空时间

数据修改操作

示例 7:插入数据

package main

import (
    "database/sql"
    "fmt"
    "log"
    "time"
    _ "github.com/go-sql-driver/mysql"
)

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 1. 插入单条记录
    result, err := db.Exec(
        "INSERT INTO users (name, email, age, created_at) VALUES (?, ?, ?, ?)",
        "张三",
        "zhangsan@example.com",
        25,
        time.Now(),
    )
    if err != nil {
        log.Fatal(err)
    }
    
    // 获取插入的 ID
    id, err := result.LastInsertId()
    if err != nil {
        log.Fatal(err)
    }
    fmt.Printf("插入成功,ID: %d\n", id)
    
    // 获取影响的行数
    rows, err := result.RowsAffected()
    if err != nil {
        log.Fatal(err)
    }
    fmt.Printf("影响行数:%d\n", rows)
    
    // 2. 插入多条记录(批量插入)
    users := []struct {
        Name  string
        Email string
        Age   int
    }{
        {"李四", "lisi@example.com", 28},
        {"王五", "wangwu@example.com", 22},
        {"赵六", "zhaoliu@example.com", 30},
    }
    
    for _, user := range users {
        result, err := db.Exec(
            "INSERT INTO users (name, email, age, created_at) VALUES (?, ?, ?, ?)",
            user.Name, user.Email, user.Age, time.Now(),
        )
        if err != nil {
            log.Printf("插入失败:%v", err)
            continue
        }
        id, _ := result.LastInsertId()
        fmt.Printf("插入用户 %s, ID: %d\n", user.Name, id)
    }
}

示例 8:更新数据

package main

import (
    "database/sql"
    "fmt"
    "log"
    _ "github.com/go-sql-driver/mysql"
)

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 1. 更新单条记录
    result, err := db.Exec(
        "UPDATE users SET email = ?, age = ? WHERE id = ?",
        "newemail@example.com",
        26,
        1,
    )
    if err != nil {
        log.Fatal(err)
    }
    
    rows, err := result.RowsAffected()
    if err != nil {
        log.Fatal(err)
    }
    fmt.Printf("更新了 %d 条记录\n", rows)
    
    // 2. 条件更新
    result, err = db.Exec(
        "UPDATE users SET active = ? WHERE age > ?",
        true,
        18,
    )
    if err != nil {
        log.Fatal(err)
    }
    
    rows, err = result.RowsAffected()
    fmt.Printf("激活了 %d 个成年用户\n", rows)
    
    // 3. 使用 NULL 更新
    result, err = db.Exec(
        "UPDATE users SET email = NULL WHERE id = ?",
        2,
    )
    if err != nil {
        log.Fatal(err)
    }
    rows, _ = result.RowsAffected()
    fmt.Printf("清空了 %d 个用户的邮箱\n", rows)
}

示例 9:删除数据

package main

import (
    "database/sql"
    "fmt"
    "log"
    _ "github.com/go-sql-driver/mysql"
)

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 1. 删除单条记录
    result, err := db.Exec("DELETE FROM users WHERE id = ?", 1)
    if err != nil {
        log.Fatal(err)
    }
    
    rows, err := result.RowsAffected()
    if err != nil {
        log.Fatal(err)
    }
    fmt.Printf("删除了 %d 条记录\n", rows)
    
    // 2. 条件删除
    result, err = db.Exec("DELETE FROM users WHERE age < ?", 18)
    if err != nil {
        log.Fatal(err)
    }
    
    rows, err = result.RowsAffected()
    fmt.Printf("删除了 %d 个未成年用户\n", rows)
    
    // 3. 软删除(更新而不是真正删除)
    result, err = db.Exec(
        "UPDATE users SET deleted_at = NOW() WHERE id = ?",
        2,
    )
    if err != nil {
        log.Fatal(err)
    }
    rows, _ = result.RowsAffected()
    fmt.Printf("软删除了 %d 条记录\n", rows)
}

预处理语句

示例 10:使用预处理语句

package main

import (
    "database/sql"
    "fmt"
    "log"
    _ "github.com/go-sql-driver/mysql"
)

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 1. 准备预处理语句
    stmt, err := db.Prepare("SELECT id, name, email FROM users WHERE id = ?")
    if err != nil {
        log.Fatal(err)
    }
    defer stmt.Close() // ⚠️ 必须关闭
    
    // 2. 多次执行
    for i := 1; i <= 5; i++ {
        var id int
        var name, email string
        
        err := stmt.QueryRow(i).Scan(&id, &name, &email)
        if err == sql.ErrNoRows {
            fmt.Printf("ID %d: 未找到\n", i)
            continue
        }
        if err != nil {
            log.Fatal(err)
        }
        
        fmt.Printf("ID %d: %s (%s)\n", id, name, email)
    }
    
    // 3. 使用预处理语句更新
    updateStmt, err := db.Prepare("UPDATE users SET email = ? WHERE id = ?")
    if err != nil {
        log.Fatal(err)
    }
    defer updateStmt.Close()
    
    // 批量更新
    for i := 1; i <= 3; i++ {
        result, err := updateStmt.Exec(fmt.Sprintf("user%d@example.com", i), i)
        if err != nil {
            log.Printf("更新失败:%v", err)
            continue
        }
        rows, _ := result.RowsAffected()
        fmt.Printf("更新 ID %d: 影响 %d 行\n", i, rows)
    }
}

预处理语句的优势

  • 防止 SQL 注入:参数自动转义
  • 提高性能:数据库可以缓存执行计划
  • 代码清晰:SQL 与数据分离

示例 11:批量操作

package main

import (
    "database/sql"
    "fmt"
    "log"
    "strings"
    _ "github.com/go-sql-driver/mysql"
)

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 1. 批量插入(单个 INSERT 语句)
    users := []struct {
        Name  string
        Email string
        Age   int
    }{
        {"用户 1", "user1@example.com", 20},
        {"用户 2", "user2@example.com", 21},
        {"用户 3", "user3@example.com", 22},
    }
    
    // 构建批量插入语句
    valueStrings := make([]string, 0, len(users))
    valueArgs := make([]interface{}, 0, len(users)*3)
    
    for _, user := range users {
        valueStrings = append(valueStrings, "(?, ?, ?)")
        valueArgs = append(valueArgs, user.Name, user.Email, user.Age)
    }
    
    query := fmt.Sprintf(
        "INSERT INTO users (name, email, age) VALUES %s",
        strings.Join(valueStrings, ","),
    )
    
    result, err := db.Exec(query, valueArgs...)
    if err != nil {
        log.Fatal(err)
    }
    
    rows, _ := result.RowsAffected()
    fmt.Printf("批量插入 %d 条记录\n", rows)
    
    // 2. 批量查询(使用 IN 子句)
    ids := []int{1, 2, 3, 4, 5}
    
    // 构建 IN 子句
    placeholders := make([]string, len(ids))
    args := make([]interface{}, len(ids))
    for i, id := range ids {
        placeholders[i] = "?"
        args[i] = id
    }
    
    query = fmt.Sprintf(
        "SELECT id, name FROM users WHERE id IN (%s)",
        strings.Join(placeholders, ","),
    )
    
    rows_result, err := db.Query(query, args...)
    if err != nil {
        log.Fatal(err)
    }
    defer rows_result.Close()
    
    fmt.Println("\n查询结果:")
    for rows_result.Next() {
        var id int
        var name string
        rows_result.Scan(&id, &name)
        fmt.Printf("  %d: %s\n", id, name)
    }
}

事务管理

示例 12:基本事务

package main

import (
    "database/sql"
    "fmt"
    "log"
    _ "github.com/go-sql-driver/mysql"
)

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 1. 开始事务
    tx, err := db.Begin()
    if err != nil {
        log.Fatal(err)
    }
    
    // 2. 使用 defer 确保回滚(如果忘记提交)
    defer func() {
        if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
            log.Printf("回滚失败:%v", err)
        }
    }()
    
    // 3. 在事务中执行操作
    // 插入用户
    result, err := tx.Exec(
        "INSERT INTO users (name, email, age) VALUES (?, ?, ?)",
        "事务用户", "tx@example.com", 25,
    )
    if err != nil {
        log.Fatal("插入失败:", err)
    }
    
    userID, err := result.LastInsertId()
    if err != nil {
        log.Fatal(err)
    }
    
    // 插入用户资料
    _, err = tx.Exec(
        "INSERT INTO user_profiles (user_id, bio) VALUES (?, ?)",
        userID, "这是个人简介",
    )
    if err != nil {
        log.Fatal("插入资料失败:", err)
    }
    
    // 4. 提交事务
    err = tx.Commit()
    if err != nil {
        log.Fatal("提交失败:", err)
    }
    
    fmt.Printf("✓ 事务成功,用户 ID: %d\n", userID)
}

示例 13:事务回滚

package main

import (
    "database/sql"
    "fmt"
    "log"
    _ "github.com/go-sql-driver/mysql"
)

func transferMoney(db *sql.DB, fromID, toID int, amount int64) error {
    // 开始事务
    tx, err := db.Begin()
    if err != nil {
        return err
    }
    
    // 确保回滚(如果未提交)
    defer func() {
        if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
            log.Printf("回滚失败:%v", err)
        }
    }()
    
    // 1. 检查余额
    var balance int64
    err = tx.QueryRow("SELECT balance FROM accounts WHERE id = ?", fromID).
        Scan(&balance)
    if err != nil {
        return fmt.Errorf("查询余额失败:%v", err)
    }
    
    if balance < amount {
        return fmt.Errorf("余额不足")
    }
    
    // 2. 扣款
    _, err = tx.Exec(
        "UPDATE accounts SET balance = balance - ? WHERE id = ?",
        amount, fromID,
    )
    if err != nil {
        return fmt.Errorf("扣款失败:%v", err)
    }
    
    // 3. 收款
    _, err = tx.Exec(
        "UPDATE accounts SET balance = balance + ? WHERE id = ?",
        amount, toID,
    )
    if err != nil {
        return fmt.Errorf("收款失败:%v", err)
    }
    
    // 4. 插入交易记录
    _, err = tx.Exec(
        "INSERT INTO transactions (from_id, to_id, amount) VALUES (?, ?, ?)",
        fromID, toID, amount,
    )
    if err != nil {
        return fmt.Errorf("记录交易失败:%v", err)
    }
    
    // 5. 提交事务
    if err := tx.Commit(); err != nil {
        return fmt.Errorf("提交失败:%v", err)
    }
    
    return nil
}

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 测试转账
    err = transferMoney(db, 1, 2, 100)
    if err != nil {
        fmt.Printf("转账失败:%v\n", err)
    } else {
        fmt.Println("✓ 转账成功")
    }
}

示例 14:事务选项和隔离级别

package main

import (
    "context"
    "database/sql"
    "fmt"
    "log"
    "time"
    _ "github.com/go-sql-driver/mysql"
)

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 1. 使用事务选项
    tx, err := db.BeginTx(context.Background(), &sql.TxOptions{
        Isolation: sql.LevelReadCommitted, // 读已提交
        ReadOnly:  false,
    })
    if err != nil {
        log.Fatal(err)
    }
    
    _, err = tx.Exec("UPDATE users SET active = ? WHERE id = ?", true, 1)
    if err != nil {
        tx.Rollback()
        log.Fatal(err)
    }
    
    err = tx.Commit()
    if err != nil {
        log.Fatal(err)
    }
    
    fmt.Println("✓ 使用隔离级别的事务成功")
    
    // 2. 带超时的上下文
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()
    
    tx2, err := db.BeginTx(ctx, nil)
    if err != nil {
        log.Fatal(err)
    }
    
    // 如果操作超过 5 秒,将自动取消
    _, err = tx2.Exec("SELECT SLEEP(10)")
    if err != nil {
        tx2.Rollback()
        fmt.Printf("操作超时或取消:%v\n", err)
    } else {
        tx2.Commit()
    }
    
    // 3. 只读事务
    tx3, err := db.BeginTx(context.Background(), &sql.TxOptions{
        ReadOnly: true,
    })
    if err != nil {
        log.Fatal(err)
    }
    
    // 只读事务不能执行写操作
    _, err = tx3.Exec("UPDATE users SET active = false")
    if err != nil {
        fmt.Printf("预期错误(只读事务):%v\n", err)
    }
    
    tx3.Rollback()
}

隔离级别

  • sql.LevelDefault:默认级别(由驱动决定)
  • sql.LevelReadUncommitted:读未提交(最低)
  • sql.LevelReadCommitted:读已提交
  • sql.LevelRepeatableRead:可重复读
  • sql.LevelSnapshot:快照隔离
  • sql.LevelSerializable:可串行化(最高)
  • sql.LevelWriteCommitted:写已提交

Context 支持

示例 15:使用 Context 控制查询

package main

import (
    "context"
    "database/sql"
    "fmt"
    "log"
    "time"
    _ "github.com/go-sql-driver/mysql"
)

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 1. 带超时的查询
    ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
    defer cancel()
    
    var result int
    err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result)
    if err != nil {
        log.Fatal(err)
    }
    fmt.Printf("查询结果:%d\n", result)
    
    // 2. 可取消的查询
    ctx2, cancel2 := context.WithCancel(context.Background())
    defer cancel2()
    
    // 模拟后台取消
    go func() {
        time.Sleep(3 * time.Second)
        fmt.Println("取消查询...")
        cancel2()
    }()
    
    rows, err := db.QueryContext(ctx2, "SELECT * FROM large_table")
    if err != nil {
        fmt.Printf("查询被取消或失败:%v\n", err)
    } else {
        defer rows.Close()
        fmt.Println("查询成功")
    }
    
    // 3. 带超时的执行
    ctx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel3()
    
    result2, err := db.ExecContext(ctx3, "UPDATE users SET active = ?", true)
    if err != nil {
        log.Fatal(err)
    }
    
    rows2, _ := result2.RowsAffected()
    fmt.Printf("更新了 %d 条记录\n", rows2)
    
    // 4. 带超时的预处理
    ctx4, cancel4 := context.WithTimeout(context.Background(), 2*time.Second)
    defer cancel4()
    
    stmt, err := db.PrepareContext(ctx4, "SELECT * FROM users WHERE id = ?")
    if err != nil {
        log.Fatal(err)
    }
    defer stmt.Close()
    
    var id int
    var name string
    err = stmt.QueryRowContext(ctx4, 1).Scan(&id, &name)
    if err != nil {
        log.Fatal(err)
    }
    
    fmt.Printf("用户:%s\n", name)
}

示例 16:Context 传播

package main

import (
    "context"
    "database/sql"
    "fmt"
    "log"
    "time"
    _ "github.com/go-sql-driver/mysql"
)

// 在调用链中传递 context
func getUser(ctx context.Context, db *sql.DB, id int) (string, error) {
    var name string
    err := db.QueryRowContext(ctx, "SELECT name FROM users WHERE id = ?", id).
        Scan(&name)
    if err != nil {
        return "", err
    }
    return name, nil
}

func getUserProfile(ctx context.Context, db *sql.DB, id int) error {
    // 使用同一个 context
    name, err := getUser(ctx, db, id)
    if err != nil {
        return err
    }
    
    // 继续其他操作
    var email string
    err = db.QueryRowContext(ctx, "SELECT email FROM users WHERE id = ?", id).
        Scan(&email)
    if err != nil {
        return err
    }
    
    fmt.Printf("用户:%s, 邮箱:%s\n", name, email)
    return nil
}

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 创建带超时的 context
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()
    
    // 传递 context
    err = getUserProfile(ctx, db, 1)
    if err != nil {
        log.Fatal(err)
    }
    
    fmt.Println("✓ 操作完成")
}

错误处理

示例 17:常见错误处理

package main

import (
    "database/sql"
    "errors"
    "fmt"
    "log"
    _ "github.com/go-sql-driver/mysql"
)

func main() {
    dsn := "root:password@tcp(localhost:3306)/testdb?charset=utf8mb4"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()
    
    // 1. 处理无结果
    var name string
    err = db.QueryRow("SELECT name FROM users WHERE id = ?", 999).Scan(&name)
    if err == sql.ErrNoRows {
        fmt.Println("未找到记录")
    } else if err != nil {
        log.Fatal(err)
    }
    
    // 2. 处理连接错误
    err = db.Ping()
    if err != nil {
        log.Printf("数据库连接失败:%v", err)
        // 可以重试或返回错误
    }
    
    // 3. 处理唯一约束冲突
    _, err = db.Exec(
        "INSERT INTO users (name, email) VALUES (?, ?)",
        "测试", "duplicate@example.com",
    )
    if err != nil {
        // MySQL 错误码 1062: 唯一键冲突
        var mysqlErr interface{ Number() uint16 }
        if errors.As(err, &mysqlErr) && mysqlErr.Number() == 1062 {
            fmt.Println("邮箱已存在")
        } else {
            log.Fatal(err)
        }
    }
    
    // 4. 处理外键约束
    _, err = db.Exec(
        "DELETE FROM users WHERE id = ?",
        1,
    )
    if err != nil {
        // MySQL 错误码 1451: 外键约束失败
        var mysqlErr interface{ Number() uint16 }
        if errors.As(err, &mysqlErr) && mysqlErr.Number() == 1451 {
            fmt.Println("存在关联记录,无法删除")
        }
    }
    
    // 5. 检查连接是否关闭
    err = db.Ping()
    if err == sql.ErrConnDone {
        fmt.Println("连接已关闭")
    }
    
    // 6. 检查事务已完成
    tx, _ := db.Begin()
    tx.Commit()
    err = tx.Commit() // 重复提交
    if err == sql.ErrTxDone {
        fmt.Println("事务已完成")
    }
}

安全最佳实践

✅ 推荐做法

  1. 始终使用参数化查询

    // ✅ 正确:防止 SQL 注入
    db.Query("SELECT * FROM users WHERE id = ?", userID)
    
    // ❌ 错误:SQL 注入风险
    db.Query(fmt.Sprintf("SELECT * FROM users WHERE id = %d", userID))
    
  2. 始终关闭资源

    // ✅ 使用 defer
    rows, err := db.Query("SELECT ...")
    if err != nil {
        return err
    }
    defer rows.Close()
    
  3. 使用连接池

    // ✅ 配置连接池
    db.SetMaxOpenConns(25)
    db.SetMaxIdleConns(5)
    db.SetConnMaxLifetime(5 * time.Minute)
    
  4. 使用 Context 控制超时

    // ✅ 设置超时
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()
    db.QueryRowContext(ctx, "SELECT ...")
    
  5. 使用事务保证原子性

    // ✅ 使用事务
    tx, err := db.Begin()
    if err != nil {
        return err
    }
    defer tx.Rollback()
    // ... 执行操作
    tx.Commit()
    

❌ 不安全做法

  1. 不要拼接 SQL 字符串

    // ❌ SQL 注入风险
    query := fmt.Sprintf("SELECT * FROM users WHERE name = '%s'", userInput)
    
  2. 不要忘记关闭资源

    // ❌ 资源泄漏
    rows, _ := db.Query("SELECT ...")
    // 忘记 defer rows.Close()
    
  3. 不要忽略错误

    // ❌ 忽略错误
    db.Query("SELECT ...") // 不检查错误
    

总结

核心 API

// 连接管理
db, err := sql.Open(driver, dsn)
db.Close()
db.Ping()
db.SetMaxOpenConns(n)
db.SetMaxIdleConns(n)
db.SetConnMaxLifetime(d)

// 查询
rows, err := db.Query(query, args...)
row := db.QueryRow(query, args...)
result, err := db.Exec(query, args...)

// 预处理
stmt, err := db.Prepare(query)

// 事务
tx, err := db.Begin()
tx.Commit()
tx.Rollback()

// Context 支持
db.QueryContext(ctx, query, args...)
db.QueryRowContext(ctx, query, args...)
db.ExecContext(ctx, query, args...)

使用场景

场景推荐方法说明
查询多行Query返回 *Rows
查询单行QueryRow返回 *Row
执行操作Exec返回 Result
重复执行Prepare预处理语句
原子操作Begin事务
超时控制Context 方法带超时的操作

数据类型映射

Go 类型SQL 类型
int, int64INT, BIGINT
float64FLOAT, DOUBLE
stringVARCHAR, TEXT
boolBOOLEAN, TINYINT
time.TimeDATETIME, TIMESTAMP
sql.NullStringVARCHAR (NULL)
sql.NullInt64BIGINT (NULL)
sql.NullFloat64FLOAT (NULL)
sql.NullBoolBOOLEAN (NULL)
sql.NullTimeDATETIME (NULL)

参考资料


最后更新:2026-04-03
Go 版本:Go 1.23+