Skip to content

Commit 1b643a4

Browse files
committed
Tweaks
1 parent 1d60664 commit 1b643a4

File tree

5 files changed

+40
-32
lines changed

5 files changed

+40
-32
lines changed

bin/src/main.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ use std::path::PathBuf;
77
#[derive(Parser)]
88
#[command(version, about, long_about = None, arg_required_else_help = true)]
99
struct Args {
10-
/// Expression to evaluate
11-
#[arg(short, long, conflicts_with_all = ["input", "expr", "symbol_table"])]
10+
/// Expression to evaluate (can be provided as positional argument or with -e/--expression)
11+
#[arg(short, long, conflicts_with = "input")]
1212
expression: Option<String>,
1313

14-
/// Expression to evaluate (positional)
14+
/// Expression to evaluate (positional, alternative to -e/--expression)
1515
#[arg(conflicts_with_all = ["expression", "input", "symbol_table"])]
1616
expr: Option<String>,
1717

lib/src/lib.rs

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ mod symbol;
2525
mod token;
2626
mod vm;
2727

28-
use std::{borrow::Cow, fs, path::PathBuf};
28+
use std::{borrow::Cow, fmt, fs, path::PathBuf};
2929

3030
// Public API
3131
pub use ir::IrBuilder;
@@ -40,6 +40,25 @@ pub use source::Source;
4040
pub use symbol::{SymTable, Symbol, SymbolError};
4141
pub use vm::{Vm, VmError};
4242

43+
/// A wrapper that formats errors with source code highlighting
44+
struct FormattedError {
45+
message: String,
46+
}
47+
48+
impl fmt::Display for FormattedError {
49+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50+
write!(f, "{}", self.message)
51+
}
52+
}
53+
54+
impl<T: SpanError> From<(&T, &Source<'_>)> for FormattedError {
55+
fn from((error, source): (&T, &Source<'_>)) -> Self {
56+
Self {
57+
message: format!("{}\n{}", error, source.highlight(&error.span())),
58+
}
59+
}
60+
}
61+
4362
#[derive(Debug)]
4463
enum EvalSource<'str> {
4564
Source(Cow<'str, Source<'str>>),
@@ -97,14 +116,14 @@ impl<'str> Eval<'str> {
97116
let mut parser = Parser::new(source);
98117
let mut ast: Expr = match parser
99118
.parse()
100-
.map_err(|err| Self::error_with_source(&err, source))?
119+
.map_err(|err| FormattedError::from((&err, source.as_ref())).to_string())?
101120
{
102121
Some(ast) => ast,
103122
None => return Ok(Program::default()),
104123
};
105124
Sema::new(table)
106125
.visit(&mut ast)
107-
.map_err(|err| Self::error_with_source(&err, source))?;
126+
.map_err(|err| FormattedError::from((&err, source.as_ref())).to_string())?;
108127
IrBuilder::new().build(&ast).map_err(|err| err.to_string())
109128
}
110129
EvalSource::File(path) => {
@@ -113,8 +132,4 @@ impl<'str> Eval<'str> {
113132
}
114133
}
115134
}
116-
117-
fn error_with_source<T: SpanError>(error: &T, source: &Source) -> String {
118-
format!("{}\n{}", error, source.highlight(&error.span()))
119-
}
120135
}

lib/src/program.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ use std::collections::HashMap;
77
use serde::{Deserialize, Serialize};
88
use thiserror::Error;
99

10+
/// Current version of the program format
11+
const PROGRAM_VERSION: &str = env!("CARGO_PKG_VERSION");
12+
1013
/// Expression parsing and evaluation errors.
1114
#[derive(Error, Debug)]
1215
pub enum ProgramError {
@@ -87,7 +90,7 @@ enum BinaryInstr {
8790
impl<'sym> Program<'sym> {
8891
pub fn new() -> Self {
8992
Self {
90-
version: env!("CARGO_PKG_VERSION").to_string(),
93+
version: PROGRAM_VERSION.to_string(),
9194
code: Vec::new(),
9295
}
9396
}
@@ -169,10 +172,9 @@ impl<'sym> Program<'sym> {
169172
}
170173

171174
fn validate_version(version: &String) -> Result<(), ProgramError> {
172-
let current_version = env!("CARGO_PKG_VERSION");
173-
if version != current_version {
175+
if version != PROGRAM_VERSION {
174176
return Err(ProgramError::IncompatibleVersions(
175-
current_version.to_string(),
177+
PROGRAM_VERSION.to_string(),
176178
version.clone(),
177179
));
178180
}

lib/src/token.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,11 @@ impl<'src> Token<'src> {
4646
}
4747

4848
pub fn is_right_associative(&self) -> bool {
49-
match self {
50-
Token::Caret => true,
51-
_ => false,
52-
}
49+
matches!(self, Token::Caret)
5350
}
5451

5552
pub fn is_postfix_unary(&self) -> bool {
56-
match self {
57-
Token::Bang => true,
58-
_ => false,
59-
}
53+
matches!(self, Token::Bang)
6054
}
6155

6256
pub fn lexeme(&self) -> Cow<'src, str> {

lib/src/vm.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,17 +197,14 @@ impl Vm {
197197
return Err(VmError::InvalidFactorial { value: n });
198198
}
199199

200-
// Calculate factorial using safe multiplication
200+
// Calculate factorial using safe multiplication with iterator
201201
let n_u64 = n.to_u64().unwrap();
202-
let mut result = Decimal::ONE;
203-
for i in 1..=n_u64 {
204-
result =
205-
result
206-
.checked_mul(Decimal::from(i))
207-
.ok_or_else(|| VmError::ArithmeticError {
208-
message: format!("Factorial calculation overflow at {}!", i),
209-
})?;
210-
}
202+
let result = (1..=n_u64).try_fold(Decimal::ONE, |acc, i| {
203+
acc.checked_mul(Decimal::from(i))
204+
.ok_or_else(|| VmError::ArithmeticError {
205+
message: format!("Factorial calculation overflow at {}!", i),
206+
})
207+
})?;
211208

212209
stack.push(result);
213210
Ok(())

0 commit comments

Comments
 (0)