超能饮料问题

今天讨论超能饮料的问题。这里是题目的来源。这个超能饮料的题目,颇有些炒作的嫌疑。据说这道题是程序员百万年薪招聘的题目之一。今天也是无意中从同学那里知道,这里就简单给出思路,然后给出Python语言的实现版本。

这里简单说一下超能饮料问题的大意。现在,已知有若干种原料,它们以C开头,比如:C1、C2.......一种原料转化为另一种原料需要使用定制的机器,而这也是成本主要来源(原料成本不计)。机器标记以M开头,如:M1、M2......

现在我们有一系列的输入(命令行参数形式,给出文件名,以文件形式读入),文件的每一行都代表特定的设备型号。每一行的格式如下:

<machine name>  <orig compound>  <new compound>  <price>(分别对应的中文意思是设备名称、源化合物、新化合物、价格)

输入文件示例:

M1  C1  C2  277317
M2  C2  C1  26247
M3  C1  C3  478726
M4  C3  C1  930382
M5  C2  C3  370287
M6  C3  C2  112344

题目的要求是,求无论能买到那种原料,都保证能生产出其他原料,并且保证花费的代价最小。

要求返回值的格式为:第一行,最优总价格。第二行,机器列表(去掉开头M),比如上例的输出示例:

617317
2 3 6

读完这条题目,我们可以发现,这题目其实是一条典型的图论题。对于某一个状态Ci,我们可以通过Mk转移至Cj,权重为Wm。于是,根据题意,我们实际上要求的是这样一条或者多条环路,使得它们包含所有的原料,并且,环路上所有的权重和最小。

在此基础上,我们必须满足两点。

  1. 这条环路必须包含所有的原料。如果不能满足包含全部原料,就不能保证“无论能买到哪种原料,都保证能生产出其他原料”的条件。
  2. 对于重复出现的状态转移,只需计算一遍。

综上所述,问题要解决的在已知图中求一条或多条环路,并且使这些环路中包含所有原料。

首先,我们需要表示图。我们使用邻接表表示:

class Node(object):
    '''表示图中的结点'''
    def __init__(self, name):
        # 原料名
        self.name = name

        # 原料相连的边,字典
        # key: Node类
        # value: Link类
        self.links = {}

    def __str__(self):
        return self.name

    def add_link(self, link, node):
        self.links[node] = link

    def nodes(self):
        return self.links.keys

class Link(object):
    def __init__(self, machine, price, start=None, end=None):
        self.machine = machine # 机器名
        self.price = price # 价格
        self.start = start # 开始节点
        self.end = end # 结束节点

我们再写一个工厂类来获取Node,并且来解析输入文件。

import threading
class NodeFactory(object):
    '''Node工厂类'''
    def __init__(self, nodes=None):
        if nodes is None:
            self.nodes = []
        else:
            self.nodes = nodes

        # 锁
        self.lock = threading.Lock()

    def get(self, node_name):
        '''
        这里涉及同步问题,设置锁,确保任意时刻只有一个方法访问self.nodes
        '''
        self.lock.acquire() # 获取锁
        try:
            r_node = [node for node in self.nodes if node.name==node_name]
            length = len(r_node)
            assert(length <=1 and length >=0)

            if length == 1:
                return r_node[0]

            node = Node(node_name)
            self.nodes.append(node)

            return node
        finally:
            self.lock.release()
        
    def parse(self, iter):
        '''iter是个可迭代对象'''
        for line in iter:
            eles = [l for l in line.split(" ") if len(l)>0]
            assert(len(eles) == 4)

            link = Link(eles[0], int(eles[3]))
            from_node = self.get(eles[1])
            to_node = self.get(eles[2])
            link.start = from_node
            link.end = to_node
            from_node.add_link(link, to_node)

    def parse_file(self, path):
        '''解析文件,path为绝对路径'''
        _parse_file = file(path)
        self.parse(_parse_file.readlines())

下面到了算法的核心部分。要得到要求环路,自然想到的是图的深度遍历算法(利用深度遍历算法很容易检测到环路),对于某个节点,如果在之前路径没有出现,就压入列表;否则,就得到一条回路。

def get_result(node_fac):
    '''
    得到结果,参数node_fac为NodeFactory类的实例
    返回值0: 机器
    返回值1: 最小代价价格
    '''
    loops = [] # 存放所有的回路

    start_node = node_fac.nodes[0] # 开始节点
    length = len(node_fac.nodes) # 节点总个数
    trav_list = [] # 存放已经遍历过的节点
    links = [] # 存放当前遍历的路径

    def trav_node(node):
        # 嵌套递归函数,深度遍历所有节点。
        trav_list.append(node)
        for n, lnk in node.links.items():
            links.append(lnk)
            if n not in trav_list: # 如果这个节点还没有遍历到
                trav_node(n) #递归调用
            else:
                loop_links = sorted(links[trav_list.index(n):])
                # 返回这个环路
                if loop_links not in loops:
                    loops.append(loop_links)
            links.remove(lnk)
        trav_list.remove(node)

    trav_node(start_node)

    def check_loop(aloop): 
        # 检查若干个回路中是否包含有全部的节点。
        nodes = []
        for loop in aloop:
            for l in loop:
                if l.start not in nodes:
                    nodes.append(l.start)
            if len(nodes) == length:
                return True
        return False

    def merge_loop(aloop): 
        # 合并回路中重复的路径
        merge_links = []
        for loop in aloop:
            for l in loop:
                if l not in merge_links:
                    merge_links.append(l)
        return merge_links
        
    loop_stack = [] # 临时存放符合条件的回路
    results = [] # 存放所有符合条件的回路
    for loop in loops:
        loop_stack.append(loop)
        
        if check_loop(loop_stack): 
            # 如果一个环路已经包含全部节点,就不需添加其他环路。
            results.append(merge_loop(loop_stack))
        else:
            for l in loops[loops.index(loop)+1:]:
                loop_stack.append(l)
                if check_loop(loop_stack):
                    results.append(merge_loop(loop_stack))
                    loop_stack.remove(l)
        loop_stack = []

    prices = -1 # 价格
    machines = [] # 机器

    for r in results:
        r_links = []
        for lnk in r:
            if lnk not in r_links:
                r_links.append(lnk)
        p = sum([l.price for l in r_links]) # 计算每种可能所需要的代价
        
        if prices == -1 or prices > p:
            prices = p
            machines = [l.machine for l in r_links]
                    
    return machines, prices 

最后我们要得到输出的格式,这么写:

if __name__ == "__main__":
    node_fac = NodeFactory()

    file_name = raw_input(u"Please input the file path: ")
    try:
        node_fac.parse_file(file_name)
    except:
        parse_list = ['M1  C1  C2  277317',
                      'M2  C2  C1  26247',
                      'M3  C1  C3  478726',
                      'M4  C3  C1  930382',
                      'M5  C2  C3  370287',
                      'M6  C3  C2  112344']
        node_fac.parse(parse_list)
        
    result = get_result(node_fac)
    result_machines = sorted([int(m[1:]) for m in result[0]])
    print result[1]
    print " ".join([str(m) for m in result_machines])

最后讨论一下时间复杂度,由于每个边和顶点都要遍历,所以时间复杂度为O(n+e)(n为顶点数,e为边数)。但在本例中,边数实际上是O(n^2),所以时间复杂度O(n^2)。所以完成算法,当顶点数变多的时候,需要较大的代价。

赞这篇文章

分享到

1个评论

给作者留言

关于作者

残阳似血(@秦续业),程序猿一枚,把梦想揣进口袋的挨踢工作者。现加入阿里云,研究僧毕业于上海交通大学软件学院ADC实验室。熟悉分布式数据分析(DataFrame并行化框架)、基于图模型的分布式数据库和并行计算、Dpark/Spark以及Python web开发(Django、tornado)等。

博客分类

点击排行

标签云

扫描访问

主题

残阳似血的微博