手把手教你实现一个 SQL 审核插件

文摘   2024-08-16 13:18   上海  

本文是投稿文章(欢迎大家投稿),作者张松鹤,来自敦煌网的高级 DBA,作者会在本文介绍 SQL 审核插件的关键技术点。下面是正文。

最新在做 SQL 审核的功能,在我们的管理平台中集成了 goInception[1] 审核工具,利用它对 DDL、DML 进行自动化审核校验,这也是大多数人的选择。在开发过程中有这样一个场景,我们期望研发人员在编辑 DML 工单时,不允许出现 DDL 语句;反之,在 DDL 工单中不允许出现 DML 语句,同时还希望禁用一些高危操作。

虽然 goInception 能够审核 DML、DDL 语句,但它无法区分操作类型,不能对工单信息进行有效约束。最初想到的解决办法是对 SQL 文本进行正则匹配,通过提取关键字来校验。这种方法的优点是实现简单,缺点是复杂的 SQL 正则表达式总有匹配不到的情况,而且大文本的正则匹配也不会很快。经调研,决定采用另一种方式,基于 AST 语法树来实现,下面记录一下实现过程:(SQL 文本 AST 解析比较复杂,我们直接使用 tidb 的 SQL 解析器,语法上兼容 MySQL)。

参考文档:  Parse a text SQL into an AST tree[2]

To convert a SQL text to an AST tree, you need to:
Use the parser.New() function to instantiate a parser, and
Invoke the method Parse(sql, charset, collation) on the parser.
Now you get the AST tree root of a SQL statement. It is time to extract the column names by traverse.

根据文档说明,需先实例化 SQL 解析器,接着使用 Pasre() 方法解析 SQL 文本以生成 AST tree,最后操作 AST 树,提取所需数据。 

初始化 SQL 解析器

p := parser.New()
// stmtNodes AST rootNode     
stmtNodes, _, err := p.Parse(rawSQL, """")       
if err != nil {
    log.Fatalf("Parse error: %v\n", err)        
}

操作 SQL AST

我们的需求是从 AST 中提取 SQL 语法关键字,不需要字段、备注等信息。此处需要实现 ast.Visitor[3] 接口,从 ast 中提取希望得到的数据。

type Visitor interface {
    Enter(n Node) (node Node, skipChildren bool)
    Leave(n Node) (node Node, ok bool)
}

实现 ast.Visitor 接口

// 定义 visitor 类型,存储 DDL、DML 操作关键字、和 SQL 高危操作关键字
type visitor struct {        
    KeywordsDDL map[string]struct{}  
    KeywordsDML map[string]struct{}    
    Blacklist   map[string]struct{} 
}

func newVisitor() *visitor {
    return &visitor{
        KeywordsDDL: make(map[string]struct{}),
        KeywordsDML: make(map[string]struct{}),
        Blacklist:   make(map[string]struct{}),
    }
}

// 主要在 Enter 方法中实现一些自定义的操作需求
// implemented Enter 提取 AST 节点信息。skipChildren 是否跳过 AST 子节点
func (v *visitor) Enter(n ast.Node) (node ast.Node, skipChildren bool) {
    // TODO
    // 1. 提取 DML 语法关键字
    // 2. 提取 DDL 语法关键字
    // 3. 提取高危操作关键字
    // 4. 自定义审核检验的规则,如  select 语句中必须包含 limit;  update 必须有 where;
    return n, false
}

// implemented Leave method
func (v *visitor) Leave(n ast.Node) (node ast.Node, ok bool) {
    return n, true
}

遍历 AST tree 提取关键字

// 使用 Visitor 模式遍历 AST
func extractKeywords(rootNode ast.Node) *visitor {
    vis := newVisitor()
    if _, ok := rootNode.Accept(vis); !ok {
        log.Fatalf("extractKeywords is failed")
    }
    return vis
}

通过执行 extractKeywords() 既可实现从 SQL 文本到 AST 语法关键字的获取。为下一步实现操作约束功能提供条件。

实现的效果

基于以上逻辑,实现了基于 AST 的 SQL 操作约束。

完整的 demo 如下:

package main

import (
    "bufio"
    "flag"
    "github.com/pingcap/parser"
    "github.com/pingcap/parser/ast"
    _ "github.com/pingcap/tidb/types/parser_driver"
    "io"
    "log"
    "os"
)

var None = struct{}{}

type visitor struct {
    KeywordsDDL map[string]struct{}
    KeywordsDML map[string]struct{}
    Blacklist   map[string]struct{}
}

func newVisitor() *visitor {
    return &visitor{
        KeywordsDDL: make(map[string]struct{}),
        KeywordsDML: make(map[string]struct{}),
        Blacklist:   make(map[string]struct{}),
    }
}

// Enter 提取 AST 节点信息。skipChildren 是否跳过 AST 子节点
func (v *visitor) Enter(n ast.Node) (node ast.Node, skipChildren bool) {
    switch x := n.(type) {
    // DML
    case *ast.SelectStmt:
        v.KeywordsDML["SELECT"] = None
        // 子树
        if x.Where != nil {
            v.KeywordsDML["WHERE"] = None
        }
        if x.Limit == nil {
            // TODO: 添加 LIMIT 控制策略
        }
        if x.OrderBy != nil {
            v.KeywordsDML["ORDER BY"] = None
        }
    case *ast.DeleteStmt:
        v.KeywordsDML["DELETE"] = None
    case *ast.InsertStmt:
        v.KeywordsDML["INSERT"] = None
    case *ast.UpdateStmt:
        v.KeywordsDML["UPDATE"] = None
    // DDL
    case *ast.AlterTableStmt:
        v.KeywordsDDL["ALTER TABLE"] = None
    case *ast.CreateIndexStmt:
        v.KeywordsDDL["CREATE INDEX"] = None
    case *ast.CreateTableStmt:
        v.KeywordsDDL["CREATE TABLE"] = None
    // 违规检测
    case *ast.DropTableStmt, *ast.DropIndexStmt, *ast.DropDatabaseStmt, *ast.DropUserStmt:
        v.Blacklist["DROP"] = None
    }
    return n, false
}

func (v *visitor) Leave(n ast.Node) (node ast.Node, ok bool) {
    return n, true
}

func extractKeywords(rootNode ast.Node) *visitor {
    // 使用 Visitor 模式遍历 AST
    vis := newVisitor()
    if _, ok := rootNode.Accept(vis); !ok {
        log.Fatalf("extractKeywords is failed")
    }
    return vis
}

func executeCheck(rawSQL string) *visitor {
    p := parser.New()

    stmtNodes, _, err := p.Parse(rawSQL, """")
    if err != nil {
        log.Fatalf("Parse error: %v\n", err)
    }

    v := newVisitor()
    for _, stmtNode := range stmtNodes {
        keywords := extractKeywords(stmtNode)
        for key, val := range keywords.KeywordsDML {
            v.KeywordsDML[key] = val
        }
        for key, val := range keywords.KeywordsDDL {
            v.KeywordsDDL[key] = val
        }
        for key, val := range keywords.Blacklist {
            v.Blacklist[key] = val
        }
    }
    return v
}

// ReadePlaintext 从文件中读取 SQL
func ReadePlaintext(filename string) (line []byte, error error) {
    file, err := os.OpenFile(filename, os.O_RDONLY, 0666)
    if err != nil {
        log.Fatalln("Open file error:", err)
    }
    defer file.Close()
    reader := bufio.NewReader(file)
    return io.ReadAll(reader)
}

func main() {
    var action, rawSQL, filename string
    flag.StringVar(&action, "t""dml""DML or DDL")
    flag.StringVar(&action, "type""dml""DML or DDL")
    flag.StringVar(&rawSQL, "s""""SQL")
    flag.StringVar(&rawSQL, "sql""""Input your SQL statement")
    flag.StringVar(&filename, "f""""Filename lite")
    flag.StringVar(&filename, "filename""""Filename")
    flag.Parse()

    if filename != "" {
        newRawSQL, _ := ReadePlaintext(filename)
        rawSQL = string(newRawSQL)
    }

    // SQL 合规性检测
    v := executeCheck(rawSQL)
    switch action {
    case "dml":
        if len(v.KeywordsDDL) != 0 || len(v.Blacklist) != 0 {
            log.Fatalf("DDL or Drop is not allowed\n. %#v, %#v", v.KeywordsDDL, v.Blacklist)
        }
    case "ddl":
        if len(v.KeywordsDML) != 0 || len(v.Blacklist) != 0 {
            log.Fatalf("DML or Drop is not allowed\n. %#v, %#v", v.KeywordsDML, v.Blacklist)
        }
    }
}

可用场景

  • 可集成到数据库管理平台
  • goInception 已停止更新,如果有需要可以进一步丰富功能,替换 goInception
  • 离线环境中对 SQL 语法的正确性、合规性进行校验的场景。减少人工审核,提升 SQL 审核效率
  • nl2sql 场景的精细化操作控制

本公众号主理人秦晓辉,开源项目 Open-Falcon、Nightingale 创始人,目前在可观测性领域创业,如果贵司想要构建监控/可观测性体系,欢迎联系我们,我们公司产品介绍:https://flashcat.cloud/

参考资料
[1] 

goInception: https://hanchuanchuan.github.io/goInception/zh/

[2] 

Parse a text SQL into an AST tree: https://github.com/pingcap/tidb/blob/master/pkg/parser/docs/quickstart.md

[3] 

ast.Visitor: https://pkg.go.dev/github.com/pingcap/tidb/pkg/parser/ast#Visitor


DBA札记
dba 数据库 知识科普 踩坑指南 经验分享 原理解读
 最新文章