diff --git a/codes/rust/chapter_heap/heap.rs b/codes/rust/chapter_heap/heap.rs index ef4f35008..8bb0c6ba7 100644 --- a/codes/rust/chapter_heap/heap.rs +++ b/codes/rust/chapter_heap/heap.rs @@ -6,18 +6,28 @@ use hello_algo_rust::include::print_util; -use std::collections::BinaryHeap; +use std::{cmp::Reverse, collections::BinaryHeap}; -fn test_push(heap: &mut BinaryHeap, val: i32, flag: i32) { - heap.push(flag * val); // 元素入堆 +fn test_push_max(heap: &mut BinaryHeap, val: i32) { + heap.push(val); // 元素入堆 println!("\n元素 {} 入堆后", val); - print_util::print_heap(heap.iter().map(|&val| flag * val).collect()); + print_util::print_heap(heap.iter().map(|&val| val).collect()); +} +fn test_push_min(heap: &mut BinaryHeap>, val: i32) { + heap.push(Reverse(val)); // 元素入堆 + println!("\n元素 {} 入堆后", val); + print_util::print_heap(heap.iter().map(|&val| val.0).collect()); } -fn test_pop(heap: &mut BinaryHeap, flag: i32) { +fn test_pop_max(heap: &mut BinaryHeap) { let val = heap.pop().unwrap(); - println!("\n堆顶元素 {} 出堆后", flag * val); - print_util::print_heap(heap.iter().map(|&val| flag * val).collect()); + println!("\n堆顶元素 {} 出堆后", val); + print_util::print_heap(heap.iter().map(|&val| val).collect()); +} +fn test_pop_min(heap: &mut BinaryHeap>) { + let val = heap.pop().unwrap().0; + println!("\n堆顶元素 {} 出堆后", val); + print_util::print_heap(heap.iter().map(|&val| val.0).collect()); } /* Driver Code */ @@ -26,31 +36,29 @@ fn main() { // 初始化小顶堆 #[allow(unused_assignments)] let mut min_heap = BinaryHeap::new(); - // Rust 的 BinaryHeap 是大顶堆,当入队时将元素值乘以 -1 将其反转,当出队时将元素值乘以 -1 将其还原 - let min_heap_flag = -1; + // Rust 的 BinaryHeap 是大顶堆,小顶堆一般会“套上”Reverse // 初始化大顶堆 let mut max_heap = BinaryHeap::new(); - let max_heap_flag = 1; println!("\n以下测试样例为大顶堆"); /* 元素入堆 */ - test_push(&mut max_heap, 1, max_heap_flag); - test_push(&mut max_heap, 3, max_heap_flag); - test_push(&mut max_heap, 2, max_heap_flag); - test_push(&mut max_heap, 5, max_heap_flag); - test_push(&mut max_heap, 4, max_heap_flag); + test_push_max(&mut max_heap, 1); + test_push_max(&mut max_heap, 3); + test_push_max(&mut max_heap, 2); + test_push_max(&mut max_heap, 5); + test_push_max(&mut max_heap, 4); /* 获取堆顶元素 */ - let peek = max_heap.peek().unwrap() * max_heap_flag; + let peek = max_heap.peek().unwrap(); println!("\n堆顶元素为 {}", peek); /* 堆顶元素出堆 */ - test_pop(&mut max_heap, max_heap_flag); - test_pop(&mut max_heap, max_heap_flag); - test_pop(&mut max_heap, max_heap_flag); - test_pop(&mut max_heap, max_heap_flag); - test_pop(&mut max_heap, max_heap_flag); + test_pop_max(&mut max_heap); + test_pop_max(&mut max_heap); + test_pop_max(&mut max_heap); + test_pop_max(&mut max_heap); + test_pop_max(&mut max_heap); /* 获取堆大小 */ let size = max_heap.len(); @@ -65,9 +73,9 @@ fn main() { min_heap = BinaryHeap::from( vec![1, 3, 2, 5, 4] .into_iter() - .map(|val| min_heap_flag * val) - .collect::>(), + .map(|val| Reverse(val)) + .collect::>>(), ); println!("\n输入列表并建立小顶堆后"); - print_util::print_heap(min_heap.iter().map(|&val| min_heap_flag * val).collect()); + print_util::print_heap(min_heap.iter().map(|&val| val.0).collect()); }