关联规则挖掘算法–Apriori算法

一、Apriori算法

  1. 简介

关联规则分析是数据挖掘中最活跃的研究方法之一,目的是在一个数据集中找到各项之间的关联关系,而这种关系并没有在数据中直接体现出来。Apriori算法 关联规则学习的经典算法之一,是R.Agrawal和R.Srikartt于1944年提出的一种具有影响力的挖掘布尔关联规则挖掘频繁项集的算法。
  1. 基本原理

关联规则的一般定义如下:

(1)项集:定义表示一个项集。

(2)事务集:设任务相关的数据D是数据库事务的集合,即D是事务的集合;每个事务T是项的集合,其中 。例如表示一个事务。

(3)关联规则蕴含式:关联规则形如A=>B的蕴含式,,并且

(4)支持度s:D中包含A和 B 的事务数与总的事务数的比值。规则 A=>B 在数据集D中的支持度为s, 其中s 表示D中包含 (即同时包含A和B)的事务的百分率,即概率P(

(5)可信度 c :D中同时包含A和B的事务数与只包含A的事务数的比值。规则 A=>B 在数据集D中的可信度为c, 其中c表示D中包含A的事务中也包含B的百分率.即可用条件概率P(B|A)表示:

其中supportCount表示项集出现的频率,即包含项集的事务数目。

(6)最小支持度 :表示规则中的所有项在事务中出现的频度。

(7)最小可信度 : 表示规则中左边的项(集)的出现暗示着右边的项(集)出现的频度。

关联规则挖掘过程:

(1)根据最小支持度找到数据集D中的所有频繁项集。

(2)由频繁项集产生强关联规则。强关联规则必须满足最小支持度和最小置信度

频繁项集两条重要性质:

(1) 频繁项集的子项集必是频繁项集。

(2)非频繁项集的超集一定是非频繁的。

3.算法过程

Apriori算法使用一种称作逐层搜索的迭代方法,利用k项集用于探索(k+1)项集。主要过程如下:

(1)首先,通过扫描数据库,累积每个项的计数,并收集满足最小支持度的项,找出频繁1项集的集合,该集合称作L1。

(2)然后,L1用于找频繁2项集的集合L2,L2用于寻找L3,如此下去,直到不能再找到新的频繁k项集。找每个 Lk要求对数据库做一次完全扫描。

Apriori算法使用的是产生-测试策略来发现频繁项集。每次迭代后,新的项集由前一次迭代发现的频繁项集产生,然后对每个候选的支持度进行计数,并与最小支持度阈值进行比较。算法需要迭代的总次数是kmax+1,其中kmax是频繁项集的最大长度。

二、Apriori算法举例

  1. 案例说明和数据

找到以下数据的频繁项集,最小支持度阈值是 2。其中每一行数据代表一个事务。

“I1 I2 I5”
“I2 I4”
“I2 I3”
“I1 I2 T4”
“I1 I3”
“I2 I3”
“I1 I3”
“I1 I2 I3 I5”
“I1 I2 I3”
  1. 基于python代码实现

(1)代码

def get_first_item(dataset):
    """
    获取数据集元素
    :param dataset: 数据集(二维数组)
    :return: 元素数组(二维)
    """
    dataset_item = []  # 数据集元素

    for i in dataset:
        for j in i:
            if [j] not in dataset_item:
                dataset_item.append([j])

    dataset_item.sort()

    return dataset_item


def get_freq_item(dataset, cand_list, min_sup):
    """
    获取频繁项集及计数
    :param dataset: 数据集
    :param cand_list: 候选集
    :param min_sup: 最小支持度
    :return: 频繁项集及计数和非频繁项集
    """
    cand_list_num = {}  # 候选集计数
    freq_list = []  # 存储频繁项集
    freq_list_num = {}  # 频繁项集计数
    unfreq_list = []  # 存储非频繁项集

    for i in cand_list:
        for j in dataset:
            if set(i).issubset(set(j)):  # 判断是否为子集
                cand_list_num[tuple(i)] = cand_list_num.get(tuple(i), 0) + 1
                # list 和dict不能作为主键

    for i in cand_list_num:
        if cand_list_num[i] >= min_sup:
            freq_list.append(list(i))
            freq_list_num[i] = cand_list_num[i]
        else:
            unfreq_list.append(list(i))
    return freq_list, freq_list_num, unfreq_list


def get_candiate(freq_list, unfreq_list):
    """
    获得候选集
    :param freq_list: 上一频繁项集
    :param unfreq_list: 上一非频繁项集
    :return: 候选集
    """
    cand_list = []  # 候选集

    length = len(freq_list[0]) + 1  # 记录上一单个频繁项集的长度+1

    for i in freq_list:
        for j in freq_list:
            com_item = list(set(i) | set(j))  # 求和集
            com_item.sort()  # 排个序
            if len(com_item) == length and com_item not in cand_list:
                cand_list.append(com_item)

                # 根据频繁项集的子集是频繁的,去除组合后的不该出现的候选集
    for m in unfreq_list:
        for n in cand_list:
            if set(m).issubset(set(n)):
                cand_list.remove(n)

    return cand_list


def Apriori(dataset, min_sup=2):
    """
    Apriori算法
    :param dataset: 数据集
    :param min_sup: 最小支持度
    :return: 所有频繁集,频繁集计数
    """
    first_can_list = get_first_item(dataset)  # 候选1项集
    freq_list, first_can_list_num, unfreq_list = get_freq_item(dataset, first_can_list, min_sup)  # 频繁1项集和计数字典

    freq_all_list = [freq_list]  # 所有频繁项集
    freq_all_list_num = first_can_list_num  # 所有频繁项集计数

    while True:
        cand_list = get_candiate(freq_list, unfreq_list)
        freq_list, freq_list_num, unfreq_list = get_freq_item(dataset, cand_list, min_sup)

        freq_all_list.append(freq_list)  # 更新频繁集
        freq_all_list_num.update(freq_list_num)  # 更新频繁集计数字典
        if len(freq_list) <= 1:
            break
    freq_all_list = [i for i in freq_all_list if i != []]  # 去除最后可能出现的空列表
    return freq_all_list, freq_all_list_num


if __name__ == '__main__':
    dataset = [['I1', 'I2', 'I5'],
               ['I2', 'I4'],
               ['I2', 'I3'],
               ['I1', 'I2', 'I4'],
               ['I1', 'I3'],
               ['I2', 'I3'],
               ['I1', 'I3'],
               ['I1', 'I2', 'I3', 'I5'],
               ['I1', 'I2', 'I3']]  # 数据集

    freq_list, freq_num = Apriori(dataset, min_sup=2)
    num = 1
    for i in freq_list:
        print(f"频繁{num}项集:{i}")
        num += 1
    for k,v in freq_num.items():
        print(f"频繁项集{k} 的个数:{v}")

(2)运行结果

  1. 基于java代码实现

(1)代码

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
public class AprioriAlgorithm {
    private int minSup;
    private static List<String> data;
    private static List<Set<String>> dataSet;

    public static void main(String[] args) {

        long startTime = System.currentTimeMillis();
        AprioriAlgorithm apriori = new AprioriAlgorithm();

      
        data = apriori.buildData();

        //设置最小支持度
        apriori.setMinSup(2);
        //构造数据集
        data = apriori.buildData();

        //构造频繁1项集
        List<Set<String>> f1Set = apriori.findF1Item(data);
        apriori.printSet(f1Set, 1);
        List<Set<String>> result = f1Set;

        int i = 2;
        do{
            result = apriori.arioriGen(result);
            apriori.printSet(result, i);
            i++;
        }while(result.size() != 0);
        long endTime = System.currentTimeMillis();
        System.out.println("共用时: " + (endTime - startTime) + "ms");
    }
    public void setMinSup(int minSup){
        this.minSup = minSup;
    }

    /**
     * 构造原始数据集,可以为之提供参数,也可以不提供
     * 如果不提供参数,将按程序默认构造的数据集
     * 如果提供参数为文件名,则使用文件中的数据集
     */
    List<String> buildData(String...fileName){
        List<String> data = new ArrayList<String>();
        if(fileName.length != 0){
            File file = new File(fileName[0]);
            try{
                BufferedReader reader = new BufferedReader(new FileReader(file));
                String line;
                while( ( line = reader.readLine()) != null ){
                    data.add(line);
                }
            }catch (FileNotFoundException e){
                e.printStackTrace();
            }catch (IOException e){
                e.printStackTrace();
            }
        }else{
            data.add("I1 I2 I5");
            data.add("I2 I4");
            data.add("I2 I3");
            data.add("I1 I2 T4");
            data.add("I1 I3");
            data.add("I2 I3");
            data.add("I1 I3");
            data.add("I1 I2 I3 I5");
            data.add("I1 I2 I3");
        }

        dataSet = new ArrayList<Set<String>>();
        Set<String> dSet;
        for(String d : data){
            dSet = new TreeSet<String>();
            String[] dArr = d.split(" ");
            for(String str : dArr){
                dSet.add(str);
            }
            dataSet.add(dSet);
        }
        return data;
    }

    /**
     * 找出候选1项集
     * @param data
     * @return result
     */
    List<Set<String>> findF1Item(List<String> data){
        List<Set<String>> result = new ArrayList<Set<String>>();
        Map<String, Integer> dc = new HashMap<String, Integer>();
        for(String d : data){
            String[] items = d.split(" ");
            for(String item : items){
                if(dc.containsKey(item)) {
                    dc.put(item, dc.get(item)+1);
                }else{
                    dc.put(item, 1);
                }
            }
        }
        Set<String> itemKeys = dc.keySet();
        Set<String> tempKeys = new TreeSet<String>();
        for(String str : itemKeys){
            tempKeys.add(str);
        }

        for(String item : tempKeys){
            if(dc.get(item) >= minSup) {
                Set<String> f1Set = new TreeSet<String>();
                f1Set.add(item);
                result.add(f1Set);
            }
        }
        return result;
    }

    /**
     * 利用arioriGen方法由k-1项集生成k项集
     *@param preSet
     *@return
     *
     */
    List<Set<String>> arioriGen(List<Set<String>> preSet) {

        List<Set<String>> result = new ArrayList<Set<String>>();
        int preSetSize = preSet.size();

        for(int i = 0; i < preSetSize - 1; i++){
            for(int j = i + 1; j < preSetSize; j++ ){
                String[] strA1 = preSet.get(i).toArray(new String[0]);
                String[] strA2 = preSet.get(j).toArray(new String[0]);
                if(isCanLink(strA1, strA2)) {//判断两个k-1项集是否符合连接成K项集的条件
                    Set<String> set = new TreeSet<String>();
                    for(String str : strA1){
                        set.add(str);//将strA1加入set中连成前K-1项集
                    }
                    set.add((String) strA2[strA2.length-1]);//连接成K项集
                    //判断K项集是否需要剪切掉,如果不需要被cut掉,则加入到k项集的列表中
                    if(!isNeedCut(preSet, set)) {
                        result.add(set);
                    }
                }
            }
        }
        return checkSupport(result);//返回的都是频繁K项集
    }

    /**
     * 把set中的项集与数量集比较并进行计算,求出支持度大于要求的项集
     * @return
     */
    List<Set<String>> checkSupport(List<Set<String> > setList){

        List<Set<String>> result = new ArrayList<Set<String>>();
        boolean flag = true;
        int [] counter = new int[setList.size()];
        for(int i = 0; i < setList.size(); i++){

            for(Set<String> dSets : dataSet) {
                if(setList.get(i).size() > dSets.size()){
                    flag = true;
                }else{
                    for(String str : setList.get(i)){
                        if(!dSets.contains(str)){
                            flag = false;
                            break;
                        }
                    }
                    if(flag) {
                        counter[i] += 1;
                    } else{
                        flag = true;
                    }
                }
            }
        }

        for(int i = 0; i < setList.size(); i++){
            if (counter[i] >= minSup) {
                result.add(setList.get(i));
            }
        }
        return result;
    }

    /**
     * 判断两个项集能否执行连接操作
     * @param s1
     * @param s2
     * @return
     */
    boolean isCanLink(String [] s1, String[] s2){
        boolean flag = true;
        if(s1.length == s2.length) {
            for(int i = 0; i < s1.length - 1; i ++){
                if(!s1[i].equals(s2[i])){
                    flag = false;
                    break;
                }
            }
            if(s1[s1.length - 1].equals(s2[s2.length - 1])){
                flag = false;
            }
        }else{
            flag = true;
        }
        return flag;
    }

    /**
     * 判断set是否需要被cut
     *
     * @param setList
     * @param set
     * @return
     */
    boolean isNeedCut(List<Set<String>> setList, Set<String> set) {//setList指频繁K-1项集,set指候选K项集
        boolean flag = false;
        List<Set<String>> subSets = getSubset(set);//获得K项集的所有k-1项集
        for ( Set<String> subSet : subSets) {
            //判断当前的k-1项集set是否在频繁k-1项集中出现,如果出现,则不需要cut
            //若没有出现,则需要被cut
            if( !isContained(setList, subSet)){
                flag = true;
                break;
            }
        }
        return flag;
    }
    /**
     * 功能:判断k项集的某k-1项集是否包含在频繁k-1项集列表中
     *
     * @param setList
     * @param set
     * @return
     */
    boolean isContained(List<Set<String>> setList, Set<String> set){
        boolean flag = false;
        int position = 0;
        for( Set<String> s : setList  ) {
            String [] sArr = s.toArray(new String[0]);
            String [] setArr = set.toArray(new String[0]);
            for(int i = 0; i < sArr.length; i++) {
                if ( sArr[i].equals(setArr[i])){
                    //如果对应位置的元素相同,则position为当前位置的值
                    position = i;
                } else{
                    break;
                }
            }
            //如果position等于数组的长度,说明已经找到某个setList中的集合与
            //set集合相同了,退出循环,返回包含
            //否则,把position置为0进入下一个比较
            if ( position == sArr.length - 1) {
                flag = true;
                break;
            } else {
                flag = false;
                position = 0;
            }
        }
        return flag;
    }

    /**
     * 获得k项集的所有k-1项子集
     *
     * @param set
     * @return
     */
    List<Set<String>> getSubset(Set <String> set){

        List<Set<String>> result = new ArrayList<Set<String>>();
        String [] setArr = set.toArray(new String[0]);

        for( int i = 0; i < setArr.length; i++){
            Set<String> subSet = new TreeSet<String>();
            for(int j = 0; j < setArr.length; j++){
                if( i != j){
                    subSet.add((String) setArr[j]);
                }
            }
            result.add(subSet);
        }
        return result;
    }
    /**
     * 功能:打印频繁项集
     */
    void printSet(List<Set<String>> setList, int i){
        System.out.print("频繁" + i + "项集: 共" + setList.size() + "项: {");
        for(Set<String> set : setList) {
            System.out.print("[");
            for(String str : set) {
                System.out.print(str + " ");
            }
            System.out.print("], ");
        }
        System.out.println("}");
    }
}

(2)运行结果

三、总结

在日常生活中,利用关联分析能够帮助我们发现大型数据集之间的关联规则,描述数据之间的密切度,同时又有利于我们进行决策等。

参考

数据挖掘十大算法—— Apriori – 知乎 (zhihu.com)

Apriori算法详解 – 知乎 (zhihu.com)

数据挖掘入门系列教程(四点五)之Apriori算法 – 知乎 (zhihu.com)

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
乘风的头像乘风管理团队
上一篇 2023年12月7日
下一篇 2023年12月7日

相关推荐