目录

pgsql-34 Go语言连接PostgreSQL

34 - Go语言连接PostgreSQL

1. 📖 概述

本章介绍如何在Go语言中连接和操作PostgreSQL数据库,包括标准库database/sql、流行ORM库、连接池管理等。

2. 🚀 快速开始

2.1 安装驱动

# pgx - 推荐的PostgreSQL驱动(纯Go实现)
go get github.com/jackc/pgx/v5
go get github.com/jackc/pgx/v5/stdlib

# pq - 老牌驱动
go get github.com/lib/pq

# GORM - 流行的ORM
go get -u gorm.io/gorm
go get -u gorm.io/driver/postgres

2.2 基本连接

package main

import (
    "context"
    "database/sql"
    "fmt"
    "log"

    _ "github.com/jackc/pgx/v5/stdlib"  // pgx驱动
)

func main() {
    // 连接字符串
    connStr := "postgres://username:password@localhost:5432/dbname?sslmode=disable"

    // 打开数据库连接
    db, err := sql.Open("pgx", connStr)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()

    // 测试连接
    ctx := context.Background()
    err = db.PingContext(ctx)
    if err != nil {
        log.Fatal(err)
    }

    fmt.Println("Successfully connected to PostgreSQL!")
}

3. 📊 使用database/sql

3.1 CRUD操作

package main

import (
    "context"
    "database/sql"
    "fmt"
    "log"
    "time"

    _ "github.com/jackc/pgx/v5/stdlib"
)

// User 用户模型
type User struct {
    ID        int
    Username  string
    Email     string
    CreatedAt time.Time
}

func main() {
    db, err := sql.Open("pgx", "postgres://postgres:password@localhost:5432/mydb?sslmode=disable")
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()

    ctx := context.Background()

    // 创建表
    createTable(ctx, db)

    // 插入数据
    insertUser(ctx, db, "alice", "alice@example.com")
    insertUser(ctx, db, "bob", "bob@example.com")

    // 查询数据
    users := queryUsers(ctx, db)
    fmt.Println("All users:", users)

    // 查询单个用户
    user, err := queryUserByID(ctx, db, 1)
    if err != nil {
        log.Println(err)
    } else {
        fmt.Printf("User: %+v\n", user)
    }

    // 更新数据
    updateUser(ctx, db, 1, "alice_updated@example.com")

    // 删除数据
    deleteUser(ctx, db, 2)
}

func createTable(ctx context.Context, db *sql.DB) {
    query := `
    CREATE TABLE IF NOT EXISTS users (
        id SERIAL PRIMARY KEY,
        username VARCHAR(50) UNIQUE NOT NULL,
        email VARCHAR(100) UNIQUE NOT NULL,
        created_at TIMESTAMPTZ DEFAULT NOW()
    )`

    _, err := db.ExecContext(ctx, query)
    if err != nil {
        log.Fatal(err)
    }
    fmt.Println("Table created successfully")
}

func insertUser(ctx context.Context, db *sql.DB, username, email string) {
    query := `INSERT INTO users (username, email) VALUES ($1, $2) RETURNING id`

    var id int
    err := db.QueryRowContext(ctx, query, username, email).Scan(&id)
    if err != nil {
        log.Println("Insert error:", err)
        return
    }

    fmt.Printf("Inserted user with ID: %d\n", id)
}

func queryUsers(ctx context.Context, db *sql.DB) []User {
    query := `SELECT id, username, email, created_at FROM users ORDER BY id`

    rows, err := db.QueryContext(ctx, query)
    if err != nil {
        log.Fatal(err)
    }
    defer rows.Close()

    var users []User
    for rows.Next() {
        var user User
        err := rows.Scan(&user.ID, &user.Username, &user.Email, &user.CreatedAt)
        if err != nil {
            log.Println("Scan error:", err)
            continue
        }
        users = append(users, user)
    }

    if err = rows.Err(); err != nil {
        log.Fatal(err)
    }

    return users
}

func queryUserByID(ctx context.Context, db *sql.DB, id int) (*User, error) {
    query := `SELECT id, username, email, created_at FROM users WHERE id = $1`

    var user User
    err := db.QueryRowContext(ctx, query, id).Scan(
        &user.ID, &user.Username, &user.Email, &user.CreatedAt,
    )
    if err != nil {
        return nil, err
    }

    return &user, nil
}

func updateUser(ctx context.Context, db *sql.DB, id int, email string) {
    query := `UPDATE users SET email = $1 WHERE id = $2`

    result, err := db.ExecContext(ctx, query, email, id)
    if err != nil {
        log.Println("Update error:", err)
        return
    }

    rowsAffected, _ := result.RowsAffected()
    fmt.Printf("Updated %d row(s)\n", rowsAffected)
}

func deleteUser(ctx context.Context, db *sql.DB, id int) {
    query := `DELETE FROM users WHERE id = $1`

    result, err := db.ExecContext(ctx, query, id)
    if err != nil {
        log.Println("Delete error:", err)
        return
    }

    rowsAffected, _ := result.RowsAffected()
    fmt.Printf("Deleted %d row(s)\n", rowsAffected)
}

3.2 事务处理

func transferMoney(ctx context.Context, db *sql.DB, fromUserID, toUserID int, amount float64) error {
    // 开始事务
    tx, err := db.BeginTx(ctx, nil)
    if err != nil {
        return err
    }
    defer tx.Rollback()  // 如果没有提交,自动回滚

    // 扣款
    _, err = tx.ExecContext(ctx,
        `UPDATE accounts SET balance = balance - $1 WHERE user_id = $2`,
        amount, fromUserID,
    )
    if err != nil {
        return err
    }

    // 入账
    _, err = tx.ExecContext(ctx,
        `UPDATE accounts SET balance = balance + $1 WHERE user_id = $2`,
        amount, toUserID,
    )
    if err != nil {
        return err
    }

    // 记录转账日志
    _, err = tx.ExecContext(ctx,
        `INSERT INTO transfer_logs (from_user_id, to_user_id, amount) VALUES ($1, $2, $3)`,
        fromUserID, toUserID, amount,
    )
    if err != nil {
        return err
    }

    // 提交事务
    return tx.Commit()
}

// 使用示例
func main() {
    // ... 数据库连接 ...

    err := transferMoney(ctx, db, 1, 2, 100.50)
    if err != nil {
        log.Printf("Transfer failed: %v\n", err)
    } else {
        log.Println("Transfer successful")
    }
}

3.3 预处理语句

func batchInsertUsers(ctx context.Context, db *sql.DB, users []User) error {
    // 准备预处理语句
    stmt, err := db.PrepareContext(ctx,
        `INSERT INTO users (username, email) VALUES ($1, $2)`,
    )
    if err != nil {
        return err
    }
    defer stmt.Close()

    // 批量插入
    for _, user := range users {
        _, err = stmt.ExecContext(ctx, user.Username, user.Email)
        if err != nil {
            log.Printf("Failed to insert user %s: %v\n", user.Username, err)
        }
    }

    return nil
}

4. 🔧 使用pgx原生API

pgx提供了比database/sql更强大的功能。

package main

import (
    "context"
    "fmt"
    "log"

    "github.com/jackc/pgx/v5"
    "github.com/jackc/pgx/v5/pgxpool"
)

func main() {
    ctx := context.Background()

    // 使用连接池
    connStr := "postgres://postgres:password@localhost:5432/mydb"
    pool, err := pgxpool.New(ctx, connStr)
    if err != nil {
        log.Fatal(err)
    }
    defer pool.Close()

    // 查询示例
    var username string
    var email string
    err = pool.QueryRow(ctx, "SELECT username, email FROM users WHERE id = $1", 1).Scan(&username, &email)
    if err != nil {
        log.Fatal(err)
    }
    fmt.Printf("User: %s <%s>\n", username, email)

    // 批量查询
    rows, err := pool.Query(ctx, "SELECT id, username, email FROM users")
    if err != nil {
        log.Fatal(err)
    }
    defer rows.Close()

    for rows.Next() {
        var id int
        var username, email string
        if err := rows.Scan(&id, &username, &email); err != nil {
            log.Fatal(err)
        }
        fmt.Printf("%d: %s <%s>\n", id, username, email)
    }

    // 批量插入(更高效)
    batch := &pgx.Batch{}
    batch.Queue("INSERT INTO users (username, email) VALUES ($1, $2)", "user1", "user1@example.com")
    batch.Queue("INSERT INTO users (username, email) VALUES ($1, $2)", "user2", "user2@example.com")
    batch.Queue("INSERT INTO users (username, email) VALUES ($1, $2)", "user3", "user3@example.com")

    batchResults := pool.SendBatch(ctx, batch)
    defer batchResults.Close()

    for i := 0; i < batch.Len(); i++ {
        _, err := batchResults.Exec()
        if err != nil {
            log.Printf("Batch insert %d failed: %v\n", i, err)
        }
    }
}

4.1 COPY命令(高性能批量插入)

func bulkInsert(ctx context.Context, pool *pgxpool.Pool, users []User) error {
    // 使用COPY命令,比INSERT快10-100倍
    _, err := pool.CopyFrom(
        ctx,
        pgx.Identifier{"users"},
        []string{"username", "email"},
        pgx.CopyFromSlice(len(users), func(i int) ([]interface{}, error) {
            return []interface{}{users[i].Username, users[i].Email}, nil
        }),
    )

    return err
}

5. 🎯 使用GORM

GORM是Go最流行的ORM库。

package main

import (
    "fmt"
    "log"
    "time"

    "gorm.io/driver/postgres"
    "gorm.io/gorm"
    "gorm.io/gorm/logger"
)

// 模型定义
type User struct {
    ID        uint      `gorm:"primaryKey"`
    Username  string    `gorm:"uniqueIndex;not null"`
    Email     string    `gorm:"uniqueIndex;not null"`
    Age       int
    CreatedAt time.Time
    UpdatedAt time.Time
    DeletedAt gorm.DeletedAt `gorm:"index"`  // 软删除
    Orders    []Order         // 一对多关系
}

type Order struct {
    ID          uint      `gorm:"primaryKey"`
    UserID      uint      `gorm:"not null;index"`
    TotalAmount float64   `gorm:"type:decimal(10,2)"`
    Status      string    `gorm:"type:varchar(20);default:'pending'"`
    CreatedAt   time.Time
    UpdatedAt   time.Time
}

func main() {
    // 连接数据库
    dsn := "host=localhost user=postgres password=password dbname=mydb port=5432 sslmode=disable TimeZone=Asia/Shanghai"
    db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
        Logger: logger.Default.LogMode(logger.Info),  // 显示SQL日志
    })
    if err != nil {
        log.Fatal(err)
    }

    // 自动迁移
    db.AutoMigrate(&User{}, &Order{})

    // 创建用户
    user := User{
        Username: "alice",
        Email:    "alice@example.com",
        Age:      25,
    }
    result := db.Create(&user)
    if result.Error != nil {
        log.Println(result.Error)
    } else {
        fmt.Printf("Created user with ID: %d\n", user.ID)
    }

    // 批量创建
    users := []User{
        {Username: "bob", Email: "bob@example.com", Age: 30},
        {Username: "charlie", Email: "charlie@example.com", Age: 28},
    }
    db.Create(&users)

    // 查询
    var foundUser User
    db.First(&foundUser, 1)  // 查询ID=1
    fmt.Printf("Found user: %+v\n", foundUser)

    // 条件查询
    var adults []User
    db.Where("age >= ?", 18).Find(&adults)
    fmt.Printf("Found %d adults\n", len(adults))

    // 更新
    db.Model(&user).Update("email", "alice_new@example.com")

    // 批量更新
    db.Model(&User{}).Where("age < ?", 20).Update("status", "junior")

    // 删除(软删除)
    db.Delete(&user)

    // 永久删除
    db.Unscoped().Delete(&user)

    // 关联查询
    var userWithOrders User
    db.Preload("Orders").First(&userWithOrders, 1)
    fmt.Printf("User has %d orders\n", len(userWithOrders.Orders))
}

5.1 GORM高级功能

// 事务
func createUserWithOrder(db *gorm.DB, username, email string, amount float64) error {
    return db.Transaction(func(tx *gorm.DB) error {
        // 创建用户
        user := User{Username: username, Email: email}
        if err := tx.Create(&user).Error; err != nil {
            return err
        }

        // 创建订单
        order := Order{UserID: user.ID, TotalAmount: amount}
        if err := tx.Create(&order).Error; err != nil {
            return err
        }

        return nil
    })
}

// 复杂查询
func advancedQuery(db *gorm.DB) {
    var users []User

    // 链式查询
    db.Where("age > ?", 20).
        Where("email LIKE ?", "%@gmail.com").
        Order("created_at DESC").
        Limit(10).
        Find(&users)

    // 原生SQL
    db.Raw("SELECT * FROM users WHERE age > ? AND email LIKE ?", 20, "%@gmail.com").Scan(&users)

    // 聚合查询
    var count int64
    db.Model(&User{}).Where("age > ?", 20).Count(&count)

    // JOIN查询
    type Result struct {
        Username    string
        Email       string
        OrderCount  int
        TotalAmount float64
    }

    var results []Result
    db.Model(&User{}).
        Select("users.username, users.email, COUNT(orders.id) as order_count, SUM(orders.total_amount) as total_amount").
        Joins("LEFT JOIN orders ON orders.user_id = users.id").
        Group("users.id, users.username, users.email").
        Having("COUNT(orders.id) > ?", 0).
        Scan(&results)
}

6. 🔗 连接池配置

import (
    "database/sql"
    "time"
)

func setupDB(connStr string) (*sql.DB, error) {
    db, err := sql.Open("pgx", connStr)
    if err != nil {
        return nil, err
    }

    // 设置最大打开连接数
    db.SetMaxOpenConns(25)

    // 设置最大空闲连接数
    db.SetMaxIdleConns(5)

    // 设置连接最大存活时间
    db.SetConnMaxLifetime(5 * time.Minute)

    // 设置连接最大空闲时间
    db.SetConnMaxIdleTime(10 * time.Minute)

    return db, nil
}

// pgxpool配置
func setupPgxPool(connStr string) (*pgxpool.Pool, error) {
    config, err := pgxpool.ParseConfig(connStr)
    if err != nil {
        return nil, err
    }

    config.MaxConns = 25
    config.MinConns = 5
    config.MaxConnLifetime = 5 * time.Minute
    config.MaxConnIdleTime = 10 * time.Minute

    pool, err := pgxpool.NewWithConfig(context.Background(), config)
    if err != nil {
        return nil, err
    }

    return pool, nil
}

7. 📦 项目结构示例

myapp/
├── cmd/
│   └── server/
│       └── main.go
├── internal/
│   ├── database/
│   │   └── postgres.go      # 数据库连接
│   ├── models/
│   │   ├── user.go          # 用户模型
│   │   └── order.go         # 订单模型
│   ├── repository/
│   │   ├── user_repo.go     # 用户数据访问层
│   │   └── order_repo.go    # 订单数据访问层
│   ├── service/
│   │   ├── user_service.go  # 用户业务逻辑
│   │   └── order_service.go # 订单业务逻辑
│   └── handler/
│       ├── user_handler.go  # 用户HTTP处理器
│       └── order_handler.go # 订单HTTP处理器
├── migrations/              # 数据库迁移文件
│   ├── 001_create_users.sql
│   └── 002_create_orders.sql
├── go.mod
└── go.sum

7.1 repository层示例

// internal/repository/user_repo.go
package repository

import (
    "context"
    "database/sql"
    "myapp/internal/models"
)

type UserRepository struct {
    db *sql.DB
}

func NewUserRepository(db *sql.DB) *UserRepository {
    return &UserRepository{db: db}
}

func (r *UserRepository) Create(ctx context.Context, user *models.User) error {
    query := `
        INSERT INTO users (username, email, password_hash)
        VALUES ($1, $2, $3)
        RETURNING id, created_at
    `

    return r.db.QueryRowContext(ctx, query,
        user.Username, user.Email, user.PasswordHash,
    ).Scan(&user.ID, &user.CreatedAt)
}

func (r *UserRepository) GetByID(ctx context.Context, id int) (*models.User, error) {
    query := `
        SELECT id, username, email, created_at, updated_at
        FROM users
        WHERE id = $1
    `

    user := &models.User{}
    err := r.db.QueryRowContext(ctx, query, id).Scan(
        &user.ID, &user.Username, &user.Email, &user.CreatedAt, &user.UpdatedAt,
    )
    if err != nil {
        return nil, err
    }

    return user, nil
}

func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*models.User, error) {
    query := `
        SELECT id, username, email, password_hash, created_at
        FROM users
        WHERE email = $1
    `

    user := &models.User{}
    err := r.db.QueryRowContext(ctx, query, email).Scan(
        &user.ID, &user.Username, &user.Email, &user.PasswordHash, &user.CreatedAt,
    )
    if err != nil {
        return nil, err
    }

    return user, nil
}

func (r *UserRepository) Update(ctx context.Context, user *models.User) error {
    query := `
        UPDATE users
        SET username = $1, email = $2, updated_at = NOW()
        WHERE id = $3
    `

    _, err := r.db.ExecContext(ctx, query, user.Username, user.Email, user.ID)
    return err
}

func (r *UserRepository) Delete(ctx context.Context, id int) error {
    query := `DELETE FROM users WHERE id = $1`
    _, err := r.db.ExecContext(ctx, query, id)
    return err
}

8. 🎯 最佳实践

8.1 1. 使用context传递超时

func getUserWithTimeout(db *sql.DB, id int) (*User, error) {
    ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
    defer cancel()

    var user User
    err := db.QueryRowContext(ctx,
        "SELECT id, username, email FROM users WHERE id = $1", id,
    ).Scan(&user.ID, &user.Username, &user.Email)

    return &user, err
}

8.2 2. 使用占位符防止SQL注入

// 好: 使用占位符
username := "alice"
db.QueryContext(ctx, "SELECT * FROM users WHERE username = $1", username)

// 坏: 字符串拼接(SQL注入风险!)
// db.QueryContext(ctx, fmt.Sprintf("SELECT * FROM users WHERE username = '%s'", username))

8.3 3. 处理NULL值

import "database/sql"

type User struct {
    ID       int
    Username string
    Email    string
    Phone    sql.NullString  // 可能为NULL
    Age      sql.NullInt64   // 可能为NULL
}

// 查询
var user User
db.QueryRow("SELECT id, username, email, phone, age FROM users WHERE id = $1", 1).Scan(
    &user.ID, &user.Username, &user.Email, &user.Phone, &user.Age,
)

// 使用
if user.Phone.Valid {
    fmt.Println("Phone:", user.Phone.String)
} else {
    fmt.Println("Phone is NULL")
}

8.4 4. 使用数据库迁移工具

# 安装golang-migrate
go install -tags 'postgres' github.com/golang-migrate/migrate/v4/cmd/migrate@latest

# 创建迁移文件
migrate create -ext sql -dir migrations -seq create_users_table

# 执行迁移
migrate -path migrations -database "postgres://postgres:password@localhost:5432/mydb?sslmode=disable" up

# 回滚
migrate -path migrations -database "postgres://..." down

9. 📚 下一步


关键要点:

  • 使用pgx驱动性能更好
  • 使用连接池管理连接
  • 始终使用占位符防止SQL注入
  • 使用context控制超时
  • repository模式分离数据访问逻辑
  • GORM适合快速开发,pgx适合高性能