超能饮料问题
今天讨论超能饮料的问题。这里是题目的来源。这个超能饮料的题目,颇有些炒作的嫌疑。据说这道题是程序员百万年薪招聘的题目之一。今天也是无意中从同学那里知道,这里就简单给出思路,然后给出Python语言的实现版本。
这里简单说一下超能饮料问题的大意。现在,已知有若干种原料,它们以C开头,比如:C1、C2.......一种原料转化为另一种原料需要使用定制的机器,而这也是成本主要来源(原料成本不计)。机器标记以M开头,如:M1、M2......
现在我们有一系列的输入(命令行参数形式,给出文件名,以文件形式读入),文件的每一行都代表特定的设备型号。每一行的格式如下:
<machine name> <orig compound> <new compound> <price>(分别对应的中文意思是设备名称、源化合物、新化合物、价格)
输入文件示例:
M1 C1 C2 277317
M2 C2 C1 26247
M3 C1 C3 478726
M4 C3 C1 930382
M5 C2 C3 370287
M6 C3 C2 112344
题目的要求是,求无论能买到那种原料,都保证能生产出其他原料,并且保证花费的代价最小。
要求返回值的格式为:第一行,最优总价格。第二行,机器列表(去掉开头M),比如上例的输出示例:
617317
2 3 6
读完这条题目,我们可以发现,这题目其实是一条典型的图论题。对于某一个状态Ci,我们可以通过Mk转移至Cj,权重为Wm。于是,根据题意,我们实际上要求的是这样一条或者多条环路,使得它们包含所有的原料,并且,环路上所有的权重和最小。
在此基础上,我们必须满足两点。
- 这条环路必须包含所有的原料。如果不能满足包含全部原料,就不能保证“无论能买到哪种原料,都保证能生产出其他原料”的条件。
- 对于重复出现的状态转移,只需计算一遍。
综上所述,问题要解决的在已知图中求一条或多条环路,并且使这些环路中包含所有原料。
首先,我们需要表示图。我们使用邻接表表示:
class Node(object): '''表示图中的结点''' def __init__(self, name): # 原料名 self.name = name # 原料相连的边,字典 # key: Node类 # value: Link类 self.links = {} def __str__(self): return self.name def add_link(self, link, node): self.links[node] = link def nodes(self): return self.links.keys class Link(object): def __init__(self, machine, price, start=None, end=None): self.machine = machine # 机器名 self.price = price # 价格 self.start = start # 开始节点 self.end = end # 结束节点
我们再写一个工厂类来获取Node,并且来解析输入文件。
import threading class NodeFactory(object): '''Node工厂类''' def __init__(self, nodes=None): if nodes is None: self.nodes = [] else: self.nodes = nodes # 锁 self.lock = threading.Lock() def get(self, node_name): ''' 这里涉及同步问题,设置锁,确保任意时刻只有一个方法访问self.nodes ''' self.lock.acquire() # 获取锁 try: r_node = [node for node in self.nodes if node.name==node_name] length = len(r_node) assert(length <=1 and length >=0) if length == 1: return r_node[0] node = Node(node_name) self.nodes.append(node) return node finally: self.lock.release() def parse(self, iter): '''iter是个可迭代对象''' for line in iter: eles = [l for l in line.split(" ") if len(l)>0] assert(len(eles) == 4) link = Link(eles[0], int(eles[3])) from_node = self.get(eles[1]) to_node = self.get(eles[2]) link.start = from_node link.end = to_node from_node.add_link(link, to_node) def parse_file(self, path): '''解析文件,path为绝对路径''' _parse_file = file(path) self.parse(_parse_file.readlines())
下面到了算法的核心部分。要得到要求环路,自然想到的是图的深度遍历算法(利用深度遍历算法很容易检测到环路),对于某个节点,如果在之前路径没有出现,就压入列表;否则,就得到一条回路。
def get_result(node_fac): ''' 得到结果,参数node_fac为NodeFactory类的实例 返回值0: 机器 返回值1: 最小代价价格 ''' loops = [] # 存放所有的回路 start_node = node_fac.nodes[0] # 开始节点 length = len(node_fac.nodes) # 节点总个数 trav_list = [] # 存放已经遍历过的节点 links = [] # 存放当前遍历的路径 def trav_node(node): # 嵌套递归函数,深度遍历所有节点。 trav_list.append(node) for n, lnk in node.links.items(): links.append(lnk) if n not in trav_list: # 如果这个节点还没有遍历到 trav_node(n) #递归调用 else: loop_links = sorted(links[trav_list.index(n):]) # 返回这个环路 if loop_links not in loops: loops.append(loop_links) links.remove(lnk) trav_list.remove(node) trav_node(start_node) def check_loop(aloop): # 检查若干个回路中是否包含有全部的节点。 nodes = [] for loop in aloop: for l in loop: if l.start not in nodes: nodes.append(l.start) if len(nodes) == length: return True return False def merge_loop(aloop): # 合并回路中重复的路径 merge_links = [] for loop in aloop: for l in loop: if l not in merge_links: merge_links.append(l) return merge_links loop_stack = [] # 临时存放符合条件的回路 results = [] # 存放所有符合条件的回路 for loop in loops: loop_stack.append(loop) if check_loop(loop_stack): # 如果一个环路已经包含全部节点,就不需添加其他环路。 results.append(merge_loop(loop_stack)) else: for l in loops[loops.index(loop)+1:]: loop_stack.append(l) if check_loop(loop_stack): results.append(merge_loop(loop_stack)) loop_stack.remove(l) loop_stack = [] prices = -1 # 价格 machines = [] # 机器 for r in results: r_links = [] for lnk in r: if lnk not in r_links: r_links.append(lnk) p = sum([l.price for l in r_links]) # 计算每种可能所需要的代价 if prices == -1 or prices > p: prices = p machines = [l.machine for l in r_links] return machines, prices
最后我们要得到输出的格式,这么写:
if __name__ == "__main__": node_fac = NodeFactory() file_name = raw_input(u"Please input the file path: ") try: node_fac.parse_file(file_name) except: parse_list = ['M1 C1 C2 277317', 'M2 C2 C1 26247', 'M3 C1 C3 478726', 'M4 C3 C1 930382', 'M5 C2 C3 370287', 'M6 C3 C2 112344'] node_fac.parse(parse_list) result = get_result(node_fac) result_machines = sorted([int(m[1:]) for m in result[0]]) print result[1] print " ".join([str(m) for m in result_machines])
最后讨论一下时间复杂度,由于每个边和顶点都要遍历,所以时间复杂度为O(n+e)(n为顶点数,e为边数)。但在本例中,边数实际上是O(n^2),所以时间复杂度O(n^2)。所以完成算法,当顶点数变多的时候,需要较大的代价。
http://download.csdn.net/detail/yige2002/6516167
这些函数非常适合在程序中使用。在函数中使用它之前,必须先定义对象。