diff --git a/src/commands/branch.rs b/src/commands/branch.rs index e096839..867a103 100644 --- a/src/commands/branch.rs +++ b/src/commands/branch.rs @@ -4,6 +4,7 @@ use crate::{ head, models::{commit::Commit, object::Hash}, store, + utils::util, }; // branch error @@ -30,12 +31,12 @@ fn search_hash(commit_hash: Hash) -> Option { fn create_branch(branch_name: String, _base_commit: Hash) -> Result<(), BranchErr> { // 找到正确的base_commit_hash let base_commit = search_hash(_base_commit.clone()); - if base_commit.is_none() { + if base_commit.is_none() || util::check_object_type(base_commit.clone().unwrap()) != util::ObjectType::Commit { println!("fatal: 非法的 commit: '{}'", _base_commit); return Err(BranchErr::InvalidObject); } - let base_commit = Commit::load(&base_commit.unwrap()); // TODO 这里会直接panic,可以优化一下错误处理流程 + let base_commit = Commit::load(&base_commit.unwrap()); let exist_branchs = head::list_local_branches(); if exist_branchs.contains(&branch_name) { diff --git a/src/commands/switch.rs b/src/commands/switch.rs index fdc6c78..fd3fdd8 100644 --- a/src/commands/switch.rs +++ b/src/commands/switch.rs @@ -6,6 +6,7 @@ use crate::{ head::{self}, models::{commit::Commit, object::Hash}, store::Store, + utils::util, }; use super::{ @@ -51,10 +52,11 @@ fn switch_to(branch: String, detach: bool) -> Result<(), SwitchErr> { println!("切换到分支: '{}'", branch.green()) } else if detach { let commit = store.search(&branch); - if commit.is_none() { + if commit.is_none() || util::check_object_type(commit.clone().unwrap()) != util::ObjectType::Commit { println!("fatal: 非法的 commit: '{}'", branch); return Err(SwitchErr::InvalidObject); } + // 切到commit let commit = commit.unwrap(); switch_to_commit(None, commit.clone()); diff --git a/src/utils/util.rs b/src/utils/util.rs index 6e85b98..6b62da7 100644 --- a/src/utils/util.rs +++ b/src/utils/util.rs @@ -6,6 +6,8 @@ use std::{ path::{Path, PathBuf}, }; +use crate::models::{commit::Commit, object::Hash, tree::Tree}; + pub const ROOT_DIR: &str = ".mit"; pub const TEST_DIR: &str = "mit_test_storage"; // 执行测试的储存库 @@ -279,8 +281,34 @@ pub fn get_absolute_path(path: &Path) -> PathBuf { } } +#[derive(Debug, PartialEq)] +pub enum ObjectType { + Blob, + Tree, + Commit, + Invalid, +} +pub fn check_object_type(hash: Hash) -> ObjectType { + let path = get_storage_path().unwrap().join("objects").join(hash.to_string()); + if path.exists() { + let data = fs::read_to_string(path).unwrap(); + let result: Result = serde_json::from_str(&data); + if result.is_ok() { + return ObjectType::Commit; + } + let result: Result = serde_json::from_str(&data); + if result.is_ok() { + return ObjectType::Tree; + } + return ObjectType::Blob; + } + ObjectType::Invalid +} + #[cfg(test)] mod tests { + use crate::models::{blob::Blob, index::Index}; + use super::*; #[test] @@ -325,4 +353,17 @@ mod tests { Err(err) => println!("{}", err), } } + + #[test] + fn test_check_object_type() { + setup_test_with_clean_mit(); + assert_eq!(check_object_type("123".into()), ObjectType::Invalid); + ensure_test_file(Path::new("test.txt"), Some("test")); + let hash = Blob::new(get_working_dir().unwrap().join("test.txt").as_path()).get_hash(); + assert_eq!(check_object_type(hash), ObjectType::Blob); + let mut commit = Commit::new(&Index::new(), vec![], "test".to_string()); + assert_eq!(check_object_type(commit.get_tree_hash()), ObjectType::Tree); + commit.save(); + assert_eq!(check_object_type(commit.get_hash()), ObjectType::Commit); + } }